├── data └── Optical-SAR │ ├── 1_36.jpg │ ├── 1_9_1.jpg │ ├── 1_9_3.jpg │ └── 1_36_warp.jpg ├── tool ├── CFOG.cp38-win_amd64.pyd ├── model_tools.cp38-win_amd64.pyd ├── __pycache__ │ ├── loss_tools.cpython-38.pyc │ └── preprocess_tools.cpython-38.pyc ├── loss_tools.py └── preprocess_tools.py ├── README.md └── code └── demo.py /data/Optical-SAR/1_36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/data/Optical-SAR/1_36.jpg -------------------------------------------------------------------------------- /data/Optical-SAR/1_9_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/data/Optical-SAR/1_9_1.jpg -------------------------------------------------------------------------------- /data/Optical-SAR/1_9_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/data/Optical-SAR/1_9_3.jpg -------------------------------------------------------------------------------- /tool/CFOG.cp38-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/tool/CFOG.cp38-win_amd64.pyd -------------------------------------------------------------------------------- /data/Optical-SAR/1_36_warp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/data/Optical-SAR/1_36_warp.jpg -------------------------------------------------------------------------------- /tool/model_tools.cp38-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/tool/model_tools.cp38-win_amd64.pyd -------------------------------------------------------------------------------- /tool/__pycache__/loss_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/tool/__pycache__/loss_tools.cpython-38.pyc -------------------------------------------------------------------------------- /tool/__pycache__/preprocess_tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yeyuanxin110/MU-Net/HEAD/tool/__pycache__/preprocess_tools.cpython-38.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MU-Net 2 | 3 | This is an implementation of our paper: A multiscale framework with unsupervised learning for remote sensing image registration. 4 | 5 | The main file is demo.py, which inputs the image path and outputs the registration result. 6 | 7 | Some *. py files in tools are encrypted as *. pyd files, which can be normally invoked in Python 3.8 and Win 64 environments. 8 | 9 | Any question pls contact: ttf@my.swjtu.edu.cn 10 | -------------------------------------------------------------------------------- /tool/loss_tools.py: -------------------------------------------------------------------------------- 1 | from tool.CFOG import CFOG 2 | import torch 3 | 4 | def NCC_loss(i, j): 5 | x = torch.ge(i.squeeze(0).squeeze(0), 1) 6 | x = torch.tensor(x, dtype=torch.float32) 7 | y = torch.ge(j.squeeze(0).squeeze(0), 1) 8 | y = torch.tensor(y, dtype=torch.float32) 9 | z = torch.mul(x, y) 10 | num = z[z.ge(1)].size()[0] 11 | CFOG_sar = torch.mul(CFOG(i), z) 12 | CFOG_optical = torch.mul(CFOG(j), z) 13 | loss = gncc_loss(CFOG_sar, CFOG_optical)*10000/num 14 | return loss 15 | 16 | def ComputeLoss(reference, sensed_tran, sensed, reference_inv_tran): 17 | loss_1 = NCC_loss(reference, sensed_tran) 18 | loss_2 = NCC_loss(sensed, reference_inv_tran) 19 | loss = loss_1 + loss_2 20 | return loss 21 | 22 | def gncc_loss(I, J, eps=1e-5): 23 | I2 = I.pow(2) 24 | J2 = J.pow(2) 25 | IJ = I*J 26 | I_ave, J_ave = I.mean(), J.mean() 27 | I2_ave, J2_ave = I2.mean(), J2.mean() 28 | IJ_ave = IJ.mean() 29 | cross = IJ_ave - I_ave * J_ave 30 | I_var = I2_ave - I_ave.pow(2) 31 | J_var = J2_ave - J_ave.pow(2) 32 | cc = cross / (I_var.sqrt() * J_var.sqrt() + eps) # 1e-5 33 | return -1.0 * cc + 1 -------------------------------------------------------------------------------- /tool/preprocess_tools.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def save_tensor_tf(T, path): 8 | T = T.squeeze(0) 9 | # T_numpy = torch.tensor(T, dtype=torch.uint8).permute([1, 2, 0]).detach().cpu().numpy() 10 | T_numpy = torch.tensor(T).squeeze(0).detach().cpu().numpy() 11 | T_numpy = float2uint8_tf(T_numpy) 12 | cv.imwrite(path, T_numpy) 13 | 14 | def affine_transform_tf(im1, im2, H12, device='cpu'): 15 | im1_grid = F.affine_grid(H12, im1.size()) 16 | im1_resample = F.grid_sample(im1, im1_grid) 17 | a = torch.tensor([[[0, 0, 1]]], dtype=torch.float).to(device) 18 | a = a.repeat(im1.size()[0], 1, 1) 19 | affine_matrix = torch.cat([H12, a], dim=1) 20 | inv_affine_matrix = torch.inverse(affine_matrix) 21 | H21 = inv_affine_matrix[:, 0:2, :] 22 | im2_grid = F.affine_grid(H21, im2.size()) 23 | im2_resample = F.grid_sample(im2, im2_grid) 24 | return im1_resample, im2_resample, H21 25 | 26 | def float2uint8_tf(M): 27 | a = np.max(M) 28 | b = np.min(M) 29 | M = (M - b) / (a - b) * 255 30 | M = M.astype(np.uint8) 31 | return M 32 | 33 | def show_registration_result(ref, sen, sen_tran_T): 34 | sen_tran_T = sen_tran_T.squeeze(0) 35 | T_numpy = torch.tensor(sen_tran_T, dtype=torch.uint8).squeeze(0).detach().cpu().numpy() 36 | plt.subplot(1,3,1) 37 | plt.imshow(ref, cmap='gray') 38 | plt.title('input ref') 39 | plt.subplot(1, 3, 2) 40 | plt.imshow(sen, cmap='gray') 41 | plt.title('input sen') 42 | plt.subplot(1, 3, 3) 43 | plt.imshow(T_numpy, cmap='gray') 44 | plt.title('output sen correct') 45 | plt.show() -------------------------------------------------------------------------------- /code/demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | from PIL import Image 6 | from tool.loss_tools import ComputeLoss 7 | from tool.model_tools import net, check_model_path 8 | from tool.preprocess_tools import affine_transform_tf, show_registration_result, save_tensor_tf 9 | import os 10 | from torch.autograd import Variable 11 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | import warnings 14 | warnings.filterwarnings('ignore') 15 | import time 16 | 17 | def ite(ref_img, sen_img, pretrained_model=None): 18 | time_start = time.time() 19 | model = net().to(device) 20 | model = check_model_path(model, pretrained_model) 21 | print('Using device: ' + str(device)) 22 | print('Registration Waiting...') 23 | save_sen_tran_name = 'save.jpg' 24 | parameter_learn_rate = 0.0004 25 | max_iter = 800 26 | stop_iter = 200 27 | ref_tensor = Variable(torch.tensor(np.float32(np.array(ref_img))).to(device)).unsqueeze(0).unsqueeze(0) 28 | sen_tensor = Variable(torch.tensor(np.float32(np.array(sen_img))).to(device)).unsqueeze(0).unsqueeze(0) 29 | optimizer = optim.SGD(model.parameters(), lr=parameter_learn_rate) 30 | model.train() 31 | Epoch = [] 32 | Loss = [] 33 | loss_0 = 1000000 34 | count = 0 35 | for epoch in range(max_iter): 36 | count = count + 1 37 | Epoch.append(epoch) 38 | optimizer.zero_grad() 39 | affine_parameter = model(torch.cat([ref_tensor, sen_tensor], dim=1)) 40 | sen_tran_tensor, ref_inv_tensor, inv_affine_parameter = affine_transform_tf(sen_tensor, ref_tensor, 41 | affine_parameter, device) 42 | loss = ComputeLoss(ref_tensor, sen_tran_tensor, sen_tensor, ref_inv_tensor) 43 | Loss.append(loss) 44 | loss.backward() 45 | if loss < loss_0 and torch.isnan(loss).any() == False: 46 | count = 0 47 | sen_tran_tensor_save = sen_tran_tensor 48 | save_tensor_tf(sen_tran_tensor_save, save_sen_tran_name) 49 | loss_0 = loss 50 | if epoch>100: 51 | parameter_learn_rate = parameter_learn_rate*0.975 52 | optimizer = optim.SGD(model.parameters(), lr=parameter_learn_rate) 53 | 54 | if count > stop_iter: 55 | break 56 | optimizer.step() 57 | time_end = time.time() 58 | print('time cost', time_end - time_start, 's') 59 | show_registration_result(ref_img, sen_img, sen_tran_tensor_save) 60 | 61 | if __name__ == "__main__": 62 | ref_img = Image.open('../data/Optical-SAR/1_9_1.jpg') 63 | sen_img = Image.open('../data/Optical-SAR/1_9_3.jpg') 64 | ite(ref_img, sen_img) --------------------------------------------------------------------------------