├── model ├── __init__.py ├── vgg.pyc ├── __init__.pyc ├── generator.pyc ├── vgg.py └── generator.py ├── models ├── __init__.py ├── NEDB.pyc ├── RNEDB.pyc ├── __init__.pyc ├── networks.pyc ├── non_local_block.pyc ├── region_non_local_block.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── networks.cpython-37.pyc │ └── derain_dense.cpython-36.pyc ├── region_non_local_block.py ├── RNEDB.py ├── NEDB.py ├── NLEDN.py └── non_local_block.py ├── util_metirc ├── __init__.py ├── util.py ├── html.py └── visualizer.py ├── datasets ├── __init__.py ├── __init__.pyc ├── my_loader.pyc ├── classification.py ├── pix2pix_class.py ├── pix2pix2.py ├── pix2pix_val.py └── pix2pix.py ├── myutils ├── __init__.py ├── utils.pyc ├── vgg16.pyc ├── __init__.pyc ├── StyleLoader.py ├── vgg16.py └── utils.py ├── transforms ├── __init__.py ├── pix2pix.pyc ├── __init__.pyc ├── pix2pix_val.py ├── pix2pix3.py └── pix2pix_val3.py ├── misc.pyc ├── util.pyc ├── visualizer.pyc ├── models_metric ├── __init__.pyc ├── base_model.pyc ├── dist_model.pyc ├── networks_basic.pyc ├── weights │ ├── v0.0 │ │ ├── alex.pth │ │ ├── vgg.pth │ │ └── squeeze.pth │ └── v0.1 │ │ ├── alex.pth │ │ ├── vgg.pth │ │ └── squeeze.pth ├── pretrained_networks.pyc ├── base_model.py ├── __init__.py ├── pretrained_networks.py ├── networks_basic.py └── dist_model.py ├── __pycache__ └── misc.cpython-37.pyc ├── requirements.txt ├── run_GDN.sh ├── run_FDN_FRN.sh ├── run_test.sh ├── run_LRN.sh ├── README.md ├── util.py ├── rank2_edge_batch.py ├── misc.py ├── visualizer.py ├── test.py ├── train_GDN.py ├── list_7000_f1000.txt └── train_LRN.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /util_metirc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /myutils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /misc.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/misc.pyc -------------------------------------------------------------------------------- /util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/util.pyc -------------------------------------------------------------------------------- /model/vgg.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/model/vgg.pyc -------------------------------------------------------------------------------- /models/NEDB.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/NEDB.pyc -------------------------------------------------------------------------------- /visualizer.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/visualizer.pyc -------------------------------------------------------------------------------- /models/RNEDB.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/RNEDB.pyc -------------------------------------------------------------------------------- /myutils/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/myutils/utils.pyc -------------------------------------------------------------------------------- /myutils/vgg16.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/myutils/vgg16.pyc -------------------------------------------------------------------------------- /model/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/model/__init__.pyc -------------------------------------------------------------------------------- /model/generator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/model/generator.pyc -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__init__.pyc -------------------------------------------------------------------------------- /models/networks.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/networks.pyc -------------------------------------------------------------------------------- /myutils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/myutils/__init__.pyc -------------------------------------------------------------------------------- /datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/datasets/__init__.pyc -------------------------------------------------------------------------------- /datasets/my_loader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/datasets/my_loader.pyc -------------------------------------------------------------------------------- /transforms/pix2pix.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/transforms/pix2pix.pyc -------------------------------------------------------------------------------- /transforms/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/transforms/__init__.pyc -------------------------------------------------------------------------------- /models/non_local_block.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/non_local_block.pyc -------------------------------------------------------------------------------- /models_metric/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/__init__.pyc -------------------------------------------------------------------------------- /models_metric/base_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/base_model.pyc -------------------------------------------------------------------------------- /models_metric/dist_model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/dist_model.pyc -------------------------------------------------------------------------------- /__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /models_metric/networks_basic.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/networks_basic.pyc -------------------------------------------------------------------------------- /models/region_non_local_block.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/region_non_local_block.pyc -------------------------------------------------------------------------------- /models_metric/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /models_metric/pretrained_networks.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/pretrained_networks.pyc -------------------------------------------------------------------------------- /models_metric/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /models_metric/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models_metric/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/derain_dense.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IMRE/FHDe2Net/HEAD/models/__pycache__/derain_dense.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | Pillow 4 | scipy 5 | scikit-image 6 | torch==1.0.1 7 | torchvision==0.2.0 8 | torchfile 9 | tqdm 10 | 11 | -------------------------------------------------------------------------------- /run_GDN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0,1 python2.7 train_GDN.py --dataroot "/media/he/80FE99D1FE99BFB8/longmao_final/train" --valDataroot "/media/he/80FE99D1FE99BFB8/longmao_final/val" --pre "" --name "GDN" --exp "GDN" --display_port 8099 --originalSize_h 420 --originalSize_w 420 --imageSize_h 384 --imageSize_w 384 --batchSize 2 3 | -------------------------------------------------------------------------------- /run_FDN_FRN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0 python2.7 train_FDN_FRN.py --dataroot "/media/he/80FE99D1FE99BFB8/longmao_final/train" --valDataroot "" --pre "" --name "FDN_FRN" --exp "FDN_FRN" --netGDN "./ckpt/netGDN.pth" --netLRN "ckpt/netLRN.pth" --display_port 8099 --originalSize_h 1080 --originalSize_w 1920 --imageSize_h 1024 --imageSize_w 1024 --batchSize 1 3 | -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=1 python2 test.py --dataroot "/media/he/80FE99D1FE99BFB8/FHDMi/test" --netGDN "ckpt/netGDN.pth" --netLRN "ckpt/netLRN.pth" --netFDN "ckpt/netFDN.pth" --netFRN "ckpt/netFRN.pth" --batchSize 2 --originalSize_h 1080 --originalSize_w 1920 --imageSize_h 1080 --imageSize_w 1920 --image_path "results" --write 1 --record "results.txt" 3 | 4 | -------------------------------------------------------------------------------- /run_LRN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=0,1 python2.7 train_LRN.py --dataroot "/media/he/80FE99D1FE99BFB8/longmao_final/train" --valDataroot "/media/he/80FE99D1FE99BFB8/longmao_final/val" --pre "" --name "LRN" --exp "LRN" --netGDN 'ckpt/netGDN.pth' --display_port 8099 --originalSize_h 1080 --originalSize_w 1920 --imageSize_h 1024 --imageSize_w 1024 --batchSize 2 --list_file 'list_7000_f1000.txt' 3 | -------------------------------------------------------------------------------- /models/region_non_local_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from non_local_block import NONLocalBlock2D 4 | 5 | 6 | class RegionNONLocalBlock(nn.Module): 7 | def __init__(self, in_channels, grid=[6, 6]): 8 | super(RegionNONLocalBlock, self).__init__() 9 | 10 | self.non_local_block = NONLocalBlock2D(in_channels, sub_sample=True, bn_layer=False) 11 | self.grid = grid 12 | 13 | def forward(self, x): 14 | batch_size, _, height, width = x.size() 15 | 16 | input_row_list = x.chunk(self.grid[0], dim=2) 17 | 18 | output_row_list = [] 19 | for i, row in enumerate(input_row_list): 20 | input_grid_list_of_a_row = row.chunk(self.grid[1], dim=3) 21 | output_grid_list_of_a_row = [] 22 | 23 | for j, grid in enumerate(input_grid_list_of_a_row): 24 | grid = self.non_local_block(grid) 25 | output_grid_list_of_a_row.append(grid) 26 | 27 | output_row = torch.cat(output_grid_list_of_a_row, dim=3) 28 | output_row_list.append(output_row) 29 | 30 | output = torch.cat(output_row_list, dim=2) 31 | return output 32 | -------------------------------------------------------------------------------- /myutils/StyleLoader.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import os 12 | from torch.autograd import Variable 13 | 14 | from myutils import utils 15 | 16 | class StyleLoader(): 17 | def __init__(self, style_folder, style_size, cuda=True): 18 | self.folder = style_folder 19 | self.style_size = style_size 20 | self.files = os.listdir(style_folder) 21 | self.cuda = cuda 22 | 23 | def get(self, i): 24 | idx = i%len(self.files) 25 | filepath = os.path.join(self.folder, self.files[idx]) 26 | style = utils.tensor_load_rgbimage(filepath, self.style_size) 27 | style = style.unsqueeze(0) 28 | style = utils.preprocess_batch(style) 29 | if self.cuda: 30 | style = style.cuda() 31 | style_v = Variable(style, requires_grad=False) 32 | return style_v 33 | 34 | def size(self): 35 | return len(self.files) 36 | 37 | 38 | -------------------------------------------------------------------------------- /util_metirc/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | def load_image(path): 11 | if(path[-3:] == 'dng'): 12 | import rawpy 13 | with rawpy.imread(path) as raw: 14 | img = raw.postprocess() 15 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 16 | import cv2 17 | return cv2.imread(path)[:,:,::-1] 18 | else: 19 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 20 | 21 | return img 22 | 23 | def save_image(image_numpy, image_path, ): 24 | image_pil = Image.fromarray(image_numpy) 25 | image_pil.save(image_path) 26 | 27 | def mkdirs(paths): 28 | if isinstance(paths, list) and not isinstance(paths, str): 29 | for path in paths: 30 | mkdir(path) 31 | else: 32 | mkdir(paths) 33 | 34 | def mkdir(path): 35 | if not os.path.exists(path): 36 | os.makedirs(path) 37 | 38 | 39 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 40 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 41 | image_numpy = image_tensor[0].cpu().float().numpy() 42 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 43 | return image_numpy.astype(imtype) 44 | 45 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 46 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 47 | return torch.Tensor((image / factor - cent) 48 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 49 | -------------------------------------------------------------------------------- /models/RNEDB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from region_non_local_block import RegionNONLocalBlock 4 | 5 | 6 | class RNEDB(nn.Module): 7 | def __init__(self, block_num=3, inter_channel=32, channel=64, grid=[8, 8]): 8 | super(RNEDB, self).__init__() 9 | 10 | concat_channels = channel + block_num * inter_channel 11 | channels_now = channel 12 | 13 | self.region_non_local = RegionNONLocalBlock(channels_now, grid=grid) 14 | 15 | self.group_list = [] 16 | for i in range(block_num): 17 | group = nn.Sequential( 18 | nn.Conv2d(in_channels=channels_now, out_channels=inter_channel, kernel_size=3, 19 | stride=1, padding=1), 20 | nn.ReLU(), 21 | ) 22 | self.add_module(name='group_%d' % i, module=group) 23 | self.group_list.append(group) 24 | 25 | channels_now += inter_channel 26 | 27 | assert channels_now == concat_channels 28 | self.fusion = nn.Sequential( 29 | nn.Conv2d(concat_channels, channel, kernel_size=1, stride=1, padding=0), 30 | ) 31 | 32 | def forward(self, x): 33 | x_rnl = self.region_non_local(x) 34 | feature_list = [x_rnl,] 35 | 36 | for group in self.group_list: 37 | inputs = torch.cat(feature_list, dim=1) 38 | outputs = group(inputs) 39 | feature_list.append(outputs) 40 | 41 | inputs = torch.cat(feature_list, dim=1) 42 | fusion_outputs = self.fusion(inputs) 43 | 44 | block_outputs = fusion_outputs + x 45 | 46 | return block_outputs 47 | -------------------------------------------------------------------------------- /models/NEDB.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from non_local_block import NONLocalBlock2D 4 | 5 | 6 | class NEDB(nn.Module): 7 | def __init__(self, block_num=3, inter_channel=32, channel=64): 8 | super(NEDB, self).__init__() 9 | 10 | concat_channels = channel + block_num * inter_channel 11 | channels_now = channel 12 | 13 | self.non_local = NONLocalBlock2D(channels_now, bn_layer=False) 14 | 15 | self.group_list = [] 16 | for i in range(block_num): 17 | group = nn.Sequential( 18 | nn.Conv2d(in_channels=channels_now, out_channels=inter_channel, kernel_size=3, 19 | stride=1, padding=1), 20 | nn.ReLU(), 21 | ) 22 | self.add_module(name='group_%d' % i, module=group) 23 | self.group_list.append(group) 24 | 25 | channels_now += inter_channel 26 | 27 | assert channels_now == concat_channels 28 | self.fusion = nn.Sequential( 29 | nn.Conv2d(concat_channels, channel, kernel_size=1, stride=1, padding=0), 30 | ) 31 | 32 | def forward(self, x, corr1, corr2): 33 | x_nl, ret_corr1, ret_corr2 = self.non_local(x, corr1, corr2) 34 | feature_list = [x_nl,] 35 | 36 | for group in self.group_list: 37 | inputs = torch.cat(feature_list, dim=1) 38 | outputs = group(inputs) 39 | feature_list.append(outputs) 40 | 41 | inputs = torch.cat(feature_list, dim=1) 42 | fusion_outputs = self.fusion(inputs) 43 | 44 | block_outputs = fusion_outputs + x 45 | 46 | return block_outputs, ret_corr1, ret_corr2 47 | -------------------------------------------------------------------------------- /models_metric/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | # from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /myutils/vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Vgg16(torch.nn.Module): 7 | def __init__(self): 8 | super(Vgg16, self).__init__() 9 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 10 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 11 | 12 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 13 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 14 | 15 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 16 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 17 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 18 | 19 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 20 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 21 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 22 | 23 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 24 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 25 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 26 | 27 | def forward(self, X): 28 | h = F.relu(self.conv1_1(X)) 29 | h = F.relu(self.conv1_2(h)) 30 | relu1_2 = h 31 | h = F.max_pool2d(h, kernel_size=2, stride=2) 32 | 33 | h = F.relu(self.conv2_1(h)) 34 | h = F.relu(self.conv2_2(h)) 35 | relu2_2 = h 36 | h = F.max_pool2d(h, kernel_size=2, stride=2) 37 | 38 | h = F.relu(self.conv3_1(h)) 39 | h = F.relu(self.conv3_2(h)) 40 | h = F.relu(self.conv3_3(h)) 41 | relu3_3 = h 42 | h = F.max_pool2d(h, kernel_size=2, stride=2) 43 | 44 | h = F.relu(self.conv4_1(h)) 45 | h = F.relu(self.conv4_2(h)) 46 | h = F.relu(self.conv4_3(h)) 47 | relu4_3 = h 48 | 49 | return [relu1_2, relu2_2, relu3_3, relu4_3] 50 | -------------------------------------------------------------------------------- /datasets/classification.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import h5py 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | if not os.path.isdir(dir): 20 | raise Exception('Check dataroot') 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(dir, fname) 25 | item = path 26 | images.append(item) 27 | return images 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class classification(data.Dataset): 33 | def __init__(self, root, transform=None, loader=default_loader, seed=None): 34 | # imgs = make_dataset(root) 35 | # if len(imgs) == 0: 36 | # raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 37 | # "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 38 | self.root = root 39 | # self.imgs = imgs 40 | self.transform = transform 41 | self.loader = loader 42 | 43 | if seed is not None: 44 | np.random.seed(seed) 45 | 46 | def __getitem__(self, _): 47 | index = np.random.randint(1,self.__len__()) 48 | # path = self.imgs[index] 49 | # img = self.loader(path) 50 | #img = img.resize((w, h), Image.BILINEAR) 51 | 52 | 53 | 54 | file_name=self.root+'/'+str(index)+'.h5' 55 | f=h5py.File(file_name,'r') 56 | 57 | haze_image=f['haze'][:] 58 | label=f['label'][:] 59 | label=label.mean()-1 60 | 61 | haze_image=np.swapaxes(haze_image,0,2) 62 | haze_image=np.swapaxes(haze_image,1,2) 63 | 64 | 65 | # if self.transform is not None: 66 | # # NOTE preprocessing for each pair of images 67 | # imgA, imgB = self.transform(imgA, imgB) 68 | return haze_image, label 69 | 70 | def __len__(self): 71 | train_list=glob.glob(self.root+'/*h5') 72 | # print len(train_list) 73 | return len(train_list) 74 | 75 | # return len(self.imgs) 76 | -------------------------------------------------------------------------------- /util_metirc/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | # self.img_dir = os.path.join(self.web_dir, ) 11 | self.img_subdir = image_subdir 12 | self.img_dir = os.path.join(self.web_dir, image_subdir) 13 | if not os.path.exists(self.web_dir): 14 | os.makedirs(self.web_dir) 15 | if not os.path.exists(self.img_dir): 16 | os.makedirs(self.img_dir) 17 | # print(self.img_dir) 18 | 19 | self.doc = dominate.document(title=title) 20 | if reflesh > 0: 21 | with self.doc.head: 22 | meta(http_equiv="reflesh", content=str(reflesh)) 23 | 24 | def get_image_dir(self): 25 | return self.img_dir 26 | 27 | def add_header(self, str): 28 | with self.doc: 29 | h3(str) 30 | 31 | def add_table(self, border=1): 32 | self.t = table(border=border, style="table-layout: fixed;") 33 | self.doc.add(self.t) 34 | 35 | def add_images(self, ims, txts, links, width=400): 36 | self.add_table() 37 | with self.t: 38 | with tr(): 39 | for im, txt, link in zip(ims, txts, links): 40 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 41 | with p(): 42 | with a(href=os.path.join(link)): 43 | img(style="width:%dpx" % width, src=os.path.join(im)) 44 | br() 45 | p(txt) 46 | 47 | def save(self,file='index'): 48 | html_file = '%s/%s.html' % (self.web_dir,file) 49 | f = open(html_file, 'wt') 50 | f.write(self.doc.render()) 51 | f.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | html = HTML('web/', 'test_html') 56 | html.add_header('hello world') 57 | 58 | ims = [] 59 | txts = [] 60 | links = [] 61 | for n in range(4): 62 | ims.append('image_%d.png' % n) 63 | txts.append('text_%d' % n) 64 | links.append('image_%d.png' % n) 65 | html.add_images(ims, txts, links) 66 | html.save() 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FHDe2Net 2 | Official Implementation for "FHDe2Net: Full High Definition Demoireing Network" (ECCV 20) 3 | 4 | ## Prerequisites: 5 | 1. Linux 6 | 2. python2 or 3 7 | 3. NVIDIA GPU + CUDA CuDNN (CUDA 8.0) 8 | 9 | ## Installation: 10 | 1. Install PyTorch from http://pytorch.org 11 | 2. Install Torch vision from https://github.com/pytorch/vision 12 | 3. Install python package: numpy, scipy, PIL, math, skimage, visdom 13 | 14 | ## Download the FHDMi dataset 15 | You can download the training and testing dataset from 16 | https://pan.baidu.com/s/19LTN7unSBAftSpNVs8x9ZQ with password jf2d 17 | or 18 | https://drive.google.com/drive/folders/1IJSeBXepXFpNAvL5OyZ2Y1yu4KPvDxN5?usp=sharing 19 | You accelerate the training with a subset of FHDMi described by FHDMi_thin.txt. 20 | 21 | ## Testing 22 | 1) Download pre-trained models 23 | You can download the pre-trained models from 24 | https://pan.baidu.com/s/14fo4gdBtx4GDohNNyYObpg with password t8vn 25 | And all the models are supposed to be placed in the ckpt folder. 26 | 27 | 2) Build up the testing environment 28 | You can easily build the testing environment by: 29 | `pip install -r requirements.txt` 30 | 31 | 3) testing 32 | Specify the --dataroot with the testing dataset path, and run 33 | `bash run_test.sh ` 34 | 35 | ## Training 36 | 1) Download Vgg19 ckpt from 37 | https://pan.baidu.com/s/1c3eEh29uAfZTzTe0X9Jz_Q by password: zvcy 38 | And put it in models/ 39 | 40 | 2) open visdom by 41 | `python -m visdom.server -port 8099` 42 | 43 | 3) change the dataroot in run_GDN.sh and train GND by running 44 | `bash run_GDN.sh` 45 | 46 | 4) change the dataroot in run_LRN.sh and train LRN by running 47 | `bash run_LRN.sh` 48 | (For trainnig LRN, you can either use the distilled dataset generated by rank2_edge_batch.py or directly use the whole dataset. 49 | A subset image list for FHDMi has been generated ahead in list_7000_f1000.txt.) 50 | 51 | 5) change the dataroot in run_FDN_FRN.sh and train FDN and FRN by running 52 | `bash run_FDN_FRN.sh` 53 | 54 | ## Citation 55 | ``` 56 | @article{hefhde2net, 57 | title={FHDe2Net: Full High Definition Demoireing Network}, 58 | author={He, Bin and Wang, Ce and Shi, Boxin and Duan, Ling-Yu}, 59 | publisher={Springer} 60 | } 61 | ``` 62 | ## Contactor 63 | If you have any question, please feel free to contact me with 1801213742@pku.edu.cn 64 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | def my_tensor2im(input_image, imtype=np.uint8): 21 | if isinstance(input_image, torch.Tensor): 22 | image_tensor = input_image.data 23 | else: 24 | return input_image 25 | # print(50*'-') 26 | # print(input_image.shape) 27 | # print(image_tensor.shape) 28 | image_numpy = image_tensor.cpu().float().numpy() 29 | if image_numpy.shape[0] == 1: 30 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 31 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 32 | return image_numpy.astype(imtype) 33 | 34 | 35 | def diagnose_network(net, name='network'): 36 | mean = 0.0 37 | count = 0 38 | for param in net.parameters(): 39 | if param.grad is not None: 40 | mean += torch.mean(torch.abs(param.grad.data)) 41 | count += 1 42 | if count > 0: 43 | mean = mean / count 44 | print(name) 45 | print(mean) 46 | 47 | 48 | def save_image(image_numpy, image_path): 49 | image_pil = Image.fromarray(image_numpy) 50 | image_pil.save(image_path) 51 | 52 | 53 | def print_numpy(x, val=True, shp=False): 54 | x = x.astype(np.float64) 55 | if shp: 56 | print('shape,', x.shape) 57 | if val: 58 | x = x.flatten() 59 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 60 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 61 | 62 | 63 | def mkdirs(paths): 64 | if isinstance(paths, list) and not isinstance(paths, str): 65 | for path in paths: 66 | mkdir(path) 67 | else: 68 | mkdir(paths) 69 | 70 | 71 | def mkdir(path): 72 | if not os.path.exists(path): 73 | os.makedirs(path) 74 | -------------------------------------------------------------------------------- /datasets/pix2pix_class.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | if not os.path.isdir(dir): 20 | raise Exception('Check dataroot') 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(dir, fname) 25 | item = path 26 | images.append(item) 27 | return images 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class pix2pix(data.Dataset): 33 | def __init__(self, root, transform=None, loader=default_loader, seed=None, pre=""): 34 | imgs = make_dataset(root) 35 | if len(imgs) == 0: 36 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 37 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 38 | self.root = root 39 | self.imgs = imgs 40 | self.transform = transform 41 | self.loader = loader 42 | 43 | if seed is not None: 44 | np.random.seed(seed) 45 | 46 | def __getitem__(self, index): 47 | 48 | index_sub = np.random.randint(0, 3) 49 | label=index_sub 50 | 51 | 52 | 53 | if index_sub==0: 54 | index = np.random.randint(0,4000) 55 | path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg' 56 | 57 | 58 | if index_sub==1: 59 | index = np.random.randint(0,4000) 60 | path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Medium/train2018new'+'/'+str(index)+'.jpg' 61 | 62 | if index_sub==2: 63 | index = np.random.randint(0,4000) 64 | path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Light/train2018new'+'/'+str(index)+'.jpg' 65 | 66 | 67 | 68 | img = self.loader(path) 69 | 70 | 71 | w, h = img.size 72 | 73 | 74 | # NOTE: split a sample into imgA and imgB 75 | imgA = img.crop((0, 0, w/2, h)) 76 | imgB = img.crop((w/2, 0, w, h)) 77 | 78 | 79 | if self.transform is not None: 80 | # NOTE preprocessing for each pair of images 81 | imgA, imgB = self.transform(imgA, imgB) 82 | 83 | return imgA, imgB, label 84 | 85 | def __len__(self): 86 | # return 679 87 | # print(len(self.imgs)) 88 | return len(self.imgs) 89 | -------------------------------------------------------------------------------- /models/NLEDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from NEDB import NEDB 4 | from RNEDB import RNEDB 5 | 6 | 7 | class NLEDN(nn.Module): 8 | def __init__(self): 9 | super(NLEDN, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(3, 64, 3, 1, 1) 12 | self.conv2 = nn.Conv2d(64, 64, 3, 1, 1) 13 | 14 | self.up_1 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[8, 8]) 15 | self.up_2 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[4, 4]) 16 | self.up_3 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[2, 2]) 17 | 18 | self.down_3 = NEDB(block_num=4, inter_channel=32, channel=64) 19 | self.down_2 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[2, 2]) 20 | self.down_1 = RNEDB(block_num=4, inter_channel=32, channel=64, grid=[4, 4]) 21 | 22 | self.down_2_fusion = nn.Conv2d(64 + 64, 64, 1, 1, 0) 23 | self.down_1_fusion = nn.Conv2d(64 + 64, 64, 1, 1, 0) 24 | 25 | self.fusion = nn.Sequential( 26 | nn.Conv2d(64 * 3, 64, 1, 1, 0), 27 | nn.Conv2d(64, 64, 3, 1, 1), 28 | ) 29 | 30 | self.final_conv = nn.Sequential( 31 | nn.Conv2d(64, 3, 3, 1, 1), 32 | nn.Tanh(), 33 | ) 34 | 35 | def forward(self, x): 36 | feature_neg_1 = self.conv1(x) 37 | feature_0 = self.conv2(feature_neg_1) 38 | 39 | up_1_banch = self.up_1(feature_0) 40 | up_1, indices_1 = nn.MaxPool2d(2, 2, return_indices=True)(up_1_banch) 41 | 42 | up_2 = self.up_2(up_1) 43 | up_2, indices_2 = nn.MaxPool2d(2, 2, return_indices=True)(up_2) 44 | 45 | up_3 = self.up_3(up_2) 46 | up_3, indices_3 = nn.MaxPool2d(2, 2, return_indices=True)(up_3) 47 | 48 | down_3 = self.down_3(up_3) 49 | 50 | down_3 = nn.MaxUnpool2d(2, 2)(down_3, indices_3, output_size=up_2.size()) 51 | 52 | down_3 = torch.cat([up_2, down_3], dim=1) 53 | down_3 = self.down_2_fusion(down_3) 54 | down_2 = self.down_2(down_3) 55 | 56 | down_2 = nn.MaxUnpool2d(2, 2)(down_2, indices_2, output_size=up_1.size()) 57 | 58 | down_2 = torch.cat([up_1, down_2], dim=1) 59 | down_2 = self.down_1_fusion(down_2) 60 | down_1 = self.down_1(down_2) 61 | down_1 = nn.MaxUnpool2d(2, 2)(down_1, indices_1, output_size=feature_0.size()) 62 | 63 | down_1 = torch.cat([feature_0, down_1], dim=1) 64 | 65 | cat_block_feature = torch.cat([down_1, up_1_banch], 1) 66 | feature = self.fusion(cat_block_feature) 67 | feature = feature + feature_neg_1 68 | 69 | outputs = self.final_conv(feature) 70 | 71 | return outputs 72 | -------------------------------------------------------------------------------- /datasets/pix2pix2.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import h5py 7 | import glob 8 | import scipy.ndimage 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def make_dataset(dir): 18 | images = [] 19 | if not os.path.isdir(dir): 20 | raise Exception('Check dataroot') 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in fnames: 23 | if is_image_file(fname): 24 | path = os.path.join(dir, fname) 25 | item = path 26 | images.append(item) 27 | return images 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class pix2pix(data.Dataset): 33 | def __init__(self, root, transform=None, loader=default_loader, seed=None): 34 | # imgs = make_dataset(root) 35 | # if len(imgs) == 0: 36 | # raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 37 | # "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 38 | self.root = root 39 | # self.imgs = imgs 40 | self.transform = transform 41 | self.loader = loader 42 | 43 | if seed is not None: 44 | np.random.seed(seed) 45 | 46 | def __getitem__(self, _): 47 | index = np.random.randint(1,self.__len__()) 48 | # index = np.random.randint(self.__len__(), size=1)[0] 49 | 50 | # path = self.imgs[index] 51 | # img = self.loader(path) 52 | #img = img.resize((w, h), Image.BILINEAR) 53 | 54 | 55 | 56 | file_name=self.root+'/'+str(index)+'.h5' 57 | f=h5py.File(file_name,'r') 58 | 59 | haze_image=f['haze'][:] 60 | trans_map=f['trans'][:] 61 | ato_map=f['ato'][:] 62 | GT=f['gt'][:] 63 | 64 | 65 | 66 | haze_image=np.swapaxes(haze_image,0,2) 67 | trans_map=np.swapaxes(trans_map,0,2) 68 | ato_map=np.swapaxes(ato_map,0,2) 69 | GT=np.swapaxes(GT,0,2) 70 | 71 | 72 | 73 | haze_image=np.swapaxes(haze_image,1,2) 74 | trans_map=np.swapaxes(trans_map,1,2) 75 | ato_map=np.swapaxes(ato_map,1,2) 76 | GT=np.swapaxes(GT,1,2) 77 | 78 | # if np.random.uniform()>0.5: 79 | # haze_image=np.flip(haze_image,2).copy() 80 | # GT = np.flip(GT, 2).copy() 81 | # trans_map=np.flip(trans_map, 2).copy() 82 | # if np.random.uniform()>0.5: 83 | # angle = np.random.uniform(-10, 10) 84 | # haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle) 85 | # GT = scipy.ndimage.interpolation.rotate(GT, angle) 86 | 87 | # if np.random.uniform()>0.5: 88 | # angle = np.random.uniform(-10, 10) 89 | # haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle) 90 | # GT = scipy.ndimage.interpolation.rotate(GT, angle) 91 | 92 | # if np.random.uniform()>0.5: 93 | # std = np.random.uniform(0.2, 1.2) 94 | # haze_image = scipy.ndimage.filters.gaussian_filter(haze_image, std,mode='constant') 95 | 96 | # haze_image=np.random.uniform(-10/5000,10/5000,size=haze_image.shape) 97 | # haze_image = np.maximum(0, haze_image) 98 | 99 | # if self.transform is not None: 100 | # # NOTE preprocessing for each pair of images 101 | # imgA, imgB = self.transform(imgA, imgB) 102 | return haze_image, GT, trans_map, ato_map 103 | 104 | def __len__(self): 105 | train_list=glob.glob(self.root+'/*h5') 106 | # print len(train_list) 107 | return len(train_list) 108 | 109 | # return len(self.imgs) 110 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def conv2d(in_channels, out_channels): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 8 | nn.ReLU(inplace=True)) 9 | 10 | 11 | def conv(in_channels, out_channels): 12 | return nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 13 | 14 | 15 | class VGG19(nn.Module): 16 | def __init__(self): 17 | super(VGG19, self).__init__() 18 | self.mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda() 19 | self.std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda() 20 | self.conv1_1 = nn.Sequential(conv(3, 64), nn.ReLU()) 21 | self.conv1_2 = nn.Sequential(conv(64, 64), nn.ReLU()) 22 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False) 23 | self.conv2_1 = nn.Sequential(conv(64, 128), nn.ReLU()) 24 | self.conv2_2 = nn.Sequential(conv(128, 128), nn.ReLU()) 25 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False) 26 | self.conv3_1 = nn.Sequential(conv(128, 256), nn.ReLU()) 27 | self.conv3_2 = nn.Sequential(conv(256, 256), nn.ReLU()) 28 | self.conv3_3 = nn.Sequential(conv(256, 256), nn.ReLU()) 29 | self.conv3_4 = nn.Sequential(conv(256, 256), nn.ReLU()) 30 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=False) 31 | self.conv4_1 = nn.Sequential(conv(256, 512), nn.ReLU()) 32 | self.conv4_2 = nn.Sequential(conv(512, 512), nn.ReLU()) 33 | 34 | def load_model(self, model_file): 35 | vgg19_dict = self.state_dict() 36 | pretrained_dict = torch.load(model_file) 37 | vgg19_keys = vgg19_dict.keys() 38 | pretrained_keys = pretrained_dict.keys() 39 | for k, pk in zip(vgg19_keys, pretrained_keys): 40 | vgg19_dict[k] = pretrained_dict[pk] 41 | self.load_state_dict(vgg19_dict) 42 | 43 | # def forward(self, input_images): 44 | # # print(self.mean) 45 | # # input_images = (input_images - self.mean) / self.std 46 | # feature = {} 47 | # feature['conv1_1'] = self.conv1_1(input_images) 48 | # feature['conv1_2'] = self.conv1_2(feature['conv1_1']) 49 | # feature['pool1'] = self.pool1(feature['conv1_2']) 50 | # feature['conv2_1'] = self.conv2_1(feature['pool1']) 51 | # feature['conv2_2'] = self.conv2_2(feature['conv2_1']) 52 | # feature['pool2'] = self.pool2(feature['conv2_2']) 53 | # feature['conv3_1'] = self.conv3_1(feature['pool2']) 54 | # feature['conv3_2'] = self.conv3_2(feature['conv3_1']) 55 | # feature['conv3_3'] = self.conv3_3(feature['conv3_2']) 56 | # feature['conv3_4'] = self.conv3_4(feature['conv3_3']) 57 | # feature['pool3'] = self.pool3(feature['conv3_4']) 58 | # feature['conv4_1'] = self.conv4_1(feature['pool3']) 59 | # feature['conv4_2'] = self.conv4_2(feature['conv4_1']) 60 | # 61 | # return feature 62 | def forward(self, input_images): 63 | # print(self.mean) 64 | # input_images = (input_images - self.mean) / self.std 65 | feature = {} 66 | tmp = self.conv1_1(input_images) 67 | tmp = self.conv1_2(tmp) 68 | feature['conv1_2'] = tmp 69 | tmp = self.pool1(tmp) 70 | tmp = self.conv2_1(tmp) 71 | tmp = self.conv2_2(tmp) 72 | feature['conv2_2'] = tmp 73 | tmp = self.pool2(tmp) 74 | tmp = self.conv3_1(tmp) 75 | feature['conv3_2'] = self.conv3_2(tmp) 76 | tmp = self.conv3_3(feature['conv3_2']) 77 | feature['conv3_4'] = self.conv3_4(tmp) 78 | # tmp = self.conv3_3(feature['conv3_2']) 79 | # tmp = self.conv3_4(tmp) 80 | # tmp = self.pool3(feature['conv3_4']) 81 | # feature['conv4_1'] = self.conv4_1(feature['pool3']) 82 | # feature['conv4_2'] = self.conv4_2(feature['conv4_1']) 83 | 84 | return feature 85 | -------------------------------------------------------------------------------- /datasets/pix2pix_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | # from guidedfilter import guidedfilter 8 | # import guidedfilter.guidedfilter as guidedfilter 9 | 10 | 11 | 12 | 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', 17 | ] 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | def make_dataset(dir): 23 | images = [] 24 | if not os.path.isdir(dir): 25 | raise Exception('Check dataroot') 26 | for root, _, fnames in sorted(os.walk(dir)): 27 | for fname in fnames: 28 | if is_image_file(fname): 29 | path = os.path.join(dir, fname) 30 | item = path 31 | images.append(item) 32 | return images 33 | 34 | def default_loader(path): 35 | return Image.open(path).convert('RGB') 36 | 37 | class pix2pix_val(data.Dataset): 38 | def __init__(self, root, transform=None, loader=default_loader, seed=None, pre =""): 39 | imgs = make_dataset(root) 40 | if len(imgs) == 0: 41 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 42 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 43 | self.root = root 44 | self.imgs = imgs 45 | self.transform = transform 46 | self.loader = loader 47 | 48 | if seed is not None: 49 | np.random.seed(seed) 50 | 51 | def __getitem__(self, index): 52 | # index = np.random.randint(self.__len__(), size=1)[0] 53 | # index = np.random.randint(self.__len__(), size=1)[0] 54 | 55 | path = self.imgs[index] 56 | 57 | path=self.root+'/'+str(index)+'.jpg' 58 | 59 | index_folder = np.random.randint(0,4) 60 | label=index_folder 61 | 62 | # path='/home/openset/Desktop/derain2018/facades/DB_Rain_test/Rain_Heavy/test2018'+'/'+str(index)+'.jpg' 63 | img = self.loader(path) 64 | 65 | # # NOTE: img -> PIL Image 66 | # w, h = img.size 67 | # w, h = 1024, 512 68 | # img = img.resize((w, h), Image.BILINEAR) 69 | # # NOTE: split a sample into imgA and imgB 70 | # imgA = img.crop((0, 0, w/2, h)) 71 | # imgB = img.crop((w/2, 0, w, h)) 72 | # if self.transform is not None: 73 | # # NOTE preprocessing for each pair of images 74 | # imgA, imgB = self.transform(imgA, imgB) 75 | # return imgA, imgB 76 | 77 | 78 | # w, h = 1536, 512 79 | # img = img.resize((w, h), Image.BILINEAR) 80 | # 81 | # 82 | # # NOTE: split a sample into imgA and imgB 83 | # imgA = img.crop((0, 0, w/3, h)) 84 | # imgC = img.crop((2*w/3, 0, w, h)) 85 | # 86 | # imgB = img.crop((w/3, 0, 2*w/3, h)) 87 | 88 | # w, h = 1024, 512 89 | # img = img.resize((w, h), Image.BILINEAR) 90 | # 91 | # r = 16 92 | # eps = 1 93 | # 94 | # # I = img.crop((0, 0, w/2, h)) 95 | # # pix = np.array(I) 96 | # # print 97 | # # base[idx,:,:,:]=guidedfilter(pix[], pix[], r, eps) 98 | # # base[]=guidedfilter(pix[], pix[], r, eps) 99 | # # base[]=guidedfilter(pix[], pix[], r, eps) 100 | # 101 | # 102 | # # base = PIL.Image.fromarray(numpy.uint8(base)) 103 | # 104 | # # NOTE: split a sample into imgA and imgB 105 | # imgA = img.crop((0, 0, w/3, h)) 106 | # imgC = img.crop((2*w/3, 0, w, h)) 107 | # 108 | # imgB = img.crop((w/3, 0, 2*w/3, h)) 109 | # imgA=base 110 | # imgB=I-base 111 | # imgC = img.crop((w/2, 0, w, h)) 112 | w, h = img.size 113 | # w, h = 586*2, 586 114 | 115 | # img = img.resize((w, h), Image.BILINEAR) 116 | 117 | 118 | # NOTE: split a sample into imgA and imgB 119 | imgA = img.crop((0, 0, w/2, h)) 120 | # imgC = img.crop((2*w/3, 0, w, h)) 121 | 122 | imgB = img.crop((w/2, 0, w, h)) 123 | 124 | 125 | if self.transform is not None: 126 | # NOTE preprocessing for each pair of images 127 | # imgA, imgB, imgC = self.transform(imgA, imgB, imgC) 128 | imgA, imgB = self.transform(imgA, imgB) 129 | 130 | return imgA, imgB, path 131 | 132 | def __len__(self): 133 | return len(self.imgs) 134 | -------------------------------------------------------------------------------- /datasets/pix2pix.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | 9 | 10 | IMG_EXTENSIONS = [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', 13 | ] 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def make_dataset(dir): 19 | images = [] 20 | if not os.path.isdir(dir): 21 | raise Exception('Check dataroot') 22 | for root, _, fnames in sorted(os.walk(dir)): 23 | for fname in fnames: 24 | if is_image_file(fname): 25 | path = os.path.join(dir, fname) 26 | item = path 27 | images.append(item) 28 | return images 29 | 30 | def default_loader(path): 31 | return Image.open(path).convert('RGB') 32 | 33 | class pix2pix(data.Dataset): 34 | def __init__(self, root, transform=None, loader=default_loader, seed=None, pre=""): 35 | imgs = make_dataset(root) 36 | if len(imgs) == 0: 37 | raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" 38 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 39 | self.root = root 40 | self.imgs = imgs 41 | self.transform = transform 42 | self.loader = loader 43 | 44 | if seed is not None: 45 | np.random.seed(seed) 46 | 47 | def __getitem__(self, index): 48 | # index = np.random.randint(self.__len__(), size=1)[0] 49 | # index = np.random.randint(self.__len__(), size=1)[0]+1 50 | # index = np.random.randint(self.__len__(), size=1)[0] 51 | 52 | # index_folder = np.random.randint(1,4) 53 | index_folder = np.random.randint(0,1) 54 | 55 | index_sub = np.random.randint(2, 5) 56 | 57 | label=index_folder 58 | 59 | 60 | if index_folder==0: 61 | path='/home/openset/Desktop/derain2018/facades/training2'+'/'+str(index)+'.jpg' 62 | 63 | 64 | 65 | if index_folder==1: 66 | if index_sub<4: 67 | path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg' 68 | if index_sub==4: 69 | index = np.random.randint(0,400) 70 | path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Heavy/trainnew'+'/'+str(index)+'.jpg' 71 | 72 | if index_folder==2: 73 | if index_sub<4: 74 | path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Medium/train2018new'+'/'+str(index)+'.jpg' 75 | if index_sub==4: 76 | index = np.random.randint(0,400) 77 | path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Medium/trainnew'+'/'+str(index)+'.jpg' 78 | 79 | if index_folder==3: 80 | if index_sub<4: 81 | path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Light/train2018new'+'/'+str(index)+'.jpg' 82 | if index_sub==4: 83 | index = np.random.randint(0,400) 84 | path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Light/trainnew'+'/'+str(index)+'.jpg' 85 | 86 | 87 | 88 | # img = self.loader(path) 89 | 90 | img = self.loader(path) 91 | 92 | # NOTE: img -> PIL Image 93 | # w, h = img.size 94 | # w, h = 1024, 512 95 | # img = img.resize((w, h), Image.BILINEAR) 96 | # pix = np.array(I) 97 | # 98 | # r = 16 99 | # eps = 1 100 | # 101 | # I = img.crop((0, 0, w/2, h)) 102 | # pix = np.array(I) 103 | # base=guidedfilter(pix, pix, r, eps) 104 | # base = PIL.Image.fromarray(numpy.uint8(base)) 105 | # 106 | # 107 | # 108 | # imgA=base 109 | # imgB=I-base 110 | # imgC = img.crop((w/2, 0, w, h)) 111 | 112 | w, h = img.size 113 | # img = img.resize((w, h), Image.BILINEAR) 114 | 115 | 116 | # NOTE: split a sample into imgA and imgB 117 | imgA = img.crop((0, 0, w/2, h)) 118 | # imgC = img.crop((2*w/3, 0, w, h)) 119 | 120 | imgB = img.crop((w/2, 0, w, h)) 121 | 122 | 123 | if self.transform is not None: 124 | # NOTE preprocessing for each pair of images 125 | # imgA, imgB, imgC = self.transform(imgA, imgB, imgC) 126 | imgA, imgB = self.transform(imgA, imgB) 127 | 128 | return imgA, imgB, label 129 | 130 | def __len__(self): 131 | # return 679 132 | print(len(self.imgs)) 133 | return len(self.imgs) 134 | -------------------------------------------------------------------------------- /myutils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from torch.autograd import Variable 7 | # from torch.utils.serialization import load_lua 8 | import torchfile 9 | from myutils.vgg16 import Vgg16 10 | 11 | from torch.optim import lr_scheduler 12 | def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False): 13 | img = Image.open(filename).convert('RGB') 14 | if size is not None: 15 | if keep_asp: 16 | size2 = int(size * 1.0 / img.size[0] * img.size[1]) 17 | img = img.resize((size, size2), Image.ANTIALIAS) 18 | else: 19 | img = img.resize((size, size), Image.ANTIALIAS) 20 | 21 | elif scale is not None: 22 | img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) 23 | img = np.array(img).transpose(2, 0, 1) 24 | img = torch.from_numpy(img).float() 25 | return img 26 | 27 | 28 | def tensor_save_rgbimage(tensor, filename, cuda=False): 29 | if cuda: 30 | img = tensor.clone().cpu().clamp(0, 255).numpy() 31 | else: 32 | img = tensor.clone().clamp(0, 255).numpy() 33 | img = img.transpose(1, 2, 0).astype('uint8') 34 | img = Image.fromarray(img) 35 | img.save(filename) 36 | 37 | 38 | def tensor_save_bgrimage(tensor, filename, cuda=False): 39 | (b, g, r) = torch.chunk(tensor, 3) 40 | tensor = torch.cat((r, g, b)) 41 | tensor_save_rgbimage(tensor, filename, cuda) 42 | 43 | 44 | def gram_matrix(y): 45 | (b, ch, h, w) = y.size() 46 | features = y.view(b, ch, w * h) 47 | features_t = features.transpose(1, 2) 48 | gram = features.bmm(features_t) / (ch * h * w) 49 | return gram 50 | 51 | 52 | def subtract_imagenet_mean_batch(batch): 53 | """Subtract ImageNet mean pixel-wise from a BGR image.""" 54 | tensortype = type(batch.data) 55 | mean = tensortype(batch.data.size()) 56 | mean[:, 0, :, :] = 103.939 57 | mean[:, 1, :, :] = 116.779 58 | mean[:, 2, :, :] = 123.680 59 | return batch - Variable(mean) 60 | 61 | 62 | def add_imagenet_mean_batch(batch): 63 | """Add ImageNet mean pixel-wise from a BGR image.""" 64 | tensortype = type(batch.data) 65 | mean = tensortype(batch.data.size()) 66 | mean[:, 0, :, :] = 103.939 67 | mean[:, 1, :, :] = 116.779 68 | mean[:, 2, :, :] = 123.680 69 | return batch + Variable(mean) 70 | 71 | def imagenet_clamp_batch(batch, low, high): 72 | batch[:,0,:,:].data.clamp_(low-103.939, high-103.939) 73 | batch[:,1,:,:].data.clamp_(low-116.779, high-116.779) 74 | batch[:,2,:,:].data.clamp_(low-123.680, high-123.680) 75 | 76 | 77 | def preprocess_batch(batch): 78 | batch = batch.transpose(0, 1) 79 | (r, g, b) = torch.chunk(batch, 3) 80 | batch = torch.cat((b, g, r)) 81 | batch = batch.transpose(0, 1) 82 | return batch 83 | 84 | 85 | def init_vgg16(model_folder): 86 | """load the vgg16 model feature""" 87 | if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')): 88 | if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')): 89 | os.system( 90 | 'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7')) 91 | vgglua = torchfile.load(os.path.join(model_folder, 'vgg16.t7')) 92 | vgg = Vgg16() 93 | for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): 94 | dst.data[:] = src 95 | torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight')) 96 | 97 | 98 | def set_requires_grad(nets, requires_grad=False): 99 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 100 | Parameters: 101 | nets (network list) -- a list of networks 102 | requires_grad (bool) -- whether the networks require gradients or not 103 | """ 104 | if not isinstance(nets, list): 105 | nets = [nets] 106 | for net in nets: 107 | if net is not None: 108 | for param in net.parameters(): 109 | param.requires_grad = requires_grad 110 | 111 | def get_scheduler(optimizer, opt): 112 | if opt.lr_policy == 'lambda': 113 | def lambda_rule(epoch): 114 | print(epoch) 115 | print(opt.epoch_count) 116 | print(opt.niter) 117 | 118 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 119 | print(lr_l) 120 | print(50*'-') 121 | return lr_l 122 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 123 | elif opt.lr_policy == 'step': 124 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 125 | elif opt.lr_policy == 'plateau': 126 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 127 | else: 128 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 129 | return scheduler 130 | 131 | def my_tensor2im(input_image, imtype=np.uint8): 132 | if isinstance(input_image, torch.Tensor): 133 | image_tensor = input_image.data 134 | else: 135 | return input_image 136 | # print(50*'-') 137 | # print(input_image.shape) 138 | # print(image_tensor.shape) 139 | image_numpy = image_tensor.cpu().float().numpy() 140 | if image_numpy.shape[0] == 1: 141 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 142 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 143 | return image_numpy.astype(imtype) -------------------------------------------------------------------------------- /rank2_edge_batch.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import cv2 3 | import numpy as np 4 | import sys 5 | import math 6 | import os 7 | import models_metric 8 | import torchvision.transforms as transforms 9 | import torch.nn as nn 10 | import torch 11 | import shutil 12 | def get_imlist(path): 13 | return [os.path.join(path,f) for f in os.listdir(path)] 14 | def get_selected_imlist(path, txt_path, num): 15 | f = open(txt_path) 16 | cc = 0 17 | line = f.readline() 18 | L = [] 19 | while line: 20 | cc+=1 21 | if cc == num: 22 | break 23 | name = line.split()[0] 24 | L.append(path + os.sep + name + '.png') 25 | line = f.readline() 26 | f.close() 27 | return L 28 | 29 | def set_requires_grad(nets, requires_grad=False): 30 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 31 | Parameters: 32 | nets (network list) -- a list of networks 33 | requires_grad (bool) -- whether the networks require gradients or not 34 | """ 35 | if not isinstance(nets, list): 36 | nets = [nets] 37 | for net in nets: 38 | if net is not None: 39 | for param in net.parameters(): 40 | param.requires_grad = requires_grad 41 | a = np.array([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]], dtype=np.float32) 42 | a = a.reshape(1, 1, 3, 3) # out_c/3, in_c, w, h 43 | a = np.repeat(a, 3, axis=0) 44 | conv1=nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 45 | conv1.weight.data.copy_(torch.from_numpy(a)) 46 | conv1.weight.requires_grad = False 47 | conv1.cuda() 48 | 49 | b = np.array([[-1, -2, -1],[0, 0, 0],[1, 2, 1]], dtype=np.float32) 50 | b = b.reshape(1, 1, 3, 3) 51 | b = np.repeat(b, 3, axis=0) 52 | conv2=nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 53 | conv2.weight.data.copy_(torch.from_numpy(b)) 54 | conv2.weight.requires_grad = False 55 | conv2.cuda() 56 | 57 | net_metric = models_metric.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, spatial=False) 58 | net_metric = net_metric.cuda() 59 | set_requires_grad(net_metric, requires_grad=False) 60 | 61 | 62 | src_folder = sys.argv[1] # The folder for demoired results by GDN 63 | tar_folder = sys.argv[2] # The folder for GT 64 | txt_path = sys.argv[3] # The txt containing image names 65 | num = int(sys.argv[4]) # The number of candidates 66 | front_num = int(sys.argv[5]) # the number of topK 67 | res_folder = src_folder + os.sep + 'd' 68 | 69 | 70 | 71 | # res_img_path = sys.argv[1] 72 | # target_img_path = sys.argv[2] 73 | 74 | # res_img = cv2.imread(res_img_path) 75 | # target_img = cv2.imread(target_img_path) 76 | f = open("recoed_rank2_edge.txt", 'a+') 77 | cnt = 0 78 | D = {} 79 | res_pre_fix = 'res' 80 | tar_pre_fix = 'tar' 81 | if os.path.exists('tmp1.npy'): 82 | A = np.load('tmp.npy') 83 | else: 84 | for res_img_path in get_selected_imlist(res_folder, txt_path, num): 85 | cnt +=1 86 | if cnt %10 ==0: 87 | print(cnt) 88 | (filename, tempfilename) = os.path.split(res_img_path) 89 | (short_name, extension) = os.path.splitext(tempfilename) 90 | tmp = tempfilename.split('_') 91 | tmp2 = short_name.split('_') 92 | target_img_path = src_folder + os.sep + 'g' + os.sep + tar_pre_fix + '_'+tmp[1] 93 | # ori_img_path = src_folder + os.sep + 'o' + os.sep + tempfilename 94 | res_img = cv2.imread(res_img_path) 95 | target_img = cv2.imread(target_img_path) 96 | # ori_img = cv2.imread(ori_img_path) 97 | gray_target = cv2.cvtColor(target_img, cv2.COLOR_BGR2GRAY) 98 | gray_res = cv2.cvtColor(res_img, cv2.COLOR_BGR2GRAY) 99 | ret1, binary_map = cv2.threshold(gray_target, 0, 255, cv2.THRESH_OTSU) 100 | 101 | # binary_map = binary_map / 255 102 | binary_map = binary_map[:, :, np.newaxis] 103 | binary_map = np.repeat(binary_map, 3, axis=2) 104 | 105 | 106 | my_transform = transforms.Compose([ 107 | transforms.ToTensor(), 108 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 109 | ]) 110 | 111 | tensor_res_img = my_transform(res_img).unsqueeze_(0).cuda() 112 | tensor_target_img = my_transform(target_img).unsqueeze_(0).cuda() 113 | tensor_binary_map = my_transform(binary_map).unsqueeze_(0).cuda() 114 | # print(tensor_binary_map) 115 | tensor_binary_map = (tensor_binary_map + 1 )/2 116 | # print(tensor_binary_map) 117 | 118 | i_G_x = conv1(tensor_res_img) 119 | i_G_y = conv2(tensor_res_img) 120 | x_hat_edge = torch.tanh(torch.abs(i_G_x) + torch.abs(i_G_y)) 121 | 122 | t_G_x = conv1(tensor_target_img) 123 | t_G_y = conv2(tensor_target_img) 124 | target_edge = torch.tanh(torch.abs(t_G_x) + torch.abs(t_G_y)) 125 | 126 | tensor_merge_res = x_hat_edge * tensor_binary_map 127 | tensor_merge_target = target_edge * tensor_binary_map 128 | res = net_metric(tensor_merge_res, tensor_merge_target).detach()[0][0][0][0].cpu().numpy() 129 | D[short_name] = res 130 | # break 131 | A = sorted(D.iteritems(), key=lambda x: x[1], reverse=True) 132 | np.save('tmp.npy', A) 133 | cc = 0 134 | for item in A: 135 | if cc==front_num: 136 | break 137 | f.write(str(item[0]) + ' ' + str(item[1]) + '\n') 138 | src_res_img_path = src_folder + os.sep + 'd' + os.sep + item[0] + '.png' 139 | tar_res_img_path = tar_folder + os.sep + 'd' + os.sep + item[0] + '.png' 140 | # tar_res_img_path = tar_folder + os.sep + 'd' + os.sep + 'src_' + '0' * (5 - len(str(cc))) + str(cc) + '.png' 141 | src_tar_img_path = src_folder + os.sep + 'g' + os.sep + 'tar_' + item[0].split('_')[1] + '.png' 142 | tar_tar_img_path = tar_folder + os.sep + 'g' + os.sep + 'tar_' + item[0].split('_')[1] + '.png' 143 | # tar_tar_img_path = tar_folder + os.sep + 'g' + os.sep + 'tar_' + '0' * (5 - len(str(cc))) + str(cc) + '.png' 144 | 145 | shutil.copy(src_res_img_path, tar_res_img_path) 146 | shutil.copy(src_tar_img_path, tar_tar_img_path) 147 | cc+=1 -------------------------------------------------------------------------------- /models_metric/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from models_metric import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /models/non_local_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.nn.init 5 | 6 | class SEBlock(nn.Module): 7 | def __init__(self, input_dim, reduction): 8 | super(SEBlock, self).__init__() 9 | mid = int(input_dim / reduction) 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | self.fc = nn.Sequential( 12 | nn.Linear(input_dim, reduction), 13 | nn.ReLU(inplace=True), 14 | nn.Linear(reduction, input_dim), 15 | nn.Sigmoid() 16 | ) 17 | 18 | def forward(self, x): 19 | b, c, _, _ = x.size() 20 | y = self.avg_pool(x).view(b, c) 21 | y = self.fc(y).view(b, c, 1, 1) 22 | return x * y 23 | 24 | 25 | 26 | 27 | class ConvGRU(nn.Module): 28 | def __init__(self, inp_dim, oup_dim, kernel, dilation): 29 | super(ConvGRU, self).__init__() 30 | pad_x = int(dilation * (kernel - 1) / 2) 31 | self.conv_xz = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x, dilation=dilation) 32 | self.conv_xr = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x, dilation=dilation) 33 | self.conv_xn = nn.Conv2d(inp_dim, oup_dim, kernel, padding=pad_x, dilation=dilation) 34 | 35 | pad_h = int((kernel - 1) / 2) 36 | self.conv_hz = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h) 37 | self.conv_hr = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h) 38 | self.conv_hn = nn.Conv2d(oup_dim, oup_dim, kernel, padding=pad_h) 39 | 40 | # self.nl = NEDB(inter_channel=oup_dim / 2, channel=oup_dim) 41 | self.relu = nn.LeakyReLU(0.2) 42 | 43 | def forward(self, x, h=None): 44 | if h is None: 45 | z = F.sigmoid(self.conv_xz(x)) 46 | f = F.tanh(self.conv_xn(x)) 47 | h = z * f 48 | else: 49 | # h.unsqueeze_(1) 50 | z = F.sigmoid(self.conv_xz(x) + self.conv_hz(h)) 51 | r = F.sigmoid(self.conv_xr(x) + self.conv_hr(h)) 52 | n = F.tanh(self.conv_xn(x) + self.conv_hn(r * h)) 53 | h = (1 - z) * h + z * n 54 | 55 | h = self.relu(h) 56 | return h, h 57 | 58 | class _NonLocalBlockND(nn.Module): 59 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 60 | super(_NonLocalBlockND, self).__init__() 61 | 62 | assert dimension in [1, 2, 3] 63 | 64 | self.dimension = dimension 65 | self.sub_sample = sub_sample 66 | 67 | self.in_channels = in_channels 68 | self.inter_channels = inter_channels 69 | 70 | 71 | if self.inter_channels is None: 72 | self.inter_channels = in_channels // 2 73 | if self.inter_channels == 0: 74 | self.inter_channels = 1 75 | self.se0 = SEBlock(self.inter_channels, self.inter_channels/2) 76 | self.se1 = SEBlock(self.inter_channels, self.inter_channels/2) 77 | self.se2 = SEBlock(self.inter_channels, self.inter_channels/2) 78 | self.se3 = SEBlock(self.inter_channels, self.inter_channels/2) 79 | self.gru1 = ConvGRU(self.inter_channels, self.inter_channels, 1, 1) 80 | self.gru2 = ConvGRU(self.inter_channels, self.inter_channels, 1, 1) 81 | if dimension == 3: 82 | conv_nd = nn.Conv3d 83 | max_pool = nn.MaxPool3d 84 | bn = nn.BatchNorm3d 85 | elif dimension == 2: 86 | conv_nd = nn.Conv2d 87 | max_pool = nn.MaxPool2d 88 | bn = nn.BatchNorm2d 89 | else: 90 | conv_nd = nn.Conv1d 91 | max_pool = nn.MaxPool1d 92 | bn = nn.BatchNorm1d 93 | 94 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 95 | kernel_size=1, stride=1, padding=0) 96 | self.theta_ = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 97 | kernel_size=1, stride=1, padding=0) 98 | if bn_layer: 99 | self.W = nn.Sequential( 100 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 101 | kernel_size=1, stride=1, padding=0), 102 | bn(self.in_channels) 103 | ) 104 | nn.init.constant_(self.W[1].weight, 0) 105 | nn.init.constant_(self.W[1].bias, 0) 106 | else: 107 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 108 | kernel_size=1, stride=1, padding=0) 109 | nn.init.constant_(self.W.weight, 0) 110 | nn.init.constant_(self.W.bias, 0) 111 | 112 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 113 | kernel_size=1, stride=1, padding=0) 114 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 115 | kernel_size=1, stride=1, padding=0) 116 | 117 | if sub_sample: 118 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 119 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 120 | def forward(self, x, corr1, corr2): 121 | ''' 122 | :param x: (b, c, t, h, w) 123 | :return: 124 | ''' 125 | 126 | batch_size = x.size(0) 127 | # print(x.shape) 128 | g_x, ret_corr1 = self.gru1(self.se0(self.g(x)), corr1) 129 | g_x = g_x.view(batch_size, self.inter_channels, -1) 130 | g_x = g_x.permute(0, 2, 1) 131 | 132 | theta_x, ret_corr2 = self.gru2(self.se1(self.theta(x)), corr2) 133 | theta_x = theta_x.view(batch_size, self.inter_channels, -1) 134 | theta_x = theta_x.permute(0, 2, 1) 135 | # print(theta_x.shape) 136 | phi_x = self.se3(self.phi(x)) 137 | phi_x = phi_x.view(batch_size, self.inter_channels, -1) 138 | # print(phi_x.shape) 139 | f = torch.matmul(theta_x, phi_x) 140 | 141 | theta_x_ = self.theta_(x) 142 | theta_x_ = self.se3(theta_x_) 143 | theta_x_ = theta_x_.view(batch_size, self.inter_channels, -1) 144 | theta_x_ = theta_x_.permute(0, 2, 1) 145 | 146 | 147 | f_div_C = F.softmax(f, dim=-1) 148 | 149 | y = torch.matmul(f_div_C, g_x) 150 | y = y + theta_x_ 151 | y = y.permute(0, 2, 1).contiguous() 152 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 153 | W_y = self.W(y) 154 | z = W_y + x 155 | 156 | return z, ret_corr1, ret_corr2 157 | 158 | 159 | class NONLocalBlock2D(_NonLocalBlockND): 160 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 161 | super(NONLocalBlock2D, self).__init__(in_channels, 162 | inter_channels=inter_channels, 163 | dimension=2, sub_sample=sub_sample, 164 | bn_layer=bn_layer) 165 | -------------------------------------------------------------------------------- /models_metric/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | # from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import sys 4 | 5 | 6 | def create_exp_dir(exp): 7 | try: 8 | os.makedirs(exp) 9 | print('Creating exp dir: %s' % exp) 10 | except OSError: 11 | pass 12 | return True 13 | 14 | 15 | def weights_init(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Conv') != -1: 18 | m.weight.data.normal_(0.0, 0.02) 19 | elif classname.find('BatchNorm') != -1: 20 | m.weight.data.normal_(1.0, 0.02) 21 | m.bias.data.fill_(0) 22 | 23 | 24 | def getLoader(datasetName, dataroot, originalSize_h, originalSize_w, imageSize_h, imageSize_w, batchSize=64, workers=4, 25 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None, pre="", label_file="", list_file=""): 26 | 27 | 28 | if datasetName == 'my_loader': 29 | from datasets.my_loader import my_loader_LRN as commonDataset 30 | import transforms.pix2pix as transforms 31 | 32 | elif datasetName == 'my_loader_fs': 33 | # from datasets.pix2pix import pix2pix as commonDataset 34 | # import transforms.pix2pix as transforms 35 | from datasets.my_loader import my_loader_fs as commonDataset 36 | import transforms.pix2pix as transforms 37 | elif datasetName == 'my_loader_LRN_f2_rand3': 38 | # from datasets.pix2pix import pix2pix as commonDataset 39 | # import transforms.pix2pix as transforms 40 | from datasets.my_loader import my_loader_LRN_f2_rand3 as commonDataset 41 | import transforms.pix2pix as transforms 42 | elif datasetName == 'my_loader_LRN_f2_rand2': 43 | # from datasets.pix2pix import pix2pix as commonDataset 44 | # import transforms.pix2pix as transforms 45 | from datasets.my_loader import my_loader_LRN_f2_rand2 as commonDataset 46 | import transforms.pix2pix as transforms 47 | if split == 'test': 48 | dataset = commonDataset(root=dataroot, 49 | transform1=transforms.Compose([ 50 | transforms.Scale(originalSize_h, originalSize_w), 51 | transforms.CenterCrop(imageSize_h, imageSize_w), 52 | ]), 53 | transform2=transforms.Compose([ 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean, std), 56 | ]), 57 | transform3=transforms.Compose1([ 58 | transforms.Scale1(384, 384), 59 | transforms.ToTensor1(), 60 | transforms.Normalize1(mean, std), 61 | ]), 62 | seed=seed, 63 | pre=pre, 64 | label_file=label_file) 65 | elif split == 'train': 66 | dataset = commonDataset(root=dataroot, 67 | transform=transforms.Compose([ 68 | transforms.Scale(originalSize_h, originalSize_w), 69 | transforms.RandomCrop(imageSize_h, imageSize_w), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | transforms.Normalize(mean, std), 73 | ]), 74 | seed=seed, 75 | pre=pre, 76 | label_file=label_file) 77 | elif split == 'LRN_train_guide2': 78 | dataset = commonDataset(root=dataroot, 79 | transform1=transforms.Scale(originalSize_h, originalSize_w), 80 | transform2=transforms.RandomCrop_index(imageSize_h, imageSize_w), 81 | transform3=transforms.RandomHorizontalFlip_index(), 82 | transform4=transforms.GuideCrop(384, 384), 83 | transform5=transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Normalize(mean, std), 86 | ]), 87 | transform6=transforms.Compose1([ 88 | transforms.Scale1(384, 384), 89 | transforms.ToTensor1(), 90 | transforms.Normalize1(mean, std), 91 | ]), 92 | seed=seed, 93 | pre=pre, 94 | list_file=list_file) 95 | elif split == 'LRN_train_guide': 96 | dataset = commonDataset(root=dataroot, 97 | transform1=transforms.Scale(originalSize_h, originalSize_w), 98 | transform2=transforms.RandomCrop_index(imageSize_h, imageSize_w), 99 | transform3=transforms.RandomHorizontalFlip_index(), 100 | transform4=transforms.GuideCrop(384, 384), 101 | transform5=transforms.Compose([ 102 | transforms.ToTensor(), 103 | transforms.Normalize(mean, std), 104 | ]), 105 | transform6=transforms.Compose1([ 106 | transforms.Scale1(384, 384), 107 | transforms.ToTensor1(), 108 | transforms.Normalize1(mean, std), 109 | ]), 110 | seed=seed, 111 | pre=pre, 112 | label_file=label_file) 113 | 114 | dataloader = torch.utils.data.DataLoader(dataset, 115 | batch_size=batchSize, 116 | shuffle=shuffle, 117 | num_workers=int(workers)) 118 | return dataloader 119 | 120 | 121 | class AverageMeter(object): 122 | """Computes and stores the average and current value""" 123 | def __init__(self): 124 | self.reset() 125 | 126 | def reset(self): 127 | self.val = 0 128 | self.avg = 0 129 | self.sum = 0 130 | self.count = 0 131 | 132 | def update(self, val, n=1): 133 | self.val = val 134 | self.sum += val * n 135 | self.count += n 136 | self.avg = self.sum / self.count 137 | 138 | 139 | import numpy as np 140 | class ImagePool: 141 | def __init__(self, pool_size=50): 142 | self.pool_size = pool_size 143 | if pool_size > 0: 144 | self.num_imgs = 0 145 | self.images = [] 146 | 147 | def query(self, image): 148 | if self.pool_size == 0: 149 | return image 150 | if self.num_imgs < self.pool_size: 151 | self.images.append(image.clone()) 152 | self.num_imgs += 1 153 | return image 154 | else: 155 | if np.random.uniform(0,1) > 0.5: 156 | random_id = np.random.randint(self.pool_size, size=1)[0] 157 | tmp = self.images[random_id].clone() 158 | self.images[random_id] = image.clone() 159 | return tmp 160 | else: 161 | return image 162 | 163 | 164 | def adjust_learning_rate(optimizer, init_lr, epoch, factor, every): 165 | #import pdb; pdb.set_trace() 166 | lrd = init_lr / every 167 | old_lr = optimizer.param_groups[0]['lr'] 168 | # linearly decaying lr 169 | lr = old_lr - lrd 170 | if lr < 0: lr = 0 171 | for param_group in optimizer.param_groups: 172 | param_group['lr'] = lr 173 | -------------------------------------------------------------------------------- /transforms/pix2pix_val.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB): 24 | for t in self.transforms: 25 | imgA, imgB = t(imgA, imgB) 26 | return imgA, imgB 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB): 33 | pics = [picA, picB] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB): 61 | pics = [picA, picB] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB): 87 | tensors = [tensorA, tensorB] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB): 109 | imgs = [imgA, imgB] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB): 139 | imgs = [imgA, imgB] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | # output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | 148 | output.append(img) 149 | 150 | return output[0], output[1] 151 | 152 | class Pad(object): 153 | """Pads the given PIL.Image on all sides with the given "pad" value""" 154 | def __init__(self, padding, fill=0): 155 | assert isinstance(padding, numbers.Number) 156 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 157 | self.padding = padding 158 | self.fill = fill 159 | 160 | def __call__(self, imgA, imgB): 161 | imgs = [imgA, imgB] 162 | output = [] 163 | for img in imgs: 164 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 165 | return output[0], output[1] 166 | 167 | class Lambda(object): 168 | """Applies a lambda as a transform.""" 169 | def __init__(self, lambd): 170 | assert isinstance(lambd, types.LambdaType) 171 | self.lambd = lambd 172 | 173 | def __call__(self, imgA, imgB): 174 | imgs = [imgA, imgB] 175 | output = [] 176 | for img in imgs: 177 | output.append(self.lambd(img)) 178 | return output[0], output[1] 179 | 180 | class RandomCrop(object): 181 | """Crops the given PIL.Image at a random location to have a region of 182 | the given size. size can be a tuple (target_height, target_width) 183 | or an integer, in which case the target will be of a square shape (size, size) 184 | """ 185 | def __init__(self, size, padding=0): 186 | if isinstance(size, numbers.Number): 187 | self.size = (int(size), int(size)) 188 | else: 189 | self.size = size 190 | self.padding = padding 191 | 192 | def __call__(self, imgA, imgB): 193 | imgs = [imgA, imgB] 194 | output = [] 195 | x1 = -1 196 | y1 = -1 197 | for img in imgs: 198 | if self.padding > 0: 199 | img = ImageOps.expand(img, border=self.padding, fill=0) 200 | 201 | w, h = img.size 202 | th, tw = self.size 203 | if w == tw and h == th: 204 | output.append(img) 205 | continue 206 | 207 | if x1 == -1 and y1 == -1: 208 | x1 = random.randint(0, w - tw) 209 | y1 = random.randint(0, h - th) 210 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 211 | return output[0], output[1] 212 | 213 | class RandomHorizontalFlip(object): 214 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 215 | """ 216 | def __call__(self, imgA, imgB): 217 | imgs = [imgA, imgB] 218 | output = [] 219 | flag = random.random() < 0.5 220 | for img in imgs: 221 | if flag: 222 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 223 | else: 224 | output.append(img) 225 | return output[0], output[1] 226 | -------------------------------------------------------------------------------- /transforms/pix2pix3.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB, imgC): 24 | for t in self.transforms: 25 | imgA, imgB, imgC = t(imgA, imgB, imgC) 26 | return imgA, imgB, imgC 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB, picC): 33 | pics = [picA, picB, picC] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1], output[2] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB, picC): 61 | pics = [picA, picB, picC] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1], output[2] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB, tensorC): 87 | tensors = [tensorA, tensorB, tensorC] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1], output[2] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB, imgC): 109 | imgs = [imgA, imgB, imgC] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1], output[2] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB, imgC): 139 | imgs = [imgA, imgB, imgC] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | return output[0], output[1], output[2] 148 | 149 | class Pad(object): 150 | """Pads the given PIL.Image on all sides with the given "pad" value""" 151 | def __init__(self, padding, fill=0): 152 | assert isinstance(padding, numbers.Number) 153 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 154 | self.padding = padding 155 | self.fill = fill 156 | 157 | def __call__(self, imgA, imgB, imgC): 158 | imgs = [imgA, imgB, imgC] 159 | output = [] 160 | for img in imgs: 161 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 162 | return output[0], output[1], output[2] 163 | 164 | class Lambda(object): 165 | """Applies a lambda as a transform.""" 166 | def __init__(self, lambd): 167 | assert isinstance(lambd, types.LambdaType) 168 | self.lambd = lambd 169 | 170 | def __call__(self, imgA, imgB, imgC): 171 | imgs = [imgA, imgB, imgC] 172 | output = [] 173 | for img in imgs: 174 | output.append(self.lambd(img)) 175 | return output[0], output[1], output[2] 176 | 177 | class RandomCrop(object): 178 | """Crops the given PIL.Image at a random location to have a region of 179 | the given size. size can be a tuple (target_height, target_width) 180 | or an integer, in which case the target will be of a square shape (size, size) 181 | """ 182 | def __init__(self, size, padding=0): 183 | if isinstance(size, numbers.Number): 184 | self.size = (int(size), int(size)) 185 | else: 186 | self.size = size 187 | self.padding = padding 188 | 189 | def __call__(self, imgA, imgB, imgC): 190 | imgs = [imgA, imgB, imgC] 191 | output = [] 192 | x1 = -1 193 | y1 = -1 194 | for img in imgs: 195 | if self.padding > 0: 196 | img = ImageOps.expand(img, border=self.padding, fill=0) 197 | 198 | w, h = img.size 199 | th, tw = self.size 200 | if w == tw and h == th: 201 | output.append(img) 202 | continue 203 | 204 | if x1 == -1 and y1 == -1: 205 | x1 = random.randint(0, w - tw) 206 | y1 = random.randint(0, h - th) 207 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 208 | return output[0], output[1], output[2] 209 | 210 | class RandomHorizontalFlip(object): 211 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 212 | """ 213 | def __call__(self, imgA, imgB, imgC): 214 | imgs = [imgA, imgB, imgC] 215 | output = [] 216 | # flag = random.random() < 0.5 217 | flag = random.random() < -1 218 | 219 | for img in imgs: 220 | if flag: 221 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 222 | else: 223 | output.append(img) 224 | return output[0], output[1], output[2] 225 | -------------------------------------------------------------------------------- /transforms/pix2pix_val3.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB, imgC): 24 | for t in self.transforms: 25 | imgA, imgB, imgC = t(imgA, imgB, imgC) 26 | return imgA, imgB, imgC 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB, picC): 33 | pics = [picA, picB, picC] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1], output[2] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB, picC): 61 | pics = [picA, picB, picC] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1], output[2] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB, tensorC): 87 | tensors = [tensorA, tensorB, tensorC] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1], output[2] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB, imgC): 109 | imgs = [imgA, imgB, imgC] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1], output[2] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB, imgC): 139 | imgs = [imgA, imgB, imgC] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | return output[0], output[1], output[2] 148 | 149 | class Pad(object): 150 | """Pads the given PIL.Image on all sides with the given "pad" value""" 151 | def __init__(self, padding, fill=0): 152 | assert isinstance(padding, numbers.Number) 153 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 154 | self.padding = padding 155 | self.fill = fill 156 | 157 | def __call__(self, imgA, imgB, imgC): 158 | imgs = [imgA, imgB, imgC] 159 | output = [] 160 | for img in imgs: 161 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 162 | return output[0], output[1], output[2] 163 | 164 | class Lambda(object): 165 | """Applies a lambda as a transform.""" 166 | def __init__(self, lambd): 167 | assert isinstance(lambd, types.LambdaType) 168 | self.lambd = lambd 169 | 170 | def __call__(self, imgA, imgB, imgC): 171 | imgs = [imgA, imgB, imgC] 172 | output = [] 173 | for img in imgs: 174 | output.append(self.lambd(img)) 175 | return output[0], output[1], output[2] 176 | 177 | class RandomCrop(object): 178 | """Crops the given PIL.Image at a random location to have a region of 179 | the given size. size can be a tuple (target_height, target_width) 180 | or an integer, in which case the target will be of a square shape (size, size) 181 | """ 182 | def __init__(self, size, padding=0): 183 | if isinstance(size, numbers.Number): 184 | self.size = (int(size), int(size)) 185 | else: 186 | self.size = size 187 | self.padding = padding 188 | 189 | def __call__(self, imgA, imgB, imgC): 190 | imgs = [imgA, imgB, imgC] 191 | output = [] 192 | x1 = -1 193 | y1 = -1 194 | for img in imgs: 195 | if self.padding > 0: 196 | img = ImageOps.expand(img, border=self.padding, fill=0) 197 | 198 | w, h = img.size 199 | th, tw = self.size 200 | if w == tw and h == th: 201 | output.append(img) 202 | continue 203 | 204 | if x1 == -1 and y1 == -1: 205 | x1 = random.randint(0, w - tw) 206 | y1 = random.randint(0, h - th) 207 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 208 | return output[0], output[1], output[2] 209 | 210 | class RandomHorizontalFlip(object): 211 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 212 | """ 213 | def __call__(self, imgA, imgB, imgC): 214 | imgs = [imgA, imgB, imgC] 215 | output = [] 216 | # flag = random.random() < 0.5 217 | flag = random.random() < -1 218 | 219 | for img in imgs: 220 | if flag: 221 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 222 | else: 223 | output.append(img) 224 | return output[0], output[1], output[2] 225 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | import util 6 | import html 7 | from scipy.misc import imresize 8 | 9 | 10 | class Visualizer(): 11 | def __init__(self, display_port, name): 12 | self.display_id = 1 13 | self.use_html = False 14 | self.win_size = 256 15 | self.name = name 16 | self.saved = False 17 | if self.display_id > 0: 18 | import visdom 19 | self.ncols = 4 20 | self.vis = visdom.Visdom(server="http://localhost", port=display_port) 21 | 22 | # if self.use_html: 23 | # self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 24 | # self.img_dir = os.path.join(self.web_dir, 'images') 25 | # print('create web directory %s...' % self.web_dir) 26 | # util.mkdirs([self.web_dir, self.img_dir]) 27 | # self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 28 | # with open(self.log_name, "a") as log_file: 29 | # now = time.strftime("%c") 30 | # log_file.write('================ Training Loss (%s) ================\n' % now) 31 | 32 | def reset(self): 33 | self.saved = False 34 | 35 | # |visuals|: dictionary of images to display or save 36 | def display_current_results(self, visuals, epoch, save_result): 37 | if self.display_id > 0: # show images in the browser 38 | ncols = self.ncols 39 | if ncols > 0: 40 | ncols = min(ncols, len(visuals)) 41 | h, w = next(iter(visuals.values())).shape[:2] 42 | table_css = """""" % (w, h) 46 | title = self.name 47 | label_html = '' 48 | label_html_row = '' 49 | images = [] 50 | idx = 0 51 | for label, image in visuals.items(): 52 | image_numpy = util.tensor2im(image) 53 | label_html_row += '