├── README.md ├── acc_benchmark.py ├── assets └── test_res.png ├── check_deterministic.py ├── cost_benchmark.py ├── grid_sample1d ├── __init__.py ├── grid_sample1d_cuda.cpp ├── grid_sample1d_cuda_kernel.cu ├── op.py └── setup.py └── note.md /README.md: -------------------------------------------------------------------------------- 1 | # Grid Sample 1d 2 | 3 | pytorch cuda extension of grid sample 1d. Since pytorch only supports grid sample 2d/3d, I extend the 1d version for 4 | efficiency. The forward pass is 2~3x faster than pytorch grid sample. 5 | 6 | ## setup 7 | 8 | * Pytorch == 1.7.1 9 | * CUDA == 10.1 10 | 11 | Other versions of pytorch or cuda may work but I haven't test. 12 | 13 | you can choose to manually build it or use JIT 14 | ### Build 15 | 16 | ```bash 17 | python setup.py install 18 | ``` 19 | 20 | ### JIT 21 | 22 | comment `import grid_sample1d_cuda as grid_sample1d` in op.py 23 | 24 | uncomment 25 | 26 | ```python 27 | grid_sample1d = load( 28 | 'grid_sample1d_cuda', ['grid_sample1d_cuda.cpp', 'grid_sample1d_cuda_kernel.cu'], verbose=True) 29 | ``` 30 | 31 | in op.py 32 | 33 | ## Usage 34 | 35 | ```python 36 | import torch 37 | from grid_sample1d import GridSample1d 38 | 39 | grid_sample1d = GridSample1d(padding_mode=True, align_corners=True) 40 | N = 16 41 | C = 256 42 | L_in = 64 43 | L_out = 128 44 | input = torch.randn((N, C, L_in)).cuda() 45 | grids = torch.randn((N, L_out)).cuda() 46 | output = grid_sample1d(input, grids) 47 | ``` 48 | 49 | Options are 50 | 51 | * padding_mode: True for border padding, False for zero padding 52 | * align_corners: same with align_corners in `torch.nn.functional.grid_sample` 53 | 54 | ## difference 55 | 56 | In forward pass, calculation on the channel dim `C` is parallel, which is serial in `torch.nn.functional.grid_sample`. 57 | Parallel calculation on `C` may cause round off error in backward. But for now, I found it doesn't influence the forward pass. 58 | 59 | ## Test 60 | 61 | ### Accuracy Test 62 | 63 | Since grid sample 1d is a special case of grid sample 2d in most cases (not true when padding_mode & align_corners are 64 | both False). I test the accuracy of the implemented grid sample based on `torch.nn.functional.grid_sample`. 65 | 66 | ```python 67 | import torch 68 | import torch.nn.functional as F 69 | 70 | 71 | def gridsample1d_by2d(input, grid, padding_mode, align_corners): 72 | shape = grid.shape 73 | input = input.unsqueeze(-1) # batch_size * C * L_in * 1 74 | grid = grid.unsqueeze(1) # batch_size * 1 * L_out 75 | grid = torch.stack([-torch.ones_like(grid), grid], dim=-1) 76 | z = F.grid_sample(input, grid, padding_mode=padding_mode, align_corners=align_corners) 77 | C = input.shape[1] 78 | out_shape = [shape[0], C, shape[1]] 79 | z = z.view(*out_shape) # batch_size * C * L_out 80 | return z 81 | ``` 82 | 83 | It is recommended to test on your computer because I only test it on CUDA 10.1 GTX 1080Ti 84 | 85 | ```bash 86 | python test/acc_benchmark.py 87 | ``` 88 | 89 | Both the forward and the backward results are identical except for align_corners=True, padding_mode=False. It may be caused 90 | by round off error when we sum series float numbers in different orders. 91 | 92 | ### Deterministic Test 93 | 94 | It is very important to do deterministic test since the associative law is no more applied for the calculation of float 95 | numbers on computers. 96 | 97 | ```bash 98 | python test/check_deterministic.py 99 | ``` 100 | 101 | ## Note 102 | When padding_mode & align_corners are both `False`, we cannot regard grid sample 1d as a special case of grid sample 2d in pytorch. 103 | I have checked the cuda kernel of grid_sample in Pytorch. When padding_mode & align_corners are both `False`, 104 | the output of `torch.nn.functional.grid_sample` will be half of the expected. Hope it can be fixed one day. 105 | 106 | ## CPU support 107 | Too lazy to support 108 | 109 | ## speed & memory cost 110 | Here are the speed test results on different size of input 111 | ![](https://raw.githubusercontent.com/luo3300612/grid_sample1d/master/assets/test_res.png) 112 | 113 | ## references 114 | * [grid sample pytorch](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample) 115 | * [grid sample cuda](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu) 116 | * [pytorch C++ doc](https://pytorch.org/cppdocs/notes/tensor_creation.html) 117 | * [cuda doc](https://docs.nvidia.com/cuda/) 118 | -------------------------------------------------------------------------------- /acc_benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from grid_sample1d import GridSample1d 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | import numpy as np 6 | import random 7 | 8 | 9 | def setup_seed(seed): 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | torch.backends.cudnn.deterministic = True 15 | 16 | 17 | args_groups = [ 18 | {'original': {'padding_mode': 'zeros', 'align_corners': True}, 19 | 'mine': {'padding_mode': False, 'align_corners': True}}, 20 | {'original': {'padding_mode': 'zeros', 'align_corners': False}, 21 | 'mine': {'padding_mode': False, 'align_corners': False}}, 22 | {'original': {'padding_mode': 'border', 'align_corners': True}, 23 | 'mine': {'padding_mode': True, 'align_corners': True}}, 24 | {'original': {'padding_mode': 'border', 'align_corners': False}, 25 | 'mine': {'padding_mode': True, 'align_corners': False}} 26 | ] 27 | 28 | 29 | def original(input, grid, padding_mode, align_corners): 30 | shape = grid.shape 31 | grid = grid.sin() # batch_size * L_out 32 | input = input.unsqueeze(-1) # batch_size * C * L_in * 1 33 | 34 | # grid = grid.unsqueeze(-1) # batch_size * L_out * 1 35 | grid = grid.unsqueeze(1) # batch_size * 1 * L_out 36 | grid = torch.stack([-torch.ones_like(grid), grid], dim=-1) 37 | z = F.grid_sample(input, grid, padding_mode=padding_mode, align_corners=align_corners) 38 | C = input.shape[1] 39 | out_shape = [shape[0], C, shape[1]] 40 | z = z.view(*out_shape) # batch_size * C * L_out 41 | return z 42 | 43 | 44 | def mine(input, grid, module): 45 | shape = grid.shape 46 | grid = grid.sin() 47 | z = module(input, grid) 48 | C = input.shape[1] 49 | out_shape = [shape[0], C, shape[1]] 50 | z = z.view(*out_shape) 51 | return z 52 | 53 | 54 | def inspect(output, output_origin, verbose_matrix=False, verbose=False): 55 | err = torch.abs(output - output_origin) 56 | max_err = torch.max(err).item() 57 | pos = torch.argmax(err) 58 | 59 | rela_err = err / torch.abs(output_origin) 60 | max_err_rela = torch.max(rela_err) 61 | pos_rela = torch.argmax(rela_err) 62 | 63 | N_err = torch.sum(err > eps).item() 64 | N_rela_err = torch.sum(rela_err > eps_r).item() 65 | 66 | # if max_err > eps: 67 | # if verbose_matrix: 68 | # print('output') 69 | # print(output) 70 | # print('origin') 71 | # print(output_origin) 72 | # print(output - output_origin) 73 | # print('different!') 74 | # print(f'max_err={max_err}') 75 | # print(f'where origin={output_origin.view(-1)[pos]}') 76 | # print(f'mine={output.view(-1)[pos]}') 77 | # print(f'N err > eps={N_err}') 78 | # print(f'err% = {N_err / torch.numel(output) * 100:.2f}') 79 | # print('-' * 50) 80 | if max_err_rela > eps_r: 81 | if verbose: 82 | if verbose_matrix: 83 | print('output') 84 | print(output) 85 | print('origin') 86 | print(output_origin) 87 | print(output - output_origin) 88 | print('different!') 89 | print(f'max_err_rela={max_err_rela}') 90 | print(f'where origin={output_origin.view(-1)[pos_rela]}') 91 | print(f'mine={output.view(-1)[pos_rela]}') 92 | print(f'N err > eps={N_err}') 93 | print(f'err% = {N_rela_err / torch.numel(output) * 100:.2f}') 94 | # if N_err == 0: 95 | # print('same!') 96 | return N_rela_err 97 | 98 | 99 | if __name__ == '__main__': 100 | setup_seed(0) 101 | 102 | batch_size = 20 103 | C = 256 104 | L_in = 16 105 | L_out = 32 106 | 107 | eps = 1e-6 108 | eps_r = 1e-5 109 | N_samples = 100 110 | 111 | print('forward') 112 | 113 | for args in args_groups: 114 | print('testing') 115 | print(args) 116 | 117 | module = GridSample1d(**args['mine']) 118 | running_err_forward = 0. 119 | running_err_backward_input = 0. 120 | running_err_backward_grid = 0. 121 | try: 122 | with torch.no_grad(): 123 | for i in tqdm(range(N_samples)): 124 | input = torch.randn((batch_size, C, L_in)).cuda() 125 | grid = torch.randn(batch_size, L_out).cuda() 126 | output = mine(input, grid, module).cpu() 127 | output_origin = original(input, grid, **args['original']).cpu() 128 | try: 129 | if (not args['mine']['padding_mode']) and (not args['mine']['align_corners']): 130 | torch.allclose(output, output_origin * 2, atol=eps, rtol=eps_r) 131 | else: 132 | assert torch.allclose(output, output_origin, atol=eps, rtol=eps_r) 133 | except: 134 | N_err = inspect(output, output_origin) 135 | running_err_forward += N_err / torch.numel(output) 136 | if N_err / torch.numel(output) >= 0.05: 137 | raise 138 | else: 139 | pass 140 | print(f'Forward ACC test done on {N_samples} samples with eps={eps}') 141 | 142 | print('backward') 143 | for i in tqdm(range(N_samples)): 144 | setup_seed(i) 145 | grid_original = torch.randn((batch_size, L_out), requires_grad=True).cuda() 146 | input_original = torch.randn((batch_size, C, L_in), requires_grad=True).cuda() 147 | grid_original.retain_grad() 148 | input_original.retain_grad() 149 | 150 | setup_seed(i) 151 | grid_mine = torch.randn((batch_size, L_out), requires_grad=True).cuda() 152 | input_mine = torch.randn((batch_size, C, L_in), requires_grad=True).cuda() 153 | grid_mine.retain_grad() 154 | input_mine.retain_grad() 155 | 156 | output_origin = original(input_original, grid_original, **args['original']) 157 | output = mine(input_mine, grid_mine, module) 158 | 159 | if (not args['mine']['padding_mode']) and (not args['mine']['align_corners']): 160 | assert torch.allclose(output, output_origin*2, atol=eps, rtol=eps_r) 161 | else: 162 | assert torch.allclose(output, output_origin, atol=eps, rtol=eps_r) 163 | 164 | output_origin = torch.sum(output_origin.view(-1)) 165 | output = torch.sum(output.view(-1)) 166 | 167 | output.backward() 168 | output_origin.backward() 169 | 170 | grad_grid_original = grid_original.grad 171 | grad_input_original = input_original.grad 172 | 173 | grad_grid_mine = grid_mine.grad 174 | grad_input_mine = input_mine.grad 175 | 176 | try: 177 | if (not args['mine']['padding_mode']) and (not args['mine']['align_corners']): 178 | assert torch.allclose(2*grad_grid_original, grad_grid_mine, atol=eps, rtol=eps_r) 179 | assert torch.allclose(2*grad_input_original, grad_input_mine, atol=eps, rtol=eps_r) 180 | else: 181 | assert torch.allclose(grad_grid_original, grad_grid_mine, atol=eps, rtol=eps_r) 182 | assert torch.allclose(grad_input_original, grad_input_mine, atol=eps, rtol=eps_r) 183 | except AssertionError: 184 | N_err_grid = inspect(grad_grid_mine, grad_grid_original,verbose=True) 185 | N_err_input = inspect(grad_input_mine, grad_input_original,verbose=True) 186 | 187 | running_err_backward_grid += N_err_grid / torch.numel(grad_grid_mine) 188 | running_err_backward_input += N_err_input / torch.numel(grad_input_mine) 189 | if N_err_grid / torch.numel(grad_grid_mine) >= 0.05 or N_err_input / torch.numel( 190 | grad_input_mine) >= 0.05: 191 | raise 192 | else: 193 | pass 194 | print(f'Backward ACC test done on {N_samples} samples with eps={eps}') 195 | 196 | print(f'running err forward:{running_err_forward * 100:.2f}%') 197 | print(f'running err backward input:{running_err_backward_input:.2f}%') 198 | print(f'running err backward grid:{running_err_backward_grid:.2f}%') 199 | except AssertionError: 200 | raise 201 | print('Done') 202 | -------------------------------------------------------------------------------- /assets/test_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luo3300612/grid_sample1d/cda0dfa311daeb36b415c30fa397b714affb20d9/assets/test_res.png -------------------------------------------------------------------------------- /check_deterministic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from grid_sample1d.op import GridSample1d 3 | from tqdm import tqdm 4 | from acc_benchmark import args_groups, mine, inspect, setup_seed 5 | 6 | if __name__ == '__main__': 7 | setup_seed(0) 8 | 9 | batch_size = 20 10 | C = 512 11 | L_in = 16 12 | L_out = 32 13 | 14 | eps = 1e-6 15 | eps_r = 1e-5 16 | N_samples = 100 17 | 18 | print('forward') 19 | 20 | for args in args_groups: 21 | print('testing') 22 | print(args) 23 | module = GridSample1d(**args['mine']) 24 | running_err_forward = 0. 25 | running_err_backward_input = 0. 26 | running_err_backward_grid = 0. 27 | prev_output = None 28 | input = torch.randn((batch_size, C, L_in)).cuda() 29 | grid = torch.randn(batch_size, L_out).cuda() 30 | try: 31 | with torch.no_grad(): 32 | for i in tqdm(range(N_samples)): 33 | output = mine(input, grid, module).cpu() 34 | 35 | if prev_output is not None: 36 | try: 37 | assert torch.allclose(output, prev_output, atol=eps, rtol=eps_r) 38 | except: 39 | N_err = inspect(output, prev_output) 40 | running_err_forward += N_err / torch.numel(output) 41 | if N_err / torch.numel(output) >= 0.05: 42 | raise 43 | else: 44 | pass 45 | prev_output = output 46 | print(f'Forward Det test done on {N_samples} samples with eps={eps}') 47 | 48 | print('backward') 49 | prev_grad_grid = None 50 | prev_grad_input = None 51 | for i in tqdm(range(N_samples)): 52 | setup_seed(0) 53 | grid_mine = torch.randn((batch_size, L_out), requires_grad=True).cuda() 54 | input_mine = torch.randn((batch_size, C, L_in), requires_grad=True).cuda() 55 | grid_mine.retain_grad() 56 | input_mine.retain_grad() 57 | 58 | output = mine(input_mine, grid_mine, module) 59 | 60 | output = torch.sum(output.view(-1)) 61 | 62 | output.backward() 63 | 64 | grad_grid_mine = grid_mine.grad 65 | grad_input_mine = input_mine.grad 66 | 67 | if prev_grad_grid is not None: 68 | try: 69 | assert torch.allclose(prev_grad_grid, grad_grid_mine, atol=eps, rtol=eps_r) 70 | assert torch.allclose(prev_grad_input, grad_input_mine, atol=eps, rtol=eps_r) 71 | except AssertionError: 72 | N_err_grid = inspect(grad_grid_mine, prev_grad_grid,verbose=True) 73 | N_err_input = inspect(grad_input_mine, prev_grad_input,verbose=True) 74 | 75 | running_err_backward_grid += N_err_grid / torch.numel(grad_grid_mine) 76 | running_err_backward_input += N_err_input / torch.numel(grad_input_mine) 77 | if N_err_grid / torch.numel(grad_grid_mine) >= 0.05 or N_err_input / torch.numel( 78 | grad_input_mine) >= 0.05: 79 | raise 80 | else: 81 | pass 82 | print(f'Backward Det test done on {N_samples} samples with eps={eps}') 83 | 84 | print(f'running err forward:{running_err_forward * 100:.2f}%') 85 | print(f'running err backward input:{running_err_backward_input:.2f}%') 86 | print(f'running err backward grid:{running_err_backward_grid:.2f}%') 87 | except AssertionError: 88 | raise 89 | print('Done') 90 | -------------------------------------------------------------------------------- /cost_benchmark.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from grid_sample1d import GridSample1d 4 | from acc_benchmark import original, mine, args_groups 5 | import json 6 | 7 | assert torch.cuda.is_available() 8 | cuda_device = torch.device("cuda") # device object representing GPU 9 | 10 | Ns = [4, 8, 16, 32, 64, 128, 256] 11 | Cs = [64, 128, 256, 512, 1024] 12 | L_ins = [16, 32, 64, 128, 256, 512] 13 | L_outs = [32, 64, 128, 256, 512] 14 | 15 | N_fix = 32 16 | C_fix = 256 17 | L_in_fix = 64 18 | L_out_fix = 128 19 | 20 | 21 | def get_forward_backward_speed(func, inputs, N_iters): 22 | forward = 0 23 | backward = 0 24 | for _ in range(N_iters): 25 | start = time.time() 26 | output = func(*inputs) 27 | torch.cuda.synchronize() 28 | forward += time.time() - start 29 | 30 | start = time.time() 31 | output.sum().backward() 32 | torch.cuda.synchronize() 33 | backward += time.time() - start 34 | print(func) 35 | print('Forward: {:.3f} ms | Backward {:.3f} ms'.format(forward * 1e3 / N_iters, backward * 1e3 / N_iters)) 36 | return forward * 1e3 / N_iters, backward * 1e3 / N_iters 37 | 38 | 39 | if __name__ == '__main__': 40 | N = 4 # 4 41 | C = 256 42 | L_in = 16 43 | L_out = 32 44 | N_iters = 1000 45 | 46 | N_ys = [[] for _ in range(len(args_groups))] 47 | N_yso = [[] for _ in range(len(args_groups))] 48 | for N in Ns: 49 | print('N=', N) 50 | input = torch.randn((N, C_fix, L_in_fix), requires_grad=True).cuda() 51 | grids = torch.randn((N, L_out_fix), requires_grad=True).cuda() 52 | 53 | func = torch.sin 54 | 55 | sin_forward, sin_backward = get_forward_backward_speed(func, [grids], N_iters=N_iters) 56 | 57 | for i, args in enumerate(args_groups): 58 | func = mine 59 | module = GridSample1d(**args['mine']) 60 | inputs = [input, grids, module] 61 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 62 | 63 | real_forward = forward - sin_forward 64 | real_backward = backward - sin_backward 65 | 66 | N_ys[i].append((real_forward, real_backward)) 67 | 68 | print('-' * 50) 69 | print('start test original') 70 | func = original 71 | inputs = [input, grids, args['original']['padding_mode'], args['original']['align_corners']] 72 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 73 | real_forward = forward - sin_forward 74 | real_backward = backward - sin_backward 75 | N_yso[i].append((real_forward, real_backward)) 76 | 77 | C_ys = [[] for _ in range(len(args_groups))] 78 | C_yso = [[] for _ in range(len(args_groups))] 79 | for C in Cs: 80 | print('C=', C) 81 | input = torch.randn((N_fix, C, L_in_fix), requires_grad=True).cuda() 82 | grids = torch.randn((N_fix, L_out_fix), requires_grad=True).cuda() 83 | 84 | func = torch.sin 85 | 86 | sin_forward, sin_backward = get_forward_backward_speed(func, [grids], N_iters=N_iters) 87 | 88 | for i, args in enumerate(args_groups): 89 | func = mine 90 | module = GridSample1d(**args['mine']) 91 | inputs = [input, grids, module] 92 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 93 | 94 | real_forward = forward - sin_forward 95 | real_backward = backward - sin_backward 96 | 97 | C_ys[i].append((real_forward, real_backward)) 98 | 99 | print('-' * 50) 100 | print('start test original') 101 | func = original 102 | inputs = [input, grids, args['original']['padding_mode'], args['original']['align_corners']] 103 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 104 | real_forward = forward - sin_forward 105 | real_backward = backward - sin_backward 106 | C_yso[i].append((real_forward, real_backward)) 107 | 108 | L_in_ys = [[] for _ in range(len(args_groups))] 109 | L_in_yso = [[] for _ in range(len(args_groups))] 110 | for L_in in L_ins: 111 | print('L_in=', L_in) 112 | input = torch.randn((N_fix, C_fix, L_in), requires_grad=True).cuda() 113 | grids = torch.randn((N_fix, L_out_fix), requires_grad=True).cuda() 114 | 115 | func = torch.sin 116 | 117 | sin_forward, sin_backward = get_forward_backward_speed(func, [grids], N_iters=N_iters) 118 | 119 | for i, args in enumerate(args_groups): 120 | func = mine 121 | module = GridSample1d(**args['mine']) 122 | inputs = [input, grids, module] 123 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 124 | 125 | real_forward = forward - sin_forward 126 | real_backward = backward - sin_backward 127 | 128 | L_in_ys[i].append((real_forward, real_backward)) 129 | 130 | print('-' * 50) 131 | print('start test original') 132 | func = original 133 | inputs = [input, grids, args['original']['padding_mode'], args['original']['align_corners']] 134 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 135 | real_forward = forward - sin_forward 136 | real_backward = backward - sin_backward 137 | L_in_yso[i].append((real_forward, real_backward)) 138 | 139 | L_out_ys = [[] for _ in range(len(args_groups))] 140 | L_out_yso = [[] for _ in range(len(args_groups))] 141 | for L_out in L_outs: 142 | print('L_out=', L_out) 143 | input = torch.randn((N_fix, C_fix, L_in_fix), requires_grad=True).cuda() 144 | grids = torch.randn((N_fix, L_out), requires_grad=True).cuda() 145 | 146 | func = torch.sin 147 | 148 | sin_forward, sin_backward = get_forward_backward_speed(func, [grids], N_iters=N_iters) 149 | 150 | for i, args in enumerate(args_groups): 151 | func = mine 152 | module = GridSample1d(**args['mine']) 153 | inputs = [input, grids, module] 154 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 155 | 156 | real_forward = forward - sin_forward 157 | real_backward = backward - sin_backward 158 | 159 | L_out_ys[i].append((real_forward, real_backward)) 160 | 161 | print('-' * 50) 162 | print('start test original') 163 | func = original 164 | inputs = [input, grids, args['original']['padding_mode'], args['original']['align_corners']] 165 | forward, backward = get_forward_backward_speed(func, inputs, N_iters=N_iters) 166 | real_forward = forward - sin_forward 167 | real_backward = backward - sin_backward 168 | L_out_yso[i].append((real_forward, real_backward)) 169 | 170 | res = { 171 | 'N': [Ns, N_ys, N_yso], 172 | 'C': [Cs, C_ys, C_yso], 173 | 'L_in': [L_ins, L_in_ys, L_in_yso], 174 | 'L_out': [L_outs, L_out_ys, L_out_yso] 175 | } 176 | json.dump(res, open('speed_res.json','w')) 177 | -------------------------------------------------------------------------------- /grid_sample1d/__init__.py: -------------------------------------------------------------------------------- 1 | from .op import GridSample1d -------------------------------------------------------------------------------- /grid_sample1d/grid_sample1d_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | torch::Tensor grid_sample1d_cuda_forward( 8 | torch::Tensor input, 9 | torch::Tensor grid, 10 | bool padding_mode, 11 | bool align_corners); 12 | 13 | std::vector grid_sample1d_cuda_backward( 14 | torch::Tensor grad_output, 15 | torch::Tensor input, 16 | torch::Tensor grid, 17 | bool padding_mode, 18 | bool align_corners); 19 | 20 | // C++ interface 21 | 22 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 23 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | torch::Tensor grid_sample1d_forward( 28 | torch::Tensor input, 29 | torch::Tensor grid, 30 | bool padding_mode, 31 | bool align_corners) { 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(grid); 34 | return grid_sample1d_cuda_forward(input, grid, padding_mode, align_corners); 35 | } 36 | 37 | std::vector grid_sample1d_backward( 38 | torch::Tensor grad_output, 39 | torch::Tensor input, 40 | torch::Tensor grid, 41 | bool padding_mode, 42 | bool align_corners) { 43 | CHECK_INPUT(grad_output); 44 | CHECK_INPUT(input); 45 | CHECK_INPUT(grid); 46 | return grid_sample1d_cuda_backward(grad_output,input,grid,padding_mode,align_corners); 47 | // return grid_sample1d_cuda_backward( 48 | // grad_output, 49 | // input, 50 | // grid); 51 | } 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", &grid_sample1d_forward, "grid sample forward (CUDA)"); 55 | m.def("backward", &grid_sample1d_backward, "grid sample backward (CUDA)"); 56 | } 57 | -------------------------------------------------------------------------------- /grid_sample1d/grid_sample1d_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | namespace { 9 | template 10 | static __forceinline__ __device__ 11 | scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { 12 | if (align_corners) { 13 | // unnormalize coord from [-1, 1] to [0, size - 1] 14 | return ((coord + 1.f) / 2) * (size - 1); 15 | } else { 16 | // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] 17 | return ((coord + 1.f) * size - 1) / 2; 18 | } 19 | } 20 | 21 | static __forceinline__ __device__ 22 | bool within_bounds(int h, int H) { 23 | return h >= 0 && h < H; 24 | } 25 | 26 | // Clips coordinates to between 0 and clip_limit - 1 27 | template 28 | static __forceinline__ __device__ 29 | scalar_t clip_coordinates(scalar_t in, int clip_limit) { 30 | return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); 31 | } 32 | 33 | // Reflects coordinates until they fall between low and high (inclusive). 34 | // The bounds are passed as twice their value so that half-integer values 35 | // can be represented as ints. 36 | template 37 | static __forceinline__ __device__ 38 | scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) { 39 | if (twice_low == twice_high) { 40 | return static_cast(0); 41 | } 42 | scalar_t min = static_cast(twice_low) / 2; 43 | scalar_t span = static_cast(twice_high - twice_low) / 2; 44 | in = ::fabs(in - min); 45 | // `fmod` returns same sign as `in`, which is positive after the `fabs` above. 46 | scalar_t extra = ::fmod(in, span); 47 | int flips = static_cast(::floor(in / span)); 48 | if (flips % 2 == 0) { 49 | return extra + min; 50 | } else { 51 | return span - extra + min; 52 | } 53 | } 54 | 55 | template 56 | static __forceinline__ __device__ 57 | scalar_t safe_downgrade_to_int_range(scalar_t x){ 58 | // -100.0 does not have special meaning. This is just to make sure 59 | // it's not within_bounds_2d or within_bounds_3d, and does not cause 60 | // undefined behavior. See #35506. 61 | if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast(x))) 62 | return static_cast(-100.0); 63 | return x; 64 | } 65 | 66 | 67 | template 68 | static __forceinline__ __device__ 69 | scalar_t compute_coordinates(scalar_t coord, int size, 70 | bool padding_mode, 71 | bool align_corners) { 72 | if (padding_mode) { // True for border padding 73 | // clip coordinates to image borders 74 | coord = clip_coordinates(coord, size); 75 | } 76 | coord = safe_downgrade_to_int_range(coord); 77 | return coord; 78 | } 79 | 80 | // Computes the pixel source index value for a grid coordinate 81 | template 82 | static __forceinline__ __device__ 83 | scalar_t grid_sampler_compute_source_index( 84 | scalar_t coord, 85 | int size, 86 | bool padding_mode, 87 | bool align_corners) { 88 | coord = grid_sampler_unnormalize(coord, size, align_corners); 89 | coord = compute_coordinates(coord, size, padding_mode, align_corners); 90 | return coord; 91 | } 92 | 93 | template 94 | static __forceinline__ __device__ 95 | scalar_t grid_sampler_unnormalize_set_grad(scalar_t coord, int size, 96 | bool align_corners, scalar_t *grad_in) { 97 | if (align_corners) { 98 | // unnormalize coord from [-1, 1] to [0, size - 1] 99 | *grad_in = static_cast(size - 1) / 2; 100 | return ((coord + 1.f) / 2) * (size - 1); 101 | } else { 102 | // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] 103 | *grad_in = static_cast(size) / 2; 104 | return ((coord + 1.f) * size - 1) / 2; 105 | } 106 | } 107 | 108 | template 109 | static __forceinline__ __device__ 110 | scalar_t clip_coordinates_set_grad(scalar_t in, int clip_limit, scalar_t *grad_in) { 111 | // Note that it is important for the gradient calculation that borders 112 | // are considered out of bounds. 113 | if (in <= static_cast(0)) { 114 | *grad_in = static_cast(0); 115 | return static_cast(0); 116 | } else { 117 | scalar_t max = static_cast(clip_limit - 1); 118 | if (in >= max) { 119 | *grad_in = static_cast(0); 120 | return max; 121 | } else { 122 | *grad_in = static_cast(1); 123 | return in; 124 | } 125 | } 126 | } 127 | 128 | template 129 | static __forceinline__ __device__ 130 | scalar_t grid_sampler_compute_source_index_set_grad( 131 | scalar_t coord, 132 | int size, 133 | bool padding_mode, 134 | bool align_corners, 135 | scalar_t *grad_in) { 136 | scalar_t grad_clip, grad_refl; 137 | coord = grid_sampler_unnormalize_set_grad(coord, size, align_corners, grad_in); 138 | if (padding_mode) { // true for border padding 139 | // clip coordinates to image borders 140 | coord = clip_coordinates_set_grad(coord, size, &grad_clip); 141 | *grad_in = (*grad_in) * grad_clip; 142 | } 143 | coord = safe_downgrade_to_int_range(coord); 144 | return coord; 145 | } 146 | 147 | template 148 | __global__ void grid_sample1d_cuda_forward_kernel( 149 | const scalar_t* __restrict__ input, 150 | const scalar_t* __restrict__ grid, 151 | scalar_t* __restrict__ output, 152 | bool padding_mode, 153 | bool align_corners, 154 | const int N, 155 | const int L_in, 156 | const int batch_size, 157 | const int C, 158 | const int L_out) { 159 | 160 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 161 | 162 | if (index < N){ 163 | const int l = index % L_out; 164 | const int c = (index/L_out) % C; 165 | const int n = index / (C * L_out); 166 | 167 | const int grid_offset = n * L_out + l; 168 | 169 | scalar_t x = grid[grid_offset]; 170 | scalar_t ix = grid_sampler_compute_source_index(x, L_in, padding_mode, align_corners); 171 | 172 | const int index_left = ::floor(ix); 173 | const int index_right = index_left + 1; 174 | 175 | // const int output_offset = l + c * L_out + n * C * L_out; 176 | const int output_offset = l + c * L_out + n * C * L_out; 177 | scalar_t surface_left = index_right-ix; 178 | scalar_t surface_right = ix-index_left; 179 | 180 | const int input_left_offset = index_left + c * L_in + n * L_in * C; 181 | const int input_right_offset = index_right + c * L_in + n * L_in * C; 182 | output[output_offset] = static_cast(0); 183 | if(within_bounds(index_left, L_in)){ 184 | output[output_offset] += input[input_left_offset] * surface_left; 185 | } 186 | if(within_bounds(index_right, L_in)){ 187 | output[output_offset] += input[input_right_offset] * surface_right; 188 | } 189 | // output[output_offset] = (ix-index_left) * (input[input_right_offset] - input[input_left_offset]) + input[input_left_offset]; 190 | } 191 | } 192 | 193 | template 194 | __global__ void grid_sample1d_cuda_backward_kernel( 195 | const scalar_t* __restrict__ grad_output, 196 | const scalar_t* __restrict__ input, 197 | const scalar_t* __restrict__ grid, 198 | scalar_t* __restrict__ grad_input, 199 | scalar_t* __restrict__ grad_grid, 200 | bool padding_mode, 201 | bool align_corners, 202 | const int N, 203 | const int L_in, 204 | const int batch_size, 205 | const int C, 206 | const int L_out) { 207 | 208 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 209 | 210 | if (index < N){ 211 | const int l = index % L_out; 212 | const int n = index / L_out; 213 | 214 | const int grid_offset = n * L_out + l; 215 | // 216 | scalar_t x = grid[grid_offset]; 217 | scalar_t gix_mult; 218 | scalar_t ix = grid_sampler_compute_source_index_set_grad(x, L_in, padding_mode, align_corners, &gix_mult); 219 | // 220 | const int index_left = ::floor(ix); 221 | const int index_right = index_left + 1; 222 | 223 | 224 | scalar_t surface_left = index_right-ix; 225 | scalar_t surface_right = ix-index_left; 226 | 227 | scalar_t iy = static_cast(0); 228 | scalar_t iy_se = static_cast(1); 229 | 230 | scalar_t gix = static_cast(0); 231 | 232 | for(int c=0; c<<>>( 284 | input.data(), 285 | grid.data(), 286 | output.data(), 287 | padding_mode, 288 | align_corners, 289 | N, 290 | L_in, 291 | batch_size, 292 | C, 293 | L_out); 294 | })); 295 | 296 | return output; 297 | } 298 | 299 | 300 | std::vector grid_sample1d_cuda_backward( 301 | torch::Tensor grad_output, 302 | torch::Tensor input, 303 | torch::Tensor grid, 304 | bool padding_mode, 305 | bool align_corners) { 306 | 307 | const auto batch_size = input.size(0); 308 | const auto C = input.size(1); 309 | const auto L_in = input.size(2); 310 | 311 | const auto L_out = grid.size(1); 312 | 313 | torch::Tensor grad_input = torch::zeros_like(input); 314 | torch::Tensor grad_grid = torch::zeros_like(grid); 315 | 316 | const int threads = 1024; 317 | const int N = L_out*batch_size; 318 | const int blocks = (N + threads-1)/ threads; 319 | 320 | AT_DISPATCH_FLOATING_TYPES(input.type(), "grid_sample1d_backward_cuda", ([&] { 321 | grid_sample1d_cuda_backward_kernel<<>>( 322 | grad_output.data(), 323 | input.data(), 324 | grid.data(), 325 | grad_input.data(), 326 | grad_grid.data(), 327 | padding_mode, 328 | align_corners, 329 | N, 330 | L_in, 331 | batch_size, 332 | C, 333 | L_out); 334 | })); 335 | return {grad_input, grad_grid}; 336 | } -------------------------------------------------------------------------------- /grid_sample1d/op.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | from torch.utils.cpp_extension import load 6 | 7 | # build by python setup.py install 8 | # import grid_sample1d_cuda as grid_sample1d 9 | 10 | # jit 11 | grid_sample1d = load( 12 | 'grid_sample1d_cuda', ['src/grid_sample1d_cuda.cpp', 'src/grid_sample1d_cuda_kernel.cu'], verbose=True) 13 | 14 | class GridSample1dFunction(Function): 15 | @staticmethod 16 | def forward(ctx, input, grid, padding_mode, align_corners): 17 | outputs = grid_sample1d.forward(input, grid, padding_mode, align_corners) 18 | # print(print(outputs)) 19 | ctx.save_for_backward(*(input, grid)) 20 | ctx.padding_mode = padding_mode 21 | ctx.align_corners = align_corners 22 | return outputs 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | outputs = grid_sample1d.backward(grad_output.contiguous(), *ctx.saved_variables, ctx.padding_mode, 27 | ctx.align_corners) 28 | # outputs = lltm_cuda.backward( 29 | # grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) 30 | # d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs 31 | # return d_input, d_weights, d_bias, d_old_h, d_old_cell 32 | d_input, d_grid = outputs 33 | # print(d_input) 34 | # print(d_grid) 35 | return d_input, d_grid, None, None 36 | 37 | 38 | class GridSample1d(nn.Module): 39 | def __init__(self, padding_mode, align_corners): 40 | ''' 41 | :param padding_mode: True for border padding, False for zero padding 42 | :param align_corners: same with grid_sample in pytorch 43 | ''' 44 | super(GridSample1d, self).__init__() 45 | self.padding_mode = padding_mode 46 | self.align_corners = align_corners 47 | 48 | def forward(self, input, grid): 49 | return GridSample1dFunction.apply(input, grid, self.padding_mode, self.align_corners) 50 | -------------------------------------------------------------------------------- /grid_sample1d/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='grid_sample1d_cuda', 6 | ext_modules=[ 7 | CUDAExtension('grid_sample1d_cuda', [ 8 | 'src/grid_sample1d_cuda.cpp', 9 | 'src/grid_sample1d_cuda_kernel.cu', 10 | ]) 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /note.md: -------------------------------------------------------------------------------- 1 | # Ref 2 | 3 | * [grid sample pytorch](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html?highlight=grid_sample#torch.nn.functional.grid_sample) 4 | * [grid sample cuda](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/GridSampler.cu) 5 | * [pytorch C++ doc](https://pytorch.org/cppdocs/notes/tensor_creation.html) 6 | * [cuda doc](https://docs.nvidia.com/cuda/) 7 | 8 | ## 在相同位置创建tensor 9 | 10 | ```c++ 11 | auto output = torch::zeros({batch_size, C, L_out}, input.options()); 12 | ``` 13 | 14 | 这样就在input相同的位置创建了output 15 | 16 | ## accessor是inefficient的 17 | 18 | 源文档中说了,即便accessor可以方便代码的可读性,但非常低效, 从pytorch官方代码中没用使用accessor也可以发现,这玩意根本就不好用 19 | 20 | ## benchmark 21 | 100 iter average 22 | ```python 23 | grid # 4 * 64**3 * 256 24 | input # 4 * 256 * 16 25 | ``` 26 | |item|speed|RAM| 27 | |----|----|-----| 28 | |lxy original|552ms|10825MB| 29 | |2d grid sample forward|198ms|9800MB| 30 | |lxy dog grid sample forward|137ms|5705MB| 31 | 32 | ## 进度 33 | forward backward done 34 | 35 | 但backward error会随着C的增大而增大 36 | 37 | 修改为C上串行后,除了当padding mode=zeros,align_corners=True以外,不再存在上述问题 38 | 39 | 尝试修改了backward中grid梯度两个计算Block的顺序,结果会与2d的sample不同,这也表明了浮点数累加不满足结合律这一问题 40 | 41 | 修复了align=false,zero padding下结果是两倍的问题 42 | ## 一个问题 43 | pytorch中的2d grid sample,在双线性插值,且zero padding模式align_corners=False时是不适用H/W中有一个是1的情况的,最终的结果会是真实结果的0.5倍,这主要是因为align_corner为False时,1维的坐标变成-0.5, 44 | 而生成出来的需要插值坐标的左右近邻必有一个不满足within_bounds的判断,最终会忽略掉恰好一半的数值 45 | --------------------------------------------------------------------------------