├── .gitignore ├── EdgeLoss.py ├── README.md ├── TVLoss.py ├── Test_SSIM.py ├── checkpoints └── 1.txt ├── input └── 000297.jpg ├── lap.py ├── makedataset.py ├── model.py ├── output └── 000297_DDNet.jpg ├── prepare_patches.py ├── test.py └── utils_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /EdgeLoss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Aug 15 14:37:45 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | def Laplacian(x): 13 | weight=torch.tensor([ 14 | [[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]],[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]],[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]]], 15 | [[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]],[[8.,0.,0.],[0.,8.,0.],[0.,0.,8.]],[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]]], 16 | [[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]],[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]],[[-1.,0.,0.],[0.,-1.,0.],[0.,0.,-1.]]] 17 | ]).cuda() 18 | 19 | 20 | frame= nn.functional.conv2d(x, weight, bias=None, stride=1, padding=1, dilation=1, groups=1) 21 | 22 | return frame 23 | 24 | 25 | def edge(x, imitation): 26 | 27 | def inference_mse_loss(frame_hr, frame_sr): 28 | content_base_loss = torch.mean(torch.sqrt((frame_hr - frame_sr) ** 2+(1e-3)**2)) 29 | return torch.mean(content_base_loss) 30 | 31 | x_edge = Laplacian(x) 32 | imitation_edge = Laplacian(imitation) 33 | edge_loss = inference_mse_loss(x_edge, imitation_edge) 34 | 35 | return edge_loss 36 | 37 | class edgeloss(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | 41 | def forward(self, out_image, gt_image): 42 | 43 | loss = edge(out_image,gt_image) 44 | 45 | return loss 46 | 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Double Domain Guided Real-Time Low-Light Image Enhancement for Ultra-High-Definition Transportation Surveillance 2 | 3 | 4 | ## 1. Requirement ## 5 | * __Python__ == 3.7 6 | * __Torch__ == 1.12.0 7 | 8 | ## 2. Test platform 9 | * The experimental computational device is a PC with an AMD EPYC 7543 32-Core Processor CPU accelerated by an Nvidia A40 GPU, which is also widely used in industrial-grade servers (e.g., Advantech SKY-6000 series and Thinkmate GPX servers). 10 | 11 | ## 3. Test 12 | * Put the test images into the input floder 13 | * Run test.py 14 | * The results will be saved into the output floder. 15 | * For the time testing, the inference time is tested by the test code ending time - test code starting time. 16 | * To calculate the exact cuda ending time, you should add the 'torch.cuda.synchronize()' before the ending time recorded. (Thanks for the reminder from @CuddleSabe) 17 | 18 | ## 4. Downloads 19 | * The checkpoint and UHD test data are available at: https://pan.baidu.com/s/1LGc7ox7QyLIdEAahmwYtxg Pass code:mipc 20 | * Google Drive: https://drive.google.com/file/d/1X18X50iMKfRrGgrr1PT6tE8_ubqTAnpx/view?usp=drive_link 21 | 22 | * PS: Due to our personal reasons (Graduation and changing computers), I have connected with my coorperator and we can't find the training code. We are sorry about that. 23 | -------------------------------------------------------------------------------- /TVLoss.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class tvloss(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, est_noise, gt_noise): 11 | h_x = est_noise.size()[2] 12 | w_x = est_noise.size()[3] 13 | count_h = self._tensor_size(est_noise[:, :, 1:, :]) 14 | count_w = self._tensor_size(est_noise[:, :, : ,1:]) 15 | h_tv = torch.pow((est_noise[:, :, 1:, :] - est_noise[:, :, :h_x-1, :]), 2).sum() 16 | w_tv = torch.pow((est_noise[:, :, :, 1:] - est_noise[:, :, :, :w_x-1]), 2).sum() 17 | loss = h_tv / count_h + w_tv / count_w 18 | 19 | return loss 20 | 21 | def _tensor_size(self,t): 22 | return t.size()[1]*t.size()[2]*t.size()[3] 23 | -------------------------------------------------------------------------------- /Test_SSIM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Nov 2 13:56:32 2021 4 | 5 | @author: 13362 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | import math 11 | import os 12 | 13 | 14 | def ssim(img1, img2): 15 | C1 = (0.01 * 255)**2 16 | C2 = (0.03 * 255)**2 17 | img1 = img1.astype(np.float64) 18 | img2 = img2.astype(np.float64) 19 | kernel = cv2.getGaussianKernel(11, 1.5) 20 | window = np.outer(kernel, kernel.transpose()) 21 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 22 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 23 | mu1_sq = mu1**2 24 | mu2_sq = mu2**2 25 | mu1_mu2 = mu1 * mu2 26 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 27 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 28 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 29 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 30 | (sigma1_sq + sigma2_sq + C2)) 31 | return ssim_map.mean() 32 | def calculate_ssim(img1, img2): 33 | '''calculate SSIM 34 | the same outputs as MATLAB's 35 | img1, img2: [0, 255] 36 | ''' 37 | if not img1.shape == img2.shape: 38 | raise ValueError('Input images must have the same dimensions.') 39 | if img1.ndim == 2: 40 | return ssim(img1, img2) 41 | elif img1.ndim == 3: 42 | if img1.shape[2] == 3: 43 | ssims = [] 44 | for i in range(3): 45 | ssims.append(ssim(img1, img2)) 46 | return np.array(ssims).mean() 47 | elif img1.shape[2] == 1: 48 | return ssim(np.squeeze(img1), np.squeeze(img2)) 49 | else: 50 | raise ValueError('Wrong input image dimensions.') 51 | 52 | 53 | def psnr1(img1, img2): 54 | mse = np.mean((img1 - img2) ** 2 ) 55 | if mse < 1.0e-10: 56 | return 100 57 | return 10 * math.log10(255.0**2/mse) 58 | 59 | def psnr(target, ref): 60 | 61 | target_data = np.array(target, dtype=np.float64) 62 | ref_data = np.array(ref,dtype=np.float64) 63 | 64 | diff = ref_data - target_data 65 | diff = diff.flatten('C') 66 | rmse = math.sqrt(np.mean(diff ** 2.)) 67 | 68 | eps = np.finfo(np.float64).eps 69 | if(rmse == 0): 70 | rmse = eps 71 | return 20*math.log10(255.0/rmse) 72 | 73 | 74 | def C_PSNR_SSIM(): 75 | files = os.listdir('./Clear') 76 | PSNR = 0 77 | SSIM = 0 78 | PSNR_STD = 0 79 | SSIM_STD = 0 80 | for i in range(len(files)): 81 | img1 = cv2.imread('./Clear/' + files[i]) 82 | img2 = cv2.imread('./GPANet/' + files[i][:-4] + '_GPANet.png') 83 | 84 | ss = calculate_ssim(img1, img2) 85 | ps = psnr(img1, img2) 86 | SSIM +=ss 87 | PSNR +=ps 88 | 89 | return PSNR/15,SSIM/15 90 | 91 | print(C_PSNR_SSIM()) 92 | # ============================================================================= 93 | # files = os.listdir('./Clear') 94 | # for i in range(len(files)): 95 | # img1 = cv2.imread('./Clear/' + files[i]) 96 | # img2 = cv2.imread('./GPANet/' + files[i][:-4] + '_GPANet.png') 97 | # #img2 = cv2.imread('./LLFlow/' + files[i]) 98 | # 99 | # ss = calculate_ssim(img1, img2) 100 | # ps = psnr(img1, img2) 101 | # print(ss) 102 | # ============================================================================= 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /checkpoints/1.txt: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /input/000297.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuJX/DDNet/2f9cc98126c35e177d9cc5246c7850dc58f2955c/input/000297.jpg -------------------------------------------------------------------------------- /lap.py: -------------------------------------------------------------------------------- 1 | import os, time, scipy.io, shutil 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | #from tensorboardX import SummaryWriter 8 | import numpy as np 9 | import cv2 10 | import scipy.misc 11 | from model_0 import * 12 | from makedataset import Dataset 13 | import utils_train 14 | from Test_SSIM import * 15 | from EdgeLoss import edgeloss 16 | from TVLoss import tvloss 17 | 18 | def hwc_to_chw(img): 19 | return np.transpose(img, axes=[2, 0, 1]) 20 | 21 | def chw_to_hwc(img): 22 | return np.transpose(img, axes=[1, 2, 0]) 23 | 24 | def GFLap(data): 25 | x = cv2.GaussianBlur(data, (3,3),0) 26 | x = cv2.Laplacian(np.clip(x*255,0,255).astype('uint8'),cv2.CV_8U,ksize =3) 27 | Lap = cv2.convertScaleAbs(x) 28 | return Lap/255.0 29 | 30 | 31 | img = cv2.imread('2015_00719.jpg')/255.0 32 | cv2.imwrite('output2.jpg',GFLap(img)*255.0) 33 | 34 | -------------------------------------------------------------------------------- /makedataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Feb 12 20:00:46 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | import os 9 | import os.path 10 | import random 11 | import numpy as np 12 | import cv2 13 | import h5py 14 | import torch 15 | import torch.utils.data as udata 16 | 17 | class Dataset(udata.Dataset): 18 | r"""Implements torch.utils.data.Dataset 19 | """ 20 | def __init__(self, trainrgb=True,trainsyn = True, shuffle=False): 21 | super(Dataset, self).__init__() 22 | self.trainrgb = trainrgb 23 | self.trainsyn = trainsyn 24 | self.train_haze = 'train_ImageEdge.h5' 25 | 26 | if self.trainrgb: 27 | if self.trainsyn: 28 | h5f = h5py.File(self.train_haze, 'r') 29 | else: 30 | h5f = h5py.File(self.train_real_rgb, 'r') 31 | else: 32 | if self.trainsyn: 33 | h5f = h5py.File(self.train_syn_gray, 'r') 34 | else: 35 | h5f = h5py.File(self.train_real_gray, 'r') 36 | self.keys = list(h5f.keys()) 37 | if shuffle: 38 | random.shuffle(self.keys) 39 | h5f.close() 40 | 41 | def __len__(self): 42 | return len(self.keys) 43 | 44 | def __getitem__(self, index): 45 | if self.trainrgb: 46 | if self.trainsyn: 47 | h5f = h5py.File(self.train_haze, 'r') 48 | else: 49 | h5f = h5py.File(self.train_real_rgb, 'r') 50 | else: 51 | if self.trainsyn: 52 | h5f = h5py.File(self.train_syn_gray, 'r') 53 | else: 54 | h5f = h5py.File(self.train_real_gray, 'r') 55 | key = self.keys[index] 56 | data = np.array(h5f[key]) 57 | h5f.close() 58 | return torch.Tensor(data) 59 | 60 | 61 | def data_augmentation(image, mode): 62 | r"""Performs dat augmentation of the input image 63 | 64 | Args: 65 | image: a cv2 (OpenCV) image 66 | mode: int. Choice of transformation to apply to the image 67 | 0 - no transformation 68 | 1 - flip up and down 69 | 2 - rotate counterwise 90 degree 70 | 3 - rotate 90 degree and flip up and down 71 | 4 - rotate 180 degree 72 | 5 - rotate 180 degree and flip 73 | 6 - rotate 270 degree 74 | 7 - rotate 270 degree and flip 75 | """ 76 | out = np.transpose(image, (1, 2, 0)) 77 | if mode == 0: 78 | # original 79 | out = out 80 | elif mode == 1: 81 | # flip up and down 82 | out = np.flipud(out) 83 | elif mode == 2: 84 | # rotate counterwise 90 degree 85 | out = np.rot90(out) 86 | elif mode == 3: 87 | # rotate 90 degree and flip up and down 88 | out = np.rot90(out) 89 | out = np.flipud(out) 90 | elif mode == 4: 91 | # rotate 180 degree 92 | out = np.rot90(out, k=2) 93 | elif mode == 5: 94 | # rotate 180 degree and flip 95 | out = np.rot90(out, k=2) 96 | out = np.flipud(out) 97 | elif mode == 6: 98 | # rotate 270 degree 99 | out = np.rot90(out, k=3) 100 | elif mode == 7: 101 | # rotate 270 degree and flip 102 | out = np.rot90(out, k=3) 103 | out = np.flipud(out) 104 | else: 105 | raise Exception('Invalid choice of image transformation') 106 | return np.transpose(out, (2, 0, 1)) 107 | 108 | def img_to_patches(img,win,stride,Syn=True): 109 | 110 | chl,raw,col = img.shape 111 | chl = int(chl) 112 | num_raw = np.ceil((raw-win)/stride+1).astype(np.uint8) 113 | num_col = np.ceil((col-win)/stride+1).astype(np.uint8) 114 | count = 0 115 | total_process = int(num_col)*int(num_raw) 116 | img_patches = np.zeros([chl,win,win,total_process]) 117 | if Syn: 118 | for i in range(num_raw): 119 | for j in range(num_col): 120 | if stride * i + win <= raw and stride * j + win <=col: 121 | img_patches[:,:,:,count] = img[:,stride*i : stride*i + win, stride*j : stride*j + win] 122 | elif stride * i + win > raw and stride * j + win<=col: 123 | img_patches[:,:,:,count] = img[:,raw-win : raw,stride * j : stride * j + win] 124 | elif stride * i + win <= raw and stride*j + win>col: 125 | img_patches[:,:,:,count] = img[:,stride*i : stride*i + win, col-win : col] 126 | else: 127 | img_patches[:,:,:,count] = img[:,raw-win : raw,col-win : col] 128 | count +=1 129 | 130 | return img_patches 131 | 132 | 133 | def readfiles(filepath): 134 | '''Get dataset images names''' 135 | files = os.listdir(filepath) 136 | return files 137 | 138 | def normalize(data): 139 | 140 | return np.float32(data/255.0) 141 | 142 | def samesize(img,size): 143 | 144 | img = cv2.resize(img,size) 145 | return img 146 | 147 | def concatenate2imgs(img,depth): 148 | c,w,h = img.shape 149 | conimg = np.zeros((c+c,w,h)) 150 | conimg[0:c,:,:] = img 151 | conimg[c:2*c,:,:] = depth 152 | 153 | return conimg 154 | 155 | def Edge_TrainSynRGB(img_filepath, depth_filepath, patch_size, stride): 156 | '''synthetic ImageEdge images''' 157 | train_haze = 'train_ImageEdge.h5' 158 | img_files = readfiles(img_filepath) 159 | count = 0 160 | scales = [1.0]#[0.6,0.8,1.0] 161 | 162 | with h5py.File(train_haze, 'w') as h5f: 163 | for i in range(len(img_files)): 164 | filename = img_files[i][:-4] 165 | oimg = cv2.imread(img_filepath + '/' + filename + '.png') 166 | 167 | odepth = cv2.imread(depth_filepath + '/' + filename + '.png') 168 | 169 | 170 | for sca in scales: 171 | #img = cv2.resize(oimg, (0, 0), fx=sca, fy=sca, interpolation=cv2.INTER_CUBIC) 172 | #depth = cv2.resize(odepth, (0, 0), fx=sca, fy=sca, interpolation=cv2.INTER_CUBIC) 173 | 174 | img = oimg.transpose(2, 0, 1) 175 | depth = odepth.transpose(2, 0, 1) 176 | #depth = depth.transpose((1,0)) 177 | 178 | print(img.shape,depth.shape) 179 | 180 | img = normalize(img) 181 | depth = normalize(depth) 182 | img_depth = concatenate2imgs(img,depth) 183 | img_patches = img_to_patches(img_depth, win=patch_size, stride=stride) 184 | print("\tfile: %s scale %.1f # samples: %d" %(img_files[i], sca,img_patches.shape[3])) 185 | for nx in range(img_patches.shape[3]): 186 | data = data_augmentation(img_patches[:, :, :, nx].copy(), np.random.randint(0, 7)) 187 | h5f.create_dataset(str(count), data=data) 188 | count += 1 189 | i += 1 190 | print(data.shape) 191 | h5f.close() 192 | 193 | 194 | def TrainSynRGB(img_filepath, patch_size, stride): 195 | '''synthetic Haze images''' 196 | train_haze = 'train_haze.h5' 197 | img_files = readfiles(img_filepath) 198 | count = 0 199 | scales = [0.6,0.8,1.0] 200 | 201 | with h5py.File(train_haze, 'w') as h5f: 202 | for i in range(len(img_files)): 203 | filename = img_files[i] 204 | o_img = cv2.imread(img_filepath + '/' + filename) 205 | o_img = cv2.resize(o_img,(360,360)) 206 | 207 | #img= samesize(img,(360,360)) 208 | for sca in scales: 209 | img = o_img 210 | 211 | img = cv2.resize(o_img, (0, 0), fx=sca, fy=sca, interpolation=cv2.INTER_CUBIC) 212 | img = img.transpose(2, 0, 1) 213 | 214 | img = normalize(img) 215 | img_patches = img_to_patches(img, win=patch_size, stride=stride) 216 | print("\tfile: %s scale %.1f # samples: %d" %(img_files[i], sca,img_patches.shape[3])) 217 | for nx in range(img_patches.shape[3]): 218 | data = data_augmentation(img_patches[:, :, :, nx].copy(), np.random.randint(0, 7)) 219 | h5f.create_dataset(str(count), data=data) 220 | count += 1 221 | i += 1 222 | print(data.shape) 223 | h5f.close() 224 | 225 | 226 | def TrainSynRGB_NA(img_filepath, patch_size, stride): 227 | '''synthetic Haze images''' 228 | train_haze = 'train_haze.h5' 229 | img_files = readfiles(img_filepath) 230 | count = 0 231 | scales = [1.0]#[0.6,0.8,1.0] 232 | 233 | with h5py.File(train_haze, 'w') as h5f: 234 | for i in range(len(img_files)): 235 | filename = img_files[i] 236 | oooimg = cv2.imread(img_filepath + '/' + filename) 237 | img = cv2.resize(oooimg,(256,256)) 238 | 239 | for sca in scales: 240 | img = cv2.resize(img, (0, 0), fx=sca, fy=sca, interpolation=cv2.INTER_CUBIC) 241 | img = img.transpose(2, 0, 1) 242 | img = normalize(img) 243 | print("\tfile: %s scale %.1f" %(img_files[i], sca)) 244 | data = data_augmentation(img.copy(), np.random.randint(0, 7)) 245 | h5f.create_dataset(str(count), data=data) 246 | count += 1 247 | i += 1 248 | print(data.shape) 249 | h5f.close() 250 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jun 20 16:14:37 2021 4 | 5 | @author: Administrator 6 | """ 7 | 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | 18 | 19 | class Main(nn.Module): 20 | def __init__(self): 21 | super(Main,self).__init__() 22 | 23 | self.left = LeftED(4,32) 24 | self.right = RightED(3,32) 25 | 26 | def forward(self,x,xgl): 27 | 28 | x_fout, x_eout, ed1, ed2, ed3, fd1, fd2, fd3, e1, e2, e3 = self.left(x,xgl) 29 | x_out = self.right(x, ed1, ed2, ed3, fd1, fd2, fd3, e1, e2, e3) 30 | 31 | return x_fout, x_eout, x_out 32 | 33 | class LeftED(nn.Module): 34 | def __init__(self,inchannel,channel): 35 | super(LeftED,self).__init__() 36 | 37 | self.e1 = ResUnit(channel) 38 | self.e2 = ResUnit(channel*2) 39 | self.e3 = ResUnit(channel*4) 40 | 41 | 42 | 43 | self.ed1 = ResUnit(channel*2) 44 | self.ed2 = ResUnit(channel*1) 45 | self.ed3 = ResUnit(int(channel*0.5)) 46 | 47 | self.fd1 = ResUnit(channel*2) 48 | self.fd2 = ResUnit(channel*1) 49 | self.fd3 = ResUnit(int(channel*0.5)) 50 | 51 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 52 | 53 | self.conv_in = nn.Conv2d(inchannel,channel,kernel_size=3,stride=1,padding=1,bias=False) 54 | 55 | self.conv_eout = nn.Conv2d(int(0.5*channel),3,kernel_size=1,stride=1,padding=0,bias=False) 56 | self.conv_fout = nn.Conv2d(int(0.5*channel),1,kernel_size=1,stride=1,padding=0,bias=False) 57 | 58 | self.conv_e1te2 = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 59 | self.conv_e2te3 = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False) 60 | 61 | self.conv_e1_a = nn.Conv2d(channel,int(0.5*channel),kernel_size=1,stride=1,padding=0,bias=False) 62 | self.conv_e2_a = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 63 | self.conv_e3_a = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 64 | 65 | 66 | self.conv_fd1td2 = nn.Conv2d(2*channel,1*channel,kernel_size=1,stride=1,padding=0,bias=False) 67 | self.conv_fd2td3 = nn.Conv2d(channel,int(0.5*channel),kernel_size=1,stride=1,padding=0,bias=False) 68 | 69 | 70 | 71 | self.conv_ed1td2 = nn.Conv2d(2*channel,1*channel,kernel_size=1,stride=1,padding=0,bias=False) 72 | self.conv_ed2td3 = nn.Conv2d(channel,int(0.5*channel),kernel_size=1,stride=1,padding=0,bias=False) 73 | 74 | 75 | def _upsample(self,x,y): 76 | _,_,H,W = y.size() 77 | return F.upsample(x,size=(H,W),mode='bilinear') 78 | 79 | def forward(self,x,xgl): 80 | 81 | x_in = self.conv_in(torch.cat((x,xgl),1)) 82 | 83 | e1 = self.e1(x_in) 84 | e2 = self.e2(self.conv_e1te2(self.maxpool(e1))) 85 | e3 = self.e3(self.conv_e2te3(self.maxpool(e2))) 86 | 87 | e1_a = self.conv_e1_a(e1) 88 | e2_a = self.conv_e2_a(e2) 89 | e3_a = self.conv_e3_a(e3) 90 | 91 | 92 | fd1 = self.fd1(e3_a) 93 | fd2 = self.fd2(self.conv_fd1td2(self._upsample(fd1,e2)) + e2_a) 94 | fd3 = self.fd3(self.conv_fd2td3(self._upsample(fd2,e1)) + e1_a) 95 | 96 | 97 | ed1 = self.ed1(e3_a + fd1) 98 | ed2 = self.ed2(self.conv_ed1td2(self._upsample(ed1,e2)) + fd2 + e2_a) 99 | ed3 = self.ed3(self.conv_ed2td3(self._upsample(ed2,e1)) + fd3 + e1_a) 100 | 101 | 102 | x_fout = self.conv_fout(fd3) 103 | x_eout = self.conv_eout(ed3) 104 | 105 | return x_fout, x_eout, ed1, ed2, ed3, fd1, fd2, fd3, e1, e2, e3 106 | 107 | 108 | 109 | class RightED(nn.Module): 110 | def __init__(self,inchannel,channel): 111 | super(RightED,self).__init__() 112 | 113 | self.ee1 = ResUnit(int(0.5*channel)) 114 | self.ee2 = ResUnit(channel) 115 | self.ee3 = ResUnit(channel*2) 116 | 117 | self.fe1 = ResUnit(int(0.5*channel)) 118 | self.fe2 = ResUnit(channel) 119 | self.fe3 = ResUnit(channel*2) 120 | 121 | self.d1 = ResUnit(channel*4) 122 | self.d2 = ResUnit(channel*2) 123 | self.d3 = ResUnit(channel) 124 | 125 | self.d4 = ResUnit(channel) 126 | 127 | self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) 128 | 129 | 130 | self.conv_out = nn.Conv2d(channel,3,kernel_size=3,stride=1,padding=1,bias=False) 131 | 132 | self.conv_fe0te1 = nn.Conv2d(int(0.5*channel),channel,kernel_size=1,stride=1,padding=0,bias=False) 133 | self.conv_fe1te2 = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 134 | 135 | self.conv_ee0te1 = nn.Conv2d(int(0.5*channel),channel,kernel_size=1,stride=1,padding=0,bias=False) 136 | self.conv_ee1te2 = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 137 | 138 | self.conv_e0te1 = nn.Conv2d(int(1*channel),channel,kernel_size=3,stride=1,padding=1,bias=False) 139 | self.conv_e1te2 = nn.Conv2d(int(2*channel),2*channel,kernel_size=3,stride=1,padding=1,bias=False) 140 | self.conv_e2te3 = nn.Conv2d(int(4*channel),4*channel,kernel_size=3,stride=1,padding=1,bias=False) 141 | 142 | self.conv_d1td2 = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False) 143 | self.conv_d2td3 = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 144 | 145 | 146 | self.act1 = nn.PReLU(channel) 147 | self.norm1 = nn.GroupNorm(num_channels=channel,num_groups=1) 148 | 149 | self.act2 = nn.PReLU(channel*2) 150 | self.norm2 = nn.GroupNorm(num_channels=channel*2,num_groups=1) 151 | 152 | self.act3 = nn.PReLU(channel*4) 153 | self.norm3 = nn.GroupNorm(num_channels=channel*4,num_groups=1) 154 | 155 | def _upsample(self,x,y): 156 | _,_,H,W = y.size() 157 | return F.upsample(x,size=(H,W),mode='bilinear') 158 | 159 | def forward(self,x, ed1, ed2, ed3, fd1, fd2, fd3, e1, e2, e3): 160 | 161 | 162 | fe1 = self.fe1(fd3) 163 | fe2 = self.fe2(self.conv_fe0te1(self.maxpool(fe1)) + fd2) 164 | fe3 = self.fe3(self.conv_fe1te2(self.maxpool(fe2)) + fd1) 165 | 166 | ee1 = self.ee1(ed3 + fe1) 167 | ee2 = self.ee2(self.conv_ee0te1(self.maxpool(ee1)) + fe2 + ed2) 168 | ee3 = self.ee3(self.conv_ee1te2(self.maxpool(ee2)) + fe3 + ed1) 169 | 170 | fde1 = self.act1(self.norm1(self.conv_e0te1(torch.cat((ee1 , fe1),1)))) 171 | fde2 = self.act2(self.norm2(self.conv_e1te2(torch.cat((ee2 , fe2),1)))) 172 | fde3 = self.act3(self.norm3(self.conv_e2te3(torch.cat((ee3 , fe3),1)))) 173 | 174 | d1 = self.d1(fde3 + e3) 175 | d2 = self.d2(self.conv_d1td2(self._upsample(d1,e2)) + fde2 + e2) 176 | d3 = self.d3(self.conv_d2td3(self._upsample(d2,e1)) + fde1 + e1) 177 | 178 | 179 | x_out = self.conv_out(self.d4(d3)) 180 | 181 | return x_out + x 182 | 183 | 184 | 185 | class ResUnit1(nn.Module): # Edge-oriented Residual Convolution Block 186 | def __init__(self,channel,norm=False): 187 | super(ResUnit1,self).__init__() 188 | 189 | self.conv_1 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 190 | self.conv_2 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 191 | self.conv_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 192 | 193 | self.act = nn.PReLU(channel) 194 | self.norm = nn.GroupNorm(num_channels=channel,num_groups=1)# nn.InstanceNorm2d(channel)# 195 | 196 | def forward(self,x): 197 | 198 | x_1 = self.act(self.norm(self.conv_1(x))) 199 | x_2 = self.act(self.norm(self.conv_2(x_1))) 200 | x_3 = self.act(self.norm(self.conv_3(x_2))+x_1) 201 | 202 | return x_3 203 | 204 | 205 | # ============================================================================= 206 | # class ResUnit1(nn.Module): 207 | # def __init__(self,channel): 208 | # super(ResUnit1,self).__init__() 209 | # 210 | # self.conv_cam_1 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 211 | # self.conv_sam_1 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 212 | # self.conv_scl_1 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 213 | # 214 | # self.conv_cam_m3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 215 | # self.conv_sam_m3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 216 | # 217 | # self.conv_cam_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 218 | # self.conv_sam_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 219 | # self.conv_scl_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 220 | # 221 | # 222 | # self.conv_11 = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 223 | # self.conv_12 = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 224 | # 225 | # self.act = nn.PReLU(channel) 226 | # 227 | # self.scl = StandardConvolutionalLayers(channel) 228 | # self.cam = ChannelAttention(channel) 229 | # self.sam = SpatialAttention() 230 | # # ============================================================================= 231 | # # self.cam = ChannelAttentionModule(channel) 232 | # # self.sam = SpatialAttentionModule(channel) 233 | # # ============================================================================= 234 | # def forward(self,x): 235 | # 236 | # x_cam_1 = self.conv_cam_1(x) 237 | # x_sam_1 = self.conv_sam_1(x) 238 | # x_scl_1 = self.conv_scl_1(x) 239 | # 240 | # x_cam_2 = self.conv_cam_m3(x_cam_1) + self.cam(x_cam_1) 241 | # x_sam_2 = self.conv_sam_m3(x_sam_1) + self.sam(x_sam_1) 242 | # x_scl_2 = self.scl(x_scl_1) 243 | # 244 | # x_cam_3 = self.conv_cam_3(x_cam_2) 245 | # x_sam_3 = self.conv_sam_3(x_sam_2) 246 | # x_scl_3 = self.conv_scl_3(x_scl_2) 247 | # 248 | # x_1 = self.conv_11(torch.cat((x_cam_3,x_sam_3),1)) 249 | # x_2 = self.conv_12(torch.cat((x_1,x_scl_3),1)) 250 | # 251 | # x_out = self.act(x_2+x) 252 | # 253 | # return x_out 254 | # 255 | # ============================================================================= 256 | 257 | class ResUnit(nn.Module): 258 | def __init__(self,channel): 259 | super(ResUnit,self).__init__() 260 | 261 | self.conv_cam_1 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 262 | self.conv_sam_1 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 263 | self.conv_scl_1 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 264 | 265 | self.conv_cam_m3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 266 | self.conv_sam_m3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 267 | 268 | self.conv_cam_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 269 | self.conv_sam_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 270 | self.conv_scl_3 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 271 | 272 | 273 | self.conv_11 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 274 | self.conv_12 = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 275 | self.conv_13 = nn.Conv2d(channel,channel,kernel_size=1,stride=1,padding=0,bias=False) 276 | 277 | self.act = nn.PReLU(channel) 278 | 279 | self.scl = StandardConvolutionalLayers(channel) 280 | self.cam = ChannelAttention(channel) 281 | self.sam = SpatialAttention() 282 | #self.cam = ChannelAttentionModule(channel) 283 | #self.sam = ChannelAttentionModule(channel)#SpatialAttentionModule(channel) 284 | self.norm = nn.GroupNorm(num_channels=channel,num_groups=1)# nn.InstanceNorm2d(channel)# 285 | 286 | 287 | def forward(self,x): 288 | 289 | #x_cam_1 = self.conv_cam_1(x) 290 | x_sam_1 = self.conv_sam_1(x) 291 | x_scl_1 = self.conv_scl_1(x) 292 | 293 | #x_cam_2 = self.conv_cam_m3(x_cam_1) * self.cam(x_cam_1) 294 | x_sam_2 = self.conv_sam_m3(x_sam_1) * self.sam(x_sam_1) 295 | x_scl_2 = self.scl(x_scl_1) 296 | 297 | #x_cam_3 = self.conv_cam_3(x_cam_2) 298 | x_sam_3 = self.scl(x_sam_2) 299 | #x_sam_3 = self.conv_sam_3(x_sam_2) 300 | #x_scl_3 = self.conv_scl_3(x_scl_2) 301 | x_scl_3 = self.scl(x_scl_2) 302 | 303 | #x_1 = self.conv_11(x_cam_3+x_sam_3) 304 | x_2 = self.conv_12(torch.cat((x_sam_3,x_scl_3),1)) 305 | 306 | x_out = self.conv_13(x_2) + x 307 | 308 | return x_out 309 | 310 | class StandardConvolutionalLayers(nn.Module): # StandardConvolutional 311 | def __init__(self,channel): 312 | super(StandardConvolutionalLayers,self).__init__() 313 | 314 | self.conv_1 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 315 | self.conv_2 = nn.Conv2d(channel,channel,kernel_size=3,stride=1,padding=1,bias=False) 316 | 317 | self.act = nn.PReLU(channel) 318 | self.norm = nn.GroupNorm(num_channels=channel,num_groups=1)# nn.InstanceNorm2d(channel)# 319 | 320 | def forward(self,x): 321 | 322 | x_1 = self.act(self.norm(self.conv_1(x))) 323 | #x_2 = self.act(self.norm(self.conv_2(x_1))) 324 | 325 | return x_1 326 | 327 | 328 | class SpatialAttentionModule(nn.Module): 329 | def __init__(self, in_channels): 330 | super(SpatialAttentionModule, self).__init__() 331 | self.query = nn.Conv2d(in_channels, in_channels // 4, kernel_size=(1, 3), padding=(0, 1)) 332 | self.key = nn.Conv2d(in_channels, in_channels // 4, kernel_size=(3, 1), padding=(1, 0)) 333 | self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1) 334 | self.gamma = nn.Parameter(torch.zeros(1)) 335 | self.softmax = nn.Softmax(dim=-1) 336 | 337 | def forward(self, x): 338 | """ 339 | :param x: input( BxCxHxW ) 340 | :return: affinity value + x 341 | """ 342 | B, C, H, W = x.size() 343 | # compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose 344 | proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1) 345 | proj_key = self.key(x).view(B, -1, W * H) 346 | affinity = torch.matmul(proj_query, proj_key) 347 | affinity = self.softmax(affinity) 348 | proj_value = self.value(x).view(B, -1, H * W) 349 | weights = torch.matmul(proj_value, affinity.permute(0, 2, 1)) 350 | weights = weights.view(B, C, H, W) 351 | out = self.gamma * weights + x 352 | return out 353 | 354 | 355 | class ChannelAttentionModule(nn.Module): 356 | def __init__(self, in_channels): 357 | super(ChannelAttentionModule, self).__init__() 358 | self.gamma = nn.Parameter(torch.zeros(1)) 359 | self.softmax = nn.Softmax(dim=-1) 360 | 361 | def forward(self, x): 362 | """ 363 | :param x: input( BxCxHxW ) 364 | :return: affinity value + x 365 | """ 366 | B, C, H, W = x.size() 367 | proj_query = x.view(B, C, -1) 368 | proj_key = x.view(B, C, -1).permute(0, 2, 1) 369 | affinity = torch.matmul(proj_query, proj_key) 370 | affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity 371 | affinity_new = self.softmax(affinity_new) 372 | proj_value = x.view(B, C, -1) 373 | weights = torch.matmul(affinity_new, proj_value) 374 | weights = weights.view(B, C, H, W) 375 | out = self.gamma * weights + x 376 | return out 377 | 378 | class ChannelAttention(nn.Module): 379 | def __init__(self, in_planes, ratio=16): 380 | super(ChannelAttention, self).__init__() 381 | #平均池化 382 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 383 | #最大池化 384 | self.max_pool = nn.AdaptiveMaxPool2d(1) 385 | 386 | #MLP 除以16是降维系数 387 | self.fc1 = nn.Conv2d(in_planes, in_planes // 4, 1, bias=False) #kernel_size=1 388 | self.relu1 = nn.ReLU() 389 | self.fc2 = nn.Conv2d(in_planes // 4, in_planes, 1, bias=False) 390 | 391 | self.sigmoid = nn.Sigmoid() 392 | 393 | def forward(self, x): 394 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 395 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 396 | #结果相加 397 | out = avg_out + max_out 398 | return self.sigmoid(out) 399 | 400 | #空间注意力 401 | class SpatialAttention(nn.Module): 402 | def __init__(self, kernel_size=7): 403 | super(SpatialAttention, self).__init__() 404 | #声明卷积核为 3 或 7 405 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 406 | #进行相应的same padding填充 407 | padding = 3 if kernel_size == 7 else 1 408 | 409 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 410 | self.sigmoid = nn.Sigmoid() 411 | 412 | def forward(self, x): 413 | avg_out = torch.mean(x, dim=1, keepdim=True) #平均池化 414 | max_out, _ = torch.max(x, dim=1, keepdim=True) #最大池化 415 | #拼接操作 416 | x = torch.cat([avg_out, max_out], dim=1) 417 | x = self.conv1(x) #7x7卷积填充为3,输入通道为2,输出通道为1 418 | return self.sigmoid(x) -------------------------------------------------------------------------------- /output/000297_DDNet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QuJX/DDNet/2f9cc98126c35e177d9cc5246c7850dc58f2955c/output/000297_DDNet.jpg -------------------------------------------------------------------------------- /prepare_patches.py: -------------------------------------------------------------------------------- 1 | """ 2 | Construction of the training and validation databases 3 | 4 | Copyright (C) 2018, Matias Tassano 5 | 6 | This program is free software: you can use, modify and/or 7 | redistribute it under the terms of the GNU General Public 8 | License as published by the Free Software Foundation, either 9 | version 3 of the License, or (at your option) any later 10 | version. You should have received a copy of this license along 11 | this program. If not, see . 12 | """ 13 | 14 | from makedataset import * 15 | import argparse 16 | 17 | if __name__ == "__main__": 18 | 19 | parser = argparse.ArgumentParser(description=\ 20 | "Building the training patch database") 21 | 22 | parser.add_argument("--rgb", action='store_true',default = True,\ 23 | help='prepare RGB database instead of grayscale') 24 | # Preprocessing parameters 25 | parser.add_argument("--patch_size", "--p", type=int, default=128, \ 26 | help="Patch size") 27 | parser.add_argument("--stride", "--s", type=int, default=128, \ 28 | help="Size of stride") 29 | 30 | args = parser.parse_args() 31 | 32 | if args.rgb: 33 | Edge_TrainSynRGB('./dataset/High','./dataset/Low',args.patch_size,args.stride) 34 | #Train_100 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Mar 21 20:48:05 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | import numpy as np 13 | import cv2 14 | import time 15 | import os 16 | from model import * 17 | import utils_train 18 | 19 | 20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 22 | 23 | def load_checkpoint(checkpoint_dir,IsGPU): 24 | 25 | if IsGPU == 1: 26 | model_info = torch.load(checkpoint_dir + 'checkpoint.pth.tar') 27 | net = Main() 28 | device_ids = [0] 29 | model = nn.DataParallel(net, device_ids=device_ids).cuda() 30 | model.load_state_dict(model_info['state_dict']) 31 | optimizer = torch.optim.Adam(model.parameters()) 32 | optimizer.load_state_dict(model_info['optimizer']) 33 | cur_epoch = model_info['epoch'] 34 | else: 35 | 36 | model_info = torch.load(checkpoint_dir + 'checkpoint.pth.tar',map_location=torch.device('cpu')) 37 | net = Main() 38 | device_ids = [0] 39 | model = nn.DataParallel(net, device_ids=device_ids) 40 | model.load_state_dict(model_info['state_dict']) 41 | optimizer = torch.optim.Adam(model.parameters()) 42 | optimizer.load_state_dict(model_info['optimizer']) 43 | cur_epoch = model_info['epoch'] 44 | 45 | return model, optimizer,cur_epoch 46 | 47 | def adjust_learning_rate(optimizer, epoch, lr_update_freq): 48 | if not epoch % lr_update_freq and epoch: 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = param_group['lr'] * 0.1 51 | print( param_group['lr']) 52 | return optimizer 53 | 54 | def train_psnr(train_in,train_out): 55 | 56 | psnr = utils_train.batch_psnr(train_in,train_out,1.) 57 | return psnr 58 | 59 | 60 | def hwc_to_chw(img): 61 | return np.transpose(img, axes=[2, 0, 1]) 62 | 63 | def chw_to_hwc(img): 64 | return np.transpose(img, axes=[1, 2, 0]) 65 | 66 | def GFLap(data): 67 | x = cv2.GaussianBlur(data, (3,3),0) 68 | x = cv2.Laplacian(np.clip(x*255,0,255).astype('uint8'),cv2.CV_8U,ksize =3) 69 | Lap = cv2.convertScaleAbs(x) 70 | return Lap/255.0 71 | 72 | 73 | if __name__ == '__main__': 74 | checkpoint_dir = './checkpoint/' 75 | test_dir = './input' 76 | result_dir = './output' 77 | testfiles = os.listdir(test_dir) 78 | 79 | IsGPU = 1 #GPU is 1, CPU is 0 80 | 81 | print('> Loading dataset ...') 82 | model,optimizer,cur_epoch = load_checkpoint(checkpoint_dir,IsGPU) 83 | 84 | if IsGPU == 1: 85 | for f in range(len(testfiles)): 86 | model.eval() 87 | with torch.no_grad(): 88 | img_c = cv2.imread(test_dir + '/' + testfiles[f]) / 255.0 89 | img_l = hwc_to_chw(np.array(img_c).astype('float32')) 90 | img_g = cv2.imread(test_dir + '/' + testfiles[f],0) / 255.0 91 | input_var = torch.from_numpy(img_l.copy()).type(torch.FloatTensor).unsqueeze(0).cuda() 92 | input_var_gl = torch.from_numpy(GFLap(img_g.copy())).type(torch.FloatTensor).unsqueeze(0).unsqueeze(0).cuda() 93 | s = time.time() 94 | _,_,E_out = model(input_var,input_var_gl) 95 | e = time.time() 96 | print(input_var.shape) 97 | print('GPUTime:%.4f'%(e-s)) 98 | E_out = chw_to_hwc(E_out.squeeze().cpu().detach().numpy()) 99 | cv2.imwrite(result_dir + '/' + testfiles[f][:-4] + '_DDNet.png',np.clip(E_out*255,0.0,255.0)) 100 | 101 | else: 102 | for f in range(len(testfiles)): 103 | model.eval() 104 | with torch.no_grad(): 105 | img_c = cv2.imread(test_dir + '/' + testfiles[f]) / 255.0 106 | img_l = hwc_to_chw(np.array(img_c).astype('float32')) 107 | img_g = cv2.imread(test_dir + '/' + testfiles[f],0) / 255.0 108 | input_var = torch.from_numpy(img_l.copy()).type(torch.FloatTensor).unsqueeze(0) 109 | input_var_gl = torch.from_numpy(GFLap(img_g.copy())).type(torch.FloatTensor).unsqueeze(0).unsqueeze(0).cuda() 110 | s = time.time() 111 | _,_,E_out = model(input_var,input_var_gl).to('cpu') 112 | e = time.time() 113 | print(input_var.shape) 114 | print('CPUTime:%.4f'%(e-s)) 115 | E_out = chw_to_hwc(E_out.squeeze().cpu().detach().numpy()) 116 | 117 | cv2.imwrite(result_dir + '/' + testfiles[f][:-4] + '_DDNet.png',np.clip(E_out*255,0.0,255.0)) 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /utils_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Different utilities such as orthogonalization of weights, initialization of 3 | loggers, etc 4 | 5 | Copyright (C) 2018, Matias Tassano 6 | 7 | This program is free software: you can use, modify and/or 8 | redistribute it under the terms of the GNU General Public 9 | License as published by the Free Software Foundation, either 10 | version 3 of the License, or (at your option) any later 11 | version. You should have received a copy of this license along 12 | this program. If not, see . 13 | """ 14 | import subprocess 15 | import math 16 | import logging 17 | import numpy as np 18 | import cv2 19 | import torch 20 | import torch.nn as nn 21 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 22 | from copy import copy 23 | import torch.nn.functional as F 24 | 25 | def weights_init_kaiming(lyr): 26 | r"""Initializes weights of the model according to the "He" initialization 27 | method described in "Delving deep into rectifiers: Surpassing human-level 28 | performance on ImageNet classification" - He, K. et al. (2015), using a 29 | normal distribution. 30 | This function is to be called by the torch.nn.Module.apply() method, 31 | which applies weights_init_kaiming() to every layer of the model. 32 | """ 33 | classname = lyr.__class__.__name__ 34 | if classname.find('Conv') != -1: 35 | nn.init.kaiming_normal_(lyr.weight.data, a=0, mode='fan_in') 36 | elif classname.find('Linear') != -1: 37 | nn.init.kaiming_normal_(lyr.weight.data, a=0, mode='fan_in') 38 | elif classname.find('BatchNorm') != -1: 39 | lyr.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).\ 40 | clamp_(-0.025, 0.025) 41 | nn.init.constant_(lyr.bias.data, 0.0) 42 | 43 | def batch_psnr(img, imclean, data_range): 44 | r""" 45 | Computes the PSNR along the batch dimension (not pixel-wise) 46 | 47 | Args: 48 | img: a `torch.Tensor` containing the restored image 49 | imclean: a `torch.Tensor` containing the reference image 50 | data_range: The data range of the input image (distance between 51 | minimum and maximum possible values). By default, this is estimated 52 | from the image data-type. 53 | """ 54 | img_cpu = img.data.cpu().numpy().astype(np.float32) 55 | imgclean = imclean.data.cpu().numpy().astype(np.float32) 56 | psnr = 0 57 | for i in range(img_cpu.shape[0]): 58 | psnr += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :], \ 59 | data_range=data_range) 60 | return psnr/img_cpu.shape[0] 61 | 62 | def data_augmentation(image, mode): 63 | r"""Performs dat augmentation of the input image 64 | 65 | Args: 66 | image: a cv2 (OpenCV) image 67 | mode: int. Choice of transformation to apply to the image 68 | 0 - no transformation 69 | 1 - flip up and down 70 | 2 - rotate counterwise 90 degree 71 | 3 - rotate 90 degree and flip up and down 72 | 4 - rotate 180 degree 73 | 5 - rotate 180 degree and flip 74 | 6 - rotate 270 degree 75 | 7 - rotate 270 degree and flip 76 | """ 77 | out = np.transpose(image, (1, 2, 0)) 78 | if mode == 0: 79 | # original 80 | out = out 81 | elif mode == 1: 82 | # flip up and down 83 | out = np.flipud(out) 84 | elif mode == 2: 85 | # rotate counterwise 90 degree 86 | out = np.rot90(out) 87 | elif mode == 3: 88 | # rotate 90 degree and flip up and down 89 | out = np.rot90(out) 90 | out = np.flipud(out) 91 | elif mode == 4: 92 | # rotate 180 degree 93 | out = np.rot90(out, k=2) 94 | elif mode == 5: 95 | # rotate 180 degree and flip 96 | out = np.rot90(out, k=2) 97 | out = np.flipud(out) 98 | elif mode == 6: 99 | # rotate 270 degree 100 | out = np.rot90(out, k=3) 101 | elif mode == 7: 102 | # rotate 270 degree and flip 103 | out = np.rot90(out, k=3) 104 | out = np.flipud(out) 105 | else: 106 | raise Exception('Invalid choice of image transformation') 107 | return np.transpose(out, (2, 0, 1)) 108 | 109 | def variable_to_cv2_image(varim): 110 | r"""Converts a torch.autograd.Variable to an OpenCV image 111 | 112 | Args: 113 | varim: a torch.autograd.Variable 114 | """ 115 | nchannels = varim.size()[1] 116 | if nchannels == 1: 117 | res = (varim.data.cpu().numpy()[0, 0, :]*255.).clip(0, 255).astype(np.uint8) 118 | elif nchannels == 3: 119 | res = varim.data.cpu().numpy()[0] 120 | res = cv2.cvtColor(res.transpose(1, 2, 0), cv2.COLOR_RGB2BGR) 121 | res = (res*255.).clip(0, 255).astype(np.uint8) 122 | else: 123 | raise Exception('Number of color channels not supported') 124 | return res 125 | 126 | def get_git_revision_short_hash(): 127 | r"""Returns the current Git commit. 128 | """ 129 | return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip() 130 | 131 | def init_logger(argdict): 132 | r"""Initializes a logging.Logger to save all the running parameters to a 133 | log file 134 | 135 | Args: 136 | argdict: dictionary of parameters to be logged 137 | """ 138 | from os.path import join 139 | 140 | logger = logging.getLogger(__name__) 141 | logger.setLevel(level=logging.INFO) 142 | fh = logging.FileHandler(join(argdict.log_dir, 'log.txt'), mode='a') 143 | formatter = logging.Formatter('%(asctime)s - %(message)s') 144 | fh.setFormatter(formatter) 145 | logger.addHandler(fh) 146 | try: 147 | logger.info("Commit: {}".format(get_git_revision_short_hash())) 148 | except Exception as e: 149 | logger.error("Couldn't get commit number: {}".format(e)) 150 | logger.info("Arguments: ") 151 | for k in argdict.__dict__: 152 | logger.info("\t{}: {}".format(k, argdict.__dict__[k])) 153 | 154 | return logger 155 | 156 | def init_logger_ipol(): 157 | r"""Initializes a logging.Logger in order to log the results after 158 | testing a model 159 | 160 | Args: 161 | result_dir: path to the folder with the denoising results 162 | """ 163 | logger = logging.getLogger('testlog') 164 | logger.setLevel(level=logging.INFO) 165 | fh = logging.FileHandler('out.txt', mode='w') 166 | formatter = logging.Formatter('%(message)s') 167 | fh.setFormatter(formatter) 168 | logger.addHandler(fh) 169 | 170 | return logger 171 | 172 | def init_logger_test(result_dir): 173 | r"""Initializes a logging.Logger in order to log the results after testing 174 | a model 175 | 176 | Args: 177 | result_dir: path to the folder with the denoising results 178 | """ 179 | from os.path import join 180 | 181 | logger = logging.getLogger('testlog') 182 | logger.setLevel(level=logging.INFO) 183 | fh = logging.FileHandler(join(result_dir, 'log.txt'), mode='a') 184 | formatter = logging.Formatter('%(asctime)s - %(message)s') 185 | fh.setFormatter(formatter) 186 | logger.addHandler(fh) 187 | 188 | return logger 189 | 190 | def normalize(data): 191 | r"""Normalizes a unit8 image to a float32 image in the range [0, 1] 192 | 193 | Args: 194 | data: a unint8 numpy array to normalize from [0, 255] to [0, 1] 195 | """ 196 | return np.float32(data/255.) 197 | 198 | def svd_orthogonalization(lyr): 199 | r"""Applies regularization to the training by performing the 200 | orthogonalization technique described in the paper "FFDNet: Toward a fast 201 | and flexible solution for CNN based image denoising." Zhang et al. (2017). 202 | For each Conv layer in the model, the method replaces the matrix whose columns 203 | are the filters of the layer by new filters which are orthogonal to each other. 204 | This is achieved by setting the singular values of a SVD decomposition to 1. 205 | 206 | This function is to be called by the torch.nn.Module.apply() method, 207 | which applies svd_orthogonalization() to every layer of the model. 208 | """ 209 | classname = lyr.__class__.__name__ 210 | if classname.find('Conv') != -1: 211 | weights = lyr.weight.data.clone() 212 | c_out, c_in, f1, f2 = weights.size() 213 | dtype = lyr.weight.data.type() 214 | 215 | # Reshape filters to columns 216 | # From (c_out, c_in, f1, f2) to (f1*f2*c_in, c_out) 217 | weights = weights.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) 218 | 219 | # Convert filter matrix to numpy array 220 | weights = weights.cpu().numpy() 221 | 222 | # SVD decomposition and orthogonalization 223 | mat_u, _, mat_vh = np.linalg.svd(weights, full_matrices=False) 224 | weights = np.dot(mat_u, mat_vh) 225 | 226 | # As full_matrices=False we don't need to set s[:] = 1 and do mat_u*s 227 | lyr.weight.data = torch.Tensor(weights).view(f1, f2, c_in, c_out).\ 228 | permute(3, 2, 0, 1).type(dtype) 229 | else: 230 | pass 231 | 232 | def remove_dataparallel_wrapper(state_dict): 233 | r"""Converts a DataParallel model to a normal one by removing the "module." 234 | wrapper in the module dictionary 235 | 236 | Args: 237 | state_dict: a torch.nn.DataParallel state dictionary 238 | """ 239 | from collections import OrderedDict 240 | 241 | new_state_dict = OrderedDict() 242 | for k, vl in state_dict.items(): 243 | name = k[7:] # remove 'module.' of DataParallel 244 | new_state_dict[name] = vl 245 | 246 | return new_state_dict 247 | 248 | def is_rgb(im_path): 249 | r""" Returns True if the image in im_path is an RGB image 250 | """ 251 | from skimage.io import imread 252 | rgb = False 253 | im = imread(im_path) 254 | if (len(im.shape) == 3): 255 | if not(np.allclose(im[...,0], im[...,1]) and np.allclose(im[...,2], im[...,1])): 256 | rgb = True 257 | print("rgb: {}".format(rgb)) 258 | print("im shape: {}".format(im.shape)) 259 | return rgb 260 | 261 | 262 | --------------------------------------------------------------------------------