└── pytorch_grid_sample_python.md /pytorch_grid_sample_python.md: -------------------------------------------------------------------------------- 1 | ***torch.nn.functional.grid_sample with Zero Padding*** 2 | ``` 3 | def grid_sampler_unnormalize(coord, side, align_corners): 4 | if align_corners: 5 | return ((coord + 1) / 2) * (side - 1) 6 | else: 7 | return ((coord + 1) * side - 1) / 2 8 | 9 | def grid_sampler_compute_source_index(coord, size, align_corners): 10 | coord = grid_sampler_unnormalize(coord, size, align_corners) 11 | return coord 12 | 13 | def safe_get(image, n, c, x, y, H, W): 14 | value = torch.Tensor([0]) 15 | if x >= 0 and x < W and y >=0 and y < H: 16 | value = image[n, c, y, x] 17 | return value 18 | 19 | 20 | def bilinear_interpolate_torch_2D(image, grid, align_corners=False): 21 | ''' 22 | input shape = [N, C, H, W] 23 | grid_shape = [N, H, W, 2] 24 | 25 | output shape = [N, C, H, W] 26 | ''' 27 | N, C, H, W = image.shape 28 | grid_H = grid.shape[1] 29 | grid_W = grid.shape[2] 30 | 31 | output_tensor = torch.zeros_like(image) 32 | for n in range(N): 33 | for w in range(grid_W): 34 | for h in range(grid_H): 35 | #get corresponding grid x and y 36 | x = grid[n, h, w, 1] 37 | y = grid[n, h, w, 0] 38 | 39 | #Unnormalize with align_corners condition 40 | ix = grid_sampler_compute_source_index(x, W, align_corners) 41 | iy = grid_sampler_compute_source_index(y, H, align_corners) 42 | 43 | x0 = torch.floor(ix).type(torch.LongTensor) 44 | x1 = x0 + 1 45 | 46 | y0 = torch.floor(iy).type(torch.LongTensor) 47 | y1 = y0 + 1 48 | 49 | #Get W matrix before I matrix, as I matrix requires Channel information 50 | wa = (x1.type(torch.FloatTensor)-ix) * (y1.type(torch.FloatTensor)-iy) 51 | wb = (x1.type(torch.FloatTensor)-ix) * (iy-y0.type(torch.FloatTensor)) 52 | wc = (ix-x0.type(torch.FloatTensor)) * (y1.type(torch.FloatTensor)-iy) 53 | wd = (ix-x0.type(torch.FloatTensor)) * (iy-y0.type(torch.FloatTensor)) 54 | 55 | #Get values of the image by provided x0,y0,x1,y1 by channel 56 | for c in range(C): 57 | #image, n, c, x, y, H, W 58 | Ia = safe_get(image, n, c, y0, x0, H, W) 59 | Ib = safe_get(image, n, c, y1, x0, H, W) 60 | Ic = safe_get(image, n, c, y0, x1, H, W) 61 | Id = safe_get(image, n, c, y1, x1, H, W) 62 | out_ch_val = torch.t((torch.t(Ia)*wa)) + torch.t(torch.t(Ib)*wb) + \ 63 | torch.t(torch.t(Ic)*wc) + torch.t(torch.t(Id)*wd) 64 | 65 | output_tensor[n, c, h, w] = out_ch_val 66 | return output_tensor 67 | ``` 68 | --------------------------------------------------------------------------------