├── LICENCE ├── README.md ├── data └── gen_data.py ├── dataset.py ├── model.py ├── results ├── 0040.jpg ├── 0041.jpg ├── 0042.jpg ├── 0043.jpg └── 0044.jpg ├── test.py ├── train.py └── utils.py /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 breadcake 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Image Homography Estimation- PyTorch Implementation 2 | [**Deep Image Homography Estimation**](https://arxiv.org/pdf/1606.03798.pdf)
3 | Daniel DeTone, Tomasz Malisiewicz, and Andrew Rabinovich 4 | 5 | ## Generate training dataset 6 | ```bash 7 | cd data/ 8 | python gen_data.py 9 | ``` 10 | ## Training 11 | ```bash 12 | python train.py 13 | ``` 14 | ## Test 15 | Download pre-trained weights 16 | ```bash 17 | 链接:https://pan.baidu.com/s/10HXNthOBhlZbrtvIkolxKw 提取码:l9l8 18 | ``` 19 | Store the model to checkpoints/ folder 20 | ```bash 21 | python test.py 22 | ``` 23 | 24 | results | 25 | --- | 26 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210323211344844.png?x-oss-process) | 27 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210323211415816.png?x-oss-process) | 28 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210323211439899.png?x-oss-process) | 29 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210323211457964.png?x-oss-process) | 30 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20210323211530847.png?x-oss-process) | 31 | 32 | ## Reference 33 | [https://github.com/mazenmel/Deep-homography-estimation-Pytorch](https://github.com/mazenmel/Deep-homography-estimation-Pytorch) 34 | -------------------------------------------------------------------------------- /data/gen_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import random 4 | import numpy as np 5 | from numpy.linalg import inv 6 | import time 7 | 8 | train_path = 'D:/Workspace/Datasets/coco2014/train2014' 9 | val_path = 'D:/Workspace/Datasets/coco2014/val2014' 10 | test_path = 'D:/Workspace/Datasets/coco2014/test2014' 11 | 12 | 13 | def ImagePreProcessing(image_path, rho, patch_size, imsize): 14 | img = cv2.imread(image_path, 0) 15 | img = cv2.resize(img, imsize) 16 | 17 | position_p = (random.randint(rho, imsize[0] - rho - patch_size), random.randint(rho, imsize[1] - rho - patch_size)) 18 | tl_point = position_p 19 | tr_point = (patch_size + position_p[0], position_p[1]) 20 | br_point = (patch_size + position_p[0], patch_size + position_p[1]) 21 | bl_point = (position_p[0], patch_size + position_p[1]) 22 | 23 | test_image = img.copy() 24 | four_points = [tl_point, tr_point, br_point, bl_point] 25 | 26 | perturbed_four_points = [] 27 | for point in four_points: 28 | perturbed_four_points.append((point[0] + random.randint(-rho, rho), point[1] + random.randint(-rho, rho))) 29 | 30 | H = cv2.getPerspectiveTransform(np.float32(four_points), np.float32(perturbed_four_points)) 31 | H_inverse = inv(H) 32 | 33 | warped_image = cv2.warpPerspective(img, H_inverse, imsize) 34 | 35 | # Extract image patches (not stored) 36 | # Ip1 = test_image[tl_point[1]:br_point[1], tl_point[0]:br_point[0]] 37 | # Ip2 = warped_image[tl_point[1]:br_point[1], tl_point[0]:br_point[0]] 38 | 39 | training_image = np.dstack((img, warped_image)) 40 | H_four_points = np.subtract(np.array(perturbed_four_points), np.array(four_points)) 41 | datum = (training_image, np.array(four_points), H_four_points) 42 | 43 | return datum 44 | 45 | 46 | # save .npy files 47 | def savedata(source_path, new_path, rho, patch_size, imsize, data_size): 48 | lst = os.listdir(source_path + '/') 49 | filenames = [os.path.join(source_path, l) for l in lst if l[-3:] == 'jpg'] 50 | print("Generate {} {} files from {} raw data...".format(data_size, new_path, len(filenames))) 51 | if not os.path.exists(new_path): 52 | os.makedirs(new_path) 53 | for i in range(data_size): 54 | image_path = random.choice(filenames) 55 | np.save(new_path + '/' + ('%s' % i).zfill(6), ImagePreProcessing(image_path, rho, patch_size, imsize)) 56 | if (i + 1) % 1000 == 0: 57 | print('--image number ', i+1) 58 | 59 | 60 | if __name__ == "__main__": 61 | start = time.time() 62 | rho = 32 63 | patch_size = 128 64 | imsize = (320, 240) 65 | savedata(train_path, './training/', rho, patch_size, imsize, data_size=500000) 66 | savedata(val_path, './validation/', rho, patch_size, imsize, data_size=5000) 67 | savedata(test_path, './testing/', rho, patch_size, imsize, data_size=5000) 68 | elapsed_time = time.time() - start 69 | print("Generate dataset in {:.0f}h {:.0f}m {:.0f}s.".format( 70 | elapsed_time // 3600, (elapsed_time % 3600) // 60, (elapsed_time % 3600) % 60)) 71 | 72 | # # show sample 73 | # from matplotlib import pyplot as plt 74 | # npy = random.choice([os.path.join('./training/', f) for f in os.listdir('./training/')]) 75 | # ori_images, pts1, delta = np.load(npy, allow_pickle=True) 76 | # pts2 = pts1 + delta 77 | # patch1 = ori_images[:, :, 0].copy() 78 | # patch2 = ori_images[:, :, 1].copy() 79 | # patch1 = cv2.cvtColor(patch1, cv2.COLOR_GRAY2RGB) 80 | # patch2 = cv2.cvtColor(patch2, cv2.COLOR_GRAY2RGB) 81 | # cv2.polylines(patch1, [pts1], True, (81, 167, 249), 2, cv2.LINE_AA) 82 | # cv2.polylines(patch1, [pts2], True, (111, 191, 64), 2, cv2.LINE_AA) 83 | # cv2.polylines(patch2, [pts1], True, (111, 191, 64), 2, cv2.LINE_AA) 84 | # plt.subplot(121), plt.imshow(patch1) 85 | # plt.subplot(122), plt.imshow(patch2) 86 | # plt.show() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | import numpy as np 5 | 6 | # Create a customized dataset class in pytorch 7 | class CocoDdataset(Dataset): 8 | def __init__(self, path, rho=32): 9 | lst = os.listdir(path) 10 | self.data = [path + i for i in lst] 11 | self.rho = rho 12 | 13 | def __getitem__(self, index): 14 | ori_images, pts1, delta = np.load(self.data[index], allow_pickle=True) 15 | 16 | ori_images = (ori_images.astype(float) - 127.5) / 127.5 17 | ori_images = np.transpose(ori_images, [2, 0, 1]) # torch [C,H,W] 18 | 19 | input_patch = ori_images[:, pts1[0, 1]: pts1[2, 1], pts1[0, 0]: pts1[2, 0]] 20 | 21 | delta = delta.astype(float) / self.rho 22 | 23 | ori_images = torch.from_numpy(ori_images) 24 | input_patch = torch.from_numpy(input_patch) 25 | pts1 = torch.from_numpy(pts1) 26 | delta = torch.from_numpy(delta) 27 | return ori_images, input_patch, pts1, delta 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class HomographyNet(nn.Module): 5 | def __init__(self): 6 | super(HomographyNet, self).__init__() 7 | self.layer1 = nn.Sequential( 8 | nn.Conv2d(2, 64, kernel_size=3, padding=1, bias=False), 9 | nn.BatchNorm2d(64), 10 | nn.ReLU(inplace=True), 11 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 12 | nn.BatchNorm2d(64), 13 | nn.ReLU(inplace=True), 14 | nn.MaxPool2d(kernel_size=2, stride=2) 15 | ) 16 | self.layer2 = nn.Sequential( 17 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 18 | nn.BatchNorm2d(64), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False), 21 | nn.BatchNorm2d(64), 22 | nn.ReLU(inplace=True), 23 | nn.MaxPool2d(kernel_size=2, stride=2) 24 | ) 25 | self.layer3 = nn.Sequential( 26 | nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(128), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), 30 | nn.BatchNorm2d(128), 31 | nn.ReLU(inplace=True), 32 | nn.MaxPool2d(kernel_size=2, stride=2) 33 | ) 34 | self.layer4 = nn.Sequential( 35 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), 36 | nn.BatchNorm2d(128), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False), 39 | nn.BatchNorm2d(128), 40 | nn.ReLU(inplace=True) 41 | ) 42 | self.fc = nn.Sequential( 43 | nn.Dropout(0.5), 44 | nn.Linear(128 * 16 * 16, 1024), 45 | nn.ReLU(True), 46 | nn.Dropout(0.5), 47 | nn.Linear(1024, 8) 48 | ) 49 | 50 | def forward(self, x): 51 | out = self.layer1(x) 52 | out = self.layer2(out) 53 | out = self.layer3(out) 54 | out = self.layer4(out) 55 | out = out.contiguous().view(x.size(0), -1) 56 | out = self.fc(out) 57 | return out 58 | 59 | 60 | if __name__ == "__main__": 61 | from torchsummary import summary 62 | model = HomographyNet().cuda() 63 | print(HomographyNet()) 64 | summary(model, (2, 128, 128)) 65 | -------------------------------------------------------------------------------- /results/0040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/Deep-homography-estimation-pytorch/8f271bb2108a7cc3fbb32cddba2ebbb7c9f1c735/results/0040.jpg -------------------------------------------------------------------------------- /results/0041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/Deep-homography-estimation-pytorch/8f271bb2108a7cc3fbb32cddba2ebbb7c9f1c735/results/0041.jpg -------------------------------------------------------------------------------- /results/0042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/Deep-homography-estimation-pytorch/8f271bb2108a7cc3fbb32cddba2ebbb7c9f1c735/results/0042.jpg -------------------------------------------------------------------------------- /results/0043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/Deep-homography-estimation-pytorch/8f271bb2108a7cc3fbb32cddba2ebbb7c9f1c735/results/0043.jpg -------------------------------------------------------------------------------- /results/0044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breadcake/Deep-homography-estimation-pytorch/8f271bb2108a7cc3fbb32cddba2ebbb7c9f1c735/results/0044.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from dataset import CocoDdataset 4 | from model import HomographyNet 5 | import argparse 6 | import os 7 | import numpy as np 8 | import cv2 9 | import utils 10 | 11 | 12 | def denorm_img(img): 13 | img = img * 127.5 + 127.5 14 | img = np.clip(img, 0, 255) 15 | return np.uint8(img) 16 | 17 | 18 | def warp_pts(H, src_pts): 19 | src_homo = np.hstack((src_pts, np.ones((4, 1)))).T 20 | dst_pts = np.matmul(H, src_homo) 21 | dst_pts = dst_pts / dst_pts[-1] 22 | return dst_pts.T[:, :2] 23 | 24 | 25 | def test(args): 26 | MODEL_SAVE_DIR = 'checkpoints/' 27 | model_path = os.path.join(MODEL_SAVE_DIR, args.checkpoint) 28 | result_dir = 'results/' 29 | if not os.path.exists(result_dir): 30 | os.mkdir(result_dir) 31 | 32 | model = HomographyNet() 33 | state = torch.load(model_path) 34 | model.load_state_dict(state['state_dict']) 35 | if torch.cuda.is_available(): 36 | model = model.cuda() 37 | 38 | TestingData = CocoDdataset(args.test_path) 39 | test_loader = DataLoader(TestingData, batch_size=1) 40 | 41 | print("start testing") 42 | with torch.no_grad(): 43 | model.eval() 44 | error = np.zeros(len(TestingData)) 45 | for i, batch_value in enumerate(test_loader): 46 | ori_images = batch_value[0].float() 47 | inputs = batch_value[1].float() 48 | pts1 = batch_value[2] 49 | target = batch_value[3].float() 50 | if torch.cuda.is_available(): 51 | inputs = inputs.cuda() 52 | 53 | outputs = model(inputs) 54 | outputs = outputs * 32 55 | target = target * 32 56 | 57 | # visual 58 | I_A = denorm_img(ori_images[0, 0, ...].numpy()) 59 | I_B = denorm_img(ori_images[0, 1, ...].numpy()) 60 | pts1 = pts1[0].numpy() 61 | 62 | gt_h4p = target[0].numpy() 63 | pts2 = pts1 + gt_h4p 64 | gt_h = cv2.getPerspectiveTransform(np.float32(pts1), np.float32(pts2)) 65 | gt_h_inv = np.linalg.inv(gt_h) 66 | pts1_ = warp_pts(gt_h_inv, pts1) 67 | 68 | pred_h4p = outputs[0].cpu().numpy().reshape([4, 2]) 69 | pred_pts2 = pts1 + pred_h4p 70 | pred_h = cv2.getPerspectiveTransform(np.float32(pts1), np.float32(pred_pts2)) 71 | pred_h_inv = np.linalg.inv(pred_h) 72 | pred_pts1_ = warp_pts(pred_h_inv, pts1) 73 | 74 | visual_file_name = ('%s' % i).zfill(4) + '.jpg' 75 | utils.save_correspondences_img(I_A, I_B, pts1, pts1_, pred_pts1_, 76 | result_dir, visual_file_name) 77 | 78 | error[i] = np.mean(np.sqrt(np.sum((gt_h4p - pred_h4p) ** 2, axis=-1))) 79 | print('Mean Corner Error: ', error[i]) 80 | 81 | print('Mean Average Corner Error over the test set: ', np.mean(error)) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--checkpoint", default="homographymodel.pth") 87 | parser.add_argument("--test_path", type=str, default="data/testing/", help="path to test images") 88 | args = parser.parse_args() 89 | test(args) 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from dataset import CocoDdataset 4 | from model import HomographyNet 5 | from torch.utils.data import DataLoader 6 | import argparse 7 | import time 8 | import os 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 11 | 12 | def train(args): 13 | MODEL_SAVE_DIR = 'checkpoints/' 14 | if not os.path.exists(MODEL_SAVE_DIR): 15 | os.makedirs(MODEL_SAVE_DIR) 16 | 17 | model = HomographyNet() 18 | 19 | TrainingData = CocoDdataset(args.train_path) 20 | ValidationData = CocoDdataset(args.val_path) 21 | print('Found totally {} training files and {} validation files'.format(len(TrainingData), len(ValidationData))) 22 | train_loader = DataLoader(TrainingData, batch_size=args.batch_size, shuffle=True, num_workers=4) 23 | val_loader = DataLoader(ValidationData, batch_size=args.batch_size, num_workers=4) 24 | 25 | if torch.cuda.is_available(): 26 | model = model.cuda() 27 | 28 | criterion = nn.MSELoss() 29 | optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9) 30 | # decrease the learning rate after every 1/3 epochs 31 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(args.epochs / 3), gamma=0.1) 32 | 33 | print("start training") 34 | glob_iter = 0 35 | t0 = time.time() 36 | for epoch in range(args.epochs): 37 | epoch_start = time.time() 38 | # Training 39 | model.train() 40 | train_loss = 0.0 41 | for i, batch_value in enumerate(train_loader): 42 | # save model 43 | if (glob_iter % 4000 == 0 and glob_iter != 0): 44 | filename = 'homographymodel' + '_iter_' + str(glob_iter) + '.pth' 45 | model_save_path = os.path.join(MODEL_SAVE_DIR, filename) 46 | state = {'epoch': args.epochs, 'state_dict': model.state_dict(), 47 | 'optimizer': optimizer.state_dict()} 48 | torch.save(state, model_save_path) 49 | 50 | ori_images = batch_value[0].float() 51 | inputs = batch_value[1].float() 52 | pts1 = batch_value[2] 53 | target = batch_value[3].float() 54 | 55 | if torch.cuda.is_available(): 56 | inputs = inputs.cuda() 57 | target = target.cuda() 58 | 59 | optimizer.zero_grad() # 梯度清零 60 | outputs = model(inputs) 61 | loss = criterion(outputs, target.view(-1, 8)) 62 | loss.backward() 63 | optimizer.step() 64 | train_loss += loss.item() 65 | if (i + 1) % 200 == 0 or (i+1) == len(train_loader): 66 | print("Training: Epoch[{:0>3}/{:0>3}] Iter[{:0>3}/{:0>3}] Mean Squared Error: {:.4f} lr={:.6f}".format( 67 | epoch+1, args.epochs, i+1, len(train_loader), train_loss / 200, scheduler.get_lr()[0])) 68 | train_loss = 0.0 69 | 70 | glob_iter += 1 71 | scheduler.step() 72 | 73 | # Validation 74 | with torch.no_grad(): 75 | model.eval() 76 | val_loss = 0.0 77 | for i, batch_value in enumerate(val_loader): 78 | ori_images = batch_value[0].float() 79 | inputs = batch_value[1].float() 80 | pts1 = batch_value[2] 81 | target = batch_value[3].float() 82 | if torch.cuda.is_available(): 83 | inputs, target = inputs.cuda(), target.cuda() 84 | outputs = model(inputs) 85 | loss = criterion(outputs, target.view(-1, 8)) 86 | val_loss += loss.item() 87 | print("Validation: Epoch[{:0>3}/{:0>3}] Mean Squared Error:{:.4f}, epoch time: {:.1f}s".format( 88 | epoch + 1, args.epochs, val_loss / len(val_loader), time.time() - epoch_start)) 89 | 90 | elapsed_time = time.time() - t0 91 | print("Finished Training in {:.0f}h {:.0f}m {:.0f}s.".format( 92 | elapsed_time // 3600, (elapsed_time % 3600) // 60, (elapsed_time % 3600) % 60)) 93 | 94 | 95 | if __name__ == "__main__": 96 | train_path = 'data/training/' 97 | val_path = 'data/validation/' 98 | 99 | total_iteration = 90000 100 | batch_size = 64 101 | num_samples = 500000 102 | steps_per_epoch = num_samples // batch_size 103 | epochs = int(total_iteration / steps_per_epoch) 104 | 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("--batch_size", type=int, default=batch_size, help="batch size") 107 | parser.add_argument("--learning_rate", type=float, default=0.005, help="learning rate") 108 | parser.add_argument("--epochs", type=int, default=epochs, help="number of epochs") 109 | 110 | parser.add_argument("--train_path", type=str, default=train_path, help="path to training imgs") 111 | parser.add_argument("--val_path", type=str, default=val_path, help="path to validation imgs") 112 | args = parser.parse_args() 113 | train(args) 114 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | 5 | 6 | def save_correspondences_img(img1, img2, corr1, corr2, pred_corr2, results_dir, img_name): 7 | """ Save pair of images with their correspondences into a single image. Used for report""" 8 | new_img = np.zeros((max(img1.shape[0], img2.shape[0]), img1.shape[1]+img2.shape[1]), np.uint8) 9 | new_img[0:img1.shape[0], 0:img1.shape[1]] = img1.copy() 10 | new_img[0:img2.shape[0], img1.shape[1]:img1.shape[1] + img2.shape[1]] = img2.copy() 11 | new_img = cv2.cvtColor(new_img, cv2.COLOR_GRAY2RGB) 12 | 13 | cv2.polylines(new_img, np.int32([corr1]), 1, (255, 0, 0), 2, cv2.LINE_AA) 14 | 15 | corr2_ = (corr2 + np.array([img1.shape[1], 0])).astype(np.int32) 16 | pred_corr2_ = (pred_corr2 + np.array([img1.shape[1], 0])).astype(np.int32) 17 | 18 | cv2.polylines(new_img, np.int32([corr2_]), 1, (255, 0, 0), 2, cv2.LINE_AA) 19 | cv2.polylines(new_img, np.int32([pred_corr2_]), 1, (0, 225, 0), 2, cv2.LINE_AA) 20 | 21 | # Save image 22 | visual_file_name = os.path.join(results_dir, img_name) 23 | # cv2.putText(full_stack_images, 'RMSE %.2f'%h_loss,(800, 100), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,0,255),2) 24 | cv2.imwrite(visual_file_name, new_img) 25 | print('Wrote file %s' % visual_file_name) 26 | --------------------------------------------------------------------------------