├── .gitignore ├── Forward_Warp ├── __init__.py ├── cuda │ ├── forward_warp.h │ ├── forward_warp_cuda.cpp │ ├── forward_warp_cuda_kernel.cu │ └── setup.py ├── forward_warp.py └── python │ ├── __init__.py │ └── forward_warp_python.py ├── LICENSE ├── Readme.md ├── install.sh ├── setup.py └── test ├── flow.pkl ├── im0.png ├── im1.png └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /Forward_Warp/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward_warp import forward_warp -------------------------------------------------------------------------------- /Forward_Warp/cuda/forward_warp.h: -------------------------------------------------------------------------------- 1 | #ifndef FORWARD_WARP_H 2 | #define FORWARD_WARP_H 3 | 4 | // Define GridSamplerInterpolation 5 | namespace at { namespace native { namespace detail { 6 | enum class GridSamplerInterpolation {Bilinear, Nearest}; 7 | enum class GridSamplerPadding {Zeros, Border, Reflection}; 8 | }}} 9 | 10 | // Define CUDA_NUM_THREAS and GET_BLOCKS 11 | const int CUDA_NUM_THREADS = 1024; 12 | inline int GET_BLOCKS(const int N){ 13 | return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; 14 | } 15 | 16 | // Define CUDA_KERNEL_LOOP 17 | #define CUDA_KERNEL_LOOP(i, n) \ 18 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) 19 | 20 | #endif 21 | -------------------------------------------------------------------------------- /Forward_Warp/cuda/forward_warp_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "forward_warp.h" 5 | using at::native::detail::GridSamplerInterpolation; 6 | 7 | at::Tensor forward_warp_cuda_forward( 8 | const at::Tensor im0, 9 | const at::Tensor flow, 10 | const GridSamplerInterpolation interpolation_mode); 11 | std::vector forward_warp_cuda_backward( 12 | const at::Tensor grad_output, 13 | const at::Tensor im0, 14 | const at::Tensor flow, 15 | const GridSamplerInterpolation interpolation_mode); 16 | 17 | // Because of the incompatible of Pytorch 1.0 && Pytorch 0.4, we have to annotation this. 18 | #define CHECK_CUDA(x) AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor") 19 | #define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous") 20 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 21 | 22 | at::Tensor forward_warp_forward( 23 | const at::Tensor im0, 24 | const at::Tensor flow, 25 | const int interpolation_mode){ 26 | // CHECK_INPUT(im0); 27 | // CHECK_INPUT(flow); 28 | return forward_warp_cuda_forward(im0, flow, (GridSamplerInterpolation)interpolation_mode); 29 | } 30 | 31 | std::vector forward_warp_backward( 32 | const at::Tensor grad_output, 33 | const at::Tensor im0, 34 | const at::Tensor flow, 35 | const int interpolation_mode){ 36 | // CHECK_INPUT(grad_output); 37 | // CHECK_INPUT(im0); 38 | // CHECK_INPUT(flow); 39 | return forward_warp_cuda_backward(grad_output, im0, flow, (GridSamplerInterpolation)interpolation_mode); 40 | } 41 | 42 | PYBIND11_MODULE( 43 | TORCH_EXTENSION_NAME, 44 | m){ 45 | m.def("forward", &forward_warp_forward, "forward warp forward (CUDA)"); 46 | m.def("backward", &forward_warp_backward, "forward warp backward (CUDA)"); 47 | } 48 | -------------------------------------------------------------------------------- /Forward_Warp/cuda/forward_warp_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "forward_warp.h" 10 | using at::native::detail::GridSamplerInterpolation; 11 | 12 | static __forceinline__ __device__ 13 | int get_im_index( 14 | const int b, 15 | const int c, 16 | const int h, 17 | const int w, 18 | const size_t C, 19 | const size_t H, 20 | const size_t W) { 21 | return b*C*H*W + c*H*W + h*W + w; 22 | } 23 | 24 | template 25 | __global__ void forward_warp_cuda_forward_kernel( 26 | const int total_step, 27 | const scalar_t* im0, 28 | const scalar_t* flow, 29 | scalar_t* im1, 30 | const int B, 31 | const int C, 32 | const int H, 33 | const int W, 34 | const GridSamplerInterpolation interpolation_mode) { 35 | // CUDA_KERNEL_LOOP(index, total_step-1) { 36 | // bug fix, thx to @tkkcc 37 | CUDA_KERNEL_LOOP(index, total_step) { 38 | const int b = index / (H * W); 39 | const int h = (index-b*H*W) / W; 40 | const int w = index % W; 41 | const scalar_t x = (scalar_t)w + flow[index*2+0]; 42 | const scalar_t y = (scalar_t)h + flow[index*2+1]; 43 | if (interpolation_mode == GridSamplerInterpolation::Bilinear) { 44 | const int x_f = static_cast(::floor(x)); 45 | const int y_f = static_cast(::floor(y)); 46 | const int x_c = x_f + 1; 47 | const int y_c = y_f + 1; 48 | if(x_f>=0 && x_c=0 && y_c(::round(x)); 65 | const int y_nearest = static_cast(::round(y)); 66 | if(x_nearest>=0 && x_nearest=0 && y_nearest 78 | __global__ void forward_warp_cuda_backward_kernel( 79 | const int total_step, 80 | const scalar_t* grad_output, 81 | const scalar_t* im0, 82 | const scalar_t* flow, 83 | scalar_t* im0_grad, 84 | scalar_t* flow_grad, 85 | const int B, 86 | const int C, 87 | const int H, 88 | const int W, 89 | const GridSamplerInterpolation interpolation_mode) { 90 | CUDA_KERNEL_LOOP(index, total_step) { 91 | const int b = index / (H * W); 92 | const int h = (index-b*H*W) / W; 93 | const int w = index % W; 94 | const scalar_t x = (scalar_t)w + flow[index*2+0]; 95 | const scalar_t y = (scalar_t)h + flow[index*2+1]; 96 | if (interpolation_mode == GridSamplerInterpolation::Bilinear) { 97 | const int x_f = static_cast(::floor(x)); 98 | const int y_f = static_cast(::floor(y)); 99 | const int x_c = x_f + 1; 100 | const int y_c = y_f + 1; 101 | if(x_f>=0 && x_c=0 && y_c(::round(x)); 134 | const int y_nearest = static_cast(::round(y)); 135 | if(x_nearest>=0 && x_nearest=0 && y_nearest 158 | <<>>( 159 | total_step, 160 | im0.data(), 161 | flow.data(), 162 | im1.data(), 163 | B, C, H, W, 164 | interpolation_mode); 165 | })); 166 | 167 | return im1; 168 | } 169 | 170 | std::vector forward_warp_cuda_backward( 171 | const at::Tensor grad_output, 172 | const at::Tensor im0, 173 | const at::Tensor flow, 174 | const GridSamplerInterpolation interpolation_mode) { 175 | auto im0_grad = at::zeros_like(grad_output); 176 | auto flow_grad = at::empty_like(flow); 177 | const int B = im0.size(0); 178 | const int C = im0.size(1); 179 | const int H = im0.size(2); 180 | const int W = im0.size(3); 181 | const int total_step = B * H * W; 182 | 183 | AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "forward_warp_backward_cuda", ([&] { 184 | forward_warp_cuda_backward_kernel 185 | <<>>( 186 | total_step, 187 | grad_output.data(), 188 | im0.data_ptr(), 189 | flow.data(), 190 | im0_grad.data(), 191 | flow_grad.data(), 192 | B, C, H, W, 193 | interpolation_mode); 194 | })); 195 | 196 | return {im0_grad, flow_grad}; 197 | } 198 | -------------------------------------------------------------------------------- /Forward_Warp/cuda/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 3 | 4 | setup( 5 | name='forward_warp_cuda', 6 | ext_modules=[ 7 | CUDAExtension('forward_warp_cuda', [ 8 | 'forward_warp_cuda.cpp', 9 | 'forward_warp_cuda_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /Forward_Warp/forward_warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Parameter 3 | from torch.autograd import Function 4 | 5 | import forward_warp_cuda 6 | from .python import Forward_Warp_Python 7 | 8 | 9 | class forward_warp_function(Function): 10 | 11 | @staticmethod 12 | def forward(ctx, im0, flow, interpolation_mode): 13 | ''' 14 | im0: the first image with shape [B, C, H, W] 15 | flow: the optical flow with shape [B, H, W, 2] (different to grid_sample, it's range is from [-W, -H] to [W, H]) 16 | interpolation_mode: 0 is Bilinear, 1 is Nearest 17 | ''' 18 | assert(len(im0.shape) == len(flow.shape) == 4) 19 | assert(interpolation_mode in (0, 1)) 20 | assert(im0.shape[0] == flow.shape[0]) 21 | assert(im0.shape[-2:] == flow.shape[1:3]) 22 | assert(flow.shape[3] == 2) 23 | 24 | ctx.interpolation_mode = interpolation_mode 25 | ctx.save_for_backward(im0, flow) 26 | if im0.is_cuda: 27 | im1 = forward_warp_cuda.forward(im0, flow, interpolation_mode) 28 | else: 29 | im1 = Forward_Warp_Python.forward(im0, flow, interpolation_mode) 30 | 31 | return im1 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | im0, flow = ctx.saved_variables 36 | if grad_output.is_cuda: 37 | im0_grad, flow_grad = forward_warp_cuda.backward( 38 | grad_output, im0, flow, ctx.interpolation_mode) 39 | else: 40 | im0_grad, flow_grad = Forward_Warp_Python.backward( 41 | grad_output, im0, flow, ctx.interpolation_mode) 42 | return im0_grad, flow_grad, None 43 | 44 | 45 | class forward_warp(Module): 46 | 47 | def __init__(self, interpolation_mode="Bilinear"): 48 | ''' 49 | Support interpolation mode with Bilinear and Nearest. 50 | ''' 51 | super(forward_warp, self).__init__() 52 | assert(interpolation_mode in ("Bilinear", "Nearest")) 53 | if(interpolation_mode is "Bilinear"): 54 | self.interpolation_mode = 0 55 | else: 56 | self.interpolation_mode = 1 57 | 58 | def forward(self, im0, flow): 59 | 60 | return forward_warp_function.apply(im0, flow, self.interpolation_mode) 61 | -------------------------------------------------------------------------------- /Forward_Warp/python/__init__.py: -------------------------------------------------------------------------------- 1 | from .forward_warp_python import Forward_Warp_Python -------------------------------------------------------------------------------- /Forward_Warp/python/forward_warp_python.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Parameter 3 | from torch.autograd import Function 4 | 5 | 6 | class Forward_Warp_Python: 7 | @staticmethod 8 | def forward(im0, flow, interpolation_mode): 9 | im1 = torch.zeros_like(im0) 10 | B = im0.shape[0] 11 | H = im0.shape[2] 12 | W = im0.shape[3] 13 | if interpolation_mode == 0: 14 | for b in range(B): 15 | for h in range(H): 16 | for w in range(W): 17 | x = w + flow[b, h, w, 0] 18 | y = h + flow[b, h, w, 1] 19 | nw = (int(torch.floor(x)), int(torch.floor(y))) 20 | ne = (nw[0]+1, nw[1]) 21 | sw = (nw[0], nw[1]+1) 22 | se = (nw[0]+1, nw[1]+1) 23 | p = im0[b, :, h, w] 24 | if nw[0] >= 0 and se[0] < W and nw[1] >= 0 and se[1] < H: 25 | nw_k = (se[0]-x)*(se[1]-y) 26 | ne_k = (x-sw[0])*(sw[1]-y) 27 | sw_k = (ne[0]-x)*(y-ne[1]) 28 | se_k = (x-nw[0])*(y-nw[1]) 29 | im1[b, :, nw[1], nw[0]] += nw_k*p 30 | im1[b, :, ne[1], ne[0]] += ne_k*p 31 | im1[b, :, sw[1], sw[0]] += sw_k*p 32 | im1[b, :, se[1], se[0]] += se_k*p 33 | else: 34 | round_flow = torch.round(flow) 35 | for b in range(B): 36 | for h in range(H): 37 | for w in range(W): 38 | x = w + int(round_flow[b, h, w, 0]) 39 | y = h + int(round_flow[b, h, w, 1]) 40 | if x >= 0 and x < W and y >= 0 and y < H: 41 | im1[b, :, y, x] = im0[b, :, h, w] 42 | return im1 43 | 44 | @staticmethod 45 | def backward(grad_output, im0, flow, interpolation_mode): 46 | B = grad_output.shape[0] 47 | C = grad_output.shape[1] 48 | H = grad_output.shape[2] 49 | W = grad_output.shape[3] 50 | im0_grad = torch.zeros_like(grad_output) 51 | flow_grad = torch.empty([B, H, W, 2]) 52 | if interpolation_mode == 0: 53 | for b in range(B): 54 | for h in range(H): 55 | for w in range(W): 56 | x = w + flow[b, h, w, 0] 57 | y = h + flow[b, h, w, 1] 58 | x_f = int(torch.floor(x)) 59 | y_f = int(torch.floor(y)) 60 | x_c = x_f+1 61 | y_c = y_f+1 62 | nw = (x_f, y_f) 63 | ne = (x_c, y_f) 64 | sw = (x_f, y_c) 65 | se = (x_c, y_c) 66 | p = im0[b, :, h, w] 67 | if nw[0] >= 0 and se[0] < W and nw[1] >= 0 and se[1] < H: 68 | nw_k = (se[0]-x)*(se[1]-y) 69 | ne_k = (x-sw[0])*(sw[1]-y) 70 | sw_k = (ne[0]-x)*(y-ne[1]) 71 | se_k = (x-nw[0])*(y-nw[1]) 72 | nw_grad = grad_output[b, :, nw[1], nw[0]] 73 | ne_grad = grad_output[b, :, ne[1], ne[0]] 74 | sw_grad = grad_output[b, :, sw[1], sw[0]] 75 | se_grad = grad_output[b, :, se[1], se[0]] 76 | im0_grad[b, :, h, w] += nw_k*nw_grad 77 | im0_grad[b, :, h, w] += ne_k*ne_grad 78 | im0_grad[b, :, h, w] += sw_k*sw_grad 79 | im0_grad[b, :, h, w] += se_k*se_grad 80 | flow_grad_x = torch.zeros(C) 81 | flow_grad_y = torch.zeros(C) 82 | flow_grad_x -= (y_c-y)*p*nw_grad 83 | flow_grad_y -= (x_c-x)*p*nw_grad 84 | flow_grad_x += (y_c-y)*p*ne_grad 85 | flow_grad_y -= (x-x_f)*p*ne_grad 86 | flow_grad_x -= (y-y_f)*p*sw_grad 87 | flow_grad_y += (x_c-x)*p*sw_grad 88 | flow_grad_x += (y-y_f)*p*se_grad 89 | flow_grad_y += (x-x_f)*p*se_grad 90 | flow_grad[b, h, w, 0] = torch.sum(flow_grad_x) 91 | flow_grad[b, h, w, 1] = torch.sum(flow_grad_y) 92 | else: 93 | round_flow = torch.round(flow) 94 | for b in range(B): 95 | for h in range(H): 96 | for w in range(W): 97 | x = w + int(round_flow[b, h, w, 0]) 98 | y = h + int(round_flow[b, h, w, 1]) 99 | if x >= 0 and x < W and y >= 0 and y < H: 100 | im0_grad[b, :, h, w] = grad_output[b, :, y, x] 101 | return im0_grad, flow_grad 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 lizhihao6 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ## Foward Warp Pytorch Version 2 | 3 | Has been tested in pytorch=0.4.0, python=3.6, CUDA=9.0 4 | 5 | ### Install 6 | 7 | ```bash 8 | export CUDA_HOME=/usr/local/cuda #use your CUDA instead 9 | chmod a+x install.sh 10 | ./install.sh 11 | ``` 12 | 13 | ### Test 14 | 15 | ```bash 16 | cd test 17 | python test.py 18 | ``` 19 | 20 | ### Usage 21 | 22 | ```python 23 | from Forward_Warp import forward_warp 24 | 25 | # default interpolation mode is Bilinear 26 | fw = forward_warp() 27 | im2_bilinear = fw(im0, flow) 28 | # use interpolation mode Nearest 29 | # Notice: Nearest input-flow's gradient will be zero when at backward. 30 | fw = forward_warp(interpolation_mode="Nearest") 31 | im2_nearest = fw(im0, flow) 32 | ``` 33 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | work_path=$(dirname $(readlink -f $0)) 3 | cd ${work_path}/Forward_Warp/cuda/ 4 | conda activate pytorch 5 | python setup.py install | grep "error" 6 | cd ../../ 7 | python setup.py install 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='Forward_Warp', 5 | version='0.0.1', 6 | packages=find_packages(), 7 | author = "lizhihao6", 8 | author_email = "lizhihao6@outlook.com", 9 | ) -------------------------------------------------------------------------------- /test/flow.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhihao6/Forward-Warp/91877f3dbb119a45b3008d720358e87ad99844ad/test/flow.pkl -------------------------------------------------------------------------------- /test/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhihao6/Forward-Warp/91877f3dbb119a45b3008d720358e87ad99844ad/test/im0.png -------------------------------------------------------------------------------- /test/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lizhihao6/Forward-Warp/91877f3dbb119a45b3008d720358e87ad99844ad/test/im1.png -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import torch 4 | import pickle 5 | import numpy as np 6 | 7 | from Forward_Warp import forward_warp 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | im0 = cv2.imread("im0.png")[np.newaxis, :, :, :] 13 | im1 = cv2.imread("im1.png")[np.newaxis, :, :, :] 14 | with open("flow.pkl", "rb+") as f: 15 | flow = pickle.load(f) 16 | im0 = torch.FloatTensor(im0).permute(0, 3, 1, 2).contiguous() 17 | im1 = torch.FloatTensor(im1).permute(0, 3, 1, 2).contiguous() 18 | flow = torch.FloatTensor(flow) 19 | 20 | fw = forward_warp() 21 | 22 | since = time.time() 23 | im1_python = fw(im0, flow) 24 | print("python version forward cost time: {}".format(time.time()-since)) 25 | 26 | im0 = im0.cuda() 27 | flow = flow.cuda() 28 | since = time.time() 29 | im1_cuda = fw(im0, flow) 30 | print("cuda version forward cost time: {}".format(time.time()-since)) 31 | 32 | 33 | loss_fn = torch.nn.MSELoss() 34 | python_loss = loss_fn(im1_python, im1) 35 | print("python loss: {}".format(python_loss)) 36 | cuda_loss = loss_fn(im1_cuda, im1.cuda()) 37 | print("cuda loss: {}".format(cuda_loss)) 38 | 39 | im1_python = im1_python.permute(0, 2, 3, 1)[0] 40 | cv2.imwrite("im1_python.png", im1_python.numpy().astype(np.uint8)) 41 | im1_cuda = im1_cuda.permute(0, 2, 3, 1)[0] 42 | cv2.imwrite("im1_cuda.png", im1_cuda.cpu().numpy().astype(np.uint8)) --------------------------------------------------------------------------------