├── pyqt ├── __init__.py ├── register_main.py └── register_v1.py ├── datasets ├── __init__.py ├── dataset_process │ ├── process_nirscene.py │ ├── process_jgp_image.py │ ├── process_mat_img.py │ ├── Generator.py │ └── export_image_pairs.py └── provider │ ├── pf_dataset.py │ ├── singlechannelData.py │ ├── test_dataset.py │ ├── nirrgbData.py │ └── randomTnsData.py ├── traditional_ntg ├── __init__.py ├── image_util.py ├── compute_image_pyramid.py ├── test_script.py ├── loss_function.py └── estimate_affine_param.py ├── poster-compressed-public.jpg ├── visualization ├── mutual_info_cave_dict.mat ├── visual_table.py ├── matplot_tool.py ├── visual_mutual_info.py ├── train_visual.py └── visual_result.py ├── util ├── time_util.py ├── matplot_util.py ├── eval_util.py ├── csv_opeartor.py ├── multi_gpu_util.py ├── interp.py ├── pytorchTcv.py ├── torch_util.py └── train_test_fn.py ├── README.md ├── tnf_transform ├── point_tnf.py └── transformation.py ├── evluate ├── eval_grid_loss.py ├── mutual_info_loss.py ├── cv2_ecc.py ├── evaluate_result.py └── lossfunc.py ├── traditional_methods └── orb │ └── orb_alignment.py ├── ntg_pytorch ├── multispectral_pytorch_test.py ├── register_pyramid.py ├── register_loss.py └── register_func.py ├── multispectral_pytorch_batch.py ├── main ├── test.py ├── eval_harvard_images.py └── eval_cave_images_singlechannel.py ├── demo_video.ipynb ├── cnn_geometric └── cnn_geometric_model.py ├── model └── cnn_registration_model.py └── eval_pf.py /pyqt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /traditional_ntg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /poster-compressed-public.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangliukun/registration_cnn_ntg/HEAD/poster-compressed-public.jpg -------------------------------------------------------------------------------- /visualization/mutual_info_cave_dict.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangliukun/registration_cnn_ntg/HEAD/visualization/mutual_info_cave_dict.mat -------------------------------------------------------------------------------- /util/time_util.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | def calculate_diff_time(previous_time): 5 | ''' 6 | :param previous_time:之前的时间 7 | :return: 两者的差值时间 8 | ''' 9 | elapsed = time.time() - previous_time 10 | return elapsed 11 | -------------------------------------------------------------------------------- /util/matplot_util.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_line_chart(X,Y,title='Mutual_Information_Chart',color='b',label = 'data'): 5 | type = plt.plot(X,Y,color = color,label = label) 6 | # plt.xticks(X,X) 7 | 8 | # plt.legend(loc='best') 9 | return type 10 | 11 | -------------------------------------------------------------------------------- /util/eval_util.py: -------------------------------------------------------------------------------- 1 | 2 | from sklearn import metrics 3 | import torch 4 | 5 | def calculate_mutual_info_batch(source_image_batch,target_image_batch): 6 | batch,c,h,w = source_image_batch.shape 7 | mutual_info_list = [] 8 | for i in range(batch): 9 | mutual_info_list.append(metrics.mutual_info_score( 10 | torch.flatten(source_image_batch[i]).cpu().detach().numpy(), 11 | torch.flatten(target_image_batch[i]).cpu().detach().numpy())) 12 | 13 | return mutual_info_list 14 | -------------------------------------------------------------------------------- /util/csv_opeartor.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | 4 | def read_csv_file(file_path): 5 | df = pd.read_csv(file_path) 6 | 7 | return df 8 | 9 | 10 | def extract_image_files(path): 11 | img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif'] 12 | vid_formats = ['.mov', '.avi', '.mp4'] 13 | 14 | with open(path,'r') as f: 15 | img_files = [x.replace('/',os.sep) for x in f.read().splitlines() 16 | if os.path.splitext(x)[-1].lower() in img_formats] 17 | 18 | return img_files -------------------------------------------------------------------------------- /pyqt/register_main.py: -------------------------------------------------------------------------------- 1 | 2 | #导入程序运行必须模块 3 | import sys 4 | #PyQt5中使用的基本控件都在PyQt5.QtWidgets模块中 5 | from PyQt5.QtWidgets import QApplication, QMainWindow 6 | #导入designer工具生成的login模块 7 | from pyqt.register import Ui_Dialog 8 | 9 | class MyMainForm(QMainWindow, Ui_Dialog): 10 | def __init__(self, parent=None): 11 | super(MyMainForm, self).__init__(parent) 12 | self.setupUi(self) 13 | 14 | if __name__ == "__main__": 15 | #固定的,PyQt5程序都需要QApplication对象。sys.argv是命令行参数列表,确保程序可以双击运行 16 | app = QApplication(sys.argv) 17 | #初始化 18 | myWin = MyMainForm() 19 | #将窗口控件显示在屏幕上 20 | myWin.show() 21 | #程序运行,sys.exit方法确保程序完整退出。 22 | sys.exit(app.exec_()) 23 | -------------------------------------------------------------------------------- /util/multi_gpu_util.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | import os 4 | import torch 5 | import torch.distributed as dist 6 | from torch.multiprocessing import Process 7 | 8 | def run(rank, size): 9 | """ Distributed function to be implemented later. """ 10 | pass 11 | 12 | ''' 13 | 这个函数保证了每个进程可以通过使用相同的io地址和端口对master进行定位。本质上允许进程之间 14 | 通过共享位置来进行交流 15 | ''' 16 | def init_process(rank, size, fn, backend='gloo'): 17 | """ Initialize the distributed environment. """ 18 | os.environ['MASTER_ADDR'] = '127.0.0.1' 19 | os.environ['MASTER_PORT'] = '29500' 20 | dist.init_process_group(backend, rank=rank, world_size=size) 21 | fn(rank, size) 22 | 23 | 24 | 25 | ''' 26 | 这个脚本产生了两个进程,每个设置了分布式的环境,初始化进程组,最后执行run函数 27 | ''' 28 | if __name__ == "__main__": 29 | size = 2 30 | processes = [] 31 | for rank in range(size): 32 | p = Process(target=init_process, args=(rank, size, run)) 33 | p.start() 34 | processes.append(p) 35 | 36 | for p in processes: 37 | p.join() -------------------------------------------------------------------------------- /util/interp.py: -------------------------------------------------------------------------------- 1 | from scipy import interpolate 2 | import skimage.io as io 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from PIL import Image 6 | 7 | from ntg_pytorch.register_func import scale_image 8 | 9 | def readImages(u1,v1): 10 | try: 11 | u = np.asarray(Image.open(u1).convert('L'), dtype=np.float64) 12 | u=u/np.max(u) 13 | v = np.asarray(Image.open(v1).convert('L'), dtype=np.float64) 14 | v=v/np.max(v) 15 | return (u,v) 16 | except IOError: 17 | return -1 18 | 19 | def interpolation(v): 20 | k=np.shape(v)[0] 21 | k1=np.shape(v)[1] 22 | return interpolate.interp2d(np.arange(k),np.arange(k1),v.T,kind='cubic',fill_value=0) 23 | 24 | 25 | if __name__ == '__main__': 26 | # img1 = io.imread('../datasets/row_data/multispectral/mul_1s_s.png') 27 | # img2 = io.imread('../datasets/row_data/multispectral/mul_1t_s.png') 28 | 29 | img1 = '../datasets/row_data/multispectral/mul_1s_s.png' 30 | img2 = '../datasets/row_data/multispectral/mul_1t_s.png' 31 | 32 | u,v = readImages(img1,img2) 33 | 34 | plt.figure() 35 | plt.imshow(v,cmap='gray') 36 | 37 | image_data = interpolation(v) 38 | 39 | image_data_z = np.reshape(image_data.z,u.shape) 40 | 41 | plt.figure() 42 | plt.imshow(image_data_z,cmap='gray') 43 | plt.show() -------------------------------------------------------------------------------- /datasets/dataset_process/process_nirscene.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def copy_file(source_path,target_path,filter): 6 | 7 | if not os.path.exists(target_path): 8 | os.makedirs(target_path) 9 | 10 | num = 0 11 | 12 | class_dir_list = sorted(os.listdir(source_path)) 13 | 14 | for class_dir in class_dir_list: 15 | class_dir_path = os.path.join(source_path,class_dir) 16 | class_image_list = sorted(os.listdir(class_dir_path)) 17 | print(class_dir) 18 | for class_image in class_image_list: 19 | if filter in class_image: 20 | fp = os.path.join(class_dir_path,class_image) 21 | if num % 100 == 0: 22 | print(class_image) 23 | newfp = os.path.join(target_path,'%04d'%num+".tiff") 24 | shutil.copy(fp,newfp) 25 | num += 1 26 | 27 | print("移动完成") 28 | 29 | if __name__ == '__main__': 30 | 31 | nirscene_path = "/mnt/4T/zlk/datasets/mulitspectral/nirscene1" 32 | nir_target_path = "/mnt/4T/zlk/datasets/mulitspectral/nirscene_total/nir_image" 33 | rgb_target_path = "/mnt/4T/zlk/datasets/mulitspectral/nirscene_total/rgb_image" 34 | 35 | copy_file(nirscene_path,nir_target_path,'nir') 36 | copy_file(nirscene_path,rgb_target_path,'rgb') 37 | 38 | 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Registration_CNN_NTG 2 | 3 | ## Paper 4 | - "A Multispectral Image Registration Method Based on Unsupervised Learning"(基于无监督学习的多光谱图像配准) 5 | 6 | ## Poster 7 | For the speed of the web loading, the image of the poster has been put in the link(poster-compressed-public.jpg) 8 | 9 | ## Train 10 | The average training progress costs about 6 hours. 11 | - Train.py: The geometric-model we used is the affine model and the training data is generatered online through random affine params. 12 | - model/cnn_registration_model.py: the main model of our model, which contains feature extraction, feature matching, feature regression. 13 | - ntg_pytroch/register_loss.py: This file contains our unsupervised loss function, which is first proposed by [this paper](https://www.researchgate.net/publication/321231034_Normalized_Total_Gradient_A_New_Measure_for_Multispectral_Image_Registration). 14 | 15 | ## Test 16 | - multispectral_pytorch_batch.py: We use two-stage registeration progress to achieve the sub-pixel level accuracy. Firstly, the deep model is used to estimate the rough affine params. Then we will use the traditional ntg method to optimize the rough params. 17 | 18 | ## Visualization 19 | For the purpose of visualization, we add the pyqt client to use our method quickly. 20 | 21 | ## FAQ 22 | If you have other questions, welcone to submit issues. 23 | -------------------------------------------------------------------------------- /datasets/dataset_process/process_jgp_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from skimage import io 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | 8 | from traditional_ntg.image_util import scale_image 9 | 10 | if __name__ == '__main__': 11 | 12 | dir_path = '/Users/zale/Downloads/complete_ms_data/balloons_ms/balloons_ms/' 13 | # dir_path = '/Users/zale/project/datasets/VOC_parts/' 14 | name_list = os.listdir(dir_path) 15 | print(name_list) 16 | 17 | for name in name_list: 18 | if str(name.split('.')[-1]) != 'png': 19 | continue 20 | image_path = os.path.join(dir_path,name) 21 | image_data = io.imread(image_path) 22 | # print(image_path) 23 | # print(image_data) 24 | #np.histogram(image_data[-1],10) 25 | IMIN = np.min(image_data) 26 | IMAX = np.max(image_data) 27 | # source_batch_max = torch.max(source_batch.view(batch_size, 1, -1), 2)[0].unsqueeze(2).unsqueeze(2) 28 | # source_batch_min = torch.min(source_batch.view(batch_size, 1, -1), 2)[0].unsqueeze(2).unsqueeze(2) 29 | 30 | image_data = scale_image(image_data, IMIN, IMAX) 31 | 32 | 33 | plt.figure() 34 | plt.hist(image_data[-1],bins=20) 35 | # plt.figure() 36 | # plt.imshow(image_data,cmap='gray') 37 | 38 | 39 | plt.show() 40 | -------------------------------------------------------------------------------- /visualization/visual_table.py: -------------------------------------------------------------------------------- 1 | from visualization.train_visual import VisdomHelper 2 | import numpy as np 3 | 4 | 5 | def test_visdom_line(vis): 6 | x_list = [i for i in range(10)] 7 | 8 | A_list = [i + 14 for i in range(10)] 9 | B_list = [0 + 15 for i in range(10)] 10 | C_list = [0 + 16 for i in range(10)] 11 | D_list = [0 + 11 for i in range(10)] 12 | 13 | # print(x_list,y_list) 14 | 15 | vis.drawGridlossGroup(x_list, A_list, B_list, C_list, D_list, layout_title='line') 16 | vis.getVisdom().line(X=np.column_stack((x_list,x_list)), 17 | Y =np.column_stack((B_list,A_list))) 18 | 19 | def test_bar(vis): 20 | x_list = [i for i in range(10)] 21 | 22 | A_list = [i + 14 for i in range(10)] 23 | B_list = [0 + 15 for i in range(10)] 24 | C_list = [0 + 16 for i in range(10)] 25 | D_list = [0 + 11 for i in range(10)] 26 | 27 | 28 | # vis.getVisdom().bar( 29 | # X=np.column_stack((A_list,B_list, C_list,D_list)), 30 | # opts=dict( 31 | # stacked=False, 32 | # legend=['The Netherlands', 'France', 'United States','sdfsd'] 33 | # ) 34 | # ) 35 | vis.drawGridlossBar(x_list,A_list,B_list,C_list,D_list,layout_title='Grid_loss_histogram') 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | env = 'DMN_test' 41 | vis = VisdomHelper(env) 42 | 43 | test_bar(vis) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /traditional_ntg/image_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # 将图片从rbg变为灰度图 4 | import torch 5 | 6 | 7 | def rgb2gray(rgb): 8 | return np.dot(rgb[...,:3],[0.299, 0.587, 0.114]) #分别对应通道 R G B 9 | 10 | # 归一化图片 11 | def scale_image(img,IMIN,IMAX): 12 | return (img-IMIN)/(IMAX-IMIN) 13 | 14 | 15 | def symmetricImagePad(image_batch,padding_factor=0.6,use_cuda = False): 16 | ''' 17 | 使用边缘镜像对称来扩充图像,先左右,后上下,选取边缘然后cat拼接 18 | :param image_batch: 批图像 19 | :param padding_factor: b, c, h, w 20 | :param use_cuda: 21 | :return: 22 | ''' 23 | b, c, h, w = image_batch.size() 24 | pad_h, pad_w = int(h * padding_factor), int(w * padding_factor) 25 | idx_pad_left = torch.LongTensor(range(pad_w - 1, -1, -1)) 26 | idx_pad_right = torch.LongTensor(range(w - 1, w - pad_w - 1, -1)) 27 | idx_pad_top = torch.LongTensor(range(pad_h - 1, -1, -1)) 28 | idx_pad_bottom = torch.LongTensor(range(h - 1, h - pad_h - 1, -1)) 29 | if use_cuda: 30 | idx_pad_left = idx_pad_left.cuda() 31 | idx_pad_right = idx_pad_right.cuda() 32 | idx_pad_top = idx_pad_top.cuda() 33 | idx_pad_bottom = idx_pad_bottom.cuda() 34 | image_batch = torch.cat((image_batch.index_select(3,idx_pad_left),image_batch,image_batch.index_select(3,idx_pad_right)),3) 35 | image_batch = torch.cat((image_batch.index_select(2, idx_pad_top), image_batch,image_batch.index_select(2, idx_pad_bottom)), 2) 36 | return image_batch -------------------------------------------------------------------------------- /traditional_ntg/compute_image_pyramid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import scipy.ndimage 4 | import matplotlib.pyplot as plt 5 | import scipy.misc as smi 6 | import torch 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | from PIL.Image import BICUBIC 10 | 11 | 12 | def compute_image_pyramid(image,f,nL,ration): 13 | 14 | P = [] 15 | tmp = image 16 | dstImg = '' 17 | P.append(tmp) 18 | #f = np.stack((f,f),2) 19 | 20 | for m in range(1,nL): 21 | # dstSize = np.round([tmp.shape[0]*ration,tmp.shape[1]*ration]) 22 | # dstSize = tuple((int(dstSize[0]),int(dstSize[1]))) 23 | # dstImg = np.zeros(dstSize) 24 | #tmp = cv2.pyrDown(tmp,dstsize=dstSize) 25 | #tmp = cv2.pyrDown(tmp) 26 | 27 | # gaussian_kernel = f.expand(1,1,3,3) 28 | # gaussian_kernel = torch.nn.Parameter(data=gaussian_kernel,requires_grad=False) 29 | # I_gauss = F.conv2d(tmp.unsqueeze(0),gaussian_kernel) 30 | 31 | tmp = cv2.filter2D(tmp,-1,f) 32 | 33 | # 使用skimage来resize图片:https://scikit-image.org/docs/stable/auto_examples/transform/plot_rescale.html 34 | 35 | im1 = np.array(Image.fromarray(tmp[:,:,0]).resize((int(tmp[:,:,0].shape[0]*ration),int(tmp[:,:,0].shape[1]*ration)),resample=BICUBIC)) 36 | im2 = np.array(Image.fromarray(tmp[:,:,1]).resize((int(tmp[:,:,1].shape[0]*ration),int(tmp[:,:,1].shape[1]*ration)),resample=BICUBIC)) 37 | 38 | # im1 = smi.imresize(tmp[:,:,0],size=ration)/255.0 39 | # im2 = smi.imresize(tmp[:,:,1],size=ration)/255.0 40 | tmp = np.stack((im1,im2),2) 41 | 42 | P.append(tmp) 43 | 44 | 45 | 46 | return P -------------------------------------------------------------------------------- /tnf_transform/point_tnf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class PointTnf: 6 | def __init__(self,use_cuda = True): 7 | self.use_cuda = use_cuda 8 | 9 | def tpsPointTnf(self,theta,points): 10 | # points are expected in [B,2,N], where first row is X and second row is Y 11 | # reshape points for applying Tps transformation 12 | points=points.unsqueeze(3).transpose(1,3) 13 | # apply transformation 14 | warped_points = self.tpsTnf.apply_transformation(theta,points) 15 | # undo reshaping 16 | warped_points=warped_points.transpose(3,1).squeeze(3) 17 | return warped_points 18 | 19 | def affPointTnf(self,theta,points): 20 | theta_mat = theta.view(-1,2,3) 21 | warped_points = torch.bmm(theta_mat[:,:,:2],points) 22 | warped_points += theta_mat[:,:,2].unsqueeze(2).expand_as(warped_points) 23 | return warped_points 24 | 25 | 26 | def PointsToUnitCoords(P,im_size): 27 | h,w = im_size[:,0],im_size[:,1] 28 | NormAxis = lambda x,L: (x-1-(L-1)/2)*2/(L-1) 29 | P_norm = P.clone() 30 | # normalize Y 31 | P_norm[:,0,:] = NormAxis(P[:,0,:],w.unsqueeze(1).expand_as(P[:,0,:])) 32 | # normalize X 33 | P_norm[:,1,:] = NormAxis(P[:,1,:],h.unsqueeze(1).expand_as(P[:,1,:])) 34 | return P_norm 35 | 36 | def PointsToPixelCoords(P,im_size): 37 | h,w = im_size[:,0],im_size[:,1] 38 | NormAxis = lambda x,L: x*(L-1)/2+1+(L-1)/2 39 | P_norm = P.clone() 40 | # normalize Y 41 | P_norm[:,0,:] = NormAxis(P[:,0,:],w.unsqueeze(1).expand_as(P[:,0,:])) 42 | # normalize X 43 | P_norm[:,1,:] = NormAxis(P[:,1,:],h.unsqueeze(1).expand_as(P[:,1,:])) 44 | return P_norm -------------------------------------------------------------------------------- /evluate/eval_grid_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from evluate.lossfunc import GridLoss 4 | import numpy as np 5 | 6 | from util.pytorchTcv import param2theta 7 | import scipy.io as scio 8 | 9 | if __name__ == '__main__': 10 | # grid_fun = GridLoss(use_cuda=False,grid_size=512) 11 | # 12 | # theta_gt = np.array([[[ 1.0833,0.1910,-80.2212], 13 | # [ -0.1910,1.0833,37.5775]]]) 14 | # theta_gt_big = np.array([[[1.08253175,0.625,-201.12812921], 15 | # [-0.625,1.08253175,158.87187079]]]) 16 | # 17 | # theta_target1 = np.array([[[1.0443,0.1159, -46.4946], 18 | # [-0.1619, 1.1374, 12.6129]]]) 19 | # 20 | # theta_target1 = np.array([[[1.0863,0.2191,-90.8348], 21 | # [-0.1884,1.0819,36.5699]]]) 22 | # 23 | # 24 | # theta_gt = torch.from_numpy(theta_gt).float() 25 | # theta_gt_big = torch.from_numpy(theta_gt_big).float() 26 | # 27 | # theta_gt_pytorch = param2theta(theta_gt, 512, 512, use_cuda=False) 28 | # theta_gt_big_pytorch = param2theta(theta_gt_big, 512, 512, use_cuda=False) 29 | # 30 | # # theta_target = torch.from_numpy(theta_target2).float() 31 | # theta_target1 = torch.from_numpy(theta_target1).float() 32 | # 33 | # theta_target1_pytorch = param2theta(theta_target1, 512, 512, use_cuda=False) 34 | # 35 | # grid_loss = grid_fun.compute_grid_loss(theta_target1,theta_gt) 36 | # grid_loss_pytorch = grid_fun.compute_grid_loss(theta_target1_pytorch,theta_gt_pytorch) 37 | # # 38 | # print(grid_loss) 39 | # print(grid_loss_pytorch) 40 | dict = {} 41 | alist = [1,2,3,4,5] 42 | blist = [11,2,23,23,23] 43 | dict['a'] = alist 44 | dict['b'] = blist 45 | scio.savemat('test.mat',dict) 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /datasets/dataset_process/process_mat_img.py: -------------------------------------------------------------------------------- 1 | 2 | import scipy.io as scio 3 | import skimage.io as io 4 | from PIL import Image 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import os 8 | 9 | from ntg_pytorch.register_func import scale_image 10 | 11 | def process_mat_img(): 12 | image_folder = '/Users/zale/project/datasets/Harvard/' 13 | mat_image_name_list = sorted(os.listdir(image_folder)) 14 | 15 | count = 0 16 | for item in mat_image_name_list: 17 | mat_image_path = os.path.join(image_folder,item) 18 | print(mat_image_path) 19 | array_struct = scio.loadmat(mat_image_path) 20 | array_data = array_struct['ms_image_denoised'] 21 | array_data_1 = array_data[:,:,16] 22 | IMAX = np.max(array_data_1) 23 | IMIN = np.min(array_data_1) 24 | I_mean = scale_image(array_data_1,IMIN,IMAX) 25 | count += 1 26 | 27 | if count%10 == 0: 28 | break 29 | # plt.figure() 30 | # plt.imshow(I_mean,cmap='gray') 31 | plt.figure() 32 | plt.hist(I_mean[-1],bins=20) 33 | plt.show() 34 | 35 | def process_cave_image(): 36 | 37 | output_folder = '/Users/zale/project/datasets/complete_ms_data_mat/' 38 | 39 | file_folder = '/Users/zale/project/datasets/complete_ms_data/' 40 | category_name_list = sorted(os.listdir(file_folder)) 41 | for i,item in enumerate(category_name_list): 42 | if 'ms' not in str(item): 43 | continue 44 | image_list = [] 45 | category_folder = os.path.join(file_folder,item+'/',item+'/') 46 | print(category_folder) 47 | image_name_list = sorted(os.listdir(category_folder)) 48 | for image_name in image_name_list: 49 | if image_name.split('.')[1] == 'png': 50 | image_array = io.imread(os.path.join(category_folder,image_name)) 51 | # 最后一个watercolors数据通道数有4个,做一下兼容 52 | if len(image_array.shape)>2: 53 | image_array = image_array[:,:,0] 54 | image_list.append(image_array) 55 | image_batch = np.array(image_list).transpose((1,2,0)) 56 | print(i,image_batch.shape) 57 | scio.savemat(os.path.join(output_folder,str(item)+'.mat'),{'cave_mat':image_batch}) 58 | 59 | 60 | 61 | 62 | 63 | if __name__ == '__main__': 64 | 65 | process_cave_image() 66 | 67 | -------------------------------------------------------------------------------- /util/pytorchTcv.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | # 将opencv的变换参数转换为pytorch的变换参数 6 | def param2theta(param, h, w,use_cuda=True): 7 | ''' 8 | :param param: [batch,2,3] 9 | :param w: 10 | :param h: 11 | :param use_cuda: 12 | :return: theta [batch,2,3] 13 | ''' 14 | if use_cuda: 15 | param = param.cuda() 16 | third_row = torch.zeros((param.shape[0],1,3)) 17 | theta = torch.zeros((param.shape[0], 2, 3)) 18 | if use_cuda: 19 | third_row = third_row.cuda() 20 | theta = theta.cuda() 21 | 22 | third_row[:,:,2] = 1 23 | square_matrix = torch.cat((param,third_row),1) 24 | 25 | inverse_matrix = torch.inverse(square_matrix) 26 | 27 | theta[:,0, 0] = inverse_matrix[:,0, 0] 28 | theta[:,0, 1] = inverse_matrix[:,0, 1] * h / w 29 | theta[:,0, 2] = inverse_matrix[:,0, 2] * 2 / w + theta[:,0, 0] + theta[:,0, 1] - 1 30 | theta[:,1, 0] = inverse_matrix[:,1, 0] * w / h 31 | theta[:,1, 1] = inverse_matrix[:,1, 1] 32 | theta[:,1, 2] = inverse_matrix[:,1, 2] * 2 / h + theta[:,1, 0] + theta[:,1, 1] - 1 33 | return theta 34 | 35 | # 将pytorch的仿射变换参数转化为opencv的变换参数 36 | def theta2param(theta,w,h,use_cuda=True): 37 | ''' 38 | :param theta: [batch,2,3] 39 | :param w: 40 | :param h: 41 | :param use_cuda: 42 | :return: opencv_param [batch,2,3] 43 | ''' 44 | param = torch.zeros((theta.shape[0],2,3)) 45 | third_row = torch.zeros((theta.shape[0],1,3)) 46 | if use_cuda: 47 | third_row = third_row.cuda() 48 | param = param.cuda() 49 | 50 | third_row[:,:,2] = 1 51 | param[:,0,0] = theta[:,0,0] 52 | param[:,0,1] = theta[:,0,1] * w / h 53 | param[:,0,2] = (-theta[:,0,0]-theta[:,0,1]+theta[:,0,2]+1) * w / 2 54 | param[:,1,0] = theta[:,1,0] * h / w 55 | param[:,1,1] = theta[:,1,1] 56 | param[:,1,2] = (-theta[:,1,0]-theta[:,1,1]+theta[:,1,2]+1) * h / 2 57 | 58 | square_matrix = torch.cat((param,third_row),1) 59 | opencv_param = torch.inverse(square_matrix)[:,0:2,:] 60 | return opencv_param 61 | 62 | def inverse_theta(theta,use_cuda=True): 63 | if not isinstance(theta, torch.Tensor): 64 | theta = torch.from_numpy(theta).float() 65 | third_row = torch.zeros((theta.shape[0], 1, 3)) 66 | if use_cuda: 67 | third_row = third_row.cuda() 68 | 69 | third_row[:, :, 2] = 1 70 | square_matrix = torch.cat((theta,third_row),1) 71 | opencv_param = torch.inverse(square_matrix)[:,0:2,:] 72 | return opencv_param 73 | -------------------------------------------------------------------------------- /evluate/mutual_info_loss.py: -------------------------------------------------------------------------------- 1 | # 参考链接https://matthew-brett.github.io/teaching/mutual_information.html 2 | 3 | # compatibility with python2 4 | from __future__ import print_function # print方法 5 | from __future__ import division # 1/2==0.5而不是0 6 | 7 | # import common modules 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from skimage import io 11 | 12 | # set gray colormap and nearnst neighbor interpolation by default 13 | plt.rcParams['image.cmap'] = 'gray' 14 | plt.rcParams['image.interpolation'] = 'nearest' 15 | 16 | # 注意,如果在IPython中运行的话,使用%matplotlib来使interactive plots生效,如果在Jupyter Notebook中 17 | # 是运行,使用%matplotlib inline 18 | 19 | image1 = io.imread("/Users/zale/project/myself/registration_cnn_ntg/datasets/row_data/multispectral/It.jpg") 20 | image2 = io.imread("/Users/zale/project/myself/registration_cnn_ntg/datasets/row_data/multispectral/It180.jpg") 21 | 22 | # t1_slice = np.array([11,21,31,41,54,61,71,81,91,101,112,123,131,145,159]) 23 | # t2_slice = np.array([15,25,34,45,56,67,77,86,95,108,111,125,136,145,155]) 24 | 25 | t1_slice = image1[:,:,0] 26 | t2_slice = image2[:,:,1] 27 | #t2_slice = image1[:,:,1] 28 | 29 | 30 | plt.imshow(np.hstack((t1_slice,t2_slice))) 31 | 32 | # plt.figure() 33 | fig,axes = plt.subplots(1,2) 34 | axes[0].hist(t1_slice.ravel(),bins=20) 35 | axes[0].set_title('t1 hist') 36 | 37 | axes[1].hist(t2_slice.ravel(),bins=20) 38 | axes[1].set_title('t2 hist') 39 | # plt.show() 40 | 41 | plt.figure() 42 | # plt.plot(t1_slice.ravel(),t2_slice.ravel(),'.') 43 | # plt.xlabel('t1 signal') 44 | # plt.ylabel('t2 signal') 45 | # plt.title('t1 vs t2 signal') 46 | hist_2d,x_edges,y_edges = np.histogram2d(t1_slice.ravel(),t2_slice.ravel(),bins=20) 47 | plt.imshow(hist_2d.T,origin='lower') 48 | plt.xlabel('T1 signal bin') 49 | plt.ylabel('T2 signal bin') 50 | 51 | 52 | # Show log histogram, avoiding divide by 0 53 | plt.figure() 54 | hist_2d_log = np.zeros(hist_2d.shape) 55 | non_zeros = hist_2d != 0 56 | hist_2d_log[non_zeros] = np.log(hist_2d[non_zeros]) 57 | plt.imshow(hist_2d_log.T, origin='lower') 58 | plt.xlabel('T1 signal nozero bin') 59 | plt.ylabel('T2 signal nozero bin') 60 | 61 | plt.show() 62 | 63 | def mutual_information(hgram): 64 | """ Mutual information for joint histogram 65 | """ 66 | # Convert bins counts to probability values 67 | pxy = hgram / float(np.sum(hgram)) 68 | px = np.sum(pxy, axis=1) # marginal for x over y 69 | py = np.sum(pxy, axis=0) # marginal for y over x 70 | px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals 71 | # Now we can do the calculation using the pxy, px_py 2D arrays 72 | nzs = pxy > 0 # Only non-zero pxy values contribute to the sum 73 | return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs])) 74 | 75 | print(mutual_information(hist_2d)) -------------------------------------------------------------------------------- /traditional_methods/orb/orb_alignment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | MAX_FEATURES = 500 5 | GOOD_MATCH_PERCENT = 0.15 6 | 7 | 8 | def alignImages(im1, im2): 9 | # Convert images to grayscale 10 | im1Gray = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY) 11 | im2Gray = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY) 12 | 13 | # Detect ORB features and compute descriptors. 14 | orb = cv2.ORB_create(MAX_FEATURES) 15 | keypoints1, descriptors1 = orb.detectAndCompute(im1Gray, None) 16 | keypoints2, descriptors2 = orb.detectAndCompute(im2Gray, None) 17 | 18 | # Match features. 19 | matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING) 20 | matches = matcher.match(descriptors1, descriptors2, None) 21 | 22 | # Sort matches by score 23 | matches.sort(key=lambda x: x.distance, reverse=False) 24 | 25 | # Remove not so good matches 26 | numGoodMatches = int(len(matches) * GOOD_MATCH_PERCENT) 27 | matches = matches[:numGoodMatches] 28 | 29 | # Draw top matches 30 | imMatches = cv2.drawMatches(im1, keypoints1, im2, keypoints2, matches, None) 31 | cv2.imwrite("matches.jpg", imMatches) 32 | 33 | # Extract location of good matches 34 | points1 = np.zeros((len(matches), 2), dtype=np.float32) 35 | points2 = np.zeros((len(matches), 2), dtype=np.float32) 36 | 37 | for i, match in enumerate(matches): 38 | points1[i, :] = keypoints1[match.queryIdx].pt 39 | points2[i, :] = keypoints2[match.trainIdx].pt 40 | 41 | # Find homography 42 | # h, mask = cv2.findHomography(points1, points2, cv2.RANSAC) 43 | # M = cv2.getAffineTransform(points1, points2) # 点对数为3 44 | M = cv2.estimateAffinePartial2D(points1, points2) 45 | 46 | # Use homography 47 | height, width, channels = im2.shape 48 | # im1Reg = cv2.warpPerspective(im1, h, (width, height)) 49 | im1Reg = cv2.warpAffine(im1,M[0],(width,height)) 50 | 51 | # return im1Reg, h 52 | return im1Reg, M 53 | 54 | 55 | if __name__ == '__main__': 56 | # Read reference image 57 | refFilename = "../images/keypoint.jpg" 58 | print("Reading reference image : ", refFilename) 59 | imReference = cv2.imread(refFilename, cv2.IMREAD_COLOR) 60 | 61 | # Read image to be aligned 62 | imFilename = "../images/keypoint2.jpg" 63 | print("Reading image to align : ", imFilename); 64 | im = cv2.imread(imFilename, cv2.IMREAD_COLOR) 65 | 66 | print("Aligning images ...") 67 | # Registered image will be resotred in imReg. 68 | # The estimated homography will be stored in h. 69 | imReg, h = alignImages(im, imReference) 70 | 71 | # Write aligned image to disk. 72 | outFilename = "aligned.jpg" 73 | print("Saving aligned image : ", outFilename); 74 | cv2.imwrite(outFilename, imReg) 75 | 76 | # Print estimated homography 77 | print("Estimated homography : \n", h) -------------------------------------------------------------------------------- /evluate/cv2_ecc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | """ 3 | OpenCV Image Alignment Example 4 | Copyright 2015 by Satya Mallick 5 | """ 6 | 7 | import cv2 8 | import numpy as np 9 | import time 10 | import skimage.io as io 11 | 12 | def estimate_affine_ecc(im1,im2): 13 | # Read the images to be aligned 14 | # im1 = cv2.imread("image1.jpg") 15 | # im2 = cv2.imread("image2.jpg") 16 | 17 | # im2 = cv2.imread('../datasets/row_data/multispectral/fake_and_real_tomatoes_ms_31.png') 18 | # im1 = cv2.imread('../datasets/row_data/multispectral/fake_and_real_tomatoes_ms_17.png') 19 | 20 | im1_gray = cv2.cvtColor(np.asarray(im1), cv2.COLOR_RGB2GRAY) 21 | im2_gray = cv2.cvtColor(np.asarray(im2), cv2.COLOR_RGB2GRAY) 22 | 23 | # Convert images to grayscale 24 | # im1_gray = cv2.cvtColor(im1, cv2.COLOR_BGR2GRAY) 25 | # im2_gray = cv2.cvtColor(im2, cv2.COLOR_BGR2GRAY) 26 | 27 | # Find size of image1 28 | im1_size = im1.shape 29 | 30 | # Define the motion model 31 | # warp_mode = cv2.MOTION_TRANSLATION 32 | warp_mode = cv2.MOTION_AFFINE 33 | # warp_mode = cv2.MOTION_EUCLIDEAN 34 | 35 | # Define 2x3 or 3x3 matrices and initialize the matrix to identity 36 | if warp_mode == cv2.MOTION_HOMOGRAPHY: 37 | warp_matrix = np.eye(3, 3, dtype=np.float32) 38 | else: 39 | warp_matrix = np.eye(2, 3, dtype=np.float32) 40 | 41 | # Specify the number of iterations. 42 | number_of_iterations = 5000 43 | 44 | # Specify the threshold of the increment 45 | # in the correlation coefficient between two iterations 46 | termination_eps = 1e-10 47 | 48 | # Define termination criteria 49 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) 50 | 51 | # Enhanced Correlation Coefficient (ECC) 52 | # Run the ECC algorithm. The results are stored in warp_matrix. 53 | start = time.time() 54 | (cc, warp_matrix) = cv2.findTransformECC(im1_gray, im2_gray, warp_matrix, warp_mode, criteria, None, 5) 55 | 56 | if warp_mode == cv2.MOTION_HOMOGRAPHY: 57 | # Use warpPerspective for Homography 58 | im2_aligned = cv2.warpPerspective(im2, warp_matrix, (im1_size[1], im1_size[0]), 59 | flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) 60 | else: 61 | # Use warpAffine for Translation, Euclidean and Affine 62 | im2_aligned = cv2.warpAffine(im2, warp_matrix, (im1_size[1], im1_size[0]), 63 | flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP) 64 | 65 | end = time.time() 66 | print('Alignment time (s): ', end - start) 67 | # print(warp_matrix) 68 | # # Show final results 69 | # cv2.imshow("Image 1", im1) 70 | # cv2.imshow("Image 2", im2) 71 | # cv2.imshow("Aligned Image 2", im2_aligned) 72 | # cv2.waitKey(0) 73 | 74 | 75 | if __name__ == '__main__': 76 | im1 = io.imread('../datasets/row_data/multispectral/mul_1t_s.png') 77 | im2 = io.imread('../datasets/row_data/multispectral/mul_1s_s.png') 78 | 79 | estimate_affine_ecc(im1,im2) 80 | -------------------------------------------------------------------------------- /util/torch_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from os import makedirs, remove 4 | from os.path import exists, join, basename, dirname 5 | 6 | import numpy as np 7 | import shutil 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.distributed as dist 11 | 12 | 13 | def init_seeds(seed=0): 14 | torch.cuda.empty_cache() 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | # torch.backends.cudnn.deterministic = True # https://pytorch.org/docs/stable/notes/randomness.html 19 | 20 | def select_device(multi_process= False,force_cpu=False, apex=False): 21 | # apex if mixed precision training https://github.com/NVIDIA/apex 22 | cuda = False if force_cpu else torch.cuda.is_available() 23 | 24 | local_rank = 0 25 | 26 | if multi_process: 27 | local_rank = dist.get_rank() 28 | torch.cuda.set_device(local_rank) 29 | device = torch.device('cuda', local_rank) 30 | else: 31 | device = torch.device('cuda:0' if cuda else 'cpu') 32 | 33 | #device = torch.device('cuda:0' if cuda else 'cpu') 34 | 35 | if not cuda: 36 | print('Using CPU') 37 | if cuda: 38 | torch.backends.cudnn.benchmark = True # set False for reproducible results 39 | c = 1024 ** 2 # bytes to MB 40 | ng = torch.cuda.device_count() 41 | x = [torch.cuda.get_device_properties(i) for i in range(ng)] 42 | cuda_str = 'Using CUDA ' + ('Apex ' if apex else '') 43 | for i in range(0, ng): 44 | if i == 1: 45 | # torch.cuda.set_device(0) # OPTIONAL: Set GPU ID 46 | cuda_str = ' ' * len(cuda_str) 47 | print("%sdevice%g _CudaDeviceProperties(name='%s', total_memory=%dMB)" % 48 | (cuda_str, i, x[i].name, x[i].total_memory / c)) 49 | 50 | print('') # skip a line 51 | return device,local_rank 52 | 53 | def save_checkpoint(state, is_best, file): 54 | model_dir = dirname(file) 55 | model_fn = basename(file) 56 | # make dir if needed (should be non-empty) 57 | if model_dir!='' and not exists(model_dir): 58 | makedirs(model_dir) 59 | # 保存模型 60 | torch.save(state, file) 61 | # 如果模型损失是最低的,则拷贝一份 62 | if is_best: 63 | shutil.copyfile(file, join(model_dir,'best_'+model_fn)) 64 | 65 | 66 | def str_to_bool(v): 67 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 68 | return True 69 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 70 | return False 71 | else: 72 | raise argparse.ArgumentTypeError('Boolean value expected.') 73 | 74 | 75 | class BatchTensorToVars(object): 76 | """Convert tensors in dict batch to vars 77 | """ 78 | 79 | def __init__(self, use_cuda=True): 80 | self.use_cuda = use_cuda 81 | 82 | def __call__(self, batch): 83 | batch_var = {} 84 | for key, value in batch.items(): 85 | batch_var[key] = Variable(value, requires_grad=False) 86 | if self.use_cuda: 87 | batch_var[key] = batch_var[key].cuda() 88 | 89 | return batch_var 90 | -------------------------------------------------------------------------------- /visualization/matplot_tool.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # def plot_batch_result(source_image_list,target_image_list, 4 | # warped_image_list,warped_image_GT_list): 5 | import torch 6 | import numpy as np 7 | 8 | from util.matplot_util import plot_line_chart 9 | 10 | 11 | def plot_batch_result(*image_list,plot_title): 12 | 13 | #plt.figure(figsize=(20,10)) 14 | plt.figure(figsize=(40,20)) 15 | plt.suptitle('compare images') 16 | 17 | #assert len(image_list) == 5 18 | assert len(image_list) == len(plot_title) 19 | 20 | plot_row = len(image_list) 21 | for i in range(plot_row): 22 | plot_col = len(image_list[i]) 23 | for j in range(plot_col): 24 | plt.subplot(plot_row,plot_col,i*plot_col+j+1), plt.title(plot_title[i]+str(j+1)) 25 | plt.imshow(image_list[i][j].squeeze().detach().numpy(),cmap='gray') 26 | 27 | print('prepare to show') 28 | plt.show() 29 | print('done') 30 | 31 | 32 | def plot_matual_information_batch_result(*image_list,plot_title,matual_info_list_batch,matual_info_traditional_list_batch,iter_list): 33 | 34 | plt.figure(figsize=(20,10)) 35 | plt.suptitle('compare images') 36 | 37 | #assert len(image_list) == 5 38 | #assert len(image_list) == len(plot_title) 39 | 40 | matual_info_list_batch = np.array(matual_info_list_batch) 41 | matual_info_traditional_list_batch = np.array(matual_info_traditional_list_batch) 42 | 43 | plot_row = len(image_list) + 1 44 | plot_col = len(image_list[0]) 45 | for i in range(plot_row): 46 | for j in range(plot_col): 47 | if i == (plot_row - 1): 48 | plt.subplot(plot_row,plot_col,i*plot_col+j+1) 49 | plot_line_chart(iter_list, matual_info_list_batch[:,j].tolist(), title='cnn_ntg', color='r', label='cnn_ntg') 50 | plot_line_chart(iter_list, matual_info_traditional_list_batch[:,j].tolist(), title='ntg', color='b', label='ntg') 51 | else: 52 | plt.subplot(plot_row,plot_col,i*plot_col+j+1) 53 | plt.title(plot_title[i]+str(j+1)) 54 | plt.imshow(image_list[i][j].squeeze().detach().numpy(),cmap='gray') 55 | 56 | plt.show() 57 | 58 | def plot_grid_loss_batch(*image_list,plot_title,grid_loss_batch,grid_loss_trditional_batch,iter_list): 59 | 60 | plt.figure(figsize=(20,10)) 61 | plt.suptitle('compare images') 62 | 63 | #assert len(image_list) == 5 64 | #assert len(image_list) == len(plot_title) 65 | 66 | grid_loss_batch = np.array(grid_loss_batch) 67 | grid_loss_trditional_batch = np.array(grid_loss_trditional_batch) 68 | 69 | plot_row = len(image_list) + 1 70 | plot_col = len(image_list[0]) 71 | for i in range(plot_row): 72 | for j in range(plot_col): 73 | if i == (plot_row - 1): 74 | plt.subplot(plot_row,plot_col,i*plot_col+j+1) 75 | plot_line_chart(iter_list, grid_loss_batch[:,j].tolist(), title='cnn_ntg', color='r', label='cnn_ntg') 76 | plot_line_chart(iter_list, grid_loss_trditional_batch[:,j].tolist(), title='ntg', color='b', label='ntg') 77 | else: 78 | plt.subplot(plot_row,plot_col,i*plot_col+j+1) 79 | plt.title(plot_title[i]+str(j+1)) 80 | plt.imshow(image_list[i][j].squeeze().detach().numpy(),cmap='gray') 81 | 82 | plt.show() 83 | -------------------------------------------------------------------------------- /visualization/visual_mutual_info.py: -------------------------------------------------------------------------------- 1 | 2 | import scipy.io as scio 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from util.matplot_util import plot_line_chart 7 | 8 | 9 | def convert_array_list(numpy_array): 10 | list = [] 11 | for i in range(np.size(numpy_array)): 12 | mutual_temp = numpy_array[i][0] 13 | for j in range(np.size(mutual_temp)): 14 | list.append(mutual_temp[j]) 15 | return list 16 | 17 | def convert_cave_list(numpy_array): 18 | list = [] 19 | for i in range(numpy_array.shape[0]): 20 | for j in range(numpy_array.shape[1]): 21 | list.append(numpy_array[i][j]) 22 | return list 23 | 24 | def visual_mutual_coco(): 25 | mutual_info_coco_dict = scio.loadmat('mutual_info_coco_dict.mat') 26 | 27 | mutual_info_cvpr_list = mutual_info_coco_dict['mutual_info_cvpr_list'][0] 28 | mutual_info_cnn_list = mutual_info_coco_dict['mutual_info_cnn_list'][0] 29 | mutual_info_ntg_list = mutual_info_coco_dict['mutual_info_ntg_list'][0] 30 | mutual_info_comb_list = mutual_info_coco_dict['mutual_info_comb_list'][0] 31 | 32 | cnn_list = convert_array_list(mutual_info_cnn_list) 33 | cvpr_list = convert_array_list(mutual_info_cvpr_list) 34 | ntg_list = convert_array_list(mutual_info_ntg_list) 35 | comb_list = convert_array_list(mutual_info_comb_list) 36 | 37 | start_index = 100 38 | end_index = -1 39 | 40 | plot_line_chart(range(len(cvpr_list))[start_index:end_index],sorted(cvpr_list)[start_index:end_index],color='b',label = 'CNNGeometric') 41 | plot_line_chart(range(len(ntg_list))[start_index:end_index],sorted(ntg_list)[start_index:end_index],color='g',label = 'NTG') 42 | plot_line_chart(range(len(cnn_list))[start_index:end_index],sorted(cnn_list)[start_index:end_index],color='r',label = 'Ours') 43 | plot_line_chart(range(len(comb_list))[start_index:end_index],sorted(comb_list)[start_index:end_index],color='y',label = 'Ours&NTG') 44 | 45 | plt.grid() 46 | plt.show() 47 | 48 | pass 49 | 50 | def visual_mutual_cave(): 51 | cave_info_dict = scio.loadmat('mutual_info_cave_dict.mat') 52 | mutual_info_cnn_list = cave_info_dict['mutual_info_cnn_list'] 53 | mutual_info_cvpr_list = cave_info_dict['mutual_info_cvpr_list'] 54 | mutual_info_ntg_list = cave_info_dict['mutual_info_ntg_list'] 55 | mutual_info_comb_list = cave_info_dict['mutual_info_comb_list'] 56 | 57 | cnn_list = convert_cave_list(mutual_info_cnn_list) 58 | cvpr_list = convert_cave_list(mutual_info_cvpr_list) 59 | ntg_list = convert_cave_list(mutual_info_ntg_list) 60 | comb_list = convert_cave_list(mutual_info_comb_list) 61 | 62 | start_index = 100 63 | end_index = -1 64 | 65 | plot_line_chart(range(len(cvpr_list))[start_index:end_index], sorted(cvpr_list)[start_index:end_index], color='b', label='CNNGeometric') 66 | plot_line_chart(range(len(ntg_list))[start_index:end_index], sorted(ntg_list)[start_index:end_index], color='g', label='NTG') 67 | plot_line_chart(range(len(cnn_list))[start_index:end_index], sorted(cnn_list)[start_index:end_index], color='r', label='Ours') 68 | plot_line_chart(range(len(comb_list))[start_index:end_index], sorted(comb_list)[start_index:end_index], color='y', label='Ours&NTG') 69 | 70 | plt.grid() 71 | plt.show() 72 | pass 73 | 74 | if __name__ == '__main__': 75 | # visual_mutual_coco() 76 | visual_mutual_cave() 77 | pass -------------------------------------------------------------------------------- /ntg_pytorch/multispectral_pytorch_test.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from skimage import io 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | from datasets.provider.randomTnsData import RandomTnsPair 11 | from datasets.provider.singlechannelData import SingleChannelPairTnf 12 | from datasets.provider.test_dataset import TestDataset 13 | from evluate.lossfunc import GridLoss 14 | from ntg_pytorch.register_func import estimate_aff_param_iterator, affine_transform, scale_image 15 | from tnf_transform.img_process import NormalizeImageDict, generate_affine_param, NormalizeCAVEDict 16 | from traditional_ntg.estimate_affine_param import estimate_affine_param 17 | from util.pytorchTcv import inverse_theta, param2theta 18 | from visualization.train_visual import VisdomHelper 19 | from sklearn import metrics 20 | 21 | def use_torch_ntg(img1,img2): 22 | img1 = img1[np.newaxis, np.newaxis, :, :] 23 | img2 = img2[np.newaxis, np.newaxis, :, :] 24 | 25 | source_batch = torch.from_numpy(img1).float() 26 | target_batch = torch.from_numpy(img2).float() 27 | 28 | # normalize_func = NormalizeCAVEDict(["image"]) 29 | p = estimate_aff_param_iterator(source_batch, target_batch, use_cuda=use_cuda, itermax=600) 30 | p = p[0].cpu().numpy() 31 | return p 32 | 33 | def use_cv2_ntg(img1,img2): 34 | p = estimate_affine_param(img2,img1,itermax=600) 35 | return p 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | small = True 41 | 42 | # img1 = io.imread('../datasets/row_data/multispectral/fake_and_real_tomatoes_ms_31.png') * 1.0 43 | # img2 = io.imread('../datasets/row_data/multispectral/fake_and_real_tomatoes_ms_17.png') * 1.0 44 | 45 | # img1 = io.imread('../datasets/row_data/multispectral/mul_1s_s.png') 46 | # img2 = io.imread('../datasets/row_data/multispectral/mul_1t_s.png') 47 | 48 | img1 = io.imread('../datasets/row_data/texs1.jpeg') 49 | img2 = io.imread('../datasets/row_data/test2.jpeg') 50 | 51 | center = (img1.shape[0]/2,img1.shape[1]/2) 52 | center = (0,0) 53 | if small: 54 | theta = generate_affine_param(scale=1.1, degree=10, translate_x=-10, translate_y=10,center=center) 55 | else: 56 | theta = generate_affine_param(scale=1.25, degree=30, translate_x=-20, translate_y=20,center=center) 57 | 58 | use_cuda = torch.cuda.is_available() 59 | 60 | env = "ntg_pytorch" 61 | vis = VisdomHelper(env) 62 | 63 | p = use_torch_ntg(img1,img2) 64 | # p = use_cv2_ntg(img1,img2) 65 | 66 | 67 | # im2warped = affine_transform(img2,p) 68 | im2warped = affine_transform(img1,p) 69 | 70 | imgGt = affine_transform(img1,theta) 71 | 72 | print(metrics.normalized_mutual_info_score(im2warped.flatten()/255.0, imgGt.flatten()/255.0)) 73 | 74 | 75 | p = torch.from_numpy(p).unsqueeze(0).float() 76 | 77 | p_pytorch = param2theta(p,img1.shape[0],img1.shape[1],use_cuda=False) 78 | 79 | theta_GT = torch.from_numpy(theta).unsqueeze(0).float() 80 | theta_GT = param2theta(theta_GT,img1.shape[0],img1.shape[1],use_cuda=False) 81 | fn_grid_loss = GridLoss(use_cuda=False, grid_size=512) 82 | 83 | 84 | print(p_pytorch) 85 | print(theta_GT) 86 | 87 | grid_loss = fn_grid_loss.compute_grid_loss(p_pytorch, theta_GT) 88 | print('grid_loss',grid_loss) 89 | 90 | 91 | plt.imshow(img1, cmap='gray') # 目标图片 92 | plt.figure() 93 | plt.imshow(img2, cmap='gray') # 待变换图片 94 | plt.figure() 95 | plt.imshow(im2warped, cmap='gray') 96 | plt.figure() 97 | plt.imshow(imgGt, cmap='gray') 98 | plt.show() -------------------------------------------------------------------------------- /multispectral_pytorch_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from skimage import io 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch.nn.functional as F 10 | 11 | from datasets.provider.test_dataset import TestDataset, NtgTestPair 12 | from evluate.lossfunc import GridLoss 13 | from main.test_mulit_images import compute_average_grid_loss, compute_correct_rate 14 | from ntg_pytorch.register_func import estimate_aff_param_iterator, affine_transform 15 | from tnf_transform.img_process import NormalizeImageDict 16 | from tnf_transform.transformation import affine_transform_opencv 17 | from traditional_ntg.estimate_affine_param import estimate_param_batch 18 | from util.pytorchTcv import param2theta 19 | from visualization.train_visual import VisdomHelper 20 | 21 | if __name__ == '__main__': 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 24 | 25 | print('使用传统NTG批量测试') 26 | 27 | use_cuda = torch.cuda.is_available() 28 | 29 | env = "ntg_pytorch" 30 | vis = VisdomHelper(env) 31 | test_image_path = '/home/zlk/datasets/coco_test2017_n2000' 32 | label_path = 'datasets/row_data/label_file/coco_test2017_n2000_custom_20r_param.csv' 33 | 34 | threshold = 3 35 | batch_size = 164 36 | 37 | # dataset = TestDataset(test_image_path,label_path,transform=NormalizeImageDict(["image"])) 38 | dataset = TestDataset(test_image_path,label_path) 39 | dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True) 40 | # pair_generator = NtgTestPair(use_cuda=use_cuda,output_size=(480, 640)) 41 | pair_generator = NtgTestPair(use_cuda=use_cuda) 42 | 43 | fn_grid_loss = GridLoss(use_cuda=use_cuda) 44 | grid_loss_ntg_list = [] 45 | 46 | for batch_idx,batch in enumerate(dataloader): 47 | if batch_idx % 5 == 0: 48 | print('test batch: [{}/{} ({:.0f}%)]'.format( 49 | batch_idx, len(dataloader), 50 | 100. * batch_idx / len(dataloader))) 51 | 52 | pair_batch = pair_generator(batch) 53 | 54 | 55 | source_image_batch = pair_batch['source_image'] 56 | target_image_batch = pair_batch['target_image'] 57 | theta_GT_batch = pair_batch['theta_GT'] 58 | image_name = pair_batch['name'] 59 | 60 | if use_cuda: 61 | source_image_batch = source_image_batch.cuda() 62 | target_image_batch = target_image_batch.cuda() 63 | 64 | with torch.no_grad(): 65 | ntg_param_batch = estimate_aff_param_iterator(source_image_batch,target_image_batch,use_cuda=use_cuda) 66 | 67 | ntg_param_pytorch_batch = param2theta(ntg_param_batch, 240, 240, use_cuda=use_cuda) 68 | # ntg_param_pytorch_batch = param2theta(ntg_param_batch, 480, , use_cuda=use_cuda) 69 | 70 | loss_ntg = fn_grid_loss.compute_grid_loss(ntg_param_pytorch_batch.detach(), theta_GT_batch) 71 | 72 | grid_loss_ntg_list.append(loss_ntg.detach().cpu()) 73 | 74 | ntg_image_warped_batch = affine_transform_opencv(source_image_batch, ntg_param_batch.cpu()) 75 | 76 | vis.showImageBatch(source_image_batch,normailze=True,win='source_image_batch',title='source_image_batch') 77 | vis.showImageBatch(target_image_batch,normailze=True,win='target_image_batch',title='target_image_batch') 78 | vis.showImageBatch(ntg_image_warped_batch, normailze=True, win='ntg_wraped_image', title='ntg_pytorch') 79 | break 80 | 81 | 82 | print('ntg网格点损失') 83 | ntg_group_list = compute_average_grid_loss(grid_loss_ntg_list,threshold=threshold) 84 | 85 | print('ntg正确率') 86 | compute_correct_rate(grid_loss_ntg_list, threshold=threshold) 87 | -------------------------------------------------------------------------------- /datasets/dataset_process/Generator.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import csv 8 | import cv2 9 | import scipy.io as scio 10 | 11 | 12 | #row_data_dir_path = '/Users/zale/project/myself/registration_cnn_ntg/datasets/row_data/VOC/' 13 | from tnf_transform.img_process import random_affine, generator_affine_param, generate_affine_param 14 | from train import init_seeds 15 | 16 | 17 | def read_row_data(data_path): 18 | image_name_list = os.listdir(data_path) 19 | #print(image_name_list) 20 | return image_name_list 21 | 22 | 23 | 24 | def affine_transform(image,param): 25 | height = image.shape[0] 26 | width = image.shape[1] 27 | image = cv2.warpAffine(image,param,(width,height)) 28 | return image 29 | 30 | 31 | 32 | def generate_result_dict(row_data_dir_path,output_path,use_custom_random_aff = False): 33 | image_name_list = read_row_data(row_data_dir_path) 34 | param_list = [] 35 | for i in range(len(image_name_list)): 36 | if use_custom_random_aff: 37 | random_param_dict = random_affine(to_dict= True) 38 | else: 39 | random_param_dict = generator_affine_param(to_dict=True) 40 | 41 | random_param_dict['image'] = image_name_list[i] 42 | param_list.append(random_param_dict) 43 | if i % 5000 == 0: 44 | print('第',i,"张") 45 | #print(image_name_list[i],param_list[i]) 46 | 47 | write_csv(output_path,param_list) 48 | print("写入完成") 49 | 50 | def write_csv(output_path,datadicts): 51 | with open(output_path,mode='w') as csv_file: 52 | # 使用这个的话就直接write_row 53 | #employee_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 54 | 55 | fieldnames = ['image','p0','p1','p2','p3','p4','p5'] 56 | writer = csv.DictWriter(csv_file,fieldnames=fieldnames) 57 | 58 | writer.writeheader() 59 | 60 | # datadict = {'name':'hell','number': 'Accounting', 'age': 'November'} 61 | # writer.writerow(datadict) 62 | writer.writerows(datadicts) 63 | 64 | # def test_affine_image(): 65 | # image_name_list = read_row_data(row_data_dir_path) 66 | # for i in range(len(image_name_list)): 67 | # 68 | # img_row = cv2.imread(row_data_dir_path+image_name_list[i]) 69 | # random_param = generator_affine_param() 70 | # img_aff = affine_transform(img_row,random_param) 71 | # 72 | # result = np.hstack([img_row,img_aff]) 73 | # 74 | # cv2.imshow('compare',result) 75 | # cv2.waitKey(0) 76 | 77 | 78 | 79 | 80 | 81 | 82 | if __name__ == '__main__': 83 | init_seeds(seed= 46763) 84 | #row_data_dir_path = '/home/zlk/datasets/coco_test2017' 85 | row_data_dir_path = '../row_data/COCO/' 86 | # row_data_dir_path = '/home/zlk/datasets/coco_test2017_n2000' 87 | # row_data_dir_path = '/mnt/4T/zlk/datasets/mulitspectral/nirscene_total/nir_image' 88 | use_custom_random_aff = False 89 | if use_custom_random_aff: 90 | # output_path = '../row_data/label_file/coco_test2017_custom_param.csv' 91 | # output_path = '../row_data/label_file/coco_test2017_n2000_custom_20r_param.csv' 92 | output_path = '../row_data/label_file/nir_rgb_custom_20r_param.csv' 93 | else: 94 | output_path = '../row_data/label_file/coco_test2017_paper_param_n2000.csv' 95 | # output_path = '../row_data/label_file/nir_rgb_paper_affine_param.csv' 96 | 97 | #generate_result_dict(row_data_dir_path,output_path,use_custom_random_aff=use_custom_random_aff) 98 | 99 | #print(random_affine(to_dict= True)) 100 | 101 | -------------------------------------------------------------------------------- /traditional_ntg/test_script.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 4 | import torch 5 | from PIL.Image import BILINEAR 6 | from skimage import io 7 | import matplotlib.pyplot as plt 8 | 9 | from evluate.lossfunc import GridLoss 10 | from ntg_pytorch.register_func import affine_transform 11 | from ntg_pytorch.register_loss import deriv_filt_pytorch 12 | from tnf_transform.img_process import generate_affine_param 13 | from traditional_ntg.estimate_affine_param import estimate_affine_param 14 | from traditional_ntg.loss_function import deriv_filt 15 | import scipy.misc as smi 16 | import numpy as np 17 | from PIL import Image 18 | 19 | # img1 = io.imread('../datasets/row_data/multispectral/mul_1t_s.png') 20 | # img2 = io.imread('../datasets/row_data/multispectral/mul_1s_s.png') 21 | 22 | img1 = io.imread('../datasets/row_data/multispectral/fake_and_real_tomatoes_ms_17.png') 23 | img2 = io.imread('../datasets/row_data/multispectral/fake_and_real_tomatoes_ms_28.png') 24 | 25 | # img1 = io.imread('../datasets/row_data/texs1.jpeg') 26 | # img2 = io.imread('../datasets/row_data/test2.jpeg') 27 | 28 | fn_grid_loss = GridLoss(use_cuda=False,grid_size=512) 29 | 30 | center = (256,256) 31 | theta_GT = generate_affine_param(scale=1.1, degree=10, translate_x=-10, translate_y=10, center=center) 32 | theta_GT = torch.from_numpy(theta_GT).unsqueeze(0).float() 33 | 34 | 35 | 36 | # 第一个是target,第二个是source 37 | p = estimate_affine_param(img1,img2,itermax=1000) 38 | 39 | im2warped = affine_transform(img2, p) 40 | print(p) 41 | print(theta_GT) 42 | 43 | p = torch.from_numpy(p).unsqueeze(0).float() 44 | 45 | grid_loss = fn_grid_loss.compute_grid_loss(p,theta_GT) 46 | print(grid_loss) 47 | 48 | 49 | 50 | plt.imshow(img1, cmap='gray') # 目标图片 51 | plt.figure() 52 | plt.imshow(img2, cmap='gray') # 待变换图片 53 | plt.figure() 54 | plt.imshow(im2warped, cmap='gray') 55 | plt.show() 56 | 57 | # # ration = (1/1.5)**6 58 | # img1 = (img1/255.0) 59 | # # 60 | # # img1 = np.array(Image.fromarray(img1).resize((int(img1.shape[0] * ration), int(img1.shape[1] * ration)))) 61 | # # 62 | # # Ix_cv,Iy_cv = deriv_filt(img1,False) 63 | # 64 | # # img1_tensor = torch.from_numpy(img1).unsqueeze(0).unsqueeze(0).float() 65 | # # Ix,Iy = deriv_filt_pytorch(img1_tensor,False,use_cuda=False) 66 | # # Ix = Ix.squeeze().numpy() 67 | # # Iy = Iy.squeeze().numpy() 68 | # 69 | # ration = (1/1.5) 70 | # smooth_sigma = np.sqrt(1.5) / np.sqrt(3) 71 | # kx = cv2.getGaussianKernel(int(2 * round(1.5 * smooth_sigma)) + 1, smooth_sigma) 72 | # ky = cv2.getGaussianKernel(int(2 * round(1.5 * smooth_sigma)) + 1, smooth_sigma) 73 | # 74 | # hg = np.multiply(kx, np.transpose(ky)) 75 | # # tmp = cv2.filter2D(img1, -1, hg, borderType=cv2.BORDER_REFLECT) 76 | # 77 | # multi_level_pyramid = [] 78 | # for i in range(7): 79 | # # multi_level_pyramid.append(img1) 80 | # # Ix_cv, Iy_cv = deriv_filt(img1, False) 81 | # 82 | # # plt.figure() 83 | # # plt.imshow(Iy_cv, cmap='gray') 84 | # if i == 6: 85 | # plt.figure() 86 | # plt.imshow(img1, cmap='gray') 87 | # break 88 | # # 默认的BORDER_REFLECT_101 89 | # img1 = cv2.filter2D(img1, -1, hg, borderType=cv2.BORDER_REFLECT_101) 90 | # img1 = np.array(Image.fromarray(img1).resize((math.ceil(img1.shape[0] * ration), math.ceil(img1.shape[1] * ration)),resample=BILINEAR)) 91 | # 92 | # 93 | # # plt.figure() 94 | # # plt.imshow(multi_level_pyramid[0],cmap='gray') 95 | # 96 | # # plt.figure() 97 | # # plt.imshow(Ix,cmap='gray') 98 | # # plt.figure() 99 | # # plt.imshow(Iy,cmap='gray') 100 | # # plt.figure() 101 | # # plt.imshow(Ix_cv,cmap='gray') 102 | # # plt.figure() 103 | # # plt.imshow(Iy_cv,cmap='gray') 104 | # plt.show() 105 | 106 | -------------------------------------------------------------------------------- /evluate/evaluate_result.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import torch 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | from tnf_transform.img_process import random_affine 10 | from tnf_transform.transformation import AffineTnf 11 | from traditional_ntg.estimate_affine_param import estimate_param_batch 12 | from util.pytorchTcv import theta2param, param2theta 13 | from skimage import io 14 | 15 | from util.time_util import calculate_diff_time 16 | 17 | 18 | def evaluate(theta_estimate_batch,theta_GT_batch,source_image_batch,target_image_batch,use_cuda = True): 19 | # 将pytorch的变换参数转为opencv的变换参数 20 | theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 21 | 22 | # P5使用传统NTG方法进行优化cnn的结果 23 | ntg_param = estimate_param_batch(source_image_batch, target_image_batch, None, itermax=600) 24 | ntg_param_pytorch = param2theta(ntg_param, 240, 240, use_cuda=use_cuda) 25 | cnn_ntg_param_batch = estimate_param_batch(source_image_batch, target_image_batch, theta_opencv, itermax=800) 26 | cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch, 240, 240, use_cuda=use_cuda) 27 | 28 | # loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch, theta_GT_batch) 29 | # loss_ntg = grid_loss.compute_grid_loss(ntg_param_pytorch, theta_GT_batch) 30 | # loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param_pytorch_batch, theta_GT_batch) 31 | # 32 | # grid_loss_list.append(loss_cnn) 33 | # grid_loss_ntg_list.append(loss_ntg) 34 | # grid_loss_comb_list.append(loss_cnn_ntg) 35 | 36 | def compare_img_resize(): 37 | img_path = '/Users/zale/project/myself/registration_cnn_ntg/datasets/row_data/multispectral/It.jpg' 38 | 39 | h,w = 600,800 40 | 41 | opencv_start_time = time.time() 42 | img = cv2.imread(img_path) 43 | print('imread_time',calculate_diff_time(opencv_start_time)) 44 | img = cv2.resize(img,(w,h),interpolation=cv2.INTER_CUBIC) 45 | start_time = time.time() 46 | # img_t = img.transpose(2,0,1) 47 | img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 48 | img = np.ascontiguousarray(img, dtype=np.float32) # uint8 to float32 49 | img /= 255.0 # 0 - 255 to 0.0 - 1.0 50 | img = torch.from_numpy(img).unsqueeze(0) 51 | 52 | elpased = calculate_diff_time(opencv_start_time) 53 | print('opencv time',img.shape,elpased) 54 | 55 | torch_start_time = time.time() 56 | img2 = io.imread(img_path) 57 | print('torch_read_time',calculate_diff_time(torch_start_time)) 58 | affineTnf = AffineTnf(h, w, use_cuda=False) 59 | image = torch.Tensor(img2.astype(np.float32)) 60 | image = image.transpose(1, 2).transpose(0, 1) 61 | img2 = affineTnf(image.unsqueeze(0)) 62 | elpased = calculate_diff_time(torch_start_time) 63 | print('torch time,',img2.shape,elpased) 64 | 65 | def compare_affine_param_generator(random_t=0.2,random_s=0.2, 66 | random_alpha = 1/4): 67 | start_time = time.time() 68 | 69 | alpha = (torch.rand(1) - 0.5) * 2 * np.pi * random_alpha 70 | alpha = alpha.numpy() 71 | theta = torch.rand(6).numpy() 72 | 73 | theta[[2, 5]] = (theta[[2, 5]] - 0.5) * 2 * random_t 74 | theta[0] = (1 + (theta[0] - 0.5) * 2 * random_s) * np.cos(alpha) 75 | theta[1] = (1 + (theta[1] - 0.5) * 2 * random_s) * (-np.sin(alpha)) 76 | theta[3] = (1 + (theta[3] - 0.5) * 2 * random_s) * np.sin(alpha) 77 | theta[4] = (1 + (theta[4] - 0.5) * 2 * random_s) * np.cos(alpha) 78 | theta = theta.reshape(2, 3) 79 | 80 | elpased = calculate_diff_time(start_time) 81 | print('计算随机变换参数:', elpased) # 0.0004s 82 | 83 | start_time = time.time() 84 | theta_m = random_affine() 85 | theta_m = torch.Tensor(theta_m) 86 | elpased = calculate_diff_time(start_time) 87 | print('随机仿射换换时间:', elpased) 88 | 89 | #compare_img_resize() 90 | #compare_affine_param_generator() 91 | 92 | -------------------------------------------------------------------------------- /datasets/provider/pf_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | from tnf_transform.transformation import AffineTnf 11 | 12 | 13 | class PFDataset(Dataset): 14 | 15 | """ 16 | 17 | Proposal Flow image pair dataset 18 | 19 | 20 | Args: 21 | csv_file (string): Path to the csv file with image names and transformations. 22 | training_image_path (string): Directory with the images. 23 | output_size (2-tuple): Desired output size 24 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 25 | 26 | """ 27 | 28 | def __init__(self, csv_file, training_image_path,output_size=(240,240),transform=None): 29 | 30 | self.out_h, self.out_w = output_size 31 | self.train_data = pd.read_csv(csv_file) 32 | self.img_A_names = self.train_data.iloc[:,0] 33 | self.img_B_names = self.train_data.iloc[:,1] 34 | self.point_A_coords = self.train_data.iloc[:, 2:22].as_matrix().astype('float') 35 | self.point_B_coords = self.train_data.iloc[:, 22:].as_matrix().astype('float') 36 | self.training_image_path = training_image_path 37 | self.transform = transform 38 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 39 | self.affineTnf = AffineTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 40 | 41 | def __len__(self): 42 | return len(self.train_data) 43 | 44 | def __getitem__(self, idx): 45 | # get pre-processed images 46 | image_A,im_size_A = self.get_image(self.img_A_names,idx,single=False) 47 | image_B,im_size_B = self.get_image(self.img_B_names,idx,single=False) 48 | 49 | # get pre-processed point coords 50 | point_A_coords = self.get_points(self.point_A_coords,idx) 51 | point_B_coords = self.get_points(self.point_B_coords,idx) 52 | 53 | # compute PCK reference length L_pck (equal to max bounding box side in image_A) 54 | L_pck = torch.FloatTensor([torch.max(point_A_coords.max(1)[0]-point_A_coords.min(1)[0])]) 55 | 56 | sample = {'source_image': image_A, 'target_image': image_B, 'source_im_size': im_size_A, 'target_im_size': im_size_B, 'source_points': point_A_coords, 'target_points': point_B_coords, 'L_pck': L_pck} 57 | 58 | if self.transform: 59 | sample = self.transform(sample) 60 | 61 | return sample 62 | 63 | def get_image(self,img_name_list,idx,single=False): 64 | img_name = os.path.join(self.training_image_path, img_name_list[idx]) 65 | image = io.imread(img_name) 66 | 67 | if single: 68 | image = image[:,:,0:1] 69 | 70 | # get image size 71 | im_size = np.asarray(image.shape) 72 | 73 | # convert to torch Variable 74 | image = np.expand_dims(image.transpose((2,0,1)),0) 75 | image = torch.Tensor(image.astype(np.float32)) 76 | image_var = Variable(image,requires_grad=False) 77 | 78 | # Resize image using bilinear sampling with identity affine tnf 79 | image = self.affineTnf(image_var).data.squeeze(0) 80 | 81 | im_size = torch.Tensor(im_size.astype(np.float32)) 82 | 83 | return (image, im_size) 84 | 85 | def get_points(self,point_coords_list,idx): 86 | point_coords = point_coords_list[idx, :].reshape(2,10) 87 | 88 | # # swap X,Y coords, as the the row,col order (Y,X) is used for computations 89 | # point_coords = point_coords[[1,0],:] 90 | 91 | # make arrays float tensor for subsequent processing 92 | point_coords = torch.Tensor(point_coords.astype(np.float32)) 93 | return point_coords 94 | -------------------------------------------------------------------------------- /traditional_ntg/loss_function.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | #计算weight的时候isconj为True 6 | def deriv_filt(I,isconj): 7 | 8 | # 使用两者滤波核都可以,但实验表明,使用1*3维的滤波核效果更加好,残差错误更加低 9 | # if not isconj: 10 | # h1 = np.array([[-0.5, 0, 0.5],[-0.5,0,0.5],[-0.5,0,0.5]]) 11 | # h2 = np.array([[-0.5,-0.5,-0.5],[0,0,0],[0.5,0.5,0.5]]) 12 | # else: 13 | # h1 = np.array([[0.5, 0, -0.5],[0.5, 0, -0.5],[0.5, 0, -0.5]]) 14 | # h2 = np.array([[0.5,0.5,0.5], [0,0,0], [-0.5,-0.5,-0.5]]) 15 | # 16 | # Ix = cv2.filter2D(I,-1,h1) 17 | # Iy = cv2.filter2D(I,-1,h2) 18 | 19 | if not isconj: 20 | h1 = np.array([-0.5, 0, 0.5]) 21 | h2 = np.array([[-0.5],[0],[0.5]]) 22 | else: 23 | h1 = np.array([0.5, 0, -0.5]) 24 | h2 = np.array([[0.5], [0], [-0.5]]) 25 | 26 | # Ix是上下减,得到横向的线条,左右pad上下减后为0。Iy是左右减,得到纵向的线条 27 | # 使用这个自动将负的值变为0,将小数向上取整 28 | # 对于0到1的浮点数不会进行取整变0操作,浮点数还是按照原来的方法进行卷积 29 | # print('reflect') 这里的bodertype不会影响最后结果 30 | # Iy = cv2.filter2D(I,-1,h1, borderType=cv2.BORDER_REFLECT) 31 | # Ix = cv2.filter2D(I,-1,np.transpose(h2), borderType=cv2.BORDER_REFLECT) 32 | # 33 | Iy = cv2.filter2D(I, -1, h1) 34 | Ix = cv2.filter2D(I, -1, np.transpose(h2)) 35 | 36 | return Ix,Iy 37 | 38 | 39 | # def inverse_affine_param(p): 40 | # try: 41 | # A = np.array([[p[0,0],p[0,1],p[0,2]],[p[1,0],p[1,1],p[1,2]],[0,0,1]]) 42 | # B = np.linalg.inv(A) 43 | # return B 44 | # except TypeError: 45 | # print("type Error") 46 | 47 | 48 | def affine_transform(im,p): 49 | height = im.shape[0] 50 | width = im.shape[1] 51 | # im = cv2.warpAffine(im,p,(width,height),flags=cv2.INTER_CUBIC) 52 | im = cv2.warpAffine(im,p,(width,height)) 53 | return im 54 | 55 | 56 | def partial_deriv_affine(I1,I2,p,h): 57 | 58 | [H,W] = I1.shape 59 | # x,y = np.meshgrid(range(0,W),range(0,H)) 60 | # x2 = p[0,0] * x + p[0,1]*y + p[0,2] 61 | # y2 = p[1,0] * x + p[1,1]*y + p[1,2] 62 | # B = (x2 > W-1) | (x2 < 0) | (y2 > H-1) | (y2 < 0) 63 | 64 | warpI2 = cv2.warpAffine(I2,p,(I2.shape[1],I2.shape[0]),flags=cv2.INTER_CUBIC) 65 | I2x,I2y = deriv_filt(I2,False) 66 | 67 | Ipx = cv2.warpAffine(I2x,p,(I2x.shape[1],I2x.shape[0]),flags=cv2.INTER_CUBIC) 68 | Ipy = cv2.warpAffine(I2y,p,(I2y.shape[1],I2y.shape[0]),flags=cv2.INTER_CUBIC) 69 | It = warpI2 - I1 70 | 71 | # It[B] = 0 72 | # Ipx[B] = 0 73 | # Ipy[B] = 0 74 | return It,Ipx,Ipy 75 | 76 | def func_rho(x,order,epsilon=0.01): 77 | if order == 0: 78 | y = np.sqrt(x*x + epsilon*epsilon) 79 | y = np.sum(y) 80 | elif order == 1: 81 | y = x/np.sqrt(x*x + epsilon*epsilon) 82 | else: 83 | print("Tag | wrong order") 84 | return y 85 | 86 | 87 | def ntg(img1,img2): 88 | [g1x,g1y] = deriv_filt(img1,False) 89 | [g2x,g2y] = deriv_filt(img2,False) 90 | 91 | m = func_rho(g1x - g2x,0) + func_rho(g1y - g2y,0) 92 | n = func_rho(g1x,0) + func_rho(g2x,0) + func_rho(g1y,0) + func_rho(g2y,0) 93 | #y = m/(n+0.01)## TOdo 94 | y = m/(n+1e-16)## TOdo 95 | return y 96 | 97 | 98 | 99 | # 返回仿射变换参数p的NTG的梯度 100 | # this.images(:,:,1) = fr1; 101 | # this.images(:,:,2) = fr2; 102 | # this.deriv_filter: drivative kernel in x direction 103 | def ntg_gradient(objdict,p): 104 | options = objdict['options'] 105 | images = objdict['images'] 106 | #warpI = cv2.warpAffine(images[:,:,1],p,(images[:,:,1].shape[1],images[:,:,1].shape[0])) 107 | warpI = affine_transform(images[:,:,1],p) 108 | 109 | [It,Ipx,Ipy] = partial_deriv_affine(images[:,:,0],images[:,:,1],p,options['deriv_filter']) # It:warp和I1差值。 Ipx和Ipy是I2的横向纵向梯度 110 | 111 | # print(p) 112 | # plt.figure() 113 | # plt.imshow(warpI, cmap=plt.cm.gray_r) 114 | # plt.figure() 115 | # plt.imshow(It, cmap=plt.cm.gray_r) 116 | # plt.figure() 117 | # plt.imshow(Ipx, cmap=plt.cm.gray_r) 118 | # plt.figure() 119 | # plt.imshow(Ipy, cmap=plt.cm.gray_r) 120 | # plt.show() 121 | 122 | J = ntg(warpI,images[:,:,0]) 123 | 124 | [Itx,Ity] = deriv_filt(It,False) 125 | rho_x = func_rho(Itx,1) - J*func_rho(Ipx,1) 126 | rho_y = func_rho(Ity,1) - J*func_rho(Ipy,1) 127 | 128 | [wxx,wxy] = deriv_filt(rho_x,True) 129 | [wyx,wyy] = deriv_filt(rho_y,True) 130 | w = wxx + wyy 131 | 132 | g = np.zeros((6, 1)); 133 | g[0] = np.mean(w * objdict['X'] * Ipx); 134 | g[1] = np.mean(w * objdict['Y'] * Ipx); 135 | g[2] = np.mean(w * Ipx); 136 | g[3] = np.mean(w * objdict['X'] * Ipy); 137 | g[4] = np.mean(w * objdict['Y'] * Ipy); 138 | g[5] = np.mean(w * Ipy); 139 | 140 | g = g.reshape(2,3) 141 | 142 | return g 143 | 144 | -------------------------------------------------------------------------------- /datasets/provider/singlechannelData.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from skimage import io 5 | import numpy as np 6 | 7 | from torch.utils.data import Dataset 8 | 9 | from tnf_transform.transformation import AffineTnf 10 | from util.csv_opeartor import read_csv_file 11 | from traditional_ntg.image_util import symmetricImagePad 12 | 13 | ''' 14 | 测试时使用的数据提供类,读入仿射变换参数和原图的信息,返回image,theta,name 15 | 16 | ''' 17 | class SinglechannelData(Dataset): 18 | 19 | def __init__(self,image_path,label_path,output_size=(480,640),transform=None, 20 | use_cuda = False): 21 | ''' 22 | :param image_path: 23 | :param label_path: 24 | :param output_size: 25 | :param normalize_range: 26 | :param use_cuda: 读写数据时使用cuda的话使用多个workers会导致不同步产生错乱,所以不使用Cuda 27 | ''' 28 | self.transform = transform 29 | self.image_path = image_path 30 | self.label_path = label_path 31 | self.image_list = os.listdir(self.image_path) 32 | self.out_h,self.out_w = output_size 33 | self.csv_data = read_csv_file(label_path) # 数据帧df,可看做表格,如果加入index限定主键的话values就不包含主键 34 | self.resizeTnf = AffineTnf(self.out_h,self.out_w,use_cuda=use_cuda) 35 | 36 | def __len__(self): 37 | return len(self.image_list) 38 | 39 | 40 | def __getitem__(self, idx): 41 | image_name = self.image_list[idx] 42 | image_path = os.path.join(self.image_path,image_name) 43 | 44 | image_array = io.imread(image_path) 45 | #label_row_param = self.csv_data.ix[idx,:] 46 | #label_row_index = self.csv_data.index[idx] 47 | #label_row_param = self.csv_data.values[idx] 48 | label_row_param = self.csv_data.loc[self.csv_data['image'] == image_name].values 49 | label_row_param = np.squeeze(label_row_param) 50 | if image_name != label_row_param[0]: 51 | raise ValueError("图片文件名和label图片文件名不匹配") 52 | 53 | theta_aff = label_row_param[1:].reshape(2,3) 54 | 55 | image_tensor = torch.Tensor(image_array.astype(np.float32)) 56 | theta_aff_tensor = torch.Tensor(theta_aff.astype(np.float32)) 57 | 58 | # permute order of image to CHW 59 | try: 60 | image_tensor = image_tensor.transpose(1, 2).transpose(0, 1) 61 | except RuntimeError: 62 | one = image_tensor.unsqueeze(0) 63 | image_tensor = torch.cat((one,one,one),0) 64 | 65 | # Resize image using bilinear sampling with identity affine tnf 66 | # 这里数据集大小要一致,否则会报错误,源代码里面好像会进行cat操作,维度不一致不能cat 67 | if image_tensor.size()[0] != self.out_h or image_tensor.size()[1] != self.out_w: 68 | image_tensor = self.resizeTnf(image_tensor.unsqueeze(0)).squeeze(0) 69 | 70 | 71 | sample = {'image':image_tensor,'theta':theta_aff_tensor,'name':image_name} 72 | 73 | if self.transform: 74 | sample = self.transform(sample) 75 | 76 | return sample 77 | 78 | ''' 79 | 使用仿射变换参数生成图片对 80 | 返回{"source_image,traget_image,theta_GT,name"} 81 | ''' 82 | class SingleChannelPairTnf(object): 83 | def __init__(self,use_cuda=True,output_size=(240,240),crop_factor = 9/16,padding_factor = 0.6): 84 | self.use_cuda = use_cuda 85 | self.out_h,self.out_w = output_size 86 | self.crop_factor = crop_factor 87 | self.padding_factor = padding_factor 88 | self.affineTnf = AffineTnf(self.out_h,self.out_w,use_cuda=use_cuda) 89 | 90 | def __call__(self, batch): 91 | image_batch,theta_batch,image_name = batch['image'], batch['theta'],batch['name'] 92 | indices_R = torch.tensor([0]) 93 | indices_G = torch.tensor([2]) 94 | if self.use_cuda: 95 | image_batch = image_batch.cuda() 96 | theta_batch = theta_batch.cuda() 97 | indices_R = indices_R.cuda() 98 | indices_G = indices_G.cuda() 99 | 100 | # 对图像边缘进行镜像填充 101 | image_batch = symmetricImagePad(image_batch,self.padding_factor,use_cuda = self.use_cuda) 102 | 103 | # 获得单通道图片,沿着指定维度对输入进行切片,取index中指定的相应项(index为一个LongTensor), 104 | # 然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。 105 | 106 | image_batch_R = torch.index_select(image_batch,1,indices_R) 107 | image_batch_G = torch.index_select(image_batch,1,indices_G) 108 | 109 | # 原始图像R通道缩放 110 | original_image_batch = self.affineTnf(image_batch_R,None,self.padding_factor,self.crop_factor) 111 | wraped_image_batch = self.affineTnf(image_batch_G,theta_batch,self.padding_factor,self.crop_factor) 112 | 113 | pair_result = {'source_image': original_image_batch, 'target_image': wraped_image_batch, 'theta_GT': theta_batch, 114 | 'name':image_name} 115 | 116 | return pair_result 117 | 118 | 119 | -------------------------------------------------------------------------------- /util/train_test_fn.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import time 9 | import cv2 10 | 11 | from tnf_transform.transformation import AffineTnf, AffineGridGen 12 | from util import utils 13 | from util.time_util import calculate_diff_time 14 | 15 | 16 | def train(epoch,model,loss_fn,optimizer,dataloader,pair_generation_tnf,gridGen,vis,use_cuda=True, 17 | gpu_id = None,log_interval=50,lr_scheduler= False,rank=0): 18 | 19 | model.train() 20 | train_loss = 0 21 | 22 | for batch_idx, batch in enumerate(dataloader): 23 | 24 | # total_start_time = time.time() 25 | 26 | #计算这个时间没有用,主要看getitem里面的时间 27 | # batch_end_time = time.time() 28 | # if batch_start_time != 0: 29 | # elpased = batch_end_time - batch_start_time 30 | # print("一个batch的时间:",elpased) 31 | 32 | optimizer.zero_grad() 33 | 34 | # 计算仿射变换参数 35 | # start_time = time.time() 36 | tnf_batch = pair_generation_tnf(batch) 37 | theta = model(tnf_batch).view(-1,2,3) # [16,6] 38 | # elpased = calculate_diff_time(start_time) 39 | # print("cnn计算参数",str(elpased)) #0.09s 40 | 41 | # start_time = time.time() 42 | # 生成采样网格,pytorch原始的方式 43 | sampling_grid = gridGen(theta) 44 | 45 | # elpased = calculate_diff_time(start_time) 46 | # print("生成采样网格", str(elpased)) # 0.0002s 47 | 48 | # start_time = time.time() 49 | # 生成原始、目标、变换后的图片 50 | source_image_batch = tnf_batch['source_image'] 51 | target_image_batch = tnf_batch['target_image'] 52 | warped_image_batch = F.grid_sample(source_image_batch, sampling_grid) 53 | 54 | # elpased = calculate_diff_time(start_time) 55 | # print("变换图片,三种图片", str(elpased)) # 0.00008s 56 | 57 | # start_time = time.time() 58 | loss = loss_fn(target_image_batch, warped_image_batch) 59 | 60 | loss_reduced = utils.reduce_loss(loss) 61 | 62 | # elpased = calculate_diff_time(start_time) 63 | # print("计算损失",str(elpased)) # 0.11s 64 | 65 | # start_time = time.time() 66 | loss.backward() 67 | optimizer.step() 68 | 69 | if lr_scheduler: 70 | lr_scheduler.step() 71 | 72 | train_loss += loss_reduced 73 | 74 | # elpased = calculate_diff_time(start_time) 75 | # print("反向传播", str(elpased)) # 0.009s 76 | 77 | # start_time = time.time() 78 | if batch_idx % log_interval == 0 and rank == 0: 79 | 80 | vis.drawImage((source_image_batch).detach(), 81 | (warped_image_batch).detach(), 82 | (target_image_batch).detach(),False) 83 | 84 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( 85 | epoch, batch_idx , len(dataloader), 86 | 100. * batch_idx / len(dataloader), loss.data.item()),'平均loss',loss_reduced.item()) 87 | # elpased = calculate_diff_time(start_time) 88 | # print('画出图示',str(elpased)) # 0.3s 89 | 90 | # elpased = calculate_diff_time(total_start_time) 91 | # print('one batch total time',elpased) 92 | 93 | 94 | train_loss /= len(dataloader) 95 | # train_loss = train_loss 96 | if lr_scheduler: 97 | print('learning rate:',lr_scheduler.get_lr()[-1]) 98 | print('Train set: Average loss: {:.4f}'.format(train_loss)) 99 | print('Time:',time.asctime(time.localtime(time.time()))) 100 | 101 | return train_loss 102 | 103 | 104 | def test(model,loss_fn,dataloader,pair_generation_tnf,gridGen,use_cuda=True): 105 | model.eval() 106 | test_loss = 0 107 | for batch_idx, batch in enumerate(dataloader): 108 | 109 | with torch.no_grad(): 110 | tnf_batch = pair_generation_tnf(batch) 111 | theta = model(tnf_batch).view(-1,2,3) 112 | 113 | sampling_grid = gridGen(theta) 114 | 115 | # 生成原始、目标、变换后的图片 116 | source_image_batch = tnf_batch['source_image'] 117 | target_image_batch = tnf_batch['target_image'] 118 | warped_image_batch = F.grid_sample(source_image_batch, sampling_grid) 119 | loss, g1xy, g2xy = loss_fn(target_image_batch, warped_image_batch) 120 | 121 | test_loss += loss.data 122 | 123 | # if batch_idx % 10 == 0: 124 | # print('test Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( 125 | # epoch, batch_idx, len(dataloader), 126 | # 100. * batch_idx / len(dataloader), loss.data)) 127 | 128 | test_loss /= len(dataloader) 129 | test_loss = test_loss.item() 130 | print('Test set: Average loss: {:.4f}'.format(test_loss)) 131 | return test_loss 132 | 133 | 134 | -------------------------------------------------------------------------------- /main/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io 3 | import numpy as np 4 | 5 | import torch 6 | from collections import OrderedDict 7 | import cv2 8 | 9 | from model.cnn_registration_model import CNNRegistration 10 | from ntg_pytorch.register_func import estimate_aff_param_iterator 11 | from tnf_transform.img_process import preprocess_image 12 | from tnf_transform.transformation import AffineTnf, affine_transform_opencv, affine_transform_pytorch, \ 13 | affine_transform_opencv_2 14 | from util.pytorchTcv import theta2param 15 | from traditional_ntg.estimate_affine_param import estimate_param_batch 16 | from visualization.matplot_tool import plot_batch_result 17 | from visualization.train_visual import VisdomHelper 18 | 19 | 20 | def register_images(source_image_path,target_image_path,use_cuda=True): 21 | 22 | env_name = 'compare_ntg_realize' 23 | vis = VisdomHelper(env_name) 24 | 25 | # 创建模型 26 | ntg_model = CNNRegistration(single_channel=True,use_cuda=use_cuda) 27 | 28 | print("Loading trained model weights") 29 | print("ntg_checkpoint_path:",ntg_checkpoint_path) 30 | 31 | # 把所有的张量加载到CPU中 GPU ==> CPU 32 | ntg_checkpoint = torch.load(ntg_checkpoint_path,map_location=lambda storage,loc: storage) 33 | ntg_checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'mo del'), v) for k, v in ntg_checkpoint['state_dict'].items()]) 34 | ntg_model.load_state_dict(ntg_checkpoint['state_dict']) 35 | 36 | source_image_raw = io.imread(source_image_path) 37 | 38 | target_image_raw = io.imread(target_image_path) 39 | 40 | source_image = source_image_raw 41 | target_image = target_image_raw 42 | 43 | source_image_var = preprocess_image(source_image,resize=True,use_cuda=use_cuda) 44 | target_image_var = preprocess_image(target_image,resize=True,use_cuda=use_cuda) 45 | 46 | # source_image_var = source_image_var[:,0,:,:][:,np.newaxis,:,:] 47 | # target_image_var = target_image_var[:,0,:,:][:,np.newaxis,:,:] 48 | 49 | batch = {'source_image': source_image_var, 'target_image':target_image_var} 50 | 51 | affine_tnf = AffineTnf(use_cuda=use_cuda) 52 | 53 | ntg_model.eval() 54 | theta = ntg_model(batch) 55 | 56 | ntg_param_batch = estimate_param_batch(source_image_var[:,0,:,:], target_image_var[:,2,:,:], None) 57 | ntg_image_warped_batch = affine_transform_opencv_2(source_image_var, ntg_param_batch) 58 | 59 | theta_opencv = theta2param(theta.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 60 | cnn_ntg_param_batch = estimate_param_batch(source_image_var[:,0,:,:], target_image_var[:,2,:,:], theta_opencv) 61 | 62 | cnn_image_warped_batch = affine_transform_pytorch(source_image_var,theta) 63 | cnn_ntg_image_warped_batch = affine_transform_opencv_2(source_image_var, cnn_ntg_param_batch) 64 | 65 | cnn_ntg_param_multi_batch = estimate_aff_param_iterator(source_image_var[:, 0, :, :].unsqueeze(1), 66 | target_image_var[:, 0, :, :].unsqueeze(1), 67 | theta_opencv, use_cuda=use_cuda, itermax=800) 68 | 69 | cnn_ntg_image_warped_mulit_batch = affine_transform_opencv_2(source_image_var, cnn_ntg_param_multi_batch.detach().cpu().numpy()) 70 | # cnn_ntg_image_warped_mulit_batch = affine_transform_opencv_2(source_image_var, theta_opencv.detach().cpu().numpy()) 71 | 72 | vis.showImageBatch(source_image_var, normailze=True, win='source_image_batch', title='source_image_batch') 73 | vis.showImageBatch(target_image_var, normailze=True, win='target_image_batch', title='target_image_batch') 74 | vis.showImageBatch(cnn_image_warped_batch, normailze=True, win='cnn_image_warped_batch', title='cnn_image_warped_batch') 75 | # 直接使用NTG去做的话不同通道可能直接就失败了 76 | # vis.showImageBatch(ntg_image_warped_batch, normailze=True, win='warped_image_batch', title='warped_image_batch') 77 | vis.showImageBatch(cnn_ntg_image_warped_mulit_batch, normailze=True, win='cnn_ntg_param_multi_batch', title='cnn_ntg_param_multi_batch') 78 | # vis.showImageBatch(cnn_ntg_image_warped_batch, normailze=True, win='cnn_ntg_wraped_image_old', 79 | # title='cnn_ntg_wraped_image_old') 80 | 81 | #plot_title = ["source","target",'ntg','cnn'] 82 | #plot_title = ["source","target","ntg",'cnn','cnn_ntg'] 83 | #plot_batch_result(source_image_var,target_image_var,ntg_image_warped_batch,cnn_ntg_image_warped_batch,plot_title=plot_title) 84 | #plot_batch_result(source_image_var,target_image_var,ntg_image_warped_batch,cnn_image_warped_batch,cnn_ntg_image_warped_batch,plot_title=plot_title) 85 | 86 | if __name__ == '__main__': 87 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 88 | ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/output/voc2012_coco2014_NTG_resnet101.pth.tar' 89 | 90 | use_cuda = torch.cuda.is_available() 91 | # source_image_path = '../datasets/row_data/multispectral/Ir.jpg' 92 | # source_image_path = '../datasets/row_data/multispectral/source.jpg' 93 | # target_image_path = '../datasets/row_data/multispectral/It.jpg' 94 | # target_image_path = '../datasets/row_data/multispectral/target.jpg' 95 | 96 | source_image_path = '../datasets/row_data/texs1.jpeg' 97 | target_image_path = '../datasets/row_data/test2.jpeg' 98 | 99 | register_images(source_image_path,target_image_path,use_cuda) 100 | -------------------------------------------------------------------------------- /evluate/lossfunc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from collections import Counter 5 | 6 | from tnf_transform.point_tnf import PointTnf 7 | import torch.nn.functional as F 8 | # from sklearn import metrics 9 | # 10 | # class MatualInfoLoss(nn.Module): 11 | # def __init__(self,use_cuda = True): 12 | # super(MatualInfoLoss,self).__init__() 13 | # 14 | # 15 | # def compute_matual_info(self,labels_pred,labels_true): 16 | # return metrics.mutual_info_score(labels_pred,labels_true) 17 | # 18 | # def entropy(self,labels): 19 | # prob_dict = Counter(labels) 20 | # s = sum(prob_dict.values()) 21 | # probs = np.array([i/s for i in prob_dict.values()]) 22 | # return -probs.dot(np.log(probs)) 23 | # 24 | # def get_sig_label(self,labels_pred,labels_true): 25 | # sig_label = ['%s%s'%(i,j) for i,j in zip(labels_pred,labels_true)] 26 | # return sig_label 27 | # 28 | # def matual_info_score(self,labels_pred,labels_true): 29 | # HA = self.entropy(labels_pred) 30 | # HB = self.entropy(labels_true) 31 | # HAB = self.entropy(self.get_sig_label(labels_pred,labels_true)) 32 | # MI = HA + HB - HAB 33 | # return MI 34 | # 35 | # def normalized_mutual_info_score(self,labels_pred,labels_true): 36 | # HA = self.entropy(labels_pred) 37 | # HB = self.entropy(labels_true) 38 | # HAB = self.entropy(self.get_sig_label(labels_pred, labels_true)) 39 | # MI = HA + HB - HAB 40 | # NMI = MI / (HA * HB)**0.5 41 | # return NMI 42 | # 43 | # def entropy_torch(self,labels): 44 | # prob_dict = Counter(labels) 45 | # s = sum(prob_dict.values()) 46 | # 47 | # 48 | 49 | #def matual_info_score_torch(self,labels_pred,labels_true): 50 | 51 | 52 | 53 | 54 | 55 | # matual_info_loss = MatualInfoLoss() 56 | # labels1 = [1,1,0,0,0] 57 | # labels2 = ['a','a','s','s','s'] 58 | # print(metrics.mutual_info_score(labels1,labels2)) 59 | # print(metrics.normalized_mutual_info_score(labels1,labels2)) 60 | # 61 | # print(matual_info_loss.matual_info_score(labels1,labels2)) 62 | # print(matual_info_loss.normalized_mutual_info_score(labels1,labels2)) 63 | 64 | 65 | 66 | # 计算两个仿射变化参数之间的网格点损失,用于评测最后结果 67 | class GridLoss: 68 | def __init__(self,use_cuda,grid_size=240): 69 | # 定义将要被变换的虚拟网格 70 | #self.axis_coords = np.linspace(-1,1,grid_size) 71 | self.axis_coords = np.linspace(0,grid_size,grid_size) 72 | self.N = grid_size*grid_size 73 | X,Y = np.meshgrid(self.axis_coords,self.axis_coords) 74 | X = np.reshape(X,(1,1,self.N)) 75 | Y = np.reshape(Y,(1,1,self.N)) 76 | P = np.concatenate((X,Y),1) 77 | self.pointTnf = PointTnf(use_cuda=use_cuda) 78 | self.P = torch.Tensor(P) 79 | if use_cuda: 80 | self.P = self.P.cuda() 81 | 82 | def compute_grid_loss(self,theta_estimate,theta_GT): 83 | # 根据batch的大小将网格展开 84 | batch_size = theta_estimate.size()[0] 85 | P = self.P.expand(batch_size,2,self.N) 86 | 87 | # 使用估计的网格点和真值网格点计算损失 88 | P_estimate = self.pointTnf.affPointTnf(theta_estimate,P) 89 | P_GT = self.pointTnf.affPointTnf(theta_GT,P) 90 | 91 | # 在网格点上面使用MSE损失 92 | P_diff = P_estimate - P_GT 93 | P_diff = torch.pow(P_diff[:,0,:],2) + torch.pow(P_diff[:,1,:],2) 94 | loss = torch.mean(torch.pow(P_diff,0.5),1) 95 | 96 | return loss 97 | 98 | def test_Grid_loss(): 99 | grid_loss = GridLoss(use_cuda=False,grid_size=10) 100 | theta1 = torch.Tensor(np.array([[0.4,0,0],[0,0.6,0]])).unsqueeze(0) 101 | #theta2 = torch.Tensor(np.array([[1.1,0,1],[0,1.1,1]])).unsqueeze(0) 102 | theta2 = torch.Tensor(np.array([[1.5,0,1],[0,1.5,1]])).unsqueeze(0) 103 | loss_value = grid_loss.compute_grid_loss(theta1,theta2) 104 | print(loss_value) 105 | 106 | 107 | # 使用NTG损失函数,用于训练 108 | class NTGLoss(nn.Module): 109 | def __init__(self,use_cuda = True): 110 | super(NTGLoss,self).__init__() 111 | self.use_cuda = use_cuda 112 | 113 | def forward(self, *input): 114 | loss_batch = compute_ntg_pytorch(input[0],input[1]) 115 | #return torch.mean(loss_batch) 116 | return loss_batch 117 | 118 | 119 | def compute_ntg_pytorch(img1,img2): 120 | g1x, g1y = gradient_1order(img1) 121 | g2x, g2y = gradient_1order(img2) 122 | 123 | g1xy = torch.sqrt(torch.pow(g1x,2)+torch.pow(g1y,2)) 124 | g2xy = torch.sqrt(torch.pow(g2x,2)+torch.pow(g2y,2)) 125 | 126 | m1 = func_rho_torch(g1x - g2x, 0) + func_rho_torch(g1y - g2y, 0) 127 | n1 = func_rho_torch(g1x, 0) + func_rho_torch(g2x, 0) + func_rho_torch(g1y, 0) + func_rho_torch(g2y, 0) 128 | #y1 = m1 / (n1 + 0.01) 129 | y1 = m1 / (n1 + 1e-16) 130 | 131 | #print(y1) 132 | return y1 133 | 134 | def gradient_1order(x,h_x=None,w_x=None): 135 | if h_x is None and w_x is None: 136 | h_x = x.size()[2] 137 | w_x = x.size()[3] 138 | r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:] 139 | l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x] 140 | t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :] 141 | b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :] 142 | #xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5) 143 | # xgrad = (r - l).float() * 0.5 144 | # ygrad = (t - b).float() * 0.5 145 | xgrad = (r - l).float() 146 | ygrad = (t - b).float() 147 | return xgrad,ygrad 148 | 149 | def func_rho_torch(x,order,epsilon = 0.01,use_cuda = True): 150 | if use_cuda: 151 | epsilon = torch.Tensor([epsilon]).float().cuda() 152 | else: 153 | epsilon = torch.Tensor([epsilon]).float() 154 | if order == 0: 155 | y = torch.sqrt(torch.pow(x,2) + torch.pow(epsilon,2)) 156 | y = torch.sum(y) 157 | elif order == 1: 158 | y = x/torch.sqrt(torch.pow(x,2) + torch.pow(epsilon,2)) 159 | 160 | return y 161 | -------------------------------------------------------------------------------- /demo_video.ipynb: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import matplotlib.pyplot as plt 3 | import os 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset, DataLoader 8 | from model.cnn_geometric_model import CNNGeometric 9 | from data.pf_dataset import PFDataset 10 | from data.download_datasets import download_PF_willow 11 | from image.normalization import NormalizeImageDict, normalize_image 12 | from util.torch_util import BatchTensorToVars, str_to_bool 13 | from geotnf.transformation import GeometricTnf 14 | from geotnf.point_tnf import * 15 | import matplotlib.pyplot as plt 16 | from skimage import io 17 | import warnings 18 | from torchvision.transforms import Normalize 19 | from collections import OrderedDict 20 | import numpy as np 21 | from IPython.display import clear_output 22 | 23 | def visualize(image): 24 | # plt.figure(figsize=(4, 4)) 25 | plt.axis('off') 26 | plt.imshow(image) 27 | plt.show() 28 | 29 | resizeCNN = GeometricTnf(out_h=240, out_w=240, use_cuda = False) 30 | 31 | def preprocess_image(image): 32 | # convert to torch Variable 33 | image = np.expand_dims(image.transpose((2,0,1)),0) 34 | image = torch.Tensor(image.astype(np.float32)/255.0) 35 | image_var = Variable(image,requires_grad=False) 36 | 37 | # Resize image using bilinear sampling with identity affine tnf 38 | image_var = resizeCNN(image_var) 39 | 40 | # Normalize image 41 | image_var = normalize_image(image_var) 42 | 43 | return image_var 44 | 45 | 46 | use_cuda = torch.cuda.is_available() 47 | def build_model(): 48 | feature_extraction_cnn = 'vgg' 49 | 50 | print(use_cuda) 51 | 52 | # model_aff_path = '/home/41875/zlk_236150/AIstation_back/best_checkpoint_adam_offset_grid_lossvgg.pth.tar' 53 | 54 | # server best epoch 108: validation average loss:0.0001 train_loss:0.00006 55 | model_aff_path = '/home/41875/zlk_236150/openProject/cnngeometric_pytorch/trained_models/best_checkpoint_adam_offset_grid_lossvgg.pth.tar' 56 | 57 | # Create model 58 | print('Creating CNN model...') 59 | 60 | model_aff = CNNGeometric(use_cuda=use_cuda,feature_extraction_cnn=feature_extraction_cnn,output_dim=2) 61 | 62 | checkpoint = torch.load(model_aff_path, map_location=lambda storage, loc: storage) 63 | checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()]) 64 | model_aff.load_state_dict(checkpoint['state_dict']) 65 | 66 | model_aff.eval() 67 | print("create model succeed") 68 | 69 | return model_aff 70 | 71 | affine_matrix_p1 = torch.tensor([[1.0, 0.0],[0.0, 1.0]]) 72 | affine_matrix_p1 = affine_matrix_p1.cuda() 73 | affine_matrix_p1 = affine_matrix_p1.repeat(1,1,1) 74 | 75 | def registerImages(model_aff,source_image,target_image): 76 | 77 | source_image_var = preprocess_image(source_image) 78 | target_image_var = preprocess_image(target_image) 79 | if use_cuda: 80 | source_image_var = source_image_var.cuda() 81 | target_image_var = target_image_var.cuda() 82 | 83 | batch = {'source_image': source_image_var, 'target_image':target_image_var} 84 | 85 | theta_aff = model_aff(batch) 86 | theta_aff = theta_aff.unsqueeze(2) 87 | theta_aff = torch.cat((affine_matrix_p1,theta_aff),2) 88 | 89 | return theta_aff,batch 90 | 91 | affTnf = GeometricTnf(geometric_model='offset', use_cuda=use_cuda) 92 | normalizeTnf = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 93 | resizeTgt = GeometricTnf(out_h=240, out_w=240, use_cuda = use_cuda) 94 | 95 | def aff_images(batchs,theta_aff): 96 | warped_image_aff = affTnf(batch['source_image'],theta_aff.view(-1,2,3)) 97 | warped_image_aff_np = normalize_image(resizeTgt(warped_image_aff),forward=False).data.squeeze(0).transpose(0,1).transpose(1,2).cpu().numpy() 98 | return warped_image_aff_np 99 | 100 | model_aff = build_model() 101 | 102 | 103 | def show_compares(source_image,target_image,warped_image_aff_np): 104 | source_image = np.clip(source_image,a_min=0,a_max=1) 105 | target_image = np.clip(target_image,a_min=0,a_max=1) 106 | warped_image_aff_np = np.clip(warped_image_aff_np,a_min=0,a_max=1) 107 | N_subplots = 3 108 | fig, axs = plt.subplots(1,N_subplots) 109 | axs[0].imshow(source_image) 110 | axs[0].set_title('origin') 111 | axs[1].imshow(target_image) 112 | axs[1].set_title('target') 113 | subplot_idx = 2 114 | 115 | axs[subplot_idx].imshow(warped_image_aff_np) 116 | axs[subplot_idx].set_title('aff') 117 | subplot_idx +=1 118 | 119 | for i in range(N_subplots): 120 | axs[i].axis('off') 121 | 122 | fig.set_dpi(150) 123 | plt.show() 124 | 125 | def plot_xy_curve(theta_xy): 126 | plt.figure() 127 | # print(theta_xy.shape) 128 | offset_x = theta_xy[:,0] 129 | offset_y = theta_xy[:,1] 130 | # print(offset_x.shape) 131 | 132 | plt.plot([i for i in range(len(offset_x))], offset_x,label="offset_x") 133 | plt.plot([i for i in range(len(offset_x))], offset_y,label="offset_y") 134 | 135 | plt.title('offset-frame') 136 | plt.xlabel('frame') 137 | plt.ylabel('offset') 138 | 139 | plt.show() 140 | 141 | import time 142 | import cv2 143 | warnings.filterwarnings("ignore") 144 | current_time = 0 145 | 146 | # 图像处理函数 147 | def processImg(img): 148 | # 画出一个框 149 | #cv2.rectangle(img, (500, 300), (800, 400), (0, 0, 255), 5, 1, 0) 150 | # 上下翻转 151 | # img= cv2.flip(img, 0) 152 | 153 | # 显示FPS 154 | global current_time 155 | if current_time == 0: 156 | current_time = time.time() 157 | else: 158 | last_time = current_time 159 | current_time = time.time() 160 | fps = 1. / (current_time - last_time) 161 | text = "FPS: %d" % int(fps) 162 | cv2.putText(img, text , (0,100), cv2.FONT_HERSHEY_TRIPLEX, 3.65, (255, 0, 0), 2) 163 | 164 | return img 165 | 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /ntg_pytorch/register_pyramid.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import cv2 4 | import scipy.misc as smi 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | from PIL.Image import BICUBIC, BILINEAR 11 | 12 | 13 | def compute_pyramid(image_batch,f,nL,ration,use_cuda=False): 14 | ''' 15 | 暂时发现使用这个构造图像金字塔精度高一点 16 | :param image_batch: 17 | :param f: 18 | :param nL: 19 | :param ration: 20 | :return: 21 | ''' 22 | 23 | image_batch = image_batch.transpose(1,2).transpose(2,3) 24 | 25 | if use_cuda: 26 | image_batch = image_batch.cpu().numpy() 27 | else: 28 | image_batch = image_batch.numpy() 29 | multi_level_image_batch = [] 30 | multi_level_image_batch.append(torch.from_numpy(image_batch.transpose((0,3,1,2)))) 31 | current_ration = ration 32 | for level in range(1,nL): 33 | 34 | level_image_batch = [] 35 | for image_item in image_batch: 36 | # image_item = cv2.filter2D(image_item,-1,f) 37 | # image_item = image_item[:,:,np.newaxis] 38 | # level_image = smi.imresize(tmp,size=current_ration,interp='cubic')/255.0 39 | # level_image = smi.imresize(image_item[:,:,0],size=current_ration)/255.0 40 | level_image = np.array(Image.fromarray(image_item[:,:,0]).resize( 41 | (int(image_item.shape[0] * current_ration), int(image_item.shape[1] * current_ration)), resample=BICUBIC)) 42 | 43 | if len(level_image.shape) == 2: 44 | level_image = level_image[:,:,np.newaxis] 45 | level_image_batch.append(level_image) 46 | 47 | level_image_batch = np.array(level_image_batch).transpose((0,3,1,2)) 48 | level_image_batch = torch.from_numpy(level_image_batch).float() 49 | 50 | if use_cuda: 51 | level_image_batch = level_image_batch.cuda() 52 | 53 | 54 | multi_level_image_batch.append(level_image_batch) 55 | current_ration = current_ration * ration 56 | 57 | # plt.figure() 58 | # plt.imshow(level_image_batch[0].squeeze()) 59 | 60 | 61 | return multi_level_image_batch 62 | 63 | def compute_pyramid_iter(image_batch,f,nL,ration,use_cuda=False): 64 | ''' 65 | 暂时发现使用这个构造图像金字塔精度高一点 66 | :param image_batch: [batch,channel,h,w] 67 | :param f: 68 | :param nL: 69 | :param ration: 70 | :return: 71 | ''' 72 | 73 | image_batch = image_batch.transpose(1,2).transpose(2,3) 74 | 75 | if use_cuda: 76 | image_batch = image_batch.cpu().numpy() 77 | else: 78 | image_batch = image_batch.numpy() 79 | multi_level_image_batch = [] 80 | multi_level_image_batch.append(torch.from_numpy(image_batch.transpose((0,3,1,2)))) 81 | for level in range(1,nL): 82 | 83 | level_image_batch = [] 84 | for i in range(len(image_batch)): 85 | # temp_image = np.squeeze(image_batch[i]) 86 | temp_image = cv2.filter2D(image_batch[i], -1, f) 87 | temp_image = np.array(Image.fromarray(temp_image).resize( 88 | (math.ceil(temp_image.shape[0] * ration), math.ceil(temp_image.shape[1] * ration)),resample=BICUBIC)) 89 | 90 | if len(temp_image.shape) == 2: 91 | temp_image = temp_image[:,:,np.newaxis] 92 | level_image_batch.append(temp_image) 93 | 94 | image_batch = level_image_batch 95 | 96 | level_image_batch = np.array(level_image_batch).transpose((0,3,1,2)) 97 | level_image_batch = torch.from_numpy(level_image_batch).float() 98 | 99 | if use_cuda: 100 | level_image_batch = level_image_batch.cuda() 101 | 102 | 103 | multi_level_image_batch.append(level_image_batch) 104 | 105 | # plt.figure() 106 | # plt.imshow(level_image_batch[0].squeeze()) 107 | 108 | 109 | return multi_level_image_batch 110 | 111 | 112 | def compute_pyramid_pytorch(image_batch,scaleTnf,filter,nL,ration,use_cuda=False): 113 | ''' 114 | :param image_batch: [batch,channel,h,w] Tensor 115 | :param f: 116 | :param nL: 117 | :param ration: 118 | :return: 119 | ''' 120 | 121 | kernel = torch.Tensor(filter).unsqueeze(0).unsqueeze(0) 122 | 123 | if use_cuda: 124 | kernel = kernel.cuda() 125 | 126 | multi_level_image_batch = [] 127 | multi_level_image_batch.append(image_batch) 128 | for level in range(1,nL): 129 | 130 | image_batch = scaleTnf(image_batch,ration) 131 | 132 | # 对图片进行高斯滤波,发现使用高斯滤波的话精度反而会降低 133 | # image_batch = F.conv2d(image_batch, kernel, padding=0) 134 | # image_batch = F.pad(image_batch, (1, 1, 1, 1), mode='reflect') 135 | 136 | # plt.figure() 137 | # plt.imshow(image_batch[0].squeeze()) 138 | 139 | multi_level_image_batch.append(image_batch) 140 | 141 | return multi_level_image_batch 142 | 143 | 144 | def compute_image_pyramid_single(image, f, nL, ration): 145 | P = [] 146 | tmp = image 147 | P.append(tmp) 148 | 149 | for m in range(1, nL): 150 | tmp = cv2.filter2D(tmp, -1, f) 151 | # 使用skimage来resize图片:https://scikit-image.org/docs/stable/auto_examples/transform/plot_rescale.html 152 | img = np.array(Image.fromarray(tmp).resize((int(tmp.shape[0] * ration), int(tmp.shape[1] * ration)))) 153 | tmp = img 154 | P.append(tmp[np.newaxis, :, :]) 155 | 156 | return P 157 | 158 | 159 | class ScaleTnf: 160 | 161 | def __init__(self,use_cuda=False): 162 | self.theta_identity = torch.tensor([ 163 | [1, 0, 0], 164 | [0, 1, 0] 165 | ], dtype=torch.float).unsqueeze(0) #[batch,2,3] 166 | 167 | if use_cuda: 168 | self.theta_identity = self.theta_identity.cuda() 169 | def __call__(self,image_batch, ration): 170 | batch_size, c, h, w = image_batch.size() 171 | out_size = torch.Size((batch_size, c, int(h * ration), int(w * ration))) 172 | grid = F.affine_grid(self.theta_identity.repeat(batch_size,1,1), out_size) 173 | output = F.grid_sample(image_batch, grid) 174 | return output 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /pyqt/register_v1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'register.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.1 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | import time 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | from PyQt5.QtGui import QPalette, QColor 13 | from PyQt5.QtWidgets import QFileDialog, QApplication 14 | from main.paper_test import RegisterHelper 15 | 16 | class Ui_Dialog(object): 17 | 18 | def __init__(self): 19 | # self.source_imagePath = '../datasets/row_data/multispectral/door1.jpg' 20 | # self.target_imagePath = '../datasets/row_data/multispectral/door1.jpg' 21 | self.source_imagePath = None 22 | self.target_imagePath = None 23 | self.itermax = 800 24 | self.registerHelper = RegisterHelper() 25 | 26 | 27 | def setupUi(self, Dialog): 28 | Dialog.setObjectName("Dialog") 29 | Dialog.resize(607, 525) 30 | self.selectimg1 = QtWidgets.QPushButton(Dialog) 31 | # self.selectimg1.setGeometry(QtCore.QRect(50, 130, 75, 23)) 32 | self.selectimg1.setGeometry(QtCore.QRect(50, 150, 95, 43)) 33 | self.selectimg1.setObjectName("selectimg1") 34 | self.selectimg2 = QtWidgets.QPushButton(Dialog) 35 | self.selectimg2.setGeometry(QtCore.QRect(260, 150, 95, 43)) 36 | self.selectimg2.setObjectName("selectimg2") 37 | self.label = QtWidgets.QLabel(Dialog) 38 | self.label.setGeometry(QtCore.QRect(70, 380, 51, 16)) 39 | self.label.setObjectName("label") 40 | self.label_2 = QtWidgets.QLabel(Dialog) 41 | self.label_2.setGeometry(QtCore.QRect(280, 380, 54, 12)) 42 | self.label_2.setObjectName("label_2") 43 | self.label_3 = QtWidgets.QLabel(Dialog) 44 | self.label_3.setGeometry(QtCore.QRect(490, 380, 54, 12)) 45 | self.label_3.setObjectName("label_3") 46 | self.showimg1 = QtWidgets.QLabel(Dialog) 47 | self.showimg1.setGeometry(QtCore.QRect(10, 190, 181, 161)) 48 | self.showimg1.setObjectName("showimg1") 49 | self.showimg2 = QtWidgets.QLabel(Dialog) 50 | self.showimg2.setGeometry(QtCore.QRect(200, 190, 181, 161)) 51 | self.showimg2.setObjectName("showimg2") 52 | self.showimg3 = QtWidgets.QLabel(Dialog) 53 | self.showimg3.setGeometry(QtCore.QRect(410, 190, 181, 161)) 54 | self.showimg3.setObjectName("showimg3") 55 | self.registerButton = QtWidgets.QPushButton(Dialog) 56 | self.registerButton.setGeometry(QtCore.QRect(450, 450, 85, 33)) 57 | self.registerButton.setObjectName("registerButton") 58 | 59 | self.showimg1.setAutoFillBackground(True) 60 | self.showimg2.setAutoFillBackground(True) 61 | self.showimg3.setAutoFillBackground(True) 62 | background_color = QColor() 63 | background_color.setNamedColor('#FFFFFF') 64 | palette = QPalette() 65 | palette.setColor(QPalette.Window, background_color) 66 | self.showimg1.setPalette(palette) 67 | self.showimg2.setPalette(palette) 68 | self.showimg3.setPalette(palette) 69 | 70 | self.selectimg1.clicked.connect(self.openImage1) 71 | self.selectimg2.clicked.connect(self.openImage2) 72 | 73 | self.registerButton.clicked.connect(self.register) 74 | 75 | self.retranslateUi(Dialog) 76 | QtCore.QMetaObject.connectSlotsByName(Dialog) 77 | 78 | self.retranslateUi(Dialog) 79 | QtCore.QMetaObject.connectSlotsByName(Dialog) 80 | 81 | # self.register() 82 | 83 | def retranslateUi(self, Dialog): 84 | _translate = QtCore.QCoreApplication.translate 85 | Dialog.setWindowTitle(_translate("Dialog", "Dialog")) 86 | self.selectimg1.setText(_translate("Dialog", "选择待配准图像")) 87 | self.selectimg2.setText(_translate("Dialog", "选择模板图像")) 88 | self.label.setText(_translate("Dialog", "source")) 89 | self.label_2.setText(_translate("Dialog", "target")) 90 | self.label_3.setText(_translate("Dialog", "warped")) 91 | self.showimg1.setText(_translate("Dialog", "source")) 92 | self.showimg2.setText(_translate("Dialog", "target")) 93 | self.showimg3.setText(_translate("Dialog", "warped")) 94 | self.registerButton.setText(_translate("Dialog", "register")) 95 | 96 | def openImage1(self): 97 | imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "", "*.jpg;;*.png;;All Files(*)") 98 | jpg = QtGui.QPixmap(imgName).scaled(self.showimg1.width(), self.showimg1.height()) 99 | self.source_imagePath = imgName 100 | self.showimg1.setPixmap(jpg) 101 | 102 | def openImage2(self): 103 | imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "", "*.jpg;;*.png;;All Files(*)") 104 | jpg = QtGui.QPixmap(imgName).scaled(self.showimg2.width(), self.showimg2.height()) 105 | self.target_imagePath = imgName 106 | self.showimg2.setPixmap(jpg) 107 | 108 | def register(self,type='cnn'): 109 | ''' 110 | :param type: cnn | ntg | cnn_ntg 111 | :return: 112 | ''' 113 | print(self.source_imagePath,self.target_imagePath," type:"+type) 114 | 115 | if type == 'cnn': 116 | image = self.registerHelper.register_CNN(self.source_imagePath,self.target_imagePath) 117 | elif type == 'ntg': 118 | image = self.registerHelper.register_NTG(self.source_imagePath,self.target_imagePath,itermax=self.itermax) 119 | else: 120 | image = self.registerHelper.register_CNN_NTG(self.source_imagePath,self.target_imagePath,itermax=self.itermax) 121 | print("配准完成,开始显示图片") 122 | image = QtGui.QImage(image,image.shape[1],image.shape[0],QtGui.QImage.Format_RGB888) 123 | print("QImage 完成") 124 | pix = QtGui.QPixmap(image).scaled(self.showimg3.width(), self.showimg3.height()) 125 | print("填充pix") 126 | self.showimg3.setPixmap(pix) 127 | print("显示完成") 128 | time.sleep(0.5) 129 | QApplication.notify() 130 | -------------------------------------------------------------------------------- /ntg_pytorch/register_loss.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | from tnf_transform.transformation import affine_transform_pytorch, affine_transform_opencv, \ 9 | affine_transform_opencv_batch 10 | from traditional_ntg.loss_function import deriv_filt 11 | from util.pytorchTcv import param2theta 12 | 13 | def ntg_gradient_torch(objdict,p,use_cuda = False): 14 | options = objdict['parser'] 15 | source_image_batch = objdict['source_images'] 16 | target_image_batch = objdict['target_images'] 17 | 18 | p_pytorch = param2theta(p, source_image_batch.shape[2], source_image_batch.shape[3], use_cuda=use_cuda) 19 | warpI = affine_transform_pytorch(source_image_batch, p_pytorch) 20 | 21 | # 结果不如使用pytorch的 22 | # warpI = affine_transform_opencv_batch(source_image_batch,p,use_cuda=use_cuda) 23 | 24 | batch, c, h, w = source_image_batch.shape 25 | 26 | # I_source_x,I_source_y = deriv_filt_pytorch(source_image_batch, False, use_cuda) 27 | # Ipx = affine_transform_opencv_batch(I_source_x,p,use_cuda=use_cuda) 28 | # Ipy = affine_transform_opencv_batch(I_source_y,p,use_cuda=use_cuda) 29 | 30 | # Ipx = affine_transform_pytorch(I_source_x,p_pytorch) 31 | # Ipy = affine_transform_pytorch(I_source_y,p_pytorch) 32 | 33 | # 直接使用变换图像的梯度不如使用梯度图进行变换精度高 34 | Ipx, Ipy = deriv_filt_pytorch(warpI, False, use_cuda) 35 | It = warpI - target_image_batch 36 | 37 | J = compute_ntg_pytorch(target_image_batch, warpI, use_cuda) 38 | 39 | [Itx, Ity] = deriv_filt_pytorch(It, False, use_cuda) 40 | 41 | rho_x = func_rho_pytorch(Itx, 1,use_cuda= use_cuda) - J.reshape(batch,1,1,1) * func_rho_pytorch(Ipx, 1,use_cuda= use_cuda) 42 | rho_y = func_rho_pytorch(Ity, 1,use_cuda= use_cuda) - J.reshape(batch,1,1,1) * func_rho_pytorch(Ipy, 1,use_cuda= use_cuda) 43 | 44 | [wxx, wxy] = deriv_filt_pytorch(rho_x, True, use_cuda) 45 | [wyx, wyy] = deriv_filt_pytorch(rho_y, True, use_cuda) 46 | 47 | w = wxx + wyy 48 | 49 | # w = w.squeeze() 50 | g = torch.zeros((p.shape[0],6, 1)) 51 | if use_cuda: 52 | g = g.cuda() 53 | g[:,0] = torch.mean(w * objdict['X_array'] * Ipx,(2,3)) 54 | g[:,1] = torch.mean(w * objdict['Y_array'] * Ipx,(2,3)) 55 | g[:,2] = torch.mean(w * Ipx,(2,3)) 56 | g[:,3] = torch.mean(w * objdict['X_array'] * Ipy,(2,3)) 57 | g[:,4] = torch.mean(w * objdict['Y_array'] * Ipy,(2,3)) 58 | g[:,5] = torch.mean(w * Ipy,(2,3)) 59 | 60 | g = g.reshape(-1, 2, 3) 61 | 62 | return g 63 | 64 | def compute_ntg_pytorch(img1,img2,use_cuda=True): 65 | g1x, g1y = deriv_filt_pytorch(img1,False,use_cuda= use_cuda) 66 | g2x, g2y = deriv_filt_pytorch(img2,False,use_cuda= use_cuda) 67 | 68 | # g1xy = torch.sqrt(torch.pow(g1x,2)+torch.pow(g1y,2)) 69 | # g2xy = torch.sqrt(torch.pow(g2x,2)+torch.pow(g2y,2)) 70 | 71 | m1 = func_rho_pytorch(g1x - g2x, 0,use_cuda= use_cuda) + func_rho_pytorch(g1y - g2y, 0,use_cuda= use_cuda) 72 | n1 = func_rho_pytorch(g1x, 0,use_cuda= use_cuda) + func_rho_pytorch(g2x, 0,use_cuda= use_cuda) + \ 73 | func_rho_pytorch(g1y, 0,use_cuda= use_cuda) + func_rho_pytorch(g2y, 0,use_cuda= use_cuda) 74 | y1 = m1 / (n1 + 1e-16) 75 | 76 | #print(y1) 77 | return y1 78 | 79 | def deriv_filt_pytorch(I,isconj,use_cuda=False): 80 | ''' 81 | :param I: 输入维度为Tensor [batch_size,channel,h,w] 82 | :param isconj: 83 | :return: 84 | ''' 85 | # I = I.squeeze().numpy() 86 | # Ix,Iy = deriv_filt(I,isconj) 87 | # Ix = torch.Tensor(Ix).unsqueeze(0).unsqueeze(0) 88 | # Iy = torch.Tensor(Iy).unsqueeze(0).unsqueeze(0) 89 | # 90 | # return Ix,Iy 91 | 92 | batch,channel,h,w = I.shape 93 | if not isconj: 94 | kernel_x = [[-0.5,0,0.5]] 95 | kernel_y = [[-0.5],[0],[0.5]] 96 | else: 97 | kernel_x = [[0.5,0,-0.5]] 98 | kernel_y = [[0.5],[0],[-0.5]] 99 | 100 | kernel_x = torch.Tensor(kernel_x).unsqueeze(0) 101 | kernel_y = torch.Tensor(kernel_y).unsqueeze(0) 102 | 103 | kernel_x = kernel_x.expand(channel,1,kernel_x.shape[1],kernel_x.shape[2]) 104 | kernel_y = kernel_y.expand(channel,1,kernel_y.shape[1],kernel_y.shape[2]) 105 | 106 | if use_cuda: 107 | kernel_x = kernel_x.cuda() 108 | kernel_y = kernel_y.cuda() 109 | 110 | ## 注意,这里面的Ix和Iy和cv2的filter不一样 111 | # Ix = F.conv2d(I,kernel_x,padding=1)[:,:,1:-1,:] # 上下为0 112 | # Iy = F.conv2d(I,kernel_y,padding=1)[:,:,:,1:-1] # 左右为0 113 | 114 | # I = F.pad(I,(1,1,1,1),mode='reflect') 115 | # Ix = F.conv2d(I,kernel_x,padding=0)[:,:,1:-1,:] # 上下为0 116 | # Iy = F.conv2d(I,kernel_y,padding=0)[:,:,:,1:-1] # 左右为0 117 | 118 | 119 | ## 注意!不同的mode导致的结果也不一样 120 | # 这里的groups是使用了分组卷积,相当于使用每个channel对滤波核进行操作,加了以后能够处理多通道如RGB图片 121 | # 和cv2.filter2D不一样,这里的直接就是计算,没有后续处理。 122 | # 按照0填充,这个和opencv的filter2D的处理方式一样 123 | # Ix_pad = F.pad(I,(1,1,0,0),mode='replicate') 124 | # Iy_pad = F.pad(I,(0,0,1,1),mode='replicate') 125 | 126 | Ix_pad = F.pad(I,(1,1,0,0),mode='reflect') 127 | Iy_pad = F.pad(I,(0,0,1,1),mode='reflect') 128 | 129 | 130 | Ix = F.conv2d(Ix_pad,kernel_x,padding=0,groups=channel) # 上下为0 131 | Iy = F.conv2d(Iy_pad,kernel_y,padding=0,groups=channel) # 左右为0 132 | 133 | 134 | # 这个pad操作很影响最后的结果 135 | # Ix = F.pad(Ix,(1,1,0,0),mode='reflect') 136 | # Iy = F.pad(Iy,(0,0,1,1),mode='reflect') 137 | 138 | 139 | # Ix = torch.max(Ix,torch.Tensor([0])) 140 | # Iy = torch.max(Iy,torch.Tensor([0])) 141 | 142 | # Ix = F.pad(Ix,(1,1,0,0),mode='circular') 143 | # Iy = F.pad(Iy,(0,0,1,1),mode='circular') 144 | 145 | return Ix,Iy 146 | 147 | 148 | def func_rho_pytorch(x,order,epsilon=0.01,use_cuda=False): 149 | if use_cuda: 150 | epsilon = torch.Tensor([epsilon]).float().cuda() 151 | else: 152 | epsilon = torch.Tensor([epsilon]).float() 153 | if order == 0: 154 | y = torch.sqrt(torch.pow(x, 2) + torch.pow(epsilon, 2)) 155 | y = torch.sum(y.reshape(x.shape[0], -1), 1) 156 | elif order == 1: 157 | y = x / torch.sqrt(torch.pow(x, 2) + torch.pow(epsilon, 2)) 158 | 159 | return y 160 | -------------------------------------------------------------------------------- /datasets/provider/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from skimage import io 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils.data import Dataset 8 | import time 9 | import cv2 10 | 11 | from random import choice 12 | from pathlib import Path 13 | 14 | from tqdm import tqdm 15 | 16 | from tnf_transform.img_process import random_affine 17 | from tnf_transform.transformation import AffineTnf 18 | from traditional_ntg.image_util import symmetricImagePad 19 | from util.csv_opeartor import read_csv_file 20 | from util.time_util import calculate_diff_time 21 | 22 | ''' 23 | 用作测试用,随机产生训练仿射变换参数然后输入网络进行训练。 24 | 作为dataloader的参数输入,自定义getitem得到Sample{image,theta,name} 25 | ''' 26 | class TestDataset(Dataset): 27 | 28 | 29 | def __init__(self,training_image_path,label_path,output_size=(480,640),transform=None,cache_images = False,use_cuda = True): 30 | ''' 31 | :param training_image_path: 32 | :param output_size: 33 | :param transform: 34 | :param cache_images: 如果数据量不是特别大可以缓存到内存里面加快读取速度 35 | :param use_cuda: 36 | ''' 37 | self.out_h, self.out_w = output_size 38 | self.use_cuda = use_cuda 39 | self.cache_images = cache_images 40 | # read image file 41 | self.training_image_path = training_image_path 42 | self.train_data = sorted(os.listdir(self.training_image_path)) 43 | self.image_count = len(self.train_data) 44 | # bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index 45 | # nb = bi[-1] + 1 # number of batches 46 | if self.cache_images: 47 | self.imgs = [None]*self.image_count 48 | else: 49 | self.imgs = [] 50 | self.transform = transform 51 | self.control_size = 1000000000 52 | self.label_path = label_path 53 | self.csv_data = read_csv_file(self.label_path) # 数据帧df,可看做表格,如果加入index限定主键的话values就不包含主键 54 | 55 | # cache images into memory for faster training(~5GB) 56 | if self.cache_images: 57 | for i in tqdm(range(min(self.image_count,self.control_size)),desc='Reading images'): # 最多10k张图片 58 | image_name = self.train_data[i] 59 | image_path = os.path.join(self.training_image_path, image_name) 60 | image = cv2.imread(image_path) # shape [h,w,c] BGR 61 | assert image is not None, 'Image Not Found' + image_path 62 | self.imgs[i] = image 63 | 64 | def __len__(self): 65 | return len(self.train_data) 66 | 67 | def __getitem__(self, idx): 68 | 69 | #total_start_time = time.time() 70 | 71 | image_name = self.train_data[idx] 72 | 73 | if self.cache_images: 74 | image = self.imgs[idx] 75 | else: 76 | image_path = os.path.join(self.training_image_path, image_name) 77 | image = cv2.imread(image_path) # shape [h,w,c] 78 | 79 | if image.shape[0]!= self.out_h or image.shape[1] != self.out_w: 80 | image = cv2.resize(image, (self.out_w, self.out_h), interpolation=cv2.INTER_LINEAR) 81 | 82 | image = image[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 83 | image = np.ascontiguousarray(image, dtype=np.float32) # uint8 to float32 84 | image = torch.from_numpy(image) 85 | 86 | label_row_param = self.csv_data.loc[self.csv_data['image'] == image_name].values 87 | label_row_param = np.squeeze(label_row_param) 88 | if image_name != label_row_param[0]: 89 | raise ValueError("图片文件名和label图片文件名不匹配") 90 | 91 | theta = label_row_param[1:].reshape(2,3) 92 | theta = torch.from_numpy(theta.astype(np.float32)) 93 | 94 | sample = {'image': image, 'theta': theta, 'name': image_name} 95 | 96 | if self.transform: 97 | sample = self.transform(sample) 98 | 99 | # elpased = calculate_diff_time(total_start_time) 100 | # print('getitem时间:',elpased) # 0.011s 101 | 102 | return sample 103 | 104 | ''' 105 | 使用仿射变换参数生成图片对 106 | 返回{"source_image,traget_image,theta_GT,name"} 107 | ''' 108 | class NtgTestPair(object): 109 | 110 | def __init__(self, use_cuda=True, crop_factor=9 / 16, output_size=(240, 240), 111 | padding_factor=0.6): 112 | self.use_cuda = use_cuda 113 | self.crop_factor = crop_factor 114 | self.padding_factor = padding_factor 115 | self.out_h, self.out_w = output_size 116 | self.channel_choicelist = [0,2] 117 | self.rescalingTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 118 | self.geometricTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 119 | 120 | def __call__(self, batch): 121 | image_batch, theta_batch,image_name = batch['image'], batch['theta'],batch['name'] 122 | if self.use_cuda: 123 | image_batch = image_batch.cuda() 124 | theta_batch = theta_batch.cuda() 125 | 126 | b, c, h, w = image_batch.size() 127 | 128 | # 为较大的采样区域生成对称填充图像 129 | # image_batch = symmetricImagePad(image_batch, self.padding_factor,use_cuda=self.use_cuda) 130 | self.crop_factor = 1.0 131 | self.padding_factor = 1.0 132 | 133 | indices_R = torch.tensor([0]) 134 | indices_G = torch.tensor([2]) 135 | 136 | if self.use_cuda: 137 | indices_R = indices_R.cuda() 138 | indices_G = indices_G.cuda() 139 | 140 | image_batch_R = torch.index_select(image_batch, 1, indices_R) 141 | image_batch_G = torch.index_select(image_batch, 1, indices_G) 142 | 143 | # 获取裁剪的图像 144 | cropped_image_batch = self.rescalingTnf(image_batch_R, None, self.padding_factor, 145 | self.crop_factor) # Identity is used as no theta given 146 | # 获取裁剪变换的图像 147 | warped_image_batch = self.geometricTnf(image_batch_G, theta_batch, 148 | self.padding_factor, 149 | self.crop_factor) # Identity is used as no theta given 150 | 151 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch, 152 | 'name': image_name} -------------------------------------------------------------------------------- /traditional_ntg/estimate_affine_param.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import traditional_ntg.image_util as imgUtil 7 | import traditional_ntg.compute_image_pyramid as pyramid 8 | import math 9 | import traditional_ntg.loss_function as lossfn 10 | import matplotlib.pyplot as plt 11 | from skimage import io 12 | 13 | # 传统方法的包装方法,批量运算 14 | from util.time_util import calculate_diff_time 15 | 16 | # 使用传统方法进行优化 17 | def estimate_param_batch(source_image_batch,target_image_batch,theta_opencv_batch=None,itermax = 800): 18 | ''' 19 | :param source_image_batch: Tensor[batch_size,C,h,w] 20 | :param target_image_batch: Tensor[batch_size,C,h,w] 21 | :param theta_opencv_batch: Tensor[batch_size,3,3] 22 | :return: opencv变换参数 23 | ''' 24 | ntg_param_batch = [] 25 | 26 | for i in range(len(source_image_batch)): 27 | if theta_opencv_batch is None: 28 | ntg_param = estimate_affine_param(target_image_batch[i].squeeze().detach().cpu().numpy(), 29 | source_image_batch[i].squeeze().detach().cpu().numpy(), 30 | None, itermax=itermax) 31 | else: 32 | ntg_param = estimate_affine_param(target_image_batch[i].squeeze().detach().cpu().numpy(), 33 | source_image_batch[i].squeeze().detach().cpu().numpy(), 34 | theta_opencv_batch[i].detach().cpu().numpy(), itermax=itermax) 35 | 36 | ntg_param_batch.append(ntg_param) 37 | 38 | ntg_param_batch = torch.Tensor(ntg_param_batch) 39 | return ntg_param_batch 40 | 41 | # 单个传统方法的运算 42 | # img1是target_image, img2是source_image 43 | def estimate_affine_param(img1,img2,p=None,itermax = 800,custom_pyramid_level=-1): 44 | ''' 45 | :param img1: img1是target_image [h,w] 46 | :param img2: img2是source_image 47 | :param p: p为None则初始化为单位矩阵,不为None则继承运行 48 | :param itermax: 49 | :return: 返回opencv的参数 50 | ''' 51 | 52 | if len(img1.shape) == 3: 53 | img1 = img1[0,:,:] 54 | 55 | if len(img2.shape) == 3: 56 | img2 = img2[0,:,:] 57 | 58 | options = {} 59 | options['tol'] = 1e-6 60 | options['itermax'] = itermax 61 | #options['itermax'] = 100 62 | options['minSize'] = 16 63 | options['pyramid_spacing'] = 1.5 64 | options['display'] = True 65 | options['deriv_filter'] = np.array([-0.5, 0, 0.5]) 66 | options['deriv_filter_conj'] = np.array([0.5, 0, -0.5]) 67 | if p is None: 68 | options['initial_affine_param'] = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) 69 | else: 70 | options['initial_affine_param'] = np.copy(p) 71 | #options['initial_affine_param'] = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) 72 | pyramid_level1 = 1 + np.floor(np.log(img1.shape[0] / options["minSize"]) / np.log(options["pyramid_spacing"])); 73 | pyramid_level2 = 1 + np.floor(np.log(img1.shape[1] / options["minSize"]) / np.log(options["pyramid_spacing"])); 74 | options['pyramid_levels'] = np.min((int(pyramid_level1),int(pyramid_level2))); 75 | #options['pyramid_levels'] = 6 76 | if custom_pyramid_level > 0: 77 | options['pyramid_levels'] = custom_pyramid_level 78 | 79 | IMAX = np.max([np.max(img1),np.max(img2)]) 80 | IMIN = np.min([np.min(img1),np.min(img2)]) 81 | 82 | images1 = imgUtil.scale_image(img1,IMIN,IMAX) 83 | images2 = imgUtil.scale_image(img2,IMIN,IMAX) 84 | 85 | images = np.stack((images1,images2),axis=2) 86 | smooth_sigma = np.sqrt(options['pyramid_spacing']) / np.sqrt(3); 87 | 88 | # plt.figure() 89 | # plt.imshow(img1,cmap='gray') 90 | # plt.figure() 91 | # plt.imshow(images1,cmap=plt.cm.gray_r) 92 | # plt.show() 93 | 94 | kx = cv2.getGaussianKernel(int(2*round(1.5*smooth_sigma))+1,smooth_sigma) 95 | ky = cv2.getGaussianKernel(int(2*round(1.5*smooth_sigma))+1,smooth_sigma) 96 | hg = np.multiply(kx,np.transpose(ky)) 97 | 98 | start_time = time.time() 99 | pyramid_images = pyramid.compute_image_pyramid(images,hg,int(options['pyramid_levels']),1/options['pyramid_spacing']) 100 | elpased = calculate_diff_time(start_time) 101 | #print('计算图像金字塔时间:',elpased) 102 | 103 | if p is not None: 104 | # 注意,看一下这里是否正确进行缩放了 105 | options['initial_affine_param'][0, 2] = options['initial_affine_param'][0, 2]/240 * pyramid_images[-1].shape[0] 106 | options['initial_affine_param'][1, 2] = options['initial_affine_param'][1, 2]/240 * pyramid_images[-1].shape[0] 107 | 108 | start_time = time.time() 109 | for k in range(options['pyramid_levels']-1,-1,-1): 110 | if k == (options['pyramid_levels']-1): 111 | p = options['initial_affine_param'] 112 | #p = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) 113 | else: 114 | options['itermax'] = math.ceil(options['itermax']/options['pyramid_spacing']) 115 | p[0,2] = p[0,2] *pyramid_images[k].shape[1]/pyramid_images[k+1].shape[1] 116 | p[1,2] = p[1,2] *pyramid_images[k].shape[0]/pyramid_images[k+1].shape[0] 117 | 118 | # 生成当前层设置的拷贝 119 | small = {} 120 | small['options'] = options 121 | small['images'] = pyramid_images[k] 122 | 123 | # 在当前层估计仿射变换参数 124 | sz = [pyramid_images[k].shape[0],pyramid_images[k].shape[1]] 125 | xlist = range(0,sz[1]) 126 | ylist = range(0,sz[0]) 127 | 128 | [X,Y] = np.meshgrid(xlist,ylist) 129 | 130 | small['X'] = X/np.max(X) 131 | small['Y'] = Y/np.max(Y) 132 | 133 | converged = False 134 | iter = 0 135 | #steplength = 0.5/np.max(sz) 136 | steplength = 0.5/np.max(sz) 137 | 138 | while not converged: 139 | g = lossfn.ntg_gradient(small,p) 140 | if p is None: 141 | print("p is None") 142 | p = p + steplength*g/np.max(np.abs(g+1e-16)) 143 | residualError = np.max(np.abs(g)) 144 | iter = iter + 1 145 | converged = (iter>=options['itermax']) or (residualError < options['tol']) 146 | #print(converged) 147 | # if converged: 148 | # print(str(k)+" "+str(iter)+" "+ str(residualError)) 149 | 150 | elpased = calculate_diff_time(start_time) 151 | #print("迭代优化时间:",elpased) 152 | 153 | return p 154 | 155 | -------------------------------------------------------------------------------- /cnn_geometric/cnn_geometric_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torchvision.models as models 6 | 7 | class FeatureExtraction(torch.nn.Module): 8 | def __init__(self, use_cuda=True, feature_extraction_cnn='vgg', last_layer=''): 9 | super(FeatureExtraction, self).__init__() 10 | if feature_extraction_cnn == 'vgg': 11 | self.model = models.vgg16(pretrained=True) 12 | # keep feature extraction network up to indicated layer 13 | vgg_feature_layers=['conv1_1','relu1_1','conv1_2','relu1_2','pool1','conv2_1', 14 | 'relu2_1','conv2_2','relu2_2','pool2','conv3_1','relu3_1', 15 | 'conv3_2','relu3_2','conv3_3','relu3_3','pool3','conv4_1', 16 | 'relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4', 17 | 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5'] 18 | if last_layer=='': 19 | last_layer = 'pool4' 20 | last_layer_idx = vgg_feature_layers.index(last_layer) 21 | self.model = nn.Sequential(*list(self.model.features.children())[:last_layer_idx+1]) 22 | if feature_extraction_cnn == 'resnet101': 23 | self.model = models.resnet101(pretrained=True) 24 | resnet_feature_layers = ['conv1', 25 | 'bn1', 26 | 'relu', 27 | 'maxpool', 28 | 'layer1', 29 | 'layer2', 30 | 'layer3', 31 | 'layer4'] 32 | if last_layer=='': 33 | last_layer = 'layer3' 34 | last_layer_idx = resnet_feature_layers.index(last_layer) 35 | resnet_module_list = [self.model.conv1, 36 | self.model.bn1, 37 | self.model.relu, 38 | self.model.maxpool, 39 | self.model.layer1, 40 | self.model.layer2, 41 | self.model.layer3, 42 | self.model.layer4] 43 | 44 | self.model = nn.Sequential(*resnet_module_list[:last_layer_idx+1]) 45 | # freeze parameters 46 | for param in self.model.parameters(): 47 | param.requires_grad = False 48 | # move to GPU 49 | if use_cuda: 50 | self.model.cuda() 51 | 52 | def forward(self, image_batch): 53 | return self.model(image_batch) 54 | 55 | class FeatureL2Norm(torch.nn.Module): 56 | def __init__(self): 57 | super(FeatureL2Norm, self).__init__() 58 | 59 | def forward(self, feature): 60 | epsilon = 1e-6 61 | # print(feature.size()) 62 | # print(torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).size()) 63 | norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature) 64 | return torch.div(feature,norm) 65 | 66 | class FeatureCorrelation(torch.nn.Module): 67 | def __init__(self): 68 | super(FeatureCorrelation, self).__init__() 69 | 70 | def forward(self, feature_A, feature_B): 71 | b,c,h,w = feature_A.size() 72 | # reshape features for matrix multiplication 73 | feature_A = feature_A.transpose(2,3).contiguous().view(b,c,h*w) 74 | feature_B = feature_B.view(b,c,h*w).transpose(1,2) 75 | # perform matrix mult. 76 | feature_mul = torch.bmm(feature_B,feature_A) 77 | correlation_tensor = feature_mul.view(b,h,w,h*w).transpose(2,3).transpose(1,2) 78 | return correlation_tensor 79 | 80 | class FeatureRegression(nn.Module): 81 | def __init__(self, output_dim=6, use_cuda=True): 82 | super(FeatureRegression, self).__init__() 83 | self.conv = nn.Sequential( 84 | nn.Conv2d(225, 128, kernel_size=7, padding=0), 85 | nn.BatchNorm2d(128), 86 | nn.ReLU(inplace=True), 87 | nn.Conv2d(128, 64, kernel_size=5, padding=0), 88 | nn.BatchNorm2d(64), 89 | nn.ReLU(inplace=True), 90 | ) 91 | self.linear = nn.Linear(64 * 5 * 5, output_dim) 92 | if use_cuda: 93 | self.conv.cuda() 94 | self.linear.cuda() 95 | 96 | def forward(self, x): 97 | x = self.conv(x) 98 | x = x.view(x.size(0), -1) 99 | x = self.linear(x) 100 | return x 101 | 102 | class CNNGeometric(nn.Module): 103 | def __init__(self, geometric_model='affine', normalize_features=True, normalize_matches=True, batch_normalization=True, use_cuda=True, feature_extraction_cnn='vgg'): 104 | super(CNNGeometric, self).__init__() 105 | self.use_cuda = use_cuda 106 | self.normalize_features = normalize_features 107 | self.normalize_matches = normalize_matches 108 | self.FeatureExtraction = FeatureExtraction(use_cuda=self.use_cuda, feature_extraction_cnn=feature_extraction_cnn) 109 | self.FeatureL2Norm = FeatureL2Norm() 110 | self.FeatureCorrelation = FeatureCorrelation() 111 | if geometric_model=='affine': 112 | output_dim = 6 113 | elif geometric_model=='tps': 114 | output_dim = 18 115 | self.FeatureRegression = FeatureRegression(output_dim,use_cuda=self.use_cuda) 116 | self.ReLU = nn.ReLU(inplace=True) 117 | 118 | def forward(self, tnf_batch): 119 | # do feature extraction 120 | feature_A = self.FeatureExtraction(tnf_batch['source_image']) 121 | feature_B = self.FeatureExtraction(tnf_batch['target_image']) 122 | # normalize 123 | if self.normalize_features: 124 | feature_A = self.FeatureL2Norm(feature_A) 125 | feature_B = self.FeatureL2Norm(feature_B) 126 | # do feature correlation 127 | correlation = self.FeatureCorrelation(feature_A,feature_B) 128 | # normalize 129 | if self.normalize_matches: 130 | correlation = self.FeatureL2Norm(self.ReLU(correlation)) 131 | # correlation = self.FeatureL2Norm(correlation) 132 | # do regression to tnf parameters theta 133 | theta = self.FeatureRegression(correlation) 134 | 135 | return theta 136 | -------------------------------------------------------------------------------- /datasets/dataset_process/export_image_pairs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import torch 5 | from skimage import io 6 | import scipy.io as scio 7 | 8 | from ntg_pytorch.register_func import affine_transform 9 | from tnf_transform.img_process import generate_affine_param 10 | from tnf_transform.transformation import affine_transform_opencv, affine_transform_opencv_2, affine_transform_pytorch, \ 11 | AffineTnf 12 | from traditional_ntg.image_util import symmetricImagePad 13 | from util.csv_opeartor import read_csv_file 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | 17 | from util.pytorchTcv import param2theta, inverse_theta 18 | from visualization.train_visual import VisdomHelper 19 | 20 | 21 | def save_image_tensor(image_batch,output_name): 22 | if isinstance(image_batch, torch.Tensor): 23 | image_batch = image_batch.squeeze().numpy() 24 | if len(image_batch.shape) == 3: 25 | image_batch = image_batch.transpose(1, 2, 0) 26 | 27 | io.imsave(output_name,image_batch) 28 | 29 | def get_image_information(image_dir,image_name,label_path,vis): 30 | image_path = os.path.join(image_dir,image_name) 31 | 32 | image_np = io.imread(image_path) 33 | csv_data = read_csv_file(label_path) 34 | 35 | label_row_param = csv_data.loc[csv_data['image'] == image_name].values 36 | label_row_param = np.squeeze(label_row_param) 37 | 38 | if image_name != label_row_param[0]: 39 | raise ValueError("图片文件名和label图片文件名不匹配") 40 | 41 | theta_aff = torch.from_numpy(label_row_param[1:].reshape(2, 3).astype(np.float32)).unsqueeze(0) 42 | 43 | 44 | image_batch = torch.from_numpy(image_np).transpose(1,2).transpose(0,1).unsqueeze(0).float() 45 | 46 | vis.showImageBatch(image_batch, win='image_batch', title='raw_image_batch') 47 | 48 | crop_factor = 9 / 16 49 | padding_factor = 0.6 50 | # crop_factor = 3 51 | # padding_factor = 0.9 52 | 53 | padding_image_batch = symmetricImagePad(image_batch,padding_factor=padding_factor) 54 | 55 | affTnf = AffineTnf(446, 640,use_cuda=False) 56 | 57 | # 变换以后超出范围自动变为0 58 | source_image_batch = affTnf(padding_image_batch, theta_aff, padding_factor,crop_factor) 59 | target_image_batch = affTnf(padding_image_batch, None, padding_factor,crop_factor) 60 | 61 | # inverse_theta_aff = inverse_theta(theta_aff,use_cuda=False) 62 | warped_image_batch = affTnf(target_image_batch,theta_aff,crop_factor=1,padding_factor=1) 63 | 64 | vis.showImageBatch(source_image_batch,win='source_image_batch',title='source_image_batch') 65 | vis.showImageBatch(target_image_batch,win='target_image_batch',title='target_image_batch') 66 | vis.showImageBatch(warped_image_batch,win='warped_image_batch',title='warped_image_batch') 67 | 68 | # save_image_tensor(image_batch,'raw.jpg') 69 | save_image_tensor(source_image_batch,'source.jpg') 70 | save_image_tensor(target_image_batch,'target.jpg') 71 | save_image_tensor(warped_image_batch,'warped2.jpg') 72 | 73 | def save_matlab_pic(image_data,theta_aff): 74 | image_batch = torch.from_numpy(image_data).transpose(1, 2).transpose(0, 1).unsqueeze(1).float() 75 | vis.showImageBatch(image_batch, win='image_batch', title='raw_image_batch',start_index=16) 76 | 77 | crop_factor = 9 / 16 78 | padding_factor = 0.6 79 | padding_image_batch = symmetricImagePad(image_batch, padding_factor=padding_factor) 80 | affTnf = AffineTnf(240, 240, use_cuda=False) 81 | # 变换以后超出范围自动变为0 82 | source_image_batch = affTnf(padding_image_batch, None, padding_factor, crop_factor) 83 | target_image_batch = affTnf(padding_image_batch, theta_aff, padding_factor, crop_factor) 84 | 85 | vis.showImageBatch(source_image_batch, win='source_image_batch', title='source_image_batch',start_index=16) 86 | vis.showImageBatch(target_image_batch, win='target_image_batch', title='target_image_batch',start_index=16) 87 | 88 | save_image_tensor(source_image_batch[16], 'mul_1s_s.png') 89 | save_image_tensor(target_image_batch[16], 'mul_1t_s.png') 90 | 91 | 92 | def read_matlab_data(data_path): 93 | image_name_list = os.listdir(data_path) 94 | mat_image_path = os.path.join(data_path, image_name_list[0]) 95 | print(mat_image_path) 96 | array_struct = scio.loadmat(mat_image_path) 97 | array_data = array_struct['ms_image_denoised'] 98 | return array_data 99 | 100 | def generate_matlab_pair(): 101 | 102 | data_dir = '/mnt/4T/zlk/datasets/mulitspectral/Harvard' 103 | array_data = read_matlab_data(data_dir) 104 | small = True 105 | if small: 106 | theta = generate_affine_param(scale=1.1, degree=10, translate_x=-10, translate_y=10) 107 | else: 108 | theta = generate_affine_param(scale=1.25, degree=30, translate_x=-20, translate_y=20) 109 | 110 | theta = torch.from_numpy(theta).float() 111 | a,b = theta.shape 112 | theta = theta.expand(array_data.shape[2],a,b) 113 | theta = param2theta(theta,240,240,use_cuda=False) 114 | save_matlab_pic(array_data,theta) 115 | 116 | def generate_cave_pair(): 117 | 118 | image_dir = '/Users/zale/project/datasets/complete_ms_data/oil_painting_ms/oil_painting_ms/' 119 | image_name = "oil_painting_ms_31.png" 120 | image_np = cv2.imread(image_dir+image_name) 121 | height,width,channel = image_np.shape 122 | small = False 123 | center = (height//2, width//2) 124 | if small: 125 | theta = generate_affine_param(scale=1.1, degree=10, translate_x=-10, translate_y=10,center=center) 126 | else: 127 | theta = generate_affine_param(scale=1.25, degree=30, translate_x=-20, translate_y=20,center=center) 128 | 129 | warped_np = cv2.warpAffine(image_np,theta,(width,height),flags=cv2.INTER_CUBIC) 130 | cv2.imwrite(image_name, warped_np) 131 | 132 | if __name__ == '__main__': 133 | 134 | env = "export_image_pairs" 135 | #vis = VisdomHelper(env) 136 | 137 | use_remote = True 138 | 139 | if use_remote: 140 | image_dir = '/home/zlk/datasets/coco_test2017_n2000' 141 | label_path = '../../datasets/row_data/label_file/coco_test2017_n2000_custom_20r_param.csv' 142 | else: 143 | image_dir = '/Users/zale/project/myself/registration_cnn_ntg/datasets/row_data/multispectral' 144 | label_path = '/Users/zale/project/myself/registration_cnn_ntg/datasets/row_data/label_file/coco_test2017_n2000_custom_20r_param.csv' 145 | 146 | 147 | 148 | image_name = "000000007226.jpg" 149 | 150 | 151 | # get_image_information(image_dir,image_name,label_path,vis) 152 | # generate_matlab_pair() 153 | generate_cave_pair() 154 | 155 | -------------------------------------------------------------------------------- /datasets/provider/nirrgbData.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from skimage import io 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils.data import Dataset 8 | import time 9 | import cv2 10 | 11 | from random import choice 12 | from pathlib import Path 13 | 14 | from tqdm import tqdm 15 | 16 | from tnf_transform.img_process import random_affine, generator_affine_param 17 | from tnf_transform.transformation import AffineTnf 18 | from traditional_ntg.image_util import symmetricImagePad 19 | from util.csv_opeartor import read_csv_file 20 | from util.time_util import calculate_diff_time 21 | 22 | ''' 23 | 用作训练用,随机产生训练仿射变换参数然后输入网络进行训练。 24 | 作为dataloader的参数输入,自定义getitem得到Sample{image,theta,name} 25 | ''' 26 | class NirRgbData(Dataset): 27 | 28 | 29 | def __init__(self,nir_path,rgb_path,label_path,output_size=(480,640),paper_affine_generator = False,transform=None,use_cuda = True): 30 | ''' 31 | :param training_image_path: 32 | :param output_size: 33 | :param transform: 34 | :param cache_images: 如果数据量不是特别大可以缓存到内存里面加快读取速度 35 | :param use_cuda: 36 | ''' 37 | self.out_h, self.out_w = output_size 38 | self.use_cuda = use_cuda 39 | self.paper_affine_generator = paper_affine_generator 40 | # read image file 41 | self.nir_image_path = nir_path 42 | self.rgb_image_path = rgb_path 43 | self.nir_image_name_list = sorted(os.listdir(self.nir_image_path)) 44 | self.rgb_image_name_list = sorted(os.listdir(self.rgb_image_path)) 45 | self.csv_data = read_csv_file(label_path) 46 | self.image_count = len(self.nir_image_name_list) 47 | self.transform = transform 48 | 49 | def __len__(self): 50 | return self.image_count 51 | 52 | def __getitem__(self, idx): 53 | 54 | nir_image_name = self.nir_image_name_list[idx] 55 | rgb_image_name = self.rgb_image_name_list[idx] 56 | 57 | nir_image_path = os.path.join(self.nir_image_path, nir_image_name) 58 | rgb_image_path = os.path.join(self.rgb_image_path, rgb_image_name) 59 | nir_image = cv2.imread(nir_image_path) # shape [h,w,c] 60 | rgb_image = cv2.imread(rgb_image_path) # shape [h,w,c] 61 | 62 | if nir_image.shape[0]!= self.out_h or nir_image.shape[1] != self.out_w: 63 | nir_image = cv2.resize(nir_image, (self.out_w, self.out_h), interpolation=cv2.INTER_LINEAR) 64 | 65 | if rgb_image.shape[0]!= self.out_h or rgb_image.shape[1] != self.out_w: 66 | rgb_image = cv2.resize(rgb_image, (self.out_w, self.out_h), interpolation=cv2.INTER_LINEAR) 67 | 68 | nir_image = nir_image[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 69 | nir_image = np.ascontiguousarray(nir_image, dtype=np.float32) # uint8 to float32 70 | nir_image = torch.from_numpy(nir_image) 71 | 72 | rgb_image = rgb_image[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 73 | rgb_image = np.ascontiguousarray(rgb_image, dtype=np.float32) # uint8 to float32 74 | rgb_image = torch.from_numpy(rgb_image) 75 | 76 | # if self.paper_affine_generator: 77 | # theta = generator_affine_param() 78 | # else: 79 | # theta = random_affine() 80 | label_row_param = self.csv_data.loc[self.csv_data['image'] == nir_image_name].values 81 | label_row_param = np.squeeze(label_row_param) 82 | if nir_image_name != label_row_param[0]: 83 | raise ValueError("图片文件名和label图片文件名不匹配") 84 | 85 | theta_aff = label_row_param[1:].reshape(2,3) 86 | 87 | theta_aff_tensor = torch.Tensor(theta_aff.astype(np.float32)) 88 | 89 | sample = {'nir_image': nir_image, 'rgb_image':rgb_image,'theta': theta_aff_tensor, 'name': nir_image_name} 90 | 91 | if self.transform: 92 | sample = self.transform(sample) 93 | 94 | return sample 95 | 96 | ''' 97 | 使用仿射变换参数生成图片对 98 | 返回{"source_image,traget_image,theta_GT,name"} 99 | ''' 100 | class NirRgbTnsPair(object): 101 | 102 | def __init__(self, use_cuda=True, crop_factor=9 / 16, output_size=(240, 240), 103 | padding_factor=0.6): 104 | self.use_cuda = use_cuda 105 | self.crop_factor = crop_factor 106 | self.padding_factor = padding_factor 107 | self.out_h, self.out_w = output_size 108 | self.channel_choicelist = [0,1,2] 109 | self.rescalingTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 110 | self.geometricTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 111 | 112 | def __call__(self, batch): 113 | nir_image_batch,rgb_image_batch,theta_batch,image_name = batch['nir_image'], batch['rgb_image'],batch['theta'],batch['name'] 114 | if self.use_cuda: 115 | nir_image_batch = nir_image_batch.cuda() 116 | rgb_image_batch = rgb_image_batch.cuda() 117 | theta_batch = theta_batch.cuda() 118 | 119 | b, c, h, w = nir_image_batch.size() 120 | 121 | # 为较大的采样区域生成对称填充图像 122 | rgb_image_batch_pad = symmetricImagePad(rgb_image_batch, self.padding_factor,use_cuda=self.use_cuda) 123 | nir_image_batch_pad = symmetricImagePad(nir_image_batch, self.padding_factor,use_cuda=self.use_cuda) 124 | 125 | # convert to variables 其中Tensor是原始数据,并不知道梯度计算等问题, 126 | # Variable里面有data,grad和grad_fn,其中data就是Tensor 127 | # image_batch = Variable(image_batch, requires_grad=False) 128 | # theta_batch = Variable(theta_batch, requires_grad=False) 129 | 130 | # indices_R = torch.tensor([choice(self.channel_choicelist)]) 131 | # indices_G = torch.tensor([choice(self.channel_choicelist)]) 132 | 133 | # indices_R = torch.tensor([1]) 134 | # indices_G = torch.tensor([0]) 135 | # 136 | # if self.use_cuda: 137 | # indices_R = indices_R.cuda() 138 | # indices_G = indices_G.cuda() 139 | # 140 | # rgb_image_batch_pad = torch.index_select(rgb_image_batch_pad, 1, indices_R) 141 | # nir_image_batch_pad = torch.index_select(nir_image_batch_pad, 1, indices_G) 142 | 143 | # 获取裁剪的图像 144 | cropped_image_batch = self.rescalingTnf(rgb_image_batch_pad, None, self.padding_factor, 145 | self.crop_factor) # Identity is used as no theta given 146 | # 获取裁剪变换的图像 147 | warped_image_batch = self.geometricTnf(nir_image_batch_pad, theta_batch, 148 | self.padding_factor, 149 | self.crop_factor) # Identity is used as no theta given 150 | 151 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch, 152 | 'name':image_name} -------------------------------------------------------------------------------- /tnf_transform/transformation.py: -------------------------------------------------------------------------------- 1 | from skimage import io 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | from torch.nn.modules.module import Module 6 | from torch.utils.data import Dataset 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | import cv2 10 | 11 | class AffineTnf(object): 12 | def __init__(self,out_h=240,out_w=240,use_cuda=True): 13 | self.out_h = out_h 14 | self.out_w = out_w 15 | self.use_cuda = use_cuda 16 | self.gridGen = AffineGridGen(out_h,out_w) 17 | self.theta_identity = torch.Tensor(np.expand_dims(np.array([[1, 0, 0], [0, 1, 0]]), 0).astype(np.float32)) 18 | if use_cuda: 19 | self.theta_identity = self.theta_identity.cuda() 20 | 21 | def __call__(self, image_batch, theta_batch=None,padding_factor=1.0, crop_factor=1.0): 22 | b, c, h, w = image_batch.size() 23 | if theta_batch is None: 24 | theta_batch = self.theta_identity 25 | theta_batch = theta_batch.expand(b, 2, 3) # 扩展维度,添加一个batch size维度 26 | #theta_batch = Variable(theta_batch, requires_grad=False) 27 | 28 | # 生成采样网格 29 | sampling_grid = self.gridGen(theta_batch) # theta_batch [1,2,3] sampling_grid [1,360,640,2] 30 | 31 | # 根据crop_factor和padding_factor重缩放网格 rescale grid according to crop_factor and padding_factor 32 | sampling_grid.data = sampling_grid.data * padding_factor * crop_factor 33 | # 采样变换图片 sample transformed image 34 | warped_image_batch = F.grid_sample(image_batch, 35 | sampling_grid) # image_batch[1,1,360,640] warped_image_batch[1,1,360,640] 36 | return warped_image_batch 37 | 38 | 39 | class AffineGridGen(Module): 40 | def __init__(self, out_h=240, out_w=240, out_ch=3): 41 | super(AffineGridGen, self).__init__() 42 | self.out_h = out_h 43 | self.out_w = out_w 44 | self.out_ch = out_ch 45 | self.factor = out_w/out_h 46 | 47 | def forward(self, theta): 48 | theta = theta.contiguous() 49 | batch_size = theta.size()[0] 50 | out_size = torch.Size((batch_size, self.out_ch, self.out_h, self.out_w)) 51 | return F.affine_grid(theta, out_size) 52 | 53 | # 通过pytorch的仿射参数生成xy方向的网格点,暂时没用到 54 | def generate_grid(theta,out_h=240,out_w=240): # theta cuda:0(16,2,3) 55 | factor = out_w/out_h 56 | batch_size = theta.size()[0] 57 | identity_theta = torch.zeros(batch_size,2,3) 58 | identity_theta += torch.Tensor(np.expand_dims(np.array([[1, 0, 0], [0, 1, 0]]), 0).astype(np.float32)) 59 | identity_theta = identity_theta.cuda() 60 | 61 | custom_grid = F.affine_grid(identity_theta, (theta.size()[0], 1, out_h, out_w)) 62 | custom_grid = custom_grid + 1 # (1,339,568,2) 63 | vector_x = custom_grid[:,:, :, 0].reshape(batch_size,-1) * factor 64 | vector_y = custom_grid[:,:, :, 1].reshape(batch_size,-1) 65 | vector_ones = torch.ones(vector_x.size()) 66 | vector_cat = torch.stack((vector_x.double().cuda(), vector_y.double().cuda(), vector_ones.double().cuda()),1).float() 67 | 68 | result = torch.bmm(theta, vector_cat) 69 | result[:,0, :] = result[:,0, :] / factor - 1 70 | result[:,1, :] = result[:,1, :] - 1 71 | result = result.reshape(batch_size,2, out_h, out_w) 72 | result = result.permute(0, 2, 3, 1).float() 73 | return result 74 | 75 | 76 | # 使用pytorch参数的仿射变换 77 | def affine_transform_pytorch(image_batch,theta_batch): 78 | ''' 79 | :param image_batch: 图片batch Tensor[batch_size,C,240,240] 80 | :param theta_batch: 参数batch Tensor[batch_size,2,3] 81 | :return: 变换图片batch warped_image_batch Tensor[batch_size,C,240,240] 82 | ''' 83 | theta_batch = theta_batch.reshape(-1, 2, 3) 84 | _, channel, height, width = image_batch.shape 85 | 86 | theta_batch = theta_batch.contiguous() 87 | batch_size = theta_batch.size()[0] 88 | out_size = torch.Size((batch_size, channel, height, width)) 89 | affine_grid = F.affine_grid(theta_batch, out_size) 90 | 91 | warped_image_batch = F.grid_sample(image_batch, affine_grid) 92 | 93 | return warped_image_batch 94 | 95 | def affine_transform_opencv_batch(image_batch,theta_batch,use_cuda = False): 96 | ''' 97 | :param image_batch: 图片batch Tensor[batch_size,C,240,240] 98 | :param theta_batch: 参数batch Tensor[batch_size,2,3] 99 | :return: 变换图片batch warped_image_batch Tensor[batch_size,C,240,240] 100 | ''' 101 | 102 | image_batch = image_batch.cpu() 103 | theta_batch = theta_batch.cpu() 104 | 105 | 106 | theta_batch = theta_batch.reshape(-1, 2, 3) 107 | batch_size, channel, height, width = image_batch.shape 108 | 109 | image_batch = image_batch.numpy().transpose((0,2,3,1)) 110 | theta_batch = theta_batch.numpy() 111 | 112 | warped_image_batch = [] 113 | 114 | for i in range(batch_size): 115 | warped_image = cv2.warpAffine(image_batch[i],theta_batch[i],(width,height),flags=cv2.INTER_CUBIC) 116 | if len(warped_image.shape)==2: 117 | warped_image = warped_image[:,:,np.newaxis] 118 | warped_image_batch.append(warped_image) 119 | 120 | warped_image_batch = np.array(warped_image_batch).transpose((0,3,1,2)) 121 | warped_image_batch= torch.from_numpy(warped_image_batch).float() 122 | 123 | if use_cuda: 124 | warped_image_batch = warped_image_batch.cuda() 125 | return warped_image_batch 126 | 127 | # 使用opencv的仿射变换 128 | def single_affine_transform_opencv(im, p): 129 | height = im.shape[0] 130 | width = im.shape[1] 131 | im = cv2.warpAffine(im,p,(width,height)) 132 | return im 133 | 134 | ''' 135 | 准备弃用 136 | ''' 137 | def affine_transform_opencv(image_batch, theta_batch): 138 | ''' 139 | :param image_batch: Tensor[batch_size,C,240,240] 140 | :param theta_batch: Tensor[batch_size,2,3] 141 | :return: warped_img_batch: Tensor[batch_size,240,240] 142 | ''' 143 | image_batch = image_batch.squeeze(1).detach().cpu().numpy() 144 | theta_batch = torch.Tensor(theta_batch) 145 | warped_img_batch = [] 146 | for i in range(len(image_batch)): 147 | source_img = image_batch[i].squeeze() 148 | height = source_img.shape[0] 149 | width = source_img.shape[1] 150 | warped_img = cv2.warpAffine(source_img, theta_batch[i].numpy(), (width, height))[np.newaxis,:,:] 151 | warped_img_batch.append(warped_img) 152 | 153 | warped_img_batch = torch.Tensor(warped_img_batch) 154 | 155 | return warped_img_batch 156 | 157 | def affine_transform_opencv_2(image_batch, theta_batch): 158 | ''' 159 | :param image_batch: Tensor[batch_size,C,240,240] 160 | :param theta_batch: Tensor[batch_size,2,3] 161 | :return: warped_img_batch: Tensor[batch_size,C,240,240] 162 | ''' 163 | image_batch = image_batch.detach().cpu().numpy() 164 | theta_batch = torch.Tensor(theta_batch) 165 | warped_img_batch = [] 166 | for i in range(len(image_batch)): 167 | source_img = image_batch[i].transpose(1,2,0) 168 | height = source_img.shape[0] 169 | width = source_img.shape[1] 170 | warped_img = cv2.warpAffine(source_img, theta_batch[i].numpy(), (width, height)) 171 | warped_img_batch.append(warped_img.transpose(2,0,1)) 172 | 173 | warped_img_batch = torch.Tensor(warped_img_batch) 174 | 175 | return warped_img_batch -------------------------------------------------------------------------------- /model/cnn_registration_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torchvision.models as models 6 | import numpy as np 7 | 8 | """ 9 | 特征提取器,去除网络的最后一层如全连接层,然后冻结参数,算是迁移学习。 10 | """ 11 | 12 | class FeatureExtraction(torch.nn.Module): 13 | def __init__(self, use_cuda=True, feature_extraction_cnn='vgg', last_layer='',single_channel = False): 14 | super(FeatureExtraction, self).__init__() 15 | if feature_extraction_cnn == 'vgg': 16 | self.model = models.vgg16(pretrained=True) 17 | vgg_feature_layers = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 18 | 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 19 | 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 20 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 21 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'pool5'] 22 | if last_layer == '': 23 | last_layer = 'pool4' 24 | last_layer_idx = vgg_feature_layers.index(last_layer) 25 | self.model = nn.Sequential(*list(self.model.features.children())[:last_layer_idx + 1]) 26 | if feature_extraction_cnn == 'resnet101': 27 | self.model = models.resnet101(pretrained=True) 28 | # 为了符合单通道的图片,所以这边修改网络的channel 29 | if single_channel: 30 | self.model.conv1.in_channels=1 31 | self.model.conv1.weight.data = self.model.conv1.weight.data[:,0,:,:][:,np.newaxis,:,:] 32 | # 33 | resnet_feature_layers = ['conv1', 34 | 'bn1', 35 | 'relu', 36 | 'maxpool', 37 | 'layer1', 38 | 'layer2', 39 | 'layer3', 40 | 'layer4'] 41 | if last_layer == '': 42 | last_layer = 'layer3' 43 | last_layer_idx = resnet_feature_layers.index(last_layer) 44 | resnet_module_list = [self.model.conv1, 45 | self.model.bn1, 46 | self.model.relu, 47 | self.model.maxpool, 48 | self.model.layer1, 49 | self.model.layer2, 50 | self.model.layer3, 51 | self.model.layer4] 52 | 53 | self.model = nn.Sequential(*resnet_module_list[:last_layer_idx + 1]) 54 | # freeze parameters 55 | for param in self.model.parameters(): 56 | param.requires_grad = False 57 | # move to GPU 58 | if use_cuda: 59 | self.model.cuda() 60 | 61 | def forward(self, image_batch): 62 | return self.model(image_batch) 63 | 64 | """ 65 | 特征L2范数归一化,就是把每个特征值除以L2范数,norm= sqrt(x1^2+x2^2+...xn^2) 66 | """ 67 | class FeatureL2Norm(torch.nn.Module): 68 | def __init__(self): 69 | super(FeatureL2Norm, self).__init__() 70 | 71 | def forward(self, feature): 72 | epsilon = 1e-6 73 | # print(feature.size()) 74 | # print(torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).size()) 75 | norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature) 76 | return torch.div(feature,norm) 77 | 78 | """ 79 | 特征矩阵之间进行相关操作 80 | """ 81 | class FeatureCorrelation(torch.nn.Module): 82 | def __init__(self): 83 | super(FeatureCorrelation, self).__init__() 84 | 85 | def forward(self, feature_A, feature_B): 86 | b, c, h, w = feature_A.size() 87 | # reshape features for matrix multiplication 88 | feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h * w) # layer3 ([16, 1024, 225]) layer2 ([16, 512, 900]) 89 | feature_B = feature_B.view(b, c, h * w).transpose(1, 2) # layer3 ([16, 225, 1024]) layer2 ([16, 900, 512]) 90 | # perform matrix mult. 91 | feature_mul = torch.bmm(feature_B, feature_A) # 如果batch1是形为b×n×m的张量,batch1是形为b×m×p的张量,则out和mat的形状都是n×p, layer3 ([16, 225, 225]) layer2(16,900,900) 92 | correlation_tensor = feature_mul.view(b, h, w, h * w).transpose(2, 3).transpose(1, 2) # layer3 ([16, 225, 15, 15]) 93 | return correlation_tensor 94 | 95 | """ 96 | 特征回归器,通过两个卷积层一个全连接层回归出6个参数 97 | """ 98 | class FeatureRegression(nn.Module): 99 | def __init__(self, output_dim=6, use_cuda=True): 100 | super(FeatureRegression, self).__init__() 101 | self.conv = nn.Sequential( 102 | nn.Conv2d(225, 128, kernel_size=7, padding=0), # layer2 (16,900,30,30) layer3 ([16, 225, 15, 15]) 103 | #nn.Conv2d(900, 128, kernel_size=7, padding=0), 104 | nn.BatchNorm2d(128), 105 | nn.ReLU(inplace=True), 106 | nn.Conv2d(128, 64, kernel_size=5, padding=0), 107 | nn.BatchNorm2d(64), 108 | nn.ReLU(inplace=True), 109 | ) 110 | self.linear = nn.Linear(64 * 5 * 5, output_dim) # layer3 ([16, 225, 15, 15])-> (16,64,5,5) 111 | #self.linear = nn.Linear(64 * 20 * 20, output_dim) # layer3 (16,900,30,30)-> (16,64,20,20) 112 | if use_cuda: 113 | self.conv.cuda() 114 | self.linear.cuda() 115 | 116 | def forward(self, x): 117 | x = self.conv(x) 118 | x = x.view(x.size(0), -1) 119 | x = self.linear(x) 120 | return x 121 | 122 | class CNNRegistration(nn.Module): 123 | def __init__(self,single_channel,normalize_features=True,normalize_matches=True,use_cuda=True,feature_extraction_cnn='resnet101'): 124 | super(CNNRegistration,self).__init__() 125 | self.use_cuda = use_cuda 126 | self.normalize_features = normalize_features 127 | self.normalize_matches = normalize_matches 128 | self.FeatureExtraction = FeatureExtraction(use_cuda=self.use_cuda,feature_extraction_cnn=feature_extraction_cnn,single_channel=single_channel) 129 | self.FeatureL2Norm = FeatureL2Norm() 130 | self.FeatureCorrelation = FeatureCorrelation() 131 | output_dim = 6 # 通过全连接层回归出6个参数 132 | self.FeatureRegression = FeatureRegression(output_dim,use_cuda=self.use_cuda) 133 | self.ReLu = nn.ReLU(inplace=True) 134 | 135 | def forward(self,img_batch): 136 | # 做特征提取 137 | feature_A = self.FeatureExtraction(img_batch['source_image']) # layer3 ([16, 1024, 15, 15]) layer2 ([16, 512, 30, 30]) 138 | feature_B = self.FeatureExtraction(img_batch['target_image']) 139 | 140 | # 特征归一化 141 | if self.normalize_features: 142 | feature_A = self.FeatureL2Norm(feature_A) 143 | feature_B = self.FeatureL2Norm(feature_B) 144 | 145 | # 对两幅图的特征进行相关操作 146 | correlation = self.FeatureCorrelation(feature_A,feature_B) # layer2 (16,900,30,30) layer3 ([16, 225, 15, 15]) 147 | 148 | # 对match相关结果归一化 149 | if self.normalize_matches: 150 | correlation = self.FeatureL2Norm(self.ReLu(correlation)) 151 | 152 | # 进行回归来生成六个仿射变换参数 153 | theta = self.FeatureRegression(correlation) 154 | 155 | return theta -------------------------------------------------------------------------------- /eval_pf.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Script to evaluate a trained model as presented in the CNNGeometric CVPR'17 paper 4 | on the ProposalFlow dataset 5 | 6 | """ 7 | import os 8 | 9 | import torch 10 | from torch.utils.data import DataLoader 11 | 12 | from datasets.provider.pf_dataset import PFDataset 13 | from main.test_mulit_images import createModel 14 | 15 | 16 | 17 | # Compute PCK 18 | from ntg_pytorch.register_func import estimate_aff_param_iterator 19 | from tnf_transform.img_process import NormalizeImage, NormalizeImageDict 20 | from tnf_transform.point_tnf import PointTnf, PointsToUnitCoords, PointsToPixelCoords 21 | from tnf_transform.transformation import affine_transform_pytorch 22 | from traditional_ntg.estimate_affine_param import estimate_param_batch 23 | from util.pytorchTcv import theta2param, param2theta 24 | from util.torch_util import BatchTensorToVars 25 | 26 | 27 | def correct_keypoints(source_points,warped_points,L_pck,alpha=0.1): 28 | # compute correct keypoints 29 | point_distance = torch.pow(torch.sum(torch.pow(source_points-warped_points,2),1),0.5).squeeze(1) 30 | L_pck_mat = L_pck.expand_as(point_distance) 31 | correct_points = torch.le(point_distance,L_pck_mat*alpha) 32 | num_of_correct_points = torch.sum(correct_points) 33 | num_of_points = correct_points.numel() 34 | return (num_of_correct_points,num_of_points) 35 | 36 | def main(): 37 | print("eval pf dataset") 38 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 39 | 40 | # ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/output/voc2012_coco2014_NTG_resnet101.pth.tar" 41 | # ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_NTG_resnet101.pth.tar" 42 | # ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_20r_NTG_resnet101.pth.tar" 43 | # ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/three_channel/checkpoint_NTG_resnet101.pth.tar' 44 | small_aff_ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/three_channel/coco2014_small_aff_checkpoint_NTG_resnet101.pth.tar' 45 | ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/best_checkpoint_voc2011_three_channel_paper_NTG_resnet101.pth.tar' 46 | # ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011_paper_affine/best_checkpoint_voc2011_NTG_resnet101.pth.tar' 47 | 48 | #ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_30r_NTG_resnet101.pth.tar" 49 | # image_path = '../datasets/row_data/VOC/3 50 | # label_path = '../datasets/row_data/label_file/aff_param2.csv' 51 | #image_path = '../datasets/row_data/COCO/' 52 | #label_path = '../datasets/row_data/label_file/aff_param_coco.csv' 53 | 54 | pf_data_path = 'datasets/row_data/pf_data' 55 | 56 | batch_size = 128 57 | # 加载模型 58 | use_cuda = torch.cuda.is_available() 59 | 60 | ntg_model = createModel(ntg_checkpoint_path,use_cuda=use_cuda) 61 | small_aff_ntg_model = createModel(small_aff_ntg_checkpoint_path,use_cuda=use_cuda) 62 | 63 | dataset = PFDataset(csv_file=os.path.join(pf_data_path,'test_pairs_pf.csv'), 64 | training_image_path=pf_data_path, 65 | transform=NormalizeImageDict(['source_image','target_image'])) 66 | 67 | dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=4) 68 | 69 | batchTensorToVars = BatchTensorToVars(use_cuda=use_cuda) 70 | 71 | pt = PointTnf(use_cuda= use_cuda) 72 | 73 | print('Computing PCK...') 74 | total_correct_points_aff = 0 75 | ntg_total_correct_points_aff = 0 76 | cnn_ntg_total_correct_points_aff = 0 77 | total_correct_points_tps = 0 78 | total_correct_points_aff_tps = 0 79 | total_points = 0 80 | ntg_total_points = 0 81 | cnn_ntg_total_points = 0 82 | 83 | for i,batch in enumerate(dataloader): 84 | batch = batchTensorToVars(batch) 85 | source_im_size = batch['source_im_size'] 86 | target_im_size = batch['target_im_size'] 87 | 88 | source_points = batch['source_points'] 89 | target_points = batch['target_points'] 90 | 91 | source_image_batch = batch['source_image'] 92 | target_image_batch = batch['target_image'] 93 | 94 | # warp points with estimated transformations 95 | target_points_norm = PointsToUnitCoords(target_points, target_im_size) 96 | 97 | theta_estimate_batch = ntg_model(batch) 98 | 99 | #warped_image_batch = affine_transform_pytorch(source_image_batch, theta_estimate_batch) 100 | #batch['source_image'] = warped_image_batch 101 | #theta_estimate_batch = small_aff_ntg_model(batch) 102 | 103 | # 将pytorch的变换参数转为opencv的变换参数 104 | #theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 105 | 106 | # P5使用传统NTG方法进行优化cnn的结果 107 | #cnn_ntg_param_batch = estimate_param_batch(source_image_batch, target_image_batch, theta_opencv,itermax = 600) 108 | #theta_pytorch = param2theta(cnn_ntg_param_batch.view(-1, 2, 3),240,240,use_cuda=use_cuda) 109 | 110 | # theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 111 | # with torch.no_grad(): 112 | # ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:, 0, :, :].unsqueeze(1), 113 | # target_image_batch[:, 0, :, :].unsqueeze(1), 114 | # None, use_cuda=use_cuda, itermax=600) 115 | # 116 | # cnn_ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:, 0, :, :].unsqueeze(1), 117 | # target_image_batch[:, 0, :, :].unsqueeze(1), 118 | # theta_opencv, use_cuda=use_cuda, itermax=600) 119 | # 120 | # ntg_param_pytorch_batch = param2theta(ntg_param_batch,240, 240, use_cuda=use_cuda) 121 | # cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch,240, 240, use_cuda=use_cuda) 122 | 123 | 124 | warped_points_aff_norm = pt.affPointTnf(theta_estimate_batch, target_points_norm) 125 | warped_points_aff = PointsToPixelCoords(warped_points_aff_norm, source_im_size) 126 | 127 | # ntg_warped_points_aff_norm = pt.affPointTnf(ntg_param_pytorch_batch, target_points_norm) 128 | # ntg_warped_points_aff = PointsToPixelCoords(ntg_warped_points_aff_norm, source_im_size) 129 | # 130 | # cnn_ntg_warped_points_aff_norm = pt.affPointTnf(cnn_ntg_param_pytorch_batch, target_points_norm) 131 | # cnn_ntg_warped_points_aff = PointsToPixelCoords(cnn_ntg_warped_points_aff_norm, source_im_size) 132 | 133 | L_pck = batch['L_pck'].data 134 | 135 | correct_points_aff, num_points = correct_keypoints(source_points.data, 136 | warped_points_aff.data, L_pck) 137 | # ntg_correct_points_aff, ntg_num_points = correct_keypoints(source_points.data, 138 | # ntg_warped_points_aff.data, L_pck) 139 | # cnn_ntg_correct_points_aff, cnn_ntg_num_points = correct_keypoints(source_points.data, 140 | # cnn_ntg_warped_points_aff.data, L_pck) 141 | 142 | total_correct_points_aff += correct_points_aff 143 | total_points += num_points 144 | 145 | # ntg_total_correct_points_aff += ntg_correct_points_aff 146 | # ntg_total_points += ntg_num_points 147 | # 148 | # cnn_ntg_total_correct_points_aff += cnn_ntg_correct_points_aff 149 | # cnn_ntg_total_points += cnn_ntg_num_points 150 | 151 | 152 | print('Batch: [{}/{} ({:.0f}%)]'.format(i, len(dataloader), 100. * i / len(dataloader))) 153 | 154 | total_correct_points_aff = total_correct_points_aff.__float__() 155 | # ntg_total_correct_points_aff = ntg_total_correct_points_aff.__float__() 156 | # cnn_ntg_total_correct_points_aff = cnn_ntg_total_correct_points_aff.__float__() 157 | 158 | PCK_aff=total_correct_points_aff/total_points 159 | # ntg_PCK_aff=ntg_total_correct_points_aff/ntg_total_points 160 | # cnn_ntg_PCK_aff=cnn_ntg_total_correct_points_aff/cnn_ntg_total_points 161 | print('PCK affine:',PCK_aff) 162 | # print('ntg_PCK affine:',ntg_PCK_aff) 163 | # print('cnn_ntg_PCK affine:',cnn_ntg_PCK_aff) 164 | print('Done!') 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /visualization/train_visual.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import visdom 4 | 5 | from tnf_transform.img_process import normalize_image 6 | import numpy as np 7 | 8 | class VisdomHelper: 9 | def __init__(self,env_name,port=8088): 10 | self.env_name = env_name 11 | self.vis = visdom.Visdom(env = self.env_name,port=port) 12 | self.index_slice = [0,5,10,15,20,25] 13 | 14 | def drawImage(self, source_image_batch, warped_image_batch, target_image_batch,single_channel = True,show_size=8): 15 | if source_image_batch.shape[0] > show_size: 16 | source_image_batch =source_image_batch[0:show_size] 17 | warped_image_batch =warped_image_batch[0:show_size] 18 | target_image_batch =target_image_batch[0:show_size] 19 | 20 | source_image_batch = normalize_image(source_image_batch,forward=False) 21 | warped_image_batch = normalize_image(warped_image_batch, forward=False) 22 | target_image_batch = normalize_image(target_image_batch, forward=False) 23 | 24 | source_image_batch = torch.clamp(source_image_batch,0,1) 25 | warped_image_batch = torch.clamp(warped_image_batch,0,1) 26 | target_image_batch = torch.clamp(target_image_batch,0,1) 27 | 28 | if single_channel: 29 | overlayImage = torch.cat((warped_image_batch, target_image_batch, warped_image_batch), 1) 30 | else: 31 | overlayImage = torch.cat((warped_image_batch[:,0:1,:,:],target_image_batch[:,1:2,:,:],warped_image_batch[:,0:1,:,:]),1) 32 | 33 | self.vis.images( 34 | overlayImage, win="overlay", 35 | opts=dict(title='overlay_image', caption='overlay.', width=1400, height=150, jpgquality=40) 36 | ) 37 | 38 | self.vis.images( 39 | source_image_batch, win="s1", 40 | opts=dict(title='source_image_batch', caption='source.', width=1400, height=150, jpgquality=40) 41 | ) 42 | self.vis.images( 43 | warped_image_batch, win="s2", 44 | opts=dict(title='warped_image_batch', caption='warped.', width=1400, height=150, jpgquality=40) 45 | ) 46 | 47 | self.vis.images( 48 | target_image_batch, win="s3", 49 | opts=dict(title='target_image_batch', caption='target.', width=1400, height=150, jpgquality=40) 50 | ) 51 | 52 | def drawLoss(self,epoch,train_loss): 53 | layout = dict(title="train_loss", xaxis={'title': 'epoch'}, yaxis={'title': 'train loss'}) 54 | self.vis.line(X=torch.IntTensor([epoch]), Y=torch.FloatTensor([train_loss]), win="lineloss", 55 | update='new' if epoch == 0 else 'append', opts=layout) 56 | 57 | def drawBothLoss(self,epoch,train_loss,test_loss,layout_title,x_axis = 'epoch',y_axis = 'loss',win='win'): 58 | layout = dict(title=layout_title, xaxis={'title': x_axis}, yaxis={'title': y_axis},legend=['train_loss','test_loss']) 59 | print(epoch, train_loss,test_loss) 60 | self.vis.line(X=np.column_stack((epoch,epoch)), Y=np.column_stack((train_loss,test_loss)), win=win, 61 | update='new' if epoch == 0 else 'append', opts=layout) 62 | 63 | def drawGridlossGroup(self,X_list,ntg_list,Y_list_A,Y_list_B,Y_list_cvpr, 64 | layout_title,x_axis ='grid_loss',y_axis ='num', 65 | win='result',update='append'): 66 | layout = dict(title=layout_title, xaxis={'title': x_axis}, yaxis={'title': y_axis}, 67 | legend=['ntg_grid_loss','cnn_grid_loss','cnn_ntg_grid_loss','cvpr_2018_loss'],xlabel=X_list) 68 | self.vis.line(X=np.column_stack((X_list, X_list, X_list,X_list)), 69 | Y=np.column_stack((ntg_list,Y_list_A, Y_list_B,Y_list_cvpr)), 70 | update=update, 71 | win=win,opts=layout) 72 | 73 | def drawGridlossBar(self,X_list,ntg_list,cnn_list_A,cnn_ntg_list_B,cvpr_list,layout_title,xlabel ='grid_loss(pixel)',ylabel ='percentage(%)', 74 | win='result',update='append'): 75 | 76 | layout = dict(title=layout_title, xlabel=xlabel, ylabel=ylabel, 77 | legend=['ntg_grid_loss', 'cnn_grid_loss', 'cnn_ntg_grid_loss', 'cvpr_2018_loss'], xaxis=X_list,stacked=False, 78 | update=update,win=win,xtickstep=1) 79 | 80 | self.vis.bar( 81 | X=np.column_stack((ntg_list, cnn_list_A, cnn_ntg_list_B, cvpr_list)),opts=layout) 82 | 83 | 84 | 85 | def getVisdom(self): 86 | return self.vis 87 | 88 | def show_cnn_result(self,source_image_batch, warped_image_batch,fine_warped_image_batch, target_image_batch,single_channel = True): 89 | source_image_batch = normalize_image(source_image_batch, forward=False) 90 | warped_image_batch = normalize_image(warped_image_batch, forward=False) 91 | target_image_batch = normalize_image(target_image_batch, forward=False) 92 | 93 | source_image_batch = torch.clamp(source_image_batch, 0, 1) 94 | warped_image_batch = torch.clamp(warped_image_batch, 0, 1) 95 | target_image_batch = torch.clamp(target_image_batch, 0, 1) 96 | 97 | if single_channel: 98 | overlayImage = torch.cat((warped_image_batch, target_image_batch, warped_image_batch), 1) 99 | else: 100 | overlayImage = torch.cat( 101 | (warped_image_batch[:, 0:1, :, :], target_image_batch[:, 1:2, :, :], warped_image_batch[:, 0:1, :, :]),1) 102 | 103 | self.vis.images( 104 | overlayImage[0:8], win="overlay", 105 | opts=dict(title='overlay_image', caption='overlay.', width=1400, height=150, jpgquality=40) 106 | ) 107 | 108 | self.vis.images( 109 | source_image_batch[0:8], win="s1", 110 | opts=dict(title='source_image_batch', caption='source.', width=1400, height=150, jpgquality=40) 111 | ) 112 | self.vis.images( 113 | warped_image_batch[0:8], win="s2", 114 | opts=dict(title='warped_image_batch', caption='warped.', width=1400, height=150, jpgquality=40) 115 | ) 116 | 117 | self.vis.images( 118 | fine_warped_image_batch[0:8], win="s2_fine", 119 | opts=dict(title='fine_warped_image_batch', caption='warped.', width=1400, height=150, jpgquality=40) 120 | ) 121 | 122 | self.vis.images( 123 | target_image_batch[0:8], win="s3", 124 | opts=dict(title='target_image_batch', caption='target.', width=1400, height=150, jpgquality=40) 125 | ) 126 | 127 | 128 | def showImageBatch(self,image_batch,win='image',title='image',normailze=False,show_num=8,start_index = 0): 129 | if normailze: 130 | image_batch = normalize_image(image_batch, forward=False) 131 | image_batch = torch.clamp(image_batch, 0, 1) 132 | 133 | # self.vis.images(image_batch[0:show_num],win=win,opts=dict(title=title, caption='image_batch', width=1400, height=150, jpgquality=40)) 134 | self.vis.images(image_batch[self.index_slice],win=win,opts=dict(title=title, caption='image_batch', width=1400, height=150, jpgquality=40)) 135 | 136 | def showHarvardBatch(self,image_batch,win='image',title='image',normailze=False,show_num=8,start_index = 0): 137 | if normailze: 138 | image_batch = normalize_image(image_batch, forward=False) 139 | image_batch = torch.clamp(image_batch, 0, 1) 140 | 141 | # self.vis.images(image_batch[0:show_num],win=win,opts=dict(title=title, caption='image_batch', width=1400, height=150, jpgquality=40)) 142 | self.vis.images(image_batch[start_index:start_index+show_num],win=win,opts=dict(title=title, caption='image_batch', width=1400, height=150, jpgquality=40)) 143 | 144 | #def draw_gridloss_group(self,x_list,y_list,win='table',title='table'): 145 | 146 | 147 | 148 | # def showImage(source_image_batch,warped_image_batch,target_image_batch,isShowRGB=True): 149 | # 150 | # source_image_batch = source_image_batch.cpu().detach().numpy() 151 | # warped_image_batch = warped_image_batch.cpu().detach().numpy() 152 | # target_image_batch = target_image_batch.cpu().detach().numpy() 153 | # fig, axs = plt.subplots(3, 4) 154 | # for i in range(4): 155 | # image1 = np.transpose(source_image_batch[i], (1, 2, 0)) 156 | # image2 = np.transpose(warped_image_batch[i], (1, 2, 0)) 157 | # image3 = np.transpose(target_image_batch[i], (1, 2, 0)) 158 | # if isShowRGB: 159 | # axs[0, i].imshow(image1) 160 | # axs[1, i].imshow(image2) 161 | # axs[2, i].imshow(image3) 162 | # else: 163 | # axs[0, i].imshow(image1.squeeze(),cmap='gray') 164 | # axs[1, i].imshow(image2.squeeze(),cmap='gray') 165 | # axs[2, i].imshow(image3.squeeze(),cmap='gray') 166 | # plt.show() 167 | 168 | -------------------------------------------------------------------------------- /visualization/visual_result.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io 3 | import numpy as np 4 | 5 | import torch 6 | from collections import OrderedDict 7 | import cv2 8 | from torch.utils.data import DataLoader 9 | import matplotlib.pyplot as plt 10 | 11 | from datasets.provider.randomTnsData import RandomTnsPair 12 | from datasets.provider.singlechannelData import SinglechannelData, SingleChannelPairTnf 13 | from datasets.provider.test_dataset import TestDataset 14 | from evluate.lossfunc import GridLoss, NTGLoss 15 | from evluate.visualize_result import visualize_compare_result, visualize_iter_result, visualize_spec_epoch_result, \ 16 | visualize_cnn_result 17 | from model.cnn_registration_model import CNNRegistration 18 | from tnf_transform.img_process import preprocess_image, NormalizeImage, NormalizeImageDict 19 | from tnf_transform.transformation import AffineTnf, affine_transform_opencv, affine_transform_pytorch, AffineGridGen 20 | from util.pytorchTcv import theta2param, param2theta 21 | from util.time_util import calculate_diff_time 22 | from traditional_ntg.estimate_affine_param import estimate_affine_param, estimate_param_batch 23 | from visualization.matplot_tool import plot_batch_result 24 | import time 25 | import torch.nn.functional as F 26 | 27 | from visualization.train_visual import VisdomHelper 28 | 29 | 30 | def createModel(ntg_checkpoint_path,use_cuda=True): 31 | ''' 32 | 创建模型 33 | :param ntg_checkpoint_path: 34 | :param use_cuda: 35 | :return: 36 | ''' 37 | ntg_model = CNNRegistration(use_cuda=use_cuda) 38 | 39 | print("Loading trained model weights") 40 | print("ntg_checkpoint_path:",ntg_checkpoint_path) 41 | 42 | # 把所有的张量加载到CPU中 GPU ==> CPU 43 | ntg_checkpoint = torch.load(ntg_checkpoint_path,map_location=lambda storage,loc: storage) 44 | ntg_checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in ntg_checkpoint['state_dict'].items()]) 45 | ntg_model.load_state_dict(ntg_checkpoint['state_dict']) 46 | 47 | ntg_model.eval() 48 | 49 | return ntg_model 50 | 51 | def createDataloader(image_path,label_path,batch_size = 16,use_cuda=True): 52 | ''' 53 | 创建dataloader 54 | :param image_path: 55 | :param label_path: 56 | :param batch_size: 57 | :param use_cuda: 58 | :return: 59 | ''' 60 | #dataset = SinglechannelData(image_path,label_path,transform=NormalizeImage(normalize_range=True, normalize_img=False)) 61 | dataset = TestDataset(image_path,label_path,transform=NormalizeImageDict(["image"])) 62 | dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True) 63 | #pair_generator = SingleChannelPairTnf(use_cuda=use_cuda) 64 | pair_generator = RandomTnsPair(use_cuda=use_cuda) 65 | 66 | return dataloader,pair_generator 67 | 68 | def compute_correct_rate(grid_loss_list,threshold = 20): 69 | correct_count = 0 70 | total_count = 0 71 | for grid in grid_loss_list: 72 | for item in grid: 73 | total_count += 1 74 | if item < threshold: 75 | correct_count += 1 76 | print('correct_rate:', correct_count / total_count) 77 | 78 | def compute_average_grid_loss(grid_loss_list): 79 | total_count = 0 80 | total_loss = 0 81 | for grid in grid_loss_list: 82 | for item in grid: 83 | total_count += 1 84 | total_loss += item 85 | print('平均网格点损失:', total_loss / total_count, total_loss,total_count) 86 | 87 | 88 | def iterDataset(dataloader,pair_generator,ntg_model,vis,threshold=10,use_cuda=True): 89 | ''' 90 | 迭代数据集中的批次数据,进行处理 91 | :param dataloader: 92 | :param pair_generator: 93 | :param ntg_model: 94 | :param use_cuda: 95 | :return: 96 | ''' 97 | 98 | grid_loss_hist = [] 99 | grid_loss_traditional_hist = [] 100 | 101 | loss_fn = NTGLoss() 102 | gridGen = AffineGridGen() 103 | 104 | grid_loss = GridLoss(use_cuda=use_cuda) 105 | grid_loss_list = [] 106 | grid_loss_ntg_list = [] 107 | grid_loss_comb_list = [] 108 | 109 | ntg_loss_total = 0 110 | 111 | # batch {image.shape = } 112 | for batch_idx,batch in enumerate(dataloader): 113 | #print("batch_id",batch_idx,'/',len(dataloader)) 114 | 115 | # if batch_idx == 2: 116 | # break 117 | 118 | if batch_idx % 5 == 0: 119 | print('test batch: [{}/{} ({:.0f}%)]'.format( 120 | batch_idx, len(dataloader), 121 | 100. * batch_idx / len(dataloader))) 122 | 123 | pair_batch = pair_generator(batch) # image[batch_size,1,w,h] theta_GT[batch_size,2,3] 124 | 125 | theta_estimate_batch = ntg_model(pair_batch) # theta [batch_size,6] 126 | 127 | source_image_batch = pair_batch['source_image'] 128 | target_image_batch = pair_batch['target_image'] 129 | theta_GT_batch = pair_batch['theta_GT'] 130 | 131 | sampling_grid = gridGen(theta_estimate_batch.view(-1,2,3)) 132 | warped_image_batch = F.grid_sample(source_image_batch, sampling_grid) 133 | 134 | loss, g1xy, g2xy = loss_fn(target_image_batch, warped_image_batch) 135 | #print("one batch ntg:",loss.item()) 136 | ntg_loss_total += loss.item() 137 | 138 | # 显示CNN配准结果 139 | # print("显示图片") 140 | visualize_cnn_result(source_image_batch,target_image_batch,theta_estimate_batch,vis) 141 | # # 142 | # time.sleep(10) 143 | # 显示一个epoch的对比结果 144 | #visualize_compare_result(source_image_batch,target_image_batch,theta_GT_batch,theta_estimate_batch,use_cuda=use_cuda) 145 | 146 | # 显示多个epoch的折线图 147 | #visualize_iter_result(source_image_batch,target_image_batch,theta_GT_batch,theta_estimate_batch,use_cuda=use_cuda) 148 | 149 | 150 | ## 计算网格点损失配准误差 151 | # 将pytorch的变换参数转为opencv的变换参数 152 | #theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 153 | 154 | # P5使用传统NTG方法进行优化cnn的结果 155 | #ntg_param = estimate_param_batch(source_image_batch,target_image_batch,None,itermax=600) 156 | #ntg_param_pytorch = param2theta(ntg_param,240,240,use_cuda=use_cuda) 157 | #cnn_ntg_param_batch = estimate_param_batch(source_image_batch, target_image_batch, theta_opencv,itermax=800) 158 | #cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch, 240, 240, use_cuda=use_cuda) 159 | 160 | loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch,theta_GT_batch) 161 | #loss_ntg = grid_loss.compute_grid_loss(ntg_param_pytorch,theta_GT_batch) 162 | #loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param_pytorch_batch,theta_GT_batch) 163 | 164 | grid_loss_list.append(loss_cnn.detach().cpu()) 165 | #grid_loss_ntg_list.append(loss_ntg) 166 | #grid_loss_comb_list.append(loss_cnn_ntg) 167 | ## 168 | 169 | # 显示特定epoch的gridloss的直方图 170 | # g_loss,g_trad_loss = visualize_spec_epoch_result(source_image_batch, target_image_batch, theta_GT_batch, theta_estimate_batch, 171 | # use_cuda=use_cuda) 172 | # grid_loss_hist.append(g_loss) 173 | # grid_loss_traditional_hist.append(g_trad_loss) 174 | 175 | # loss_cnn = grid_loss.compute_grid_loss(theta_estimate_batch,theta_GT_list) 176 | # 177 | # loss_cnn_ntg = grid_loss.compute_grid_loss(cnn_ntg_param,theta_GT_list) 178 | print("计算平均网格点损失") 179 | compute_average_grid_loss(grid_loss_list) 180 | print("计算平均NTG值",ntg_loss_total / len(dataloader)) 181 | 182 | print("计算正确率") 183 | compute_correct_rate(grid_loss_list,threshold=threshold) 184 | 185 | 186 | 187 | 188 | def main(): 189 | 190 | 191 | print("开始进行测试") 192 | 193 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 194 | 195 | #ntg_checkpoint_path = "../trained_weight/output/checkpoint_NTG_resnet101.pth.tar" 196 | ntg_checkpoint_path = "/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/checkpoint_voc2011_NTG_resnet101.pth.tar" 197 | image_path = '/home/zlk/datasets/coco_test2017' 198 | 199 | use_custom_aff_param = True 200 | if use_custom_aff_param: 201 | label_path = '../datasets/row_data/label_file/coco_test2017_custom_param.csv' 202 | else: 203 | label_path = '../datasets/row_data/label_file/coco_test2017_paper_param.csv' 204 | 205 | threshold = 10 206 | 207 | batch_size = 164 208 | # 加载模型 209 | use_cuda = torch.cuda.is_available() 210 | 211 | vis = VisdomHelper(env_name='DMN_test') 212 | 213 | ntg_model = createModel(ntg_checkpoint_path,use_cuda=use_cuda) 214 | dataloader,pair_generator = createDataloader(image_path,label_path,batch_size,use_cuda = use_cuda) 215 | iterDataset(dataloader,pair_generator,ntg_model,vis,threshold=threshold,use_cuda=use_cuda) 216 | 217 | if __name__ == '__main__': 218 | main() 219 | 220 | 221 | 222 | -------------------------------------------------------------------------------- /datasets/provider/randomTnsData.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from skimage import io 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils.data import Dataset 8 | import time 9 | import cv2 10 | 11 | from random import choice 12 | from pathlib import Path 13 | 14 | from tqdm import tqdm 15 | 16 | from tnf_transform.img_process import random_affine, generator_affine_param 17 | from tnf_transform.transformation import AffineTnf 18 | from traditional_ntg.image_util import symmetricImagePad 19 | from util.pytorchTcv import inverse_theta 20 | from util.time_util import calculate_diff_time 21 | 22 | ''' 23 | 用作训练用,随机产生训练仿射变换参数然后输入网络进行训练。 24 | 作为dataloader的参数输入,自定义getitem得到Sample{image,theta,name} 25 | ''' 26 | class RandomTnsData(Dataset): 27 | 28 | 29 | def __init__(self,training_image_path,output_size=(480,640),paper_affine_generator = False,transform=None,cache_images = False,use_cuda = True): 30 | ''' 31 | :param training_image_path: 32 | :param output_size: 33 | :param transform: 34 | :param cache_images: 如果数据量不是特别大可以缓存到内存里面加快读取速度 35 | :param use_cuda: 36 | ''' 37 | self.out_h, self.out_w = output_size 38 | self.use_cuda = use_cuda 39 | self.cache_images = cache_images 40 | self.paper_affine_generator = paper_affine_generator 41 | # read image file 42 | self.training_image_path = training_image_path 43 | self.train_data = os.listdir(self.training_image_path) 44 | self.image_count = len(self.train_data) 45 | # bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index 46 | # nb = bi[-1] + 1 # number of batches 47 | self.imgs = [None]*self.image_count 48 | self.transform = transform 49 | self.control_size = 1000000000 50 | 51 | # cache images into memory for faster training(~5GB) 52 | if self.cache_images: 53 | for i in tqdm(range(min(self.image_count,self.control_size)),desc='Reading images'): # 最多10k张图片 54 | image_name = self.train_data[i] 55 | image_path = os.path.join(self.training_image_path, image_name) 56 | image = cv2.imread(image_path) # shape [h,w,c] BGR 57 | assert image is not None, 'Image Not Found' + image_path 58 | self.imgs[i] = image 59 | 60 | def __len__(self): 61 | return len(self.train_data) 62 | 63 | def __getitem__(self, idx): 64 | 65 | #total_start_time = time.time() 66 | 67 | image_name = self.train_data[idx] 68 | 69 | if self.cache_images: 70 | image = self.imgs[idx] 71 | else: 72 | image_path = os.path.join(self.training_image_path, image_name) 73 | image = cv2.imread(image_path) # shape [h,w,c] 74 | 75 | if image.shape[0]!= self.out_h or image.shape[1] != self.out_w: 76 | image = cv2.resize(image, (self.out_w, self.out_h), interpolation=cv2.INTER_LINEAR) 77 | 78 | image = image[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 79 | image = np.ascontiguousarray(image, dtype=np.float32) # uint8 to float32 80 | image = torch.from_numpy(image) 81 | 82 | if self.paper_affine_generator: 83 | theta = generator_affine_param() 84 | else: 85 | theta = random_affine() 86 | 87 | theta = torch.from_numpy(theta.astype(np.float32)) 88 | 89 | sample = {'image': image, 'theta': theta, 'name': image_name} 90 | 91 | if self.transform: 92 | sample = self.transform(sample) 93 | 94 | # elpased = calculate_diff_time(total_start_time) 95 | # print('getitem时间:',elpased) # 0.011s 96 | 97 | return sample 98 | 99 | ''' 100 | 使用仿射变换参数生成图片对 101 | 返回{"source_image,traget_image,theta_GT,name"} 102 | ''' 103 | class RandomTnsPair(object): 104 | 105 | def __init__(self, use_cuda=True, crop_factor=9 / 16, output_size=(240, 240), 106 | padding_factor=0.6): 107 | self.use_cuda = use_cuda 108 | self.crop_factor = crop_factor 109 | self.padding_factor = padding_factor 110 | self.out_h, self.out_w = output_size 111 | self.channel_choicelist = [0,1,2] 112 | self.rescalingTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 113 | self.geometricTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 114 | 115 | def __call__(self, batch): 116 | image_batch, theta_batch,image_name = batch['image'], batch['theta'],batch['name'] 117 | if self.use_cuda: 118 | image_batch = image_batch.cuda() 119 | theta_batch = theta_batch.cuda() 120 | 121 | b, c, h, w = image_batch.size() 122 | 123 | # 为较大的采样区域生成对称填充图像 124 | image_batch = symmetricImagePad(image_batch, self.padding_factor,use_cuda=self.use_cuda) 125 | 126 | # indices_R = torch.tensor([choice(self.channel_choicelist)]) 127 | # indices_G = torch.tensor([choice(self.channel_choicelist)]) 128 | 129 | indices_R = torch.tensor([0]) 130 | indices_G = torch.tensor([2]) 131 | 132 | if self.use_cuda: 133 | indices_R = indices_R.cuda() 134 | indices_G = indices_G.cuda() 135 | 136 | image_batch_R = torch.index_select(image_batch, 1, indices_R) 137 | image_batch_G = torch.index_select(image_batch, 1, indices_G) 138 | 139 | image_batch_R = torch.cat((image_batch_R,image_batch_R,image_batch_R),1) 140 | image_batch_G = torch.cat((image_batch_G,image_batch_G,image_batch_G),1) 141 | 142 | # 获取裁剪的图像 143 | cropped_image_batch = self.rescalingTnf(image_batch_R, None, self.padding_factor, 144 | self.crop_factor) # Identity is used as no theta given 145 | # 获取裁剪变换的图像 146 | warped_image_batch = self.geometricTnf(image_batch_G, theta_batch, 147 | self.padding_factor, 148 | self.crop_factor) # Identity is used as no theta given 149 | 150 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch, 151 | 'name':image_name} 152 | 153 | # theta_batch = inverse_theta(theta_batch,use_cuda=True) 154 | # 155 | # return {'source_image': warped_image_batch, 'target_image': cropped_image_batch, 'theta_GT': theta_batch, 156 | # 'name':image_name} 157 | 158 | class RandomTnsPairSingleChannelTest(object): 159 | 160 | def __init__(self, use_cuda=True, crop_factor=9 / 16, output_size=(240, 240), 161 | padding_factor=0.6): 162 | self.use_cuda = use_cuda 163 | self.crop_factor = crop_factor 164 | self.padding_factor = padding_factor 165 | self.out_h, self.out_w = output_size 166 | self.channel_choicelist = [0,1,2] 167 | self.rescalingTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 168 | self.geometricTnf = AffineTnf(self.out_h, self.out_w,use_cuda=self.use_cuda) 169 | 170 | def __call__(self, batch): 171 | image_batch, theta_batch,image_name = batch['image'], batch['theta'],batch['name'] 172 | if self.use_cuda: 173 | image_batch = image_batch.cuda() 174 | theta_batch = theta_batch.cuda() 175 | 176 | b, c, h, w = image_batch.size() 177 | 178 | # 为较大的采样区域生成对称填充图像 179 | image_batch = symmetricImagePad(image_batch, self.padding_factor,use_cuda=self.use_cuda) 180 | 181 | # indices_R = torch.tensor([choice(self.channel_choicelist)]) 182 | # indices_G = torch.tensor([choice(self.channel_choicelist)]) 183 | 184 | indices_R = torch.tensor([0]) 185 | indices_G = torch.tensor([2]) 186 | # 187 | if self.use_cuda: 188 | indices_R = indices_R.cuda() 189 | indices_G = indices_G.cuda() 190 | 191 | image_batch_R = torch.index_select(image_batch, 1, indices_R) 192 | image_batch_G = torch.index_select(image_batch, 1, indices_G) 193 | 194 | image_batch_R = torch.cat((image_batch_R,image_batch_R,image_batch_R),1) 195 | image_batch_G = torch.cat((image_batch_G,image_batch_G,image_batch_G),1) 196 | 197 | # 获取裁剪的图像 198 | cropped_image_batch = self.rescalingTnf(image_batch_R, None, self.padding_factor, 199 | self.crop_factor) # Identity is used as no theta given 200 | # 获取裁剪变换的图像 201 | warped_image_batch = self.geometricTnf(image_batch_G, theta_batch, 202 | self.padding_factor, 203 | self.crop_factor) # Identity is used as no theta given 204 | 205 | return {'source_image': cropped_image_batch, 'target_image': warped_image_batch, 'theta_GT': theta_batch, 206 | 'name':image_name} -------------------------------------------------------------------------------- /main/eval_harvard_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io 3 | import numpy as np 4 | 5 | import torch 6 | from collections import OrderedDict 7 | import cv2 8 | from torch.utils.data import DataLoader 9 | import matplotlib.pyplot as plt 10 | 11 | from cnn_geometric.cnn_geometric_model import CNNGeometric 12 | from datasets.provider.harvardData import HarvardData, HarvardDataPair 13 | from datasets.provider.nirrgbData import NirRgbData, NirRgbTnsPair 14 | from datasets.provider.randomTnsData import RandomTnsPair, RandomTnsPairSingleChannelTest 15 | from datasets.provider.singlechannelData import SinglechannelData, SingleChannelPairTnf 16 | from datasets.provider.test_dataset import TestDataset 17 | from evluate.lossfunc import GridLoss, NTGLoss 18 | from evluate.visualize_result import visualize_compare_result, visualize_iter_result, visualize_spec_epoch_result, \ 19 | visualize_cnn_result 20 | from main.test_mulit_images import compute_average_grid_loss, compute_correct_rate, createModel, createCVPRModel 21 | from model.cnn_registration_model import CNNRegistration 22 | from ntg_pytorch.register_func import estimate_aff_param_iterator 23 | from tnf_transform.img_process import preprocess_image, NormalizeImage, NormalizeImageDict 24 | from tnf_transform.transformation import AffineTnf, affine_transform_opencv, affine_transform_pytorch, AffineGridGen 25 | from util.pytorchTcv import theta2param, param2theta 26 | from util.time_util import calculate_diff_time 27 | from traditional_ntg.estimate_affine_param import estimate_affine_param, estimate_param_batch 28 | from visualization.matplot_tool import plot_batch_result 29 | import time 30 | import torch.nn.functional as F 31 | 32 | from visualization.train_visual import VisdomHelper 33 | 34 | 35 | def createDataloader(image_path,batch_size = 16,use_cuda=True,single_channel=False): 36 | ''' 37 | 创建dataloader 38 | :param image_path: 39 | :param label_path: 40 | :param batch_size: 41 | :param use_cuda: 42 | :return: 43 | ''' 44 | # dataset = HarvardData(image_path,label_path,transform=NormalizeImageDict(["image"])) 45 | dataset = HarvardData(image_path) 46 | dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True) 47 | pair_generator = HarvardDataPair(single_channel=single_channel) 48 | 49 | return dataloader,pair_generator 50 | 51 | def iterDataset(dataloader,pair_generator,ntg_model,cvpr_model,vis,threshold=10,use_cuda=True): 52 | ''' 53 | 迭代数据集中的批次数据,进行处理 54 | :param dataloader: 55 | :param pair_generator: 56 | :param ntg_model: 57 | :param use_cuda: 58 | :return: 59 | ''' 60 | 61 | fn_grid_loss = GridLoss(use_cuda=use_cuda) 62 | grid_loss_cnn_list = [] 63 | grid_loss_cvpr_list = [] 64 | grid_loss_ntg_list = [] 65 | grid_loss_comb_list = [] 66 | 67 | ntg_loss_total = 0 68 | cnn_ntg_loss_total = 0 69 | 70 | # batch {image.shape = } 71 | for batch_idx,batch in enumerate(dataloader): 72 | #print("batch_id",batch_idx,'/',len(dataloader)) 73 | 74 | # if batch_idx == 15: 75 | # break 76 | 77 | if batch_idx % 5 == 0: 78 | print('test batch: [{}/{} ({:.0f}%)]'.format( 79 | batch_idx, len(dataloader), 80 | 100. * batch_idx / len(dataloader))) 81 | 82 | pair_batch = pair_generator(batch) # image[batch_size,1,w,h] theta_GT[batch_size,2,3] 83 | 84 | theta_estimate_batch = ntg_model(pair_batch) # theta [batch_size,6] 85 | 86 | if cvpr_model is not None: 87 | theta_cvpr_estimate_batch = cvpr_model(pair_batch) 88 | 89 | source_image_batch = pair_batch['source_image'] 90 | target_image_batch = pair_batch['target_image'] 91 | theta_GT_batch = pair_batch['theta_GT'] 92 | image_name = pair_batch['name'] 93 | 94 | ## 计算网格点损失配准误差 95 | # 将pytorch的变换参数转为opencv的变换参数 96 | theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 97 | 98 | #print('使用并行ntg进行估计') 99 | with torch.no_grad(): 100 | 101 | ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:, 0, :, :].unsqueeze(1), 102 | target_image_batch[:, 0, :, :].unsqueeze(1), 103 | None, use_cuda=use_cuda, itermax=600) 104 | 105 | cnn_ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:,0,:,:].unsqueeze(1), 106 | target_image_batch[:,0,:,:].unsqueeze(1), 107 | theta_opencv,use_cuda=use_cuda,itermax=600) 108 | cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch, 240, 240, use_cuda=use_cuda) 109 | ntg_param_pytorch_batch = param2theta(ntg_param_batch, 240, 240, use_cuda=use_cuda) 110 | cnn_ntg_wraped_image = affine_transform_pytorch(source_image_batch, cnn_ntg_param_pytorch_batch) 111 | ntg_wraped_image = affine_transform_pytorch(source_image_batch, ntg_param_pytorch_batch) 112 | cnn_wraped_image = affine_transform_pytorch(source_image_batch, theta_estimate_batch) 113 | GT_image = affine_transform_pytorch(source_image_batch, theta_GT_batch) 114 | 115 | # loss_cvpr_2018 = fn_grid_loss.compute_grid_loss(theta_cvpr_estimate_batch,theta_GT_batch) 116 | loss_cnn = fn_grid_loss.compute_grid_loss(theta_estimate_batch.detach(),theta_GT_batch) 117 | loss_ntg = fn_grid_loss.compute_grid_loss(ntg_param_pytorch_batch.detach(),theta_GT_batch) 118 | loss_cnn_ntg = fn_grid_loss.compute_grid_loss(cnn_ntg_param_pytorch_batch.detach(),theta_GT_batch) 119 | 120 | vis.showHarvardBatch(source_image_batch,normailze=True,win='source_image_batch',title='source_image_batch') 121 | vis.showHarvardBatch(target_image_batch,normailze=True,win='target_image_batch',title='target_image_batch') 122 | vis.showHarvardBatch(ntg_wraped_image,normailze=True,win='ntg_wraped_image',title='ntg_wraped_image') 123 | vis.showHarvardBatch(cnn_wraped_image,normailze=True,win='cnn_wraped_image',title='cnn_wraped_image') 124 | vis.showHarvardBatch(cnn_ntg_wraped_image,normailze=True,win='cnn_ntg_wraped_image',title='cnn_ntg_wraped_image') 125 | vis.showHarvardBatch(GT_image,normailze=True,win='GT_image',title='GT_image') 126 | 127 | 128 | grid_loss_ntg_list.append(loss_ntg.detach().cpu()) 129 | grid_loss_cnn_list.append(loss_cnn.detach().cpu()) 130 | grid_loss_comb_list.append(loss_cnn_ntg.detach().cpu()) 131 | # grid_loss_cvpr_list.append(loss_cvpr_2018.detach().cpu()) 132 | 133 | print("网格点损失超过阈值的不计入平均值") 134 | print('ntg网格点损失') 135 | ntg_group_list = compute_average_grid_loss(grid_loss_ntg_list) 136 | print('cnn网格点损失') 137 | cnn_group_list = compute_average_grid_loss(grid_loss_cnn_list) 138 | print('cnn_ntg网格点损失') 139 | cnn_ntg_group_list = compute_average_grid_loss(grid_loss_comb_list) 140 | print('cvpr网格点损失') 141 | # cvpr_group_list = compute_average_grid_loss(grid_loss_cvpr_list) 142 | 143 | x_list = [i for i in range(10)] 144 | 145 | # vis.drawGridlossBar(x_list,ntg_group_list,cnn_group_list,cnn_ntg_group_list,cvpr_group_list, 146 | # layout_title="Grid_loss_histogram",win='Grid_loss_histogram') 147 | 148 | print("计算CNN平均NTG值",ntg_loss_total / len(dataloader)) 149 | print("计算CNN+NTG平均NTG值",cnn_ntg_loss_total / len(dataloader)) 150 | 151 | print("计算正确率") 152 | print('ntg正确率') 153 | compute_correct_rate(grid_loss_ntg_list, threshold=threshold) 154 | print('cnn正确率') 155 | compute_correct_rate(grid_loss_cnn_list,threshold=threshold) 156 | print('cnn+ntg 正确率') 157 | compute_correct_rate(grid_loss_comb_list,threshold=threshold) 158 | # print('cvpr正确率') 159 | # compute_correct_rate(grid_loss_cvpr_list, threshold=threshold) 160 | 161 | def main(): 162 | 163 | single_channel = False 164 | print("开始进行测试") 165 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 166 | 167 | ntg_checkpoint_path = '/mnt/4T/zlk/trained_weights/best_checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar' 168 | #test_image_path = '/home/zlk/datasets/coco_test2017' 169 | # test_image_path = '/mnt/4T/zlk/datasets/mulitspectral/Harvard' 170 | test_image_path = '/mnt/4T/zlk/datasets/mulitspectral/complete_ms_data_mat' 171 | 172 | threshold = 3 173 | batch_size = 1 174 | # 加载模型 175 | use_cuda = torch.cuda.is_available() 176 | 177 | # vis = VisdomHelper(env_name='Harvard_test') 178 | vis = VisdomHelper(env_name='CAVE_test') 179 | 180 | ntg_model = createModel(ntg_checkpoint_path,use_cuda=use_cuda,single_channel=single_channel) 181 | cvpr_model = createCVPRModel(use_cuda=use_cuda) 182 | 183 | print('测试harvard网格点损失') 184 | dataloader,pair_generator = createDataloader(test_image_path,batch_size,use_cuda = use_cuda,single_channel=single_channel) 185 | 186 | iterDataset(dataloader,pair_generator,ntg_model,cvpr_model,vis,threshold=threshold,use_cuda=use_cuda) 187 | 188 | if __name__ == '__main__': 189 | main() 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /ntg_pytorch/register_func.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from ntg_pytorch.register_loss import ntg_gradient_torch 9 | from ntg_pytorch.register_pyramid import compute_pyramid, compute_pyramid_pytorch, ScaleTnf, compute_pyramid_iter 10 | import matplotlib.pyplot as plt 11 | 12 | from traditional_ntg.loss_function import ntg_gradient 13 | 14 | 15 | def scale_image(img,IMIN,IMAX): 16 | return (img-IMIN)/(IMAX-IMIN) 17 | 18 | def affine_transform(im,p): 19 | height = im.shape[0] 20 | width = im.shape[1] 21 | im = cv2.warpAffine(im,p,(width,height),flags=cv2.INTER_CUBIC) 22 | # im = cv2.warpAffine(im,p,(width,height),flags=cv2.INTER_NEAREST) 23 | return im 24 | 25 | ''' 26 | 注意,如果使用cnn计算出来的参数来给传统方法继续迭代的话,计算高斯金字塔的时候不能进行高斯滤波,因为高斯滤波会降低精度,猜想是因为 27 | 有些cnn得到的结果不是很准,这样进行平滑滤波的时候可能会把信息给掩盖掉。 28 | ''' 29 | def estimate_aff_param_iterator(source_batch,target_batch,theta_opencv_batch=None,use_cuda=False,itermax = 800,normalize_func = None): 30 | 31 | batch_size = source_batch.shape[0] 32 | 33 | parser = {} 34 | parser['tol'] = 1e-6 35 | parser['itermax'] = itermax 36 | parser['pyramid_spacing'] = 1.5 37 | parser['minSize'] = 16 38 | 39 | if theta_opencv_batch is None: 40 | p = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) 41 | p = np.tile(p, (batch_size, 1, 1)).astype(np.float32) 42 | p = torch.from_numpy(p) 43 | parser['initial_affine_param'] = p 44 | else: 45 | parser['initial_affine_param'] = theta_opencv_batch.clone() 46 | 47 | # start_time = time.time() 48 | 49 | pyramid_level1 = 1 + np.floor(np.log(source_batch.shape[2] / parser['minSize']) / np.log(parser['pyramid_spacing'])) 50 | pyramid_level2 = 1 + np.floor(np.log(source_batch.shape[3] / parser['minSize']) / np.log(parser['pyramid_spacing'])) 51 | parser['pyramid_levels'] = np.min((int(pyramid_level1),int(pyramid_level2))) 52 | # 实测发现如果金字塔不够的话有些情况下可能导致cnn+ntg结合起来的精度还不如传统NTG的精度。 53 | # print('串行金字塔,层数减1') 54 | if theta_opencv_batch is not None: 55 | parser['pyramid_levels'] = parser['pyramid_levels'] -1 56 | # parser['pyramid_levels'] = 1 57 | 58 | 59 | if normalize_func is not None: 60 | print('分开归一化') 61 | source_batch = normalize_func.scale_image_batch(source_batch) 62 | target_batch = normalize_func.scale_image_batch(target_batch) 63 | else: 64 | print('原图目标图联合归一化') 65 | source_batch_max = torch.max(source_batch.view(batch_size,1,-1),2)[0].unsqueeze(2).unsqueeze(2) 66 | target_batch_max = torch.max(target_batch.view(batch_size,1,-1),2)[0].unsqueeze(2).unsqueeze(2) 67 | 68 | IMAX_index = target_batch_max > source_batch_max 69 | source_batch_max[IMAX_index] = target_batch_max[IMAX_index] 70 | IMAX = source_batch_max 71 | 72 | source_batch_min = torch.min(source_batch.view(batch_size,1,-1),2)[0].unsqueeze(2).unsqueeze(2) 73 | target_batch_min = torch.min(target_batch.view(batch_size,1,-1),2)[0].unsqueeze(2).unsqueeze(2) 74 | 75 | IMIN_index = target_batch_min < source_batch_min 76 | source_batch_min[IMIN_index] = target_batch_min[IMIN_index] 77 | IMIN = source_batch_min 78 | 79 | source_batch = scale_image(source_batch,IMIN,IMAX) 80 | target_batch = scale_image(target_batch,IMIN,IMAX) 81 | 82 | batch_size,channel, h, w = source_batch.shape 83 | 84 | smooth_sigma = np.sqrt(parser['pyramid_spacing']) / np.sqrt(3) 85 | kx = cv2.getGaussianKernel(int(2 * round(1.5 * smooth_sigma)) + 1, smooth_sigma) 86 | ky = cv2.getGaussianKernel(int(2 * round(1.5 * smooth_sigma)) + 1, smooth_sigma) 87 | hg = np.multiply(kx, np.transpose(ky)) 88 | 89 | # print("配置时间",calculate_diff_time(start_time)) 90 | # start_time = time.time() 91 | 92 | scaleTnf = ScaleTnf(use_cuda=use_cuda) 93 | 94 | # print('使用pytorch计算金字塔') 实验证明不如串行金字塔好 95 | # pyramid_images_list = compute_pyramid_pytorch(source_batch,scaleTnf,hg, int(parser['pyramid_levels']), 96 | # 1 / parser['pyramid_spacing'],use_cuda = use_cuda) 97 | # 98 | # target_pyramid_images_list = compute_pyramid_pytorch(target_batch,scaleTnf,hg, int(parser['pyramid_levels']), 99 | # 1 / parser['pyramid_spacing'],use_cuda = use_cuda) 100 | 101 | 102 | # print('使用不加高斯滤波串行计算金字塔') 103 | # pyramid_images_list = compute_pyramid(source_batch,hg, int(parser['pyramid_levels']), 104 | # 1 / parser['pyramid_spacing'],use_cuda=use_cuda) 105 | # 106 | # target_pyramid_images_list = compute_pyramid(target_batch,hg, int(parser['pyramid_levels']), 107 | # 1 / parser['pyramid_spacing'],use_cuda=use_cuda) 108 | # 109 | pyramid_images_list = compute_pyramid_iter(source_batch,hg, int(parser['pyramid_levels']), 110 | 1 / parser['pyramid_spacing'],use_cuda=use_cuda) 111 | 112 | target_pyramid_images_list = compute_pyramid_iter(target_batch,hg, int(parser['pyramid_levels']), 113 | 1 / parser['pyramid_spacing'],use_cuda=use_cuda) 114 | 115 | # plt.show() 116 | 117 | # print("pytorch金字塔",calculate_diff_time(start_time)) 118 | # start_time = time.time() 119 | 120 | # 这里因为传入的变换参数是从CNN中获得的,大小为240*240,所以传入进来使用的话需要进行缩放,使用最小层除以最大层得到缩放比例, 121 | # 然后就得到CNN变换比例得到的最小层相应的大小了。 122 | if theta_opencv_batch is not None: 123 | ration_diff = pyramid_images_list[-1].shape[-1] / pyramid_images_list[0].shape[-1] 124 | parser['initial_affine_param'][:,0, 2] = parser['initial_affine_param'][:,0, 2]*ration_diff 125 | parser['initial_affine_param'][:,1, 2] = parser['initial_affine_param'][:,1, 2]*ration_diff 126 | 127 | for k in range(parser['pyramid_levels'] - 1, -1, -1): 128 | if k == (parser['pyramid_levels'] - 1): 129 | p = parser['initial_affine_param'] 130 | if use_cuda: 131 | p = p.cuda() 132 | 133 | else: 134 | parser['itermax'] = math.ceil(parser['itermax'] / parser['pyramid_spacing']) 135 | p[:, 0, 2] = p[:, 0, 2] * pyramid_images_list[k].shape[3] / pyramid_images_list[k + 1].shape[3] 136 | p[:, 1, 2] = p[:, 1, 2] * pyramid_images_list[k].shape[2] / pyramid_images_list[k + 1].shape[2] 137 | 138 | copy = {} 139 | copy['parser'] = parser 140 | # copy['source_images'] = torch.from_numpy(pyramid_images_list[k]).float() 141 | # copy['target_images'] = torch.from_numpy(target_pyramid_images_list[k]).float() 142 | copy['source_images'] = pyramid_images_list[k] 143 | copy['target_images'] = target_pyramid_images_list[k] 144 | 145 | if use_cuda: 146 | copy['source_images'] = copy['source_images'].cuda() 147 | copy['target_images'] = copy['target_images'].cuda() 148 | 149 | sz = [pyramid_images_list[k].shape[2], pyramid_images_list[k].shape[3]] 150 | xlist = torch.tensor(range(0, sz[1])) 151 | ylist = torch.tensor(range(0, sz[0])) 152 | 153 | if use_cuda: 154 | xlist = xlist.cuda() 155 | ylist = ylist.cuda() 156 | 157 | [X_array, Y_array] = torch.meshgrid(xlist, ylist) 158 | 159 | X_array = X_array.float().transpose(0, 1) 160 | Y_array = Y_array.float().transpose(0, 1) 161 | 162 | # copy['W_array'] = X_array.expand(batch_size, channel, -1, -1) 163 | # copy['H_array'] = Y_array.expand(batch_size, channel, -1, -1) 164 | 165 | copy['X_array'] = (X_array / torch.max(X_array)).expand(batch_size,channel,-1,-1) 166 | copy['Y_array'] = (Y_array / torch.max(Y_array)).expand(batch_size,channel,-1,-1) 167 | 168 | converged = False 169 | iter = 0 170 | steplength = 0.5 / np.max(sz) 171 | 172 | while not converged: 173 | start_time = time.time() 174 | 175 | # source_image_batch = copy['source_images'].squeeze().numpy() 176 | # target_image_batch = copy['target_images'].squeeze().numpy() 177 | # images = np.stack((source_image_batch,target_image_batch),2) 178 | # copy['images'] =images 179 | # copy['options'] = None 180 | # copy['X'] = (X_array / torch.max(X_array)).expand(batch_size, channel, -1, -1).numpy() 181 | # copy['Y'] = (Y_array / torch.max(Y_array)).expand(batch_size, channel, -1, -1).numpy() 182 | # g = ntg_gradient(copy,p.squeeze().numpy()) 183 | # g = torch.from_numpy(g).unsqueeze(0).float() 184 | 185 | g = ntg_gradient_torch(copy, p, use_cuda=use_cuda).detach() 186 | # print("ntg_gradient_torch", calculate_diff_time(start_time)) 187 | if p is None: 188 | print("p is None") 189 | p = p + steplength * g / torch.max(torch.abs(g+1e-16).view(g.shape[0],-1),1)[0].unsqueeze(1).unsqueeze(1) 190 | #residualError = torch.max(torch.abs(g[0])) 191 | residualError = torch.max(torch.abs(g).view(g.shape[0],-1),1)[0] 192 | iter = iter + 1 193 | #converged = (iter >= parser['itermax']) or (residualError < parser['tol']) 194 | converged = iter >= parser['itermax'] 195 | # print(converged) 196 | # if converged: 197 | # print(str(k) + " " + str(iter) + " " + str(residualError[0:8])) 198 | #print(str(k) + " " + str(iter)) 199 | #torch.cuda.empty_cache() 200 | 201 | #print("循环结束时间:",calculate_diff_time(start_time)) 202 | 203 | return p 204 | 205 | -------------------------------------------------------------------------------- /main/eval_cave_images_singlechannel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join, abspath, dirname 4 | 5 | from skimage import io 6 | import numpy as np 7 | 8 | import torch 9 | from collections import OrderedDict 10 | import cv2 11 | from torch.utils.data import DataLoader 12 | import matplotlib.pyplot as plt 13 | 14 | from cnn_geometric.cnn_geometric_model import CNNGeometric 15 | from datasets.provider.harvardData import HarvardData, HarvardDataPair 16 | from datasets.provider.nirrgbData import NirRgbData, NirRgbTnsPair 17 | from datasets.provider.randomTnsData import RandomTnsPair, RandomTnsPairSingleChannelTest 18 | from datasets.provider.singlechannelData import SinglechannelData, SingleChannelPairTnf 19 | from datasets.provider.test_dataset import TestDataset 20 | from evluate.lossfunc import GridLoss, NTGLoss 21 | from evluate.visualize_result import visualize_compare_result, visualize_iter_result, visualize_spec_epoch_result, \ 22 | visualize_cnn_result 23 | from main.test_mulit_images import compute_average_grid_loss, compute_correct_rate, createModel, createCVPRModel 24 | from model.cnn_registration_model import CNNRegistration 25 | from ntg_pytorch.register_func import estimate_aff_param_iterator 26 | from tnf_transform.img_process import preprocess_image, NormalizeImage, NormalizeImageDict, normalize_image_simple 27 | from tnf_transform.transformation import AffineTnf, affine_transform_opencv, affine_transform_pytorch, AffineGridGen 28 | from util.pytorchTcv import theta2param, param2theta 29 | from util.time_util import calculate_diff_time 30 | from traditional_ntg.estimate_affine_param import estimate_affine_param, estimate_param_batch 31 | from visualization.matplot_tool import plot_batch_result 32 | import time 33 | import torch.nn.functional as F 34 | 35 | from visualization.train_visual import VisdomHelper 36 | 37 | 38 | def createDataloader(image_path,single_channel = False ,batch_size = 16,use_cuda=True): 39 | ''' 40 | 创建dataloader 41 | :param image_path: 42 | :param label_path: 43 | :param batch_size: 44 | :param use_cuda: 45 | :return: 46 | ''' 47 | # dataset = HarvardData(image_path,label_path,transform=NormalizeImageDict(["image"])) 48 | dataset = HarvardData(image_path) 49 | dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=False,num_workers=4,pin_memory=True) 50 | pair_generator = HarvardDataPair(single_channel=single_channel) 51 | 52 | return dataloader,pair_generator 53 | 54 | def iterDataset(dataloader,pair_generator,ntg_model,vis,threshold=10,use_cuda=True,single_channel = False): 55 | ''' 56 | 迭代数据集中的批次数据,进行处理 57 | :param dataloader: 58 | :param pair_generator: 59 | :param ntg_model: 60 | :param use_cuda: 61 | :return: 62 | ''' 63 | 64 | fn_grid_loss = GridLoss(use_cuda=use_cuda) 65 | grid_loss_cnn_list = [] 66 | grid_loss_cvpr_list = [] 67 | grid_loss_ntg_list = [] 68 | grid_loss_comb_list = [] 69 | 70 | ntg_loss_total = 0 71 | cnn_ntg_loss_total = 0 72 | 73 | # batch {image.shape = } 74 | for batch_idx,batch in enumerate(dataloader): 75 | #print("batch_id",batch_idx,'/',len(dataloader)) 76 | 77 | # if batch_idx == 1: 78 | # break 79 | 80 | if batch_idx % 5 == 0: 81 | print('test batch: [{}/{} ({:.0f}%)]'.format( 82 | batch_idx, len(dataloader), 83 | 100. * batch_idx / len(dataloader))) 84 | 85 | pair_batch = pair_generator(batch) # image[batch_size,1,w,h] theta_GT[batch_size,2,3] 86 | 87 | theta_estimate_batch = ntg_model(pair_batch) # theta [batch_size,6] 88 | 89 | source_image_batch = pair_batch['source_image'] 90 | target_image_batch = pair_batch['target_image'] 91 | 92 | # source_image_batch = normalize_image_simple(source_image_batch,forward=False) 93 | # target_image_batch = normalize_image_simple(target_image_batch,forward=False) 94 | 95 | theta_GT_batch = pair_batch['theta_GT'] 96 | image_name = pair_batch['name'] 97 | 98 | ## 计算网格点损失配准误差 99 | # 将pytorch的变换参数转为opencv的变换参数 100 | theta_opencv = theta2param(theta_estimate_batch.view(-1, 2, 3), 240, 240, use_cuda=use_cuda) 101 | 102 | #print('使用并行ntg进行估计') 103 | with torch.no_grad(): 104 | 105 | if single_channel: 106 | ntg_param_batch = estimate_aff_param_iterator(source_image_batch, 107 | target_image_batch, 108 | None, use_cuda=use_cuda, itermax=600) 109 | 110 | cnn_ntg_param_batch = estimate_aff_param_iterator(source_image_batch, 111 | target_image_batch, 112 | theta_opencv, use_cuda=use_cuda, itermax=600) 113 | else: 114 | ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:, 0, :, :].unsqueeze(1), 115 | target_image_batch[:, 0, :, :].unsqueeze(1), 116 | None, use_cuda=use_cuda, itermax=600) 117 | 118 | cnn_ntg_param_batch = estimate_aff_param_iterator(source_image_batch[:,0,:,:].unsqueeze(1), 119 | target_image_batch[:,0,:,:].unsqueeze(1), 120 | theta_opencv,use_cuda=use_cuda,itermax=600) 121 | cnn_ntg_param_pytorch_batch = param2theta(cnn_ntg_param_batch, 240, 240, use_cuda=use_cuda) 122 | ntg_param_pytorch_batch = param2theta(ntg_param_batch, 240, 240, use_cuda=use_cuda) 123 | cnn_ntg_wraped_image = affine_transform_pytorch(source_image_batch, cnn_ntg_param_pytorch_batch) 124 | ntg_wraped_image = affine_transform_pytorch(source_image_batch, ntg_param_pytorch_batch) 125 | cnn_wraped_image = affine_transform_pytorch(source_image_batch, theta_estimate_batch) 126 | GT_image = affine_transform_pytorch(source_image_batch, theta_GT_batch) 127 | 128 | loss_cnn = fn_grid_loss.compute_grid_loss(theta_estimate_batch.detach(),theta_GT_batch) 129 | loss_ntg = fn_grid_loss.compute_grid_loss(ntg_param_pytorch_batch.detach(),theta_GT_batch) 130 | loss_cnn_ntg = fn_grid_loss.compute_grid_loss(cnn_ntg_param_pytorch_batch.detach(),theta_GT_batch) 131 | 132 | vis.showHarvardBatch(source_image_batch,normailze=True,win='source_image_batch',title='source_image_batch') 133 | vis.showHarvardBatch(target_image_batch,normailze=True,win='target_image_batch',title='target_image_batch') 134 | vis.showHarvardBatch(ntg_wraped_image,normailze=True,win='ntg_wraped_image',title='ntg_wraped_image') 135 | vis.showHarvardBatch(cnn_wraped_image,normailze=True,win='cnn_wraped_image',title='cnn_wraped_image') 136 | vis.showHarvardBatch(cnn_ntg_wraped_image,normailze=True,win='cnn_ntg_wraped_image',title='cnn_ntg_wraped_image') 137 | vis.showHarvardBatch(GT_image,normailze=True,win='GT_image',title='GT_image') 138 | 139 | 140 | grid_loss_ntg_list.append(loss_ntg.detach().cpu()) 141 | grid_loss_cnn_list.append(loss_cnn.detach().cpu()) 142 | grid_loss_comb_list.append(loss_cnn_ntg.detach().cpu()) 143 | 144 | print("网格点损失超过阈值的不计入平均值") 145 | print('ntg网格点损失') 146 | ntg_group_list = compute_average_grid_loss(grid_loss_ntg_list) 147 | print('cnn网格点损失') 148 | cnn_group_list = compute_average_grid_loss(grid_loss_cnn_list) 149 | print('cnn_ntg网格点损失') 150 | cnn_ntg_group_list = compute_average_grid_loss(grid_loss_comb_list) 151 | 152 | x_list = [i for i in range(10)] 153 | 154 | # vis.drawGridlossBar(x_list,ntg_group_list,cnn_group_list,cnn_ntg_group_list,cvpr_group_list, 155 | # layout_title="Grid_loss_histogram",win='Grid_loss_histogram') 156 | 157 | print("计算CNN平均NTG值",ntg_loss_total / len(dataloader)) 158 | print("计算CNN+NTG平均NTG值",cnn_ntg_loss_total / len(dataloader)) 159 | 160 | print("计算正确率") 161 | print('ntg正确率') 162 | compute_correct_rate(grid_loss_ntg_list, threshold=threshold) 163 | print('cnn正确率') 164 | compute_correct_rate(grid_loss_cnn_list,threshold=threshold) 165 | print('cnn+ntg 正确率') 166 | compute_correct_rate(grid_loss_comb_list,threshold=threshold) 167 | 168 | def main(): 169 | 170 | single_channel = True 171 | 172 | print("开始进行测试") 173 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 174 | 175 | #ntg_checkpoint_path = '/mnt/4T/zlk/trained_weights/best_checkpoint_coco2017_multi_gpu_paper30_NTG_resnet101.pth.tar' 176 | # ntg_checkpoint_path = '/mnt/4T/zlk/trained_weights/checkpoint_NTG_resnet101.pth.tar' # 这两个一样 177 | ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/output/voc2012_coco2014_NTG_resnet101.pth.tar' # 这两个一样 178 | # ntg_checkpoint_path = '/home/zlk/project/registration_cnn_ntg/trained_weight/voc2011/best_checkpoint_voc2011_NTG_resnet101.pth.tar' 179 | test_image_path = '/mnt/4T/zlk/datasets/mulitspectral/complete_ms_data_mat' 180 | 181 | threshold = 3 182 | batch_size = 1 183 | # 加载模型 184 | use_cuda = torch.cuda.is_available() 185 | 186 | vis = VisdomHelper(env_name='CAVE_test') 187 | ntg_model = createModel(ntg_checkpoint_path,use_cuda=use_cuda,single_channel=single_channel) 188 | 189 | print('测试harvard网格点损失') 190 | dataloader,pair_generator = createDataloader(test_image_path,batch_size=batch_size,single_channel=single_channel,use_cuda = use_cuda) 191 | 192 | iterDataset(dataloader,pair_generator,ntg_model,vis,threshold=threshold,use_cuda=use_cuda) 193 | 194 | if __name__ == '__main__': 195 | 196 | main() 197 | 198 | 199 | 200 | --------------------------------------------------------------------------------