├── README.md ├── __pycache__ └── matching.cpython-36.pyc ├── data └── synth_dataset.py ├── datasets ├── src145.jpg ├── src15.jpg ├── src221.jpg ├── src339.jpg ├── src440.jpg ├── src45.jpg ├── src61.jpg ├── src76.jpg ├── src80.jpg ├── tgt145.jpg ├── tgt15.jpg ├── tgt221.jpg ├── tgt339.jpg ├── tgt440.jpg ├── tgt45.jpg ├── tgt61.jpg ├── tgt76.jpg └── tgt80.jpg ├── demo.py ├── geotnf ├── __pycache__ │ └── transformation.cpython-36.pyc └── transformation.py ├── gui.py ├── image ├── __pycache__ │ └── normalization.cpython-36.pyc └── normalization.py ├── matching.py ├── model ├── __pycache__ │ └── cnn_geometric_model.cpython-36.pyc └── cnn_geometric_model.py └── util ├── __pycache__ └── torch_util.cpython-36.pyc └── torch_util.py /README.md: -------------------------------------------------------------------------------- 1 | ## A Robust Matching Network for Gradually Estimating Geometric Transformation on Remote Sensing Imagery 2 | 3 | We propose a matching network for gradually estimating the geometric transformation parameters between two aerial images taken in the same area but in different environments. To precisely matching two aerial images, there are important factors to consider such as different time, a variation of viewpoint, size, and rotation. 4 | 5 | This paper has been accepted in SMC 2019. [[Paper](https://github.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/files/7631074/A.Robust.Matching.Network.for.Gradually.Estimating.Geometric.Transformation.on.Remote.Sensing.Imagery.pdf)] 6 | [[Online info](https://ieeexplore.ieee.org/document/8913881)] 7 | 8 | 9 | ## Requirements 10 | ``` 11 | python==3.6.8 12 | torch==1.0.1 13 | torchvision=0.2.2 14 | PyQt5==5.14 (for gui) 15 | opencv==3.4.1 (for gui) 16 | ``` 17 | ## Run 18 | ``` 19 | python demo.py 20 | python gui.py (recommended) 21 | ``` 22 | ## Trained models 23 | 24 | Save the files in ./trained_models folder 25 | 26 | Download link : [ResNet models](https://drive.google.com/file/d/1au049oWWxio9Pgowo4Rias9knL_yiNth/view?usp=sharing, "trained models link") 27 | 28 | ## Citation 29 | ``` 30 | @inproceedings{kim2019robust, 31 | title={A Robust Matching Network for Gradually Estimating Geometric Transformation on Remote Sensing Imagery}, 32 | author={Kim, Dong-Geon and Nam, Woo-Jeoung and Lee, Seong-Whan}, 33 | booktitle={2019 IEEE International Conference on Systems, Man and Cybernetics (SMC)}, 34 | pages={3889--3894}, 35 | year={2019}, 36 | organization={IEEE} 37 | } 38 | ``` 39 | 40 | ## Contact 41 | 42 | If you need more performance, further research or issues, please feel free to contact me anytime. 43 | 44 | I currently research the NAS (Nerual Architecture Search) to create a network suitable for specific image domain. 45 | 46 | E-mail : 47 | 48 | 49 | ## Screenshot for running gui.py 50 | 51 | Support real-time matching and overlay. 52 | 53 |

54 | 55 | 56 | -------------------------------------------------------------------------------- /__pycache__/matching.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/__pycache__/matching.cpython-36.pyc -------------------------------------------------------------------------------- /data/synth_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import os 4 | from os.path import exists, join, basename 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | from numpy.linalg import inv 9 | from torch.utils.data import Dataset 10 | from geotnf.transformation import GeometricTnf 11 | from torch.autograd import Variable 12 | 13 | class SynthDataset(Dataset): 14 | """ 15 | 16 | Synthetically transformed pairs dataset for training with strong supervision 17 | 18 | Args: 19 | csv_file (string): Path to the csv file with image names and transformations. 20 | training_image_path (string): Directory with all the images. 21 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 22 | 23 | Returns: 24 | Dict: {'image': full dataset image, 'theta': desired transformation} 25 | 26 | """ 27 | # Test : 240 28 | # Training : 1080 29 | # def __init__(self, csv_file, training_image_path, output_size=(240,240), geometric_model='affine', transform=None, 30 | # def __init__(self, csv_file, training_image_path, output_size=(1080,1080), geometric_model='affine', transform=None, 31 | def __init__(self, csv_file, training_image_path, output_size=(540,540), geometric_model='affine', transform=None, 32 | # def __init__(self, csv_file, training_image_path, output_size=(2160,2160), geometric_model='affine', transform=None, 33 | random_sample=False, random_t=0.5, random_s=0.5, random_alpha=1/6, random_t_tps=0.4): # Original random_s=0.5, random_alpha = 1/6 34 | self.random_sample = random_sample 35 | self.random_t = random_t 36 | self.random_t_tps = random_t_tps 37 | self.random_alpha = random_alpha 38 | self.random_s = random_s 39 | self.out_h, self.out_w = output_size 40 | # read csv file 41 | self.train_data = pd.read_csv(csv_file) 42 | self.img_names = self.train_data.iloc[:,0] 43 | self.img_names2 = self.train_data.iloc[:,1] 44 | self.theta_array = self.train_data.iloc[:, 2:].as_matrix().astype('float') 45 | # copy arguments 46 | self.training_image_path = training_image_path 47 | self.transform = transform 48 | self.geometric_model = geometric_model 49 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 50 | # Ready for distance 51 | grid_size = 20 52 | axis_coords = np.linspace(-1, 1, grid_size) 53 | self.N = grid_size * grid_size 54 | X, Y = np.meshgrid(axis_coords, axis_coords) 55 | X = np.reshape(X, (1, self.N)) 56 | Y = np.reshape(Y, (1, self.N)) 57 | self.P = np.concatenate((X, Y)) 58 | # Check tilt 59 | self.cosVal = [1, 0, -1, 0] 60 | self.sinVal = [0, 1, 0, -1] 61 | self.tilt = [1, 1 / np.cos(7/18 * np.pi)] 62 | 63 | def __len__(self): 64 | return len(self.train_data) 65 | 66 | def __getitem__(self, idx): 67 | # read image 68 | img_name = os.path.join(self.training_image_path, self.img_names[idx]) 69 | image = io.imread(img_name) 70 | 71 | img_name2 = os.path.join(self.training_image_path, self.img_names2[idx]) 72 | image2 = io.imread(img_name2) 73 | 74 | # read theta 75 | if self.random_sample==False: 76 | theta = self.theta_array[idx, :] 77 | 78 | if self.geometric_model=='affine': 79 | # reshape theta to 2x3 matrix [A|t] where 80 | # first row corresponds to X and second to Y 81 | theta = theta[[3,2,5,1,0,4]].reshape(2,3) 82 | elif self.geometric_model=='tps': 83 | theta = np.expand_dims(np.expand_dims(theta,1),2) 84 | else: 85 | if self.geometric_model=='affine': 86 | alpha = (np.random.rand(1)-0.5)*2*np.pi*self.random_alpha 87 | theta = np.random.rand(6) 88 | theta[[2,5]]=(theta[[2,5]]-0.5)*2*self.random_t 89 | theta[0]=(1+(theta[0]-0.5)*2*self.random_s)*np.cos(alpha) 90 | theta[1]=(1+(theta[1]-0.5)*2*self.random_s)*(-np.sin(alpha)) 91 | theta[3]=(1+(theta[3]-0.5)*2*self.random_s)*np.sin(alpha) 92 | theta[4]=(1+(theta[4]-0.5)*2*self.random_s)*np.cos(alpha) 93 | theta = theta.reshape(2,3) 94 | if self.geometric_model=='tps': 95 | theta = np.array([-1 , -1 , -1 , 0 , 0 , 0 , 1 , 1 , 1 , -1 , 0 , 1 , -1 , 0 , 1 , -1 , 0 , 1]) 96 | theta = theta+(np.random.rand(18)-0.5)*2*self.random_t_tps 97 | 98 | theta_Discrete = {} 99 | cnt = 1 100 | 101 | for i in range(len(self.tilt)): 102 | for k in range(len(self.cosVal)): 103 | value = np.dot(np.array([[self.tilt[i], 0], [0, 1]]), 104 | np.array([[self.cosVal[k], -self.sinVal[k]], [self.sinVal[k], self.cosVal[k]]])) 105 | theta_Discrete['case' + str(cnt)] = value 106 | cnt += 1 107 | 108 | min_error = 0 109 | flag = 0 110 | case_cnt = 1 111 | for k in theta_Discrete.keys(): 112 | warped_points = np.dot(theta[:, :2], self.P) 113 | warped_points2 = np.dot(theta_Discrete[k][:, :2], self.P) 114 | 115 | error = np.sum((warped_points[0, :] - warped_points2[0, :]) ** 2 + ( 116 | warped_points[1, :] - warped_points2[1, :]) ** 2) / len(self.P[1]) 117 | if case_cnt == 1: 118 | min_error = error 119 | flag = case_cnt 120 | else: 121 | if min_error > error: 122 | min_error = error 123 | flag = case_cnt 124 | case_cnt += 1 125 | 126 | # make arrays float tensor for subsequent processing 127 | image = torch.Tensor(image.astype(np.float32)) 128 | image2 = torch.Tensor(image2.astype(np.float32)) 129 | theta = torch.Tensor(theta.astype(np.float32)) 130 | 131 | # permute order of image to CHW 132 | image = image.transpose(1,2).transpose(0,1) 133 | image2 = image2.transpose(1,2).transpose(0,1) 134 | 135 | # Resize image using bilinear sampling with identity affine tnf 136 | if image.size()[0]!=self.out_h or image.size()[1]!=self.out_w: 137 | image = self.affineTnf(Variable(image.unsqueeze(0),requires_grad=False),None).data.squeeze(0) 138 | image2 = self.affineTnf(Variable(image2.unsqueeze(0),requires_grad=False),None).data.squeeze(0) 139 | 140 | sample = {'image': image, 'image2': image2, 'theta': theta} 141 | if self.transform: 142 | sample = self.transform(sample) 143 | 144 | return sample 145 | 146 | 147 | class SynthDataset2(Dataset): 148 | """ 149 | 150 | Synthetically transformed pairs dataset for training with strong supervision 151 | 152 | Args: 153 | csv_file (string): Path to the csv file with image names and transformations. 154 | training_image_path (string): Directory with all the images. 155 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 156 | 157 | Returns: 158 | Dict: {'image': full dataset image, 'theta': desired transformation} 159 | 160 | """ 161 | 162 | # Test : 240 163 | # Training : 1080 164 | def __init__(self, csv_file, training_image_path, output_size=(240, 240), geometric_model='affine', transform=None, # 1080, 1080 165 | random_sample=False, random_t=0.5, random_s=0.5, random_alpha=1 / 6, 166 | random_t_tps=0.4): 167 | self.random_sample = random_sample 168 | self.random_t = random_t 169 | self.random_t_tps = random_t_tps 170 | self.random_alpha = random_alpha 171 | self.random_s = random_s 172 | self.out_h, self.out_w = output_size 173 | # read csv file 174 | self.train_data = pd.read_csv(csv_file) 175 | self.img_names = self.train_data.iloc[:, 0] 176 | self.img_names2 = self.train_data.iloc[:, 1] 177 | self.theta_array = self.train_data.iloc[:, 2:].as_matrix().astype('float') 178 | # copy arguments 179 | self.training_image_path = training_image_path 180 | self.transform = transform 181 | self.geometric_model = geometric_model 182 | self.affineTnf = GeometricTnf(out_h=self.out_h, out_w=self.out_w, use_cuda=False) 183 | # Ready for distance 184 | grid_size = 20 185 | axis_coords = np.linspace(-1, 1, grid_size) 186 | self.N = grid_size * grid_size 187 | X, Y = np.meshgrid(axis_coords, axis_coords) 188 | X = np.reshape(X, (1, self.N)) 189 | Y = np.reshape(Y, (1, self.N)) 190 | self.P = np.concatenate((X, Y)) 191 | # Check tilt 192 | self.cosVal = [1, 0, -1, 0] 193 | self.sinVal = [0, 1, 0, -1] 194 | self.tilt = [1, 1 / np.cos(7 / 18 * np.pi)] 195 | 196 | def __len__(self): 197 | return len(self.train_data) 198 | 199 | def __getitem__(self, idx): 200 | # read image 201 | # img_name = os.path.join() 202 | image = io.imread(self.training_image_path) 203 | # img_name2 = os.path.join() 204 | image2 = io.imread(self.training_image_path) 205 | 206 | # read theta 207 | if self.random_sample == False: 208 | theta = self.theta_array[idx, :] 209 | 210 | if self.geometric_model == 'affine': 211 | # reshape theta to 2x3 matrix [A|t] where 212 | # first row corresponds to X and second to Y 213 | theta = theta[[3, 2, 5, 1, 0, 4]].reshape(2, 3) 214 | elif self.geometric_model == 'tps': 215 | theta = np.expand_dims(np.expand_dims(theta, 1), 2) 216 | else: 217 | if self.geometric_model == 'affine': 218 | alpha = (np.random.rand(1) - 0.5) * 2 * np.pi * self.random_alpha 219 | theta = np.random.rand(6) 220 | theta[[2, 5]] = (theta[[2, 5]] - 0.5) * 2 * self.random_t 221 | theta[0] = (1 + (theta[0] - 0.5) * 2 * self.random_s) * np.cos(alpha) 222 | theta[1] = (1 + (theta[1] - 0.5) * 2 * self.random_s) * (-np.sin(alpha)) 223 | theta[3] = (1 + (theta[3] - 0.5) * 2 * self.random_s) * np.sin(alpha) 224 | theta[4] = (1 + (theta[4] - 0.5) * 2 * self.random_s) * np.cos(alpha) 225 | theta = theta.reshape(2, 3) 226 | if self.geometric_model == 'tps': 227 | theta = np.array([-1, -1, -1, 0, 0, 0, 1, 1, 1, -1, 0, 1, -1, 0, 1, -1, 0, 1]) 228 | theta = theta + (np.random.rand(18) - 0.5) * 2 * self.random_t_tps 229 | 230 | theta_Discrete = {} 231 | cnt = 1 232 | 233 | for i in range(len(self.tilt)): 234 | for k in range(len(self.cosVal)): 235 | value = np.dot(np.array([[self.tilt[i], 0], [0, 1]]), 236 | np.array([[self.cosVal[k], -self.sinVal[k]], [self.sinVal[k], self.cosVal[k]]])) 237 | theta_Discrete['case' + str(cnt)] = value 238 | cnt += 1 239 | 240 | min_error = 0 241 | flag = 0 242 | case_cnt = 1 243 | for k in theta_Discrete.keys(): 244 | warped_points = np.dot(theta[:, :2], self.P) 245 | warped_points2 = np.dot(theta_Discrete[k][:, :2], self.P) 246 | 247 | error = np.sum((warped_points[0, :] - warped_points2[0, :]) ** 2 + ( 248 | warped_points[1, :] - warped_points2[1, :]) ** 2) / len(self.P[1]) 249 | if case_cnt == 1: 250 | min_error = error 251 | flag = case_cnt 252 | else: 253 | if min_error > error: 254 | min_error = error 255 | flag = case_cnt 256 | case_cnt += 1 257 | 258 | # make arrays float tensor for subsequent processing 259 | image = torch.Tensor(image.astype(np.float32)) 260 | image2 = torch.Tensor(image2.astype(np.float32)) 261 | theta = torch.Tensor(theta.astype(np.float32)) 262 | 263 | # permute order of image to CHW 264 | image = image.transpose(1, 2).transpose(0, 1) 265 | image2 = image2.transpose(1, 2).transpose(0, 1) 266 | 267 | # Resize image using bilinear sampling with identity affine tnf 268 | if image.size()[0] != self.out_h or image.size()[1] != self.out_w: 269 | image = self.affineTnf(Variable(image.unsqueeze(0), requires_grad=False), None).data.squeeze(0) 270 | image2 = self.affineTnf(Variable(image2.unsqueeze(0), requires_grad=False), None).data.squeeze(0) 271 | 272 | sample = {'image': image, 'image2': image2, 'theta': theta} 273 | 274 | if self.transform: 275 | sample = self.transform(sample) 276 | 277 | return sample -------------------------------------------------------------------------------- /datasets/src145.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src145.jpg -------------------------------------------------------------------------------- /datasets/src15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src15.jpg -------------------------------------------------------------------------------- /datasets/src221.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src221.jpg -------------------------------------------------------------------------------- /datasets/src339.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src339.jpg -------------------------------------------------------------------------------- /datasets/src440.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src440.jpg -------------------------------------------------------------------------------- /datasets/src45.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src45.jpg -------------------------------------------------------------------------------- /datasets/src61.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src61.jpg -------------------------------------------------------------------------------- /datasets/src76.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src76.jpg -------------------------------------------------------------------------------- /datasets/src80.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/src80.jpg -------------------------------------------------------------------------------- /datasets/tgt145.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt145.jpg -------------------------------------------------------------------------------- /datasets/tgt15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt15.jpg -------------------------------------------------------------------------------- /datasets/tgt221.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt221.jpg -------------------------------------------------------------------------------- /datasets/tgt339.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt339.jpg -------------------------------------------------------------------------------- /datasets/tgt440.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt440.jpg -------------------------------------------------------------------------------- /datasets/tgt45.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt45.jpg -------------------------------------------------------------------------------- /datasets/tgt61.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt61.jpg -------------------------------------------------------------------------------- /datasets/tgt76.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt76.jpg -------------------------------------------------------------------------------- /datasets/tgt80.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/datasets/tgt80.jpg -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | from torchvision.transforms import Normalize 6 | 7 | from model.cnn_geometric_model import CNNGeometricPearson 8 | from image.normalization import NormalizeImageDict, normalize_image 9 | from util.torch_util import BatchTensorToVars, str_to_bool 10 | # from util.checkboard import createCheckBoard 11 | from geotnf.transformation import GeometricTnf 12 | # from geotnf.point_tnf import * 13 | import matplotlib.pyplot as plt 14 | from skimage import io 15 | import cv2 16 | import numpy as np 17 | import warnings 18 | from collections import OrderedDict 19 | 20 | import pickle 21 | from functools import partial 22 | 23 | import time 24 | start_time = time.time() 25 | 26 | warnings.filterwarnings('ignore') 27 | 28 | # torch.cuda.set_device(1) 29 | 30 | ### Parameter 31 | feature_extraction_cnn = 'resnet101' 32 | 33 | if feature_extraction_cnn=='vgg': 34 | model_homo_path = '' 35 | elif feature_extraction_cnn=='resnet101': 36 | model_aff_path = 'trained_models/resnet36_myproc_1_new_cor_fefr_4p5.pth.tar' 37 | model_aff_path2 = 'trained_models/resnet101_epo81_lr4p4_rm11.pth.tar' 38 | 39 | target_image_path='datasets/tgt15.jpg' 40 | source_image_path='datasets/src15.jpg' 41 | 42 | ### Load models 43 | use_cuda = torch.cuda.is_available() 44 | do_aff = not model_aff_path2 == ''\ 45 | 46 | # Create model 47 | print('Creating CNN model...') 48 | if do_aff: 49 | model_aff = CNNGeometricPearson(use_cuda=use_cuda, geometric_model='affine', feature_extraction_cnn=feature_extraction_cnn)\ 50 | 51 | pickle.load = partial(pickle.load, encoding="latin1") 52 | pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1") 53 | 54 | # Load trained weights 55 | print('Loading trained model weights...') 56 | if do_aff: 57 | checkpoint = torch.load(model_aff_path, map_location=lambda storage, loc: storage) 58 | checkpoint2 = torch.load(model_aff_path2, map_location=lambda storage, loc: storage) 59 | model_dict = model_aff.FeatureExtraction.state_dict() 60 | for name, param in model_dict.items(): 61 | model_dict[name].copy_(checkpoint['state_dict'][ 62 | 'FeatureExtraction.' + name]) 63 | model_dict = model_aff.FeatureClassification.state_dict() 64 | for name, param in model_dict.items(): 65 | model_dict[name].copy_(checkpoint['state_dict'][ 66 | 'FeatureClassification.' + name]) 67 | model_dict = model_aff.FeatureExtraction2.state_dict() 68 | for name, param in model_dict.items(): 69 | model_dict[name].copy_(checkpoint2['state_dict'][ 70 | 'FeatureExtraction.' + name]) 71 | model_dict = model_aff.FeatureRegression.state_dict() 72 | for name, param in model_dict.items(): 73 | model_dict[name].copy_(checkpoint2['state_dict'][ 74 | 'FeatureRegression.' + name]) 75 | ### Create image transformers 76 | affTnf = GeometricTnf(geometric_model='affine', use_cuda=use_cuda) 77 | 78 | ### Load and preprocess images 79 | resizeCNN = GeometricTnf(out_h=240, out_w=240, use_cuda=False) 80 | affTnf_origin = GeometricTnf(out_h=1080, out_w=1080, use_cuda=False) 81 | affTnf_Demo = GeometricTnf(out_h=540, out_w=540, use_cuda=False) 82 | normalizeTnf = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 83 | 84 | def Im2Tensor(image): 85 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 86 | image = torch.Tensor(image.astype(np.float32) / 255.0) 87 | image_var = Variable(image, requires_grad=False) 88 | 89 | if use_cuda: 90 | image_var = image_var.cuda() 91 | return image_var 92 | 93 | def preprocess_image(image): 94 | # convert to torch Variable 95 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 96 | image = torch.Tensor(image.astype(np.float32) / 255.0) 97 | image_var = Variable(image, requires_grad=False) 98 | 99 | # Resize image using bilinear sampling with identity affine tnf 100 | image_var = resizeCNN(image_var) 101 | 102 | # Normalize image 103 | image_var = normalize_image(image_var) 104 | 105 | return image_var 106 | 107 | def preprocess_image_Demo(image): 108 | # convert to torch Variable 109 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 110 | image = torch.Tensor(image.astype(np.float32) / 255.0) 111 | image_var = Variable(image, requires_grad=False) 112 | 113 | # Resize image using bilinear sampling with identity affine tnf 114 | image_var = affTnf_Demo(image_var) 115 | 116 | # Normalize image 117 | image_var = normalize_image(image_var) 118 | 119 | return image_var 120 | 121 | def preprocess_image_Origin(image): 122 | # convert to torch Variable 123 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 124 | image = torch.Tensor(image.astype(np.float32) / 255.0) 125 | image_var = Variable(image, requires_grad=False) 126 | 127 | # Resize image using bilinear sampling with identity affine tnf 128 | image_var = affTnf_origin(image_var) 129 | 130 | # Normalize image 131 | image_var = normalize_image(image_var) 132 | 133 | return image_var 134 | 135 | source_image = io.imread(source_image_path) 136 | target_image = io.imread(target_image_path) 137 | 138 | source_image_var = preprocess_image(source_image) 139 | source_image_var_orgin = preprocess_image_Origin(source_image) 140 | source_image_var_demo = preprocess_image_Demo(source_image) 141 | target_image_var = preprocess_image(target_image) 142 | target_image = np.float32(target_image/255.) 143 | 144 | if use_cuda: 145 | source_image_var = source_image_var.cuda() 146 | source_image_var_demo = source_image_var_demo.cuda() 147 | source_image_var_orgin = source_image_var_orgin.cuda() 148 | target_image_var = target_image_var.cuda() 149 | 150 | batch = {'source_image': source_image_var, 'target_image':target_image_var, 'source_image_demo':source_image_var_demo, 'origin_image':source_image_var_orgin} 151 | 152 | resizeTgt = GeometricTnf(out_h=target_image.shape[0], out_w=target_image.shape[1], use_cuda = use_cuda) 153 | resizeTgt_demo = GeometricTnf(out_h=540, out_w=540, use_cuda = use_cuda) 154 | 155 | ### Evaluate model 156 | if do_aff: 157 | model_aff.eval() 158 | 159 | # Evaluate models 160 | if do_aff: 161 | theta_aff = model_aff(batch) 162 | warped_image_aff = affTnf(batch['source_image'], theta_aff.view(-1, 2, 3)) 163 | 164 | ### Process result 165 | if do_aff: 166 | result_aff = affTnf(Im2Tensor(source_image), theta_aff.view(-1,2,3)) 167 | warped_image_aff_np = resizeTgt(result_aff).squeeze(0).transpose(0,1).transpose(1,2).cpu().detach().numpy() 168 | # io.imsave('results/aff.jpg', warped_image_aff_np) 169 | result_aff_demo = affTnf_Demo(Im2Tensor(source_image), theta_aff.view(-1,2,3)) 170 | warped_image_aff_np_demo = resizeTgt_demo(result_aff_demo).squeeze(0).transpose(0,1).transpose(1,2).cpu().detach().numpy() 171 | io.imsave('aff_demo.jpg', warped_image_aff_np_demo) 172 | 173 | print() 174 | print("# ====================================== #") 175 | print("# #") 176 | print("# - %.4s seconds - #" %(time.time() - start_time)) 177 | print("# ====================================== #") 178 | 179 | # Create checkboard 180 | if do_aff: 181 | aff_checkboard = createCheckBoard(warped_image_aff_np, target_image) 182 | io.imsave('aff_checkboard.jpg', aff_checkboard) 183 | 184 | 185 | N_subplots = 3 186 | fig, axs = plt.subplots(1, N_subplots) 187 | axs[0].imshow(source_image) 188 | axs[0].set_title('src') 189 | axs[1].imshow(target_image) 190 | axs[1].set_title('tgt') 191 | axs[2].imshow(warped_image_aff_np) 192 | axs[2].set_title('aff') 193 | for i in range(N_subplots): 194 | axs[i].axis('off') 195 | plt.show() 196 | -------------------------------------------------------------------------------- /geotnf/__pycache__/transformation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/geotnf/__pycache__/transformation.cpython-36.pyc -------------------------------------------------------------------------------- /geotnf/transformation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | from skimage import io 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | from torch.nn.modules.module import Module 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | 12 | class GeometricTnf(object): 13 | """ 14 | 15 | Geometric transfromation to an image batch (wrapped in a PyTorch Variable) 16 | ( can be used with no transformation to perform bilinear resizing ) 17 | 18 | """ 19 | def __init__(self, geometric_model='affine', out_h=240, out_w=240, use_cuda=True): 20 | self.out_h = out_h 21 | self.out_w = out_w 22 | self.use_cuda = use_cuda 23 | if geometric_model=='affine': 24 | self.gridGen = AffineGridGen(out_h, out_w) 25 | elif geometric_model=='tps': 26 | self.gridGen = TpsGridGen(out_h, out_w, use_cuda=use_cuda) 27 | self.theta_identity = torch.Tensor(np.expand_dims(np.array([[1,0,0],[0,1,0]]),0).astype(np.float32)) 28 | if use_cuda: 29 | self.theta_identity = self.theta_identity.cuda() 30 | 31 | def __call__(self, image_batch, theta_batch=None, padding_factor=1.0, crop_factor=1.0): 32 | b, c, h, w = image_batch.size() 33 | if theta_batch is None: 34 | theta_batch = self.theta_identity 35 | theta_batch = theta_batch.expand(b,2,3) 36 | theta_batch = Variable(theta_batch,requires_grad=False) 37 | 38 | sampling_grid = self.gridGen(theta_batch) 39 | 40 | # rescale grid according to crop_factor and padding_factor 41 | sampling_grid.data = sampling_grid.data*padding_factor*crop_factor 42 | # sample transformed image, shape 1, 3, 1080, 1080 43 | warped_image_batch = F.grid_sample(image_batch, sampling_grid) 44 | 45 | return warped_image_batch 46 | 47 | 48 | class SynthPairTnf(object): 49 | """ 50 | 51 | Generate a synthetically warped training pair using an affine transformation. 52 | 53 | """ 54 | def __init__(self, use_cuda=True, geometric_model='affine', crop_factor=9/16, output_size=(240,240), padding_factor = 0.5): 55 | assert isinstance(use_cuda, (bool)) 56 | assert isinstance(crop_factor, (float)) 57 | assert isinstance(output_size, (tuple)) 58 | assert isinstance(padding_factor, (float)) 59 | self.use_cuda=use_cuda 60 | self.crop_factor = crop_factor 61 | self.padding_factor = padding_factor 62 | self.out_h, self.out_w = output_size 63 | self.rescalingTnf = GeometricTnf('affine', self.out_h, self.out_w, 64 | use_cuda = self.use_cuda) 65 | self.geometricTnf = GeometricTnf(geometric_model, self.out_h, self.out_w, 66 | use_cuda = self.use_cuda) 67 | self.tilt = GeometricTnf(geometric_model, 240, 240, 68 | use_cuda = self.use_cuda) 69 | 70 | def __call__(self, batch): 71 | image_batch, image_batch2, theta_batch= batch['image'], batch['image2'], batch['theta'] 72 | 73 | if self.use_cuda: 74 | image_batch = image_batch.cuda() 75 | image_batch2 = image_batch2.cuda() 76 | theta_batch = theta_batch.cuda() 77 | 78 | # convert to variables 79 | image_batch = Variable(image_batch,requires_grad=False) 80 | image_batch2 = Variable(image_batch2,requires_grad=False) 81 | theta_batch = Variable(theta_batch,requires_grad=False) 82 | 83 | # # get cropped image 84 | cropped_image_batch = self.rescalingTnf(image_batch,None,self.padding_factor,self.crop_factor) 85 | # # # get transformed image 86 | warped_image_batch = self.geometricTnf(image_batch2,theta_batch, 87 | self.padding_factor,self.crop_factor) 88 | 89 | # Origin 90 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch, 'origin_image': image_batch} 91 | 92 | def symmetricImagePad(self,image_batch, padding_factor): 93 | b, c, h, w = image_batch.size() 94 | pad_h, pad_w = int(h*padding_factor), int(w*padding_factor) 95 | idx_pad_left = torch.LongTensor(range(pad_w-1,-1,-1)) 96 | idx_pad_right = torch.LongTensor(range(w-1,w-pad_w-1,-1)) 97 | idx_pad_top = torch.LongTensor(range(pad_h-1,-1,-1)) 98 | idx_pad_bottom = torch.LongTensor(range(h-1,h-pad_h-1,-1)) 99 | if self.use_cuda: 100 | idx_pad_left = idx_pad_left.cuda() 101 | idx_pad_right = idx_pad_right.cuda() 102 | idx_pad_top = idx_pad_top.cuda() 103 | idx_pad_bottom = idx_pad_bottom.cuda() 104 | image_batch = torch.cat((image_batch.index_select(3,idx_pad_left),image_batch, 105 | image_batch.index_select(3,idx_pad_right)),3) 106 | image_batch = torch.cat((image_batch.index_select(2,idx_pad_top),image_batch, 107 | image_batch.index_select(2,idx_pad_bottom)),2) 108 | return image_batch 109 | 110 | 111 | class SynthSingleTnf(object): 112 | """ 113 | 114 | Generate a synthetically warped training image in Session 2 using an affine transformation. 115 | 116 | """ 117 | 118 | def __init__(self, use_cuda=True, geometric_model='affine', crop_factor=9/16, output_size=(240,240), padding_factor=0.5): # Original 119 | assert isinstance(use_cuda, (bool)) 120 | assert isinstance(crop_factor, (float)) 121 | assert isinstance(output_size, (tuple)) 122 | assert isinstance(padding_factor, (float)) 123 | self.use_cuda = use_cuda 124 | self.crop_factor = crop_factor 125 | self.padding_factor = padding_factor 126 | self.out_h, self.out_w = output_size 127 | 128 | self.geometricTnf = GeometricTnf(geometric_model, self.out_h, self.out_w, # Original 129 | use_cuda=self.use_cuda) 130 | 131 | # I think not __call__, but forward 132 | def __call__(self, batch, theta1): 133 | image_batch, theta_batch = batch, theta1 134 | 135 | if self.use_cuda: 136 | theta_batch = theta_batch.cuda() 137 | 138 | # generate symmetrically padded image for bigger sampling region 139 | # convert to variables 140 | theta_batch = Variable(theta_batch, requires_grad=False) 141 | 142 | # For Demo 143 | warped_image_batch = self.geometricTnf(image_batch, 144 | theta_batch) # Identity is not used as theta given # Original 145 | 146 | return warped_image_batch # Original 147 | 148 | def symmetricImagePad(self, image_batch, padding_factor): 149 | b, c, h, w = image_batch.size() 150 | pad_h, pad_w = int(h * padding_factor), int(w * padding_factor) 151 | idx_pad_left = torch.LongTensor(range(pad_w - 1, -1, -1)) 152 | idx_pad_right = torch.LongTensor(range(w - 1, w - pad_w - 1, -1)) 153 | idx_pad_top = torch.LongTensor(range(pad_h - 1, -1, -1)) 154 | idx_pad_bottom = torch.LongTensor(range(h - 1, h - pad_h - 1, -1)) 155 | if self.use_cuda: 156 | idx_pad_left = idx_pad_left.cuda() 157 | idx_pad_right = idx_pad_right.cuda() 158 | idx_pad_top = idx_pad_top.cuda() 159 | idx_pad_bottom = idx_pad_bottom.cuda() 160 | image_batch = torch.cat((image_batch.index_select(3, idx_pad_left), image_batch, 161 | image_batch.index_select(3, idx_pad_right)), 3) 162 | image_batch = torch.cat((image_batch.index_select(2, idx_pad_top), image_batch, 163 | image_batch.index_select(2, idx_pad_bottom)), 2) 164 | return image_batch 165 | 166 | class AffineGridGen(Module): 167 | def __init__(self, out_h=240, out_w=240, out_ch = 3): 168 | super(AffineGridGen, self).__init__() 169 | self.out_h = out_h 170 | self.out_w = out_w 171 | self.out_ch = out_ch 172 | 173 | def forward(self, theta): 174 | theta = theta.contiguous() 175 | batch_size = theta.size()[0] 176 | out_size = torch.Size((batch_size,self.out_ch,self.out_h,self.out_w)) 177 | return F.affine_grid(theta, out_size) 178 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'mine_matching.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.9.2 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | from PyQt5.QtGui import QIcon, QPixmap, QImage 11 | from matching import matching_demo 12 | import os 13 | import cv2 14 | 15 | img_path= '' 16 | img2_path = '' 17 | 18 | class Ui_MainWindow(object): 19 | def pushButtonClicked(self): 20 | self.label_4.setText("Waiting") 21 | self.verticalSlider.setSliderPosition(50) 22 | self.label_5.clear() 23 | self.label_6.clear() 24 | self.label_7.clear() 25 | self.label_8.clear() 26 | fname = QtWidgets.QFileDialog.getOpenFileName() 27 | global img_path 28 | img_path = fname[0] 29 | img = QPixmap(fname[0]) 30 | img = img.scaled(self.label_5.width(),self.label_5.height()) 31 | self.label_5.setPixmap(img) 32 | 33 | def pushButton2Clicked(self): 34 | fname2 = QtWidgets.QFileDialog.getOpenFileName() 35 | global img2_path 36 | img2_path = fname2[0] 37 | img2 = QPixmap(fname2[0]) 38 | img2 = img2.scaled(self.label_6.width(),self.label_6.height()) 39 | self.label_6.setPixmap(img2) 40 | 41 | def pushButton3Clicked(self): 42 | self.label_4.setText("Proceeding...") 43 | img3 = demo(img_path, img2_path) # 1080,1080, 3 44 | self.label_4.setText("Done...") 45 | imge = QPixmap(os.getcwd()+'/result.jpg') 46 | 47 | imge = imge.scaled(self.label_7.width(),self.label_7.height()) 48 | self.label_7.setPixmap(imge) 49 | tar = cv2.imread(img2_path) 50 | result = cv2.imread(os.getcwd() + '/result.jpg') 51 | alpha = 0.5 52 | beta = (1.0 - alpha) 53 | shrink = cv2.resize(tar, None, fx=result.shape[0] / tar.shape[0], fy=result.shape[0] / tar.shape[0], 54 | interpolation=cv2.INTER_AREA) 55 | dst = cv2.addWeighted(shrink, alpha, result, beta, 0.0, None) 56 | cv2.imwrite('overlay.jpg', dst) 57 | imge2 = QPixmap(os.getcwd()+'/overlay.jpg') 58 | imge2 = imge2.scaled(self.label_8.width(),self.label_8.height()) 59 | self.label_8.setPixmap(imge2) 60 | 61 | def detectChange(self): 62 | tar = cv2.imread(img2_path) 63 | result = cv2.imread(os.getcwd() + '/result.jpg') 64 | alpha = 0.01 * self.verticalSlider.value() 65 | beta = (1.0 - alpha) 66 | shrink = cv2.resize(tar, None, fx=result.shape[0] / tar.shape[0], fy=result.shape[0] / tar.shape[0], 67 | interpolation=cv2.INTER_AREA) 68 | dst = cv2.addWeighted(shrink, alpha, result, beta, 0.0, None) 69 | cv2.imwrite('overlay.jpg', dst) 70 | imge2 = QPixmap(os.getcwd()+'/overlay.jpg') 71 | imge2 = imge2.scaled(self.label_8.width(),self.label_8.height()) 72 | self.label_8.setPixmap(imge2) 73 | 74 | def setupUi(self, MainWindow): 75 | MainWindow.setObjectName("MainWindow") 76 | MainWindow.resize(731, 921) 77 | self.centralwidget = QtWidgets.QWidget(MainWindow) 78 | self.centralwidget.setObjectName("centralwidget") 79 | self.pushButton = QtWidgets.QPushButton(self.centralwidget) 80 | self.pushButton.setGeometry(QtCore.QRect(10, 70, 191, 31)) 81 | font = QtGui.QFont() 82 | font.setFamily("Agency FB") 83 | font.setPointSize(12) 84 | font.setBold(True) 85 | font.setWeight(75) 86 | self.pushButton.setFont(font) 87 | self.pushButton.setObjectName("pushButton") 88 | self.pushButton.clicked.connect(self.pushButtonClicked) 89 | self.label = QtWidgets.QLabel(self.centralwidget) 90 | self.label.setGeometry(QtCore.QRect(50, 0, 331, 31)) 91 | font = QtGui.QFont() 92 | font.setFamily("Agency FB") 93 | font.setPointSize(14) 94 | font.setBold(True) 95 | font.setWeight(75) 96 | self.label.setFont(font) 97 | self.label.setObjectName("label") 98 | self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget) 99 | self.pushButton_2.setGeometry(QtCore.QRect(210, 70, 191, 31)) 100 | font = QtGui.QFont() 101 | font.setFamily("Agency FB") 102 | font.setPointSize(12) 103 | font.setBold(True) 104 | font.setWeight(75) 105 | self.pushButton_2.setFont(font) 106 | self.pushButton_2.setObjectName("pushButton_2") 107 | self.pushButton_2.clicked.connect(self.pushButton2Clicked) 108 | self.pushButton_3 = QtWidgets.QPushButton(self.centralwidget) 109 | self.pushButton_3.setGeometry(QtCore.QRect(10, 120, 191, 31)) 110 | font = QtGui.QFont() 111 | font.setFamily("Agency FB") 112 | font.setPointSize(12) 113 | font.setBold(True) 114 | font.setWeight(75) 115 | self.pushButton_3.setFont(font) 116 | self.pushButton_3.setObjectName("pushButton_3") 117 | self.pushButton_3.clicked.connect(self.pushButton3Clicked) 118 | self.label_2 = QtWidgets.QLabel(self.centralwidget) 119 | self.label_2.setGeometry(QtCore.QRect(410, 20, 121, 121)) 120 | self.label_2.setText("") 121 | self.label_2.setObjectName("label_2") 122 | addlogo = QPixmap(os.getcwd()+'/ADD_Logo.jpg') 123 | logo1 = addlogo.scaled(self.label_2.width(),self.label_2.height()) 124 | self.label_2.setPixmap(logo1) 125 | self.label_3 = QtWidgets.QLabel(self.centralwidget) 126 | self.label_3.setGeometry(QtCore.QRect(560, 20, 121, 121)) 127 | self.label_3.setText("") 128 | self.label_3.setObjectName("label_3") 129 | kulogo = QPixmap(os.getcwd()+'/KU_Logo.jpg') 130 | logo2 = kulogo.scaled(self.label_3.width(),self.label_3.height()) 131 | self.label_3.setPixmap(logo2) 132 | self.groupBox = QtWidgets.QGroupBox(self.centralwidget) 133 | self.groupBox.setGeometry(QtCore.QRect(210, 110, 191, 41)) 134 | font = QtGui.QFont() 135 | font.setFamily("Agency FB") 136 | font.setPointSize(10) 137 | font.setBold(True) 138 | font.setWeight(75) 139 | self.groupBox.setFont(font) 140 | self.groupBox.setObjectName("groupBox") 141 | self.label_4 = QtWidgets.QLabel(self.groupBox) 142 | self.label_4.setGeometry(QtCore.QRect(10, 10, 171, 31)) 143 | font = QtGui.QFont() 144 | font.setFamily("Agency FB") 145 | font.setPointSize(8) 146 | font.setBold(False) 147 | font.setWeight(50) 148 | self.label_4.setFont(font) 149 | self.label_4.setTextFormat(QtCore.Qt.AutoText) 150 | self.label_4.setAlignment(QtCore.Qt.AlignCenter) 151 | self.label_4.setObjectName("label_4") 152 | self.label_5 = QtWidgets.QLabel(self.centralwidget) 153 | self.label_5.setGeometry(QtCore.QRect(10, 160, 337, 337)) 154 | self.label_5.setFrameShape(QtWidgets.QFrame.NoFrame) 155 | self.label_5.setText("") 156 | self.label_5.setObjectName("label_5") 157 | self.label_6 = QtWidgets.QLabel(self.centralwidget) 158 | self.label_6.setGeometry(QtCore.QRect(360, 160, 337, 337)) 159 | self.label_6.setText("") 160 | self.label_6.setObjectName("label_6") 161 | self.label_7 = QtWidgets.QLabel(self.centralwidget) 162 | self.label_7.setGeometry(QtCore.QRect(10, 520, 337, 337)) 163 | self.label_7.setFrameShape(QtWidgets.QFrame.NoFrame) 164 | self.label_7.setText("") 165 | self.label_7.setObjectName("label_7") 166 | self.label_8 = QtWidgets.QLabel(self.centralwidget) 167 | self.label_8.setGeometry(QtCore.QRect(360, 520, 337, 337)) 168 | self.label_8.setFrameShape(QtWidgets.QFrame.NoFrame) 169 | self.label_8.setText("") 170 | self.label_8.setObjectName("label_8") 171 | self.label_9 = QtWidgets.QLabel(self.centralwidget) 172 | self.label_9.setGeometry(QtCore.QRect(150, 500, 111, 16)) 173 | font = QtGui.QFont() 174 | font.setFamily("Agency FB") 175 | font.setPointSize(10) 176 | font.setBold(True) 177 | font.setWeight(75) 178 | self.label_9.setFont(font) 179 | self.label_9.setObjectName("label_9") 180 | self.label_10 = QtWidgets.QLabel(self.centralwidget) 181 | self.label_10.setGeometry(QtCore.QRect(500, 500, 111, 16)) 182 | font = QtGui.QFont() 183 | font.setFamily("Agency FB") 184 | font.setPointSize(10) 185 | font.setBold(True) 186 | font.setWeight(75) 187 | self.label_10.setFont(font) 188 | self.label_10.setObjectName("label_10") 189 | self.label_11 = QtWidgets.QLabel(self.centralwidget) 190 | self.label_11.setGeometry(QtCore.QRect(470, 860, 111, 16)) 191 | font = QtGui.QFont() 192 | font.setFamily("Agency FB") 193 | font.setPointSize(10) 194 | font.setBold(True) 195 | font.setWeight(75) 196 | self.label_11.setFont(font) 197 | self.label_11.setAlignment(QtCore.Qt.AlignCenter) 198 | self.label_11.setObjectName("label_11") 199 | self.label_12 = QtWidgets.QLabel(self.centralwidget) 200 | self.label_12.setGeometry(QtCore.QRect(140, 860, 111, 16)) 201 | font = QtGui.QFont() 202 | font.setFamily("Agency FB") 203 | font.setPointSize(10) 204 | font.setBold(True) 205 | font.setWeight(75) 206 | self.label_12.setFont(font) 207 | self.label_12.setAlignment(QtCore.Qt.AlignCenter) 208 | self.label_12.setObjectName("label_12") 209 | self.verticalSlider = QtWidgets.QSlider(self.centralwidget) 210 | self.verticalSlider.setGeometry(QtCore.QRect(700, 520, 22, 341)) 211 | self.verticalSlider.setMaximum(100) 212 | self.verticalSlider.setSliderPosition(50) 213 | self.verticalSlider.setOrientation(QtCore.Qt.Vertical) 214 | self.verticalSlider.setInvertedAppearance(False) 215 | self.verticalSlider.setTickPosition(QtWidgets.QSlider.TicksAbove) 216 | self.verticalSlider.setTickInterval(10) 217 | self.verticalSlider.setObjectName("verticalSlider") 218 | self.verticalSlider.valueChanged.connect(self.detectChange) 219 | MainWindow.setCentralWidget(self.centralwidget) 220 | self.menubar = QtWidgets.QMenuBar(MainWindow) 221 | self.menubar.setGeometry(QtCore.QRect(0, 0, 731, 21)) 222 | self.menubar.setObjectName("menubar") 223 | MainWindow.setMenuBar(self.menubar) 224 | self.statusbar = QtWidgets.QStatusBar(MainWindow) 225 | self.statusbar.setObjectName("statusbar") 226 | MainWindow.setStatusBar(self.statusbar) 227 | 228 | self.retranslateUi(MainWindow) 229 | QtCore.QMetaObject.connectSlotsByName(MainWindow) 230 | 231 | def retranslateUi(self, MainWindow): 232 | _translate = QtCore.QCoreApplication.translate 233 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 234 | self.pushButton.setText(_translate("MainWindow", "Loading source image")) # 입력영상 불러오기 235 | self.label.setText(_translate("MainWindow", "Aerial image matching program")) # 딥러닝 기반 항공 영상 정합 프로그램 236 | self.pushButton_2.setText(_translate("MainWindow", "Loading target image")) # 기준영상 불러오기 237 | self.pushButton_3.setText(_translate("MainWindow", "Executing image matching")) # 영상 정합 수행 238 | self.groupBox.setTitle(_translate("MainWindow", "Current progress")) # 정합 진행 상태 239 | self.label_4.setText(_translate("MainWindow", "Waiting")) # 대기중 240 | self.label_9.setText(_translate("MainWindow", "Source image")) # 입력영상 241 | self.label_10.setText(_translate("MainWindow", "Target image")) # 기준영상 242 | self.label_11.setText(_translate("MainWindow", "Overlay check")) # 정합 결과 중첩 243 | self.label_12.setText(_translate("MainWindow", "Matching result")) # 정합 결과 244 | 245 | if __name__ == "__main__": 246 | import sys 247 | demo = matching_demo() 248 | app = QtWidgets.QApplication(sys.argv) 249 | MainWindow = QtWidgets.QMainWindow() 250 | ui = Ui_MainWindow() 251 | ui.setupUi(MainWindow) 252 | MainWindow.show() 253 | sys.exit(app.exec_()) 254 | 255 | -------------------------------------------------------------------------------- /image/__pycache__/normalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/image/__pycache__/normalization.cpython-36.pyc -------------------------------------------------------------------------------- /image/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.autograd import Variable 4 | 5 | class NormalizeImageDict(object): 6 | """ 7 | 8 | Normalizes Tensor images in dictionary 9 | 10 | Args: 11 | image_keys (list): dict. keys of the images to be normalized 12 | normalizeRange (bool): if True the image is divided by 255.0s 13 | 14 | """ 15 | 16 | def __init__(self,image_keys,normalizeRange=True): 17 | self.image_keys = image_keys 18 | self.normalizeRange=normalizeRange 19 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 20 | std=[0.229, 0.224, 0.225]) 21 | 22 | def __call__(self, sample): 23 | for key in self.image_keys: 24 | if self.normalizeRange: 25 | sample[key] /= 255.0 26 | sample[key] = self.normalize(sample[key]) 27 | return sample 28 | 29 | 30 | def normalize_image(image, forward=True, mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]): 31 | im_size = image.size() 32 | mean=torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2) 33 | std=torch.FloatTensor(std).unsqueeze(1).unsqueeze(2) 34 | if image.is_cuda: 35 | mean = mean.cuda() 36 | std = std.cuda() 37 | if isinstance(image,torch.autograd.Variable): 38 | mean = Variable(mean,requires_grad=False) 39 | std = Variable(std,requires_grad=False) 40 | if forward: 41 | if len(im_size)==3: 42 | result = image.sub(mean.expand(im_size)).div(std.expand(im_size)) 43 | elif len(im_size)==4: 44 | result = image.sub(mean.unsqueeze(0).expand(im_size)).div(std.unsqueeze(0).expand(im_size)) 45 | else: 46 | if len(im_size)==3: 47 | result = image.mul(std.expand(im_size)).add(mean.expand(im_size)) 48 | elif len(im_size)==4: 49 | result = image.mul(std.unsqueeze(0).expand(im_size)).add(mean.unsqueeze(0).expand(im_size)) 50 | 51 | return result -------------------------------------------------------------------------------- /matching.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import argparse 3 | from scipy.misc import imsave 4 | from model.cnn_geometric_model import CNNGeometricPearson 5 | from image.normalization import NormalizeImageDict, normalize_image 6 | from util.torch_util import BatchTensorToVars 7 | from geotnf.transformation import GeometricTnf 8 | from skimage import io 9 | import torch 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | class matching_demo(object): 14 | def __init__(self, geometric_model='affine'): 15 | # Argument parsing 16 | parser = argparse.ArgumentParser(description='Gradual Estimation for Aerial Image Matching demo script') 17 | # Paths 18 | parser.add_argument('--model-aff', type=str, 19 | default='trained_models/resnet36_myproc_1_new_cor_fefr_4p5.pth.tar', 20 | help='Trained affine model filename') 21 | parser.add_argument('--model-aff2', type=str, 22 | default='trained_models/resnet101_epo81_lr4p4_rm11.pth.tar', 23 | help='Trained affine model filename') 24 | parser.add_argument('--feature-extraction-cnn', type=str, default='resnet101', 25 | help='Feature extraction architecture: vgg/resnet101') 26 | 27 | self.args = parser.parse_args() 28 | self.use_cuda = torch.cuda.is_available() 29 | 30 | self.do_aff = not self.args.model_aff2 == '' 31 | # Create model 32 | print('Creating CNN model...') 33 | if self.do_aff: 34 | self.model_aff = CNNGeometricPearson(use_cuda=self.use_cuda, geometric_model=geometric_model, 35 | feature_extraction_cnn=self.args.feature_extraction_cnn) 36 | 37 | # Load trained weights 38 | print('Loading trained model weights...') 39 | if self.do_aff: 40 | checkpoint = torch.load(self.args.model_aff, map_location=lambda storage, loc: storage) 41 | checkpoint2 = torch.load(self.args.model_aff2, map_location=lambda storage, loc: storage) 42 | model_dict = self.model_aff.FeatureExtraction.state_dict() 43 | for name, param in model_dict.items(): 44 | model_dict[name].copy_(checkpoint['state_dict'][ 45 | 'FeatureExtraction.' + name]) 46 | model_dict = self.model_aff.FeatureClassification.state_dict() 47 | for name, param in model_dict.items(): 48 | model_dict[name].copy_(checkpoint['state_dict'][ 49 | 'FeatureClassification.' + name]) 50 | model_dict = self.model_aff.FeatureExtraction2.state_dict() 51 | for name, param in model_dict.items(): 52 | model_dict[name].copy_(checkpoint2['state_dict'][ 53 | 'FeatureExtraction.' + name]) 54 | model_dict = self.model_aff.FeatureRegression.state_dict() 55 | for name, param in model_dict.items(): 56 | model_dict[name].copy_(checkpoint2['state_dict'][ 57 | 'FeatureRegression.' + name]) 58 | self.affTnf = GeometricTnf(geometric_model='affine', out_h=240, out_w=240, use_cuda=False) 59 | self.affTnf_demo = GeometricTnf(geometric_model='affine', out_h=338, out_w=338, use_cuda=False) 60 | self.affTnf_origin = GeometricTnf(geometric_model='affine', out_h=480, out_w=480, use_cuda=False) 61 | 62 | self.transform = NormalizeImageDict(['source_image', 'target_image', 'demo', 'origin_image']) 63 | self.rescalingTnf = GeometricTnf('affine', 240, 240, 64 | use_cuda=True) 65 | self.geometricTnf = GeometricTnf(geometric_model, 240, 240, 66 | use_cuda=True) 67 | 68 | 69 | def __call__(self, fname, fname2): 70 | 71 | image = io.imread(fname) 72 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 73 | image = torch.Tensor(image.astype(np.float32)) 74 | image_var = Variable(image, requires_grad=False) 75 | image_A = self.affTnf(image_var).data.squeeze(0) 76 | image_A_demo = self.affTnf_demo(image_var).data.squeeze(0) 77 | image_A_origin = self.affTnf_origin(image_var).data.squeeze(0) 78 | 79 | image2 = io.imread(fname2) 80 | image2 = np.expand_dims(image2.transpose((2, 0, 1)), 0) 81 | image2 = torch.Tensor(image2.astype(np.float32)) 82 | image_var2 = Variable(image2, requires_grad=False) 83 | image_B = self.affTnf(image_var2).data.squeeze(0) 84 | 85 | sample = {'source_image': image_A, 'target_image': image_B, 'demo': image_A_demo, 'origin_image': image_A_origin} 86 | 87 | sample = self.transform(sample) 88 | 89 | batchTensorToVars = BatchTensorToVars(use_cuda=self.use_cuda) 90 | 91 | batch = batchTensorToVars(sample) 92 | batch['source_image'] = torch.unsqueeze(batch['source_image'],0) 93 | batch['target_image'] = torch.unsqueeze(batch['target_image'],0) 94 | batch['origin_image'] = torch.unsqueeze(batch['origin_image'],0) 95 | batch['demo'] = torch.unsqueeze(batch['demo'],0) 96 | 97 | if self.do_aff: 98 | self.model_aff.eval() 99 | 100 | # Evaluate models 101 | if self.do_aff: 102 | theta_aff = self.model_aff(batch) 103 | warped_image_aff_demo = self.affTnf_demo(batch['demo'], theta_aff.view(-1, 2, 3)) 104 | 105 | if self.do_aff: 106 | warped_image_aff_demo = normalize_image(warped_image_aff_demo, forward=False) 107 | warped_image_aff_demo = warped_image_aff_demo.data.squeeze(0).transpose(0, 1).transpose(1, 2).cpu().numpy() 108 | 109 | print("Done") 110 | imsave('result.jpg', warped_image_aff_demo) 111 | 112 | return warped_image_aff_demo 113 | 114 | -------------------------------------------------------------------------------- /model/__pycache__/cnn_geometric_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/model/__pycache__/cnn_geometric_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/cnn_geometric_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torchvision.models as models 6 | from geotnf.transformation import SynthSingleTnf 7 | import numpy as np 8 | 9 | 10 | class FeatureExtraction(torch.nn.Module): 11 | def __init__(self, use_cuda=True, feature_extraction_cnn='vgg', last_layer=''): 12 | super(FeatureExtraction, self).__init__() 13 | if feature_extraction_cnn == 'vgg': 14 | self.model = models.vgg16(pretrained=True) 15 | # keep feature extraction network up to indicated layer 16 | vgg_feature_layers = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 17 | 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 18 | 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 19 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 20 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'pool5'] 21 | if last_layer == '': 22 | last_layer = 'pool4' 23 | last_layer_idx = vgg_feature_layers.index(last_layer) 24 | self.model = nn.Sequential(*list(self.model.features.children())[:last_layer_idx + 1]) 25 | if feature_extraction_cnn == 'resnet101': 26 | self.model = models.resnet101(pretrained=True) 27 | resnet_feature_layers = ['conv1', 28 | 'bn1', 29 | 'relu', 30 | 'maxpool', 31 | 'layer1', 32 | 'layer2', 33 | 'layer3', 34 | 'layer4'] 35 | if last_layer == '': 36 | last_layer = 'layer3' 37 | last_layer_idx = resnet_feature_layers.index(last_layer) 38 | resnet_module_list = [self.model.conv1, 39 | self.model.bn1, 40 | self.model.relu, 41 | self.model.maxpool, 42 | self.model.layer1, 43 | self.model.layer2, 44 | self.model.layer3, 45 | self.model.layer4] 46 | 47 | self.model = nn.Sequential(*resnet_module_list[:last_layer_idx + 1]) 48 | # freeze parameters 49 | for param in self.model.parameters(): 50 | # param.requires_grad = False 51 | param.requires_grad = True 52 | # move to GPU 53 | if use_cuda: 54 | self.model.cuda() 55 | 56 | def forward(self, image_batch): 57 | return self.model(image_batch) 58 | 59 | 60 | class FeatureL2Norm(torch.nn.Module): 61 | def __init__(self): 62 | super(FeatureL2Norm, self).__init__() 63 | 64 | def forward(self, feature): 65 | epsilon = 1e-6 66 | norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature) 67 | return torch.div(feature, norm) 68 | 69 | 70 | class Feature2Pearson(torch.nn.Module): 71 | def __init__(self): 72 | super(Feature2Pearson, self).__init__() 73 | 74 | def forward(self, feature): 75 | epsilon = 1e-6 76 | feature_mean = torch.mean(feature, 1, True) 77 | pearson = feature - feature_mean 78 | norm = torch.pow(torch.sum(torch.pow(pearson, 2), 1) + epsilon, 0.5).unsqueeze(1).expand_as(feature) 79 | return torch.div(pearson, norm) 80 | 81 | class FeatureCorrelation(torch.nn.Module): 82 | def __init__(self): 83 | super(FeatureCorrelation, self).__init__() 84 | 85 | def forward(self, feature_A, feature_B): 86 | b, c, h, w = feature_A.size() 87 | # reshape features for matrix multiplication 88 | # Existed ver 89 | feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h * w) 90 | feature_B = feature_B.view(b, c, h * w).transpose(1, 2) 91 | # perform matrix mult. 92 | feature_mul = torch.bmm(feature_B, feature_A) 93 | 94 | # else: 95 | correlation_tensor = feature_mul.view(b, h, w, h * w).transpose(2, 3).transpose(1, 2) 96 | 97 | return correlation_tensor 98 | 99 | 100 | class FeatureMasking(torch.nn.Module): 101 | def __init__(self): 102 | super(FeatureMasking, self).__init__() 103 | # 104 | def forward(self, correlation_tensor): 105 | correlation_tensor = correlation_tensor.transpose(1, 2).transpose(2, 3) 106 | l = 11 107 | h = 15 108 | w = 15 109 | limit_region = np.zeros((w, h, w * h)) 110 | for i in range(h): 111 | for j in range(w): 112 | for r_h in range(-1 * l, l + 1): 113 | for r_w in range(-1 * l, l + 1): 114 | temp_col = j + r_w 115 | temp_raw = i + r_h 116 | if temp_col in range(w) and temp_raw in range(h): 117 | limit_region[i][j][w * (temp_col) + temp_raw] = 1 118 | cor_mask = torch.unsqueeze(Variable(torch.FloatTensor(limit_region), requires_grad=False), 0) 119 | cor_mask = cor_mask.cuda() 120 | correlation_tensor = correlation_tensor * cor_mask 121 | correlation_tensor = correlation_tensor.transpose(2, 3).transpose(1, 2) 122 | 123 | return correlation_tensor 124 | 125 | 126 | 127 | class FeatureClassification(nn.Module): 128 | def __init__(self, output_dim=4, use_cuda=True): 129 | super(FeatureClassification, self).__init__() 130 | self.conv = nn.Sequential( 131 | nn.Conv2d(225, 128, kernel_size=7, padding=0), 132 | nn.BatchNorm2d(128), 133 | nn.ReLU(inplace=True), 134 | nn.Conv2d(128, 64, kernel_size=5, padding=0), 135 | nn.BatchNorm2d(64), 136 | nn.ReLU(inplace=True), 137 | ) 138 | self.linear = nn.Linear(64 * 5 * 5, output_dim) 139 | if use_cuda: 140 | self.conv.cuda() 141 | self.linear.cuda() 142 | 143 | def forward(self, x): 144 | x = self.conv(x) 145 | x = x.view(x.size(0), -1) 146 | x = self.linear(x) 147 | 148 | return x 149 | 150 | 151 | class FeatureRegression(nn.Module): 152 | def __init__(self, output_dim=6, use_cuda=True): 153 | super(FeatureRegression, self).__init__() 154 | self.conv = nn.Sequential( 155 | nn.Conv2d(225, 128, kernel_size=7, padding=0), 156 | nn.BatchNorm2d(128), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(128, 64, kernel_size=5, padding=0), 159 | nn.BatchNorm2d(64), 160 | nn.ReLU(inplace=True), 161 | ) 162 | self.linear = nn.Linear(64 * 5 * 5, output_dim) 163 | if use_cuda: 164 | self.conv.cuda() 165 | self.linear.cuda() 166 | 167 | def forward(self, x): 168 | x = self.conv(x) 169 | x = x.view(x.size(0), -1) 170 | x = self.linear(x) 171 | return x 172 | 173 | class FeatureRegression2(nn.Module): 174 | def __init__(self, output_dim=6, use_cuda=True): 175 | super(FeatureRegression2, self).__init__() 176 | self.conv = nn.Sequential( 177 | nn.Conv2d(225, 128, kernel_size=7, padding=0), 178 | nn.BatchNorm2d(128), 179 | nn.ReLU(inplace=True), 180 | nn.Conv2d(128, 64, kernel_size=5, padding=0), 181 | nn.BatchNorm2d(64), 182 | nn.ReLU(inplace=True), 183 | ) 184 | self.linear = nn.Linear(64 * 5 * 5, output_dim) 185 | if use_cuda: 186 | self.conv.cuda() 187 | self.linear.cuda() 188 | 189 | def forward(self, x): 190 | x = self.conv(x) 191 | x = x.view(x.size(0), -1) 192 | x = self.linear(x) 193 | return x 194 | 195 | class CNNGeometricPearson(nn.Module): 196 | def __init__(self, geometric_model='affine', normalize_features=True, normalize_matches=True, 197 | batch_normalization=True, use_cuda=True, feature_extraction_cnn='vgg'): 198 | super(CNNGeometricPearson, self).__init__() 199 | self.use_cuda = use_cuda 200 | self.normalize_features = normalize_features 201 | self.normalize_matches = normalize_matches 202 | self.FeatureExtraction = FeatureExtraction(use_cuda=self.use_cuda, 203 | feature_extraction_cnn=feature_extraction_cnn) 204 | self.FeatureExtraction2 = FeatureExtraction(use_cuda=self.use_cuda, 205 | feature_extraction_cnn=feature_extraction_cnn) 206 | self.Feature2Pearson = Feature2Pearson() 207 | self.FeatureL2Norm = FeatureL2Norm() 208 | self.FeatureMasking = FeatureMasking() 209 | self.FeatureCorrelation = FeatureCorrelation() 210 | if geometric_model == 'affine': 211 | output_dim = 6 212 | self.FeatureClassification = FeatureClassification(8,use_cuda=self.use_cuda) 213 | self.FeatureRegression = FeatureRegression(output_dim, use_cuda=self.use_cuda) 214 | self.ReLU = nn.ReLU(inplace=True) 215 | 216 | self.single_generation_tnf = SynthSingleTnf(use_cuda=self.use_cuda, geometric_model=geometric_model, output_size = (240,240)) 217 | 218 | def forward(self, tnf_batch): 219 | # do feature extraction 220 | feature_A = self.FeatureExtraction(tnf_batch['source_image']) 221 | feature_B = self.FeatureExtraction(tnf_batch['target_image']) 222 | # normalize 223 | if self.normalize_features: 224 | feature_A = self.Feature2Pearson(feature_A) 225 | feature_B = self.Feature2Pearson(feature_B) 226 | correlation = self.FeatureCorrelation(feature_A,feature_B) 227 | if self.normalize_matches: 228 | correlation = self.FeatureL2Norm(self.ReLU(correlation)) 229 | theta = self.FeatureClassification(correlation) 230 | 231 | _, predicted = torch.max(theta, 1) 232 | predicted = predicted.cpu().numpy() 233 | # 45 deg classification 234 | if predicted == 0: 235 | theta = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float).cuda() 236 | theta = Variable(theta, requires_grad=False) 237 | # print ('middle result', theta) 238 | elif predicted == 1: 239 | theta = torch.tensor([0.70710678118, -0.70710678118, 0, 0.70710678118, 0.70710678118, 0], dtype=torch.float).cuda() 240 | theta = Variable(theta, requires_grad=False) 241 | # print ('middle result', theta) 242 | elif predicted == 2: 243 | theta = torch.tensor([0, -1, 0, 1, 0, 0], dtype=torch.float).cuda() 244 | theta = Variable(theta, requires_grad=False) 245 | # print ('middle result', theta) 246 | elif predicted == 3: 247 | theta = torch.tensor([-0.70710678118, -0.70710678118, 0, 0.70710678118, -0.70710678118, 0], dtype=torch.float).cuda() 248 | theta = Variable(theta, requires_grad=False) 249 | # print ('middle result', theta) 250 | elif predicted == 4: # 180 251 | theta = torch.tensor([-1, 0, 0, 0, -1, 0], dtype=torch.float).cuda() 252 | theta = Variable(theta, requires_grad=False) 253 | # print ('middle result', theta) 254 | elif predicted == 5: # 225 255 | theta = torch.tensor([-0.70710678118, 0.70710678118, 0, -0.70710678118, -0.70710678118, 0], dtype=torch.float).cuda() 256 | theta = Variable(theta, requires_grad=False) 257 | # print ('middle result', theta) 258 | elif predicted == 6: 259 | theta = torch.tensor([0, 1, 0, -1, 0, 0], dtype=torch.float).cuda() 260 | theta = Variable(theta, requires_grad=False) 261 | # print ('middle result', theta) 262 | else: 263 | theta = torch.tensor([0.70710678118, 0.70710678118, 0, -0.70710678118, 0.70710678118, 0], dtype=torch.float).cuda() 264 | theta = Variable(theta, requires_grad=False) 265 | 266 | # Session 2 267 | theta1 = theta.view(-1, 2, 3) 268 | warped_image_batch2 = self.single_generation_tnf(tnf_batch['origin_image'], theta1) 269 | feature_A2 = self.FeatureExtraction2(warped_image_batch2) 270 | feature_B2 = self.FeatureExtraction2(tnf_batch['target_image']) 271 | 272 | feature_A2 = self.Feature2Pearson(feature_A2) 273 | feature_B2 = self.Feature2Pearson(feature_B2) 274 | 275 | correlation2 = self.FeatureCorrelation(feature_A2, feature_B2) 276 | if self.normalize_matches: 277 | correlation2 = self.FeatureL2Norm(self.ReLU(correlation2)) 278 | correlation2 = self.FeatureMasking(correlation2) 279 | theta2 = self.FeatureRegression(correlation2) 280 | 281 | theta = theta.view(-1, 2, 3) 282 | theta2 = theta2.view(-1, 2, 3) 283 | theta_r1, theta_r2 = torch.chunk(theta, 2, dim=1) 284 | theta_r1 = theta_r1.type(torch.FloatTensor).cuda() 285 | theta_r2 = theta_r2.type(torch.FloatTensor).cuda() 286 | ho_last = torch.tensor([0,0,1]).type(torch.FloatTensor).cuda() 287 | theta_f = torch.cat([theta_r1, theta_r2, ho_last.expand_as(theta_r1)], dim = 1) 288 | 289 | theta2_r1, theta2_r2 = torch.chunk(theta2, 2, dim=1) 290 | theta2_r1 = theta2_r1.type(torch.FloatTensor).cuda() 291 | theta2_r2 = theta2_r2.type(torch.FloatTensor).cuda() 292 | ho_last = torch.tensor([0,0,1]).type(torch.FloatTensor).cuda() 293 | gt_f = torch.cat([theta2_r1, theta2_r2, ho_last.expand_as(theta2_r1)], dim = 1) 294 | 295 | result = torch.bmm(theta_f, gt_f) 296 | result = result[:,:2,:] 297 | 298 | return result 299 | -------------------------------------------------------------------------------- /util/__pycache__/torch_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrk1992/robust_matching_network_on_remote_sensing_imagery_pytorch/15b442a6f1d7753945b9323294911a2780989155/util/__pycache__/torch_util.cpython-36.pyc -------------------------------------------------------------------------------- /util/torch_util.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | from torch.autograd import Variable 4 | from os import makedirs, remove 5 | from os.path import exists, join, basename, dirname 6 | 7 | class BatchTensorToVars(object): 8 | """Convert tensors in dict batch to vars 9 | """ 10 | def __init__(self, use_cuda=True): 11 | self.use_cuda=use_cuda 12 | 13 | def __call__(self, batch): 14 | batch_var = {} 15 | for key,value in batch.items(): 16 | batch_var[key] = Variable(value,requires_grad=False) 17 | if self.use_cuda: 18 | batch_var[key] = batch_var[key].cuda() 19 | 20 | return batch_var 21 | 22 | def save_checkpoint(state, is_best, file): 23 | model_dir = dirname(file) 24 | model_fn = basename(file) 25 | # make dir if needed (should be non-empty) 26 | if model_dir!='' and not exists(model_dir): 27 | makedirs(model_dir) 28 | torch.save(state, file) 29 | if is_best: 30 | shutil.copyfile(file, join(model_dir,'best_' + model_fn)) 31 | 32 | def str_to_bool(v): 33 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 34 | return True 35 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 36 | return False 37 | else: 38 | raise argparse.ArgumentTypeError('Boolean value expected.') 39 | --------------------------------------------------------------------------------