├── README.md ├── code ├── dataset.py ├── homography_CNN_synthetic.py ├── homography_model.py └── utils │ ├── __init__.py │ ├── gen_synthetic_data.py │ ├── numpy_spatial_transformer.py │ ├── torch_spatial_transformer.py │ └── utils.py └── results └── synthetic └── report ├── 0025.jpg ├── 0026.jpg ├── 0027.jpg ├── 0028.jpg └── 0029.jpg /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Deep Homography - PyTorch Implementation 2 | 3 | [**Unsupervised Deep Homography: A Fast and Robust Homography Estimation 4 | Model**](https://arxiv.org/abs/1709.03966)
5 | Ty Nguyen, Steven W. Chen, Shreyas S. Shivakumar, Camillo J. Taylor, Vijay 6 | Kumar
7 | 8 | ```bash 9 | cd code/ 10 | ``` 11 | in code/ folder: 12 | 13 | `dataset.py`: class SyntheticDataset(torch.utils.data.Dataset) implementation
14 | `homography_model.py`: Unsupervised deep homography model implementation
15 | `homography_CNN_synthetic.py`: Train and test 16 | 17 | ## Preparing training dataset (synthetic) 18 | Download MS-COCO 2014 dataset
19 | Store Train and test set into RAW_DATA_PATH and TEST_RAW_DATA_PATH respectly. 20 | ### Generate training dataset 21 | It will take a few hours to generate 100.000 data samples. 22 | ```bash 23 | python utils/gen_synthetic_data.py --mode train 24 | ``` 25 | ### Generate test dataset 26 | ```bash 27 | python utils/gen_synthetic_data.py --mode test 28 | ``` 29 | 30 | ## Train model with synthetic dataset 31 | ```bash 32 | python homography_CNN_synthetic.py --mode train 33 | ``` 34 | 35 | ## Test model with synthetic dataset 36 | Download pre-trained weights 37 | ```bash 38 | 链接:https://pan.baidu.com/s/102ilb5HJGydpeHtYelx_Xw 提取码:boq9 39 | ``` 40 | Store the model to models/synthetic_models folder 41 | ```bash 42 | python homography_CNN_synthetic.py --mode test 43 | ``` 44 | 45 | results | 46 | --- | 47 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210322175425747.png?x-oss-process) | 48 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210322175643842.png?x-oss-process) | 49 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210322180132270.png?x-oss-process) | 50 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/2021032218020122.png?x-oss-process) | 51 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210322180502181.png?x-oss-process) | 52 | 53 | ## Release History 54 | * **2021.4.5** 55 | * Add TensorBoard visualization and some metrics. 56 | 57 | ## Reference 58 | [https://github.com/tynguyen/unsupervisedDeepHomographyRAL2018](https://github.com/tynguyen/unsupervisedDeepHomographyRAL2018) 59 | -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 as cv 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | def read_img_and_gt(filenames_file, pts1_file, gt_file): 8 | with open(filenames_file, 'r') as img_f: 9 | img_array = img_f.readlines() 10 | img_array = [x.strip() for x in img_array] 11 | img_array = [x.split() for x in img_array] # Use x.split()[0] if assuming image left and right have same name 12 | 13 | with open(pts1_file, 'r') as pts1_f: 14 | pts1_array = pts1_f.readlines() 15 | pts1_array = [x.strip() for x in pts1_array] 16 | pts1_array = [x.split() for x in pts1_array] 17 | pts1_array = np.array(pts1_array).astype('float64') 18 | 19 | # In case there is not ground truth 20 | if not gt_file: 21 | return img_array, pts1_array, None 22 | 23 | with open(gt_file, 'r') as gt_f: 24 | gt_array = gt_f.readlines() 25 | gt_array = [x.strip() for x in gt_array] 26 | gt_array = [x.split() for x in gt_array] 27 | gt_array = np.array(gt_array).astype('float64') 28 | 29 | return img_array, pts1_array, gt_array 30 | 31 | def get_mesh_grid_per_img(patch_w, patch_h): 32 | x_flat = np.arange(0, patch_w) 33 | x_flat = x_flat[np.newaxis, :] 34 | y_one = np.ones(patch_h) 35 | y_one = y_one[:, np.newaxis] 36 | x_mesh = np.matmul(y_one, x_flat) 37 | 38 | y_flat = np.arange(0, patch_h) 39 | y_flat = y_flat[:, np.newaxis] 40 | x_one = np.ones(patch_w) 41 | x_one = x_one[np.newaxis, :] 42 | y_mesh = np.matmul(y_flat, x_one) 43 | x_t_flat = np.reshape(x_mesh, (-1)) 44 | y_t_flat = np.reshape(y_mesh, (-1)) 45 | return x_t_flat, y_t_flat 46 | 47 | class SyntheticDataset(Dataset): 48 | """Load synthetic data""" 49 | def __init__(self, data_path, mode, img_h, img_w, patch_size, do_augment): 50 | self.mode = mode 51 | if self.mode == "train": 52 | self.data_path = data_path + "train/" 53 | self.pts1_file = os.path.join(self.data_path, 'pts1.txt') 54 | self.filenames_file = os.path.join(self.data_path, 'train_synthetic.txt') 55 | self.gt_file = os.path.join(self.data_path, 'gt.txt') 56 | elif self.mode == "test": 57 | self.data_path = data_path + "test/" 58 | self.pts1_file = os.path.join(self.data_path, 'test_pts1.txt') 59 | self.filenames_file = os.path.join(self.data_path, 'test_synthetic.txt') 60 | self.gt_file = os.path.join(self.data_path, 'test_gt.txt') 61 | self.img_h = img_h 62 | self.img_w = img_w 63 | self.patch_size = patch_size 64 | self.do_augment = do_augment 65 | self.mean_I = np.reshape(np.array([118.93, 113.97, 102.60]), (1, 1, 3)) 66 | self.std_I = np.reshape(np.array([69.85, 68.81, 72.45]), (1, 1, 3)) 67 | 68 | # Read to arrays 69 | self.img_np, self.pts1_np, self.gt_np = read_img_and_gt(self.filenames_file, self.pts1_file, self.gt_file) 70 | 71 | # Find indices of the pixels in the patch w.r.t the large image 72 | # All patches have the same size so their pixels have the same base indices 73 | self.x_t_flat, self.y_t_flat = get_mesh_grid_per_img(patch_size, patch_size) 74 | 75 | def __len__(self): 76 | return len(self.img_np) 77 | 78 | def __getitem__(self, index): 79 | pts1_index = self.pts1_np[index] 80 | gt_index = self.gt_np[index] 81 | split_line = self.img_np[index] 82 | 83 | I_path = self.data_path + 'I/' + split_line[0] 84 | I_prime_path = self.data_path + 'I_prime/' + split_line[1] 85 | 86 | I = self.read_image(I_path, self.img_h, self.img_w) 87 | I_prime = self.read_image(I_prime_path, self.img_h, self.img_w) 88 | 89 | # Data Augmentation 90 | do_augment = np.random.uniform(0, 1) 91 | # Training: use joint augmentation (images in one pair are inserted same noise) 92 | # Test: use disjoint augmentation (images in one pair are inserted different noise) 93 | if self.mode == 'train': 94 | I_aug, I_prime_aug = self.joint_augment_image_pair(I, I_prime, 0, 255) \ 95 | if do_augment > (1 - self.do_augment) else (I, I_prime) 96 | else: 97 | I_aug, I_prime_aug = self.disjoint_augment_image_pair(I, I_prime, 0, 255) \ 98 | if do_augment > (1 - self.do_augment) else (I, I_prime) 99 | 100 | # Standardize images 101 | I = self.norm_img(I, self.mean_I, self.std_I) 102 | I_prime = self.norm_img(I_prime, self.mean_I, self.std_I) 103 | # These are augmented large images which will be used 104 | I_aug = self.norm_img(I_aug, self.mean_I, self.std_I) 105 | I_prime_aug = self.norm_img(I_prime_aug, self.mean_I, self.std_I) 106 | 107 | # Read patch_indices 108 | x_start = pts1_index[0] # x 109 | y_start = pts1_index[1] # y 110 | patch_indices = (self.y_t_flat + y_start) * self.img_w + (self.x_t_flat + x_start) 111 | 112 | # Convert to tensor 113 | I = torch.tensor(I) 114 | I_prime = torch.tensor(I_prime) 115 | I_aug = torch.tensor(I_aug) 116 | I_prime_aug = torch.tensor(I_prime_aug) 117 | pts1_tensor = torch.tensor(pts1_index) 118 | gt_tensor = torch.tensor(gt_index) 119 | patch_indices = torch.tensor(patch_indices) 120 | 121 | # Obtain I1, I2, I1_aug and I2_aug 122 | I_flat = torch.reshape(torch.mean(I, 0), [-1]) # I: 3xHxW 123 | I_prime_flat = torch.reshape(torch.mean(I_prime, 0), [-1]) # I_prime: 3xHxW 124 | I_aug_flat = torch.reshape(torch.mean(I_aug, 0), [-1]) # I_aug: 3xHxW 125 | I_prime_aug_flat = torch.reshape(torch.mean(I_prime_aug, 0), [-1]) # I_prime_aug: 3xHxW 126 | 127 | patch_indices_flat = torch.reshape(patch_indices, [-1]) 128 | pixel_indices = patch_indices_flat.long() 129 | 130 | I1_flat = torch.gather(I_flat, 0, pixel_indices) 131 | I2_flat = torch.gather(I_prime_flat, 0, pixel_indices) 132 | I1_aug_flat = torch.gather(I_aug_flat, 0, pixel_indices) 133 | I2_aug_flat = torch.gather(I_prime_aug_flat, 0, pixel_indices) 134 | 135 | I1 = torch.reshape(I1_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1) 136 | I2 = torch.reshape(I2_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1) 137 | I1_aug = torch.reshape(I1_aug_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1) 138 | I2_aug = torch.reshape(I2_aug_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1) 139 | 140 | return I1, I2, I1_aug, I2_aug, I_aug, I_prime_aug, pts1_tensor, gt_tensor, patch_indices 141 | 142 | def read_image(self, image_path, img_h, img_w): 143 | image = cv.imread(image_path) 144 | height, width = image.shape[:2] 145 | if height != img_h or width != img_w: 146 | image = cv.resize(image, (img_w, img_h), interpolation=cv.INTER_AREA) 147 | return image 148 | 149 | def norm_img(self, img, mean, std): 150 | img = (img - mean) / std 151 | img = np.transpose(img, [2, 0, 1]) # torch [C,H,W] 152 | return img 153 | 154 | def disjoint_augment_image_pair(self, img1, img2, min_val=0, max_val=255): 155 | # Randomly shift gamma 156 | random_gamma = np.random.uniform(0.8, 1.2) 157 | img1_aug = img1 ** random_gamma 158 | random_gamma = np.random.uniform(0.8, 1.2) 159 | img2_aug = img2 ** random_gamma 160 | 161 | # Randomly shift brightness 162 | random_brightness = np.random.uniform(0.5, 2.0) 163 | img1_aug = img1_aug * random_brightness 164 | random_brightness = np.random.uniform(0.5, 2.0) 165 | img2_aug = img2_aug * random_brightness 166 | 167 | # Randomly shift color 168 | random_colors = np.random.uniform(0.8, 1.2, 3) 169 | white = np.ones([img1.shape[0], img1.shape[1], 1]) 170 | color_image = np.concatenate([white * random_colors[i] for i in range(3)], axis=2) 171 | img1_aug *= color_image 172 | 173 | random_colors = np.random.uniform(0.8, 1.2, 3) 174 | white = np.ones([img1.shape[0], img1.shape[1], 1]) 175 | color_image = np.concatenate([white * random_colors[i] for i in range(3)], axis=2) 176 | img2_aug *= color_image 177 | 178 | # Saturate 179 | img1_aug = np.clip(img1_aug, min_val, max_val) 180 | img2_aug = np.clip(img2_aug, min_val, max_val) 181 | 182 | return img1_aug, img2_aug 183 | 184 | def joint_augment_image_pair(self, img1, img2, min_val=0, max_val=255): 185 | # Randomly shift gamma 186 | random_gamma = np.random.uniform(0.8, 1.2) 187 | img1_aug = img1 ** random_gamma 188 | img2_aug = img2 ** random_gamma 189 | 190 | # Randomly shift brightness 191 | random_brightness = np.random.uniform(0.5, 2.0) 192 | img1_aug = img1_aug * random_brightness 193 | img2_aug = img2_aug * random_brightness 194 | 195 | # Randomly shift color 196 | random_colors = np.random.uniform(0.8, 1.2, 3) 197 | white = np.ones([img1.shape[0], img1.shape[1], 1]) 198 | color_image = np.concatenate([white * random_colors[i] for i in range(3)], axis=2) 199 | img1_aug *= color_image 200 | img2_aug *= color_image 201 | 202 | # Saturate 203 | img1_aug = np.clip(img1_aug, min_val, max_val) 204 | img2_aug = np.clip(img2_aug, min_val, max_val) 205 | 206 | return img1_aug, img2_aug 207 | 208 | 209 | if __name__ == "__main__": 210 | TrainDataset = SyntheticDataset(data_path="../data/synthetic/45/", 211 | mode='train', 212 | img_h=240, 213 | img_w=320, 214 | patch_size=128, 215 | do_augment=0.5) 216 | print(len(TrainDataset)) 217 | sample = TrainDataset[np.random.randint(0, len(TrainDataset))] 218 | for attr in sample: 219 | print(attr.shape, attr.dtype) -------------------------------------------------------------------------------- /code/homography_CNN_synthetic.py: -------------------------------------------------------------------------------- 1 | import os, shutil 2 | import argparse 3 | import torch 4 | from torch import optim 5 | from torch.utils.data import DataLoader 6 | from tensorboardX import SummaryWriter 7 | from dataset import SyntheticDataset 8 | from homography_model import HomographyModel 9 | from utils import utils 10 | import numpy as np 11 | import math 12 | import time 13 | 14 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 15 | 16 | 17 | def train(args): 18 | # Load data 19 | TrainDataset = SyntheticDataset(data_path=args.data_path, 20 | mode=args.mode, 21 | img_h=args.img_h, 22 | img_w=args.img_w, 23 | patch_size=args.patch_size, 24 | do_augment=args.do_augment) 25 | train_loader = DataLoader(TrainDataset, batch_size=args.batch_size, shuffle=True, num_workers=4) 26 | print('===> Train: There are totally {} training files'.format(len(TrainDataset))) 27 | 28 | net = HomographyModel(args.use_batch_norm) 29 | if args.resume: 30 | model_path = os.path.join(args.model_dir, args.model_name) 31 | ckpt = torch.load(model_path) 32 | net.load_state_dict(ckpt.state_dict()) 33 | if torch.cuda.is_available(): 34 | net = net.cuda() 35 | 36 | optimizer = optim.Adam(net.parameters(), lr=args.lr) # default as 0.0001 37 | decay_rate = 0.96 38 | step_size = (math.log(decay_rate) * args.max_epochs) / math.log(args.min_lr * 1.0 / args.lr) 39 | print('args lr:', args.lr, args.min_lr) 40 | print('===> Decay steps:', step_size) 41 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(step_size), gamma=0.96) 42 | 43 | print("start training") 44 | writer = SummaryWriter(logdir=args.log_dir, flush_secs=60) 45 | score_print_fre = 100 46 | summary_fre = 1000 47 | model_save_fre = 4000 48 | glob_iter = 0 49 | t0 = time.time() 50 | 51 | for epoch in range(args.max_epochs): 52 | net.train() 53 | epoch_start = time.time() 54 | train_l1_loss = 0.0 55 | train_l1_smooth_loss = 0.0 56 | train_h_loss = 0.0 57 | 58 | for i, batch_value in enumerate(train_loader): 59 | I1_batch = batch_value[0].float() 60 | I2_batch = batch_value[1].float() 61 | I1_aug_batch = batch_value[2].float() 62 | I2_aug_batch = batch_value[3].float() 63 | I_batch = batch_value[4].float() 64 | I_prime_batch = batch_value[5].float() 65 | pts1_batch = batch_value[6].float() 66 | gt_batch = batch_value[7].float() 67 | patch_indices_batch = batch_value[8].float() 68 | 69 | if torch.cuda.is_available(): 70 | I1_aug_batch = I1_aug_batch.cuda() 71 | I2_aug_batch = I2_aug_batch.cuda() 72 | I_batch = I_batch.cuda() 73 | pts1_batch = pts1_batch.cuda() 74 | gt_batch = gt_batch.cuda() 75 | patch_indices_batch = patch_indices_batch.cuda() 76 | 77 | # forward, backward, update weights 78 | optimizer.zero_grad() 79 | batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch) 80 | h_loss = batch_out['h_loss'] 81 | rec_loss = batch_out['rec_loss'] 82 | ssim_loss = batch_out['ssim_loss'] 83 | l1_loss = batch_out['l1_loss'] 84 | l1_smooth_loss = batch_out['l1_smooth_loss'] 85 | ncc_loss = batch_out['ncc_loss'] 86 | pred_I2 = batch_out['pred_I2'] 87 | 88 | loss = l1_loss 89 | loss.backward() 90 | optimizer.step() 91 | 92 | train_l1_loss += loss.item() 93 | train_l1_smooth_loss += l1_smooth_loss.item() 94 | train_h_loss += h_loss.item() 95 | if (i + 1) % score_print_fre == 0 or (i + 1) == len(train_loader): 96 | print( 97 | "Training: Epoch[{:0>3}/{:0>3}] Iter[{:0>3}]/[{:0>3}] l1 loss: {:.4f} " 98 | "l1 smooth loss: {:.4f} h loss: {:.4f} lr={:.8f}".format( 99 | epoch + 1, args.max_epochs, i + 1, len(train_loader), train_l1_loss / score_print_fre, 100 | train_l1_smooth_loss / score_print_fre, train_h_loss / score_print_fre, scheduler.get_lr()[0])) 101 | train_l1_loss = 0.0 102 | train_l1_smooth_loss = 0.0 103 | train_h_loss = 0.0 104 | 105 | if glob_iter % summary_fre == 0: 106 | writer.add_scalar('learning_rate', scheduler.get_lr()[0], glob_iter) 107 | writer.add_scalar('h_loss', h_loss, glob_iter) 108 | writer.add_scalar('rec_loss', rec_loss, glob_iter) 109 | writer.add_scalar('ssim_loss', ssim_loss, glob_iter) 110 | writer.add_scalar('l1_loss', l1_loss, glob_iter) 111 | writer.add_scalar('l1_smooth_loss', l1_smooth_loss, glob_iter) 112 | writer.add_scalar('ncc_loss', ncc_loss, glob_iter) 113 | 114 | writer.add_image('I', utils.denorm_img(I_batch[0, ...].cpu().numpy()).astype(np.uint8)[:, :, ::-1], 115 | glob_iter, dataformats='HWC') 116 | writer.add_image('I_prime', 117 | utils.denorm_img(I_prime_batch[0, ...].numpy()).astype(np.uint8)[:, :, ::-1], 118 | glob_iter, dataformats='HWC') 119 | 120 | writer.add_image('I1_aug', utils.denorm_img(I1_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8), 121 | glob_iter, dataformats='HW') 122 | writer.add_image('I2_aug', utils.denorm_img(I2_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8), 123 | glob_iter, dataformats='HW') 124 | writer.add_image('pred_I2', 125 | utils.denorm_img(pred_I2[0, 0, ...].cpu().detach().numpy()).astype(np.uint8), 126 | glob_iter, dataformats='HW') 127 | 128 | writer.add_image('I2', utils.denorm_img(I2_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter, 129 | dataformats='HW') 130 | writer.add_image('I1', utils.denorm_img(I1_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter, 131 | dataformats='HW') 132 | 133 | # save model 134 | if glob_iter % model_save_fre == 0 and glob_iter != 0: 135 | filename = 'model' + '_iter_' + str(glob_iter) + '.pth' 136 | model_save_path = os.path.join(args.model_dir, filename) 137 | torch.save(net, model_save_path) 138 | 139 | glob_iter += 1 140 | scheduler.step() 141 | print("Epoch: {} epoch time: {:.1f}s".format(epoch, time.time() - epoch_start)) 142 | 143 | elapsed_time = time.time() - t0 144 | print("Finished Training in {:.0f}h {:.0f}m {:.0f}s.".format( 145 | elapsed_time // 3600, (elapsed_time % 3600) // 60, (elapsed_time % 3600) % 60)) 146 | 147 | 148 | def test(args): 149 | # Load data 150 | TestDataset = SyntheticDataset(data_path=args.data_path, 151 | mode=args.mode, 152 | img_h=args.img_h, 153 | img_w=args.img_w, 154 | patch_size=args.patch_size, 155 | do_augment=args.do_augment) 156 | test_loader = DataLoader(TestDataset, batch_size=1) 157 | print('===> Test: There are totally {} testing files'.format(len(TestDataset))) 158 | 159 | # Load model 160 | net = HomographyModel() 161 | model_path = os.path.join(args.model_dir, args.model_name) 162 | state = torch.load(model_path) 163 | net.load_state_dict(state.state_dict()) 164 | if torch.cuda.is_available(): 165 | net = net.cuda() 166 | 167 | print("start testing") 168 | 169 | with torch.no_grad(): 170 | net.eval() 171 | test_l1_loss = 0.0 172 | test_h_loss = 0.0 173 | h_losses_array = [] 174 | for i, batch_value in enumerate(test_loader): 175 | I1_aug_batch = batch_value[2].float() 176 | I2_aug_batch = batch_value[3].float() 177 | I_batch = batch_value[4].float() 178 | I_prime_batch = batch_value[5].float() 179 | pts1_batch = batch_value[6].float() 180 | gt_batch = batch_value[7].float() 181 | patch_indices_batch = batch_value[8].float() 182 | 183 | if torch.cuda.is_available(): 184 | I1_aug_batch = I1_aug_batch.cuda() 185 | I2_aug_batch = I2_aug_batch.cuda() 186 | I_batch = I_batch.cuda() 187 | pts1_batch = pts1_batch.cuda() 188 | gt_batch = gt_batch.cuda() 189 | patch_indices_batch = patch_indices_batch.cuda() 190 | 191 | batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch) 192 | h_loss = batch_out['h_loss'] 193 | rec_loss = batch_out['rec_loss'] 194 | ssim_loss = batch_out['ssim_loss'] 195 | l1_loss = batch_out['l1_loss'] 196 | pred_h4p_value = batch_out['pred_h4p'] 197 | 198 | test_h_loss += h_loss.item() 199 | test_l1_loss += l1_loss.item() 200 | h_losses_array.append(h_loss.item()) 201 | 202 | if args.save_visual: 203 | I_sample = utils.denorm_img(I_batch[0].cpu().numpy()).astype(np.uint8) 204 | I_prime_sample = utils.denorm_img(I_prime_batch[0].numpy()).astype(np.uint8) 205 | pts1_sample = pts1_batch[0].cpu().numpy().reshape([4, 2]).astype(np.float32) 206 | gt_h4p_sample = gt_batch[0].cpu().numpy().reshape([4, 2]).astype(np.float32) 207 | 208 | pts2_sample = pts1_sample + gt_h4p_sample 209 | 210 | pred_h4p_sample = pred_h4p_value[0].cpu().numpy().reshape([4, 2]).astype(np.float32) 211 | pred_pts2_sample = pts1_sample + pred_h4p_sample 212 | 213 | # Save 214 | visual_file_name = ('%s' % i).zfill(4) + '.jpg' 215 | utils.save_correspondences_img(I_prime_sample, I_sample, pts1_sample, pts2_sample, pred_pts2_sample, 216 | args.results_dir, visual_file_name) 217 | 218 | print("Testing: h_loss: {:4.3f}, rec_loss: {:4.3f}, ssim_loss: {:4.3f}, l1_loss: {:4.3f}".format( 219 | h_loss.item(), rec_loss.item(), ssim_loss.item(), l1_loss.item() 220 | )) 221 | 222 | print('|Test size | h_loss | l1_loss |') 223 | print(len(test_loader), test_h_loss / len(test_loader), test_l1_loss / len(test_loader)) 224 | 225 | tops_list = utils.find_percentile(h_losses_array) 226 | print('===> Percentile Values: (20, 50, 80, 100):') 227 | print(tops_list) 228 | print('======> End! ====================================') 229 | 230 | 231 | def main(): 232 | # Size of synthetic image and the pertubation range (RH0) 233 | HEIGHT = 240 # 234 | WIDTH = 320 235 | RHO = 45 236 | PATCH_SIZE = 128 237 | 238 | # Synthetic data directories 239 | DATA_PATH = "../data/synthetic/" + str(RHO) + '/' 240 | 241 | # Log and model directories 242 | MAIN_LOG_PATH = '../' 243 | LOG_DIR = MAIN_LOG_PATH + "logs/" 244 | MODEL_DIR = MAIN_LOG_PATH + "models/synthetic_models" 245 | 246 | # Where to save visualization images (for report) 247 | RESULTS_DIR = MAIN_LOG_PATH + "results/synthetic/report/" 248 | 249 | def str2bool(s): 250 | return s.lower() == 'true' 251 | 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument('--mode', type=str, default='train', help='Train or test', choices=['train', 'test']) 254 | parser.add_argument('--use_batch_norm', type=str2bool, default='False', help='Use batch_norm?') 255 | parser.add_argument('--do_augment', type=float, default=0.5, 256 | help='Possibility of augmenting image: color shift, brightness shift...') 257 | 258 | parser.add_argument('--data_path', type=str, default=DATA_PATH, help='The raw data path.') 259 | parser.add_argument('--log_dir', type=str, default=LOG_DIR, help='The log path') 260 | parser.add_argument('--results_dir', type=str, default=RESULTS_DIR, help='Store visualization for report') 261 | parser.add_argument('--model_dir', type=str, default=MODEL_DIR, help='The models path') 262 | parser.add_argument('--model_name', type=str, default='model.pth', help='The model name') 263 | 264 | parser.add_argument('--save_visual', type=str2bool, default='True', help='Save visual images for report') 265 | 266 | parser.add_argument('--img_w', type=int, default=WIDTH) 267 | parser.add_argument('--img_h', type=int, default=HEIGHT) 268 | parser.add_argument('--patch_size', type=int, default=PATCH_SIZE) 269 | parser.add_argument('--batch_size', type=int, default=128) 270 | parser.add_argument('--max_epochs', type=int, default=150) 271 | parser.add_argument('--lr', type=float, default=1e-4, help='Max learning rate') 272 | parser.add_argument('--min_lr', type=float, default=.9e-4, help='Min learning rate') 273 | 274 | parser.add_argument('--resume', type=str2bool, default='False', 275 | help='True: restore the existing model. False: retrain') 276 | 277 | args = parser.parse_args() 278 | print('<==================== Loading data ===================>\n') 279 | 280 | if not args.resume: 281 | try: 282 | shutil.rmtree(args.log_dir) 283 | except: 284 | pass 285 | 286 | if not os.path.exists(args.model_dir): 287 | os.makedirs(args.model_dir) 288 | if not os.path.exists(args.log_dir): 289 | os.makedirs(args.log_dir) 290 | if not os.path.exists(args.results_dir): 291 | os.makedirs(args.results_dir) 292 | 293 | if args.mode == "train": 294 | train(args) 295 | else: 296 | test(args) 297 | 298 | 299 | if __name__ == "__main__": 300 | main() 301 | -------------------------------------------------------------------------------- /code/homography_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils.torch_spatial_transformer import transformer 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | def L1_smooth_loss(x, y): 8 | abs_diff = torch.abs(x - y) 9 | abs_diff_lt_1 = torch.le(abs_diff, 1) 10 | return torch.mean(torch.where(abs_diff_lt_1, 0.5 * abs_diff ** 2, abs_diff - 0.5)) 11 | 12 | def SSIM_loss(x, y, size=3): 13 | # C = (K*L)^2 with K = max of intensity range (i.e. 255). L is very small 14 | C1 = 0.01 ** 2 15 | C2 = 0.03 ** 2 16 | 17 | mu_x = F.avg_pool2d(x, size, 1, padding=0) 18 | mu_y = F.avg_pool2d(y, size, 1, padding=0) 19 | 20 | sigma_x = F.avg_pool2d(x ** 2, size, 1, padding=0) - mu_x ** 2 21 | sigma_y = F.avg_pool2d(y ** 2, size, 1, padding=0) - mu_y ** 2 22 | sigma_xy = F.avg_pool2d(x * y, size, 1, padding=0) - mu_x * mu_y 23 | 24 | SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) 25 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2) 26 | 27 | SSIM = SSIM_n / SSIM_d 28 | 29 | return torch.clamp((1 - SSIM) / 2, 0, 1) 30 | 31 | def NCC_loss(x, y): 32 | """Consider x, y are vectors. Take L2 of the difference 33 | of the them after being normalized by their length""" 34 | len_x = torch.sqrt(torch.sum(x ** 2)) 35 | len_y = torch.sqrt(torch.sum(y ** 2)) 36 | return torch.sqrt(torch.sum((x / len_x - y / len_y) ** 2)) 37 | 38 | class ConvBlock(nn.Module): 39 | def __init__(self, inchannels, outchannels, batch_norm=False, pool=True): 40 | super(ConvBlock, self).__init__() 41 | layers = [] 42 | layers.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1)) 43 | layers.append(nn.ReLU(inplace=True)) 44 | if batch_norm: 45 | layers.append(nn.BatchNorm2d(outchannels)) 46 | layers.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1)) 47 | layers.append(nn.ReLU(inplace=True)) 48 | if batch_norm: 49 | layers.append(nn.BatchNorm2d(outchannels)) 50 | if pool: 51 | layers.append(nn.MaxPool2d(2, 2)) 52 | self.layers = nn.Sequential(*layers) 53 | 54 | def forward(self, x): 55 | return self.layers(x) 56 | 57 | class HomographyModel(nn.Module): 58 | def __init__(self, batch_norm=False): 59 | super(HomographyModel, self).__init__() 60 | self.feature = nn.Sequential( 61 | ConvBlock(2, 64, batch_norm), 62 | ConvBlock(64, 64, batch_norm), 63 | ConvBlock(64, 128, batch_norm), 64 | ConvBlock(128, 128, batch_norm, pool=False), 65 | ) 66 | self.fc = nn.Sequential( 67 | nn.Dropout(0.5), 68 | nn.Linear(128 * 16 * 16, 1024), 69 | nn.ReLU(inplace=True), 70 | nn.Dropout(0.5), 71 | nn.Linear(1024, 8) 72 | ) 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 77 | if m.bias is not None: 78 | nn.init.constant_(m.bias, 0) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | nn.init.constant_(m.weight, 1) 81 | nn.init.constant_(m.bias, 0) 82 | elif isinstance(m, nn.Linear): 83 | nn.init.normal_(m.weight, 0, 0.01) 84 | nn.init.constant_(m.bias, 0) 85 | 86 | def forward(self, I1_aug, I2_aug, I_aug, h4p, gt, patch_indices): 87 | batch_size, _, img_h, img_w = I_aug.size() 88 | _, _, patch_size, patch_size = I1_aug.size() 89 | 90 | y_t = torch.arange(0, batch_size * img_w * img_h, 91 | img_w * img_h) 92 | batch_indices_tensor = y_t.unsqueeze(1).expand(y_t.shape[0], patch_size * patch_size).reshape(-1) 93 | 94 | M_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0], 95 | [0., img_h / 2.0, img_h / 2.0], 96 | [0., 0., 1.]]) 97 | 98 | if torch.cuda.is_available(): 99 | M_tensor = M_tensor.cuda() 100 | batch_indices_tensor = batch_indices_tensor.cuda() 101 | 102 | M_tile = M_tensor.unsqueeze(0).expand(batch_size, M_tensor.shape[-2], M_tensor.shape[-1]) 103 | 104 | # Inverse of M 105 | M_tensor_inv = torch.inverse(M_tensor) 106 | M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, M_tensor_inv.shape[-2], 107 | M_tensor_inv.shape[-1]) 108 | 109 | pred_h4p = self.build_model(I1_aug, I2_aug) 110 | 111 | H_mat = self.solve_DLT(h4p, pred_h4p).squeeze(1) 112 | 113 | pred_I2 = self.transform(patch_size, M_tile_inv, H_mat, M_tile, 114 | I_aug, patch_indices, batch_indices_tensor) 115 | 116 | h_loss = torch.sqrt(torch.mean((pred_h4p - gt) ** 2)) 117 | rec_loss, ssim_loss, l1_loss, l1_smooth_loss, ncc_loss = self.build_losses(pred_I2, I2_aug) 118 | 119 | out_dict = {} 120 | out_dict.update(h_loss=h_loss, rec_loss=rec_loss, ssim_loss=ssim_loss, l1_loss=l1_loss, 121 | l1_smooth_loss=l1_smooth_loss, ncc_loss=ncc_loss, 122 | pred_h4p=pred_h4p, H_mat=H_mat, pred_I2=pred_I2) 123 | 124 | return out_dict 125 | 126 | def build_model(self, I1_aug, I2_aug): 127 | model_input = torch.cat([I1_aug, I2_aug], dim=1) 128 | x = self.feature(model_input) 129 | x = x.view(x.size(0), -1) 130 | x = self.fc(x) 131 | return x 132 | 133 | def solve_DLT(self, src_p, off_set): 134 | # src_p: shape=(bs, n, 4, 2) 135 | # off_set: shape=(bs, n, 4, 2) 136 | # can be used to compute mesh points (multi-H) 137 | 138 | bs, _ = src_p.shape 139 | divide = int(np.sqrt(len(src_p[0]) / 2) - 1) 140 | row_num = (divide + 1) * 2 141 | 142 | for i in range(divide): 143 | for j in range(divide): 144 | 145 | h4p = src_p[:, [2 * j + row_num * i, 2 * j + row_num * i + 1, 146 | 2 * (j + 1) + row_num * i, 2 * (j + 1) + row_num * i + 1, 147 | 2 * (j + 1) + row_num * i + row_num, 2 * (j + 1) + row_num * i + row_num + 1, 148 | 2 * j + row_num * i + row_num, 2 * j + row_num * i + row_num + 1]].reshape(bs, 1, 4, 2) 149 | 150 | pred_h4p = off_set[:, [2 * j + row_num * i, 2 * j + row_num * i + 1, 151 | 2 * (j + 1) + row_num * i, 2 * (j + 1) + row_num * i + 1, 152 | 2 * (j + 1) + row_num * i + row_num, 2 * (j + 1) + row_num * i + row_num + 1, 153 | 2 * j + row_num * i + row_num, 2 * j + row_num * i + row_num + 1]].reshape(bs, 1, 154 | 4, 2) 155 | 156 | if i + j == 0: 157 | src_ps = h4p 158 | off_sets = pred_h4p 159 | else: 160 | src_ps = torch.cat((src_ps, h4p), axis=1) 161 | off_sets = torch.cat((off_sets, pred_h4p), axis=1) 162 | 163 | bs, n, h, w = src_ps.shape 164 | 165 | N = bs * n 166 | 167 | src_ps = src_ps.reshape(N, h, w) 168 | off_sets = off_sets.reshape(N, h, w) 169 | 170 | dst_p = src_ps + off_sets 171 | 172 | ones = torch.ones(N, 4, 1) 173 | if torch.cuda.is_available(): 174 | ones = ones.cuda() 175 | xy1 = torch.cat((src_ps, ones), 2) 176 | zeros = torch.zeros_like(xy1) 177 | if torch.cuda.is_available(): 178 | zeros = zeros.cuda() 179 | 180 | xyu, xyd = torch.cat((xy1, zeros), 2), torch.cat((zeros, xy1), 2) 181 | M1 = torch.cat((xyu, xyd), 2).reshape(N, -1, 6) 182 | M2 = torch.matmul( 183 | dst_p.reshape(-1, 2, 1), 184 | src_ps.reshape(-1, 1, 2), 185 | ).reshape(N, -1, 2) 186 | 187 | A = torch.cat((M1, -M2), 2) 188 | b = dst_p.reshape(N, -1, 1) 189 | 190 | Ainv = torch.inverse(A) 191 | h8 = torch.matmul(Ainv, b).reshape(N, 8) 192 | 193 | H = torch.cat((h8, ones[:, 0, :]), 1).reshape(N, 3, 3) 194 | H = H.reshape(bs, n, 3, 3) 195 | return H 196 | 197 | def transform(self, patch_size, M_tile_inv, H_mat, M_tile, I, patch_indices, batch_indices_tensor): 198 | # Transform H_mat since we scale image indices in transformer 199 | batch_size, num_channels, img_h, img_w = I.size() 200 | # if torch.cuda.is_available(): 201 | # M_tile_inv = M_tile_inv.cuda() 202 | H_mat = torch.matmul(torch.matmul(M_tile_inv, H_mat), M_tile) 203 | # Transform image 1 (large image) to image 2 204 | out_size = (img_h, img_w) 205 | warped_images, _ = transformer(I, H_mat, out_size) 206 | 207 | # Extract the warped patch from warped_images by flatting the whole batch before using indices 208 | # Note that input I is 3 channels so we reduce to gray 209 | warped_gray_images = torch.mean(warped_images, dim=3) 210 | warped_images_flat = torch.reshape(warped_gray_images, [-1]) 211 | patch_indices_flat = torch.reshape(patch_indices, [-1]) 212 | pixel_indices = patch_indices_flat.long() + batch_indices_tensor 213 | pred_I2_flat = torch.gather(warped_images_flat, 0, pixel_indices) 214 | 215 | pred_I2 = torch.reshape(pred_I2_flat, [batch_size, patch_size, patch_size, 1]) 216 | 217 | return pred_I2.permute(0, 3, 1, 2) 218 | 219 | def build_losses(self, pred_I2, I2_aug): 220 | rec_loss = torch.sqrt(torch.mean((pred_I2 - I2_aug) ** 2)) 221 | ssim_loss = torch.mean(SSIM_loss(pred_I2, I2_aug)) 222 | l1_loss = torch.mean(torch.abs(pred_I2 - I2_aug)) 223 | l1_smooth_loss = L1_smooth_loss(pred_I2, I2_aug) 224 | ncc_loss = NCC_loss(I2_aug, pred_I2) 225 | return rec_loss, ssim_loss, l1_loss, l1_smooth_loss, ncc_loss 226 | -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/code/utils/__init__.py -------------------------------------------------------------------------------- /code/utils/gen_synthetic_data.py: -------------------------------------------------------------------------------- 1 | import os, shutil, argparse 2 | import glob 3 | import cv2 4 | import random 5 | import numpy as np 6 | from numpy.linalg import inv 7 | from numpy_spatial_transformer import numpy_transformer 8 | 9 | 10 | def homographyGeneration(args, raw_image_path, index, I_dir, I_prime_dir, gt_file, pts1_file, filenames_file): 11 | rho = args.rho 12 | patch_size = args.patch_size 13 | height = args.img_h 14 | width = args.img_w 15 | 16 | try: 17 | color_image = cv2.imread(raw_image_path) 18 | color_image = cv2.resize(color_image, (width, height)) 19 | gray_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2GRAY) 20 | except: 21 | print('Error with image:', raw_image_path) 22 | return index, -1 23 | 24 | # Randomly pick the top left point of the patch on the real image 25 | y = random.randint(rho, height - rho - patch_size) # row? 26 | x = random.randint(rho, width - rho - patch_size) # col? 27 | 28 | # define corners of image patch 29 | top_left_point = (x, y) 30 | bottom_left_point = (patch_size + x, y) 31 | bottom_right_point = (patch_size + x, patch_size + y) 32 | top_right_point = (x, patch_size + y) 33 | four_points = [top_left_point, bottom_left_point, bottom_right_point, top_right_point] 34 | perturbed_four_points = [] 35 | for point in four_points: 36 | perturbed_four_points.append((point[0] + random.randint(-rho, rho), point[1] + random.randint(-rho, rho))) 37 | 38 | # compute Homography 39 | H = cv2.getPerspectiveTransform(np.float32(four_points), np.float32(perturbed_four_points)) 40 | try: 41 | H_inverse = inv(H) 42 | except: 43 | print("singular Error!") 44 | return index, -1 45 | 46 | inv_warped_color_image = None 47 | inv_warped_image = None 48 | if args.color: 49 | inv_warped_color_image = numpy_transformer(color_image, H_inverse, (width, height)) 50 | else: 51 | inv_warped_image = numpy_transformer(gray_image, H_inverse, (width, height)) 52 | 53 | # Extreact image patches (not used) 54 | if args.color: 55 | original_patch = gray_image[y:y + patch_size, x:x + patch_size] 56 | else: 57 | warped_patch = inv_warped_image[y:y + patch_size, x:x + patch_size] 58 | 59 | ###################################################################################### 60 | # Save synthetic data I_dir I_prime_dir gt pts1 61 | large_img_path = os.path.join(I_dir, str(index) + '.jpg') 62 | if args.mode == 'train' and args.color == False: 63 | cv2.imwrite(large_img_path, gray_image) 64 | else: 65 | cv2.imwrite(large_img_path, color_image) 66 | 67 | if I_prime_dir is not None: 68 | img_prime_path = os.path.join(I_prime_dir, str(index) + '.jpg') 69 | if args.mode == 'train' and args.color == False: 70 | cv2.imwrite(img_prime_path, inv_warped_image) 71 | else: 72 | cv2.imwrite(img_prime_path, inv_warped_color_image) 73 | 74 | # Text files to store homography parameters (4 corners) 75 | f_pts1 = open(pts1_file, 'ab') 76 | f_gt = open(gt_file, 'ab') 77 | f_file_list = open(filenames_file, 'ab') 78 | 79 | # Ground truth is delta displacement 80 | gt = np.subtract(np.array(perturbed_four_points), np.array(four_points)) 81 | gt = np.array(gt).flatten().astype(np.float32) 82 | # Four corners in the first image 83 | pts1 = np.array(four_points).flatten().astype(np.float32) 84 | 85 | np.savetxt(f_gt, [gt], fmt='%.1f', delimiter=' ') 86 | np.savetxt(f_pts1, [pts1], fmt='%.1f', delimiter=' ') 87 | f_file_list.write(('%s %s\n' % (str(index) + '.jpg', str(index) + '.jpg')).encode()) 88 | 89 | index += 1 90 | if index % 1000 == 0: 91 | print('--image number ', index) 92 | 93 | f_gt.close() 94 | f_pts1.close() 95 | f_file_list.close() 96 | return index, 0 97 | 98 | 99 | def dataCollection(args): 100 | # Default folders and files for storage 101 | DATA_PATH = None 102 | gt_file = None 103 | pts1_file = None 104 | filenames_file = None 105 | if args.mode == 'train': 106 | DATA_PATH = args.data_path + 'train/' 107 | pts1_file = os.path.join(DATA_PATH, 'pts1.txt') 108 | filenames_file = os.path.join(DATA_PATH, 'train_synthetic.txt') 109 | gt_file = os.path.join(DATA_PATH, 'gt.txt') 110 | elif args.mode == 'test': 111 | DATA_PATH = args.data_path + 'test/' 112 | pts1_file = os.path.join(DATA_PATH, 'test_pts1.txt') 113 | filenames_file = os.path.join(DATA_PATH, 'test_synthetic.txt') 114 | gt_file = os.path.join(DATA_PATH, 'test_gt.txt') 115 | I_dir = DATA_PATH + 'I/' # Large image 116 | I_prime_dir = DATA_PATH + 'I_prime/' # Large image 117 | 118 | try: 119 | os.remove(gt_file) 120 | os.remove(pts1_file) 121 | os.remove(filenames_file) 122 | print('-- Current {} existed. Deleting..!'.format(gt_file)) 123 | shutil.rmtree(I_dir, ignore_errors=True) 124 | if I_prime_dir is not None: 125 | shutil.rmtree(I_prime_dir, ignore_errors=True) 126 | except: 127 | print('-- Current {} not existed yet!'.format(gt_file)) 128 | 129 | if not os.path.exists(I_dir): 130 | os.makedirs(I_dir) 131 | if I_prime_dir is not None and not os.path.exists(I_prime_dir): 132 | os.makedirs(I_prime_dir) 133 | 134 | raw_image_list = glob.glob(os.path.join(args.raw_data_path, '*.jpg')) 135 | print("Generate {} {} files from {} raw data.".format(args.num_data, args.mode, len(raw_image_list))) 136 | 137 | index = 0 138 | while True: 139 | raw_img_name = random.choice(raw_image_list) 140 | raw_image_path = os.path.join(args.raw_data_path, raw_img_name) 141 | index, error = homographyGeneration(args, raw_image_path, index, 142 | I_dir, I_prime_dir, gt_file, pts1_file, filenames_file) 143 | if error == -1: 144 | continue 145 | if index >= args.num_data: 146 | break 147 | 148 | 149 | def main(): 150 | RHO = 45 # The maximum value of pertubation 151 | 152 | DATA_NUMBER = 100000 153 | TEST_DATA_NUMBER = 5000 154 | 155 | # Size of synthetic image 156 | HEIGHT = 240 # 157 | WIDTH = 320 158 | PATCH_SIZE = 128 159 | 160 | # Directories to files 161 | RAW_DATA_PATH = "D:/Workspace/Datasets/coco2014/train2014/" # Real images used for generating synthetic data 162 | TEST_RAW_DATA_PATH = "D:/Workspace/Datasets/coco2014/test2014/" # Real images used for generating test synthetic data 163 | 164 | # Synthetic data directories 165 | DATA_PATH = "../data/synthetic/" + str(RHO) + '/' 166 | if not os.path.exists(DATA_PATH): 167 | os.makedirs(DATA_PATH) 168 | 169 | def str2bool(s): 170 | return s.lower() == 'true' 171 | 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--mode', type=str, default='test', help='Train or test', choices=['train', 'test']) 174 | parser.add_argument('--color', type=str2bool, default='true', help='Generate color or gray images') 175 | 176 | parser.add_argument('--raw_data_path', type=str, default=RAW_DATA_PATH, help='The raw data path.') 177 | parser.add_argument('--test_raw_data_path', type=str, default=TEST_RAW_DATA_PATH, help='The test raw data path.') 178 | parser.add_argument('--data_path', type=str, default=DATA_PATH, help='The raw data path.') 179 | parser.add_argument('--num_data', type=int, default=DATA_NUMBER, help='The data size for training') 180 | parser.add_argument('--test_num_data', type=int, default=TEST_DATA_NUMBER, help='The data size for test') 181 | 182 | parser.add_argument('--img_w', type=int, default=WIDTH) 183 | parser.add_argument('--img_h', type=int, default=HEIGHT) 184 | parser.add_argument('--rho', type=int, default=RHO) 185 | parser.add_argument('--patch_size', type=int, default=PATCH_SIZE) 186 | 187 | args = parser.parse_args() 188 | print('<==================== Loading raw data ===================>\n') 189 | if args.mode == 'test': 190 | args.num_data = args.test_num_data 191 | args.raw_data_path = args.test_raw_data_path 192 | 193 | print('<================= Generating Data .... =================>\n') 194 | 195 | dataCollection(args) 196 | 197 | 198 | if __name__ == '__main__': 199 | main() 200 | -------------------------------------------------------------------------------- /code/utils/numpy_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | # Simple version of spatial_transformer.py, work on a single image with multiple channels 2 | import numpy as np 3 | import cv2 4 | import pdb 5 | import matplotlib.pyplot as plt 6 | from skimage import io 7 | ############################################################### 8 | # Changable parameter 9 | SCALE_H = True 10 | # scale_H:# The indices of the grid of the target output is 11 | # scaled to [-1, 1]. Set False to stay in normal mode 12 | def _meshgrid(height, width, scale_H = SCALE_H): 13 | if scale_H: 14 | x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 15 | np.linspace(-1, 1, height)) 16 | ones = np.ones(np.prod(x_t.shape)) 17 | grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 18 | 19 | else: 20 | x_t, y_t = np.meshgrid(range(0,width), range(0,height)) 21 | ones = np.ones(np.prod(x_t.shape)) 22 | grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 23 | # print('--grid size:', grid.shape) 24 | return grid 25 | 26 | 27 | def _interpolate(im, x, y, out_size, scale_H = SCALE_H): 28 | # constants 29 | height = im.shape[0] 30 | width = im.shape[1] 31 | 32 | 33 | height_f = float(height) 34 | width_f = float(width) 35 | out_height = out_size[0] 36 | out_width = out_size[1] 37 | zero = np.zeros([], dtype='int32') 38 | max_y = im.shape[0] - 1 39 | max_x = im.shape[1] - 1 40 | 41 | if scale_H: 42 | # # scale indices from [-1, 1] to [0, width/height] 43 | x = (x + 1.0)*(width_f) / 2.0 44 | y = (y + 1.0)*(height_f) / 2.0 45 | 46 | # do sampling 47 | x0 = np.floor(x).astype(int) 48 | x1 = x0 + 1 49 | y0 = np.floor(y).astype(int) 50 | y1 = y0 + 1 51 | 52 | # print('x0:', x0) 53 | # print('y0:', y0) 54 | # Limit the size of the output image 55 | x0 = np.clip(x0, zero, max_x) 56 | x1 = np.clip(x1, zero, max_x) 57 | y0 = np.clip(y0, zero, max_y) 58 | y1 = np.clip(y1, zero, max_y) 59 | # print('x0:', x0) 60 | # print('y0:', y0) 61 | 62 | Ia = im[ y0, x0, ... ] 63 | Ib = im[ y1, x0, ... ] 64 | Ic = im[ y0, x1, ... ] 65 | Id = im[ y1, x1, ... ] 66 | # print(Ia.shape) 67 | 68 | # print 69 | # plt.figure(2) 70 | # plt.subplot(221) 71 | # plt.imshow(Ia) 72 | # plt.subplot(222) 73 | # plt.imshow(Ib) 74 | # plt.subplot(223) 75 | # plt.imshow(Ic) 76 | # plt.subplot(224) 77 | # plt.imshow(Id) 78 | # plt.show() 79 | 80 | wa = (x1 -x) * (y1-y) 81 | wb = (x1-x) * (y-y0) 82 | wc = (x-x0) * (y1-y) 83 | wd = (x-x0) * (y-y0) 84 | # print 'wabcd...', wa,wb, wc,wd 85 | 86 | # Handle multi channel image 87 | if im.ndim == 3: 88 | num_channels = im.shape[2] 89 | # wa = np.expand_dims(wa, 2) 90 | # wb = np.expand_dims(wb, 2) 91 | # wc = np.expand_dims(wc, 2) 92 | # wd = np.expand_dims(wd, 2) 93 | wa = np.tile(wa.reshape(-1, 1), num_channels) 94 | wb = np.tile(wb.reshape(-1, 1), num_channels) 95 | wc = np.tile(wc.reshape(-1, 1), num_channels) 96 | wd = np.tile(wd.reshape(-1, 1), num_channels) 97 | out = wa*Ia + wb*Ib + wc*Ic + wd*Id 98 | # print('--shape of out:', out.shape) 99 | return out 100 | 101 | def _transform(theta, input_dim, out_size): 102 | height, width = input_dim.shape[0], input_dim.shape[1] 103 | theta = np.reshape(theta, (3, 3)) 104 | # print '--Theta:', theta 105 | # print '-- Theta shape:', theta.shape 106 | 107 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 108 | out_height = out_size[0] 109 | out_width = out_size[1] 110 | grid = _meshgrid(out_height, out_width) 111 | 112 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 113 | T_g = np.dot(theta, grid) 114 | x_s = T_g[0,:] 115 | y_s = T_g[1,:] 116 | t_s = T_g[2,:] 117 | # print '-- T_g:', T_g 118 | # print '-- x_s:', x_s 119 | # print '-- y_s:', y_s 120 | # print '-- t_s:', t_s 121 | 122 | t_s_flat = np.reshape(t_s, [-1]) 123 | # Ty changed 124 | # x_s_flat = np.reshape(x_s, [-1]) 125 | # y_s_flat = np.reshape(y_s, [-1]) 126 | x_s_flat = np.reshape(x_s, [-1])/t_s_flat 127 | y_s_flat = np.reshape(y_s, [-1])/t_s_flat 128 | 129 | 130 | input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size) 131 | if input_dim.ndim == 3: 132 | output = np.reshape(input_transformed, [out_height, out_width, -1]) 133 | else: 134 | output = np.reshape(input_transformed, [out_height, out_width]) 135 | 136 | output = output.astype(np.uint8) 137 | return output 138 | 139 | 140 | def numpy_transformer(img, H, out_size, scale_H = SCALE_H): 141 | h, w = img.shape[0], img.shape[1] 142 | # Matrix M 143 | M = np.array([[w/2.0, 0, w/2.0], [0, h/2.0, h/2.0], [0, 0, 1.]]).astype(np.float32) 144 | 145 | if scale_H: 146 | H_transformed = np.dot(np.dot(np.linalg.inv(M), np.linalg.inv(H)), M) 147 | # print 'H_transformed:', H_transformed 148 | img2 = _transform(H_transformed, img, [h,w]) 149 | else: 150 | img2 = _transform(np.linalg.inv(H), img, [h,w]) 151 | return img2 152 | 153 | 154 | def test_transformer(scale_H = SCALE_H): 155 | img = io.imread('D:/Workspace/Datasets/ms_coco_test_images/COCO_test2014_000000000001.jpg') 156 | h, w = img.shape[0], img.shape[1] 157 | print( '-- h, w:', h, w ) 158 | 159 | 160 | # Apply homography transformation 161 | 162 | H = np.array([[2., 0.3, 5], [0.3, 2., 10.], [0.0001, 0.0002, 1.]]).astype(np.float32) 163 | img2 = cv2.warpPerspective(img, H, (w, h)) 164 | 165 | 166 | # # Matrix M 167 | M = np.array([[w/2.0, 0, w/2.0], [0, h/2.0, h/2.0], [0, 0, 1.]]).astype(np.float32) 168 | 169 | if scale_H: 170 | H_transformed = np.dot(np.dot(np.linalg.inv(M), np.linalg.inv(H)), M) 171 | print('H_transformed:', H_transformed) 172 | img3 = _transform(H_transformed, img, [h,w]) 173 | else: 174 | img3 = _transform(np.linalg.inv(H), img, [h,w]) 175 | 176 | print ( '-- Reprojection error:', np.mean(np.abs(img3 - img2))) 177 | Reprojection = abs(img3 - img2) 178 | # Test on real image 179 | count = 0 180 | amount = 0 181 | for i in range(48): 182 | for j in range(48): 183 | for k in range(2): 184 | if Reprojection[i, j, k] > 10: 185 | print(i, j, k, 'value', Reprojection[i, j, k]) 186 | count += 1 187 | amount += Reprojection[i, j, k] 188 | print('There is total %d > 10, over total %d, account for %.3f'%( count, 48*48*3,amount*1.0/count) ) 189 | 190 | #io.imshow('img3', img3) 191 | try: 192 | plt.subplot(221) 193 | plt.imshow(img) 194 | plt.title('Original image') 195 | 196 | plt.subplot(222) 197 | plt.imshow(img2) 198 | plt.title('cv2.warpPerspective') 199 | 200 | plt.subplot(223) 201 | plt.imshow(img3) 202 | plt.title('Transformer') 203 | 204 | plt.subplot(224) 205 | plt.imshow(Reprojection) 206 | plt.title('Reprojection Error') 207 | plt.show() 208 | except KeyboardInterrupt: 209 | plt.close() 210 | exit(1) 211 | 212 | 213 | if __name__ == "__main__": 214 | test_transformer() 215 | -------------------------------------------------------------------------------- /code/utils/torch_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def transformer(U, theta, out_size, **kwargs): 5 | """Spatial Transformer Layer 6 | 7 | Implements a spatial transformer layer as described in [1]_. 8 | Based on [2]_ and edited by David Dao for Tensorflow. 9 | 10 | Parameters 11 | ---------- 12 | U : float 13 | The output of a convolutional net should have the 14 | shape [num_batch, height, width, num_channels]. 15 | theta: float 16 | The output of the 17 | localisation network should be [num_batch, 6]. 18 | out_size: tuple of two ints 19 | The size of the output of the network (height, width) 20 | 21 | References 22 | ---------- 23 | .. [1] Spatial Transformer Networks 24 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 25 | Submitted on 5 Jun 2015 26 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 27 | 28 | Notes 29 | ----- 30 | To initialize the network to the identity transform init 31 | ``theta`` to : 32 | identity = np.array([[1., 0., 0.], 33 | [0., 1., 0.]]) 34 | identity = identity.flatten() 35 | theta = tf.Variable(initial_value=identity) 36 | 37 | """ 38 | 39 | def _repeat(x, n_repeats): 40 | 41 | rep = torch.ones([n_repeats, ]).unsqueeze(0) 42 | rep = rep.int() 43 | x = x.int() 44 | 45 | x = torch.matmul(x.reshape([-1,1]), rep) 46 | return x.reshape([-1]) 47 | 48 | def _interpolate(im, x, y, out_size, scale_h): 49 | 50 | num_batch, num_channels , height, width = im.size() 51 | 52 | height_f = height 53 | width_f = width 54 | out_height, out_width = out_size[0], out_size[1] 55 | 56 | zero = 0 57 | max_y = height - 1 58 | max_x = width - 1 59 | if scale_h: 60 | 61 | x = (x + 1.0)*(width_f) / 2.0 62 | y = (y + 1.0) * (height_f) / 2.0 63 | 64 | # do sampling 65 | x0 = torch.floor(x).int() 66 | x1 = x0 + 1 67 | y0 = torch.floor(y).int() 68 | y1 = y0 + 1 69 | 70 | x0 = torch.clamp(x0, zero, max_x) 71 | x1 = torch.clamp(x1, zero, max_x) 72 | y0 = torch.clamp(y0, zero, max_y) 73 | y1 = torch.clamp(y1, zero, max_y) 74 | dim2 = torch.from_numpy( np.array(width) ) 75 | dim1 = torch.from_numpy( np.array(width * height) ) 76 | 77 | base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width) 78 | if torch.cuda.is_available(): 79 | dim2 = dim2.cuda() 80 | dim1 = dim1.cuda() 81 | y0 = y0.cuda() 82 | y1 = y1.cuda() 83 | x0 = x0.cuda() 84 | x1 = x1.cuda() 85 | base = base.cuda() 86 | base_y0 = base + y0 * dim2 87 | base_y1 = base + y1 * dim2 88 | idx_a = base_y0 + x0 89 | idx_b = base_y1 + x0 90 | idx_c = base_y0 + x1 91 | idx_d = base_y1 + x1 92 | 93 | # channels dim 94 | im = im.permute(0,2,3,1) 95 | im_flat = im.reshape([-1, num_channels]).float() 96 | 97 | idx_a = idx_a.unsqueeze(-1).long() 98 | idx_a = idx_a.expand(height * width * num_batch,num_channels) 99 | Ia = torch.gather(im_flat, 0, idx_a) 100 | 101 | idx_b = idx_b.unsqueeze(-1).long() 102 | idx_b = idx_b.expand(height * width * num_batch, num_channels) 103 | Ib = torch.gather(im_flat, 0, idx_b) 104 | 105 | idx_c = idx_c.unsqueeze(-1).long() 106 | idx_c = idx_c.expand(height * width * num_batch, num_channels) 107 | Ic = torch.gather(im_flat, 0, idx_c) 108 | 109 | idx_d = idx_d.unsqueeze(-1).long() 110 | idx_d = idx_d.expand(height * width * num_batch, num_channels) 111 | Id = torch.gather(im_flat, 0, idx_d) 112 | 113 | x0_f = x0.float() 114 | x1_f = x1.float() 115 | y0_f = y0.float() 116 | y1_f = y1.float() 117 | 118 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1) 119 | wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1) 120 | wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1) 121 | wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1) 122 | output = wa*Ia+wb*Ib+wc*Ic+wd*Id 123 | 124 | return output 125 | 126 | def _meshgrid(height, width, scale_h): 127 | 128 | if scale_h: 129 | x_t = torch.matmul(torch.ones([height, 1]), 130 | torch.transpose(torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 1), 1, 0)) 131 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1), 132 | torch.ones([1, width])) 133 | else: 134 | x_t = torch.matmul(torch.ones([height, 1]), 135 | torch.transpose(torch.unsqueeze(torch.linspace(0.0, width.float(), width), 1), 1, 0)) 136 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(0.0, height.float(), height), 1), 137 | torch.ones([1, width])) 138 | 139 | 140 | x_t_flat = x_t.reshape((1, -1)).float() 141 | y_t_flat = y_t.reshape((1, -1)).float() 142 | 143 | ones = torch.ones_like(x_t_flat) 144 | grid = torch.cat([x_t_flat, y_t_flat, ones], 0) 145 | if torch.cuda.is_available(): 146 | grid = grid.cuda() 147 | return grid 148 | 149 | def _transform(theta, input_dim, out_size, scale_h): 150 | num_batch, num_channels , height, width = input_dim.size() 151 | # Changed 152 | theta = theta.reshape([-1, 3, 3]).float() 153 | 154 | out_height, out_width = out_size[0], out_size[1] 155 | grid = _meshgrid(out_height, out_width, scale_h) 156 | grid = grid.unsqueeze(0).reshape([1,-1]) 157 | shape = grid.size() 158 | grid = grid.expand(num_batch,shape[1]) 159 | grid = grid.reshape([num_batch, 3, -1]) 160 | 161 | T_g = torch.matmul(theta, grid) 162 | x_s = T_g[:,0,:] 163 | y_s = T_g[:,1,:] 164 | t_s = T_g[:,2,:] 165 | 166 | t_s_flat = t_s.reshape([-1]) 167 | 168 | # smaller 169 | small = 1e-7 170 | smallers = 1e-6*(1.0 - torch.ge(torch.abs(t_s_flat), small).float()) 171 | 172 | t_s_flat = t_s_flat + smallers 173 | condition = torch.sum(torch.gt(torch.abs(t_s_flat), small).float()) 174 | # Ty changed 175 | x_s_flat = x_s.reshape([-1]) / t_s_flat 176 | y_s_flat = y_s.reshape([-1]) / t_s_flat 177 | 178 | input_transformed = _interpolate( input_dim, x_s_flat, y_s_flat,out_size,scale_h) 179 | 180 | output = input_transformed.reshape([num_batch, out_height, out_width, num_channels ]) 181 | return output, condition 182 | 183 | img_w = U.size()[2] 184 | img_h = U.size()[1] 185 | 186 | scale_h = True 187 | output, condition = _transform(theta, U, out_size, scale_h) 188 | return output, condition -------------------------------------------------------------------------------- /code/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def denorm_img(img): 8 | if len(img.shape) == 3: 9 | img = np.transpose(img, [1, 2, 0]) # torch [C,H,W] 10 | mean = np.array([118.93, 113.97, 102.60]).reshape([1, 1, 3]) 11 | std = np.array([69.85, 68.81, 72.45]).reshape([1, 1, 3]) 12 | elif len(img.shape) == 2: 13 | mean = np.mean([118.93, 113.97, 102.60]) 14 | std = np.mean([69.85, 68.81, 72.45]) 15 | return img * std + mean 16 | 17 | 18 | def save_correspondences_img(img1, img2, corr1, corr2, pred_corr2, results_dir, img_name): 19 | """ Save pair of images with their correspondences into a single image. Used for report""" 20 | # Draw prediction 21 | copy_img2 = img2.copy() 22 | copy_img1 = img1.copy() 23 | cv2.polylines(copy_img2, np.int32([pred_corr2]), 1, (5, 225, 225), 3) 24 | 25 | point_color = (0, 255, 255) 26 | line_color_set = [(255, 102, 255), (51, 153, 255), (102, 255, 255), (255, 255, 0), (102, 102, 244), (150, 202, 178), 27 | (153, 240, 142), (102, 0, 51), (51, 51, 0)] 28 | # Draw 4 points (ground truth) 29 | full_stack_images = draw_matches(copy_img1, corr1, copy_img2, corr2, '/tmp/tmp.jpg', color_set=line_color_set, 30 | show=False) 31 | # Save image 32 | visual_file_name = os.path.join(results_dir, img_name) 33 | # cv2.putText(full_stack_images, 'RMSE %.2f'%h_loss,(800, 100), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2) 34 | cv2.imwrite(visual_file_name, full_stack_images) 35 | print('Wrote file %s', visual_file_name) 36 | 37 | 38 | def draw_matches(img1, kp1, img2, kp2, output_img_file=None, color_set=None, show=True): 39 | """Draws lines between matching keypoints of two images without matches. 40 | This is a replacement for cv2.drawMatches 41 | Places the images side by side in a new image and draws circles 42 | around each keypoint, with line segments connecting matching pairs. 43 | You can tweak the r, thickness, and figsize values as needed. 44 | Args: 45 | img1: An openCV image ndarray in a grayscale or color format. 46 | kp1: A list of cv2.KeyPoint objects for img1. 47 | img2: An openCV image ndarray of the same format and with the same 48 | element type as img1. 49 | kp2: A list of cv2.KeyPoint objects for img2. 50 | color_set: The colors of the circles and connecting lines drawn on the images. 51 | A 3-tuple for color images, a scalar for grayscale images. If None, these 52 | values are randomly generated. 53 | """ 54 | # We're drawing them side by side. Get dimensions accordingly. 55 | # Handle both color and grayscale images. 56 | if len(img1.shape) == 3: 57 | new_shape = (max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1], img1.shape[2]) 58 | elif len(img1.shape) == 2: 59 | new_shape = (max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1]) 60 | new_img = np.zeros(new_shape, type(img1.flat[0])) 61 | # Place images onto the new image. 62 | new_img[0:img1.shape[0], 0:img1.shape[1]] = img1 63 | new_img[0:img2.shape[0], img1.shape[1]:img1.shape[1] + img2.shape[1]] = img2 64 | 65 | # Draw lines between points 66 | 67 | kp2_on_stack_image = (kp2 + np.array([img1.shape[1], 0])).astype(np.int32) 68 | 69 | kp1 = kp1.astype(np.int32) 70 | # kp2_on_stack_image[0:4,0:2] 71 | line_color1 = (2, 10, 240) 72 | line_color2 = (2, 10, 240) 73 | # We want to make connections between points to make a square grid so first count the number of rows in the square grid. 74 | grid_num_rows = int(np.sqrt(kp1.shape[0])) 75 | 76 | if output_img_file is not None and grid_num_rows >= 3: 77 | for i in range(grid_num_rows): 78 | cv2.line(new_img, tuple(kp1[i * grid_num_rows]), tuple(kp1[i * grid_num_rows + (grid_num_rows - 1)]), 79 | line_color1, 1, LINE_AA) 80 | cv2.line(new_img, tuple(kp1[i]), tuple(kp1[i + (grid_num_rows - 1) * grid_num_rows]), line_color1, 1, 81 | cv2.LINE_AA) 82 | cv2.line(new_img, tuple(kp2_on_stack_image[i * grid_num_rows]), 83 | tuple(kp2_on_stack_image[i * grid_num_rows + (grid_num_rows - 1)]), line_color2, 1, cv2.LINE_AA) 84 | cv2.line(new_img, tuple(kp2_on_stack_image[i]), 85 | tuple(kp2_on_stack_image[i + (grid_num_rows - 1) * grid_num_rows]), line_color2, 1, cv2.LINE_AA) 86 | 87 | if output_img_file is not None and grid_num_rows == 2: 88 | cv2.polylines(new_img, np.int32([kp2_on_stack_image]), 1, line_color2, 3) 89 | cv2.polylines(new_img, np.int32([kp1]), 1, line_color1, 3) 90 | # Draw lines between matches. Make sure to offset kp coords in second image appropriately. 91 | r = 7 92 | thickness = 1 93 | 94 | for i in range(len(kp1)): 95 | key1 = kp1[i] 96 | key2 = kp2[i] 97 | # Generate random color for RGB/BGR and grayscale images as needed. 98 | try: 99 | c = color_set[i] 100 | except: 101 | c = np.random.randint(0, 256, 3) if len(img1.shape) == 3 else np.random.randint(0, 256) 102 | # So the keypoint locs are stored as a tuple of floats. cv2.line(), like most other things, 103 | # wants locs as a tuple of ints. 104 | end1 = tuple(np.round(key1).astype(int)) 105 | end2 = tuple(np.round(key2).astype(int) + np.array([img1.shape[1], 0])) 106 | cv2.line(new_img, end1, end2, c, thickness, cv2.LINE_AA) 107 | cv2.circle(new_img, end1, r, c, thickness, cv2.LINE_AA) 108 | cv2.circle(new_img, end2, r, c, thickness, cv2.LINE_AA) 109 | # pdb.set_trace() 110 | if show: 111 | plt.figure(figsize=(15, 15)) 112 | if len(img1.shape) == 3: 113 | plt.imshow(new_img) 114 | else: 115 | plt.imshow(new_img) 116 | plt.axis('off') 117 | plt.show() 118 | if output_img_file is not None: 119 | cv2.imwrite(output_img_file, new_img) 120 | 121 | return new_img 122 | 123 | 124 | def find_percentile(x): 125 | x_sorted = np.sort(x) 126 | len_x = len(x_sorted) 127 | 128 | # Find mean, var of top 20: 129 | tops_list = [0.3, 0.6, 1] 130 | return_list = [] 131 | start_index = 0 132 | for i in range(len(tops_list)): 133 | print('>> Top %.0f - %.0f %%' % (tops_list[i - 1] * 100 if i >= 1 else 0, tops_list[i] * 100)) 134 | stop_index = int(tops_list[i] * len_x) 135 | 136 | interval = x_sorted[start_index:stop_index] 137 | interval_mu = np.mean(interval) 138 | interval_std = np.std(interval) 139 | 140 | start_index = stop_index 141 | return_list.append([interval_mu, interval_std]) 142 | return np.array(return_list) 143 | -------------------------------------------------------------------------------- /results/synthetic/report/0025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0025.jpg -------------------------------------------------------------------------------- /results/synthetic/report/0026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0026.jpg -------------------------------------------------------------------------------- /results/synthetic/report/0027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0027.jpg -------------------------------------------------------------------------------- /results/synthetic/report/0028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0028.jpg -------------------------------------------------------------------------------- /results/synthetic/report/0029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0029.jpg --------------------------------------------------------------------------------