├── 31.31_0.90 ├── Net1 │ └── 883.pth ├── Net2 │ └── 883.pth ├── Net3 │ └── 883.pth ├── Net4 │ └── 883.pth └── refine │ └── 883.pth ├── Derain-2Net.py ├── Derain.py ├── Derain2.py ├── Derain2_skip.py ├── Derain_ICGAN.py ├── Derain_add_densenet.py ├── MyDataset └── Datasets.py ├── README.md ├── Utils ├── Vidsom.py ├── model_init.py ├── skle_ssim.py ├── ssim_map.py ├── torch_ssim.py └── utils.py ├── checkpoints └── 31.31_0.90 │ ├── Net1 │ └── 883.pth │ ├── Net2 │ └── 883.pth │ ├── Net3 │ └── 883.pth │ ├── Net4 │ └── 883.pth │ └── refine │ └── 883.pth ├── image-criterions-python ├── FSIM.mat ├── NIQE.py ├── PSNR_SSIM.py ├── README.md ├── VIF.py ├── comp_for_derain.py ├── demo-images │ ├── clear.png │ └── rain.png ├── demo-test.py ├── derain_test.py ├── for_class.py └── niqe_image_params.mat ├── net ├── __pycache__ │ ├── model.cpython-37.pyc │ ├── model_skip.cpython-37.pyc │ └── networks.cpython-37.pyc ├── model.py ├── model_skip.py └── networks.py ├── now_loss_derain_train.py └── w.py /31.31_0.90/Net1/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/31.31_0.90/Net1/883.pth -------------------------------------------------------------------------------- /31.31_0.90/Net2/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/31.31_0.90/Net2/883.pth -------------------------------------------------------------------------------- /31.31_0.90/Net3/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/31.31_0.90/Net3/883.pth -------------------------------------------------------------------------------- /31.31_0.90/Net4/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/31.31_0.90/Net4/883.pth -------------------------------------------------------------------------------- /31.31_0.90/refine/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/31.31_0.90/refine/883.pth -------------------------------------------------------------------------------- /Derain-2Net.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import torch 5 | import argparse 6 | import urllib.request 7 | from Utils.utils import * 8 | from Utils.Vidsom import * 9 | from Utils.model_init import * 10 | from Utils.ssim_map import SSIM_MAP 11 | from Utils.torch_ssim import SSIM 12 | from torch import nn, optim 13 | from torch.backends import cudnn 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from torchvision.utils import make_grid 17 | from MyDataset.Datasets import derain_test_datasets , derain_train_datasets 18 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 19 | 20 | from net.model import w_net as net1 21 | from net.model import u_net as net2 22 | from net.model import Net4_2Net as net4 23 | from net.model import refineNet_2Net 24 | 25 | 26 | parser = argparse.ArgumentParser(description="PyTorch Derain") 27 | #root 28 | parser.add_argument("--train", default="/home/ws/Desktop/PL/Derain_Dataset2018/train", type=str, 29 | help="path to load train datasets(default: none)") 30 | parser.add_argument("--test", default="/home/ws/Desktop/PL/Derain_Dataset2018/test", type=str, 31 | help="path to load test datasets(default: none)") 32 | 33 | parser.add_argument("--save_image_root", default='./result', type=str, 34 | help="save test image root") 35 | parser.add_argument("--save_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 36 | help="path to save networks") 37 | parser.add_argument("--pretrain_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 38 | help="path to pretrained net1 net2 net3 root") 39 | 40 | #hypeparameters 41 | parser.add_argument("--batchSize", type=int, default=8, help="training batch size") 42 | parser.add_argument("--nEpoch", type=int, default=500, help="number of epochs to train for") 43 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 44 | parser.add_argument("--lr1", type=float, default=5e-5, help="Learning Rate For pretrained net. Default=1e-5") 45 | parser.add_argument("--p", default=0.8, type=float, help="probability of normal conditions") 46 | 47 | parser.add_argument("--train_print_fre", type=int, default=200, help="frequency of print train loss on train phase") 48 | parser.add_argument("--test_frequency", type=int, default=1, help="frequency of test") 49 | parser.add_argument("--test_print_fre", type=int, default=200, help="frequency of print train loss on test phase") 50 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 51 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 52 | parser.add_argument("--startweights", default= 0, type=int, help="start number of net's weight , 0 is None") 53 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 54 | parser.add_argument("--startepoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 55 | parser.add_argument("--works", type=int, default=4, help="Number of works for data loader to use, Default: 1") 56 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 57 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 58 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 59 | parser.add_argument("--pretrain_epoch", default=[93,169,123], type=list, help="pretrained epoch for Net1 Net2 Net3") 60 | 61 | 62 | def main(): 63 | global opt, Net1 , Net2 , Net4 , RefineNet , criterion_mse , criterion_ssim_map,criterion_ssim,criterion_ace 64 | opt = parser.parse_args() 65 | print(opt) 66 | 67 | 68 | cuda = opt.cuda 69 | if cuda and not torch.cuda.is_available(): 70 | raise Exception("No GPU found, please run without --cuda") 71 | 72 | seed = 1334 73 | torch.manual_seed(seed) 74 | if cuda: 75 | torch.cuda.manual_seed(seed) 76 | 77 | cudnn.benchmark = True 78 | 79 | print("==========> Loading datasets") 80 | 81 | train_dataset = derain_train_datasets( data_root= opt.train, transform=Compose([ 82 | ToTensor() 83 | ])) 84 | 85 | test_dataset = derain_test_datasets(opt.test, transform=Compose([ 86 | ToTensor() 87 | ])) 88 | 89 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 90 | pin_memory=True, shuffle=True) 91 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 92 | shuffle=True) 93 | 94 | if opt.initmethod == 'orthogonal': 95 | init_function = weights_init_orthogonal 96 | 97 | elif opt.initmethod == 'kaiming': 98 | init_function = weights_init_kaiming 99 | 100 | elif opt.initmethod == 'normal': 101 | init_function = weights_init_normal 102 | 103 | else: 104 | init_function = weights_init_xavier 105 | 106 | Net1 = net1() 107 | Net1.apply(init_function) 108 | Net2 = net2() 109 | Net2.apply(init_function) 110 | # Net3 = net3() 111 | # Net3.apply(init_function) 112 | Net4 = net4() 113 | Net4.apply(init_function) 114 | RefineNet = refineNet_2Net() 115 | RefineNet.apply(init_function) 116 | 117 | 118 | 119 | criterion_mse = nn.MSELoss(size_average=True) 120 | criterion_ssim_map = SSIM_MAP() 121 | criterion_ssim = SSIM() 122 | criterion_ace = nn.SmoothL1Loss() 123 | 124 | print("==========> Setting GPU") 125 | #if cuda: 126 | if opt.cuda: 127 | Net1 = nn.DataParallel(Net1, device_ids=[i for i in range(opt.gpus)]).cuda() 128 | Net2 = nn.DataParallel(Net2, device_ids=[i for i in range(opt.gpus)]).cuda() 129 | # Net3 = nn.DataParallel(Net3, device_ids=[i for i in range(opt.gpus)]).cuda() 130 | Net4 = nn.DataParallel(Net4, device_ids=[i for i in range(opt.gpus)]).cuda() 131 | RefineNet = nn.DataParallel(RefineNet, device_ids=[i for i in range(opt.gpus)]).cuda() 132 | 133 | criterion_ssim = criterion_ssim.cuda() 134 | criterion_ssim_map = criterion_ssim_map.cuda() 135 | criterion_mse= criterion_mse.cuda() 136 | criterion_ace = criterion_ace.cuda() 137 | else: 138 | raise Exception("it takes a long time without cuda ") 139 | #print(net) 140 | 141 | if opt.pretrain_root: 142 | if os.path.exists(opt.pretrain_root): 143 | print("=> loading net from '{}'".format(opt.pretrain_root)) 144 | weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 145 | Net1.load_state_dict(weights['state_dict'] ) 146 | 147 | weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 148 | Net2.load_state_dict(weights['state_dict'] ) 149 | 150 | # weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 151 | # Net3.load_state_dict(weights['state_dict']) 152 | 153 | del weights 154 | else: 155 | print("=> no net found at '{}'".format(opt.pretrain_root)) 156 | 157 | # weights start from early 158 | if opt.startweights: 159 | if os.path.exists(opt.save_root): 160 | print("=> loading checkpoint '{}'".format(opt.save_root)) 161 | weights = torch.load(opt.save_root + '/Net1/%s.pth'%opt.startweights) 162 | Net1.load_state_dict(weights["state_dict"] ) 163 | 164 | weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 165 | Net2.load_state_dict(weights["state_dict"]) 166 | 167 | # weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 168 | # Net3.load_state_dict(weights["state_dict"]) 169 | 170 | weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 171 | Net4.load_state_dict(weights["state_dict"]) 172 | 173 | weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 174 | RefineNet.load_state_dict(weights["state_dict"]) 175 | 176 | del weights 177 | else: 178 | raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 179 | 180 | 181 | 182 | print("==========> Setting Optimizer") 183 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 184 | optimizer2 = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 185 | # optimizer3 = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 186 | optimizer4 = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 187 | optimizer_Refine = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 188 | 189 | optimizer = [ 1 , optimizer1 , optimizer2 , optimizer4 , optimizer_Refine ] 190 | print("==========> Training") 191 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 192 | 193 | if epoch > 10 : 194 | opt.lr = 1e-4 195 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 196 | optimizer[2] = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 197 | # optimizer[3] = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 198 | optimizer[3] = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 199 | optimizer[4] = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 200 | 201 | train(training_data_loader, optimizer, epoch) 202 | 203 | if epoch % opt.test_frequency == 0 and epoch >10: 204 | test(testing_data_loader ,epoch) 205 | 206 | 207 | 208 | def train(training_data_loader, optimizer, epoch): 209 | print("training ==========> epoch =", epoch, "lr =", opt.lr) 210 | Net1.train() 211 | Net2.train() 212 | # Net3.train() 213 | Net4.train() 214 | RefineNet.train() 215 | t_loss = [] # save trainloss 216 | 217 | for step, (data, label) in enumerate(training_data_loader, 1): 218 | if opt.cuda and torch.cuda.is_available(): 219 | data = data.clone().detach().requires_grad_(True).cuda() 220 | label = label.cuda() 221 | else: 222 | raise Exception("it takes a long time without cuda ") 223 | data = data.cpu() 224 | label = label.cpu() 225 | 226 | Net1_out = Net1(data) 227 | Net2_out = Net2(Net1_out) 228 | # Net3_out = Net3(Net2_out) 229 | Net4_out = Net4( data - Net1_out ,data - Net2_out ) 230 | RefineNet_out =RefineNet( Net1_out , Net2_out , data - Net4_out ) 231 | 232 | init_map = torch.ones(size=Net1_out.size()).cuda() 233 | ssim_map1 = torch.mul(criterion_ssim_map(Net1_out , label) , init_map ) 234 | ssim_map2 = torch.mul(criterion_ssim_map(Net2_out , label) , ssim_map1 ) 235 | # ssim_map3 = torch.mul(criterion_ssim_map(Net3_out , label) , ssim_map2 ) 236 | 237 | loss1 = torch.mul((1 - ssim_map1) , torch.abs(Net1_out - label)).mean() 238 | loss2 = torch.mul((1 - ssim_map2) , torch.abs(Net2_out - label)).mean() 239 | # loss3 = torch.mul((1 - ssim_map3) , torch.abs(Net3_out - label)).mean() 240 | 241 | new_loss = torch.mul((1-criterion_ssim_map(RefineNet_out , label)) ,torch.abs(RefineNet_out-label)).mean().cuda() 242 | ssim_loss = 1- criterion_ssim(RefineNet_out , label) 243 | 244 | loss = new_loss + 0.01 * (loss1 + loss2 ) 245 | 246 | 247 | del Net1_out , Net2_out , Net4_out 248 | Net1.zero_grad() 249 | Net2.zero_grad() 250 | # Net3.zero_grad() 251 | Net4.zero_grad() 252 | RefineNet.zero_grad() 253 | 254 | 255 | optimizer[1].zero_grad() 256 | optimizer[2].zero_grad() 257 | optimizer[3].zero_grad() 258 | optimizer[4].zero_grad() 259 | # optimizer[5].zero_grad() 260 | 261 | loss.backward() 262 | optimizer[1].step() 263 | optimizer[2].step() 264 | optimizer[3].step() 265 | optimizer[4].step() 266 | # optimizer[5].step() 267 | 268 | 269 | if step % opt.train_print_fre == 0: 270 | print("epoch{} step {} loss {:6f} new_loss {:6f} ssimloss {:6f} loss1 {:6f} loss2 {:6f} ".format(epoch, step, 271 | loss.item(), 272 | new_loss.item(), 273 | ssim_loss.item(), 274 | loss1.item(), 275 | loss2.item(), 276 | )) 277 | t_loss.append(loss.item()) 278 | del loss1, loss2 , loss 279 | 280 | else: 281 | # displaying to train loss 282 | updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 283 | 284 | import time 285 | def test(test_data_loader, epoch): 286 | print("------> testing") 287 | Net1.eval() 288 | Net2.eval() 289 | # Net3.eval() 290 | Net4.eval() 291 | RefineNet.eval() 292 | torch.cuda.empty_cache() 293 | starttime = 0 294 | endtime = 0 295 | with torch.no_grad(): 296 | 297 | test_Psnr_sum = 0.0 298 | test_Ssim_sum = 0.0 299 | 300 | # showing list 301 | test_Psnr_loss = [] 302 | test_Ssim_loss = [] 303 | dict_psnr_ssim = {} 304 | starttime = time.time() 305 | for test_step, (data, label, data_path) in enumerate(test_data_loader, 1): 306 | data = data.cuda() 307 | label = label.cuda() 308 | 309 | Net1_out = Net1(data).cuda() 310 | Net2_out = Net2(Net1_out).cuda() 311 | # Net3_out = Net3(Net2_out).cuda() 312 | Net4_out = Net4(data - Net1_out , data - Net2_out ).cuda() #best rain streaks 313 | refineNet_out = RefineNet(Net1_out , Net2_out , data - Net4_out ).cuda() 314 | 315 | del Net1_out, Net2_out, 316 | 317 | loss = criterion_mse(refineNet_out, label) 318 | Psnr, Ssim = get_psnr_ssim(refineNet_out, label) 319 | 320 | Psnr = round(Psnr.item(), 4) 321 | Ssim = round(Ssim.item(), 4) 322 | 323 | # del derain , label 324 | test_Psnr_sum += Psnr 325 | test_Ssim_sum += Ssim 326 | 327 | #if opt.save_image == True: 328 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 329 | # out = refineNet_out.cpu().data[0] 330 | # out = ToPILImage()(out) 331 | # image_number = re.findall(r'\d+', data_path[0])[1] 332 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 333 | if test_step % opt.test_print_fre == 0: 334 | print("epoch={} Psnr={} Ssim={} loss{}".format(epoch, Psnr, Ssim, loss.item())) 335 | test_Psnr_loss.append(test_Psnr_sum / test_step) 336 | test_Ssim_loss.append(test_Ssim_sum / test_step) 337 | 338 | else: 339 | del loss 340 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 341 | test_Ssim_sum / test_step)) 342 | write_test_perform("./perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 343 | # visdom showing 344 | print("---->testing over show in visdom") 345 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 346 | env="derain_test") 347 | endtime = time.time() 348 | 349 | print("----------TestTime:{}".format(endtime - starttime)) 350 | print("epoch {} train over-----> save net".format(epoch)) 351 | print("saving checkpoint save_root{}".format(opt.save_root)) 352 | if os.path.exists(opt.save_root): 353 | save_checkpoint(root=opt.save_root, model=Net1, epoch=epoch, model_stage="Net1") 354 | save_checkpoint(root=opt.save_root, model=Net2, epoch=epoch, model_stage="Net2") 355 | # save_checkpoint(root=opt.save_root, model=Net3, epoch=epoch, model_stage="Net3") 356 | save_checkpoint(root=opt.save_root, model=Net4, epoch=epoch, model_stage="Net4") 357 | save_checkpoint(root=opt.save_root, model=RefineNet, epoch=epoch, model_stage="refine") 358 | 359 | print("finish save epoch{} checkporint".format({epoch})) 360 | else: 361 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 362 | # 363 | 364 | print("all epoch is over ------ ") 365 | print("show epoch and epoch_loss in visdom") 366 | 367 | if __name__ == "__main__": 368 | os.system('clear') 369 | main() 370 | -------------------------------------------------------------------------------- /Derain.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import torch 5 | import argparse 6 | import urllib.request 7 | from Utils.utils import * 8 | from Utils.Vidsom import * 9 | from Utils.model_init import * 10 | from Utils.ssim_map import SSIM_MAP 11 | from Utils.torch_ssim import SSIM 12 | from torch import nn, optim 13 | from torch.backends import cudnn 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from torchvision.utils import make_grid 17 | from MyDataset.Datasets import derain_test_datasets , derain_train_datasets 18 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 19 | 20 | from net.model import w_net as net1 21 | from net.model import u_net as net2 22 | from net.model import res_net as net3 23 | from net.model import Net4 as net4 24 | from net.model import refineNet as refine 25 | 26 | 27 | parser = argparse.ArgumentParser(description="PyTorch Derain") 28 | #root 29 | parser.add_argument("--train", default="/home/ws/Desktop/PL/Derain_Dataset2018/train", type=str, 30 | help="path to load train datasets(default: none)") 31 | parser.add_argument("--test", default="/home/ws/Desktop/PL/Derain_Dataset2018/test", type=str, 32 | help="path to load test datasets(default: none)") 33 | 34 | parser.add_argument("--save_image_root", default='./result', type=str, 35 | help="save test image root") 36 | parser.add_argument("--save_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 37 | help="path to save networks") 38 | parser.add_argument("--pretrain_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 39 | help="path to pretrained net1 net2 net3 root") 40 | 41 | #hypeparameters 42 | parser.add_argument("--batchSize", type=int, default=8, help="training batch size") 43 | parser.add_argument("--nEpoch", type=int, default=500, help="number of epochs to train for") 44 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 45 | parser.add_argument("--lr1", type=float, default=5e-5, help="Learning Rate For pretrained net. Default=1e-5") 46 | parser.add_argument("--p", default=0.8, type=float, help="probability of normal conditions") 47 | 48 | parser.add_argument("--train_print_fre", type=int, default=200, help="frequency of print train loss on train phase") 49 | parser.add_argument("--test_frequency", type=int, default=1, help="frequency of test") 50 | parser.add_argument("--test_print_fre", type=int, default=200, help="frequency of print train loss on test phase") 51 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 52 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 53 | parser.add_argument("--startweights", default= 32, type=int, help="start number of net's weight , 0 is None") 54 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 55 | parser.add_argument("--startepoch", default=38, type=int, help="Manual epoch number (useful on restarts)") 56 | parser.add_argument("--works", type=int, default=8, help="Number of works for data loader to use, Default: 1") 57 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 58 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 59 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 60 | parser.add_argument("--pretrain_epoch", default=[93,169,123], type=list, help="pretrained epoch for Net1 Net2 Net3") 61 | 62 | 63 | def main(): 64 | global opt, Net1 , Net2 , Net3 , Net4 , RefineNet , criterion_mse , criterion_ssim_map,criterion_ssim,criterion_ace 65 | opt = parser.parse_args() 66 | print(opt) 67 | 68 | 69 | cuda = opt.cuda 70 | if cuda and not torch.cuda.is_available(): 71 | raise Exception("No GPU found, please run without --cuda") 72 | 73 | seed = 1334 74 | torch.manual_seed(seed) 75 | if cuda: 76 | torch.cuda.manual_seed(seed) 77 | 78 | cudnn.benchmark = True 79 | 80 | print("==========> Loading datasets") 81 | 82 | train_dataset = derain_train_datasets( data_root= opt.train, transform=Compose([ 83 | ToTensor() 84 | ])) 85 | 86 | test_dataset = derain_test_datasets(opt.test, transform=Compose([ 87 | ToTensor() 88 | ])) 89 | 90 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 91 | pin_memory=True, shuffle=True) 92 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 93 | shuffle=True) 94 | 95 | if opt.initmethod == 'orthogonal': 96 | init_function = weights_init_orthogonal 97 | 98 | elif opt.initmethod == 'kaiming': 99 | init_function = weights_init_kaiming 100 | 101 | elif opt.initmethod == 'normal': 102 | init_function = weights_init_normal 103 | 104 | else: 105 | init_function = weights_init_xavier 106 | 107 | Net1 = net1() 108 | Net1.apply(init_function) 109 | Net2 = net2() 110 | Net2.apply(init_function) 111 | Net3 = net3() 112 | Net3.apply(init_function) 113 | Net4 = net4() 114 | Net4.apply(init_function) 115 | RefineNet = refine() 116 | RefineNet.apply(init_function) 117 | 118 | 119 | 120 | criterion_mse = nn.MSELoss(size_average=True) 121 | criterion_ssim_map = SSIM_MAP() 122 | criterion_ssim = SSIM() 123 | criterion_ace = nn.SmoothL1Loss() 124 | 125 | print("==========> Setting GPU") 126 | #if cuda: 127 | if opt.cuda: 128 | Net1 = nn.DataParallel(Net1, device_ids=[i for i in range(opt.gpus)]).cuda() 129 | Net2 = nn.DataParallel(Net2, device_ids=[i for i in range(opt.gpus)]).cuda() 130 | Net3 = nn.DataParallel(Net3, device_ids=[i for i in range(opt.gpus)]).cuda() 131 | Net4 = nn.DataParallel(Net4, device_ids=[i for i in range(opt.gpus)]).cuda() 132 | RefineNet = nn.DataParallel(RefineNet, device_ids=[i for i in range(opt.gpus)]).cuda() 133 | 134 | criterion_ssim = criterion_ssim.cuda() 135 | criterion_ssim_map = criterion_ssim_map.cuda() 136 | criterion_mse= criterion_mse.cuda() 137 | criterion_ace = criterion_ace.cuda() 138 | else: 139 | raise Exception("it takes a long time without cuda ") 140 | #print(net) 141 | 142 | if opt.pretrain_root: 143 | if os.path.exists(opt.pretrain_root): 144 | print("=> loading net from '{}'".format(opt.pretrain_root)) 145 | weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 146 | Net1.load_state_dict(weights['state_dict'] ) 147 | 148 | weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 149 | Net2.load_state_dict(weights['state_dict'] ) 150 | 151 | weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 152 | Net3.load_state_dict(weights['state_dict']) 153 | 154 | del weights 155 | else: 156 | print("=> no net found at '{}'".format(opt.pretrain_root)) 157 | 158 | # weights start from early 159 | if opt.startweights: 160 | if os.path.exists(opt.save_root): 161 | print("=> loading checkpoint '{}'".format(opt.save_root)) 162 | weights = torch.load(opt.save_root + '/Net1/%s.pth'%opt.startweights) 163 | Net1.load_state_dict(weights["state_dict"] ) 164 | 165 | weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 166 | Net2.load_state_dict(weights["state_dict"]) 167 | 168 | weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 169 | Net3.load_state_dict(weights["state_dict"]) 170 | 171 | weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 172 | Net4.load_state_dict(weights["state_dict"]) 173 | 174 | weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 175 | RefineNet.load_state_dict(weights["state_dict"]) 176 | 177 | del weights 178 | else: 179 | raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 180 | 181 | 182 | 183 | print("==========> Setting Optimizer") 184 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 185 | optimizer2 = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 186 | optimizer3 = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 187 | optimizer4 = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 188 | optimizer_Refine = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 189 | 190 | optimizer = [ 1 , optimizer1 , optimizer2 , optimizer3 , optimizer4 , optimizer_Refine ] 191 | print("==========> Training") 192 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 193 | 194 | if epoch > 10 : 195 | opt.lr = 1e-4 196 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 197 | optimizer[2] = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 198 | optimizer[3] = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 199 | optimizer[4] = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 200 | optimizer[5] = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 201 | 202 | # train(training_data_loader, optimizer, epoch) 203 | 204 | if epoch % opt.test_frequency == 0 : 205 | test(testing_data_loader ,epoch) 206 | 207 | 208 | 209 | def train(training_data_loader, optimizer, epoch): 210 | print("training ==========> epoch =", epoch, "lr =", opt.lr) 211 | Net1.train() 212 | Net2.train() 213 | Net3.train() 214 | Net4.train() 215 | RefineNet.train() 216 | t_loss = [] # save trainloss 217 | 218 | for step, (data, label) in enumerate(training_data_loader, 1): 219 | if opt.cuda and torch.cuda.is_available(): 220 | data = data.clone().detach().requires_grad_(True).cuda() 221 | label = label.cuda() 222 | else: 223 | raise Exception("it takes a long time without cuda ") 224 | data = data.cpu() 225 | label = label.cpu() 226 | 227 | Net1_out = Net1(data) 228 | Net2_out = Net2(Net1_out) 229 | Net3_out = Net3(Net2_out) 230 | Net4_out = Net4( data - Net1_out ,data - Net2_out ,data - Net3_out ) 231 | RefineNet_out =RefineNet( Net1_out , Net2_out , Net3_out , data - Net4_out ) 232 | 233 | init_map = torch.ones(size=Net1_out.size()).cuda() 234 | ssim_map1 = torch.mul(criterion_ssim_map(Net1_out , label) , init_map ) 235 | ssim_map2 = torch.mul(criterion_ssim_map(Net2_out , label) , ssim_map1 ) 236 | ssim_map3 = torch.mul(criterion_ssim_map(Net3_out , label) , ssim_map2 ) 237 | 238 | loss1 = torch.mul((1 - ssim_map1) , torch.abs(Net1_out - label)).mean() 239 | loss2 = torch.mul((1 - ssim_map2) , torch.abs(Net2_out - label)).mean() 240 | loss3 = torch.mul((1 - ssim_map3) , torch.abs(Net3_out - label)).mean() 241 | 242 | new_loss = torch.mul((1-criterion_ssim_map(RefineNet_out , label)) ,torch.abs(RefineNet_out-label)).mean().cuda() 243 | ssim_loss = 1- criterion_ssim(RefineNet_out , label) 244 | 245 | loss = new_loss + 0.01 * (loss1 + loss2 +loss3) 246 | 247 | 248 | del Net1_out , Net2_out , Net3_out , Net4_out 249 | Net1.zero_grad() 250 | Net2.zero_grad() 251 | Net3.zero_grad() 252 | Net4.zero_grad() 253 | RefineNet.zero_grad() 254 | 255 | 256 | optimizer[1].zero_grad() 257 | optimizer[2].zero_grad() 258 | optimizer[3].zero_grad() 259 | optimizer[4].zero_grad() 260 | optimizer[5].zero_grad() 261 | 262 | loss.backward() 263 | optimizer[1].step() 264 | optimizer[2].step() 265 | optimizer[3].step() 266 | optimizer[4].step() 267 | optimizer[5].step() 268 | 269 | 270 | if step % opt.train_print_fre == 0: 271 | print("epoch{} step {} loss {:6f} new_loss {:6f} ssimloss {:6f} loss1 {:6f} loss2 {:6f} loss3 {:6f}".format(epoch, step, 272 | loss.item(), 273 | new_loss.item(), 274 | ssim_loss.item(), 275 | loss1.item(), 276 | loss2.item(), 277 | loss3.item())) 278 | t_loss.append(loss.item()) 279 | del loss1, loss2, loss3 , loss 280 | 281 | else: 282 | # displaying to train loss 283 | updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 284 | 285 | import time 286 | def test(test_data_loader, epoch): 287 | print("------> testing") 288 | Net1.eval() 289 | Net2.eval() 290 | Net3.eval() 291 | Net4.eval() 292 | RefineNet.eval() 293 | torch.cuda.empty_cache() 294 | starttime = 0 295 | endtime = 0 296 | with torch.no_grad(): 297 | 298 | test_Psnr_sum = 0.0 299 | test_Ssim_sum = 0.0 300 | 301 | # showing list 302 | test_Psnr_loss = [] 303 | test_Ssim_loss = [] 304 | dict_psnr_ssim = {} 305 | starttime = time.time() 306 | for test_step, (data, label, data_path) in enumerate(test_data_loader, 1): 307 | data = data.cuda() 308 | label = label.cuda() 309 | 310 | Net1_out = Net1(data).cuda() 311 | Net2_out = Net2(Net1_out).cuda() 312 | Net3_out = Net3(Net2_out).cuda() 313 | Net4_out = Net4(data - Net1_out , data - Net2_out , data - Net3_out).cuda() #best rain streaks 314 | refineNet_out = RefineNet(Net1_out , Net2_out , Net3_out , data - Net4_out ).cuda() 315 | 316 | del Net1_out, Net2_out, Net3_out 317 | 318 | loss = criterion_mse(refineNet_out, label) 319 | Psnr, Ssim = get_psnr_ssim(refineNet_out, label) 320 | 321 | Psnr = round(Psnr.item(), 4) 322 | Ssim = round(Ssim.item(), 4) 323 | 324 | # del derain , label 325 | test_Psnr_sum += Psnr 326 | test_Ssim_sum += Ssim 327 | 328 | #if opt.save_image == True: 329 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 330 | # out = refineNet_out.cpu().data[0] 331 | # out = ToPILImage()(out) 332 | # image_number = re.findall(r'\d+', data_path[0])[1] 333 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 334 | if test_step % opt.test_print_fre == 0: 335 | print("epoch={} Psnr={} Ssim={} loss{}".format(epoch, Psnr, Ssim, loss.item())) 336 | test_Psnr_loss.append(test_Psnr_sum / test_step) 337 | test_Ssim_loss.append(test_Ssim_sum / test_step) 338 | 339 | else: 340 | del loss 341 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 342 | test_Ssim_sum / test_step)) 343 | write_test_perform("./perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 344 | # visdom showing 345 | print("---->testing over show in visdom") 346 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 347 | env="derain_test") 348 | endtime = time.time() 349 | 350 | print("----------TestTime:{}".format(endtime - starttime)) 351 | print("epoch {} train over-----> save net".format(epoch)) 352 | print("saving checkpoint save_root{}".format(opt.save_root)) 353 | if os.path.exists(opt.save_root): 354 | save_checkpoint(root=opt.save_root, model=Net1, epoch=epoch, model_stage="Net1") 355 | save_checkpoint(root=opt.save_root, model=Net2, epoch=epoch, model_stage="Net2") 356 | save_checkpoint(root=opt.save_root, model=Net3, epoch=epoch, model_stage="Net3") 357 | save_checkpoint(root=opt.save_root, model=Net4, epoch=epoch, model_stage="Net4") 358 | save_checkpoint(root=opt.save_root, model=RefineNet, epoch=epoch, model_stage="refine") 359 | 360 | print("finish save epoch{} checkporint".format({epoch})) 361 | else: 362 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 363 | # 364 | 365 | print("all epoch is over ------ ") 366 | print("show epoch and epoch_loss in visdom") 367 | 368 | if __name__ == "__main__": 369 | os.system('clear') 370 | main() 371 | -------------------------------------------------------------------------------- /Derain2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import torch 5 | import time 6 | import argparse 7 | import urllib.request 8 | from Utils.utils import * 9 | from Utils.Vidsom import * 10 | from Utils.model_init import * 11 | from Utils.ssim_map import SSIM_MAP 12 | from Utils.torch_ssim import SSIM 13 | from torch import nn, optim 14 | from torch.nn import functional as F 15 | from torch.backends import cudnn 16 | from torch.autograd import Variable 17 | from torch.utils.data import DataLoader 18 | from torchvision.utils import make_grid 19 | from MyDataset.Datasets import derain_test_datasets_17,derain_train_datasets_17 20 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop , RandomHorizontalFlip 21 | from prefetch_generator import BackgroundGenerator 22 | 23 | from net.model import w_net as net1 24 | from net.model import u_net as net2 25 | from net.model import res_net as net3 26 | from net.model import Net4 as net4 27 | from net.model import refineNet_2Net as refine 28 | 29 | 30 | parser = argparse.ArgumentParser(description="PyTorch Derain") 31 | #root 32 | parser.add_argument("--train", default="/home/ws/Desktop/PL/Derain2017_datasets_transpose_pl/TrainH", type=str, 33 | help="path to load train datasets(default: none)") 34 | parser.add_argument("--test", default="/home/ws/Desktop/PL/Derain2017_datasets_transpose_pl/TestH", type=str, 35 | help="path to load test datasets(default: none)") 36 | 37 | parser.add_argument("--save_image_root", default='./result', type=str, 38 | help="save test image root") 39 | parser.add_argument("--save_root", default="/home/ws/Desktop/derain2020/checkpoints2", type=str, 40 | help="path to save networks") 41 | 42 | parser.add_argument("--pretrain_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 43 | help="path to pretrained net1 net2 net3 root") 44 | 45 | #hypeparameters 46 | parser.add_argument("--batchSize", type=int, default=14, help="training batch size") 47 | parser.add_argument("--nEpoch", type=int, default=100000, help="number of epochs to train for") 48 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 49 | parser.add_argument("--lr1", type=float, default=5e-5, help="Learning Rate For pretrained net. Default=1e-5") 50 | 51 | parser.add_argument("--train_print_fre", type=int, default=50, help="frequency of print train loss on train phase") 52 | parser.add_argument("--test_frequency", type=int, default=3, help="frequency of test") 53 | parser.add_argument("--test_print_fre", type=int, default=50, help="frequency of print train loss on test phase") 54 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 55 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 56 | parser.add_argument("--startweights", default= 303, type=int, help="start number of net's weight , 0 is None") 57 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 58 | parser.add_argument("--startepoch", default=310, type=int, help="Manual epoch number (useful on restarts)") 59 | parser.add_argument("--works", type=int, default=8, help="Number of works for data loader to use, Default: 1") 60 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 61 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 62 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 63 | parser.add_argument("--pretrain_epoch", default=[93,169,123], type=list, help="pretrained epoch for Net1 Net2 Net3") 64 | 65 | 66 | 67 | class data_prefetcher(): 68 | def __init__(self, loader): 69 | self.loader = iter(loader) 70 | self.stream = torch.cuda.Stream() 71 | #self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 72 | #self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 73 | # With Amp, it isn't necessary to manually convert data to half. 74 | # if args.fp16: 75 | # self.mean = self.mean.half() 76 | # self.std = self.std.half() 77 | self.preload() 78 | 79 | def preload(self): 80 | try: 81 | self.next_input, self.next_target = next(self.loader) 82 | except StopIteration: 83 | self.next_input = None 84 | self.next_target = None 85 | return 86 | with torch.cuda.stream(self.stream): 87 | self.next_input = self.next_input.cuda(non_blocking=True) 88 | self.next_target = self.next_target.cuda(non_blocking=True) 89 | # With Amp, it isn't necessary to manually convert data to half. 90 | # if args.fp16: 91 | # self.next_input = self.next_input.half() 92 | # else: 93 | #self.next_input = self.next_input.float() 94 | #self.next_input = self.next_input.sub_(self.mean).div_(self.std) 95 | 96 | def next(self): 97 | torch.cuda.current_stream().wait_stream(self.stream) 98 | input = self.next_input 99 | target = self.next_target 100 | self.preload() 101 | return input, target 102 | 103 | 104 | def main(): 105 | global opt, Net1 , Net2 , Net3 , Net4 , RefineNet , criterion_mse , criterion_ssim_map,criterion_ssim 106 | global it_train , it_test 107 | opt = parser.parse_args() 108 | print(opt) 109 | 110 | 111 | cuda = opt.cuda 112 | if cuda and not torch.cuda.is_available(): 113 | raise Exception("No GPU found, please run without --cuda") 114 | 115 | #seed = 1334 116 | #torch.manual_seed(seed) 117 | #if cuda: 118 | # torch.cuda.manual_seed(seed) 119 | 120 | cudnn.benchmark = True 121 | 122 | print("==========> Loading datasets") 123 | 124 | train_dataset = derain_test_datasets_17(opt.train, transform=Compose([ 125 | #Resize(size= (512, 512)), 126 | #CenterCrop(321), 127 | #RandomCrop(321), 128 | #RandomHorizontalFlip(), 129 | ToTensor(), 130 | # Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5) ) 131 | ])) 132 | it_train = len(train_dataset) 133 | 134 | test_dataset = derain_test_datasets_17(opt.test, transform=Compose([ 135 | #CenterCrop(320), 136 | ToTensor(), 137 | # Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 138 | ])) 139 | it_test = len(test_dataset) 140 | 141 | 142 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 143 | pin_memory=True, shuffle=True) 144 | 145 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 146 | shuffle=True) 147 | 148 | if opt.initmethod == 'orthogonal': 149 | init_function = weights_init_orthogonal 150 | 151 | elif opt.initmethod == 'kaiming': 152 | init_function = weights_init_kaiming 153 | 154 | elif opt.initmethod == 'normal': 155 | init_function = weights_init_normal 156 | 157 | else: 158 | init_function = weights_init_xavier 159 | 160 | Net1 = net1() 161 | Net1.apply(init_function) 162 | Net2 = net2() 163 | Net2.apply(init_function) 164 | Net3 = net3() 165 | Net3.apply(init_function) 166 | Net4 = net4() 167 | Net4.apply(init_function) 168 | RefineNet = refine() 169 | RefineNet.apply(init_function) 170 | 171 | criterion_mse = nn.MSELoss(size_average=True) 172 | criterion_ssim_map = SSIM_MAP() 173 | criterion_ssim = SSIM() 174 | 175 | 176 | print("==========> Setting GPU") 177 | #if cuda: 178 | if opt.cuda: 179 | Net1 = nn.DataParallel(Net1, device_ids=[i for i in range(opt.gpus)]).cuda() 180 | Net2 = nn.DataParallel(Net2, device_ids=[i for i in range(opt.gpus)]).cuda() 181 | Net3 = nn.DataParallel(Net3, device_ids=[i for i in range(opt.gpus)]).cuda() 182 | Net4 = nn.DataParallel(Net4, device_ids=[i for i in range(opt.gpus)]).cuda() 183 | RefineNet = nn.DataParallel(RefineNet, device_ids=[i for i in range(opt.gpus)]).cuda() 184 | 185 | criterion_ssim = criterion_ssim.cuda() 186 | criterion_ssim_map = criterion_ssim_map.cuda() 187 | criterion_mse= criterion_mse.cuda() 188 | else: 189 | raise Exception("it takes a long time without cuda ") 190 | #print(net) 191 | 192 | if opt.pretrain_root and opt.startweights== 0 : 193 | if os.path.exists(opt.pretrain_root): 194 | print("=> loading net1 2 3 from '{}'".format(opt.pretrain_root)) 195 | weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 196 | Net1.load_state_dict(weights['state_dict'] ) 197 | 198 | weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 199 | Net2.load_state_dict(weights['state_dict'] ) 200 | 201 | weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 202 | Net3.load_state_dict(weights['state_dict']) 203 | 204 | del weights 205 | else: 206 | print("=> no net found at '{}'".format(opt.pretrain_root)) 207 | 208 | # weights start from early 209 | if opt.startweights: 210 | if os.path.exists(opt.save_root): 211 | print("=> resume loading checkpoint '{}'".format(opt.save_root)) 212 | weights = torch.load(opt.save_root + '/Net1/%s.pth'%opt.startweights) 213 | Net1.load_state_dict(weights["state_dict"] ) 214 | 215 | weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 216 | Net2.load_state_dict(weights["state_dict"]) 217 | 218 | weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 219 | Net3.load_state_dict(weights["state_dict"]) 220 | 221 | weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 222 | Net4.load_state_dict(weights["state_dict"]) 223 | 224 | weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 225 | RefineNet.load_state_dict(weights["state_dict"]) 226 | 227 | del weights 228 | 229 | else: 230 | raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 231 | 232 | 233 | print("==========> Setting Optimizer") 234 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 235 | optimizer2 = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 236 | optimizer3 = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 237 | optimizer4 = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 238 | optimizer_Refine = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 239 | 240 | optimizer = [ 1 , optimizer1 , optimizer2 , optimizer3 , optimizer4 , optimizer_Refine ] 241 | print("==========> Training") 242 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 243 | start = time.clock() 244 | if epoch > 50000 : 245 | opt.lr = 5e-5 246 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 247 | optimizer[2] = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 248 | optimizer[3] = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 249 | optimizer[4] = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 250 | optimizer[5] = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 251 | 252 | train(training_data_loader, optimizer, epoch) 253 | 254 | if epoch % opt.test_frequency == 0 : 255 | test(testing_data_loader, epoch) 256 | 257 | end = time.clock() 258 | print('--------->run epoch{} takes {}s time'.format(epoch , end - start)) 259 | 260 | 261 | def train(training_data_loader, optimizer, epoch): 262 | print("training ==========> epoch =", epoch, "lr =", opt.lr) 263 | Net1.train() 264 | Net2.train() 265 | Net3.train() 266 | Net4.train() 267 | RefineNet.train() 268 | t_loss = [] # save trainloss 269 | training_data_loader = data_prefetcher(training_data_loader) 270 | 271 | #####training##### 272 | data, label = training_data_loader.next() 273 | step = 0 274 | while data is not None: 275 | step += 1 276 | #if step > it_train: 277 | # break 278 | 279 | if opt.cuda and torch.cuda.is_available(): 280 | data = data.clone().detach().requires_grad_(True).cuda() 281 | label = label.cuda() 282 | else: 283 | #raise Exception("it takes a long time without cuda ") 284 | data = data.cpu() 285 | label = label.cpu() 286 | 287 | Net1_out = Net1(data) 288 | Net2_out = Net2(Net1_out) 289 | Net3_out = Net3(Net2_out) 290 | Net4_out = Net4( data - Net1_out ,data - Net2_out ,data - Net3_out ) 291 | RefineNet_out =RefineNet( Net1_out , Net2_out , Net3_out , data - Net4_out ) 292 | 293 | init_map = torch.ones(size=Net1_out.size()).cuda() 294 | ssim_map1 = torch.mul(criterion_ssim_map(Net1_out , label) , init_map ) 295 | ssim_map2 = torch.mul(criterion_ssim_map(Net2_out , label) , ssim_map1 ) 296 | ssim_map3 = torch.mul(criterion_ssim_map(Net3_out , label) , ssim_map2 ) 297 | 298 | loss1 = torch.mul((1 - ssim_map1) , torch.abs(Net1_out - label)).mean() 299 | loss2 = torch.mul((1 - ssim_map2) , torch.abs(Net2_out - label)).mean() 300 | loss3 = torch.mul((1 - ssim_map3) , torch.abs(Net3_out - label)).mean() 301 | 302 | new_loss = torch.mul((1-criterion_ssim_map(RefineNet_out , label)) ,torch.abs(RefineNet_out-label)).mean().cuda() 303 | ssim_loss = 1- criterion_ssim(RefineNet_out , label) 304 | 305 | loss = new_loss + 0.001 * (loss1 + loss2 +loss3) + 0.001*ssim_loss 306 | del Net1_out , Net2_out , Net3_out , Net4_out 307 | Net1.zero_grad() 308 | Net2.zero_grad() 309 | Net3.zero_grad() 310 | Net4.zero_grad() 311 | RefineNet.zero_grad() 312 | 313 | 314 | optimizer[1].zero_grad() 315 | optimizer[2].zero_grad() 316 | optimizer[3].zero_grad() 317 | optimizer[4].zero_grad() 318 | optimizer[5].zero_grad() 319 | 320 | loss.backward() 321 | optimizer[1].step() 322 | optimizer[2].step() 323 | optimizer[3].step() 324 | optimizer[4].step() 325 | optimizer[5].step() 326 | 327 | 328 | if step % opt.train_print_fre == 0: 329 | print("epoch{} step {} loss {:6f} new_loss {:6f} ssimloss {:6f} loss1 {:6f} loss2 {:6f} loss3 {:6f}".format(epoch, step, 330 | loss.item(), 331 | new_loss.item(), 332 | ssim_loss.item(), 333 | loss1.item(), 334 | loss2.item(), 335 | loss3.item())) 336 | t_loss.append(loss.item()) 337 | del loss1, loss2, loss3 , loss 338 | ########next data and label ######## 339 | data, label = training_data_loader.next() 340 | 341 | 342 | #if t_loss != []: 343 | # displaying to train loss 344 | updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 345 | 346 | 347 | def test(test_data_loader, epoch): 348 | print("------> testing") 349 | Net1.eval() 350 | Net2.eval() 351 | Net3.eval() 352 | Net4.eval() 353 | RefineNet.eval() 354 | torch.cuda.empty_cache() 355 | 356 | with torch.no_grad(): 357 | 358 | test_Psnr_sum = 0.0 359 | test_Ssim_sum = 0.0 360 | 361 | # showing list 362 | test_Psnr_loss = [] 363 | test_Ssim_loss = [] 364 | dict_psnr_ssim = {} 365 | test_data_loader = data_prefetcher(test_data_loader) 366 | 367 | data, label = test_data_loader.next() 368 | test_step = 0 369 | while data is not None: 370 | test_step += 1 371 | #if test_step > it_test: 372 | # break 373 | data = data.cuda() 374 | label = label.cuda() 375 | 376 | Net1_out = Net1(data) 377 | Net2_out = Net2(Net1_out) 378 | Net3_out = Net3(Net2_out) 379 | Net4_out = Net4(data - Net1_out, data - Net2_out, data - Net3_out) 380 | refineNet_out = RefineNet(Net1_out, Net2_out, Net3_out, data - Net4_out) 381 | 382 | del Net1_out, Net2_out, Net3_out 383 | 384 | loss = criterion_mse(refineNet_out, label) 385 | Psnr, Ssim = get_psnr_ssim(refineNet_out, label) 386 | 387 | Psnr = round(Psnr.item(), 4) 388 | Ssim = round(Ssim.item(), 4) 389 | 390 | # del derain , label 391 | test_Psnr_sum += Psnr 392 | test_Ssim_sum += Ssim 393 | 394 | #if opt.save_image == True: 395 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 396 | # out = refineNet_out.cpu().data[0] 397 | # out = ToPILImage()(out) 398 | # image_number = re.findall(r'\d+', data_path[0])[1] 399 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 400 | if test_step % opt.test_print_fre == 0: 401 | print("epoch={} Psnr={} Ssim={} loss{}".format(epoch, Psnr, Ssim, loss.item())) 402 | test_Psnr_loss.append(test_Psnr_sum / test_step) 403 | test_Ssim_loss.append(test_Ssim_sum / test_step) 404 | 405 | del loss 406 | ########next data and label ######## 407 | data, label = test_data_loader.next() 408 | 409 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 410 | test_Ssim_sum / test_step)) 411 | write_test_perform("./perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 412 | # visdom showing 413 | print("---->testing over show in visdom") 414 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 415 | env="derain_test") 416 | 417 | print("epoch {} train over-----> save net".format(epoch)) 418 | print("saving checkpoint to save_root{}".format(opt.save_root)) 419 | if os.path.exists(opt.save_root): 420 | save_checkpoint(root=opt.save_root, model=Net1, epoch=epoch, model_stage="Net1") 421 | save_checkpoint(root=opt.save_root, model=Net2, epoch=epoch, model_stage="Net2") 422 | save_checkpoint(root=opt.save_root, model=Net3, epoch=epoch, model_stage="Net3") 423 | save_checkpoint(root=opt.save_root, model=Net4, epoch=epoch, model_stage="Net4") 424 | save_checkpoint(root=opt.save_root, model=RefineNet, epoch=epoch, model_stage="refine") 425 | 426 | print("finish save epoch{} checkporint".format({epoch})) 427 | else: 428 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 429 | # 430 | 431 | print("all epoch is over ------ ") 432 | print("show epoch and epoch_loss in visdom") 433 | 434 | if __name__ == "__main__": 435 | os.system('clear') 436 | main() -------------------------------------------------------------------------------- /Derain_ICGAN.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import torch 5 | import argparse 6 | import torch.functional as f 7 | import urllib.request 8 | from Utils.utils import * 9 | from Utils.Vidsom import * 10 | from Utils.model_init import * 11 | from Utils.ssim_map import SSIM_MAP 12 | from Utils.torch_ssim import SSIM 13 | from torch import nn, optim 14 | from torch.backends import cudnn 15 | from torch.autograd import Variable 16 | from torch.utils.data import DataLoader 17 | from torchvision.utils import make_grid 18 | from MyDataset.Datasets import derain_train_datasets_IC 19 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 20 | 21 | from net.model import w_net as net1 22 | from net.model import u_net as net2 23 | from net.model import res_net as net3 24 | from net.model import Net4 as net4 25 | from net.model import refineNet as refine 26 | import torch.multiprocessing 27 | torch.multiprocessing.set_sharing_strategy('file_system') 28 | 29 | 30 | parser = argparse.ArgumentParser(description="PyTorch Derain") 31 | #root 32 | parser.add_argument("--train", default="/home/ws/Desktop/PL/ICGANdateset/training", type=str, 33 | help="path to load train datasets(default: none)") 34 | parser.add_argument("--test", default="/home/ws/Desktop/PL/ICGANdateset/test_syn", type=str, 35 | help="path to load test datasets(default: none)") 36 | 37 | parser.add_argument("--save_image_root", default='./result', type=str, 38 | help="save test image root") 39 | parser.add_argument("--save_root", default="/home/ws/Desktop/derain2020/checkpoints2", type=str, 40 | help="path to save networks") 41 | parser.add_argument("--pretrain_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 42 | help="path to pretrained net1 net2 net3 root") 43 | 44 | #hypeparameters 45 | parser.add_argument("--batchSize", type=int, default=16, help="training batch size") 46 | parser.add_argument("--nEpoch", type=int, default=10000, help="number of epochs to train for") 47 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 48 | parser.add_argument("--lr1", type=float, default=5e-5, help="Learning Rate For pretrained net. Default=1e-5") 49 | 50 | parser.add_argument("--train_print_fre", type=int, default=20, help="frequency of print train loss on train phase") 51 | parser.add_argument("--test_frequency", type=int, default=3, help="frequency of test") 52 | parser.add_argument("--test_print_fre", type=int, default=200, help="frequency of print train loss on test phase") 53 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 54 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 55 | parser.add_argument("--startweights", default= 216, type=int, help="start number of net's weight , 0 is None") 56 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 57 | parser.add_argument("--startepoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 58 | parser.add_argument("--works", type=int, default=8, help="Number of works for data loader to use, Default: 1") 59 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 60 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 61 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 62 | parser.add_argument("--pretrain_epoch", default=[93,169,123], type=list, help="pretrained epoch for Net1 Net2 Net3") 63 | 64 | 65 | def main(): 66 | global opt, Net1 , Net2 , Net3 , Net4 , RefineNet , criterion_mse , criterion_ssim_map,criterion_ssim 67 | opt = parser.parse_args() 68 | print(opt) 69 | 70 | 71 | cuda = opt.cuda 72 | if cuda and not torch.cuda.is_available(): 73 | raise Exception("No GPU found, please run without --cuda") 74 | 75 | #seed = 1334 76 | #torch.manual_seed(seed) 77 | #if cuda: 78 | # torch.cuda.manual_seed(seed) 79 | 80 | cudnn.benchmark = True 81 | 82 | print("==========> Loading datasets") 83 | 84 | train_dataset = derain_train_datasets_IC(opt.train, transform=Compose([ 85 | 86 | Resize(size= (256, 256)), 87 | ToTensor() 88 | ])) 89 | 90 | test_dataset = derain_train_datasets_IC(opt.test, transform=Compose([ 91 | Resize(size= (256, 256)), 92 | ToTensor(), 93 | ])) 94 | 95 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 96 | pin_memory=True, shuffle=True) 97 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 98 | shuffle=True) 99 | 100 | if opt.initmethod == 'orthogonal': 101 | init_function = weights_init_orthogonal 102 | 103 | elif opt.initmethod == 'kaiming': 104 | init_function = weights_init_kaiming 105 | 106 | elif opt.initmethod == 'normal': 107 | init_function = weights_init_normal 108 | 109 | else: 110 | init_function = weights_init_xavier 111 | 112 | Net1 = net1() 113 | Net1.apply(init_function) 114 | Net2 = net2() 115 | Net2.apply(init_function) 116 | Net3 = net3() 117 | Net3.apply(init_function) 118 | Net4 = net4() 119 | Net4.apply(init_function) 120 | RefineNet = refine() 121 | RefineNet.apply(init_function) 122 | 123 | 124 | criterion_mse = nn.MSELoss(size_average=True) 125 | criterion_ssim_map = SSIM_MAP() 126 | criterion_ssim = SSIM() 127 | 128 | 129 | print("==========> Setting GPU") 130 | #if cuda: 131 | if opt.cuda: 132 | Net1 = nn.DataParallel(Net1, device_ids=[i for i in range(opt.gpus)]).cuda() 133 | Net2 = nn.DataParallel(Net2, device_ids=[i for i in range(opt.gpus)]).cuda() 134 | Net3 = nn.DataParallel(Net3, device_ids=[i for i in range(opt.gpus)]).cuda() 135 | Net4 = nn.DataParallel(Net4, device_ids=[i for i in range(opt.gpus)]).cuda() 136 | RefineNet = nn.DataParallel(RefineNet, device_ids=[i for i in range(opt.gpus)]).cuda() 137 | 138 | criterion_ssim = criterion_ssim.cuda() 139 | criterion_ssim_map = criterion_ssim_map.cuda() 140 | criterion_mse= criterion_mse.cuda() 141 | else: 142 | raise Exception("it takes a long time without cuda ") 143 | #print(net) 144 | 145 | if opt.pretrain_root: 146 | if os.path.exists(opt.pretrain_root): 147 | print("=> loading net from '{}'".format(opt.pretrain_root)) 148 | weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 149 | Net1.load_state_dict(weights['state_dict'] ) 150 | 151 | weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 152 | Net2.load_state_dict(weights['state_dict'] ) 153 | 154 | weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 155 | Net3.load_state_dict(weights['state_dict']) 156 | 157 | del weights 158 | else: 159 | print("=> no net found at '{}'".format(opt.pretrain_root)) 160 | 161 | # weights start from early 162 | if opt.startweights: 163 | if os.path.exists(opt.save_root): 164 | print("=> loading checkpoint '{}'".format(opt.save_root)) 165 | weights = torch.load(opt.save_root + '/Net1/%s.pth'%opt.startweights) 166 | Net1.load_state_dict(weights["state_dict"] ) 167 | 168 | weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 169 | Net2.load_state_dict(weights["state_dict"]) 170 | 171 | weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 172 | Net3.load_state_dict(weights["state_dict"]) 173 | 174 | weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 175 | Net4.load_state_dict(weights["state_dict"]) 176 | 177 | weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 178 | RefineNet.load_state_dict(weights["state_dict"]) 179 | 180 | del weights 181 | else: 182 | raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 183 | 184 | 185 | 186 | print("==========> Setting Optimizer") 187 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 188 | optimizer2 = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 189 | optimizer3 = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 190 | optimizer4 = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 191 | optimizer_Refine = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 192 | 193 | optimizer = [ 1 , optimizer1 , optimizer2 , optimizer3 , optimizer4 , optimizer_Refine ] 194 | print("==========> Training") 195 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 196 | 197 | if epoch > 50000 : 198 | opt.lr = 5e-5 199 | opt.lr1 = 1e-6 200 | 201 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 202 | optimizer[2] = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 203 | optimizer[3] = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 204 | optimizer[4] = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 205 | optimizer[5] = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 206 | 207 | train(training_data_loader, optimizer, epoch) 208 | 209 | if epoch % opt.test_frequency == 0 : 210 | test(testing_data_loader ,epoch) 211 | 212 | 213 | 214 | def train(training_data_loader, optimizer, epoch): 215 | print("training ==========> epoch =", epoch, "lr =", opt.lr) 216 | Net1.train() 217 | Net2.train() 218 | Net3.train() 219 | Net4.train() 220 | RefineNet.train() 221 | t_loss = [] # save trainloss 222 | 223 | for step, (data, label , _) in enumerate(training_data_loader, 1): 224 | if opt.cuda and torch.cuda.is_available(): 225 | data = data.clone().detach().requires_grad_(True).cuda() 226 | label = label.cuda() 227 | else: 228 | raise Exception("it takes a long time without cuda ") 229 | data = data.cpu() 230 | label = label.cpu() 231 | 232 | Net1_out = Net1(data) 233 | Net2_out = Net2(Net1_out) 234 | Net3_out = Net3(Net2_out) 235 | Net4_out = Net4( data - Net1_out ,data - Net2_out ,data - Net3_out ) 236 | RefineNet_out =RefineNet( Net1_out , Net2_out , Net3_out , data - Net4_out ) 237 | 238 | init_map = torch.ones(size=Net1_out.size()).cuda() 239 | ssim_map1 = torch.mul(criterion_ssim_map(Net1_out , label) , init_map ) 240 | ssim_map2 = torch.mul(criterion_ssim_map(Net2_out , label) , ssim_map1 ) 241 | ssim_map3 = torch.mul(criterion_ssim_map(Net3_out , label) , ssim_map2 ) 242 | 243 | loss1 = torch.mul((1 - ssim_map1) , torch.abs(Net1_out - label)).mean() 244 | loss2 = torch.mul((1 - ssim_map2) , torch.abs(Net2_out - label)).mean() 245 | loss3 = torch.mul((1 - ssim_map3) , torch.abs(Net3_out - label)).mean() 246 | 247 | #mse_loss = criterion_mse(RefineNet_out , label) 248 | new_loss = torch.mul((1-criterion_ssim_map(RefineNet_out , label)) ,torch.abs(RefineNet_out-label)).mean().cuda() 249 | 250 | ssim_loss = 1- criterion_ssim(RefineNet_out , label) 251 | 252 | loss = new_loss + 0.0001 * (loss1 + loss2 +loss3) # + 0.001*ssim_loss 253 | del Net1_out , Net2_out , Net3_out , Net4_out 254 | Net1.zero_grad() 255 | Net2.zero_grad() 256 | Net3.zero_grad() 257 | Net4.zero_grad() 258 | RefineNet.zero_grad() 259 | 260 | 261 | optimizer[1].zero_grad() 262 | optimizer[2].zero_grad() 263 | optimizer[3].zero_grad() 264 | optimizer[4].zero_grad() 265 | optimizer[5].zero_grad() 266 | 267 | loss.backward() 268 | optimizer[1].step() 269 | optimizer[2].step() 270 | optimizer[3].step() 271 | optimizer[4].step() 272 | optimizer[5].step() 273 | 274 | 275 | if step % opt.train_print_fre == 0: 276 | print("epoch{} step {} loss {:6f} new_loss {:6f} ssimloss {:6f} loss1 {:6f} loss2 {:6f} loss3 {:6f}".format(epoch, step, 277 | loss.item(), 278 | new_loss.item(), 279 | ssim_loss.item(), 280 | loss1.item(), 281 | loss2.item(), 282 | loss3.item())) 283 | t_loss.append(loss.item()) 284 | del loss1, loss2, loss3 , loss 285 | 286 | else: 287 | # displaying to train loss 288 | updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 289 | 290 | 291 | def test(test_data_loader, epoch): 292 | print("------> testing") 293 | Net1.eval() 294 | Net2.eval() 295 | Net3.eval() 296 | Net4.eval() 297 | RefineNet.eval() 298 | torch.cuda.empty_cache() 299 | 300 | with torch.no_grad(): 301 | 302 | test_Psnr_sum = 0.0 303 | test_Ssim_sum = 0.0 304 | 305 | # showing list 306 | test_Psnr_loss = [] 307 | test_Ssim_loss = [] 308 | dict_psnr_ssim = {} 309 | for test_step, (data, label, data_path) in enumerate(test_data_loader, 1): 310 | data = data.cuda() 311 | label = label.cuda() 312 | 313 | Net1_out = Net1(data) 314 | Net2_out = Net2(Net1_out) 315 | Net3_out = Net3(Net2_out) 316 | Net4_out = Net4(data - Net1_out, data - Net2_out, data - Net3_out) 317 | refineNet_out = RefineNet(Net1_out, Net2_out, Net3_out, data - Net4_out) 318 | 319 | del Net1_out, Net2_out, Net3_out 320 | 321 | loss = criterion_mse(refineNet_out, label) 322 | Psnr, Ssim = get_psnr_ssim(refineNet_out, label) 323 | 324 | Psnr = round(Psnr.item(), 4) 325 | Ssim = round(Ssim.item(), 4) 326 | 327 | # del derain , label 328 | test_Psnr_sum += Psnr 329 | test_Ssim_sum += Ssim 330 | 331 | #if opt.save_image == True: 332 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 333 | # out = refineNet_out.cpu().data[0] 334 | # out = ToPILImage()(out) 335 | # image_number = re.findall(r'\d+', data_path[0])[1] 336 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 337 | if test_step % opt.test_print_fre == 0: 338 | print("epoch={} Psnr={} Ssim={} loss{}".format(epoch, Psnr, Ssim, loss.item())) 339 | test_Psnr_loss.append(test_Psnr_sum / test_step) 340 | test_Ssim_loss.append(test_Ssim_sum / test_step) 341 | 342 | else: 343 | del loss 344 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 345 | test_Ssim_sum / test_step)) 346 | write_test_perform("./perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 347 | # visdom showing 348 | print("---->testing over show in visdom") 349 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 350 | env="derain_test") 351 | 352 | print("epoch {} train over-----> save net".format(epoch)) 353 | print("saving checkpoint save_root{}".format(opt.save_root)) 354 | if os.path.exists(opt.save_root): 355 | save_checkpoint(root=opt.save_root, model=Net1, epoch=epoch, model_stage="Net1") 356 | save_checkpoint(root=opt.save_root, model=Net2, epoch=epoch, model_stage="Net2") 357 | save_checkpoint(root=opt.save_root, model=Net3, epoch=epoch, model_stage="Net3") 358 | save_checkpoint(root=opt.save_root, model=Net4, epoch=epoch, model_stage="Net4") 359 | save_checkpoint(root=opt.save_root, model=RefineNet, epoch=epoch, model_stage="refine") 360 | 361 | print("finish save epoch{} checkporint".format({epoch})) 362 | else: 363 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 364 | # 365 | 366 | print("all epoch is over ------ ") 367 | print("show epoch and epoch_loss in visdom") 368 | 369 | if __name__ == "__main__": 370 | os.system('clear') 371 | main() -------------------------------------------------------------------------------- /Derain_add_densenet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import torch 5 | import argparse 6 | import urllib.request 7 | from Utils.utils import * 8 | from Utils.Vidsom import * 9 | from Utils.model_init import * 10 | from Utils.ssim_map import SSIM_MAP 11 | from Utils.torch_ssim import SSIM 12 | from torch import nn, optim 13 | from torch.backends import cudnn 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from torchvision.utils import make_grid 17 | from MyDataset.Datasets import derain_test_datasets , derain_train_datasets 18 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 19 | 20 | from net.model import w_net as net1 21 | from net.model import u_net as net2 22 | from net.model import res_net as net3 23 | 24 | from net.model import Net4 as net4 25 | from net.model import refineNet as refine 26 | 27 | 28 | parser = argparse.ArgumentParser(description="PyTorch Derain") 29 | #root 30 | parser.add_argument("--train", default="/home/ws/Desktop/PL/Derain_Dataset2018/train", type=str, 31 | help="path to load train datasets(default: none)") 32 | parser.add_argument("--test", default="/home/ws/Desktop/PL/Derain_Dataset2018/test", type=str, 33 | help="path to load test datasets(default: none)") 34 | 35 | parser.add_argument("--save_image_root", default='./result', type=str, 36 | help="save test image root") 37 | parser.add_argument("--save_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 38 | help="path to save networks") 39 | parser.add_argument("--pretrain_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 40 | help="path to pretrained net1 net2 net3 root") 41 | 42 | #hypeparameters 43 | parser.add_argument("--batchSize", type=int, default=8, help="training batch size") 44 | parser.add_argument("--nEpoch", type=int, default=500, help="number of epochs to train for") 45 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 46 | parser.add_argument("--lr1", type=float, default=5e-5, help="Learning Rate For pretrained net. Default=1e-5") 47 | parser.add_argument("--p", default=0.8, type=float, help="probability of normal conditions") 48 | 49 | parser.add_argument("--train_print_fre", type=int, default=200, help="frequency of print train loss on train phase") 50 | parser.add_argument("--test_frequency", type=int, default=1, help="frequency of test") 51 | parser.add_argument("--test_print_fre", type=int, default=200, help="frequency of print train loss on test phase") 52 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 53 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 54 | parser.add_argument("--startweights", default= 32, type=int, help="start number of net's weight , 0 is None") 55 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 56 | parser.add_argument("--startepoch", default=38, type=int, help="Manual epoch number (useful on restarts)") 57 | parser.add_argument("--works", type=int, default=8, help="Number of works for data loader to use, Default: 1") 58 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 59 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 60 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 61 | parser.add_argument("--pretrain_epoch", default=[93,169,123], type=list, help="pretrained epoch for Net1 Net2 Net3") 62 | 63 | 64 | def main(): 65 | global opt, Net1 , Net2 , Net3 , Net4 , RefineNet , criterion_mse , criterion_ssim_map,criterion_ssim,criterion_ace 66 | opt = parser.parse_args() 67 | print(opt) 68 | 69 | 70 | cuda = opt.cuda 71 | if cuda and not torch.cuda.is_available(): 72 | raise Exception("No GPU found, please run without --cuda") 73 | 74 | seed = 1334 75 | torch.manual_seed(seed) 76 | if cuda: 77 | torch.cuda.manual_seed(seed) 78 | 79 | cudnn.benchmark = True 80 | 81 | print("==========> Loading datasets") 82 | 83 | train_dataset = derain_train_datasets( data_root= opt.train, transform=Compose([ 84 | ToTensor() 85 | ])) 86 | 87 | test_dataset = derain_test_datasets(opt.test, transform=Compose([ 88 | ToTensor() 89 | ])) 90 | 91 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 92 | pin_memory=True, shuffle=True) 93 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 94 | shuffle=True) 95 | 96 | if opt.initmethod == 'orthogonal': 97 | init_function = weights_init_orthogonal 98 | 99 | elif opt.initmethod == 'kaiming': 100 | init_function = weights_init_kaiming 101 | 102 | elif opt.initmethod == 'normal': 103 | init_function = weights_init_normal 104 | 105 | else: 106 | init_function = weights_init_xavier 107 | 108 | Net1 = net1() 109 | Net1.apply(init_function) 110 | Net2 = net2() 111 | Net2.apply(init_function) 112 | Net3 = net3() 113 | Net3.apply(init_function) 114 | Net4 = net4() 115 | Net4.apply(init_function) 116 | RefineNet = refine() 117 | RefineNet.apply(init_function) 118 | 119 | 120 | 121 | criterion_mse = nn.MSELoss(size_average=True) 122 | criterion_ssim_map = SSIM_MAP() 123 | criterion_ssim = SSIM() 124 | criterion_ace = nn.SmoothL1Loss() 125 | 126 | print("==========> Setting GPU") 127 | #if cuda: 128 | if opt.cuda: 129 | Net1 = nn.DataParallel(Net1, device_ids=[i for i in range(opt.gpus)]).cuda() 130 | Net2 = nn.DataParallel(Net2, device_ids=[i for i in range(opt.gpus)]).cuda() 131 | Net3 = nn.DataParallel(Net3, device_ids=[i for i in range(opt.gpus)]).cuda() 132 | Net4 = nn.DataParallel(Net4, device_ids=[i for i in range(opt.gpus)]).cuda() 133 | RefineNet = nn.DataParallel(RefineNet, device_ids=[i for i in range(opt.gpus)]).cuda() 134 | 135 | criterion_ssim = criterion_ssim.cuda() 136 | criterion_ssim_map = criterion_ssim_map.cuda() 137 | criterion_mse= criterion_mse.cuda() 138 | criterion_ace = criterion_ace.cuda() 139 | else: 140 | raise Exception("it takes a long time without cuda ") 141 | #print(net) 142 | 143 | if opt.pretrain_root: 144 | if os.path.exists(opt.pretrain_root): 145 | print("=> loading net from '{}'".format(opt.pretrain_root)) 146 | weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 147 | Net1.load_state_dict(weights['state_dict'] ) 148 | 149 | weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 150 | Net2.load_state_dict(weights['state_dict'] ) 151 | 152 | weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 153 | Net3.load_state_dict(weights['state_dict']) 154 | 155 | del weights 156 | else: 157 | print("=> no net found at '{}'".format(opt.pretrain_root)) 158 | 159 | # weights start from early 160 | if opt.startweights: 161 | if os.path.exists(opt.save_root): 162 | print("=> loading checkpoint '{}'".format(opt.save_root)) 163 | weights = torch.load(opt.save_root + '/Net1/%s.pth'%opt.startweights) 164 | Net1.load_state_dict(weights["state_dict"] ) 165 | 166 | weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 167 | Net2.load_state_dict(weights["state_dict"]) 168 | 169 | weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 170 | Net3.load_state_dict(weights["state_dict"]) 171 | 172 | weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 173 | Net4.load_state_dict(weights["state_dict"]) 174 | 175 | weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 176 | RefineNet.load_state_dict(weights["state_dict"]) 177 | 178 | del weights 179 | else: 180 | raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 181 | 182 | 183 | 184 | print("==========> Setting Optimizer") 185 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 186 | optimizer2 = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 187 | optimizer3 = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 188 | optimizer4 = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 189 | optimizer_Refine = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 190 | 191 | optimizer = [ 1 , optimizer1 , optimizer2 , optimizer3 , optimizer4 , optimizer_Refine ] 192 | print("==========> Training") 193 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 194 | 195 | if epoch > 10 : 196 | opt.lr = 1e-4 197 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, Net1.parameters()), lr=opt.lr1) 198 | optimizer[2] = optim.Adam(filter(lambda p: p.requires_grad, Net2.parameters()), lr=opt.lr1) 199 | optimizer[3] = optim.Adam(filter(lambda p: p.requires_grad, Net3.parameters()), lr=opt.lr1) 200 | optimizer[4] = optim.Adam(filter(lambda p: p.requires_grad, Net4.parameters()), lr=opt.lr) 201 | optimizer[5] = optim.Adam(filter(lambda p: p.requires_grad, RefineNet.parameters()), lr=opt.lr) 202 | 203 | # train(training_data_loader, optimizer, epoch) 204 | 205 | if epoch % opt.test_frequency == 0 : 206 | test(testing_data_loader ,epoch) 207 | 208 | 209 | 210 | def train(training_data_loader, optimizer, epoch): 211 | print("training ==========> epoch =", epoch, "lr =", opt.lr) 212 | Net1.train() 213 | Net2.train() 214 | Net3.train() 215 | Net4.train() 216 | RefineNet.train() 217 | t_loss = [] # save trainloss 218 | 219 | for step, (data, label) in enumerate(training_data_loader, 1): 220 | if opt.cuda and torch.cuda.is_available(): 221 | data = data.clone().detach().requires_grad_(True).cuda() 222 | label = label.cuda() 223 | else: 224 | raise Exception("it takes a long time without cuda ") 225 | data = data.cpu() 226 | label = label.cpu() 227 | 228 | Net1_out = Net1(data) 229 | Net2_out = Net2(Net1_out) 230 | Net3_out = Net3(Net2_out) 231 | Net4_out = Net4( data - Net1_out ,data - Net2_out ,data - Net3_out ) 232 | RefineNet_out =RefineNet( Net1_out , Net2_out , Net3_out , data - Net4_out ) 233 | 234 | init_map = torch.ones(size=Net1_out.size()).cuda() 235 | ssim_map1 = torch.mul(criterion_ssim_map(Net1_out , label) , init_map ) 236 | ssim_map2 = torch.mul(criterion_ssim_map(Net2_out , label) , ssim_map1 ) 237 | ssim_map3 = torch.mul(criterion_ssim_map(Net3_out , label) , ssim_map2 ) 238 | 239 | loss1 = torch.mul((1 - ssim_map1) , torch.abs(Net1_out - label)).mean() 240 | loss2 = torch.mul((1 - ssim_map2) , torch.abs(Net2_out - label)).mean() 241 | loss3 = torch.mul((1 - ssim_map3) , torch.abs(Net3_out - label)).mean() 242 | 243 | new_loss = torch.mul((1-criterion_ssim_map(RefineNet_out , label)) ,torch.abs(RefineNet_out-label)).mean().cuda() 244 | ssim_loss = 1- criterion_ssim(RefineNet_out , label) 245 | 246 | loss = new_loss + 0.01 * (loss1 + loss2 +loss3) 247 | 248 | 249 | del Net1_out , Net2_out , Net3_out , Net4_out 250 | Net1.zero_grad() 251 | Net2.zero_grad() 252 | Net3.zero_grad() 253 | Net4.zero_grad() 254 | RefineNet.zero_grad() 255 | 256 | 257 | optimizer[1].zero_grad() 258 | optimizer[2].zero_grad() 259 | optimizer[3].zero_grad() 260 | optimizer[4].zero_grad() 261 | optimizer[5].zero_grad() 262 | 263 | loss.backward() 264 | optimizer[1].step() 265 | optimizer[2].step() 266 | optimizer[3].step() 267 | optimizer[4].step() 268 | optimizer[5].step() 269 | 270 | 271 | if step % opt.train_print_fre == 0: 272 | print("epoch{} step {} loss {:6f} new_loss {:6f} ssimloss {:6f} loss1 {:6f} loss2 {:6f} loss3 {:6f}".format(epoch, step, 273 | loss.item(), 274 | new_loss.item(), 275 | ssim_loss.item(), 276 | loss1.item(), 277 | loss2.item(), 278 | loss3.item())) 279 | t_loss.append(loss.item()) 280 | del loss1, loss2, loss3 , loss 281 | 282 | else: 283 | # displaying to train loss 284 | updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 285 | 286 | import time 287 | def test(test_data_loader, epoch): 288 | print("------> testing") 289 | Net1.eval() 290 | Net2.eval() 291 | Net3.eval() 292 | Net4.eval() 293 | RefineNet.eval() 294 | torch.cuda.empty_cache() 295 | starttime = 0 296 | endtime = 0 297 | with torch.no_grad(): 298 | 299 | test_Psnr_sum = 0.0 300 | test_Ssim_sum = 0.0 301 | 302 | # showing list 303 | test_Psnr_loss = [] 304 | test_Ssim_loss = [] 305 | dict_psnr_ssim = {} 306 | starttime = time.time() 307 | for test_step, (data, label, data_path) in enumerate(test_data_loader, 1): 308 | data = data.cuda() 309 | label = label.cuda() 310 | 311 | Net1_out = Net1(data).cuda() 312 | Net2_out = Net2(Net1_out).cuda() 313 | Net3_out = Net3(Net2_out).cuda() 314 | Net4_out = Net4(data - Net1_out , data - Net2_out , data - Net3_out).cuda() #best rain streaks 315 | refineNet_out = RefineNet(Net1_out , Net2_out , Net3_out , data - Net4_out ).cuda() 316 | 317 | del Net1_out, Net2_out, Net3_out 318 | 319 | loss = criterion_mse(refineNet_out, label) 320 | Psnr, Ssim = get_psnr_ssim(refineNet_out, label) 321 | 322 | Psnr = round(Psnr.item(), 4) 323 | Ssim = round(Ssim.item(), 4) 324 | 325 | # del derain , label 326 | test_Psnr_sum += Psnr 327 | test_Ssim_sum += Ssim 328 | 329 | #if opt.save_image == True: 330 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 331 | # out = refineNet_out.cpu().data[0] 332 | # out = ToPILImage()(out) 333 | # image_number = re.findall(r'\d+', data_path[0])[1] 334 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 335 | if test_step % opt.test_print_fre == 0: 336 | print("epoch={} Psnr={} Ssim={} loss{}".format(epoch, Psnr, Ssim, loss.item())) 337 | test_Psnr_loss.append(test_Psnr_sum / test_step) 338 | test_Ssim_loss.append(test_Ssim_sum / test_step) 339 | 340 | else: 341 | del loss 342 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 343 | test_Ssim_sum / test_step)) 344 | write_test_perform("./perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 345 | # visdom showing 346 | print("---->testing over show in visdom") 347 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 348 | env="derain_test") 349 | endtime = time.time() 350 | 351 | print("----------TestTime:{}".format(endtime - starttime)) 352 | print("epoch {} train over-----> save net".format(epoch)) 353 | print("saving checkpoint save_root{}".format(opt.save_root)) 354 | if os.path.exists(opt.save_root): 355 | save_checkpoint(root=opt.save_root, model=Net1, epoch=epoch, model_stage="Net1") 356 | save_checkpoint(root=opt.save_root, model=Net2, epoch=epoch, model_stage="Net2") 357 | save_checkpoint(root=opt.save_root, model=Net3, epoch=epoch, model_stage="Net3") 358 | save_checkpoint(root=opt.save_root, model=Net4, epoch=epoch, model_stage="Net4") 359 | save_checkpoint(root=opt.save_root, model=RefineNet, epoch=epoch, model_stage="refine") 360 | 361 | print("finish save epoch{} checkporint".format({epoch})) 362 | else: 363 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 364 | # 365 | 366 | print("all epoch is over ------ ") 367 | print("show epoch and epoch_loss in visdom") 368 | 369 | if __name__ == "__main__": 370 | os.system('clear') 371 | main() 372 | -------------------------------------------------------------------------------- /MyDataset/Datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import torch 4 | import numpy as np 5 | from Utils.torch_ssim import SSIM 6 | import cv2 7 | import torch.utils.data as Data 8 | 9 | from os import listdir 10 | from os.path import join 11 | from PIL import Image 12 | from os.path import basename 13 | from torchvision import transforms as TF 14 | from Utils.utils import get_mean_and_std 15 | 16 | 17 | 18 | 19 | def is_image_file(filename): 20 | filename_lower = filename.lower() 21 | return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.bmp', '.mat']) 22 | 23 | 24 | class derain_train_datasets(Data.Dataset): 25 | '''return rain_img ,clear , classfy_label''' 26 | def __init__(self, data_root , transform = None): 27 | super(derain_train_datasets, self).__init__() 28 | 29 | self.data_filenames = [join(data_root, x) for x in listdir(data_root) if is_image_file(x) and '._' not in x] 30 | if transform : 31 | self.transform = transform 32 | 33 | def __getitem__(self, index): 34 | data_path = self.data_filenames[index] 35 | data = Image.open(data_path) 36 | data = self.transform(data) 37 | 38 | label = data[:, :, 512:1024] 39 | data = data[:, :, :512] 40 | 41 | return data, label 42 | 43 | def __len__(self): 44 | return len(self.data_filenames) 45 | 46 | 47 | class derain_test_datasets(Data.Dataset): 48 | '''return rain_img . classfy_label''' 49 | 50 | def __init__(self, data_root , transform = None): 51 | super(derain_test_datasets, self).__init__() 52 | self.data_filenames = [join(data_root, x) for x in listdir(data_root) if is_image_file(x) and '._' not in x] 53 | 54 | if transform: 55 | self.transform = transform 56 | 57 | def __getitem__(self, index): 58 | data_path = self.data_filenames[index] 59 | data = Image.open(data_path) 60 | data = self.transform(data) 61 | 62 | label = data[:, :, 512:1024] 63 | data = data[:, :, :512] 64 | 65 | return data, label ,data_path 66 | 67 | def __len__(self): 68 | return len(self.data_filenames) 69 | 70 | 71 | 72 | 73 | class derain_test_datasets_17(Data.Dataset): 74 | '''return rain_img . classfy_label''' 75 | 76 | def __init__(self, data_root , transform = None): 77 | super(derain_test_datasets_17, self).__init__() 78 | self.root = data_root 79 | rain_root = self.root + '/rain/' 80 | self.data_filenames = [join(rain_root, x) for x in listdir(rain_root) if is_image_file(x) and '._' not in x] 81 | self.transform = transform 82 | 83 | def __getitem__(self, index): 84 | data_path = self.data_filenames[index] 85 | number = data_path.split('-')[1].split('.')[0] 86 | label_path = self.root + '/label/norain-' + number + '.png' 87 | 88 | label = Image.open(label_path) 89 | data = Image.open(data_path) 90 | if self.transform: 91 | data = self.transform(data) 92 | label = self.transform(label) 93 | 94 | 95 | return data, label 96 | 97 | def __len__(self): 98 | return len(self.data_filenames) 99 | 100 | import cv2 101 | class derain_train_datasets_17(Data.Dataset): 102 | '''return rain_img . classfy_label''' 103 | 104 | def __init__(self, data_root , transform = None): 105 | super(derain_train_datasets_17, self).__init__() 106 | self.root = data_root 107 | rain_root = self.root + '/rain/' 108 | self.data_filenames = [join(rain_root, x) for x in listdir(rain_root) if is_image_file(x) and '._' not in x] 109 | self.transform = transform 110 | 111 | def __getitem__(self, index): 112 | data_path = self.data_filenames[index] 113 | number = data_path.split('-')[1].split('.')[0] 114 | label_path = self.root + '/label/norain-' + number + '.png' 115 | ########CV2######## 116 | label = cv2.imread(label_path)[:,:,::-1] 117 | data = cv2.imread(data_path)[:,:,::-1] 118 | if data.shape != (481,321,3): 119 | label = cv2.transpose(label) 120 | data = cv2.transpose(data) 121 | 122 | 123 | label = Image.fromarray(label) 124 | data = Image.fromarray(data) 125 | if self.transform: 126 | data = self.transform(data) 127 | label = self.transform(label) 128 | 129 | 130 | return data, label 131 | 132 | def __len__(self): 133 | return len(self.data_filenames) 134 | 135 | class derain_train_datasets_IC(Data.Dataset): 136 | '''return rain_img ,clear , classfy_label''' 137 | def __init__(self, data_root , transform = None): 138 | super(derain_train_datasets_IC, self).__init__() 139 | 140 | self.data_filenames = [join(data_root, x) for x in listdir(data_root) if is_image_file(x) and '._' not in x] 141 | if transform : 142 | self.transform = transform 143 | 144 | def __getitem__(self, index): 145 | data_path = self.data_filenames[index] 146 | data = cv2.imread(data_path)[:, :, ::-1] # bgr to rgb 147 | #print(data.shape) 148 | #data = self.transform(data) 149 | h , w , c =data.shape 150 | w = int(w/2) 151 | label = data[:, :w, :] 152 | data = data[:, w:, :] 153 | data = Image.fromarray(data) 154 | label = Image.fromarray(label) 155 | 156 | return self.transform(data), self.transform(label) , data_path 157 | 158 | def __len__(self): 159 | return len(self.data_filenames) 160 | 161 | class derain_train_datasets_2020(Data.Dataset): 162 | '''return rain_img ,clear , classfy_label''' 163 | def __init__(self, p, data_root , transform = None ): 164 | super(derain_train_datasets_2020, self).__init__() 165 | 166 | self.data_filenames = [join(data_root, x) for x in listdir(data_root) if is_image_file(x) and '._' not in x] 167 | if transform : 168 | self.transform = transform 169 | self.p = p 170 | def __getitem__(self, index): 171 | data_path = self.data_filenames[index] 172 | data = Image.open(data_path) 173 | data = self.transform(data) 174 | 175 | if float(index / len(self)) < self.p : 176 | 177 | label = data[:, :, 512:1024] 178 | data = data[:, :, :512] 179 | else: 180 | label = data[:, :, 512:1024] 181 | data = label 182 | 183 | return data, label 184 | 185 | def __len__(self): 186 | return len(self.data_filenames) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EnsembleNet 2 | 3 | ## Installation 4 | 5 | **Framework** 6 | 1. Python 3.7 7 | 2. Pytorch1.0 (with ubuntu 16) 8 | 3. Torchvision 9 | 10 | **Python Dependencies** 11 | 12 | 1. skimage 13 | 2. numpy 14 | 3. visdom : `pip install visdom` 15 | 16 | 17 | ## Train 18 | 19 | ``` 20 | python3.7 Derain.py 21 | ``` 22 | 23 | 24 | ## Test 25 | 26 | ``` 27 | python3.7 Derain.py 28 | ``` 29 | 30 | ## Citation: 31 | [Ensemble single image deraining network via progressive structural boosting constraints](https://www.sciencedirect.com/science/article/abs/pii/S0923596521002204) 32 | ``` 33 | @article{peng2021ensemble, 34 | title={Ensemble single image deraining network via progressive structural boosting constraints}, 35 | author={Peng, Long and Jiang, Aiwen and Wei, Haoran and Liu, Bo and Wang, Mingwen}, 36 | journal={Signal Processing: Image Communication}, 37 | pages={116460}, 38 | year={2021}, 39 | publisher={Elsevier} 40 | } 41 | 42 | IF it helps you , please quotes us , thanks a lot 43 | ``` 44 | 45 | 46 | -------------------------------------------------------------------------------- /Utils/Vidsom.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import visdom 4 | from skimage.measure import compare_psnr, compare_ssim 5 | 6 | 7 | def updata_epoch_loss_display( train_loss,v_epoch ,envr): 8 | '''updata train_loss . test_loss for every epoch 9 | X.dim and Y.dim in line must be >= 1 ''' 10 | epoch_env = visdom.Visdom(env= envr) 11 | train_loss = torch.tensor(train_loss) 12 | Epoch = torch.tensor([i for i in range((v_epoch-1)*len(train_loss)+ 1 ,v_epoch*len(train_loss)+1)])\ 13 | 14 | if train_loss.dim() == 0 : 15 | train_loss = torch.tensor([train_loss]) 16 | #line 17 | epoch_env.line(X=Epoch ,Y=train_loss , win="img_loss" ,update='append' ,name='img_loss') 18 | 19 | 20 | 21 | def test_train_loss_display(test_loss_list , train_loss_list , v_env , epoch ): 22 | '''show train_loss , test loss in just a epoch when testing 23 | Convert test_loss_list(python_list) to tensor 24 | test_train_loss_display([1,2,4,3],[3,4,3,2,3,1,31,31,2],1)''' 25 | 26 | test_loss_list = torch.tensor(test_loss_list) 27 | train_loss_list = torch.tensor(train_loss_list) 28 | 29 | vis = visdom.Visdom(env=v_env) 30 | train_x = torch.arange(len(train_loss_list)) 31 | test_x = torch.arange(len(test_loss_list)) 32 | 33 | 34 | 35 | vis.line(X=train_x,Y=train_loss_list ,win= "epoch%d_loss"%epoch ,update='append',\ 36 | name = 'train_loss') 37 | vis.line(X = test_x , Y = test_loss_list , win="epoch%d_loss"%epoch,update='append',\ 38 | name = 'test_loss') 39 | 40 | 41 | def img_loss_withclassfy_vis_continue(v_epoch , train_loss_list , test_loss_list): 42 | train_x =torch.arange(v_epoch*len(train_loss_list),(v_epoch+1)*len(train_loss_list)) 43 | test_x = torch.arange(v_epoch * len(test_loss_list), (v_epoch + 1) * len(test_loss_list)) 44 | 45 | test_loss_list = torch.tensor(test_loss_list) 46 | train_loss_list = torch.tensor(train_loss_list) 47 | 48 | loss_vis = visdom.Visdom(env = 'loss') 49 | 50 | loss_vis.line(X=train_x,Y=train_loss_list ,win= "img_loss" ,update='append',\ 51 | name = 'train_loss') 52 | loss_vis.line(X = test_x , Y = test_loss_list , win="classfy_loss",update='append',\ 53 | name = 'test_loss') 54 | 55 | 56 | 57 | 58 | def display_Psnr_Ssim( Psnr, Ssim ,v_epoch , env ): 59 | 60 | Psnr_list = torch.tensor([Psnr]) 61 | Ssim_list = torch.tensor([Ssim]) 62 | 63 | x = torch.tensor([v_epoch]) 64 | 65 | loss_vis = visdom.Visdom(env=env) 66 | 67 | loss_vis.line(X=x, Y=Psnr_list, win="Psnr", update='append', \ 68 | name='train_loss',opts={'title':"Derain_2019_test",'xlabel': 'epoch' , 'ylabel':'Psnr' }) 69 | loss_vis.line(X=x ,Y=Ssim_list, win="Ssim", update='append', \ 70 | name='test_loss',opts={'title':"Derain_2019_test",'xlabel': 'epoch' , 'ylabel':'Ssim' }) 71 | 72 | 73 | def through_threhold(tensor , threhold): 74 | if tensor > threhold : 75 | tensor = torch.tensor([1.0 - 1e-8] ,requires_grad=True ).cuda() 76 | return tensor 77 | 78 | 79 | def write_test_perform(file_path , psnr , ssim): 80 | """test psnr , ssim """ 81 | if not isinstance(psnr , str): 82 | psnr = str(psnr) 83 | 84 | if not isinstance(ssim , str): 85 | ssim = str(ssim) 86 | 87 | with open(file_path , "a") as f : 88 | f.write("\n psnr :" + psnr + ", ssim :" + ssim ) 89 | 90 | 91 | def get_psnr_ssim(input_img , compared_img) : 92 | '''input and compared should be numpy 93 | batch_size 512 512 3''' 94 | if not isinstance(input_img , np.ndarray): 95 | input_img = np.squeeze(input_img.cpu().detach().numpy().transpose(0,2,3,1)) 96 | 97 | if not isinstance(compared_img , np.ndarray): 98 | compared_img =np.squeeze(compared_img.cpu().detach().numpy().transpose(0,2,3,1)) 99 | 100 | Ssim = compare_ssim(input_img , compared_img ,data_range = 1 , multichannel=True) 101 | Psnr = compare_psnr(input_img , compared_img ,data_range = 1) 102 | return Psnr,Ssim 103 | -------------------------------------------------------------------------------- /Utils/model_init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import * 3 | 4 | def weights_init_normal(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | init.uniform_(m.weight.data, 0.0, 0.02) 8 | elif classname.find('Linear') != -1: 9 | init.uniform_(m.weight.data, 0.0, 0.02) 10 | elif classname.find('BatchNorm2d') != -1: 11 | init.uniform_(m.weight.data, b=1.0, a=0.02) 12 | init.constant_(m.bias.data, 0.0) 13 | 14 | 15 | def weights_init_xavier(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | init.xavier_normal_(m.weight.data, gain=1) 19 | elif classname.find('Linear') != -1: 20 | init.xavier_normal_(m.weight.data, gain=1) 21 | elif classname.find('BatchNorm2d') != -1: 22 | init.uniform_(m.weight.data, b=1.0, a=0.02) 23 | init.constant_(m.bias.data, 0.0) 24 | 25 | 26 | def weights_init_kaiming(m): 27 | classname = m.__class__.__name__ 28 | if classname.find('Conv') != -1: 29 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 30 | elif classname.find('Linear') != -1: 31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 32 | elif classname.find('BatchNorm2d') != -1: 33 | init.uniform_(m.weight.data, b=1.0, a=0.02) 34 | init.constant_(m.bias.data, 0.0) 35 | 36 | 37 | def weights_init_orthogonal(m): 38 | classname = m.__class__.__name__ 39 | if classname.find('Conv') != -1: 40 | init.orthogonal(m.weight.data, gain=1) 41 | elif classname.find('Linear') != -1: 42 | init.orthogonal(m.weight.data, gain=1) 43 | elif classname.find('BatchNorm2d') != -1: 44 | init.uniform_(m.weight.data, b=1.0, a=0.02) 45 | init.constant(m.bias.data, 0.0) 46 | -------------------------------------------------------------------------------- /Utils/skle_ssim.py: -------------------------------------------------------------------------------- 1 | from skimage.measure import compare_psnr , compare_ssim 2 | import numpy as np 3 | 4 | def get_psnr_ssim(input_img , compared_img) : 5 | '''input and compared should be numpy 6 | batch_size 512 512 3''' 7 | if not type(input_img) == type(np.array([1])): 8 | input_img = np.squeeze(input_img.cpu().detach().numpy().transpose(0,2,3,1)) 9 | 10 | if not type(compared_img) == type(np.array([1])): 11 | compared_img =np.squeeze(compared_img.cpu().detach().numpy().transpose(0,2,3,1)) 12 | 13 | Ssim = compare_ssim(input_img , compared_img ,data_range = 1 , multichannel=True) 14 | Psnr = compare_psnr(input_img , compared_img ,data_range = 1) 15 | return Psnr,Ssim -------------------------------------------------------------------------------- /Utils/ssim_map.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 10 | return gauss / gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 21 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel ) 22 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1 * mu2 27 | 28 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 30 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 31 | 32 | C1 = 0.01 ** 2 33 | C2 = 0.03 ** 2 34 | 35 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return torch.tensor([ssim_map.mean()]) 39 | 40 | else: 41 | 42 | # return ssim_map.mean(1).mean(1).mean(1) 43 | return ssim_map 44 | 45 | class SSIM_MAP(torch.nn.Module): 46 | def __init__(self, window_size=11, size_average=False): 47 | super(SSIM_MAP, self).__init__() 48 | self.window_size = window_size 49 | self.size_average = size_average 50 | self.channel = 1 51 | self.window = create_window(window_size, self.channel) 52 | 53 | def forward(self, img1, img2): 54 | (_, channel, _, _) = img1.size() 55 | 56 | if channel == self.channel and self.window.data.type() == img1.data.type(): #gray image 57 | window = self.window 58 | else: 59 | window = create_window(self.window_size, channel) 60 | 61 | if img1.is_cuda: 62 | window = window.cuda(img1.get_device()) 63 | window = window.type_as(img1) 64 | 65 | self.window = window 66 | self.channel = channel 67 | 68 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 69 | 70 | 71 | def ssim(img1, img2, window_size=11, size_average=True): 72 | (_, channel, _, _) = img1.size() 73 | window = create_window(window_size, channel) 74 | 75 | if img1.is_cuda: 76 | window = window.cuda(img1.get_device()) 77 | window = window.type_as(img1) 78 | 79 | return _ssim(img1, img2, window, window_size, channel, size_average) 80 | 81 | 82 | # import torch 83 | # 84 | # if __name__ == "__main__": 85 | # # a = torch.rand(size=[2,3,512,512]) 86 | # # b = torch.rand(size=[2, 3, 512, 512]) 87 | # # ssim = SSIM(size_average= False) 88 | # # print(ssim(a,b).size())----- > 2 3 512 512 89 | # 90 | # ssim = SSIM(size_average=False).cuda() 91 | # for buffer in ssim.buffers(): 92 | # print(buffer) -------------------------------------------------------------------------------- /Utils/torch_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 10 | return gauss / gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 21 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel ) 22 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1 * mu2 27 | 28 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 30 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 31 | 32 | C1 = 0.01 ** 2 33 | C2 = 0.03 ** 2 34 | 35 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return torch.tensor([ssim_map.mean()]).cuda() 39 | 40 | else: 41 | 42 | return ssim_map.mean(1).mean(1).mean(1) 43 | #return ssim_map 44 | 45 | class SSIM(torch.nn.Module): 46 | def __init__(self, window_size=11, size_average=True): 47 | super(SSIM, self).__init__() 48 | self.window_size = window_size 49 | self.size_average = size_average 50 | self.channel = 1 51 | self.window = create_window(window_size, self.channel) 52 | 53 | def forward(self, img1, img2): 54 | (_, channel, _, _) = img1.size() 55 | 56 | if channel == self.channel and self.window.data.type() == img1.data.type(): #gray image 57 | window = self.window 58 | else: 59 | window = create_window(self.window_size, channel) 60 | 61 | if img1.is_cuda: 62 | window = window.cuda(img1.get_device()) 63 | window = window.type_as(img1) 64 | 65 | self.window = window 66 | self.channel = channel 67 | 68 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 69 | 70 | 71 | def ssim(img1, img2, window_size=11, size_average=True): 72 | (_, channel, _, _) = img1.size() 73 | window = create_window(window_size, channel) 74 | 75 | if img1.is_cuda: 76 | window = window.cuda(img1.get_device()) 77 | window = window.type_as(img1) 78 | 79 | return _ssim(img1, img2, window, window_size, channel, size_average) 80 | 81 | 82 | # import torch 83 | # 84 | # if __name__ == "__main__": 85 | # a = torch.rand(size=[2,3,512,512]) 86 | # b = torch.rand(size=[2, 3, 512, 512]) 87 | # ssim = SSIM(size_average= False) 88 | # print(ssim(a,b).size()) 89 | -------------------------------------------------------------------------------- /Utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import * 4 | 5 | import numpy as np 6 | import os 7 | from os import listdir 8 | from os.path import join 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | 12 | 13 | def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False): 14 | img = Image.open(filename).convert('RGB') 15 | if size is not None: 16 | if keep_asp: 17 | size2 = int(size * 1.0 / img.size[0] * img.size[1]) 18 | img = img.resize((size, size2), Image.ANTIALIAS) 19 | else: 20 | img = img.resize((size, size), Image.ANTIALIAS) 21 | 22 | elif scale is not None: 23 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) 24 | img = np.array(img).transpose(2, 0, 1) 25 | img = torch.from_numpy(img).float() 26 | return img 27 | 28 | 29 | 30 | def tensor_save_rgbimage(tensor, filename, cuda=False): 31 | if cuda: 32 | img = tensor.clone().cpu().clamp(0, 255).numpy() 33 | else: 34 | img = tensor.clone().clamp(0, 255).numpy() 35 | img = img.transpose(1, 2, 0).astype('uint8') 36 | img = Image.fromarray(img) 37 | img.save(filename) 38 | 39 | 40 | def tensor_save_bgrimage(tensor, filename, cuda=False): 41 | (b, g, r) = torch.chunk(tensor, 3) 42 | tensor = torch.cat((r, g, b)) 43 | tensor_save_rgbimage(tensor, filename, cuda) 44 | 45 | 46 | def output_psnr_mse(img_orig, img_out): 47 | squared_error = np.square(img_orig - img_out) 48 | mse = np.mean(squared_error) 49 | psnr = 10 * np.log10(1.0 / mse) 50 | return psnr 51 | 52 | 53 | 54 | def is_image_file(filename): 55 | filename_lower = filename.lower() 56 | return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.bmp', '.mat']) 57 | 58 | 59 | 60 | 61 | def load_all_image(path): 62 | return [join(path, x) for x in listdir(path) if is_image_file(x)] 63 | 64 | 65 | def save_checkpoint(root ,model, epoch, model_stage ): 66 | 67 | model_out_path = root+"/%s/%d.pth" % (model_stage, epoch) 68 | state_dict = model.state_dict() 69 | for key in state_dict.keys(): 70 | state_dict[key] = state_dict[key].cpu() 71 | 72 | if not os.path.exists("checkpoints"): 73 | os.makedirs("checkpoints") 74 | 75 | torch.save({ 76 | 'epoch': epoch, 77 | 'state_dict': state_dict}, model_out_path) 78 | print("Checkpoint saved to {}".format(model_out_path)) 79 | 80 | 81 | class FeatureExtractor(nn.Module): 82 | def __init__(self, cnn, feature_layer=11): 83 | super(FeatureExtractor, self).__init__() 84 | self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer + 1)]) 85 | 86 | def forward(self, x): 87 | return self.features(x) 88 | 89 | 90 | def get_mean_and_std(dataset): 91 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 92 | mean = torch.zeros(3) 93 | std = torch.zeros(3) 94 | print('==> Computing mean and std..') 95 | for inputs, targets in dataloader: 96 | for i in range(3): 97 | mean[i] += inputs[:, i, :, :].mean() 98 | std[i] += inputs[:, i, :, :].std() 99 | mean.div_(len(dataset)) 100 | std.div_(len(dataset)) 101 | return mean, std 102 | 103 | 104 | 105 | def get_residue(tensor , r_dim = 1): 106 | """ 107 | return residue_channle (RGB) 108 | """ 109 | max_channel = torch.max(tensor, dim=r_dim, keepdim=True) # keepdim 110 | min_channel = torch.min(tensor, dim=r_dim, keepdim=True) 111 | res_channel = max_channel[0] - min_channel[0] 112 | return res_channel 113 | 114 | def output_psnr_mse(img_orig, img_out): 115 | squared_error = np.square(img_orig - img_out) 116 | mse = np.mean(squared_error) 117 | psnr = 10 * np.log10(1.0 / mse) 118 | return psnr 119 | 120 | 121 | 122 | def rgb2grad(img): 123 | '''in:batch_size 3 512 512 out:batch_size 1 512 512 124 | in:tensor out:tensor''' 125 | R = img[ : ,:1] 126 | G = img[ : ,1:2] 127 | B = img[ : ,2:3] 128 | img = 0.299*R + 0.587*G + 0.114*B 129 | return img 130 | -------------------------------------------------------------------------------- /checkpoints/31.31_0.90/Net1/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/checkpoints/31.31_0.90/Net1/883.pth -------------------------------------------------------------------------------- /checkpoints/31.31_0.90/Net2/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/checkpoints/31.31_0.90/Net2/883.pth -------------------------------------------------------------------------------- /checkpoints/31.31_0.90/Net3/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/checkpoints/31.31_0.90/Net3/883.pth -------------------------------------------------------------------------------- /checkpoints/31.31_0.90/Net4/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/checkpoints/31.31_0.90/Net4/883.pth -------------------------------------------------------------------------------- /checkpoints/31.31_0.90/refine/883.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/checkpoints/31.31_0.90/refine/883.pth -------------------------------------------------------------------------------- /image-criterions-python/FSIM.mat: -------------------------------------------------------------------------------- 1 | function [FSIM, FSIMc] = FeatureSIM(imageRef, imageDis) 2 | % ======================================================================== 3 | % FSIM Index with automatic downsampling, Version 1.0 4 | % Copyright(c) 2010 Lin ZHANG, Lei Zhang, Xuanqin Mou and David Zhang 5 | % All Rights Reserved. 6 | % 7 | % ---------------------------------------------------------------------- 8 | % Permission to use, copy, or modify this software and its documentation 9 | % for educational and research purposes only and without fee is here 10 | % granted, provided that this copyright notice and the original authors' 11 | % names appear on all copies and supporting documentation. This program 12 | % shall not be used, rewritten, or adapted as the basis of a commercial 13 | % software or hardware product without first obtaining permission of the 14 | % authors. The authors make no representations about the suitability of 15 | % this software for any purpose. It is provided "as is" without express 16 | % or implied warranty. 17 | %---------------------------------------------------------------------- 18 | % 19 | % This is an implementation of the algorithm for calculating the 20 | % Feature SIMilarity (FSIM) index between two images. 21 | % 22 | % Please refer to the following paper 23 | % 24 | % Lin Zhang, Lei Zhang, Xuanqin Mou, and David Zhang,"FSIM: a feature similarity 25 | % index for image qualtiy assessment", IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, 2011. 26 | % 27 | %---------------------------------------------------------------------- 28 | % 29 | %Input : (1) imageRef: the first image being compared 30 | % (2) imageDis: the second image being compared 31 | % 32 | %Output: (1) FSIM: is the similarty score calculated using FSIM algorithm. FSIM 33 | % only considers the luminance component of images. For colorful images, 34 | % they will be converted to the grayscale at first. 35 | % (2) FSIMc: is the similarity score calculated using FSIMc algorithm. FSIMc 36 | % considers both the grayscale and the color information. 37 | %Note: For grayscale images, the returned FSIM and FSIMc are the same. 38 | % 39 | %----------------------------------------------------------------------- 40 | % 41 | %Usage: 42 | %Given 2 test images img1 and img2. For gray-scale images, their dynamic range should be 0-255. 43 | %For colorful images, the dynamic range of each color channel should be 0-255. 44 | % 45 | %[FSIM, FSIMc] = FeatureSIM(img1, img2); 46 | %----------------------------------------------------------------------- 47 | 48 | [rows, cols] = size(imageRef(:,:,1)); 49 | I1 = ones(rows, cols); 50 | I2 = ones(rows, cols); 51 | Q1 = ones(rows, cols); 52 | Q2 = ones(rows, cols); 53 | 54 | if ndims(imageRef) == 3 %images are colorful 55 | Y1 = 0.299 * double(imageRef(:,:,1)) + 0.587 * double(imageRef(:,:,2)) + 0.114 * double(imageRef(:,:,3)); 56 | Y2 = 0.299 * double(imageDis(:,:,1)) + 0.587 * double(imageDis(:,:,2)) + 0.114 * double(imageDis(:,:,3)); 57 | I1 = 0.596 * double(imageRef(:,:,1)) - 0.274 * double(imageRef(:,:,2)) - 0.322 * double(imageRef(:,:,3)); 58 | I2 = 0.596 * double(imageDis(:,:,1)) - 0.274 * double(imageDis(:,:,2)) - 0.322 * double(imageDis(:,:,3)); 59 | Q1 = 0.211 * double(imageRef(:,:,1)) - 0.523 * double(imageRef(:,:,2)) + 0.312 * double(imageRef(:,:,3)); 60 | Q2 = 0.211 * double(imageDis(:,:,1)) - 0.523 * double(imageDis(:,:,2)) + 0.312 * double(imageDis(:,:,3)); 61 | else %images are grayscale 62 | Y1 = imageRef; 63 | Y2 = imageDis; 64 | end 65 | 66 | Y1 = double(Y1); 67 | Y2 = double(Y2); 68 | %%%%%%%%%%%%%%%%%%%%%%%%% 69 | % Downsample the image 70 | %%%%%%%%%%%%%%%%%%%%%%%%% 71 | minDimension = min(rows,cols); 72 | F = max(1,round(minDimension / 256)); 73 | aveKernel = fspecial('average',F); 74 | 75 | aveI1 = conv2(I1, aveKernel,'same'); 76 | aveI2 = conv2(I2, aveKernel,'same'); 77 | I1 = aveI1(1:F:rows,1:F:cols); 78 | I2 = aveI2(1:F:rows,1:F:cols); 79 | 80 | aveQ1 = conv2(Q1, aveKernel,'same'); 81 | aveQ2 = conv2(Q2, aveKernel,'same'); 82 | Q1 = aveQ1(1:F:rows,1:F:cols); 83 | Q2 = aveQ2(1:F:rows,1:F:cols); 84 | 85 | aveY1 = conv2(Y1, aveKernel,'same'); 86 | aveY2 = conv2(Y2, aveKernel,'same'); 87 | Y1 = aveY1(1:F:rows,1:F:cols); 88 | Y2 = aveY2(1:F:rows,1:F:cols); 89 | 90 | %%%%%%%%%%%%%%%%%%%%%%%%% 91 | % Calculate the phase congruency maps 92 | %%%%%%%%%%%%%%%%%%%%%%%%% 93 | PC1 = phasecong2(Y1); 94 | PC2 = phasecong2(Y2); 95 | 96 | %%%%%%%%%%%%%%%%%%%%%%%%% 97 | % Calculate the gradient map 98 | %%%%%%%%%%%%%%%%%%%%%%%%% 99 | dx = [3 0 -3; 10 0 -10; 3 0 -3]/16; 100 | dy = [3 10 3; 0 0 0; -3 -10 -3]/16; 101 | IxY1 = conv2(Y1, dx, 'same'); 102 | IyY1 = conv2(Y1, dy, 'same'); 103 | gradientMap1 = sqrt(IxY1.^2 + IyY1.^2); 104 | 105 | IxY2 = conv2(Y2, dx, 'same'); 106 | IyY2 = conv2(Y2, dy, 'same'); 107 | gradientMap2 = sqrt(IxY2.^2 + IyY2.^2); 108 | 109 | %%%%%%%%%%%%%%%%%%%%%%%%% 110 | % Calculate the FSIM 111 | %%%%%%%%%%%%%%%%%%%%%%%%% 112 | T1 = 0.85; %fixed 113 | T2 = 160; %fixed 114 | PCSimMatrix = (2 * PC1 .* PC2 + T1) ./ (PC1.^2 + PC2.^2 + T1); 115 | gradientSimMatrix = (2*gradientMap1.*gradientMap2 + T2) ./(gradientMap1.^2 + gradientMap2.^2 + T2); 116 | PCm = max(PC1, PC2); 117 | SimMatrix = gradientSimMatrix .* PCSimMatrix .* PCm; 118 | FSIM = sum(sum(SimMatrix)) / sum(sum(PCm)); 119 | 120 | %%%%%%%%%%%%%%%%%%%%%%%%% 121 | % Calculate the FSIMc 122 | %%%%%%%%%%%%%%%%%%%%%%%%% 123 | T3 = 200; 124 | T4 = 200; 125 | ISimMatrix = (2 * I1 .* I2 + T3) ./ (I1.^2 + I2.^2 + T3); 126 | QSimMatrix = (2 * Q1 .* Q2 + T4) ./ (Q1.^2 + Q2.^2 + T4); 127 | 128 | lambda = 0.03; 129 | 130 | SimMatrixC = gradientSimMatrix .* PCSimMatrix .* real((ISimMatrix .* QSimMatrix) .^ lambda) .* PCm; 131 | FSIMc = sum(sum(SimMatrixC)) / sum(sum(PCm)); 132 | 133 | return; 134 | 135 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 136 | 137 | function [ResultPC]=phasecong2(im) 138 | % ======================================================================== 139 | % Copyright (c) 1996-2009 Peter Kovesi 140 | % School of Computer Science & Software Engineering 141 | % The University of Western Australia 142 | % http://www.csse.uwa.edu.au/ 143 | % 144 | % Permission is hereby granted, free of charge, to any person obtaining a copy 145 | % of this software and associated documentation files (the "Software"), to deal 146 | % in the Software without restriction, subject to the following conditions: 147 | % 148 | % The above copyright notice and this permission notice shall be included in all 149 | % copies or substantial portions of the Software. 150 | % 151 | % The software is provided "as is", without warranty of any kind. 152 | % References: 153 | % 154 | % Peter Kovesi, "Image Features From Phase Congruency". Videre: A 155 | % Journal of Computer Vision Research. MIT Press. Volume 1, Number 3, 156 | % Summer 1999 http://mitpress.mit.edu/e-journals/Videre/001/v13.html 157 | 158 | nscale = 4; % Number of wavelet scales. 159 | norient = 4; % Number of filter orientations. 160 | minWaveLength = 6; % Wavelength of smallest scale filter. 161 | mult = 2; % Scaling factor between successive filters. 162 | sigmaOnf = 0.55; % Ratio of the standard deviation of the 163 | % Gaussian describing the log Gabor filter's 164 | % transfer function in the frequency domain 165 | % to the filter center frequency. 166 | dThetaOnSigma = 1.2; % Ratio of angular interval between filter orientations 167 | % and the standard deviation of the angular Gaussian 168 | % function used to construct filters in the 169 | % freq. plane. 170 | k = 2.0; % No of standard deviations of the noise 171 | % energy beyond the mean at which we set the 172 | % noise threshold point. 173 | % below which phase congruency values get 174 | % penalized. 175 | epsilon = .0001; % Used to prevent division by zero. 176 | 177 | thetaSigma = pi/norient/dThetaOnSigma; % Calculate the standard deviation of the 178 | % angular Gaussian function used to 179 | % construct filters in the freq. plane. 180 | 181 | [rows,cols] = size(im); 182 | imagefft = fft2(im); % Fourier transform of image 183 | 184 | zero = zeros(rows,cols); 185 | EO = cell(nscale, norient); % Array of convolution results. 186 | 187 | estMeanE2n = []; 188 | ifftFilterArray = cell(1,nscale); % Array of inverse FFTs of filters 189 | 190 | % Pre-compute some stuff to speed up filter construction 191 | 192 | % Set up X and Y matrices with ranges normalised to +/- 0.5 193 | % The following code adjusts things appropriately for odd and even values 194 | % of rows and columns. 195 | if mod(cols,2) 196 | xrange = [-(cols-1)/2:(cols-1)/2]/(cols-1); 197 | else 198 | xrange = [-cols/2:(cols/2-1)]/cols; 199 | end 200 | 201 | if mod(rows,2) 202 | yrange = [-(rows-1)/2:(rows-1)/2]/(rows-1); 203 | else 204 | yrange = [-rows/2:(rows/2-1)]/rows; 205 | end 206 | 207 | [x,y] = meshgrid(xrange, yrange); 208 | 209 | radius = sqrt(x.^2 + y.^2); % Matrix values contain *normalised* radius from centre. 210 | theta = atan2(-y,x); % Matrix values contain polar angle. 211 | % (note -ve y is used to give +ve 212 | % anti-clockwise angles) 213 | 214 | radius = ifftshift(radius); % Quadrant shift radius and theta so that filters 215 | theta = ifftshift(theta); % are constructed with 0 frequency at the corners. 216 | radius(1,1) = 1; % Get rid of the 0 radius value at the 0 217 | % frequency point (now at top-left corner) 218 | % so that taking the log of the radius will 219 | % not cause trouble. 220 | 221 | sintheta = sin(theta); 222 | costheta = cos(theta); 223 | clear x; clear y; clear theta; % save a little memory 224 | 225 | % Filters are constructed in terms of two components. 226 | % 1) The radial component, which controls the frequency band that the filter 227 | % responds to 228 | % 2) The angular component, which controls the orientation that the filter 229 | % responds to. 230 | % The two components are multiplied together to construct the overall filter. 231 | 232 | % Construct the radial filter components... 233 | 234 | % First construct a low-pass filter that is as large as possible, yet falls 235 | % away to zero at the boundaries. All log Gabor filters are multiplied by 236 | % this to ensure no extra frequencies at the 'corners' of the FFT are 237 | % incorporated as this seems to upset the normalisation process when 238 | % calculating phase congrunecy. 239 | lp = lowpassfilter([rows,cols],.45,15); % Radius .45, 'sharpness' 15 240 | 241 | logGabor = cell(1,nscale); 242 | 243 | for s = 1:nscale 244 | wavelength = minWaveLength*mult^(s-1); 245 | fo = 1.0/wavelength; % Centre frequency of filter. 246 | logGabor{s} = exp((-(log(radius/fo)).^2) / (2 * log(sigmaOnf)^2)); 247 | logGabor{s} = logGabor{s}.*lp; % Apply low-pass filter 248 | logGabor{s}(1,1) = 0; % Set the value at the 0 frequency point of the filter 249 | % back to zero (undo the radius fudge). 250 | end 251 | 252 | % Then construct the angular filter components... 253 | 254 | spread = cell(1,norient); 255 | 256 | for o = 1:norient 257 | angl = (o-1)*pi/norient; % Filter angle. 258 | 259 | % For each point in the filter matrix calculate the angular distance from 260 | % the specified filter orientation. To overcome the angular wrap-around 261 | % problem sine difference and cosine difference values are first computed 262 | % and then the atan2 function is used to determine angular distance. 263 | 264 | ds = sintheta * cos(angl) - costheta * sin(angl); % Difference in sine. 265 | dc = costheta * cos(angl) + sintheta * sin(angl); % Difference in cosine. 266 | dtheta = abs(atan2(ds,dc)); % Absolute angular distance. 267 | spread{o} = exp((-dtheta.^2) / (2 * thetaSigma^2)); % Calculate the 268 | % angular filter component. 269 | end 270 | 271 | % The main loop... 272 | EnergyAll(rows,cols) = 0; 273 | AnAll(rows,cols) = 0; 274 | 275 | for o = 1:norient % For each orientation. 276 | sumE_ThisOrient = zero; % Initialize accumulator matrices. 277 | sumO_ThisOrient = zero; 278 | sumAn_ThisOrient = zero; 279 | Energy = zero; 280 | for s = 1:nscale, % For each scale. 281 | filter = logGabor{s} .* spread{o}; % Multiply radial and angular 282 | % components to get the filter. 283 | ifftFilt = real(ifft2(filter))*sqrt(rows*cols); % Note rescaling to match power 284 | ifftFilterArray{s} = ifftFilt; % record ifft2 of filter 285 | % Convolve image with even and odd filters returning the result in EO 286 | EO{s,o} = ifft2(imagefft .* filter); 287 | 288 | An = abs(EO{s,o}); % Amplitude of even & odd filter response. 289 | sumAn_ThisOrient = sumAn_ThisOrient + An; % Sum of amplitude responses. 290 | sumE_ThisOrient = sumE_ThisOrient + real(EO{s,o}); % Sum of even filter convolution results. 291 | sumO_ThisOrient = sumO_ThisOrient + imag(EO{s,o}); % Sum of odd filter convolution results. 292 | if s==1 % Record mean squared filter value at smallest 293 | EM_n = sum(sum(filter.^2)); % scale. This is used for noise estimation. 294 | maxAn = An; % Record the maximum An over all scales. 295 | else 296 | maxAn = max(maxAn, An); 297 | end 298 | end % ... and process the next scale 299 | 300 | % Get weighted mean filter response vector, this gives the weighted mean 301 | % phase angle. 302 | 303 | XEnergy = sqrt(sumE_ThisOrient.^2 + sumO_ThisOrient.^2) + epsilon; 304 | MeanE = sumE_ThisOrient ./ XEnergy; 305 | MeanO = sumO_ThisOrient ./ XEnergy; 306 | 307 | % Now calculate An(cos(phase_deviation) - | sin(phase_deviation)) | by 308 | % using dot and cross products between the weighted mean filter response 309 | % vector and the individual filter response vectors at each scale. This 310 | % quantity is phase congruency multiplied by An, which we call energy. 311 | 312 | for s = 1:nscale, 313 | E = real(EO{s,o}); O = imag(EO{s,o}); % Extract even and odd 314 | % convolution results. 315 | Energy = Energy + E.*MeanE + O.*MeanO - abs(E.*MeanO - O.*MeanE); 316 | end 317 | 318 | % Compensate for noise 319 | % We estimate the noise power from the energy squared response at the 320 | % smallest scale. If the noise is Gaussian the energy squared will have a 321 | % Chi-squared 2DOF pdf. We calculate the median energy squared response 322 | % as this is a robust statistic. From this we estimate the mean. 323 | % The estimate of noise power is obtained by dividing the mean squared 324 | % energy value by the mean squared filter value 325 | 326 | medianE2n = median(reshape(abs(EO{1,o}).^2,1,rows*cols)); 327 | meanE2n = -medianE2n/log(0.5); 328 | estMeanE2n(o) = meanE2n; 329 | 330 | noisePower = meanE2n/EM_n; % Estimate of noise power. 331 | 332 | % Now estimate the total energy^2 due to noise 333 | % Estimate for sum(An^2) + sum(Ai.*Aj.*(cphi.*cphj + sphi.*sphj)) 334 | 335 | EstSumAn2 = zero; 336 | for s = 1:nscale 337 | EstSumAn2 = EstSumAn2 + ifftFilterArray{s}.^2; 338 | end 339 | 340 | EstSumAiAj = zero; 341 | for si = 1:(nscale-1) 342 | for sj = (si+1):nscale 343 | EstSumAiAj = EstSumAiAj + ifftFilterArray{si}.*ifftFilterArray{sj}; 344 | end 345 | end 346 | sumEstSumAn2 = sum(sum(EstSumAn2)); 347 | sumEstSumAiAj = sum(sum(EstSumAiAj)); 348 | 349 | EstNoiseEnergy2 = 2*noisePower*sumEstSumAn2 + 4*noisePower*sumEstSumAiAj; 350 | 351 | tau = sqrt(EstNoiseEnergy2/2); % Rayleigh parameter 352 | EstNoiseEnergy = tau*sqrt(pi/2); % Expected value of noise energy 353 | EstNoiseEnergySigma = sqrt( (2-pi/2)*tau^2 ); 354 | 355 | T = EstNoiseEnergy + k*EstNoiseEnergySigma; % Noise threshold 356 | 357 | % The estimated noise effect calculated above is only valid for the PC_1 measure. 358 | % The PC_2 measure does not lend itself readily to the same analysis. However 359 | % empirically it seems that the noise effect is overestimated roughly by a factor 360 | % of 1.7 for the filter parameters used here. 361 | 362 | T = T/1.7; % Empirical rescaling of the estimated noise effect to 363 | % suit the PC_2 phase congruency measure 364 | Energy = max(Energy - T, zero); % Apply noise threshold 365 | 366 | EnergyAll = EnergyAll + Energy; 367 | AnAll = AnAll + sumAn_ThisOrient; 368 | end % For each orientation 369 | ResultPC = EnergyAll ./ AnAll; 370 | return; 371 | 372 | 373 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 374 | % LOWPASSFILTER - Constructs a low-pass butterworth filter. 375 | % 376 | % usage: f = lowpassfilter(sze, cutoff, n) 377 | % 378 | % where: sze is a two element vector specifying the size of filter 379 | % to construct [rows cols]. 380 | % cutoff is the cutoff frequency of the filter 0 - 0.5 381 | % n is the order of the filter, the higher n is the sharper 382 | % the transition is. (n must be an integer >= 1). 383 | % Note that n is doubled so that it is always an even integer. 384 | % 385 | % 1 386 | % f = -------------------- 387 | % 2n 388 | % 1.0 + (w/cutoff) 389 | % 390 | % The frequency origin of the returned filter is at the corners. 391 | % 392 | % See also: HIGHPASSFILTER, HIGHBOOSTFILTER, BANDPASSFILTER 393 | % 394 | 395 | % Copyright (c) 1999 Peter Kovesi 396 | % School of Computer Science & Software Engineering 397 | % The University of Western Australia 398 | % http://www.csse.uwa.edu.au/ 399 | % 400 | % Permission is hereby granted, free of charge, to any person obtaining a copy 401 | % of this software and associated documentation files (the "Software"), to deal 402 | % in the Software without restriction, subject to the following conditions: 403 | % 404 | % The above copyright notice and this permission notice shall be included in 405 | % all copies or substantial portions of the Software. 406 | % 407 | % The Software is provided "as is", without warranty of any kind. 408 | 409 | % October 1999 410 | % August 2005 - Fixed up frequency ranges for odd and even sized filters 411 | % (previous code was a bit approximate) 412 | 413 | function f = lowpassfilter(sze, cutoff, n) 414 | 415 | if cutoff < 0 || cutoff > 0.5 416 | error('cutoff frequency must be between 0 and 0.5'); 417 | end 418 | 419 | if rem(n,1) ~= 0 || n < 1 420 | error('n must be an integer >= 1'); 421 | end 422 | 423 | if length(sze) == 1 424 | rows = sze; cols = sze; 425 | else 426 | rows = sze(1); cols = sze(2); 427 | end 428 | 429 | % Set up X and Y matrices with ranges normalised to +/- 0.5 430 | % The following code adjusts things appropriately for odd and even values 431 | % of rows and columns. 432 | if mod(cols,2) 433 | xrange = [-(cols-1)/2:(cols-1)/2]/(cols-1); 434 | else 435 | xrange = [-cols/2:(cols/2-1)]/cols; 436 | end 437 | 438 | if mod(rows,2) 439 | yrange = [-(rows-1)/2:(rows-1)/2]/(rows-1); 440 | else 441 | yrange = [-rows/2:(rows/2-1)]/rows; 442 | end 443 | 444 | [x,y] = meshgrid(xrange, yrange); 445 | radius = sqrt(x.^2 + y.^2); % A matrix with every pixel = radius relative to centre. 446 | f = ifftshift( 1 ./ (1.0 + (radius ./ cutoff).^(2*n)) ); % The filter 447 | return; 448 | 449 | 450 | -------------------------------------------------------------------------------- /image-criterions-python/NIQE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import scipy.io 4 | from os.path import dirname 5 | from os.path import join 6 | import scipy 7 | from PIL import Image 8 | import numpy as np 9 | import scipy.ndimage 10 | import numpy as np 11 | import scipy.special 12 | import math 13 | 14 | gamma_range = np.arange(0.2, 10, 0.001) 15 | a = scipy.special.gamma(2.0 / gamma_range) 16 | a *= a 17 | b = scipy.special.gamma(1.0 / gamma_range) 18 | c = scipy.special.gamma(3.0 / gamma_range) 19 | prec_gammas = a / (b * c) 20 | 21 | 22 | def aggd_features(imdata): 23 | # flatten imdata 24 | imdata.shape = (len(imdata.flat),) 25 | imdata2 = imdata * imdata 26 | left_data = imdata2[imdata < 0] 27 | right_data = imdata2[imdata >= 0] 28 | left_mean_sqrt = 0 29 | right_mean_sqrt = 0 30 | if len(left_data) > 0: 31 | left_mean_sqrt = np.sqrt(np.average(left_data)) 32 | if len(right_data) > 0: 33 | right_mean_sqrt = np.sqrt(np.average(right_data)) 34 | 35 | if right_mean_sqrt != 0: 36 | gamma_hat = left_mean_sqrt / right_mean_sqrt 37 | else: 38 | gamma_hat = np.inf 39 | # solve r-hat norm 40 | 41 | imdata2_mean = np.mean(imdata2) 42 | if imdata2_mean != 0: 43 | r_hat = (np.average(np.abs(imdata)) ** 2) / (np.average(imdata2)) 44 | else: 45 | r_hat = np.inf 46 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) 47 | 48 | # solve alpha by guessing values that minimize ro 49 | pos = np.argmin((prec_gammas - rhat_norm) ** 2); 50 | alpha = gamma_range[pos] 51 | 52 | gam1 = scipy.special.gamma(1.0 / alpha) 53 | gam2 = scipy.special.gamma(2.0 / alpha) 54 | gam3 = scipy.special.gamma(3.0 / alpha) 55 | 56 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3) 57 | bl = aggdratio * left_mean_sqrt 58 | br = aggdratio * right_mean_sqrt 59 | 60 | # mean parameter 61 | N = (br - bl) * (gam2 / gam1) # *aggdratio 62 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) 63 | 64 | 65 | def ggd_features(imdata): 66 | nr_gam = 1 / prec_gammas 67 | sigma_sq = np.var(imdata) 68 | E = np.mean(np.abs(imdata)) 69 | rho = sigma_sq / E ** 2 70 | pos = np.argmin(np.abs(nr_gam - rho)); 71 | return gamma_range[pos], sigma_sq 72 | 73 | 74 | def paired_product(new_im): 75 | shift1 = np.roll(new_im.copy(), 1, axis=1) 76 | shift2 = np.roll(new_im.copy(), 1, axis=0) 77 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) 78 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) 79 | 80 | H_img = shift1 * new_im 81 | V_img = shift2 * new_im 82 | D1_img = shift3 * new_im 83 | D2_img = shift4 * new_im 84 | 85 | return (H_img, V_img, D1_img, D2_img) 86 | 87 | 88 | def gen_gauss_window(lw, sigma): 89 | sd = np.float32(sigma) 90 | lw = int(lw) 91 | weights = [0.0] * (2 * lw + 1) 92 | weights[lw] = 1.0 93 | sum = 1.0 94 | sd *= sd 95 | for ii in range(1, lw + 1): 96 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) 97 | weights[lw + ii] = tmp 98 | weights[lw - ii] = tmp 99 | sum += 2.0 * tmp 100 | for ii in range(2 * lw + 1): 101 | weights[ii] /= sum 102 | return weights 103 | 104 | 105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): 106 | if avg_window is None: 107 | avg_window = gen_gauss_window(3, 7.0 / 6.0) 108 | assert len(np.shape(image)) == 2 109 | h, w = np.shape(image) 110 | mu_image = np.zeros((h, w), dtype=np.float32) 111 | var_image = np.zeros((h, w), dtype=np.float32) 112 | image = np.array(image).astype('float32') 113 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) 114 | scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode=extend_mode) 115 | scipy.ndimage.correlate1d(image ** 2, avg_window, 0, var_image, mode=extend_mode) 116 | scipy.ndimage.correlate1d(var_image, avg_window, 1, var_image, mode=extend_mode) 117 | var_image = np.sqrt(np.abs(var_image - mu_image ** 2)) 118 | return (image - mu_image) / (var_image + C), var_image, mu_image 119 | 120 | 121 | def _niqe_extract_subband_feats(mscncoefs): 122 | # alpha_m, = extract_ggd_features(mscncoefs) 123 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) 124 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs) 125 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) 126 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) 127 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) 128 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) 129 | return np.array([alpha_m, (bl + br) / 2.0, 130 | alpha1, N1, bl1, br1, # (V) 131 | alpha2, N2, bl2, br2, # (H) 132 | alpha3, N3, bl3, bl3, # (D1) 133 | alpha4, N4, bl4, bl4, # (D2) 134 | ]) 135 | 136 | 137 | def get_patches_train_features(img, patch_size, stride=8): 138 | return _get_patches_generic(img, patch_size, 1, stride) 139 | 140 | 141 | def get_patches_test_features(img, patch_size, stride=8): 142 | return _get_patches_generic(img, patch_size, 0, stride) 143 | 144 | 145 | def extract_on_patches(img, patch_size): 146 | h, w = img.shape 147 | patch_size = np.int(patch_size) 148 | patches = [] 149 | for j in range(0, h - patch_size + 1, patch_size): 150 | for i in range(0, w - patch_size + 1, patch_size): 151 | patch = img[j:j + patch_size, i:i + patch_size] 152 | patches.append(patch) 153 | 154 | patches = np.array(patches) 155 | 156 | patch_features = [] 157 | for p in patches: 158 | patch_features.append(_niqe_extract_subband_feats(p)) 159 | patch_features = np.array(patch_features) 160 | 161 | return patch_features 162 | 163 | 164 | def _get_patches_generic(img, patch_size, is_train, stride): 165 | h, w = np.shape(img) 166 | if h < patch_size or w < patch_size: 167 | print("Input image is too small") 168 | exit(0) 169 | 170 | # ensure that the patch divides evenly into img 171 | hoffset = (h % patch_size) 172 | woffset = (w % patch_size) 173 | 174 | if hoffset > 0: 175 | img = img[:-hoffset, :] 176 | if woffset > 0: 177 | img = img[:, :-woffset] 178 | 179 | img = img.astype(np.float32) 180 | img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') 181 | #img2 = np.array(Image.fromarray(img).resize(0.5, interp='bicubic', mode='F')) 182 | 183 | mscn1, var, mu = compute_image_mscn_transform(img) 184 | mscn1 = mscn1.astype(np.float32) 185 | 186 | mscn2, _, _ = compute_image_mscn_transform(img2) 187 | mscn2 = mscn2.astype(np.float32) 188 | 189 | feats_lvl1 = extract_on_patches(mscn1, patch_size) 190 | feats_lvl2 = extract_on_patches(mscn2, patch_size / 2) 191 | 192 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3)) 193 | 194 | return feats 195 | 196 | 197 | def niqe(inputImgData,params): 198 | patch_size = 96 199 | module_path = dirname(__file__) 200 | 201 | # TODO: memoize 202 | #params = scipy.io.loadmat(join(module_path, 'niqe_image_params.mat')) 203 | pop_mu = np.ravel(params["pop_mu"]) 204 | pop_cov = params["pop_cov"] 205 | # print(inputImgData.shape ) 206 | M, N = inputImgData.shape 207 | 208 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) 209 | assert M > ( 210 | patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 211 | assert N > ( 212 | patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 213 | 214 | feats = get_patches_test_features(inputImgData, patch_size) 215 | sample_mu = np.mean(feats, axis=0) 216 | sample_cov = np.cov(feats.T) 217 | 218 | X = sample_mu - pop_mu 219 | covmat = ((pop_cov + sample_cov) / 2.0) 220 | pinvmat = scipy.linalg.pinv(covmat) 221 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) 222 | 223 | return niqe_score 224 | 225 | 226 | if __name__ == "__main__": 227 | ref = np.array(Image.open('./test_imgs/bikes.bmp').convert('LA'))[:, :, 0] # ref 228 | dis = np.array(Image.open('./test_imgs/bikes_distorted.bmp').convert('LA'))[:, :, 0] # dis 229 | 230 | print('NIQE of ref bikes image is: %0.3f' % niqe(ref)) 231 | print('NIQE of dis bikes image is: %0.3f' % niqe(dis)) 232 | 233 | ref = np.array(Image.open('./test_imgs/parrots.bmp').convert('LA'))[:, :, 0] # ref 234 | dis = np.array(Image.open('./test_imgs/parrots_distorted.bmp').convert('LA'))[:, :, 0] # dis 235 | 236 | print('NIQE of ref parrot image is: %0.3f' % niqe(ref)) 237 | print('NIQE of dis parrot image is: %0.3f' % niqe(dis)) -------------------------------------------------------------------------------- /image-criterions-python/PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from skimage.measure import compare_psnr , compare_ssim 4 | import numpy as np 5 | from torchvision.transforms import ToPILImage , ToTensor 6 | 7 | 8 | def get_psnr_ssim(input_img , compared_img , data_range = 1 ,multichannel=True ) : 9 | '''input and compared should be numpy 10 | batch_size h w channel''' 11 | #####################torch test with bn ################### 12 | if isinstance(input_img , torch.Tensor): 13 | if len(input_img.size()) == 4 : 14 | input_img = np.squeeze(input_img.cpu().detach().numpy().transpose(0,2,3,1)) #rgb -> gbr 15 | 16 | 17 | 18 | 19 | if isinstance(compared_img , torch.Tensor): 20 | if len(compared_img.size()) == 4: 21 | compared_img =np.squeeze(compared_img.cpu().detach().numpy().transpose(0,2,3,1)) 22 | #################numpy test################# 23 | Ssim = compare_ssim(input_img , compared_img ,data_range = data_range , multichannel=multichannel) 24 | Psnr = compare_psnr(input_img , compared_img ,data_range = data_range) 25 | return Psnr,Ssim 26 | 27 | 28 | 29 | def output_psnr_mse(img_orig, img_out): 30 | squared_error = np.square(img_orig - img_out) 31 | mse = np.mean(squared_error) 32 | psnr = 10 * np.log10(1.0 / mse) 33 | return psnr -------------------------------------------------------------------------------- /image-criterions-python/README.md: -------------------------------------------------------------------------------- 1 | # Image-Criterions-Python&MATLAB 2 | 3 | This repository contains python implementation of SSIM,PSNR,NIQE,VIF 4 | 5 | FSIM just can find in matlab,try to repreduce in python soon 6 | [link](https://blog.csdn.net/ccheng_11/article/details/88554902) 7 | 8 | 9 | ## Dependencies 10 | 1) Python (>=3.5) 11 | 2) Numpy (>=1.16) 12 | 3) Python Imaging Library (PIL) (>=6.0) 13 | 4) Steerable Pyramid Toolbox (PyPyrTools) [link](https://github.com/LabForComputationalVision/pyPyrTools) 14 | 15 | ## Usage 16 | Let imref and imdist denote reference and distorted images respectively. Then the VIF value is calculated as 17 | VIF = vifvec(imref, imdist) 18 | 19 | A demo code is provided in test.py for testing purposes 20 | 21 | [1]H.R. Sheikh, A.C. Bovik and G. de Veciana, "An information fidelity criterion for image quality assessment using natural scene statistics," IEEE Transactions on Image Processing , vol.14, no.12pp. 2117- 2128, Dec. 2005. 22 | -------------------------------------------------------------------------------- /image-criterions-python/VIF.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Portions Copyright (c) 2014 CiiNOW Inc. 5 | Written by Alex Izvorski , 6 | 2014-03-03 Ported from matlab to python/numpy/scipy 7 | 2014-03-04 Added utility functions to read/compare images and video 8 | """ 9 | 10 | """ 11 | -----------COPYRIGHT NOTICE STARTS WITH THIS LINE------------ 12 | Copyright (c) 2005 The University of Texas at Austin 13 | All rights reserved. 14 | Permission is hereby granted, without written agreement and without license or royalty fees, to use, copy, 15 | modify, and distribute this code (the source files) and its documentation for 16 | any purpose, provided that the copyright notice in its entirety appear in all copies of this code, and the 17 | original source of this code, Laboratory for Image and Video Engineering (LIVE, http://live.ece.utexas.edu) 18 | at the University of Texas at Austin (UT Austin, 19 | http://www.utexas.edu), is acknowledged in any publication that reports research using this code. The research 20 | is to be cited in the bibliography as: 21 | H. R. Sheikh and A. C. Bovik, "Image Information and Visual Quality", IEEE Transactions on 22 | Image Processing, (to appear). 23 | IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT AUSTIN BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, 24 | OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF THIS DATABASE AND ITS DOCUMENTATION, EVEN IF THE UNIVERSITY OF TEXAS 25 | AT AUSTIN HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | THE UNIVERSITY OF TEXAS AT AUSTIN SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 27 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE DATABASE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, 28 | AND THE UNIVERSITY OF TEXAS AT AUSTIN HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 29 | -----------COPYRIGHT NOTICE ENDS WITH THIS LINE------------ 30 | This software release consists of a MULTISCALE PIXEL DOMAIN, SCALAR GSM implementation of the algorithm described in the paper: 31 | H. R. Sheikh and A. C. Bovik, "Image Information and Visual Quality"., IEEE Transactions on Image Processing, (to appear). 32 | Download manuscript draft from http://live.ece.utexas.edu in the Publications link. 33 | THE PIXEL DOMAIN ALGORITHM IS NOT DESCRIBED IN THE PAPER. THIS IS A COMPUTATIONALLY SIMPLER 34 | DERIVATIVE OF THE ALGORITHM PRESENTED IN THE PAPER 35 | Input : (1) img1: The reference image as a matrix 36 | (2) img2: The distorted image (order is important) 37 | Output: (1) VIF the visual information fidelity measure between the two images 38 | Default Usage: 39 | Given 2 test images img1 and img2, whose dynamic range is 0-255 40 | vif = vifvec(img1, img2); 41 | Advanced Usage: 42 | Users may want to modify the parameters in the code. 43 | (1) Modify sigma_nsq to find tune for your image dataset. 44 | Email comments and bug reports to hamid.sheikh@ieee.org 45 | """ 46 | 47 | import numpy 48 | import scipy.signal 49 | import scipy.ndimage 50 | 51 | 52 | def vifp_mscale(ref, dist): 53 | sigma_nsq = 2 54 | eps = 1e-10 55 | 56 | num = 0.0 57 | den = 0.0 58 | for scale in range(1, 5): 59 | 60 | N = 2 ** (4 - scale + 1) + 1 61 | sd = N / 5.0 62 | 63 | if (scale > 1): 64 | ref = scipy.ndimage.gaussian_filter(ref, sd) 65 | dist = scipy.ndimage.gaussian_filter(dist, sd) 66 | ref = ref[::2, ::2] 67 | dist = dist[::2, ::2] 68 | 69 | mu1 = scipy.ndimage.gaussian_filter(ref, sd) 70 | mu2 = scipy.ndimage.gaussian_filter(dist, sd) 71 | mu1_sq = mu1 * mu1 72 | mu2_sq = mu2 * mu2 73 | mu1_mu2 = mu1 * mu2 74 | sigma1_sq = scipy.ndimage.gaussian_filter(ref * ref, sd) - mu1_sq 75 | sigma2_sq = scipy.ndimage.gaussian_filter(dist * dist, sd) - mu2_sq 76 | sigma12 = scipy.ndimage.gaussian_filter(ref * dist, sd) - mu1_mu2 77 | 78 | sigma1_sq[sigma1_sq < 0] = 0 79 | sigma2_sq[sigma2_sq < 0] = 0 80 | 81 | g = sigma12 / (sigma1_sq + eps) 82 | sv_sq = sigma2_sq - g * sigma12 83 | 84 | g[sigma1_sq < eps] = 0 85 | sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps] 86 | sigma1_sq[sigma1_sq < eps] = 0 87 | 88 | g[sigma2_sq < eps] = 0 89 | sv_sq[sigma2_sq < eps] = 0 90 | 91 | sv_sq[g < 0] = sigma2_sq[g < 0] 92 | g[g < 0] = 0 93 | sv_sq[sv_sq <= eps] = eps 94 | 95 | num += numpy.sum(numpy.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq))) 96 | den += numpy.sum(numpy.log10(1 + sigma1_sq / sigma_nsq)) 97 | 98 | vifp = num / den 99 | 100 | return vifp 101 | 102 | -------------------------------------------------------------------------------- /image-criterions-python/comp_for_derain.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 5 | 6 | import re 7 | import torch 8 | import argparse 9 | import urllib.request 10 | from Utils.utils import * 11 | from Utils.Vidsom import * 12 | from Utils.model_init import * 13 | from Utils.ssim_map import SSIM_MAP 14 | from Utils.torch_ssim import SSIM 15 | from torch import nn, optim 16 | from torch.backends import cudnn 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | from torchvision.utils import make_grid 20 | from MyDataset.Datasets import derain_test_datasets, derain_train_datasets 21 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 22 | 23 | from net.model_ori import w_net 24 | 25 | parser = argparse.ArgumentParser(description="PyTorch Derain W") 26 | # root 27 | 28 | # root 29 | parser.add_argument("--train", default="/scratch1/hxw170830/derain_ablation_V1/DID-MDN-datasets/train", type=str, 30 | help="path to load train datasets(default: none)") 31 | parser.add_argument("--test", default="/scratch1/hxw170830/derain_ablation_V1/DID-MDN-datasets/test", type=str, 32 | help="path to load test datasets(default: none)") 33 | 34 | parser.add_argument("--save_image_root", default='./result', type=str, 35 | help="save test image root") 36 | parser.add_argument("--save_root", default="./checkpoints", type=str, 37 | help="path to save networks") 38 | parser.add_argument("--pretrain_root", default="/scratch1/hxw170830/derain_ablation_V1/checkpoints2", type=str, 39 | help="path to pretrained net1 net2 net3 root") 40 | 41 | # hypeparameters 42 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size") 43 | parser.add_argument("--nEpoch", type=int, default=10000, help="number of epochs to train for") 44 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 45 | parser.add_argument("--lr1", type=float, default=1e-4, help="Learning Rate For w") 46 | 47 | parser.add_argument("--train_print_fre", type=int, default=50, help="frequency of print train loss on train phase") 48 | parser.add_argument("--test_frequency", type=int, default=1, help="frequency of test") 49 | parser.add_argument("--test_print_fre", type=int, default=400, help="frequency of print train loss on test phase") 50 | parser.add_argument("--cuda", type=str, default="Ture", help="Use cuda?") 51 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 52 | parser.add_argument("--startweights", default=0, type=int, help="start number of net's weight , 0 is None") 53 | parser.add_argument("--initmethod", default='xavier', type=str, 54 | help="xavier , kaiming , normal ,orthogonal ,default : xavier") 55 | parser.add_argument("--startepoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 56 | parser.add_argument("--works", type=int, default=8, help="Number of works for data loader to use, Default: 1") 57 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 58 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 59 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 60 | parser.add_argument("--pretrain_epoch", default=[93, 169, 123], type=list, help="pretrained epoch for Net1 Net2 Net3") 61 | 62 | 63 | def main(): 64 | global opt, Net1, Net2, Net3, Net4, RefineNet, criterion_mse, criterion_ssim_map, criterion_ssim, criterion_ace 65 | global w,criterion_l1 66 | opt = parser.parse_args() 67 | print(opt) 68 | 69 | cuda = opt.cuda 70 | if cuda and not torch.cuda.is_available(): 71 | raise Exception("No GPU found, please run without --cuda") 72 | 73 | cudnn.benchmark = True 74 | 75 | print("==========> Loading datasets") 76 | 77 | train_dataset = derain_train_datasets(data_root=opt.train, transform=Compose([ 78 | ToTensor() 79 | ])) 80 | 81 | test_dataset = derain_test_datasets(opt.test, transform=Compose([ 82 | ToTensor() 83 | ])) 84 | 85 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 86 | pin_memory=True, shuffle=True) 87 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 88 | shuffle=True) 89 | 90 | if opt.initmethod == 'orthogonal': 91 | init_function = weights_init_orthogonal 92 | 93 | elif opt.initmethod == 'kaiming': 94 | init_function = weights_init_kaiming 95 | 96 | elif opt.initmethod == 'normal': 97 | init_function = weights_init_normal 98 | 99 | else: 100 | init_function = weights_init_xavier 101 | 102 | w = w_net() 103 | w.apply(init_function) 104 | 105 | criterion_mse = nn.MSELoss() 106 | criterion_ssim_map = SSIM_MAP() 107 | criterion_ssim = SSIM() 108 | criterion_ace = nn.SmoothL1Loss() 109 | 110 | criterion_l1 = nn.L1Loss() 111 | 112 | print("==========> Setting GPU") 113 | # if cuda: 114 | if opt.cuda: 115 | w = nn.DataParallel(w, device_ids=[i for i in range(opt.gpus)]).cuda() 116 | 117 | criterion_ssim = criterion_ssim.cuda() 118 | criterion_ssim_map = criterion_ssim_map.cuda() 119 | criterion_mse = criterion_mse.cuda() 120 | criterion_ace = criterion_ace.cuda() 121 | criterion_l1 = criterion_l1.cuda() 122 | else: 123 | raise Exception("it takes a long time without cuda ") 124 | # print(net) 125 | 126 | # if opt.pretrain_root: 127 | # if os.path.exists(opt.pretrain_root): 128 | # print("=> loading net from '{}'".format(opt.pretrain_root)) 129 | # weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 130 | # Net1.load_state_dict(weights['state_dict'] ) 131 | # 132 | # weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 133 | # Net2.load_state_dict(weights['state_dict'] ) 134 | # 135 | # weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 136 | # Net3.load_state_dict(weights['state_dict']) 137 | # 138 | # del weights 139 | # else: 140 | # print("=> no net found at '{}'".format(opt.pretrain_root)) 141 | 142 | # weights start from early 143 | # if opt.startweights: 144 | # if os.path.exists(opt.save_root): 145 | # print("=> loading checkpoint '{}'".format(opt.save_root)) 146 | # weights = torch.load(opt.save_root + '/Net1/%s.pth'%opt.startweights) 147 | # Net1.load_state_dict(weights["state_dict"] ) 148 | # 149 | # weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 150 | # Net2.load_state_dict(weights["state_dict"]) 151 | # 152 | # weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 153 | # Net3.load_state_dict(weights["state_dict"]) 154 | # 155 | # weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 156 | # Net4.load_state_dict(weights["state_dict"]) 157 | # 158 | # weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 159 | # RefineNet.load_state_dict(weights["state_dict"]) 160 | # 161 | # del weights 162 | # else: 163 | # raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 164 | 165 | print("==========> Setting Optimizer") 166 | optimizerw = optim.Adam(filter(lambda p: p.requires_grad, w.parameters()), lr=opt.lr1) 167 | 168 | optimizer = [1, optimizerw] 169 | print("==========> Training") 170 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 171 | 172 | if epoch > 50: 173 | opt.lr1 = 1e-4 174 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, w.parameters()), lr=opt.lr1) 175 | 176 | train(training_data_loader, optimizer, epoch) 177 | 178 | if epoch % opt.test_frequency == 0: 179 | test(testing_data_loader, epoch) 180 | 181 | 182 | def train(training_data_loader, optimizer, epoch): 183 | print("training ==========> epoch =", epoch, "lr =", opt.lr1) 184 | w.train() 185 | # Net1.train() 186 | # Net2.train() 187 | # Net3.train() 188 | # Net4.train() 189 | # RefineNet.train() 190 | t_loss = [] # save trainloss 191 | 192 | for step, (data, label) in enumerate(training_data_loader, 1): 193 | if opt.cuda and torch.cuda.is_available(): 194 | data = data.clone().detach().requires_grad_(True).cuda() 195 | label = label.cuda() 196 | else: 197 | raise Exception("it takes a long time without cuda ") 198 | data = data.cpu() 199 | label = label.cpu() 200 | 201 | w_out = w(data) 202 | ssim_loss = 1 - criterion_ssim_map(w_out , label).mean() 203 | 204 | w.zero_grad() 205 | optimizer[1].zero_grad() 206 | 207 | ssim_loss.backward() 208 | optimizer[1].step() 209 | 210 | if step % opt.train_print_fre == 0: 211 | print("epoch{} step {} loss {:6f} ".format(epoch, step, ssim_loss.item(), )) 212 | t_loss.append(l1_loss.item()) 213 | 214 | # else: 215 | # displaying to train loss 216 | # updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 217 | 218 | 219 | def test(test_data_loader, epoch): 220 | print("------> testing") 221 | w.eval() 222 | 223 | torch.cuda.empty_cache() 224 | 225 | with torch.no_grad(): 226 | 227 | test_Psnr_sum = 0.0 228 | test_Ssim_sum = 0.0 229 | 230 | # showing list 231 | test_Psnr_loss = [] 232 | test_Ssim_loss = [] 233 | dict_psnr_ssim = {} 234 | for test_step, (data, label, data_path) in enumerate(test_data_loader, 1): 235 | data = data.cuda() 236 | label = label.cuda() 237 | 238 | w_out = w(data) 239 | 240 | ssim_loss = 1 - criterion_ssim(w_out, label) 241 | Psnr, Ssim = get_psnr_ssim(w_out, label) 242 | 243 | Psnr = round(Psnr.item(), 4) 244 | Ssim = round(Ssim.item(), 4) 245 | 246 | test_Psnr_sum += Psnr 247 | test_Ssim_sum += Ssim 248 | 249 | # if opt.save_image == True: 250 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 251 | # out = refineNet_out.cpu().data[0] 252 | # out = ToPILImage()(out) 253 | # image_number = re.findall(r'\d+', data_path[0])[1] 254 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 255 | if test_step % opt.test_print_fre == 0: 256 | print("epoch={} Psnr={} Ssim={} mseloss{}".format(epoch, Psnr, Ssim, ssim_loss.item())) 257 | test_Psnr_loss.append(test_Psnr_sum / test_step) 258 | test_Ssim_loss.append(test_Ssim_sum / test_step) 259 | 260 | else: 261 | del ssim_loss 262 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 263 | test_Ssim_sum / test_step)) 264 | write_test_perform("./perform_test_l1.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 265 | # visdom showing 266 | print("---->testing over show in visdom") 267 | # display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 268 | # env="derain_test") 269 | 270 | print("epoch {} train over-----> save net".format(epoch)) 271 | print("saving checkpoint save_root{}".format(opt.save_root)) 272 | if os.path.exists(opt.save_root): 273 | save_checkpoint(root=opt.save_root, model=w, epoch=epoch, model_stage="wl1") 274 | print("finish save epoch{} checkporint".format({epoch})) 275 | else: 276 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 277 | # 278 | 279 | print("all epoch is over ------ ") 280 | print("show epoch and epoch_loss in visdom") 281 | 282 | 283 | if __name__ == "__main__": 284 | os.system('clear') 285 | main() 286 | 287 | -------------------------------------------------------------------------------- /image-criterions-python/demo-images/clear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/image-criterions-python/demo-images/clear.png -------------------------------------------------------------------------------- /image-criterions-python/demo-images/rain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/image-criterions-python/demo-images/rain.png -------------------------------------------------------------------------------- /image-criterions-python/demo-test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Written by peylnog 5 | 2020-3-30 6 | """ 7 | import cv2 8 | import numpy as np 9 | from VIF import vifp_mscale as vif 10 | from NIQE import niqe 11 | from PIL import Image 12 | from sewar.full_ref import uqi 13 | from PSNR_SSIM import get_psnr_ssim 14 | 15 | 16 | if __name__ == '__main__': 17 | image_rain = cv2.imread('./demo-images/rain.png') 18 | image_clear = cv2.imread('./demo-images/clear.png') 19 | #------------------------VIF----------------- 20 | try: 21 | print( " vif : " , vif(image_rain , image_clear) ) 22 | 23 | except IOError: 24 | print("check out image path") 25 | # ------------------------niqe----------------- 26 | 27 | try: 28 | rain = np.array(Image.open('./demo-images/rain.png').convert('LA'))[:, :, 0] 29 | clear = np.array(Image.open('./demo-images/clear.png').convert('LA'))[:, :, 0] 30 | 31 | print( " clear image niqe : " , niqe(rain)) 32 | print( " rain image niqe : " , niqe(clear)) 33 | except IOError: 34 | print("check out image path") 35 | # ------------------------psnr----------------- 36 | try: 37 | rain = np.array(Image.open('./demo-images/rain.png')) 38 | clear = np.array(Image.open('./demo-images/clear.png')) 39 | 40 | print( " uqi : " , uqi(clear,rain)) 41 | except IOError: 42 | print("check out image path") 43 | 44 | -------------------------------------------------------------------------------- /image-criterions-python/derain_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Written by peylnog 5 | 2020-3-30 6 | """ 7 | import cv2 8 | import os 9 | from os.path import join 10 | import numpy as np 11 | from VIF import vifp_mscale as vif 12 | from NIQE import niqe 13 | from PIL import Image 14 | from sewar.full_ref import uqi 15 | from sewar.full_ref import vifp 16 | 17 | from PSNR_SSIM import get_psnr_ssim 18 | import scipy 19 | from os.path import dirname 20 | 21 | if __name__ == '__main__': 22 | rain_img_root = "/home/ws/Desktop/derain2020/result_new" 23 | clear_img_root = "/home/ws/Desktop/derain2020/clear" 24 | 25 | AV_VIF = 0 26 | AV_NIQE = 0 27 | AV_UQI = 0 28 | 29 | rain_imgs = [join(rain_img_root , x ) for x in os.listdir(rain_img_root)] 30 | rain_imgs = sorted(rain_imgs) 31 | 32 | clear_imgs = [join(clear_img_root , x ) for x in os.listdir(clear_img_root)] 33 | clear_imgs = sorted(clear_imgs) 34 | module_path = dirname(__file__) 35 | 36 | params = scipy.io.loadmat(join(module_path, 'niqe_image_params.mat')) 37 | 38 | n = len(rain_imgs) 39 | for i in range(len(rain_imgs)): 40 | print(i) 41 | img_r = rain_imgs[i] 42 | img_c = clear_imgs[i] 43 | rain = np.array(Image.open(img_r)) 44 | clear = np.array(Image.open(img_c)) 45 | 46 | AV_VIF += vifp( clear , rain) 47 | AV_UQI += uqi(clear,rain) 48 | rain = np.array(Image.open(img_r).convert('LA'))[:, :, 0] # ref 49 | AV_NIQE += niqe(rain , params) 50 | 51 | print("AV_VIF :" , AV_VIF/n) 52 | print("AV_UQI :" , AV_UQI/n) 53 | print("AV_NIQE :" , AV_NIQE/n) 54 | 55 | 56 | # 57 | # 58 | # #------------------------VIF----------------- 59 | # rain = np.array(Image.open('./demo-images/rain.png')) 60 | # clear = np.array(Image.open('./demo-images/clear.png')) 61 | # try: 62 | # print( " vif : " , vifp(rain , clear) ) 63 | # 64 | # except IOError: 65 | # print("check out image path") 66 | # # ------------------------niqe----------------- 67 | # 68 | # try: 69 | # rain = np.array(Image.open('./demo-images/rain.png').convert('LA'))[:, :, 0] 70 | # clear = np.array(Image.open('./demo-images/clear.png').convert('LA'))[:, :, 0] 71 | # 72 | # print( " rain image niqe : " , niqe(rain)) 73 | # print( " clear image niqe : " , niqe(clear)) 74 | # except IOError: 75 | # print("check out image path") 76 | # # ------------------------psnr----------------- 77 | # try: 78 | # rain = np.array(Image.open('./demo-images/rain.png')) 79 | # clear = np.array(Image.open('./demo-images/clear.png')) 80 | # 81 | # print( " uqi : " , uqi(clear,rain)) 82 | # except IOError: 83 | # print("check out image path") 84 | 85 | -------------------------------------------------------------------------------- /image-criterions-python/for_class.py: -------------------------------------------------------------------------------- 1 | from sewar import full_ref -------------------------------------------------------------------------------- /image-criterions-python/niqe_image_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/image-criterions-python/niqe_image_params.mat -------------------------------------------------------------------------------- /net/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/net/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/model_skip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/net/__pycache__/model_skip.cpython-37.pyc -------------------------------------------------------------------------------- /net/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peylnog/EnsembleNet/75fb52dc1f76d4000311a31024994f2f67e4303e/net/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /net/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .networks import * 6 | from Utils.model_init import * 7 | from torchvision import models 8 | 9 | 10 | 11 | class u_net(nn.Module): 12 | '''u-net get 512 512 3 13 | out 512 512 3 in order to derain''' 14 | 15 | def __init__(self ): 16 | super(u_net,self).__init__() 17 | 18 | self.encode1 = conv_blocks_size_3(in_dim=3 , out_dim=8 ,Use_pool=True) #256 19 | self.encode2 = conv_blocks_size_3(in_dim=8 ,out_dim=16,Use_pool=True) #128 20 | self.encode3 = conv_blocks_size_3(in_dim=16 ,out_dim=32,Use_pool=True) #64 21 | self.encode4 = conv_blocks_size_3(in_dim=32,out_dim=64,Use_pool=True) #32 22 | self.encode5 = conv_blocks_size_3(in_dim=64,out_dim=128,Use_pool=True) #16 23 | 24 | 25 | 26 | self.decode1 =deconv_blocks_size_3(128,64,Use_pool=True) #32 27 | self.decode2 = deconv_blocks_size_3(64,32,Use_pool=True) #64 28 | self.decode3 = deconv_blocks_size_3(32,16,Use_pool=True) #128 29 | self.decode4 = deconv_blocks_size_3(16,8,True) #256 30 | self.decode5 = deconv_blocks_size_3(8,3,True) #512 31 | 32 | def forward(self,x): 33 | """ 34 | 35 | :param x: rain image size: bn channel size 36 | :return: same size with x 37 | """ 38 | 39 | 40 | encode1 = self.encode1(x) #256 256 8 41 | encode2 = self.encode2(encode1) #128 128 16 42 | encode3 = self.encode3(encode2) #64 64 32 43 | encode4 = self.encode4(encode3) #32 32 64 44 | encode5 = self.encode5(encode4) #16 16 128 45 | 46 | decode = self.decode1(encode5) 47 | decode = F.interpolate(decode, encode4.size()[2:], mode='bilinear', align_corners=True) 48 | decode = torch.add(decode , encode4) #32 49 | 50 | decode = self.decode2(decode) 51 | decode = F.interpolate(decode, encode3.size()[2:], mode='bilinear', align_corners=True) 52 | decode = torch.add(decode , encode3)#64 53 | 54 | decode = self.decode3(decode) 55 | decode = F.interpolate(decode, encode2.size()[2:], mode='bilinear', align_corners=True) 56 | decode = torch.add(decode, encode2)#128 57 | 58 | decode = self.decode4(decode) 59 | decode = F.interpolate(decode, encode1.size()[2:], mode='bilinear', align_corners=True) 60 | decode = torch.add(decode, encode1)#256 61 | 62 | decode = self.decode5(decode) 63 | decode = F.interpolate(decode, x.size()[2:], mode='bilinear', align_corners=True) 64 | decode = torch.add(decode, x)#512 65 | 66 | 67 | return decode 68 | 69 | 70 | 71 | class w_net(nn.Module): 72 | '''w_net get 512 512 3 73 | out 512 512 3''' 74 | def __init__(self ): 75 | super(w_net, self).__init__() 76 | 77 | self.encode1 = conv_blocks_size_3(in_dim=3, out_dim=8, Use_pool=True ) # 256 78 | self.encode2 = conv_blocks_size_3(in_dim=8, out_dim=16, Use_pool=True ) # 128 79 | self.encode3 = conv_blocks_size_3(in_dim=16, out_dim=32, Use_pool=True ) # 64 80 | self.encode4 = conv_blocks_size_3(in_dim=32, out_dim=64, Use_pool=True ) # 32 81 | self.encode5 = conv_blocks_size_3(in_dim=64, out_dim=128, Use_pool=True ) # 16 82 | 83 | self.decode2 = deconv_blocks_size_3(128, 64, Use_pool=True) # 32 84 | self.decode3 = deconv_blocks_size_3(64, 32, Use_pool=True) # 64 85 | self.decode4 = deconv_blocks_size_3(32, 16, Use_pool=True) # 128 86 | self.decode5 = deconv_blocks_size_3(16, 8, True) # 256 87 | self.decode6 = deconv_blocks_size_3(8, 3, True) # 512 88 | 89 | self.eencode1 = conv_blocks_size_3(3, 8, True) # 256 90 | self.eencode2 = conv_blocks_size_3(8, 16, True) # 128 91 | self.eencode3 = conv_blocks_size_3(16, 32, True) # 64 92 | self.eencode4 = conv_blocks_size_3(32, 64, True) # 32 93 | self.eencode5 = conv_blocks_size_3(64, 128, True) # 16 94 | 95 | 96 | self.ddcode2 = deconv_blocks_size_3(128, 64) #32 97 | self.ddcode3 = deconv_blocks_size_3(64, 32) #64 98 | self.ddcode4 = deconv_blocks_size_3(32, 16) #128 99 | self.ddcode5 = deconv_blocks_size_3(16, 8) #256 100 | self.ddcode6 = deconv_blocks_size_3(8, 3) #512 101 | 102 | def forward(self, x): 103 | 104 | # w-net 105 | 106 | """ 107 | :param x: 108 | :return: 109 | """ 110 | encode1 = self.encode1(x) # 256 111 | encode2 = self.encode2(encode1) # 128 112 | encode3 = self.encode3(encode2) # 64 113 | encode4 = self.encode4(encode3) # 32 114 | encode5 = self.encode5(encode4) # 16 115 | 116 | decode = self.decode2(encode5) 117 | decode = F.interpolate(decode, encode4.size()[2:], mode='bilinear', align_corners=True) 118 | decode = torch.add(decode, encode4) # 32 119 | 120 | decode = self.decode3(decode) 121 | decode = F.interpolate(decode, encode3.size()[2:], mode='bilinear', align_corners=True) 122 | decode = torch.add(decode, encode3) # 64 123 | 124 | decode = self.decode4(decode) 125 | decode = F.interpolate(decode, encode2.size()[2:], mode='bilinear', align_corners=True) 126 | decode = torch.add(decode, encode2) # 128 127 | 128 | decode = self.decode5(decode) 129 | decode = F.interpolate(decode, encode1.size()[2:], mode='bilinear', align_corners=True) 130 | decode = torch.add(decode, encode1) # 256 131 | 132 | decode = self.decode6(decode) 133 | decode = F.interpolate(decode, x.size()[2:], mode='bilinear', align_corners=True) 134 | decode = torch.add(decode, x) # 512 135 | 136 | del encode1 137 | del encode2 138 | del encode3 139 | del encode4 140 | del encode5 141 | 142 | eencode1 = self.eencode1(decode) #256 143 | eencode2 = self.eencode2(eencode1) #128 144 | eencode3 = self.eencode3(eencode2) #64 145 | eencode4 = self.eencode4(eencode3) #32 146 | eencode5 = self.eencode5(eencode4) #16 147 | 148 | 149 | ddecode = self.ddcode2(eencode5) #32 150 | ddecode = torch.add(eencode4 , F.interpolate(ddecode, eencode4.size()[2:], mode='bilinear', align_corners=True)) 151 | 152 | ddecode = self.ddcode3(ddecode) #64 153 | ddecode = torch.add(eencode3, F.interpolate(ddecode, eencode3.size()[2:], mode='bilinear', align_corners=True)) 154 | 155 | ddecode = self.ddcode4(ddecode) #128 156 | ddecode = torch.add(eencode2, F.interpolate(ddecode, eencode2.size()[2:], mode='bilinear', align_corners=True)) 157 | 158 | ddecode = self.ddcode5(ddecode) #256 159 | ddecode = torch.add(eencode1, F.interpolate(ddecode, eencode1.size()[2:], mode='bilinear', align_corners=True)) 160 | 161 | ddecode = self.ddcode6(ddecode) #512 162 | ddecode = torch.add(decode, F.interpolate(ddecode, decode.size()[2:], mode='bilinear', align_corners=True)) 163 | 164 | return ddecode 165 | 166 | class res_net(nn.Module): 167 | def __init__(self ): 168 | super(res_net, self).__init__() 169 | 170 | self.block1 = res_net_blocks(3,3,ker_size=3,padding=1,stride=1) 171 | self.block2 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 172 | self.block3 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 173 | self.block4 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 174 | self.block5 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 175 | self.block6 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 176 | self.block7 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 177 | self.block8 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 178 | self.block9 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 179 | self.block10 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 180 | self.block11 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 181 | self.reduce_channle = res_net_blocks(4, 3, ker_size=3, padding=1, stride=1) #if add insert_img 182 | 183 | def forward(self, x): 184 | x1 =torch.add ( self.block1(x) ,x ) 185 | del x 186 | 187 | x2 = torch.add ( self.block2(x1) ,x1 ) 188 | del x1 189 | 190 | x3 = torch.add(self.block3(x2), x2) 191 | del x2 192 | 193 | x4 = torch.add(self.block4(x3) , x3) 194 | del x3 195 | 196 | x5 = torch.add(self.block5(x4), x4) 197 | del x4 198 | 199 | x6 = torch.add(self.block6(x5), x5) 200 | del x5 201 | 202 | x7 = torch.add(self.block7(x6), x6) 203 | del x6 204 | 205 | x8 = torch.add(self.block8(x7), x7) 206 | 207 | del x7 208 | 209 | x9 = torch.add(self.block9(x8), x8) 210 | 211 | del x8 212 | x10 = torch.add(self.block10(x9), x9) 213 | 214 | del x9 215 | x11 = torch.add(self.block11(x10), x10) 216 | 217 | del x10 218 | 219 | 220 | return x11 221 | 222 | 223 | 224 | class Net4(nn.Module): 225 | def __init__(self): 226 | super(Net4 , self).__init__() 227 | 228 | self.reduce_channel_1 = conv_block_size_3(9, 6) 229 | self.reduce_channel_2 = conv_block_size_3(6, 3) 230 | 231 | def forward(self,x1 , x2 , x3): 232 | 233 | x = self.reduce_channel_1(torch.cat([x1,x2,x3] ,dim = 1)) 234 | return self.reduce_channel_2(x) 235 | 236 | 237 | class Net4_2Net(nn.Module): 238 | def __init__(self): 239 | super(Net4_2Net , self).__init__() 240 | 241 | self.reduce_channel_1 = conv_block_size_3(6, 6) 242 | self.reduce_channel_2 = conv_block_size_3(6, 3) 243 | 244 | def forward(self,x1 ,x2): 245 | 246 | x = self.reduce_channel_1(torch.cat([x1,x2] ,dim = 1)) 247 | return self.reduce_channel_2(x) 248 | 249 | 250 | 251 | class refineNet(nn.Module): 252 | 253 | def __init__(self): 254 | super(refineNet,self).__init__() 255 | 256 | self.up_dim = deconv_blocks_size_3(12 , 24) 257 | self.down_dim1 = conv_blocks_size_3(24, 12 , Use_pool= True ) 258 | 259 | self.reduce1 = conv_block_size_3(12,6) 260 | self.reduce2 = conv_block_size_3(6, 3) 261 | def forward(self, x1 , x2 , x3 , x4): 262 | x = self.up_dim(torch.cat([x1 , x2 , x3 , x4] , dim = 1)) 263 | x = self.down_dim1(x) 264 | 265 | return self.reduce2(self.reduce1(x)) 266 | 267 | 268 | class refineNet_2Net(nn.Module): 269 | 270 | def __init__(self): 271 | super(refineNet_2Net,self).__init__() 272 | 273 | self.up_dim = deconv_blocks_size_3(9 , 24) 274 | self.down_dim1 = conv_blocks_size_3(24, 12 , Use_pool= True ) 275 | 276 | self.reduce1 = conv_block_size_3(12,6) 277 | self.reduce2 = conv_block_size_3(6, 3) 278 | def forward(self, x1 , x2 , x3 ): 279 | x = self.up_dim(torch.cat([x1 , x2 , x3 ] , dim = 1)) 280 | x = self.down_dim1(x) 281 | 282 | return self.reduce2(self.reduce1(x)) 283 | -------------------------------------------------------------------------------- /net/model_skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .networks import * 6 | from Utils.model_init import * 7 | from torchvision import models 8 | 9 | 10 | 11 | class u_net(nn.Module): 12 | '''u-net get 512 512 3 13 | out 512 512 3 in order to derain''' 14 | 15 | def __init__(self ): 16 | super(u_net,self).__init__() 17 | 18 | self.encode1 = conv_blocks_size_3(in_dim=3 , out_dim=8 ,Use_pool=True) #256 19 | self.encode2 = conv_blocks_size_3(in_dim=8 ,out_dim=16,Use_pool=True) #128 20 | self.encode3 = conv_blocks_size_3(in_dim=16 ,out_dim=32,Use_pool=True) #64 21 | self.encode4 = conv_blocks_size_3(in_dim=32,out_dim=64,Use_pool=True) #32 22 | self.encode5 = conv_blocks_size_3(in_dim=64,out_dim=128,Use_pool=True) #16 23 | 24 | 25 | 26 | self.decode1 =deconv_blocks_size_3(128,64,Use_pool=True) #32 27 | self.decode2 = deconv_blocks_size_3(64,32,Use_pool=True) #64 28 | self.decode3 = deconv_blocks_size_3(32,16,Use_pool=True) #128 29 | self.decode4 = deconv_blocks_size_3(16,8,True) #256 30 | self.decode5 = deconv_blocks_size_3(8,3,True) #512 31 | 32 | def forward(self,x): 33 | """ 34 | 35 | :param x: rain image size: bn channel size 36 | :return: same size with x 37 | """ 38 | 39 | 40 | encode1 = self.encode1(x) #256 256 8 41 | encode2 = self.encode2(encode1) #128 128 16 42 | encode3 = self.encode3(encode2) #64 64 32 43 | encode4 = self.encode4(encode3) #32 32 64 44 | encode5 = self.encode5(encode4) #16 16 128 45 | 46 | decode = self.decode1(encode5) 47 | decode = F.interpolate(decode, encode4.size()[2:], mode='bilinear', align_corners=True) 48 | decode = torch.add(decode , encode4) #32 49 | 50 | decode = self.decode2(decode) 51 | decode = F.interpolate(decode, encode3.size()[2:], mode='bilinear', align_corners=True) 52 | decode = torch.add(decode , encode3)#64 53 | 54 | decode = self.decode3(decode) 55 | decode = F.interpolate(decode, encode2.size()[2:], mode='bilinear', align_corners=True) 56 | decode = torch.add(decode, encode2)#128 57 | 58 | decode = self.decode4(decode) 59 | decode = F.interpolate(decode, encode1.size()[2:], mode='bilinear', align_corners=True) 60 | decode = torch.add(decode, encode1)#256 61 | 62 | decode = self.decode5(decode) 63 | decode = F.interpolate(decode, x.size()[2:], mode='bilinear', align_corners=True) 64 | decode = torch.add(decode, x)#512 65 | 66 | 67 | return decode 68 | 69 | 70 | 71 | class w_net(nn.Module): 72 | '''w_net get 512 512 3 73 | out 512 512 3''' 74 | def __init__(self ): 75 | super(w_net, self).__init__() 76 | 77 | self.encode1 = conv_blocks_size_3(in_dim=3, out_dim=8, Use_pool=True ) # 256 78 | self.encode2 = conv_blocks_size_3(in_dim=8, out_dim=16, Use_pool=True ) # 128 79 | self.encode3 = conv_blocks_size_3(in_dim=16, out_dim=32, Use_pool=True ) # 64 80 | self.encode4 = conv_blocks_size_3(in_dim=32, out_dim=64, Use_pool=True ) # 32 81 | self.encode5 = conv_blocks_size_3(in_dim=64, out_dim=128, Use_pool=True ) # 16 82 | 83 | self.decode2 = deconv_blocks_size_3(128, 64, Use_pool=True) # 32 84 | self.decode3 = deconv_blocks_size_3(64, 32, Use_pool=True) # 64 85 | self.decode4 = deconv_blocks_size_3(32, 16, Use_pool=True) # 128 86 | self.decode5 = deconv_blocks_size_3(16, 8, True) # 256 87 | self.decode6 = deconv_blocks_size_3(8, 3, True) # 512 88 | 89 | self.eencode1 = conv_blocks_size_3(3, 8, True) # 256 90 | self.eencode2 = conv_blocks_size_3(8, 16, True) # 128 91 | self.eencode3 = conv_blocks_size_3(16, 32, True) # 64 92 | self.eencode4 = conv_blocks_size_3(32, 64, True) # 32 93 | self.eencode5 = conv_blocks_size_3(64, 128, True) # 16 94 | 95 | 96 | self.ddcode2 = deconv_blocks_size_3(128, 64) #32 97 | self.ddcode3 = deconv_blocks_size_3(64, 32) #64 98 | self.ddcode4 = deconv_blocks_size_3(32, 16) #128 99 | self.ddcode5 = deconv_blocks_size_3(16, 8) #256 100 | self.ddcode6 = deconv_blocks_size_3(8, 3) #512 101 | 102 | def forward(self, x): 103 | 104 | # w-net 105 | 106 | """ 107 | :param x: 108 | :return: 109 | """ 110 | encode1 = self.encode1(x) # 256 111 | encode2 = self.encode2(encode1) # 128 112 | encode3 = self.encode3(encode2) # 64 113 | encode4 = self.encode4(encode3) # 32 114 | encode5 = self.encode5(encode4) # 16 115 | 116 | decode = self.decode2(encode5) 117 | decode = F.interpolate(decode, encode4.size()[2:], mode='bilinear', align_corners=True) 118 | decode = torch.add(decode, encode4) # 32 119 | 120 | decode = self.decode3(decode) 121 | decode = F.interpolate(decode, encode3.size()[2:], mode='bilinear', align_corners=True) 122 | decode = torch.add(decode, encode3) # 64 123 | 124 | decode = self.decode4(decode) 125 | decode = F.interpolate(decode, encode2.size()[2:], mode='bilinear', align_corners=True) 126 | decode = torch.add(decode, encode2) # 128 127 | 128 | decode = self.decode5(decode) 129 | decode = F.interpolate(decode, encode1.size()[2:], mode='bilinear', align_corners=True) 130 | decode = torch.add(decode, encode1) # 256 131 | 132 | decode = self.decode6(decode) 133 | decode = F.interpolate(decode, x.size()[2:], mode='bilinear', align_corners=True) 134 | decode = torch.add(decode, x) # 512 135 | 136 | del encode1 137 | del encode2 138 | del encode3 139 | del encode4 140 | del encode5 141 | 142 | eencode1 = self.eencode1(decode) #256 143 | eencode2 = self.eencode2(eencode1) #128 144 | eencode3 = self.eencode3(eencode2) #64 145 | eencode4 = self.eencode4(eencode3) #32 146 | eencode5 = self.eencode5(eencode4) #16 147 | 148 | 149 | ddecode = self.ddcode2(eencode5) #32 150 | ddecode = torch.add(eencode4 , F.interpolate(ddecode, eencode4.size()[2:], mode='bilinear', align_corners=True)) 151 | 152 | ddecode = self.ddcode3(ddecode) #64 153 | ddecode = torch.add(eencode3, F.interpolate(ddecode, eencode3.size()[2:], mode='bilinear', align_corners=True)) 154 | 155 | ddecode = self.ddcode4(ddecode) #128 156 | ddecode = torch.add(eencode2, F.interpolate(ddecode, eencode2.size()[2:], mode='bilinear', align_corners=True)) 157 | 158 | ddecode = self.ddcode5(ddecode) #256 159 | ddecode = torch.add(eencode1, F.interpolate(ddecode, eencode1.size()[2:], mode='bilinear', align_corners=True)) 160 | 161 | ddecode = self.ddcode6(ddecode) #512 162 | ddecode = torch.add(decode, F.interpolate(ddecode, decode.size()[2:], mode='bilinear', align_corners=True)) 163 | 164 | return ddecode 165 | 166 | class res_net(nn.Module): 167 | def __init__(self ): 168 | super(res_net, self).__init__() 169 | 170 | self.block1 = res_net_blocks(3,3,ker_size=3,padding=1,stride=1) 171 | self.block2 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 172 | self.block3 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 173 | self.block4 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 174 | self.block5 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 175 | self.block6 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 176 | self.block7 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 177 | self.block8 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 178 | self.block9 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 179 | self.block10 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 180 | self.block11 = res_net_blocks(3, 3, ker_size=3, padding=1, stride=1) 181 | self.reduce_channle = res_net_blocks(4, 3, ker_size=3, padding=1, stride=1) #if add insert_img 182 | 183 | def forward(self, x): 184 | x1 =torch.add ( self.block1(x) ,x ) 185 | del x 186 | 187 | x2 = torch.add ( self.block2(x1) ,x1 ) 188 | del x1 189 | 190 | x3 = torch.add(self.block3(x2), x2) 191 | del x2 192 | 193 | x4 = torch.add(self.block4(x3) , x3) 194 | del x3 195 | 196 | x5 = torch.add(self.block5(x4), x4) 197 | del x4 198 | 199 | x6 = torch.add(self.block6(x5), x5) 200 | del x5 201 | 202 | x7 = torch.add(self.block7(x6), x6) 203 | del x6 204 | 205 | x8 = torch.add(self.block8(x7), x7) 206 | 207 | del x7 208 | 209 | x9 = torch.add(self.block9(x8), x8) 210 | 211 | del x8 212 | x10 = torch.add(self.block10(x9), x9) 213 | 214 | del x9 215 | x11 = torch.add(self.block11(x10), x10) 216 | 217 | del x10 218 | 219 | 220 | return x11 221 | 222 | 223 | 224 | class Net4(nn.Module): 225 | def __init__(self): 226 | super(Net4 , self).__init__() 227 | 228 | self.reduce_channel_1 = conv_block_size_3(9, 6) 229 | self.reduce_channel_2 = conv_block_size_3(6, 3) 230 | 231 | def forward(self,x1 , x2 , x3): 232 | 233 | x = self.reduce_channel_1(torch.cat([x1,x2,x3] ,dim = 1)) 234 | return self.reduce_channel_2(x) 235 | 236 | 237 | 238 | 239 | class refineNet(nn.Module): 240 | '''refinenet: input 512 512 12 , 241 | 3 rainstreaks with background info 242 | 1 derain image 243 | out put :512 512 3 derain ''' 244 | 245 | def __init__(self): 246 | super(refineNet,self).__init__() 247 | 248 | self.up_dim = deconv_blocks_size_3(15 , 24) #1024 1024 24 249 | self.down_dim1 = conv_blocks_size_3(24, 15,True ) 250 | 251 | self.reduce1 = conv_block_size_3(15,9) 252 | 253 | self.reduce2 = conv_block_size_3(9, 6) 254 | 255 | self.reduce3 = conv_block_size_3(6, 3) 256 | 257 | 258 | 259 | def forward(self, x1 , x2 , x3 , x4 ,data ): 260 | x = self.up_dim(torch.cat([x1 , x2 , x3 , x4 , data ] , dim = 1)) #15 261 | 262 | x = self.down_dim1(x) 263 | return self.reduce3(self.reduce2(self.reduce1(x))) 264 | -------------------------------------------------------------------------------- /net/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import * 4 | 5 | 6 | 7 | def conv_block_size_3(in_dim,out_dim, bn = False): 8 | layer = nn.Sequential() 9 | layer.add_module( "conv1",nn.Conv2d(in_dim ,out_dim , kernel_size= 3 , padding= 1, stride=1)) 10 | layer.add_module('relu', nn.ReLU(False)) 11 | 12 | if bn: 13 | layer.add_module('bn', nn.BatchNorm2d(out_dim)) 14 | 15 | return layer 16 | 17 | 18 | def conv_blocks_size_3(in_dim,out_dim, Use_pool = False ,Maxpool = True,bn = True): 19 | layer = nn.Sequential() 20 | layer.add_module( "conv1",nn.Conv2d(in_dim ,out_dim , kernel_size= 3 , padding= 1, stride=1)) 21 | if bn: 22 | layer.add_module('bn' ,nn.BatchNorm2d(out_dim)) 23 | layer.add_module('relu', nn.ReLU(False)) 24 | 25 | 26 | layer.add_module( "conv2",nn.Conv2d(out_dim ,out_dim , kernel_size= 3 , padding= 1, stride=1)) 27 | if bn: 28 | layer.add_module('bn' ,nn.BatchNorm2d(out_dim)) 29 | layer.add_module('relu', nn.ReLU(False)) 30 | 31 | if Use_pool : 32 | if Maxpool: 33 | layer.add_module("Maxpool",nn.MaxPool2d(kernel_size=2,stride=2)) 34 | else: 35 | layer.add_module("Avgpool", nn.AvgPool2d(kernel_size=2, stride=2)) 36 | return layer 37 | 38 | 39 | 40 | def res_net_blocks(in_dim,out_dim, Use_pool = False ,Maxpool = True,bn = True ,ker_size =3 , padding = 1 ,stride = 1): 41 | layer = nn.Sequential() 42 | layer.add_module( "conv1",nn.Conv2d(in_dim ,out_dim , kernel_size= ker_size , padding= padding, stride=stride)) 43 | if bn: 44 | layer.add_module('bn' ,nn.BatchNorm2d(out_dim)) 45 | layer.add_module('relu', nn.ReLU(False)) 46 | 47 | 48 | layer.add_module( "conv2",nn.Conv2d(out_dim ,out_dim , kernel_size= ker_size , padding= padding, stride=stride)) 49 | if bn: 50 | layer.add_module('bn' ,nn.BatchNorm2d(out_dim)) 51 | layer.add_module('relu', nn.ReLU(False)) 52 | 53 | return layer 54 | 55 | 56 | def conv_blocks_size_5(in_dim,out_dim, Use_pool = False ,Maxpool = True ,bn = True): 57 | layer = nn.Sequential() 58 | layer.add_module( "conv1",nn.Conv2d(in_dim ,out_dim , kernel_size= 5, padding= 2, stride=1)) 59 | 60 | if bn: 61 | layer.add_module('bn' ,nn.BatchNorm2d(out_dim)) 62 | 63 | layer.add_module('relu', nn.ReLU(False)) 64 | 65 | 66 | 67 | layer.add_module( "conv2",nn.Conv2d(out_dim ,out_dim , kernel_size= 5, padding= 2, stride=1)) 68 | # layer.add_module('relu', nn.ReLU(False)) 69 | 70 | if bn: 71 | layer.add_module('bn' ,nn.BatchNorm2d(out_dim)) 72 | layer.add_module('relu', nn.ReLU(False)) 73 | 74 | if Use_pool : 75 | if Maxpool: 76 | layer.add_module("Maxpool",nn.MaxPool2d(kernel_size=2,stride=2)) 77 | else: 78 | layer.add_module("Avgpool", nn.AvgPool2d(kernel_size=2, stride=2)) 79 | return layer 80 | 81 | 82 | 83 | 84 | def conv_blocks_size_7(in_dim,out_dim, Use_pool = False ,Maxpool = True , bn = True): 85 | layer = nn.Sequential() 86 | layer.add_module( "conv1",nn.Conv2d(in_dim ,out_dim , kernel_size= 7 , padding= 3, stride=1)) 87 | if bn: 88 | layer.add_module('bn', nn.BatchNorm2d(out_dim)) 89 | layer.add_module('relu', nn.ReLU(False)) 90 | 91 | 92 | layer.add_module( "conv2",nn.Conv2d(out_dim ,out_dim , kernel_size= 7 , padding= 3, stride=1)) 93 | if bn: 94 | layer.add_module('bn', nn.BatchNorm2d(out_dim)) 95 | layer.add_module('relu', nn.ReLU(False)) 96 | 97 | if Use_pool : 98 | if Maxpool: 99 | layer.add_module("Maxpool",nn.MaxPool2d(kernel_size=2,stride=2)) 100 | else: 101 | layer.add_module("Avgpool", nn.AvgPool2d(kernel_size=2, stride=2)) 102 | return layer 103 | 104 | 105 | 106 | 107 | def deconv_blocks_size_3(in_dim,out_dim,Use_pool=True,bn=True): 108 | layer = nn.Sequential() 109 | layer.add_module( "conv1",nn.ConvTranspose2d(in_dim , out_dim , 3 , 1, 1)) 110 | if bn: 111 | layer.add_module('bn', nn.BatchNorm2d(out_dim)) 112 | layer.add_module('relu', nn.ReLU(False)) 113 | 114 | 115 | layer.add_module("conv2", nn.ConvTranspose2d(out_dim, out_dim ,3, 1, 1)) 116 | if bn: 117 | layer.add_module('bn', nn.BatchNorm2d(out_dim)) 118 | layer.add_module('relu', nn.ReLU(False)) 119 | 120 | if Use_pool: 121 | layer.add_module("Upsamp",nn.UpsamplingNearest2d(scale_factor= 2)) 122 | return layer 123 | 124 | 125 | 126 | 127 | def Nonlinear_layer(in_c=0 , name="nonlinear", bn=False , relu=True, LeakReLU = False , dropout=False ): 128 | layer = nn.Sequential() 129 | if relu: 130 | layer.add_module('%s_relu' % name, nn.ReLU(inplace=False)) 131 | if LeakReLU: 132 | layer.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) 133 | if bn: 134 | layer.add_module('%s_bn' % name, nn.BatchNorm2d(in_c)) 135 | 136 | if dropout: 137 | layer.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=False)) 138 | return layer 139 | -------------------------------------------------------------------------------- /now_loss_derain_train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import torch 4 | import argparse 5 | import urllib.request 6 | 7 | from Utils.torch_ssim import SSIM 8 | from Utils.ssim_map import SSIM_MAP 9 | from Utils.utils import * 10 | from Utils.Vidsom import * 11 | from Utils.model_init import * 12 | from torch import nn, optim 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from MyDataset.Datasets import derain_test_datasets , derain_train_datasets 16 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 17 | 18 | from net.model import u_net as wNet 19 | 20 | parser = argparse.ArgumentParser(description="PyTorch Derain") 21 | parser.add_argument("--train", default="/home/ws/Desktop/PL/Derain_Dataset2018/train", type=str, 22 | help="path to load train datasets(default: none)") 23 | parser.add_argument("--test", default="/home/ws/Desktop/PL/Derain_Dataset2018/test", type=str, 24 | help="path to load test datasets(default: none)") 25 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size") 26 | parser.add_argument("--nEpoch", type=int, default=400, help="number of epochs to train for") 27 | parser.add_argument("--lr", type=float, default=3e-4, help="Learning Rate. Default=1e-4") 28 | parser.add_argument("--train_print_freq", type=int, default=100, help="frequency of print train loss on train phase") 29 | parser.add_argument("--test_frequency", type=int, default=1, help="frequency of test") 30 | parser.add_argument("--test_print_freq", type=int, default=200, help="frequency of print train loss on test phase") 31 | 32 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 33 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 34 | parser.add_argument("--startweights", default= 0, type=int, help="start number of net's weight , 0 is None") 35 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 36 | parser.add_argument("--startepoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 37 | parser.add_argument("--works", type=int, default=8, help="Number of works for data loader to use, Default: 1") 38 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 39 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained net (default: none)") 40 | parser.add_argument("--saveroot", default="/home/ws/Desktop/derain2020/checkpoints", type=str, help="path to save networks") 41 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 42 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 43 | parser.add_argument("--save_image_root", default="", type=str, help="save test image root") 44 | 45 | 46 | 47 | def main(): 48 | global opt, w , criterion_mse , criterion_ssim_map ,criterion_ssim 49 | opt = parser.parse_args() 50 | print(opt) 51 | 52 | 53 | cuda = opt.cuda 54 | if cuda and not torch.cuda.is_available(): 55 | raise Exception("No GPU found, please run without --cuda") 56 | 57 | seed = 1334 58 | torch.manual_seed(seed) 59 | if cuda: 60 | torch.cuda.manual_seed(seed) 61 | 62 | cudnn.benchmark = True 63 | 64 | print("==========> Loading datasets") 65 | 66 | train_dataset = derain_train_datasets(opt.train, transform=Compose([ 67 | ToTensor() 68 | ])) 69 | 70 | test_dataset = derain_test_datasets(opt.test, transform=Compose([ 71 | ToTensor() 72 | ])) 73 | 74 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 75 | pin_memory=True, shuffle=True) 76 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=True, 77 | shuffle=True) 78 | 79 | if opt.initmethod == 'orthogonal': 80 | init_function = weights_init_orthogonal 81 | 82 | elif opt.initmethod == 'kaiming': 83 | init_function = weights_init_kaiming 84 | 85 | elif opt.initmethod == 'normal': 86 | init_function = weights_init_normal 87 | 88 | else: 89 | init_function = weights_init_xavier 90 | 91 | w = wNet(init_function) # 1: 256, 2:512 92 | criterion_mse = nn.MSELoss() 93 | criterion_ssim_map = SSIM_MAP() 94 | criterion_ssim = SSIM() 95 | 96 | #print(net) 97 | 98 | # weights start from early 99 | if opt.startweights: 100 | if os.path.isfile(opt.saveroot): 101 | print("=> loading checkpoint '{}'".format(opt.saveroot)) 102 | weights = torch.load(opt.saveroot + '/u/%s.pth'%(opt.startweights-1)) 103 | w.load_state_dict(weights["state_dict"]) 104 | 105 | else: 106 | raise Exception("'{}' is not a file , Check out it again".format(opt.savaroot)) 107 | 108 | 109 | # optionally copy weights from a checkpoint 110 | # if opt.pretrained: 111 | # if os.path.isfile(opt.pretrained): 112 | # print("=> loading net '{}'".format(opt.pretrained)) 113 | # weights = torch.load(opt.pretrained) 114 | # wNet.load_state_dict(weights['state_dict']) 115 | # else: 116 | # print("=> no net found at '{}'".format(opt.pretrained)) 117 | 118 | #if cuda: 119 | if opt.cuda and torch.cuda.is_available(): 120 | print("==========> Setting GPU") 121 | w = nn.DataParallel(w, device_ids=[i for i in range(opt.gpus)]).cuda() 122 | criterion_mse = criterion_mse.cuda() 123 | criterion_ssim_map = criterion_ssim_map.cuda() 124 | criterion_ssim = criterion_ssim.cuda() 125 | 126 | 127 | else: 128 | print("==========> Setting CPU") 129 | w = w.cpu() 130 | criterion_mse = criterion_mse.cpu() 131 | criterion_ssim_map = criterion_ssim_map.cpu() 132 | 133 | print("==========> Setting Optimizer") 134 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, w.parameters()), lr=opt.lr) 135 | 136 | print("==========> Training") 137 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 138 | train(training_data_loader, optimizer1, epoch) 139 | 140 | if epoch % 100 == 0 : 141 | opt.lr = 1e-4 142 | optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, w.parameters()), lr=opt.lr) 143 | 144 | if epoch % opt.test_frequency == 0 : 145 | test(testing_data_loader ,epoch) 146 | 147 | def train(training_data_loader, optimizer, epoch): 148 | print("training ==========> epoch =", epoch, "lr =", opt.lr) 149 | w.train() 150 | # model2.train() 151 | # model3.train() 152 | t_loss = [] # save trainloss 153 | 154 | for step, (data, label) in enumerate(training_data_loader, 1): 155 | if opt.cuda and torch.cuda.is_available(): 156 | data = data.clone().detach().requires_grad_(True).cuda() 157 | label = label.cuda() 158 | else: 159 | data = data.cpu() 160 | label = label.cpu() 161 | 162 | w.zero_grad() 163 | optimizer.zero_grad() 164 | 165 | output = w(data) 166 | 167 | mse_loss = criterion_mse(output, label) 168 | ssim_map = criterion_ssim_map(output, label) 169 | ssim_loss = 1 - criterion_ssim(output , label) 170 | #loss = torch.mul((1 - ssim_map) ,torch.abs(output - label)).mean() + 0.01*ssim_loss 171 | loss = mse_loss 172 | 173 | loss.backward() 174 | optimizer.step() 175 | 176 | 177 | 178 | if step % opt.train_print_freq == 0: 179 | print("epoch{} step {} train_loss {:5f} l1_loss{:6f} ssim_loss{:6f}".format(epoch, step,loss.item(), 180 | mse_loss.item(), 181 | ssim_loss.item())) 182 | t_loss.append(loss.item()) 183 | 184 | del loss, mse_loss, ssim_map 185 | # displaying to train loss 186 | 187 | updata_epoch_loss_display(train_loss=t_loss, v_epoch=epoch ,envr= 'w train') 188 | 189 | 190 | def test(test_data_loader, epoch): 191 | print("------> testing") 192 | torch.cuda.empty_cache() 193 | 194 | w.eval() 195 | with torch.no_grad(): 196 | test_Psnr_sum = 0.0 197 | test_Ssim_sum = 0.0 198 | # showing list 199 | test_Psnr_loss = [] 200 | test_Ssim_loss = [] 201 | dict_psnr_ssim = {} 202 | for test_step, (data, label, _) in enumerate(test_data_loader, 1): 203 | data = data.cuda() 204 | label = label.cuda() 205 | 206 | out = w(data) 207 | del data 208 | 209 | mse_loss = criterion_mse(out, label) 210 | Psnr, Ssim = get_psnr_ssim(out, label) 211 | del out 212 | Psnr = round(Psnr.item(), 5) 213 | Ssim = round(Ssim.item(), 5) 214 | # del derain , label 215 | test_Psnr_sum += Psnr 216 | test_Ssim_sum += Ssim 217 | 218 | # if opt.save_image : 219 | # dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 220 | # out = derain.cpu().data[0] 221 | # out = ToPILImage()(out) 222 | # image_number = re.findall(r'\d+', data_path[0])[1] 223 | # out.save( opt.save_image_root + "/%s_p:%s_s:%s.jpg" % (image_number, Psnr, Ssim)) 224 | if test_step % opt.test_print_freq == 0: 225 | print("epoch={} Psnr={} Ssim={} loss{}".format(epoch, Psnr, Ssim, mse_loss.item())) 226 | test_Psnr_loss.append(test_Psnr_sum / test_step) 227 | test_Ssim_loss.append(test_Ssim_sum / test_step) 228 | 229 | del mse_loss,Psnr,Ssim 230 | 231 | else: 232 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 233 | test_Ssim_sum / test_step)) 234 | write_test_perform("/home/ws/Desktop/derain2020/perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 235 | # visdom showing 236 | print("---->testing over show in visdom") 237 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 238 | env="w test") 239 | 240 | print("epoch {} test over-----> save net".format(epoch)) 241 | print("saving checkpoint save_root{}".format(opt.saveroot)) 242 | 243 | #if os.path.isfile(opt.saveroot): 244 | save_checkpoint(root=opt.saveroot, model=w, epoch=epoch, model_stage="u") 245 | print("finish save epoch{} checkporint".format({epoch})) 246 | #else: 247 | # raise Exception("saveroot :{} not found , Checkout it".format(opt.saveroot)) 248 | 249 | if __name__ == "__main__": 250 | os.system('clear') 251 | main() -------------------------------------------------------------------------------- /w.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import re 4 | import torch 5 | import argparse 6 | import urllib.request 7 | from Utils.utils import * 8 | from Utils.Vidsom import * 9 | from Utils.model_init import * 10 | from Utils.ssim_map import SSIM_MAP 11 | from Utils.torch_ssim import SSIM 12 | from torch import nn, optim 13 | from torch.backends import cudnn 14 | from torch.autograd import Variable 15 | from torch.utils.data import DataLoader 16 | from torchvision.utils import make_grid 17 | from MyDataset.Datasets import derain_test_datasets , derain_train_datasets 18 | from torchvision.transforms import Compose, ToTensor, Resize, Normalize, CenterCrop, RandomCrop 19 | 20 | from net.model import w_net 21 | 22 | 23 | 24 | parser = argparse.ArgumentParser(description="PyTorch Derain W") 25 | #root 26 | parser.add_argument("--train", default="/home/ws/Desktop/PL/Derain_Dataset2018/train", type=str, 27 | help="path to load train datasets(default: none)") 28 | parser.add_argument("--test", default="/home/ws/Desktop/PL/Derain_Dataset2018/test", type=str, 29 | help="path to load test datasets(default: none)") 30 | 31 | parser.add_argument("--save_image_root", default='./result_mseloss', type=str, 32 | help="save test image root") 33 | parser.add_argument("--save_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 34 | help="path to save networks") 35 | parser.add_argument("--pretrain_root", default="/home/ws/Desktop/derain2020/checkpoints", type=str, 36 | help="path to pretrained net1 net2 net3 root") 37 | 38 | #hypeparameters 39 | parser.add_argument("--batchSize", type=int, default=8, help="training batch size") 40 | parser.add_argument("--nEpoch", type=int, default=10000, help="number of epochs to train for") 41 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate. Default=1e-4") 42 | parser.add_argument("--lr1", type=float, default=1e-4, help="Learning Rate For w") 43 | 44 | parser.add_argument("--train_print_fre", type=int, default=500, help="frequency of print train loss on train phase") 45 | parser.add_argument("--test_frequency", type=int, default=1, help="frequency of test") 46 | parser.add_argument("--test_print_fre", type=int, default=200, help="frequency of print train loss on test phase") 47 | parser.add_argument("--cuda",type=str, default="Ture", help="Use cuda?") 48 | parser.add_argument("--gpus", type=int, default=1, help="nums of gpu to use") 49 | parser.add_argument("--startweights", default=0, type=int, help="start number of net's weight , 0 is None") 50 | parser.add_argument("--initmethod", default='xavier', type=str, help="xavier , kaiming , normal ,orthogonal ,default : xavier") 51 | parser.add_argument("--startepoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 52 | parser.add_argument("--works", type=int, default=4, help="Number of works for data loader to use, Default: 1") 53 | parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum, Default: 0.9") 54 | parser.add_argument("--report", default=False, type=bool, help="report to wechat") 55 | parser.add_argument("--save_image", default=False, type=bool, help="save test image") 56 | parser.add_argument("--pretrain_epoch", default=[93,169,123], type=list, help="pretrained epoch for Net1 Net2 Net3") 57 | 58 | from torchvision import models 59 | from PIL import Image 60 | 61 | 62 | 63 | class Vgg16(nn.Module): 64 | def __init__(self): 65 | super(Vgg16, self).__init__() 66 | features = models.vgg16(pretrained=True).features 67 | self.to_relu_1_2 = nn.Sequential() 68 | self.to_relu_2_2 = nn.Sequential() 69 | self.to_relu_3_3 = nn.Sequential() 70 | self.to_relu_4_3 = nn.Sequential() 71 | 72 | for x in range(4): 73 | self.to_relu_1_2.add_module(str(x), features[x]) 74 | for x in range(4, 9): 75 | self.to_relu_2_2.add_module(str(x), features[x]) 76 | for x in range(9, 16): 77 | self.to_relu_3_3.add_module(str(x), features[x]) 78 | for x in range(16, 23): 79 | self.to_relu_4_3.add_module(str(x), features[x]) 80 | 81 | # don't need the gradients, just want the features 82 | #for param in self.parameters(): 83 | # param.requires_grad = False 84 | 85 | def forward(self, x): 86 | h = self.to_relu_1_2(x) 87 | h_relu_1_2 = h 88 | h = self.to_relu_2_2(h) 89 | h_relu_2_2 = h 90 | h = self.to_relu_3_3(h) 91 | h_relu_3_3 = h 92 | h = self.to_relu_4_3(h) 93 | h_relu_4_3 = h 94 | out = (h_relu_1_2, h_relu_2_2, h_relu_3_3, h_relu_4_3) 95 | return out 96 | 97 | 98 | class PerceptualLoss: 99 | def __init__(self , content_layer = 1 , content_layers = 0): 100 | self.content_layer = content_layer 101 | self.content_layers = content_layers 102 | 103 | self.vgg = nn.DataParallel(Vgg16()).cuda() 104 | self.vgg.eval() 105 | self.L1Loss = nn.DataParallel(nn.L1Loss()).cuda() 106 | self.L1Loss_sum = nn.DataParallel(nn.L1Loss(reduction='sum')).cuda() 107 | 108 | def __call__(self, x, y_hat): 109 | # b, c, h, w = x.shape 110 | y_content_features = self.vgg(x) 111 | y_hat_features = self.vgg(y_hat) 112 | 113 | recon = y_content_features[self.content_layer].cuda() 114 | recon_hat = y_hat_features[self.content_layer].cuda() 115 | 116 | recon1 = y_content_features[self.content_layers].cuda() 117 | recon_hat1 = y_hat_features[self.content_layers].cuda() 118 | 119 | L_content = self.L1Loss_sum(recon_hat, recon).cuda() 120 | L_content1 = self.L1Loss_sum(recon_hat1, recon1).cuda() 121 | 122 | 123 | return L_content+L_content1 124 | 125 | def main(): 126 | global opt, Net1 , Net2 , Net3 , Net4 , RefineNet , criterion_mse , criterion_ssim_map,criterion_ssim,criterion_ace 127 | global w,criterion_p 128 | opt = parser.parse_args() 129 | print(opt) 130 | 131 | 132 | cuda = opt.cuda 133 | if cuda and not torch.cuda.is_available(): 134 | raise Exception("No GPU found, please run without --cuda") 135 | 136 | 137 | cudnn.benchmark = True 138 | 139 | print("==========> Loading datasets") 140 | 141 | train_dataset = derain_train_datasets( data_root= opt.train, transform=Compose([ 142 | ToTensor() 143 | ])) 144 | 145 | test_dataset = derain_test_datasets(opt.test, transform=Compose([ 146 | ToTensor() 147 | ])) 148 | 149 | training_data_loader = DataLoader(dataset=train_dataset, num_workers=opt.works, batch_size=opt.batchSize, 150 | pin_memory=False, shuffle=True) 151 | testing_data_loader = DataLoader(dataset=test_dataset, num_workers=opt.works, batch_size=1, pin_memory=False, 152 | shuffle=True) 153 | 154 | if opt.initmethod == 'orthogonal': 155 | init_function = weights_init_orthogonal 156 | 157 | elif opt.initmethod == 'kaiming': 158 | init_function = weights_init_kaiming 159 | 160 | elif opt.initmethod == 'normal': 161 | init_function = weights_init_normal 162 | 163 | else: 164 | init_function = weights_init_xavier 165 | 166 | w = w_net() 167 | w.apply(init_function) 168 | 169 | 170 | 171 | 172 | criterion_mse = nn.MSELoss() 173 | criterion_ssim_map = SSIM_MAP() 174 | criterion_ssim = SSIM() 175 | criterion_ace = nn.L1Loss() 176 | criterion_p = PerceptualLoss() 177 | print("==========> Setting GPU") 178 | #if cuda: 179 | if opt.cuda: 180 | w = nn.DataParallel(w, device_ids=[i for i in range(opt.gpus)]).cuda() 181 | 182 | 183 | criterion_ssim = criterion_ssim.cuda() 184 | criterion_ssim_map = criterion_ssim_map.cuda() 185 | criterion_mse= criterion_mse.cuda() 186 | criterion_ace = criterion_ace.cuda() 187 | #criterion_p = criterion_p.cuda() 188 | else: 189 | raise Exception("it takes a long time without cuda ") 190 | #print(net) 191 | 192 | # if opt.pretrain_root: 193 | # if os.path.exists(opt.pretrain_root): 194 | # print("=> loading net from '{}'".format(opt.pretrain_root)) 195 | # weights = torch.load(opt.pretrain_root +"/w/%s.pth"%opt.pretrain_epoch[0]) 196 | # Net1.load_state_dict(weights['state_dict'] ) 197 | # 198 | # weights = torch.load(opt.pretrain_root + "/u/%s.pth" % opt.pretrain_epoch[1]) 199 | # Net2.load_state_dict(weights['state_dict'] ) 200 | # 201 | # weights = torch.load(opt.pretrain_root + "/res/%s.pth" % opt.pretrain_epoch[2]) 202 | # Net3.load_state_dict(weights['state_dict']) 203 | # 204 | # del weights 205 | # else: 206 | # print("=> no net found at '{}'".format(opt.pretrain_root)) 207 | 208 | # weights start from early 209 | if opt.startweights: 210 | if os.path.exists(opt.save_root): 211 | print("=> loading checkpoint '{}'".format(opt.save_root)) 212 | weights = torch.load(opt.save_root + '/%s.pth'%opt.startweights) 213 | w.load_state_dict(weights["state_dict"] ) 214 | 215 | # weights = torch.load(opt.save_root + '/Net2/%s.pth' % opt.startweights) 216 | # Net2.load_state_dict(weights["state_dict"]) 217 | 218 | # weights = torch.load(opt.save_root + '/Net3/%s.pth' % opt.startweights) 219 | # Net3.load_state_dict(weights["state_dict"]) 220 | 221 | # weights = torch.load(opt.save_root + '/Net4/%s.pth' % opt.startweights) 222 | # Net4.load_state_dict(weights["state_dict"]) 223 | 224 | # weights = torch.load(opt.save_root + '/refine/%s.pth' % opt.startweights) 225 | # RefineNet.load_state_dict(weights["state_dict"]) 226 | 227 | del weights 228 | else: 229 | raise Exception("'{}' is not a file , Check out it again".format(opt.save_root)) 230 | 231 | 232 | 233 | print("==========> Setting Optimizer") 234 | optimizerw = optim.Adam(filter(lambda p: p.requires_grad, w.parameters()), lr=opt.lr1) 235 | 236 | 237 | optimizer = [ 1 , optimizerw ] 238 | print("==========> Training") 239 | for epoch in range(opt.startepoch, opt.nEpoch + 1): 240 | 241 | if epoch > 400 : 242 | opt.lr1 = 1e-4 243 | optimizer[1] = optim.Adam(filter(lambda p: p.requires_grad, w.parameters()), lr=opt.lr1) 244 | 245 | 246 | train(training_data_loader, optimizer, epoch) 247 | 248 | if epoch % opt.test_frequency == 0 : 249 | test(testing_data_loader ,epoch) 250 | 251 | 252 | 253 | def train(training_data_loader, optimizer, epoch): 254 | print("training ==========> epoch =", epoch, "lr1 =", opt.lr1) 255 | w.train() 256 | # Net1.train() 257 | # Net2.train() 258 | # Net3.train() 259 | # Net4.train() 260 | # RefineNet.train() 261 | t_loss = [] # save trainloss 262 | 263 | for step, (data, label) in enumerate(training_data_loader, 1): 264 | if opt.cuda and torch.cuda.is_available(): 265 | data = data.clone().detach().requires_grad_(True).cuda() 266 | label = label.cuda() 267 | else: 268 | raise Exception("it takes a long time without cuda ") 269 | data = data.cpu() 270 | label = label.cpu() 271 | 272 | w_out = w(data) 273 | loss = criterion_p(w_out , label) 274 | #new_loss = torch.mul((1-criterion_ssim_map(w_out , label)) , torch.abs(w_out-label)).mean().cuda() 275 | 276 | #loss = new_loss 277 | # del Net1_out , Net2_out , Net3_out , Net4_out 278 | 279 | w.zero_grad() 280 | optimizer[1].zero_grad() 281 | loss.backward() 282 | optimizer[1].step() 283 | 284 | 285 | 286 | if step % opt.train_print_fre == 0: 287 | print("epoch{} step {} loss {:6f} ".format(epoch, step,loss.item(),)) 288 | t_loss.append(loss.item()) 289 | 290 | else: 291 | # displaying to train loss 292 | updata_epoch_loss_display( train_loss= t_loss , v_epoch= epoch , envr= "derain train") 293 | 294 | 295 | def test(test_data_loader, epoch): 296 | from torchvision.transforms import ToPILImage 297 | print("------> testing") 298 | w.eval() 299 | 300 | torch.cuda.empty_cache() 301 | 302 | with torch.no_grad(): 303 | 304 | test_Psnr_sum = 0.0 305 | test_Ssim_sum = 0.0 306 | 307 | # showing list 308 | test_Psnr_loss = [] 309 | test_Ssim_loss = [] 310 | dict_psnr_ssim = {} 311 | for test_step, (data, label, data_path) in enumerate(test_data_loader, 1): 312 | data = data.cuda() 313 | label = label.cuda() 314 | 315 | w_out = w(data) 316 | 317 | # new_loss = torch.mul((1 - criterion_ssim_map(w_out, label)), torch.abs(w_out - label)).mean().cuda() 318 | Psnr, Ssim = get_psnr_ssim(w_out, label) 319 | 320 | Psnr = round(Psnr.item(), 4) 321 | Ssim = round(Ssim.item(), 4) 322 | 323 | test_Psnr_sum += Psnr 324 | test_Ssim_sum += Ssim 325 | 326 | if opt.save_image == True: 327 | 328 | dict_psnr_ssim["Psnr%s_Ssim%s" % (Psnr, Ssim)] = data_path 329 | 330 | out = w_out.cpu().data[0] 331 | out = ToPILImage()(out) 332 | image_number = re.findall(r'\d+', data_path[0])[1] 333 | print(image_number) 334 | #out.save( "/home/ws/Desktop/derain2020/result_p/%s.jpg" % (image_number)) 335 | 336 | # if test_step % opt.test_print_fre == 0: 337 | # print("epoch={} Psnr={} Ssim={} mseloss{}".format(epoch, Psnr, Ssim, new_loss.item())) 338 | # test_Psnr_loss.append(test_Psnr_sum / test_step) 339 | # test_Ssim_loss.append(test_Ssim_sum / test_step) 340 | 341 | else: 342 | #del new_loss 343 | print("epoch={} avr_Psnr ={} avr_Ssim={}".format(epoch, test_Psnr_sum / test_step, 344 | test_Ssim_sum / test_step)) 345 | write_test_perform("./perform_test.txt", test_Psnr_sum / test_step, test_Ssim_sum / test_step) 346 | # visdom showing 347 | print("---->testing over show in visdom") 348 | display_Psnr_Ssim(Psnr=test_Psnr_sum / test_step, Ssim=test_Ssim_sum / test_step, v_epoch=epoch, 349 | env="derain_test") 350 | 351 | print("epoch {} train over-----> save net".format(epoch)) 352 | print("saving checkpoint save_root{}".format(opt.save_root)) 353 | if os.path.exists(opt.save_root): 354 | save_checkpoint(root=opt.save_root, model=w, epoch=epoch, model_stage="w_p") 355 | 356 | print("finish save epoch{} checkporint".format({epoch})) 357 | else: 358 | raise Exception("saveroot :{} not found , Checkout it".format(opt.save_root)) 359 | # 360 | 361 | print("all epoch is over ------ ") 362 | print("show epoch and epoch_loss in visdom") 363 | 364 | if __name__ == "__main__": 365 | os.system('clear') 366 | main() --------------------------------------------------------------------------------