├── .gitignore ├── LICENSE ├── environment.yml ├── gradient_visualizer └── gradient_visualizer.py ├── linearized sampler tutorial.ipynb ├── notebook_data └── cute.jpg ├── readme.md ├── utils └── utils.py └── warp ├── linearized.py ├── perturbation_helper.py └── sampling_helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | .idea 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yi Research Group @ The University of British Columbia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: linearized 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - attrs=19.1.0=py37_1 8 | - backcall=0.1.0=py37_0 9 | - blas=1.0=mkl 10 | - bleach=3.1.0=py37_0 11 | - ca-certificates=2019.5.15=1 12 | - certifi=2019.6.16=py37_1 13 | - cffi=1.12.3=py37h2e261b9_0 14 | - cudatoolkit=10.0.130=0 15 | - cycler=0.10.0=py37_0 16 | - dbus=1.13.6=h746ee38_0 17 | - decorator=4.4.0=py37_1 18 | - defusedxml=0.6.0=py_0 19 | - entrypoints=0.3=py37_0 20 | - expat=2.2.6=he6710b0_0 21 | - fontconfig=2.13.0=h9420a91_0 22 | - freetype=2.9.1=h8a8886c_1 23 | - glib=2.56.2=hd408876_0 24 | - gmp=6.1.2=h6c8ec71_1 25 | - gst-plugins-base=1.14.0=hbbd80ab_1 26 | - gstreamer=1.14.0=hb453b48_1 27 | - icu=58.2=h9c2bf20_1 28 | - imageio=2.5.0=py37_0 29 | - intel-openmp=2019.4=243 30 | - ipykernel=5.1.2=py37h39e3cac_0 31 | - ipython=7.8.0=py37h39e3cac_0 32 | - ipython_genutils=0.2.0=py37_0 33 | - ipywidgets=7.5.1=py_0 34 | - jedi=0.15.1=py37_0 35 | - jinja2=2.10.1=py37_0 36 | - jpeg=9b=h024ee3a_2 37 | - jsonschema=3.0.2=py37_0 38 | - jupyter=1.0.0=py37_7 39 | - jupyter_client=5.3.1=py_0 40 | - jupyter_console=6.0.0=py37_0 41 | - jupyter_core=4.5.0=py_0 42 | - kiwisolver=1.1.0=py37he6710b0_0 43 | - libedit=3.1.20181209=hc058e9b_0 44 | - libffi=3.2.1=hd88cf55_4 45 | - libgcc-ng=9.1.0=hdf63c60_0 46 | - libgfortran-ng=7.3.0=hdf63c60_0 47 | - libpng=1.6.37=hbc83047_0 48 | - libsodium=1.0.16=h1bed415_0 49 | - libstdcxx-ng=9.1.0=hdf63c60_0 50 | - libtiff=4.0.10=h2733197_2 51 | - libuuid=1.0.3=h1bed415_2 52 | - libxcb=1.13=h1bed415_1 53 | - libxml2=2.9.9=hea5a465_1 54 | - markupsafe=1.1.1=py37h7b6447c_0 55 | - matplotlib=3.1.1=py37h5429711_0 56 | - mistune=0.8.4=py37h7b6447c_0 57 | - mkl=2019.4=243 58 | - mkl-service=2.3.0=py37he904b0f_0 59 | - mkl_fft=1.0.14=py37ha843d7b_0 60 | - mkl_random=1.0.2=py37hd81dba3_0 61 | - nbconvert=5.5.0=py_0 62 | - nbformat=4.4.0=py37_0 63 | - ncurses=6.1=he6710b0_1 64 | - ninja=1.9.0=py37hfd86e86_0 65 | - notebook=6.0.0=py37_0 66 | - numpy=1.16.4=py37h7e9f1db_0 67 | - numpy-base=1.16.4=py37hde5b4d6_0 68 | - olefile=0.46=py37_0 69 | - openssl=1.1.1c=h7b6447c_1 70 | - pandoc=2.2.3.2=0 71 | - pandocfilters=1.4.2=py37_1 72 | - parso=0.5.1=py_0 73 | - pcre=8.43=he6710b0_0 74 | - pexpect=4.7.0=py37_0 75 | - pickleshare=0.7.5=py37_0 76 | - pillow=6.1.0=py37h34e0f95_0 77 | - pip=19.2.2=py37_0 78 | - prometheus_client=0.7.1=py_0 79 | - prompt_toolkit=2.0.9=py37_0 80 | - ptyprocess=0.6.0=py37_0 81 | - pycparser=2.19=py37_0 82 | - pygments=2.4.2=py_0 83 | - pyparsing=2.4.2=py_0 84 | - pyqt=5.9.2=py37h05f1152_2 85 | - pyrsistent=0.14.11=py37h7b6447c_0 86 | - python=3.7.4=h265db76_1 87 | - python-dateutil=2.8.0=py37_0 88 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0 89 | - pytz=2019.2=py_0 90 | - pyzmq=18.1.0=py37he6710b0_0 91 | - qt=5.9.7=h5867ecd_1 92 | - qtconsole=4.5.4=py_0 93 | - readline=7.0=h7b6447c_5 94 | - send2trash=1.5.0=py37_0 95 | - setuptools=41.0.1=py37_0 96 | - sip=4.19.8=py37hf484d3e_0 97 | - six=1.12.0=py37_0 98 | - sqlite=3.29.0=h7b6447c_0 99 | - terminado=0.8.2=py37_0 100 | - testpath=0.4.2=py37_0 101 | - tk=8.6.8=hbc83047_0 102 | - torchvision=0.4.0=py37_cu100 103 | - tornado=6.0.3=py37h7b6447c_0 104 | - traitlets=4.3.2=py37_0 105 | - wcwidth=0.1.7=py37_0 106 | - webencodings=0.5.1=py37_1 107 | - wheel=0.33.4=py37_0 108 | - widgetsnbextension=3.5.1=py37_0 109 | - xz=5.2.4=h14c3975_4 110 | - zeromq=4.3.1=he6710b0_3 111 | - zlib=1.2.11=h7b6447c_3 112 | - zstd=1.3.7=h0b5b093_0 113 | prefix: /home/jiangwei/miniconda3/envs/linearized 114 | 115 | -------------------------------------------------------------------------------- /gradient_visualizer/gradient_visualizer.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | from utils import utils 8 | from warp import perturbation_helper, sampling_helper 9 | 10 | 11 | class GradientVisualizer(object): 12 | def __init__(self, opt): 13 | self.opt = opt 14 | self.warper = sampling_helper.DifferentiableImageSampler('bilinear', 'zeros') 15 | 16 | def build_criterion(self): 17 | if self.opt.optim_criterion == 'l1loss': 18 | criterion = torch.nn.L1Loss() 19 | elif self.opt.optim_criterion == 'mse': 20 | criterion = torch.nn.MSELoss() 21 | else: 22 | raise ValueError('unknown optimization criterion: {0}'.format(self.opt.optim_criterion)) 23 | return criterion 24 | 25 | def build_gd_optimizer(self, params): 26 | optim_list = [{"params": params, "lr": self.opt.optim_lr}] 27 | optimizer = torch.optim.SGD(optim_list) 28 | return optimizer 29 | 30 | def create_translation_grid(self, resolution=None): 31 | if resolution is None: 32 | resolution = self.opt.grid_size 33 | results = [] 34 | x_steps = torch.linspace(-1, 1, steps=resolution) 35 | y_steps = torch.linspace(-1, 1, steps=resolution) 36 | for x in x_steps: 37 | for y in y_steps: 38 | translation_vec = torch.stack([x, y], dim=0)[None] 39 | results.append(translation_vec) 40 | return results 41 | 42 | def get_next_translation_vec(self, data_pack, image_warper): 43 | translation_vec = data_pack['translation_vec'].clone().detach().requires_grad_(True) 44 | translation_mat = perturbation_helper.vec2mat_for_translation(translation_vec) 45 | orig_image = data_pack['original_image'] 46 | criterion = self.build_criterion() 47 | optimizer = self.build_gd_optimizer(params=translation_vec) 48 | ident_mat = perturbation_helper.gen_identity_mat(1) 49 | down_sampled_orig_image = self.warper.warp_image(orig_image, ident_mat, self.opt.out_shape).detach() 50 | warped_image = image_warper.warp_image(orig_image, translation_mat, self.opt.out_shape) 51 | loss = criterion(warped_image, down_sampled_orig_image) 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | return translation_vec 56 | 57 | def get_gradient_over_translation_vec(self, data_pack, image_warper): 58 | translation_vec = data_pack['translation_vec'] 59 | next_translation_vec = self.get_next_translation_vec(data_pack, image_warper) 60 | moving_dir = next_translation_vec - translation_vec 61 | return moving_dir 62 | 63 | def get_gradient_grid(self, orig_image, image_warper): 64 | gradient_grid = [] 65 | translation_grid = self.create_translation_grid() 66 | for translation_vec in translation_grid: 67 | data_pack = {} 68 | data_pack['translation_vec'] = translation_vec 69 | data_pack['original_image'] = orig_image 70 | cur_gradient = self.get_gradient_over_translation_vec(data_pack, image_warper) 71 | gradient_pack = {'translation_vec': translation_vec, 'gradient': cur_gradient} 72 | gradient_grid.append(gradient_pack) 73 | return gradient_grid 74 | 75 | def draw_gradient_grid(self, orig_image, image_warper): 76 | gradient_grid = self.get_gradient_grid(orig_image, image_warper) 77 | 78 | _, ax = plt.subplots() 79 | ax.axis('equal') 80 | ax.set_xlim([-1.5, 1.5]) 81 | ax.set_ylim([-1.5, 1.5]) 82 | orig_image_show = utils.torch_img_to_np_img(orig_image)[0] 83 | ax.imshow(orig_image_show, extent=[-1, 1, -1, 1]) 84 | 85 | for gradient in gradient_grid: 86 | orig_point = np.zeros([2], dtype=np.float32) 87 | base_loc = 0 - (gradient['translation_vec'])[0].data.cpu().numpy() 88 | gradient_dir = (gradient['gradient'])[0].data.cpu().numpy() 89 | gradient_dir = 0 - utils.unit_vector(gradient_dir) 90 | gt_dir = orig_point - base_loc 91 | gt_dir = utils.unit_vector(gt_dir) 92 | 93 | angle = utils.angle_between(gradient_dir, gt_dir) 94 | try: 95 | cur_color = self.angle_to_color(angle) 96 | except ValueError: 97 | cur_color = [0., 0., 0.] 98 | gradient_dir = gradient_dir / 10 99 | ax.arrow(base_loc[0], base_loc[1], gradient_dir[0], gradient_dir[1], head_width=0.05, head_length=0.1, color=cur_color) 100 | plt.show() 101 | 102 | def angle_to_color(self, angle): 103 | red_hue, _, _ = colorsys.rgb_to_hsv(1, 0, 0) 104 | green_hue, _, _ = colorsys.rgb_to_hsv(0, 1, 0) 105 | cur_hue = np.interp(angle, (0, np.pi), (green_hue, red_hue)) 106 | cur_color = colorsys.hsv_to_rgb(cur_hue, 1, 1) 107 | return cur_color 108 | -------------------------------------------------------------------------------- /notebook_data/cute.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vcg-uvic/linearized_multisampling_release/329c3c62655583882c6f586c436ae79da7e4fab0/notebook_data/cute.jpg -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Linearized Multi-Sampling for Differentiable Image Transformation (ICCV 2019) 2 | 3 | This repository is a reference implementation for "Linearized Multi-Sampling for Differentiable Image Transformation", ICCV 2019. If you use this code in your research, please cite the paper. 4 | 5 | [ArXiv](https://arxiv.org/abs/1901.07124) 6 | 7 | ### Installation 8 | 9 | This implementation is based on Python3 and PyTorch. 10 | 11 | You can install the environment by: ```conda env create -f environment.yml``` 12 | 13 | Activate the env by: ```conda activate linearized``` 14 | 15 | ### Tutorial 16 | 17 | A tutorial is in `linearized sampler tutorial.ipynb` . We built the method to have the same function prototype as `torch.nn.functional.grid_sample`, so you can replace bilinear sampling with linearized multi-sampling with minimum modification. 18 | 19 | ### Direct plug-in 20 | 21 | Copy `./warp/linearized.py` to your project folder, and replace `torch.nn.functional.grid_sample` in your code with `linearized.grid_sample`. 22 | 23 | We made `linearize.py` to have minimum dependencies(PyTorch only), so we put some extra utils methods in that file. You can move those utils methods to another place to make it cleaner. 24 | 25 | ### Notes 26 | 27 | If you find linearized multi-sampling useful in you project, please feel free to let us know by leaving an issue on this git repository or sending an email to jiangwei1993@gmail.com. 28 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def embed_breakpoint(terminate=True): 6 | embedding = ('import IPython\n' 7 | 'import matplotlib.pyplot as plt\n' 8 | 'IPython.embed()\n' 9 | ) 10 | if terminate: 11 | embedding += ( 12 | 'assert 0, \'force termination\'\n' 13 | ) 14 | 15 | return embedding 16 | 17 | 18 | def torch_img_to_np_img(torch_img): 19 | '''convert a torch image to matplotlib-able numpy image 20 | torch use Channels x Height x Width 21 | numpy use Height x Width x Channels 22 | Arguments: 23 | torch_img {[type]} -- [description] 24 | ''' 25 | assert isinstance(torch_img, torch.Tensor), 'cannot process data type: {0}'.format(type(torch_img)) 26 | if len(torch_img.shape) == 4 and (torch_img.shape[1] == 3 or torch_img.shape[1] == 1): 27 | return np.transpose(torch_img.detach().cpu().numpy(), (0, 2, 3, 1)) 28 | if len(torch_img.shape) == 3 and (torch_img.shape[0] == 3 or torch_img.shape[0] == 1): 29 | return np.transpose(torch_img.detach().cpu().numpy(), (1, 2, 0)) 30 | elif len(torch_img.shape) == 2: 31 | return torch_img.detach().cpu().numpy() 32 | else: 33 | raise ValueError('cannot process this image') 34 | 35 | 36 | def np_img_to_torch_img(np_img): 37 | """convert a numpy image to torch image 38 | numpy use Height x Width x Channels 39 | torch use Channels x Height x Width 40 | 41 | Arguments: 42 | np_img {[type]} -- [description] 43 | """ 44 | assert isinstance(np_img, np.ndarray), 'cannot process data type: {0}'.format(type(np_img)) 45 | if len(np_img.shape) == 4 and (np_img.shape[3] == 3 or np_img.shape[3] == 1): 46 | return torch.from_numpy(np.transpose(np_img, (0, 3, 1, 2))) 47 | if len(np_img.shape) == 3 and (np_img.shape[2] == 3 or np_img.shape[2] == 1): 48 | return torch.from_numpy(np.transpose(np_img, (2, 0, 1))) 49 | elif len(np_img.shape) == 2: 50 | return torch.from_numpy(np_img) 51 | else: 52 | raise ValueError('cannot process this image with shape: {0}'.format(np_img.shape)) 53 | 54 | 55 | def unit_vector(vector): 56 | """ Returns the unit vector of the vector. """ 57 | return vector / np.linalg.norm(vector) 58 | 59 | 60 | def angle_between(v1, v2): 61 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 62 | 63 | >>> angle_between((1, 0, 0), (0, 1, 0)) 64 | 1.5707963267948966 65 | >>> angle_between((1, 0, 0), (1, 0, 0)) 66 | 0.0 67 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 68 | 3.141592653589793 69 | """ 70 | v1_u = unit_vector(v1) 71 | v2_u = unit_vector(v2) 72 | return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) 73 | -------------------------------------------------------------------------------- /warp/linearized.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Linearized multi-sampling core part. 3 | All methods are encapsuled in class LinearizedMutilSampler. 4 | Hyperparameters are stored as static variables. 5 | Main sampling method entrance is linearized_grid_sample. 6 | ''' 7 | 8 | import torch 9 | 10 | 11 | ######### Utils to minimize dependencies ######### 12 | # Move utils to another file if you want 13 | def print_notification(content_list, notification_type='NOTIFICATION'): 14 | print('---------------------- {0} ----------------------'.format(notification_type)) 15 | print() 16 | for content in content_list: 17 | print(content) 18 | print() 19 | print('----------------------------------------------------') 20 | 21 | 22 | def is_nan(x): 23 | ''' 24 | get mask of nan values. 25 | :param x: torch or numpy var. 26 | :return: a N-D array of bool. True -> nan, False -> ok. 27 | ''' 28 | return x != x 29 | 30 | 31 | def has_nan(x) -> bool: 32 | ''' 33 | check whether x contains nan. 34 | :param x: torch or numpy var. 35 | :return: single bool, True -> x containing nan, False -> ok. 36 | ''' 37 | return is_nan(x).any() 38 | 39 | 40 | def mat_3x3_inv(mat): 41 | ''' 42 | calculate the inverse of a 3x3 matrix, support batch. 43 | :param mat: torch.Tensor -- [input matrix, shape: (B, 3, 3)] 44 | :return: mat_inv: torch.Tensor -- [inversed matrix shape: (B, 3, 3)] 45 | ''' 46 | if len(mat.shape) < 3: 47 | mat = mat[None] 48 | assert mat.shape[1:] == (3, 3) 49 | 50 | # Divide the matrix with it's maximum element 51 | max_vals = mat.max(1)[0].max(1)[0].view((-1, 1, 1)) 52 | mat = mat / max_vals 53 | 54 | det = mat_3x3_det(mat) 55 | inv_det = 1.0 / det 56 | 57 | mat_inv = torch.zeros(mat.shape, device=mat.device) 58 | mat_inv[:, 0, 0] = (mat[:, 1, 1] * mat[:, 2, 2] - mat[:, 2, 1] * mat[:, 1, 2]) * inv_det 59 | mat_inv[:, 0, 1] = (mat[:, 0, 2] * mat[:, 2, 1] - mat[:, 0, 1] * mat[:, 2, 2]) * inv_det 60 | mat_inv[:, 0, 2] = (mat[:, 0, 1] * mat[:, 1, 2] - mat[:, 0, 2] * mat[:, 1, 1]) * inv_det 61 | mat_inv[:, 1, 0] = (mat[:, 1, 2] * mat[:, 2, 0] - mat[:, 1, 0] * mat[:, 2, 2]) * inv_det 62 | mat_inv[:, 1, 1] = (mat[:, 0, 0] * mat[:, 2, 2] - mat[:, 0, 2] * mat[:, 2, 0]) * inv_det 63 | mat_inv[:, 1, 2] = (mat[:, 1, 0] * mat[:, 0, 2] - mat[:, 0, 0] * mat[:, 1, 2]) * inv_det 64 | mat_inv[:, 2, 0] = (mat[:, 1, 0] * mat[:, 2, 1] - mat[:, 2, 0] * mat[:, 1, 1]) * inv_det 65 | mat_inv[:, 2, 1] = (mat[:, 2, 0] * mat[:, 0, 1] - mat[:, 0, 0] * mat[:, 2, 1]) * inv_det 66 | mat_inv[:, 2, 2] = (mat[:, 0, 0] * mat[:, 1, 1] - mat[:, 1, 0] * mat[:, 0, 1]) * inv_det 67 | 68 | # Divide the maximum value once more 69 | mat_inv = mat_inv / max_vals 70 | return mat_inv 71 | 72 | 73 | def mat_3x3_det(mat): 74 | ''' 75 | calculate the determinant of a 3x3 matrix, support batch. 76 | ''' 77 | if len(mat.shape) < 3: 78 | mat = mat[None] 79 | assert mat.shape[1:] == (3, 3) 80 | 81 | det = mat[:, 0, 0] * (mat[:, 1, 1] * mat[:, 2, 2] - mat[:, 2, 1] * mat[:, 1, 2]) \ 82 | - mat[:, 0, 1] * (mat[:, 1, 0] * mat[:, 2, 2] - mat[:, 1, 2] * mat[:, 2, 0]) \ 83 | + mat[:, 0, 2] * (mat[:, 1, 0] * mat[:, 2, 1] - mat[:, 1, 1] * mat[:, 2, 0]) 84 | return det 85 | 86 | 87 | ######### Linearized multi-sampling ######### 88 | def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'): 89 | ''' 90 | original function prototype: 91 | torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros') 92 | copy from pytorch 1.2.0 source code 93 | ''' 94 | if mode == 'linearized': 95 | return LinearizedMutilSampler.linearized_grid_sample(input, grid, padding_mode) 96 | else: 97 | return torch.nn.functional.grid_sample(input, grid, mode, padding_mode) 98 | 99 | 100 | class LinearizedMutilSampler(): 101 | 102 | num_grid = 8 103 | NUM_XY = 2 104 | noise_strength = 0.5 105 | need_push_away = True 106 | fixed_bias = False 107 | is_hyperparameters_set = False 108 | 109 | @classmethod 110 | def set_hyperparameters(cls, opt): 111 | cls.num_grid = opt.num_grid 112 | cls.noise_strength = opt.noise_strength 113 | cls.need_push_away = opt.need_push_away 114 | cls.fixed_bias = opt.fixed_bias 115 | if cls.is_hyperparameters_set: 116 | raise RuntimeError('Trying to reset the hyperparamter for linearized multi sampler, currently not allowed') 117 | else: 118 | cls.is_hyperparameters_set = True 119 | notification = [] 120 | notification.append('Hyperparameters are set') 121 | notification.append('num_grid: {0}'.format(cls.num_grid)) 122 | notification.append('noise_strength: {0}'.format(cls.noise_strength)) 123 | notification.append('need_push_away: {0}'.format(cls.need_push_away)) 124 | notification.append('fixed_bias: {0}'.format(cls.fixed_bias)) 125 | print_notification(notification) 126 | 127 | @classmethod 128 | def linearized_grid_sample(cls, input, grid, padding_mode): 129 | # assert cls.is_hyperparameters_set, 'linearized sampler hyperparameters are not set' 130 | assert isinstance(input, torch.Tensor), 'cannot process data type: {0}'.format(type(input)) 131 | assert isinstance(grid, torch.Tensor), 'cannot process data type: {0}'.format(type(grid)) 132 | 133 | batch_size, source_channels, source_height, source_width = input.shape 134 | least_offset = torch.tensor([2.0 / source_width, 2.0 / source_height], device=grid.device) 135 | auxiliary_grid = cls.create_auxiliary_grid(grid, least_offset) 136 | warped_input = cls.warp_input_with_auxiliary_grid(input, auxiliary_grid, padding_mode) 137 | out = cls.linearized_fitting(warped_input, auxiliary_grid) 138 | return out 139 | 140 | @classmethod 141 | def linearized_fitting(cls, input, grid): 142 | def defensive_assert(input, grid): 143 | assert len(input.shape) == 5, 'shape should be: B x Grid x C x H x W' 144 | assert len(grid.shape) == 5, 'shape should be: B x Grid x H x W x XY' 145 | assert input.shape[0] == grid.shape[0] 146 | assert input.shape[1] == grid.shape[1] 147 | assert input.shape[1] > 1, 'num of grid should be larger than 1' 148 | 149 | def get_center_and_other(input, grid): 150 | center_image = input[:, 0:1] 151 | center_grid = grid[:, 0:1] 152 | other_image = input[:, 1:] 153 | otehr_grid = grid[:, 1:] 154 | result = {'center_image': center_image, 155 | 'center_grid': center_grid, 156 | 'other_image': other_image, 157 | 'other_grid': otehr_grid, 158 | } 159 | return result 160 | 161 | defensive_assert(input, grid) 162 | extracted_dict = get_center_and_other(input, grid) 163 | delta_vals = cls.get_delta_vals(extracted_dict) 164 | center_image = extracted_dict['center_image'] 165 | center_grid = extracted_dict['center_grid'] 166 | delta_intensity = delta_vals['delta_intensity'] 167 | delta_grid = delta_vals['delta_grid'] 168 | # reshape to [B, H, W, Grid-1, XY1] 169 | delta_grid = delta_grid.permute(0, 2, 3, 1, 4) 170 | # reshape to [B, H, W, Grid-1, C] 171 | delta_intensity = delta_intensity.permute(0, 3, 4, 1, 2) 172 | # calculate dI/dX, euqation(7) in paper 173 | xTx = torch.matmul(torch.transpose(delta_grid, 3, 4), delta_grid) 174 | 175 | # take inverse 176 | xTx_inv = mat_3x3_inv(xTx.view(-1, 3, 3)) 177 | xTx_inv = xTx_inv.view(xTx.shape) 178 | xTx_inv_xT = torch.matmul(xTx_inv, torch.transpose(delta_grid, 3, 4)) 179 | # gradient_intensity shape: [B, H, W, XY1, C] 180 | gradient_intensity = torch.matmul(xTx_inv_xT, delta_intensity) 181 | 182 | if has_nan(gradient_intensity): 183 | print('nan val in gradient_intensity') 184 | nan_idx = is_nan(gradient_intensity) 185 | gradient_intensity[nan_idx] = torch.zeros(gradient_intensity[nan_idx].shape, 186 | device=gradient_intensity.device).detach() 187 | 188 | # stop gradient 189 | gradient_intensity_stop = gradient_intensity.detach() 190 | center_grid_stop = center_grid.detach() 191 | 192 | # center_grid shape: [B, H, W, XY1, 1] 193 | center_grid_xyz = torch.cat([center_grid, torch.ones(center_grid[..., 0:1].shape, device=center_grid.device)], 194 | dim=4).permute(0, 2, 3, 4, 1) 195 | if cls.fixed_bias: 196 | center_grid_xyz_stop = torch.cat([center_grid_stop, torch.ones(center_grid_stop[..., 0:1].shape, device=center_grid_stop.device)], 197 | dim=4).permute(0, 2, 3, 4, 1) 198 | else: 199 | center_grid_xyz_stop = torch.cat([center_grid_stop, torch.zeros(center_grid_stop[..., 0:1].shape, device=center_grid_stop.device)], 200 | dim=4).permute(0, 2, 3, 4, 1) 201 | 202 | # map to linearized, equation(2) in paper 203 | image_linearized = torch.matmul(gradient_intensity_stop.transpose(3, 4), (center_grid_xyz - center_grid_xyz_stop))[..., 0].permute(0, 3, 1, 2) + center_image[:, 0] 204 | return image_linearized 205 | 206 | @staticmethod 207 | def get_delta_vals(data_dict): 208 | def defensive_assert(center_image, other_image): 209 | assert len(center_image.shape) == 5, 'shape should be: B x Grid x C x H x W' 210 | assert len(other_image.shape) == 5, 'shape should be: B x Grid x C x H x W' 211 | assert center_image.shape[0] == other_image.shape[0] 212 | assert center_image.shape[1] == 1, 'num of center_image per single sample should be 1' 213 | assert other_image.shape[1] >= 1, ('num of other_image per single sample should be larger' 214 | ' or equal than 1, got shape {0} for {1}'.format(other_image.shape, 215 | 'other_image')) 216 | 217 | center_image = data_dict['center_image'] 218 | center_grid = data_dict['center_grid'] 219 | other_image = data_dict['other_image'] 220 | other_grid = data_dict['other_grid'] 221 | defensive_assert(center_image, other_image) 222 | 223 | batch_size = other_image.shape[0] 224 | num_other_image = other_image.shape[1] 225 | center_image_batch = center_image.repeat([1, num_other_image, 1, 1, 1]) 226 | center_grid_batch = center_grid.repeat([1, num_other_image, 1, 1, 1]) 227 | delta_intensity = other_image - center_image_batch 228 | delta_grid = other_grid - center_grid_batch 229 | delta_mask = (delta_grid[..., 0:1] >= -1.0) * (delta_grid[..., 0:1] <= 1.0) * (delta_grid[..., 1:2] >= -1.0) * (delta_grid[..., 1:2] <= 1.0) 230 | delta_mask = delta_mask.float() 231 | delta_grid = torch.cat([delta_grid, torch.ones(delta_grid[..., 0:1].shape, device=delta_grid.device)], dim=4) 232 | delta_grid *= delta_mask 233 | delta_vals = {'delta_intensity': delta_intensity, 234 | 'delta_grid': delta_grid, 235 | } 236 | return delta_vals 237 | 238 | @staticmethod 239 | def warp_input_with_auxiliary_grid(input, grid, padding_mode): 240 | assert len(input.shape) == 4 241 | assert len(grid.shape) == 5 242 | assert input.shape[0] == grid.shape[0] 243 | 244 | batch_size, num_grid, height, width, num_xy = grid.shape 245 | grid = grid.reshape([-1, height, width, num_xy]) 246 | grid = grid.detach() 247 | input = input.repeat_interleave(num_grid, 0) 248 | warped_input = torch.nn.functional.grid_sample(input, grid, mode='bilinear', 249 | padding_mode=padding_mode) 250 | warped_input = warped_input.reshape(batch_size, num_grid, -1, height, width) 251 | return warped_input 252 | 253 | @classmethod 254 | def create_auxiliary_grid(cls, grid, least_offset): 255 | batch_size, height, width, num_xy = grid.shape 256 | grid = grid.repeat(1, cls.num_grid, 1, 1).reshape(batch_size, cls.num_grid, height, width, num_xy) 257 | grid = cls.add_noise_to_grid(grid, least_offset) 258 | return grid 259 | 260 | @classmethod 261 | def add_noise_to_grid(cls, grid, least_offset): 262 | grid_shape = grid.shape 263 | assert len(grid_shape) == 5 264 | batch_size, num_grid, height, width, num_xy = grid_shape 265 | assert num_xy == cls.NUM_XY 266 | assert num_grid == cls.num_grid 267 | 268 | grid_noise = torch.randn([batch_size, cls.num_grid - 1, height, width, num_xy], device=grid.device) / torch.tensor([[width, height]], dtype=torch.float32, device=grid.device) * cls.noise_strength 269 | grid[:, 1:] += grid_noise 270 | if cls.need_push_away: 271 | grid = cls.push_away_samples(grid, least_offset) 272 | return grid 273 | 274 | @classmethod 275 | def push_away_samples(cls, grid, least_offset): 276 | grid_shape = grid.shape 277 | assert len(grid_shape) == 5 278 | batch_size, num_grid, height, width, num_xy = grid_shape 279 | assert num_xy == cls.NUM_XY 280 | assert num_grid == cls.num_grid 281 | assert cls.need_push_away 282 | 283 | noise = torch.randn(grid[:, 1:].shape, device=grid.device) 284 | noise = noise * least_offset 285 | grid[:, 1:] = grid[:, 1:] + noise 286 | return grid 287 | -------------------------------------------------------------------------------- /warp/perturbation_helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | function for generation perturbation 3 | modified from: https://github.com/chenhsuanlin/inverse-compositional-STN 4 | """ 5 | import numpy as np 6 | import torch 7 | 8 | from utils import utils 9 | 10 | 11 | def gen_perturbation_vec(opt, num_pert: int): 12 | """generate homography perturbation 13 | 14 | Arguments: 15 | opt -- [user defined options] 16 | num_pert -- [generate how many perturbations] 17 | Returns: 18 | transformation matrix, shape is (B, warp_dim) 19 | """ 20 | # TODO: remove np, use torch 21 | assert opt.need_pert, 'please enable perturbation' 22 | if opt.warp_type == 'translation': 23 | perturbation_vec = gen_pert_for_translation(opt, num_pert) 24 | elif opt.warp_type == 'trans+rot': 25 | perturbation_vec = gen_pert_for_trans_rot(opt, num_pert) 26 | elif opt.warp_type == 'similarity': 27 | perturbation_vec = gen_pert_for_similarity(opt, num_pert) 28 | else: 29 | raise ValueError('unknown warping method') 30 | return perturbation_vec 31 | 32 | 33 | def gen_perturbation_mat(opt, num_pert: int): 34 | """generate homography perturbation matrix 35 | 36 | Arguments: 37 | opt -- [user defined options] 38 | num_pert -- [generate how many perturbations] 39 | Returns: 40 | transformation matrix, shape is (B, 3, 3) 41 | """ 42 | perturbation_vec = gen_perturbation_vec(opt, num_pert) 43 | perturbation_mat = vec2mat(opt, perturbation_vec) 44 | return perturbation_mat 45 | 46 | 47 | def gen_identity_mat(num_ident: int): 48 | """ 49 | :param num_ident: number of 3x3 identity matrix 50 | :return: identity matrix, shape is (B, 3, 3) 51 | """ 52 | identity = torch.eye(3) 53 | identity = identity.repeat(num_ident, 1, 1) 54 | return identity 55 | 56 | 57 | def gen_random_rotation(opt, num_pert: int): 58 | rad = float(opt.rotation_range) / 180.0 * np.pi 59 | if opt.pert_distribution == 'normal': 60 | theta = np.clip(np.random.normal(size=(num_pert,)) * rad, -2.0 * rad, 2.0 * rad) 61 | elif opt.pert_distribution == 'uniform': 62 | theta = np.random.uniform(low=-1, high=1, size=(num_pert,)) * rad 63 | else: 64 | raise NotImplementedError('unknown sampling distribution') 65 | return theta 66 | 67 | 68 | def gen_random_translation(opt, num_pert): 69 | if opt.pert_distribution == 'normal': 70 | dx = np.clip(np.random.normal(size=(num_pert,)) * opt.translation_range, 71 | -2.0 * opt.translation_range, 72 | 2.0 * opt.translation_range) 73 | elif opt.pert_distribution == 'uniform': 74 | dx = np.random.uniform(low=-1, high=1, size=(num_pert,)) * opt.translation_range 75 | else: 76 | raise NotImplementedError('unknown sampling distribution') 77 | return dx 78 | 79 | 80 | def gen_random_scaling(opt, num_pert): 81 | if opt.pert_distribution == 'normal': 82 | sx = np.clip(np.random.normal(size=(num_pert,)) * opt.scaling_range, 83 | -2.0 * opt.scaling_range, 84 | 2.0 * opt.scaling_range) 85 | elif opt.pert_distribution == 'uniform': 86 | sx = np.random.uniform(low=-1, high=1, size=(num_pert,)) * opt.scaling_range 87 | else: 88 | raise NotImplementedError('unknown sampling distribution') 89 | return sx 90 | 91 | 92 | def gen_pert_for_translation(opt, num_pert): 93 | dx = gen_random_translation(opt, num_pert) 94 | dy = gen_random_translation(opt, num_pert) 95 | # make it a torch vector 96 | perturbation_vec = utils.to_torch(np.stack([dx, dy], axis=-1).astype(np.float32)) 97 | return perturbation_vec 98 | 99 | 100 | def gen_pert_for_trans_rot(opt, num_pert): 101 | theta = gen_random_rotation(opt, num_pert) 102 | dx = gen_random_translation(opt, num_pert) 103 | dy = gen_random_translation(opt, num_pert) 104 | # make it a torch vector 105 | perturbation_vec = utils.to_torch(np.stack([theta, dx, dy], axis=-1).astype(np.float32)) 106 | return perturbation_vec 107 | 108 | 109 | def gen_pert_for_similarity(opt, num_pert): 110 | theta = gen_random_rotation(opt, num_pert) 111 | s = gen_random_scaling(opt, num_pert) 112 | dx = gen_random_translation(opt, num_pert) 113 | dy = gen_random_translation(opt, num_pert) 114 | # make it a torch vector 115 | perturbation_vec = utils.to_torch(np.stack([theta, s, dx, dy], axis=-1).astype(np.float32)) 116 | return perturbation_vec 117 | 118 | 119 | def vec2mat(opt, vec): 120 | """covert a transformation vector to transformation matrix, 121 | 122 | Arguments: 123 | 124 | vec -- [transformation vector: , shape: (B, n)], where n is the number of warping parameters 125 | Returns: 126 | mat -- [transformation matrix, shape: (B, 3, 3)] 127 | """ 128 | assert isinstance(vec, torch.Tensor), 'cannot process data type: {0}'.format(type(vec)) 129 | if len(vec.shape) == 1: 130 | vec = vec[None] 131 | assert len(vec.shape) == 2 132 | if opt.warp_type == 'translation': 133 | transformation_mat = vec2mat_for_translation(vec) 134 | elif opt.warp_type == 'trans+rot': 135 | transformation_mat = vec2mat_for_trans_rot(vec) 136 | elif opt.warp_type == 'similarity': 137 | transformation_mat = vec2mat_for_similarity(vec) 138 | else: 139 | raise NotImplementedError('unknown warping method') 140 | return transformation_mat 141 | 142 | 143 | def vec2mat_for_translation(vec): 144 | assert vec.shape[1] == 2 145 | _len = vec.shape[0] 146 | O = torch.zeros([_len], dtype=torch.float32, device=vec.device) 147 | I = torch.ones([_len], dtype=torch.float32, device=vec.device) 148 | 149 | dx, dy = torch.unbind(vec, dim=1) 150 | transformation_mat = torch.stack([torch.stack([I, O, dx], dim=-1), 151 | torch.stack([O, I, dy], dim=-1), 152 | torch.stack([O, O, I], dim=-1)], dim=1) 153 | return transformation_mat 154 | 155 | 156 | def vec2mat_for_trans_rot(vec): 157 | assert vec.shape[1] == 3 158 | _len = vec.shape[0] 159 | O = torch.zeros([_len], dtype=torch.float32, device=vec.device) 160 | I = torch.ones([_len], dtype=torch.float32, device=vec.device) 161 | 162 | p1, p2, p3 = torch.unbind(vec, dim=1) 163 | theta = p1 164 | cos = torch.cos(theta) 165 | sin = torch.sin(theta) 166 | dx = p2 167 | dy = p3 168 | R = torch.stack([torch.stack([cos, -sin, O], dim=-1), 169 | torch.stack([sin, cos, O], dim=-1), 170 | torch.stack([O, O, I], dim=-1)], dim=1) 171 | S = torch.stack([torch.stack([I, O, O], dim=-1), 172 | torch.stack([O, I, O], dim=-1), 173 | torch.stack([O, O, I], dim=-1)], dim=1) 174 | T = torch.stack([torch.stack([I, O, dx], dim=-1), 175 | torch.stack([O, I, dy], dim=-1), 176 | torch.stack([O, O, I], dim=-1)], dim=1) 177 | transformation_mat = torch.bmm(R, torch.bmm(S, T)) 178 | 179 | return transformation_mat 180 | 181 | 182 | def vec2mat_for_similarity(vec): 183 | assert vec.shape[1] == 4 184 | _len = vec.shape[0] 185 | O = torch.zeros([_len], dtype=torch.float32, device=vec.device) 186 | I = torch.ones([_len], dtype=torch.float32, device=vec.device) 187 | 188 | p1, p2, p3, p4 = torch.unbind(vec, dim=1) 189 | theta = p1 190 | cos = torch.cos(theta) 191 | sin = torch.sin(theta) 192 | s = 2.0 ** (p2) 193 | dx = p3 194 | dy = p4 195 | R = torch.stack([torch.stack([cos, -sin, O], dim=-1), 196 | torch.stack([sin, cos, O], dim=-1), 197 | torch.stack([O, O, I], dim=-1)], dim=1) 198 | S = torch.stack([torch.stack([s, O, O], dim=-1), 199 | torch.stack([O, s, O], dim=-1), 200 | torch.stack([O, O, I], dim=-1)], dim=1) 201 | T = torch.stack([torch.stack([I, O, dx], dim=-1), 202 | torch.stack([O, I, dy], dim=-1), 203 | torch.stack([O, O, I], dim=-1)], dim=1) 204 | transformation_mat = torch.bmm(R, torch.bmm(S, T)) 205 | 206 | return transformation_mat 207 | -------------------------------------------------------------------------------- /warp/sampling_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from warp import linearized 3 | 4 | 5 | class DifferentiableImageSampler(): 6 | ''' 7 | a differentiable image sampler which works with homography 8 | ''' 9 | 10 | def __init__(self, sampling_mode, padding_mode): 11 | self.sampling_mode = sampling_mode 12 | self.padding_mode = padding_mode 13 | self.flatten_xy_cache = {} 14 | 15 | def warp_image(self, image, homography, out_shape=None): 16 | assert isinstance(image, torch.Tensor), 'cannot process data type: {0}'.format(type(image)) 17 | assert isinstance(homography, torch.Tensor), 'cannot process data type: {0}'.format(type(homography)) 18 | 19 | if out_shape is None: 20 | out_shape = image.shape[-2:] 21 | if len(image.shape) < 4: 22 | image = image[None] 23 | if len(homography.shape) < 3: 24 | homography = homography[None] 25 | assert image.shape[0] == homography.shape[0], 'batch size of images do not match the batch size of homographies' 26 | # create grid for interpolation (in frame coordinates) 27 | x, y = self.create_flatten_xy(x_steps=out_shape[-1], y_steps=out_shape[-2], device=homography.device) 28 | grid = self.flatten_xy_to_warped_grid(x, y, homography, out_shape) 29 | # sample warped image 30 | warped_img = linearized.grid_sample(image, grid, mode=self.sampling_mode, padding_mode=self.padding_mode) 31 | 32 | if linearized.has_nan(warped_img): 33 | print('nan value in warped image! set to zeros') 34 | warped_img[linearized.is_nan(warped_img)] = 0 35 | 36 | return warped_img 37 | 38 | def create_flatten_xy(self, x_steps: int, y_steps: int, device): 39 | if (x_steps, y_steps) in self.flatten_xy_cache: 40 | x, y = self.flatten_xy_cache[(x_steps, y_steps)] 41 | return x.clone(), y.clone() 42 | y, x = torch.meshgrid([ 43 | torch.linspace(-1.0, 1.0, steps=y_steps, device=device), 44 | torch.linspace(-1.0, 1.0, steps=x_steps, device=device) 45 | ]) 46 | x, y = x.flatten(), y.flatten() 47 | self.flatten_xy_cache[(x_steps, y_steps)] = x, y 48 | return x.clone(), y.clone() 49 | 50 | def flatten_xy_to_warped_grid(self, x, y, homography, out_shape): 51 | batch_size = homography.shape[0] 52 | # append ones for homogeneous coordinates 53 | xy = torch.stack([x, y, torch.ones_like(x)]) 54 | xy = xy.repeat([batch_size, 1, 1]) # shape: (B, 3, N) 55 | 56 | xy_warped = torch.matmul(homography, xy) 57 | xy_warped, z_warped = xy_warped.split(2, dim=1) 58 | xy_warped = xy_warped / (z_warped + 1e-8) 59 | x_warped, y_warped = torch.unbind(xy_warped, dim=1) 60 | # build grid 61 | grid = torch.stack([ 62 | x_warped.view(batch_size, *out_shape[-2:]), 63 | y_warped.view(batch_size, *out_shape[-2:]) 64 | ], dim=-1) 65 | return grid 66 | --------------------------------------------------------------------------------