├── README.md
├── code
├── dataset.py
├── homography_CNN_synthetic.py
├── homography_model.py
└── utils
│ ├── __init__.py
│ ├── gen_synthetic_data.py
│ ├── numpy_spatial_transformer.py
│ ├── torch_spatial_transformer.py
│ └── utils.py
└── results
└── synthetic
└── report
├── 0025.jpg
├── 0026.jpg
├── 0027.jpg
├── 0028.jpg
└── 0029.jpg
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Deep Homography - PyTorch Implementation
2 |
3 | [**Unsupervised Deep Homography: A Fast and Robust Homography Estimation
4 | Model**](https://arxiv.org/abs/1709.03966)
5 | Ty Nguyen, Steven W. Chen, Shreyas S. Shivakumar, Camillo J. Taylor, Vijay
6 | Kumar
7 |
8 | ```bash
9 | cd code/
10 | ```
11 | in code/ folder:
12 |
13 | `dataset.py`: class SyntheticDataset(torch.utils.data.Dataset) implementation
14 | `homography_model.py`: Unsupervised deep homography model implementation
15 | `homography_CNN_synthetic.py`: Train and test
16 |
17 | ## Preparing training dataset (synthetic)
18 | Download MS-COCO 2014 dataset
19 | Store Train and test set into RAW_DATA_PATH and TEST_RAW_DATA_PATH respectly.
20 | ### Generate training dataset
21 | It will take a few hours to generate 100.000 data samples.
22 | ```bash
23 | python utils/gen_synthetic_data.py --mode train
24 | ```
25 | ### Generate test dataset
26 | ```bash
27 | python utils/gen_synthetic_data.py --mode test
28 | ```
29 |
30 | ## Train model with synthetic dataset
31 | ```bash
32 | python homography_CNN_synthetic.py --mode train
33 | ```
34 |
35 | ## Test model with synthetic dataset
36 | Download pre-trained weights
37 | ```bash
38 | 链接:https://pan.baidu.com/s/102ilb5HJGydpeHtYelx_Xw 提取码:boq9
39 | ```
40 | Store the model to models/synthetic_models folder
41 | ```bash
42 | python homography_CNN_synthetic.py --mode test
43 | ```
44 |
45 | results |
46 | --- |
47 |  |
48 |  |
49 |  |
50 |  |
51 |  |
52 |
53 | ## Release History
54 | * **2021.4.5**
55 | * Add TensorBoard visualization and some metrics.
56 |
57 | ## Reference
58 | [https://github.com/tynguyen/unsupervisedDeepHomographyRAL2018](https://github.com/tynguyen/unsupervisedDeepHomographyRAL2018)
59 |
--------------------------------------------------------------------------------
/code/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2 as cv
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | def read_img_and_gt(filenames_file, pts1_file, gt_file):
8 | with open(filenames_file, 'r') as img_f:
9 | img_array = img_f.readlines()
10 | img_array = [x.strip() for x in img_array]
11 | img_array = [x.split() for x in img_array] # Use x.split()[0] if assuming image left and right have same name
12 |
13 | with open(pts1_file, 'r') as pts1_f:
14 | pts1_array = pts1_f.readlines()
15 | pts1_array = [x.strip() for x in pts1_array]
16 | pts1_array = [x.split() for x in pts1_array]
17 | pts1_array = np.array(pts1_array).astype('float64')
18 |
19 | # In case there is not ground truth
20 | if not gt_file:
21 | return img_array, pts1_array, None
22 |
23 | with open(gt_file, 'r') as gt_f:
24 | gt_array = gt_f.readlines()
25 | gt_array = [x.strip() for x in gt_array]
26 | gt_array = [x.split() for x in gt_array]
27 | gt_array = np.array(gt_array).astype('float64')
28 |
29 | return img_array, pts1_array, gt_array
30 |
31 | def get_mesh_grid_per_img(patch_w, patch_h):
32 | x_flat = np.arange(0, patch_w)
33 | x_flat = x_flat[np.newaxis, :]
34 | y_one = np.ones(patch_h)
35 | y_one = y_one[:, np.newaxis]
36 | x_mesh = np.matmul(y_one, x_flat)
37 |
38 | y_flat = np.arange(0, patch_h)
39 | y_flat = y_flat[:, np.newaxis]
40 | x_one = np.ones(patch_w)
41 | x_one = x_one[np.newaxis, :]
42 | y_mesh = np.matmul(y_flat, x_one)
43 | x_t_flat = np.reshape(x_mesh, (-1))
44 | y_t_flat = np.reshape(y_mesh, (-1))
45 | return x_t_flat, y_t_flat
46 |
47 | class SyntheticDataset(Dataset):
48 | """Load synthetic data"""
49 | def __init__(self, data_path, mode, img_h, img_w, patch_size, do_augment):
50 | self.mode = mode
51 | if self.mode == "train":
52 | self.data_path = data_path + "train/"
53 | self.pts1_file = os.path.join(self.data_path, 'pts1.txt')
54 | self.filenames_file = os.path.join(self.data_path, 'train_synthetic.txt')
55 | self.gt_file = os.path.join(self.data_path, 'gt.txt')
56 | elif self.mode == "test":
57 | self.data_path = data_path + "test/"
58 | self.pts1_file = os.path.join(self.data_path, 'test_pts1.txt')
59 | self.filenames_file = os.path.join(self.data_path, 'test_synthetic.txt')
60 | self.gt_file = os.path.join(self.data_path, 'test_gt.txt')
61 | self.img_h = img_h
62 | self.img_w = img_w
63 | self.patch_size = patch_size
64 | self.do_augment = do_augment
65 | self.mean_I = np.reshape(np.array([118.93, 113.97, 102.60]), (1, 1, 3))
66 | self.std_I = np.reshape(np.array([69.85, 68.81, 72.45]), (1, 1, 3))
67 |
68 | # Read to arrays
69 | self.img_np, self.pts1_np, self.gt_np = read_img_and_gt(self.filenames_file, self.pts1_file, self.gt_file)
70 |
71 | # Find indices of the pixels in the patch w.r.t the large image
72 | # All patches have the same size so their pixels have the same base indices
73 | self.x_t_flat, self.y_t_flat = get_mesh_grid_per_img(patch_size, patch_size)
74 |
75 | def __len__(self):
76 | return len(self.img_np)
77 |
78 | def __getitem__(self, index):
79 | pts1_index = self.pts1_np[index]
80 | gt_index = self.gt_np[index]
81 | split_line = self.img_np[index]
82 |
83 | I_path = self.data_path + 'I/' + split_line[0]
84 | I_prime_path = self.data_path + 'I_prime/' + split_line[1]
85 |
86 | I = self.read_image(I_path, self.img_h, self.img_w)
87 | I_prime = self.read_image(I_prime_path, self.img_h, self.img_w)
88 |
89 | # Data Augmentation
90 | do_augment = np.random.uniform(0, 1)
91 | # Training: use joint augmentation (images in one pair are inserted same noise)
92 | # Test: use disjoint augmentation (images in one pair are inserted different noise)
93 | if self.mode == 'train':
94 | I_aug, I_prime_aug = self.joint_augment_image_pair(I, I_prime, 0, 255) \
95 | if do_augment > (1 - self.do_augment) else (I, I_prime)
96 | else:
97 | I_aug, I_prime_aug = self.disjoint_augment_image_pair(I, I_prime, 0, 255) \
98 | if do_augment > (1 - self.do_augment) else (I, I_prime)
99 |
100 | # Standardize images
101 | I = self.norm_img(I, self.mean_I, self.std_I)
102 | I_prime = self.norm_img(I_prime, self.mean_I, self.std_I)
103 | # These are augmented large images which will be used
104 | I_aug = self.norm_img(I_aug, self.mean_I, self.std_I)
105 | I_prime_aug = self.norm_img(I_prime_aug, self.mean_I, self.std_I)
106 |
107 | # Read patch_indices
108 | x_start = pts1_index[0] # x
109 | y_start = pts1_index[1] # y
110 | patch_indices = (self.y_t_flat + y_start) * self.img_w + (self.x_t_flat + x_start)
111 |
112 | # Convert to tensor
113 | I = torch.tensor(I)
114 | I_prime = torch.tensor(I_prime)
115 | I_aug = torch.tensor(I_aug)
116 | I_prime_aug = torch.tensor(I_prime_aug)
117 | pts1_tensor = torch.tensor(pts1_index)
118 | gt_tensor = torch.tensor(gt_index)
119 | patch_indices = torch.tensor(patch_indices)
120 |
121 | # Obtain I1, I2, I1_aug and I2_aug
122 | I_flat = torch.reshape(torch.mean(I, 0), [-1]) # I: 3xHxW
123 | I_prime_flat = torch.reshape(torch.mean(I_prime, 0), [-1]) # I_prime: 3xHxW
124 | I_aug_flat = torch.reshape(torch.mean(I_aug, 0), [-1]) # I_aug: 3xHxW
125 | I_prime_aug_flat = torch.reshape(torch.mean(I_prime_aug, 0), [-1]) # I_prime_aug: 3xHxW
126 |
127 | patch_indices_flat = torch.reshape(patch_indices, [-1])
128 | pixel_indices = patch_indices_flat.long()
129 |
130 | I1_flat = torch.gather(I_flat, 0, pixel_indices)
131 | I2_flat = torch.gather(I_prime_flat, 0, pixel_indices)
132 | I1_aug_flat = torch.gather(I_aug_flat, 0, pixel_indices)
133 | I2_aug_flat = torch.gather(I_prime_aug_flat, 0, pixel_indices)
134 |
135 | I1 = torch.reshape(I1_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1)
136 | I2 = torch.reshape(I2_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1)
137 | I1_aug = torch.reshape(I1_aug_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1)
138 | I2_aug = torch.reshape(I2_aug_flat, [self.patch_size, self.patch_size, 1]).permute(2, 0, 1)
139 |
140 | return I1, I2, I1_aug, I2_aug, I_aug, I_prime_aug, pts1_tensor, gt_tensor, patch_indices
141 |
142 | def read_image(self, image_path, img_h, img_w):
143 | image = cv.imread(image_path)
144 | height, width = image.shape[:2]
145 | if height != img_h or width != img_w:
146 | image = cv.resize(image, (img_w, img_h), interpolation=cv.INTER_AREA)
147 | return image
148 |
149 | def norm_img(self, img, mean, std):
150 | img = (img - mean) / std
151 | img = np.transpose(img, [2, 0, 1]) # torch [C,H,W]
152 | return img
153 |
154 | def disjoint_augment_image_pair(self, img1, img2, min_val=0, max_val=255):
155 | # Randomly shift gamma
156 | random_gamma = np.random.uniform(0.8, 1.2)
157 | img1_aug = img1 ** random_gamma
158 | random_gamma = np.random.uniform(0.8, 1.2)
159 | img2_aug = img2 ** random_gamma
160 |
161 | # Randomly shift brightness
162 | random_brightness = np.random.uniform(0.5, 2.0)
163 | img1_aug = img1_aug * random_brightness
164 | random_brightness = np.random.uniform(0.5, 2.0)
165 | img2_aug = img2_aug * random_brightness
166 |
167 | # Randomly shift color
168 | random_colors = np.random.uniform(0.8, 1.2, 3)
169 | white = np.ones([img1.shape[0], img1.shape[1], 1])
170 | color_image = np.concatenate([white * random_colors[i] for i in range(3)], axis=2)
171 | img1_aug *= color_image
172 |
173 | random_colors = np.random.uniform(0.8, 1.2, 3)
174 | white = np.ones([img1.shape[0], img1.shape[1], 1])
175 | color_image = np.concatenate([white * random_colors[i] for i in range(3)], axis=2)
176 | img2_aug *= color_image
177 |
178 | # Saturate
179 | img1_aug = np.clip(img1_aug, min_val, max_val)
180 | img2_aug = np.clip(img2_aug, min_val, max_val)
181 |
182 | return img1_aug, img2_aug
183 |
184 | def joint_augment_image_pair(self, img1, img2, min_val=0, max_val=255):
185 | # Randomly shift gamma
186 | random_gamma = np.random.uniform(0.8, 1.2)
187 | img1_aug = img1 ** random_gamma
188 | img2_aug = img2 ** random_gamma
189 |
190 | # Randomly shift brightness
191 | random_brightness = np.random.uniform(0.5, 2.0)
192 | img1_aug = img1_aug * random_brightness
193 | img2_aug = img2_aug * random_brightness
194 |
195 | # Randomly shift color
196 | random_colors = np.random.uniform(0.8, 1.2, 3)
197 | white = np.ones([img1.shape[0], img1.shape[1], 1])
198 | color_image = np.concatenate([white * random_colors[i] for i in range(3)], axis=2)
199 | img1_aug *= color_image
200 | img2_aug *= color_image
201 |
202 | # Saturate
203 | img1_aug = np.clip(img1_aug, min_val, max_val)
204 | img2_aug = np.clip(img2_aug, min_val, max_val)
205 |
206 | return img1_aug, img2_aug
207 |
208 |
209 | if __name__ == "__main__":
210 | TrainDataset = SyntheticDataset(data_path="../data/synthetic/45/",
211 | mode='train',
212 | img_h=240,
213 | img_w=320,
214 | patch_size=128,
215 | do_augment=0.5)
216 | print(len(TrainDataset))
217 | sample = TrainDataset[np.random.randint(0, len(TrainDataset))]
218 | for attr in sample:
219 | print(attr.shape, attr.dtype)
--------------------------------------------------------------------------------
/code/homography_CNN_synthetic.py:
--------------------------------------------------------------------------------
1 | import os, shutil
2 | import argparse
3 | import torch
4 | from torch import optim
5 | from torch.utils.data import DataLoader
6 | from tensorboardX import SummaryWriter
7 | from dataset import SyntheticDataset
8 | from homography_model import HomographyModel
9 | from utils import utils
10 | import numpy as np
11 | import math
12 | import time
13 |
14 | os.environ['CUDA_VISIBLE_DEVICES'] = "0"
15 |
16 |
17 | def train(args):
18 | # Load data
19 | TrainDataset = SyntheticDataset(data_path=args.data_path,
20 | mode=args.mode,
21 | img_h=args.img_h,
22 | img_w=args.img_w,
23 | patch_size=args.patch_size,
24 | do_augment=args.do_augment)
25 | train_loader = DataLoader(TrainDataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
26 | print('===> Train: There are totally {} training files'.format(len(TrainDataset)))
27 |
28 | net = HomographyModel(args.use_batch_norm)
29 | if args.resume:
30 | model_path = os.path.join(args.model_dir, args.model_name)
31 | ckpt = torch.load(model_path)
32 | net.load_state_dict(ckpt.state_dict())
33 | if torch.cuda.is_available():
34 | net = net.cuda()
35 |
36 | optimizer = optim.Adam(net.parameters(), lr=args.lr) # default as 0.0001
37 | decay_rate = 0.96
38 | step_size = (math.log(decay_rate) * args.max_epochs) / math.log(args.min_lr * 1.0 / args.lr)
39 | print('args lr:', args.lr, args.min_lr)
40 | print('===> Decay steps:', step_size)
41 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(step_size), gamma=0.96)
42 |
43 | print("start training")
44 | writer = SummaryWriter(logdir=args.log_dir, flush_secs=60)
45 | score_print_fre = 100
46 | summary_fre = 1000
47 | model_save_fre = 4000
48 | glob_iter = 0
49 | t0 = time.time()
50 |
51 | for epoch in range(args.max_epochs):
52 | net.train()
53 | epoch_start = time.time()
54 | train_l1_loss = 0.0
55 | train_l1_smooth_loss = 0.0
56 | train_h_loss = 0.0
57 |
58 | for i, batch_value in enumerate(train_loader):
59 | I1_batch = batch_value[0].float()
60 | I2_batch = batch_value[1].float()
61 | I1_aug_batch = batch_value[2].float()
62 | I2_aug_batch = batch_value[3].float()
63 | I_batch = batch_value[4].float()
64 | I_prime_batch = batch_value[5].float()
65 | pts1_batch = batch_value[6].float()
66 | gt_batch = batch_value[7].float()
67 | patch_indices_batch = batch_value[8].float()
68 |
69 | if torch.cuda.is_available():
70 | I1_aug_batch = I1_aug_batch.cuda()
71 | I2_aug_batch = I2_aug_batch.cuda()
72 | I_batch = I_batch.cuda()
73 | pts1_batch = pts1_batch.cuda()
74 | gt_batch = gt_batch.cuda()
75 | patch_indices_batch = patch_indices_batch.cuda()
76 |
77 | # forward, backward, update weights
78 | optimizer.zero_grad()
79 | batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch)
80 | h_loss = batch_out['h_loss']
81 | rec_loss = batch_out['rec_loss']
82 | ssim_loss = batch_out['ssim_loss']
83 | l1_loss = batch_out['l1_loss']
84 | l1_smooth_loss = batch_out['l1_smooth_loss']
85 | ncc_loss = batch_out['ncc_loss']
86 | pred_I2 = batch_out['pred_I2']
87 |
88 | loss = l1_loss
89 | loss.backward()
90 | optimizer.step()
91 |
92 | train_l1_loss += loss.item()
93 | train_l1_smooth_loss += l1_smooth_loss.item()
94 | train_h_loss += h_loss.item()
95 | if (i + 1) % score_print_fre == 0 or (i + 1) == len(train_loader):
96 | print(
97 | "Training: Epoch[{:0>3}/{:0>3}] Iter[{:0>3}]/[{:0>3}] l1 loss: {:.4f} "
98 | "l1 smooth loss: {:.4f} h loss: {:.4f} lr={:.8f}".format(
99 | epoch + 1, args.max_epochs, i + 1, len(train_loader), train_l1_loss / score_print_fre,
100 | train_l1_smooth_loss / score_print_fre, train_h_loss / score_print_fre, scheduler.get_lr()[0]))
101 | train_l1_loss = 0.0
102 | train_l1_smooth_loss = 0.0
103 | train_h_loss = 0.0
104 |
105 | if glob_iter % summary_fre == 0:
106 | writer.add_scalar('learning_rate', scheduler.get_lr()[0], glob_iter)
107 | writer.add_scalar('h_loss', h_loss, glob_iter)
108 | writer.add_scalar('rec_loss', rec_loss, glob_iter)
109 | writer.add_scalar('ssim_loss', ssim_loss, glob_iter)
110 | writer.add_scalar('l1_loss', l1_loss, glob_iter)
111 | writer.add_scalar('l1_smooth_loss', l1_smooth_loss, glob_iter)
112 | writer.add_scalar('ncc_loss', ncc_loss, glob_iter)
113 |
114 | writer.add_image('I', utils.denorm_img(I_batch[0, ...].cpu().numpy()).astype(np.uint8)[:, :, ::-1],
115 | glob_iter, dataformats='HWC')
116 | writer.add_image('I_prime',
117 | utils.denorm_img(I_prime_batch[0, ...].numpy()).astype(np.uint8)[:, :, ::-1],
118 | glob_iter, dataformats='HWC')
119 |
120 | writer.add_image('I1_aug', utils.denorm_img(I1_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8),
121 | glob_iter, dataformats='HW')
122 | writer.add_image('I2_aug', utils.denorm_img(I2_aug_batch[0, 0, ...].cpu().numpy()).astype(np.uint8),
123 | glob_iter, dataformats='HW')
124 | writer.add_image('pred_I2',
125 | utils.denorm_img(pred_I2[0, 0, ...].cpu().detach().numpy()).astype(np.uint8),
126 | glob_iter, dataformats='HW')
127 |
128 | writer.add_image('I2', utils.denorm_img(I2_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter,
129 | dataformats='HW')
130 | writer.add_image('I1', utils.denorm_img(I1_batch[0, 0, ...].numpy()).astype(np.uint8), glob_iter,
131 | dataformats='HW')
132 |
133 | # save model
134 | if glob_iter % model_save_fre == 0 and glob_iter != 0:
135 | filename = 'model' + '_iter_' + str(glob_iter) + '.pth'
136 | model_save_path = os.path.join(args.model_dir, filename)
137 | torch.save(net, model_save_path)
138 |
139 | glob_iter += 1
140 | scheduler.step()
141 | print("Epoch: {} epoch time: {:.1f}s".format(epoch, time.time() - epoch_start))
142 |
143 | elapsed_time = time.time() - t0
144 | print("Finished Training in {:.0f}h {:.0f}m {:.0f}s.".format(
145 | elapsed_time // 3600, (elapsed_time % 3600) // 60, (elapsed_time % 3600) % 60))
146 |
147 |
148 | def test(args):
149 | # Load data
150 | TestDataset = SyntheticDataset(data_path=args.data_path,
151 | mode=args.mode,
152 | img_h=args.img_h,
153 | img_w=args.img_w,
154 | patch_size=args.patch_size,
155 | do_augment=args.do_augment)
156 | test_loader = DataLoader(TestDataset, batch_size=1)
157 | print('===> Test: There are totally {} testing files'.format(len(TestDataset)))
158 |
159 | # Load model
160 | net = HomographyModel()
161 | model_path = os.path.join(args.model_dir, args.model_name)
162 | state = torch.load(model_path)
163 | net.load_state_dict(state.state_dict())
164 | if torch.cuda.is_available():
165 | net = net.cuda()
166 |
167 | print("start testing")
168 |
169 | with torch.no_grad():
170 | net.eval()
171 | test_l1_loss = 0.0
172 | test_h_loss = 0.0
173 | h_losses_array = []
174 | for i, batch_value in enumerate(test_loader):
175 | I1_aug_batch = batch_value[2].float()
176 | I2_aug_batch = batch_value[3].float()
177 | I_batch = batch_value[4].float()
178 | I_prime_batch = batch_value[5].float()
179 | pts1_batch = batch_value[6].float()
180 | gt_batch = batch_value[7].float()
181 | patch_indices_batch = batch_value[8].float()
182 |
183 | if torch.cuda.is_available():
184 | I1_aug_batch = I1_aug_batch.cuda()
185 | I2_aug_batch = I2_aug_batch.cuda()
186 | I_batch = I_batch.cuda()
187 | pts1_batch = pts1_batch.cuda()
188 | gt_batch = gt_batch.cuda()
189 | patch_indices_batch = patch_indices_batch.cuda()
190 |
191 | batch_out = net(I1_aug_batch, I2_aug_batch, I_batch, pts1_batch, gt_batch, patch_indices_batch)
192 | h_loss = batch_out['h_loss']
193 | rec_loss = batch_out['rec_loss']
194 | ssim_loss = batch_out['ssim_loss']
195 | l1_loss = batch_out['l1_loss']
196 | pred_h4p_value = batch_out['pred_h4p']
197 |
198 | test_h_loss += h_loss.item()
199 | test_l1_loss += l1_loss.item()
200 | h_losses_array.append(h_loss.item())
201 |
202 | if args.save_visual:
203 | I_sample = utils.denorm_img(I_batch[0].cpu().numpy()).astype(np.uint8)
204 | I_prime_sample = utils.denorm_img(I_prime_batch[0].numpy()).astype(np.uint8)
205 | pts1_sample = pts1_batch[0].cpu().numpy().reshape([4, 2]).astype(np.float32)
206 | gt_h4p_sample = gt_batch[0].cpu().numpy().reshape([4, 2]).astype(np.float32)
207 |
208 | pts2_sample = pts1_sample + gt_h4p_sample
209 |
210 | pred_h4p_sample = pred_h4p_value[0].cpu().numpy().reshape([4, 2]).astype(np.float32)
211 | pred_pts2_sample = pts1_sample + pred_h4p_sample
212 |
213 | # Save
214 | visual_file_name = ('%s' % i).zfill(4) + '.jpg'
215 | utils.save_correspondences_img(I_prime_sample, I_sample, pts1_sample, pts2_sample, pred_pts2_sample,
216 | args.results_dir, visual_file_name)
217 |
218 | print("Testing: h_loss: {:4.3f}, rec_loss: {:4.3f}, ssim_loss: {:4.3f}, l1_loss: {:4.3f}".format(
219 | h_loss.item(), rec_loss.item(), ssim_loss.item(), l1_loss.item()
220 | ))
221 |
222 | print('|Test size | h_loss | l1_loss |')
223 | print(len(test_loader), test_h_loss / len(test_loader), test_l1_loss / len(test_loader))
224 |
225 | tops_list = utils.find_percentile(h_losses_array)
226 | print('===> Percentile Values: (20, 50, 80, 100):')
227 | print(tops_list)
228 | print('======> End! ====================================')
229 |
230 |
231 | def main():
232 | # Size of synthetic image and the pertubation range (RH0)
233 | HEIGHT = 240 #
234 | WIDTH = 320
235 | RHO = 45
236 | PATCH_SIZE = 128
237 |
238 | # Synthetic data directories
239 | DATA_PATH = "../data/synthetic/" + str(RHO) + '/'
240 |
241 | # Log and model directories
242 | MAIN_LOG_PATH = '../'
243 | LOG_DIR = MAIN_LOG_PATH + "logs/"
244 | MODEL_DIR = MAIN_LOG_PATH + "models/synthetic_models"
245 |
246 | # Where to save visualization images (for report)
247 | RESULTS_DIR = MAIN_LOG_PATH + "results/synthetic/report/"
248 |
249 | def str2bool(s):
250 | return s.lower() == 'true'
251 |
252 | parser = argparse.ArgumentParser()
253 | parser.add_argument('--mode', type=str, default='train', help='Train or test', choices=['train', 'test'])
254 | parser.add_argument('--use_batch_norm', type=str2bool, default='False', help='Use batch_norm?')
255 | parser.add_argument('--do_augment', type=float, default=0.5,
256 | help='Possibility of augmenting image: color shift, brightness shift...')
257 |
258 | parser.add_argument('--data_path', type=str, default=DATA_PATH, help='The raw data path.')
259 | parser.add_argument('--log_dir', type=str, default=LOG_DIR, help='The log path')
260 | parser.add_argument('--results_dir', type=str, default=RESULTS_DIR, help='Store visualization for report')
261 | parser.add_argument('--model_dir', type=str, default=MODEL_DIR, help='The models path')
262 | parser.add_argument('--model_name', type=str, default='model.pth', help='The model name')
263 |
264 | parser.add_argument('--save_visual', type=str2bool, default='True', help='Save visual images for report')
265 |
266 | parser.add_argument('--img_w', type=int, default=WIDTH)
267 | parser.add_argument('--img_h', type=int, default=HEIGHT)
268 | parser.add_argument('--patch_size', type=int, default=PATCH_SIZE)
269 | parser.add_argument('--batch_size', type=int, default=128)
270 | parser.add_argument('--max_epochs', type=int, default=150)
271 | parser.add_argument('--lr', type=float, default=1e-4, help='Max learning rate')
272 | parser.add_argument('--min_lr', type=float, default=.9e-4, help='Min learning rate')
273 |
274 | parser.add_argument('--resume', type=str2bool, default='False',
275 | help='True: restore the existing model. False: retrain')
276 |
277 | args = parser.parse_args()
278 | print('<==================== Loading data ===================>\n')
279 |
280 | if not args.resume:
281 | try:
282 | shutil.rmtree(args.log_dir)
283 | except:
284 | pass
285 |
286 | if not os.path.exists(args.model_dir):
287 | os.makedirs(args.model_dir)
288 | if not os.path.exists(args.log_dir):
289 | os.makedirs(args.log_dir)
290 | if not os.path.exists(args.results_dir):
291 | os.makedirs(args.results_dir)
292 |
293 | if args.mode == "train":
294 | train(args)
295 | else:
296 | test(args)
297 |
298 |
299 | if __name__ == "__main__":
300 | main()
301 |
--------------------------------------------------------------------------------
/code/homography_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils.torch_spatial_transformer import transformer
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | def L1_smooth_loss(x, y):
8 | abs_diff = torch.abs(x - y)
9 | abs_diff_lt_1 = torch.le(abs_diff, 1)
10 | return torch.mean(torch.where(abs_diff_lt_1, 0.5 * abs_diff ** 2, abs_diff - 0.5))
11 |
12 | def SSIM_loss(x, y, size=3):
13 | # C = (K*L)^2 with K = max of intensity range (i.e. 255). L is very small
14 | C1 = 0.01 ** 2
15 | C2 = 0.03 ** 2
16 |
17 | mu_x = F.avg_pool2d(x, size, 1, padding=0)
18 | mu_y = F.avg_pool2d(y, size, 1, padding=0)
19 |
20 | sigma_x = F.avg_pool2d(x ** 2, size, 1, padding=0) - mu_x ** 2
21 | sigma_y = F.avg_pool2d(y ** 2, size, 1, padding=0) - mu_y ** 2
22 | sigma_xy = F.avg_pool2d(x * y, size, 1, padding=0) - mu_x * mu_y
23 |
24 | SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
25 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)
26 |
27 | SSIM = SSIM_n / SSIM_d
28 |
29 | return torch.clamp((1 - SSIM) / 2, 0, 1)
30 |
31 | def NCC_loss(x, y):
32 | """Consider x, y are vectors. Take L2 of the difference
33 | of the them after being normalized by their length"""
34 | len_x = torch.sqrt(torch.sum(x ** 2))
35 | len_y = torch.sqrt(torch.sum(y ** 2))
36 | return torch.sqrt(torch.sum((x / len_x - y / len_y) ** 2))
37 |
38 | class ConvBlock(nn.Module):
39 | def __init__(self, inchannels, outchannels, batch_norm=False, pool=True):
40 | super(ConvBlock, self).__init__()
41 | layers = []
42 | layers.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1))
43 | layers.append(nn.ReLU(inplace=True))
44 | if batch_norm:
45 | layers.append(nn.BatchNorm2d(outchannels))
46 | layers.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1))
47 | layers.append(nn.ReLU(inplace=True))
48 | if batch_norm:
49 | layers.append(nn.BatchNorm2d(outchannels))
50 | if pool:
51 | layers.append(nn.MaxPool2d(2, 2))
52 | self.layers = nn.Sequential(*layers)
53 |
54 | def forward(self, x):
55 | return self.layers(x)
56 |
57 | class HomographyModel(nn.Module):
58 | def __init__(self, batch_norm=False):
59 | super(HomographyModel, self).__init__()
60 | self.feature = nn.Sequential(
61 | ConvBlock(2, 64, batch_norm),
62 | ConvBlock(64, 64, batch_norm),
63 | ConvBlock(64, 128, batch_norm),
64 | ConvBlock(128, 128, batch_norm, pool=False),
65 | )
66 | self.fc = nn.Sequential(
67 | nn.Dropout(0.5),
68 | nn.Linear(128 * 16 * 16, 1024),
69 | nn.ReLU(inplace=True),
70 | nn.Dropout(0.5),
71 | nn.Linear(1024, 8)
72 | )
73 |
74 | for m in self.modules():
75 | if isinstance(m, nn.Conv2d):
76 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
77 | if m.bias is not None:
78 | nn.init.constant_(m.bias, 0)
79 | elif isinstance(m, nn.BatchNorm2d):
80 | nn.init.constant_(m.weight, 1)
81 | nn.init.constant_(m.bias, 0)
82 | elif isinstance(m, nn.Linear):
83 | nn.init.normal_(m.weight, 0, 0.01)
84 | nn.init.constant_(m.bias, 0)
85 |
86 | def forward(self, I1_aug, I2_aug, I_aug, h4p, gt, patch_indices):
87 | batch_size, _, img_h, img_w = I_aug.size()
88 | _, _, patch_size, patch_size = I1_aug.size()
89 |
90 | y_t = torch.arange(0, batch_size * img_w * img_h,
91 | img_w * img_h)
92 | batch_indices_tensor = y_t.unsqueeze(1).expand(y_t.shape[0], patch_size * patch_size).reshape(-1)
93 |
94 | M_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],
95 | [0., img_h / 2.0, img_h / 2.0],
96 | [0., 0., 1.]])
97 |
98 | if torch.cuda.is_available():
99 | M_tensor = M_tensor.cuda()
100 | batch_indices_tensor = batch_indices_tensor.cuda()
101 |
102 | M_tile = M_tensor.unsqueeze(0).expand(batch_size, M_tensor.shape[-2], M_tensor.shape[-1])
103 |
104 | # Inverse of M
105 | M_tensor_inv = torch.inverse(M_tensor)
106 | M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, M_tensor_inv.shape[-2],
107 | M_tensor_inv.shape[-1])
108 |
109 | pred_h4p = self.build_model(I1_aug, I2_aug)
110 |
111 | H_mat = self.solve_DLT(h4p, pred_h4p).squeeze(1)
112 |
113 | pred_I2 = self.transform(patch_size, M_tile_inv, H_mat, M_tile,
114 | I_aug, patch_indices, batch_indices_tensor)
115 |
116 | h_loss = torch.sqrt(torch.mean((pred_h4p - gt) ** 2))
117 | rec_loss, ssim_loss, l1_loss, l1_smooth_loss, ncc_loss = self.build_losses(pred_I2, I2_aug)
118 |
119 | out_dict = {}
120 | out_dict.update(h_loss=h_loss, rec_loss=rec_loss, ssim_loss=ssim_loss, l1_loss=l1_loss,
121 | l1_smooth_loss=l1_smooth_loss, ncc_loss=ncc_loss,
122 | pred_h4p=pred_h4p, H_mat=H_mat, pred_I2=pred_I2)
123 |
124 | return out_dict
125 |
126 | def build_model(self, I1_aug, I2_aug):
127 | model_input = torch.cat([I1_aug, I2_aug], dim=1)
128 | x = self.feature(model_input)
129 | x = x.view(x.size(0), -1)
130 | x = self.fc(x)
131 | return x
132 |
133 | def solve_DLT(self, src_p, off_set):
134 | # src_p: shape=(bs, n, 4, 2)
135 | # off_set: shape=(bs, n, 4, 2)
136 | # can be used to compute mesh points (multi-H)
137 |
138 | bs, _ = src_p.shape
139 | divide = int(np.sqrt(len(src_p[0]) / 2) - 1)
140 | row_num = (divide + 1) * 2
141 |
142 | for i in range(divide):
143 | for j in range(divide):
144 |
145 | h4p = src_p[:, [2 * j + row_num * i, 2 * j + row_num * i + 1,
146 | 2 * (j + 1) + row_num * i, 2 * (j + 1) + row_num * i + 1,
147 | 2 * (j + 1) + row_num * i + row_num, 2 * (j + 1) + row_num * i + row_num + 1,
148 | 2 * j + row_num * i + row_num, 2 * j + row_num * i + row_num + 1]].reshape(bs, 1, 4, 2)
149 |
150 | pred_h4p = off_set[:, [2 * j + row_num * i, 2 * j + row_num * i + 1,
151 | 2 * (j + 1) + row_num * i, 2 * (j + 1) + row_num * i + 1,
152 | 2 * (j + 1) + row_num * i + row_num, 2 * (j + 1) + row_num * i + row_num + 1,
153 | 2 * j + row_num * i + row_num, 2 * j + row_num * i + row_num + 1]].reshape(bs, 1,
154 | 4, 2)
155 |
156 | if i + j == 0:
157 | src_ps = h4p
158 | off_sets = pred_h4p
159 | else:
160 | src_ps = torch.cat((src_ps, h4p), axis=1)
161 | off_sets = torch.cat((off_sets, pred_h4p), axis=1)
162 |
163 | bs, n, h, w = src_ps.shape
164 |
165 | N = bs * n
166 |
167 | src_ps = src_ps.reshape(N, h, w)
168 | off_sets = off_sets.reshape(N, h, w)
169 |
170 | dst_p = src_ps + off_sets
171 |
172 | ones = torch.ones(N, 4, 1)
173 | if torch.cuda.is_available():
174 | ones = ones.cuda()
175 | xy1 = torch.cat((src_ps, ones), 2)
176 | zeros = torch.zeros_like(xy1)
177 | if torch.cuda.is_available():
178 | zeros = zeros.cuda()
179 |
180 | xyu, xyd = torch.cat((xy1, zeros), 2), torch.cat((zeros, xy1), 2)
181 | M1 = torch.cat((xyu, xyd), 2).reshape(N, -1, 6)
182 | M2 = torch.matmul(
183 | dst_p.reshape(-1, 2, 1),
184 | src_ps.reshape(-1, 1, 2),
185 | ).reshape(N, -1, 2)
186 |
187 | A = torch.cat((M1, -M2), 2)
188 | b = dst_p.reshape(N, -1, 1)
189 |
190 | Ainv = torch.inverse(A)
191 | h8 = torch.matmul(Ainv, b).reshape(N, 8)
192 |
193 | H = torch.cat((h8, ones[:, 0, :]), 1).reshape(N, 3, 3)
194 | H = H.reshape(bs, n, 3, 3)
195 | return H
196 |
197 | def transform(self, patch_size, M_tile_inv, H_mat, M_tile, I, patch_indices, batch_indices_tensor):
198 | # Transform H_mat since we scale image indices in transformer
199 | batch_size, num_channels, img_h, img_w = I.size()
200 | # if torch.cuda.is_available():
201 | # M_tile_inv = M_tile_inv.cuda()
202 | H_mat = torch.matmul(torch.matmul(M_tile_inv, H_mat), M_tile)
203 | # Transform image 1 (large image) to image 2
204 | out_size = (img_h, img_w)
205 | warped_images, _ = transformer(I, H_mat, out_size)
206 |
207 | # Extract the warped patch from warped_images by flatting the whole batch before using indices
208 | # Note that input I is 3 channels so we reduce to gray
209 | warped_gray_images = torch.mean(warped_images, dim=3)
210 | warped_images_flat = torch.reshape(warped_gray_images, [-1])
211 | patch_indices_flat = torch.reshape(patch_indices, [-1])
212 | pixel_indices = patch_indices_flat.long() + batch_indices_tensor
213 | pred_I2_flat = torch.gather(warped_images_flat, 0, pixel_indices)
214 |
215 | pred_I2 = torch.reshape(pred_I2_flat, [batch_size, patch_size, patch_size, 1])
216 |
217 | return pred_I2.permute(0, 3, 1, 2)
218 |
219 | def build_losses(self, pred_I2, I2_aug):
220 | rec_loss = torch.sqrt(torch.mean((pred_I2 - I2_aug) ** 2))
221 | ssim_loss = torch.mean(SSIM_loss(pred_I2, I2_aug))
222 | l1_loss = torch.mean(torch.abs(pred_I2 - I2_aug))
223 | l1_smooth_loss = L1_smooth_loss(pred_I2, I2_aug)
224 | ncc_loss = NCC_loss(I2_aug, pred_I2)
225 | return rec_loss, ssim_loss, l1_loss, l1_smooth_loss, ncc_loss
226 |
--------------------------------------------------------------------------------
/code/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/code/utils/__init__.py
--------------------------------------------------------------------------------
/code/utils/gen_synthetic_data.py:
--------------------------------------------------------------------------------
1 | import os, shutil, argparse
2 | import glob
3 | import cv2
4 | import random
5 | import numpy as np
6 | from numpy.linalg import inv
7 | from numpy_spatial_transformer import numpy_transformer
8 |
9 |
10 | def homographyGeneration(args, raw_image_path, index, I_dir, I_prime_dir, gt_file, pts1_file, filenames_file):
11 | rho = args.rho
12 | patch_size = args.patch_size
13 | height = args.img_h
14 | width = args.img_w
15 |
16 | try:
17 | color_image = cv2.imread(raw_image_path)
18 | color_image = cv2.resize(color_image, (width, height))
19 | gray_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2GRAY)
20 | except:
21 | print('Error with image:', raw_image_path)
22 | return index, -1
23 |
24 | # Randomly pick the top left point of the patch on the real image
25 | y = random.randint(rho, height - rho - patch_size) # row?
26 | x = random.randint(rho, width - rho - patch_size) # col?
27 |
28 | # define corners of image patch
29 | top_left_point = (x, y)
30 | bottom_left_point = (patch_size + x, y)
31 | bottom_right_point = (patch_size + x, patch_size + y)
32 | top_right_point = (x, patch_size + y)
33 | four_points = [top_left_point, bottom_left_point, bottom_right_point, top_right_point]
34 | perturbed_four_points = []
35 | for point in four_points:
36 | perturbed_four_points.append((point[0] + random.randint(-rho, rho), point[1] + random.randint(-rho, rho)))
37 |
38 | # compute Homography
39 | H = cv2.getPerspectiveTransform(np.float32(four_points), np.float32(perturbed_four_points))
40 | try:
41 | H_inverse = inv(H)
42 | except:
43 | print("singular Error!")
44 | return index, -1
45 |
46 | inv_warped_color_image = None
47 | inv_warped_image = None
48 | if args.color:
49 | inv_warped_color_image = numpy_transformer(color_image, H_inverse, (width, height))
50 | else:
51 | inv_warped_image = numpy_transformer(gray_image, H_inverse, (width, height))
52 |
53 | # Extreact image patches (not used)
54 | if args.color:
55 | original_patch = gray_image[y:y + patch_size, x:x + patch_size]
56 | else:
57 | warped_patch = inv_warped_image[y:y + patch_size, x:x + patch_size]
58 |
59 | ######################################################################################
60 | # Save synthetic data I_dir I_prime_dir gt pts1
61 | large_img_path = os.path.join(I_dir, str(index) + '.jpg')
62 | if args.mode == 'train' and args.color == False:
63 | cv2.imwrite(large_img_path, gray_image)
64 | else:
65 | cv2.imwrite(large_img_path, color_image)
66 |
67 | if I_prime_dir is not None:
68 | img_prime_path = os.path.join(I_prime_dir, str(index) + '.jpg')
69 | if args.mode == 'train' and args.color == False:
70 | cv2.imwrite(img_prime_path, inv_warped_image)
71 | else:
72 | cv2.imwrite(img_prime_path, inv_warped_color_image)
73 |
74 | # Text files to store homography parameters (4 corners)
75 | f_pts1 = open(pts1_file, 'ab')
76 | f_gt = open(gt_file, 'ab')
77 | f_file_list = open(filenames_file, 'ab')
78 |
79 | # Ground truth is delta displacement
80 | gt = np.subtract(np.array(perturbed_four_points), np.array(four_points))
81 | gt = np.array(gt).flatten().astype(np.float32)
82 | # Four corners in the first image
83 | pts1 = np.array(four_points).flatten().astype(np.float32)
84 |
85 | np.savetxt(f_gt, [gt], fmt='%.1f', delimiter=' ')
86 | np.savetxt(f_pts1, [pts1], fmt='%.1f', delimiter=' ')
87 | f_file_list.write(('%s %s\n' % (str(index) + '.jpg', str(index) + '.jpg')).encode())
88 |
89 | index += 1
90 | if index % 1000 == 0:
91 | print('--image number ', index)
92 |
93 | f_gt.close()
94 | f_pts1.close()
95 | f_file_list.close()
96 | return index, 0
97 |
98 |
99 | def dataCollection(args):
100 | # Default folders and files for storage
101 | DATA_PATH = None
102 | gt_file = None
103 | pts1_file = None
104 | filenames_file = None
105 | if args.mode == 'train':
106 | DATA_PATH = args.data_path + 'train/'
107 | pts1_file = os.path.join(DATA_PATH, 'pts1.txt')
108 | filenames_file = os.path.join(DATA_PATH, 'train_synthetic.txt')
109 | gt_file = os.path.join(DATA_PATH, 'gt.txt')
110 | elif args.mode == 'test':
111 | DATA_PATH = args.data_path + 'test/'
112 | pts1_file = os.path.join(DATA_PATH, 'test_pts1.txt')
113 | filenames_file = os.path.join(DATA_PATH, 'test_synthetic.txt')
114 | gt_file = os.path.join(DATA_PATH, 'test_gt.txt')
115 | I_dir = DATA_PATH + 'I/' # Large image
116 | I_prime_dir = DATA_PATH + 'I_prime/' # Large image
117 |
118 | try:
119 | os.remove(gt_file)
120 | os.remove(pts1_file)
121 | os.remove(filenames_file)
122 | print('-- Current {} existed. Deleting..!'.format(gt_file))
123 | shutil.rmtree(I_dir, ignore_errors=True)
124 | if I_prime_dir is not None:
125 | shutil.rmtree(I_prime_dir, ignore_errors=True)
126 | except:
127 | print('-- Current {} not existed yet!'.format(gt_file))
128 |
129 | if not os.path.exists(I_dir):
130 | os.makedirs(I_dir)
131 | if I_prime_dir is not None and not os.path.exists(I_prime_dir):
132 | os.makedirs(I_prime_dir)
133 |
134 | raw_image_list = glob.glob(os.path.join(args.raw_data_path, '*.jpg'))
135 | print("Generate {} {} files from {} raw data.".format(args.num_data, args.mode, len(raw_image_list)))
136 |
137 | index = 0
138 | while True:
139 | raw_img_name = random.choice(raw_image_list)
140 | raw_image_path = os.path.join(args.raw_data_path, raw_img_name)
141 | index, error = homographyGeneration(args, raw_image_path, index,
142 | I_dir, I_prime_dir, gt_file, pts1_file, filenames_file)
143 | if error == -1:
144 | continue
145 | if index >= args.num_data:
146 | break
147 |
148 |
149 | def main():
150 | RHO = 45 # The maximum value of pertubation
151 |
152 | DATA_NUMBER = 100000
153 | TEST_DATA_NUMBER = 5000
154 |
155 | # Size of synthetic image
156 | HEIGHT = 240 #
157 | WIDTH = 320
158 | PATCH_SIZE = 128
159 |
160 | # Directories to files
161 | RAW_DATA_PATH = "D:/Workspace/Datasets/coco2014/train2014/" # Real images used for generating synthetic data
162 | TEST_RAW_DATA_PATH = "D:/Workspace/Datasets/coco2014/test2014/" # Real images used for generating test synthetic data
163 |
164 | # Synthetic data directories
165 | DATA_PATH = "../data/synthetic/" + str(RHO) + '/'
166 | if not os.path.exists(DATA_PATH):
167 | os.makedirs(DATA_PATH)
168 |
169 | def str2bool(s):
170 | return s.lower() == 'true'
171 |
172 | parser = argparse.ArgumentParser()
173 | parser.add_argument('--mode', type=str, default='test', help='Train or test', choices=['train', 'test'])
174 | parser.add_argument('--color', type=str2bool, default='true', help='Generate color or gray images')
175 |
176 | parser.add_argument('--raw_data_path', type=str, default=RAW_DATA_PATH, help='The raw data path.')
177 | parser.add_argument('--test_raw_data_path', type=str, default=TEST_RAW_DATA_PATH, help='The test raw data path.')
178 | parser.add_argument('--data_path', type=str, default=DATA_PATH, help='The raw data path.')
179 | parser.add_argument('--num_data', type=int, default=DATA_NUMBER, help='The data size for training')
180 | parser.add_argument('--test_num_data', type=int, default=TEST_DATA_NUMBER, help='The data size for test')
181 |
182 | parser.add_argument('--img_w', type=int, default=WIDTH)
183 | parser.add_argument('--img_h', type=int, default=HEIGHT)
184 | parser.add_argument('--rho', type=int, default=RHO)
185 | parser.add_argument('--patch_size', type=int, default=PATCH_SIZE)
186 |
187 | args = parser.parse_args()
188 | print('<==================== Loading raw data ===================>\n')
189 | if args.mode == 'test':
190 | args.num_data = args.test_num_data
191 | args.raw_data_path = args.test_raw_data_path
192 |
193 | print('<================= Generating Data .... =================>\n')
194 |
195 | dataCollection(args)
196 |
197 |
198 | if __name__ == '__main__':
199 | main()
200 |
--------------------------------------------------------------------------------
/code/utils/numpy_spatial_transformer.py:
--------------------------------------------------------------------------------
1 | # Simple version of spatial_transformer.py, work on a single image with multiple channels
2 | import numpy as np
3 | import cv2
4 | import pdb
5 | import matplotlib.pyplot as plt
6 | from skimage import io
7 | ###############################################################
8 | # Changable parameter
9 | SCALE_H = True
10 | # scale_H:# The indices of the grid of the target output is
11 | # scaled to [-1, 1]. Set False to stay in normal mode
12 | def _meshgrid(height, width, scale_H = SCALE_H):
13 | if scale_H:
14 | x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
15 | np.linspace(-1, 1, height))
16 | ones = np.ones(np.prod(x_t.shape))
17 | grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
18 |
19 | else:
20 | x_t, y_t = np.meshgrid(range(0,width), range(0,height))
21 | ones = np.ones(np.prod(x_t.shape))
22 | grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
23 | # print('--grid size:', grid.shape)
24 | return grid
25 |
26 |
27 | def _interpolate(im, x, y, out_size, scale_H = SCALE_H):
28 | # constants
29 | height = im.shape[0]
30 | width = im.shape[1]
31 |
32 |
33 | height_f = float(height)
34 | width_f = float(width)
35 | out_height = out_size[0]
36 | out_width = out_size[1]
37 | zero = np.zeros([], dtype='int32')
38 | max_y = im.shape[0] - 1
39 | max_x = im.shape[1] - 1
40 |
41 | if scale_H:
42 | # # scale indices from [-1, 1] to [0, width/height]
43 | x = (x + 1.0)*(width_f) / 2.0
44 | y = (y + 1.0)*(height_f) / 2.0
45 |
46 | # do sampling
47 | x0 = np.floor(x).astype(int)
48 | x1 = x0 + 1
49 | y0 = np.floor(y).astype(int)
50 | y1 = y0 + 1
51 |
52 | # print('x0:', x0)
53 | # print('y0:', y0)
54 | # Limit the size of the output image
55 | x0 = np.clip(x0, zero, max_x)
56 | x1 = np.clip(x1, zero, max_x)
57 | y0 = np.clip(y0, zero, max_y)
58 | y1 = np.clip(y1, zero, max_y)
59 | # print('x0:', x0)
60 | # print('y0:', y0)
61 |
62 | Ia = im[ y0, x0, ... ]
63 | Ib = im[ y1, x0, ... ]
64 | Ic = im[ y0, x1, ... ]
65 | Id = im[ y1, x1, ... ]
66 | # print(Ia.shape)
67 |
68 | # print
69 | # plt.figure(2)
70 | # plt.subplot(221)
71 | # plt.imshow(Ia)
72 | # plt.subplot(222)
73 | # plt.imshow(Ib)
74 | # plt.subplot(223)
75 | # plt.imshow(Ic)
76 | # plt.subplot(224)
77 | # plt.imshow(Id)
78 | # plt.show()
79 |
80 | wa = (x1 -x) * (y1-y)
81 | wb = (x1-x) * (y-y0)
82 | wc = (x-x0) * (y1-y)
83 | wd = (x-x0) * (y-y0)
84 | # print 'wabcd...', wa,wb, wc,wd
85 |
86 | # Handle multi channel image
87 | if im.ndim == 3:
88 | num_channels = im.shape[2]
89 | # wa = np.expand_dims(wa, 2)
90 | # wb = np.expand_dims(wb, 2)
91 | # wc = np.expand_dims(wc, 2)
92 | # wd = np.expand_dims(wd, 2)
93 | wa = np.tile(wa.reshape(-1, 1), num_channels)
94 | wb = np.tile(wb.reshape(-1, 1), num_channels)
95 | wc = np.tile(wc.reshape(-1, 1), num_channels)
96 | wd = np.tile(wd.reshape(-1, 1), num_channels)
97 | out = wa*Ia + wb*Ib + wc*Ic + wd*Id
98 | # print('--shape of out:', out.shape)
99 | return out
100 |
101 | def _transform(theta, input_dim, out_size):
102 | height, width = input_dim.shape[0], input_dim.shape[1]
103 | theta = np.reshape(theta, (3, 3))
104 | # print '--Theta:', theta
105 | # print '-- Theta shape:', theta.shape
106 |
107 | # grid of (x_t, y_t, 1), eq (1) in ref [1]
108 | out_height = out_size[0]
109 | out_width = out_size[1]
110 | grid = _meshgrid(out_height, out_width)
111 |
112 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
113 | T_g = np.dot(theta, grid)
114 | x_s = T_g[0,:]
115 | y_s = T_g[1,:]
116 | t_s = T_g[2,:]
117 | # print '-- T_g:', T_g
118 | # print '-- x_s:', x_s
119 | # print '-- y_s:', y_s
120 | # print '-- t_s:', t_s
121 |
122 | t_s_flat = np.reshape(t_s, [-1])
123 | # Ty changed
124 | # x_s_flat = np.reshape(x_s, [-1])
125 | # y_s_flat = np.reshape(y_s, [-1])
126 | x_s_flat = np.reshape(x_s, [-1])/t_s_flat
127 | y_s_flat = np.reshape(y_s, [-1])/t_s_flat
128 |
129 |
130 | input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size)
131 | if input_dim.ndim == 3:
132 | output = np.reshape(input_transformed, [out_height, out_width, -1])
133 | else:
134 | output = np.reshape(input_transformed, [out_height, out_width])
135 |
136 | output = output.astype(np.uint8)
137 | return output
138 |
139 |
140 | def numpy_transformer(img, H, out_size, scale_H = SCALE_H):
141 | h, w = img.shape[0], img.shape[1]
142 | # Matrix M
143 | M = np.array([[w/2.0, 0, w/2.0], [0, h/2.0, h/2.0], [0, 0, 1.]]).astype(np.float32)
144 |
145 | if scale_H:
146 | H_transformed = np.dot(np.dot(np.linalg.inv(M), np.linalg.inv(H)), M)
147 | # print 'H_transformed:', H_transformed
148 | img2 = _transform(H_transformed, img, [h,w])
149 | else:
150 | img2 = _transform(np.linalg.inv(H), img, [h,w])
151 | return img2
152 |
153 |
154 | def test_transformer(scale_H = SCALE_H):
155 | img = io.imread('D:/Workspace/Datasets/ms_coco_test_images/COCO_test2014_000000000001.jpg')
156 | h, w = img.shape[0], img.shape[1]
157 | print( '-- h, w:', h, w )
158 |
159 |
160 | # Apply homography transformation
161 |
162 | H = np.array([[2., 0.3, 5], [0.3, 2., 10.], [0.0001, 0.0002, 1.]]).astype(np.float32)
163 | img2 = cv2.warpPerspective(img, H, (w, h))
164 |
165 |
166 | # # Matrix M
167 | M = np.array([[w/2.0, 0, w/2.0], [0, h/2.0, h/2.0], [0, 0, 1.]]).astype(np.float32)
168 |
169 | if scale_H:
170 | H_transformed = np.dot(np.dot(np.linalg.inv(M), np.linalg.inv(H)), M)
171 | print('H_transformed:', H_transformed)
172 | img3 = _transform(H_transformed, img, [h,w])
173 | else:
174 | img3 = _transform(np.linalg.inv(H), img, [h,w])
175 |
176 | print ( '-- Reprojection error:', np.mean(np.abs(img3 - img2)))
177 | Reprojection = abs(img3 - img2)
178 | # Test on real image
179 | count = 0
180 | amount = 0
181 | for i in range(48):
182 | for j in range(48):
183 | for k in range(2):
184 | if Reprojection[i, j, k] > 10:
185 | print(i, j, k, 'value', Reprojection[i, j, k])
186 | count += 1
187 | amount += Reprojection[i, j, k]
188 | print('There is total %d > 10, over total %d, account for %.3f'%( count, 48*48*3,amount*1.0/count) )
189 |
190 | #io.imshow('img3', img3)
191 | try:
192 | plt.subplot(221)
193 | plt.imshow(img)
194 | plt.title('Original image')
195 |
196 | plt.subplot(222)
197 | plt.imshow(img2)
198 | plt.title('cv2.warpPerspective')
199 |
200 | plt.subplot(223)
201 | plt.imshow(img3)
202 | plt.title('Transformer')
203 |
204 | plt.subplot(224)
205 | plt.imshow(Reprojection)
206 | plt.title('Reprojection Error')
207 | plt.show()
208 | except KeyboardInterrupt:
209 | plt.close()
210 | exit(1)
211 |
212 |
213 | if __name__ == "__main__":
214 | test_transformer()
215 |
--------------------------------------------------------------------------------
/code/utils/torch_spatial_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def transformer(U, theta, out_size, **kwargs):
5 | """Spatial Transformer Layer
6 |
7 | Implements a spatial transformer layer as described in [1]_.
8 | Based on [2]_ and edited by David Dao for Tensorflow.
9 |
10 | Parameters
11 | ----------
12 | U : float
13 | The output of a convolutional net should have the
14 | shape [num_batch, height, width, num_channels].
15 | theta: float
16 | The output of the
17 | localisation network should be [num_batch, 6].
18 | out_size: tuple of two ints
19 | The size of the output of the network (height, width)
20 |
21 | References
22 | ----------
23 | .. [1] Spatial Transformer Networks
24 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
25 | Submitted on 5 Jun 2015
26 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
27 |
28 | Notes
29 | -----
30 | To initialize the network to the identity transform init
31 | ``theta`` to :
32 | identity = np.array([[1., 0., 0.],
33 | [0., 1., 0.]])
34 | identity = identity.flatten()
35 | theta = tf.Variable(initial_value=identity)
36 |
37 | """
38 |
39 | def _repeat(x, n_repeats):
40 |
41 | rep = torch.ones([n_repeats, ]).unsqueeze(0)
42 | rep = rep.int()
43 | x = x.int()
44 |
45 | x = torch.matmul(x.reshape([-1,1]), rep)
46 | return x.reshape([-1])
47 |
48 | def _interpolate(im, x, y, out_size, scale_h):
49 |
50 | num_batch, num_channels , height, width = im.size()
51 |
52 | height_f = height
53 | width_f = width
54 | out_height, out_width = out_size[0], out_size[1]
55 |
56 | zero = 0
57 | max_y = height - 1
58 | max_x = width - 1
59 | if scale_h:
60 |
61 | x = (x + 1.0)*(width_f) / 2.0
62 | y = (y + 1.0) * (height_f) / 2.0
63 |
64 | # do sampling
65 | x0 = torch.floor(x).int()
66 | x1 = x0 + 1
67 | y0 = torch.floor(y).int()
68 | y1 = y0 + 1
69 |
70 | x0 = torch.clamp(x0, zero, max_x)
71 | x1 = torch.clamp(x1, zero, max_x)
72 | y0 = torch.clamp(y0, zero, max_y)
73 | y1 = torch.clamp(y1, zero, max_y)
74 | dim2 = torch.from_numpy( np.array(width) )
75 | dim1 = torch.from_numpy( np.array(width * height) )
76 |
77 | base = _repeat(torch.arange(0,num_batch) * dim1, out_height * out_width)
78 | if torch.cuda.is_available():
79 | dim2 = dim2.cuda()
80 | dim1 = dim1.cuda()
81 | y0 = y0.cuda()
82 | y1 = y1.cuda()
83 | x0 = x0.cuda()
84 | x1 = x1.cuda()
85 | base = base.cuda()
86 | base_y0 = base + y0 * dim2
87 | base_y1 = base + y1 * dim2
88 | idx_a = base_y0 + x0
89 | idx_b = base_y1 + x0
90 | idx_c = base_y0 + x1
91 | idx_d = base_y1 + x1
92 |
93 | # channels dim
94 | im = im.permute(0,2,3,1)
95 | im_flat = im.reshape([-1, num_channels]).float()
96 |
97 | idx_a = idx_a.unsqueeze(-1).long()
98 | idx_a = idx_a.expand(height * width * num_batch,num_channels)
99 | Ia = torch.gather(im_flat, 0, idx_a)
100 |
101 | idx_b = idx_b.unsqueeze(-1).long()
102 | idx_b = idx_b.expand(height * width * num_batch, num_channels)
103 | Ib = torch.gather(im_flat, 0, idx_b)
104 |
105 | idx_c = idx_c.unsqueeze(-1).long()
106 | idx_c = idx_c.expand(height * width * num_batch, num_channels)
107 | Ic = torch.gather(im_flat, 0, idx_c)
108 |
109 | idx_d = idx_d.unsqueeze(-1).long()
110 | idx_d = idx_d.expand(height * width * num_batch, num_channels)
111 | Id = torch.gather(im_flat, 0, idx_d)
112 |
113 | x0_f = x0.float()
114 | x1_f = x1.float()
115 | y0_f = y0.float()
116 | y1_f = y1.float()
117 |
118 | wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
119 | wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
120 | wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
121 | wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
122 | output = wa*Ia+wb*Ib+wc*Ic+wd*Id
123 |
124 | return output
125 |
126 | def _meshgrid(height, width, scale_h):
127 |
128 | if scale_h:
129 | x_t = torch.matmul(torch.ones([height, 1]),
130 | torch.transpose(torch.unsqueeze(torch.linspace(-1.0, 1.0, width), 1), 1, 0))
131 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(-1.0, 1.0, height), 1),
132 | torch.ones([1, width]))
133 | else:
134 | x_t = torch.matmul(torch.ones([height, 1]),
135 | torch.transpose(torch.unsqueeze(torch.linspace(0.0, width.float(), width), 1), 1, 0))
136 | y_t = torch.matmul(torch.unsqueeze(torch.linspace(0.0, height.float(), height), 1),
137 | torch.ones([1, width]))
138 |
139 |
140 | x_t_flat = x_t.reshape((1, -1)).float()
141 | y_t_flat = y_t.reshape((1, -1)).float()
142 |
143 | ones = torch.ones_like(x_t_flat)
144 | grid = torch.cat([x_t_flat, y_t_flat, ones], 0)
145 | if torch.cuda.is_available():
146 | grid = grid.cuda()
147 | return grid
148 |
149 | def _transform(theta, input_dim, out_size, scale_h):
150 | num_batch, num_channels , height, width = input_dim.size()
151 | # Changed
152 | theta = theta.reshape([-1, 3, 3]).float()
153 |
154 | out_height, out_width = out_size[0], out_size[1]
155 | grid = _meshgrid(out_height, out_width, scale_h)
156 | grid = grid.unsqueeze(0).reshape([1,-1])
157 | shape = grid.size()
158 | grid = grid.expand(num_batch,shape[1])
159 | grid = grid.reshape([num_batch, 3, -1])
160 |
161 | T_g = torch.matmul(theta, grid)
162 | x_s = T_g[:,0,:]
163 | y_s = T_g[:,1,:]
164 | t_s = T_g[:,2,:]
165 |
166 | t_s_flat = t_s.reshape([-1])
167 |
168 | # smaller
169 | small = 1e-7
170 | smallers = 1e-6*(1.0 - torch.ge(torch.abs(t_s_flat), small).float())
171 |
172 | t_s_flat = t_s_flat + smallers
173 | condition = torch.sum(torch.gt(torch.abs(t_s_flat), small).float())
174 | # Ty changed
175 | x_s_flat = x_s.reshape([-1]) / t_s_flat
176 | y_s_flat = y_s.reshape([-1]) / t_s_flat
177 |
178 | input_transformed = _interpolate( input_dim, x_s_flat, y_s_flat,out_size,scale_h)
179 |
180 | output = input_transformed.reshape([num_batch, out_height, out_width, num_channels ])
181 | return output, condition
182 |
183 | img_w = U.size()[2]
184 | img_h = U.size()[1]
185 |
186 | scale_h = True
187 | output, condition = _transform(theta, U, out_size, scale_h)
188 | return output, condition
--------------------------------------------------------------------------------
/code/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def denorm_img(img):
8 | if len(img.shape) == 3:
9 | img = np.transpose(img, [1, 2, 0]) # torch [C,H,W]
10 | mean = np.array([118.93, 113.97, 102.60]).reshape([1, 1, 3])
11 | std = np.array([69.85, 68.81, 72.45]).reshape([1, 1, 3])
12 | elif len(img.shape) == 2:
13 | mean = np.mean([118.93, 113.97, 102.60])
14 | std = np.mean([69.85, 68.81, 72.45])
15 | return img * std + mean
16 |
17 |
18 | def save_correspondences_img(img1, img2, corr1, corr2, pred_corr2, results_dir, img_name):
19 | """ Save pair of images with their correspondences into a single image. Used for report"""
20 | # Draw prediction
21 | copy_img2 = img2.copy()
22 | copy_img1 = img1.copy()
23 | cv2.polylines(copy_img2, np.int32([pred_corr2]), 1, (5, 225, 225), 3)
24 |
25 | point_color = (0, 255, 255)
26 | line_color_set = [(255, 102, 255), (51, 153, 255), (102, 255, 255), (255, 255, 0), (102, 102, 244), (150, 202, 178),
27 | (153, 240, 142), (102, 0, 51), (51, 51, 0)]
28 | # Draw 4 points (ground truth)
29 | full_stack_images = draw_matches(copy_img1, corr1, copy_img2, corr2, '/tmp/tmp.jpg', color_set=line_color_set,
30 | show=False)
31 | # Save image
32 | visual_file_name = os.path.join(results_dir, img_name)
33 | # cv2.putText(full_stack_images, 'RMSE %.2f'%h_loss,(800, 100), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2)
34 | cv2.imwrite(visual_file_name, full_stack_images)
35 | print('Wrote file %s', visual_file_name)
36 |
37 |
38 | def draw_matches(img1, kp1, img2, kp2, output_img_file=None, color_set=None, show=True):
39 | """Draws lines between matching keypoints of two images without matches.
40 | This is a replacement for cv2.drawMatches
41 | Places the images side by side in a new image and draws circles
42 | around each keypoint, with line segments connecting matching pairs.
43 | You can tweak the r, thickness, and figsize values as needed.
44 | Args:
45 | img1: An openCV image ndarray in a grayscale or color format.
46 | kp1: A list of cv2.KeyPoint objects for img1.
47 | img2: An openCV image ndarray of the same format and with the same
48 | element type as img1.
49 | kp2: A list of cv2.KeyPoint objects for img2.
50 | color_set: The colors of the circles and connecting lines drawn on the images.
51 | A 3-tuple for color images, a scalar for grayscale images. If None, these
52 | values are randomly generated.
53 | """
54 | # We're drawing them side by side. Get dimensions accordingly.
55 | # Handle both color and grayscale images.
56 | if len(img1.shape) == 3:
57 | new_shape = (max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1], img1.shape[2])
58 | elif len(img1.shape) == 2:
59 | new_shape = (max(img1.shape[0], img2.shape[0]), img1.shape[1] + img2.shape[1])
60 | new_img = np.zeros(new_shape, type(img1.flat[0]))
61 | # Place images onto the new image.
62 | new_img[0:img1.shape[0], 0:img1.shape[1]] = img1
63 | new_img[0:img2.shape[0], img1.shape[1]:img1.shape[1] + img2.shape[1]] = img2
64 |
65 | # Draw lines between points
66 |
67 | kp2_on_stack_image = (kp2 + np.array([img1.shape[1], 0])).astype(np.int32)
68 |
69 | kp1 = kp1.astype(np.int32)
70 | # kp2_on_stack_image[0:4,0:2]
71 | line_color1 = (2, 10, 240)
72 | line_color2 = (2, 10, 240)
73 | # We want to make connections between points to make a square grid so first count the number of rows in the square grid.
74 | grid_num_rows = int(np.sqrt(kp1.shape[0]))
75 |
76 | if output_img_file is not None and grid_num_rows >= 3:
77 | for i in range(grid_num_rows):
78 | cv2.line(new_img, tuple(kp1[i * grid_num_rows]), tuple(kp1[i * grid_num_rows + (grid_num_rows - 1)]),
79 | line_color1, 1, LINE_AA)
80 | cv2.line(new_img, tuple(kp1[i]), tuple(kp1[i + (grid_num_rows - 1) * grid_num_rows]), line_color1, 1,
81 | cv2.LINE_AA)
82 | cv2.line(new_img, tuple(kp2_on_stack_image[i * grid_num_rows]),
83 | tuple(kp2_on_stack_image[i * grid_num_rows + (grid_num_rows - 1)]), line_color2, 1, cv2.LINE_AA)
84 | cv2.line(new_img, tuple(kp2_on_stack_image[i]),
85 | tuple(kp2_on_stack_image[i + (grid_num_rows - 1) * grid_num_rows]), line_color2, 1, cv2.LINE_AA)
86 |
87 | if output_img_file is not None and grid_num_rows == 2:
88 | cv2.polylines(new_img, np.int32([kp2_on_stack_image]), 1, line_color2, 3)
89 | cv2.polylines(new_img, np.int32([kp1]), 1, line_color1, 3)
90 | # Draw lines between matches. Make sure to offset kp coords in second image appropriately.
91 | r = 7
92 | thickness = 1
93 |
94 | for i in range(len(kp1)):
95 | key1 = kp1[i]
96 | key2 = kp2[i]
97 | # Generate random color for RGB/BGR and grayscale images as needed.
98 | try:
99 | c = color_set[i]
100 | except:
101 | c = np.random.randint(0, 256, 3) if len(img1.shape) == 3 else np.random.randint(0, 256)
102 | # So the keypoint locs are stored as a tuple of floats. cv2.line(), like most other things,
103 | # wants locs as a tuple of ints.
104 | end1 = tuple(np.round(key1).astype(int))
105 | end2 = tuple(np.round(key2).astype(int) + np.array([img1.shape[1], 0]))
106 | cv2.line(new_img, end1, end2, c, thickness, cv2.LINE_AA)
107 | cv2.circle(new_img, end1, r, c, thickness, cv2.LINE_AA)
108 | cv2.circle(new_img, end2, r, c, thickness, cv2.LINE_AA)
109 | # pdb.set_trace()
110 | if show:
111 | plt.figure(figsize=(15, 15))
112 | if len(img1.shape) == 3:
113 | plt.imshow(new_img)
114 | else:
115 | plt.imshow(new_img)
116 | plt.axis('off')
117 | plt.show()
118 | if output_img_file is not None:
119 | cv2.imwrite(output_img_file, new_img)
120 |
121 | return new_img
122 |
123 |
124 | def find_percentile(x):
125 | x_sorted = np.sort(x)
126 | len_x = len(x_sorted)
127 |
128 | # Find mean, var of top 20:
129 | tops_list = [0.3, 0.6, 1]
130 | return_list = []
131 | start_index = 0
132 | for i in range(len(tops_list)):
133 | print('>> Top %.0f - %.0f %%' % (tops_list[i - 1] * 100 if i >= 1 else 0, tops_list[i] * 100))
134 | stop_index = int(tops_list[i] * len_x)
135 |
136 | interval = x_sorted[start_index:stop_index]
137 | interval_mu = np.mean(interval)
138 | interval_std = np.std(interval)
139 |
140 | start_index = stop_index
141 | return_list.append([interval_mu, interval_std])
142 | return np.array(return_list)
143 |
--------------------------------------------------------------------------------
/results/synthetic/report/0025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0025.jpg
--------------------------------------------------------------------------------
/results/synthetic/report/0026.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0026.jpg
--------------------------------------------------------------------------------
/results/synthetic/report/0027.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0027.jpg
--------------------------------------------------------------------------------
/results/synthetic/report/0028.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0028.jpg
--------------------------------------------------------------------------------
/results/synthetic/report/0029.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/breadcake/unsupervisedDeepHomography-pytorch/11d207b05ed522938082b0f1255b8c8d4d49fb49/results/synthetic/report/0029.jpg
--------------------------------------------------------------------------------