├── LICENSE ├── README.md ├── dataset.py ├── env.yml ├── gen_data.py ├── imgs └── overview.jpg ├── network.py ├── output ├── B_0001.png ├── B_0002.png ├── B_0003.png ├── B_0004.png ├── B_0005.png ├── B_0006.png ├── B_0007.png ├── B_0008.png ├── B_0009.png ├── B_0010.png ├── B_0011.png ├── B_0012.png ├── B_0013.png ├── R_0001.png ├── R_0002.png ├── R_0003.png ├── R_0004.png ├── R_0005.png ├── R_0006.png ├── R_0007.png ├── R_0008.png ├── R_0009.png ├── R_0010.png ├── R_0011.png ├── R_0012.png └── R_0013.png ├── samples ├── 0001.jpg ├── 0002.jpg ├── 0003.jpg ├── 0004.jpg ├── 0005.jpg ├── 0006.jpg ├── 0007.jpg ├── 0008.jpg ├── 0009.jpg ├── 0010(synthetic).png ├── 0011(synthetic).png ├── 0012(synthetic).png └── 0013(synthetic).png ├── test.py ├── test.sh └── vutil.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jie Yang and Dong Gong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bdn-refremv 2 | Deep Bidirectional Estimation for Single Image Reflection Removal. This package is the implementation of the paper: 3 | 4 | *[Seeing Deeply and Bidirectionally: A Deep Learning Approach for Single Image Reflection Removal](http://openaccess.thecvf.com/content_ECCV_2018/papers/Jie_Yang_Seeing_Deeply_and_ECCV_2018_paper.pdf) 5 | [Jie Yang](https://github.com/yangj1e)\*, [Dong Gong](https://donggong1.github.io)\*, [Lingqiao Liu](https://sites.google.com/site/lingqiaoliu83/), [Qinfeng Shi](https://cs.adelaide.edu.au/~javen/index.html). 6 | In European Conference on Computer Vision (ECCV), 2018.* (* Equal contribution) 7 | 8 | 9 | 10 | 11 | 12 | ## Requirements 13 | 14 | + Python packages 15 | ``` 16 | pytorch>=0.4.0 17 | numpy 18 | pillow 19 | ``` 20 | + An NVIDIA GPU and CUDA 9.0 or higher 21 | 22 | ### Conda environment 23 | 24 | A minimal conda environment for running the test.sh is provided. 25 | 26 | ``` 27 | conda env create -f env.yml 28 | ``` 29 | 30 | ## Usage 31 | 32 | + Download our pretrained model [here](https://drive.google.com/open?id=1zBCl2qI_fT3CwPZkVvZEv37bDIlhakF6). Unpack the archive into `model` folder. 33 | 34 | + Put test images into `samples` folder, and run script `bash test.sh`. 35 | 36 | ## Examples and Real-world Testing Images 37 | Two examples (on real-world images taken by a mobile phone) are shown in the following: from left to right: I (observed image with reflection), B (recovered reflection-free image) and R (the intermediate reflection image). Please see details and examples in our paper. 38 | 39 | More real-world reflection images can be found in `/samples` for testing. 40 | 41 |

42 | 43 | 44 | 45 |

46 | 47 |

48 | 49 | 50 | 51 |

52 | 53 | ## Datasets 54 | 55 | The synthetic datasets used for training and testing in our paper: 56 | 57 | + [Training data](https://drive.google.com/open?id=1bbWsGG1qQgB-sbktI2h5vO8UhD1uHaj7) 58 | + [Test data](https://drive.google.com/open?id=1ZeeKJVbZ_bifsdpAlbguDleViDA4QjCw) 59 | 60 | 61 | ## Citation 62 | If you use this code for your research, please cite our paper: 63 | ```` 64 | @inproceedings{eccv18refrmv, 65 | title={Seeing deeply and bidirectionally: a deep learning approach for single image reflection removal}, 66 | author={Yang, Jie and Gong, Dong and Liu, Lingqiao and Shi, Qinfeng}, 67 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 68 | pages={654--669}, 69 | year={2018} 70 | } 71 | ```` 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | from torch.utils.data import Dataset 7 | import random 8 | 9 | 10 | class ref_dataset(Dataset): 11 | def __init__(self, 12 | root, 13 | transform=None, 14 | target_transform=None, 15 | rf_transform=None, 16 | real=False): 17 | self.root = root 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | self.rf_transform = rf_transform 21 | self.real = real 22 | if real: 23 | self.ids = sorted(os.listdir(root)) 24 | else: 25 | self.ids = sorted(os.listdir(os.path.join(root, 'I'))) 26 | 27 | def __getitem__(self, index): 28 | img = self.ids[index] 29 | if self.real: 30 | input = Image.open(os.path.join(self.root, img)).convert('RGB') 31 | if self.transform is not None: 32 | input = self.transform(input) 33 | return input 34 | else: 35 | input = Image.open(os.path.join(self.root, 'I', img)).convert('RGB') 36 | target = Image.open(os.path.join(self.root, 'B', img)).convert('RGB') 37 | target_rf = Image.open(os.path.join(self.root, 'R', img)).convert('RGB') 38 | if self.transform is not None: 39 | input = self.transform(input) 40 | if self.target_transform is not None: 41 | target = self.target_transform(target) 42 | if self.rf_transform is not None: 43 | target_rf = self.rf_transform(target_rf) 44 | return input, target, target_rf 45 | 46 | def __len__(self): 47 | return len(self.ids) 48 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: bdn-refremv 2 | channels: 3 | - defaults 4 | dependencies: 5 | - blas=1.0 6 | - ca-certificates=2019.1.23 7 | - certifi=2019.3.9 8 | - cffi=1.12.2 9 | - cudatoolkit=9.0 10 | - cudnn=7.3.1 11 | - freetype=2.9.1 12 | - intel-openmp=2019.1 13 | - jpeg=9b 14 | - libedit=3.1.20181209 15 | - libffi=3.2.1 16 | - libgcc-ng=8.2.0 17 | - libgfortran-ng=7.3.0 18 | - libpng=1.6.36 19 | - libstdcxx-ng=8.2.0 20 | - libtiff=4.0.10 21 | - mkl=2019.1 22 | - mkl_fft=1.0.10 23 | - mkl_random=1.0.2 24 | - ncurses=6.1 25 | - ninja=1.8.2 26 | - numpy=1.16.2 27 | - numpy-base=1.16.2 28 | - olefile=0.46 29 | - openssl=1.1.1b 30 | - pillow=5.4.1 31 | - pip=19.0.3 32 | - pycparser=2.19 33 | - python=3.6.8 34 | - pytorch=1.0.1 35 | - readline=7.0 36 | - setuptools=40.8.0 37 | - six=1.12.0 38 | - sqlite=3.27.2 39 | - tk=8.6.8 40 | - torchvision=0.2.1 41 | - wheel=0.33.1 42 | - xz=5.2.4 43 | - zlib=1.2.11 44 | - zstd=1.3.7 45 | 46 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import itertools 7 | import random 8 | from glob import glob 9 | import argparse 10 | 11 | import cv2 12 | import scipy.misc 13 | import numpy as np 14 | from skimage import color 15 | 16 | from PIL import Image 17 | 18 | SIZES = (3, 5, 7) 19 | SIGMAS = (0, 2) 20 | THRESHOLDS = (0.2, 0.4) 21 | 22 | def get_img_list(folders, ext='.jpg'): 23 | if ext is None: 24 | pattern = '*' 25 | else: 26 | pattern = '*' + ext 27 | return list(itertools.chain.from_iterable(glob(os.path.join(folder, pattern)) for folder in folders)) 28 | 29 | 30 | # img1 and img2 are PIL images 31 | def sample_patches(img1, img2, size): 32 | w1, h1 = img1.size 33 | w2, h2 = img2.size 34 | if all(np.array((w1, h1, w2, h2)) >= 256): 35 | th = min(h1, h2) 36 | tw = min(w1, w2) 37 | x1 = random.randint(0, w1 - tw) 38 | y1 = random.randint(0, h1 - th) 39 | x2 = random.randint(0, w2 - tw) 40 | y2 = random.randint(0, h2 - th) 41 | img1 = img1.crop((x1, y1, x1 + tw, y1 + th)) 42 | img2 = img2.crop((x2, y2, x2 + tw, y2 + th)) 43 | return img1, img2 44 | else: 45 | return None 46 | 47 | 48 | def sample_patch(img, crop_h, crop_w=None): 49 | if crop_w is None: 50 | crop_w = crop_h 51 | h, w, c = img.shape 52 | if h < crop_h or w < crop_w: 53 | return None 54 | j = random.randint(0, h - crop_h) 55 | i = random.randint(0, w - crop_w) 56 | return img[j:j + crop_h, i:i + crop_w, ...] 57 | 58 | 59 | def merge(img1, img2, beta): 60 | return cv2.addWeighted(img1, 1 - beta, img2, beta, 0) 61 | 62 | 63 | def generate_images(opt): 64 | if not opt.test: 65 | train_list_f = os.path.join(opt.dataroot, 'ImageSets', 'Main', 'train.txt') 66 | else: 67 | train_list_f = os.path.join(opt.dataroot, 'ImageSets', 'Main', 'val.txt') 68 | with open(train_list_f) as f: 69 | train_list = f.read().splitlines() 70 | 71 | obs_dir = os.path.join(opt.outf, 'obs') 72 | trans_dir = os.path.join(opt.outf, 'trans') 73 | ref_dir = os.path.join(opt.outf, 'ref') 74 | refb_dir = os.path.join(opt.outf, 'refb') 75 | # label_dir = os.path.join(opt.outf, 'label') 76 | 77 | if not os.path.exists(opt.outf): 78 | os.mkdir(opt.outf) 79 | if not os.path.exists(obs_dir): 80 | os.mkdir(obs_dir) 81 | if not os.path.exists(trans_dir): 82 | os.mkdir(trans_dir) 83 | if not os.path.exists(ref_dir): 84 | os.mkdir(ref_dir) 85 | if not os.path.exists(refb_dir): 86 | os.mkdir(refb_dir) 87 | # if not os.path.exists(label_dir): 88 | # os.mkdir(label_dir) 89 | print('Number of source images: %d' % len(train_list)) 90 | 91 | # random_crop = transforms.RandomCrop(opt.imageSize) 92 | # f = open(os.path.join(opt.outf, 'stat.txt'), 'w') 93 | for i in range(opt.numImages): 94 | while True: 95 | T_f, R_f = random.choices(train_list, k=2) 96 | T = np.array(Image.open(os.path.join(opt.dataroot, 'JPEGImages', T_f + '.jpg'))) 97 | R = np.array(Image.open(os.path.join(opt.dataroot, 'JPEGImages', R_f + '.jpg'))) 98 | T_crop = sample_patch(T, opt.imageSize) 99 | R_crop = sample_patch(R, opt.imageSize) 100 | if T_crop is not None and R_crop is not None: 101 | break 102 | # patches = sample_patches(T, R, opt.imageSize) 103 | # if patches is not None: 104 | # T_crop, R_crop = patches 105 | # break 106 | # T_crop = np.array(T_crop) 107 | # R_crop = np.array(R_crop) 108 | beta = random.uniform(*THRESHOLDS) 109 | sigma = random.uniform(*SIGMAS) 110 | size = random.choice(SIZES) 111 | R_blur = cv2.GaussianBlur(R_crop, (size, size), sigma) 112 | I = merge(T_crop, R_blur, beta) 113 | scipy.misc.imsave(os.path.join(obs_dir, '{:06d}.jpg'.format(i + 1)), I) 114 | scipy.misc.imsave(os.path.join(trans_dir, '{:06d}.jpg'.format(i + 1)), T_crop) 115 | scipy.misc.imsave(os.path.join(ref_dir, '{:06d}.jpg'.format(i + 1)), R_crop) 116 | scipy.misc.imsave(os.path.join(refb_dir, '{:06d}.jpg'.format(i + 1)), R_blur) 117 | # f.write('{}\t{}\t{}\t{}\t{}\n'.format(T_f, R_f, beta, size, sigma)) 118 | f.close() 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--dataroot', required=True, help='path to BSDS500 dataset') 124 | parser.add_argument('--outf', required=True, help='folder to output generated dataset') 125 | parser.add_argument('--numImages', type=int, default=10000, help='number of images to generate') 126 | parser.add_argument('--imageSize', type=int, default=256, help='the height / width of the image') 127 | parser.add_argument('--test', action='store_true', help='generate test images') 128 | opt = parser.parse_args() 129 | print(opt) 130 | 131 | generate_images(opt) 132 | -------------------------------------------------------------------------------- /imgs/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/imgs/overview.jpg -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | 9 | 10 | ############################################################################### 11 | # Functions 12 | ############################################################################### 13 | def get_norm_layer(norm_type): 14 | if norm_type == 'batch': 15 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 16 | elif norm_type == 'instance': 17 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 18 | else: 19 | print('normalization layer [%s] is not found' % norm_type) 20 | return norm_layer 21 | 22 | 23 | def define_G(input_nc, 24 | output_nc, 25 | ngf, 26 | which_model_netG, 27 | ns, 28 | norm='batch', 29 | use_dropout=False, 30 | gpu_ids=[], 31 | iteration=0, 32 | padding_type='zero', 33 | upsample_type='transpose', 34 | init_type='normal'): 35 | netG = None 36 | use_gpu = len(gpu_ids) > 0 37 | norm_layer = get_norm_layer(norm_type=norm) 38 | 39 | if use_gpu: 40 | assert (torch.cuda.is_available()) 41 | 42 | if which_model_netG == 'cascade_unet': 43 | netG = Generator_cascade( 44 | input_nc, 45 | output_nc, 46 | 'unet', 47 | ns, 48 | ngf, 49 | norm_layer=norm_layer, 50 | use_dropout=use_dropout, 51 | gpu_ids=gpu_ids, 52 | iteration=iteration) 53 | else: 54 | print('Model name [%s] is not recognized' % which_model_netG) 55 | if len(gpu_ids) > 0: 56 | netG.cuda(device=gpu_ids[0]) 57 | # init_weights(netG, init_type=init_type) 58 | return netG 59 | 60 | 61 | def print_network(net): 62 | num_params = 0 63 | for param in net.parameters(): 64 | num_params += param.numel() 65 | print(net) 66 | print('Total number of parameters: %d' % num_params) 67 | 68 | 69 | ############################################################################## 70 | # Classes 71 | ############################################################################## 72 | class Generator_cascade(nn.Module): 73 | def __init__(self, 74 | input_nc, 75 | output_nc, 76 | base_model, 77 | ns, 78 | ngf=64, 79 | norm_layer=nn.BatchNorm2d, 80 | use_dropout=False, 81 | gpu_ids=[], 82 | iteration=0, 83 | padding_type='zero', 84 | upsample_type='transpose'): 85 | super(Generator_cascade, self).__init__() 86 | self.input_nc = input_nc 87 | self.output_nc = output_nc 88 | self.ngf = ngf 89 | self.gpu_ids = gpu_ids 90 | self.iteration = iteration 91 | 92 | if base_model == 'unet': 93 | self.model1 = UnetGenerator( 94 | input_nc, 95 | output_nc, 96 | ns[0], 97 | ngf, 98 | norm_layer=norm_layer, 99 | use_dropout=use_dropout, 100 | gpu_ids=gpu_ids) 101 | self.model2 = UnetGenerator( 102 | input_nc * 2, 103 | output_nc, 104 | ns[1], 105 | ngf, 106 | norm_layer=norm_layer, 107 | use_dropout=use_dropout, 108 | gpu_ids=gpu_ids) 109 | if self.iteration > 0: 110 | self.model3 = UnetGenerator( 111 | input_nc * 2, 112 | output_nc, 113 | ns[2], 114 | ngf, 115 | norm_layer=norm_layer, 116 | use_dropout=use_dropout, 117 | gpu_ids=gpu_ids) 118 | 119 | def forward(self, input): 120 | x = self.model1(input) 121 | res = [x] 122 | for i in range(self.iteration + 1): 123 | if i % 2 == 0: 124 | xy = torch.cat([x, input], 1) 125 | z = self.model2(xy) 126 | res += [z] 127 | else: 128 | zy = torch.cat([z, input], 1) 129 | x = self.model3(zy) 130 | res += [x] 131 | return res 132 | 133 | 134 | # Defines the Unet generator. 135 | # |num_downs|: number of downsamplings in UNet. For example, 136 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 137 | # at the bottleneck 138 | class UnetGenerator(nn.Module): 139 | def __init__(self, 140 | input_nc, 141 | output_nc, 142 | num_downs, 143 | ngf=64, 144 | norm_layer=nn.BatchNorm2d, 145 | use_dropout=False, 146 | gpu_ids=[]): 147 | super(UnetGenerator, self).__init__() 148 | self.gpu_ids = gpu_ids 149 | 150 | # currently support only input_nc == output_nc 151 | # assert (input_nc == output_nc) 152 | 153 | # construct unet structure 154 | unet_block = UnetSkipConnectionBlock( 155 | ngf * 8, 156 | ngf * 8, 157 | norm_layer=norm_layer, 158 | innermost=True, 159 | use_dropout=use_dropout) 160 | for i in range(num_downs - 5): 161 | unet_block = UnetSkipConnectionBlock( 162 | ngf * 8, 163 | ngf * 8, 164 | unet_block, 165 | norm_layer=norm_layer, 166 | use_dropout=use_dropout) 167 | unet_block = UnetSkipConnectionBlock( 168 | ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) 169 | unet_block = UnetSkipConnectionBlock( 170 | ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer) 171 | unet_block = UnetSkipConnectionBlock( 172 | ngf, ngf * 2, unet_block, norm_layer=norm_layer) 173 | unet_block = UnetSkipConnectionBlock( 174 | output_nc, 175 | ngf, 176 | unet_block, 177 | outermost=True, 178 | norm_layer=norm_layer, 179 | outermost_input_nc=input_nc) 180 | 181 | self.model = unet_block 182 | 183 | def forward(self, input): 184 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): 185 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids) 186 | else: 187 | return self.model(input) 188 | 189 | 190 | # Defines the submodule with skip connection. 191 | # X -------------------identity---------------------- X 192 | # |-- downsampling -- |submodule| -- upsampling --| 193 | class UnetSkipConnectionBlock(nn.Module): 194 | def __init__(self, 195 | outer_nc, 196 | inner_nc, 197 | submodule=None, 198 | outermost=False, 199 | innermost=False, 200 | norm_layer=nn.BatchNorm2d, 201 | use_dropout=False, 202 | outermost_input_nc=-1): 203 | super(UnetSkipConnectionBlock, self).__init__() 204 | self.outermost = outermost 205 | 206 | if outermost and outermost_input_nc > 0: 207 | downconv = nn.Conv2d( 208 | outermost_input_nc, 209 | inner_nc, 210 | kernel_size=4, 211 | stride=2, 212 | padding=1) 213 | else: 214 | downconv = nn.Conv2d( 215 | outer_nc, inner_nc, kernel_size=4, stride=2, padding=1) 216 | 217 | downrelu = nn.LeakyReLU(0.2, True) 218 | downnorm = norm_layer(inner_nc) 219 | uprelu = nn.ReLU(True) 220 | upnorm = norm_layer(outer_nc) 221 | 222 | if outermost: 223 | upconv = nn.ConvTranspose2d( 224 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) 225 | down = [downconv] 226 | up = [uprelu, upconv, nn.Tanh()] 227 | model = down + [submodule] + up 228 | elif innermost: 229 | upconv = nn.ConvTranspose2d( 230 | inner_nc, outer_nc, kernel_size=4, stride=2, padding=1) 231 | down = [downrelu, downconv] 232 | up = [uprelu, upconv, upnorm] 233 | model = down + up 234 | else: 235 | upconv = nn.ConvTranspose2d( 236 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) 237 | down = [downrelu, downconv, downnorm] 238 | up = [uprelu, upconv, upnorm] 239 | 240 | if use_dropout: 241 | model = down + [submodule] + up + [nn.Dropout(0.5)] 242 | else: 243 | model = down + [submodule] + up 244 | 245 | self.model = nn.Sequential(*model) 246 | 247 | def forward(self, x): 248 | x1 = self.model(x) 249 | diff_h = x.size()[2] - x1.size()[2] 250 | diff_w = x.size()[3] - x1.size()[3] 251 | x1 = F.pad(x1, (diff_w // 2, diff_w - diff_w // 2, diff_h // 2, 252 | diff_h - diff_h // 2)) 253 | if self.outermost: 254 | return x1 255 | else: 256 | return torch.cat([x1, x], 1) -------------------------------------------------------------------------------- /output/B_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0001.png -------------------------------------------------------------------------------- /output/B_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0002.png -------------------------------------------------------------------------------- /output/B_0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0003.png -------------------------------------------------------------------------------- /output/B_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0004.png -------------------------------------------------------------------------------- /output/B_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0005.png -------------------------------------------------------------------------------- /output/B_0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0006.png -------------------------------------------------------------------------------- /output/B_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0007.png -------------------------------------------------------------------------------- /output/B_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0008.png -------------------------------------------------------------------------------- /output/B_0009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0009.png -------------------------------------------------------------------------------- /output/B_0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0010.png -------------------------------------------------------------------------------- /output/B_0011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0011.png -------------------------------------------------------------------------------- /output/B_0012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0012.png -------------------------------------------------------------------------------- /output/B_0013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0013.png -------------------------------------------------------------------------------- /output/R_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0001.png -------------------------------------------------------------------------------- /output/R_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0002.png -------------------------------------------------------------------------------- /output/R_0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0003.png -------------------------------------------------------------------------------- /output/R_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0004.png -------------------------------------------------------------------------------- /output/R_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0005.png -------------------------------------------------------------------------------- /output/R_0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0006.png -------------------------------------------------------------------------------- /output/R_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0007.png -------------------------------------------------------------------------------- /output/R_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0008.png -------------------------------------------------------------------------------- /output/R_0009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0009.png -------------------------------------------------------------------------------- /output/R_0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0010.png -------------------------------------------------------------------------------- /output/R_0011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0011.png -------------------------------------------------------------------------------- /output/R_0012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0012.png -------------------------------------------------------------------------------- /output/R_0013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0013.png -------------------------------------------------------------------------------- /samples/0001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0001.jpg -------------------------------------------------------------------------------- /samples/0002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0002.jpg -------------------------------------------------------------------------------- /samples/0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0003.jpg -------------------------------------------------------------------------------- /samples/0004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0004.jpg -------------------------------------------------------------------------------- /samples/0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0005.jpg -------------------------------------------------------------------------------- /samples/0006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0006.jpg -------------------------------------------------------------------------------- /samples/0007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0007.jpg -------------------------------------------------------------------------------- /samples/0008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0008.jpg -------------------------------------------------------------------------------- /samples/0009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0009.jpg -------------------------------------------------------------------------------- /samples/0010(synthetic).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0010(synthetic).png -------------------------------------------------------------------------------- /samples/0011(synthetic).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0011(synthetic).png -------------------------------------------------------------------------------- /samples/0012(synthetic).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0012(synthetic).png -------------------------------------------------------------------------------- /samples/0013(synthetic).png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0013(synthetic).png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import time 6 | from collections import OrderedDict 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | import torchvision.utils as vutils 16 | from torch.autograd import Variable 17 | from math import log10 18 | from PIL import Image 19 | 20 | from dataset import ref_dataset 21 | from vutil import save_image 22 | import network 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataroot', required=True, help='path to dataset') 26 | parser.add_argument( 27 | '--workers', type=int, help='number of data loading workers', default=2) 28 | parser.add_argument( 29 | '--batchSize', type=int, default=8, help='input batch size') 30 | parser.add_argument( 31 | '--which_model_netG', 32 | type=str, 33 | default='cascade_unet', 34 | help='selects model to use for netG') 35 | parser.add_argument( 36 | '--ns', type=str, default='5', help='number of blocks for each module') 37 | parser.add_argument( 38 | '--netG', default='', help="path to netG (to continue training)") 39 | parser.add_argument( 40 | '--norm', 41 | type=str, 42 | default='batch', 43 | help='instance normalization or batch normalization') 44 | parser.add_argument( 45 | '--use_dropout', action='store_true', help='use dropout for the generator') 46 | parser.add_argument( 47 | '--imageSize', 48 | type=int, 49 | default=256, 50 | help='the height / width of the input image to network') 51 | parser.add_argument( 52 | '--outf', 53 | default='.', 54 | help='folder to output images and model checkpoints') 55 | parser.add_argument('--real', action='store_true', help='test real images') 56 | parser.add_argument( 57 | '--iteration', type=int, default=0, help='number of iterative updates') 58 | parser.add_argument( 59 | '--n_outputs', type=int, default=0, help='number of images to save') 60 | 61 | opt = parser.parse_args() 62 | 63 | str_ids = opt.ns.split(',') 64 | opt.ns = [] 65 | for str_id in str_ids: 66 | id = int(str_id) 67 | if id >= 0: 68 | opt.ns.append(id) 69 | 70 | try: 71 | os.makedirs(opt.outf) 72 | except OSError: 73 | pass 74 | 75 | nc = 3 76 | ngf = 64 77 | netG = network.define_G(nc, nc, ngf, opt.which_model_netG, opt.ns, opt.norm, 78 | opt.use_dropout, [], opt.iteration) 79 | if opt.netG != '': 80 | netG.load_state_dict(torch.load(opt.netG)) 81 | 82 | transform = transforms.Compose([ 83 | # transforms.Scale(opt.imageSize), 84 | # transforms.CenterCrop(opt.imageSize), 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 87 | ]) 88 | 89 | dataset = ref_dataset( 90 | opt.dataroot, 91 | transform=transform, 92 | target_transform=transform, 93 | rf_transform=transform, 94 | real=opt.real) 95 | assert dataset 96 | 97 | dataloader = torch.utils.data.DataLoader( 98 | dataset, 99 | batch_size=opt.batchSize, 100 | shuffle=False, 101 | num_workers=int(opt.workers)) 102 | 103 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize) 104 | input = input.cuda() 105 | netG.cuda() 106 | netG.eval() 107 | 108 | criterion = nn.MSELoss() 109 | criterion.cuda() 110 | 111 | for i, data in enumerate(dataloader, 1): 112 | if opt.real: 113 | input_cpu = data 114 | category = 'real' 115 | else: 116 | input_cpu, target_B_cpu, target_R_cpu = data 117 | category = 'test' 118 | input.resize_(input_cpu.size()).copy_(input_cpu) 119 | if opt.which_model_netG.startswith('cascade'): 120 | res = netG(input) 121 | if len(res) % 2 == 1: 122 | output_B, output_R = res[-1], res[-2] 123 | else: 124 | output_B, output_R = res[-2], res[-1] 125 | else: 126 | output_B = netG(input) 127 | 128 | if opt.n_outputs == 0 or i <= opt.n_outputs: 129 | save_image(output_B / 2 + 0.5, '%s/B_%04d.png' % (opt.outf, i)) 130 | if opt.which_model_netG.startswith('cascade'): 131 | save_image(output_R / 2 + 0.5, '%s/R_%04d.png' % (opt.outf, i)) 132 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python ./test.py --dataroot ./samples \ 3 | --batchSize 1 \ 4 | --norm batch \ 5 | --which_model_netG cascade_unet \ 6 | --ns 7,5,5 \ 7 | --iteration 1 \ 8 | --outf ./output \ 9 | --netG ./model/model.pth \ 10 | --real -------------------------------------------------------------------------------- /vutil.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | import torchvision.utils 4 | import numpy as np 5 | 6 | 7 | def save_image(tensor, filename): 8 | if tensor.size()[0] == 1: 9 | tensor = tensor.cpu()[0, ...] 10 | ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() 11 | im = Image.fromarray(ndarr) 12 | im.save(filename) 13 | else: 14 | torchvision.utils.save_image( 15 | tensor, filename, normalize=False, range=(0, 1)) 16 | --------------------------------------------------------------------------------