├── model ├── __init__.py ├── vgg.pyc ├── __init__.pyc ├── generator.pyc ├── vgg.py └── generator.py ├── models ├── __init__.py ├── NEDB.pyc ├── RNEDB.pyc ├── __init__.pyc ├── networks.pyc ├── non_local_block.pyc ├── region_non_local_block.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── networks.cpython-37.pyc │ └── derain_dense.cpython-36.pyc ├── region_non_local_block.py ├── RNEDB.py ├── NEDB.py ├── NLEDN.py └── non_local_block.py ├── util_metirc ├── __init__.py ├── util.py ├── html.py └── visualizer.py ├── datasets ├── __init__.py ├── __init__.pyc ├── my_loader.pyc ├── classification.py ├── pix2pix_class.py ├── pix2pix2.py ├── pix2pix_val.py └── pix2pix.py ├── myutils ├── __init__.py ├── utils.pyc ├── vgg16.pyc ├── __init__.pyc ├── StyleLoader.py ├── vgg16.py └── utils.py ├── transforms ├── __init__.py ├── pix2pix.pyc ├── __init__.pyc ├── pix2pix_val.py ├── pix2pix3.py └── pix2pix_val3.py ├── misc.pyc ├── util.pyc ├── visualizer.pyc ├── models_metric ├── __init__.pyc ├── base_model.pyc ├── dist_model.pyc ├── networks_basic.pyc ├── weights │ ├── v0.0 │ │ ├── alex.pth │ │ ├── vgg.pth │ │ └── squeeze.pth │ └── v0.1 │ │ ├── alex.pth │ │ ├── vgg.pth │ │ └── squeeze.pth ├── pretrained_networks.pyc ├── base_model.py ├── __init__.py ├── pretrained_networks.py ├── networks_basic.py └── dist_model.py ├── __pycache__ └── misc.cpython-37.pyc ├── requirements.txt ├── run_GDN.sh ├── run_FDN_FRN.sh ├── run_test.sh ├── run_LRN.sh ├── README.md ├── util.py ├── rank2_edge_batch.py ├── misc.py ├── visualizer.py ├── test.py ├── train_GDN.py ├── list_7000_f1000.txt └── train_LRN.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util_metirc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /myutils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /misc.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/misc.pyc -------------------------------------------------------------------------------- /util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/util.pyc -------------------------------------------------------------------------------- /model/vgg.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/model/vgg.pyc -------------------------------------------------------------------------------- /models/NEDB.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/NEDB.pyc -------------------------------------------------------------------------------- /visualizer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/visualizer.pyc -------------------------------------------------------------------------------- /models/RNEDB.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/RNEDB.pyc -------------------------------------------------------------------------------- /myutils/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/myutils/utils.pyc -------------------------------------------------------------------------------- /myutils/vgg16.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/myutils/vgg16.pyc -------------------------------------------------------------------------------- /model/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/model/__init__.pyc -------------------------------------------------------------------------------- /model/generator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/model/generator.pyc -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__init__.pyc -------------------------------------------------------------------------------- /models/networks.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/networks.pyc -------------------------------------------------------------------------------- /myutils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/myutils/__init__.pyc -------------------------------------------------------------------------------- /datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/datasets/__init__.pyc -------------------------------------------------------------------------------- /datasets/my_loader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/datasets/my_loader.pyc -------------------------------------------------------------------------------- /transforms/pix2pix.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/transforms/pix2pix.pyc -------------------------------------------------------------------------------- /transforms/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/transforms/__init__.pyc -------------------------------------------------------------------------------- /models/non_local_block.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/non_local_block.pyc -------------------------------------------------------------------------------- /models_metric/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/__init__.pyc -------------------------------------------------------------------------------- /models_metric/base_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/base_model.pyc -------------------------------------------------------------------------------- /models_metric/dist_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/dist_model.pyc -------------------------------------------------------------------------------- /__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /models_metric/networks_basic.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/networks_basic.pyc -------------------------------------------------------------------------------- /models/region_non_local_block.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/region_non_local_block.pyc -------------------------------------------------------------------------------- /models_metric/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /models_metric/pretrained_networks.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/pretrained_networks.pyc -------------------------------------------------------------------------------- /models_metric/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/derain_dense.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/derain_dense.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | Pillow 4 | scipy 5 | scikit-image 6 | torch==1.0.1 7 | torchvision==0.2.0 8 | torchfile 9 | tqdm 10 | 11 | -------------------------------------------------------------------------------- /run_GDN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0,1 python2.7 train_GDN.py --dataroot "/media/he/80FE99D1FE99BFB8/longmao_final/train" --valDataroot "/media/he/80FE99D1FE99BFB8/longmao_final/val" --pre "" --name "GDN" --exp "GDN" --display_port 8099 --originalSize_h 420 --originalSize_w 420 --imageSize_h 384 --imageSize_w 384 --batchSize 2 3 | -------------------------------------------------------------------------------- /run_FDN_FRN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python2.7 train_FDN_FRN.py --dataroot "/media/he/80FE99D1FE99BFB8/longmao_final/train" --valDataroot "" --pre "" --name "FDN_FRN" --exp "FDN_FRN" --netGDN "./ckpt/netGDN.pth" --netLRN "ckpt/netLRN.pth" --display_port 8099 --originalSize_h 1080 --originalSize_w 1920 --imageSize_h 1024 --imageSize_w 1024 --batchSize 1 3 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=1 python2 test.py --dataroot "/media/he/80FE99D1FE99BFB8/FHDMi/test" --netGDN "ckpt/netGDN.pth" --netLRN "ckpt/netLRN.pth" --netFDN "ckpt/netFDN.pth" --netFRN "ckpt/netFRN.pth" --batchSize 2 --originalSize_h 1080 --originalSize_w 1920 --imageSize_h 1080 --imageSize_w 1920 --image_path "results" --write 1 --record "results.txt" 3 | 4 | -------------------------------------------------------------------------------- /run_LRN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0,1 python2.7 train_LRN.py --dataroot "/media/he/80FE99D1FE99BFB8/longmao_final/train" --valDataroot "/media/he/80FE99D1FE99BFB8/longmao_final/val" --pre "" --name "LRN" --exp "LRN" --netGDN 'ckpt/netGDN.pth' --display_port 8099 --originalSize_h 1080 --originalSize_w 1920 --imageSize_h 1024 --imageSize_w 1024 --batchSize 2 --list_file 'list_7000_f1000.txt' 3 | -------------------------------------------------------------------------------- /models/region_non_local_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from non_local_block import NONLocalBlock2D 4 | 5 | 6 | class RegionNONLocalBlock(nn.Module): 7 | def __init__(self, in_channels, grid=[6, 6]): 8 | super(RegionNONLocalBlock, self).__init__() 9 | 10 | self.non_local_block = NONLocalBlock2D(in_channels, sub_sample=True, bn_layer=False) 11 | self.grid = grid 12 | 13 | def forward(self, x): 14 | batch_size, _, height, width = x.size() 15 | 16 | input_row_list = x.chunk(self.grid[0], dim=2) 17 | 18 | output_row_list = [] 19 | for i, row in enumerate(input_row_list): 20 | input_grid_list_of_a_row = row.chunk(self.grid[1], dim=3) 21 | output_grid_list_of_a_row = [] 22 | 23 | for j, grid in enumerate(input_grid_list_of_a_row): 24 | grid = self.non_local_block(grid) 25 | output_grid_list_of_a_row.append(grid) 26 | 27 | output_row = torch.cat(output_grid_list_of_a_row, dim=3) 28 | output_row_list.append(output_row) 29 | 30 | output = torch.cat(output_row_list, dim=2) 31 | return output 32 | -------------------------------------------------------------------------------- /myutils/StyleLoader.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | from torch.autograd import Variable 13 | 14 | from myutils import utils 15 | 16 | class StyleLoader(): 17 | def __init__(self, style_folder, style_size, cuda=True): 18 | self.folder = style_folder 19 | self.style_size = style_size 20 | self.files = os.listdir(style_folder) 21 | self.cuda = cuda 22 | 23 | def get(self, i): 24 | idx = i%len(self.files) 25 | filepath = os.path.join(self.folder, self.files[idx]) 26 | style = utils.tensor_load_rgbimage(filepath, self.style_size) 27 | style = style.unsqueeze(0) 28 | style = utils.preprocess_batch(style) 29 | if self.cuda: 30 | style = style.cuda() 31 | style_v = Variable(style, requires_grad=False) 32 | return style_v 33 | 34 | def size(self): 35 | return len(self.files) 36 | 37 | 38 | -------------------------------------------------------------------------------- /util_metirc/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | def load_image(path): 11 | if(path[-3:] == 'dng'): 12 | import rawpy 13 | with rawpy.imread(path) as raw: 14 | img = raw.postprocess() 15 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 16 | import cv2 17 | return cv2.imread(path)[:,:,::-1] 18 | else: 19 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 20 | 21 | return img 22 | 23 | def save_image(image_numpy, image_path, ): 24 | image_pil = Image.fromarray(image_numpy) 25 | image_pil.save(image_path) 26 | 27 | def mkdirs(paths): 28 | if isinstance(paths, list) and not isinstance(paths, str): 29 | for path in paths: 30 | mkdir(path) 31 | else: 32 | mkdir(paths) 33 | 34 | def mkdir(path): 35 | if not os.path.exists(path): 36 | os.makedirs(path) 37 | 38 | 39 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 40 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 41 | image_numpy = image_tensor[0].cpu().float().numpy() 42 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 43 | return image_numpy.astype(imtype) 44 | 45 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 46 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 47 | return torch.Tensor((image / factor - cent) 48 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 49 | -------------------------------------------------------------------------------- /models/RNEDB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from region_non_local_block import RegionNONLocalBlock 4 | 5 | 6 | class RNEDB(nn.Module): 7 | def __init__(self, block_num=3, inter_channel=32, channel=64, grid=[8, 8]): 8 | super(RNEDB, self).__init__() 9 | 10 | concat_channels = channel + block_num * inter_channel 11 | channels_now = channel 12 | 13 | self.region_non_local = RegionNONLocalBlock(channels_now, grid=grid) 14 | 15 | self.group_list = [] 16 | for i in range(block_num): 17 | group = nn.Sequential( 18 | nn.Conv2d(in_channels=channels_now, out_channels=inter_channel, kernel_size=3, 19 | stride=1, padding=1), 20 | nn.ReLU(), 21 | ) 22 | self.add_module(name='group_%d' % i, module=group) 23 | self.group_list.append(group) 24 | 25 | channels_now += inter_channel 26 | 27 | assert channels_now == concat_channels 28 | self.fusion = nn.Sequential( 29 | nn.Conv2d(concat_channels, channel, kernel_size=1, stride=1, padding=0), 30 | ) 31 | 32 | def forward(self, x): 33 | x_rnl = self.region_non_local(x) 34 | feature_list = [x_rnl,] 35 | 36 | for group in self.group_list: 37 | inputs = torch.cat(feature_list, dim=1) 38 | outputs = group(inputs) 39 | feature_list.append(outputs) 40 | 41 | inputs = torch.cat(feature_list, dim=1) 42 | fusion_outputs = self.fusion(inputs) 43 | 44 | block_outputs = fusion_outputs + x 45 | 46 | return block_outputs 47 | -------------------------------------------------------------------------------- /models/NEDB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from non_local_block import NONLocalBlock2D 4 | 5 | 6 | class NEDB(nn.Module): 7 | def __init__(self, block_num=3, inter_channel=32, channel=64): 8 | super(NEDB, self).__init__() 9 | 10 | concat_channels = channel + block_num * inter_channel 11 | channels_now = channel 12 | 13 | self.non_local = NONLocalBlock2D(channels_now, bn_layer=False) 14 | 15 | self.group_list = [] 16 | for i in range(block_num): 17 | group = nn.Sequential( 18 | nn.Conv2d(in_channels=channels_now, out_channels=inter_channel, kernel_size=3, 19 | stride=1, padding=1), 20 | nn.ReLU(), 21 | ) 22 | self.add_module(name='group_%d' % i, module=group) 23 | self.group_list.append(group) 24 | 25 | channels_now += inter_channel 26 | 27 | assert channels_now == concat_channels 28 | self.fusion = nn.Sequential( 29 | nn.Conv2d(concat_channels, channel, kernel_size=1, stride=1, padding=0), 30 | ) 31 | 32 | def forward(self, x, corr1, corr2): 33 | x_nl, ret_corr1, ret_corr2 = self.non_local(x, corr1, corr2) 34 | feature_list = [x_nl,] 35 | 36 | for group in self.group_list: 37 | inputs = torch.cat(feature_list, dim=1) 38 | outputs = group(inputs) 39 | feature_list.append(outputs) 40 | 41 | inputs = torch.cat(feature_list, dim=1) 42 | fusion_outputs = self.fusion(inputs) 43 | 44 | block_outputs = fusion_outputs + x 45 | 46 | return block_outputs, ret_corr1, ret_corr2 47 | -------------------------------------------------------------------------------- /models_metric/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | # from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /myutils/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Vgg16(torch.nn.Module): 7 | def __init__(self): 8 | super(Vgg16, self).__init__() 9 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 10 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 11 | 12 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 13 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 14 | 15 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 16 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 17 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 18 | 19 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 20 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 21 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 22 | 23 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 24 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 25 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 26 | 27 | def forward(self, X): 28 | h = F.relu(self.conv1_1(X)) 29 | h = F.relu(self.conv1_2(h)) 30 | relu1_2 = h 31 | h = F.max_pool2d(h, kernel_size=2, stride=2) 32 | 33 | h = F.relu(self.conv2_1(h)) 34 | h = F.relu(self.conv2_2(h)) 35 | relu2_2 = h 36 | h = F.max_pool2d(h, kernel_size=2, stride=2) 37 | 38 | h = F.relu(self.conv3_1(h)) 39 | h = F.relu(self.conv3_2(h)) 40 | h = F.relu(self.conv3_3(h)) 41 | relu3_3 = h 42 | h = F.max_pool2d(h, kernel_size=2, stride=2) 43 | 44 | h = F.relu(self.conv4_1(h)) 45 | h = F.relu(self.conv4_2(h)) 46 | h = F.relu(self.conv4_3(h)) 47 | relu4_3 = h 48 | 49 | return [relu1_2, relu2_2, relu3_3, relu4_3] 50 | -------------------------------------------------------------------------------- /datasets/classification.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import h5py 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | if not os.path.isdir(dir): 20 | raise Exception('Check dataroot') 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(dir, fname) 25 | item = path 26 | images.append(item) 27 | return images 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class classification(data.Dataset): 33 | def __init__(self, root, transform=None, loader=default_loader, seed=None): 34 | # imgs = make_dataset(root) 35 | # if len(imgs) == 0: 36 | # raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 37 | # "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 38 | self.root = root 39 | # self.imgs = imgs 40 | self.transform = transform 41 | self.loader = loader 42 | 43 | if seed is not None: 44 | np.random.seed(seed) 45 | 46 | def __getitem__(self, _): 47 | index = np.random.randint(1,self.__len__()) 48 | # path = self.imgs[index] 49 | # img = self.loader(path) 50 | #img = img.resize((w, h), Image.BILINEAR) 51 | 52 | 53 | 54 | file_name=self.root+'/'+str(index)+'.h5' 55 | f=h5py.File(file_name,'r') 56 | 57 | haze_image=f['haze'][:] 58 | label=f['label'][:] 59 | label=label.mean()-1 60 | 61 | haze_image=np.swapaxes(haze_image,0,2) 62 | haze_image=np.swapaxes(haze_image,1,2) 63 | 64 | 65 | # if self.transform is not None: 66 | # # NOTE preprocessing for each pair of images 67 | # imgA, imgB = self.transform(imgA, imgB) 68 | return haze_image, label 69 | 70 | def __len__(self): 71 | train_list=glob.glob(self.root+'/*h5') 72 | # print len(train_list) 73 | return len(train_list) 74 | 75 | # return len(self.imgs) 76 | -------------------------------------------------------------------------------- /util_metirc/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | # self.img_dir = os.path.join(self.web_dir, ) 11 | self.img_subdir = image_subdir 12 | self.img_dir = os.path.join(self.web_dir, image_subdir) 13 | if not os.path.exists(self.web_dir): 14 | os.makedirs(self.web_dir) 15 | if not os.path.exists(self.img_dir): 16 | os.makedirs(self.img_dir) 17 | # print(self.img_dir) 18 | 19 | self.doc = dominate.document(title=title) 20 | if reflesh > 0: 21 | with self.doc.head: 22 | meta(http_equiv="reflesh", content=str(reflesh)) 23 | 24 | def get_image_dir(self): 25 | return self.img_dir 26 | 27 | def add_header(self, str): 28 | with self.doc: 29 | h3(str) 30 | 31 | def add_table(self, border=1): 32 | self.t = table(border=border, style="table-layout: fixed;") 33 | self.doc.add(self.t) 34 | 35 | def add_images(self, ims, txts, links, width=400): 36 | self.add_table() 37 | with self.t: 38 | with tr(): 39 | for im, txt, link in zip(ims, txts, links): 40 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 41 | with p(): 42 | with a(href=os.path.join(link)): 43 | img(style="width:%dpx" % width, src=os.path.join(im)) 44 | br() 45 | p(txt) 46 | 47 | def save(self,file='index'): 48 | html_file = '%s/%s.html' % (self.web_dir,file) 49 | f = open(html_file, 'wt') 50 | f.write(self.doc.render()) 51 | f.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | html = HTML('web/', 'test_html') 56 | html.add_header('hello world') 57 | 58 | ims = [] 59 | txts = [] 60 | links = [] 61 | for n in range(4): 62 | ims.append('image_%d.png' % n) 63 | txts.append('text_%d' % n) 64 | links.append('image_%d.png' % n) 65 | html.add_images(ims, txts, links) 66 | html.save() 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FHDe2Net 2 | Official Implementation for "FHDe2Net: Full High Definition Demoireing Network" (ECCV 20) 3 | 4 | ## Prerequisites: 5 | 1. Linux 6 | 2. python2 or 3 7 | 3. NVIDIA GPU + CUDA CuDNN (CUDA 8.0) 8 | 9 | ## Installation: 10 | 1. Install PyTorch from http://pytorch.org 11 | 2. Install Torch vision from https://github.com/pytorch/vision 12 | 3. Install python package: numpy, scipy, PIL, math, skimage, visdom 13 | 14 | ## Download the FHDMi dataset 15 | You can download the training and testing dataset from 16 | https://pan.baidu.com/s/19LTN7unSBAftSpNVs8x9ZQ with password jf2d 17 | or 18 | https://drive.google.com/drive/folders/1IJSeBXepXFpNAvL5OyZ2Y1yu4KPvDxN5?usp=sharing 19 | You accelerate the training with a subset of FHDMi described by FHDMi_thin.txt. 20 | 21 | ## Testing 22 | 1) Download pre-trained models 23 | You can download the pre-trained models from 24 | https://pan.baidu.com/s/14fo4gdBtx4GDohNNyYObpg with password t8vn 25 | And all the models are supposed to be placed in the ckpt folder. 26 | 27 | 2) Build up the testing environment 28 | You can easily build the testing environment by: 29 | `pip install -r requirements.txt` 30 | 31 | 3) testing 32 | Specify the --dataroot with the testing dataset path, and run 33 | `bash run_test.sh ` 34 | 35 | ## Training 36 | 1) Download Vgg19 ckpt from 37 | https://pan.baidu.com/s/1c3eEh29uAfZTzTe0X9Jz_Q by password: zvcy 38 | And put it in models/ 39 | 40 | 2) open visdom by 41 | `python -m visdom.server -port 8099` 42 | 43 | 3) change the dataroot in run_GDN.sh and train GND by running 44 | `bash run_GDN.sh` 45 | 46 | 4) change the dataroot in run_LRN.sh and train LRN by running 47 | `bash run_LRN.sh` 48 | (For trainnig LRN, you can either use the distilled dataset generated by rank2_edge_batch.py or directly use the whole dataset. 49 | A subset image list for FHDMi has been generated ahead in list_7000_f1000.txt.) 50 | 51 | 5) change the dataroot in run_FDN_FRN.sh and train FDN and FRN by running 52 | `bash run_FDN_FRN.sh` 53 | 54 | ## Citation 55 | ``` 56 | @article{hefhde2net, 57 | title={FHDe2Net: Full High Definition Demoireing Network}, 58 | author={He, Bin and Wang, Ce and Shi, Boxin and Duan, Ling-Yu}, 59 | publisher={Springer} 60 | } 61 | ``` 62 | ## Contactor 63 | If you have any question, please feel free to contact me with 1801213742@pku.edu.cn 64 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | def my_tensor2im(input_image, imtype=np.uint8): 21 | if isinstance(input_image, torch.Tensor): 22 | image_tensor = input_image.data 23 | else: 24 | return input_image 25 | # print(50*'-') 26 | # print(input_image.shape) 27 | # print(image_tensor.shape) 28 | image_numpy = image_tensor.cpu().float().numpy() 29 | if image_numpy.shape[0] == 1: 30 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 31 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 32 | return image_numpy.astype(imtype) 33 | 34 | 35 | def diagnose_network(net, name='network'): 36 | mean = 0.0 37 | count = 0 38 | for param in net.parameters(): 39 | if param.grad is not None: 40 | mean += torch.mean(torch.abs(param.grad.data)) 41 | count += 1 42 | if count > 0: 43 | mean = mean / count 44 | print(name) 45 | print(mean) 46 | 47 | 48 | def save_image(image_numpy, image_path): 49 | image_pil = Image.fromarray(image_numpy) 50 | image_pil.save(image_path) 51 | 52 | 53 | def print_numpy(x, val=True, shp=False): 54 | x = x.astype(np.float64) 55 | if shp: 56 | print('shape,', x.shape) 57 | if val: 58 | x = x.flatten() 59 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 60 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 61 | 62 | 63 | def mkdirs(paths): 64 | if isinstance(paths, list) and not isinstance(paths, str): 65 | for path in paths: 66 | mkdir(path) 67 | else: 68 | mkdir(paths) 69 | 70 | 71 | def mkdir(path): 72 | if not os.path.exists(path): 73 | os.makedirs(path) 74 | -------------------------------------------------------------------------------- /datasets/pix2pix_class.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | if not os.path.isdir(dir): 20 | raise Exception('Check dataroot') 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(dir, fname) 25 | item = path 26 | images.append(item) 27 | return images 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class pix2pix(data.Dataset): 33 | def __init__(self, root, transform=None, loader=default_loader, seed=None, pre=""): 34 | imgs = make_dataset(root) 35 | if len(imgs) == 0: 36 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 37 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 38 | self.root = root 39 | self.imgs = imgs 40 | self.transform = transform 41 | self.loader = loader 42 | 43 | if seed is not None: 44 | np.random.seed(seed) 45 | 46 | def __getitem__(self, index): 47 | 48 | index_sub = np.random.randint(0, 3) 49 | label=index_sub 50 | 51 | 52 | 53 | if index_sub==0: 54 | index = np.random.randint(0,4000) 55 | path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg' 56 | 57 | 58 | if index_sub==1: 59 | index = np.random.randint(0,4000) 60 | path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Medium/train2018new'+'/'+str(index)+'.jpg' 61 | 62 | if index_sub==2: 63 | index = np.random.randint(0,4000) 64 | path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Light/train2018new'+'/'+str(index)+'.jpg' 65 | 66 | 67 | 68 | img = self.loader(path) 69 | 70 | 71 | w, h = img.size 72 | 73 | 74 | # NOTE: split a sample into imgA and imgB 75 | imgA = img.crop((0, 0, w/2, h)) 76 | imgB = img.crop((w/2, 0, w, h)) 77 | 78 | 79 | if self.transform is not None: 80 | # NOTE preprocessing for each pair of images 81 | imgA, imgB = self.transform(imgA, imgB) 82 | 83 | return imgA, imgB, label 84 | 85 | def __len__(self): 86 | # return 679 87 | # print(len(self.imgs)) 88 | return len(self.imgs) 89 | -------------------------------------------------------------------------------- /models/NLEDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from NEDB import NEDB 4 | from RNEDB import RNEDB 5 | 6 | 7 | class NLEDN(nn.Module): 8 | def __init__(self): 9 | super(NLEDN, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(3, 64, 3, 1, 1) 12 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1) 13 | 14 | self.up_1 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[8, 8]) 15 | self.up_2 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[4, 4]) 16 | self.up_3 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[2, 2]) 17 | 18 | self.down_3 = NEDB(block_num=4, inter_channel=32, channel=64) 19 | self.down_2 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[2, 2]) 20 | self.down_1 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[4, 4]) 21 | 22 | self.down_2_fusion = nn.Conv2d(64 + 64, 64, 1, 1, 0) 23 | self.down_1_fusion = nn.Conv2d(64 + 64, 64, 1, 1, 0) 24 | 25 | self.fusion = nn.Sequential( 26 | nn.Conv2d(64 * 3, 64, 1, 1, 0), 27 | nn.Conv2d(64, 64, 3, 1, 1), 28 | ) 29 | 30 | self.final_conv = nn.Sequential( 31 | nn.Conv2d(64, 3, 3, 1, 1), 32 | nn.Tanh(), 33 | ) 34 | 35 | def forward(self, x): 36 | feature_neg_1 = self.conv1(x) 37 | feature_0 = self.conv2(feature_neg_1) 38 | 39 | up_1_banch = self.up_1(feature_0) 40 | up_1, indices_1 = nn.MaxPool2d(2, 2, return_indices=True)(up_1_banch) 41 | 42 | up_2 = self.up_2(up_1) 43 | up_2, indices_2 = nn.MaxPool2d(2, 2, return_indices=True)(up_2) 44 | 45 | up_3 = self.up_3(up_2) 46 | up_3, indices_3 = nn.MaxPool2d(2, 2, return_indices=True)(up_3) 47 | 48 | down_3 = self.down_3(up_3) 49 | 50 | down_3 = nn.MaxUnpool2d(2, 2)(down_3, indices_3, output_size=up_2.size()) 51 | 52 | down_3 = torch.cat([up_2, down_3], dim=1) 53 | down_3 = self.down_2_fusion(down_3) 54 | down_2 = self.down_2(down_3) 55 | 56 | down_2 = nn.MaxUnpool2d(2, 2)(down_2, indices_2, output_size=up_1.size()) 57 | 58 | down_2 = torch.cat([up_1, down_2], dim=1) 59 | down_2 = self.down_1_fusion(down_2) 60 | down_1 = self.down_1(down_2) 61 | down_1 = nn.MaxUnpool2d(2, 2)(down_1, indices_1, output_size=feature_0.size()) 62 | 63 | down_1 = torch.cat([feature_0, down_1], dim=1) 64 | 65 | cat_block_feature = torch.cat([down_1, up_1_banch], 1) 66 | feature = self.fusion(cat_block_feature) 67 | feature = feature + feature_neg_1 68 | 69 | outputs = self.final_conv(feature) 70 | 71 | return outputs 72 | -------------------------------------------------------------------------------- /datasets/pix2pix2.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import h5py 7 | import glob 8 | import scipy.ndimage 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | if not os.path.isdir(dir): 20 | raise Exception('Check dataroot') 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(dir, fname) 25 | item = path 26 | images.append(item) 27 | return images 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class pix2pix(data.Dataset): 33 | def __init__(self, root, transform=None, loader=default_loader, seed=None): 34 | # imgs = make_dataset(root) 35 | # if len(imgs) == 0: 36 | # raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 37 | # "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 38 | self.root = root 39 | # self.imgs = imgs 40 | self.transform = transform 41 | self.loader = loader 42 | 43 | if seed is not None: 44 | np.random.seed(seed) 45 | 46 | def __getitem__(self, _): 47 | index = np.random.randint(1,self.__len__()) 48 | # index = np.random.randint(self.__len__(), size=1)[0] 49 | 50 | # path = self.imgs[index] 51 | # img = self.loader(path) 52 | #img = img.resize((w, h), Image.BILINEAR) 53 | 54 | 55 | 56 | file_name=self.root+'/'+str(index)+'.h5' 57 | f=h5py.File(file_name,'r') 58 | 59 | haze_image=f['haze'][:] 60 | trans_map=f['trans'][:] 61 | ato_map=f['ato'][:] 62 | GT=f['gt'][:] 63 | 64 | 65 | 66 | haze_image=np.swapaxes(haze_image,0,2) 67 | trans_map=np.swapaxes(trans_map,0,2) 68 | ato_map=np.swapaxes(ato_map,0,2) 69 | GT=np.swapaxes(GT,0,2) 70 | 71 | 72 | 73 | haze_image=np.swapaxes(haze_image,1,2) 74 | trans_map=np.swapaxes(trans_map,1,2) 75 | ato_map=np.swapaxes(ato_map,1,2) 76 | GT=np.swapaxes(GT,1,2) 77 | 78 | # if np.random.uniform()>0.5: 79 | # haze_image=np.flip(haze_image,2).copy() 80 | # GT = np.flip(GT, 2).copy() 81 | # trans_map=np.flip(trans_map, 2).copy() 82 | # if np.random.uniform()>0.5: 83 | # angle = np.random.uniform(-10, 10) 84 | # haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle) 85 | # GT = scipy.ndimage.interpolation.rotate(GT, angle) 86 | 87 | # if np.random.uniform()>0.5: 88 | # angle = np.random.uniform(-10, 10) 89 | # haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle) 90 | # GT = scipy.ndimage.interpolation.rotate(GT, angle) 91 | 92 | # if np.random.uniform()>0.5: 93 | # std = np.random.uniform(0.2, 1.2) 94 | # haze_image = scipy.ndimage.filters.gaussian_filter(haze_image, std,mode='constant') 95 | 96 | # haze_image=np.random.uniform(-10/5000,10/5000,size=haze_image.shape) 97 | # haze_image = np.maximum(0, haze_image) 98 | 99 | # if self.transform is not None: 100 | # # NOTE preprocessing for each pair of images 101 | # imgA, imgB = self.transform(imgA, imgB) 102 | return haze_image, GT, trans_map, ato_map 103 | 104 | def __len__(self): 105 | train_list=glob.glob(self.root+'/*h5') 106 | # print len(train_list) 107 | return len(train_list) 108 | 109 | # return len(self.imgs) 110 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def conv2d(in_channels, out_channels): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 8 | nn.ReLU(inplace=True)) 9 | 10 | 11 | def conv(in_channels, out_channels): 12 | return nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 13 | 14 | 15 | class VGG19(nn.Module): 16 | def __init__(self): 17 | super(VGG19, self).__init__() 18 | self.mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda() 19 | self.std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda() 20 | self.conv1_1 = nn.Sequential(conv(3, 64), nn.ReLU()) 21 | self.conv1_2 = nn.Sequential(conv(64, 64), nn.ReLU()) 22 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False) 23 | self.conv2_1 = nn.Sequential(conv(64, 128), nn.ReLU()) 24 | self.conv2_2 = nn.Sequential(conv(128, 128), nn.ReLU()) 25 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False) 26 | self.conv3_1 = nn.Sequential(conv(128, 256), nn.ReLU()) 27 | self.conv3_2 = nn.Sequential(conv(256, 256), nn.ReLU()) 28 | self.conv3_3 = nn.Sequential(conv(256, 256), nn.ReLU()) 29 | self.conv3_4 = nn.Sequential(conv(256, 256), nn.ReLU()) 30 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False) 31 | self.conv4_1 = nn.Sequential(conv(256, 512), nn.ReLU()) 32 | self.conv4_2 = nn.Sequential(conv(512, 512), nn.ReLU()) 33 | 34 | def load_model(self, model_file): 35 | vgg19_dict = self.state_dict() 36 | pretrained_dict = torch.load(model_file) 37 | vgg19_keys = vgg19_dict.keys() 38 | pretrained_keys = pretrained_dict.keys() 39 | for k, pk in zip(vgg19_keys, pretrained_keys): 40 | vgg19_dict[k] = pretrained_dict[pk] 41 | self.load_state_dict(vgg19_dict) 42 | 43 | # def forward(self, input_images): 44 | # # print(self.mean) 45 | # # input_images = (input_images - self.mean) / self.std 46 | # feature = {} 47 | # feature['conv1_1'] = self.conv1_1(input_images) 48 | # feature['conv1_2'] = self.conv1_2(feature['conv1_1']) 49 | # feature['pool1'] = self.pool1(feature['conv1_2']) 50 | # feature['conv2_1'] = self.conv2_1(feature['pool1']) 51 | # feature['conv2_2'] = self.conv2_2(feature['conv2_1']) 52 | # feature['pool2'] = self.pool2(feature['conv2_2']) 53 | # feature['conv3_1'] = self.conv3_1(feature['pool2']) 54 | # feature['conv3_2'] = self.conv3_2(feature['conv3_1']) 55 | # feature['conv3_3'] = self.conv3_3(feature['conv3_2']) 56 | # feature['conv3_4'] = self.conv3_4(feature['conv3_3']) 57 | # feature['pool3'] = self.pool3(feature['conv3_4']) 58 | # feature['conv4_1'] = self.conv4_1(feature['pool3']) 59 | # feature['conv4_2'] = self.conv4_2(feature['conv4_1']) 60 | # 61 | # return feature 62 | def forward(self, input_images): 63 | # print(self.mean) 64 | # input_images = (input_images - self.mean) / self.std 65 | feature = {} 66 | tmp = self.conv1_1(input_images) 67 | tmp = self.conv1_2(tmp) 68 | feature['conv1_2'] = tmp 69 | tmp = self.pool1(tmp) 70 | tmp = self.conv2_1(tmp) 71 | tmp = self.conv2_2(tmp) 72 | feature['conv2_2'] = tmp 73 | tmp = self.pool2(tmp) 74 | tmp = self.conv3_1(tmp) 75 | feature['conv3_2'] = self.conv3_2(tmp) 76 | tmp = self.conv3_3(feature['conv3_2']) 77 | feature['conv3_4'] = self.conv3_4(tmp) 78 | # tmp = self.conv3_3(feature['conv3_2']) 79 | # tmp = self.conv3_4(tmp) 80 | # tmp = self.pool3(feature['conv3_4']) 81 | # feature['conv4_1'] = self.conv4_1(feature['pool3']) 82 | # feature['conv4_2'] = self.conv4_2(feature['conv4_1']) 83 | 84 | return feature 85 | -------------------------------------------------------------------------------- /datasets/pix2pix_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | # from guidedfilter import guidedfilter 8 | # import guidedfilter.guidedfilter as guidedfilter 9 | 10 | 11 | 12 | 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', 17 | ] 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | def make_dataset(dir): 23 | images = [] 24 | if not os.path.isdir(dir): 25 | raise Exception('Check dataroot') 26 | for root, _, fnames in sorted(os.walk(dir)): 27 | for fname in fnames: 28 | if is_image_file(fname): 29 | path = os.path.join(dir, fname) 30 | item = path 31 | images.append(item) 32 | return images 33 | 34 | def default_loader(path): 35 | return Image.open(path).convert('RGB') 36 | 37 | class pix2pix_val(data.Dataset): 38 | def __init__(self, root, transform=None, loader=default_loader, seed=None, pre =""): 39 | imgs = make_dataset(root) 40 | if len(imgs) == 0: 41 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 42 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 43 | self.root = root 44 | self.imgs = imgs 45 | self.transform = transform 46 | self.loader = loader 47 | 48 | if seed is not None: 49 | np.random.seed(seed) 50 | 51 | def __getitem__(self, index): 52 | # index = np.random.randint(self.__len__(), size=1)[0] 53 | # index = np.random.randint(self.__len__(), size=1)[0] 54 | 55 | path = self.imgs[index] 56 | 57 | path=self.root+'/'+str(index)+'.jpg' 58 | 59 | index_folder = np.random.randint(0,4) 60 | label=index_folder 61 | 62 | # path='/home/openset/Desktop/derain2018/facades/DB_Rain_test/Rain_Heavy/test2018'+'/'+str(index)+'.jpg' 63 | img = self.loader(path) 64 | 65 | # # NOTE: img -> PIL Image 66 | # w, h = img.size 67 | # w, h = 1024, 512 68 | # img = img.resize((w, h), Image.BILINEAR) 69 | # # NOTE: split a sample into imgA and imgB 70 | # imgA = img.crop((0, 0, w/2, h)) 71 | # imgB = img.crop((w/2, 0, w, h)) 72 | # if self.transform is not None: 73 | # # NOTE preprocessing for each pair of images 74 | # imgA, imgB = self.transform(imgA, imgB) 75 | # return imgA, imgB 76 | 77 | 78 | # w, h = 1536, 512 79 | # img = img.resize((w, h), Image.BILINEAR) 80 | # 81 | # 82 | # # NOTE: split a sample into imgA and imgB 83 | # imgA = img.crop((0, 0, w/3, h)) 84 | # imgC = img.crop((2*w/3, 0, w, h)) 85 | # 86 | # imgB = img.crop((w/3, 0, 2*w/3, h)) 87 | 88 | # w, h = 1024, 512 89 | # img = img.resize((w, h), Image.BILINEAR) 90 | # 91 | # r = 16 92 | # eps = 1 93 | # 94 | # # I = img.crop((0, 0, w/2, h)) 95 | # # pix = np.array(I) 96 | # # print 97 | # # base[idx,:,:,:]=guidedfilter(pix[], pix[], r, eps) 98 | # # base[]=guidedfilter(pix[], pix[], r, eps) 99 | # # base[]=guidedfilter(pix[], pix[], r, eps) 100 | # 101 | # 102 | # # base = PIL.Image.fromarray(numpy.uint8(base)) 103 | # 104 | # # NOTE: split a sample into imgA and imgB 105 | # imgA = img.crop((0, 0, w/3, h)) 106 | # imgC = img.crop((2*w/3, 0, w, h)) 107 | # 108 | # imgB = img.crop((w/3, 0, 2*w/3, h)) 109 | # imgA=base 110 | # imgB=I-base 111 | # imgC = img.crop((w/2, 0, w, h)) 112 | w, h = img.size 113 | # w, h = 586*2, 586 114 | 115 | # img = img.resize((w, h), Image.BILINEAR) 116 | 117 | 118 | # NOTE: split a sample into imgA and imgB 119 | imgA = img.crop((0, 0, w/2, h)) 120 | # imgC = img.crop((2*w/3, 0, w, h)) 121 | 122 | imgB = img.crop((w/2, 0, w, h)) 123 | 124 | 125 | if self.transform is not None: 126 | # NOTE preprocessing for each pair of images 127 | # imgA, imgB, imgC = self.transform(imgA, imgB, imgC) 128 | imgA, imgB = self.transform(imgA, imgB) 129 | 130 | return imgA, imgB, path 131 | 132 | def __len__(self): 133 | return len(self.imgs) 134 | -------------------------------------------------------------------------------- /datasets/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | 9 | 10 | IMG_EXTENSIONS = [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', 13 | ] 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def make_dataset(dir): 19 | images = [] 20 | if not os.path.isdir(dir): 21 | raise Exception('Check dataroot') 22 | for root, _, fnames in sorted(os.walk(dir)): 23 | for fname in fnames: 24 | if is_image_file(fname): 25 | path = os.path.join(dir, fname) 26 | item = path 27 | images.append(item) 28 | return images 29 | 30 | def default_loader(path): 31 | return Image.open(path).convert('RGB') 32 | 33 | class pix2pix(data.Dataset): 34 | def __init__(self, root, transform=None, loader=default_loader, seed=None, pre=""): 35 | imgs = make_dataset(root) 36 | if len(imgs) == 0: 37 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 38 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 39 | self.root = root 40 | self.imgs = imgs 41 | self.transform = transform 42 | self.loader = loader 43 | 44 | if seed is not None: 45 | np.random.seed(seed) 46 | 47 | def __getitem__(self, index): 48 | # index = np.random.randint(self.__len__(), size=1)[0] 49 | # index = np.random.randint(self.__len__(), size=1)[0]+1 50 | # index = np.random.randint(self.__len__(), size=1)[0] 51 | 52 | # index_folder = np.random.randint(1,4) 53 | index_folder = np.random.randint(0,1) 54 | 55 | index_sub = np.random.randint(2, 5) 56 | 57 | label=index_folder 58 | 59 | 60 | if index_folder==0: 61 | path='/home/openset/Desktop/derain2018/facades/training2'+'/'+str(index)+'.jpg' 62 | 63 | 64 | 65 | if index_folder==1: 66 | if index_sub<4: 67 | path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg' 68 | if index_sub==4: 69 | index = np.random.randint(0,400) 70 | path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Heavy/trainnew'+'/'+str(index)+'.jpg' 71 | 72 | if index_folder==2: 73 | if index_sub<4: 74 | path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Medium/train2018new'+'/'+str(index)+'.jpg' 75 | if index_sub==4: 76 | index = np.random.randint(0,400) 77 | path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Medium/trainnew'+'/'+str(index)+'.jpg' 78 | 79 | if index_folder==3: 80 | if index_sub<4: 81 | path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Light/train2018new'+'/'+str(index)+'.jpg' 82 | if index_sub==4: 83 | index = np.random.randint(0,400) 84 | path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Light/trainnew'+'/'+str(index)+'.jpg' 85 | 86 | 87 | 88 | # img = self.loader(path) 89 | 90 | img = self.loader(path) 91 | 92 | # NOTE: img -> PIL Image 93 | # w, h = img.size 94 | # w, h = 1024, 512 95 | # img = img.resize((w, h), Image.BILINEAR) 96 | # pix = np.array(I) 97 | # 98 | # r = 16 99 | # eps = 1 100 | # 101 | # I = img.crop((0, 0, w/2, h)) 102 | # pix = np.array(I) 103 | # base=guidedfilter(pix, pix, r, eps) 104 | # base = PIL.Image.fromarray(numpy.uint8(base)) 105 | # 106 | # 107 | # 108 | # imgA=base 109 | # imgB=I-base 110 | # imgC = img.crop((w/2, 0, w, h)) 111 | 112 | w, h = img.size 113 | # img = img.resize((w, h), Image.BILINEAR) 114 | 115 | 116 | # NOTE: split a sample into imgA and imgB 117 | imgA = img.crop((0, 0, w/2, h)) 118 | # imgC = img.crop((2*w/3, 0, w, h)) 119 | 120 | imgB = img.crop((w/2, 0, w, h)) 121 | 122 | 123 | if self.transform is not None: 124 | # NOTE preprocessing for each pair of images 125 | # imgA, imgB, imgC = self.transform(imgA, imgB, imgC) 126 | imgA, imgB = self.transform(imgA, imgB) 127 | 128 | return imgA, imgB, label 129 | 130 | def __len__(self): 131 | # return 679 132 | print(len(self.imgs)) 133 | return len(self.imgs) 134 | -------------------------------------------------------------------------------- /myutils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.autograd import Variable 7 | # from torch.utils.serialization import load_lua 8 | import torchfile 9 | from myutils.vgg16 import Vgg16 10 | 11 | from torch.optim import lr_scheduler 12 | def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False): 13 | img = Image.open(filename).convert('RGB') 14 | if size is not None: 15 | if keep_asp: 16 | size2 = int(size * 1.0 / img.size[0] * img.size[1]) 17 | img = img.resize((size, size2), Image.ANTIALIAS) 18 | else: 19 | img = img.resize((size, size), Image.ANTIALIAS) 20 | 21 | elif scale is not None: 22 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) 23 | img = np.array(img).transpose(2, 0, 1) 24 | img = torch.from_numpy(img).float() 25 | return img 26 | 27 | 28 | def tensor_save_rgbimage(tensor, filename, cuda=False): 29 | if cuda: 30 | img = tensor.clone().cpu().clamp(0, 255).numpy() 31 | else: 32 | img = tensor.clone().clamp(0, 255).numpy() 33 | img = img.transpose(1, 2, 0).astype('uint8') 34 | img = Image.fromarray(img) 35 | img.save(filename) 36 | 37 | 38 | def tensor_save_bgrimage(tensor, filename, cuda=False): 39 | (b, g, r) = torch.chunk(tensor, 3) 40 | tensor = torch.cat((r, g, b)) 41 | tensor_save_rgbimage(tensor, filename, cuda) 42 | 43 | 44 | def gram_matrix(y): 45 | (b, ch, h, w) = y.size() 46 | features = y.view(b, ch, w * h) 47 | features_t = features.transpose(1, 2) 48 | gram = features.bmm(features_t) / (ch * h * w) 49 | return gram 50 | 51 | 52 | def subtract_imagenet_mean_batch(batch): 53 | """Subtract ImageNet mean pixel-wise from a BGR image.""" 54 | tensortype = type(batch.data) 55 | mean = tensortype(batch.data.size()) 56 | mean[:, 0, :, :] = 103.939 57 | mean[:, 1, :, :] = 116.779 58 | mean[:, 2, :, :] = 123.680 59 | return batch - Variable(mean) 60 | 61 | 62 | def add_imagenet_mean_batch(batch): 63 | """Add ImageNet mean pixel-wise from a BGR image.""" 64 | tensortype = type(batch.data) 65 | mean = tensortype(batch.data.size()) 66 | mean[:, 0, :, :] = 103.939 67 | mean[:, 1, :, :] = 116.779 68 | mean[:, 2, :, :] = 123.680 69 | return batch + Variable(mean) 70 | 71 | def imagenet_clamp_batch(batch, low, high): 72 | batch[:,0,:,:].data.clamp_(low-103.939, high-103.939) 73 | batch[:,1,:,:].data.clamp_(low-116.779, high-116.779) 74 | batch[:,2,:,:].data.clamp_(low-123.680, high-123.680) 75 | 76 | 77 | def preprocess_batch(batch): 78 | batch = batch.transpose(0, 1) 79 | (r, g, b) = torch.chunk(batch, 3) 80 | batch = torch.cat((b, g, r)) 81 | batch = batch.transpose(0, 1) 82 | return batch 83 | 84 | 85 | def init_vgg16(model_folder): 86 | """load the vgg16 model feature""" 87 | if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')): 88 | if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')): 89 | os.system( 90 | 'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7')) 91 | vgglua = torchfile.load(os.path.join(model_folder, 'vgg16.t7')) 92 | vgg = Vgg16() 93 | for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): 94 | dst.data[:] = src 95 | torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight')) 96 | 97 | 98 | def set_requires_grad(nets, requires_grad=False): 99 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 100 | Parameters: 101 | nets (network list) -- a list of networks 102 | requires_grad (bool) -- whether the networks require gradients or not 103 | """ 104 | if not isinstance(nets, list): 105 | nets = [nets] 106 | for net in nets: 107 | if net is not None: 108 | for param in net.parameters(): 109 | param.requires_grad = requires_grad 110 | 111 | def get_scheduler(optimizer, opt): 112 | if opt.lr_policy == 'lambda': 113 | def lambda_rule(epoch): 114 | print(epoch) 115 | print(opt.epoch_count) 116 | print(opt.niter) 117 | 118 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 119 | print(lr_l) 120 | print(50*'-') 121 | return lr_l 122 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 123 | elif opt.lr_policy == 'step': 124 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 125 | elif opt.lr_policy == 'plateau': 126 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 127 | else: 128 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 129 | return scheduler 130 | 131 | def my_tensor2im(input_image, imtype=np.uint8): 132 | if isinstance(input_image, torch.Tensor): 133 | image_tensor = input_image.data 134 | else: 135 | return input_image 136 | # print(50*'-') 137 | # print(input_image.shape) 138 | # print(image_tensor.shape) 139 | image_numpy = image_tensor.cpu().float().numpy() 140 | if image_numpy.shape[0] == 1: 141 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 142 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 143 | return image_numpy.astype(imtype) -------------------------------------------------------------------------------- /rank2_edge_batch.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import cv2 3 | import numpy as np 4 | import sys 5 | import math 6 | import os 7 | import models_metric 8 | import torchvision.transforms as transforms 9 | import torch.nn as nn 10 | import torch 11 | import shutil 12 | def get_imlist(path): 13 | return [os.path.join(path,f) for f in os.listdir(path)] 14 | def get_selected_imlist(path, txt_path, num): 15 | f = open(txt_path) 16 | cc = 0 17 | line = f.readline() 18 | L = [] 19 | while line: 20 | cc+=1 21 | if cc == num: 22 | break 23 | name = line.split()[0] 24 | L.append(path + os.sep + name + '.png') 25 | line = f.readline() 26 | f.close() 27 | return L 28 | 29 | def set_requires_grad(nets, requires_grad=False): 30 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 31 | Parameters: 32 | nets (network list) -- a list of networks 33 | requires_grad (bool) -- whether the networks require gradients or not 34 | """ 35 | if not isinstance(nets, list): 36 | nets = [nets] 37 | for net in nets: 38 | if net is not None: 39 | for param in net.parameters(): 40 | param.requires_grad = requires_grad 41 | a = np.array([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]], dtype=np.float32) 42 | a = a.reshape(1, 1, 3, 3) # out_c/3, in_c, w, h 43 | a = np.repeat(a, 3, axis=0) 44 | conv1=nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 45 | conv1.weight.data.copy_(torch.from_numpy(a)) 46 | conv1.weight.requires_grad = False 47 | conv1.cuda() 48 | 49 | b = np.array([[-1, -2, -1],[0, 0, 0],[1, 2, 1]], dtype=np.float32) 50 | b = b.reshape(1, 1, 3, 3) 51 | b = np.repeat(b, 3, axis=0) 52 | conv2=nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 53 | conv2.weight.data.copy_(torch.from_numpy(b)) 54 | conv2.weight.requires_grad = False 55 | conv2.cuda() 56 | 57 | net_metric = models_metric.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, spatial=False) 58 | net_metric = net_metric.cuda() 59 | set_requires_grad(net_metric, requires_grad=False) 60 | 61 | 62 | src_folder = sys.argv[1] # The folder for demoired results by GDN 63 | tar_folder = sys.argv[2] # The folder for GT 64 | txt_path = sys.argv[3] # The txt containing image names 65 | num = int(sys.argv[4]) # The number of candidates 66 | front_num = int(sys.argv[5]) # the number of topK 67 | res_folder = src_folder + os.sep + 'd' 68 | 69 | 70 | 71 | # res_img_path = sys.argv[1] 72 | # target_img_path = sys.argv[2] 73 | 74 | # res_img = cv2.imread(res_img_path) 75 | # target_img = cv2.imread(target_img_path) 76 | f = open("recoed_rank2_edge.txt", 'a+') 77 | cnt = 0 78 | D = {} 79 | res_pre_fix = 'res' 80 | tar_pre_fix = 'tar' 81 | if os.path.exists('tmp1.npy'): 82 | A = np.load('tmp.npy') 83 | else: 84 | for res_img_path in get_selected_imlist(res_folder, txt_path, num): 85 | cnt +=1 86 | if cnt %10 ==0: 87 | print(cnt) 88 | (filename, tempfilename) = os.path.split(res_img_path) 89 | (short_name, extension) = os.path.splitext(tempfilename) 90 | tmp = tempfilename.split('_') 91 | tmp2 = short_name.split('_') 92 | target_img_path = src_folder + os.sep + 'g' + os.sep + tar_pre_fix + '_'+tmp[1] 93 | # ori_img_path = src_folder + os.sep + 'o' + os.sep + tempfilename 94 | res_img = cv2.imread(res_img_path) 95 | target_img = cv2.imread(target_img_path) 96 | # ori_img = cv2.imread(ori_img_path) 97 | gray_target = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY) 98 | gray_res = cv2.cvtColor(res_img, cv2.COLOR_BGR2GRAY) 99 | ret1, binary_map = cv2.threshold(gray_target, 0, 255, cv2.THRESH_OTSU) 100 | 101 | # binary_map = binary_map / 255 102 | binary_map = binary_map[:, :, np.newaxis] 103 | binary_map = np.repeat(binary_map, 3, axis=2) 104 | 105 | 106 | my_transform = transforms.Compose([ 107 | transforms.ToTensor(), 108 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 109 | ]) 110 | 111 | tensor_res_img = my_transform(res_img).unsqueeze_(0).cuda() 112 | tensor_target_img = my_transform(target_img).unsqueeze_(0).cuda() 113 | tensor_binary_map = my_transform(binary_map).unsqueeze_(0).cuda() 114 | # print(tensor_binary_map) 115 | tensor_binary_map = (tensor_binary_map + 1 )/2 116 | # print(tensor_binary_map) 117 | 118 | i_G_x = conv1(tensor_res_img) 119 | i_G_y = conv2(tensor_res_img) 120 | x_hat_edge = torch.tanh(torch.abs(i_G_x) + torch.abs(i_G_y)) 121 | 122 | t_G_x = conv1(tensor_target_img) 123 | t_G_y = conv2(tensor_target_img) 124 | target_edge = torch.tanh(torch.abs(t_G_x) + torch.abs(t_G_y)) 125 | 126 | tensor_merge_res = x_hat_edge * tensor_binary_map 127 | tensor_merge_target = target_edge * tensor_binary_map 128 | res = net_metric(tensor_merge_res, tensor_merge_target).detach()[0][0][0][0].cpu().numpy() 129 | D[short_name] = res 130 | # break 131 | A = sorted(D.iteritems(), key=lambda x: x[1], reverse=True) 132 | np.save('tmp.npy', A) 133 | cc = 0 134 | for item in A: 135 | if cc==front_num: 136 | break 137 | f.write(str(item[0]) + ' ' + str(item[1]) + '\n') 138 | src_res_img_path = src_folder + os.sep + 'd' + os.sep + item[0] + '.png' 139 | tar_res_img_path = tar_folder + os.sep + 'd' + os.sep + item[0] + '.png' 140 | # tar_res_img_path = tar_folder + os.sep + 'd' + os.sep + 'src_' + '0' * (5 - len(str(cc))) + str(cc) + '.png' 141 | src_tar_img_path = src_folder + os.sep + 'g' + os.sep + 'tar_' + item[0].split('_')[1] + '.png' 142 | tar_tar_img_path = tar_folder + os.sep + 'g' + os.sep + 'tar_' + item[0].split('_')[1] + '.png' 143 | # tar_tar_img_path = tar_folder + os.sep + 'g' + os.sep + 'tar_' + '0' * (5 - len(str(cc))) + str(cc) + '.png' 144 | 145 | shutil.copy(src_res_img_path, tar_res_img_path) 146 | shutil.copy(src_tar_img_path, tar_tar_img_path) 147 | cc+=1 -------------------------------------------------------------------------------- /models_metric/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from models_metric import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /models/non_local_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.nn.init 5 | 6 | class SEBlock(nn.Module): 7 | def __init__(self, input_dim, reduction): 8 | super(SEBlock, self).__init__() 9 | mid = int(input_dim / reduction) 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | self.fc = nn.Sequential( 12 | nn.Linear(input_dim, reduction), 13 | nn.ReLU(inplace=True), 14 | nn.Linear(reduction, input_dim), 15 | nn.Sigmoid() 16 | ) 17 | 18 | def forward(self, x): 19 | b, c, _, _ = x.size() 20 | y = self.avg_pool(x).view(b, c) 21 | y = self.fc(y).view(b, c, 1, 1) 22 | return x * y 23 | 24 | 25 | 26 | 27 | class ConvGRU(nn.Module): 28 | def __init__(self, inp_dim, oup_dim, kernel, dilation): 29 | super(ConvGRU, self).__init__() 30 | pad_x = int(dilation * (kernel - 1) / 2) 31 | self.conv_xz = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x, dilation=dilation) 32 | self.conv_xr = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x, dilation=dilation) 33 | self.conv_xn = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x, dilation=dilation) 34 | 35 | pad_h = int((kernel - 1) / 2) 36 | self.conv_hz = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h) 37 | self.conv_hr = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h) 38 | self.conv_hn = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h) 39 | 40 | # self.nl = NEDB(inter_channel=oup_dim / 2, channel=oup_dim) 41 | self.relu = nn.LeakyReLU(0.2) 42 | 43 | def forward(self, x, h=None): 44 | if h is None: 45 | z = F.sigmoid(self.conv_xz(x)) 46 | f = F.tanh(self.conv_xn(x)) 47 | h = z * f 48 | else: 49 | # h.unsqueeze_(1) 50 | z = F.sigmoid(self.conv_xz(x) + self.conv_hz(h)) 51 | r = F.sigmoid(self.conv_xr(x) + self.conv_hr(h)) 52 | n = F.tanh(self.conv_xn(x) + self.conv_hn(r * h)) 53 | h = (1 - z) * h + z * n 54 | 55 | h = self.relu(h) 56 | return h, h 57 | 58 | class _NonLocalBlockND(nn.Module): 59 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 60 | super(_NonLocalBlockND, self).__init__() 61 | 62 | assert dimension in [1, 2, 3] 63 | 64 | self.dimension = dimension 65 | self.sub_sample = sub_sample 66 | 67 | self.in_channels = in_channels 68 | self.inter_channels = inter_channels 69 | 70 | 71 | if self.inter_channels is None: 72 | self.inter_channels = in_channels // 2 73 | if self.inter_channels == 0: 74 | self.inter_channels = 1 75 | self.se0 = SEBlock(self.inter_channels, self.inter_channels/2) 76 | self.se1 = SEBlock(self.inter_channels, self.inter_channels/2) 77 | self.se2 = SEBlock(self.inter_channels, self.inter_channels/2) 78 | self.se3 = SEBlock(self.inter_channels, self.inter_channels/2) 79 | self.gru1 = ConvGRU(self.inter_channels, self.inter_channels, 1, 1) 80 | self.gru2 = ConvGRU(self.inter_channels, self.inter_channels, 1, 1) 81 | if dimension == 3: 82 | conv_nd = nn.Conv3d 83 | max_pool = nn.MaxPool3d 84 | bn = nn.BatchNorm3d 85 | elif dimension == 2: 86 | conv_nd = nn.Conv2d 87 | max_pool = nn.MaxPool2d 88 | bn = nn.BatchNorm2d 89 | else: 90 | conv_nd = nn.Conv1d 91 | max_pool = nn.MaxPool1d 92 | bn = nn.BatchNorm1d 93 | 94 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 95 | kernel_size=1, stride=1, padding=0) 96 | self.theta_ = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 97 | kernel_size=1, stride=1, padding=0) 98 | if bn_layer: 99 | self.W = nn.Sequential( 100 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 101 | kernel_size=1, stride=1, padding=0), 102 | bn(self.in_channels) 103 | ) 104 | nn.init.constant_(self.W[1].weight, 0) 105 | nn.init.constant_(self.W[1].bias, 0) 106 | else: 107 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 108 | kernel_size=1, stride=1, padding=0) 109 | nn.init.constant_(self.W.weight, 0) 110 | nn.init.constant_(self.W.bias, 0) 111 | 112 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 113 | kernel_size=1, stride=1, padding=0) 114 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 115 | kernel_size=1, stride=1, padding=0) 116 | 117 | if sub_sample: 118 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 119 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 120 | def forward(self, x, corr1, corr2): 121 | ''' 122 | :param x: (b, c, t, h, w) 123 | :return: 124 | ''' 125 | 126 | batch_size = x.size(0) 127 | # print(x.shape) 128 | g_x, ret_corr1 = self.gru1(self.se0(self.g(x)), corr1) 129 | g_x = g_x.view(batch_size, self.inter_channels, -1) 130 | g_x = g_x.permute(0, 2, 1) 131 | 132 | theta_x, ret_corr2 = self.gru2(self.se1(self.theta(x)), corr2) 133 | theta_x = theta_x.view(batch_size, self.inter_channels, -1) 134 | theta_x = theta_x.permute(0, 2, 1) 135 | # print(theta_x.shape) 136 | phi_x = self.se3(self.phi(x)) 137 | phi_x = phi_x.view(batch_size, self.inter_channels, -1) 138 | # print(phi_x.shape) 139 | f = torch.matmul(theta_x, phi_x) 140 | 141 | theta_x_ = self.theta_(x) 142 | theta_x_ = self.se3(theta_x_) 143 | theta_x_ = theta_x_.view(batch_size, self.inter_channels, -1) 144 | theta_x_ = theta_x_.permute(0, 2, 1) 145 | 146 | 147 | f_div_C = F.softmax(f, dim=-1) 148 | 149 | y = torch.matmul(f_div_C, g_x) 150 | y = y + theta_x_ 151 | y = y.permute(0, 2, 1).contiguous() 152 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 153 | W_y = self.W(y) 154 | z = W_y + x 155 | 156 | return z, ret_corr1, ret_corr2 157 | 158 | 159 | class NONLocalBlock2D(_NonLocalBlockND): 160 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 161 | super(NONLocalBlock2D, self).__init__(in_channels, 162 | inter_channels=inter_channels, 163 | dimension=2, sub_sample=sub_sample, 164 | bn_layer=bn_layer) 165 | -------------------------------------------------------------------------------- /models_metric/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | # from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | 5 | 6 | def create_exp_dir(exp): 7 | try: 8 | os.makedirs(exp) 9 | print('Creating exp dir: %s' % exp) 10 | except OSError: 11 | pass 12 | return True 13 | 14 | 15 | def weights_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | m.weight.data.normal_(0.0, 0.02) 19 | elif classname.find('BatchNorm') != -1: 20 | m.weight.data.normal_(1.0, 0.02) 21 | m.bias.data.fill_(0) 22 | 23 | 24 | def getLoader(datasetName, dataroot, originalSize_h, originalSize_w, imageSize_h, imageSize_w, batchSize=64, workers=4, 25 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None, pre="", label_file="", list_file=""): 26 | 27 | 28 | if datasetName == 'my_loader': 29 | from datasets.my_loader import my_loader_LRN as commonDataset 30 | import transforms.pix2pix as transforms 31 | 32 | elif datasetName == 'my_loader_fs': 33 | # from datasets.pix2pix import pix2pix as commonDataset 34 | # import transforms.pix2pix as transforms 35 | from datasets.my_loader import my_loader_fs as commonDataset 36 | import transforms.pix2pix as transforms 37 | elif datasetName == 'my_loader_LRN_f2_rand3': 38 | # from datasets.pix2pix import pix2pix as commonDataset 39 | # import transforms.pix2pix as transforms 40 | from datasets.my_loader import my_loader_LRN_f2_rand3 as commonDataset 41 | import transforms.pix2pix as transforms 42 | elif datasetName == 'my_loader_LRN_f2_rand2': 43 | # from datasets.pix2pix import pix2pix as commonDataset 44 | # import transforms.pix2pix as transforms 45 | from datasets.my_loader import my_loader_LRN_f2_rand2 as commonDataset 46 | import transforms.pix2pix as transforms 47 | if split == 'test': 48 | dataset = commonDataset(root=dataroot, 49 | transform1=transforms.Compose([ 50 | transforms.Scale(originalSize_h, originalSize_w), 51 | transforms.CenterCrop(imageSize_h, imageSize_w), 52 | ]), 53 | transform2=transforms.Compose([ 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean, std), 56 | ]), 57 | transform3=transforms.Compose1([ 58 | transforms.Scale1(384, 384), 59 | transforms.ToTensor1(), 60 | transforms.Normalize1(mean, std), 61 | ]), 62 | seed=seed, 63 | pre=pre, 64 | label_file=label_file) 65 | elif split == 'train': 66 | dataset = commonDataset(root=dataroot, 67 | transform=transforms.Compose([ 68 | transforms.Scale(originalSize_h, originalSize_w), 69 | transforms.RandomCrop(imageSize_h, imageSize_w), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | transforms.Normalize(mean, std), 73 | ]), 74 | seed=seed, 75 | pre=pre, 76 | label_file=label_file) 77 | elif split == 'LRN_train_guide2': 78 | dataset = commonDataset(root=dataroot, 79 | transform1=transforms.Scale(originalSize_h, originalSize_w), 80 | transform2=transforms.RandomCrop_index(imageSize_h, imageSize_w), 81 | transform3=transforms.RandomHorizontalFlip_index(), 82 | transform4=transforms.GuideCrop(384, 384), 83 | transform5=transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Normalize(mean, std), 86 | ]), 87 | transform6=transforms.Compose1([ 88 | transforms.Scale1(384, 384), 89 | transforms.ToTensor1(), 90 | transforms.Normalize1(mean, std), 91 | ]), 92 | seed=seed, 93 | pre=pre, 94 | list_file=list_file) 95 | elif split == 'LRN_train_guide': 96 | dataset = commonDataset(root=dataroot, 97 | transform1=transforms.Scale(originalSize_h, originalSize_w), 98 | transform2=transforms.RandomCrop_index(imageSize_h, imageSize_w), 99 | transform3=transforms.RandomHorizontalFlip_index(), 100 | transform4=transforms.GuideCrop(384, 384), 101 | transform5=transforms.Compose([ 102 | transforms.ToTensor(), 103 | transforms.Normalize(mean, std), 104 | ]), 105 | transform6=transforms.Compose1([ 106 | transforms.Scale1(384, 384), 107 | transforms.ToTensor1(), 108 | transforms.Normalize1(mean, std), 109 | ]), 110 | seed=seed, 111 | pre=pre, 112 | label_file=label_file) 113 | 114 | dataloader = torch.utils.data.DataLoader(dataset, 115 | batch_size=batchSize, 116 | shuffle=shuffle, 117 | num_workers=int(workers)) 118 | return dataloader 119 | 120 | 121 | class AverageMeter(object): 122 | """Computes and stores the average and current value""" 123 | def __init__(self): 124 | self.reset() 125 | 126 | def reset(self): 127 | self.val = 0 128 | self.avg = 0 129 | self.sum = 0 130 | self.count = 0 131 | 132 | def update(self, val, n=1): 133 | self.val = val 134 | self.sum += val * n 135 | self.count += n 136 | self.avg = self.sum / self.count 137 | 138 | 139 | import numpy as np 140 | class ImagePool: 141 | def __init__(self, pool_size=50): 142 | self.pool_size = pool_size 143 | if pool_size > 0: 144 | self.num_imgs = 0 145 | self.images = [] 146 | 147 | def query(self, image): 148 | if self.pool_size == 0: 149 | return image 150 | if self.num_imgs < self.pool_size: 151 | self.images.append(image.clone()) 152 | self.num_imgs += 1 153 | return image 154 | else: 155 | if np.random.uniform(0,1) > 0.5: 156 | random_id = np.random.randint(self.pool_size, size=1)[0] 157 | tmp = self.images[random_id].clone() 158 | self.images[random_id] = image.clone() 159 | return tmp 160 | else: 161 | return image 162 | 163 | 164 | def adjust_learning_rate(optimizer, init_lr, epoch, factor, every): 165 | #import pdb; pdb.set_trace() 166 | lrd = init_lr / every 167 | old_lr = optimizer.param_groups[0]['lr'] 168 | # linearly decaying lr 169 | lr = old_lr - lrd 170 | if lr < 0: lr = 0 171 | for param_group in optimizer.param_groups: 172 | param_group['lr'] = lr 173 | -------------------------------------------------------------------------------- /transforms/pix2pix_val.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB): 24 | for t in self.transforms: 25 | imgA, imgB = t(imgA, imgB) 26 | return imgA, imgB 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB): 33 | pics = [picA, picB] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB): 61 | pics = [picA, picB] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB): 87 | tensors = [tensorA, tensorB] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB): 109 | imgs = [imgA, imgB] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB): 139 | imgs = [imgA, imgB] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | # output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | 148 | output.append(img) 149 | 150 | return output[0], output[1] 151 | 152 | class Pad(object): 153 | """Pads the given PIL.Image on all sides with the given "pad" value""" 154 | def __init__(self, padding, fill=0): 155 | assert isinstance(padding, numbers.Number) 156 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 157 | self.padding = padding 158 | self.fill = fill 159 | 160 | def __call__(self, imgA, imgB): 161 | imgs = [imgA, imgB] 162 | output = [] 163 | for img in imgs: 164 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 165 | return output[0], output[1] 166 | 167 | class Lambda(object): 168 | """Applies a lambda as a transform.""" 169 | def __init__(self, lambd): 170 | assert isinstance(lambd, types.LambdaType) 171 | self.lambd = lambd 172 | 173 | def __call__(self, imgA, imgB): 174 | imgs = [imgA, imgB] 175 | output = [] 176 | for img in imgs: 177 | output.append(self.lambd(img)) 178 | return output[0], output[1] 179 | 180 | class RandomCrop(object): 181 | """Crops the given PIL.Image at a random location to have a region of 182 | the given size. size can be a tuple (target_height, target_width) 183 | or an integer, in which case the target will be of a square shape (size, size) 184 | """ 185 | def __init__(self, size, padding=0): 186 | if isinstance(size, numbers.Number): 187 | self.size = (int(size), int(size)) 188 | else: 189 | self.size = size 190 | self.padding = padding 191 | 192 | def __call__(self, imgA, imgB): 193 | imgs = [imgA, imgB] 194 | output = [] 195 | x1 = -1 196 | y1 = -1 197 | for img in imgs: 198 | if self.padding > 0: 199 | img = ImageOps.expand(img, border=self.padding, fill=0) 200 | 201 | w, h = img.size 202 | th, tw = self.size 203 | if w == tw and h == th: 204 | output.append(img) 205 | continue 206 | 207 | if x1 == -1 and y1 == -1: 208 | x1 = random.randint(0, w - tw) 209 | y1 = random.randint(0, h - th) 210 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 211 | return output[0], output[1] 212 | 213 | class RandomHorizontalFlip(object): 214 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 215 | """ 216 | def __call__(self, imgA, imgB): 217 | imgs = [imgA, imgB] 218 | output = [] 219 | flag = random.random() < 0.5 220 | for img in imgs: 221 | if flag: 222 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 223 | else: 224 | output.append(img) 225 | return output[0], output[1] 226 | -------------------------------------------------------------------------------- /transforms/pix2pix3.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB, imgC): 24 | for t in self.transforms: 25 | imgA, imgB, imgC = t(imgA, imgB, imgC) 26 | return imgA, imgB, imgC 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB, picC): 33 | pics = [picA, picB, picC] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1], output[2] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB, picC): 61 | pics = [picA, picB, picC] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1], output[2] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB, tensorC): 87 | tensors = [tensorA, tensorB, tensorC] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1], output[2] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB, imgC): 109 | imgs = [imgA, imgB, imgC] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1], output[2] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB, imgC): 139 | imgs = [imgA, imgB, imgC] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | return output[0], output[1], output[2] 148 | 149 | class Pad(object): 150 | """Pads the given PIL.Image on all sides with the given "pad" value""" 151 | def __init__(self, padding, fill=0): 152 | assert isinstance(padding, numbers.Number) 153 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 154 | self.padding = padding 155 | self.fill = fill 156 | 157 | def __call__(self, imgA, imgB, imgC): 158 | imgs = [imgA, imgB, imgC] 159 | output = [] 160 | for img in imgs: 161 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 162 | return output[0], output[1], output[2] 163 | 164 | class Lambda(object): 165 | """Applies a lambda as a transform.""" 166 | def __init__(self, lambd): 167 | assert isinstance(lambd, types.LambdaType) 168 | self.lambd = lambd 169 | 170 | def __call__(self, imgA, imgB, imgC): 171 | imgs = [imgA, imgB, imgC] 172 | output = [] 173 | for img in imgs: 174 | output.append(self.lambd(img)) 175 | return output[0], output[1], output[2] 176 | 177 | class RandomCrop(object): 178 | """Crops the given PIL.Image at a random location to have a region of 179 | the given size. size can be a tuple (target_height, target_width) 180 | or an integer, in which case the target will be of a square shape (size, size) 181 | """ 182 | def __init__(self, size, padding=0): 183 | if isinstance(size, numbers.Number): 184 | self.size = (int(size), int(size)) 185 | else: 186 | self.size = size 187 | self.padding = padding 188 | 189 | def __call__(self, imgA, imgB, imgC): 190 | imgs = [imgA, imgB, imgC] 191 | output = [] 192 | x1 = -1 193 | y1 = -1 194 | for img in imgs: 195 | if self.padding > 0: 196 | img = ImageOps.expand(img, border=self.padding, fill=0) 197 | 198 | w, h = img.size 199 | th, tw = self.size 200 | if w == tw and h == th: 201 | output.append(img) 202 | continue 203 | 204 | if x1 == -1 and y1 == -1: 205 | x1 = random.randint(0, w - tw) 206 | y1 = random.randint(0, h - th) 207 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 208 | return output[0], output[1], output[2] 209 | 210 | class RandomHorizontalFlip(object): 211 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 212 | """ 213 | def __call__(self, imgA, imgB, imgC): 214 | imgs = [imgA, imgB, imgC] 215 | output = [] 216 | # flag = random.random() < 0.5 217 | flag = random.random() < -1 218 | 219 | for img in imgs: 220 | if flag: 221 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 222 | else: 223 | output.append(img) 224 | return output[0], output[1], output[2] 225 | -------------------------------------------------------------------------------- /transforms/pix2pix_val3.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB, imgC): 24 | for t in self.transforms: 25 | imgA, imgB, imgC = t(imgA, imgB, imgC) 26 | return imgA, imgB, imgC 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB, picC): 33 | pics = [picA, picB, picC] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1], output[2] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB, picC): 61 | pics = [picA, picB, picC] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1], output[2] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB, tensorC): 87 | tensors = [tensorA, tensorB, tensorC] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1], output[2] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB, imgC): 109 | imgs = [imgA, imgB, imgC] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1], output[2] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB, imgC): 139 | imgs = [imgA, imgB, imgC] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | return output[0], output[1], output[2] 148 | 149 | class Pad(object): 150 | """Pads the given PIL.Image on all sides with the given "pad" value""" 151 | def __init__(self, padding, fill=0): 152 | assert isinstance(padding, numbers.Number) 153 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 154 | self.padding = padding 155 | self.fill = fill 156 | 157 | def __call__(self, imgA, imgB, imgC): 158 | imgs = [imgA, imgB, imgC] 159 | output = [] 160 | for img in imgs: 161 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 162 | return output[0], output[1], output[2] 163 | 164 | class Lambda(object): 165 | """Applies a lambda as a transform.""" 166 | def __init__(self, lambd): 167 | assert isinstance(lambd, types.LambdaType) 168 | self.lambd = lambd 169 | 170 | def __call__(self, imgA, imgB, imgC): 171 | imgs = [imgA, imgB, imgC] 172 | output = [] 173 | for img in imgs: 174 | output.append(self.lambd(img)) 175 | return output[0], output[1], output[2] 176 | 177 | class RandomCrop(object): 178 | """Crops the given PIL.Image at a random location to have a region of 179 | the given size. size can be a tuple (target_height, target_width) 180 | or an integer, in which case the target will be of a square shape (size, size) 181 | """ 182 | def __init__(self, size, padding=0): 183 | if isinstance(size, numbers.Number): 184 | self.size = (int(size), int(size)) 185 | else: 186 | self.size = size 187 | self.padding = padding 188 | 189 | def __call__(self, imgA, imgB, imgC): 190 | imgs = [imgA, imgB, imgC] 191 | output = [] 192 | x1 = -1 193 | y1 = -1 194 | for img in imgs: 195 | if self.padding > 0: 196 | img = ImageOps.expand(img, border=self.padding, fill=0) 197 | 198 | w, h = img.size 199 | th, tw = self.size 200 | if w == tw and h == th: 201 | output.append(img) 202 | continue 203 | 204 | if x1 == -1 and y1 == -1: 205 | x1 = random.randint(0, w - tw) 206 | y1 = random.randint(0, h - th) 207 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 208 | return output[0], output[1], output[2] 209 | 210 | class RandomHorizontalFlip(object): 211 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 212 | """ 213 | def __call__(self, imgA, imgB, imgC): 214 | imgs = [imgA, imgB, imgC] 215 | output = [] 216 | # flag = random.random() < 0.5 217 | flag = random.random() < -1 218 | 219 | for img in imgs: 220 | if flag: 221 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 222 | else: 223 | output.append(img) 224 | return output[0], output[1], output[2] 225 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | import util 6 | import html 7 | from scipy.misc import imresize 8 | 9 | 10 | class Visualizer(): 11 | def __init__(self, display_port, name): 12 | self.display_id = 1 13 | self.use_html = False 14 | self.win_size = 256 15 | self.name = name 16 | self.saved = False 17 | if self.display_id > 0: 18 | import visdom 19 | self.ncols = 4 20 | self.vis = visdom.Visdom(server="http://localhost", port=display_port) 21 | 22 | # if self.use_html: 23 | # self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 24 | # self.img_dir = os.path.join(self.web_dir, 'images') 25 | # print('create web directory %s...' % self.web_dir) 26 | # util.mkdirs([self.web_dir, self.img_dir]) 27 | # self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 28 | # with open(self.log_name, "a") as log_file: 29 | # now = time.strftime("%c") 30 | # log_file.write('================ Training Loss (%s) ================\n' % now) 31 | 32 | def reset(self): 33 | self.saved = False 34 | 35 | # |visuals|: dictionary of images to display or save 36 | def display_current_results(self, visuals, epoch, save_result): 37 | if self.display_id > 0: # show images in the browser 38 | ncols = self.ncols 39 | if ncols > 0: 40 | ncols = min(ncols, len(visuals)) 41 | h, w = next(iter(visuals.values())).shape[:2] 42 | table_css = """""" % (w, h) 46 | title = self.name 47 | label_html = '' 48 | label_html_row = '' 49 | images = [] 50 | idx = 0 51 | for label, image in visuals.items(): 52 | image_numpy = util.tensor2im(image) 53 | label_html_row += '%s' % label 54 | images.append(image_numpy.transpose([2, 0, 1])) 55 | idx += 1 56 | if idx % ncols == 0: 57 | label_html += '%s' % label_html_row 58 | label_html_row = '' 59 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 60 | while idx % ncols != 0: 61 | images.append(white_image) 62 | label_html_row += '' 63 | idx += 1 64 | if label_html_row != '': 65 | label_html += '%s' % label_html_row 66 | # pane col = image row 67 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 68 | padding=2, opts=dict(title=title + ' images')) 69 | label_html = '%s
' % label_html 70 | self.vis.text(table_css + label_html, win=self.display_id + 2, 71 | opts=dict(title=title + ' labels')) 72 | else: 73 | idx = 1 74 | for label, image in visuals.items(): 75 | image_numpy = util.tensor2im(image) 76 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 77 | win=self.display_id + idx) 78 | idx += 1 79 | 80 | if self.use_html and (save_result or not self.saved): # save images to a html file 81 | self.saved = True 82 | for label, image in visuals.items(): 83 | image_numpy = util.tensor2im(image) 84 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 85 | util.save_image(image_numpy, img_path) 86 | # update website 87 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 88 | for n in range(epoch, 0, -1): 89 | webpage.add_header('epoch [%d]' % n) 90 | ims, txts, links = [], [], [] 91 | 92 | for label, image_numpy in visuals.items(): 93 | image_numpy = util.tensor2im(image) 94 | img_path = 'epoch%.3d_%s.png' % (n, label) 95 | ims.append(img_path) 96 | txts.append(label) 97 | links.append(img_path) 98 | webpage.add_images(ims, txts, links, width=self.win_size) 99 | webpage.save() 100 | 101 | # losses: dictionary of error labels and values 102 | # def plot_current_losses(self, epoch, counter_ratio, opt, losses): 103 | # if not hasattr(self, 'plot_data'): 104 | # self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 105 | # self.plot_data['X'].append(epoch + counter_ratio) 106 | # self.plot_data['Y'].append([util.tensor2float(losses[k]) for k in self.plot_data['legend']]) 107 | # self.vis.line( 108 | # X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 109 | # Y=np.array(self.plot_data['Y']), 110 | # opts={ 111 | # 'title': self.name + ' loss over time', 112 | # 'legend': self.plot_data['legend'], 113 | # 'xlabel': 'epoch', 114 | # 'ylabel': 'loss'}, 115 | # win=self.display_id) 116 | 117 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 118 | if not hasattr(self, 'plot_data'): 119 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 120 | self.plot_data['X'].append(epoch + counter_ratio) 121 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 122 | # print(50*'-') 123 | # print(losses['L_img']) 124 | self.vis.line( 125 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 126 | Y=np.array(self.plot_data['Y']), 127 | opts={ 128 | 'title': self.name + ' loss over time', 129 | 'legend': self.plot_data['legend'], 130 | 'xlabel': 'epoch', 131 | 'ylabel': 'loss'}, 132 | win=self.display_id) 133 | 134 | 135 | # losses: same format as |losses| of plot_current_losses 136 | def print_current_losses(self, epoch, i, losses, t, t_data): 137 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 138 | for k, v in losses.items(): 139 | message += '%s: %.3f ' % (k, v) 140 | 141 | print(message) 142 | return message 143 | # with open(self.log_name, "a") as log_file: 144 | # log_file.write('%s\n' % message) 145 | 146 | # save image to the disk 147 | def save_images(self, webpage, visuals, image_path, aspect_ratio=1.0): 148 | image_dir = webpage.get_image_dir() 149 | short_path = ntpath.basename(image_path[0]) 150 | name = os.path.splitext(short_path)[0] 151 | 152 | webpage.add_header(name) 153 | ims, txts, links = [], [], [] 154 | 155 | for label, im_data in visuals.items(): 156 | im = util.tensor2im(im_data) 157 | image_name = '%s_%s.png' % (name, label) 158 | save_path = os.path.join(image_dir, image_name) 159 | h, w, _ = im.shape 160 | if aspect_ratio > 1.0: 161 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 162 | if aspect_ratio < 1.0: 163 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 164 | util.save_image(im, save_path) 165 | 166 | ims.append(image_name) 167 | txts.append(label) 168 | links.append(image_name) 169 | webpage.add_images(ims, txts, links, width=self.win_size) 170 | -------------------------------------------------------------------------------- /models_metric/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from pdb import set_trace as st 11 | from skimage import color 12 | # from IPython import embed 13 | from . import pretrained_networks as pn 14 | 15 | import models_metric as util 16 | 17 | def spatial_average(in_tens, keepdim=True): 18 | return in_tens.mean([2,3],keepdim=keepdim) 19 | 20 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 21 | in_H = in_tens.shape[2] 22 | scale_factor = 1.*out_H/in_H 23 | 24 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 25 | 26 | # Learned perceptual metric 27 | class PNetLin(nn.Module): 28 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 29 | super(PNetLin, self).__init__() 30 | 31 | self.pnet_type = pnet_type 32 | self.pnet_tune = pnet_tune 33 | self.pnet_rand = pnet_rand 34 | self.spatial = spatial 35 | self.lpips = lpips 36 | self.version = version 37 | self.scaling_layer = ScalingLayer() 38 | 39 | if(self.pnet_type in ['vgg','vgg16']): 40 | net_type = pn.vgg16 41 | self.chns = [64,128,256,512,512] 42 | elif(self.pnet_type=='alex'): 43 | net_type = pn.alexnet 44 | self.chns = [64,192,384,256,256] 45 | elif(self.pnet_type=='squeeze'): 46 | net_type = pn.squeezenet 47 | self.chns = [64,128,256,384,384,512,512] 48 | self.L = len(self.chns) 49 | 50 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 51 | 52 | if(lpips): 53 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 54 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 55 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 56 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 57 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 58 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 59 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 60 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 61 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 62 | self.lins+=[self.lin5,self.lin6] 63 | 64 | def forward(self, in0, in1, retPerLayer=False): 65 | # v0.0 - original release had a bug, where input was not scaled 66 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 67 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 68 | feats0, feats1, diffs = {}, {}, {} 69 | 70 | for kk in range(self.L): 71 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 72 | diffs[kk] = (feats0[kk]-feats1[kk])**2 73 | 74 | if(self.lpips): 75 | if(self.spatial): 76 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 77 | else: 78 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 79 | else: 80 | if(self.spatial): 81 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 82 | else: 83 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 84 | 85 | val = res[0] 86 | for l in range(1,self.L): 87 | val += res[l] 88 | 89 | if(retPerLayer): 90 | return (val, res) 91 | else: 92 | return val 93 | 94 | class ScalingLayer(nn.Module): 95 | def __init__(self): 96 | super(ScalingLayer, self).__init__() 97 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 98 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 99 | 100 | def forward(self, inp): 101 | return (inp - self.shift) / self.scale 102 | 103 | 104 | class NetLinLayer(nn.Module): 105 | ''' A single linear layer which does a 1x1 conv ''' 106 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 107 | super(NetLinLayer, self).__init__() 108 | 109 | layers = [nn.Dropout(),] if(use_dropout) else [] 110 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 111 | self.model = nn.Sequential(*layers) 112 | 113 | 114 | class Dist2LogitLayer(nn.Module): 115 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 116 | def __init__(self, chn_mid=32, use_sigmoid=True): 117 | super(Dist2LogitLayer, self).__init__() 118 | 119 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 120 | layers += [nn.LeakyReLU(0.2,True),] 121 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 122 | layers += [nn.LeakyReLU(0.2,True),] 123 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 124 | if(use_sigmoid): 125 | layers += [nn.Sigmoid(),] 126 | self.model = nn.Sequential(*layers) 127 | 128 | def forward(self,d0,d1,eps=0.1): 129 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 130 | 131 | class BCERankingLoss(nn.Module): 132 | def __init__(self, chn_mid=32): 133 | super(BCERankingLoss, self).__init__() 134 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 135 | # self.parameters = list(self.net.parameters()) 136 | self.loss = torch.nn.BCELoss() 137 | 138 | def forward(self, d0, d1, judge): 139 | per = (judge+1.)/2. 140 | self.logit = self.net.forward(d0,d1) 141 | return self.loss(self.logit, per) 142 | 143 | # L2, DSSIM metrics 144 | class FakeNet(nn.Module): 145 | def __init__(self, use_gpu=True, colorspace='Lab'): 146 | super(FakeNet, self).__init__() 147 | self.use_gpu = use_gpu 148 | self.colorspace=colorspace 149 | 150 | class L2(FakeNet): 151 | 152 | def forward(self, in0, in1, retPerLayer=None): 153 | assert(in0.size()[0]==1) # currently only supports batchSize 1 154 | 155 | if(self.colorspace=='RGB'): 156 | (N,C,X,Y) = in0.size() 157 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 158 | return value 159 | elif(self.colorspace=='Lab'): 160 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 161 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 162 | ret_var = Variable( torch.Tensor((value,) ) ) 163 | if(self.use_gpu): 164 | ret_var = ret_var.cuda() 165 | return ret_var 166 | 167 | class DSSIM(FakeNet): 168 | 169 | def forward(self, in0, in1, retPerLayer=None): 170 | assert(in0.size()[0]==1) # currently only supports batchSize 1 171 | 172 | if(self.colorspace=='RGB'): 173 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 174 | elif(self.colorspace=='Lab'): 175 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 176 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 177 | ret_var = Variable( torch.Tensor((value,) ) ) 178 | if(self.use_gpu): 179 | ret_var = ret_var.cuda() 180 | return ret_var 181 | 182 | def print_network(net): 183 | num_params = 0 184 | for param in net.parameters(): 185 | num_params += param.numel() 186 | print('Network',net) 187 | print('Total number of parameters: %d' % num_params) 188 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import sys 5 | import random 6 | import time 7 | import pdb 8 | from PIL import Image 9 | import math 10 | import numpy as np 11 | import cv2 12 | from skimage.measure import compare_ssim as ssim 13 | from skimage.measure import compare_psnr as Psnr 14 | from collections import OrderedDict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.optim as optim 21 | import torchvision.utils as vutils 22 | from torch.autograd import Variable 23 | import torch.nn.functional as F 24 | cudnn.benchmark = True 25 | cudnn.fastest = True 26 | 27 | from misc import * 28 | import models.networks as net 29 | from myutils import utils 30 | import models_metric 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--dataset', required=False, 34 | default='my_loader', help='name for the dataset loader') 35 | parser.add_argument('--dataroot', required=False, 36 | default='', help='path to trn dataset') 37 | parser.add_argument('--netGDN', default='', help="path to netGDN") 38 | parser.add_argument('--netLRN', default='', help="path to netLRN") 39 | parser.add_argument('--netFDN', default='', help="path to netFDN") 40 | parser.add_argument('--netFRN', default='', help="path to netFRN") 41 | parser.add_argument('--kernel_size', type=int, default=8, help='patch size for dct') 42 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 43 | parser.add_argument('--originalSize_h', type=int, 44 | default=539, help='the height of the original input image') 45 | parser.add_argument('--originalSize_w', type=int, 46 | default=959, help='the height of the original input image') 47 | parser.add_argument('--imageSize_h', type=int, 48 | default=512, help='the height of the cropped input image to network') 49 | parser.add_argument('--imageSize_w', type=int, 50 | default=512, help='the width of the cropped input image to network') 51 | parser.add_argument('--pre', type=str, default='', help='prefix of different dataset') 52 | parser.add_argument('--image_path', type=str, default='', help='path to save the evaluated image') 53 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) 54 | parser.add_argument('--record', type=str, default='default.txt', help='file to record scores for each image') 55 | parser.add_argument('--write', type=int, default=0, help='determine whether we save the result images') 56 | opt = parser.parse_args() 57 | print(opt) 58 | 59 | opt.manualSeed = random.randint(1, 10000) 60 | random.seed(opt.manualSeed) 61 | torch.manual_seed(opt.manualSeed) 62 | torch.cuda.manual_seed_all(opt.manualSeed) 63 | print("Random Seed: ", opt.manualSeed) 64 | 65 | val_dataloader = getLoader(opt.dataset, 66 | opt.dataroot, 67 | opt.originalSize_h, 68 | opt.originalSize_w, 69 | opt.imageSize_h, 70 | opt.imageSize_w, 71 | opt.batchSize, 72 | opt.workers, 73 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 74 | split='test', 75 | shuffle=False, 76 | seed=opt.manualSeed, 77 | pre=opt.pre) 78 | 79 | 80 | if opt.write==0: 81 | print('no') 82 | else: 83 | print('yes') 84 | 85 | device = torch.device("cuda:0") 86 | 87 | # dfine and load models 88 | netGDN = net.GDN() 89 | if opt.netGDN != '': 90 | print("load pre-trained GDN model!!!!!!!!!!!!!!!!!") 91 | netGDN.load_state_dict(torch.load(opt.netGDN)) 92 | netGDN.eval() 93 | utils.set_requires_grad(netGDN, False) 94 | 95 | netLRN = net.LRN() 96 | if opt.netLRN != '': 97 | print("load pre-trained LRN model!!!!!!!!!!!!!!!!!") 98 | netLRN.load_state_dict(torch.load(opt.netLRN)) 99 | netLRN.eval() 100 | utils.set_requires_grad(netLRN, False) 101 | 102 | netFDN = net.FDN(ORI_SIZE=opt.imageSize_w, KERNEL_SIZE=opt.kernel_size) 103 | if opt.netFDN != '': 104 | print("load pre-trained FDN model!!!!!!!!!!!!!!!!!") 105 | netFDN.load_state_dict(torch.load(opt.netFDN)) 106 | netFDN.eval() 107 | utils.set_requires_grad(netFDN, False) 108 | 109 | netFRN = net.FRN() 110 | if opt.netFRN != '': 111 | print("load pre-trained FRN model!!!!!!!!!!!!!!!!!") 112 | netFRN.load_state_dict(torch.load(opt.netFRN)) 113 | netFRN.eval() 114 | utils.set_requires_grad(netFRN, False) 115 | 116 | # load metric 117 | net_metric = models_metric.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, spatial=False) 118 | net_metric = net_metric.cuda() 119 | utils.set_requires_grad(net_metric, requires_grad=False) 120 | 121 | # to gpu 122 | netLRN.to(device) 123 | netGDN.to(device) 124 | netFDN.to(device) 125 | netFRN.to(device) 126 | 127 | my_psnr = 0 128 | my_ssim_multi = 0 129 | patch_size = 384 130 | res = 0 131 | cnt1 = 0 132 | f = open(opt.record, "w") 133 | for i, data in enumerate(val_dataloader, 0): 134 | # netG.eval() 135 | print(50*'-') 136 | print(i) 137 | 138 | input, target, down_input, name= data 139 | batch_size = input.size(0) 140 | input = input.cuda() 141 | target = target.cuda() 142 | down_input = down_input.cuda() 143 | 144 | gray_input = 0.299 * input[:, 0, :, : ] + 0.587 * input[:, 1, :, : ] + 0.114 * input[:, 2, :, : ] 145 | gray_input.unsqueeze_(1) 146 | gray_target = 0.299 * target[:, 0, :, : ] + 0.587 * target[:, 1, :, : ] + 0.114 * target[:, 2, :, : ] 147 | gray_target.unsqueeze_(1) 148 | 149 | # GDN 150 | demoire_down = netGDN(down_input)[-1].detach() 151 | 152 | # upsampling 153 | demoire_up = F.interpolate(demoire_down, size=[opt.imageSize_h, opt.imageSize_w], mode='bilinear') 154 | 155 | # LRN 156 | demoire_up = netLRN(demoire_up) 157 | 158 | # get Y channel 159 | gray_demoire_up = 0.299 * demoire_up[:, 0, :, : ] + 0.587 * demoire_up[:, 1, :, : ] + 0.114 * demoire_up[:, 2, :, : ] 160 | gray_demoire_up.unsqueeze_(1) 161 | 162 | # FDN 163 | dct_oup = netFDN(gray_input, gray_demoire_up) 164 | 165 | # merge YUV from spatial and frequency domain 166 | demoire_up_u = -0.169 * demoire_up[:, 0, :, : ] - 0.331 * demoire_up[:, 1, :, : ] + 0.5 * demoire_up[:, 2, :, : ] - 1 167 | demoire_up_u.unsqueeze_(1) 168 | demoire_up_v = 0.5 * demoire_up[:, 0, :, : ] - 0.419 * demoire_up[:, 1, :, : ] - 0.081 * demoire_up[:, 2, :, : ] - 1 169 | demoire_up_v.unsqueeze_(1) 170 | yuv_merged_image = torch.cat([dct_oup, demoire_up_u, demoire_up_v], dim=1) 171 | 172 | # YUV to RGB 173 | r_merged_image = yuv_merged_image[:,0,:,:] + 1.403 * yuv_merged_image[:,2,:,:] + 1.403 174 | r_merged_image.unsqueeze_(1) 175 | g_merged_image = yuv_merged_image[:,0,:,:] -0.344 * yuv_merged_image[:,1,:,:] -0.714 * yuv_merged_image[:,2,:,:] -1.058 176 | g_merged_image.unsqueeze_(1) 177 | b_merged_image = yuv_merged_image[:,0,:,:] +1.773 * yuv_merged_image[:,1,:,:] + 1.773 178 | b_merged_image.unsqueeze_(1) 179 | 180 | # FRN 181 | merged = torch.cat([r_merged_image, g_merged_image, b_merged_image], dim=1) 182 | x_hat = netFRN(merged) 183 | 184 | # calculate scores 185 | cnt1+=batch_size 186 | tmp = torch.sum(net_metric(target, x_hat).detach()) 187 | res += tmp 188 | L = str(tmp) 189 | print(res / cnt1) 190 | 191 | for j in range(x_hat.shape[0]): 192 | b, c, w, h = x_hat.shape 193 | ti1 = x_hat[j, :,:,: ] 194 | tt1 = target[j, :,:,: ] 195 | mi1 = cv2.cvtColor(utils.my_tensor2im(ti1), cv2.COLOR_BGR2RGB) 196 | mt1 = cv2.cvtColor(utils.my_tensor2im(tt1), cv2.COLOR_BGR2RGB) 197 | tmp2 = Psnr(mt1, mi1) 198 | my_psnr += tmp2 199 | tmp3 = ssim(mt1, mi1, multichannel=True) 200 | my_ssim_multi += tmp3 201 | L = L +' ' + str(tmp2) +str(tmp3) + '\n' 202 | f.write(L) 203 | if opt.write == 1 and i<200/batch_size: 204 | if os.path.exists(opt.image_path) == False: 205 | os.makedirs(opt.image_path) 206 | cv2.imwrite(opt.image_path +os.sep+'res_' + name[j] +'.png', mi1) 207 | print(my_psnr / cnt1) 208 | print(my_ssim_multi / cnt1) 209 | 210 | print("avergaed results:") 211 | print(res / cnt1) 212 | print(my_psnr / cnt1) 213 | print(my_ssim_multi / cnt1) 214 | f.close() 215 | -------------------------------------------------------------------------------- /util_metirc/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | from . import util 5 | from . import html 6 | # from pdb import set_trace as st 7 | import matplotlib.pyplot as plt 8 | import math 9 | # from IPython import embed 10 | 11 | def zoom_to_res(img,res=256,order=0,axis=0): 12 | # img 3xXxX 13 | from scipy.ndimage import zoom 14 | zoom_factor = res/img.shape[1] 15 | if(axis==0): 16 | return zoom(img,[1,zoom_factor,zoom_factor],order=order) 17 | elif(axis==2): 18 | return zoom(img,[zoom_factor,zoom_factor,1],order=order) 19 | 20 | class Visualizer(): 21 | def __init__(self, opt): 22 | # self.opt = opt 23 | self.display_id = opt.display_id 24 | # self.use_html = opt.is_train and not opt.no_html 25 | self.win_size = opt.display_winsize 26 | self.name = opt.name 27 | self.display_cnt = 0 # display_current_results counter 28 | self.display_cnt_high = 0 29 | self.use_html = opt.use_html 30 | 31 | if self.display_id > 0: 32 | import visdom 33 | self.vis = visdom.Visdom(port = opt.display_port) 34 | 35 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 36 | util.mkdirs([self.web_dir,]) 37 | if self.use_html: 38 | self.img_dir = os.path.join(self.web_dir, 'images') 39 | print('create web directory %s...' % self.web_dir) 40 | util.mkdirs([self.img_dir,]) 41 | 42 | # |visuals|: dictionary of images to display or save 43 | def display_current_results(self, visuals, epoch, nrows=None, res=256): 44 | if self.display_id > 0: # show images in the browser 45 | title = self.name 46 | if(nrows is None): 47 | nrows = int(math.ceil(len(visuals.items()) / 2.0)) 48 | images = [] 49 | idx = 0 50 | for label, image_numpy in visuals.items(): 51 | title += " | " if idx % nrows == 0 else ", " 52 | title += label 53 | img = image_numpy.transpose([2, 0, 1]) 54 | img = zoom_to_res(img,res=res,order=0) 55 | images.append(img) 56 | idx += 1 57 | if len(visuals.items()) % 2 != 0: 58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 59 | white_image = zoom_to_res(white_image,res=res,order=0) 60 | images.append(white_image) 61 | self.vis.images(images, nrow=nrows, win=self.display_id + 1, 62 | opts=dict(title=title)) 63 | 64 | if self.use_html: # save images to a html file 65 | for label, image_numpy in visuals.items(): 66 | img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label)) 67 | util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path) 68 | 69 | self.display_cnt += 1 70 | self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt) 71 | 72 | # update website 73 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 74 | for n in range(epoch, 0, -1): 75 | webpage.add_header('epoch [%d]' % n) 76 | if(n==epoch): 77 | high = self.display_cnt 78 | else: 79 | high = self.display_cnt_high 80 | for c in range(high-1,-1,-1): 81 | ims = [] 82 | txts = [] 83 | links = [] 84 | 85 | for label, image_numpy in visuals.items(): 86 | img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label) 87 | ims.append(os.path.join('images',img_path)) 88 | txts.append(label) 89 | links.append(os.path.join('images',img_path)) 90 | webpage.add_images(ims, txts, links, width=self.win_size) 91 | webpage.save() 92 | 93 | # save errors into a directory 94 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False): 95 | if not hasattr(self, 'plot_data'): 96 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 97 | self.plot_data['X'].append(epoch + counter_ratio) 98 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 99 | 100 | # embed() 101 | if(keys=='+ALL'): 102 | plot_keys = self.plot_data['legend'] 103 | else: 104 | plot_keys = keys 105 | 106 | if(to_plot): 107 | (f,ax) = plt.subplots(1,1) 108 | for (k,kname) in enumerate(plot_keys): 109 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0] 110 | x = self.plot_data['X'] 111 | y = np.array(self.plot_data['Y'])[:,kk] 112 | if(to_plot): 113 | ax.plot(x, y, 'o-', label=kname) 114 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x) 115 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y) 116 | 117 | if(to_plot): 118 | plt.legend(loc=0,fontsize='small') 119 | plt.xlabel('epoch') 120 | plt.ylabel('Value') 121 | f.savefig(os.path.join(self.web_dir,'%s.png'%name)) 122 | f.clf() 123 | plt.close() 124 | 125 | # errors: dictionary of error labels and values 126 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 127 | if not hasattr(self, 'plot_data'): 128 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 129 | self.plot_data['X'].append(epoch + counter_ratio) 130 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 131 | self.vis.line( 132 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 133 | Y=np.array(self.plot_data['Y']), 134 | opts={ 135 | 'title': self.name + ' loss over time', 136 | 'legend': self.plot_data['legend'], 137 | 'xlabel': 'epoch', 138 | 'ylabel': 'loss'}, 139 | win=self.display_id) 140 | 141 | # errors: same format as |errors| of plotCurrentErrors 142 | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None): 143 | message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2) 144 | message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()]) 145 | 146 | print(message) 147 | if(fid is not None): 148 | fid.write('%s\n'%message) 149 | 150 | 151 | # save image to the disk 152 | def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256): 153 | image_dir = webpage.get_image_dir() 154 | ims = [] 155 | txts = [] 156 | links = [] 157 | 158 | for name, image_numpy, txt in zip(names, images, in_txts): 159 | image_name = '%s_%s.png' % (prefix, name) 160 | save_path = os.path.join(image_dir, image_name) 161 | if(res is not None): 162 | util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path) 163 | else: 164 | util.save_image(image_numpy, save_path) 165 | 166 | ims.append(os.path.join(webpage.img_subdir,image_name)) 167 | # txts.append(name) 168 | txts.append(txt) 169 | links.append(os.path.join(webpage.img_subdir,image_name)) 170 | # embed() 171 | webpage.add_images(ims, txts, links, width=self.win_size) 172 | 173 | # save image to the disk 174 | def save_images(self, webpage, images, names, image_path, title=''): 175 | image_dir = webpage.get_image_dir() 176 | # short_path = ntpath.basename(image_path) 177 | # name = os.path.splitext(short_path)[0] 178 | # name = short_path 179 | # webpage.add_header('%s, %s' % (name, title)) 180 | ims = [] 181 | txts = [] 182 | links = [] 183 | 184 | for label, image_numpy in zip(names, images): 185 | image_name = '%s.jpg' % (label,) 186 | save_path = os.path.join(image_dir, image_name) 187 | util.save_image(image_numpy, save_path) 188 | 189 | ims.append(image_name) 190 | txts.append(label) 191 | links.append(image_name) 192 | webpage.add_images(ims, txts, links, width=self.win_size) 193 | 194 | # save image to the disk 195 | # def save_images(self, webpage, visuals, image_path, short=False): 196 | # image_dir = webpage.get_image_dir() 197 | # if short: 198 | # short_path = ntpath.basename(image_path) 199 | # name = os.path.splitext(short_path)[0] 200 | # else: 201 | # name = image_path 202 | 203 | # webpage.add_header(name) 204 | # ims = [] 205 | # txts = [] 206 | # links = [] 207 | 208 | # for label, image_numpy in visuals.items(): 209 | # image_name = '%s_%s.png' % (name, label) 210 | # save_path = os.path.join(image_dir, image_name) 211 | # util.save_image(image_numpy, save_path) 212 | 213 | # ims.append(image_name) 214 | # txts.append(label) 215 | # links.append(image_name) 216 | # webpage.add_images(ims, txts, links, width=self.win_size) 217 | -------------------------------------------------------------------------------- /train_GDN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import sys 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | cudnn.benchmark = True 11 | cudnn.fastest = True 12 | import torch.optim as optim 13 | # import torchvision.utils as vutils 14 | from torch.autograd import Variable 15 | from misc import * 16 | import models.networks as net 17 | from myutils.vgg16 import Vgg16 18 | from myutils import utils 19 | from visualizer import Visualizer 20 | import time 21 | import pdb 22 | import torch.nn.functional as F 23 | import sys 24 | import os 25 | from PIL import Image 26 | import math 27 | import numpy as np 28 | from skimage.measure import compare_ssim as ssim 29 | from skimage.measure import compare_psnr as Psnr 30 | import cv2 31 | import util 32 | from collections import OrderedDict 33 | 34 | from model.vgg import VGG19 35 | from model.generator import CXLoss 36 | def cal_psnr(src, tar, avg=False): 37 | 38 | data_range = 2 39 | diff = (src - tar)**2 40 | err = torch.sum(diff, (1,2,3)) / (src.shape[-1] * src.shape[-2] * src.shape[-3] ) 41 | # err = criterion(src, tar) 42 | 43 | psnr = 10 * torch.log10((data_range ** 2) / err) 44 | if avg == False: 45 | return torch.sum(psnr) 46 | else: 47 | return torch.mean(psnr) 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--dataset', required=False, default='my_loader_fs', help='') 50 | parser.add_argument('--dataroot', required=False, default='', help='path to trn dataset') 51 | parser.add_argument('--valDataroot', required=False, default='', help='path to val dataset') 52 | parser.add_argument('--pre', type=str, default='', help='prefix of different dataset') 53 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 54 | parser.add_argument('--epoch_count', type=int, default=1, help='input batch size') 55 | parser.add_argument('--valBatchSize', type=int, default=1, help='input batch size') 56 | parser.add_argument('--originalSize_h', type=int, default=539, help='the height / width of the original input image') 57 | parser.add_argument('--originalSize_w', type=int, default=959, help='the height / width of the original input image') 58 | parser.add_argument('--imageSize_h', type=int, default=512, help='the height / width of the cropped input image to network') 59 | parser.add_argument('--imageSize_w', type=int, default=512, help='the height / width of the cropped input image to network') 60 | parser.add_argument('--inputChannelSize', type=int, default=3, help='size of the input channels') 61 | parser.add_argument('--outputChannelSize', type=int, default=3, help='size of the output channels') 62 | 63 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') 64 | parser.add_argument('--annealEvery', type=int, default=150, help='epoch to reaching at learning rate of 0') 65 | parser.add_argument('--lambdaIMG', type=float, default=1, help='lambdaIMG') 66 | parser.add_argument('--lambdaCX', type=float, default=1, help='lambdaCX') 67 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam') 68 | parser.add_argument('--netG', default='', help="path to netG1 (to continue training)") 69 | parser.add_argument('--netGDN', default='', help="path to netGDN (to continue training)") 70 | parser.add_argument('--netG2', default='', help="path to netG2 (to continue training)") 71 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 72 | parser.add_argument('--netD_moire', default='', help="path to netD_moire (to continue training)") 73 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) 74 | parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints') 75 | parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs') 76 | parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot') 77 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 78 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 79 | parser.add_argument('--vgg19', default='./models/vgg19-dcbb9e9d.pth', help="path to vgg19.pth") 80 | opt = parser.parse_args() 81 | print(opt) 82 | opt.manualSeed = random.randint(1, 10000) 83 | create_exp_dir(opt.exp) 84 | device = torch.device("cuda:0") 85 | 86 | # get dataloader 87 | dataloader = getLoader(opt.dataset, 88 | opt.dataroot, 89 | opt.originalSize_h, 90 | opt.originalSize_w, 91 | opt.imageSize_h, 92 | opt.imageSize_w, 93 | opt.batchSize, 94 | opt.workers, 95 | # mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), 96 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 97 | split='train', 98 | shuffle=True, 99 | seed=opt.manualSeed, 100 | pre=opt.pre) 101 | 102 | # val_dataloader = getLoader(opt.dataset, 103 | # opt.valDataroot, 104 | # opt.originalSize_h, 105 | # opt.originalSize_w, 106 | # opt.imageSize_h, 107 | # opt.imageSize_w, 108 | # opt.valBatchSize, 109 | # 1, 110 | # # mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), 111 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 112 | # split='test', 113 | # shuffle=False, 114 | # seed=opt.manualSeed, 115 | # pre=opt.pre) 116 | 117 | print(len(dataloader)) 118 | 119 | # val_iterator = enumerate(val_dataloader, 0) 120 | # val_cnt = 0 121 | 122 | # get logger 123 | trainLogger = open('%s/train.log' % opt.exp, 'a+') 124 | 125 | inputChannelSize = opt.inputChannelSize 126 | outputChannelSize= opt.outputChannelSize 127 | 128 | # get models 129 | netGDN = net.GDN() 130 | print(netGDN) 131 | visualizer = Visualizer(opt.display_port, opt.name) 132 | 133 | if opt.netGDN != '': 134 | print("load pre-trained model!!!!!!!!!!!!!!!!!") 135 | netGDN.load_state_dict(torch.load(opt.netGDN)) 136 | 137 | netGDN.train() 138 | 139 | vgg19 = VGG19() 140 | vgg19.load_model(opt.vgg19) 141 | vgg19.to(device) 142 | utils.set_requires_grad(vgg19, False) 143 | vgg19.eval() 144 | 145 | vgg_layers = {'conv3_2':1} 146 | criterionCAE = nn.L1Loss() 147 | criterionCX = CXLoss(sigma=0.5, spatial_weight=0.5) 148 | criterionCX_patch = CXLoss(sigma=0.5, spatial_weight=0.5) 149 | netGDN.to(device) 150 | 151 | criterionCAE.to(device) 152 | criterionCX.to(device) 153 | criterionCX_patch.to(device) 154 | 155 | lambdaIMG = opt.lambdaIMG 156 | lambdaCX = opt.lambdaCX 157 | 158 | # get optimizer 159 | my_lrG = opt.lrG - opt.epoch_count * (opt.lrG/opt.annealEvery) 160 | optimizerG = optim.Adam(netGDN.parameters(), lr = my_lrG, betas = (opt.beta1, 0.999)) 161 | 162 | # NOTE training loop 163 | ganIterations = 0 164 | total_steps = 0 165 | dataset_size = len(dataloader) 166 | 167 | for epoch in range(opt.epoch_count, opt.annealEvery): 168 | trainLogger = open('%s/train.log' % opt.exp, 'a+') 169 | # switch the state ! 170 | netGDN.train() 171 | 172 | epoch_iter = 0 173 | epoch_start_time = time.time() 174 | iter_data_time = time.time() 175 | 176 | my_psnr = 0 177 | my_ssim = 0 178 | my_ssim_multi = 0 179 | 180 | adjust_learning_rate(optimizerG, opt.lrG, epoch, None, opt.annealEvery) 181 | print(50 * '-' + 'lr' + 50 * '-') 182 | print(str(optimizerG.param_groups[0]['lr'])) 183 | print(50 * '-' + 'lr' + 50 * '-') 184 | 185 | ccnt = 0 186 | for i, data in enumerate(dataloader, 0): 187 | 188 | netGDN.train() 189 | iter_start_time = time.time() 190 | if total_steps % 100 == 0: 191 | t_data = iter_start_time - iter_data_time 192 | visualizer.reset() 193 | total_steps += opt.batchSize 194 | epoch_iter += opt.batchSize 195 | 196 | input, target, name = data 197 | batch_size = target.size(0) 198 | input = input.cuda() 199 | target = target.cuda() 200 | 201 | optimizerG.zero_grad() 202 | oups = netGDN(input) 203 | vgg_target = vgg19(target) 204 | 205 | feat_target = vgg_target['conv3_2'] 206 | CX_loss_list = [criterionCX(vgg19(x_hat)['conv3_2'] ,feat_target) for x_hat in oups] 207 | loss_CX = CX_loss_list[0] 208 | 209 | L = loss_CX 210 | L.backward() 211 | optimizerG.step() 212 | ganIterations += 1 213 | 214 | for i in range(x_hat.shape[0]): 215 | ccnt += 1 216 | ti1 = x_hat[i, :, :, :] 217 | tt1 = target[i, :, :, :] 218 | mi1 = util.my_tensor2im(ti1) 219 | mt1 = util.my_tensor2im(tt1) 220 | g_mi1 = cv2.cvtColor(mi1, cv2.COLOR_BGR2RGB) 221 | g_mt1 = cv2.cvtColor(mt1, cv2.COLOR_BGR2RGB) 222 | # cv2.imwrite("res.jpg", mt1) 223 | my_psnr += Psnr(g_mt1, g_mi1) 224 | my_ssim_multi += ssim(g_mt1, g_mi1, multichannel=True) 225 | 226 | if total_steps % 100 == 0: 227 | 228 | current_visuals = OrderedDict([('input', input), 229 | ('output0', oups[0]), 230 | ('GT', target) 231 | ]) 232 | 233 | losses = OrderedDict([ 234 | ('loss_CX', loss_CX.detach().cpu().float().numpy()), 235 | ('my_psnr', my_psnr / (ccnt)), ('my_ssim_multi', my_ssim_multi / ccnt)]) 236 | # print(losses) 237 | t = (time.time() - iter_start_time) / opt.batchSize 238 | trainLogger.write(visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) + '\n') 239 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 240 | r = float(epoch_iter) / (dataset_size*opt.batchSize) 241 | if opt.display_port!=-1: 242 | visualizer.display_current_results(current_visuals, epoch, False) 243 | visualizer.plot_current_losses(epoch, r, opt, losses) 244 | netGDN.train() 245 | 246 | if epoch % 5 == 0: 247 | 248 | print('hit') 249 | my_file = open("./" + opt.name + "_" + "evaluation.txt", 'a+') 250 | torch.save(netGDN.state_dict(), '%s/netG_epoch_%d.pth' % (opt.exp, epoch)) 251 | torch.save(netGDN.state_dict(), 'ckpt/netGDN.pth') 252 | # vcnt = 0 253 | # vpsnr = 0 254 | # vssim = 0 255 | # netGDN.eval() 256 | # for i, data in enumerate(val_dataloader, 0): 257 | # input, target = data 258 | # batch_size = target.size(0) 259 | # input, target = input.to(device), target.to(device) 260 | # x_hat = netGDN(input) 261 | # for i in range(x_hat.shape[0]): 262 | # vcnt += 1 263 | # ti1 = x_hat[i, :, :, :] 264 | # tt1 = target[i, :, :, :] 265 | # mi1 = util.my_tensor2im(ti1) 266 | # mt1 = util.my_tensor2im(tt1) 267 | # g_mi1 = cv2.cvtColor(mi1, cv2.COLOR_BGR2RGB) 268 | # g_mt1 = cv2.cvtColor(mt1, cv2.COLOR_BGR2RGB) 269 | # vpsnr += Psnr(g_mt1, g_mi1) 270 | # 271 | # my_file.write(str(epoch) + str('-') + str(total_steps) + '\n') 272 | # my_file.write(str(float(vpsnr) / vcnt) + '\n') 273 | # print("val:") 274 | # print(float(vpsnr) / vcnt) 275 | # trainLogger.close() 276 | # netGDN.train() 277 | 278 | 279 | my_file.close() 280 | trainLogger.close() -------------------------------------------------------------------------------- /list_7000_f1000.txt: -------------------------------------------------------------------------------- 1 | 05867 2 | 00011 3 | 00031 4 | 00035 5 | 00038 6 | 00058 7 | 00123 8 | 00130 9 | 00138 10 | 00166 11 | 00168 12 | 00183 13 | 00190 14 | 00197 15 | 00211 16 | 00229 17 | 00265 18 | 00275 19 | 00280 20 | 00293 21 | 00314 22 | 00333 23 | 00369 24 | 00374 25 | 00390 26 | 00409 27 | 04611 28 | 04612 29 | 04616 30 | 04630 31 | 04642 32 | 04643 33 | 04665 34 | 04673 35 | 04680 36 | 04690 37 | 04718 38 | 04734 39 | 04783 40 | 04789 41 | 04793 42 | 04830 43 | 04851 44 | 04865 45 | 04871 46 | 04878 47 | 08758 48 | 08783 49 | 08796 50 | 08818 51 | 08838 52 | 08841 53 | 08849 54 | 08895 55 | 08916 56 | 08918 57 | 08959 58 | 08961 59 | 08963 60 | 08969 61 | 08987 62 | 08993 63 | 09035 64 | 09076 65 | 09077 66 | 09095 67 | 09123 68 | 09136 69 | 09155 70 | 09159 71 | 09166 72 | 09182 73 | 09186 74 | 09187 75 | 02742 76 | 02744 77 | 02747 78 | 02755 79 | 02776 80 | 02778 81 | 02779 82 | 02783 83 | 02789 84 | 02790 85 | 02810 86 | 02811 87 | 02849 88 | 02854 89 | 02855 90 | 02874 91 | 02875 92 | 02885 93 | 02899 94 | 02906 95 | 06168 96 | 06172 97 | 06182 98 | 06212 99 | 06230 100 | 06271 101 | 06275 102 | 06277 103 | 06296 104 | 06298 105 | 06306 106 | 06311 107 | 06316 108 | 06324 109 | 06325 110 | 06338 111 | 06342 112 | 06353 113 | 06375 114 | 06423 115 | 06432 116 | 06462 117 | 06463 118 | 06495 119 | 06516 120 | 06531 121 | 06534 122 | 00431 123 | 00764 124 | 01070 125 | 01421 126 | 01752 127 | 02045 128 | 02423 129 | 02740 130 | 02917 131 | 03135 132 | 03505 133 | 03761 134 | 04164 135 | 04379 136 | 04610 137 | 04882 138 | 05078 139 | 05430 140 | 05596 141 | 01427 142 | 01442 143 | 01452 144 | 01462 145 | 01479 146 | 01481 147 | 01485 148 | 01502 149 | 01506 150 | 01508 151 | 01524 152 | 01543 153 | 01569 154 | 01576 155 | 01620 156 | 01625 157 | 01629 158 | 01642 159 | 01661 160 | 01664 161 | 01669 162 | 01675 163 | 01684 164 | 01728 165 | 01750 166 | 10400 167 | 10419 168 | 10421 169 | 10448 170 | 10451 171 | 10454 172 | 10475 173 | 10476 174 | 10478 175 | 10491 176 | 10506 177 | 10537 178 | 10571 179 | 10618 180 | 10621 181 | 10622 182 | 10628 183 | 10633 184 | 10638 185 | 10644 186 | 10652 187 | 10658 188 | 03764 189 | 03789 190 | 03798 191 | 03818 192 | 03825 193 | 03848 194 | 03863 195 | 03867 196 | 03869 197 | 03879 198 | 03893 199 | 03898 200 | 03907 201 | 03914 202 | 03915 203 | 03921 204 | 03975 205 | 03976 206 | 03982 207 | 03985 208 | 04005 209 | 04007 210 | 04009 211 | 04014 212 | 04027 213 | 04034 214 | 04038 215 | 04045 216 | 04060 217 | 04071 218 | 04095 219 | 04098 220 | 04118 221 | 04122 222 | 04138 223 | 07306 224 | 07320 225 | 07322 226 | 07336 227 | 07349 228 | 07350 229 | 07369 230 | 07406 231 | 07415 232 | 07421 233 | 07424 234 | 07436 235 | 07441 236 | 07465 237 | 07466 238 | 07471 239 | 07476 240 | 07484 241 | 07498 242 | 07529 243 | 07530 244 | 07568 245 | 07569 246 | 07578 247 | 07601 248 | 07651 249 | 07662 250 | 11131 251 | 11151 252 | 11163 253 | 11176 254 | 11192 255 | 11249 256 | 11287 257 | 11331 258 | 11341 259 | 11360 260 | 11382 261 | 11468 262 | 11501 263 | 11510 264 | 11517 265 | 11546 266 | 11557 267 | 11565 268 | 11596 269 | 11599 270 | 02166 271 | 02177 272 | 02185 273 | 02186 274 | 02197 275 | 02199 276 | 02208 277 | 02210 278 | 02212 279 | 02215 280 | 02216 281 | 02240 282 | 02248 283 | 02251 284 | 02253 285 | 02276 286 | 02279 287 | 02280 288 | 02282 289 | 02291 290 | 02308 291 | 02316 292 | 02319 293 | 02333 294 | 02334 295 | 02345 296 | 02351 297 | 02354 298 | 02357 299 | 02361 300 | 02364 301 | 02385 302 | 02405 303 | 02409 304 | 02414 305 | 05432 306 | 05438 307 | 05442 308 | 05447 309 | 05449 310 | 05451 311 | 05466 312 | 05471 313 | 05505 314 | 05508 315 | 05509 316 | 05515 317 | 05536 318 | 05545 319 | 05549 320 | 05550 321 | 05566 322 | 05576 323 | 05580 324 | 05587 325 | 05591 326 | 09505 327 | 09509 328 | 09517 329 | 09524 330 | 09544 331 | 09557 332 | 09561 333 | 09566 334 | 09575 335 | 09589 336 | 09606 337 | 09607 338 | 09612 339 | 09621 340 | 09666 341 | 09674 342 | 09677 343 | 09688 344 | 09693 345 | 09696 346 | 09715 347 | 09721 348 | 09723 349 | 09729 350 | 09748 351 | 09801 352 | 09810 353 | 09880 354 | 09885 355 | 09887 356 | 09892 357 | 09908 358 | 00765 359 | 00769 360 | 00774 361 | 00785 362 | 00799 363 | 00822 364 | 00835 365 | 00839 366 | 00841 367 | 00843 368 | 00862 369 | 00870 370 | 00889 371 | 00898 372 | 00914 373 | 00918 374 | 00926 375 | 00945 376 | 00970 377 | 00972 378 | 00976 379 | 01015 380 | 01066 381 | 03137 382 | 03148 383 | 03149 384 | 03186 385 | 03197 386 | 03215 387 | 03234 388 | 03255 389 | 03259 390 | 03282 391 | 03338 392 | 03351 393 | 03359 394 | 03365 395 | 03369 396 | 03405 397 | 03417 398 | 03434 399 | 03475 400 | 03493 401 | 03494 402 | 08089 403 | 08090 404 | 08102 405 | 08160 406 | 08203 407 | 08210 408 | 08224 409 | 08239 410 | 08241 411 | 08243 412 | 08261 413 | 08279 414 | 08314 415 | 08346 416 | 08375 417 | 08380 418 | 08393 419 | 08399 420 | 08400 421 | 08405 422 | 08407 423 | 08416 424 | 08421 425 | 08446 426 | 08454 427 | 08458 428 | 02424 429 | 02429 430 | 02442 431 | 02456 432 | 02480 433 | 02485 434 | 02503 435 | 02509 436 | 02528 437 | 02540 438 | 02550 439 | 02569 440 | 02570 441 | 02576 442 | 02579 443 | 02587 444 | 02603 445 | 02621 446 | 02626 447 | 02637 448 | 02640 449 | 02646 450 | 02652 451 | 02656 452 | 02668 453 | 02670 454 | 02680 455 | 02696 456 | 02702 457 | 02707 458 | 02725 459 | 11857 460 | 11862 461 | 11866 462 | 11873 463 | 11887 464 | 11931 465 | 11932 466 | 11937 467 | 11942 468 | 11946 469 | 11972 470 | 12004 471 | 12023 472 | 12037 473 | 12039 474 | 12049 475 | 12069 476 | 12104 477 | 12120 478 | 12129 479 | 12145 480 | 12150 481 | 12154 482 | 12155 483 | 12161 484 | 12188 485 | 06828 486 | 06838 487 | 06843 488 | 06875 489 | 06886 490 | 06909 491 | 06940 492 | 06954 493 | 06972 494 | 06974 495 | 06980 496 | 07008 497 | 07013 498 | 07029 499 | 07030 500 | 07049 501 | 07069 502 | 07075 503 | 07087 504 | 07088 505 | 07090 506 | 07119 507 | 07131 508 | 07136 509 | 07150 510 | 07168 511 | 07187 512 | 07200 513 | 07204 514 | 07217 515 | 07260 516 | 07262 517 | 07275 518 | 07278 519 | 05081 520 | 05085 521 | 05088 522 | 05107 523 | 05136 524 | 05143 525 | 05150 526 | 05169 527 | 05170 528 | 05177 529 | 05196 530 | 05233 531 | 05236 532 | 05237 533 | 05252 534 | 05266 535 | 05279 536 | 05297 537 | 05302 538 | 05303 539 | 05305 540 | 05310 541 | 05317 542 | 05336 543 | 05341 544 | 05357 545 | 05360 546 | 05367 547 | 05372 548 | 05375 549 | 05380 550 | 05382 551 | 05420 552 | 04170 553 | 04174 554 | 04176 555 | 04177 556 | 04181 557 | 04215 558 | 04223 559 | 04229 560 | 04234 561 | 04236 562 | 04245 563 | 04273 564 | 04286 565 | 04302 566 | 04316 567 | 04318 568 | 04326 569 | 04327 570 | 04346 571 | 04349 572 | 04358 573 | 04366 574 | 09963 575 | 09969 576 | 09978 577 | 10006 578 | 10021 579 | 10031 580 | 10033 581 | 10044 582 | 10046 583 | 10059 584 | 10070 585 | 10073 586 | 10092 587 | 10096 588 | 10104 589 | 10111 590 | 10116 591 | 10119 592 | 05889 593 | 05891 594 | 05898 595 | 05915 596 | 05922 597 | 05924 598 | 05936 599 | 05937 600 | 05952 601 | 05954 602 | 05961 603 | 05987 604 | 05990 605 | 06011 606 | 06022 607 | 06026 608 | 06031 609 | 06048 610 | 06068 611 | 06077 612 | 06082 613 | 06083 614 | 06111 615 | 06119 616 | 06130 617 | 06136 618 | 06145 619 | 06150 620 | 06157 621 | 00452 622 | 00458 623 | 00477 624 | 00513 625 | 00519 626 | 00532 627 | 00539 628 | 00556 629 | 00581 630 | 00592 631 | 00593 632 | 00605 633 | 00614 634 | 00624 635 | 00625 636 | 00626 637 | 00636 638 | 00639 639 | 00642 640 | 00646 641 | 00666 642 | 00689 643 | 00692 644 | 00712 645 | 00725 646 | 00734 647 | 00746 648 | 00749 649 | 10697 650 | 10710 651 | 10726 652 | 10736 653 | 10746 654 | 10762 655 | 10774 656 | 10801 657 | 10817 658 | 10864 659 | 10868 660 | 10877 661 | 10917 662 | 10939 663 | 10957 664 | 10967 665 | 10981 666 | 10991 667 | 11013 668 | 11022 669 | 11023 670 | 11033 671 | 11044 672 | 11046 673 | 11067 674 | 11088 675 | 11097 676 | 11099 677 | 11107 678 | 11112 679 | 01072 680 | 01084 681 | 01112 682 | 01122 683 | 01130 684 | 01137 685 | 01138 686 | 01164 687 | 01205 688 | 01210 689 | 01216 690 | 01222 691 | 01223 692 | 01240 693 | 01242 694 | 01248 695 | 01251 696 | 01273 697 | 01276 698 | 01289 699 | 01298 700 | 01306 701 | 01312 702 | 01315 703 | 01319 704 | 01325 705 | 01372 706 | 01379 707 | 01390 708 | 01392 709 | 09210 710 | 09239 711 | 09275 712 | 09281 713 | 09282 714 | 09313 715 | 09344 716 | 09360 717 | 09365 718 | 09370 719 | 09372 720 | 09374 721 | 09384 722 | 09388 723 | 09412 724 | 09429 725 | 09449 726 | 09450 727 | 09468 728 | 09472 729 | 09486 730 | 09500 731 | 01759 732 | 01761 733 | 01766 734 | 01770 735 | 01773 736 | 01786 737 | 01794 738 | 01809 739 | 01814 740 | 01816 741 | 01826 742 | 01830 743 | 01838 744 | 01842 745 | 01849 746 | 01921 747 | 01948 748 | 01955 749 | 01961 750 | 01974 751 | 01983 752 | 01997 753 | 02004 754 | 02016 755 | 02039 756 | 02041 757 | 12211 758 | 12223 759 | 12234 760 | 12242 761 | 12243 762 | 12250 763 | 12272 764 | 12275 765 | 12290 766 | 12302 767 | 12314 768 | 12344 769 | 12355 770 | 12359 771 | 12370 772 | 12374 773 | 12380 774 | 12387 775 | 12392 776 | 12400 777 | 12408 778 | 07686 779 | 07699 780 | 07707 781 | 07727 782 | 07762 783 | 07774 784 | 07815 785 | 07819 786 | 07820 787 | 07840 788 | 07860 789 | 07872 790 | 07886 791 | 07901 792 | 07939 793 | 07941 794 | 07969 795 | 07977 796 | 08012 797 | 08015 798 | 08024 799 | 08034 800 | 08057 801 | 08059 802 | 08074 803 | 08080 804 | 08505 805 | 08506 806 | 08525 807 | 08542 808 | 08543 809 | 08554 810 | 08556 811 | 08594 812 | 08608 813 | 08614 814 | 08619 815 | 08655 816 | 08664 817 | 08688 818 | 08694 819 | 08695 820 | 08705 821 | 08716 822 | 08718 823 | 08719 824 | 08721 825 | 08746 826 | 06539 827 | 06577 828 | 06591 829 | 06597 830 | 06606 831 | 06621 832 | 06659 833 | 06663 834 | 06670 835 | 06687 836 | 06688 837 | 06705 838 | 06725 839 | 06729 840 | 06750 841 | 06761 842 | 06765 843 | 06777 844 | 06795 845 | 04886 846 | 04899 847 | 04905 848 | 04916 849 | 04932 850 | 04940 851 | 04945 852 | 04961 853 | 04965 854 | 04986 855 | 04993 856 | 04999 857 | 05003 858 | 05004 859 | 05011 860 | 05016 861 | 05030 862 | 05038 863 | 03527 864 | 03540 865 | 03547 866 | 03553 867 | 03557 868 | 03559 869 | 03597 870 | 03607 871 | 03619 872 | 03626 873 | 03629 874 | 03631 875 | 03651 876 | 03654 877 | 03677 878 | 03679 879 | 03690 880 | 03698 881 | 03728 882 | 03730 883 | 03734 884 | 03758 885 | 05608 886 | 05666 887 | 05712 888 | 05723 889 | 05728 890 | 05729 891 | 05739 892 | 05762 893 | 05766 894 | 05783 895 | 05796 896 | 05799 897 | 05811 898 | 05813 899 | 05817 900 | 05819 901 | 05821 902 | 05838 903 | 05846 904 | 05853 905 | 06163 906 | 06536 907 | 06824 908 | 07297 909 | 07673 910 | 08086 911 | 08483 912 | 08753 913 | 09201 914 | 09501 915 | 09910 916 | 10123 917 | 10377 918 | 10660 919 | 11118 920 | 11608 921 | 11852 922 | 12208 923 | 04391 924 | 04397 925 | 04404 926 | 04411 927 | 04424 928 | 04431 929 | 04432 930 | 04440 931 | 04443 932 | 04457 933 | 04489 934 | 04516 935 | 04517 936 | 04540 937 | 04552 938 | 04554 939 | 04578 940 | 04581 941 | 04583 942 | 04589 943 | 04608 944 | 10131 945 | 10144 946 | 10177 947 | 10181 948 | 10193 949 | 10251 950 | 10259 951 | 10262 952 | 10275 953 | 10282 954 | 10288 955 | 10299 956 | 10308 957 | 10328 958 | 10331 959 | 10333 960 | 10339 961 | 10344 962 | 10345 963 | 11618 964 | 11633 965 | 11638 966 | 11647 967 | 11666 968 | 11677 969 | 11688 970 | 11704 971 | 11736 972 | 11739 973 | 11746 974 | 11753 975 | 11769 976 | 11771 977 | 11778 978 | 11795 979 | 11800 980 | 11805 981 | 11818 982 | 02925 983 | 02946 984 | 02949 985 | 02962 986 | 02972 987 | 02975 988 | 02990 989 | 02993 990 | 03008 991 | 03017 992 | 03021 993 | 03023 994 | 03029 995 | 03036 996 | 03076 997 | 03108 998 | 03117 999 | 03126 1000 | 03133 1001 | -------------------------------------------------------------------------------- /models_metric/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | # from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import models as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s 1x1xCxHW --> HWxCx1x1 118 | patches = features.view(1, 1, C, P).permute((3, 2, 0, 1)) 119 | return patches 120 | 121 | def calc_relative_distances(self, raw_dist, axis=1): 122 | epsilon = 1e-5 123 | div = torch.min(raw_dist, dim=axis, keepdim=True)[0] 124 | relative_dist = raw_dist / (div + epsilon) 125 | return relative_dist 126 | 127 | def calc_CX(self, dist, axis=1): 128 | W = torch.exp((self.b - dist) / self.sigma) 129 | W_sum = W.sum(dim=axis, keepdim=True) 130 | return W.div(W_sum) 131 | 132 | def forward(self, I_features, T_features): 133 | 134 | cx_sp, cx_feat = self.create(I_features, T_features) 135 | CX = (1. - self.spatial_weight) * cx_feat + self.spatial_weight * cx_sp 136 | 137 | CX = CX.max(dim=3)[0].max(dim=2)[0] 138 | CX = CX.mean(1) 139 | CX = -torch.log(CX) 140 | CX = torch.mean(CX) 141 | return CX 142 | 143 | 144 | def init_params(modules): 145 | for m in modules: 146 | if isinstance(m, nn.Conv2d): 147 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 148 | if m.bias is not None: 149 | nn.init.constant_(m.bias, 0) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | elif isinstance(m, nn.Linear): 154 | nn.init.normal_(m.weight, std=0.02) 155 | if m.bias is not None: 156 | nn.init.constant_(m.bias, 0) 157 | 158 | 159 | class Down(nn.Module): 160 | 161 | def __init__(self, size, in_channels, out_channels): 162 | super(Down, self).__init__() 163 | self.size = size 164 | self.features = [nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1), 165 | nn.LayerNorm([out_channels, size, size]), 166 | nn.LeakyReLU(0.2), 167 | nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1), 168 | nn.LayerNorm([out_channels, size, size]), 169 | nn.LeakyReLU(0.2)] 170 | self.features = nn.Sequential(*self.features) 171 | self.upsample = nn.Upsample(size=(self.size, self.size), mode='bilinear') 172 | 173 | def forward(self, image, x=None): 174 | out = self.upsample(image) 175 | if x is not None: 176 | out = torch.cat([x, out], dim=1) 177 | 178 | return self.features(out) 179 | 180 | 181 | class Generator(nn.Module): 182 | 183 | def __init__(self, image_size): 184 | super(Generator, self).__init__() 185 | 186 | self.image_size = image_size 187 | 188 | self.image_sizes = [256, 128, 64, 32, 16, 8, 4] 189 | 190 | self.in_dims = [259, 515, 515, 515, 515, 515, 3] 191 | self.out_dims = [256, 256, 512, 512, 512, 512, 512] 192 | 193 | self.rec = nn.Sigmoid() 194 | self.conv = nn.Conv2d(256, 3, kernel_size=(1, 1)) 195 | self.down1 = Down(self.image_sizes[0], 259, 256) 196 | self.donw2 = Down(self.image_sizes[1], 515, 256) 197 | self.down3 = Down(self.image_sizes[2], 515, 512) 198 | self.down4 = Down(self.image_sizes[3], 515, 512) 199 | self.down5 = Down(self.image_sizes[4], 515, 512) 200 | self.down6 = Down(self.image_sizes[5], 515, 512) 201 | self.down7 = Down(self.image_sizes[6], 3, 512) 202 | 203 | init_params(self.modules()) 204 | 205 | def forward(self, x): 206 | down7 = self.down7(x) 207 | down7 = F.interpolate(down7, size=(self.image_sizes[-2], self.image_sizes[-2]), mode='bilinear') 208 | down6 = self.down6(x, down7) 209 | down6 = F.interpolate(down6, size=(self.image_sizes[-3], self.image_sizes[-3]), mode='bilinear') 210 | down5 = self.down5(x, down6) 211 | down5 = F.interpolate(down5, size=(self.image_sizes[-4], self.image_sizes[-4]), mode='bilinear') 212 | down4 = self.down4(x, down5) 213 | down4 = F.interpolate(down4, size=(self.image_sizes[-5], self.image_sizes[-5]), mode='bilinear') 214 | down3 = self.down3(x, down4) 215 | down3 = F.interpolate(down3, size=(self.image_sizes[-6], self.image_sizes[-6]), mode='bilinear') 216 | down2 = self.donw2(x, down3) 217 | down2 = F.interpolate(down2, size=(self.image_sizes[-7], self.image_sizes[-7]), mode='bilinear') 218 | down1 = self.down1(x, down2) 219 | return (self.conv(down1) + 1.) / 2. 220 | 221 | 222 | # class Down(nn.Module): 223 | # 224 | # def __init__(self, in_channels, out_channels, BN=False, IN=True): 225 | # super(Down, self).__init__() 226 | # modules = [nn.LeakyReLU(0.2, inplace=False), 227 | # nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)] 228 | # if BN: 229 | # modules.append(nn.BatchNorm2d(out_channels)) 230 | # if IN: 231 | # modules.append(nn.InstanceNorm2d(out_channels)) 232 | # 233 | # self.feature = nn.Sequential(*modules) 234 | # init_params(self.feature) 235 | # 236 | # def forward(self, x): 237 | # return self.feature(x) 238 | 239 | 240 | class Up(nn.Module): 241 | 242 | def __init__(self, in_channels, out_channels, BN=False, IN=True, dropout=True): 243 | super(Up, self).__init__() 244 | modules = [nn.LeakyReLU(0.2, inplace=False), 245 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)] 246 | if BN: 247 | modules.append(nn.BatchNorm2d(out_channels)) 248 | if IN: 249 | modules.append(nn.InstanceNorm2d(out_channels)) 250 | if dropout: 251 | modules.append(nn.Dropout(0.5)) 252 | 253 | self.feature = nn.Sequential(*modules) 254 | 255 | def forward(self, c1, c2=None): 256 | c1 = self.feature(c1) 257 | if c2 is not None: 258 | return torch.cat([c1, c2], dim=1) 259 | else: 260 | return c1 261 | 262 | 263 | # class Generator(nn.Module): 264 | # 265 | # def __init__(self, out_channels=64): 266 | # super(Generator, self).__init__() 267 | # # 256 x 256 x 6 268 | # self.down1 = nn.Conv2d(6, out_channels, kernel_size=4, stride=2, padding=1) 269 | # # 128 x 128 x 64 270 | # self.down2 = Down(out_channels, out_channels * 2) 271 | # # 64 x 64 x 128 272 | # self.down3 = Down(out_channels * 2, out_channels * 2) 273 | # # 32 x 32 x 256 274 | # self.down4 = Down(out_channels * 2, out_channels * 2) 275 | # # 16 x 16 x 512 276 | # self.down5 = Down(out_channels * 2, out_channels * 2) 277 | # # 8 x 8 x 512 278 | # self.down6 = Down(out_channels * 2, out_channels * 2) 279 | # # 4 x 4 x 512 280 | # self.down7 = Down(out_channels * 2, out_channels * 2) 281 | # # 2 x 2 x 512 282 | # self.down8 = Down(out_channels * 2, out_channels * 2, IN=False) 283 | # 284 | # # 1 x 1 x 512 285 | # self.up1 = Up(out_channels * 2, out_channels * 2) 286 | # # 2 x 2 x (512 + 512) 287 | # self.up2 = Up(out_channels * 2 * 2, out_channels * 2) 288 | # # 4 x 4 x (512 + 512) 289 | # self.up3 = Up(out_channels * 2 * 2, out_channels * 2) 290 | # # 8 x 8 x (512 + 512) 291 | # self.up4 = Up(out_channels * 2 * 2, out_channels * 2, dropout=False) 292 | # # 16 x 16 x (512 + 512) 293 | # self.up5 = Up(out_channels * 2 * 2, out_channels * 2, dropout=False) 294 | # # 32 x 32 x (256 + 256) 295 | # self.up6 = Up(out_channels * 2 * 2, out_channels * 2, dropout=False) 296 | # # 64 x 64 x (128 + 128) 297 | # self.up7 = Up(out_channels * 2 * 2, out_channels, dropout=False) 298 | # # 128 x 128 x (64 + 64) 299 | # self.up8 = Up(out_channels * 2, 3, IN=False, dropout=False) 300 | # # 256 x 256 x 3 301 | # self.rec = nn.Sigmoid() 302 | # 303 | # def forward(self, s, t): 304 | # ''' 305 | # :param dImage: degraded image 306 | # :param wImage: wrap guidance image 307 | # :return: 308 | # ''' 309 | # x = torch.cat([s, t], dim=1) 310 | # down1 = self.down1(x) 311 | # down2 = self.down2(down1) 312 | # down3 = self.down3(down2) 313 | # down4 = self.down4(down3) 314 | # down5 = self.down5(down4) 315 | # down6 = self.down6(down5) 316 | # down7 = self.down7(down6) 317 | # down8 = self.down8(down7) 318 | # 319 | # up1 = self.up1(down8, down7) 320 | # up2 = self.up2(up1, down6) 321 | # up3 = self.up3(up2, down5) 322 | # up4 = self.up4(up3, down4) 323 | # up5 = self.up5(up4, down3) 324 | # up6 = self.up6(up5, down2) 325 | # up7 = self.up7(up6, down1) 326 | # up8 = self.up8(up7) 327 | # rec = self.rec(up8) 328 | # return rec 329 | -------------------------------------------------------------------------------- /train_LRN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import sys 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | cudnn.benchmark = True 11 | cudnn.fastest = True 12 | import torch.optim as optim 13 | # import torchvision.utils as vutils 14 | from torch.autograd import Variable 15 | from misc import * 16 | import models.networks as net 17 | from myutils.vgg16 import Vgg16 18 | from myutils import utils 19 | from visualizer import Visualizer 20 | import time 21 | import pdb 22 | import torch.nn.functional as F 23 | import sys 24 | import os 25 | from PIL import Image 26 | import math 27 | import numpy as np 28 | from skimage.measure import compare_ssim as ssim 29 | from skimage.measure import compare_psnr as Psnr 30 | import cv2 31 | import util 32 | from collections import OrderedDict 33 | 34 | from model.vgg import VGG19 35 | from model.generator import CXLoss 36 | 37 | import models_metric 38 | import itertools 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--dataset', required=False, default='my_loader_LRN_f2_rand3', help='') 42 | parser.add_argument('--dataroot', required=False, default='', help='path to trn dataset') 43 | parser.add_argument('--valDataroot', required=False, default='', help='path to val dataset') 44 | parser.add_argument('--pre', type=str, default='', help='prefix of different dataset') 45 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 46 | parser.add_argument('--epoch_count', type=int, default=1, help='input batch size') 47 | parser.add_argument('--valBatchSize', type=int, default=1, help='input batch size') 48 | parser.add_argument('--originalSize_h', type=int, default=539, help='the height / width of the original input image') 49 | parser.add_argument('--originalSize_w', type=int, default=959, help='the height / width of the original input image') 50 | parser.add_argument('--imageSize_h', type=int, default=512, help='the height / width of the cropped input image to network') 51 | parser.add_argument('--imageSize_w', type=int, default=512, help='the height / width of the cropped input image to network') 52 | parser.add_argument('--inputChannelSize', type=int, default=3, help='size of the input channels') 53 | parser.add_argument('--outputChannelSize', type=int, default=3, help='size of the output channels') 54 | parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') 55 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 56 | parser.add_argument('--annealEvery', type=int, default=150, help='epoch to reaching at learning rate of 0') 57 | parser.add_argument('--lambdaIMG', type=float, default=1, help='lambdaIMG') 58 | parser.add_argument('--lambdaCX', type=float, default=1, help='lambdaCX') 59 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam') 60 | 61 | parser.add_argument('--netGDN', default='', help="path to netGDN (to continue training)") 62 | parser.add_argument('--netLRN', default='', help="path to netLRN (to continue training)") 63 | parser.add_argument('--kernel_size', type=int, default=8, help='patch size for dct') 64 | 65 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) 66 | parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints') 67 | parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs') 68 | parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot') 69 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 70 | parser.add_argument('--name', type=str, default='experiment_name', 71 | help='name of the experiment. It decides where to store samples and models') 72 | parser.add_argument('--vgg19', default='./models/vgg19-dcbb9e9d.pth', help="path to vgg19.pth") 73 | parser.add_argument('--list_file', type=str, default='', help='list_file') 74 | parser.add_argument('--spatial_weight', type=float, default=0.5, help='spatial weight for CXloss') 75 | 76 | opt = parser.parse_args() 77 | print(opt) 78 | opt.manualSeed = random.randint(1, 10000) 79 | create_exp_dir(opt.exp) 80 | device = torch.device("cuda:0") 81 | 82 | # get dataloader 83 | dataloader = getLoader(opt.dataset, 84 | opt.dataroot, 85 | opt.originalSize_h, 86 | opt.originalSize_w, 87 | opt.imageSize_h, 88 | opt.imageSize_w, 89 | opt.batchSize, 90 | opt.workers, 91 | # mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), 92 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 93 | split='LRN_train_guide2', 94 | shuffle=True, 95 | seed=opt.manualSeed, 96 | pre=opt.pre, 97 | label_file='final_position.txt', 98 | list_file=opt.list_file) 99 | 100 | # val_dataloader = getLoader("my_loader_LSN", 101 | # opt.dataroot, 102 | # 1024, 103 | # 1024, 104 | # 1024, 105 | # 1024, 106 | # opt.batchSize, 107 | # opt.workers, 108 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 109 | # split='LSN_test', 110 | # shuffle=False, 111 | # seed=opt.manualSeed, 112 | # pre=opt.pre) 113 | 114 | print(len(dataloader)) 115 | 116 | # val_iterator = enumerate(val_dataloader, 0) 117 | # val_cnt = 0 118 | 119 | # get logger 120 | trainLogger = open('%s/train.log' % opt.exp, 'a+') 121 | 122 | inputChannelSize = opt.inputChannelSize 123 | outputChannelSize = opt.outputChannelSize 124 | 125 | visualizer = Visualizer(opt.display_port, opt.name) 126 | 127 | # # get models 128 | netGDN = net.GDN() 129 | if opt.netGDN != '': 130 | print("load pre-trained GSN model!!!!!!!!!!!!!!!!!") 131 | netGDN.load_state_dict(torch.load(opt.netGDN)) 132 | netGDN.eval() 133 | utils.set_requires_grad(netGDN, False) 134 | 135 | netLRN = net.LRN() 136 | if opt.netLRN != '': 137 | print("load pre-trained LSN model!!!!!!!!!!!!!!!!!") 138 | netLRN.load_state_dict(torch.load(opt.netLRN)) 139 | netLRN.train() 140 | 141 | 142 | 143 | # vgg = Vgg16() 144 | # utils.init_vgg16('./models/') 145 | # vgg.load_state_dict(torch.load(os.path.join('./models/', "vgg16.weight"))) 146 | # vgg.to(device) 147 | 148 | vgg19 = VGG19() 149 | vgg19.load_model(opt.vgg19) 150 | vgg19.to(device) 151 | utils.set_requires_grad(vgg19, False) 152 | vgg19.eval() 153 | # vgg_layers = ['conv3_3', 'conv4_2'] 154 | # vgg_layers = ['conv1_2', 'conv2_2', 'conv3_2'] 155 | # vgg_layers = {'conv1_2': 1.0, 'conv2_2': 1.0, 'conv3_2':0.5} 156 | # vgg_layers = {'conv2_2': 1.0} 157 | vgg_layers = {'conv3_2': 1} 158 | criterionCAE = nn.L1Loss() 159 | criterionCX = CXLoss(sigma=0.5, spatial_weight=opt.spatial_weight) 160 | criterionCX_patch = CXLoss(sigma=0.5, spatial_weight=0.5) 161 | 162 | netLRN.to(device) 163 | netGDN.to(device) 164 | 165 | criterionCAE.to(device) 166 | criterionCX.to(device) 167 | criterionCX_patch.to(device) 168 | 169 | lambdaIMG = opt.lambdaIMG 170 | lambdaCX = opt.lambdaCX 171 | 172 | # get optimizer 173 | my_lrG = opt.lrG - opt.epoch_count * (opt.lrG/opt.annealEvery) 174 | optimizerG = optim.Adam( netLRN.parameters(), lr = my_lrG, betas = (opt.beta1, 0.999)) 175 | # NOTE training loop 176 | ganIterations = 0 177 | total_steps = 0 178 | dataset_size = len(dataloader) 179 | st = [0, 256, 512, 640] 180 | patch_size = 384 181 | 182 | net_metric = models_metric.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, spatial=False) 183 | net_metric = net_metric.cuda() 184 | utils.set_requires_grad(net_metric, requires_grad=False) 185 | demoire_up_patch = torch.zeros([opt.batchSize, 3, patch_size, patch_size]).cuda() 186 | res = 0 187 | train_res = 0 188 | for epoch in range(opt.epoch_count, opt.annealEvery): 189 | trainLogger = open('%s/train.log' % opt.exp, 'a+') 190 | # switch the state ! 191 | netLRN.train() 192 | 193 | epoch_iter = 0 194 | epoch_start_time = time.time() 195 | iter_data_time = time.time() 196 | 197 | my_psnr = 0 198 | my_ssim = 0 199 | my_ssim_multi = 0 200 | 201 | adjust_learning_rate(optimizerG, opt.lrG, epoch, None, opt.annealEvery) 202 | print(50 * '-' + 'lr' + 50 * '-') 203 | print(str(optimizerG.param_groups[0]['lr'])) 204 | print(50 * '-' + 'lr' + 50 * '-') 205 | 206 | ccnt = 0 207 | train_res = 0 208 | for i, data in enumerate(dataloader, 0): 209 | 210 | netLRN.train() 211 | iter_start_time = time.time() 212 | if total_steps % 100 == 0: 213 | t_data = iter_start_time - iter_data_time 214 | visualizer.reset() 215 | total_steps += opt.batchSize 216 | epoch_iter += opt.batchSize 217 | 218 | input_patch, target_patch, down_input, down_target, indexes1, indexes2 = data 219 | # print(indexes1.size()) 220 | r = indexes1[:].numpy() 221 | c = indexes2[:].numpy() 222 | # print(r) 223 | batch_size = target_patch.size(0) 224 | 225 | # print(indexes) 226 | 227 | input_patch = input_patch.cuda() 228 | target_patch = target_patch.cuda() 229 | gray_input_patch = 0.299 * input_patch[:, 0, :, : ] + 0.587 * input_patch[:, 1, :, : ] + 0.114 * input_patch[:, 2, :, : ] 230 | gray_input_patch.unsqueeze_(1) 231 | gray_target_patch = 0.299 * target_patch[:, 0, :, : ] + 0.587 * target_patch[:, 1, :, : ] + 0.114 * target_patch[:, 2, :, : ] 232 | gray_target_patch.unsqueeze_(1) 233 | down_input = down_input.cuda() 234 | down_target = down_target.cuda() 235 | 236 | optimizerG.zero_grad() 237 | demoire_down = netGDN(down_input)[-1] 238 | demoire_up = F.interpolate(demoire_down, size=1024, mode='bilinear') 239 | down_target_up = F.interpolate(down_target, size=1024, mode='bilinear') 240 | # for i in range(batch_size): 241 | # print(demoire_up.shape) 242 | demoire_up_patch = demoire_up[:, :, indexes2[0]:indexes2[0] + patch_size, indexes1[0]: indexes1[0] + patch_size] 243 | down_target_up_patch = down_target_up[:, :, indexes2[0]:indexes2[0] + patch_size, indexes1[0]: indexes1[0] + patch_size] 244 | # demoire_up_patch.unsqueeze_(0) 245 | # down_target_up_patch.unsqueeze_(0) 246 | # print(demoire_up_patch.shape) 247 | x_hat_patch = netLRN(demoire_up_patch) 248 | 249 | vgg_x_hat = vgg19(x_hat_patch) 250 | vgg_target = vgg19(down_target_up_patch) 251 | 252 | loss_CX = 0.0 253 | if lambdaCX >0.0: 254 | for l, w in vgg_layers.items(): 255 | loss_CX += w * criterionCX(vgg_x_hat[l], vgg_target[l]) 256 | 257 | loss_CX = lambdaCX * loss_CX 258 | L = loss_CX 259 | 260 | L.backward() 261 | optimizerG.step() 262 | 263 | ganIterations += 1 264 | ccnt+=1 265 | 266 | if total_steps % 10 == 0: 267 | current_visuals = OrderedDict([ 268 | ('down_input', down_input), 269 | ('demoire_down', demoire_down), 270 | ('down_target', down_target), 271 | ('input_patch', input_patch), 272 | ('demoire_up_patch', demoire_up_patch), 273 | ('x_hat_patch', x_hat_patch), 274 | ('down_target_up_patch', down_target_up_patch), 275 | ('target_patch', target_patch) 276 | # ('fake_moire', fake) 277 | ]) 278 | 279 | losses = OrderedDict([ 280 | ('loss_CX', loss_CX.detach().cpu().float().numpy()), 281 | # ('loss_perce', loss_perce.detach().cpu().float().numpy()), 282 | # ('L_img2', L_img2.detach().cpu().float().numpy()), 283 | # ('content_loss0', content_loss0.detach().cpu().float().numpy()), 284 | # ('content_loss1', content_loss1.detach().cpu().float().numpy()), 285 | # ('percep_metric', train_res.detach() .cpu().float().numpy()[0][0][0][0]/ (ccnt)), 286 | ('my_psnr', my_psnr / (ccnt)), ('my_ssim_multi', my_ssim_multi / ccnt)]) 287 | # print(losses) 288 | t = (time.time() - iter_start_time) / opt.batchSize 289 | trainLogger.write(visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) + '\n') 290 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 291 | r = float(epoch_iter) / (dataset_size*opt.batchSize) 292 | if opt.display_port!=-1: 293 | visualizer.display_current_results(current_visuals, epoch, False) 294 | visualizer.plot_current_losses(epoch, r, opt, losses) 295 | 296 | if epoch % 5 == 0: 297 | my_file = open(opt.name + "/" + opt.name + "_" + "evaluation.txt", 'a+') 298 | print('hit') 299 | torch.save(netLRN.state_dict(), '%s/netLRN_epoch_%d.pth' % (opt.exp, epoch)) 300 | # torch.save(netGDN.state_dict(), '%s/netGDN_epoch_%d.pth' % (opt.exp, epoch)) 301 | # vcnt = 0 302 | # vpsnr = 0 303 | # vssim = 0 304 | # res = 0 305 | # netGDN.eval() 306 | # netLRN.eval() 307 | # utils.set_requires_grad(netGDN, False) 308 | # utils.set_requires_grad(netLRN, False) 309 | # for i, data in enumerate(val_dataloader, 0): 310 | # if i % 50 == 0: 311 | # print('testing: ' + str(i)) 312 | # input, target, down_input = data 313 | # batch_size = input.size(0) 314 | # input = input.cuda() 315 | # target = target.cuda() 316 | # down_input = down_input.cuda() 317 | # 318 | # gray_input = 0.299 * input[:, 0, :, :] + 0.587 * input[:, 1, :, :] + 0.114 * input[:, 2, :, :] 319 | # gray_input.unsqueeze_(1) 320 | # gray_target = 0.299 * target[:, 0, :, :] + 0.587 * target[:, 1, :, :] + 0.114 * target[:, 2, :, :] 321 | # gray_target.unsqueeze_(1) 322 | # 323 | # demoire_down = netGDN(down_input)[-1].detach() 324 | # demoire_up = F.interpolate(demoire_down, size=1024, mode='bilinear') 325 | # 326 | # x_hat = netLRN(demoire_up) 327 | # res += torch.sum(net_metric(target, x_hat).detach()) 328 | # vcnt += batch_size 329 | # for i in range(x_hat.shape[0]): 330 | # ti1 = x_hat[i, :, :, :] 331 | # tt1 = target[i, :, :, :] 332 | # mi1 = util.my_tensor2im(ti1) 333 | # mt1 = util.my_tensor2im(tt1) 334 | # g_mi1 = cv2.cvtColor(mi1, cv2.COLOR_BGR2RGB) 335 | # g_mt1 = cv2.cvtColor(mt1, cv2.COLOR_BGR2RGB) 336 | # vpsnr += Psnr(g_mt1, g_mi1) 337 | # vssim += ssim(g_mt1, g_mi1, multichannel=True) 338 | # 339 | # my_file.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '(' + str(epoch) + ')' + '\n') 340 | # my_file.write(str(float(res) / vcnt) + '-' + str(float(vpsnr) / vcnt) + '-' + str(float(vssim) / vcnt) + '\n') 341 | # my_file.close() 342 | # netGDN.train() 343 | # netLRN.train() 344 | # utils.set_requires_grad(netGDN, True) 345 | # utils.set_requires_grad(netLRN, True) 346 | 347 | trainLogger.close() --------------------------------------------------------------------------------