├── .gitattributes ├── .gitignore ├── README.md ├── VOC2012_224_train_png.txt ├── checkpoints └── ERRNet_AdaNEC_NI │ └── merge_model.py ├── data ├── __init__.py ├── image_folder.py ├── reflect_dataset.py ├── torchdata.py └── transforms.py ├── datasets └── eval │ └── real20 │ ├── blended │ ├── 103.png │ ├── 107.png │ ├── 110.png │ ├── 12.png │ ├── 15.png │ ├── 22.png │ ├── 23.png │ ├── 25.png │ ├── 29.png │ ├── 3.png │ ├── 39.png │ ├── 4.png │ ├── 46.png │ ├── 47.png │ ├── 58.png │ ├── 86.png │ ├── 87.png │ ├── 89.png │ ├── 9.png │ └── 93.png │ ├── data_list.txt │ └── transmission_layer │ ├── 103.png │ ├── 107.png │ ├── 110.png │ ├── 12.png │ ├── 15.png │ ├── 22.png │ ├── 23.png │ ├── 25.png │ ├── 29.png │ ├── 3.png │ ├── 39.png │ ├── 4.png │ ├── 46.png │ ├── 47.png │ ├── 58.png │ ├── 86.png │ ├── 87.png │ ├── 89.png │ ├── 9.png │ └── 93.png ├── engine.py ├── models ├── CX │ ├── CX_distance.py │ ├── CX_helper.py │ ├── __init__.py │ └── enums.py ├── __init__.py ├── arch │ ├── __init__.py │ └── default.py ├── base_model.py ├── errnet_model.py ├── losses.py ├── networks.py └── vgg.py ├── options ├── __init__.py ├── base_option.py └── errnet │ ├── __init__.py │ ├── base_options.py │ └── train_options.py ├── real_test.txt ├── test_AdaNEC_NI.sh ├── test_AdaNEC_OF.sh ├── test_errnet.py ├── train.sh ├── train_errnet.py └── util ├── __init__.py ├── html.py ├── index.py ├── util.py └── visualizer.py /.gitattributes: -------------------------------------------------------------------------------- 1 | checkpoints/ERRNet_AdaNEC_OF/final_release.pt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pyc 3 | *.pth 4 | __pycache__ 5 | result 6 | runs 7 | .ssh* 8 | .rsyncignore 9 | bak 10 | *.zip 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaNEC 2 | 3 | The implementation of Anonymous Submission 3582 "[Adaptive Network Combination for Single-Image Reflection Removal: A Domain Generalization Perspective](https://github.com/Anonymous3582/AdaNEC)" 4 | 5 | 6 | ## Envoronment 7 | 8 | This repo has been tested in the following environment 9 | * Platforms: Ubuntu 20.04 10 | * Framework: We use PyTorch 1.8, and it should work with PyTorch 1.2 - PyTorch 1.10 11 | * Requirements: opencv-python, tensorboardX, visdom, dominate, scikit-image, etc 12 | * GPU: We use RTX 3090 with CUDA 11 13 | 14 | 15 | ## Quick Start 16 | 17 | ### Downloading this repository 18 | We provide the pre-trained model via git-lfs (large file storage), please clone this repository via one of the following methods. 19 | 1) `git lfs clone https://github.com/Anonymous3582/AdaNEC` or 20 | 2) `git clone https://github.com/Anonymous3582/AdaNEC`, then download the pre-trained model from [this page](./checkpoints/ERRNet_AdaNEC_OF/final_release.pt), and place it in the `checkpoints/ERRNet_AdaNEC_OF` folder. 21 | 22 | ### Preparing your testing datasets 23 | 24 | Note that we have provided the *Real20* testing set in the `datasets` folder, and the *SIR^2* subsets are not provided due to their policy. You can request for them and organize the *SIR^2* sub-datasets as *Real20*. 25 | * 20 real testing images from [Berkeley real dataset](https://github.com/ceciliavision/perceptual-reflection-removal). 26 | * Three sub-datasets, namely *Objects*, *Postcard*, *Wild* from [SIR^2 dataset](https://sir2data.github.io/) 27 | 28 | ### Testing 29 | We provide two working schemes of AdaNEC, i.e., output fusion (OF) and network interpolation (NI). 30 | 31 | #### Output Fusion (OF) 32 | The OF model has been provided, you can run the OF scheme via 33 | ```shell 34 | $ bash test_AdaNEC_OF.sh 35 | ``` 36 | 37 | #### Network Interpolation (NI) 38 | The NI model can be generated by running `python merge_model.py` under the `checkpoints/ERRNet_AdaNEC_NI` directory. 39 | And then the NI scheme can be excueted via 40 | ```shell 41 | $ bash test_AdaNEC_NI.sh 42 | ``` 43 | Particularly note that you should generate an NI model for each testing dataset. Please see [`merge_model.py`](./checkpoints/ERRNet_AdaNEC_NI/merge_model.py) for more details. 44 | 45 | ### Results 46 | The results will be placed in the `results` folder. 47 | 48 | ## Acknowledgments 49 | The code is built upon [ERRNet](https://github.com/Vandermode/ERRNet), one of the backbone models of our work. We will include their github commits when releasing the code publicly. 50 | We highly appreciate all the authors of [ERRNet](https://github.com/Vandermode/ERRNet), [IBCLN](https://github.com/JHL-HUST/IBCLN), and [RAGNet](https://github.com/liyucs/RAGNet) for their efforts. 51 | -------------------------------------------------------------------------------- /checkpoints/ERRNet_AdaNEC_NI/merge_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | model_weights = torch.load('../ERRNet_AdaNEC_OF/final_release.pt')['icnn'] 3 | 4 | # We use the domain-level RTAW weights for network interpolation, which are averaged on a dataset. 5 | weights = [0.108686739, 0.009344623, 0.881968638] # real20 6 | # weights = [0.217511492, 0.111057787, 0.671430722] # wild 7 | # weights = [0.101291473, 0.580948268, 0.317760259] # postcard 8 | # weights = [0.186080102, 0.355502779, 0.458417119] # solid 9 | 10 | 11 | ckpt = {'icnn': {}} 12 | for k in model_weights.keys(): 13 | if k.startswith('0'): 14 | k_ = k[2:] 15 | ckpt['icnn'][k_] = model_weights['0.'+k_] * weights[0] \ 16 | + model_weights['1.'+k_] * weights[1] \ 17 | + model_weights['2.'+k_] * weights[2] 18 | 19 | torch.save(ckpt, 'final_release.pt') -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/data/__init__.py -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def read_fns(filename): 21 | with open(filename) as f: 22 | fns = f.readlines() 23 | fns = [fn.strip() for fn in fns] 24 | return fns 25 | 26 | 27 | def is_image_file(filename): 28 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 29 | 30 | 31 | def make_dataset(dir, fns=None): 32 | images = [] 33 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 34 | 35 | if fns is None: 36 | for root, _, fnames in sorted(os.walk(dir)): 37 | for fname in fnames: 38 | if is_image_file(fname): 39 | path = os.path.join(root, fname) 40 | images.append(path) 41 | else: 42 | for fname in fns: 43 | if is_image_file(fname): 44 | path = os.path.join(dir, fname) 45 | images.append(path) 46 | 47 | return images 48 | 49 | 50 | def default_loader(path): 51 | return Image.open(path).convert('RGB') 52 | -------------------------------------------------------------------------------- /data/reflect_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from os.path import join 3 | from data.image_folder import make_dataset 4 | from data.transforms import Sobel, to_norm_tensor, to_tensor, ReflectionSythesis_1, ReflectionSythesis_2 5 | from PIL import Image 6 | import random 7 | import torch 8 | import math 9 | 10 | import torchvision.transforms as transforms 11 | import torchvision.transforms.functional as F 12 | 13 | import util.util as util 14 | import data.torchdata as torchdata 15 | 16 | 17 | def __scale_width(img, target_width): 18 | ow, oh = img.size 19 | if (ow == target_width): 20 | return img 21 | w = target_width 22 | h = int(target_width * oh / ow) 23 | h = math.ceil(h / 2.) * 2 # round up to even 24 | return img.resize((w, h), Image.BICUBIC) 25 | 26 | def __scale_height(img, target_height): 27 | ow, oh = img.size 28 | if (oh == target_height): 29 | return img 30 | h = target_height 31 | w = int(target_height * ow / oh) 32 | w = math.ceil(w / 2.) * 2 # round up to even 33 | return img.resize((w, h), Image.BICUBIC) 34 | 35 | 36 | def paired_data_transforms(img_1, img_2, unaligned_transforms=False): 37 | def get_params(img, output_size): 38 | w, h = img.size 39 | th, tw = output_size 40 | if w == tw and h == th: 41 | return 0, 0, h, w 42 | 43 | i = random.randint(0, h - th) 44 | j = random.randint(0, w - tw) 45 | return i, j, th, tw 46 | 47 | # target_size = int(random.randint(224+10, 448) / 2.) * 2 48 | target_size = int(random.randint(224, 448) / 2.) * 2 49 | # target_size = int(random.randint(256, 480) / 2.) * 2 50 | ow, oh = img_1.size 51 | if ow >= oh: 52 | img_1 = __scale_height(img_1, target_size) 53 | img_2 = __scale_height(img_2, target_size) 54 | else: 55 | img_1 = __scale_width(img_1, target_size) 56 | img_2 = __scale_width(img_2, target_size) 57 | 58 | if random.random() < 0.5: 59 | img_1 = F.hflip(img_1) 60 | img_2 = F.hflip(img_2) 61 | 62 | i, j, h, w = get_params(img_1, (224,224)) 63 | # i, j, h, w = get_params(img_1, (256,256)) 64 | img_1 = F.crop(img_1, i, j, h, w) 65 | 66 | if unaligned_transforms: 67 | # print('random shift') 68 | i_shift = random.randint(-10, 10) 69 | j_shift = random.randint(-10, 10) 70 | i += i_shift 71 | j += j_shift 72 | 73 | img_2 = F.crop(img_2, i, j, h, w) 74 | 75 | return img_1,img_2 76 | 77 | 78 | BaseDataset = torchdata.Dataset 79 | 80 | 81 | class DataLoader(torch.utils.data.DataLoader): 82 | def __init__(self, dataset, batch_size, shuffle, *args, **kwargs): 83 | super(DataLoader, self).__init__(dataset, batch_size, shuffle, *args, **kwargs) 84 | self.shuffle = shuffle 85 | 86 | def reset(self): 87 | if self.shuffle: 88 | print('Reset Dataset...') 89 | self.dataset.reset() 90 | 91 | 92 | class CEILDataset(BaseDataset): 93 | def __init__(self, datadir, fns=None, size=None, enable_transforms=True, low_sigma=2, high_sigma=5, low_gamma=1.3, high_gamma=1.3): 94 | super(CEILDataset, self).__init__() 95 | self.size = size 96 | self.datadir = datadir 97 | self.enable_transforms = enable_transforms 98 | 99 | sortkey = lambda key: os.path.split(key)[-1] 100 | self.paths = sorted(make_dataset(datadir, fns), key=sortkey) 101 | if size is not None: 102 | self.paths = self.paths[:size] 103 | 104 | self.syn_model = ReflectionSythesis_1(kernel_sizes=[11], low_sigma=low_sigma, high_sigma=high_sigma, low_gamma=low_gamma, high_gamma=high_gamma) 105 | self.reset(shuffle=False) 106 | 107 | def reset(self, shuffle=True): 108 | if shuffle: 109 | random.shuffle(self.paths) 110 | num_paths = len(self.paths) // 2 111 | self.B_paths = self.paths[0:num_paths] 112 | self.R_paths = self.paths[num_paths:2*num_paths] 113 | 114 | def data_synthesis(self, t_img, r_img): 115 | if self.enable_transforms: 116 | t_img, r_img = paired_data_transforms(t_img, r_img) 117 | syn_model = self.syn_model 118 | t_img, r_img, m_img = syn_model(t_img, r_img) 119 | 120 | B = to_tensor(t_img) 121 | R = to_tensor(r_img) 122 | M = to_tensor(m_img) 123 | 124 | return B, R, M 125 | 126 | def __getitem__(self, index): 127 | index_B = index % len(self.B_paths) 128 | index_R = index % len(self.R_paths) 129 | 130 | B_path = self.B_paths[index_B] 131 | R_path = self.R_paths[index_R] 132 | 133 | t_img = Image.open(B_path).convert('RGB') 134 | r_img = Image.open(R_path).convert('RGB') 135 | 136 | B, R, M = self.data_synthesis(t_img, r_img) 137 | 138 | fn = os.path.basename(B_path) 139 | return {'input': M, 'target_t': B, 'target_r': R, 'fn': fn} 140 | 141 | def __len__(self): 142 | if self.size is not None: 143 | return min(max(len(self.B_paths), len(self.R_paths)), self.size) 144 | else: 145 | return max(len(self.B_paths), len(self.R_paths)) 146 | 147 | 148 | class CEILTestDataset(BaseDataset): 149 | def __init__(self, datadir, fns=None, size=None, enable_transforms=False, unaligned_transforms=False, round_factor=1, flag=None): 150 | super(CEILTestDataset, self).__init__() 151 | self.size = size 152 | self.datadir = datadir 153 | self.fns = fns or os.listdir(join(datadir, 'blended')) 154 | self.enable_transforms = enable_transforms 155 | self.unaligned_transforms = unaligned_transforms 156 | self.round_factor = round_factor 157 | self.flag = flag 158 | 159 | if size is not None: 160 | self.fns = self.fns[:size] 161 | 162 | def __getitem__(self, index): 163 | fn = self.fns[index] 164 | 165 | t_img = Image.open(join(self.datadir, 'transmission_layer', fn)).convert('RGB') 166 | m_img = Image.open(join(self.datadir, 'blended', fn)).convert('RGB') 167 | 168 | if self.enable_transforms: 169 | t_img, m_img = paired_data_transforms(t_img, m_img, self.unaligned_transforms) 170 | 171 | B = to_tensor(t_img) 172 | M = to_tensor(m_img) 173 | 174 | dic = {'input': M, 'target_t': B, 'fn': fn, 'real':True, 'target_r': B} # fake reflection gt 175 | if self.flag is not None: 176 | dic.update(self.flag) 177 | return dic 178 | 179 | def __len__(self): 180 | if self.size is not None: 181 | return min(len(self.fns), self.size) 182 | else: 183 | return len(self.fns) 184 | 185 | 186 | class RealDataset(BaseDataset): 187 | def __init__(self, datadir, fns=None, size=None): 188 | super(RealDataset, self).__init__() 189 | self.size = size 190 | self.datadir = datadir 191 | self.fns = fns or os.listdir(join(datadir)) 192 | 193 | if size is not None: 194 | self.fns = self.fns[:size] 195 | 196 | def __getitem__(self, index): 197 | fn = self.fns[index] 198 | B = -1 199 | 200 | m_img = Image.open(join(self.datadir, fn)).convert('RGB') 201 | 202 | M = to_tensor(m_img) 203 | data = {'input': M, 'target_t': B, 'fn': fn} 204 | return data 205 | 206 | def __len__(self): 207 | if self.size is not None: 208 | return min(len(self.fns), self.size) 209 | else: 210 | return len(self.fns) 211 | 212 | 213 | class PairedCEILDataset(CEILDataset): 214 | def __init__(self, datadir, fns=None, size=None, enable_transforms=True, low_sigma=2, high_sigma=5): 215 | self.size = size 216 | self.datadir = datadir 217 | 218 | self.fns = fns or os.listdir(join(datadir, 'reflection_layer')) 219 | if size is not None: 220 | self.fns = self.fns[:size] 221 | 222 | self.syn_model = ReflectionSythesis_1(kernel_sizes=[11], low_sigma=low_sigma, high_sigma=high_sigma) 223 | self.enable_transforms = enable_transforms 224 | self.reset() 225 | 226 | def reset(self): 227 | return 228 | 229 | def __getitem__(self, index): 230 | fn = self.fns[index] 231 | B_path = join(self.datadir, 'transmission_layer', fn) 232 | R_path = join(self.datadir, 'reflection_layer', fn) 233 | 234 | t_img = Image.open(B_path).convert('RGB') 235 | r_img = Image.open(R_path).convert('RGB') 236 | 237 | B, R, M = self.data_synthesis(t_img, r_img) 238 | 239 | data = {'input': M, 'target_t': B, 'target_r': R, 'fn': fn} 240 | # return M, B 241 | return data 242 | 243 | def __len__(self): 244 | if self.size is not None: 245 | return min(len(self.fns), self.size) 246 | else: 247 | return len(self.fns) 248 | 249 | 250 | class FusionDataset(BaseDataset): 251 | def __init__(self, datasets, fusion_ratios=None): 252 | self.datasets = datasets 253 | self.size = sum([len(dataset) for dataset in datasets]) 254 | self.fusion_ratios = fusion_ratios or [1./len(datasets)] * len(datasets) 255 | print('[i] using a fusion dataset: %d %s imgs fused with ratio %s' %(self.size, [len(dataset) for dataset in datasets], self.fusion_ratios)) 256 | 257 | def reset(self): 258 | for dataset in self.datasets: 259 | dataset.reset() 260 | 261 | def __getitem__(self, index): 262 | residual = 1 263 | for i, ratio in enumerate(self.fusion_ratios): 264 | if random.random() < ratio/residual or i == len(self.fusion_ratios) - 1: 265 | dataset = self.datasets[i] 266 | # return dataset[index%len(dataset)] 267 | ret = dataset[index%len(dataset)] 268 | ret['idx'] = i 269 | return ret 270 | residual -= ratio 271 | 272 | def __len__(self): 273 | return self.size 274 | 275 | 276 | class RepeatedDataset(BaseDataset): 277 | def __init__(self, dataset, repeat=1): 278 | self.dataset = dataset 279 | self.size = len(dataset) * repeat 280 | # self.reset() 281 | 282 | def reset(self): 283 | 284 | self.dataset.reset() 285 | 286 | def __getitem__(self, index): 287 | dataset = self.dataset 288 | return dataset[index%len(dataset)] 289 | 290 | def __len__(self): 291 | return self.size 292 | -------------------------------------------------------------------------------- /data/torchdata.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | def reset(self): 26 | return 27 | 28 | 29 | class ConcatDataset(Dataset): 30 | """ 31 | Dataset to concatenate multiple datasets. 32 | Purpose: useful to assemble different existing datasets, possibly 33 | large-scale datasets as the concatenation operation is done in an 34 | on-the-fly manner. 35 | 36 | Arguments: 37 | datasets (sequence): List of datasets to be concatenated 38 | """ 39 | 40 | @staticmethod 41 | def cumsum(sequence): 42 | r, s = [], 0 43 | for e in sequence: 44 | l = len(e) 45 | r.append(l + s) 46 | s += l 47 | return r 48 | 49 | def __init__(self, datasets): 50 | super(ConcatDataset, self).__init__() 51 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 52 | self.datasets = list(datasets) 53 | self.cumulative_sizes = self.cumsum(self.datasets) 54 | 55 | def __len__(self): 56 | return self.cumulative_sizes[-1] 57 | 58 | def __getitem__(self, idx): 59 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 60 | if dataset_idx == 0: 61 | sample_idx = idx 62 | else: 63 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 64 | return self.datasets[dataset_idx][sample_idx] 65 | 66 | @property 67 | def cummulative_sizes(self): 68 | warnings.warn("cummulative_sizes attribute is renamed to " 69 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 70 | return self.cumulative_sizes -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps, ImageEnhance, PILLOW_VERSION 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numpy as np 11 | import scipy.stats as st 12 | import cv2 13 | import numbers 14 | import types 15 | import collections 16 | import matplotlib.pyplot as plt 17 | import torchvision.transforms as transforms 18 | import util.util as util 19 | from scipy.signal import convolve2d 20 | 21 | 22 | # utility 23 | def _is_pil_image(img): 24 | if accimage is not None: 25 | return isinstance(img, (Image.Image, accimage.Image)) 26 | else: 27 | return isinstance(img, Image.Image) 28 | 29 | 30 | def _is_tensor_image(img): 31 | return torch.is_tensor(img) and img.ndimension() == 3 32 | 33 | 34 | def _is_numpy_image(img): 35 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 36 | 37 | 38 | def arrshow(arr): 39 | Image.fromarray(arr.astype(np.uint8)).show() 40 | 41 | 42 | def get_transform(opt): 43 | transform_list = [] 44 | osizes = util.parse_args(opt.loadSize) 45 | fineSize = util.parse_args(opt.fineSize) 46 | if opt.resize_or_crop == 'resize_and_crop': 47 | transform_list.append( 48 | transforms.RandomChoice([ 49 | transforms.Resize([osize, osize], Image.BICUBIC) for osize in osizes 50 | ])) 51 | transform_list.append(transforms.RandomCrop(fineSize)) 52 | elif opt.resize_or_crop == 'crop': 53 | transform_list.append(transforms.RandomCrop(fineSize)) 54 | elif opt.resize_or_crop == 'scale_width': 55 | transform_list.append(transforms.Lambda( 56 | lambda img: __scale_width(img, fineSize))) 57 | elif opt.resize_or_crop == 'scale_width_and_crop': 58 | transform_list.append(transforms.Lambda( 59 | lambda img: __scale_width(img, opt.loadSize))) 60 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 61 | 62 | if opt.isTrain and not opt.no_flip: 63 | transform_list.append(transforms.RandomHorizontalFlip()) 64 | 65 | return transforms.Compose(transform_list) 66 | 67 | 68 | to_norm_tensor = transforms.Compose([ 69 | transforms.ToTensor(), 70 | transforms.Normalize( 71 | (0.5, 0.5, 0.5), 72 | (0.5, 0.5, 0.5) 73 | ) 74 | ]) 75 | 76 | to_tensor = transforms.ToTensor() 77 | 78 | 79 | def __scale_width(img, target_width): 80 | ow, oh = img.size 81 | if (ow == target_width): 82 | return img 83 | w = target_width 84 | h = int(target_width * oh / ow) 85 | h = math.ceil(h / 2.) * 2 # round up to even 86 | return img.resize((w, h), Image.BICUBIC) 87 | 88 | 89 | # functional 90 | def gaussian_blur(img, kernel_size, sigma): 91 | from scipy.ndimage.filters import gaussian_filter 92 | if not _is_pil_image(img): 93 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 94 | 95 | img = np.asarray(img) 96 | # the 3rd dimension (i.e. inter-band) would be filtered which is unwanted for our purpose 97 | # new = gaussian_filter(img, sigma=sigma, truncate=truncate) 98 | if isinstance(kernel_size, int): 99 | kernel_size = (kernel_size, kernel_size) 100 | elif isinstance(kernel_size, collections.Sequence): 101 | assert len(kernel_size) == 2 102 | new = cv2.GaussianBlur(img, kernel_size, sigma) # apply gaussian filter band by band 103 | return Image.fromarray(new) 104 | 105 | 106 | # transforms 107 | class GaussianBlur(object): 108 | def __init__(self, kernel_size=11, sigma=3): 109 | self.kernel_size = kernel_size 110 | self.sigma = sigma 111 | 112 | def __call__(self, img): 113 | return gaussian_blur(img, self.kernel_size, self.sigma) 114 | 115 | 116 | class ReflectionSythesis_1(object): 117 | """Reflection image data synthesis for weakly-supervised learning 118 | of ICCV 2017 paper *"A Generic Deep Architecture for Single Image Reflection Removal and Image Smoothing"* 119 | """ 120 | def __init__(self, kernel_sizes=None, low_sigma=2, high_sigma=5, low_gamma=1.3, high_gamma=1.3): 121 | self.kernel_sizes = kernel_sizes or [11] 122 | self.low_sigma = low_sigma 123 | self.high_sigma = high_sigma 124 | self.low_gamma = low_gamma 125 | self.high_gamma = high_gamma 126 | print('[i] reflection sythesis model: {}'.format({ 127 | 'kernel_sizes': kernel_sizes, 'low_sigma': low_sigma, 'high_sigma': high_sigma, 128 | 'low_gamma': low_gamma, 'high_gamma': high_gamma})) 129 | 130 | def __call__(self, B, R): 131 | if not _is_pil_image(B): 132 | raise TypeError('B should be PIL Image. Got {}'.format(type(B))) 133 | if not _is_pil_image(R): 134 | raise TypeError('R should be PIL Image. Got {}'.format(type(R))) 135 | 136 | B_ = np.asarray(B, np.float32) / 255. 137 | R_ = np.asarray(R, np.float32) / 255. 138 | 139 | kernel_size = np.random.choice(self.kernel_sizes) 140 | sigma = np.random.uniform(self.low_sigma, self.high_sigma) 141 | gamma = np.random.uniform(self.low_gamma, self.high_gamma) 142 | R_blur = R_ 143 | kernel = cv2.getGaussianKernel(11, sigma) 144 | kernel2d = np.dot(kernel, kernel.T) 145 | 146 | for i in range(3): 147 | R_blur[...,i] = convolve2d(R_blur[...,i], kernel2d, mode='same') 148 | 149 | M_ = B_ + R_blur 150 | 151 | if np.max(M_) > 1: 152 | m = M_[M_ > 1] 153 | m = (np.mean(m) - 1) * gamma 154 | R_blur = np.clip(R_blur - m, 0, 1) 155 | M_ = np.clip(R_blur + B_, 0, 1) 156 | 157 | return B_, R_blur, M_ 158 | 159 | 160 | class Sobel(object): 161 | def __call__(self, img): 162 | if not _is_pil_image(img): 163 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 164 | 165 | gray_img = np.array(img.convert('L')) 166 | x = cv2.Sobel(gray_img,cv2.CV_16S,1,0) 167 | y = cv2.Sobel(gray_img,cv2.CV_16S,0,1) 168 | 169 | absX = cv2.convertScaleAbs(x) 170 | absY = cv2.convertScaleAbs(y) 171 | 172 | dst = cv2.addWeighted(absX,0.5,absY,0.5,0) 173 | return Image.fromarray(dst) 174 | 175 | 176 | class ReflectionSythesis_2(object): 177 | """Reflection image data synthesis for weakly-supervised learning 178 | of CVPR 2018 paper *"Single Image Reflection Separation with Perceptual Losses"* 179 | """ 180 | def __init__(self, kernel_sizes=None): 181 | self.kernel_sizes = kernel_sizes or np.linspace(1,5,80) 182 | 183 | @staticmethod 184 | def gkern(kernlen=100, nsig=1): 185 | """Returns a 2D Gaussian kernel array.""" 186 | interval = (2*nsig+1.)/(kernlen) 187 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) 188 | kern1d = np.diff(st.norm.cdf(x)) 189 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 190 | kernel = kernel_raw/kernel_raw.sum() 191 | kernel = kernel/kernel.max() 192 | return kernel 193 | 194 | def __call__(self, t, r): 195 | t = np.float32(t) / 255. 196 | r = np.float32(r) / 255. 197 | ori_t = t 198 | # create a vignetting mask 199 | g_mask=self.gkern(560,3) 200 | g_mask=np.dstack((g_mask,g_mask,g_mask)) 201 | sigma=self.kernel_sizes[np.random.randint(0, len(self.kernel_sizes))] 202 | 203 | t=np.power(t,2.2) 204 | r=np.power(r,2.2) 205 | 206 | sz=int(2*np.ceil(2*sigma)+1) 207 | 208 | r_blur=cv2.GaussianBlur(r,(sz,sz),sigma,sigma,0) 209 | blend=r_blur+t 210 | 211 | att=1.08+np.random.random()/10.0 212 | 213 | for i in range(3): 214 | maski=blend[:,:,i]>1 215 | mean_i=max(1.,np.sum(blend[:,:,i]*maski)/(maski.sum()+1e-6)) 216 | r_blur[:,:,i]=r_blur[:,:,i]-(mean_i-1)*att 217 | r_blur[r_blur>=1]=1 218 | r_blur[r_blur<=0]=0 219 | 220 | h,w=r_blur.shape[0:2] 221 | neww=np.random.randint(0, 560-w-10) 222 | newh=np.random.randint(0, 560-h-10) 223 | alpha1=g_mask[newh:newh+h,neww:neww+w,:] 224 | alpha2 = 1-np.random.random()/5.0 225 | r_blur_mask=np.multiply(r_blur,alpha1) 226 | blend=r_blur_mask+t*alpha2 227 | 228 | t=np.power(t,1/2.2) 229 | r_blur_mask=np.power(r_blur_mask,1/2.2) 230 | blend=np.power(blend,1/2.2) 231 | blend[blend>=1]=1 232 | blend[blend<=0]=0 233 | 234 | return np.float32(ori_t), np.float32(r_blur_mask), np.float32(blend) 235 | 236 | 237 | # Examples 238 | if __name__ == '__main__': 239 | """cv2 imread""" 240 | # img = cv2.imread('testdata_reflection_real/19-input.png') 241 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 242 | # img2 = cv2.GaussianBlur(img, (11,11), 3) 243 | 244 | """Sobel Operator""" 245 | # img = np.array(Image.open('datasets/VOC224/train/B/2007_000250.png').convert('L')) 246 | 247 | 248 | """Reflection Sythesis""" 249 | b = Image.open('datasets/VOCsmall/train/B/2008_000148.png') 250 | r = Image.open('datasets/VOCsmall/train/B/2007_000243.png') 251 | G = ReflectionSythesis_1() 252 | m, r = G(b, r) 253 | r.show() 254 | 255 | # img2 = gaussian_blur(img, 11, 3) 256 | # img2 = GaussianBlur(1, 1)(img) 257 | # print(np.sum(np.array(img2) - np.array(img))) 258 | # img2.show() 259 | -------------------------------------------------------------------------------- /datasets/eval/real20/blended/103.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/103.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/107.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/107.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/110.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/12.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/15.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/22.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/23.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/25.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/29.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/3.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/39.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/4.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/46.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/47.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/58.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/58.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/86.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/86.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/87.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/87.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/89.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/89.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/9.png -------------------------------------------------------------------------------- /datasets/eval/real20/blended/93.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/blended/93.png -------------------------------------------------------------------------------- /datasets/eval/real20/data_list.txt: -------------------------------------------------------------------------------- 1 | 103.png 2 | 107.png 3 | 110.png 4 | 12.png 5 | 15.png 6 | 22.png 7 | 23.png 8 | 25.png 9 | 29.png 10 | 3.png 11 | 39.png 12 | 4.png 13 | 46.png 14 | 47.png 15 | 58.png 16 | 86.png 17 | 87.png 18 | 89.png 19 | 9.png 20 | 93.png 21 | -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/103.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/103.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/107.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/107.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/110.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/12.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/15.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/22.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/23.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/25.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/29.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/29.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/3.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/39.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/39.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/4.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/46.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/46.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/47.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/47.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/58.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/58.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/86.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/86.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/87.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/87.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/89.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/89.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/9.png -------------------------------------------------------------------------------- /datasets/eval/real20/transmission_layer/93.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/datasets/eval/real20/transmission_layer/93.png -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import util.util as util 3 | import models 4 | import time 5 | import os 6 | import sys 7 | from os.path import join 8 | from util.visualizer import Visualizer 9 | import cv2 10 | import numpy as np 11 | import time 12 | import copy 13 | 14 | class Engine(object): 15 | def __init__(self, opt): 16 | self.opt = opt 17 | self.writer = None 18 | self.visualizer = None 19 | self.model = None 20 | self.best_val_loss = 1e6 21 | 22 | self.__setup() 23 | 24 | def __setup(self): 25 | self.basedir = join('checkpoints', self.opt.name) 26 | if not os.path.exists(self.basedir): 27 | os.mkdir(self.basedir) 28 | 29 | opt = self.opt 30 | 31 | """Model""" 32 | self.model = models.__dict__[self.opt.model]() 33 | self.model.initialize(opt) 34 | if not opt.no_log: 35 | self.writer = util.get_summary_writer(os.path.join(self.basedir, 'logs')) 36 | self.visualizer = Visualizer(opt) 37 | 38 | def train(self, train_loader, **kwargs): 39 | print('\nEpoch: %d' % self.epoch) 40 | avg_meters = util.AverageMeters() 41 | opt = self.opt 42 | model = self.model 43 | epoch = self.epoch 44 | 45 | epoch_start_time = time.time() 46 | for i, data in enumerate(train_loader): 47 | iter_start_time = time.time() 48 | iterations = self.iterations 49 | 50 | 51 | model.set_input(data, mode='train') 52 | model.optimize_parameters(**kwargs) 53 | 54 | errors = model.get_current_errors() 55 | avg_meters.update(errors) 56 | util.progress_bar(i, len(train_loader), str(avg_meters)) 57 | 58 | if not opt.no_log: 59 | util.write_loss(self.writer, 'train', avg_meters, iterations) 60 | 61 | if iterations % opt.display_freq == 0 and opt.display_id != 0: 62 | save_result = iterations % opt.update_html_freq == 0 63 | self.visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 64 | 65 | if iterations % opt.print_freq == 0 and opt.display_id != 0: 66 | t = (time.time() - iter_start_time) 67 | 68 | self.iterations += 1 69 | 70 | self.epoch += 1 71 | 72 | if not self.opt.no_log: 73 | if self.epoch % opt.save_epoch_freq == 0: 74 | print('saving the model at epoch %d, iters %d' % 75 | (self.epoch, self.iterations)) 76 | model.save() 77 | 78 | print('saving the latest model at the end of epoch %d, iters %d' % 79 | (self.epoch, self.iterations)) 80 | model.save(label='latest') 81 | 82 | print('Time Taken: %d sec' % 83 | (time.time() - epoch_start_time)) 84 | 85 | # model.update_learning_rate() 86 | train_loader.reset() 87 | 88 | def eval(self, val_loader, dataset_name, savedir=None, loss_key=None, **kwargs): 89 | 90 | avg_meters = util.AverageMeters() 91 | model = self.model 92 | opt = self.opt 93 | with torch.no_grad(): 94 | for i, data in enumerate(val_loader): 95 | index = model.eval(data, savedir=savedir, **kwargs) 96 | avg_meters.update(index) 97 | 98 | util.progress_bar(i, len(val_loader), str(avg_meters)) 99 | 100 | if not opt.no_log: 101 | util.write_loss(self.writer, join('eval', dataset_name), avg_meters, self.epoch) 102 | 103 | if loss_key is not None: 104 | val_loss = avg_meters[loss_key] 105 | if val_loss < self.best_val_loss: 106 | self.best_val_loss = val_loss 107 | print('saving the best model at the end of epoch %d, iters %d' % 108 | (self.epoch, self.iterations)) 109 | model.save(label='best_{}_{}'.format(loss_key, dataset_name)) 110 | 111 | return avg_meters 112 | 113 | def test(self, test_loader, savedir=None, **kwargs): 114 | model = self.model 115 | opt = self.opt 116 | with torch.no_grad(): 117 | for i, data in enumerate(test_loader): 118 | model.test(data, savedir=savedir, **kwargs) 119 | util.progress_bar(i, len(test_loader)) 120 | 121 | @property 122 | def iterations(self): 123 | return self.model.iterations 124 | 125 | @iterations.setter 126 | def iterations(self, i): 127 | self.model.iterations = i 128 | 129 | @property 130 | def epoch(self): 131 | return self.model.epoch 132 | 133 | @epoch.setter 134 | def epoch(self, e): 135 | self.model.epoch = e 136 | -------------------------------------------------------------------------------- /models/CX/CX_distance.py: -------------------------------------------------------------------------------- 1 | # import tensorflow as tf 2 | import torch 3 | import numpy as np 4 | # import sklearn.manifold.t_sne 5 | 6 | class TensorAxis: 7 | N = 0 8 | C = 1 9 | H = 2 10 | W = 3 11 | 12 | 13 | class CSFlow: 14 | def __init__(self, sigma=float(0.1), b=float(1.0)): 15 | self.b = b 16 | self.sigma = sigma 17 | 18 | def __calculate_CS(self, scaled_distances, axis_for_normalization=TensorAxis.C): 19 | self.scaled_distances = scaled_distances 20 | self.cs_weights_before_normalization = torch.exp((self.b - scaled_distances) / self.sigma) 21 | # self.cs_weights_before_normalization = 1 / (1 + scaled_distances) 22 | self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization) 23 | 24 | # print('cs_NHWC:{}'.format(self.cs_NHWC.sum())) 25 | 26 | # self.cs_NHWC = self.cs_weights_before_normalization 27 | 28 | # def reversed_direction_CS(self): 29 | # cs_flow_opposite = CSFlow(self.sigma, self.b) 30 | # cs_flow_opposite.raw_distances = self.raw_distances 31 | # work_axis = [TensorAxis.H, TensorAxis.W] 32 | # relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis) 33 | # cs_flow_opposite.__calculate_CS(relative_dist, work_axis) 34 | # return cs_flow_opposite 35 | 36 | # -- 37 | @staticmethod 38 | def create_using_L2(I_features, T_features, sigma=float(0.5), b=float(1.0)): 39 | cs_flow = CSFlow(sigma, b) 40 | sT = T_features.shape 41 | sI = I_features.shape # N, C, H, W 42 | 43 | Ivecs = torch.reshape(I_features, (sI[0], sI[1], -1)) 44 | Tvecs = torch.reshape(T_features, (sI[0], sT[1], -1)) 45 | r_Ts = torch.sum(Tvecs * Tvecs, 1) # N C P 46 | r_Is = torch.sum(Ivecs * Ivecs, 1) 47 | 48 | # print('r_Ts:{}'.format(r_Ts.sum())) 49 | # print('r_Is:{}'.format(r_Is.sum())) 50 | 51 | raw_distances_list = [] 52 | for i in range(sT[0]): 53 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i] 54 | A = torch.transpose(Tvec, 0, 1) @ Ivec # (matrix multiplication) 55 | cs_flow.A = A 56 | # A = tf.matmul(Tvec, tf.transpose(Ivec)) 57 | r_T = torch.reshape(r_T, [-1, 1]) # turn to column vector 58 | dist = r_T - 2 * A + r_I 59 | 60 | # print('dist:{}'.format(dist.mean())) 61 | 62 | dist = torch.reshape(dist, shape=(1, dist.shape[0], sI[2], sI[3])) 63 | # protecting against numerical problems, dist should be positive 64 | dist = torch.clamp(dist, min=float(0.0)) 65 | # dist = tf.sqrt(dist) 66 | raw_distances_list += [dist] 67 | 68 | cs_flow.raw_distances = torch.cat(raw_distances_list) 69 | 70 | # print('raw_distances:{}'.format(cs_flow.raw_distances.mean())) 71 | 72 | relative_dist = cs_flow.calc_relative_distances() 73 | 74 | # print('relative_dist:{}'.format(relative_dist.mean())) 75 | 76 | cs_flow.__calculate_CS(relative_dist) 77 | return cs_flow 78 | 79 | # -- 80 | @staticmethod 81 | def create_using_L1(I_features, T_features, sigma=float(0.5), b=float(1.0)): 82 | cs_flow = CSFlow(sigma, b) 83 | sT = T_features.shape 84 | sI = I_features.shape 85 | 86 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3])) 87 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3])) 88 | raw_distances_list = [] 89 | for i in range(sT[0]): 90 | Ivec, Tvec = Ivecs[i], Tvecs[i] 91 | dist = torch.abs(torch.sum(Ivec.unsqueeze(1) - Tvec.unsqueeze(0), dim=2)) 92 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0])) 93 | # protecting against numerical problems, dist should be positive 94 | dist = torch.clamp(dist, min=float(0.0)) 95 | # dist = tf.sqrt(dist) 96 | raw_distances_list += [dist] 97 | 98 | cs_flow.raw_distances = torch.cat(raw_distances_list) 99 | 100 | relative_dist = cs_flow.calc_relative_distances() 101 | cs_flow.__calculate_CS(relative_dist) 102 | return cs_flow 103 | 104 | # -- 105 | @staticmethod 106 | def create_using_dotP(I_features, T_features, sigma=float(0.5), b=float(1.0)): 107 | cs_flow = CSFlow(sigma, b) 108 | # prepare feature before calculating cosine distance 109 | T_features, I_features = cs_flow.center_by_T(T_features, I_features) 110 | T_features = CSFlow.l2_normalize_channelwise(T_features) 111 | I_features = CSFlow.l2_normalize_channelwise(I_features) 112 | 113 | # work seperatly for each example in dim 1 114 | cosine_dist_l = [] 115 | N = T_features.size()[0] 116 | for i in range(N): 117 | T_features_i = T_features[i, :, :, :].unsqueeze_(0) 118 | I_features_i = I_features[i, :, :, :].unsqueeze_(0) 119 | patches_PC11_i = cs_flow.patch_decomposition(T_features_i) # 1CHW --> PC11, with P=H*W (C_out x C_in x H x W) 120 | cosine_dist_i = torch.nn.functional.conv2d(I_features_i, patches_PC11_i) 121 | # cosine_dist_1HWC = cosine_dist_i.permute((0, 2, 3, 1)) 122 | cosine_dist_l.append(cosine_dist_i) # 1PHW 123 | 124 | cs_flow.cosine_dist = torch.cat(cosine_dist_l, dim=0) 125 | 126 | cs_flow.raw_distances = - (cs_flow.cosine_dist - 1) / 2 ### why - 127 | 128 | relative_dist = cs_flow.calc_relative_distances() 129 | cs_flow.__calculate_CS(relative_dist) 130 | return cs_flow 131 | 132 | def calc_relative_distances(self, axis=TensorAxis.C): 133 | epsilon = 1e-5 134 | div = torch.min(self.raw_distances, dim=axis, keepdim=True)[0] 135 | relative_dist = self.raw_distances / (div + epsilon) 136 | return relative_dist 137 | 138 | @staticmethod 139 | def sum_normalize(cs, axis=TensorAxis.C): 140 | reduce_sum = torch.sum(cs, dim=axis, keepdim=True) 141 | cs_normalize = torch.div(cs, reduce_sum) 142 | return cs_normalize 143 | 144 | def center_by_T(self, T_features, I_features): 145 | # assuming both input are of the same size 146 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor 147 | axes = [0, 1, 2] 148 | self.meanT = T_features.mean(TensorAxis.N, keepdim=True).mean(TensorAxis.H, keepdim=True).mean(TensorAxis.W, keepdim=True) 149 | # self.varT = T_features.var(0, keepdim=True).var(2, keepdim=True).var(3, keepdim=True) 150 | self.T_features_centered = T_features - self.meanT 151 | self.I_features_centered = I_features - self.meanT 152 | 153 | return self.T_features_centered, self.I_features_centered 154 | 155 | @staticmethod 156 | def l2_normalize_channelwise(features): 157 | norms = features.norm(p=2, dim=TensorAxis.C, keepdim=True) 158 | features = features.div(norms) 159 | return features 160 | 161 | def patch_decomposition(self, T_features): 162 | # 1HWC --> 11PC --> PC11, with P=H*W 163 | (_, C, H, W) = T_features.shape 164 | P = H * W 165 | patches_PC11 = T_features.reshape(shape=(C, P, 1, 1)).permute(dims=(1, 0, 2, 3)) 166 | return patches_PC11 167 | 168 | @staticmethod 169 | def pdist2(x, keepdim=False): 170 | sx = x.shape 171 | x = x.reshape(shape=(sx[0], sx[1] * sx[2], sx[3])) 172 | differences = x.unsqueeze(2) - x.unsqueeze(1) 173 | distances = torch.sum(differences**2, -1) 174 | if keepdim: 175 | distances = distances.reshape(shape=(sx[0], sx[1], sx[2], sx[3])) 176 | return distances 177 | 178 | @staticmethod 179 | def calcR_static(sT, order='C', deformation_sigma=0.05): 180 | # oreder can be C or F (matlab order) 181 | pixel_count = sT[0] * sT[1] 182 | 183 | rangeRows = range(0, sT[1]) 184 | rangeCols = range(0, sT[0]) 185 | Js, Is = np.meshgrid(rangeRows, rangeCols) 186 | row_diff_from_first_row = Is 187 | col_diff_from_first_col = Js 188 | 189 | row_diff_from_first_row_3d_repeat = np.repeat(row_diff_from_first_row[:, :, np.newaxis], pixel_count, axis=2) 190 | col_diff_from_first_col_3d_repeat = np.repeat(col_diff_from_first_col[:, :, np.newaxis], pixel_count, axis=2) 191 | 192 | rowDiffs = -row_diff_from_first_row_3d_repeat + row_diff_from_first_row.flatten(order).reshape(1, 1, -1) 193 | colDiffs = -col_diff_from_first_col_3d_repeat + col_diff_from_first_col.flatten(order).reshape(1, 1, -1) 194 | R = rowDiffs ** 2 + colDiffs ** 2 195 | R = R.astype(np.float32) 196 | R = np.exp(-(R) / (2 * deformation_sigma ** 2)) 197 | return R 198 | 199 | 200 | 201 | 202 | 203 | 204 | # -------------------------------------------------- 205 | # CX loss 206 | # -------------------------------------------------- 207 | 208 | 209 | 210 | def CX_loss(I_features, T_features, deformation=False, dis=False): 211 | # T_features = tf.convert_to_tensor(T_features, dtype=tf.float32) 212 | # I_features = tf.convert_to_tensor(I_features, dtype=tf.float32) 213 | # since this is a convertion of tensorflow to pytorch we permute the tensor from 214 | # T_features = normalize_tensor(T_features) 215 | # I_features = normalize_tensor(I_features) 216 | 217 | # since this originally Tensorflow implemntation 218 | # we modify all tensors to be as TF convention and not as the convention of pytorch. 219 | # def from_pt2tf(Tpt): 220 | # Ttf = Tpt.permute(0, 2, 3, 1) 221 | # return Ttf 222 | # N x C x H x W --> N x H x W x C 223 | # T_features_tf = from_pt2tf(T_features) 224 | # I_features_tf = from_pt2tf(I_features) 225 | 226 | cs_flow = CSFlow.create_using_dotP(I_features, T_features, sigma=1.0) 227 | # cs_flow = CSFlow.create_using_L2(I_features, T_features, sigma=1.0) 228 | # sum_normalize: 229 | # To: 230 | cs = cs_flow.cs_NHWC 231 | 232 | # print('cs:{}'.format(cs.std())) 233 | 234 | if deformation: 235 | deforma_sigma = 0.001 236 | sT = T_features_tf.shape[1:2 + 1] 237 | R = CSFlow.calcR_static(sT, deformation_sigma=deforma_sigma) 238 | cs *= torch.Tensor(R).unsqueeze(dim=0).cuda() 239 | 240 | if dis: 241 | CS = [] 242 | k_max_NC = torch.max(torch.max(cs, dim=1)[1], dim=1)[1] 243 | indices = k_max_NC.cpu() 244 | N, C = indices.shape 245 | for i in range(N): 246 | CS.append((C - len(torch.unique(indices[i, :]))) / C) 247 | score = torch.FloatTensor(CS) 248 | else: 249 | # reduce_max X and Y dims 250 | # cs = CSFlow.pdist2(cs,keepdim=True) N C H W 251 | k_max_NC = torch.max(torch.max(cs, dim=2)[0], dim=2)[0] 252 | # reduce mean over C(H*W) dim 253 | CS = torch.mean(k_max_NC, dim=1) 254 | # score = 1/CS 255 | # score = torch.exp(-CS*10) 256 | score = -torch.log(CS) 257 | # reduce mean over N dim 258 | # CX_loss = torch.mean(CX_loss) 259 | return score 260 | 261 | 262 | def symetric_CX_loss(T_features, I_features): 263 | score = (CX_loss(T_features, I_features) + CX_loss(I_features, T_features)) / 2 264 | return score 265 | -------------------------------------------------------------------------------- /models/CX/CX_helper.py: -------------------------------------------------------------------------------- 1 | from CX import CSFlow 2 | import tensorflow as tf 3 | 4 | 5 | def random_sampling(tensor_NHWC, n, indices=None): 6 | N, H, W, C = tf.convert_to_tensor(tensor_NHWC).shape.as_list() 7 | S = H * W 8 | tensor_NSC = tf.reshape(tensor_NHWC, [N, S, C]) 9 | all_indices = list(range(S)) 10 | shuffled_indices = tf.random_shuffle(all_indices) 11 | indices = tf.gather(shuffled_indices, list(range(n)), axis=0) if indices is None else indices 12 | indices_old = tf.random_uniform([n], 0, S, tf.int32) if indices is None else indices 13 | res = tf.gather(tensor_NSC, indices, axis=1) 14 | return res, indices 15 | 16 | 17 | def random_pooling(feats, output_1d_size=100): 18 | is_input_tensor = type(feats) is tf.Tensor 19 | 20 | if is_input_tensor: 21 | feats = [feats] 22 | 23 | # convert all inputs to tensors 24 | feats = [tf.convert_to_tensor(feats_i) for feats_i in feats] 25 | 26 | N, H, W, C = feats[0].shape.as_list() 27 | feats_sampled_0, indices = random_sampling(feats[0], output_1d_size ** 2) 28 | res = [feats_sampled_0] 29 | for i in range(1, len(feats)): 30 | feats_sampled_i, _ = random_sampling(feats[i], -1, indices) 31 | res.append(feats_sampled_i) 32 | 33 | res = [tf.reshape(feats_sampled_i, [N, output_1d_size, output_1d_size, C]) for feats_sampled_i in res] 34 | if is_input_tensor: 35 | return res[0] 36 | return res 37 | 38 | 39 | def crop_quarters(feature_tensor): 40 | N, fH, fW, fC = feature_tensor.shape.as_list() 41 | quarters_list = [] 42 | quarter_size = [N, round(fH / 2), round(fW / 2), fC] 43 | quarters_list.append(tf.slice(feature_tensor, [0, 0, 0, 0], quarter_size)) 44 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), 0, 0], quarter_size)) 45 | quarters_list.append(tf.slice(feature_tensor, [0, 0, round(fW / 2), 0], quarter_size)) 46 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), round(fW / 2), 0], quarter_size)) 47 | feature_tensor = tf.concat(quarters_list, axis=0) 48 | return feature_tensor 49 | 50 | 51 | def CX_loss_helper(vgg_A, vgg_B, CX_config): 52 | if CX_config.crop_quarters is True: 53 | vgg_A = crop_quarters(vgg_A) 54 | vgg_B = crop_quarters(vgg_B) 55 | 56 | N, fH, fW, fC = vgg_A.shape.as_list() 57 | if fH * fW <= CX_config.max_sampling_1d_size ** 2: 58 | print(' #### Skipping pooling for CX....') 59 | else: 60 | print(' #### pooling for CX %d**2 out of %dx%d' % (CX_config.max_sampling_1d_size, fH, fW)) 61 | vgg_A, vgg_B = random_pooling([vgg_A, vgg_B], output_1d_size=CX_config.max_sampling_1d_size) 62 | 63 | CX_loss = CSFlow.CX_loss(vgg_A, vgg_B, distance=CX_config.Dist, nnsigma=CX_config.nn_stretch_sigma) 64 | return CX_loss 65 | -------------------------------------------------------------------------------- /models/CX/__init__.py: -------------------------------------------------------------------------------- 1 | # The Contextual Loss for Image Transformation with Non-Aligned Data 2 | # https://arxiv.org/abs/1803.02077 3 | # https://github.com/roimehrez/contextualLoss 4 | from .CX_distance import * -------------------------------------------------------------------------------- /models/CX/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Distance(Enum): 5 | L2 = 0 6 | DotProduct = 1 7 | 8 | 9 | class TensorAxis: 10 | N = 0 11 | H = 1 12 | W = 2 13 | C = 3 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .errnet_model import ERRNetModel 2 | 3 | def errnet_model(): 4 | return ERRNetModel() 5 | -------------------------------------------------------------------------------- /models/arch/__init__.py: -------------------------------------------------------------------------------- 1 | # Add your custom network here 2 | from .default import DRNet 3 | import torch.nn as nn 4 | 5 | 6 | def basenet(in_channels, out_channels, **kwargs): 7 | return DRNet(in_channels, out_channels, 256, 13, norm=None, res_scale=0.1, bottom_kernel_size=1, **kwargs) 8 | 9 | 10 | def errnet(in_channels, out_channels, **kwargs): 11 | return DRNet(in_channels, out_channels, 256, 13, norm=None, res_scale=0.1, se_reduction=8, bottom_kernel_size=1, pyramid=True, **kwargs) 12 | -------------------------------------------------------------------------------- /models/arch/default.py: -------------------------------------------------------------------------------- 1 | # Define network components here 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class PyramidPooling(nn.Module): 8 | def __init__(self, in_channels, out_channels, scales=(4, 8, 16, 32), ct_channels=1): 9 | super().__init__() 10 | self.stages = [] 11 | self.stages = nn.ModuleList([self._make_stage(in_channels, scale, ct_channels) for scale in scales]) 12 | self.bottleneck = nn.Conv2d(in_channels + len(scales) * ct_channels, out_channels, kernel_size=1, stride=1) 13 | self.relu = nn.LeakyReLU(0.2, inplace=True) 14 | 15 | def _make_stage(self, in_channels, scale, ct_channels): 16 | # prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 17 | prior = nn.AvgPool2d(kernel_size=(scale, scale)) 18 | conv = nn.Conv2d(in_channels, ct_channels, kernel_size=1, bias=False) 19 | relu = nn.LeakyReLU(0.2, inplace=True) 20 | return nn.Sequential(prior, conv, relu) 21 | 22 | def forward(self, feats): 23 | h, w = feats.size(2), feats.size(3) 24 | priors = torch.cat([F.interpolate(input=stage(feats), size=(h, w), mode='nearest') for stage in self.stages] + [feats], dim=1) 25 | return self.relu(self.bottleneck(priors)) 26 | 27 | 28 | class SELayer(nn.Module): 29 | def __init__(self, channel, reduction=16): 30 | super(SELayer, self).__init__() 31 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 32 | self.fc = nn.Sequential( 33 | nn.Linear(channel, channel // reduction), 34 | nn.ReLU(inplace=True), 35 | nn.Linear(channel // reduction, channel), 36 | nn.Sigmoid() 37 | ) 38 | 39 | def forward(self, x): 40 | b, c, _, _ = x.size() 41 | y = self.avg_pool(x).view(b, c) 42 | y = self.fc(y).view(b, c, 1, 1) 43 | 44 | return x * y 45 | 46 | 47 | class DRNet(torch.nn.Module): 48 | def __init__(self, in_channels, out_channels, n_feats, n_resblocks, norm=nn.BatchNorm2d, 49 | se_reduction=None, res_scale=1, bottom_kernel_size=3, pyramid=False): 50 | super(DRNet, self).__init__() 51 | # Initial convolution layers 52 | conv = nn.Conv2d 53 | deconv = nn.ConvTranspose2d 54 | act = nn.ReLU(True) 55 | 56 | self.pyramid_module = None 57 | self.conv1 = ConvLayer(conv, in_channels, n_feats, kernel_size=bottom_kernel_size, stride=1, norm=None, act=act) 58 | self.conv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act) 59 | self.conv3 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=2, norm=norm, act=act) 60 | 61 | # Residual layers 62 | dilation_config = [1] * n_resblocks 63 | 64 | self.res_module = nn.Sequential(*[ResidualBlock( 65 | n_feats, dilation=dilation_config[i], norm=norm, act=act, 66 | se_reduction=se_reduction, res_scale=res_scale) for i in range(n_resblocks)]) 67 | 68 | # Upsampling Layers 69 | self.deconv1 = ConvLayer(deconv, n_feats, n_feats, kernel_size=4, stride=2, padding=1, norm=norm, act=act) 70 | 71 | if not pyramid: 72 | self.deconv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act) 73 | self.deconv3 = ConvLayer(conv, n_feats, out_channels, kernel_size=1, stride=1, norm=None, act=act) 74 | else: 75 | self.deconv2 = ConvLayer(conv, n_feats, n_feats, kernel_size=3, stride=1, norm=norm, act=act) 76 | self.pyramid_module = PyramidPooling(n_feats, n_feats, scales=(4,8,16,32), ct_channels=n_feats//4) 77 | self.deconv3 = ConvLayer(conv, n_feats, out_channels, kernel_size=1, stride=1, norm=None, act=act) 78 | 79 | def forward(self, x): 80 | x = self.conv1(x) 81 | x = self.conv2(x) 82 | x = self.conv3(x) 83 | x = self.res_module(x) 84 | 85 | x = self.deconv1(x) 86 | x = self.deconv2(x) 87 | if self.pyramid_module is not None: 88 | x = self.pyramid_module(x) 89 | x = self.deconv3(x) 90 | 91 | return x 92 | 93 | 94 | class ConvLayer(torch.nn.Sequential): 95 | def __init__(self, conv, in_channels, out_channels, kernel_size, stride, padding=None, dilation=1, norm=None, act=None): 96 | super(ConvLayer, self).__init__() 97 | # padding = padding or kernel_size // 2 98 | padding = padding or dilation * (kernel_size - 1) // 2 99 | self.add_module('conv2d', conv(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation)) 100 | if norm is not None: 101 | self.add_module('norm', norm(out_channels)) 102 | # self.add_module('norm', norm(out_channels, track_running_stats=True)) 103 | if act is not None: 104 | self.add_module('act', act) 105 | 106 | 107 | class ResidualBlock(torch.nn.Module): 108 | def __init__(self, channels, dilation=1, norm=nn.BatchNorm2d, act=nn.ReLU(True), se_reduction=None, res_scale=1): 109 | super(ResidualBlock, self).__init__() 110 | conv = nn.Conv2d 111 | self.conv1 = ConvLayer(conv, channels, channels, kernel_size=3, stride=1, dilation=dilation, norm=norm, act=act) 112 | self.conv2 = ConvLayer(conv, channels, channels, kernel_size=3, stride=1, dilation=dilation, norm=norm, act=None) 113 | self.se_layer = None 114 | self.res_scale = res_scale 115 | if se_reduction is not None: 116 | self.se_layer = SELayer(channels, se_reduction) 117 | 118 | def forward(self, x): 119 | residual = x 120 | out = self.conv1(x) 121 | out = self.conv2(out) 122 | if self.se_layer: 123 | out = self.se_layer(out) 124 | out = out * self.res_scale 125 | out = out + residual 126 | return out 127 | 128 | def extra_repr(self): 129 | return 'res_scale={}'.format(self.res_scale) 130 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import util.util as util 4 | 5 | 6 | class BaseModel(): 7 | def name(self): 8 | return self.__class__.__name__.lower() 9 | 10 | def initialize(self, opt): 11 | self.opt = opt 12 | self.gpu_ids = opt.gpu_ids 13 | self.isTrain = opt.isTrain 14 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 16 | self._count = 0 17 | 18 | def set_input(self, input): 19 | self.input = input 20 | 21 | def forward(self): 22 | pass 23 | 24 | # used in test time, no backprop 25 | def test(self): 26 | pass 27 | 28 | def get_image_paths(self): 29 | pass 30 | 31 | def optimize_parameters(self): 32 | pass 33 | 34 | def get_current_visuals(self): 35 | return self.input 36 | 37 | def get_current_errors(self): 38 | return {} 39 | 40 | def save(self, label): 41 | pass 42 | 43 | def print_optimizer_param(self): 44 | # for optimizer in self.optimizers: 45 | # print(optimizer) 46 | print(self.optimizers[-1]) 47 | 48 | def save(self, label=None): 49 | epoch = self.epoch 50 | iterations = self.iterations 51 | 52 | if label is None: 53 | model_name = os.path.join(self.save_dir, self.name() + '_%03d_%08d.pt' % ((epoch), (iterations))) 54 | else: 55 | model_name = os.path.join(self.save_dir, self.name() + '_' + label + '.pt') 56 | 57 | torch.save(self.state_dict(), model_name) 58 | 59 | def _init_optimizer(self, optimizers): 60 | self.optimizers = optimizers 61 | for optimizer in self.optimizers: 62 | util.set_opt_param(optimizer, 'initial_lr', self.opt.lr) 63 | util.set_opt_param(optimizer, 'weight_decay', self.opt.wd) 64 | -------------------------------------------------------------------------------- /models/errnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import os 6 | import numpy as np 7 | import itertools 8 | from collections import OrderedDict 9 | 10 | import util.util as util 11 | import util.index as index 12 | import models.networks as networks 13 | import models.losses as losses 14 | from models import arch 15 | 16 | from .base_model import BaseModel 17 | from PIL import Image 18 | from os.path import join 19 | 20 | 21 | def tensor2im(image_tensor, imtype=np.uint8): 22 | image_tensor = image_tensor.detach() 23 | image_numpy = image_tensor[0].cpu().float().numpy() 24 | image_numpy = np.clip(image_numpy, 0, 1) 25 | if image_numpy.shape[0] == 1: 26 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 27 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 28 | # image_numpy = image_numpy.astype(imtype) 29 | return image_numpy 30 | 31 | 32 | class EdgeMap(nn.Module): 33 | def __init__(self, scale=1): 34 | super(EdgeMap, self).__init__() 35 | self.scale = scale 36 | self.requires_grad = False 37 | 38 | def forward(self, img): 39 | img = img / self.scale 40 | 41 | N, C, H, W = img.shape 42 | gradX = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device) 43 | gradY = torch.zeros(N, 1, H, W, dtype=img.dtype, device=img.device) 44 | 45 | gradx = (img[...,1:,:] - img[...,:-1,:]).abs().sum(dim=1, keepdim=True) 46 | grady = (img[...,1:] - img[...,:-1]).abs().sum(dim=1, keepdim=True) 47 | 48 | gradX[...,:-1,:] += gradx 49 | gradX[...,1:,:] += gradx 50 | gradX[...,1:-1,:] /= 2 51 | 52 | gradY[...,:-1] += grady 53 | gradY[...,1:] += grady 54 | gradY[...,1:-1] /= 2 55 | 56 | # edge = (gradX + gradY) / 2 57 | edge = (gradX + gradY) 58 | 59 | return edge 60 | 61 | 62 | class ERRNetBase(BaseModel): 63 | def _init_optimizer(self, optimizers): 64 | self.optimizers = optimizers 65 | for optimizer in self.optimizers: 66 | util.set_opt_param(optimizer, 'initial_lr', self.opt.lr) 67 | util.set_opt_param(optimizer, 'weight_decay', self.opt.wd) 68 | 69 | def set_input(self, data, mode='train'): 70 | target_t = None 71 | target_r = None 72 | data_name = None 73 | mode = mode.lower() 74 | if mode == 'train': 75 | input, target_t, target_r = data['input'], data['target_t'], data['target_r'] 76 | elif mode == 'eval': 77 | input, target_t, target_r, data_name = data['input'], data['target_t'], data['target_r'], data['fn'] 78 | elif mode == 'test': 79 | input, data_name = data['input'], data['fn'] 80 | else: 81 | raise NotImplementedError('Mode [%s] is not implemented' % mode) 82 | 83 | if 'idx' in data: 84 | self.idx = data['idx'].to(device=self.gpu_ids[0]) 85 | self.idx_vec = torch.eye(self.opt.nModel)[self.idx].to(device=self.gpu_ids[0], dtype=torch.bool) 86 | 87 | if len(self.gpu_ids) > 0: # transfer data into gpu 88 | input = input.to(device=self.gpu_ids[0]) 89 | if target_t is not None: 90 | target_t = target_t.to(device=self.gpu_ids[0]) 91 | if target_r is not None: 92 | target_r = target_r.to(device=self.gpu_ids[0]) 93 | 94 | self.input = input 95 | 96 | self.input_edge = self.edge_map(self.input) 97 | self.target_t = target_t 98 | self.data_name = data_name 99 | 100 | self.issyn = False if 'real' in data else True 101 | self.aligned = False if 'unaligned' in data else True 102 | 103 | if target_t is not None: 104 | self.target_edge = self.edge_map(self.target_t) 105 | 106 | def eval(self, data, savedir=None, suffix=None, pieapp=None): 107 | # only the 1st input of the whole minibatch would be evaluated 108 | self._eval() 109 | self.set_input(data, 'eval') 110 | 111 | with torch.no_grad(): 112 | self.forward() 113 | 114 | output_i = tensor2im(self.output_i) 115 | target = tensor2im(self.target_t) 116 | 117 | if self.aligned: 118 | res = index.quality_assess(output_i, target) 119 | else: 120 | res = {} 121 | 122 | if savedir is not None: 123 | if self.data_name is not None: 124 | name = os.path.splitext(os.path.basename(self.data_name[0]))[0] 125 | if not os.path.exists(join(savedir, name)): 126 | os.makedirs(join(savedir, name)) 127 | if suffix is not None: 128 | Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name,'{}_{}.png'.format(self.opt.name, suffix))) 129 | else: 130 | Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name, '{}.png'.format(self.opt.name))) 131 | Image.fromarray(target.astype(np.uint8)).save(join(savedir, name, 't_label.png')) 132 | Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, name, 'm_input.png')) 133 | else: 134 | if not os.path.exists(join(savedir, 'transmission_layer')): 135 | os.makedirs(join(savedir, 'transmission_layer')) 136 | os.makedirs(join(savedir, 'blended')) 137 | Image.fromarray(target.astype(np.uint8)).save(join(savedir, 'transmission_layer', str(self._count)+'.png')) 138 | Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, 'blended', str(self._count)+'.png')) 139 | self._count += 1 140 | 141 | return res 142 | 143 | def test(self, data, savedir=None): 144 | # only the 1st input of the whole minibatch would be evaluated 145 | self._eval() 146 | self.set_input(data, 'test') 147 | 148 | if self.data_name is not None and savedir is not None: 149 | name = os.path.splitext(os.path.basename(self.data_name[0]))[0] 150 | if not os.path.exists(join(savedir, name)): 151 | os.makedirs(join(savedir, name)) 152 | 153 | if os.path.exists(join(savedir, name, '{}.png'.format(self.opt.name))): 154 | return 155 | 156 | with torch.no_grad(): 157 | output_i = self.forward() 158 | output_i = tensor2im(output_i) 159 | # if os.path.exists(join(savedir, name,'t_output.png')): 160 | # i = 2 161 | # while True: 162 | # if not os.path.exists(join(savedir, name,'t_output_{}.png'.format(i))): 163 | # Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name,'t_output_{}.png'.format(i))) 164 | # break 165 | # i += 1 166 | # else: 167 | # Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name,'t_output.png')) 168 | if self.data_name is not None and savedir is not None: 169 | Image.fromarray(output_i.astype(np.uint8)).save(join(savedir, name, '{}.png'.format(self.opt.name))) 170 | Image.fromarray(tensor2im(self.input).astype(np.uint8)).save(join(savedir, name, 'm_input.png')) 171 | 172 | 173 | class Predictor(nn.Module): 174 | """ This is the feature extractor in the paper. """ 175 | def __init__(self, feats=64, in_c=3, inter_c=32, n_layers=5): 176 | super().__init__() 177 | 178 | layers = [] 179 | prev_c, curr_c = in_c, inter_c 180 | for i in range(1, n_layers+1): 181 | layers.append(('conv%d'%i, nn.Conv2d(prev_c, curr_c, 4, 2, 1))) 182 | layers.append(('relu%d'%i, nn.ReLU(True))) 183 | prev_c, curr_c = curr_c, curr_c*2 184 | 185 | layers.append(('avg', nn.AdaptiveAvgPool2d(1))) 186 | layers.append(('flatten', nn.Flatten())) 187 | layers.append(('FC', nn.Linear(prev_c, feats))) 188 | 189 | self.model = nn.Sequential(OrderedDict(layers)) 190 | 191 | def forward(self, x): 192 | return self.model(x) 193 | 194 | 195 | class Attention(nn.Module): 196 | 197 | def __init__(self, temperature): 198 | super().__init__() 199 | self.temperature = temperature 200 | 201 | def forward(self, q, k, v, attn_=None): 202 | ########################## Input shape ########################## 203 | # q = [N n_head 1 d_k] 204 | # k = [N n_head len_k d_k] # len_k is the number of k's 205 | # v = [N len_k 3 H W] 206 | ########################## k * v -> attn ########################## 207 | # q = [N n_head 1 d_k] 208 | # @ 209 | # k = [N n_head len_k d_k] --> [N n_head d_k len_k] 210 | # ↓ 211 | # attn = [N n_head 1 len_k] 212 | ####################################################################### 213 | 214 | attn = (q / self.temperature) @ (k.transpose(2, 3)) 215 | if attn_ is None: 216 | attn_ = F.softmax(attn, dim=-1) 217 | else: 218 | attn_ = torch.tensor([[[attn_]]], dtype=torch.float32, device=q.device) 219 | 220 | v_ = v.permute(0, 2, 3, 4, 1).unsqueeze(4) 221 | attn_ = attn_.permute(0, 2, 3, 1)[:, None, None] 222 | res_ = v_ @ attn_ 223 | res = res_[..., 0, :].permute(0, 4, 1, 2, 3) 224 | 225 | C = attn.shape[-1] 226 | return res, attn.view(-1, C) 227 | 228 | 229 | class CDAM(nn.Module): 230 | 231 | def __init__(self, n_head=1, d_model=64, d_k=64): 232 | super().__init__() 233 | 234 | self.n_head = n_head 235 | self.d_k = d_k 236 | 237 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 238 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 239 | 240 | self.attention = Attention(temperature=d_k**0.5) 241 | 242 | def forward(self, q, k, v, attn_=None): 243 | d_k, n_head = self.d_k, self.n_head 244 | sz_b, len_q, len_k = q.size(0), q.size(1), k.size(1) 245 | 246 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 247 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 248 | q, k = q.transpose(1, 2), k.transpose(1, 2) 249 | 250 | res, attn = self.attention(q, k, v, attn_) 251 | return res, attn 252 | 253 | 254 | class ERRNetModel(ERRNetBase): 255 | def name(self): 256 | return 'errnet' 257 | 258 | def __init__(self): 259 | self.epoch = 0 260 | self.iterations = 0 261 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 262 | 263 | def print_network(self): 264 | pass 265 | # print('--------------------- Model ---------------------') 266 | # print('##################### NetG #####################') 267 | # networks.print_network(self.net_i) 268 | # if self.isTrain and self.opt.lambda_gan > 0: 269 | # print('##################### NetD #####################') 270 | # networks.print_network(self.netD) 271 | 272 | def _eval(self): 273 | try: 274 | self.net_i_list.eval() 275 | self.net_p_list.eval() 276 | self.net_p.eval() 277 | self.net_A.eval() 278 | except: 279 | pass 280 | 281 | def _train(self): 282 | try: 283 | self.net_i_list.train() 284 | self.net_p_list.train() 285 | self.net_p.train() 286 | self.net_A.train() 287 | except: 288 | pass 289 | 290 | def initialize(self, opt): 291 | BaseModel.initialize(self, opt) 292 | 293 | if self.isTrain: 294 | assert self.opt.nModel == len(self.opt.icnn_path) 295 | self.net_i_list = nn.ModuleList() 296 | self.net_p_list = nn.ModuleList() 297 | 298 | in_channels = 3 299 | self.vgg = None 300 | 301 | if opt.hyper: 302 | self.vgg = losses.Vgg19(requires_grad=False).to(self.device) 303 | in_channels += 1472 304 | 305 | if not self.isTrain and self.opt.nModel == 1: 306 | self.net_i = arch.__dict__[self.opt.inet](in_channels, 3).to(self.device) 307 | networks.init_weights(self.net_i, init_type=opt.init_type) # using default initialization as EDSR 308 | else: 309 | for idx in range(self.opt.nModel): 310 | net_i = arch.__dict__[self.opt.inet](in_channels, 3).to(self.device) 311 | networks.init_weights(net_i, init_type=opt.init_type) 312 | self.net_i_list.append(net_i) 313 | 314 | netP = Predictor().to(self.device) 315 | self.net_p_list.append(netP) 316 | 317 | self.net_p = Predictor().to(self.device) 318 | self.net_A = CDAM().to(self.device) 319 | 320 | self.edge_map = EdgeMap(scale=1).to(self.device) 321 | 322 | if self.isTrain: 323 | # define loss functions 324 | self.loss_dic = losses.init_loss(opt, self.Tensor) 325 | vggloss = losses.ContentLoss() 326 | vggloss.initialize(losses.VGGLoss(self.vgg)) 327 | self.loss_dic['t_vgg'] = vggloss 328 | 329 | cxloss = losses.ContentLoss() 330 | if opt.unaligned_loss == 'vgg': 331 | cxloss.initialize(losses.VGGLoss(self.vgg, weights=[0.1], indices=[opt.vgg_layer])) 332 | elif opt.unaligned_loss == 'ctx': 333 | cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1,0.1,0.1], indices=[8, 13, 22])) 334 | elif opt.unaligned_loss == 'mse': 335 | cxloss.initialize(nn.MSELoss()) 336 | elif opt.unaligned_loss == 'ctx_vgg': 337 | cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1,0.1,0.1,0.1], indices=[8, 13, 22, 31], criterions=[losses.CX_loss]*3+[nn.L1Loss()])) 338 | else: 339 | raise NotImplementedError 340 | 341 | self.loss_dic['t_cx'] = cxloss 342 | self.loss_dic['ce'] = nn.CrossEntropyLoss() 343 | 344 | # Define discriminator 345 | # if self.opt.lambda_gan > 0: 346 | self.netD = networks.define_D(opt, 3) 347 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 348 | lr=opt.lr, betas=(0.9, 0.999)) 349 | self._init_optimizer([self.optimizer_D]) 350 | 351 | # initialize optimizers 352 | param_list = [self.net_p_list.parameters(), self.net_p.parameters(), self.net_A.parameters()] 353 | self.optimizer_G = torch.optim.Adam(itertools.chain(*param_list), 354 | lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.wd) 355 | 356 | self._init_optimizer([self.optimizer_G]) 357 | 358 | if opt.resume: 359 | self.load(self, opt.resume_epoch) 360 | 361 | if opt.no_verbose is False: 362 | self.print_network() 363 | 364 | def backward_D(self): 365 | for p in self.netD.parameters(): 366 | p.requires_grad = True 367 | 368 | self.loss_D, self.pred_fake, self.pred_real = self.loss_dic['gan'].get_loss( 369 | self.netD, self.input, self.output_i, self.target_t) 370 | 371 | (self.loss_D*self.opt.lambda_gan).backward(retain_graph=True) 372 | 373 | def backward_G(self): 374 | # Make it a tiny bit faster 375 | for p in self.netD.parameters(): 376 | p.requires_grad = False 377 | 378 | self.loss_G = 0 379 | self.loss_CX = None 380 | self.loss_icnn_pixel = None 381 | self.loss_icnn_vgg = None 382 | self.loss_G_GAN = None 383 | self.loss_ce = None 384 | 385 | if self.opt.lambda_gan > 0: 386 | self.loss_G_GAN = self.loss_dic['gan'].get_g_loss( 387 | self.netD, self.input, self.output_i, self.target_t) #self.pred_real.detach()) 388 | self.loss_G += self.loss_G_GAN*self.opt.lambda_gan 389 | 390 | if self.aligned: 391 | self.loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss( 392 | self.output_i, self.target_t) 393 | 394 | self.loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss( 395 | self.output_i, self.target_t) 396 | 397 | self.loss_G += self.loss_icnn_pixel+self.loss_icnn_vgg*self.opt.lambda_vgg 398 | else: 399 | self.loss_CX = self.loss_dic['t_cx'].get_loss(self.output_i, self.target_t) 400 | 401 | self.loss_G += self.loss_CX 402 | 403 | if self.opt.lambda_ce > 0: 404 | attn_shape, idx_shape = self.attn.shape[0], self.idx.shape[0] 405 | if attn_shape != idx_shape: 406 | assert attn_shape % idx_shape == 0 407 | idx = self.idx.repeat_interleave(attn_shape//idx_shape, 0) 408 | else: 409 | idx = self.idx 410 | self.loss_ce = self.loss_dic['ce'](self.attn, idx) 411 | self.loss_G += self.loss_ce * self.opt.lambda_ce 412 | 413 | self.loss_G.backward() 414 | 415 | def forward(self): 416 | # without edge 417 | input_i = self.input 418 | 419 | if self.vgg is not None: 420 | hypercolumn = self.vgg(self.input) 421 | _, C, H, W = self.input.shape 422 | hypercolumn = [F.interpolate(feature.detach(), size=(H, W), mode='bilinear', align_corners=False) for feature in hypercolumn] 423 | input_i = [input_i] 424 | input_i.extend(hypercolumn) 425 | input_i = torch.cat(input_i, dim=1) 426 | 427 | if not self.isTrain and self.opt.nModel == 1: 428 | output_i = self.net_i(input_i) 429 | self.output_i = output_i 430 | return self.output_i 431 | 432 | with torch.no_grad(): 433 | self.v = torch.stack([net(input_i) for net in self.net_i_list], dim=1) 434 | N, X, C, H, W = self.v.shape 435 | self.q = self.net_p(self.input).unsqueeze(1) # [N 1 C] 436 | self.k = torch.stack([net(self.input) for net in self.net_p_list], dim=1) # [N X C] 437 | if self.isTrain: 438 | _, self.attn = self.net_A(self.q, self.k, self.v) # for IDE Loss 439 | 440 | self.v = self.v[~self.idx_vec].reshape(N, X-1, C, H, W) # [N X-1 C H W] 441 | self.k = self.k[~self.idx_vec].reshape(N, X-1, -1) # [N X-1 C] 442 | self.attn_res, _ = self.net_A(self.q, self.k, self.v) # for v_i^C $\mathbf{v}_\mathit{i}^\complement 443 | else: # use all experts in the testing phase 444 | self.attn_res, self.attn = self.net_A(self.q, self.k, self.v, self.opt.avg) 445 | if self.opt.show_expertise_level: 446 | print(self.attn) 447 | self.output_i = self.attn_res.reshape(N, -1, H, W) 448 | 449 | return self.output_i 450 | 451 | def optimize_parameters(self): 452 | self._train() 453 | self.forward() 454 | 455 | if self.opt.lambda_gan > 0: 456 | self.optimizer_D.zero_grad() 457 | self.backward_D() 458 | self.optimizer_D.step() 459 | 460 | self.optimizer_G.zero_grad() 461 | self.backward_G() 462 | self.optimizer_G.step() 463 | 464 | def get_current_errors(self): 465 | ret_errors = OrderedDict() 466 | if self.loss_icnn_pixel is not None: 467 | ret_errors['IPixel'] = self.loss_icnn_pixel.item() 468 | if self.loss_icnn_vgg is not None: 469 | ret_errors['VGG'] = self.loss_icnn_vgg.item() 470 | 471 | if self.opt.lambda_gan > 0 and self.loss_G_GAN is not None: 472 | ret_errors['G'] = self.loss_G_GAN.item() 473 | ret_errors['D'] = self.loss_D.item() 474 | 475 | if self.loss_CX is not None: 476 | ret_errors['CX'] = self.loss_CX.item() 477 | 478 | if self.loss_ce is not None: 479 | ret_errors['CE'] = self.loss_ce.item() 480 | 481 | return ret_errors 482 | 483 | def get_current_visuals(self): 484 | ret_visuals = OrderedDict() 485 | ret_visuals['input'] = tensor2im(self.input).astype(np.uint8) 486 | ret_visuals['output_i'] = tensor2im(self.output_i).astype(np.uint8) 487 | ret_visuals['target'] = tensor2im(self.target_t).astype(np.uint8) 488 | ret_visuals['residual'] = tensor2im((self.input - self.output_i)).astype(np.uint8) 489 | 490 | return ret_visuals 491 | 492 | @staticmethod 493 | def load(model, resume_epoch=None): 494 | icnn_path = model.opt.icnn_path 495 | state_dict = None 496 | 497 | if model.opt.nModel == 1 and icnn_path[0] is None: 498 | model_path = util.get_model_list(model.save_dir, model.name(), epoch=resume_epoch) 499 | state_dict = torch.load(model_path) 500 | model.epoch = state_dict['epoch'] 501 | model.iterations = state_dict['iterations'] 502 | model.net_i_list[0].load_state_dict(state_dict['icnn']) 503 | if model.isTrain: 504 | model.optimizer_G.load_state_dict(state_dict['opt_g']) 505 | else: 506 | if len(icnn_path) == 1 and model.opt.nModel == 1: 507 | state_dict = torch.load(icnn_path[0]) 508 | model.net_i.load_state_dict(state_dict['icnn']) 509 | elif len(icnn_path) == 1 and model.opt.nModel != 1: 510 | state_dict = torch.load(icnn_path[0]) 511 | model.net_i_list.load_state_dict(state_dict['icnn']) 512 | if hasattr(model, 'net_p_list'): 513 | model.net_p_list.load_state_dict(state_dict['net_p_list']) 514 | if hasattr(model, 'net_p'): 515 | model.net_p.load_state_dict(state_dict['net_p']) 516 | if hasattr(model, 'net_A'): 517 | model.net_A.load_state_dict(state_dict['net_A']) 518 | model.epoch = state_dict['epoch'] 519 | model.iterations = state_dict['iterations'] 520 | else: 521 | assert len(icnn_path) == model.opt.nModel 522 | for idx, ckpt in enumerate(icnn_path): 523 | state_dict = torch.load(ckpt) 524 | model.net_i_list[idx].load_state_dict(state_dict['icnn']) 525 | model.epoch = state_dict['epoch'] 526 | model.iterations = state_dict['iterations'] 527 | 528 | print('Resume from epoch %d, iteration %d' % (model.epoch, model.iterations)) 529 | return state_dict 530 | 531 | def state_dict(self): 532 | state_dict = { 533 | 'icnn': self.net_i_list.state_dict(), 534 | 'opt_g': self.optimizer_G.state_dict(), 535 | 'epoch': self.epoch, 'iterations': self.iterations 536 | } 537 | 538 | if hasattr(self, 'net_p_list'): 539 | state_dict.update({'net_p_list': self.net_p_list.state_dict()}) 540 | if hasattr(self, 'net_p'): 541 | state_dict.update({'net_p': self.net_p.state_dict()}) 542 | if hasattr(self, 'net_A'): 543 | state_dict.update({'net_A': self.net_A.state_dict()}) 544 | 545 | if self.opt.lambda_gan > 0: 546 | state_dict.update({ 547 | 'opt_d': self.optimizer_D.state_dict(), 548 | 'netD': self.netD.state_dict(), 549 | }) 550 | 551 | return state_dict 552 | 553 | 554 | class NetworkWrapper(ERRNetBase): 555 | # You can use this class to wrap other module into our training framework (\eg BDN module) 556 | def __init__(self): 557 | self.epoch = 0 558 | self.iterations = 0 559 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 560 | 561 | def print_network(self): 562 | print('--------------------- NetworkWrapper ---------------------') 563 | networks.print_network(self.net) 564 | 565 | def _eval(self): 566 | self.net.eval() 567 | 568 | def _train(self): 569 | self.net.train() 570 | 571 | def initialize(self, opt, net): 572 | BaseModel.initialize(self, opt) 573 | self.net = net.to(self.device) 574 | self.edge_map = EdgeMap(scale=1).to(self.device) 575 | 576 | if self.isTrain: 577 | # define loss functions 578 | self.vgg = losses.Vgg19(requires_grad=False).to(self.device) 579 | self.loss_dic = losses.init_loss(opt, self.Tensor) 580 | vggloss = losses.ContentLoss() 581 | vggloss.initialize(losses.VGGLoss(self.vgg)) 582 | self.loss_dic['t_vgg'] = vggloss 583 | 584 | cxloss = losses.ContentLoss() 585 | if opt.unaligned_loss == 'vgg': 586 | cxloss.initialize(losses.VGGLoss(self.vgg, weights=[0.1], indices=[31])) 587 | elif opt.unaligned_loss == 'ctx': 588 | cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1,0.1,0.1], indices=[8, 13, 22])) 589 | elif opt.unaligned_loss == 'mse': 590 | cxloss.initialize(nn.MSELoss()) 591 | elif opt.unaligned_loss == 'ctx_vgg': 592 | cxloss.initialize(losses.CXLoss(self.vgg, weights=[0.1,0.1,0.1,0.1], indices=[8, 13, 22, 31], criterions=[losses.CX_loss]*3+[nn.L1Loss()])) 593 | 594 | else: 595 | raise NotImplementedError 596 | 597 | self.loss_dic['t_cx'] = cxloss 598 | 599 | # initialize optimizers 600 | self.optimizer_G = torch.optim.Adam(self.net.parameters(), 601 | lr=opt.lr, betas=(opt.beta1, 0.999), weight_decay=opt.wd) 602 | 603 | self._init_optimizer([self.optimizer_G]) 604 | 605 | # define discriminator 606 | # if self.opt.lambda_gan > 0: 607 | self.netD = networks.define_D(opt, 3) 608 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 609 | lr=opt.lr, betas=(opt.beta1, 0.999)) 610 | self._init_optimizer([self.optimizer_D]) 611 | 612 | if opt.no_verbose is False: 613 | self.print_network() 614 | 615 | def backward_D(self): 616 | for p in self.netD.parameters(): 617 | p.requires_grad = True 618 | 619 | self.loss_D, self.pred_fake, self.pred_real = self.loss_dic['gan'].get_loss( 620 | self.netD, self.input, self.output_i, self.target_t) 621 | 622 | (self.loss_D*self.opt.lambda_gan).backward(retain_graph=True) 623 | 624 | def backward_G(self): 625 | for p in self.netD.parameters(): 626 | p.requires_grad = False 627 | 628 | self.loss_G = 0 629 | self.loss_CX = None 630 | self.loss_icnn_pixel = None 631 | self.loss_icnn_vgg = None 632 | self.loss_G_GAN = None 633 | 634 | if self.opt.lambda_gan > 0: 635 | self.loss_G_GAN = self.loss_dic['gan'].get_g_loss( 636 | self.netD, self.input, self.output_i, self.target_t) #self.pred_real.detach()) 637 | self.loss_G += self.loss_G_GAN*self.opt.lambda_gan 638 | 639 | if self.aligned: 640 | self.loss_icnn_pixel = self.loss_dic['t_pixel'].get_loss( 641 | self.output_i, self.target_t) 642 | 643 | self.loss_icnn_vgg = self.loss_dic['t_vgg'].get_loss( 644 | self.output_i, self.target_t) 645 | 646 | # self.loss_G += self.loss_icnn_pixel 647 | self.loss_G += self.loss_icnn_pixel+self.loss_icnn_vgg*self.opt.lambda_vgg 648 | # self.loss_G += self.loss_fm * self.opt.lambda_vgg 649 | else: 650 | self.loss_CX = self.loss_dic['t_cx'].get_loss(self.output_i, self.target_t) 651 | 652 | self.loss_G += self.loss_CX 653 | 654 | self.loss_G.backward() 655 | 656 | def forward(self): 657 | raise NotImplementedError 658 | 659 | def optimize_parameters(self): 660 | self._train() 661 | self.forward() 662 | 663 | if self.opt.lambda_gan > 0: 664 | self.optimizer_D.zero_grad() 665 | self.backward_D() 666 | self.optimizer_D.step() 667 | 668 | self.optimizer_G.zero_grad() 669 | self.backward_G() 670 | self.optimizer_G.step() 671 | 672 | def get_current_errors(self): 673 | ret_errors = OrderedDict() 674 | if self.loss_icnn_pixel is not None: 675 | ret_errors['IPixel'] = self.loss_icnn_pixel.item() 676 | if self.loss_icnn_vgg is not None: 677 | ret_errors['VGG'] = self.loss_icnn_vgg.item() 678 | if self.opt.lambda_gan > 0 and self.loss_G_GAN is not None: 679 | ret_errors['G'] = self.loss_G_GAN.item() 680 | ret_errors['D'] = self.loss_D.item() 681 | if self.loss_CX is not None: 682 | ret_errors['CX'] = self.loss_CX.item() 683 | 684 | return ret_errors 685 | 686 | def get_current_visuals(self): 687 | ret_visuals = OrderedDict() 688 | ret_visuals['input'] = tensor2im(self.input).astype(np.uint8) 689 | ret_visuals['output_i'] = tensor2im(self.output_i).astype(np.uint8) 690 | ret_visuals['target'] = tensor2im(self.target_t).astype(np.uint8) 691 | ret_visuals['residual'] = tensor2im((self.input - self.output_i)).astype(np.uint8) 692 | return ret_visuals 693 | 694 | def state_dict(self): 695 | state_dict = self.net.state_dict() 696 | return state_dict 697 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import functools 6 | import numpy as np 7 | from torch import autograd 8 | import torchvision.models as models 9 | import util.util as util 10 | from models.vgg import Vgg19 11 | from torch.autograd import Function 12 | from models.CX import CX_loss 13 | 14 | ############################################################################### 15 | # Functions 16 | ############################################################################### 17 | def compute_gradient(img): 18 | gradx=img[...,1:,:]-img[...,:-1,:] 19 | grady=img[...,1:]-img[...,:-1] 20 | return gradx,grady 21 | 22 | 23 | class GradientLoss(nn.Module): 24 | def __init__(self): 25 | super(GradientLoss, self).__init__() 26 | self.loss = nn.L1Loss() 27 | 28 | def forward(self, predict, target): 29 | predict_gradx, predict_grady = compute_gradient(predict) 30 | target_gradx, target_grady = compute_gradient(target) 31 | 32 | return self.loss(predict_gradx, target_gradx) + self.loss(predict_grady, target_grady) 33 | 34 | 35 | class MultipleLoss(nn.Module): 36 | def __init__(self, losses, weight=None): 37 | super(MultipleLoss, self).__init__() 38 | self.losses = nn.ModuleList(losses) 39 | self.weight = weight or [1/len(self.losses)] * len(self.losses) 40 | 41 | def forward(self, predict, target): 42 | total_loss = 0 43 | for weight, loss in zip(self.weight, self.losses): 44 | total_loss += loss(predict, target) * weight 45 | return total_loss 46 | 47 | 48 | class MeanShift(nn.Conv2d): 49 | def __init__(self, data_mean, data_std, data_range=1, norm=True): 50 | """norm (bool): normalize/denormalize the stats""" 51 | c = len(data_mean) 52 | super(MeanShift, self).__init__(c, c, kernel_size=1) 53 | std = torch.Tensor(data_std) 54 | self.weight.data = torch.eye(c).view(c, c, 1, 1) 55 | if norm: 56 | self.weight.data.div_(std.view(c, 1, 1, 1)) 57 | self.bias.data = -1 * data_range * torch.Tensor(data_mean) 58 | self.bias.data.div_(std) 59 | else: 60 | self.weight.data.mul_(std.view(c, 1, 1, 1)) 61 | self.bias.data = data_range * torch.Tensor(data_mean) 62 | self.requires_grad = False 63 | 64 | 65 | class VGGLoss(nn.Module): 66 | def __init__(self, vgg=None, weights=None, indices=None, normalize=True): 67 | super(VGGLoss, self).__init__() 68 | if vgg is None: 69 | self.vgg = Vgg19().cuda() 70 | else: 71 | self.vgg = vgg 72 | self.criterion = nn.L1Loss() 73 | self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] 74 | self.indices = indices or [2, 7, 12, 21, 30] 75 | if normalize: 76 | self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() 77 | else: 78 | self.normalize = None 79 | 80 | def forward(self, x, y): 81 | if self.normalize is not None: 82 | x = self.normalize(x) 83 | y = self.normalize(y) 84 | x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices) 85 | loss = 0 86 | for i in range(len(x_vgg)): 87 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 88 | 89 | return loss 90 | 91 | 92 | class CXLoss(VGGLoss): 93 | # Contextual Loss from 94 | # https://arxiv.org/abs/1803.02077 95 | def __init__(self, vgg=None, weights=None, indices=None, criterions=None): 96 | super(CXLoss, self).__init__(vgg, weights, indices) 97 | self.criterions = criterions or [CX_loss] * (len(weights)) 98 | 99 | def forward(self, x, y): 100 | x = self.normalize(x) 101 | y = self.normalize(y) 102 | x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices) 103 | loss = 0 104 | for i in range(len(x_vgg)): 105 | loss += self.weights[i] * self.criterions[i](x_vgg[i], y_vgg[i].detach()) 106 | 107 | loss = loss[0] if loss.dim() == 1 else loss 108 | return loss 109 | 110 | 111 | class ContentLoss(): 112 | def initialize(self, loss): 113 | self.criterion = loss 114 | 115 | def get_loss(self, fakeIm, realIm): 116 | return self.criterion(fakeIm, realIm) 117 | 118 | 119 | class GANLoss(nn.Module): 120 | def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0, 121 | tensor=torch.FloatTensor): 122 | super(GANLoss, self).__init__() 123 | self.real_label = target_real_label 124 | self.fake_label = target_fake_label 125 | self.real_label_var = None 126 | self.fake_label_var = None 127 | self.Tensor = tensor 128 | if use_l1: 129 | self.loss = nn.L1Loss() 130 | else: 131 | self.loss = nn.BCEWithLogitsLoss() # absorb sigmoid into BCELoss 132 | 133 | def get_target_tensor(self, input, target_is_real): 134 | target_tensor = None 135 | if target_is_real: 136 | create_label = ((self.real_label_var is None) or 137 | (self.real_label_var.numel() != input.numel())) 138 | if create_label: 139 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 140 | self.real_label_var = real_tensor 141 | target_tensor = self.real_label_var 142 | else: 143 | create_label = ((self.fake_label_var is None) or 144 | (self.fake_label_var.numel() != input.numel())) 145 | if create_label: 146 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 147 | self.fake_label_var = fake_tensor 148 | target_tensor = self.fake_label_var 149 | return target_tensor 150 | 151 | def __call__(self, input, target_is_real): 152 | if isinstance(input, list): 153 | loss = 0 154 | for input_i in input: 155 | target_tensor = self.get_target_tensor(input_i, target_is_real) 156 | loss += self.loss(input_i, target_tensor) 157 | return loss 158 | else: 159 | target_tensor = self.get_target_tensor(input, target_is_real) 160 | return self.loss(input, target_tensor) 161 | 162 | 163 | class DiscLoss(): 164 | def name(self): 165 | return 'SGAN' 166 | 167 | def initialize(self, opt, tensor): 168 | self.criterionGAN = GANLoss(use_l1=False, tensor=tensor) 169 | 170 | def get_g_loss(self, net, realA, fakeB, realB): 171 | # First, G(A) should fake the discriminator 172 | pred_fake = net.forward(fakeB) 173 | return self.criterionGAN(pred_fake, 1) 174 | 175 | def get_loss(self, net, realA=None, fakeB=None, realB=None): 176 | pred_fake = None 177 | pred_real = None 178 | loss_D_fake = 0 179 | loss_D_real = 0 180 | # Fake 181 | # stop backprop to the generator by detaching fake_B 182 | # Generated Image Disc Output should be close to zero 183 | 184 | if fakeB is not None: 185 | pred_fake = net.forward(fakeB.detach()) 186 | loss_D_fake = self.criterionGAN(pred_fake, 0) 187 | 188 | # Real 189 | if realB is not None: 190 | pred_real = net.forward(realB) 191 | loss_D_real = self.criterionGAN(pred_real, 1) 192 | 193 | # Combined loss 194 | loss_D = (loss_D_fake + loss_D_real) * 0.5 195 | return loss_D, pred_fake, pred_real 196 | 197 | 198 | class DiscLossR(DiscLoss): 199 | # RSGAN from 200 | # https://arxiv.org/abs/1807.00734 201 | def name(self): 202 | return 'RSGAN' 203 | 204 | def initialize(self, opt, tensor): 205 | DiscLoss.initialize(self, opt, tensor) 206 | self.criterionGAN = GANLoss(use_l1=False, tensor=tensor) 207 | 208 | def get_g_loss(self, net, realA, fakeB, realB, pred_real=None): 209 | if pred_real is None: 210 | pred_real = net.forward(realB) 211 | pred_fake = net.forward(fakeB) 212 | return self.criterionGAN(pred_fake - pred_real, 1) 213 | 214 | def get_loss(self, net, realA, fakeB, realB): 215 | pred_real = net.forward(realB) 216 | pred_fake = net.forward(fakeB.detach()) 217 | 218 | loss_D = self.criterionGAN(pred_real - pred_fake, 1) # BCE_stable loss 219 | return loss_D, pred_fake, pred_real 220 | 221 | 222 | class DiscLossRa(DiscLoss): 223 | # RaSGAN from 224 | # https://arxiv.org/abs/1807.00734 225 | def name(self): 226 | return 'RaSGAN' 227 | 228 | def initialize(self, opt, tensor): 229 | DiscLoss.initialize(self, opt, tensor) 230 | self.criterionGAN = GANLoss(use_l1=False, tensor=tensor) 231 | 232 | def get_g_loss(self, net, realA, fakeB, realB, pred_real=None): 233 | if pred_real is None: 234 | pred_real = net.forward(realB) 235 | pred_fake = net.forward(fakeB) 236 | 237 | loss_G = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 0) 238 | loss_G += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 1) 239 | return loss_G * 0.5 240 | 241 | def get_loss(self, net, realA, fakeB, realB): 242 | pred_real = net.forward(realB) 243 | pred_fake = net.forward(fakeB.detach()) 244 | 245 | loss_D = self.criterionGAN(pred_real - torch.mean(pred_fake, dim=0, keepdim=True), 1) 246 | loss_D += self.criterionGAN(pred_fake - torch.mean(pred_real, dim=0, keepdim=True), 0) 247 | return loss_D * 0.5, pred_fake, pred_real 248 | 249 | 250 | def init_loss(opt, tensor): 251 | disc_loss = None 252 | content_loss = None 253 | 254 | loss_dic = {} 255 | 256 | pixel_loss = ContentLoss() 257 | pixel_loss.initialize(MultipleLoss([nn.MSELoss(), GradientLoss()], [0.2,0.4])) 258 | 259 | loss_dic['t_pixel'] = pixel_loss 260 | loss_dic['r_pixel'] = pixel_loss 261 | 262 | if opt.lambda_gan > 0: 263 | if opt.gan_type == 'sgan' or opt.gan_type == 'gan': 264 | disc_loss = DiscLoss() 265 | elif opt.gan_type == 'rsgan': 266 | disc_loss = DiscLossR() 267 | elif opt.gan_type == 'rasgan': 268 | disc_loss = DiscLossRa() 269 | else: 270 | raise ValueError("GAN [%s] not recognized." % opt.gan_type) 271 | 272 | disc_loss.initialize(opt, tensor) 273 | loss_dic['gan'] = disc_loss 274 | 275 | return loss_dic 276 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import init 5 | import functools 6 | from torch.optim import lr_scheduler 7 | 8 | import util.util as util 9 | from collections import OrderedDict 10 | from .vgg import Vgg16, Vgg19 11 | ############################################################################### 12 | # Functions 13 | ############################################################################### 14 | 15 | 16 | def weights_init_normal(m): 17 | classname = m.__class__.__name__ 18 | # print(classname) 19 | if isinstance(m, nn.Sequential): 20 | return 21 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 22 | init.normal_(m.weight.data, 0.0, 0.02) 23 | elif isinstance(m, nn.Linear): 24 | init.normal_(m.weight.data, 0.0, 0.02) 25 | elif isinstance(m, nn.BatchNorm2d): 26 | init.normal_(m.weight.data, 1.0, 0.02) 27 | init.constant_(m.bias.data, 0.0) 28 | 29 | 30 | def weights_init_xavier(m): 31 | classname = m.__class__.__name__ 32 | # print(classname) 33 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 34 | init.xavier_normal_(m.weight.data, gain=0.02) 35 | elif isinstance(m, nn.Linear): 36 | init.xavier_normal_(m.weight.data, gain=0.02) 37 | elif isinstance(m, nn.BatchNorm2d): 38 | init.normal_(m.weight.data, 1.0, 0.02) 39 | init.constant_(m.bias.data, 0.0) 40 | 41 | 42 | def weights_init_kaiming(m): 43 | classname = m.__class__.__name__ 44 | # print(classname) 45 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 46 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 47 | elif isinstance(m, nn.Linear): 48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 49 | elif isinstance(m, nn.BatchNorm2d): 50 | init.normal_(m.weight.data, 1.0, 0.02) 51 | init.constant_(m.bias.data, 0.0) 52 | 53 | 54 | def weights_init_orthogonal(m): 55 | classname = m.__class__.__name__ 56 | print(classname) 57 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 58 | init.orthogonal(m.weight.data, gain=1) 59 | elif isinstance(m, nn.Linear): 60 | init.orthogonal(m.weight.data, gain=1) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | init.normal(m.weight.data, 1.0, 0.02) 63 | init.constant_(m.bias.data, 0.0) 64 | 65 | 66 | def init_weights(net, init_type='normal'): 67 | print('[i] initialization method [%s]' % init_type) 68 | if init_type == 'normal': 69 | net.apply(weights_init_normal) 70 | elif init_type == 'xavier': 71 | net.apply(weights_init_xavier) 72 | elif init_type == 'kaiming': 73 | net.apply(weights_init_kaiming) 74 | elif init_type == 'orthogonal': 75 | net.apply(weights_init_orthogonal) 76 | elif init_type == 'edsr': 77 | pass 78 | else: 79 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 80 | 81 | 82 | def get_norm_layer(norm_type='instance'): 83 | if norm_type == 'batch': 84 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 85 | elif norm_type == 'instance': 86 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 87 | elif norm_type == 'none': 88 | norm_layer = None 89 | else: 90 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 91 | return norm_layer 92 | 93 | 94 | def define_D(opt, in_channels=3): 95 | # use_sigmoid = opt.gan_type == 'gan' 96 | use_sigmoid = False # incorporate sigmoid into BCE_stable loss 97 | 98 | if opt.which_model_D == 'disc_vgg': 99 | netD = Discriminator_VGG(in_channels, use_sigmoid=use_sigmoid) 100 | init_weights(netD, init_type='kaiming') 101 | elif opt.which_model_D == 'disc_patch': 102 | netD = NLayerDiscriminator(in_channels, 64, 3, nn.InstanceNorm2d, use_sigmoid, getIntermFeat=False) 103 | init_weights(netD, init_type='normal') 104 | else: 105 | raise NotImplementedError('%s is not implemented' %opt.which_model_D) 106 | 107 | if len(opt.gpu_ids) > 0: 108 | assert(torch.cuda.is_available()) 109 | netD.cuda(opt.gpu_ids[0]) 110 | 111 | return netD 112 | 113 | 114 | def print_network(net): 115 | num_params = 0 116 | for param in net.parameters(): 117 | num_params += param.numel() 118 | print(net) 119 | print('Total number of parameters: %d' % num_params) 120 | print('The size of receptive field: %d' % receptive_field(net)) 121 | 122 | 123 | def receptive_field(net): 124 | def _f(output_size, ksize, stride, dilation): 125 | return (output_size - 1) * stride + ksize * dilation - dilation + 1 126 | 127 | stats = [] 128 | for m in net.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | stats.append((m.kernel_size, m.stride, m.dilation)) 131 | 132 | rsize = 1 133 | for (ksize, stride, dilation) in reversed(stats): 134 | if type(ksize) == tuple: ksize = ksize[0] 135 | if type(stride) == tuple: stride = stride[0] 136 | if type(dilation) == tuple: dilation = dilation[0] 137 | rsize = _f(rsize, ksize, stride, dilation) 138 | return rsize 139 | 140 | 141 | def debug_network(net): 142 | def _hook(m, i, o): 143 | print(o.size()) 144 | for m in net.modules(): 145 | m.register_forward_hook(_hook) 146 | 147 | 148 | ############################################################################## 149 | # Classes 150 | ############################################################################## 151 | 152 | # Defines the PatchGAN discriminator with the specified arguments. 153 | class NLayerDiscriminator(nn.Module): 154 | def __init__(self, input_nc, ndf=64, n_layers=3, 155 | norm_layer=nn.BatchNorm2d, use_sigmoid=False, 156 | branch=1, bias=True, getIntermFeat=False): 157 | super(NLayerDiscriminator, self).__init__() 158 | self.getIntermFeat = getIntermFeat 159 | self.n_layers = n_layers 160 | kw = 4 161 | padw = int(np.ceil((kw-1.0)/2)) 162 | sequence = [[nn.Conv2d(input_nc*branch, ndf*branch, kernel_size=kw, stride=2, padding=padw, groups=branch, bias=True), nn.LeakyReLU(0.2, True)]] 163 | 164 | nf = ndf 165 | for n in range(1, n_layers): 166 | nf_prev = nf 167 | nf = min(nf * 2, 512) 168 | sequence += [[ 169 | nn.Conv2d(nf_prev*branch, nf*branch, groups=branch, kernel_size=kw, stride=2, padding=padw, bias=bias), 170 | norm_layer(nf*branch), nn.LeakyReLU(0.2, True) 171 | ]] 172 | 173 | nf_prev = nf 174 | nf = min(nf * 2, 512) 175 | sequence += [[ 176 | nn.Conv2d(nf_prev*branch, nf*branch, groups=branch, kernel_size=kw, stride=1, padding=padw, bias=bias), 177 | norm_layer(nf*branch), 178 | nn.LeakyReLU(0.2, True) 179 | ]] 180 | 181 | sequence += [[nn.Conv2d(nf*branch, 1*branch, groups=branch, kernel_size=kw, stride=1, padding=padw, bias=True)]] 182 | 183 | if use_sigmoid: 184 | sequence += [[nn.Sigmoid()]] 185 | 186 | if getIntermFeat: 187 | for n in range(len(sequence)): 188 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 189 | else: 190 | sequence_stream = [] 191 | for n in range(len(sequence)): 192 | sequence_stream += sequence[n] 193 | self.model = nn.Sequential(*sequence_stream) 194 | 195 | def forward(self, input): 196 | if self.getIntermFeat: 197 | res = [input] 198 | for n in range(self.n_layers+2): 199 | model = getattr(self, 'model'+str(n)) 200 | res.append(model(res[-1])) 201 | return res[1:] 202 | else: 203 | return self.model(input) 204 | 205 | 206 | class Discriminator_VGG(nn.Module): 207 | def __init__(self, in_channels=3, use_sigmoid=True): 208 | super(Discriminator_VGG, self).__init__() 209 | def conv(*args, **kwargs): 210 | return nn.Conv2d(*args, **kwargs) 211 | 212 | num_groups = 32 213 | 214 | body = [ 215 | conv(in_channels, 64, kernel_size=3, padding=1), # 224 216 | nn.LeakyReLU(0.2), 217 | 218 | conv(64, 64, kernel_size=3, stride=2, padding=1), # 112 219 | nn.GroupNorm(num_groups, 64), 220 | nn.LeakyReLU(0.2), 221 | 222 | conv(64, 128, kernel_size=3, padding=1), 223 | nn.GroupNorm(num_groups, 128), 224 | nn.LeakyReLU(0.2), 225 | 226 | conv(128, 128, kernel_size=3, stride=2, padding=1), # 56 227 | nn.GroupNorm(num_groups, 128), 228 | nn.LeakyReLU(0.2), 229 | 230 | conv(128, 256, kernel_size=3, padding=1), 231 | nn.GroupNorm(num_groups, 256), 232 | nn.LeakyReLU(0.2), 233 | 234 | conv(256, 256, kernel_size=3, stride=2, padding=1), # 28 235 | nn.GroupNorm(num_groups, 256), 236 | nn.LeakyReLU(0.2), 237 | 238 | conv(256, 512, kernel_size=3, padding=1), 239 | nn.GroupNorm(num_groups, 512), 240 | nn.LeakyReLU(0.2), 241 | 242 | conv(512, 512, kernel_size=3, stride=2, padding=1), # 14 243 | nn.GroupNorm(num_groups, 512), 244 | nn.LeakyReLU(0.2), 245 | 246 | conv(512, 512, kernel_size=3, stride=1, padding=1), 247 | nn.GroupNorm(num_groups, 512), 248 | nn.LeakyReLU(0.2), 249 | 250 | conv(512, 512, kernel_size=3, stride=2, padding=1), # 7 251 | nn.GroupNorm(num_groups, 512), 252 | nn.LeakyReLU(0.2), 253 | ] 254 | 255 | tail = [ 256 | nn.AdaptiveAvgPool2d(1), 257 | nn.Conv2d(512, 1024, kernel_size=1), 258 | nn.LeakyReLU(0.2), 259 | nn.Conv2d(1024, 1, kernel_size=1) 260 | ] 261 | 262 | if use_sigmoid: 263 | tail.append(nn.Sigmoid()) 264 | 265 | self.body = nn.Sequential(*body) 266 | self.tail = nn.Sequential(*tail) 267 | 268 | def forward(self, x): 269 | x = self.body(x) 270 | out = self.tail(x) 271 | return out 272 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torchvision import models 5 | 6 | 7 | class Vgg16(torch.nn.Module): 8 | def __init__(self, requires_grad=False): 9 | super(Vgg16, self).__init__() 10 | vgg_pretrained_features = models.vgg16(pretrained=True).features 11 | self.slice1 = torch.nn.Sequential() 12 | self.slice2 = torch.nn.Sequential() 13 | self.slice3 = torch.nn.Sequential() 14 | self.slice4 = torch.nn.Sequential() 15 | for x in range(4): 16 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 17 | for x in range(4, 9): 18 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 19 | for x in range(9, 16): 20 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(16, 23): 22 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 23 | if not requires_grad: 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def forward(self, X): 28 | h = self.slice1(X) 29 | h_relu1_2 = h 30 | h = self.slice2(h) 31 | h_relu2_2 = h 32 | h = self.slice3(h) 33 | h_relu3_3 = h 34 | h = self.slice4(h) 35 | h_relu4_3 = h 36 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 37 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 38 | return out 39 | 40 | 41 | class Vgg19(torch.nn.Module): 42 | def __init__(self, requires_grad=False): 43 | super(Vgg19, self).__init__() 44 | # vgg_pretrained_features = models.vgg19(pretrained=True).features 45 | self.vgg_pretrained_features = models.vgg19(pretrained=True).features 46 | # self.slice1 = torch.nn.Sequential() 47 | # self.slice2 = torch.nn.Sequential() 48 | # self.slice3 = torch.nn.Sequential() 49 | # self.slice4 = torch.nn.Sequential() 50 | # self.slice5 = torch.nn.Sequential() 51 | # for x in range(2): 52 | # self.slice1.add_module(str(x), vgg_pretrained_features[x]) 53 | # for x in range(2, 7): 54 | # self.slice2.add_module(str(x), vgg_pretrained_features[x]) 55 | # for x in range(7, 12): 56 | # self.slice3.add_module(str(x), vgg_pretrained_features[x]) 57 | # for x in range(12, 21): 58 | # self.slice4.add_module(str(x), vgg_pretrained_features[x]) 59 | # for x in range(21, 30): 60 | # self.slice5.add_module(str(x), vgg_pretrained_features[x]) 61 | if not requires_grad: 62 | for param in self.parameters(): 63 | param.requires_grad = False 64 | 65 | def forward(self, X, indices=None): 66 | if indices is None: 67 | indices = [2, 7, 12, 21, 30] 68 | out = [] 69 | #indices = sorted(indices) 70 | for i in range(indices[-1]): 71 | X = self.vgg_pretrained_features[i](X) 72 | if (i+1) in indices: 73 | out.append(X) 74 | 75 | return out 76 | 77 | # h_relu1 = self.slice1(X) 78 | # h_relu2 = self.slice2(h_relu1) 79 | # h_relu3 = self.slice3(h_relu2) 80 | # h_relu4 = self.slice4(h_relu3) 81 | # h_relu5 = self.slice5(h_relu4) 82 | # out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 83 | # return out 84 | 85 | 86 | if __name__ == '__main__': 87 | vgg = Vgg19() 88 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/options/__init__.py -------------------------------------------------------------------------------- /options/base_option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import models 3 | 4 | model_names = sorted(name for name in models.__dict__ 5 | if name.islower() and not name.startswith("__") 6 | and callable(models.__dict__[name])) 7 | 8 | def str2bool(x): 9 | if x.lower() in ('y', 't', '1', 'yes', 'true'): 10 | return True 11 | elif x.lower() in ('n', 'f', '0', 'no', 'false'): 12 | return False 13 | else: 14 | raise ValueError('The given parameter [%s] is invalid'%x) 15 | 16 | 17 | class BaseOptions(): 18 | def __init__(self): 19 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | self.initialized = False 21 | 22 | def initialize(self): 23 | # experiment specifics 24 | self.parser.add_argument('--name', type=str, default=None, help='name of the experiment. It decides where to store samples and models') 25 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | self.parser.add_argument('--model', type=str, default='errnet_model', help='chooses which model to use.', choices=model_names) 27 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 28 | self.parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 29 | self.parser.add_argument('--resume_epoch', '-re', type=int, default=None, help='checkpoint to use. (default: latest') 30 | self.parser.add_argument('--seed', type=int, default=2018, help='random seed to use. Default=2018') 31 | 32 | # for setting input 33 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 34 | self.parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') 35 | self.parser.add_argument('--max_dataset_size', type=int, default=None, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 36 | 37 | # for display 38 | self.parser.add_argument('--no-log', action='store_true', help='disable tf logger?') 39 | self.parser.add_argument('--no-verbose', action='store_true', help='disable verbose info?') 40 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 41 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 42 | self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (use 0 to disable visdom)') 43 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 44 | 45 | # for AdaNEC 46 | self.parser.add_argument('--nModel', type=int, default=1) 47 | self.parser.add_argument('--avg', type=float, default=[0.108686739, 0.009344623, 0.881968638], nargs='+', help='real20') 48 | # self.parser.add_argument('--avg', type=float, default=[0.217511492, 0.111057787, 0.671430722], nargs='+', help='wild') 49 | # self.parser.add_argument('--avg', type=float, default=[0.101291473, 0.580948268, 0.317760259], nargs='+', help='postcard') 50 | # self.parser.add_argument('--avg', type=float, default=[0.186080102, 0.355502779, 0.458417119], nargs='+', help='solid') 51 | # self.parser.add_argument('--avg', type=float, default=[0.190290713, 0.423943493, 0.385765795], nargs='+', help='nature20') 52 | self.parser.add_argument('--show_expertise_level', type=str2bool, default=False) 53 | 54 | self.initialized = True 55 | -------------------------------------------------------------------------------- /options/errnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/options/errnet/__init__.py -------------------------------------------------------------------------------- /options/errnet/base_options.py: -------------------------------------------------------------------------------- 1 | from options.base_option import BaseOptions as Base 2 | from util import util 3 | import os 4 | import torch 5 | import numpy as np 6 | import random 7 | 8 | class BaseOptions(Base): 9 | def initialize(self): 10 | Base.initialize(self) 11 | # experiment specifics 12 | self.parser.add_argument('--inet', type=str, default='errnet', help='chooses which architecture to use for inet.') 13 | # self.parser.add_argument('--icnn_path', type=str, default=None, help='icnn checkpoint to use.') 14 | self.parser.add_argument('--icnn_path', type=str, default=[None], nargs='+', help='icnn checkpoint to use.') 15 | self.parser.add_argument('--init_type', type=str, default='edsr', help='network initialization [normal|xavier|kaiming|orthogonal|uniform]') 16 | # for network 17 | self.parser.add_argument('--hyper', action='store_true', help='if true, augment input with vgg hypercolumn feature') 18 | 19 | self.initialized = True 20 | 21 | def parse(self): 22 | if not self.initialized: 23 | self.initialize() 24 | self.opt = self.parser.parse_args() 25 | self.opt.isTrain = self.isTrain # train or test 26 | 27 | torch.backends.cudnn.deterministic = True 28 | torch.manual_seed(self.opt.seed) 29 | np.random.seed(self.opt.seed) # seed for every module 30 | random.seed(self.opt.seed) 31 | 32 | str_ids = self.opt.gpu_ids.split(',') 33 | self.opt.gpu_ids = [] 34 | for str_id in str_ids: 35 | id = int(str_id) 36 | if id >= 0: 37 | self.opt.gpu_ids.append(id) 38 | 39 | # set gpu ids 40 | if len(self.opt.gpu_ids) > 0: 41 | torch.cuda.set_device(self.opt.gpu_ids[0]) 42 | 43 | args = vars(self.opt) 44 | 45 | print('------------ Options -------------') 46 | for k, v in sorted(args.items()): 47 | print('%s: %s' % (str(k), str(v))) 48 | print('-------------- End ----------------') 49 | 50 | # save to the disk 51 | self.opt.name = self.opt.name or '_'.join([self.opt.model]) 52 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 53 | util.mkdirs(expr_dir) 54 | file_name = os.path.join(expr_dir, 'opt.txt') 55 | with open(file_name, 'wt') as opt_file: 56 | opt_file.write('------------ Options -------------\n') 57 | for k, v in sorted(args.items()): 58 | opt_file.write('%s: %s\n' % (str(k), str(v))) 59 | opt_file.write('-------------- End ----------------\n') 60 | 61 | if self.opt.debug: 62 | self.opt.display_freq = 20 63 | self.opt.print_freq = 20 64 | self.opt.nEpochs = 40 65 | self.opt.max_dataset_size = 100 66 | self.opt.no_log = False 67 | self.opt.nThreads = 0 68 | self.opt.decay_iter = 0 69 | self.opt.serial_batches = True 70 | self.opt.no_flip = True 71 | 72 | return self.opt 73 | -------------------------------------------------------------------------------- /options/errnet/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | # for displays 8 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 9 | self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 12 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 13 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 14 | 15 | # for training (Note: in train_errnet.py, we mannually tune the training protocol, but you can also use following setting by modifying the code in errnet_model.py) 16 | self.parser.add_argument('--nEpochs', '-n', type=int, default=60, help='# of epochs to run') 17 | self.parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') 18 | self.parser.add_argument('--wd', type=float, default=0, help='weight decay for adam') 19 | 20 | self.parser.add_argument('--low_sigma', type=float, default=2, help='min sigma in synthetic dataset') 21 | self.parser.add_argument('--high_sigma', type=float, default=5, help='max sigma in synthetic dataset') 22 | self.parser.add_argument('--low_gamma', type=float, default=1.3, help='max gamma in synthetic dataset') 23 | self.parser.add_argument('--high_gamma', type=float, default=1.3, help='max gamma in synthetic dataset') 24 | 25 | # data augmentation 26 | self.parser.add_argument('--batchSize', '-b', type=int, default=1, help='input batch size') 27 | self.parser.add_argument('--loadSize', type=str, default='224,336,448', help='scale images to multiple size') 28 | self.parser.add_argument('--fineSize', type=str, default='224,224', help='then crop to this size') 29 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 30 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 31 | 32 | # for discriminator 33 | self.parser.add_argument('--which_model_D', type=str, default='disc_vgg', choices=['disc_vgg', 'disc_patch']) 34 | self.parser.add_argument('--gan_type', type=str, default='rasgan', help='gan/sgan : Vanilla GAN; rasgan : relativistic gan') 35 | 36 | # loss weight 37 | self.parser.add_argument('--unaligned_loss', type=str, default='vgg', help='learning rate policy: vgg|mse|ctx|ctx_vgg') 38 | self.parser.add_argument('--vgg_layer', type=int, default=31, help='vgg layer of unaligned loss') 39 | 40 | self.parser.add_argument('--lambda_gan', type=float, default=0.01, help='weight for gan loss') 41 | self.parser.add_argument('--lambda_vgg', type=float, default=0.1, help='weight for vgg loss') 42 | self.parser.add_argument('--lambda_ce', type=float, default=0.1) 43 | 44 | self.isTrain = True 45 | -------------------------------------------------------------------------------- /real_test.txt: -------------------------------------------------------------------------------- 1 | 3.jpg 2 | 4.jpg 3 | 9.jpg 4 | 12.jpg 5 | 15.jpg 6 | 22.jpg 7 | 23.jpg 8 | 25.jpg 9 | 29.jpg 10 | 39.jpg 11 | 46.jpg 12 | 47.jpg 13 | 58.jpg 14 | 86.jpg 15 | 87.jpg 16 | 89.jpg 17 | 93.jpg 18 | 103.jpg 19 | 107.jpg 20 | 110.jpg -------------------------------------------------------------------------------- /test_AdaNEC_NI.sh: -------------------------------------------------------------------------------- 1 | python test_errnet.py --name ERRNet_AdaNEC_NI -r --icnn_path ./checkpoints/ERRNet_AdaNEC_NI/final_release.pt --hyper --nModel 1 -------------------------------------------------------------------------------- /test_AdaNEC_OF.sh: -------------------------------------------------------------------------------- 1 | python test_errnet.py --name ERRNet_AdaNEC_OF -r --icnn_path ./checkpoints/ERRNet_AdaNEC_OF/final_release.pt --hyper --nModel 3 -------------------------------------------------------------------------------- /test_errnet.py: -------------------------------------------------------------------------------- 1 | from os.path import join, basename 2 | from options.errnet.train_options import TrainOptions 3 | from engine import Engine 4 | from data.image_folder import read_fns 5 | from data.transforms import __scale_width 6 | import torch.backends.cudnn as cudnn 7 | import data.reflect_dataset as datasets 8 | import util.util as util 9 | 10 | 11 | opt = TrainOptions().parse() 12 | 13 | opt.isTrain = False 14 | cudnn.benchmark = False # True on SIR (wild, postcard, solid) dataset for speedup 15 | opt.no_log =True 16 | opt.display_id=0 17 | opt.verbose = False 18 | 19 | datadir = './datasets/eval' 20 | 21 | # Define evaluation/test dataset 22 | 23 | eval_dataset_real = datasets.CEILTestDataset(join(datadir, 'real20'), fns=read_fns(join(datadir, 'real20', 'data_list.txt'))) 24 | # eval_dataset_wild = datasets.CEILTestDataset(join(datadir, 'wild'), fns=read_fns(join(datadir, 'wild', 'data_list.txt'))) 25 | # eval_dataset_postcard = datasets.CEILTestDataset(join(datadir, 'postcard'), fns=read_fns(join(datadir, 'postcard', 'data_list.txt'))) 26 | # eval_dataset_solid = datasets.CEILTestDataset(join(datadir, 'solid'), fns=read_fns(join(datadir, 'solid', 'data_list.txt'))) 27 | 28 | 29 | eval_dataloader_real = datasets.DataLoader( 30 | eval_dataset_real, batch_size=1, shuffle=False, 31 | num_workers=opt.nThreads, pin_memory=True) 32 | 33 | # eval_dataloader_wild = datasets.DataLoader( 34 | # eval_dataset_wild, batch_size=1, shuffle=False, 35 | # num_workers=opt.nThreads, pin_memory=True) 36 | 37 | # eval_dataloader_solid = datasets.DataLoader( 38 | # eval_dataset_solid, batch_size=1, shuffle=False, 39 | # num_workers=opt.nThreads, pin_memory=True) 40 | 41 | # eval_dataloader_postcard = datasets.DataLoader( 42 | # eval_dataset_postcard, batch_size=1, shuffle=False, 43 | # num_workers=opt.nThreads, pin_memory=True) 44 | 45 | 46 | engine = Engine(opt) 47 | 48 | """Main Loop""" 49 | result_dir = './results' 50 | 51 | all_res = {} 52 | res = engine.eval(eval_dataloader_real, dataset_name='testdata_real', savedir=join(result_dir, 'real20')) 53 | all_res['real20'] = res 54 | print('real20', res) 55 | # res = engine.eval(eval_dataloader_wild, dataset_name='testdata_wild', savedir=join(result_dir, 'wild')) 56 | # all_res['wild'] = res 57 | # print('wild', res) 58 | # res = engine.eval(eval_dataloader_postcard, dataset_name='testdata_postcard', savedir=join(result_dir, 'postcard')) 59 | # all_res['postcard'] = res 60 | # print('postcard', res) 61 | # res = engine.eval(eval_dataloader_solid, dataset_name='testdata_solid', savedir=join(result_dir, 'solid')) 62 | # all_res['solid'] = res 63 | # print('solid', res) 64 | 65 | 66 | num = { 67 | 'real20': 20, 68 | 'wild': 50, 69 | 'postcard': 199, 70 | 'solid': 200, 71 | } 72 | 73 | 74 | avg_res = {} 75 | cnt = 0 76 | for d, res in all_res.items(): 77 | for k in res.keys(): 78 | avg_res[k] = avg_res.get(k, 0) + res[k] * num[d] 79 | cnt += num[d] 80 | for k, v in avg_res.items(): 81 | avg_res[k] = v / cnt 82 | print('avg:', avg_res) 83 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train_errnet.py \ 2 | --name errnet_AdaNEC_OF \ 3 | --hyper \ 4 | -r \ 5 | --unaligned_loss vgg \ 6 | --icnn_path ./checkpoints/errnet_ceilnet/errnet_latest.pt \ 7 | ./checkpoints/errnet_unaligned/errnet_latest.pt \ 8 | ./checkpoints/errnet_real90/errnet_latest.pt \ 9 | --nModel 3 -------------------------------------------------------------------------------- /train_errnet.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from options.errnet.train_options import TrainOptions 3 | from engine import Engine 4 | from data.image_folder import read_fns 5 | import torch.backends.cudnn as cudnn 6 | import data.reflect_dataset as datasets 7 | import util.util as util 8 | import data 9 | 10 | opt = TrainOptions().parse() 11 | 12 | cudnn.benchmark = True 13 | 14 | # modify the following code to 15 | datadir = './datasets/train' 16 | 17 | datadir_syn = join(datadir, 'VOC2012_PNGImages') 18 | datadir_real = join(datadir, 'real90') 19 | datadir_unaligned = join(datadir, 'unaligned_train250') 20 | 21 | train_dataset = datasets.CEILDataset(datadir_syn, read_fns('VOC2012_224_train_png.txt'), size=opt.max_dataset_size) 22 | train_dataset_real = datasets.CEILTestDataset(datadir_real, enable_transforms=True) 23 | train_dataset_unaligned = datasets.CEILTestDataset(datadir_unaligned, enable_transforms=True, flag={'unaligned':True}, size=None) 24 | 25 | train_dataset_fusion = datasets.FusionDataset([train_dataset, train_dataset_unaligned, train_dataset_real]) 26 | 27 | train_dataloader_fusion = datasets.DataLoader( 28 | train_dataset_fusion, batch_size=opt.batchSize, shuffle=not opt.serial_batches, 29 | num_workers=opt.nThreads, pin_memory=True) 30 | 31 | engine = Engine(opt) 32 | """Main Loop""" 33 | def set_learning_rate(lr): 34 | for optimizer in engine.model.optimizers: 35 | util.set_opt_param(optimizer, 'lr', lr) 36 | 37 | engine.epoch = 60 38 | 39 | set_learning_rate(1e-4) 40 | while engine.epoch < 80: 41 | if engine.epoch == 65: 42 | set_learning_rate(5e-5) 43 | if engine.epoch == 70: 44 | set_learning_rate(1e-5) 45 | 46 | engine.train(train_dataloader_fusion) -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csmliu/AdaNEC/c0eb518b070b720216e72dad94a31dbeab21385b/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, height=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="height:%dpx" % height, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/index.py: -------------------------------------------------------------------------------- 1 | # Metrics/Indexes 2 | try: 3 | from skimage.measure import compare_ssim, compare_psnr 4 | except: 5 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr, structural_similarity as compare_ssim 6 | from functools import partial 7 | import numpy as np 8 | 9 | 10 | class Bandwise(object): 11 | def __init__(self, index_fn): 12 | self.index_fn = index_fn 13 | 14 | def __call__(self, X, Y): 15 | C = X.shape[-1] 16 | bwindex = [] 17 | for ch in range(C): 18 | x = X[..., ch] 19 | y = Y[..., ch] 20 | index = self.index_fn(x, y) 21 | bwindex.append(index) 22 | return bwindex 23 | 24 | cal_bwpsnr = Bandwise(partial(compare_psnr, data_range=255)) 25 | cal_bwssim = Bandwise(partial(compare_ssim, data_range=255)) 26 | 27 | 28 | def compare_ncc(x, y): 29 | return np.mean((x-np.mean(x)) * (y-np.mean(y))) / (np.std(x) * np.std(y)) 30 | 31 | 32 | def ssq_error(correct, estimate): 33 | """Compute the sum-squared-error for an image, where the estimate is 34 | multiplied by a scalar which minimizes the error. Sums over all pixels 35 | where mask is True. If the inputs are color, each color channel can be 36 | rescaled independently.""" 37 | assert correct.ndim == 2 38 | if np.sum(estimate**2) > 1e-5: 39 | alpha = np.sum(correct * estimate) / np.sum(estimate**2) 40 | else: 41 | alpha = 0. 42 | return np.sum((correct - alpha*estimate) ** 2) 43 | 44 | 45 | def local_error(correct, estimate, window_size, window_shift): 46 | """Returns the sum of the local sum-squared-errors, where the estimate may 47 | be rescaled within each local region to minimize the error. The windows are 48 | window_size x window_size, and they are spaced by window_shift.""" 49 | M, N, C = correct.shape 50 | ssq = total = 0. 51 | for c in range(C): 52 | for i in range(0, M - window_size + 1, window_shift): 53 | for j in range(0, N - window_size + 1, window_shift): 54 | correct_curr = correct[i:i+window_size, j:j+window_size, c] 55 | estimate_curr = estimate[i:i+window_size, j:j+window_size, c] 56 | ssq += ssq_error(correct_curr, estimate_curr) 57 | total += np.sum(correct_curr**2) 58 | # assert np.isnan(ssq/total) 59 | return ssq / total 60 | 61 | def quality_assess(X, Y): 62 | # Y: correct; X: estimate 63 | psnr = np.mean(cal_bwpsnr(Y, X)) 64 | ssim = np.mean(cal_bwssim(Y, X)) 65 | lmse = local_error(Y, X, 20, 10) 66 | ncc = compare_ncc(Y, X) 67 | return {'PSNR':psnr, 'SSIM': ssim, 'LMSE': lmse, 'NCC': ncc} 68 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import sys 4 | import time 5 | import math 6 | 7 | import torch 8 | import numpy as np 9 | import yaml 10 | from PIL import Image 11 | from torch.optim import lr_scheduler 12 | 13 | 14 | def get_config(config): 15 | with open(config, 'r') as stream: 16 | return yaml.load(stream) 17 | 18 | 19 | # Converts a Tensor into a Numpy array 20 | # |imtype|: the desired type of the converted numpy array 21 | def tensor2im(image_tensor, imtype=np.uint8): 22 | image_numpy = image_tensor[0].cpu().float().numpy() 23 | if image_numpy.shape[0] == 1: 24 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 25 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 26 | image_numpy = image_numpy.astype(imtype) 27 | if image_numpy.shape[-1] == 6: 28 | image_numpy = np.concatenate([image_numpy[:,:,:3], image_numpy[:,:,3:]], axis=1) 29 | if image_numpy.shape[-1] == 7: 30 | edge_map = np.tile(image_numpy[:,:,6:7], (1, 1, 3)) 31 | image_numpy = np.concatenate([image_numpy[:,:,:3], image_numpy[:,:,3:6], edge_map], axis=1) 32 | return image_numpy 33 | 34 | 35 | def tensor2numpy(image_tensor): 36 | image_numpy = torch.squeeze(image_tensor).cpu().float().numpy() 37 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 38 | image_numpy = image_numpy.astype(np.float32) 39 | return image_numpy 40 | 41 | 42 | # Get model list for resume 43 | def get_model_list(dirname, key, epoch=None): 44 | if epoch is None: 45 | return os.path.join(dirname, key+'_latest.pt') 46 | if os.path.exists(dirname) is False: 47 | return None 48 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 49 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f and 'latest' not in f] 50 | if gen_models is None: 51 | return None 52 | 53 | epoch_index = [int(os.path.basename(model_name).split('_')[-2]) for model_name in gen_models if 'latest' not in model_name] 54 | print('[i] available epoch list: %s' %epoch_index, gen_models) 55 | i = epoch_index.index(int(epoch)) 56 | 57 | return gen_models[i] 58 | 59 | 60 | def vgg_preprocess(batch): 61 | # normalize using imagenet mean and std 62 | mean = batch.new(batch.size()) 63 | std = batch.new(batch.size()) 64 | mean[:, 0, :, :] = 0.485 65 | mean[:, 1, :, :] = 0.456 66 | mean[:, 2, :, :] = 0.406 67 | std[:, 0, :, :] = 0.229 68 | std[:, 1, :, :] = 0.224 69 | std[:, 2, :, :] = 0.225 70 | batch = (batch + 1) / 2 71 | batch -= mean 72 | batch = batch / std 73 | return batch 74 | 75 | 76 | def diagnose_network(net, name='network'): 77 | mean = 0.0 78 | count = 0 79 | for param in net.parameters(): 80 | if param.grad is not None: 81 | mean += torch.mean(torch.abs(param.grad.data)) 82 | count += 1 83 | if count > 0: 84 | mean = mean / count 85 | print(name) 86 | print(mean) 87 | 88 | 89 | def save_image(image_numpy, image_path): 90 | image_pil = Image.fromarray(image_numpy) 91 | image_pil.save(image_path) 92 | 93 | 94 | def print_numpy(x, val=True, shp=False): 95 | x = x.astype(np.float64) 96 | if shp: 97 | print('shape,', x.shape) 98 | if val: 99 | x = x.flatten() 100 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 101 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 102 | 103 | 104 | def mkdirs(paths): 105 | if isinstance(paths, list) and not isinstance(paths, str): 106 | for path in paths: 107 | mkdir(path) 108 | else: 109 | mkdir(paths) 110 | 111 | 112 | def mkdir(path): 113 | if not os.path.exists(path): 114 | os.makedirs(path) 115 | 116 | 117 | def set_opt_param(optimizer, key, value): 118 | for group in optimizer.param_groups: 119 | group[key] = value 120 | 121 | 122 | def vis(x): 123 | if isinstance(x, torch.Tensor): 124 | Image.fromarray(tensor2im(x)).show() 125 | elif isinstance(x, np.ndarray): 126 | Image.fromarray(x.astype(np.uint8)).show() 127 | else: 128 | raise NotImplementedError('vis for type [%s] is not implemented', type(x)) 129 | 130 | """tensorboard""" 131 | from tensorboardX import SummaryWriter 132 | from datetime import datetime 133 | 134 | def get_summary_writer(log_dir): 135 | if not os.path.exists(log_dir): 136 | os.mkdir(log_dir) 137 | log_dir = os.path.join(log_dir, datetime.now().strftime('%b%d_%H-%M-%S')+'_'+socket.gethostname()) 138 | if not os.path.exists(log_dir): 139 | os.mkdir(log_dir) 140 | writer = SummaryWriter(log_dir) 141 | return writer 142 | 143 | 144 | class AverageMeters(object): 145 | def __init__(self, dic=None, total_num=None): 146 | self.dic = dic or {} 147 | # self.total_num = total_num 148 | self.total_num = total_num or {} 149 | 150 | def update(self, new_dic): 151 | for key in new_dic: 152 | if not key in self.dic: 153 | self.dic[key] = new_dic[key] 154 | self.total_num[key] = 1 155 | else: 156 | self.dic[key] += new_dic[key] 157 | self.total_num[key] += 1 158 | # self.total_num += 1 159 | 160 | def __getitem__(self, key): 161 | return self.dic[key] / self.total_num[key] 162 | 163 | def __str__(self): 164 | keys = sorted(self.keys()) 165 | res = '' 166 | for key in keys: 167 | res += (key + ': %.4f' % self[key] + ' | ') 168 | return res 169 | 170 | def keys(self): 171 | return self.dic.keys() 172 | 173 | 174 | def write_loss(writer, prefix, avg_meters, iteration): 175 | for key in avg_meters.keys(): 176 | meter = avg_meters[key] 177 | writer.add_scalar( 178 | os.path.join(prefix, key), meter, iteration) 179 | 180 | 181 | """progress bar""" 182 | import socket 183 | 184 | _, term_width = os.popen('stty size', 'r').read().split() 185 | term_width = int(term_width) 186 | 187 | TOTAL_BAR_LENGTH = 65. 188 | last_time = time.time() 189 | begin_time = last_time 190 | def progress_bar(current, total, msg=None): 191 | global last_time, begin_time 192 | if current == 0: 193 | begin_time = time.time() # Reset for new bar. 194 | 195 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 196 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 197 | 198 | sys.stdout.write(' [') 199 | for i in range(cur_len): 200 | sys.stdout.write('=') 201 | sys.stdout.write('>') 202 | for i in range(rest_len): 203 | sys.stdout.write('.') 204 | sys.stdout.write(']') 205 | 206 | cur_time = time.time() 207 | step_time = cur_time - last_time 208 | last_time = cur_time 209 | tot_time = cur_time - begin_time 210 | 211 | L = [] 212 | L.append(' Step: %s' % format_time(step_time)) 213 | L.append(' | Tot: %s' % format_time(tot_time)) 214 | if msg: 215 | L.append(' | ' + msg) 216 | 217 | msg = ''.join(L) 218 | sys.stdout.write(msg) 219 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 220 | sys.stdout.write(' ') 221 | 222 | # Go back to the center of the bar. 223 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 224 | sys.stdout.write('\b') 225 | sys.stdout.write(' %d/%d ' % (current+1, total)) 226 | 227 | if current < total-1: 228 | sys.stdout.write('\r') 229 | else: 230 | sys.stdout.write('\n') 231 | sys.stdout.flush() 232 | 233 | def format_time(seconds): 234 | days = int(seconds / 3600/24) 235 | seconds = seconds - days*3600*24 236 | hours = int(seconds / 3600) 237 | seconds = seconds - hours*3600 238 | minutes = int(seconds / 60) 239 | seconds = seconds - minutes*60 240 | secondsf = int(seconds) 241 | seconds = seconds - secondsf 242 | millis = int(seconds*1000) 243 | 244 | f = '' 245 | i = 1 246 | if days > 0: 247 | f += str(days) + 'D' 248 | i += 1 249 | if hours > 0 and i <= 2: 250 | f += str(hours) + 'h' 251 | i += 1 252 | if minutes > 0 and i <= 2: 253 | f += str(minutes) + 'm' 254 | i += 1 255 | if secondsf > 0 and i <= 2: 256 | f += str(secondsf) + 's' 257 | i += 1 258 | if millis > 0 and i <= 2: 259 | f += str(millis) + 'ms' 260 | i += 1 261 | if f == '': 262 | f = '0ms' 263 | return f 264 | 265 | 266 | def parse_args(args): 267 | str_args = args.split(',') 268 | parsed_args = [] 269 | for str_arg in str_args: 270 | arg = int(str_arg) 271 | if arg >= 0: 272 | parsed_args.append(arg) 273 | return parsed_args 274 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | 8 | 9 | class Visualizer(): 10 | def __init__(self, opt): 11 | self.display_id = opt.display_id 12 | self.use_html = opt.isTrain and not opt.no_html 13 | self.win_size = opt.display_winsize 14 | self.name = opt.name 15 | self.opt = opt 16 | self.saved = False 17 | if self.display_id > 0: 18 | import visdom 19 | self.vis = visdom.Visdom(port=opt.display_port, ipv6=False) 20 | 21 | if self.use_html: 22 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 23 | self.img_dir = os.path.join(self.web_dir, 'images') 24 | print('create web directory %s...' % self.web_dir) 25 | util.mkdirs([self.web_dir, self.img_dir]) 26 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 27 | with open(self.log_name, "a") as log_file: 28 | now = time.strftime("%c") 29 | log_file.write('================ Training Loss (%s) ================\n' % now) 30 | 31 | def reset(self): 32 | self.saved = False 33 | 34 | # |visuals|: dictionary of images to display or save 35 | def display_current_results(self, visuals, epoch, save_result): 36 | if self.display_id > 0: # show images in the browser 37 | ncols = self.opt.display_single_pane_ncols 38 | if ncols > 0: 39 | h, w = next(iter(visuals.values())).shape[:2] 40 | table_css = """""" % (w, h) 44 | title = self.name 45 | label_html = '' 46 | label_html_row = '' 47 | nrows = int(np.ceil(len(visuals.items()) / ncols)) 48 | images = [] 49 | idx = 0 50 | for label, image_numpy in visuals.items(): 51 | label_html_row += '