├── .gitignore ├── LICENSE ├── README.md ├── data ├── dataset_generate.py └── distortion_model.py ├── dataloaderNetM.py ├── dataloaderNetS.py ├── eval.py ├── imgs └── results.jpg ├── logger.py ├── modelNetM.py ├── modelNetS.py ├── resample └── resampling.py ├── trainNetM.py └── trainNetS.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xiaoyu Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeoProj 2 | 3 | ### [Paper](https://arxiv.org/abs/1909.03459) 4 | 5 | The source code of Blind Geometric Distortion Correction on Images Through Deep Learning by Li et al, CVPR 2019. 6 | 7 | 8 | 9 | ## Prerequisites 10 | - Linux or Windows 11 | - Python 3 12 | - CPU or NVIDIA GPU + CUDA CuDNN 13 | 14 | ## Getting Started 15 | 16 | ### Dataset Generation 17 | In order to train the model using the provided code, the data needs to be generated in a certain manner. 18 | 19 | You can use any distortion-free images to generate the dataset. In this paper, we use [Places365-Standard dataset](http://places2.csail.mit.edu/download.html) at the resolution of 512\*512 as the original non-distorted images to generate the 256\*256 dataset. 20 | 21 | Run the following command for dataset generation: 22 | ```bash 23 | python data/dataset_generate.py [--sourcedir [PATH]] [--datasetdir [PATH]] 24 | [--trainnum [NUMBER]] [--testnum [NUMBER]] 25 | 26 | --sourcedir Path to original non-distorted images 27 | --datasetdir Path to the generated dataset 28 | --trainnum Number of generated training samples 29 | --testnum Number of generated testing samples 30 | ``` 31 | 32 | ### Training 33 | Run the following command for help message about optional arguments like learning rate, dataset directory, etc. 34 | ```bash 35 | python trainNetS.py --h # if you want to train GeoNetS 36 | python trainNetM.py --h # if you want to train GeoNetM 37 | ``` 38 | 39 | ### Use a Pre-trained Model 40 | You can download the pretrained model [here](https://drive.google.com/open?id=1Tdi92IMA-rrX2ozdUMvfiN0jCZY7wIp_). 41 | 42 | You can also use `eval.py` and modify the model path, image path and saved result path to your own directory to generate your own results. 43 | 44 | ### Resampling 45 | Import `resample.resampling.rectification` function to resample the distorted image by the forward flow. 46 | 47 | The distorted image should be a Numpy array with the shape of H\*W\*3 for a color image or H\*W for a greyscale image, the forward flow should be an array with the shape of 2\*H\*W. 48 | 49 | The function will return the resulting image and a mask to indicate whether each pixel will converge within the maximum iteration. 50 | ## Citation 51 | ```bash 52 | @inproceedings{li2019blind, 53 | title={Blind Geometric Distortion Correction on Images Through Deep Learning}, 54 | author={Li, Xiaoyu and Zhang, Bo and Sander, Pedro V and Liao, Jing}, 55 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 56 | pages={4855--4864}, 57 | year={2019} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /data/dataset_generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage 3 | import skimage.io as io 4 | from skimage.transform import rescale 5 | import scipy.io as scio 6 | import distortion_model 7 | import argparse 8 | import os 9 | 10 | # For parsing commandline arguments 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--sourcedir", type=str, default='/home/xliea/GeoProj/Dataset/Dataset_512_ori') 13 | parser.add_argument("--datasetdir", type=str, default='/home/xliea/GeoProj/Dataset/Dataset_256_gen') 14 | parser.add_argument("--trainnum", type=int, default=50000, help='number of the training set') 15 | parser.add_argument("--testnum", type=int, default=5000, help='number of the test set') 16 | args = parser.parse_args() 17 | 18 | if not os.path.exists(args.datasetdir): 19 | os.mkdir(args.datasetdir) 20 | 21 | trainDisPath = args.datasetdir + '/train_distorted' 22 | trainUvPath = args.datasetdir + '/train_flow' 23 | testDisPath = args.datasetdir + '/test_distorted' 24 | testUvPath = args.datasetdir + '/test_flow' 25 | 26 | if not os.path.exists(trainDisPath): 27 | os.mkdir(trainDisPath) 28 | 29 | if not os.path.exists(trainUvPath): 30 | os.mkdir(trainUvPath) 31 | 32 | if not os.path.exists(testDisPath): 33 | os.mkdir(testDisPath) 34 | 35 | if not os.path.exists(testUvPath): 36 | os.mkdir(testUvPath) 37 | 38 | def generatedata(types, k, trainFlag): 39 | 40 | print(types,trainFlag,k) 41 | 42 | width = 512 43 | height = 512 44 | 45 | parameters = distortion_model.distortionParameter(types) 46 | 47 | OriImg = io.imread('%s%s%s%s' % (args.sourcedir, '/', str(k).zfill(6), '.jpg')) 48 | 49 | disImg = np.array(np.zeros(OriImg.shape), dtype = np.uint8) 50 | u = np.array(np.zeros((OriImg.shape[0],OriImg.shape[1])), dtype = np.float32) 51 | v = np.array(np.zeros((OriImg.shape[0],OriImg.shape[1])), dtype = np.float32) 52 | 53 | cropImg = np.array(np.zeros((int(height/2),int(width/2),3)), dtype = np.uint8) 54 | crop_u = np.array(np.zeros((int(height/2),int(width/2))), dtype = np.float32) 55 | crop_v = np.array(np.zeros((int(height/2),int(width/2))), dtype = np.float32) 56 | 57 | # crop range 58 | xmin = int(width*1/4) 59 | xmax = int(width*3/4 - 1) 60 | ymin = int(height*1/4) 61 | ymax = int(height*3/4 - 1) 62 | 63 | for i in range(width): 64 | for j in range(height): 65 | 66 | xu, yu = distortion_model.distortionModel(types, i, j, width, height, parameters) 67 | 68 | if (0 <= xu < width - 1) and (0 <= yu < height - 1): 69 | 70 | u[j][i] = xu - i 71 | v[j][i] = yu - j 72 | 73 | # Bilinear interpolation 74 | Q11 = OriImg[int(yu), int(xu), :] 75 | Q12 = OriImg[int(yu), int(xu) + 1, :] 76 | Q21 = OriImg[int(yu) + 1, int(xu), :] 77 | Q22 = OriImg[int(yu) + 1, int(xu) + 1, :] 78 | 79 | disImg[j,i,:] = Q11*(int(xu) + 1 - xu)*(int(yu) + 1 - yu) + \ 80 | Q12*(xu - int(xu))*(int(yu) + 1 - yu) + \ 81 | Q21*(int(xu) + 1 - xu)*(yu - int(yu)) + \ 82 | Q22*(xu - int(xu))*(yu - int(yu)) 83 | 84 | 85 | if(xmin <= i <= xmax) and (ymin <= j <= ymax): 86 | cropImg[j - ymin, i - xmin, :] = disImg[j,i,:] 87 | crop_u[j - ymin, i - xmin] = u[j,i] 88 | crop_v[j - ymin, i - xmin] = v[j,i] 89 | 90 | if trainFlag == True: 91 | saveImgPath = '%s%s%s%s%s%s' % (trainDisPath, '/',types,'_', str(k).zfill(6), '.jpg') 92 | saveMatPath = '%s%s%s%s%s%s' % (trainUvPath, '/',types,'_', str(k).zfill(6), '.mat') 93 | io.imsave(saveImgPath, cropImg) 94 | scio.savemat(saveMatPath, {'u': crop_u,'v': crop_v}) 95 | else: 96 | saveImgPath = '%s%s%s%s%s%s' % (testDisPath, '/',types,'_', str(k).zfill(6), '.jpg') 97 | saveMatPath = '%s%s%s%s%s%s' % (testUvPath, '/',types,'_', str(k).zfill(6), '.mat') 98 | io.imsave(saveImgPath, cropImg) 99 | scio.savemat(saveMatPath, {'u': crop_u,'v': crop_v}) 100 | 101 | def generatepindata(types, k, trainFlag): 102 | 103 | print(types,trainFlag,k) 104 | 105 | width = 256 106 | height = 256 107 | 108 | parameters = distortion_model.distortionParameter(types) 109 | 110 | OriImg = io.imread('%s%s%s%s' % (args.sourcedir, '/', str(k).zfill(6), '.jpg')) 111 | temImg = rescale(OriImg, 0.5, mode='reflect') 112 | ScaImg = skimage.img_as_ubyte(temImg) 113 | 114 | padImg = np.array(np.zeros((ScaImg.shape[0] + 1,ScaImg.shape[1] + 1, 3)), dtype = np.uint8) 115 | padImg[0:height, 0:width, :] = ScaImg[0:height, 0:width, :] 116 | padImg[height, 0:width, :] = ScaImg[height - 1, 0:width, :] 117 | padImg[0:height, width, :] = ScaImg[0:height, width - 1, :] 118 | padImg[height, width, :] = ScaImg[height - 1, width - 1, :] 119 | 120 | disImg = np.array(np.zeros(ScaImg.shape), dtype = np.uint8) 121 | u = np.array(np.zeros((ScaImg.shape[0],ScaImg.shape[1])), dtype = np.float32) 122 | v = np.array(np.zeros((ScaImg.shape[0],ScaImg.shape[1])), dtype = np.float32) 123 | 124 | for i in range(width): 125 | for j in range(height): 126 | 127 | xu, yu = distortion_model.distortionModel(types, i, j, width, height, parameters) 128 | 129 | if (0 <= xu <= width - 1) and (0 <= yu <= height - 1): 130 | 131 | u[j][i] = xu - i 132 | v[j][i] = yu - j 133 | 134 | # Bilinear interpolation 135 | Q11 = padImg[int(yu), int(xu), :] 136 | Q12 = padImg[int(yu), int(xu) + 1, :] 137 | Q21 = padImg[int(yu) + 1, int(xu), :] 138 | Q22 = padImg[int(yu) + 1, int(xu) + 1, :] 139 | 140 | disImg[j,i,:] = Q11*(int(xu) + 1 - xu)*(int(yu) + 1 - yu) + \ 141 | Q12*(xu - int(xu))*(int(yu) + 1 - yu) + \ 142 | Q21*(int(xu) + 1 - xu)*(yu - int(yu)) + \ 143 | Q22*(xu - int(xu))*(yu - int(yu)) 144 | 145 | if trainFlag == True: 146 | saveImgPath = '%s%s%s%s%s%s' % (trainDisPath, '/',types,'_', str(k).zfill(6), '.jpg') 147 | saveMatPath = '%s%s%s%s%s%s' % (trainUvPath, '/',types,'_', str(k).zfill(6), '.mat') 148 | io.imsave(saveImgPath, disImg) 149 | scio.savemat(saveMatPath, {'u': u,'v': v}) 150 | else: 151 | saveImgPath = '%s%s%s%s%s%s' % (testDisPath, '/',types,'_', str(k).zfill(6), '.jpg') 152 | saveMatPath = '%s%s%s%s%s%s' % (testUvPath, '/',types,'_', str(k).zfill(6), '.mat') 153 | io.imsave(saveImgPath, disImg) 154 | scio.savemat(saveMatPath, {'u': u,'v': v}) 155 | 156 | 157 | for types in ['barrel','rotation','shear','wave']: 158 | for k in range(args.trainnum): 159 | generatedata(types, k, trainFlag = True) 160 | 161 | for k in range(args.trainnum, args.trainnum + args.testnum): 162 | generatedata(types, k, trainFlag = False) 163 | 164 | for types in ['pincushion','projective']: 165 | for k in range(args.trainnum): 166 | generatepindata(types, k, trainFlag = True) 167 | 168 | for k in range(args.trainnum, args.trainnum + args.testnum): 169 | generatepindata(types, k, trainFlag = False) 170 | -------------------------------------------------------------------------------- /data/distortion_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | def distortionParameter(types): 5 | parameters = [] 6 | 7 | if (types == 'barrel'): 8 | Lambda = np.random.random_sample( )* -5e-5/4 9 | x0 = 256 10 | y0 = 256 11 | parameters.append(Lambda) 12 | parameters.append(x0) 13 | parameters.append(y0) 14 | return parameters 15 | 16 | elif (types == 'pincushion'): 17 | Lambda = np.random.random_sample() * 8.6e-5/4 18 | x0 = 128 19 | y0 = 128 20 | parameters.append(Lambda) 21 | parameters.append(x0) 22 | parameters.append(y0) 23 | return parameters 24 | 25 | elif (types == 'rotation'): 26 | theta = np.random.random_sample() * 30 - 15 27 | radian = math.pi*theta/180 28 | sina = math.sin(radian) 29 | cosa = math.cos(radian) 30 | parameters.append(sina) 31 | parameters.append(cosa) 32 | return parameters 33 | 34 | elif (types == 'shear'): 35 | shear = np.random.random_sample() * 0.8 - 0.4 36 | parameters.append(shear) 37 | return parameters 38 | 39 | elif (types == 'projective'): 40 | 41 | x1 = 0 42 | x4 = np.random.random_sample()* 0.1 + 0.1 43 | 44 | x2 = 1 - x1 45 | x3 = 1 - x4 46 | 47 | y1 = 0.005 48 | y4 = 1 - y1 49 | y2 = y1 50 | y3 = y4 51 | 52 | a31 = ((x1-x2+x3-x4)*(y4-y3) - (y1-y2+y3-y4)*(x4-x3))/((x2-x3)*(y4-y3)-(x4-x3)*(y2-y3)) 53 | a32 = ((y1-y2+y3-y4)*(x2-x3) - (x1-x2+x3-x4)*(y2-y3))/((x2-x3)*(y4-y3)-(x4-x3)*(y2-y3)) 54 | 55 | a11 = x2 - x1 + a31*x2 56 | a12 = x4 - x1 + a32*x4 57 | a13 = x1 58 | 59 | a21 = y2 - y1 + a31*y2 60 | a22 = y4 - y1 + a32*y4 61 | a23 = y1 62 | 63 | parameters.append(a11) 64 | parameters.append(a12) 65 | parameters.append(a13) 66 | parameters.append(a21) 67 | parameters.append(a22) 68 | parameters.append(a23) 69 | parameters.append(a31) 70 | parameters.append(a32) 71 | return parameters 72 | 73 | elif (types == 'wave'): 74 | mag = np.random.random_sample() * 32 75 | parameters.append(mag) 76 | return parameters 77 | 78 | 79 | def distortionModel(types, xd, yd, W, H, parameter): 80 | 81 | if (types == 'barrel' or types == 'pincushion'): 82 | Lambda = parameter[0] 83 | x0 = parameter[1] 84 | y0 = parameter[2] 85 | coeff = 1 + Lambda * ((xd - x0)**2 + (yd - y0)**2) 86 | if (coeff == 0): 87 | xu = W 88 | yu = H 89 | else: 90 | xu = (xd - x0)/coeff + x0 91 | yu = (yd - y0)/coeff + y0 92 | return xu, yu 93 | 94 | elif (types == 'rotation'): 95 | sina = parameter[0] 96 | cosa = parameter[1] 97 | xu = cosa*xd + sina*yd + (1 - sina - cosa)*W/2 98 | yu = -sina*xd + cosa*yd + (1 + sina - cosa)*H/2 99 | return xu, yu 100 | 101 | elif (types == 'shear'): 102 | shear = parameter[0] 103 | xu = xd + shear*yd - shear*W/2 104 | yu = yd 105 | return xu, yu 106 | 107 | elif (types == 'projective'): 108 | a11 = parameter[0] 109 | a12 = parameter[1] 110 | a13 = parameter[2] 111 | a21 = parameter[3] 112 | a22 = parameter[4] 113 | a23 = parameter[5] 114 | a31 = parameter[6] 115 | a32 = parameter[7] 116 | im = xd/(W - 1.0) 117 | jm = yd/(H - 1.0) 118 | xu = (W - 1.0) *(a11*im + a12*jm +a13)/(a31*im + a32*jm + 1) 119 | yu = (H - 1.0)*(a21*im + a22*jm +a23)/(a31*im + a32*jm + 1) 120 | return xu, yu 121 | 122 | elif (types == 'wave'): 123 | mag = parameter[0] 124 | yu = yd 125 | xu = xd + mag*math.sin(math.pi*4*yd/W) 126 | return xu, yu 127 | 128 | -------------------------------------------------------------------------------- /dataloaderNetM.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import scipy.io as spio 5 | import numpy as np 6 | import skimage 7 | import torch 8 | 9 | """Custom Dataset compatible with prebuilt DataLoader.""" 10 | class DistortionDataset(data.Dataset): 11 | def __init__(self, distortedImgDir, flowDir, transform, distortion_type, data_num): 12 | 13 | self.distorted_image_paths = [] 14 | self.displacement_paths = [] 15 | 16 | for fs in os.listdir(distortedImgDir): 17 | 18 | types = fs.split('_')[0] 19 | number = int(fs.split('_')[1].split('.')[0]) 20 | 21 | if types in distortion_type and number < data_num: 22 | self.distorted_image_paths.append(os.path.join(distortedImgDir, fs)) 23 | 24 | for fs in os.listdir(flowDir): 25 | 26 | types = fs.split('_')[0] 27 | number = int(fs.split('_')[1].split('.')[0]) 28 | 29 | if types in distortion_type and number < data_num: 30 | self.displacement_paths.append(os.path.join(flowDir, fs)) 31 | 32 | self.distorted_image_paths.sort() 33 | self.displacement_paths.sort() 34 | 35 | self.transform = transform 36 | 37 | def __getitem__(self, index): 38 | """Reads an image from a file and preprocesses it and returns.""" 39 | distorted_image_path = self.distorted_image_paths[index] 40 | displacement_path = self.displacement_paths[index] 41 | 42 | distorted_image =skimage.io.imread(distorted_image_path) 43 | displacement = spio.loadmat(displacement_path) 44 | 45 | displacement_x = displacement['u'].astype(np.float32) 46 | displacement_y = displacement['v'].astype(np.float32) 47 | 48 | displacement_x = displacement_x[np.newaxis,:] 49 | displacement_y = displacement_y[np.newaxis,:] 50 | 51 | label_type = distorted_image_path.split('_')[0].split('/')[-1] 52 | label = 0 53 | if (label_type == 'barrel'): 54 | label = 0 55 | elif (label_type == 'pincushion'): 56 | label = 1 57 | elif (label_type == 'rotation'): 58 | label = 2 59 | elif (label_type == 'shear'): 60 | label = 3 61 | elif (label_type == 'projective'): 62 | label = 4 63 | elif (label_type == 'randomnew'): 64 | label = 5 65 | 66 | if self.transform is not None: 67 | trans_distorted_image = self.transform(distorted_image) 68 | else: 69 | trans_distorted_image = distorted_image 70 | 71 | return trans_distorted_image, displacement_x, displacement_y, label 72 | 73 | def __len__(self): 74 | """Returns the total number of image files.""" 75 | return len(self.distorted_image_paths) 76 | 77 | 78 | def get_loader(distortedImgDir, flowDir, batch_size, distortion_type, data_num): 79 | """Builds and returns Dataloader.""" 80 | 81 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 82 | dataset = DistortionDataset(distortedImgDir, flowDir, transform, distortion_type, data_num) 83 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True) 84 | 85 | return data_loader -------------------------------------------------------------------------------- /dataloaderNetS.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from torchvision import transforms 4 | import scipy.io as spio 5 | import numpy as np 6 | import skimage 7 | import torch 8 | 9 | """Custom Dataset compatible with prebuilt DataLoader.""" 10 | class DistortionDataset(data.Dataset): 11 | def __init__(self, distortedImgDir, flowDir, transform, distortion_type): 12 | 13 | self.distorted_image_paths = [] 14 | self.displacement_paths = [] 15 | 16 | for fs in os.listdir(distortedImgDir): 17 | types = fs.split('_')[0] 18 | if types in distortion_type: 19 | self.distorted_image_paths.append(os.path.join(distortedImgDir, fs)) 20 | 21 | for fs in os.listdir(flowDir): 22 | types = fs.split('_')[0] 23 | if types in distortion_type: 24 | self.displacement_paths.append(os.path.join(flowDir, fs)) 25 | 26 | self.distorted_image_paths.sort() 27 | self.displacement_paths.sort() 28 | 29 | self.transform = transform 30 | 31 | def __getitem__(self, index): 32 | """Reads an image from a file and preprocesses it and returns.""" 33 | distorted_image_path = self.distorted_image_paths[index] 34 | displacement_path = self.displacement_paths[index] 35 | 36 | distorted_image =skimage.io.imread(distorted_image_path) 37 | displacement = spio.loadmat(displacement_path) 38 | 39 | displacement_x = displacement['u'] 40 | displacement_y = displacement['v'] 41 | 42 | displacement_x = displacement_x[np.newaxis,:] 43 | displacement_y = displacement_y[np.newaxis,:] 44 | 45 | if self.transform is not None: 46 | trans_distorted_image = self.transform(distorted_image) 47 | else: 48 | trans_distorted_image = distorted_image 49 | 50 | return trans_distorted_image, displacement_x, displacement_y 51 | 52 | def __len__(self): 53 | """Returns the total number of image files.""" 54 | return len(self.distorted_image_paths) 55 | 56 | 57 | def get_loader(distortedImgDir, flowDir, batch_size, distortion_type): 58 | """Builds and returns Dataloader.""" 59 | 60 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 61 | 62 | dataset = DistortionDataset(distortedImgDir, flowDir, transform, distortion_type) 63 | 64 | data_loader = data.DataLoader(dataset=dataset, 65 | batch_size=batch_size, 66 | shuffle=True, drop_last=True) 67 | return data_loader -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import skimage 5 | import skimage.io as io 6 | from torchvision import transforms 7 | import numpy as np 8 | import scipy.io as scio 9 | 10 | from modelNetM import EncoderNet, DecoderNet, ClassNet, EPELoss 11 | 12 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 13 | 14 | model_en = EncoderNet([1,1,1,1,2]) 15 | model_de = DecoderNet([1,1,1,1,2]) 16 | model_class = ClassNet() 17 | 18 | if torch.cuda.device_count() > 1: 19 | print("Let's use", torch.cuda.device_count(), "GPUs!") 20 | model_en = nn.DataParallel(model_en) 21 | model_de = nn.DataParallel(model_de) 22 | model_class = nn.DataParallel(model_class) 23 | 24 | if torch.cuda.is_available(): 25 | model_en = model_en.cuda() 26 | model_de = model_de.cuda() 27 | model_class = model_class.cuda() 28 | 29 | 30 | model_en.load_state_dict(torch.load('model_en.pkl')) 31 | model_de.load_state_dict(torch.load('model_de.pkl')) 32 | model_class.load_state_dict(torch.load('model_class.pkl')) 33 | 34 | model_en.eval() 35 | model_de.eval() 36 | model_class.eval() 37 | 38 | testImgPath = '/home/xliea/Dataset256/Dataset256/test/distorted' 39 | saveFlowPath = '/home/xliea/test/flow_256/flow_cla' 40 | 41 | correct = 0 42 | for index, types in enumerate(['barrel','pincushion','rotation','shear','projective','wave']): 43 | for k in range(50000,55000): 44 | 45 | imgPath = '%s%s%s%s%s%s' % (testImgPath, '/',types,'_', str(k).zfill(6), '.jpg') 46 | disimgs = io.imread(imgPath) 47 | disimgs = transform(disimgs) 48 | 49 | use_GPU = torch.cuda.is_available() 50 | if use_GPU: 51 | disimgs = disimgs.cuda() 52 | 53 | disimgs = disimgs.view(1,3,256,256) 54 | disimgs = Variable(disimgs) 55 | 56 | middle = model_en(disimgs) 57 | flow_output = model_de(middle) 58 | clas = model_class(middle) 59 | 60 | _, predicted = torch.max(clas.data, 1) 61 | if predicted.cpu().numpy()[0] == index: 62 | correct += 1 63 | 64 | u = flow_output.data.cpu().numpy()[0][0] 65 | v = flow_output.data.cpu().numpy()[0][1] 66 | 67 | saveMatPath = '%s%s%s%s%s%s' % (saveFlowPath, '/',types,'_', str(k).zfill(6), '.mat') 68 | scio.savemat(saveMatPath, {'u': u,'v': v}) 69 | 70 | -------------------------------------------------------------------------------- /imgs/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoyu258/GeoProj/00ca7a27a5673cdf8363cd7d535505c22300c4c4/imgs/results.jpg -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | try: 6 | from StringIO import StringIO # Python 2.7 7 | except ImportError: 8 | from io import BytesIO # Python 3.x 9 | 10 | 11 | class Logger(object): 12 | 13 | def __init__(self, log_dir): 14 | """Create a summary writer logging to log_dir.""" 15 | self.writer = tf.summary.FileWriter(log_dir) 16 | 17 | def scalar_summary(self, tag, value, step): 18 | """Log a scalar variable.""" 19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 20 | self.writer.add_summary(summary, step) 21 | 22 | def image_summary(self, tag, images, step): 23 | """Log a list of images.""" 24 | 25 | img_summaries = [] 26 | for i, img in enumerate(images): 27 | # Write the image to a string 28 | try: 29 | s = StringIO() 30 | except: 31 | s = BytesIO() 32 | scipy.misc.toimage(img).save(s, format="png") 33 | 34 | # Create an Image object 35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 36 | height=img.shape[0], 37 | width=img.shape[1]) 38 | # Create a Summary value 39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 40 | 41 | # Create and write Summary 42 | summary = tf.Summary(value=img_summaries) 43 | self.writer.add_summary(summary, step) 44 | 45 | def histo_summary(self, tag, values, step, bins=1000): 46 | """Log a histogram of the tensor of values.""" 47 | 48 | # Create a histogram using numpy 49 | counts, bin_edges = np.histogram(values, bins=bins) 50 | 51 | # Fill the fields of the histogram proto 52 | hist = tf.HistogramProto() 53 | hist.min = float(np.min(values)) 54 | hist.max = float(np.max(values)) 55 | hist.num = int(np.prod(values.shape)) 56 | hist.sum = float(np.sum(values)) 57 | hist.sum_squares = float(np.sum(values**2)) 58 | 59 | # Drop the start of the first bin 60 | bin_edges = bin_edges[1:] 61 | 62 | # Add bin edges and counts 63 | for edge in bin_edges: 64 | hist.bucket_limit.append(edge) 65 | for c in counts: 66 | hist.bucket.append(c) 67 | 68 | # Create and write Summary 69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 70 | self.writer.add_summary(summary, step) 71 | self.writer.flush() 72 | -------------------------------------------------------------------------------- /modelNetM.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class plainEncoderBlock(nn.Module): 7 | def __init__(self, inChannel, outChannel, stride): 8 | 9 | super(plainEncoderBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=stride, padding=1) 11 | self.bn1 = nn.BatchNorm2d(outChannel) 12 | self.conv2 = nn.Conv2d(outChannel, outChannel, kernel_size=3, stride=1, padding=1) 13 | self.bn2 = nn.BatchNorm2d(outChannel) 14 | 15 | def forward(self, x): 16 | x = F.relu(self.bn1(self.conv1(x))) 17 | x = F.relu(self.bn2(self.conv2(x))) 18 | return x 19 | 20 | class plainDecoderBlock(nn.Module): 21 | def __init__(self, inChannel, outChannel, stride): 22 | 23 | super(plainDecoderBlock, self).__init__() 24 | self.conv1 = nn.Conv2d(inChannel, inChannel, kernel_size=3, stride=1, padding=1) 25 | self.bn1 = nn.BatchNorm2d(inChannel) 26 | 27 | if stride == 1: 28 | self.conv2 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=1, padding=1) 29 | self.bn2 = nn.BatchNorm2d(outChannel) 30 | else: 31 | self.conv2 = nn.ConvTranspose2d(inChannel, outChannel, kernel_size=2, stride=2) 32 | self.bn2 = nn.BatchNorm2d(outChannel) 33 | 34 | def forward(self, x): 35 | x = F.relu(self.bn1(self.conv1(x))) 36 | x = F.relu(self.bn2(self.conv2(x))) 37 | return x 38 | 39 | 40 | class resEncoderBlock(nn.Module): 41 | def __init__(self, inChannel, outChannel, stride): 42 | 43 | super(resEncoderBlock, self).__init__() 44 | self.conv1 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=stride, padding=1) 45 | self.bn1 = nn.BatchNorm2d(outChannel) 46 | self.conv2 = nn.Conv2d(outChannel, outChannel, kernel_size=3, stride=1, padding=1) 47 | self.bn2 = nn.BatchNorm2d(outChannel) 48 | 49 | self.downsample = None 50 | if stride != 1: 51 | self.downsample = nn.Sequential( 52 | nn.Conv2d(inChannel, outChannel, kernel_size=1, stride=stride), 53 | nn.BatchNorm2d(outChannel)) 54 | 55 | def forward(self, x): 56 | residual = x 57 | 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = self.bn2(self.conv2(out)) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = F.relu(out) 66 | return out 67 | 68 | class resDecoderBlock(nn.Module): 69 | def __init__(self, inChannel, outChannel, stride): 70 | 71 | super(resDecoderBlock, self).__init__() 72 | self.conv1 = nn.Conv2d(inChannel, inChannel, kernel_size=3, stride=1, padding=1) 73 | self.bn1 = nn.BatchNorm2d(inChannel) 74 | 75 | self.downsample = None 76 | 77 | if stride == 1: 78 | self.conv2 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=1, padding=1) 79 | self.bn2 = nn.BatchNorm2d(outChannel) 80 | else: 81 | self.conv2 = nn.ConvTranspose2d(inChannel, outChannel, kernel_size=2, stride=2) 82 | self.bn2 = nn.BatchNorm2d(outChannel) 83 | 84 | self.downsample = nn.Sequential( 85 | nn.ConvTranspose2d(inChannel, outChannel, kernel_size=1, stride=2, output_padding=1), 86 | nn.BatchNorm2d(outChannel)) 87 | 88 | def forward(self, x): 89 | residual = x 90 | 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.bn2(self.conv2(out)) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = F.relu(out) 99 | return out 100 | 101 | 102 | class EncoderNet(nn.Module): 103 | def __init__(self, layers): 104 | super(EncoderNet, self).__init__() 105 | 106 | self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 107 | self.bn = nn.BatchNorm2d(64) 108 | 109 | self.en_layer1 = self.make_encoder_layer(plainEncoderBlock, 64, 64, layers[0], stride=1) 110 | self.en_layer2 = self.make_encoder_layer(resEncoderBlock, 64, 128, layers[1], stride=2) 111 | self.en_layer3 = self.make_encoder_layer(resEncoderBlock, 128, 256, layers[2], stride=2) 112 | self.en_layer4 = self.make_encoder_layer(resEncoderBlock, 256, 512, layers[3], stride=2) 113 | self.en_layer5 = self.make_encoder_layer(resEncoderBlock, 512, 512, layers[4], stride=2) 114 | 115 | # weight initializaion with Kaiming method 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 119 | m.weight.data.normal_(0, math.sqrt(2. / n)) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | 124 | def make_encoder_layer(self, block, inChannel, outChannel, block_num, stride): 125 | layers = [] 126 | layers.append(block(inChannel, outChannel, stride=stride)) 127 | for i in range(1, block_num): 128 | layers.append(block(outChannel, outChannel, stride=1)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | 134 | x = F.relu(self.bn(self.conv(x))) 135 | 136 | x = self.en_layer1(x) #128 137 | x = self.en_layer2(x) #64 138 | x = self.en_layer3(x) #32 139 | x = self.en_layer4(x) #16 140 | x = self.en_layer5(x) #8 141 | 142 | return x 143 | 144 | 145 | class DecoderNet(nn.Module): 146 | def __init__(self, layers): 147 | super(DecoderNet, self).__init__() 148 | 149 | self.de_layer5 = self.make_decoder_layer(resDecoderBlock, 512, 512, layers[4], stride=2) 150 | self.de_layer4 = self.make_decoder_layer(resDecoderBlock, 512, 256, layers[3], stride=2) 151 | self.de_layer3 = self.make_decoder_layer(resDecoderBlock, 256, 128, layers[2], stride=2) 152 | self.de_layer2 = self.make_decoder_layer(resDecoderBlock, 128, 64, layers[1], stride=2) 153 | self.de_layer1 = self.make_decoder_layer(plainDecoderBlock, 64, 64, layers[0], stride=1) 154 | 155 | self.conv_end = nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1) 156 | 157 | # weight initializaion with Kaiming method 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 161 | m.weight.data.normal_(0, math.sqrt(2. / n)) 162 | elif isinstance(m, nn.BatchNorm2d): 163 | m.weight.data.fill_(1) 164 | m.bias.data.zero_() 165 | 166 | def make_decoder_layer(self, block, inChannel, outChannel, block_num, stride): 167 | 168 | layers = [] 169 | for i in range(0, block_num-1): 170 | layers.append(block(inChannel, inChannel, stride=1)) 171 | 172 | layers.append(block(inChannel, outChannel, stride=stride)) 173 | 174 | return nn.Sequential(*layers) 175 | 176 | def forward(self, x): 177 | 178 | x = self.de_layer5(x) #8 179 | x = self.de_layer4(x) #16 180 | x = self.de_layer3(x) #32 181 | x = self.de_layer2(x) #64 182 | x = self.de_layer1(x) #128 183 | 184 | x = self.conv_end(x) 185 | return x 186 | 187 | class ClassNet(nn.Module): 188 | def __init__(self): 189 | super(ClassNet, self).__init__() 190 | 191 | self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1) 192 | self.bn1 = nn.BatchNorm2d(512) 193 | 194 | self.conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1) 195 | self.bn2 = nn.BatchNorm2d(512) 196 | 197 | self.fc = nn.Linear(512 * 4 * 4, 6) 198 | 199 | def forward(self, x): 200 | 201 | x = F.relu(self.bn1(self.conv1(x))) 202 | x = F.relu(self.bn2(self.conv2(x))) 203 | 204 | x = x.view(x.size(0), -1) 205 | x = self.fc(x) 206 | 207 | return x 208 | 209 | class EPELoss(nn.Module): 210 | def __init__(self): 211 | super(EPELoss, self).__init__() 212 | def forward(self, output, target): 213 | lossvalue = torch.norm(output - target + 1e-16, p=2, dim=1).mean() 214 | return lossvalue -------------------------------------------------------------------------------- /modelNetS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import math 7 | 8 | W = 256 9 | H = 256 10 | 11 | x0 = W/2 12 | y0 = H/2 13 | 14 | batchSize = 32 15 | 16 | class BasicEncoderPlainBlock(nn.Module): 17 | def __init__(self, inChannel, outChannel, stride): 18 | 19 | super(BasicEncoderPlainBlock, self).__init__() 20 | self.conv1 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=stride, padding=1) 21 | self.bn1 = nn.BatchNorm2d(outChannel) 22 | self.conv2 = nn.Conv2d(outChannel, outChannel, kernel_size=3, stride=1, padding=1) 23 | self.bn2 = nn.BatchNorm2d(outChannel) 24 | 25 | def forward(self, x): 26 | x = F.relu(self.bn1(self.conv1(x))) 27 | x = F.relu(self.bn2(self.conv2(x))) 28 | return x 29 | 30 | class BasicEncoderBlock(nn.Module): 31 | def __init__(self, inChannel, outChannel, stride): 32 | 33 | super(BasicEncoderBlock, self).__init__() 34 | self.conv1 = nn.Conv2d(inChannel, outChannel, kernel_size=3, stride=stride, padding=1) 35 | self.bn1 = nn.BatchNorm2d(outChannel) 36 | self.conv2 = nn.Conv2d(outChannel, outChannel, kernel_size=3, stride=1, padding=1) 37 | self.bn2 = nn.BatchNorm2d(outChannel) 38 | 39 | self.downsample = None 40 | if stride != 1: 41 | self.downsample = nn.Sequential( 42 | nn.Conv2d(inChannel, outChannel, kernel_size=1, stride=stride), 43 | nn.BatchNorm2d(outChannel)) 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = F.relu(out) 56 | return out 57 | 58 | class GenerateLenFlow(torch.autograd.Function): 59 | 60 | @staticmethod 61 | def forward(self, Input): 62 | 63 | self.save_for_backward(Input) 64 | 65 | Input = Input.cuda() 66 | flow = torch.Tensor(batchSize, 2, H, W).cuda() 67 | 68 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 69 | i = torch.from_numpy(i).float().cuda() 70 | j = torch.from_numpy(j).float().cuda() 71 | 72 | for s in range(batchSize): 73 | coeff = 1 + Input[s] * (-1e-5) * ((i - x0)**2 + (j - y0)**2) 74 | flow[s,0] = (i-x0)/coeff + x0 - i 75 | flow[s,1] = (j-y0)/coeff + y0 - j 76 | 77 | return flow 78 | 79 | 80 | @staticmethod 81 | def backward(self, grad_output): 82 | 83 | Input, = self.saved_tensors 84 | 85 | Input = Input.cuda() 86 | grad_output = grad_output.cuda() 87 | grad_input = Variable(torch.ones(batchSize, 1), requires_grad=False).cuda() 88 | grad_current = Variable(torch.ones(batchSize, 2, H, W), requires_grad=False).cuda() 89 | 90 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 91 | i = torch.from_numpy(i).float().cuda() 92 | j = torch.from_numpy(j).float().cuda() 93 | 94 | for s in range(batchSize): 95 | r2 = (i - x0)**2 + (j - y0)**2 96 | temp = (1+r2*Input[s]*(-1e-5))**2 97 | grad_current[s,0] = (i - x0)*(-1)*r2*(-1e-5) / temp 98 | grad_current[s,1] = (j - y0)*(-1)*r2*(-1e-5) / temp 99 | 100 | grad = grad_output * grad_current 101 | 102 | for s in range(batchSize): 103 | grad_input[s,0] = torch.sum(grad[s,:,:,:]) 104 | 105 | return grad_input 106 | 107 | 108 | class GenerateRotFlow(torch.autograd.Function): 109 | @staticmethod 110 | def forward(self, Input): 111 | 112 | self.save_for_backward(Input) 113 | 114 | Input = Input.cuda() 115 | flow = torch.Tensor(batchSize, 2, H, W).cuda() 116 | 117 | sina = torch.sin(Input) 118 | cosa = torch.cos(Input) 119 | 120 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 121 | i = torch.from_numpy(i).float().cuda() 122 | j = torch.from_numpy(j).float().cuda() 123 | 124 | for s in range(batchSize): 125 | flow[s,0] = cosa[s]*i + sina[s]*j + (1 - sina[s] - cosa[s])*W/2 - i 126 | flow[s,1] = -sina[s]*i + cosa[s]*j + (1 + sina[s] - cosa[s])*H/2 - j 127 | 128 | return flow 129 | 130 | @staticmethod 131 | def backward(self, grad_output): 132 | 133 | Input, = self.saved_tensors 134 | 135 | Input = Input.cuda() 136 | grad_output = grad_output.cuda() 137 | grad_input = Variable(torch.ones(batchSize, 1), requires_grad=False).cuda() 138 | grad_current = Variable(torch.ones(batchSize, 2, H, W), requires_grad=False).cuda() 139 | 140 | sina = torch.sin(Input) 141 | cosa = torch.cos(Input) 142 | 143 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 144 | i = torch.from_numpy(i).float().cuda() 145 | j = torch.from_numpy(j).float().cuda() 146 | for s in range(batchSize): 147 | grad_current[s,0] = -sina[s]*i + cosa[s]*j + (sina[s] - cosa[s])*W/2 148 | grad_current[s,1] = -cosa[s]*i - sina[s]*j + (cosa[s] + sina[s])*H/2 149 | 150 | grad = grad_output * grad_current 151 | 152 | for s in range(batchSize): 153 | grad_input[s,0] = torch.sum(grad[s,:,:,:]) 154 | 155 | return grad_input 156 | 157 | 158 | 159 | SheFactor = 1.0/5 160 | class GenerateSheFlow(torch.autograd.Function): 161 | @staticmethod 162 | def forward(self, Input): 163 | 164 | self.save_for_backward(Input) 165 | 166 | Input = Input.cuda() 167 | flow = torch.Tensor(batchSize, 2, H, W).cuda() 168 | 169 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 170 | i = torch.from_numpy(i).float().cuda() 171 | j = torch.from_numpy(j).float().cuda() 172 | for s in range(batchSize): 173 | flow[s,0] = (Input[s]*j - Input[s]*W/2.0)*SheFactor 174 | flow[s,1] = 0 175 | 176 | return flow 177 | 178 | @staticmethod 179 | def backward(self, grad_output): 180 | 181 | Input, = self.saved_tensors 182 | 183 | Input = Input.cuda() 184 | grad_output = grad_output.cuda() 185 | grad_input = Variable(torch.ones(batchSize, 1), requires_grad=False).cuda() 186 | grad_current = Variable(torch.ones(batchSize, 2, H, W), requires_grad=False).cuda() 187 | 188 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 189 | i = torch.from_numpy(i).float().cuda() 190 | j = torch.from_numpy(j).float().cuda() 191 | for s in range(batchSize): 192 | grad_current[s,0] = (j - W/2.0)*SheFactor 193 | grad_current[s,1] = 0 194 | 195 | grad = grad_output * grad_current 196 | 197 | for s in range(batchSize): 198 | grad_input[s,0] = torch.sum(grad[s,:,:,:]) 199 | 200 | return grad_input 201 | 202 | ProFactor = 0.1 203 | class GenerateProFlow(torch.autograd.Function): 204 | @staticmethod 205 | def forward(self, Input): 206 | 207 | self.save_for_backward(Input) 208 | 209 | Input = Input.cuda() 210 | flow = torch.Tensor(batchSize, 2, H, W).cuda() 211 | 212 | x4 = Input * ProFactor 213 | 214 | a31 = 0 215 | a32 = 2*x4/(1-2*x4) 216 | 217 | a11 = 1 218 | a12 = x4/(1-2*x4) 219 | a13 = 0 220 | 221 | a21 = 0 222 | a22 = 0.99 + a32*0.995 223 | a23 = 0.005 224 | 225 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 226 | i = torch.from_numpy(i).float().cuda() 227 | j = torch.from_numpy(j).float().cuda() 228 | for s in range(batchSize): 229 | 230 | im = i/(W - 1.0) 231 | jm = j/(H - 1.0) 232 | 233 | flow[s,0] = (W - 1.0)*(a11*im + a12[s]*jm +a13)/(a31*im + a32[s]*jm + 1) - i 234 | flow[s,1] = (H - 1.0)*(a21*im + a22[s]*jm +a23)/(a31*im + a32[s]*jm + 1) - j 235 | 236 | 237 | return flow 238 | 239 | @staticmethod 240 | def backward(self, grad_output): 241 | 242 | Input, = self.saved_tensors 243 | 244 | Input = Input.cuda() 245 | grad_output = grad_output.cuda() 246 | grad_input = Variable(torch.ones(batchSize, 1), requires_grad=False).cuda() 247 | grad_current = Variable(torch.ones(batchSize, 2, H, W), requires_grad=False).cuda() 248 | 249 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 250 | i = torch.from_numpy(i).float().cuda() 251 | j = torch.from_numpy(j).float().cuda() 252 | for s in range(batchSize): 253 | 254 | x = Input[s]*ProFactor 255 | 256 | t00 = -1.0*((2*H-2)*i+(1-H)*W+H-1)*j 257 | t01 = (4*j*j+(8-8*H)*j+4*H*H-8*H+4)*x*x+((4*H-4)*j-4*H*H+8*H-4)*x+H*H-2*H+1 258 | grad_current[s,0] = ProFactor*t00/t01 259 | 260 | t10 = -1.0*((99*H-99)*j*j+(-99*H*H+198*H-99)*j) 261 | t11 = (200*j*j+(400-400*H)*j+200*H*H-400*H+200)*x*x+((200*H-200)*j-200*H*H+400*H-200)*x+50*H*H-100*H+50 262 | grad_current[s,1] = ProFactor*t10/t11 263 | 264 | grad = grad_output * grad_current 265 | 266 | for s in range(batchSize): 267 | grad_input[s,0] = torch.sum(grad[s,:,:,:]) 268 | 269 | return grad_input 270 | 271 | class GenerateWavFlow(torch.autograd.Function): 272 | @staticmethod 273 | def forward(self, Input): 274 | 275 | self.save_for_backward(Input) 276 | 277 | Input = Input.cuda() 278 | flow = torch.Tensor(batchSize, 2, H, W).cuda() 279 | 280 | temp = torch.ones(batchSize, 1).cuda() 281 | 282 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 283 | i = torch.from_numpy(i).float().cuda() 284 | j = torch.from_numpy(j).float().cuda() 285 | for s in range(batchSize): 286 | flow[s,0] = Input[s]*torch.sin(math.pi*4*j/(W*2)) 287 | flow[s,1] = 0 288 | 289 | return flow 290 | 291 | @staticmethod 292 | def backward(self, grad_output): 293 | 294 | Input, = self.saved_tensors 295 | 296 | Input = Input.cuda() 297 | grad_output = grad_output.cuda() 298 | grad_input = Variable(torch.ones(batchSize, 1), requires_grad=False).cuda() 299 | grad_current = Variable(torch.ones(batchSize, 2, H, W), requires_grad=False).cuda() 300 | 301 | temp = torch.ones(batchSize, 1).cuda() 302 | 303 | i, j = np.meshgrid(np.arange(H), np.arange(W), indexing='xy') 304 | i = torch.from_numpy(i).float().cuda() 305 | j = torch.from_numpy(j).float().cuda() 306 | for s in range(batchSize): 307 | grad_current[s,0] = torch.sin(math.pi*4*j/(W*2)) 308 | grad_current[s,1] = 0 309 | 310 | 311 | grad = grad_output * grad_current 312 | 313 | for s in range(batchSize): 314 | grad_input[s,0] = torch.sum(grad[s,:,:,:]) 315 | 316 | return grad_input 317 | 318 | class EncoderNet(nn.Module): 319 | def __init__(self, layers): 320 | super(EncoderNet, self).__init__() 321 | 322 | self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 323 | self.bn = nn.BatchNorm2d(64) 324 | 325 | self.en_layer1 = self.make_encoder_layer(BasicEncoderPlainBlock, 64, 64, layers[0], stride=1) 326 | self.en_layer2 = self.make_encoder_layer(BasicEncoderBlock, 64, 128, layers[1], stride=2) 327 | self.en_layer3 = self.make_encoder_layer(BasicEncoderBlock, 128, 256, layers[2], stride=2) 328 | self.en_layer4 = self.make_encoder_layer(BasicEncoderBlock, 256, 512, layers[3], stride=2) 329 | self.en_layer5 = self.make_encoder_layer(BasicEncoderBlock, 512, 512, layers[4], stride=2) 330 | 331 | self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1) 332 | self.bn1 = nn.BatchNorm2d(512) 333 | 334 | self.conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1) 335 | self.bn2 = nn.BatchNorm2d(512) 336 | 337 | self.fc = nn.Linear(512 * 4 * 4, 1) 338 | 339 | 340 | def make_encoder_layer(self, block, inChannel, outChannel, block_num, stride): 341 | layers = [] 342 | layers.append(block(inChannel, outChannel, stride=stride)) 343 | for i in range(1, block_num): 344 | layers.append(block(outChannel, outChannel, stride=1)) 345 | 346 | return nn.Sequential(*layers) 347 | 348 | def forward(self, x): 349 | 350 | x = F.relu(self.bn(self.conv(x))) 351 | 352 | x = self.en_layer1(x) #128 353 | x = self.en_layer2(x) #64 354 | x = self.en_layer3(x) #32 355 | x = self.en_layer4(x) #16 356 | x = self.en_layer5(x) #8 357 | 358 | x = F.relu(self.bn1(self.conv1(x))) 359 | x = F.relu(self.bn2(self.conv2(x))) 360 | 361 | x = x.view(x.size(0), -1) 362 | para = self.fc(x) 363 | 364 | return para 365 | 366 | class ModelNet(nn.Module): 367 | def __init__(self, types): 368 | super(ModelNet, self).__init__() 369 | self.types = types 370 | 371 | def forward(self, x): 372 | 373 | para = x 374 | 375 | if (self.types == 'barrel' or self.types == 'pincushion'): 376 | OBJflow = GenerateLenFlow.apply 377 | flow = OBJflow(para) 378 | 379 | elif (self.types == 'rotation'): 380 | OBJflow = GenerateRotFlow.apply 381 | flow = OBJflow(para) 382 | 383 | elif (self.types == 'shear'): 384 | OBJflow = GenerateSheFlow.apply 385 | flow = OBJflow(para) 386 | 387 | elif (self.types == 'projective'): 388 | OBJflow = GenerateProFlow.apply 389 | flow = OBJflow(para) 390 | 391 | elif (self.types == 'wave'): 392 | OBJflow = GenerateWavFlow.apply 393 | flow = OBJflow(para) 394 | 395 | return flow 396 | 397 | 398 | class EPELoss(nn.Module): 399 | def __init__(self): 400 | super(EPELoss, self).__init__() 401 | def forward(self, output, target): 402 | lossvalue = torch.norm(output - target + 1e-16, p=2, dim=1).mean() 403 | return lossvalue 404 | 405 | 406 | -------------------------------------------------------------------------------- /resample/resampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.io as io 3 | from numba import cuda 4 | import math 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='resamping') 8 | parser.add_argument("--img_path", type=str, default= '/home/xliea/GeoProj/img.png') 9 | parser.add_argument("--flow_path", type=str, default= '/home/xliea/GeoProj/flow.npy') 10 | args = parser.parse_args() 11 | 12 | @cuda.jit(device=True) 13 | def iterSearchShader(padu, padv, xr, yr, maxIter, precision): 14 | 15 | H = padu.shape[0] - 1 16 | W = padu.shape[1] - 1 17 | 18 | if abs(padu[yr,xr]) < precision and abs(padv[yr,xr]) < precision: 19 | return xr, yr 20 | 21 | else: 22 | # Our initialize method in this paper, can see the overleaf for detail 23 | if (xr + 1) <= (W - 1): 24 | dif = padu[yr,xr + 1] - padu[yr,xr] 25 | u_next = padu[yr,xr]/(1 + dif) 26 | else: 27 | dif = padu[yr,xr] - padu[yr,xr - 1] 28 | u_next = padu[yr,xr]/(1 + dif) 29 | 30 | if (yr + 1) <= (H - 1): 31 | dif = padv[yr + 1,xr] - padv[yr,xr] 32 | v_next = padv[yr,xr]/(1 + dif) 33 | else: 34 | dif = padv[yr,xr] - padv[yr - 1,xr] 35 | v_next = padv[yr,xr]/(1 + dif) 36 | 37 | i = xr - u_next 38 | j = yr - v_next 39 | ''' 40 | i = xr - padu[yr,xr] 41 | j = yr - padv[yr,xr] 42 | ''' 43 | # The same as traditinal iterative search method 44 | for iter in range(maxIter): 45 | 46 | if 0<= i <= (W - 1) and 0 <= j <= (H - 1): 47 | 48 | u11 = padu[int(j), int(i)] 49 | v11 = padv[int(j), int(i)] 50 | 51 | u12 = padu[int(j), int(i) + 1] 52 | v12 = padv[int(j), int(i) + 1] 53 | 54 | u21 = padu[int(j) + 1, int(i)] 55 | v21 = padv[int(j) + 1, int(i)] 56 | 57 | u22 = padu[int(j) + 1, int(i) + 1] 58 | v22 = padv[int(j) + 1, int(i) + 1] 59 | 60 | 61 | u = u11*(int(i) + 1 - i)*(int(j) + 1 - j) + u12*(i - int(i))*(int(j) + 1 - j) + \ 62 | u21*(int(i) + 1 - i)*(j - int(j)) + u22*(i - int(i))*(j - int(j)) 63 | 64 | v = v11*(int(i) + 1 - i)*(int(j) + 1 - j) + v12*(i - int(i))*(int(j) + 1 - j) + \ 65 | v21*(int(i) + 1 - i)*(j - int(j)) + v22*(i - int(i))*(j - int(j)) 66 | 67 | i_next = xr - u 68 | j_next = yr - v 69 | 70 | if abs(i - i_next) 1: 54 | print("Let's use", torch.cuda.device_count(), "GPUs!") 55 | model_en = nn.DataParallel(model_en) 56 | model_de = nn.DataParallel(model_de) 57 | model_class = nn.DataParallel(model_class) 58 | 59 | if torch.cuda.is_available(): 60 | model_en = model_en.cuda() 61 | model_de = model_de.cuda() 62 | model_class = model_class.cuda() 63 | criterion = criterion.cuda() 64 | criterion_clas = criterion_clas.cuda() 65 | 66 | reg = args.reg 67 | lr = args.lr 68 | optimizer = torch.optim.Adam(list(model_en.parameters()) + list(model_de.parameters()) + list(model_class.parameters()), lr=lr) 69 | 70 | step = 0 71 | logger = Logger('./logs') 72 | 73 | model_en.train() 74 | model_de.train() 75 | model_class.train() 76 | 77 | for epoch in range(args.epochs): 78 | for i, (disimgs, disx, disy, labels) in enumerate(train_loader): 79 | 80 | if use_GPU: 81 | disimgs = disimgs.cuda() 82 | disx = disx.cuda() 83 | disy = disy.cuda() 84 | labels = labels.cuda() 85 | 86 | disimgs = Variable(disimgs) 87 | labels_x = Variable(disx) 88 | labels_y = Variable(disy) 89 | labels_clas = Variable(labels) 90 | flow_truth = torch.cat([labels_x, labels_y], dim=1) 91 | 92 | # Forward + Backward + Optimize 93 | optimizer.zero_grad() 94 | 95 | middle = model_en(disimgs) 96 | flow_output = model_de(middle) 97 | clas = model_class(middle) 98 | 99 | loss1 = criterion(flow_output, flow_truth) 100 | loss2 = criterion_clas(clas, labels_clas)*reg 101 | 102 | loss = loss1 + loss2 103 | 104 | loss.backward() 105 | optimizer.step() 106 | 107 | print("Epoch [%d], Iter [%d], Loss: %.4f, Loss1: %.4f, Loss2: %.4f" %(epoch + 1, i + 1, loss.data[0], loss1.data[0], loss2.data[0])) 108 | 109 | #============ TensorBoard logging ============# 110 | step = step + 1 111 | #Log the scalar values 112 | info = {'loss': loss.data[0]} 113 | for tag, value in info.items(): 114 | logger.scalar_summary(tag, value, step) 115 | 116 | torch.save(model_en.state_dict(), '%s%s%s' % ('model_en_',epoch + 1,'.pkl')) 117 | torch.save(model_de.state_dict(), '%s%s%s' % ('model_de_',epoch + 1,'.pkl')) 118 | torch.save(model_class.state_dict(), '%s%s%s' % ('model_class_',epoch + 1,'.pkl')) 119 | 120 | torch.save(model_en.state_dict(), 'model_en_last.pkl') 121 | torch.save(model_de.state_dict(), 'model_de_last.pkl') 122 | torch.save(model_class.state_dict(), 'model_class_last.pkl') 123 | 124 | # Test 125 | total = 0 126 | 127 | model_en.eval() 128 | model_de.eval() 129 | model_class.eval() 130 | 131 | for i, (disimgs, disx, disy, labels) in enumerate(val_loader): 132 | 133 | if use_GPU: 134 | disimgs = disimgs.cuda() 135 | disx = disx.cuda() 136 | disy = disy.cuda() 137 | labels = labels.cuda() 138 | 139 | disimgs = Variable(disimgs) 140 | labels_x = Variable(disx) 141 | labels_y = Variable(disy) 142 | labels_clas = Variable(labels) 143 | flow_truth = torch.cat([labels_x, labels_y], dim=1) 144 | 145 | middle = model_en(disimgs) 146 | flow_output = model_de(middle) 147 | clas = model_class(middle) 148 | 149 | loss1 = criterion(flow_output, flow_truth) 150 | loss2 = criterion_clas(clas, labels_clas)*reg 151 | 152 | loss = loss1 + loss2 153 | 154 | total = total + loss.data[0] 155 | print(loss.data[0], loss1.data[0], loss2.data[0]) 156 | 157 | print('val loss',total/(i+1)) 158 | 159 | -------------------------------------------------------------------------------- /trainNetS.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from logger import Logger 6 | import scipy.io as scio 7 | import skimage 8 | from skimage import io 9 | import numpy as np 10 | import argparse 11 | 12 | from dataloaderNetS import get_loader 13 | from modelNetS import EncoderNet, ModelNet, EPELoss 14 | 15 | parser = argparse.ArgumentParser(description='GeoNetS') 16 | parser.add_argument('--dataset_type', type=int, default=0, metavar='N') 17 | parser.add_argument('--batch_size', type=int, default=32, metavar='N') 18 | parser.add_argument('--epochs', type=int, default= 8, metavar='N') 19 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR') 20 | parser.add_argument("--dataset_dir", type=str, default='/home/xliea/GeoProj/Dataset/Dataset_256') 21 | args = parser.parse_args() 22 | 23 | if(args.dataset_type == 0): 24 | distortion_type = ['barrel'] 25 | elif(args.dataset_type == 1): 26 | distortion_type = ['pincushion'] 27 | elif(args.dataset_type == 2): 28 | distortion_type = ['rotation'] 29 | elif(args.dataset_type == 3): 30 | distortion_type = ['shear'] 31 | elif(args.dataset_type == 4): 32 | distortion_type = ['projective'] 33 | elif(args.dataset_type == 5): 34 | distortion_type = ['wave'] 35 | 36 | use_GPU = torch.cuda.is_available() 37 | 38 | train_loader = get_loader(distortedImgDir = '%s%s' % (args.dataset_dir, '/train/distorted'), 39 | flowDir = '%s%s' % (args.dataset_dir, '/train/uv'), 40 | batch_size = args.batch_size, 41 | distortion_type = distortion_type) 42 | 43 | val_loader = get_loader(distortedImgDir = '%s%s' % (args.dataset_dir, '/test/distorted'), 44 | flowDir = '%s%s' % (args.dataset_dir, '/test/uv'), 45 | batch_size = args.batch_size, 46 | distortion_type = distortion_type) 47 | 48 | model_1 = EncoderNet([1,1,1,1,2]) 49 | model_2 = ModelNet(distortion_type[0]) 50 | criterion = EPELoss() 51 | 52 | print('dataset type:',distortion_type) 53 | print('batch size:', args.batch_size) 54 | print('epochs:', args.epochs) 55 | print('lr:', args.lr) 56 | print('train_loader',len(train_loader)) 57 | print('val_loader', len(val_loader)) 58 | print(model_1) 59 | print(model_2) 60 | print(criterion) 61 | print(torch.cuda.is_available()) 62 | 63 | if torch.cuda.device_count() > 1: 64 | print("Let's use", torch.cuda.device_count(), "GPUs!") 65 | model_1 = nn.DataParallel(model_1) 66 | 67 | if torch.cuda.is_available(): 68 | model_1 = model_1.cuda() 69 | model_2 = model_2.cuda() 70 | criterion = criterion.cuda() 71 | 72 | lr = args.lr 73 | optimizer = torch.optim.Adam(model_1.parameters(), lr=lr) 74 | 75 | # Set the logger 76 | step = 0 77 | logger = Logger('./logs') 78 | 79 | model_1.train() 80 | model_2.train() 81 | for epoch in range(args.epochs): 82 | for i, (disimgs, disx, disy) in enumerate(train_loader): 83 | 84 | if use_GPU: 85 | disimgs = disimgs.cuda() 86 | disx = disx.cuda() 87 | disy = disy.cuda() 88 | 89 | disimgs = Variable(disimgs) 90 | labels_x = Variable(disx) 91 | labels_y = Variable(disy) 92 | 93 | # Forward + Backward + Optimize 94 | optimizer.zero_grad() 95 | 96 | flow_truth = torch.cat([labels_x, labels_y], dim=1) 97 | flow_output = model_2(model_1(disimgs)) 98 | loss = criterion(flow_output, flow_truth) 99 | 100 | loss.backward() 101 | optimizer.step() 102 | 103 | print("Epoch [%d], Iter [%d], Loss: %.8f" %(epoch + 1, i + 1, loss.data[0].item())) 104 | 105 | #============ TensorBoard logging ============# 106 | step = step + 1 107 | #Log the scalar values 108 | info = {'loss': loss.data[0]} 109 | for tag, value in info.items(): 110 | logger.scalar_summary(tag, value, step) 111 | 112 | # Decaying Learning Rate 113 | if (epoch + 1) % 2 == 0: 114 | lr /= 2 115 | optimizer = torch.optim.Adam(model_1.parameters(), lr=lr) 116 | 117 | 118 | torch.save(model_1.state_dict(), '%s%s%s%s' % (distortion_type[0],'_', args.lr, '_model_1.pkl')) 119 | torch.save(model_2.state_dict(), '%s%s%s%s' % (distortion_type[0],'_', args.lr, '_model_2.pkl')) 120 | 121 | # Test 122 | model_1.eval() 123 | model_2.eval() 124 | total = 0 125 | for i, (disimgs, disx, disy) in enumerate(val_loader): 126 | 127 | if use_GPU: 128 | disimgs = disimgs.cuda() 129 | disx = disx.cuda() 130 | disy = disy.cuda() 131 | 132 | disimgs = Variable(disimgs) 133 | labels_x = Variable(disx) 134 | labels_y = Variable(disy) 135 | 136 | flow_truth = torch.cat([labels_x, labels_y], dim=1) 137 | flow_output = model_2(model_1(disimgs)) 138 | loss = criterion(flow_output, flow_truth) 139 | 140 | print(loss.data[0].item()) 141 | total = total + loss.data[0] 142 | 143 | print('test loss',total/(i+1).item()) 144 | 145 | --------------------------------------------------------------------------------