├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data_gen.py ├── export.py ├── extract.py ├── images └── result.jpg ├── mobilenet_v2.py ├── optimizer.py ├── pre_process.py ├── test.py ├── test ├── 000000523955.jpg ├── test_gen_test.py └── test_gen_train.py ├── test_orb.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 刘杨 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HomographyNet 2 | 3 | This is a deep convolutional neural network for estimating the relative homography between a pair of images. 4 | Deep Image Homography Estimation [paper](https://arxiv.org/abs/1606.03798) implementation in PyTorch. 5 | 6 | ## Features 7 | 8 | - Backbone: MobileNetV2 9 | - Dataset: MSCOCO 2014 training set 10 | 11 | ## DataSet 12 | 13 | - Train/valid: generated 500,000/41,435 pairs of image patches sized 128x128(rho=32). 14 | - Test: generated 10,000 pairs of image patches sized 256x256(rho=64). 15 | 16 | 17 | ## Dependencies 18 | 19 | - Python 3.6.8 20 | - PyTorch 1.3.0 21 | 22 | 23 | ## Usage 24 | ### Data Pre-processing 25 | Extract training images: 26 | ```bash 27 | $ python3 extract.py 28 | $ python3 pre_process.py 29 | ``` 30 | 31 | ### Train 32 | ```bash 33 | $ python3 train.py --lr 0.005 --batch-size 64 34 | ``` 35 | 36 | If you want to visualize during training, run in your terminal: 37 | ```bash 38 | $ tensorboard --logdir runs 39 | ``` 40 | 41 | ## Test 42 | Homography Estimation Comparison on Warped MS-COCO 14 Test Set. 43 | ```bash 44 | $ python3 test.py 45 | $ python3 test_orb.py --type surf 46 | $ python3 test_orb.py --type identity 47 | ``` 48 | ### Result 49 | |Method|Mean Average Corner Error (pixels)| 50 | |---|---| 51 | |HomographyNet|3.53| 52 | |SURF + RANSAC|8.83| 53 | |Identity Homography|32.13| 54 | 55 | ### Graph 56 | ![image](https://gitee.com/foamliu/HomographyNet/raw/master/images/result.jpg) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # sets device for model and PyTorch tensors 4 | 5 | im_size = 128 6 | batch_size = 64 7 | 8 | num_samples = 118287 9 | num_train = 500000 10 | num_valid = 41435 11 | num_test = 10000 12 | image_folder = 'data/train2017' 13 | train_file = 'data/train.pkl' 14 | valid_file = 'data/valid.pkl' 15 | test_file = 'data/test.pkl' 16 | 17 | # Training parameters 18 | num_workers = 8 # for data-loading 19 | grad_clip = 5. # clip gradients at an absolute value of 20 | print_freq = 100 # print training/validation stats every __ batches 21 | checkpoint = None # path to checkpoint, None if none 22 | -------------------------------------------------------------------------------- /data_gen.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | from config import im_size 8 | 9 | 10 | class DeepHNDataset(Dataset): 11 | def __init__(self, split): 12 | filename = 'data/{}.pkl'.format(split) 13 | print('loading {}...'.format(filename)) 14 | with open(filename, 'rb') as file: 15 | samples = pickle.load(file) 16 | np.random.shuffle(samples) 17 | self.split = split 18 | self.samples = samples 19 | 20 | def __getitem__(self, i): 21 | sample = self.samples[i] 22 | image, four_points, perturbed_four_points = sample 23 | img0 = image[:, :, 0] 24 | img0 = cv.resize(img0, (im_size, im_size)) 25 | img1 = image[:, :, 1] 26 | img1 = cv.resize(img1, (im_size, im_size)) 27 | img = np.zeros((im_size, im_size, 3), np.float32) 28 | img[:, :, 0] = img0 / 255. 29 | img[:, :, 1] = img1 / 255. 30 | img = np.transpose(img, (2, 0, 1)) # HxWxC array to CxHxW 31 | H_four_points = np.subtract(np.array(perturbed_four_points), np.array(four_points)) 32 | target = np.reshape(H_four_points, (8,)) 33 | return img, target 34 | 35 | def __len__(self): 36 | return len(self.samples) 37 | 38 | 39 | if __name__ == "__main__": 40 | train = DeepHNDataset('train') 41 | print('num_train: ' + str(len(train))) 42 | valid = DeepHNDataset('valid') 43 | print('num_valid: ' + str(len(valid))) 44 | 45 | print(train[0]) 46 | print(valid[0]) 47 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | 5 | from mobilenet_v2 import MobileNetV2 6 | 7 | if __name__ == '__main__': 8 | checkpoint = 'BEST_checkpoint.tar' 9 | print('loading {}...'.format(checkpoint)) 10 | start = time.time() 11 | checkpoint = torch.load(checkpoint) 12 | print('elapsed {} sec'.format(time.time() - start)) 13 | model = checkpoint['model'].module 14 | # print(model) 15 | # print(type(model)) 16 | 17 | # model.eval() 18 | filename = 'homonet.pt' 19 | print('saving {}...'.format(filename)) 20 | start = time.time() 21 | torch.save(model.state_dict(), filename) 22 | print('elapsed {} sec'.format(time.time() - start)) 23 | 24 | print('loading {}...'.format(filename)) 25 | start = time.time() 26 | model = MobileNetV2() 27 | model.load_state_dict(torch.load(filename)) 28 | print('elapsed {} sec'.format(time.time() - start)) 29 | -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | 3 | 4 | def extract(filename): 5 | print('Extracting {}...'.format(filename)) 6 | zip_ref = zipfile.ZipFile(filename, 'r') 7 | zip_ref.extractall('data') 8 | zip_ref.close() 9 | 10 | 11 | if __name__ == "__main__": 12 | extract('data/train2017.zip') 13 | -------------------------------------------------------------------------------- /images/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/HomographyNet/fbe7f2271fd9686f641df49656e16002e6c1f9d1/images/result.jpg -------------------------------------------------------------------------------- /mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.quantization import QuantStub, DeQuantStub 6 | from torchsummary import summary 7 | 8 | from config import device 9 | 10 | 11 | def _make_divisible(v, divisor, min_value=None): 12 | """ 13 | This function is taken from the original tf repo. 14 | It ensures that all layers have a channel number that is divisible by 8 15 | It can be seen here: 16 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 17 | :param v: 18 | :param divisor: 19 | :param min_value: 20 | :return: 21 | """ 22 | if min_value is None: 23 | min_value = divisor 24 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than 10%. 26 | if new_v < 0.9 * v: 27 | new_v += divisor 28 | return new_v 29 | 30 | 31 | class ConvBNReLU(nn.Sequential): 32 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 33 | padding = (kernel_size - 1) // 2 34 | super(ConvBNReLU, self).__init__( 35 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 36 | nn.BatchNorm2d(out_planes, momentum=0.1), 37 | # Replace with ReLU 38 | nn.ReLU(inplace=False) 39 | ) 40 | 41 | 42 | class InvertedResidual(nn.Module): 43 | def __init__(self, inp, oup, stride, expand_ratio): 44 | super(InvertedResidual, self).__init__() 45 | self.stride = stride 46 | assert stride in [1, 2] 47 | 48 | hidden_dim = int(round(inp * expand_ratio)) 49 | self.use_res_connect = self.stride == 1 and inp == oup 50 | 51 | layers = [] 52 | if expand_ratio != 1: 53 | # pw 54 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 55 | layers.extend([ 56 | # dw 57 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | nn.BatchNorm2d(oup, momentum=0.1), 61 | ]) 62 | self.conv = nn.Sequential(*layers) 63 | # Replace torch.add with floatfunctional 64 | self.skip_add = nn.quantized.FloatFunctional() 65 | 66 | def forward(self, x): 67 | if self.use_res_connect: 68 | return self.skip_add.add(x, self.conv(x)) 69 | else: 70 | return self.conv(x) 71 | 72 | 73 | class MobileNetV2(nn.Module): 74 | def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 75 | """ 76 | MobileNet V2 main class 77 | 78 | Args: 79 | num_classes (int): Number of classes 80 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 81 | inverted_residual_setting: Network structure 82 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 83 | Set to 1 to turn off rounding 84 | """ 85 | super(MobileNetV2, self).__init__() 86 | block = InvertedResidual 87 | input_channel = 32 88 | last_channel = 1280 89 | 90 | if inverted_residual_setting is None: 91 | inverted_residual_setting = [ 92 | # t, c, n, s 93 | [1, 16, 1, 1], 94 | [6, 24, 2, 2], 95 | [6, 32, 3, 2], 96 | [6, 64, 4, 2], 97 | [6, 96, 3, 1], 98 | [6, 160, 3, 2], 99 | [6, 320, 1, 1], 100 | ] 101 | 102 | # only check the first element, assuming user knows t,c,n,s are required 103 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 104 | raise ValueError("inverted_residual_setting should be non-empty " 105 | "or a 4-element list, got {}".format(inverted_residual_setting)) 106 | 107 | # building first layer 108 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 109 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 110 | features = [ConvBNReLU(3, input_channel, stride=2)] 111 | # building inverted residual blocks 112 | for t, c, n, s in inverted_residual_setting: 113 | output_channel = _make_divisible(c * width_mult, round_nearest) 114 | for i in range(n): 115 | stride = s if i == 0 else 1 116 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 117 | input_channel = output_channel 118 | # building last several layers 119 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 120 | # make it nn.Sequential 121 | self.features = nn.Sequential(*features) 122 | self.quant = QuantStub() 123 | self.dequant = DeQuantStub() 124 | # building classifier 125 | self.classifier = nn.Sequential( 126 | nn.Dropout(0.2), 127 | nn.LeakyReLU(0.2, inplace=True), 128 | nn.Linear(1280, 8), 129 | # nn.Sigmoid() 130 | ) 131 | 132 | # weight initialization 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 136 | if m.bias is not None: 137 | nn.init.zeros_(m.bias) 138 | elif isinstance(m, nn.BatchNorm2d): 139 | nn.init.ones_(m.weight) 140 | nn.init.zeros_(m.bias) 141 | elif isinstance(m, nn.Linear): 142 | nn.init.normal_(m.weight, 0, 0.01) 143 | nn.init.zeros_(m.bias) 144 | 145 | def forward(self, x): 146 | 147 | x = self.quant(x) 148 | 149 | x = self.features(x) 150 | x = x.mean([2, 3]) 151 | x = self.classifier(x) 152 | # x = (x - 0.5) * 64 # (0, 1) -> (-32, 32) 153 | x = self.dequant(x) 154 | return x 155 | 156 | # Fuse Conv+BN and Conv+BN+Relu modules prior to quantization 157 | # This operation does not change the numerics 158 | def fuse_model(self): 159 | for m in self.modules(): 160 | if type(m) == ConvBNReLU: 161 | torch.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True) 162 | if type(m) == InvertedResidual: 163 | for idx in range(len(m.conv)): 164 | if type(m.conv[idx]) == nn.Conv2d: 165 | torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True) 166 | 167 | 168 | def print_size_of_model(model): 169 | torch.save(model.state_dict(), "temp.p") 170 | print('Size (MB):', os.path.getsize("temp.p") / 1e6) 171 | os.remove('temp.p') 172 | 173 | 174 | if __name__ == "__main__": 175 | model = MobileNetV2().to(device) 176 | print(model) 177 | summary(model, input_size=(3, 128, 128)) 178 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | class HNetOptimizer(object): 2 | """A simple wrapper class for learning rate scheduling""" 3 | 4 | def __init__(self, optimizer): 5 | self.optimizer = optimizer 6 | self.lr = 0.005 7 | self.step_num = 0 8 | 9 | def zero_grad(self): 10 | self.optimizer.zero_grad() 11 | 12 | def step(self): 13 | self._update_lr() 14 | self.optimizer.step() 15 | 16 | def _update_lr(self): 17 | self.step_num += 1 18 | if self.step_num % 50000 == 0 and self.lr > 1e-5: 19 | self.lr = self.lr / 10 20 | for param_group in self.optimizer.param_groups: 21 | param_group['lr'] = self.lr 22 | -------------------------------------------------------------------------------- /pre_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | from numpy.linalg import inv 8 | from tqdm import tqdm 9 | 10 | from config import image_folder 11 | from config import train_file, valid_file, test_file 12 | 13 | 14 | def get_datum(img, test_image, size, rho, top_point, patch_size): 15 | left_point = (top_point[0], patch_size + top_point[1]) 16 | bottom_point = (patch_size + top_point[0], patch_size + top_point[1]) 17 | right_point = (patch_size + top_point[0], top_point[1]) 18 | four_points = [top_point, left_point, bottom_point, right_point] 19 | # print('top_point: ' + str(top_point)) 20 | # print('left_point: ' + str(left_point)) 21 | # print('bottom_point: ' + str(bottom_point)) 22 | # print('right_point: ' + str(right_point)) 23 | # print('four_points: ' + str(four_points)) 24 | 25 | perturbed_four_points = [] 26 | for point in four_points: 27 | perturbed_four_points.append((point[0] + random.randint(-rho, rho), point[1] + random.randint(-rho, rho))) 28 | 29 | H = cv.getPerspectiveTransform(np.float32(four_points), np.float32(perturbed_four_points)) 30 | H_inverse = inv(H) 31 | 32 | warped_image = cv.warpPerspective(img, H_inverse, size) 33 | 34 | # print('test_image.shape: ' + str(test_image.shape)) 35 | # print('warped_image.shape: ' + str(warped_image.shape)) 36 | 37 | Ip1 = test_image[top_point[1]:bottom_point[1], top_point[0]:bottom_point[0]] 38 | Ip2 = warped_image[top_point[1]:bottom_point[1], top_point[0]:bottom_point[0]] 39 | 40 | training_image = np.dstack((Ip1, Ip2)) 41 | # H_four_points = np.subtract(np.array(perturbed_four_points), np.array(four_points)) 42 | datum = (training_image, np.array(four_points), np.array(perturbed_four_points)) 43 | return datum 44 | 45 | 46 | ### This function is provided by Mez Gebre's repository "deep_homography_estimation" 47 | # https://github.com/mez/deep_homography_estimation 48 | # Dataset_Generation_Visualization.ipynb 49 | def process(files, is_test): 50 | if is_test: 51 | size = (640, 480) 52 | # Data gen parameters 53 | rho = 64 54 | patch_size = 256 55 | 56 | else: 57 | size = (320, 240) 58 | # Data gen parameters 59 | rho = 32 60 | patch_size = 128 61 | 62 | samples = [] 63 | for f in tqdm(files): 64 | fullpath = os.path.join(image_folder, f) 65 | img = cv.imread(fullpath, 0) 66 | img = cv.resize(img, size) 67 | test_image = img.copy() 68 | 69 | if not is_test: 70 | for top_point in [(0 + 32, 0 + 32), (128 + 32, 0 + 32), (0 + 32, 48 + 32), (128 + 32, 48 + 32), 71 | (64 + 32, 24 + 32)]: 72 | # top_point = (rho, rho) 73 | datum = get_datum(img, test_image, size, rho, top_point, patch_size) 74 | samples.append(datum) 75 | else: 76 | top_point = (rho, rho) 77 | datum = get_datum(img, test_image, size, rho, top_point, patch_size) 78 | samples.append(datum) 79 | 80 | return samples 81 | 82 | 83 | if __name__ == "__main__": 84 | files = [f for f in os.listdir(image_folder) if f.lower().endswith('.jpg')] 85 | np.random.shuffle(files) 86 | 87 | num_files = len(files) 88 | print('num_files: ' + str(num_files)) 89 | 90 | num_train_files = 100000 91 | num_valid_files = 8287 92 | num_test_files = 10000 93 | 94 | train_files = files[:num_train_files] 95 | valid_files = files[num_train_files:num_train_files + num_valid_files] 96 | test_files = files[num_train_files + num_valid_files:num_train_files + num_valid_files + num_test_files] 97 | 98 | train = process(train_files, False) 99 | valid = process(valid_files, False) 100 | test = process(test_files, True) 101 | 102 | print('num_train: ' + str(len(train))) 103 | print('num_valid: ' + str(len(valid))) 104 | print('num_test: ' + str(len(test))) 105 | 106 | with open(train_file, 'wb') as f: 107 | pickle.dump(train, f) 108 | with open(valid_file, 'wb') as f: 109 | pickle.dump(valid, f) 110 | with open(test_file, 'wb') as f: 111 | pickle.dump(test, f) 112 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from tqdm import tqdm 7 | 8 | from config import batch_size, num_workers 9 | from data_gen import DeepHNDataset 10 | from mobilenet_v2 import MobileNetV2 11 | from utils import AverageMeter 12 | 13 | device = torch.device('cpu') 14 | 15 | if __name__ == '__main__': 16 | filename = 'homonet.pt' 17 | 18 | print('loading {}...'.format(filename)) 19 | model = MobileNetV2() 20 | model.load_state_dict(torch.load(filename)) 21 | 22 | test_dataset = DeepHNDataset('test') 23 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, 24 | num_workers=num_workers) 25 | 26 | num_samples = len(test_dataset) 27 | 28 | # Loss function 29 | criterion = nn.L1Loss().to(device) 30 | losses = AverageMeter() 31 | elapsed = 0 32 | 33 | # Batches 34 | for (img, target) in tqdm(test_loader): 35 | # Move to CPU, if available 36 | # img = F.interpolate(img, size=(img.size(2) // 2, img.size(3) // 2), mode='bicubic', align_corners=False) 37 | img = img.to(device) # [N, 3, 128, 128] 38 | target = target.float().to(device) # [N, 8] 39 | 40 | # Forward prop. 41 | with torch.no_grad(): 42 | start = time.time() 43 | out = model(img) # [N, 8] 44 | end = time.time() 45 | elapsed = elapsed + (end - start) 46 | 47 | # Calculate loss 48 | out = out.squeeze(dim=1) 49 | loss = criterion(out * 2, target) 50 | 51 | losses.update(loss.item(), img.size(0)) 52 | 53 | print('Elapsed: {0:.5f} ms'.format(elapsed / num_samples * 1000)) 54 | print('Loss: {0:.2f}'.format(losses.avg)) 55 | -------------------------------------------------------------------------------- /test/000000523955.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foamliu/HomographyNet/fbe7f2271fd9686f641df49656e16002e6c1f9d1/test/000000523955.jpg -------------------------------------------------------------------------------- /test/test_gen_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 as cv 4 | import numpy as np 5 | from numpy.linalg import inv 6 | 7 | from test_orb import compute_homo 8 | 9 | rho = 64 10 | patch_size = 256 11 | top_point = (rho, rho) 12 | left_point = (patch_size + rho, rho) 13 | bottom_point = (patch_size + rho, patch_size + rho) 14 | right_point = (rho, patch_size + rho) 15 | four_points = [top_point, left_point, bottom_point, right_point] 16 | 17 | if __name__ == "__main__": 18 | # np.random.seed(7) 19 | # random.seed(7) 20 | fullpath = '000000523955.jpg' 21 | img = cv.imread(fullpath, 0) 22 | img = cv.resize(img, (640, 480)) 23 | test_image = img.copy() 24 | perturbed_four_points = [] 25 | for point in four_points: 26 | perturbed_four_points.append((point[0] + random.randint(-rho, rho), point[1] + random.randint(-rho, rho))) 27 | print('perturbed_four_points: ' + str(np.float32(perturbed_four_points))) 28 | H_four_points = np.subtract(np.array(perturbed_four_points), np.array(four_points)) 29 | print('H_four_points: ' + str(H_four_points)) 30 | 31 | H = cv.getPerspectiveTransform(np.float32(four_points), np.float32(perturbed_four_points)) 32 | print('H: ' + str(H)) 33 | H_inverse = inv(H) 34 | print('H_inverse: ' + str(H_inverse)) 35 | 36 | warped_image = cv.warpPerspective(img, H_inverse, (640, 480)) 37 | # warped_image = cv.warpPerspective(img, H, (640, 480)) 38 | 39 | Ip1 = test_image[top_point[1]:bottom_point[1], top_point[0]:bottom_point[0]] 40 | Ip2 = warped_image[top_point[1]:bottom_point[1], top_point[0]:bottom_point[0]] 41 | 42 | test_image = cv.polylines(test_image, [np.int32(perturbed_four_points)], True, 255, 3, cv.LINE_AA) 43 | warped_image = cv.polylines(warped_image, [np.int32(four_points)], True, 255, 3, cv.LINE_AA) 44 | 45 | Ip1_new = np.zeros((640, 480), np.uint8) 46 | Ip1_new[64:320, 64:320] = Ip1 47 | Ip2_new = np.zeros((640, 480), np.uint8) 48 | Ip2_new[64:320, 64:320] = Ip2 49 | 50 | # pred_H = compute_homo(Ip1, Ip2) 51 | pred_H = compute_homo(Ip2_new, Ip1_new) 52 | print('pred_H: ' + str(pred_H)) 53 | # inv_pred_H = inv(pred_H) 54 | # print('inv_pred_H: ' + str(inv_pred_H)) 55 | 56 | four_points = np.float32([four_points]) 57 | print('four_points.shape: ' + str(four_points.shape)) 58 | 59 | pred_four_pints = cv.perspectiveTransform(np.float32(four_points), pred_H) 60 | # pred_four_pints = np.dot(pred_H, np.float32(four_points)) 61 | # print('pred_four_pints: ' + str(np.float32(pred_four_pints))) 62 | 63 | Ip3 = cv.warpPerspective(Ip1_new, pred_H, (640, 480)) 64 | 65 | error = np.subtract(np.array(pred_four_pints), np.array(four_points)) 66 | error = np.abs(error).mean() 67 | print('MACE: ' + str(error)) 68 | 69 | cv.imshow('test_image', test_image) 70 | cv.imshow('warped_image', warped_image) 71 | cv.imshow('Ip1', Ip1) 72 | cv.imshow('Ip2', Ip2) 73 | cv.imshow('Ip3', Ip3) 74 | cv.waitKey(0) 75 | 76 | H_four_points = np.subtract(np.array(perturbed_four_points), np.array(four_points)) 77 | print(H_four_points) 78 | -------------------------------------------------------------------------------- /test/test_gen_train.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | 3 | from pre_process import get_datum 4 | 5 | if __name__ == "__main__": 6 | fullpath = '000000523955.jpg' 7 | size = (320, 240) 8 | rho = 32 9 | patch_size = 128 10 | 11 | img = cv.imread(fullpath, 0) 12 | img = cv.resize(img, size) 13 | test_image = img.copy() 14 | 15 | for top_point in [(0 + 32, 0 + 32), (128 + 32, 0 + 32), (0 + 32, 48 + 32), (128 + 32, 48 + 32), (64 + 32, 24 + 32)]: 16 | # top_point = (rho, rho) 17 | datum = get_datum(img, test_image, size, rho, top_point, patch_size) 18 | img1 = datum[0][:, :, 0] 19 | img2 = datum[0][:, :, 1] 20 | print('img1.shape: ' + str(img1.shape)) 21 | print('img2.shape: ' + str(img2.shape)) 22 | cv.imshow('img1', img1) 23 | cv.imshow('img2', img2) 24 | cv.waitKey(0) 25 | -------------------------------------------------------------------------------- /test_orb.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import cv2 4 | import numpy as np 5 | 6 | from config import print_freq 7 | from utils import AverageMeter 8 | 9 | MIN_MATCH_COUNT = 10 10 | 11 | 12 | def compute_homo(img1, img2, args): 13 | H = np.identity(3) 14 | if args.type == 'surf': 15 | try: 16 | # Initiate SIFT detector 17 | # sift = cv2.xfeatures2d.SIFT_create() 18 | sift = cv2.xfeatures2d.SURF_create() 19 | 20 | # find the keypoints and descriptors with SIFT 21 | kp1, des1 = sift.detectAndCompute(img1, None) 22 | kp2, des2 = sift.detectAndCompute(img2, None) 23 | 24 | FLANN_INDEX_KDTREE = 0 25 | index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) 26 | search_params = dict(checks=50) 27 | 28 | flann = cv2.FlannBasedMatcher(index_params, search_params) 29 | 30 | matches = flann.knnMatch(des1, des2, k=2) 31 | 32 | # store all the good matches as per Lowe's ratio test. 33 | good = [] 34 | for m, n in matches: 35 | if m.distance < 0.7 * n.distance: 36 | good.append(m) 37 | 38 | if len(good) > MIN_MATCH_COUNT: 39 | src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) 40 | dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) 41 | 42 | H, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) 43 | 44 | except cv2.error as err: 45 | print(err) 46 | 47 | elif args.type == 'identity': 48 | pass 49 | 50 | return H 51 | 52 | 53 | def compute_mace(H, perturbed_four_points): 54 | four_points = np.float32([[64, 64], [64, 320], [320, 320], [320, 64]]) 55 | # print('four_points: ' + str(four_points)) 56 | # print('perturbed_four_points: ' + str(perturbed_four_points)) 57 | # print(four_points.shape) 58 | # print(H) 59 | pred_four_pints = cv2.perspectiveTransform(np.array([four_points]), H) 60 | # print('predicted_four_pints: ' + str(pred_four_pints)) 61 | # print(pred_four_pints.shape) 62 | # print('predicted_four_pints.shape: ' + str(predicted_four_pints.shape)) 63 | error = np.subtract(pred_four_pints, perturbed_four_points) 64 | # print('error: ' + str(error)) 65 | mace = (np.abs(error)).mean() 66 | return mace 67 | 68 | 69 | def test(args): 70 | filename = 'data/test.pkl' 71 | with open(filename, 'rb') as file: 72 | samples = pickle.load(file) 73 | 74 | mace_list = [] 75 | maces = AverageMeter() 76 | for i, sample in enumerate(samples): 77 | image, four_points, perturbed_four_points = sample 78 | img1 = np.zeros((640, 480), np.uint8) 79 | img1[64:320, 64:320] = image[:, :, 0] 80 | img2 = np.zeros((640, 480), np.uint8) 81 | img2[64:320, 64:320] = image[:, :, 1] 82 | 83 | H = compute_homo(img2, img1, args) 84 | try: 85 | mace = compute_mace(H, perturbed_four_points) 86 | mace_list.append(mace) 87 | maces.update(mace) 88 | except cv2.error as err: 89 | print(err) 90 | if i % print_freq == 0: 91 | print('[{0}/{1}]\tMean Average Corner Error {mace.val:.5f} ({mace.avg:.5f})'.format(i, len(samples), 92 | mace=maces)) 93 | 94 | print('MSE: {:5f}'.format(np.mean(mace_list))) 95 | print('len(mse_list): ' + str(len(mace_list))) 96 | 97 | 98 | def parse_args(): 99 | parser = argparse.ArgumentParser(description='Test with SURF+RANSAC') 100 | # general 101 | parser.add_argument('--type', type=str, default='surf', help='surf or identity') 102 | args = parser.parse_args() 103 | return args 104 | 105 | 106 | if __name__ == "__main__": 107 | args = parse_args() 108 | test(args) 109 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | from config import device, grad_clip, print_freq, num_workers 7 | from data_gen import DeepHNDataset 8 | from mobilenet_v2 import MobileNetV2 9 | from optimizer import HNetOptimizer 10 | from utils import parse_args, save_checkpoint, AverageMeter, clip_gradient, get_logger 11 | 12 | 13 | def train_net(args): 14 | torch.manual_seed(7) 15 | np.random.seed(7) 16 | checkpoint = args.checkpoint 17 | start_epoch = 0 18 | best_loss = float('inf') 19 | writer = SummaryWriter() 20 | epochs_since_improvement = 0 21 | 22 | # Initialize / load checkpoint 23 | if checkpoint is None: 24 | model = MobileNetV2() 25 | model = nn.DataParallel(model) 26 | 27 | optimizer = HNetOptimizer(torch.optim.Adam(model.parameters(), lr=args.lr)) 28 | 29 | else: 30 | checkpoint = torch.load(checkpoint) 31 | start_epoch = checkpoint['epoch'] + 1 32 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 33 | model = checkpoint['model'] 34 | optimizer = checkpoint['optimizer'] 35 | 36 | logger = get_logger() 37 | 38 | # Move to GPU, if available 39 | model = model.to(device) 40 | 41 | # Loss function 42 | criterion = nn.MSELoss().to(device) 43 | 44 | # Custom dataloaders 45 | train_dataset = DeepHNDataset('train') 46 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 47 | num_workers=num_workers) 48 | valid_dataset = DeepHNDataset('valid') 49 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, 50 | num_workers=num_workers) 51 | 52 | # Epochs 53 | for epoch in range(start_epoch, args.end_epoch): 54 | model.zero_grad() 55 | # One epoch's training 56 | train_loss = train(train_loader=train_loader, 57 | model=model, 58 | criterion=criterion, 59 | optimizer=optimizer, 60 | epoch=epoch, 61 | logger=logger) 62 | 63 | writer.add_scalar('model/train_loss', train_loss, epoch) 64 | writer.add_scalar('model/learning_rate', optimizer.lr, epoch) 65 | print('\nCurrent effective learning rate: {}\n'.format(optimizer.lr)) 66 | 67 | # One epoch's validation 68 | valid_loss = valid(valid_loader=valid_loader, 69 | model=model, 70 | criterion=criterion, 71 | logger=logger) 72 | 73 | writer.add_scalar('model/valid_loss', valid_loss, epoch) 74 | 75 | # Check if there was an improvement 76 | is_best = valid_loss < best_loss 77 | best_loss = min(valid_loss, best_loss) 78 | if not is_best: 79 | epochs_since_improvement += 1 80 | print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,)) 81 | else: 82 | epochs_since_improvement = 0 83 | 84 | # Save checkpoint 85 | save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best) 86 | 87 | 88 | def train(train_loader, model, criterion, optimizer, epoch, logger): 89 | model.train() # train mode (dropout and batchnorm is used) 90 | 91 | losses = AverageMeter() 92 | 93 | # Batches 94 | for i, (img, target) in enumerate(train_loader): 95 | # Move to GPU, if available 96 | img = img.to(device) 97 | target = target.float().to(device) # [N, 8] 98 | 99 | # Forward prop. 100 | out = model(img) # [N, 8] 101 | out = out.squeeze(dim=1) 102 | 103 | # Calculate loss 104 | loss = criterion(out, target) 105 | 106 | # Back prop. 107 | optimizer.zero_grad() 108 | loss.backward() 109 | 110 | # Clip gradients 111 | # clip_gradient(optimizer, grad_clip) 112 | 113 | # Update weights 114 | optimizer.step() 115 | 116 | # Keep track of metrics 117 | losses.update(loss.item()) 118 | 119 | # Print status 120 | if i % print_freq == 0: 121 | if i % print_freq == 0: 122 | status = 'Epoch: [{0}][{1}/{2}]\t' \ 123 | 'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(epoch, i, 124 | len(train_loader), 125 | loss=losses, 126 | 127 | ) 128 | logger.info(status) 129 | 130 | return losses.avg 131 | 132 | 133 | def valid(valid_loader, model, criterion, logger): 134 | model.eval() # eval mode (dropout and batchnorm is NOT used) 135 | 136 | losses = AverageMeter() 137 | 138 | # Batches 139 | for i, (img, target) in enumerate(valid_loader): 140 | # Move to GPU, if available 141 | img = img.to(device) 142 | target = target.float().to(device) 143 | 144 | # Forward prop. 145 | out = model(img) 146 | out = out.squeeze(dim=1) 147 | 148 | # Calculate loss 149 | loss = criterion(out, target) 150 | 151 | # Keep track of metrics 152 | losses.update(loss.item()) 153 | 154 | # Print status 155 | status = 'Validation\t Loss {loss.avg:.5f}\n'.format(loss=losses) 156 | logger.info(status) 157 | 158 | return losses.avg 159 | 160 | 161 | def main(): 162 | global args 163 | args = parse_args() 164 | train_net(args) 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | 5 | import torch 6 | 7 | 8 | def clip_gradient(optimizer, grad_clip): 9 | """ 10 | Clips gradients computed during backpropagation to avoid explosion of gradients. 11 | :param optimizer: optimizer with the gradients to be clipped 12 | :param grad_clip: clip value 13 | """ 14 | for group in optimizer.param_groups: 15 | for param in group['params']: 16 | if param.grad is not None: 17 | param.grad.data.clamp_(-grad_clip, grad_clip) 18 | 19 | 20 | def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, loss, is_best): 21 | state = {'epoch': epoch, 22 | 'epochs_since_improvement': epochs_since_improvement, 23 | 'loss': loss, 24 | 'model': model, 25 | 'optimizer': optimizer} 26 | # filename = 'checkpoint_' + str(epoch) + '_' + str(loss) + '.tar' 27 | filename = 'checkpoint.tar' 28 | torch.save(state, filename) 29 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 30 | if is_best: 31 | torch.save(state, 'BEST_checkpoint.tar') 32 | 33 | 34 | class AverageMeter(object): 35 | """ 36 | Keeps track of most recent, average, sum, and count of a metric. 37 | """ 38 | 39 | def __init__(self): 40 | self.reset() 41 | 42 | def reset(self): 43 | self.val = 0 44 | self.avg = 0 45 | self.sum = 0 46 | self.count = 0 47 | 48 | def update(self, val, n=1): 49 | self.val = val 50 | self.sum += val * n 51 | self.count += n 52 | self.avg = self.sum / self.count 53 | 54 | 55 | class LossMeterBag(object): 56 | 57 | def __init__(self, name_list): 58 | self.meter_dict = dict() 59 | self.name_list = name_list 60 | for name in self.name_list: 61 | self.meter_dict[name] = AverageMeter() 62 | 63 | def update(self, val_list): 64 | for i, name in enumerate(self.name_list): 65 | val = val_list[i] 66 | self.meter_dict[name].update(val) 67 | 68 | def __str__(self): 69 | ret = '' 70 | for name in self.name_list: 71 | ret += '{0}:\t {1:.4f}({2:.4f})\t'.format(name, self.meter_dict[name].val, self.meter_dict[name].avg) 72 | 73 | return ret 74 | 75 | 76 | def adjust_learning_rate(optimizer, shrink_factor): 77 | """ 78 | Shrinks learning rate by a specified factor. 79 | :param optimizer: optimizer whose learning rate must be shrunk. 80 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 81 | """ 82 | 83 | print("\nDECAYING learning rate.") 84 | for param_group in optimizer.param_groups: 85 | param_group['lr'] = param_group['lr'] * shrink_factor 86 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 87 | 88 | 89 | def get_learning_rate(optimizer): 90 | return optimizer.param_groups[0]['lr'] 91 | 92 | 93 | def accuracy(pred, target): 94 | batch_size = pred.size(0) 95 | correct = [] 96 | for i in range(batch_size): 97 | if math.fabs(pred[i].item() - target[i].item()) < 0.5: 98 | correct += [1.0] 99 | # correct = torch.abs(pred - target).lt(0.5) 100 | # correct_total = correct.view(-1).float().sum() # 0D tensor 101 | correct_total = sum(correct) 102 | # return correct_total.item() * (100.0 / batch_size) 103 | return correct_total * (100.0 / batch_size) 104 | 105 | 106 | def parse_args(): 107 | parser = argparse.ArgumentParser(description='Train face network') 108 | # general 109 | parser.add_argument('--end-epoch', type=int, default=1000, help='training epoch size.') 110 | parser.add_argument('--lr', type=float, default=0.005, help='start learning rate') 111 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 112 | parser.add_argument('--batch-size', type=int, default=64, help='batch size in each context') 113 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 114 | args = parser.parse_args() 115 | return args 116 | 117 | 118 | def get_logger(): 119 | logger = logging.getLogger() 120 | handler = logging.StreamHandler() 121 | formatter = logging.Formatter("%(asctime)s %(levelname)s \t%(message)s") 122 | handler.setFormatter(formatter) 123 | logger.addHandler(handler) 124 | logger.setLevel(logging.INFO) 125 | return logger 126 | 127 | 128 | def ensure_folder(folder): 129 | import os 130 | if not os.path.isdir(folder): 131 | os.mkdir(folder) 132 | --------------------------------------------------------------------------------