├── LICENSE
├── README.md
├── dataset.py
├── env.yml
├── gen_data.py
├── imgs
└── overview.jpg
├── network.py
├── output
├── B_0001.png
├── B_0002.png
├── B_0003.png
├── B_0004.png
├── B_0005.png
├── B_0006.png
├── B_0007.png
├── B_0008.png
├── B_0009.png
├── B_0010.png
├── B_0011.png
├── B_0012.png
├── B_0013.png
├── R_0001.png
├── R_0002.png
├── R_0003.png
├── R_0004.png
├── R_0005.png
├── R_0006.png
├── R_0007.png
├── R_0008.png
├── R_0009.png
├── R_0010.png
├── R_0011.png
├── R_0012.png
└── R_0013.png
├── samples
├── 0001.jpg
├── 0002.jpg
├── 0003.jpg
├── 0004.jpg
├── 0005.jpg
├── 0006.jpg
├── 0007.jpg
├── 0008.jpg
├── 0009.jpg
├── 0010(synthetic).png
├── 0011(synthetic).png
├── 0012(synthetic).png
└── 0013(synthetic).png
├── test.py
├── test.sh
└── vutil.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Jie Yang and Dong Gong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # bdn-refremv
2 | Deep Bidirectional Estimation for Single Image Reflection Removal. This package is the implementation of the paper:
3 |
4 | *[Seeing Deeply and Bidirectionally: A Deep Learning Approach for Single Image Reflection Removal](http://openaccess.thecvf.com/content_ECCV_2018/papers/Jie_Yang_Seeing_Deeply_and_ECCV_2018_paper.pdf)
5 | [Jie Yang](https://github.com/yangj1e)\*, [Dong Gong](https://donggong1.github.io)\*, [Lingqiao Liu](https://sites.google.com/site/lingqiaoliu83/), [Qinfeng Shi](https://cs.adelaide.edu.au/~javen/index.html).
6 | In European Conference on Computer Vision (ECCV), 2018.* (* Equal contribution)
7 |
8 |
9 |
10 |
11 |
12 | ## Requirements
13 |
14 | + Python packages
15 | ```
16 | pytorch>=0.4.0
17 | numpy
18 | pillow
19 | ```
20 | + An NVIDIA GPU and CUDA 9.0 or higher
21 |
22 | ### Conda environment
23 |
24 | A minimal conda environment for running the test.sh is provided.
25 |
26 | ```
27 | conda env create -f env.yml
28 | ```
29 |
30 | ## Usage
31 |
32 | + Download our pretrained model [here](https://drive.google.com/open?id=1zBCl2qI_fT3CwPZkVvZEv37bDIlhakF6). Unpack the archive into `model` folder.
33 |
34 | + Put test images into `samples` folder, and run script `bash test.sh`.
35 |
36 | ## Examples and Real-world Testing Images
37 | Two examples (on real-world images taken by a mobile phone) are shown in the following: from left to right: I (observed image with reflection), B (recovered reflection-free image) and R (the intermediate reflection image). Please see details and examples in our paper.
38 |
39 | More real-world reflection images can be found in `/samples` for testing.
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | ## Datasets
54 |
55 | The synthetic datasets used for training and testing in our paper:
56 |
57 | + [Training data](https://drive.google.com/open?id=1bbWsGG1qQgB-sbktI2h5vO8UhD1uHaj7)
58 | + [Test data](https://drive.google.com/open?id=1ZeeKJVbZ_bifsdpAlbguDleViDA4QjCw)
59 |
60 |
61 | ## Citation
62 | If you use this code for your research, please cite our paper:
63 | ````
64 | @inproceedings{eccv18refrmv,
65 | title={Seeing deeply and bidirectionally: a deep learning approach for single image reflection removal},
66 | author={Yang, Jie and Gong, Dong and Liu, Lingqiao and Shi, Qinfeng},
67 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
68 | pages={654--669},
69 | year={2018}
70 | }
71 | ````
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import itertools
3 | import numpy as np
4 | from PIL import Image
5 | import torch
6 | from torch.utils.data import Dataset
7 | import random
8 |
9 |
10 | class ref_dataset(Dataset):
11 | def __init__(self,
12 | root,
13 | transform=None,
14 | target_transform=None,
15 | rf_transform=None,
16 | real=False):
17 | self.root = root
18 | self.transform = transform
19 | self.target_transform = target_transform
20 | self.rf_transform = rf_transform
21 | self.real = real
22 | if real:
23 | self.ids = sorted(os.listdir(root))
24 | else:
25 | self.ids = sorted(os.listdir(os.path.join(root, 'I')))
26 |
27 | def __getitem__(self, index):
28 | img = self.ids[index]
29 | if self.real:
30 | input = Image.open(os.path.join(self.root, img)).convert('RGB')
31 | if self.transform is not None:
32 | input = self.transform(input)
33 | return input
34 | else:
35 | input = Image.open(os.path.join(self.root, 'I', img)).convert('RGB')
36 | target = Image.open(os.path.join(self.root, 'B', img)).convert('RGB')
37 | target_rf = Image.open(os.path.join(self.root, 'R', img)).convert('RGB')
38 | if self.transform is not None:
39 | input = self.transform(input)
40 | if self.target_transform is not None:
41 | target = self.target_transform(target)
42 | if self.rf_transform is not None:
43 | target_rf = self.rf_transform(target_rf)
44 | return input, target, target_rf
45 |
46 | def __len__(self):
47 | return len(self.ids)
48 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: bdn-refremv
2 | channels:
3 | - defaults
4 | dependencies:
5 | - blas=1.0
6 | - ca-certificates=2019.1.23
7 | - certifi=2019.3.9
8 | - cffi=1.12.2
9 | - cudatoolkit=9.0
10 | - cudnn=7.3.1
11 | - freetype=2.9.1
12 | - intel-openmp=2019.1
13 | - jpeg=9b
14 | - libedit=3.1.20181209
15 | - libffi=3.2.1
16 | - libgcc-ng=8.2.0
17 | - libgfortran-ng=7.3.0
18 | - libpng=1.6.36
19 | - libstdcxx-ng=8.2.0
20 | - libtiff=4.0.10
21 | - mkl=2019.1
22 | - mkl_fft=1.0.10
23 | - mkl_random=1.0.2
24 | - ncurses=6.1
25 | - ninja=1.8.2
26 | - numpy=1.16.2
27 | - numpy-base=1.16.2
28 | - olefile=0.46
29 | - openssl=1.1.1b
30 | - pillow=5.4.1
31 | - pip=19.0.3
32 | - pycparser=2.19
33 | - python=3.6.8
34 | - pytorch=1.0.1
35 | - readline=7.0
36 | - setuptools=40.8.0
37 | - six=1.12.0
38 | - sqlite=3.27.2
39 | - tk=8.6.8
40 | - torchvision=0.2.1
41 | - wheel=0.33.1
42 | - xz=5.2.4
43 | - zlib=1.2.11
44 | - zstd=1.3.7
45 |
46 |
--------------------------------------------------------------------------------
/gen_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import itertools
7 | import random
8 | from glob import glob
9 | import argparse
10 |
11 | import cv2
12 | import scipy.misc
13 | import numpy as np
14 | from skimage import color
15 |
16 | from PIL import Image
17 |
18 | SIZES = (3, 5, 7)
19 | SIGMAS = (0, 2)
20 | THRESHOLDS = (0.2, 0.4)
21 |
22 | def get_img_list(folders, ext='.jpg'):
23 | if ext is None:
24 | pattern = '*'
25 | else:
26 | pattern = '*' + ext
27 | return list(itertools.chain.from_iterable(glob(os.path.join(folder, pattern)) for folder in folders))
28 |
29 |
30 | # img1 and img2 are PIL images
31 | def sample_patches(img1, img2, size):
32 | w1, h1 = img1.size
33 | w2, h2 = img2.size
34 | if all(np.array((w1, h1, w2, h2)) >= 256):
35 | th = min(h1, h2)
36 | tw = min(w1, w2)
37 | x1 = random.randint(0, w1 - tw)
38 | y1 = random.randint(0, h1 - th)
39 | x2 = random.randint(0, w2 - tw)
40 | y2 = random.randint(0, h2 - th)
41 | img1 = img1.crop((x1, y1, x1 + tw, y1 + th))
42 | img2 = img2.crop((x2, y2, x2 + tw, y2 + th))
43 | return img1, img2
44 | else:
45 | return None
46 |
47 |
48 | def sample_patch(img, crop_h, crop_w=None):
49 | if crop_w is None:
50 | crop_w = crop_h
51 | h, w, c = img.shape
52 | if h < crop_h or w < crop_w:
53 | return None
54 | j = random.randint(0, h - crop_h)
55 | i = random.randint(0, w - crop_w)
56 | return img[j:j + crop_h, i:i + crop_w, ...]
57 |
58 |
59 | def merge(img1, img2, beta):
60 | return cv2.addWeighted(img1, 1 - beta, img2, beta, 0)
61 |
62 |
63 | def generate_images(opt):
64 | if not opt.test:
65 | train_list_f = os.path.join(opt.dataroot, 'ImageSets', 'Main', 'train.txt')
66 | else:
67 | train_list_f = os.path.join(opt.dataroot, 'ImageSets', 'Main', 'val.txt')
68 | with open(train_list_f) as f:
69 | train_list = f.read().splitlines()
70 |
71 | obs_dir = os.path.join(opt.outf, 'obs')
72 | trans_dir = os.path.join(opt.outf, 'trans')
73 | ref_dir = os.path.join(opt.outf, 'ref')
74 | refb_dir = os.path.join(opt.outf, 'refb')
75 | # label_dir = os.path.join(opt.outf, 'label')
76 |
77 | if not os.path.exists(opt.outf):
78 | os.mkdir(opt.outf)
79 | if not os.path.exists(obs_dir):
80 | os.mkdir(obs_dir)
81 | if not os.path.exists(trans_dir):
82 | os.mkdir(trans_dir)
83 | if not os.path.exists(ref_dir):
84 | os.mkdir(ref_dir)
85 | if not os.path.exists(refb_dir):
86 | os.mkdir(refb_dir)
87 | # if not os.path.exists(label_dir):
88 | # os.mkdir(label_dir)
89 | print('Number of source images: %d' % len(train_list))
90 |
91 | # random_crop = transforms.RandomCrop(opt.imageSize)
92 | # f = open(os.path.join(opt.outf, 'stat.txt'), 'w')
93 | for i in range(opt.numImages):
94 | while True:
95 | T_f, R_f = random.choices(train_list, k=2)
96 | T = np.array(Image.open(os.path.join(opt.dataroot, 'JPEGImages', T_f + '.jpg')))
97 | R = np.array(Image.open(os.path.join(opt.dataroot, 'JPEGImages', R_f + '.jpg')))
98 | T_crop = sample_patch(T, opt.imageSize)
99 | R_crop = sample_patch(R, opt.imageSize)
100 | if T_crop is not None and R_crop is not None:
101 | break
102 | # patches = sample_patches(T, R, opt.imageSize)
103 | # if patches is not None:
104 | # T_crop, R_crop = patches
105 | # break
106 | # T_crop = np.array(T_crop)
107 | # R_crop = np.array(R_crop)
108 | beta = random.uniform(*THRESHOLDS)
109 | sigma = random.uniform(*SIGMAS)
110 | size = random.choice(SIZES)
111 | R_blur = cv2.GaussianBlur(R_crop, (size, size), sigma)
112 | I = merge(T_crop, R_blur, beta)
113 | scipy.misc.imsave(os.path.join(obs_dir, '{:06d}.jpg'.format(i + 1)), I)
114 | scipy.misc.imsave(os.path.join(trans_dir, '{:06d}.jpg'.format(i + 1)), T_crop)
115 | scipy.misc.imsave(os.path.join(ref_dir, '{:06d}.jpg'.format(i + 1)), R_crop)
116 | scipy.misc.imsave(os.path.join(refb_dir, '{:06d}.jpg'.format(i + 1)), R_blur)
117 | # f.write('{}\t{}\t{}\t{}\t{}\n'.format(T_f, R_f, beta, size, sigma))
118 | f.close()
119 |
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('--dataroot', required=True, help='path to BSDS500 dataset')
124 | parser.add_argument('--outf', required=True, help='folder to output generated dataset')
125 | parser.add_argument('--numImages', type=int, default=10000, help='number of images to generate')
126 | parser.add_argument('--imageSize', type=int, default=256, help='the height / width of the image')
127 | parser.add_argument('--test', action='store_true', help='generate test images')
128 | opt = parser.parse_args()
129 | print(opt)
130 |
131 | generate_images(opt)
132 |
--------------------------------------------------------------------------------
/imgs/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/imgs/overview.jpg
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import torch.nn.functional as F
7 | from torch.nn import init
8 |
9 |
10 | ###############################################################################
11 | # Functions
12 | ###############################################################################
13 | def get_norm_layer(norm_type):
14 | if norm_type == 'batch':
15 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
16 | elif norm_type == 'instance':
17 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
18 | else:
19 | print('normalization layer [%s] is not found' % norm_type)
20 | return norm_layer
21 |
22 |
23 | def define_G(input_nc,
24 | output_nc,
25 | ngf,
26 | which_model_netG,
27 | ns,
28 | norm='batch',
29 | use_dropout=False,
30 | gpu_ids=[],
31 | iteration=0,
32 | padding_type='zero',
33 | upsample_type='transpose',
34 | init_type='normal'):
35 | netG = None
36 | use_gpu = len(gpu_ids) > 0
37 | norm_layer = get_norm_layer(norm_type=norm)
38 |
39 | if use_gpu:
40 | assert (torch.cuda.is_available())
41 |
42 | if which_model_netG == 'cascade_unet':
43 | netG = Generator_cascade(
44 | input_nc,
45 | output_nc,
46 | 'unet',
47 | ns,
48 | ngf,
49 | norm_layer=norm_layer,
50 | use_dropout=use_dropout,
51 | gpu_ids=gpu_ids,
52 | iteration=iteration)
53 | else:
54 | print('Model name [%s] is not recognized' % which_model_netG)
55 | if len(gpu_ids) > 0:
56 | netG.cuda(device=gpu_ids[0])
57 | # init_weights(netG, init_type=init_type)
58 | return netG
59 |
60 |
61 | def print_network(net):
62 | num_params = 0
63 | for param in net.parameters():
64 | num_params += param.numel()
65 | print(net)
66 | print('Total number of parameters: %d' % num_params)
67 |
68 |
69 | ##############################################################################
70 | # Classes
71 | ##############################################################################
72 | class Generator_cascade(nn.Module):
73 | def __init__(self,
74 | input_nc,
75 | output_nc,
76 | base_model,
77 | ns,
78 | ngf=64,
79 | norm_layer=nn.BatchNorm2d,
80 | use_dropout=False,
81 | gpu_ids=[],
82 | iteration=0,
83 | padding_type='zero',
84 | upsample_type='transpose'):
85 | super(Generator_cascade, self).__init__()
86 | self.input_nc = input_nc
87 | self.output_nc = output_nc
88 | self.ngf = ngf
89 | self.gpu_ids = gpu_ids
90 | self.iteration = iteration
91 |
92 | if base_model == 'unet':
93 | self.model1 = UnetGenerator(
94 | input_nc,
95 | output_nc,
96 | ns[0],
97 | ngf,
98 | norm_layer=norm_layer,
99 | use_dropout=use_dropout,
100 | gpu_ids=gpu_ids)
101 | self.model2 = UnetGenerator(
102 | input_nc * 2,
103 | output_nc,
104 | ns[1],
105 | ngf,
106 | norm_layer=norm_layer,
107 | use_dropout=use_dropout,
108 | gpu_ids=gpu_ids)
109 | if self.iteration > 0:
110 | self.model3 = UnetGenerator(
111 | input_nc * 2,
112 | output_nc,
113 | ns[2],
114 | ngf,
115 | norm_layer=norm_layer,
116 | use_dropout=use_dropout,
117 | gpu_ids=gpu_ids)
118 |
119 | def forward(self, input):
120 | x = self.model1(input)
121 | res = [x]
122 | for i in range(self.iteration + 1):
123 | if i % 2 == 0:
124 | xy = torch.cat([x, input], 1)
125 | z = self.model2(xy)
126 | res += [z]
127 | else:
128 | zy = torch.cat([z, input], 1)
129 | x = self.model3(zy)
130 | res += [x]
131 | return res
132 |
133 |
134 | # Defines the Unet generator.
135 | # |num_downs|: number of downsamplings in UNet. For example,
136 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
137 | # at the bottleneck
138 | class UnetGenerator(nn.Module):
139 | def __init__(self,
140 | input_nc,
141 | output_nc,
142 | num_downs,
143 | ngf=64,
144 | norm_layer=nn.BatchNorm2d,
145 | use_dropout=False,
146 | gpu_ids=[]):
147 | super(UnetGenerator, self).__init__()
148 | self.gpu_ids = gpu_ids
149 |
150 | # currently support only input_nc == output_nc
151 | # assert (input_nc == output_nc)
152 |
153 | # construct unet structure
154 | unet_block = UnetSkipConnectionBlock(
155 | ngf * 8,
156 | ngf * 8,
157 | norm_layer=norm_layer,
158 | innermost=True,
159 | use_dropout=use_dropout)
160 | for i in range(num_downs - 5):
161 | unet_block = UnetSkipConnectionBlock(
162 | ngf * 8,
163 | ngf * 8,
164 | unet_block,
165 | norm_layer=norm_layer,
166 | use_dropout=use_dropout)
167 | unet_block = UnetSkipConnectionBlock(
168 | ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
169 | unet_block = UnetSkipConnectionBlock(
170 | ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
171 | unet_block = UnetSkipConnectionBlock(
172 | ngf, ngf * 2, unet_block, norm_layer=norm_layer)
173 | unet_block = UnetSkipConnectionBlock(
174 | output_nc,
175 | ngf,
176 | unet_block,
177 | outermost=True,
178 | norm_layer=norm_layer,
179 | outermost_input_nc=input_nc)
180 |
181 | self.model = unet_block
182 |
183 | def forward(self, input):
184 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
185 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
186 | else:
187 | return self.model(input)
188 |
189 |
190 | # Defines the submodule with skip connection.
191 | # X -------------------identity---------------------- X
192 | # |-- downsampling -- |submodule| -- upsampling --|
193 | class UnetSkipConnectionBlock(nn.Module):
194 | def __init__(self,
195 | outer_nc,
196 | inner_nc,
197 | submodule=None,
198 | outermost=False,
199 | innermost=False,
200 | norm_layer=nn.BatchNorm2d,
201 | use_dropout=False,
202 | outermost_input_nc=-1):
203 | super(UnetSkipConnectionBlock, self).__init__()
204 | self.outermost = outermost
205 |
206 | if outermost and outermost_input_nc > 0:
207 | downconv = nn.Conv2d(
208 | outermost_input_nc,
209 | inner_nc,
210 | kernel_size=4,
211 | stride=2,
212 | padding=1)
213 | else:
214 | downconv = nn.Conv2d(
215 | outer_nc, inner_nc, kernel_size=4, stride=2, padding=1)
216 |
217 | downrelu = nn.LeakyReLU(0.2, True)
218 | downnorm = norm_layer(inner_nc)
219 | uprelu = nn.ReLU(True)
220 | upnorm = norm_layer(outer_nc)
221 |
222 | if outermost:
223 | upconv = nn.ConvTranspose2d(
224 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
225 | down = [downconv]
226 | up = [uprelu, upconv, nn.Tanh()]
227 | model = down + [submodule] + up
228 | elif innermost:
229 | upconv = nn.ConvTranspose2d(
230 | inner_nc, outer_nc, kernel_size=4, stride=2, padding=1)
231 | down = [downrelu, downconv]
232 | up = [uprelu, upconv, upnorm]
233 | model = down + up
234 | else:
235 | upconv = nn.ConvTranspose2d(
236 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
237 | down = [downrelu, downconv, downnorm]
238 | up = [uprelu, upconv, upnorm]
239 |
240 | if use_dropout:
241 | model = down + [submodule] + up + [nn.Dropout(0.5)]
242 | else:
243 | model = down + [submodule] + up
244 |
245 | self.model = nn.Sequential(*model)
246 |
247 | def forward(self, x):
248 | x1 = self.model(x)
249 | diff_h = x.size()[2] - x1.size()[2]
250 | diff_w = x.size()[3] - x1.size()[3]
251 | x1 = F.pad(x1, (diff_w // 2, diff_w - diff_w // 2, diff_h // 2,
252 | diff_h - diff_h // 2))
253 | if self.outermost:
254 | return x1
255 | else:
256 | return torch.cat([x1, x], 1)
--------------------------------------------------------------------------------
/output/B_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0001.png
--------------------------------------------------------------------------------
/output/B_0002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0002.png
--------------------------------------------------------------------------------
/output/B_0003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0003.png
--------------------------------------------------------------------------------
/output/B_0004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0004.png
--------------------------------------------------------------------------------
/output/B_0005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0005.png
--------------------------------------------------------------------------------
/output/B_0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0006.png
--------------------------------------------------------------------------------
/output/B_0007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0007.png
--------------------------------------------------------------------------------
/output/B_0008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0008.png
--------------------------------------------------------------------------------
/output/B_0009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0009.png
--------------------------------------------------------------------------------
/output/B_0010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0010.png
--------------------------------------------------------------------------------
/output/B_0011.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0011.png
--------------------------------------------------------------------------------
/output/B_0012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0012.png
--------------------------------------------------------------------------------
/output/B_0013.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/B_0013.png
--------------------------------------------------------------------------------
/output/R_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0001.png
--------------------------------------------------------------------------------
/output/R_0002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0002.png
--------------------------------------------------------------------------------
/output/R_0003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0003.png
--------------------------------------------------------------------------------
/output/R_0004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0004.png
--------------------------------------------------------------------------------
/output/R_0005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0005.png
--------------------------------------------------------------------------------
/output/R_0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0006.png
--------------------------------------------------------------------------------
/output/R_0007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0007.png
--------------------------------------------------------------------------------
/output/R_0008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0008.png
--------------------------------------------------------------------------------
/output/R_0009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0009.png
--------------------------------------------------------------------------------
/output/R_0010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0010.png
--------------------------------------------------------------------------------
/output/R_0011.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0011.png
--------------------------------------------------------------------------------
/output/R_0012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0012.png
--------------------------------------------------------------------------------
/output/R_0013.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/output/R_0013.png
--------------------------------------------------------------------------------
/samples/0001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0001.jpg
--------------------------------------------------------------------------------
/samples/0002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0002.jpg
--------------------------------------------------------------------------------
/samples/0003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0003.jpg
--------------------------------------------------------------------------------
/samples/0004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0004.jpg
--------------------------------------------------------------------------------
/samples/0005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0005.jpg
--------------------------------------------------------------------------------
/samples/0006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0006.jpg
--------------------------------------------------------------------------------
/samples/0007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0007.jpg
--------------------------------------------------------------------------------
/samples/0008.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0008.jpg
--------------------------------------------------------------------------------
/samples/0009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0009.jpg
--------------------------------------------------------------------------------
/samples/0010(synthetic).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0010(synthetic).png
--------------------------------------------------------------------------------
/samples/0011(synthetic).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0011(synthetic).png
--------------------------------------------------------------------------------
/samples/0012(synthetic).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0012(synthetic).png
--------------------------------------------------------------------------------
/samples/0013(synthetic).png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangj1e/bdn-refremv/c9f44478564bbb19100359733f4d987cc6a124a0/samples/0013(synthetic).png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import random
5 | import time
6 | from collections import OrderedDict
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim as optim
12 | import torch.utils.data
13 | import torchvision.datasets as dset
14 | import torchvision.transforms as transforms
15 | import torchvision.utils as vutils
16 | from torch.autograd import Variable
17 | from math import log10
18 | from PIL import Image
19 |
20 | from dataset import ref_dataset
21 | from vutil import save_image
22 | import network
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--dataroot', required=True, help='path to dataset')
26 | parser.add_argument(
27 | '--workers', type=int, help='number of data loading workers', default=2)
28 | parser.add_argument(
29 | '--batchSize', type=int, default=8, help='input batch size')
30 | parser.add_argument(
31 | '--which_model_netG',
32 | type=str,
33 | default='cascade_unet',
34 | help='selects model to use for netG')
35 | parser.add_argument(
36 | '--ns', type=str, default='5', help='number of blocks for each module')
37 | parser.add_argument(
38 | '--netG', default='', help="path to netG (to continue training)")
39 | parser.add_argument(
40 | '--norm',
41 | type=str,
42 | default='batch',
43 | help='instance normalization or batch normalization')
44 | parser.add_argument(
45 | '--use_dropout', action='store_true', help='use dropout for the generator')
46 | parser.add_argument(
47 | '--imageSize',
48 | type=int,
49 | default=256,
50 | help='the height / width of the input image to network')
51 | parser.add_argument(
52 | '--outf',
53 | default='.',
54 | help='folder to output images and model checkpoints')
55 | parser.add_argument('--real', action='store_true', help='test real images')
56 | parser.add_argument(
57 | '--iteration', type=int, default=0, help='number of iterative updates')
58 | parser.add_argument(
59 | '--n_outputs', type=int, default=0, help='number of images to save')
60 |
61 | opt = parser.parse_args()
62 |
63 | str_ids = opt.ns.split(',')
64 | opt.ns = []
65 | for str_id in str_ids:
66 | id = int(str_id)
67 | if id >= 0:
68 | opt.ns.append(id)
69 |
70 | try:
71 | os.makedirs(opt.outf)
72 | except OSError:
73 | pass
74 |
75 | nc = 3
76 | ngf = 64
77 | netG = network.define_G(nc, nc, ngf, opt.which_model_netG, opt.ns, opt.norm,
78 | opt.use_dropout, [], opt.iteration)
79 | if opt.netG != '':
80 | netG.load_state_dict(torch.load(opt.netG))
81 |
82 | transform = transforms.Compose([
83 | # transforms.Scale(opt.imageSize),
84 | # transforms.CenterCrop(opt.imageSize),
85 | transforms.ToTensor(),
86 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
87 | ])
88 |
89 | dataset = ref_dataset(
90 | opt.dataroot,
91 | transform=transform,
92 | target_transform=transform,
93 | rf_transform=transform,
94 | real=opt.real)
95 | assert dataset
96 |
97 | dataloader = torch.utils.data.DataLoader(
98 | dataset,
99 | batch_size=opt.batchSize,
100 | shuffle=False,
101 | num_workers=int(opt.workers))
102 |
103 | input = torch.FloatTensor(opt.batchSize, 3, opt.imageSize, opt.imageSize)
104 | input = input.cuda()
105 | netG.cuda()
106 | netG.eval()
107 |
108 | criterion = nn.MSELoss()
109 | criterion.cuda()
110 |
111 | for i, data in enumerate(dataloader, 1):
112 | if opt.real:
113 | input_cpu = data
114 | category = 'real'
115 | else:
116 | input_cpu, target_B_cpu, target_R_cpu = data
117 | category = 'test'
118 | input.resize_(input_cpu.size()).copy_(input_cpu)
119 | if opt.which_model_netG.startswith('cascade'):
120 | res = netG(input)
121 | if len(res) % 2 == 1:
122 | output_B, output_R = res[-1], res[-2]
123 | else:
124 | output_B, output_R = res[-2], res[-1]
125 | else:
126 | output_B = netG(input)
127 |
128 | if opt.n_outputs == 0 or i <= opt.n_outputs:
129 | save_image(output_B / 2 + 0.5, '%s/B_%04d.png' % (opt.outf, i))
130 | if opt.which_model_netG.startswith('cascade'):
131 | save_image(output_R / 2 + 0.5, '%s/R_%04d.png' % (opt.outf, i))
132 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python ./test.py --dataroot ./samples \
3 | --batchSize 1 \
4 | --norm batch \
5 | --which_model_netG cascade_unet \
6 | --ns 7,5,5 \
7 | --iteration 1 \
8 | --outf ./output \
9 | --netG ./model/model.pth \
10 | --real
--------------------------------------------------------------------------------
/vutil.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | import torchvision.utils
4 | import numpy as np
5 |
6 |
7 | def save_image(tensor, filename):
8 | if tensor.size()[0] == 1:
9 | tensor = tensor.cpu()[0, ...]
10 | ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
11 | im = Image.fromarray(ndarr)
12 | im.save(filename)
13 | else:
14 | torchvision.utils.save_image(
15 | tensor, filename, normalize=False, range=(0, 1))
16 |
--------------------------------------------------------------------------------