├── .gitignore ├── README.md ├── setup.py ├── setup_utils.py ├── shifts.png ├── tests └── shifts_test.py ├── torch_patch.py └── torchshifts ├── __init__.py ├── csrc ├── macros.h ├── ops │ ├── autograd │ │ └── shifts_autograd.cpp │ ├── cpu │ │ └── shifts_cpu.cpp │ ├── cuda │ │ └── shifts_cuda.cu │ ├── global_scope.h │ ├── kernels │ │ ├── interpolation.h │ │ └── shifts_kernels.h │ ├── ops.h │ ├── quantized │ │ └── shifts_quantized.cpp │ ├── shifts.cpp │ └── shifts.h ├── torchshifts.cpp └── torchshifts.h ├── extension.py ├── functional.py ├── modules ├── __init__.py └── shifts.py └── quantized ├── __init__.py ├── functional.py └── modules ├── __init__.py └── shifts.py /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | torchshifts.egg-info/ 4 | */**/__pycache__ 5 | */__pycache__ 6 | */*.pyc 7 | **/*.ipynb_checkpoints/ 8 | *.txt 9 | */**/*.pyc 10 | */**/**/*.pyc 11 | */**/*~ 12 | *~ 13 | docs/build 14 | .coverage 15 | htmlcov 16 | .*.swp 17 | *.so* 18 | *.dylib* 19 | */*.so* 20 | */*.dylib* 21 | *.swp 22 | *.swo 23 | gen.yml 24 | .mypy_cache 25 | .vscode/ 26 | *.orig 27 | *-checkpoint.ipynb 28 | torchshifts/version.py 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PyTorch implementation of Sparse Shift Layer(SSL) for 3D, 4D and 5D tensors from "All You Need is a Few Shifts: Designing Efficient Convolutional Neural Networks 2 | for Image Classification" (https://arxiv.org/pdf/1903.05285.pdf) 3 | 4 | (**I am not the author** any of mentioned articles, I just implement this for my own purposes) 5 | ## !FOR PYTORCH >= 1.7! ## 6 | 7 | ## Theory 8 | 9 | ### [Shift operation](https://en.wikipedia.org/wiki/Shift_operator): 10 | 11 | shifts tensor data(in memory) by indexes. Value and direction of shift are learnable and different between channels. 12 | It might be considered as Zero-FLOP replacement of DepthWise Convolution, with 4.5x less memory consumption(in compare with 3x3 DepthWise Convolution). 13 | 14 | ### Articles summary: 15 | * [GroupedShift](https://arxiv.org/pdf/1711.08141.pdf): First known application of shifts operator as replace of depthwise convolution. It utilize shifts as their exact form on forward and backward, hence the shifts values (weights) are not learnable (and for simplicity applied to group of channels, see article for detail) and act like hyperparams. 16 | 17 | (Officially we have not support this kind of shifts here, but for exact 18 | 19 | * [Active Shift](https://arxiv.org/pdf/1806.07370.pdf): Replacing shift operation on linear(bi-,tri- for 2D,3D cases) interpolation on both forward and backward pass. "Shifts" values became learnable (because they are floats) and moreover shifts defined for each channel. 20 | 21 | * [Sparse Shift Layer(SSL)]( (https://arxiv.org/pdf/1903.05285.pdf)): The combination of two above articles. "Shifts" values are still learnable vi interpolation(on backward pass), and use exact shift operator on forward pass ("shift" values just rounded during forward pass). So we have simple Zero-FLOP shift operation (which is also native quantized, because shift operator require integer values), instead of DepthWise convolution! 22 | 23 | Sparse - stands to L1 regularization on weights, this obviously sparsifying the shifts values among channel axis! 24 | 25 | 26 | ![alt text](https://github.com/DeadAt0m/ActiveSparseShifts-PyTorch/raw/master/shifts.png "Shifts evolution") 27 | 28 | ## Implementation details: 29 | 30 | * By default all Shift modules are Sparse Shift Layers! The module is always returns ```output``` and ```loss```, where is last is L1 regularization loss(see theory), which should be added to general loss for take an effect! 31 | 32 | * Active Shift can be enabled by setting ```active_flag=True```, and ```sparsity_term=0```, because we do not need to compute regularization term(at least in original article). 33 | 34 | * Grouped Shifts are not officially supported here, however technically it possible: set ```active_flag=False``` and ```sparsity_term=0```, freeze ```.weights``` params from gradient computation like ```shift_layer.weights.requires_grad = False``` (inside C function the gradient for weights will be always computed, so you will not gain in performance) and don't forget properly re-initialize ```.weights``` values(including channels groups, etc.) 35 | 36 | * We implement several padding variants for filling empty values after shifts: 37 | Zeros (by default), Border, Periodic(stands for circular shifts!), Reflect and Symmetric. See [here](https://pywavelets.readthedocs.io/en/latest/ref/signal-extension-modes.html) for details.(This paddings is also used during interpolation calculation) 38 | 39 | 40 | ## Requirements: 41 | C++17 must be supported by your compiler! (due to constexpr in code) 42 | PyTorch >= 1.7.0; 43 | 44 | ## Instalation: 45 | 1. Clone this repo and ```cd ActiveSparseShifts-PyTorch``` 46 | 2. (optional)If you compile with CUDA, please pass path to nvcc to CUDA_HOME env variable! 47 | 3. **Important!** There is bug in PyTorch which can lead to crash during build under CUDA. 48 | This bug was fixed in PyTorch 1.8. However it easy to fix it in previous versions. 49 | Run ```python torch_patch.py```(anyway it will automatically run during step 3) to fix it. 50 | This script change a few lines of code in single C++ header file, however doing this directly in python dist-package folder. 51 | Please, be sure that you have rights for changing files inside this folder! 52 | Anyway, you should do it only once for each python environment(PyTorch package). 53 | (If something will going wrong, please inspect ```torch_patch.py``` first (it very simple) and try to reproduce patch manually.) 54 | 4. Run ```python setup.py install``` or ```python setup.py bdist_wheel``` - to install/build package 55 | 56 | 57 | ## Using: 58 | 59 | Example: 60 | 61 | from torchshifts import Shift1d, Shift2d, Shift3d 62 | shift_layer = Shift1d(in_channels=3) 63 | 64 | Additional options for shift layer: 65 | 66 | padding(str) - Padding for filling empty values. 67 | Allowed: ['zeros', 'border', 'periodic', 'reflect', 'symmetric']. Default: 'zeros'. 68 | init_shift(float) - Border for uniform initialization of weights(shifts): [-init_shift; init_shift]. Default: 1. 69 | sparsity_term(float) - Strength of sparsity. Default: 5e-4. 70 | active_flag(bool) - Enable Active Shift instead of SSL. Default: False 71 | emulate_dw(dict) - Just pass params of depthwise conv, that you trying replace with shift layer. 72 | It applies a heuristic and try to emulate their properties(including output shape) 73 | init_thumb_rule(int) - Type of thumb rule for shifts initialization. Allowed: Type 1(default): uniform(-init_shift, init_shift), 74 | Type 2: uniform(0,init_shift) * random_sign 75 | 76 | 77 | 78 | ## Additionals: 79 | 1. Depthwise Convolution Emulation: 80 | Provides a heuristic rules for emulation of DepthWise Convolution via Shift layer 81 | in terms of output shape and shift kernel behaviour. 82 | 83 | 1. This directly influence on proper shift param initialization. 84 | 2. Output shape via cutting the output and pooling(depending on stride) 85 | 3. Automatically using AveragePooling for emulation stride > 1 86 | 87 | 2. Pytorch Quantization: SSL shifts can be used in quantized pipeline! 88 | Shifts do not needed the activation tracking and so model with shift module can be easily converted by following: 89 | ``` 90 | from torchshifts import quant_mapping 91 | torch.quantization.convert(, ..., mapping=quant_mapping) 92 | ``` 93 | 3. Pytorch JIT: We support it out-of-box: 94 | ``` torch.jit.trace_module() ``` 95 | 96 | 97 | ## Update Notes: 98 | 1. (05.05.2021) Compatibility with Pytorch 1.8.1 99 | 100 | ## TO DO: 101 | 1. Add unit tests(yes I still make testing in some strange manners) 102 | 2. Speed up the ops on CUDA, still slower than Pytorch's 3x3 DW Convolution 103 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | MODULE_NAME = 'torchshifts' 2 | MODULE_VERSION = '3.1' 3 | #DO NOT CHANGE ON EARLIER STANDARDS PLEASE 4 | #(We use c++17 for using "constexpr" in our code) 5 | STD_VERSION = "c++17" 6 | PYTORCH_VERSION = "1.7" 7 | 8 | 9 | import sys, os, copy 10 | from setuptools import setup, find_packages 11 | from torch.utils.cpp_extension import BuildExtension, CUDA_HOME 12 | from torch.utils.cpp_extension import CppExtension, CUDAExtension 13 | from torch.cuda import is_available as cuda_available 14 | from torch import version as torch_version 15 | from pathlib import Path 16 | import torch 17 | from setup_utils import check_for_openmp, clean 18 | import subprocess 19 | cwd = Path.cwd() 20 | torch_ver = torch.__version__ 21 | torch_ver = torch_ver.split('+')[0] if '+' in torch_ver else torch_ver 22 | 23 | requirements = [f'torch >= {PYTORCH_VERSION}'] 24 | 25 | #cuda 26 | cuda_avail = (cuda_available() and (CUDA_HOME is not None)) or os.getenv('FORCE_CUDA', '0') == '1' 27 | cu_ver = '' 28 | if cuda_avail: 29 | if CUDA_HOME is not None: 30 | cu_ver = Path(CUDA_HOME).resolve().name.strip('cuda-') 31 | elif cuda_available(): 32 | cu_ver = copy(torch_version.cuda) 33 | if cu_ver: 34 | cu_ver = 'cu' + cu_ver 35 | 36 | if torch_ver < '1.8': 37 | from torch_patch import patch_torch_infer_schema_h 38 | __SUCC = patch_torch_infer_schema_h() 39 | if not __SUCC: 40 | print('Something went wrong during patching! The CUDA build have chance to fail!') 41 | MODULE_VERSION += f'+{cu_ver if cu_ver else ""}torch{torch_ver.strip("0").strip(".")}' 42 | 43 | version = copy.copy(MODULE_VERSION) 44 | sha = 'Unknown' 45 | try: 46 | sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=str(cwd)).decode('ascii').strip() 47 | except Exception: 48 | pass 49 | if sha != 'Unknown': 50 | version += '+' + sha[:7] 51 | print(f'Building wheel {MODULE_NAME}-{version}') 52 | 53 | version_path = cwd / MODULE_NAME / 'version.py' 54 | if version_path.exists(): 55 | version_path.unlink() 56 | version_path.touch() 57 | version_path = version_path.open("a") 58 | version_path.write(f"__version__ = '{version}'\n") 59 | version_path.write(f"git_version = {repr(sha)}\n") 60 | version_path.write(f"from {MODULE_NAME}.extension import _check_cuda_version\n") 61 | version_path.write("if _check_cuda_version() > 0:\n") 62 | version_path.write(" cuda = _check_cuda_version()\n") 63 | version_path.close() 64 | 65 | 66 | def get_extensions(): 67 | extensions_dir = cwd / MODULE_NAME / 'csrc' 68 | 69 | sources = list(extensions_dir.glob('*.cpp')) 70 | sources += list((extensions_dir / 'ops').glob('*.cpp')) 71 | sources += list((extensions_dir / 'ops' / 'autograd').glob('*.cpp')) 72 | sources += list((extensions_dir / 'ops' / 'cpu').glob('*.cpp')) 73 | sources += list((extensions_dir / 'ops' / 'quantized').glob('*.cpp')) 74 | 75 | extension = CppExtension 76 | 77 | define_macros = [] 78 | extra_compile_args = {'cxx':[f'-std={STD_VERSION}', '-O3']} 79 | 80 | parallel_method = ['-DAT_PARALLEL_NATIVE=1'] 81 | if sys.platform == 'win32': 82 | parallel_method = ['-DAT_PARALLEL_NATIVE_TBB=1'] 83 | extra_compile_args['cxx'].append('/MP') 84 | define_macros += [('TORCHSHIFTS_EXPORTS', None)] 85 | if sys.platform == 'linux': 86 | extra_compile_args['cxx'].append('-Wno-unused-but-set-variable') 87 | extra_compile_args['cxx'].append('-Wno-unused-variable') 88 | extra_compile_args['cxx'].append('-Wno-sign-compare') 89 | extra_compile_args['cxx'].extend(['-Wno-unknown-pragmas','-Wno-unused-function']) 90 | if check_for_openmp(): 91 | parallel_method = ['-fopenmp','-DAT_PARALLEL_OPENMP=1'] 92 | extra_compile_args['cxx'].extend(parallel_method) 93 | 94 | if cuda_avail: 95 | print('Building with CUDA') 96 | extension = CUDAExtension 97 | sources += list((extensions_dir / 'ops' / 'cuda').glob('*.cu')) 98 | define_macros += [('WITH_CUDA', None)] 99 | extra_compile_args['nvcc'] = ['-O3', '-DNDEBUG', '--expt-extended-lambda'] 100 | if os.getenv('NVCC_FLAGS', '') != '': 101 | extra_compile_args['nvcc'].extend(os.getenv('NVCC_FLAGS', '').split(' ')) 102 | 103 | # time to dirty things 104 | if torch_ver >= '1.8': 105 | define_macros += [('TORCH18', None)] 106 | elif torch_ver >= '1.7': 107 | define_macros += [('TORCH17', None)] 108 | else: 109 | print(f'PyTorch Version <1.7 is not supported! Your version is {torch_ver}. Please update and try again') 110 | quit() 111 | 112 | sources = list(set(map(lambda x: str(x.resolve()), sources))) 113 | include_dirs = [str(extensions_dir)] 114 | ext_modules = [ 115 | extension( 116 | f'{MODULE_NAME}._C', 117 | sources, 118 | include_dirs=include_dirs, 119 | define_macros=define_macros, 120 | extra_compile_args=extra_compile_args, 121 | ) 122 | ] 123 | return ext_modules 124 | 125 | setup( 126 | # Metadata 127 | name=MODULE_NAME, 128 | version=MODULE_VERSION, 129 | description='Implementation of Sparse Active Shift https://arxiv.org/pdf/1903.05285.pdf for PyTorch', 130 | keywords=['shifts','activeshifts', 'shiftspytorch'], 131 | author='Ignatii Dubyshkin aka DeadAt0m', 132 | author_email='kheldi@yandex.ru', 133 | url='https://github.com/DeadAt0m/ActiveSparseShifts-PyTorch', 134 | license='BSD', 135 | 136 | # Package info 137 | packages=find_packages(), 138 | package_dir={MODULE_NAME : MODULE_NAME}, 139 | package_data={ MODULE_NAME:['*.dll', '*.dylib', '*.so'] }, 140 | zip_safe=False, 141 | install_requires=requirements, 142 | ext_modules=get_extensions(), 143 | cmdclass={ 144 | 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True), 145 | 'clean': clean, 146 | } 147 | ) 148 | -------------------------------------------------------------------------------- /setup_utils.py: -------------------------------------------------------------------------------- 1 | import os, tempfile, subprocess, shutil, glob 2 | import distutils.command.clean 3 | 4 | 5 | def check_for_openmp(): 6 | omp_test = \ 7 | r""" 8 | #include 9 | #include 10 | int main() { 11 | #pragma omp parallel 12 | printf("Hello from thread %d, nthreads %d\n", omp_get_thread_num(), omp_get_num_threads()); 13 | } 14 | """ 15 | tmpdir = tempfile.mkdtemp() 16 | curdir = os.getcwd() 17 | os.chdir(tmpdir) 18 | filename = 'test.c' 19 | with open(filename, 'w', buffering=1) as file: 20 | file.write(omp_test) 21 | with open(os.devnull, 'w') as fnull: 22 | result = subprocess.call(['cc', '-fopenmp', filename], stdout=fnull, stderr=fnull) 23 | os.chdir(curdir) 24 | shutil.rmtree(tmpdir) 25 | return not bool(result) 26 | 27 | class clean(distutils.command.clean.clean): 28 | def run(self): 29 | this_dir = os.path.dirname(os.path.abspath(__file__)) 30 | with open(os.path.join(this_dir,'.gitignore'), 'r') as f: 31 | ignores = f.read() 32 | for wildcard in filter(None, ignores.split('\n')): 33 | for filename in glob.glob(wildcard): 34 | try: 35 | os.remove(filename) 36 | except OSError: 37 | shutil.rmtree(filename, ignore_errors=True) 38 | distutils.command.clean.clean.run(self) -------------------------------------------------------------------------------- /shifts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeadAt0m/ActiveSparseShifts-PyTorch/218c6beb83956c8be53b07fe7919f6de27cebb74/shifts.png -------------------------------------------------------------------------------- /tests/shifts_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchshifts import Shift2d 3 | from torchshifts.quantized.modules import Shift2d as QShift2d 4 | from torchshifts.functional import shift2d_func 5 | 6 | if __name__ == '__main__': 7 | d = torch.device('cpu') 8 | channels = 16 9 | ia = torch.rand(512,channels, 64, 64, requires_grad=True).to(d) 10 | ib = ia.detach().clone().to(d) 11 | ib.requires_grad = True 12 | ta = 10*torch.rand(512,channels,62,62).to(d) 13 | tb = ta.detach().clone() 14 | args = {'kernel_size':3, 'stride':1, 'padding':(0,0)} 15 | # args = None 16 | a = Shift2d(channels, init_shift=1, sparsity_term=0., active_flag=False,emulate_dw=args, 17 | init_thumb_rule=2).to(d) 18 | b = Shift2d(channels, init_shift=1, sparsity_term=0., active_flag=True,emulate_dw=args, 19 | init_thumb_rule=2).to(d) 20 | a.weight = b.weight 21 | la = torch.nn.MSELoss().to(d) 22 | lb = torch.nn.MSELoss().to(d) 23 | oa,_ = a(ia) 24 | ob,_ = b(ib) 25 | jac = la(oa, ta) 26 | jbc = lb(ob, tb) 27 | jac.backward() 28 | jbc.backward() 29 | assert a.weight.grad is not None, 'FAIL' 30 | assert b.weight.grad is not None, 'FAIL' 31 | print(a.weight.grad) 32 | import random 33 | c = random.randrange(channels) 34 | iq = torch.quantize_per_tensor(ia, 1/255.,0, torch.quint8) 35 | aq = QShift2d.from_float(a) 36 | oq = aq(iq) 37 | print('Channel:', c) 38 | print('Weights:', a.weight[c]) 39 | print('Forward pass (input, output(active=False), quantized, output(active=True)):', ia[0,c], oa[0,c], oq[0,c], ob[0,c]) 40 | #### Test interpolation 41 | i,j = 3,3 42 | i_lb, j_lb = 1,1 43 | import math 44 | w1,w2 = a.weight[c] 45 | dw = lambda k: k.item() - math.floor(k) if k>0 else k.item() - math.ceil(k) 46 | iw = lambda k: math.floor(k) if k>0 else math.ceil(k) 47 | dw1 = dw(w1) 48 | dw2 = dw(w2) 49 | si = iw(w1) 50 | sj = iw(w2) 51 | a00 = ia[0,c,i-si, j-sj].item() 52 | a10 = ia[0,c,i+1-si, j-sj].item() 53 | a01 = ia[0,c,i-si,j+1-sj].item() 54 | a11 = ia[0,c,i+1-si,j+1-sj].item() 55 | def interp1D(v1, v2, x): 56 | return v1*(1 - x) + v2*x 57 | def interp2D(v1, v2, v3, v4, x, y): 58 | return interp1D(interp1D(v1, v2, x), interp1D(v3, v4, x), y) 59 | from math import floor 60 | print(abs(float(interp2D(a00,a10,a01,a11,dw1,dw2)) - float(ob[0,c,i-i_lb,j-j_lb].item()))) 61 | #test grad 62 | print('Backward pass (shape, SSL grad, Active grad):', ia.grad.shape, ia.grad[0,c], ib.grad[0,c]) 63 | #CUDA 64 | d = torch.device('cuda:0') 65 | iac = ia.detach().clone().to(d) 66 | ibc = ib.detach().clone().to(d) 67 | iac.requires_grad = True 68 | ibc.requires_grad = True 69 | tac = ta.detach().clone().to(d) 70 | tbc = tb.detach().clone().to(d) 71 | from copy import deepcopy as dp 72 | ac = dp(a).to(d) 73 | bc = dp(b).to(d) 74 | ac.weight =bc.weight 75 | lac = torch.nn.MSELoss().to(d) 76 | lbc = torch.nn.MSELoss().to(d) 77 | oac,_ = ac(iac) 78 | obc,_ = bc(ibc) 79 | jac = lac(oac, tac) 80 | jbc = lbc(obc, tbc) 81 | jac.backward() 82 | jbc.backward() 83 | print('CUDAvsCPU difference: Forward pass SSL:', abs(oac.cpu()[0,c] - oa[0,c])) 84 | print('CUDAvsCPU difference: Forward pass Active:', abs(obc.cpu()[0,c] - ob[0,c])) 85 | print('CUDAvsCPU difference: Backward pass SSL:', abs(iac.grad.cpu()[0,c] - ia.grad[0,c])) 86 | print('CUDAvsCPU difference: Backward pass Active:', abs(ibc.grad.cpu()[0,c] - ib.grad[0,c])) -------------------------------------------------------------------------------- /torch_patch.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | import sys 4 | from copy import deepcopy as copy 5 | import subprocess 6 | import torch 7 | 8 | python_exec = sys.executable 9 | torch_path = Path(torch.__file__).parent.resolve() 10 | 11 | find_str_1 = ''' 12 | struct ArgumentDef final { 13 | using GetTypeFn = TypePtr(); 14 | GetTypeFn* getTypeFn; 15 | }; 16 | ''' 17 | 18 | patch_str_1 = ''' 19 | struct ArgumentDef final { 20 | using GetTypeFn = TypePtr(); 21 | GetTypeFn* getTypeFn; 22 | constexpr ArgumentDef(): getTypeFn(nullptr) {} 23 | explicit constexpr ArgumentDef(GetTypeFn *getTypeFn): getTypeFn(getTypeFn) {} 24 | }; 25 | ''' 26 | 27 | find_str_2 = r"std::array{{ArgumentDef{&getTypePtr_>::call}...}}" 28 | patch_str_2 = r"std::array{ArgumentDef(&getTypePtr_>::call)...}" 29 | 30 | 31 | def patch_torch_infer_schema_h(): 32 | infer_schema_header = torch_path / 'include' / 'ATen' / 'core' / 'op_registration' / 'infer_schema.h' 33 | if not infer_schema_header.exists(): 34 | print(f'{str(infer_schema_header)} not found') 35 | return False 36 | content = infer_schema_header.read_text() 37 | orig_content = copy(content) 38 | ret = True 39 | content = content.replace(find_str_1, patch_str_1) 40 | ret *= (content.find(find_str_1) == -1) 41 | content = content.replace(find_str_2, patch_str_2) 42 | ret *= (content.find(find_str_2) == -1) 43 | if content != orig_content: 44 | print(f'Try writing into file: {str(infer_schema_header)}...') 45 | try: 46 | infer_schema_header.unlink() 47 | infer_schema_header.write_text(content) 48 | except: 49 | print('You need to execute this as root for proper patching!') 50 | subprocess.call(['sudo', python_exec, *sys.argv]) 51 | sys.exit() 52 | print('Success!') 53 | return ret 54 | 55 | if __name__ == '__main__': 56 | print(patch_torch_infer_schema_h()) -------------------------------------------------------------------------------- /torchshifts/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from pathlib import Path 3 | from .extension import _HAS_OPS 4 | 5 | try: 6 | from .version import __version__ 7 | except ImportError: 8 | pass 9 | 10 | # Check if torchshifts is being imported within the root folder 11 | if (not _HAS_OPS and Path(__file__).parent.resolve() == (Path.cwd() / 'torchshifts')): 12 | message = (f'You are importing torchshifts within its own root folder ({Path.cwd() / "torchshifts"}). ' 13 | 'This is not expected to work and may give errors. Please exit the ' 14 | 'torchshifts project source and relaunch your python interpreter.') 15 | warnings.warn(message) 16 | 17 | from torchshifts.modules import Shift1d, Shift2d, Shift3d 18 | from torchshifts.quantized import quant_mapping -------------------------------------------------------------------------------- /torchshifts/csrc/macros.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(TORCH18) 4 | #define TS_TORCH_LIBRARY_FRAGMENT(ns,m) TORCH_LIBRARY_FRAGMENT(ns, m) 5 | #elif defined(TORCH17) 6 | #define TS_TORCH_LIBRARY_FRAGMENT(ns,m) TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) 7 | #else 8 | #define TS_TORCH_LIBRARY_FRAGMENT(ns,m) TORCH_LIBRARY(ns, m) 9 | #endif 10 | 11 | 12 | #ifdef _WIN32 13 | #if defined(TORCHSHIFTS_EXPORTS) 14 | #define API_EXPORT __declspec(dllexport) 15 | #else 16 | #define API_EXPORT __declspec(dllimport) 17 | #endif 18 | #else 19 | #define API_EXPORT 20 | #endif 21 | -------------------------------------------------------------------------------- /torchshifts/csrc/ops/autograd/shifts_autograd.cpp: -------------------------------------------------------------------------------- 1 | #ifndef SHIFTS_CPU 2 | #define SHIFTS_CPU 3 | 4 | #include "../shifts.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace shifts { 10 | namespace ops { 11 | 12 | namespace { 13 | 14 | 15 | class Shift1dFunction : public torch::autograd::Function { 16 | public: 17 | static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, 18 | const torch::Tensor& input, 19 | const torch::Tensor& weight, 20 | const torch::Tensor& borders, 21 | const std::vector& new_size, 22 | int64_t padding_mode, bool active_flag){ 23 | at::AutoNonVariableTypeMode g; 24 | auto output = detail::_shift1d_forward(input, weight, borders, new_size, padding_mode, active_flag); 25 | ctx->saved_data["padding_mode"] = padding_mode; 26 | ctx->saved_data["active_flag"] = active_flag; 27 | ctx->save_for_backward({input, weight, borders}); 28 | return {output}; 29 | } 30 | 31 | static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, 32 | const torch::autograd::variable_list& grad_output) { 33 | auto saved = ctx->get_saved_variables(); 34 | auto input = saved[0]; 35 | auto weight = saved[1]; 36 | auto borders = saved[2]; 37 | auto padding_mode = ctx->saved_data["padding_mode"].toInt(); 38 | auto active_flag = ctx->saved_data["active_flag"].toBool(); 39 | 40 | auto result = detail::_shift1d_backward(grad_output[0], weight, input, borders, 41 | padding_mode, active_flag); 42 | auto grad_input = std::get<0>(result); 43 | auto grad_weight = std::get<1>(result); 44 | return {grad_input, grad_weight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; 45 | 46 | } 47 | }; 48 | 49 | // Hack for backward working during dispatch 50 | class Shift1dBackwardFunction: public torch::autograd::Function { 51 | public: 52 | static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, 53 | const torch::Tensor& grad, 54 | const torch::Tensor& weights, 55 | const torch::Tensor& input, 56 | const torch::Tensor& borders, 57 | int64_t padding_mode, 58 | bool active_flag) { 59 | at::AutoNonVariableTypeMode g; 60 | auto result = detail::_shift1d_backward(grad, weights, input, borders, 61 | padding_mode, active_flag); 62 | auto grad_input = std::get<0>(result); 63 | auto grad_weight = std::get<1>(result); 64 | 65 | return { grad_input, grad_weight }; 66 | } 67 | 68 | static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, 69 | const torch::autograd::variable_list& grad_output) { 70 | TORCH_CHECK(0, "double backwards on shift1d not supported"); 71 | } 72 | }; 73 | 74 | 75 | 76 | 77 | class Shift2dFunction : public torch::autograd::Function { 78 | public: 79 | static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, 80 | const torch::Tensor& input, 81 | const torch::Tensor& weight, 82 | const torch::Tensor& borders, 83 | const std::vector& new_size, 84 | int64_t padding_mode, bool active_flag){ 85 | at::AutoNonVariableTypeMode g; 86 | auto output = detail::_shift2d_forward(input, weight, borders, new_size, padding_mode, active_flag); 87 | ctx->saved_data["padding_mode"] = padding_mode; 88 | ctx->saved_data["active_flag"] = active_flag; 89 | ctx->save_for_backward({input, weight, borders}); 90 | return {output}; 91 | } 92 | 93 | static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, 94 | const torch::autograd::variable_list& grad_output) { 95 | auto saved = ctx->get_saved_variables(); 96 | auto input = saved[0]; 97 | auto weight = saved[1]; 98 | auto borders = saved[2]; 99 | auto padding_mode = ctx->saved_data["padding_mode"].toInt(); 100 | auto active_flag = ctx->saved_data["active_flag"].toBool(); 101 | 102 | auto result = detail::_shift2d_backward(grad_output[0], weight, input, borders, 103 | padding_mode, active_flag); 104 | auto grad_input = std::get<0>(result); 105 | auto grad_weight = std::get<1>(result); 106 | return {grad_input, grad_weight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; 107 | 108 | } 109 | }; 110 | 111 | 112 | // Hack for backward working during dispatch 113 | class Shift2dBackwardFunction: public torch::autograd::Function { 114 | public: 115 | static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, 116 | const torch::Tensor& grad, 117 | const torch::Tensor& weights, 118 | const torch::Tensor& input, 119 | const torch::Tensor& borders, 120 | int64_t padding_mode, 121 | bool active_flag) { 122 | at::AutoNonVariableTypeMode g; 123 | auto result = detail::_shift2d_backward(grad, weights, input, borders, 124 | padding_mode, active_flag); 125 | auto grad_input = std::get<0>(result); 126 | auto grad_weight = std::get<1>(result); 127 | 128 | return { grad_input, grad_weight }; 129 | } 130 | 131 | static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, 132 | const torch::autograd::variable_list& grad_output) { 133 | TORCH_CHECK(0, "double backwards on shift2d not supported"); 134 | } 135 | }; 136 | 137 | 138 | 139 | 140 | class Shift3dFunction : public torch::autograd::Function { 141 | public: 142 | static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, 143 | const torch::Tensor& input, 144 | const torch::Tensor& weight, 145 | const torch::Tensor& borders, 146 | const std::vector& new_size, 147 | int64_t padding_mode, bool active_flag){ 148 | at::AutoNonVariableTypeMode g; 149 | auto output = detail::_shift3d_forward(input, weight, borders, new_size, padding_mode, active_flag); 150 | ctx->saved_data["padding_mode"] = padding_mode; 151 | ctx->saved_data["active_flag"] = active_flag; 152 | ctx->save_for_backward({input, weight, borders}); 153 | return { output }; 154 | } 155 | 156 | static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, 157 | const torch::autograd::variable_list& grad_output) { 158 | auto saved = ctx->get_saved_variables(); 159 | auto input = saved[0]; 160 | auto weight = saved[1]; 161 | auto borders = saved[2]; 162 | auto padding_mode = ctx->saved_data["padding_mode"].toInt(); 163 | auto active_flag = ctx->saved_data["active_flag"].toBool(); 164 | 165 | auto result = detail::_shift3d_backward(grad_output[0], weight, input, borders, 166 | padding_mode, active_flag); 167 | auto grad_in = std::get<0>(result); 168 | auto grad_weight = std::get<1>(result); 169 | return {grad_in, grad_weight, torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()}; 170 | 171 | } 172 | }; 173 | 174 | 175 | // Hack for backward working during dispatch 176 | class Shift3dBackwardFunction: public torch::autograd::Function { 177 | public: 178 | static torch::autograd::variable_list forward(torch::autograd::AutogradContext* ctx, 179 | const torch::Tensor& grad, 180 | const torch::Tensor& weights, 181 | const torch::Tensor& input, 182 | const torch::Tensor& borders, 183 | int64_t padding_mode, 184 | bool active_flag) { 185 | at::AutoNonVariableTypeMode g; 186 | auto result = detail::_shift3d_backward(grad, weights, input, borders, 187 | padding_mode, active_flag); 188 | auto grad_input = std::get<0>(result); 189 | auto grad_weight = std::get<1>(result); 190 | 191 | return { grad_input, grad_weight }; 192 | } 193 | 194 | static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, 195 | const torch::autograd::variable_list& grad_output) { 196 | TORCH_CHECK(0, "double backwards on shift3d not supported"); 197 | } 198 | }; 199 | 200 | 201 | torch::Tensor shift1d_autograd(const torch::Tensor& input, 202 | const torch::Tensor& weights, 203 | const torch::Tensor& borders, 204 | const std::vector& new_size, 205 | int64_t padding_mode, 206 | bool active_flag){ 207 | return Shift1dFunction::apply(input, weights, borders, new_size, padding_mode, active_flag)[0]; 208 | } 209 | 210 | torch::Tensor shift2d_autograd(const torch::Tensor& input, 211 | const torch::Tensor& weights, 212 | const torch::Tensor& borders, 213 | const std::vector& new_size, 214 | int64_t padding_mode, 215 | bool active_flag){ 216 | return Shift2dFunction::apply(input, weights, borders, new_size, padding_mode, active_flag)[0]; 217 | } 218 | 219 | torch::Tensor shift3d_autograd(const torch::Tensor& input, 220 | const torch::Tensor& weights, 221 | const torch::Tensor& borders, 222 | const std::vector& new_size, 223 | int64_t padding_mode, 224 | bool active_flag){ 225 | return Shift3dFunction::apply(input, weights, borders, new_size, padding_mode, active_flag)[0]; 226 | } 227 | 228 | std::tuple shift1d_autograd_backward(const torch::Tensor& grad, 229 | const torch::Tensor& weights, 230 | const torch::Tensor& input, 231 | const torch::Tensor& borders, 232 | int64_t padding_mode, 233 | bool active_flag){ 234 | auto result = Shift1dBackwardFunction::apply(grad, weights, input, borders, padding_mode, active_flag); 235 | return std::make_tuple(result[0], result[1]); 236 | } 237 | 238 | std::tuple shift2d_autograd_backward(const torch::Tensor& grad, 239 | const torch::Tensor& weights, 240 | const torch::Tensor& input, 241 | const torch::Tensor& borders, 242 | int64_t padding_mode, 243 | bool active_flag){ 244 | auto result = Shift2dBackwardFunction::apply(grad, weights, input, borders, padding_mode, active_flag); 245 | return std::make_tuple(result[0], result[1]); 246 | } 247 | 248 | std::tuple shift3d_autograd_backward(const torch::Tensor& grad, 249 | const torch::Tensor& weights, 250 | const torch::Tensor& input, 251 | const torch::Tensor& borders, 252 | int64_t padding_mode, 253 | bool active_flag){ 254 | auto result = Shift3dBackwardFunction::apply(grad, weights, input, borders, padding_mode, active_flag); 255 | return std::make_tuple(result[0], result[1]); 256 | } 257 | 258 | 259 | 260 | } // end of anonymous namespace 261 | 262 | TORCH_LIBRARY_IMPL(torchshifts, Autograd, m) { 263 | m.impl( 264 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_forward"), 265 | TORCH_FN(shift1d_autograd)); 266 | m.impl( 267 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_backward"), 268 | TORCH_FN(shift1d_autograd_backward)); 269 | m.impl( 270 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_forward"), 271 | TORCH_FN(shift2d_autograd)); 272 | m.impl( 273 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_backward"), 274 | TORCH_FN(shift2d_autograd_backward)); 275 | m.impl( 276 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_forward"), 277 | TORCH_FN(shift3d_autograd)); 278 | m.impl( 279 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_backward"), 280 | TORCH_FN(shift3d_autograd_backward)); 281 | } 282 | 283 | } // namespace ops 284 | } // namespace shifts 285 | 286 | #endif -------------------------------------------------------------------------------- /torchshifts/csrc/ops/cpu/shifts_cpu.cpp: -------------------------------------------------------------------------------- 1 | #ifndef SHIFTS_CPU 2 | #define SHIFTS_CPU 3 | 4 | #include 5 | #include "../global_scope.h" 6 | #include "../kernels/shifts_kernels.h" 7 | #include 8 | 9 | namespace shifts { 10 | namespace ops { 11 | 12 | namespace { 13 | 14 | 15 | template 18 | API_INLINE void shiftnd_forward_kernel(const torch::Tensor& input, const torch::Tensor& iweights, 19 | const torch::Tensor& dweights, 20 | const torch::Tensor& borders, 21 | torch::Tensor& output){ 22 | const int64_t sizeN = input.size(0); 23 | const int64_t sizeC = input.size(1); 24 | const int64_t sizeH = input.size(2); 25 | const int64_t sizeW = kSpatialDim < 2 ? 1 : input.size(3); 26 | const int64_t sizeD = kSpatialDim < 3 ? 1 : input.size(4); 27 | const int64_t input_sN = input.stride(0); 28 | const int64_t input_sC = input.stride(1); 29 | const int64_t input_sH = input.stride(2); 30 | const int64_t input_sW = kSpatialDim < 2 ? 0 : input.stride(3); 31 | const int64_t input_sD = kSpatialDim < 3 ? 0 : input.stride(4); 32 | const int64_t output_sN = output.stride(0); 33 | const int64_t output_sC = output.stride(1); 34 | const int64_t output_sH = output.stride(2); 35 | const int64_t output_sW = kSpatialDim < 2 ? 0 : output.stride(3); 36 | const int64_t output_sD = kSpatialDim < 3 ? 0 : output.stride(4); 37 | scalar_t* input_ptr = input.data_ptr(); 38 | scalar_t* output_ptr = output.data_ptr(); 39 | int64_t* weights_ptr = iweights.data_ptr(); 40 | const int64_t weights_sC = iweights.stride(0); 41 | const int64_t weights_sS = iweights.stride(1); 42 | scalar_t* dweights_ptr = dweights.data_ptr(); 43 | const int64_t dweights_sC = dweights.stride(0); 44 | const int64_t dweights_sS = dweights.stride(1); 45 | 46 | int64_t* borders_data = borders.data_ptr(); 47 | const int64_t i_left_border = borders_data[0]; 48 | const int64_t i_right_border = borders_data[1]; 49 | const int64_t j_left_border = kSpatialDim < 2 ? 0 : borders_data[2]; 50 | const int64_t j_right_border = kSpatialDim < 2 ? 1 : borders_data[3]; 51 | const int64_t k_left_border = kSpatialDim < 3 ? 0 : borders_data[4]; 52 | const int64_t k_right_border = kSpatialDim < 3 ? 1 : borders_data[5]; 53 | 54 | 55 | if (input.is_contiguous(c10::MemoryFormat::ChannelsLast) || input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) 56 | {// Path for NDHWC 57 | at::parallel_for(0, sizeN, 0, [&](int64_t start, int64_t end){ 58 | for (int64_t n = start; n < end; ++n) { 59 | for (int64_t i = 0; i < sizeH; ++i){ 60 | for (int64_t j = 0; j < sizeW; ++j){ 61 | for (int64_t k = 0; k < sizeD; ++k){ 62 | shift_forward_kernel_nhwdc( 64 | input_ptr, output_ptr, weights_ptr, dweights_ptr, 65 | n, i, j, k, sizeC, sizeH, sizeW, sizeD, 66 | input_sN, input_sC, input_sH, input_sW, input_sD, 67 | output_sN, output_sC, output_sH, output_sW, output_sD, 68 | weights_sC, weights_sS, dweights_sC, dweights_sS, 69 | i_left_border, j_left_border, k_left_border, 70 | i_right_border, j_right_border, k_right_border); 71 | } 72 | } 73 | } 74 | } 75 | }); 76 | } else 77 | { 78 | at::parallel_for(0, sizeN*sizeC, 0, [&](int64_t start, int64_t end){ 79 | for (int64_t index = start; index < end; ++index) { 80 | const int64_t c = index % sizeC; 81 | const int64_t n = index / sizeC; 82 | for (int64_t i = 0; i < sizeH; ++i){ 83 | for (int64_t j = 0; j < sizeW; ++j){ 84 | for (int64_t k = 0; k < sizeD; ++k){ 85 | shift_forward_kernel_nchwd( 87 | input_ptr, output_ptr, weights_ptr, dweights_ptr, 88 | n, c, i, j, k, sizeH, sizeW, sizeD, 89 | input_sN, input_sC, input_sH, input_sW, input_sD, 90 | output_sN, output_sC, output_sH, output_sW, output_sD, 91 | weights_sC, weights_sS, dweights_sC, dweights_sS, 92 | i_left_border, j_left_border, k_left_border, 93 | i_right_border, j_right_border, k_right_border); 94 | } 95 | } 96 | } 97 | } 98 | }); 99 | } 100 | } 101 | 102 | 103 | template 106 | API_INLINE void shiftnd_backward_kernel(const torch::Tensor& grad_input, 107 | const torch::Tensor& iweights, 108 | const torch::Tensor& dweights, 109 | const torch::Tensor& input, 110 | const torch::Tensor& borders, 111 | torch::Tensor& grad_output, 112 | torch::Tensor& grad_weights) 113 | { 114 | const int64_t sizeN = input.size(0); 115 | const int64_t sizeC = input.size(1); 116 | const int64_t sizeH = input.size(2); 117 | const int64_t sizeW = kSpatialDim < 2 ? 1 : input.size(3); 118 | const int64_t sizeD = kSpatialDim < 3 ? 1 : input.size(4); 119 | const int64_t grad_input_sN = grad_input.stride(0); 120 | const int64_t grad_input_sC = grad_input.stride(1); 121 | const int64_t grad_input_sH = grad_input.stride(2); 122 | const int64_t grad_input_sW = kSpatialDim < 2 ? 0 : grad_input.stride(3); 123 | const int64_t grad_input_sD = kSpatialDim < 3 ? 0 : grad_input.stride(4); 124 | const int64_t input_sN = input.stride(0); 125 | const int64_t input_sC = input.stride(1); 126 | const int64_t input_sH = input.stride(2); 127 | const int64_t input_sW = kSpatialDim < 2 ? 0 : input.stride(3); 128 | const int64_t input_sD = kSpatialDim < 3 ? 0 : input.stride(4); 129 | const int64_t grad_output_sN = grad_output.stride(0); 130 | const int64_t grad_output_sC = grad_output.stride(1); 131 | const int64_t grad_output_sH = grad_output.stride(2); 132 | const int64_t grad_output_sW = kSpatialDim < 2 ? 0 : grad_output.stride(3); 133 | const int64_t grad_output_sD = kSpatialDim < 3 ? 0 : grad_output.stride(4); 134 | int64_t* weights_ptr = iweights.data_ptr(); 135 | const int64_t weights_sC = iweights.stride(0); 136 | const int64_t weights_sS = iweights.stride(1); 137 | scalar_t* dweights_ptr = dweights.data_ptr(); 138 | const int64_t dweights_sC = dweights.stride(0); 139 | const int64_t dweights_sS = dweights.stride(1); 140 | const int64_t grad_weights_sC = grad_weights.stride(0); 141 | const int64_t grad_weights_sS = grad_weights.stride(1); 142 | scalar_t* grad_weights_ptr = grad_weights.data_ptr(); 143 | scalar_t* grad_input_ptr = grad_input.data_ptr(); 144 | scalar_t* input_ptr = input.data_ptr(); 145 | scalar_t* grad_output_ptr = grad_output.data_ptr(); 146 | 147 | int64_t* borders_data = borders.data_ptr(); 148 | const int64_t i_left_border = borders_data[0]; 149 | const int64_t i_right_border = borders_data[1]; 150 | const int64_t j_left_border = kSpatialDim < 2 ? 0 : borders_data[2]; 151 | const int64_t j_right_border = kSpatialDim < 2 ? 1 : borders_data[3]; 152 | const int64_t k_left_border = kSpatialDim < 3 ? 0 : borders_data[4]; 153 | const int64_t k_right_border = kSpatialDim < 3 ? 1 : borders_data[5]; 154 | 155 | 156 | if (input.is_contiguous(c10::MemoryFormat::ChannelsLast) || input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) 157 | {// Path for NDHWC 158 | at::parallel_for(0, sizeN, 0, [&](int64_t start, int64_t end){ 159 | for (int64_t n = start; n < end; ++n) { 160 | for (int64_t i = 0; i < sizeH; ++i){ 161 | for (int64_t j = 0; j < sizeW; ++j){ 162 | for (int64_t k = 0; k < sizeD; ++k){ 163 | shift_backward_kernel_nhwdc( 165 | grad_input_ptr, input_ptr, grad_output_ptr, 166 | weights_ptr, dweights_ptr, grad_weights_ptr, 167 | n, i, j, k, sizeC, sizeH, sizeW, sizeD, 168 | grad_input_sN, grad_input_sC, grad_input_sH, 169 | grad_input_sW, grad_input_sD, 170 | input_sN, input_sC, input_sH, input_sW, input_sD, 171 | grad_output_sN, grad_output_sC, grad_output_sH, 172 | grad_output_sW, grad_output_sD, 173 | weights_sC, weights_sS, dweights_sC, dweights_sS, 174 | grad_weights_sC, grad_weights_sS, 175 | i_left_border, j_left_border, k_left_border, 176 | i_right_border, j_right_border, k_right_border); 177 | } 178 | } 179 | } 180 | } 181 | }); 182 | } else 183 | { 184 | at::parallel_for(0, sizeN*sizeC, 0, [&](int64_t start, int64_t end){ 185 | for (int64_t index = start; index < end; ++index) { 186 | const int64_t c = index % sizeC; 187 | const int64_t n = index / sizeC; 188 | for (int64_t i = 0; i < sizeH; ++i){ 189 | for (int64_t j = 0; j < sizeW; ++j){ 190 | for (int64_t k = 0; k < sizeD; ++k){ 191 | shift_backward_kernel_nchwd( 193 | grad_input_ptr, input_ptr, grad_output_ptr, 194 | weights_ptr, dweights_ptr, grad_weights_ptr, 195 | n, c, i, j, k, sizeC, sizeH, sizeW, sizeD, 196 | grad_input_sN, grad_input_sC, grad_input_sH, 197 | grad_input_sW, grad_input_sD, 198 | input_sN, input_sC, input_sH, input_sW, input_sD, 199 | grad_output_sN, grad_output_sC, grad_output_sH, 200 | grad_output_sW, grad_output_sD, 201 | weights_sC, weights_sS, dweights_sC, dweights_sS, 202 | grad_weights_sC, grad_weights_sS, 203 | i_left_border, j_left_border, k_left_border, 204 | i_right_border, j_right_border, k_right_border); 205 | } 206 | } 207 | } 208 | } 209 | }); 210 | } 211 | } 212 | 213 | 214 | template 216 | torch::Tensor shiftnd_forward(const torch::Tensor& input, 217 | const torch::Tensor& weights, 218 | const torch::Tensor& borders, 219 | const std::vector& new_size){ 220 | 221 | torch::Tensor output = torch::empty(new_size, input.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); 222 | 223 | torch::Tensor iweights = (active?torch::floor(weights):torch::round(weights)).to(torch::kLong); 224 | torch::Tensor dweights = active?(weights - iweights):torch::zeros_like(weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 225 | 226 | torch::Tensor _borders = borders.to(torch::kLong); 227 | 228 | AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "shiftnd_forward_cpu", [&] { 229 | shiftnd_forward_kernel(input, iweights, dweights, _borders, output); 230 | }); 231 | return output; 232 | } 233 | 234 | 235 | template 237 | std::tuple shiftnd_backward(const torch::Tensor& grad, 238 | const torch::Tensor& weights, 239 | const torch::Tensor& input, 240 | const torch::Tensor& borders) { 241 | 242 | torch::Tensor dweights = active?(weights - torch::floor(weights)):torch::where(weights>0,weights - torch::floor(weights), 243 | torch::ceil(weights) - weights); 244 | torch::Tensor iweights = (active?(weights - dweights):torch::round(weights)).to(torch::kLong); 245 | 246 | torch::Tensor out_grad = torch::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 247 | torch::Tensor weights_grad = torch::zeros_like(weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 248 | 249 | torch::Tensor _borders = borders.to(torch::kLong); 250 | 251 | AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "shiftnd_backward_cpu", [&] { 252 | shiftnd_backward_kernel(grad, iweights, dweights, input, _borders, out_grad, weights_grad); 253 | }); 254 | return std::make_tuple(out_grad, weights_grad); 255 | } 256 | 257 | 258 | // TEMPLATE DISPATCHERS 259 | 260 | torch::Tensor shift1d_forward(const torch::Tensor& input, 261 | const torch::Tensor& weights, 262 | const torch::Tensor& borders, 263 | const std::vector& new_size, 264 | int64_t padding_mode, 265 | bool active_flag){ 266 | torch::Tensor ret; 267 | switch (padding_mode){ 268 | case 0: 269 | ret = active_flag?shiftnd_forward<1,BIPadding::Zeros,true>(input, weights, borders, new_size): 270 | shiftnd_forward<1,BIPadding::Zeros,false>(input, weights, borders, new_size); 271 | break; 272 | case 1: 273 | ret = active_flag?shiftnd_forward<1,BIPadding::Border,true>(input, weights, borders, new_size): 274 | shiftnd_forward<1,BIPadding::Border,false>(input, weights, borders, new_size); 275 | break; 276 | case 2: 277 | ret = active_flag?shiftnd_forward<1,BIPadding::Periodic,true>(input, weights, borders, new_size): 278 | shiftnd_forward<1,BIPadding::Periodic,false>(input, weights, borders, new_size); 279 | break; 280 | case 3: 281 | ret = active_flag?shiftnd_forward<1,BIPadding::Reflect,true>(input, weights, borders, new_size): 282 | shiftnd_forward<1,BIPadding::Reflect,false>(input, weights, borders, new_size); 283 | break; 284 | case 4: 285 | ret = active_flag?shiftnd_forward<1,BIPadding::Symmetric,true>(input, weights, borders, new_size): 286 | shiftnd_forward<1,BIPadding::Symmetric,false>(input, weights, borders, new_size); 287 | break; 288 | } 289 | return ret; 290 | } 291 | 292 | torch::Tensor shift2d_forward(const torch::Tensor& input, 293 | const torch::Tensor& weights, 294 | const torch::Tensor& borders, 295 | const std::vector& new_size, 296 | int64_t padding_mode, 297 | bool active_flag){ 298 | torch::Tensor ret; 299 | switch (padding_mode){ 300 | case 0: 301 | ret = active_flag?shiftnd_forward<2,BIPadding::Zeros,true>(input, weights, borders, new_size): 302 | shiftnd_forward<2,BIPadding::Zeros,false>(input, weights, borders, new_size); 303 | break; 304 | case 1: 305 | ret = active_flag?shiftnd_forward<2,BIPadding::Border,true>(input, weights, borders, new_size): 306 | shiftnd_forward<2,BIPadding::Border,false>(input, weights, borders, new_size); 307 | break; 308 | case 2: 309 | ret = active_flag?shiftnd_forward<2,BIPadding::Periodic,true>(input, weights, borders, new_size): 310 | shiftnd_forward<2,BIPadding::Periodic,false>(input, weights, borders, new_size); 311 | break; 312 | case 3: 313 | ret = active_flag?shiftnd_forward<2,BIPadding::Reflect,true>(input, weights, borders, new_size): 314 | shiftnd_forward<2,BIPadding::Reflect,false>(input, weights, borders, new_size); 315 | break; 316 | case 4: 317 | ret = active_flag?shiftnd_forward<2,BIPadding::Symmetric,true>(input, weights, borders, new_size): 318 | shiftnd_forward<2,BIPadding::Symmetric,false>(input, weights, borders, new_size); 319 | break; 320 | } 321 | return ret; 322 | } 323 | 324 | torch::Tensor shift3d_forward(const torch::Tensor& input, 325 | const torch::Tensor& weights, 326 | const torch::Tensor& borders, 327 | const std::vector& new_size, 328 | int64_t padding_mode, 329 | bool active_flag){ 330 | torch::Tensor ret; 331 | switch (padding_mode){ 332 | case 0: 333 | ret = active_flag?shiftnd_forward<3,BIPadding::Zeros,true>(input, weights, borders, new_size): 334 | shiftnd_forward<3,BIPadding::Zeros,false>(input, weights, borders, new_size); 335 | break; 336 | case 1: 337 | ret = active_flag?shiftnd_forward<3,BIPadding::Border,true>(input, weights, borders, new_size): 338 | shiftnd_forward<3,BIPadding::Border,false>(input, weights, borders, new_size); 339 | break; 340 | case 2: 341 | ret = active_flag?shiftnd_forward<3,BIPadding::Periodic,true>(input, weights, borders, new_size): 342 | shiftnd_forward<3,BIPadding::Periodic,false>(input, weights, borders, new_size); 343 | break; 344 | case 3: 345 | ret = active_flag?shiftnd_forward<3,BIPadding::Reflect,true>(input, weights, borders, new_size): 346 | shiftnd_forward<3,BIPadding::Reflect,false>(input, weights, borders, new_size); 347 | break; 348 | case 4: 349 | ret = active_flag?shiftnd_forward<3,BIPadding::Symmetric,true>(input, weights, borders, new_size): 350 | shiftnd_forward<3,BIPadding::Symmetric,false>(input, weights, borders, new_size); 351 | break; 352 | } 353 | return ret; 354 | } 355 | 356 | 357 | std::tuple shift1d_backward(const torch::Tensor& grad, 358 | const torch::Tensor& weights, 359 | const torch::Tensor& input, 360 | const torch::Tensor& borders, 361 | int64_t padding_mode, 362 | bool active_flag){ 363 | std::tuple ret; 364 | switch (padding_mode){ 365 | case 0: 366 | ret = active_flag?shiftnd_backward<1,BIPadding::Zeros,true>(grad, weights, input, borders): 367 | shiftnd_backward<1,BIPadding::Zeros,false>(grad, weights, input, borders); 368 | break; 369 | case 1: 370 | ret = active_flag?shiftnd_backward<1,BIPadding::Border,true>(grad, weights, input, borders): 371 | shiftnd_backward<1,BIPadding::Border,false>(grad, weights, input, borders); 372 | break; 373 | case 2: 374 | ret = active_flag?shiftnd_backward<1,BIPadding::Periodic,true>(grad, weights, input, borders): 375 | shiftnd_backward<1,BIPadding::Periodic,false>(grad, weights, input, borders); 376 | break; 377 | case 3: 378 | ret = active_flag?shiftnd_backward<1,BIPadding::Reflect,true>(grad, weights, input, borders): 379 | shiftnd_backward<1,BIPadding::Reflect,false>(grad, weights, input, borders); 380 | break; 381 | case 4: 382 | ret = active_flag?shiftnd_backward<1,BIPadding::Symmetric,true>(grad, weights, input, borders): 383 | shiftnd_backward<1,BIPadding::Symmetric,false>(grad, weights, input, borders); 384 | break; 385 | } 386 | return ret; 387 | } 388 | 389 | std::tuple shift2d_backward(const torch::Tensor& grad, 390 | const torch::Tensor& weights, 391 | const torch::Tensor& input, 392 | const torch::Tensor& borders, 393 | int64_t padding_mode, 394 | bool active_flag){ 395 | std::tuple ret; 396 | switch (padding_mode){ 397 | case 0: 398 | ret = active_flag?shiftnd_backward<2,BIPadding::Zeros,true>(grad, weights, input, borders): 399 | shiftnd_backward<2,BIPadding::Zeros,false>(grad, weights, input, borders); 400 | break; 401 | case 1: 402 | ret = active_flag?shiftnd_backward<2,BIPadding::Border,true>(grad, weights, input, borders): 403 | shiftnd_backward<2,BIPadding::Border,false>(grad, weights, input, borders); 404 | break; 405 | case 2: 406 | ret = active_flag?shiftnd_backward<2,BIPadding::Periodic,true>(grad, weights, input, borders): 407 | shiftnd_backward<2,BIPadding::Periodic,false>(grad, weights, input, borders); 408 | break; 409 | case 3: 410 | ret = active_flag?shiftnd_backward<2,BIPadding::Reflect,true>(grad, weights, input, borders): 411 | shiftnd_backward<2,BIPadding::Reflect,false>(grad, weights, input, borders); 412 | break; 413 | case 4: 414 | ret = active_flag?shiftnd_backward<2,BIPadding::Symmetric,true>(grad, weights, input, borders): 415 | shiftnd_backward<2,BIPadding::Symmetric,false>(grad, weights, input, borders); 416 | break; 417 | } 418 | return ret; 419 | } 420 | 421 | std::tuple shift3d_backward(const torch::Tensor& grad, 422 | const torch::Tensor& weights, 423 | const torch::Tensor& input, 424 | const torch::Tensor& borders, 425 | int64_t padding_mode, 426 | bool active_flag){ 427 | std::tuple ret; 428 | switch (padding_mode){ 429 | case 0: 430 | ret = active_flag?shiftnd_backward<3,BIPadding::Zeros,true>(grad, weights, input, borders): 431 | shiftnd_backward<3,BIPadding::Zeros,false>(grad, weights, input, borders); 432 | break; 433 | case 1: 434 | ret = active_flag?shiftnd_backward<3,BIPadding::Border,true>(grad, weights, input, borders): 435 | shiftnd_backward<3,BIPadding::Border,false>(grad, weights, input, borders); 436 | break; 437 | case 2: 438 | ret = active_flag?shiftnd_backward<3,BIPadding::Periodic,true>(grad, weights, input, borders): 439 | shiftnd_backward<3,BIPadding::Periodic,false>(grad, weights, input, borders); 440 | break; 441 | case 3: 442 | ret = active_flag?shiftnd_backward<3,BIPadding::Reflect,true>(grad, weights, input, borders): 443 | shiftnd_backward<3,BIPadding::Reflect,false>(grad, weights, input, borders); 444 | break; 445 | case 4: 446 | ret = active_flag?shiftnd_backward<3,BIPadding::Symmetric,true>(grad, weights, input, borders): 447 | shiftnd_backward<3,BIPadding::Symmetric,false>(grad, weights, input, borders); 448 | break; 449 | } 450 | return ret; 451 | } 452 | 453 | 454 | 455 | } // end of anonymous namespace 456 | 457 | 458 | TORCH_LIBRARY_IMPL(torchshifts, CPU, m) { 459 | m.impl( 460 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_forward"), 461 | TORCH_FN(shift1d_forward)); 462 | m.impl( 463 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_backward"), 464 | TORCH_FN(shift1d_backward)); 465 | m.impl( 466 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_forward"), 467 | TORCH_FN(shift2d_forward)); 468 | m.impl( 469 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_backward"), 470 | TORCH_FN(shift2d_backward)); 471 | m.impl( 472 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_forward"), 473 | TORCH_FN(shift3d_forward)); 474 | m.impl( 475 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_backward"), 476 | TORCH_FN(shift3d_backward)); 477 | } 478 | 479 | } // namespace ops 480 | } // namespace shifts 481 | 482 | 483 | #endif 484 | -------------------------------------------------------------------------------- /torchshifts/csrc/ops/cuda/shifts_cuda.cu: -------------------------------------------------------------------------------- 1 | #ifndef SHIFTS_CUDA 2 | #define SHIFTS_CUDA 3 | 4 | 5 | #include 6 | #include "../global_scope.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | using namespace at::cuda::detail; 20 | 21 | namespace shifts { 22 | namespace ops { 23 | 24 | namespace { 25 | #include "../kernels/shifts_kernels.h" 26 | 27 | 28 | template 31 | C10_LAUNCH_BOUNDS_1(CUDA_THREADS) 32 | __global__ void shiftnd_forward_kernel(const idx_t n_threads, 33 | TensorInfo input, 34 | TensorInfo iweights, 35 | TensorInfo dweights, 36 | TensorInfo borders, 37 | TensorInfo output){ 38 | const idx_t sizeC = input.sizes[1]; 39 | const idx_t sizeH = input.sizes[2]; 40 | const idx_t sizeW = kSpatialDim < 2 ? 1 : input.sizes[3]; 41 | const idx_t sizeD = kSpatialDim < 3 ? 1 : input.sizes[4]; 42 | const idx_t input_sN = input.strides[0]; 43 | const idx_t input_sC = input.strides[1]; 44 | const idx_t input_sH = input.strides[2]; 45 | const idx_t input_sW = kSpatialDim < 2 ? 0 : input.strides[3]; 46 | const idx_t input_sD = kSpatialDim < 3 ? 0 : input.strides[4]; 47 | const idx_t output_sN = output.strides[0]; 48 | const idx_t output_sC = output.strides[1]; 49 | const idx_t output_sH = output.strides[2]; 50 | const idx_t output_sW = kSpatialDim < 2 ? 0 : output.strides[3]; 51 | const idx_t output_sD = kSpatialDim < 3 ? 0 : output.strides[4]; 52 | scalar_t* input_ptr = input.data; 53 | scalar_t* output_ptr = output.data; 54 | idx_t* weights_ptr = iweights.data; 55 | const idx_t weights_sC = iweights.strides[0]; 56 | const idx_t weights_sS = iweights.strides[1]; 57 | scalar_t* dweights_ptr = dweights.data; 58 | const idx_t dweights_sC = dweights.strides[0]; 59 | const idx_t dweights_sS = dweights.strides[1]; 60 | 61 | idx_t* borders_data = borders.data; 62 | const idx_t i_left_border = borders_data[0]; 63 | const idx_t i_right_border = borders_data[1]; 64 | const idx_t j_left_border = kSpatialDim < 2 ? 0 : borders_data[2]; 65 | const idx_t j_right_border = kSpatialDim < 2 ? 1 : borders_data[3]; 66 | const idx_t k_left_border = kSpatialDim < 3 ? 0 : borders_data[4]; 67 | const idx_t k_right_border = kSpatialDim < 3 ? 1 : borders_data[5]; 68 | 69 | const idx_t sizeDW = (kSpatialDim > 1)?(sizeD*sizeW):1; 70 | const idx_t sizeDWH = sizeDW*sizeH; 71 | const idx_t sizeDWHC = sizeDWH*sizeC; 72 | 73 | CUDA_KERNEL_LOOP(index, n_threads){ 74 | const int k = (kSpatialDim > 2)? (index % sizeD):0; 75 | const int j = (kSpatialDim > 1)? ((index / sizeD) % sizeW): 0; 76 | const int i = (index / sizeDW) % sizeH; 77 | const int c = (index / sizeDWH) % sizeC; 78 | const int n = (index / sizeDWHC); 79 | shift_forward_kernel_nchwd( 80 | input_ptr, output_ptr, weights_ptr, dweights_ptr, 81 | n, c, i, j, k, sizeH, sizeW, sizeD, 82 | input_sN, input_sC, input_sH, input_sW, input_sD, 83 | output_sN, output_sC, output_sH, output_sW, output_sD, 84 | weights_sC, weights_sS, dweights_sC, dweights_sS, 85 | i_left_border, j_left_border, k_left_border, 86 | i_right_border, j_right_border, k_right_border); 87 | } 88 | } 89 | 90 | template 93 | C10_LAUNCH_BOUNDS_1(CUDA_THREADS) 94 | __global__ void shiftnd_backward_kernel(const idx_t n_threads, 95 | TensorInfo grad_input, 96 | TensorInfo iweights, 97 | TensorInfo dweights, 98 | TensorInfo input, 99 | TensorInfo borders, 100 | TensorInfo grad_output, 101 | TensorInfo grad_weights) 102 | { 103 | const idx_t sizeC = input.sizes[1]; 104 | const idx_t sizeH = input.sizes[2]; 105 | const idx_t sizeW = kSpatialDim < 2 ? 1 : input.sizes[3]; 106 | const idx_t sizeD = kSpatialDim < 3 ? 1 : input.sizes[4]; 107 | const idx_t grad_input_sN = grad_input.strides[0]; 108 | const idx_t grad_input_sC = grad_input.strides[1]; 109 | const idx_t grad_input_sH = grad_input.strides[2]; 110 | const idx_t grad_input_sW = kSpatialDim < 2 ? 0 : grad_input.strides[3]; 111 | const idx_t grad_input_sD = kSpatialDim < 3 ? 0 : grad_input.strides[4]; 112 | const idx_t input_sN = input.strides[0]; 113 | const idx_t input_sC = input.strides[1]; 114 | const idx_t input_sH = input.strides[2]; 115 | const idx_t input_sW = kSpatialDim < 2 ? 0 : input.strides[3]; 116 | const idx_t input_sD = kSpatialDim < 3 ? 0 : input.strides[4]; 117 | const idx_t grad_output_sN = grad_output.strides[0]; 118 | const idx_t grad_output_sC = grad_output.strides[1]; 119 | const idx_t grad_output_sH = grad_output.strides[2]; 120 | const idx_t grad_output_sW = kSpatialDim < 2 ? 0 : grad_output.strides[3]; 121 | const idx_t grad_output_sD = kSpatialDim < 3 ? 0 : grad_output.strides[4]; 122 | const idx_t grad_weights_sC = grad_weights.strides[0]; 123 | const idx_t grad_weights_sS = grad_weights.strides[1]; 124 | scalar_t* grad_input_ptr = grad_input.data; 125 | scalar_t* input_ptr = input.data; 126 | scalar_t* grad_output_ptr = grad_output.data; 127 | scalar_t* grad_weights_ptr = grad_weights.data; 128 | idx_t* weights_ptr = iweights.data; 129 | const idx_t weights_sC = iweights.strides[0]; 130 | const idx_t weights_sS = iweights.strides[1]; 131 | scalar_t* dweights_ptr = dweights.data; 132 | const idx_t dweights_sC = dweights.strides[0]; 133 | const idx_t dweights_sS = dweights.strides[1]; 134 | 135 | 136 | idx_t* borders_data = borders.data; 137 | const idx_t i_left_border = borders_data[0]; 138 | const idx_t i_right_border = borders_data[1]; 139 | const idx_t j_left_border = kSpatialDim < 2 ? 0 : borders_data[2]; 140 | const idx_t j_right_border = kSpatialDim < 2 ? 1 : borders_data[3]; 141 | const idx_t k_left_border = kSpatialDim < 3 ? 0 : borders_data[4]; 142 | const idx_t k_right_border = kSpatialDim < 3 ? 1 : borders_data[5]; 143 | 144 | const idx_t sizeDW = (kSpatialDim > 1)?(sizeD*sizeW):1; 145 | const idx_t sizeDWH = sizeDW*sizeH; 146 | const idx_t sizeDWHC = sizeDWH*sizeC; 147 | 148 | CUDA_KERNEL_LOOP(index, n_threads){ 149 | const int k = (kSpatialDim > 2)? (index % sizeD):0; 150 | const int j = (kSpatialDim > 1)? ((index / sizeD) % sizeW): 0; 151 | const int i = (index / sizeDW) % sizeH; 152 | const int c = (index / sizeDWH) % sizeC; 153 | const int n = (index / sizeDWHC); 154 | shift_backward_kernel_nchwd( 155 | grad_input_ptr, input_ptr, grad_output_ptr, 156 | weights_ptr, dweights_ptr, grad_weights_ptr, 157 | n, c, i, j, k, sizeC, sizeH, sizeW, sizeD, 158 | grad_input_sN, grad_input_sC, grad_input_sH, grad_input_sW, grad_input_sD, 159 | input_sN, input_sC, input_sH, input_sW, input_sD, 160 | grad_output_sN, grad_output_sC, grad_output_sH, grad_output_sW, grad_output_sD, 161 | weights_sC, weights_sS, dweights_sC, dweights_sS, grad_weights_sC, grad_weights_sS, 162 | i_left_border, j_left_border, k_left_border, 163 | i_right_border, j_right_border, k_right_border); 164 | } 165 | } 166 | 167 | 168 | template 169 | inline void weights_init_forward(const torch::Tensor& weights, 170 | torch::Tensor iweights, 171 | torch::Tensor dweights){ 172 | 173 | torch::TensorIterator iter = torch::TensorIteratorConfig().add_output(iweights) 174 | .add_output(dweights) 175 | .add_input(weights).build(); 176 | 177 | at::native::gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t src) -> thrust::tuple { 178 | scalar_t iw = active?static_cast(FLOOR(src)):static_cast(ROUND(src)); 179 | scalar_t dw = active?(src - iw):static_cast(0); 180 | 181 | return {iw, dw}; 182 | }); 183 | } 184 | 185 | template 186 | inline void weights_init_backward(const torch::Tensor& weights, 187 | torch::Tensor iweights, 188 | torch::Tensor dweights){ 189 | torch::TensorIterator iter = torch::TensorIteratorConfig().add_output(iweights) 190 | .add_output(dweights) 191 | .add_input(weights).build(); 192 | 193 | at::native::gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t src) -> thrust::tuple { 194 | scalar_t dw = active?(src - static_castFLOOR(src)):(src>0)?(src - static_castFLOOR(src)): 195 | (static_castCEIL(src) - src); 196 | scalar_t iw = active?(src-dw):static_cast(ROUND(src)); 197 | return {iw, dw}; 198 | }); 199 | } 200 | 201 | 202 | template 204 | torch::Tensor shiftnd_forward(const torch::Tensor& input, 205 | const torch::Tensor& weights, 206 | const torch::Tensor& borders, 207 | const std::vector& new_size){ 208 | TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); 209 | TORCH_CHECK(weights.is_cuda(), "weights must be a CUDA tensor"); 210 | torch::TensorArg input_t{input, "input", 1}, weights_t{weights, "weights", 2}; 211 | torch::CheckedFrom c = "shiftnd_forward_cuda"; 212 | 213 | torch::checkAllSameGPU(c, {input_t, weights_t}); 214 | torch::checkAllSameType(c, {input_t, weights_t}); 215 | at::cuda::CUDAGuard device_guard(input.device()); 216 | 217 | bool int32bit_cond = canUse32BitIndexMath(input) && canUse32BitIndexMath(weights); 218 | 219 | torch::Tensor _weights = weights.contiguous(LEGACY_CONTIGUOUS_MEMORY_FORMAT); 220 | torch::Tensor _borders = int32bit_cond?borders.to(torch::kInt):borders.to(torch::kLong); 221 | 222 | torch::Tensor output = torch::empty(new_size, input.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); 223 | torch::Tensor _iweights = torch::empty_like(_weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 224 | torch::Tensor dweights = torch::empty_like(_weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 225 | 226 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(_weights.scalar_type(), "weights_init_cuda_forward", [&] { 227 | weights_init_forward(_weights, _iweights, dweights); 228 | }); 229 | torch::Tensor iweights = int32bit_cond?_iweights.to(torch::kInt):_iweights.to(torch::kLong); 230 | 231 | const int64_t N = input.size(0); 232 | const int64_t C = input.size(1); 233 | const int64_t H = input.size(2); 234 | const int64_t W = (nD<2)?1:input.size(3); 235 | const int64_t D = (nD<3)?1:input.size(4); 236 | 237 | const int64_t count = N*C*H*W*D; 238 | const dim3 blocks(CUDA_BLOCKS(count, LOCAL_CUDA_NUM_THREADS), 1, 1); 239 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 240 | 241 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "shiftnd_forward_cuda", [&] { 242 | if (int32bit_cond){ 243 | shiftnd_forward_kernel 244 | <<>>( 245 | static_cast(count), 246 | getTensorInfo(input), 247 | getTensorInfo(iweights), 248 | getTensorInfo(dweights), 249 | getTensorInfo(_borders), 250 | getTensorInfo(output)); 251 | } else { 252 | shiftnd_forward_kernel 253 | <<>>( 254 | count, 255 | getTensorInfo(input), 256 | getTensorInfo(iweights), 257 | getTensorInfo(dweights), 258 | getTensorInfo(_borders), 259 | getTensorInfo(output)); 260 | } 261 | }); 262 | 263 | AT_CUDA_CHECK(cudaGetLastError()); 264 | 265 | return output; 266 | } 267 | 268 | template 270 | std::tuple shiftnd_backward(const torch::Tensor& grad, 271 | const torch::Tensor& weights, 272 | const torch::Tensor& input, 273 | const torch::Tensor& borders) { 274 | at::globalContext().alertNotDeterministic("shiftnd_backward_cuda"); 275 | 276 | TORCH_CHECK(grad.is_cuda(), "grad must be a CUDA tensor"); 277 | TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); 278 | TORCH_CHECK(weights.is_cuda(), "weights must be a CUDA tensor"); 279 | torch::TensorArg grad_t{grad, "grad", 1}, weights_t{weights, "weights", 2}, input_t{input, "input", 3}; 280 | torch::CheckedFrom c = "shiftnd_backward_cuda"; 281 | 282 | torch::checkAllSameGPU(c, {grad_t, input_t, weights_t}); 283 | torch::checkAllSameType(c, {grad_t, input_t, weights_t}); 284 | at::cuda::CUDAGuard device_guard(grad.device()); 285 | 286 | 287 | bool int32bit_cond = canUse32BitIndexMath(grad) && canUse32BitIndexMath(weights) && 288 | canUse32BitIndexMath(input); 289 | 290 | torch::Tensor _weights = weights.contiguous(LEGACY_CONTIGUOUS_MEMORY_FORMAT); 291 | torch::Tensor _borders = int32bit_cond?borders.to(torch::kInt):borders.to(torch::kLong); 292 | 293 | torch::Tensor out_grad = torch::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 294 | torch::Tensor weights_grad = torch::zeros_like(_weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 295 | torch::Tensor _iweights = torch::empty_like(_weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 296 | torch::Tensor dweights = torch::empty_like(_weights, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 297 | 298 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(_weights.scalar_type(), "weights_init_cuda_backward", [&] { 299 | weights_init_backward(_weights, _iweights, dweights); 300 | }); 301 | torch::Tensor iweights = int32bit_cond?_iweights.to(torch::kInt):_iweights.to(torch::kLong); 302 | 303 | 304 | //Yes it's not a mistake, iteration happens under input size 305 | const int64_t N = input.size(0); 306 | const int64_t C = input.size(1); 307 | const int64_t H = input.size(2); 308 | const int64_t W = (nD<2)?1:input.size(3); 309 | const int64_t D = (nD<3)?1:input.size(4); 310 | 311 | const int64_t count = N*C*H*W*D; 312 | const dim3 blocks(CUDA_BLOCKS(count, LOCAL_CUDA_NUM_THREADS), 1, 1); 313 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 314 | 315 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.scalar_type(), "shiftnd_backward_cuda", [&] { 316 | if (int32bit_cond){ 317 | shiftnd_backward_kernel 318 | <<>>( 319 | static_cast(count), 320 | getTensorInfo(grad), 321 | getTensorInfo(iweights), 322 | getTensorInfo(dweights), 323 | getTensorInfo(input), 324 | getTensorInfo(_borders), 325 | getTensorInfo(out_grad), 326 | getTensorInfo(weights_grad)); 327 | } else { 328 | shiftnd_backward_kernel 329 | <<>>( 330 | count, 331 | getTensorInfo(grad), 332 | getTensorInfo(iweights), 333 | getTensorInfo(dweights), 334 | getTensorInfo(input), 335 | getTensorInfo(_borders), 336 | getTensorInfo(out_grad), 337 | getTensorInfo(weights_grad)); 338 | } 339 | }); 340 | 341 | AT_CUDA_CHECK(cudaGetLastError()); 342 | 343 | return std::make_tuple(out_grad, weights_grad); 344 | 345 | } 346 | // TEMPLATE DISPATCHERS 347 | 348 | torch::Tensor shift1d_forward(const torch::Tensor& input, 349 | const torch::Tensor& weights, 350 | const torch::Tensor& borders, 351 | const std::vector& new_size, 352 | int64_t padding_mode, 353 | bool active_flag){ 354 | torch::Tensor ret; 355 | switch (padding_mode){ 356 | case 0: 357 | ret = active_flag?shiftnd_forward<1,BIPadding::Zeros,true>(input, weights, borders, new_size): 358 | shiftnd_forward<1,BIPadding::Zeros,false>(input, weights, borders, new_size); 359 | break; 360 | case 1: 361 | ret = active_flag?shiftnd_forward<1,BIPadding::Border,true>(input, weights, borders, new_size): 362 | shiftnd_forward<1,BIPadding::Border,false>(input, weights, borders, new_size); 363 | break; 364 | case 2: 365 | ret = active_flag?shiftnd_forward<1,BIPadding::Periodic,true>(input, weights, borders, new_size): 366 | shiftnd_forward<1,BIPadding::Periodic,false>(input, weights, borders, new_size); 367 | break; 368 | case 3: 369 | ret = active_flag?shiftnd_forward<1,BIPadding::Reflect,true>(input, weights, borders, new_size): 370 | shiftnd_forward<1,BIPadding::Reflect,false>(input, weights, borders, new_size); 371 | break; 372 | case 4: 373 | ret = active_flag?shiftnd_forward<1,BIPadding::Symmetric,true>(input, weights, borders, new_size): 374 | shiftnd_forward<1,BIPadding::Symmetric,false>(input, weights, borders, new_size); 375 | break; 376 | } 377 | return ret; 378 | } 379 | 380 | torch::Tensor shift2d_forward(const torch::Tensor& input, 381 | const torch::Tensor& weights, 382 | const torch::Tensor& borders, 383 | const std::vector& new_size, 384 | int64_t padding_mode, 385 | bool active_flag){ 386 | torch::Tensor ret; 387 | switch (padding_mode){ 388 | case 0: 389 | ret = active_flag?shiftnd_forward<2,BIPadding::Zeros,true>(input, weights, borders, new_size): 390 | shiftnd_forward<2,BIPadding::Zeros,false>(input, weights, borders, new_size); 391 | break; 392 | case 1: 393 | ret = active_flag?shiftnd_forward<2,BIPadding::Border,true>(input, weights, borders, new_size): 394 | shiftnd_forward<2,BIPadding::Border,false>(input, weights, borders, new_size); 395 | break; 396 | case 2: 397 | ret = active_flag?shiftnd_forward<2,BIPadding::Periodic,true>(input, weights, borders, new_size): 398 | shiftnd_forward<2,BIPadding::Periodic,false>(input, weights, borders, new_size); 399 | break; 400 | case 3: 401 | ret = active_flag?shiftnd_forward<2,BIPadding::Reflect,true>(input, weights, borders, new_size): 402 | shiftnd_forward<2,BIPadding::Reflect,false>(input, weights, borders, new_size); 403 | break; 404 | case 4: 405 | ret = active_flag?shiftnd_forward<2,BIPadding::Symmetric,true>(input, weights, borders, new_size): 406 | shiftnd_forward<2,BIPadding::Symmetric,false>(input, weights, borders, new_size); 407 | break; 408 | } 409 | return ret; 410 | } 411 | 412 | torch::Tensor shift3d_forward(const torch::Tensor& input, 413 | const torch::Tensor& weights, 414 | const torch::Tensor& borders, 415 | const std::vector& new_size, 416 | int64_t padding_mode, 417 | bool active_flag){ 418 | torch::Tensor ret; 419 | switch (padding_mode){ 420 | case 0: 421 | ret = active_flag?shiftnd_forward<3,BIPadding::Zeros,true>(input, weights, borders, new_size): 422 | shiftnd_forward<3,BIPadding::Zeros,false>(input, weights, borders, new_size); 423 | break; 424 | case 1: 425 | ret = active_flag?shiftnd_forward<3,BIPadding::Border,true>(input, weights, borders, new_size): 426 | shiftnd_forward<3,BIPadding::Border,false>(input, weights, borders, new_size); 427 | break; 428 | case 2: 429 | ret = active_flag?shiftnd_forward<3,BIPadding::Periodic,true>(input, weights, borders, new_size): 430 | shiftnd_forward<3,BIPadding::Periodic,false>(input, weights, borders, new_size); 431 | break; 432 | case 3: 433 | ret = active_flag?shiftnd_forward<3,BIPadding::Reflect,true>(input, weights, borders, new_size): 434 | shiftnd_forward<3,BIPadding::Reflect,false>(input, weights, borders, new_size); 435 | break; 436 | case 4: 437 | ret = active_flag?shiftnd_forward<3,BIPadding::Symmetric,true>(input, weights, borders, new_size): 438 | shiftnd_forward<3,BIPadding::Symmetric,false>(input, weights, borders, new_size); 439 | break; 440 | } 441 | return ret; 442 | } 443 | 444 | 445 | std::tuple shift1d_backward(const torch::Tensor& grad, 446 | const torch::Tensor& weights, 447 | const torch::Tensor& input, 448 | const torch::Tensor& borders, 449 | int64_t padding_mode, 450 | bool active_flag){ 451 | std::tuple ret; 452 | switch (padding_mode){ 453 | case 0: 454 | ret = active_flag?shiftnd_backward<1,BIPadding::Zeros,true>(grad, weights, input, borders): 455 | shiftnd_backward<1,BIPadding::Zeros,false>(grad, weights, input, borders); 456 | break; 457 | case 1: 458 | ret = active_flag?shiftnd_backward<1,BIPadding::Border,true>(grad, weights, input, borders): 459 | shiftnd_backward<1,BIPadding::Border,false>(grad, weights, input, borders); 460 | break; 461 | case 2: 462 | ret = active_flag?shiftnd_backward<1,BIPadding::Periodic,true>(grad, weights, input, borders): 463 | shiftnd_backward<1,BIPadding::Periodic,false>(grad, weights, input, borders); 464 | break; 465 | case 3: 466 | ret = active_flag?shiftnd_backward<1,BIPadding::Reflect,true>(grad, weights, input, borders): 467 | shiftnd_backward<1,BIPadding::Reflect,false>(grad, weights, input, borders); 468 | break; 469 | case 4: 470 | ret = active_flag?shiftnd_backward<1,BIPadding::Symmetric,true>(grad, weights, input, borders): 471 | shiftnd_backward<1,BIPadding::Symmetric,false>(grad, weights, input, borders); 472 | break; 473 | } 474 | return ret; 475 | } 476 | 477 | std::tuple shift2d_backward(const torch::Tensor& grad, 478 | const torch::Tensor& weights, 479 | const torch::Tensor& input, 480 | const torch::Tensor& borders, 481 | int64_t padding_mode, 482 | bool active_flag){ 483 | std::tuple ret; 484 | switch (padding_mode){ 485 | case 0: 486 | ret = active_flag?shiftnd_backward<2,BIPadding::Zeros,true>(grad, weights, input, borders): 487 | shiftnd_backward<2,BIPadding::Zeros,false>(grad, weights, input, borders); 488 | break; 489 | case 1: 490 | ret = active_flag?shiftnd_backward<2,BIPadding::Border,true>(grad, weights, input, borders): 491 | shiftnd_backward<2,BIPadding::Border,false>(grad, weights, input, borders); 492 | break; 493 | case 2: 494 | ret = active_flag?shiftnd_backward<2,BIPadding::Periodic,true>(grad, weights, input, borders): 495 | shiftnd_backward<2,BIPadding::Periodic,false>(grad, weights, input, borders); 496 | break; 497 | case 3: 498 | ret = active_flag?shiftnd_backward<2,BIPadding::Reflect,true>(grad, weights, input, borders): 499 | shiftnd_backward<2,BIPadding::Reflect,false>(grad, weights, input, borders); 500 | break; 501 | case 4: 502 | ret = active_flag?shiftnd_backward<2,BIPadding::Symmetric,true>(grad, weights, input, borders): 503 | shiftnd_backward<2,BIPadding::Symmetric,false>(grad, weights, input, borders); 504 | break; 505 | } 506 | return ret; 507 | } 508 | 509 | std::tuple shift3d_backward(const torch::Tensor& grad, 510 | const torch::Tensor& weights, 511 | const torch::Tensor& input, 512 | const torch::Tensor& borders, 513 | int64_t padding_mode, 514 | bool active_flag){ 515 | std::tuple ret; 516 | switch (padding_mode){ 517 | case 0: 518 | ret = active_flag?shiftnd_backward<3,BIPadding::Zeros,true>(grad, weights, input, borders): 519 | shiftnd_backward<3,BIPadding::Zeros,false>(grad, weights, input, borders); 520 | break; 521 | case 1: 522 | ret = active_flag?shiftnd_backward<3,BIPadding::Border,true>(grad, weights, input, borders): 523 | shiftnd_backward<3,BIPadding::Border,false>(grad, weights, input, borders); 524 | break; 525 | case 2: 526 | ret = active_flag?shiftnd_backward<3,BIPadding::Periodic,true>(grad, weights, input, borders): 527 | shiftnd_backward<3,BIPadding::Periodic,false>(grad, weights, input, borders); 528 | break; 529 | case 3: 530 | ret = active_flag?shiftnd_backward<3,BIPadding::Reflect,true>(grad, weights, input, borders): 531 | shiftnd_backward<3,BIPadding::Reflect,false>(grad, weights, input, borders); 532 | break; 533 | case 4: 534 | ret = active_flag?shiftnd_backward<3,BIPadding::Symmetric,true>(grad, weights, input, borders): 535 | shiftnd_backward<3,BIPadding::Symmetric,false>(grad, weights, input, borders); 536 | break; 537 | } 538 | return ret; 539 | } 540 | 541 | 542 | } // end of anonymous namespace 543 | 544 | 545 | TORCH_LIBRARY_IMPL(torchshifts, CUDA, m) { 546 | m.impl( 547 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_forward"), 548 | TORCH_FN(shift1d_forward)); 549 | m.impl( 550 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_backward"), 551 | TORCH_FN(shift1d_backward)); 552 | m.impl( 553 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_forward"), 554 | TORCH_FN(shift2d_forward)); 555 | m.impl( 556 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_backward"), 557 | TORCH_FN(shift2d_backward)); 558 | m.impl( 559 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_forward"), 560 | TORCH_FN(shift3d_forward)); 561 | m.impl( 562 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_backward"), 563 | TORCH_FN(shift3d_backward)); 564 | } 565 | 566 | } // namespace ops 567 | } // namespace shifts 568 | 569 | #endif -------------------------------------------------------------------------------- /torchshifts/csrc/ops/global_scope.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCHSHIFTS_GLOBAL_SCOPE 2 | #define TORCHSHIFTS_GLOBAL_SCOPE 3 | 4 | 5 | #define CUDA_THREADS 1024 6 | 7 | 8 | 9 | #include 10 | #ifdef SHIFTS_CPU 11 | 12 | #define ROUND(a) (std::round(a)) 13 | #define FASTROUND(a) (std::nearbyint(a)) 14 | #define FLOOR(a) (std::floor(a)) 15 | #define CEIL(a) (std::ceil(a)) 16 | #define MIN(a,b) (std::min(a,b)) 17 | #define MAX(a,b) (std::max(a,b)) 18 | #define ABS(a) (std::abs(a)) 19 | #define STDLIB std 20 | #define FMIN(a,b) (std::fmin(a, b)) 21 | #define FMAX(a,b) (std::fmax(a, b)) 22 | #define ADD(tensor, idx, numel, val) ( *(tensor+idx)+=val ) 23 | #define API_DEVICE 24 | #define API_HOST 25 | #if (defined __cpp_inline_variables) || __cplusplus >= 201703L 26 | #define API_INLINE inline 27 | #else 28 | #ifdef _MSC_VER 29 | #define API_INLINE __inline 30 | #else 31 | #define API_INLINE __attribute__((weak)) 32 | #endif 33 | #endif 34 | 35 | #endif 36 | 37 | #ifdef SHIFTS_CUDA 38 | 39 | #define ROUND(a) (::round(a)) 40 | #define FASTROUND(a) (::nearbyint(a)) 41 | #define FLOOR(a) (::floor(a)) 42 | #define CEIL(a) (::ceil(a)) 43 | #define MIN(a,b) (::min(a,b)) 44 | #define MAX(a,b) (::max(a,b)) 45 | #define ABS(a) (::abs(a)) 46 | #define STDLIB thrust 47 | #include 48 | #define ADD(tensor, idx, numel, val) ( at::native::fastSpecializedAtomicAdd(tensor, idx, numel, val)) 49 | #define API_INLINE __forceinline__ 50 | #define API_DEVICE __device__ 51 | #define API_HOST __host__ 52 | #define FMIN(a,b) (::fminf(a, b)) 53 | #define FMAX(a,b) (::fmaxf(a, b)) 54 | const int LOCAL_CUDA_NUM_THREADS = CUDA_THREADS; 55 | // taken from PyTorch 56 | inline int CUDA_BLOCKS(const int64_t N, const int64_t NUMTHREADS) 57 | { 58 | return static_cast(N / NUMTHREADS + ((N % NUMTHREADS) == 0 ? 0 : 1)); 59 | } 60 | 61 | #endif 62 | 63 | 64 | 65 | #endif 66 | 67 | -------------------------------------------------------------------------------- /torchshifts/csrc/ops/kernels/interpolation.h: -------------------------------------------------------------------------------- 1 | #include "../global_scope.h" 2 | 3 | template 4 | API_DEVICE API_INLINE scalar_t interp1D(scalar_t v1, scalar_t v2, scalar_t x) 5 | { 6 | return v1*(1 - x) + v2*x; 7 | } 8 | 9 | template 10 | API_DEVICE API_INLINE scalar_t interp1D_dx(scalar_t v1, scalar_t v2) 11 | { 12 | return v2 - v1; 13 | } 14 | 15 | template 16 | API_DEVICE API_INLINE scalar_t interp2D(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, scalar_t x, scalar_t y) 17 | { 18 | return interp1D(interp1D(v1, v2, x), interp1D(v3, v4, x), y); 19 | } 20 | 21 | template 22 | API_DEVICE API_INLINE scalar_t interp2D_dx(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, scalar_t y) 23 | { 24 | return interp1D(interp1D_dx(v1, v3), interp1D_dx(v2, v4), y); 25 | } 26 | 27 | template 28 | API_DEVICE API_INLINE scalar_t interp2D_dy(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, scalar_t x) 29 | { 30 | return interp1D_dx(interp1D(v1, v2, x), interp1D(v3, v4, x)); 31 | } 32 | 33 | template 34 | API_DEVICE API_INLINE scalar_t interp3D(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, 35 | scalar_t v5, scalar_t v6, scalar_t v7, scalar_t v8, 36 | scalar_t x, scalar_t y, scalar_t z){ 37 | return interp1D(interp2D(v1, v2, v3, v4, x, y), interp2D(v5, v6, v7, v8, x, y), z); 38 | } 39 | 40 | template 41 | API_DEVICE API_INLINE scalar_t interp3D_dx(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, 42 | scalar_t v5, scalar_t v6, scalar_t v7, scalar_t v8, 43 | scalar_t y, scalar_t z) 44 | { 45 | return interp1D(interp2D_dx(v1, v2, v3, v4, y), interp2D_dx(v5, v6, v7, v8, y), z); 46 | } 47 | 48 | template 49 | API_DEVICE API_INLINE scalar_t interp3D_dy(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, 50 | scalar_t v5, scalar_t v6, scalar_t v7, scalar_t v8, 51 | scalar_t x, scalar_t z) 52 | { 53 | return interp1D(interp2D_dy(v1, v2, v3, v4, x), interp2D_dy(v5, v6, v7, v8, x), z); 54 | } 55 | 56 | template 57 | API_DEVICE API_INLINE scalar_t interp3D_dz(scalar_t v1, scalar_t v2, scalar_t v3, scalar_t v4, 58 | scalar_t v5, scalar_t v6, scalar_t v7, scalar_t v8, 59 | scalar_t x, scalar_t y) 60 | { 61 | return interp1D_dx(interp2D(v1, v2, v3, v4, x, y), interp2D(v5, v6, v7, v8, x, y)); 62 | } -------------------------------------------------------------------------------- /torchshifts/csrc/ops/kernels/shifts_kernels.h: -------------------------------------------------------------------------------- 1 | #include "../global_scope.h" 2 | #include "interpolation.h" 3 | 4 | 5 | enum class BIPadding {Zeros, Border, Periodic, Reflect, Symmetric}; 6 | 7 | template 8 | API_DEVICE API_INLINE T mod(const T a, const T b){return (b + (a % b)) % b;} 9 | 10 | template 11 | API_DEVICE API_INLINE idx_t infer_index(const idx_t index, const idx_t len){ 12 | bool odd_seq; 13 | switch (padding_mode){ 14 | case BIPadding::Zeros: 15 | return (index > len - 1)?-1:index; 16 | case BIPadding::Border: 17 | return MIN(len-1,MAX(index,static_cast(0))); 18 | case BIPadding::Periodic: 19 | return mod(index, len); 20 | case BIPadding::Reflect: 21 | odd_seq = static_cast((static_cast(index<0) + (ABS(index)- static_cast(index<0))/ (len-1)) & 1); 22 | return odd_seq?(len - 1 - mod(index, len - 1)):mod(index, len - 1); 23 | case BIPadding::Symmetric: 24 | odd_seq = static_cast((static_cast(index<0) + (ABS(index)- static_cast(index<0))/ len) & 1); 25 | return odd_seq?(len - 1 - mod(index, len)):mod(index, len); 26 | default: 27 | return (index > len - 1)?-1:index; 28 | } 29 | } 30 | 31 | 32 | template 35 | API_DEVICE API_INLINE scalar_t get_shifted_value(const idx_t i_shifted, const idx_t sizeH, const idx_t strideH, 36 | const idx_t j_shifted, const idx_t sizeW, const idx_t strideW, 37 | const idx_t k_shifted, const idx_t sizeD, const idx_t strideD, 38 | const idx_t c, const idx_t strideC, const bool out_passcond, 39 | const scalar_t* const array, const scalar_t zero_point){ 40 | const idx_t tidx_i = (sizeH==1)?0:infer_index(i_shifted, sizeH); 41 | const idx_t pass_cond_i = static_cast(tidx_i>=0); 42 | const idx_t isH = tidx_i*strideH*pass_cond_i; 43 | 44 | const idx_t tidx_j = (kSpatialDim > 1)?((sizeW==1)?0:infer_index(j_shifted, sizeW)):0; 45 | const idx_t pass_cond_j = (kSpatialDim > 1)?(static_cast(tidx_j>=0)*pass_cond_i):pass_cond_i; 46 | const idx_t isW = (kSpatialDim > 1)?tidx_j*strideW*pass_cond_j:0; 47 | 48 | const idx_t tidx_k = (kSpatialDim > 2)?((sizeD==1)?0:infer_index(k_shifted, sizeD)):0; 49 | const idx_t pass_cond_k = (kSpatialDim > 2)?(static_cast(tidx_k>=0)*pass_cond_j):pass_cond_j; 50 | const idx_t isD = (kSpatialDim > 2)?tidx_k*strideD*pass_cond_k:0; 51 | 52 | const bool pass_cond = static_cast(pass_cond_k)&&out_passcond; 53 | return pass_cond?array[isH+isW+isD+c*strideC]:zero_point; 54 | } 55 | 56 | 57 | 58 | template 61 | API_DEVICE API_INLINE void get_shifted_values(const idx_t i_shifted, const idx_t sizeH, const idx_t strideH, 62 | const idx_t j_shifted, const idx_t sizeW, const idx_t strideW, 63 | const idx_t k_shifted, const idx_t sizeD, const idx_t strideD, 64 | const idx_t c, const idx_t strideC, const bool out_passcond, 65 | const scalar_t* const array, const scalar_t zero_point, 66 | scalar_t* const output_values){ 67 | output_values[0] = get_shifted_value( 68 | i_shifted, sizeH, strideH, j_shifted, sizeW, strideW, 69 | k_shifted, sizeD, strideD, c, strideC, 70 | out_passcond, array, zero_point); 71 | output_values[1] = get_shifted_value( 72 | i_shifted+1, sizeH, strideH, j_shifted, sizeW, strideW, 73 | k_shifted, sizeD, strideD, c, strideC, 74 | out_passcond, array, zero_point); 75 | if (kSpatialDim > 1){ 76 | output_values[2] = get_shifted_value( 77 | i_shifted, sizeH, strideH, j_shifted+1, sizeW, strideW, 78 | k_shifted, sizeD, strideD, c, strideC, 79 | out_passcond, array, zero_point); 80 | output_values[3] = get_shifted_value( 81 | i_shifted+1, sizeH, strideH, j_shifted+1, sizeW, strideW, 82 | k_shifted, sizeD, strideD, c, strideC, 83 | out_passcond, array, zero_point); 84 | } 85 | if (kSpatialDim > 2){ 86 | output_values[4] = get_shifted_value( 87 | i_shifted, sizeH, strideH, j_shifted, sizeW, strideW, 88 | k_shifted+1, sizeD, strideD, c, strideC, 89 | out_passcond, array, zero_point); 90 | output_values[5] = get_shifted_value( 91 | i_shifted+1, sizeH, strideH, j_shifted, sizeW, strideW, 92 | k_shifted+1, sizeD, strideD, c, strideC, 93 | out_passcond, array, zero_point); 94 | output_values[6] = get_shifted_value( 95 | i_shifted, sizeH, strideH, j_shifted+1, sizeW, strideW, 96 | k_shifted+1, sizeD, strideD, c, strideC, 97 | out_passcond, array, zero_point); 98 | output_values[7] = get_shifted_value( 99 | i_shifted+1, sizeH, strideH, j_shifted+1, sizeW, strideW, 100 | k_shifted+1, sizeD, strideD, c, strideC, 101 | out_passcond, array, zero_point); 102 | } 103 | } 104 | 105 | template 106 | API_DEVICE API_INLINE scalar_t rev_shift(const scalar_t diff_shift){ 107 | return (reverse)?(static_cast(1)-diff_shift):diff_shift; 108 | } 109 | 110 | template 111 | API_DEVICE API_INLINE scalar_t compute_interpolated(const scalar_t* const v, const scalar_t diff_shiftH, 112 | const scalar_t diff_shiftW, const scalar_t diff_shiftD, 113 | const bool pass_cond, const scalar_t zp){ 114 | switch (kSpatialDim){ 115 | case 3: 116 | return pass_cond?interp3D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], 117 | rev_shift(diff_shiftH), 118 | rev_shift(diff_shiftW), 119 | rev_shift(diff_shiftD)): 120 | zp; 121 | case 2: 122 | return pass_cond?interp2D(v[0], v[1], v[2], v[3], 123 | rev_shift(diff_shiftH), 124 | rev_shift(diff_shiftW)): 125 | zp; 126 | default: 127 | return pass_cond?interp1D(v[0], v[1], rev_shift(diff_shiftH)): 128 | zp; 129 | } 130 | } 131 | 132 | template 133 | API_DEVICE API_INLINE void compute_weight_gradients(const scalar_t* const v, const scalar_t diff_shiftH, const scalar_t diff_shiftW, const scalar_t diff_shiftD, 134 | scalar_t* const output_grad, const bool pass_cond, const scalar_t zp){ 135 | switch (kSpatialDim){ 136 | case 3: 137 | output_grad[0]=pass_cond?interp3D_dx(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], 138 | diff_shiftW, diff_shiftD):zp; 139 | output_grad[1]=pass_cond?interp3D_dy(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], 140 | diff_shiftH, diff_shiftD):zp; 141 | output_grad[2]=pass_cond?interp3D_dz(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7], 142 | diff_shiftH, diff_shiftW):zp; 143 | break; 144 | case 2: 145 | output_grad[0]=pass_cond?interp2D_dx(v[0], v[1], v[2], v[3], 146 | diff_shiftW):zp; 147 | output_grad[1]=pass_cond?interp2D_dy(v[0], v[1], v[2], v[3], 148 | diff_shiftH):zp; 149 | break; 150 | case 1: 151 | output_grad[0]=pass_cond?interp1D_dx(v[0], v[1]):zp; 152 | break; 153 | } 154 | } 155 | 156 | template 160 | API_DEVICE API_INLINE void shift_forward_kernel_nchwd(const scalar_t* const input, scalar_t* const output, 161 | const idx_t* const weights, const scalar_t* const dweights, 162 | const idx_t n, const idx_t c, const idx_t i, const idx_t j, const idx_t k, 163 | const idx_t sizeH, const idx_t sizeW, const idx_t sizeD, 164 | const idx_t input_sN, const idx_t input_sC, const idx_t input_sH, const idx_t input_sW, const idx_t input_sD, 165 | const idx_t output_sN, const idx_t output_sC, const idx_t output_sH, const idx_t output_sW, const idx_t output_sD, 166 | const idx_t weights_sC, const idx_t weights_sS, const idx_t dweights_sC, const idx_t dweights_sS, 167 | const idx_t i_left_border, const idx_t j_left_border, const idx_t k_left_border, 168 | const idx_t i_right_border, const idx_t j_right_border, const idx_t k_right_border){ 169 | const scalar_t* const input_NC = input + n*input_sN + c*input_sC; 170 | const scalar_t zp = static_cast(0); 171 | 172 | const idx_t oi = i - i_left_border; 173 | const idx_t oj = (kSpatialDim > 1) ? (j - j_left_border) : j; 174 | const idx_t ok = (kSpatialDim > 2) ? (k - k_left_border) : k; 175 | 176 | const idx_t si = i - *(weights+c*weights_sC); 177 | const idx_t sj = (kSpatialDim > 1) ? (j - *(weights+c*weights_sC+weights_sS)) : j; 178 | const idx_t sk = (kSpatialDim > 2) ? (k - *(weights+c*weights_sC+2*weights_sS)) : k; 179 | 180 | 181 | const bool pass_cond_i = (i >= i_left_border)&&(i < i_right_border); 182 | const bool pass_cond_j = (j >= j_left_border)&&(j < j_right_border); 183 | const bool pass_cond_k = (k >= k_left_border)&&(k < k_right_border); 184 | const bool pass_cond = pass_cond_i&&pass_cond_j&&pass_cond_k; 185 | 186 | if (pass_cond){ 187 | if (active) 188 | { 189 | const scalar_t di = *(dweights + c*dweights_sC); 190 | const scalar_t dj = (kSpatialDim > 1) ? *(dweights + c*dweights_sC + dweights_sS) : zp; 191 | const scalar_t dk = (kSpatialDim > 2) ? *(dweights + c*dweights_sC + 2*dweights_sS): zp; 192 | scalar_t vals_array[8] = {zp, zp, zp, zp, zp, zp, zp, zp}; 193 | get_shifted_values( 194 | si, sizeH, input_sH, 195 | sj, sizeW, input_sW, 196 | sk, sizeD, input_sD, 197 | 0, 0, true, 198 | input_NC, zp, vals_array); 199 | *(output + n*output_sN + 200 | c*output_sC + 201 | oi*output_sH + 202 | oj*output_sW + 203 | ok*output_sD) = compute_interpolated( 204 | vals_array, di, dj, dk, 205 | true, zp); 206 | } 207 | else { 208 | *(output + n*output_sN + 209 | c*output_sC + 210 | oi*output_sH + 211 | oj*output_sW + 212 | ok*output_sD) = get_shifted_value( 213 | si, sizeH, input_sH, 214 | sj, sizeW, input_sW, 215 | sk, sizeD, input_sD, 216 | 0, 0, true, 217 | input_NC, zp); 218 | } 219 | } 220 | } 221 | 222 | template 226 | API_DEVICE API_INLINE void shift_backward_kernel_nchwd(const scalar_t* const input_grad, const scalar_t* const input, scalar_t* const output_grad, 227 | const idx_t* const weights, const scalar_t* const dweights, scalar_t* const weights_grad, 228 | const idx_t n, const idx_t c, const idx_t i, const idx_t j, const idx_t k, 229 | const idx_t sizeC, const idx_t sizeH, const idx_t sizeW, const idx_t sizeD, 230 | const idx_t input_grad_sN, const idx_t input_grad_sC, const idx_t input_grad_sH, 231 | const idx_t input_grad_sW, const idx_t input_grad_sD, 232 | const idx_t input_sN, const idx_t input_sC, const idx_t input_sH, 233 | const idx_t input_sW, const idx_t input_sD, 234 | const idx_t output_grad_sN, const idx_t output_grad_sC, const idx_t output_grad_sH, 235 | const idx_t output_grad_sW, const idx_t output_grad_sD, 236 | const idx_t weights_sC, const idx_t weights_sS, 237 | const idx_t dweights_sC, const idx_t dweights_sS, 238 | const idx_t weights_grad_sC, const idx_t weights_grad_sS, 239 | const idx_t i_left_border, const idx_t j_left_border, const idx_t k_left_border, 240 | const idx_t i_right_border, const idx_t j_right_border, const idx_t k_right_border){ 241 | // i,j,k - from input 242 | const scalar_t* const input_grad_NC = input_grad + n*input_grad_sN + c*input_grad_sC; 243 | const scalar_t* const input_NC = input + n*input_sN + c*input_sC; 244 | const idx_t weights_numel = kSpatialDim * sizeC; 245 | const scalar_t zp = static_cast(0); 246 | 247 | const idx_t shifti = *(weights+c*weights_sC); 248 | const idx_t shiftj = (kSpatialDim > 1)?*(weights+c*weights_sC + weights_sS):0; 249 | const idx_t shiftk = (kSpatialDim > 2)?*(weights + c*weights_sC + 2*weights_sS):0; 250 | 251 | const scalar_t di = *(dweights + c*dweights_sC); 252 | const scalar_t dj = (kSpatialDim > 1)?*(dweights + c*dweights_sC + dweights_sS):zp; 253 | const scalar_t dk = (kSpatialDim > 2)?*(dweights + c*dweights_sC + 2*dweights_sS):zp; 254 | 255 | const idx_t si = i - shifti; 256 | const idx_t sj = (kSpatialDim > 1) ? (j - shiftj) : j; 257 | const idx_t sk = (kSpatialDim > 2) ? (k - shiftk) : k; 258 | 259 | const bool pass_cond_i = (i >= i_left_border)&&(i < i_right_border); 260 | const bool pass_cond_j = (j >= j_left_border)&&(j < j_right_border); 261 | const bool pass_cond_k = (k >= k_left_border)&&(k < k_right_border); 262 | const bool pass_cond = pass_cond_i&&pass_cond_j&&pass_cond_k; 263 | 264 | 265 | const idx_t oi = i - i_left_border; 266 | const idx_t oj = (kSpatialDim > 1) ? (j - j_left_border) : j; 267 | const idx_t ok = (kSpatialDim > 2) ? (k - k_left_border) : k; 268 | 269 | scalar_t vals_array[8] = {zp, zp, zp, zp, zp, zp, zp, zp}; 270 | scalar_t new_weights_grad[3] = {zp, zp, zp}; 271 | const scalar_t input_grad_NCHWD_val = pass_cond?input_grad_NC[oi*input_grad_sH + oj*input_grad_sW + ok*input_grad_sD]:zp; 272 | 273 | // weight gradients 274 | get_shifted_values( 275 | si, sizeH, input_sH, 276 | sj, sizeW, input_sW, 277 | sk, sizeD, input_sD, 278 | 0, 0, pass_cond, 279 | input_NC, zp, vals_array); 280 | compute_weight_gradients(vals_array, di, dj, dk, new_weights_grad, pass_cond, zp); 281 | ADD(weights_grad, c*weights_grad_sC, weights_numel, input_grad_NCHWD_val * new_weights_grad[0]); 282 | if (kSpatialDim > 1){ ADD(weights_grad, c*weights_grad_sC + weights_grad_sS, weights_numel, input_grad_NCHWD_val * new_weights_grad[1]); } 283 | if (kSpatialDim > 2){ ADD(weights_grad, c*weights_grad_sC + 2*weights_grad_sS, weights_numel, input_grad_NCHWD_val * new_weights_grad[2]); } 284 | 285 | // input gradient 286 | 287 | const idx_t rsi = oi + shifti; 288 | const idx_t rsj = (kSpatialDim > 1)?(oj + shiftj):oj; 289 | const idx_t rsk = (kSpatialDim > 2)?(ok + shiftk):ok; 290 | 291 | const idx_t osi = oi - shifti; 292 | const idx_t osj = (kSpatialDim > 1) ? (oj - shiftj) : oj; 293 | const idx_t osk = (kSpatialDim > 2) ? (ok - shiftk) : ok; 294 | 295 | const idx_t osizeH = i_right_border - i_left_border; 296 | const idx_t osizeW = j_right_border - j_left_border; 297 | const idx_t osizeD = k_right_border - k_left_border; 298 | 299 | if (active) 300 | { 301 | get_shifted_values( 302 | osi, osizeH, input_grad_sH, 303 | osj, osizeW, input_grad_sW, 304 | osk, osizeD, input_grad_sD, 305 | 0, 0, pass_cond, 306 | input_grad_NC, zp, vals_array); 307 | *(output_grad + n*output_grad_sN + 308 | c*output_grad_sC + 309 | i*output_grad_sH + 310 | j*output_grad_sW + 311 | k*output_grad_sD) = compute_interpolated( 312 | vals_array, di, dj, dk, pass_cond, zp); 313 | } 314 | else { 315 | *(output_grad + n*output_grad_sN + 316 | c*output_grad_sC + 317 | i*output_grad_sH + 318 | j*output_grad_sW + 319 | k*output_grad_sD) = get_shifted_value( 320 | rsi, osizeH, input_grad_sH, 321 | rsj, osizeW, input_grad_sW, 322 | rsk, osizeD, input_grad_sD, 323 | 0, 0, pass_cond, 324 | input_grad_NC, zp); 325 | } 326 | 327 | } 328 | 329 | 330 | template 334 | API_DEVICE API_INLINE void shift_forward_kernel_nhwdc(const scalar_t* const input, scalar_t* const output, 335 | const idx_t* const weights, const scalar_t* const dweights, 336 | const idx_t n, const idx_t i, const idx_t j, const idx_t k, 337 | const idx_t sizeC, const idx_t sizeH, const idx_t sizeW, const idx_t sizeD, 338 | const idx_t input_sN, const idx_t input_sC, const idx_t input_sH, const idx_t input_sW, const idx_t input_sD, 339 | const idx_t output_sN, const idx_t output_sC, const idx_t output_sH, const idx_t output_sW, const idx_t output_sD, 340 | const idx_t weights_sC, const idx_t weights_sS, const idx_t dweights_sC, const idx_t dweights_sS, 341 | const idx_t i_left_border, const idx_t j_left_border, const idx_t k_left_border, 342 | const idx_t i_right_border, const idx_t j_right_border, const idx_t k_right_border){ 343 | const scalar_t* input_N = input + n*input_sN; 344 | const scalar_t zp = static_cast(0); 345 | 346 | const idx_t oi = i - i_left_border; 347 | const idx_t oj = (kSpatialDim > 1) ? j - j_left_border : j; 348 | const idx_t ok = (kSpatialDim > 2) ? k - k_left_border : k; 349 | 350 | const idx_t* w_S = (kSpatialDim > 1) ? (weights+weights_sS) : nullptr; 351 | const idx_t* w_2S = (kSpatialDim > 2) ? (weights+2*weights_sS) : nullptr; 352 | const scalar_t* dw_S = (kSpatialDim > 1) ? (dweights+dweights_sS) : nullptr; 353 | const scalar_t* dw_2S = (kSpatialDim > 2) ? (dweights+2*dweights_sS) : nullptr; 354 | 355 | const bool pass_cond_i = (i >= i_left_border)&&(i < i_right_border); 356 | const bool pass_cond_j = (j >= j_left_border)&&(j < j_right_border); 357 | const bool pass_cond_k = (k >= k_left_border)&&(k < k_right_border); 358 | const bool pass_cond = pass_cond_i&&pass_cond_j&&pass_cond_k; 359 | 360 | if (pass_cond){ 361 | scalar_t val; 362 | idx_t si = i; 363 | idx_t sj = j; 364 | idx_t sk = k; 365 | scalar_t di = zp; 366 | scalar_t dj = zp; 367 | scalar_t dk = zp; 368 | scalar_t *output_NHWD = output + n*output_sN + oi*output_sH + oj*output_sW + ok*output_sD; 369 | for (idx_t c = 0; c < sizeC; c++) 370 | { 371 | si = i - *(weights+c*weights_sC); 372 | if (kSpatialDim > 1) { sj = j - *(w_S+c*weights_sC); } 373 | if (kSpatialDim > 2) { sk = k - *(w_2S+c*weights_sC); } 374 | if (active) 375 | { 376 | di = *(dweights + c*dweights_sC); 377 | if (kSpatialDim > 1) { dj = *(dw_S+c*dweights_sC); } 378 | if (kSpatialDim > 2) { dk = *(dw_2S+c*dweights_sC); } 379 | // define array here to avoid unnessary warnings, Hope the compiler can optimize it itself 380 | scalar_t vals_array[8] = {zp, zp, zp, zp, zp, zp, zp, zp}; 381 | get_shifted_values( 382 | si, sizeH, input_sH, 383 | sj, sizeW, input_sW, 384 | sk, sizeD, input_sD, 385 | c, input_sC, true, 386 | input_N, zp, vals_array); 387 | val = compute_interpolated(vals_array, di, dj, dk, true, zp); 388 | } 389 | else { 390 | val = get_shifted_value( 391 | si, sizeH, input_sH, 392 | sj, sizeW, input_sW, 393 | sk, sizeD, input_sD, 394 | c, input_sC, true, 395 | input_N, zp); 396 | } 397 | output_NHWD[c*output_sC] = val; 398 | } 399 | } 400 | } 401 | 402 | template 406 | API_DEVICE API_INLINE void shift_backward_kernel_nhwdc(const scalar_t* const input_grad, const scalar_t* const input, 407 | scalar_t* const output_grad, 408 | const idx_t* const weights, const scalar_t* const dweights, 409 | scalar_t* const weights_grad, 410 | const idx_t n, const idx_t i, const idx_t j, const idx_t k, 411 | const idx_t sizeC, const idx_t sizeH, const idx_t sizeW, const idx_t sizeD, 412 | const idx_t input_grad_sN, const idx_t input_grad_sC, const idx_t input_grad_sH, 413 | const idx_t input_grad_sW, const idx_t input_grad_sD, 414 | const idx_t input_sN, const idx_t input_sC, const idx_t input_sH, 415 | const idx_t input_sW, const idx_t input_sD, 416 | const idx_t output_grad_sN, const idx_t output_grad_sC, const idx_t output_grad_sH, 417 | const idx_t output_grad_sW, const idx_t output_grad_sD, 418 | const idx_t weights_sC, const idx_t weights_sS, const idx_t dweights_sC, 419 | const idx_t dweights_sS, const idx_t weights_grad_sC, const idx_t weights_grad_sS, 420 | const idx_t i_left_border, const idx_t j_left_border, const idx_t k_left_border, 421 | const idx_t i_right_border, const idx_t j_right_border, const idx_t k_right_border){ 422 | const scalar_t* const input_grad_N = input_grad + n*input_grad_sN; 423 | const scalar_t* const input_N = input + n*input_sN; 424 | const idx_t weights_numel = kSpatialDim * sizeC; 425 | const scalar_t zp = static_cast(0); 426 | scalar_t vals_array[8] = {zp, zp, zp, zp, zp, zp, zp, zp}; 427 | scalar_t new_weights_grad[3] = {zp, zp, zp}; 428 | scalar_t input_grad_NHWDC_val; 429 | 430 | idx_t shifti = 0; 431 | idx_t shiftj = 0; 432 | idx_t shiftk = 0; 433 | 434 | scalar_t di = zp; 435 | scalar_t dj = zp; 436 | scalar_t dk = zp; 437 | 438 | idx_t si = i; 439 | idx_t sj = j; 440 | idx_t sk = k; 441 | 442 | const bool pass_cond_i = (i >= i_left_border)&&(i < i_right_border); 443 | const bool pass_cond_j = (j >= j_left_border)&&(j < j_right_border); 444 | const bool pass_cond_k = (k >= k_left_border)&&(k < k_right_border); 445 | const bool pass_cond = pass_cond_i&&pass_cond_j&&pass_cond_k; 446 | 447 | const idx_t oi = i - i_left_border; 448 | const idx_t oj = (kSpatialDim > 1) ? (j - j_left_border) : j; 449 | const idx_t ok = (kSpatialDim > 2) ? (k - k_left_border) : k; 450 | const scalar_t* input_grad_NHWD = input_grad_N + (pass_cond?(oi*input_grad_sH + oj*input_grad_sW + ok*input_grad_sD):0); 451 | 452 | idx_t rsi = oi; 453 | idx_t rsj = oj; 454 | idx_t rsk = ok; 455 | idx_t osi = oi; 456 | idx_t osj = oj; 457 | idx_t osk = ok; 458 | 459 | const idx_t osizeH = i_right_border - i_left_border; 460 | const idx_t osizeW = j_right_border - j_left_border; 461 | const idx_t osizeD = k_right_border - k_left_border; 462 | 463 | const idx_t* w_S = (kSpatialDim > 1) ? (weights+weights_sS) : nullptr; 464 | const idx_t* w_2S = (kSpatialDim > 2) ? (weights+2*weights_sS) : nullptr; 465 | const scalar_t* dw_S = (kSpatialDim > 1) ? (dweights+dweights_sS) : nullptr; 466 | const scalar_t* dw_2S = (kSpatialDim > 2) ? (dweights+2*dweights_sS) : nullptr; 467 | 468 | scalar_t* output_grad_NHWD= output_grad + n*output_grad_sN + i*output_grad_sH + j*output_grad_sW + k*output_grad_sD; 469 | 470 | for (idx_t c = 0; c < sizeC; c++) 471 | { 472 | shifti = *(weights+c*weights_sC); 473 | di = *(dweights+c*dweights_sC); 474 | si = i - shifti; 475 | rsi = oi + shifti; 476 | osi = oi - shifti; 477 | if (kSpatialDim > 1) { 478 | shiftj = *(w_S+c*weights_sC); 479 | dj = *(dw_S+c*dweights_sC); 480 | sj = j - shiftj; 481 | rsj = oj + shiftj; 482 | osj = oj - shiftj; 483 | } 484 | if (kSpatialDim > 2) { 485 | shiftk = *(w_2S+c*weights_sC); 486 | dk = *(dw_2S+c*dweights_sC); 487 | sk = k - shiftk; 488 | rsk = ok + shiftk; 489 | osk = ok - shiftk; 490 | } 491 | // weight gradients 492 | get_shifted_values( 493 | si, sizeH, input_sH, 494 | sj, sizeW, input_sW, 495 | sk, sizeD, input_sD, 496 | c, input_sC, pass_cond, 497 | input_N, zp, vals_array); 498 | compute_weight_gradients(vals_array, di, dj, dk, new_weights_grad, pass_cond, zp); 499 | input_grad_NHWDC_val = input_grad_NHWD[c*input_grad_sC]; 500 | ADD(weights_grad, c*weights_grad_sC, weights_numel, input_grad_NHWDC_val * new_weights_grad[0]); 501 | if (kSpatialDim > 1){ ADD(weights_grad, c*weights_grad_sC + weights_grad_sS, weights_numel, input_grad_NHWDC_val * new_weights_grad[1]); } 502 | if (kSpatialDim > 2){ ADD(weights_grad, c*weights_grad_sC + 2*weights_grad_sS, weights_numel, input_grad_NHWDC_val * new_weights_grad[2]); } 503 | 504 | 505 | 506 | // input gradient 507 | if (active) 508 | { 509 | get_shifted_values( 510 | osi, osizeH, input_grad_sH, 511 | osj, osizeW, input_grad_sW, 512 | osk, osizeD, input_grad_sD, 513 | c, input_grad_sC, pass_cond, 514 | input_grad_N, zp, vals_array); 515 | *(output_grad_NHWD+c*output_grad_sC) = compute_interpolated( 516 | vals_array, di, dj, dk, pass_cond, zp); 517 | } 518 | else { 519 | *(output_grad_NHWD+c*output_grad_sC) = get_shifted_value( 520 | rsi, osizeH, input_grad_sH, 521 | rsj, osizeW, input_grad_sW, 522 | rsk, osizeD, input_grad_sD, 523 | c, input_grad_sC, pass_cond, 524 | input_grad_N, zp); 525 | } 526 | } 527 | } 528 | 529 | 530 | 531 | /////////QUANTIZED 532 | template 535 | API_DEVICE API_INLINE void shift_forward_kernel_nchwd_q(const scalar_t* const input, scalar_t* const output, 536 | const idx_t* const weights, 537 | const idx_t n, const idx_t c, const idx_t i, const idx_t j, const idx_t k, 538 | const idx_t sizeH, const idx_t sizeW, const idx_t sizeD, 539 | const idx_t input_sN, const idx_t input_sC, const idx_t input_sH, 540 | const idx_t input_sW, const idx_t input_sD, 541 | const idx_t output_sN, const idx_t output_sC, const idx_t output_sH, 542 | const idx_t output_sW, const idx_t output_sD, 543 | const idx_t weights_sC, const idx_t weights_sS, 544 | const idx_t i_left_border, const idx_t j_left_border, const idx_t k_left_border, 545 | const idx_t i_right_border, const idx_t j_right_border, const idx_t k_right_border, 546 | const scalar_t zero_point, const idx_t weights_zero_point){ 547 | const scalar_t* const input_NC = input + n*input_sN + c*input_sC; 548 | 549 | const idx_t oi = i - i_left_border; 550 | const idx_t oj = (kSpatialDim > 1) ? (j - j_left_border) : j; 551 | const idx_t ok = (kSpatialDim > 2) ? (k - k_left_border) : k; 552 | 553 | const idx_t si = i - *(weights+c*weights_sC) + weights_zero_point; 554 | const idx_t sj = (kSpatialDim > 1) ? (j - *(weights+c*weights_sC+weights_sS) + weights_zero_point) : j; 555 | const idx_t sk = (kSpatialDim > 2) ? (k - *(weights+c*weights_sC+2*weights_sS) + weights_zero_point) : k; 556 | 557 | const bool pass_cond_i = (i >= i_left_border)&&(i < i_right_border); 558 | const bool pass_cond_j = (j >= j_left_border)&&(j < j_right_border); 559 | const bool pass_cond_k = (k >= k_left_border)&&(k < k_right_border); 560 | const bool pass_cond = pass_cond_i&&pass_cond_j&&pass_cond_k; 561 | 562 | if (pass_cond) { 563 | scalar_t* output_NCHWD = output + n*output_sN + c*output_sC + oi*output_sH + oj*output_sW + ok*output_sD; 564 | *output_NCHWD = get_shifted_value( 565 | si, sizeH, input_sH, 566 | sj, sizeW, input_sW, 567 | sk, sizeD, input_sD, 568 | 0, 0, true, 569 | input_NC, zero_point); 570 | } 571 | } 572 | 573 | 574 | template 577 | API_DEVICE API_INLINE void shift_forward_kernel_nhwdc_q(const scalar_t* const input, scalar_t* const output, 578 | const idx_t* const weights, 579 | const idx_t n, const idx_t i, const idx_t j, const idx_t k, 580 | const idx_t sizeC, const idx_t sizeH, const idx_t sizeW, const idx_t sizeD, 581 | const idx_t input_sN, const idx_t input_sC, const idx_t input_sH, 582 | const idx_t input_sW, const idx_t input_sD, 583 | const idx_t output_sN, const idx_t output_sC, const idx_t output_sH, 584 | const idx_t output_sW, const idx_t output_sD, 585 | const idx_t weights_sC, const idx_t weights_sS, 586 | const idx_t i_left_border, const idx_t j_left_border, const idx_t k_left_border, 587 | const idx_t i_right_border, const idx_t j_right_border, const idx_t k_right_border, 588 | const scalar_t zero_point, const idx_t weights_zero_point){ 589 | const scalar_t* input_N = input + n*input_sN; 590 | 591 | const idx_t oi = i - i_left_border; 592 | const idx_t oj = (kSpatialDim > 1) ? j - j_left_border : j; 593 | const idx_t ok = (kSpatialDim > 2) ? k - k_left_border : k; 594 | 595 | const idx_t* w_S = (kSpatialDim > 1) ? (weights+weights_sS) : nullptr; 596 | const idx_t* w_2S = (kSpatialDim > 2) ? (weights+2*weights_sS) : nullptr; 597 | 598 | 599 | const bool pass_cond_i = (i >= i_left_border)&&(i < i_right_border); 600 | const bool pass_cond_j = (j >= j_left_border)&&(j < j_right_border); 601 | const bool pass_cond_k = (k >= k_left_border)&&(k < k_right_border); 602 | const bool pass_cond = pass_cond_i&&pass_cond_j&&pass_cond_k; 603 | 604 | if (pass_cond) { 605 | scalar_t val; 606 | idx_t si = i; 607 | idx_t sj = j; 608 | idx_t sk = k; 609 | scalar_t *output_NHWD = output + n*output_sN + oi*output_sH + oj*output_sW + ok*output_sD; 610 | for (idx_t c = 0; c < sizeC; c++) 611 | { 612 | si = i - *(weights+c*weights_sC) + weights_zero_point; 613 | if (kSpatialDim > 1){ sj = j - *(w_S+c*weights_sC) + weights_zero_point; } 614 | if (kSpatialDim > 2){ sk = k - *(w_2S+c*weights_sC) + weights_zero_point; } 615 | val = get_shifted_value( 616 | si, sizeH, input_sH, 617 | sj, sizeW, input_sW, 618 | sk, sizeD, input_sD, 619 | c, input_sC, true, 620 | input_N, zero_point); 621 | output_NHWD[c*output_sC] = val; 622 | } 623 | } 624 | } -------------------------------------------------------------------------------- /torchshifts/csrc/ops/ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "shifts.h" 4 | -------------------------------------------------------------------------------- /torchshifts/csrc/ops/quantized/shifts_quantized.cpp: -------------------------------------------------------------------------------- 1 | #ifndef SHIFTS_CPU 2 | #define SHIFTS_CPU 3 | 4 | 5 | #include 6 | #include "../global_scope.h" 7 | #include "../kernels/shifts_kernels.h" 8 | 9 | 10 | namespace shifts { 11 | namespace ops { 12 | 13 | namespace { 14 | 15 | 16 | template 18 | API_INLINE void qshiftnd_kernel(const torch::Tensor& input, 19 | const torch::Tensor& weights, 20 | const torch::Tensor& borders, 21 | torch::Tensor& output, 22 | int64_t weights_zero_point){ 23 | const int64_t sizeN = input.size(0); 24 | const int64_t sizeC = input.size(1); 25 | const int64_t sizeH = input.size(2); 26 | const int64_t sizeW = kSpatialDim < 2 ? 1 : input.size(3); 27 | const int64_t sizeD = kSpatialDim < 3 ? 1 : input.size(4); 28 | const int64_t input_sN = input.stride(0); 29 | const int64_t input_sC = input.stride(1); 30 | const int64_t input_sH = input.stride(2); 31 | const int64_t input_sW = kSpatialDim < 2 ? 0 : input.stride(3); 32 | const int64_t input_sD = kSpatialDim < 3 ? 0 : input.stride(4); 33 | const int64_t output_sN = output.stride(0); 34 | const int64_t output_sC = output.stride(1); 35 | const int64_t output_sH = output.stride(2); 36 | const int64_t output_sW = kSpatialDim < 2 ? 0 : output.stride(3); 37 | const int64_t output_sD = kSpatialDim < 3 ? 0 : output.stride(4); 38 | scalar_t* input_ptr = input.data_ptr(); 39 | const scalar_t zero_point = static_cast(input.q_zero_point()); 40 | scalar_t* output_ptr = output.data_ptr(); 41 | int64_t* weights_ptr = weights.data_ptr(); 42 | const int64_t weights_sC = weights.stride(0); 43 | const int64_t weights_sS = weights.stride(1); 44 | 45 | 46 | 47 | int64_t* borders_data = borders.data_ptr(); 48 | const int64_t i_left_border = borders_data[0]; 49 | const int64_t i_right_border = borders_data[1]; 50 | const int64_t j_left_border = kSpatialDim < 2 ? 0 : borders_data[2]; 51 | const int64_t j_right_border = kSpatialDim < 2 ? 1 : borders_data[3]; 52 | const int64_t k_left_border = kSpatialDim < 3 ? 0 : borders_data[4]; 53 | const int64_t k_right_border = kSpatialDim < 3 ? 1 : borders_data[5]; 54 | 55 | if (input.is_contiguous(c10::MemoryFormat::ChannelsLast) || input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) 56 | {// Path for NDHWC 57 | at::parallel_for(0, sizeN, 0, [&](int64_t start, int64_t end){ 58 | for (int64_t n = start; n < end; ++n) { 59 | for (int64_t i = 0; i < sizeH; ++i){ 60 | for (int64_t j = 0; j < sizeW; ++j){ 61 | for (int64_t k = 0; k < sizeD; ++k){ 62 | shift_forward_kernel_nhwdc_q( 63 | input_ptr, output_ptr, weights_ptr, 64 | n, i, j, k, sizeC, sizeH, sizeW, sizeD, 65 | input_sN, input_sC, input_sH, input_sW, input_sD, 66 | output_sN, output_sC, output_sH, output_sW, output_sD, 67 | weights_sC, weights_sS, 68 | i_left_border, j_left_border, k_left_border, 69 | i_right_border, j_right_border, k_right_border, 70 | zero_point, weights_zero_point); 71 | } 72 | } 73 | } 74 | } 75 | }); 76 | } else 77 | { 78 | at::parallel_for(0, sizeN*sizeC, 0, [&](int64_t start, int64_t end){ 79 | for (int64_t index = start; index < end; ++index) { 80 | const int64_t c = index % sizeC; 81 | const int64_t n = index / sizeC; 82 | for (int64_t i = 0; i < sizeH; ++i){ 83 | for (int64_t j = 0; j < sizeW; ++j){ 84 | for (int64_t k = 0; k < sizeD; ++k){ 85 | shift_forward_kernel_nchwd_q( 86 | input_ptr, output_ptr, weights_ptr, 87 | n, c, i, j, k, sizeH, sizeW, sizeD, 88 | input_sN, input_sC, input_sH, input_sW, input_sD, 89 | output_sN, output_sC, output_sH, output_sW, output_sD, 90 | weights_sC, weights_sS, 91 | i_left_border, j_left_border, k_left_border, 92 | i_right_border, j_right_border, k_right_border, 93 | zero_point, weights_zero_point); 94 | } 95 | } 96 | } 97 | } 98 | }); 99 | } 100 | } 101 | 102 | 103 | 104 | 105 | 106 | 107 | template 108 | torch::Tensor qshiftnd(const torch::Tensor& input, 109 | const torch::Tensor& weights, 110 | const torch::Tensor& borders, 111 | const std::vector& new_size){ 112 | std::string name = "q_shift"+std::to_string(nD)+"d_cpu"; 113 | torch::Tensor output; 114 | int64_t weights_zero_point = static_cast(weights.q_zero_point()); 115 | torch::Tensor iweights = weights.int_repr().to(torch::kLong); 116 | 117 | torch::Tensor _borders = borders.to(torch::kLong); 118 | 119 | if (input.is_contiguous(c10::MemoryFormat::ChannelsLast) || input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) { 120 | output = at::_empty_affine_quantized(new_size, input.options().memory_format(input.suggest_memory_format()), 121 | input.q_scale(), input.q_zero_point(), c10::nullopt); 122 | } 123 | else { 124 | output = at::_empty_affine_quantized(new_size, input.options(), input.q_scale(), input.q_zero_point()); 125 | } 126 | AT_DISPATCH_QINT_TYPES(input.scalar_type(), name, [&] { 127 | qshiftnd_kernel(input, iweights, _borders, output, weights_zero_point); 128 | }); 129 | return output; 130 | } 131 | 132 | 133 | // TEMPLATE DISPATCHERS 134 | 135 | 136 | torch::Tensor qshift1d(const torch::Tensor& input, 137 | const torch::Tensor& weights, 138 | const torch::Tensor& borders, 139 | const std::vector& new_size, 140 | int64_t padding_mode, 141 | bool active_flag){ //active_flag not used here, but needs for API compatibility 142 | torch::Tensor ret; 143 | switch (padding_mode){ 144 | case 0: 145 | ret = qshiftnd<1, BIPadding::Zeros>(input, weights, borders, new_size); 146 | break; 147 | case 1: 148 | ret = qshiftnd<1, BIPadding::Border>(input, weights, borders, new_size); 149 | break; 150 | case 2: 151 | ret = qshiftnd<1, BIPadding::Periodic>(input, weights, borders, new_size); 152 | break; 153 | case 3: 154 | ret = qshiftnd<1, BIPadding::Reflect>(input, weights, borders, new_size); 155 | break; 156 | case 4: 157 | ret = qshiftnd<1, BIPadding::Symmetric>(input, weights, borders, new_size); 158 | break; 159 | } 160 | return ret; 161 | } 162 | 163 | 164 | torch::Tensor qshift2d(const torch::Tensor& input, 165 | const torch::Tensor& weights, 166 | const torch::Tensor& borders, 167 | const std::vector& new_size, 168 | int64_t padding_mode, 169 | bool active_flag){ //active_flag not used here, but needs for API compatibility 170 | torch::Tensor ret; 171 | switch (padding_mode){ 172 | case 0: 173 | ret = qshiftnd<2, BIPadding::Zeros>(input, weights, borders, new_size); 174 | break; 175 | case 1: 176 | ret = qshiftnd<2, BIPadding::Border>(input, weights, borders, new_size); 177 | break; 178 | case 2: 179 | ret = qshiftnd<2, BIPadding::Periodic>(input, weights, borders, new_size); 180 | break; 181 | case 3: 182 | ret = qshiftnd<2, BIPadding::Reflect>(input, weights, borders, new_size); 183 | break; 184 | case 4: 185 | ret = qshiftnd<2, BIPadding::Symmetric>(input, weights, borders, new_size); 186 | break; 187 | } 188 | return ret; 189 | } 190 | 191 | torch::Tensor qshift3d(const torch::Tensor& input, 192 | const torch::Tensor& weights, 193 | const torch::Tensor& borders, 194 | const std::vector& new_size, 195 | int64_t padding_mode, 196 | bool active_flag){ //active_flag not used here, but needs for API compatibility 197 | torch::Tensor ret; 198 | switch (padding_mode){ 199 | case 0: 200 | ret = qshiftnd<3, BIPadding::Zeros>(input, weights, borders, new_size); 201 | break; 202 | case 1: 203 | ret = qshiftnd<3, BIPadding::Border>(input, weights, borders, new_size); 204 | break; 205 | case 2: 206 | ret = qshiftnd<3, BIPadding::Periodic>(input, weights, borders, new_size); 207 | break; 208 | case 3: 209 | ret = qshiftnd<3, BIPadding::Reflect>(input, weights, borders, new_size); 210 | break; 211 | case 4: 212 | ret = qshiftnd<3, BIPadding::Symmetric>(input, weights, borders, new_size); 213 | break; 214 | } 215 | return ret; 216 | } 217 | 218 | std::tuple qshiftnd_backward(const torch::Tensor& grad, 219 | const torch::Tensor& weights, 220 | const torch::Tensor& input, 221 | const torch::Tensor& borders, 222 | int64_t padding_mode, 223 | bool active_flag){ 224 | TORCH_CHECK(0, "backwards on quantized tensor are not supported"); 225 | } 226 | 227 | } // end of anonymous namespace 228 | 229 | 230 | TORCH_LIBRARY_IMPL(torchshifts, QuantizedCPU, m) { 231 | m.impl( 232 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_forward"), 233 | TORCH_FN(qshift1d)); 234 | m.impl( 235 | TORCH_SELECTIVE_NAME("torchshifts::_shift1d_backward"), 236 | TORCH_FN(qshiftnd_backward)); 237 | m.impl( 238 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_forward"), 239 | TORCH_FN(qshift2d)); 240 | m.impl( 241 | TORCH_SELECTIVE_NAME("torchshifts::_shift2d_backward"), 242 | TORCH_FN(qshiftnd_backward)); 243 | m.impl( 244 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_forward"), 245 | TORCH_FN(qshift3d)); 246 | m.impl( 247 | TORCH_SELECTIVE_NAME("torchshifts::_shift3d_backward"), 248 | TORCH_FN(qshiftnd_backward)); 249 | } 250 | 251 | } // namespace ops 252 | } // namespace shifts 253 | 254 | 255 | #endif 256 | -------------------------------------------------------------------------------- /torchshifts/csrc/ops/shifts.cpp: -------------------------------------------------------------------------------- 1 | #include "shifts.h" 2 | 3 | #include 4 | 5 | namespace shifts { 6 | namespace ops { 7 | 8 | namespace detail { 9 | 10 | torch::Tensor _shift1d_forward(const torch::Tensor& input, 11 | const torch::Tensor& weights, 12 | const torch::Tensor& borders, 13 | const std::vector& new_size, 14 | int64_t padding_mode, 15 | bool active_flag){ 16 | static auto op = 17 | c10::Dispatcher::singleton() 18 | .findSchemaOrThrow("torchshifts::_shift1d_forward", "") 19 | .typed(); 20 | return op.call(input, weights, borders, new_size, padding_mode, active_flag); 21 | } 22 | 23 | std::tuple _shift1d_backward(const torch::Tensor& grad, 24 | const torch::Tensor& weights, 25 | const torch::Tensor& input, 26 | const torch::Tensor& borders, 27 | int64_t padding_mode, 28 | bool active_flag){ 29 | static auto op = 30 | c10::Dispatcher::singleton() 31 | .findSchemaOrThrow("torchshifts::_shift1d_backward", "") 32 | .typed(); 33 | return op.call(grad, weights, input, borders, padding_mode, active_flag); 34 | } 35 | 36 | 37 | torch::Tensor _shift2d_forward(const torch::Tensor& input, 38 | const torch::Tensor& weights, 39 | const torch::Tensor& borders, 40 | const std::vector& new_size, 41 | int64_t padding_mode, 42 | bool active_flag){ 43 | static auto op = 44 | c10::Dispatcher::singleton() 45 | .findSchemaOrThrow("torchshifts::_shift2d_forward", "") 46 | .typed(); 47 | return op.call(input, weights, borders, new_size, padding_mode, active_flag); 48 | } 49 | 50 | std::tuple _shift2d_backward(const torch::Tensor& grad, 51 | const torch::Tensor& weights, 52 | const torch::Tensor& input, 53 | const torch::Tensor& borders, 54 | int64_t padding_mode, 55 | bool active_flag){ 56 | static auto op = 57 | c10::Dispatcher::singleton() 58 | .findSchemaOrThrow("torchshifts::_shift2d_backward", "") 59 | .typed(); 60 | return op.call(grad, weights, input, borders, padding_mode, active_flag); 61 | } 62 | 63 | 64 | torch::Tensor _shift3d_forward(const torch::Tensor& input, 65 | const torch::Tensor& weights, 66 | const torch::Tensor& borders, 67 | const std::vector& new_size, 68 | int64_t padding_mode, 69 | bool active_flag){ 70 | static auto op = 71 | c10::Dispatcher::singleton() 72 | .findSchemaOrThrow("torchshifts::_shift3d_forward", "") 73 | .typed(); 74 | return op.call(input, weights, borders, new_size, padding_mode, active_flag); 75 | } 76 | 77 | std::tuple _shift3d_backward(const torch::Tensor& grad, 78 | const torch::Tensor& weights, 79 | const torch::Tensor& input, 80 | const torch::Tensor& borders, 81 | int64_t padding_mode, 82 | bool active_flag){ 83 | static auto op = 84 | c10::Dispatcher::singleton() 85 | .findSchemaOrThrow("torchshifts::_shift3d_backward", "") 86 | .typed(); 87 | return op.call(grad, weights, input, borders, padding_mode, active_flag); 88 | } 89 | 90 | } // namespace detail 91 | 92 | using namespace torch::indexing; 93 | std::tuple> check_borders(const torch::Tensor& input, 94 | const torch::Tensor& borders, 95 | const int64_t idim){ 96 | auto sizes = input.sizes(); 97 | const int dim = static_cast(idim); 98 | const int shift = (((dim + 1) == (int)sizes.size())?1:2); 99 | const int hdim = 3; // hardcoded for pass no more than, 5D tensor 100 | const int _dim = std::min(hdim,dim); 101 | auto dev = input.device(); 102 | torch::Tensor std_borders = torch::empty({hdim*2}, borders.options().dtype(torch::kInt).device(torch::kCPU)); 103 | int* std_borders_data = std_borders.data_ptr(); 104 | for (int i=0 ; i < hdim; ++i){ 105 | std_borders_data[i*2] = 0; 106 | std_borders_data[i*2+1] = ((i+1)>dim)?1:sizes[i+shift]; 107 | } 108 | if (borders.numel() != 0){ 109 | auto _borders = borders.to(torch::kInt).to(torch::kCPU); 110 | int* borders_data = _borders.data_ptr(); 111 | for (int i=0 ; i < _dim; ++i){ 112 | std_borders_data[i*2+1] -= borders_data[i*2+1]; 113 | std_borders_data[i*2] = borders_data[i*2]; 114 | if ((std_borders_data[i*2+1] - std_borders_data[i*2]) < 1){ 115 | std_borders_data[i*2+1] = std_borders_data[i*2] + 1; 116 | } 117 | if (std_borders_data[i*2] == static_cast(sizes[i+shift])){ 118 | std_borders_data[i*2] = static_cast(sizes[i+shift]) - 1; 119 | std_borders_data[i*2+1] = std_borders_data[i*2] + 1; 120 | } 121 | if (std_borders_data[i*2+1] == 0){ 122 | std_borders_data[i*2] = 0; 123 | std_borders_data[i*2+1] = 1; 124 | } 125 | std_borders_data[i*2] = std::max(static_cast(0), std_borders_data[i*2]); 126 | std_borders_data[i*2+1] = std::min(static_cast(sizes[i+shift]), std_borders_data[i*2+1]); 127 | } 128 | } 129 | std::vector new_sizes(shift+_dim); 130 | std::copy(sizes.begin(), sizes.begin()+shift, new_sizes.begin()); 131 | for (int i=0 ; i < _dim; ++i){ 132 | new_sizes[i+shift] = static_cast(std_borders_data[i*2+1] - std_borders_data[i*2]); 133 | } 134 | return std::make_tuple(std_borders.to(dev), new_sizes); 135 | } 136 | 137 | 138 | torch::Tensor shift1d(const torch::Tensor& input, 139 | const torch::Tensor& weights, 140 | const torch::Tensor& borders, 141 | int64_t padding_mode, bool active_flag){ 142 | auto bands = check_borders(input, borders, 1); 143 | auto _borders = std::get<0>(bands); 144 | auto new_size = std::get<1>(bands); 145 | return detail::_shift1d_forward(input, weights, _borders, new_size, padding_mode, active_flag); 146 | } 147 | 148 | torch::Tensor shift2d(const torch::Tensor& input, 149 | const torch::Tensor& weights, 150 | const torch::Tensor& borders, 151 | int64_t padding_mode, bool active_flag){ 152 | auto bands = check_borders(input, borders, 2); 153 | auto _borders = std::get<0>(bands); 154 | auto new_size = std::get<1>(bands); 155 | return detail::_shift2d_forward(input, weights, _borders, new_size, padding_mode, active_flag); 156 | } 157 | 158 | torch::Tensor shift3d(const torch::Tensor& input, 159 | const torch::Tensor& weights, 160 | const torch::Tensor& borders, 161 | int64_t padding_mode, bool active_flag){ 162 | auto bands = check_borders(input, borders, 3); 163 | auto _borders = std::get<0>(bands); 164 | auto new_size = std::get<1>(bands); 165 | return detail::_shift3d_forward(input, weights, _borders, new_size, padding_mode, active_flag); 166 | } 167 | 168 | TS_TORCH_LIBRARY_FRAGMENT(torchshifts, m) { 169 | m.def(TORCH_SELECTIVE_SCHEMA( 170 | "torchshifts::_shift1d_forward(Tensor input, Tensor weights, Tensor borders, int[] new_size, int padding_mode, bool active_flag) -> Tensor")); 171 | m.def(TORCH_SELECTIVE_SCHEMA( 172 | "torchshifts::_shift1d_backward(Tensor grad, Tensor weights, Tensor input, Tensor borders, int padding_mode, bool active_flag) -> (Tensor, Tensor)")); 173 | m.def(TORCH_SELECTIVE_SCHEMA( 174 | "torchshifts::_shift2d_forward(Tensor input, Tensor weights, Tensor borders, int[] new_size, int padding_mode, bool active_flag) -> Tensor")); 175 | m.def(TORCH_SELECTIVE_SCHEMA( 176 | "torchshifts::_shift2d_backward(Tensor grad, Tensor weights, Tensor input, Tensor borders, int padding_mode, bool active_flag) -> (Tensor, Tensor)")); 177 | m.def(TORCH_SELECTIVE_SCHEMA( 178 | "torchshifts::_shift3d_forward(Tensor input, Tensor weights, Tensor borders, int[] new_size, int padding_mode, bool active_flag) -> Tensor")); 179 | m.def(TORCH_SELECTIVE_SCHEMA( 180 | "torchshifts::_shift3d_backward(Tensor grad, Tensor weights, Tensor input, Tensor borders, int padding_mode, bool active_flag) -> (Tensor, Tensor)")); 181 | } 182 | 183 | 184 | } // namespace ops 185 | } // namespace shifts 186 | -------------------------------------------------------------------------------- /torchshifts/csrc/ops/shifts.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "../macros.h" 6 | 7 | namespace shifts { 8 | namespace ops { 9 | 10 | 11 | 12 | API_EXPORT torch::Tensor shift1d(const torch::Tensor& input, 13 | const torch::Tensor& weights, 14 | const torch::Tensor& borders, 15 | int64_t padding_mode, 16 | bool active_flag); 17 | 18 | API_EXPORT torch::Tensor shift2d(const torch::Tensor& input, 19 | const torch::Tensor& weights, 20 | const torch::Tensor& borders, 21 | int64_t padding_mode, 22 | bool active_flag); 23 | 24 | API_EXPORT torch::Tensor shift3d(const torch::Tensor& input, 25 | const torch::Tensor& weights, 26 | const torch::Tensor& borders, 27 | int64_t padding_mode, 28 | bool active_flag); 29 | 30 | 31 | 32 | 33 | namespace detail { 34 | 35 | torch::Tensor _shift1d_forward(const torch::Tensor& input, 36 | const torch::Tensor& weights, 37 | const torch::Tensor& borders, 38 | const std::vector& new_size, 39 | int64_t padding_mode, 40 | bool active_flag); 41 | 42 | torch::Tensor _shift2d_forward(const torch::Tensor& input, 43 | const torch::Tensor& weights, 44 | const torch::Tensor& borders, 45 | const std::vector& new_size, 46 | int64_t padding_mode, 47 | bool active_flag); 48 | 49 | torch::Tensor _shift3d_forward(const torch::Tensor& input, 50 | const torch::Tensor& weights, 51 | const torch::Tensor& borders, 52 | const std::vector& new_size, 53 | int64_t padding_mode, 54 | bool active_flag); 55 | 56 | std::tuple _shift1d_backward(const torch::Tensor& grad, 57 | const torch::Tensor& weights, 58 | const torch::Tensor& input, 59 | const torch::Tensor& borders, 60 | int64_t padding_mode, 61 | bool active_flag); 62 | 63 | std::tuple _shift2d_backward(const torch::Tensor& grad, 64 | const torch::Tensor& weights, 65 | const torch::Tensor& input, 66 | const torch::Tensor& borders, 67 | int64_t padding_mode, 68 | bool active_flag); 69 | 70 | std::tuple _shift3d_backward(const torch::Tensor& grad, 71 | const torch::Tensor& weights, 72 | const torch::Tensor& input, 73 | const torch::Tensor& borders, 74 | int64_t padding_mode, 75 | bool active_flag); 76 | 77 | } // namespace detail 78 | 79 | } // namespace ops 80 | } // namespace shifts -------------------------------------------------------------------------------- /torchshifts/csrc/torchshifts.cpp: -------------------------------------------------------------------------------- 1 | #include "torchshifts.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "ops/ops.h" 8 | 9 | #ifdef WITH_CUDA 10 | #include 11 | #endif 12 | 13 | 14 | 15 | #ifdef _WIN32 16 | #if PY_MAJOR_VERSION < 3 17 | PyMODINIT_FUNC init_C(void) {return NULL;} 18 | #else 19 | PyMODINIT_FUNC PyInit__C(void) {return NULL;} 20 | #endif 21 | #endif 22 | 23 | namespace shifts { 24 | 25 | int64_t cuda_version() { 26 | #ifdef WITH_CUDA 27 | return CUDA_VERSION; 28 | #else 29 | return -1; 30 | #endif 31 | } 32 | 33 | } 34 | 35 | TS_TORCH_LIBRARY_FRAGMENT(torchshifts, m) { 36 | m.def("_cuda_version", &shifts::cuda_version); 37 | m.def("shift1d", &shifts::ops::shift1d); 38 | m.def("shift2d", &shifts::ops::shift2d); 39 | m.def("shift3d", &shifts::ops::shift3d); 40 | } 41 | -------------------------------------------------------------------------------- /torchshifts/csrc/torchshifts.h: -------------------------------------------------------------------------------- 1 | #ifndef TORCHSHIFT 2 | #define TORCHSHIFT 3 | 4 | 5 | #include 6 | #include "macros.h" 7 | 8 | namespace shifts { 9 | API_EXPORT int64_t cuda_version(); 10 | 11 | namespace detail { 12 | //(Taken from torchvision) 13 | int64_t _cuda_version = cuda_version(); 14 | 15 | } 16 | } 17 | 18 | #endif 19 | -------------------------------------------------------------------------------- /torchshifts/extension.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file was partly taken from torchvision 3 | ''' 4 | 5 | _HAS_OPS = False 6 | error_str = '' 7 | 8 | 9 | def _has_ops(): 10 | return False 11 | 12 | def _register_extensions(): 13 | from pathlib import Path 14 | import sys 15 | import os 16 | import importlib 17 | import torch 18 | 19 | # load the custom_op_library and register the custom ops 20 | lib_dir = Path(__file__).resolve().parent 21 | if os.name == 'nt': 22 | # Register the main torchsrrops library location on the default DLL path 23 | import ctypes 24 | kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) 25 | with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') 26 | prev_error_mode = kernel32.SetErrorMode(0x0001) 27 | if with_load_library_flags: 28 | kernel32.AddDllDirectory.restype = ctypes.c_void_p 29 | if sys.version_info >= (3, 8): 30 | os.add_dll_directory(str(lib_dir)) 31 | elif with_load_library_flags: 32 | res = kernel32.AddDllDirectory(str(lib_dir)) 33 | if res is None: 34 | err = ctypes.WinError(ctypes.get_last_error()) 35 | err.strerror += f' Error adding "{str(lib_dir)}" to the DLL directories.' 36 | raise err 37 | kernel32.SetErrorMode(prev_error_mode) 38 | 39 | loader_details = (importlib.machinery.ExtensionFileLoader, 40 | importlib.machinery.EXTENSION_SUFFIXES) 41 | extfinder = importlib.machinery.FileFinder(str(lib_dir), loader_details) 42 | ext_specs = extfinder.find_spec("_C") 43 | if ext_specs is None: 44 | raise ImportError 45 | torch.ops.load_library(ext_specs.origin) 46 | 47 | 48 | try: 49 | _register_extensions() 50 | _HAS_OPS = True 51 | 52 | def _has_ops(): 53 | return True 54 | except (ImportError, OSError) as e: 55 | error_str = str(e) 56 | pass 57 | 58 | def _assert_has_ops(): 59 | if not _has_ops(): 60 | raise RuntimeError( 61 | "Couldn't load custom C++ ops. This can happen if your PyTorch and " 62 | "torchshifts versions are incompatible, or if you had errors while compiling " 63 | "torchshifts from source. For further information on the compatible versions, check " 64 | "https://github.com/DeadAt0m/ActiveSparseShifts-PyTorch/blob/master/README.md for the compatibility matrix. " 65 | "Please check your PyTorch version with torch.__version__ and verify if it is compatible, and if not " 66 | "please reinstall your PyTorch." 67 | f"\n\nImport error details:\n\t{error_str}" 68 | ) 69 | 70 | 71 | def _check_cuda_version(): 72 | """ 73 | Make sure that CUDA versions match between the pytorch install and torchshifts install 74 | """ 75 | if not _HAS_OPS: 76 | return -1 77 | import torch 78 | _version = torch.ops.torchshifts._cuda_version() 79 | if _version != -1 and torch.version.cuda is not None: 80 | ts_version = str(_version) 81 | if int(ts_version) < 10000: 82 | ts_major = int(ts_version[0]) 83 | ts_minor = int(ts_version[2]) 84 | else: 85 | ts_major = int(ts_version[0:2]) 86 | ts_minor = int(ts_version[3]) 87 | t_version = torch.version.cuda 88 | t_version = t_version.split('.') 89 | t_major = int(t_version[0]) 90 | t_minor = int(t_version[1]) 91 | if t_major != ts_major or t_minor > ts_minor: 92 | raise RuntimeError("Detected that PyTorch and torchshifts were compiled with different CUDA versions. " 93 | "PyTorch has CUDA Version={}.{} and torchshifts has CUDA Version={}.{}. " 94 | "Please reinstall the torchshifts that matches your PyTorch install." 95 | .format(t_major, t_minor, ts_major, ts_minor)) 96 | return _version 97 | 98 | _check_cuda_version() 99 | -------------------------------------------------------------------------------- /torchshifts/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .extension import _assert_has_ops 3 | from typing import Optional 4 | 5 | Tensor = torch.Tensor 6 | 7 | def shift1d_func(input: Tensor, weights: Tensor, 8 | padding_mode: int, active_flag: bool, 9 | borders: Optional[Tensor] = None) -> Tensor: 10 | """ 11 | Performs shift operation on 1D tensor 12 | Arguments: 13 | input (Tensor[N, C, H]): input 3D tensor 14 | weights (Tensor[C, 1]): tensor contained shift(amount(abs) and direction(sign)) value for each channel of 1D tensor 15 | padding_mode (int): padding applyed during shift. Allowed following modes: 0 - zeros, 16 | 1 - border, 17 | 2 - periodic, 18 | 3 - reflective, 19 | 4 - symmetric 20 | active_flag (bool): if true - the active shift(via billinear interpolation) will used on forward pass. 21 | This option has no effect if input is Quantized tensor. 22 | borders (Tensor[1,2]): dim x (left_border, right_border) output tensor will be cut off proportional to borders 23 | Returns: 24 | output (Tensor[N, C, H]) 25 | """ 26 | _assert_has_ops() 27 | assert padding_mode in [0,1,2,3,4], f'shift1d_func() expected padding_mode can be 0 - zeros, 1 - border, 2 - periodic, 3 - reflect, 4 - symmetric' 28 | assert len(input.shape) == 3, f'shift1d_func(): expected 3D tensor as input, but it is shape is {input.shape}' 29 | assert weights.shape[-1] == 1, f'shift1d_func(): expected [n_channels,1] tensor as weight, but it is shape is {weights.shape}' 30 | assert input.shape[1] == weights.shape[0], f'shift1d_func(): expected that input and weight have equal number of channels, but input have {input.shape[1]} and weight have {weights.shape[0]} channels.' 31 | assert input.device == weights.device, f'shift1d_func(): expected input and weights to be on same device, but input is on {input.device} and weights is on {weights.device}' 32 | if borders is not None: 33 | assert (len(borders.shape) == 2) and (borders.shape[1] == 2) and (borders.shape[0] == 1), f'borders must have shape [1, 2]' 34 | else: 35 | borders = torch.Tensor() 36 | return torch.ops.torchshifts.shift1d(input, weights, borders, padding_mode, active_flag) 37 | 38 | 39 | def shift2d_func(input: Tensor, weights: Tensor, 40 | padding_mode: int, active_flag: bool, 41 | borders: Optional[Tensor] = None) -> Tensor: 42 | """ 43 | Performs shift operation on 2D tensor 44 | Arguments: 45 | input (Tensor[N, C, H, W]): input 4D tensor 46 | weights (Tensor[C, 2]): tensor contained 2 shift(amount(abs) and direction(sign)) values(for H and W axes) for each channel of 2D tensor. 47 | padding_mode (int): padding applyed during shift. Allowed following modes: 0 - zeros, 48 | 1 - border, 49 | 2 - periodic, 50 | 3 - reflective, 51 | 4 - symmetric 52 | active_flag (bool): if true - the active shift(via billinear interpolation) will used on forward pass. 53 | This option has no effect if input is Quantized tensor. 54 | borders (Tensor[2,2]): dim x (left_border, right_border) output tensor will be cut off proportional to borders 55 | Returns: 56 | output (Tensor[N, C, H. W]) 57 | """ 58 | _assert_has_ops() 59 | assert padding_mode in [0,1,2,3,4], f'shift2d_func() expected padding_mode can be 0 - zeros, 1 - border, 2 - periodic, 3 - reflect, 4 - symmetric' 60 | assert len(input.shape) == 4, f'shift2d_func(): expected 4D tensor as input, but it is shape is {input.shape}' 61 | assert weights.shape[-1] == 2, f'shift2d_func(): expected [n_channels,2] tensor as weight, but it is shape is {weights.shape}' 62 | assert input.shape[1] == weights.shape[0], f'shift2d_func(): expected that input and weight have equal number of channels, but input have {input.shape[1]} and weight have {weights.shape[0]} channels.' 63 | assert input.device == weights.device, f'shift2d_func(): expected input and weights to be on same device, but input is on {input.device} and weights is on {weights.device}' 64 | if borders is not None: 65 | assert (len(borders.shape) == 2) and (borders.shape[1] == 2) and (borders.shape[0] == 2), f'borders must have shape [2, 2]' 66 | else: 67 | borders = torch.Tensor() 68 | return torch.ops.torchshifts.shift2d(input, weights, borders, padding_mode, active_flag) 69 | 70 | def shift3d_func(input: Tensor, weights: Tensor, 71 | padding_mode: int, active_flag: bool, 72 | borders: Optional[Tensor] = None) -> Tensor: 73 | """ 74 | Performs shift operation on 3D tensor 75 | Arguments: 76 | input (Tensor[N, C, H, W, D]): input 5D tensor 77 | weights (Tensor[C, 3]): tensor contained 3 shift(amount(abs) and direction(sign)) values(for H,W and D axes) for each channel of 3D tensor. 78 | padding_mode (int): padding applyed during shift. Allowed following modes: 0 - zeros, 79 | 1 - border, 80 | 2 - periodic, 81 | 3 - reflective, 82 | 4 - symmetric 83 | active_flag (bool): if true - the active shift(via billinear interpolation) will used on forward pass. 84 | This option has no effect if input is Quantized tensor. 85 | borders (Tensor[3,2]): dim x (left_border, right_border) output tensor will be cut off proportional to borders 86 | Returns: 87 | output (Tensor[N, C, H, W, D]) 88 | """ 89 | _assert_has_ops() 90 | assert padding_mode in [0,1,2,3,4], f'shift3d_func() expected padding_mode can be 0 - zeros, 1 - border, 2 - periodic, 3 - reflect, 4 - symmetric' 91 | assert len(input.shape) == 5, f'shift3d_func(): expected 5D tensor as input, but it is shape is {input.shape}' 92 | assert weights.shape[-1] == 3, f'shift3d_func(): expected [n_channels,3] tensor as weight, but it is shape is {weights.shape}' 93 | assert input.shape[1] == weights.shape[0], f'shift3d_func(): expected that input and weight have equal number of channels, but input have {input.shape[1]} and weight have {weights.shape[0]} channels.' 94 | assert input.device == weights.device, f'shift3d_func(): expected input and weights to be on same device, but input is on {input.device} and weights is on {weights.device}' 95 | if borders is not None: 96 | assert (len(borders.shape) == 2) and (borders.shape[1] == 2) and (borders.shape[0] == 3), f'borders must have shape [3, 2]' 97 | else: 98 | borders = torch.Tensor() 99 | return torch.ops.torchshifts.shift3d(input, weights, borders, padding_mode, active_flag) 100 | -------------------------------------------------------------------------------- /torchshifts/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .shifts import Shift1d, Shift2d, Shift3d -------------------------------------------------------------------------------- /torchshifts/modules/shifts.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torchshifts.functional import shift1d_func, shift2d_func, shift3d_func 4 | import random 5 | from functools import partial 6 | 7 | paddings_dict = {'zeros':0, 'border':1, 'periodic':2, 'reflect':3, 'symmetric':4} 8 | 9 | 10 | def _wrap_dim(val, dim, name): 11 | if isinstance(val, tuple): 12 | val = list(val) 13 | if not isinstance(val, list): 14 | val = [val] * dim 15 | if len(val) != dim: 16 | print(f'{name} params has different kernel sizes, but length of list do not corresponds to dim: {dim}, and was reduced') 17 | val = val[:dim] 18 | return val 19 | 20 | 21 | def _create_dw_emulation(args, dim): 22 | """ 23 | Heuristic rules for emulation of DepthWise Conv via Shift layer 24 | in terms of output shape and and shift kernel behaviour. 25 | 26 | This directly influence on proper shift param initialization. 27 | Output shape via cutting the output and pooling(depending on stride) 28 | 29 | 30 | """ 31 | assert isinstance(args,dict), f'args must be dict' 32 | assert 'kernel_size' in args, f'args must contains at least the kernel_size inside' 33 | if 'dilation' in args: 34 | print('Warning! Found the dilation param which is not supported and will be ignored') 35 | kernel_size = _wrap_dim(args['kernel_size'], dim, 'kernel_size') 36 | padding = _wrap_dim(args.get('padding', 0), dim, 'padding') 37 | stride = _wrap_dim(args.get('stride', 1), dim, 'stride') 38 | itrt_scale = 2 if args['init_thumb_rule_type'] == 1 else 1 39 | 40 | borders = None 41 | padding = torch.tensor(padding, requires_grad=False) 42 | kernel_size = torch.tensor(kernel_size, requires_grad=False) 43 | tmp = 2*padding - kernel_size + 1 44 | if (tmp < 0).any(): 45 | borders = torch.zeros(dim, 2, dtype=torch.long, requires_grad=False) 46 | borders[tmp<0, 0] = abs(tmp[tmp<0]) // 2 47 | borders[tmp<0, 1] = abs(tmp[tmp<0]) - borders[tmp<0, 0] 48 | 49 | init_shift = kernel_size // itrt_scale 50 | scales = torch.tensor(stride,requires_grad=False).unsqueeze(0) 51 | 52 | pad_conv = {'zeros':0, 'replicate': 1, 'circular': 2, 'reflect':3,} 53 | padding = args.get('padding_mode', -1) 54 | if isinstance(padding,str): 55 | padding = pad_conv[padding] 56 | 57 | return init_shift, scales, borders, padding 58 | 59 | 60 | class _Shiftnd(nn.Module): 61 | """ 62 | Base module for all shifts. 63 | 64 | Arguments: 65 | in_channels(int) – Number of channels in the input image. 66 | padding(str) - Padding added to the input during shift. 67 | Allowed: ['zeros', 'border', 'periodic', 'reflect', 'symmetric']. Default: 'zeros'. 68 | init_shift(float/Tuple[float]) - Border for uniform initialization of weights(shifts). Default: 1. 69 | sparsity_term(float) - Strength of sparsity. Default: 5e-4. 70 | active_shift(bool) - Compute forward pass via bilinear interpolation. Default: False. 71 | emulate_dw(dict) - Just pass params of depthwise conv, that you trying replace with shift layer. 72 | It applies a heuristic and try to emulate their properties(including output shape) 73 | init_thumb_rule(int) - Type of thumb rule for shifts initialization. Allowed: Type 1(default): uniform(-init_shift, init_shift), 74 | Type 2: uniform(0,init_shift) * random_sign 75 | """ 76 | @staticmethod 77 | def _identity(x): 78 | return x 79 | 80 | @staticmethod 81 | def _pooling(ks, dim): 82 | if isinstance(ks, torch.Tensor): 83 | ks = ks.squeeze().cpu().numpy().tolist() 84 | if dim == 1: 85 | return partial(torch.nn.functional.avg_pool1d, kernel_size=ks, stride=ks, ceil_mode=True) 86 | elif dim == 2: 87 | return partial(torch.nn.functional.avg_pool2d, kernel_size=ks, stride=ks, ceil_mode=True) 88 | else: 89 | return partial(torch.nn.functional.avg_pool3d, kernel_size=ks, stride=ks, ceil_mode=True) 90 | 91 | @staticmethod 92 | def _init_thumb_rule_1(size, shape): 93 | return 2*size*torch.rand(shape) - size 94 | 95 | @staticmethod 96 | def _init_thumb_rule_2(size, shape): 97 | return size*torch.rand(shape) * (1 if random.random() < 0.5 else -1) 98 | 99 | 100 | def __init__(self, in_channels, padding='zeros', 101 | init_shift=1, 102 | sparsity_term=5e-4, 103 | active_flag=False, 104 | emulate_dw=None, 105 | init_thumb_rule=1): 106 | super(_Shiftnd, self).__init__() 107 | assert padding.lower() in paddings_dict.keys(), f'incorrect padding option: {padding}' 108 | self.padding = paddings_dict[padding] 109 | self.sparsity_term = sparsity_term 110 | self.in_channels = in_channels 111 | self._active_flag = active_flag 112 | self._shift_func = self._init_shift_fn() 113 | self.cut_borders = None 114 | self._reduction_fn = self._identity 115 | # init weights 116 | self._w_init_func = self._init_thumb_rule_1 117 | if init_thumb_rule == 2: 118 | self._w_init_func == self._init_thumb_rule_2 119 | # init hyper params 120 | self.init_shift = torch.tensor(_wrap_dim(init_shift, self.dim, 'init_shift'), 121 | requires_grad=False) 122 | self._w_post_init_scale = torch.ones(1, self.dim, requires_grad=False) 123 | 124 | if emulate_dw is not None: 125 | emulate_dw['init_thumb_rule_type'] = init_thumb_rule 126 | out = _create_dw_emulation(emulate_dw, self.dim) 127 | self.init_shift, self._w_post_init_scale, self.cut_borders, padding = out 128 | if padding != -1: 129 | self.padding == padding 130 | if not (self._w_post_init_scale == 1).all(): 131 | self._reduction_fn = self._pooling(self._w_post_init_scale, self.dim) 132 | self._init_weights() 133 | 134 | 135 | def _init_shift_fn(self): 136 | raise NotImplemented 137 | 138 | def _init_weights(self): 139 | self.weight = nn.Parameter(torch.Tensor(self.in_channels, self.dim)) 140 | self.reset_parameters() 141 | 142 | def reset_parameters(self): 143 | for i in range(self.dim): 144 | self.weight.data[:,i] = self._w_init_func(self.init_shift[i], self.in_channels) 145 | self.weight.data *= self._w_post_init_scale 146 | 147 | def _compute_weight_loss(self): 148 | return self.sparsity_term * torch.sum(torch.abs(self.weight)) 149 | 150 | def forward(self, input): 151 | loss = self._compute_weight_loss() if bool(self.sparsity_term) else None 152 | out = self._shift_func(input, self.weight, self.padding, self._active_flag, self.cut_borders) 153 | return self._reduction_fn(out), loss 154 | 155 | def extra_repr(self): 156 | pad = dict(zip(paddings_dict.values(), paddings_dict.keys()))[self.padding] 157 | active = f'Active shift on forward pass: {"Yes" if self._active_flag else "No"}' 158 | sp = f'Sparse shift: {"Yes - sparsity strength: {}".format(self.sparsity_term) if bool(self.sparsity_term) else "No"}' 159 | return f'in_channels={self.in_channels}, padding_method={pad}, {active}, {sp}' 160 | 161 | 162 | 163 | class Shift1d(_Shiftnd): 164 | """ 165 | Performs (index)shift operation under 3D tensor. Zero-FLOPs replacement of Depth-Wise convolution. 166 | 167 | 168 | Notes: 169 | - Shift values and directions is learnable for each channel. 170 | - Forward method is always return the two terms: output and loss 171 | - loss is None if sparsity_term is greater than zero 172 | 173 | 174 | Arguments: 175 | in_channels(int) – Number of channels in the input image. 176 | padding(str) - Padding added to the input during shift. 177 | Allowed: ['zeros', 'border', 'periodic', 'reflect', 'symmetric']. Default: 'zeros'. 178 | init_shift(float) - Border for uniform initialization of weights(shifts). Default: 1. 179 | sparsity_term(float) - Strength of sparsity. Default: 5e-4. 180 | active_shift(bool) - Compute forward pass via bilinear interpolation. Default: False. 181 | emulate_dw(dict) - Just pass params of depthwise conv, that you trying replace with shift layer. 182 | It applies a heuristic and try to emulate their properties(including output shape) 183 | init_thumb_rule(int) - Type of thumb rule for shifts initialization. Allowed: Type 1(default): uniform(-init_shift, init_shift), 184 | Type 2: uniform(0,init_shift) * random_sign 185 | """ 186 | def __init__(self, in_channels, padding='zeros', 187 | init_shift = 1, sparsity_term=5e-4, active_flag=False, 188 | emulate_dw=None, 189 | init_thumb_rule=1): 190 | self.dim = 1 191 | super(Shift1d, self).__init__(in_channels, padding, init_shift, sparsity_term, 192 | active_flag, emulate_dw, init_thumb_rule) 193 | 194 | def _init_shift_fn(self): 195 | return shift1d_func 196 | 197 | class Shift2d(_Shiftnd): 198 | """ 199 | Performs (index)shift operation under 4D(by h and w axes) tensor. Zero-FLOPs replacement of Depth-Wise convolution. 200 | 201 | 202 | Notes: 203 | - Shift values and directions is learnable for each channel. 204 | - Forward method is always return the two terms: output and loss 205 | - loss is None if sparsity_term is greater than zero 206 | 207 | 208 | Arguments: 209 | in_channels(int) – Number of channels in the input image. 210 | padding(str) - Padding added to the input during shift. 211 | Allowed: ['zeros', 'border', 'periodic', 'reflect', 'symmetric']. Default: 'zeros'. 212 | init_stride(float) - Border for uniform initialization of weights(shifts). Default: 1. 213 | sparsity_term(float) - Strength of sparsity. Default: 5e-4. 214 | active_shift(bool) - Compute forward pass via bilinear interpolation. Default: False. 215 | emulate_dw(dict) - Just pass params of depthwise conv, that you trying replace with shift layer. 216 | It applies a heuristic and try to emulate their properties(including output shape) 217 | init_thumb_rule(int) - Type of thumb rule for shifts initialization. Allowed: Type 1(default): uniform(-init_shift, init_shift), 218 | Type 2: uniform(0,init_shift) * random_sign 219 | """ 220 | def __init__(self, in_channels, padding='zeros', 221 | init_shift = 1, sparsity_term=5e-4, active_flag=False, 222 | emulate_dw=None, 223 | init_thumb_rule=1): 224 | self.dim = 2 225 | super(Shift2d, self).__init__(in_channels, padding, init_shift, sparsity_term, 226 | active_flag, emulate_dw, init_thumb_rule) 227 | 228 | def _init_shift_fn(self): 229 | return shift2d_func 230 | 231 | 232 | class Shift3d(_Shiftnd): 233 | """ 234 | Performs (index)shift operation under 5D(by h, w and d axes) tensor. Zero-FLOPs replacement of Depth-Wise convolution. 235 | 236 | 237 | Notes: 238 | - Shift values and directions is learnable for each channel. 239 | - Forward method is always return the two terms: output and loss 240 | - loss is None if sparsity_term is greater than zero 241 | 242 | 243 | Arguments: 244 | in_channels(int) – Number of channels in the input image. 245 | padding(str) - Padding added to the input during shift. 246 | Allowed: ['zeros', 'border', 'periodic', 'reflect', 'symmetric']. Default: 'zeros'. 247 | init_stride(float) - Border for uniform initialization of weights(shifts). Default: 1. 248 | sparsity_term(float) - Strength of sparsity. Default: 5e-4. 249 | active_shift(bool) - Compute forward pass via bilinear interpolation. Default: False. 250 | emulate_dw(dict) - Just pass params of depthwise conv, that you trying replace with shift layer. 251 | It applies a heuristic and try to emulate their properties(including output shape) 252 | init_thumb_rule(int) - Type of thumb rule for shifts initialization. Allowed: Type 1(default): uniform(-init_shift, init_shift), 253 | Type 2: uniform(0,init_shift) * random_sign 254 | """ 255 | def __init__(self, in_channels, padding='zeros', 256 | init_shift = 1, sparsity_term=5e-4, active_flag=False, 257 | emulate_dw=None, 258 | init_thumb_rule=1): 259 | self.dim = 3 260 | super(Shift3d, self).__init__(in_channels, padding, init_shift, sparsity_term, 261 | active_flag, emulate_dw, init_thumb_rule) 262 | 263 | def _init_shift_fn(self): 264 | return shift3d_func 265 | 266 | -------------------------------------------------------------------------------- /torchshifts/quantized/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | tv = torch.__version__ 4 | 5 | 6 | if tv >= '1.8': 7 | tndm_prop_list = torch.quantization.quantization_mappings.get_default_qconfig_propagation_list() 8 | tndm_mapping = torch.quantization.quantization_mappings.get_default_static_quant_module_mappings() 9 | tndm_qat_mapping = torch.quantization.quantization_mappings.get_default_qat_module_mappings() 10 | tndm_fuse_mapping = copy.deepcopy(torch.quantization.fuser_method_mappings.DEFAULT_OP_LIST_TO_FUSER_METHOD) 11 | elif tv >= '1.7': 12 | tndm_prop_list = copy.deepcopy(torch.quantization.quantization_mappings.get_qconfig_propagation_list()) 13 | tndm_mapping = copy.deepcopy(torch.quantization.quantization_mappings.get_static_quant_module_mappings()) 14 | tndm_qat_mapping = copy.deepcopy(torch.quantization.quantization_mappings.get_qat_module_mappings()) 15 | tndm_fuse_mapping = copy.deepcopy(torch.quantization.fuser_method_mappings.OP_LIST_TO_FUSER_METHOD) 16 | else: 17 | raise RuntimeError('The torch earlier than 1.7 is not supported. Please Update!') 18 | 19 | 20 | from .modules import Shift1d, Shift2d, Shift3d, new_quant_mapping 21 | quant_mapping = {**tndm_mapping, **new_quant_mapping} 22 | 23 | -------------------------------------------------------------------------------- /torchshifts/quantized/functional.py: -------------------------------------------------------------------------------- 1 | from torchshifts.functional import shift1d_func, shift2d_func, shift3d_func 2 | 3 | def shift1d_quantized(input, weight, padding_mode, cut_borders=None): 4 | if not input.is_quantized: 5 | raise ValueError("Input to 'shift1d_quantized' must be quantized!") 6 | return shift1d_func(input, weight, padding_mode, False, cut_borders) 7 | 8 | def shift2d_quantized(input, weight, padding_mode, cut_borders=None): 9 | if not input.is_quantized: 10 | raise ValueError("Input to 'shift2d_quantized' must be quantized!") 11 | return shift2d_func(input, weight, padding_mode, False, cut_borders) 12 | 13 | def shift3d_quantized(input, weight, padding_mode, cut_borders=None): 14 | if not input.is_quantized: 15 | raise ValueError("Input to 'shift3d_quantized' must be quantized!") 16 | return shift3d_func(input, weight, padding_mode, False, cut_borders) -------------------------------------------------------------------------------- /torchshifts/quantized/modules/__init__.py: -------------------------------------------------------------------------------- 1 | new_quant_mapping = {} 2 | 3 | from .shifts import Shift1d, Shift2d, Shift3d 4 | import torchshifts.modules.shifts as shifts 5 | 6 | new_quant_mapping.update({shifts.Shift1d: Shift1d, shifts.Shift2d: Shift2d, shifts.Shift3d: Shift3d}) 7 | -------------------------------------------------------------------------------- /torchshifts/quantized/modules/shifts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import copy 4 | from torchshifts.quantized.functional import shift1d_quantized, shift2d_quantized, shift3d_quantized 5 | import torchshifts.modules.shifts as shifts 6 | 7 | 8 | rp_dict = {v: k for k, v in shifts.paddings_dict.items()} 9 | 10 | def quantize_shift_weights(weight): 11 | scale = math.ceil((weight.max().item() - weight.min().item()) / 255.) 12 | return torch.quantize_per_tensor(weight, scale, 128, torch.quint8) 13 | 14 | class Shift1d(shifts.Shift1d): 15 | def __init__(self, in_channels, padding='zeros'): 16 | super(Shift1d, self).__init__(in_channels, padding, 1, 0, False) 17 | self.qweight = quantize_shift_weights(self.weight.float()) 18 | 19 | def forward(self, input): 20 | return self._reduction_fn(shift1d_quantized(input, self.qweight, self.padding, self.cut_borders)) 21 | 22 | def _get_name(self): 23 | return 'QuantizedShift1D' 24 | 25 | @staticmethod 26 | def from_float(mod): 27 | qshift = Shift1d(mod.in_channels, rp_dict[mod.padding]) 28 | qshift.cut_borders = mod.cut_borders 29 | qshift._reduction_fn = mod._reduction_fn 30 | qshift.weight = mod.weight 31 | qshift.qweight = quantize_shift_weights(mod.weight.float()) 32 | return qshift 33 | 34 | 35 | class Shift2d(shifts.Shift2d): 36 | def __init__(self, in_channels, padding='zeros'): 37 | super(Shift2d, self).__init__(in_channels, padding, 1, 0, False) 38 | self.qweight = quantize_shift_weights(self.weight.float()) 39 | 40 | def forward(self, input): 41 | return self._reduction_fn(shift2d_quantized(input, self.qweight, self.padding, self.cut_borders)) 42 | 43 | def _get_name(self): 44 | return 'QuantizedShift2D' 45 | 46 | @staticmethod 47 | def from_float(mod): 48 | qshift = Shift2d(mod.in_channels, rp_dict[mod.padding]) 49 | qshift.cut_borders = mod.cut_borders 50 | qshift._reduction_fn = mod._reduction_fn 51 | qshift.weight = mod.weight 52 | qshift.qweight = quantize_shift_weights(mod.weight.float()) 53 | return qshift 54 | 55 | class Shift3d(shifts.Shift3d): 56 | def __init__(self, in_channels, padding='zeros'): 57 | super(Shift3d, self).__init__(in_channels, padding, 1, 0, False) 58 | self.qweight = quantize_shift_weights(self.weight.float()) 59 | 60 | def forward(self, input): 61 | return self._reduction_fn(shift3d_quantized(input, self.qweight, self.padding, self.cut_borders)) 62 | 63 | def _get_name(self): 64 | return 'QuantizedShift3D' 65 | 66 | @staticmethod 67 | def from_float(mod): 68 | qshift = Shift3d(mod.in_channels, rp_dict[mod.padding]) 69 | qshift.cut_borders = mod.cut_borders 70 | qshift._reduction_fn = mod._reduction_fn 71 | qshift.weight = mod.weight 72 | qshift.qweight = quantize_shift_weights(mod.weight.float()) 73 | return qshift 74 | --------------------------------------------------------------------------------