├── .gitignore ├── LICENSE ├── README.md ├── __pycache__ ├── block.cpython-36.pyc ├── data_loader.cpython-36.pyc ├── data_loader.cpython-38.pyc ├── model.cpython-36.pyc ├── solver.cpython-36.pyc ├── solver.cpython-38.pyc ├── utils.cpython-36.pyc └── utils.cpython-38.pyc ├── block.py ├── data_loader.py ├── input └── t10.png ├── logger.py ├── model.py ├── requirements.txt ├── solver.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | Result/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Siyeong 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Recursive HDRI in Pytorch 2 | [paper](https://openaccess.thecvf.com/content_ECCV_2018/papers/Siyeong_Lee_Deep_Recursive_HDRI_ECCV_2018_paper.pdf) 3 | 4 | We provide PyTorch implementations for GAN-based mutliple exposure stack generation. 5 | - [x] Deep recursive HDRI 6 | 7 | ## General 8 | If you use the code for your research work, please cite our papers. 9 | 10 | ``` 11 | @inproceedings{lee2018deep, 12 | title={Deep recursive hdri: Inverse tone mapping using generative adversarial networks}, 13 | author={Lee, Siyeong and Hwan An, Gwon and Kang, Suk-Ju}, 14 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 15 | pages={596--611}, 16 | year={2018} 17 | } 18 | ``` 19 | 20 | ### Model inference 21 | * Conda environment 22 | ``` 23 | conda create -n hdr python=3.6 24 | conda activate hdr 25 | conda install -c anaconda mkl 26 | conda install pytorch==1.0.0 torchvision==0.2.1 cuda100 -c pytorch 27 | ``` 28 | 29 | * install requirements.txt 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | * Please download two model weights below and organize the downloaded files as follows: 35 | ``` 36 | DeepRecursive_HDRI 37 | ├──Result 38 | └──model 39 | ├── HDRGAN_stopdown_G_param_ch3_batch1_epoch20_lr0.0002.pkl 40 | └── HDRGAN_stopup_G_param_ch3_batch1_epoch20_lr0.0002.pkl 41 | ``` 42 | 43 | * Prepare your test images 44 | ``` 45 | DeepRecursive_HDRI 46 | ├──input 47 | ├── t10.png 48 | ├── t11.png 49 | ``` 50 | 51 | * Run the pretrained model 52 | ``` 53 | python test.py --test_dataset './input' 54 | ``` 55 | 56 | * output 57 | ``` 58 | DeepRecursive_HDRI 59 | ├──Result 60 | ├── t10 (multi exposure stack) 61 | ├── t11 (multi exposure stack) 62 | ``` 63 | 64 | **Note: We used the HDR Toolbox implementation of [Debevec and Malik 1997] to generate the results in our paper.** 65 | * see https://github.com/banterle/HDR_Toolbox 66 | 67 | ### Model weight 68 | | Model Name | model weight | 69 | |:-------------------:|:------------:| 70 | |Deep Recursive HDRI | [stopdown](https://drive.google.com/file/d/1EBNzkpPAlb01baNhw878BTGkmQpjFKdJ/view?usp=sharing)
[stopup](https://drive.google.com/file/d/1qiCfOxOn7rfEbNrOvkp1RkpFk91hmvF3/view?usp=sharing) | 71 | 72 | ## License 73 | 74 | Copyright (c) 2020, Siyeong Lee. 75 | All rights reserved. 76 | 77 | The code is distributed under a BSD license. See `LICENSE` for information. 78 | -------------------------------------------------------------------------------- /__pycache__/block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/block.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/solver.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/solver.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/solver.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class MPReLU(torch.nn.Module): 6 | def __init__(self, num_parameters=1, init=0.25): 7 | self.num_parameters = num_parameters 8 | super(MPReLU, self).__init__() 9 | self.weight = torch.nn.Parameter(torch.Tensor(num_parameters).fill_(init)) 10 | 11 | def forward(self, input): 12 | return -torch.nn.functional.prelu(-input, self.weight) 13 | 14 | 15 | def __repr__(self): 16 | return self.__class__.__name__ + '(' \ 17 | + 'num_parameters=' + str(self.num_parameters) + ')' 18 | 19 | 20 | class DenseBlock(torch.nn.Module): 21 | def __init__(self, input_size, output_size, bias=True, activation='prelu', norm='batch'): 22 | super(DenseBlock, self).__init__() 23 | self.fc = torch.nn.Linear(input_size, output_size, bias=bias) 24 | 25 | self.norm = norm 26 | if self.norm =='batch': 27 | self.bn = torch.nn.BatchNorm1d(output_size) 28 | elif self.norm == 'instance': 29 | self.bn = torch.nn.InstanceNorm1d(output_size) 30 | 31 | self.activation = activation 32 | if self.activation == 'relu': 33 | self.act = torch.nn.ReLU(True) 34 | elif self.activation == 'prelu': 35 | self.act = torch.nn.PReLU() 36 | elif self.activation == 'lrelu': 37 | self.act = torch.nn.LeakyReLU(0.2, True) 38 | elif self.activation == 'tanh': 39 | self.act = torch.nn.Tanh() 40 | elif self.activation == 'sigmoid': 41 | self.act = torch.nn.Sigmoid() 42 | elif self.activation == 'mprelu': 43 | self.act = MPReLU() 44 | 45 | def forward(self, x): 46 | if self.norm is not None: 47 | out = self.bn(self.fc(x)) 48 | else: 49 | out = self.fc(x) 50 | 51 | if self.activation is not None: 52 | return self.act(out) 53 | else: 54 | return out 55 | 56 | 57 | class ConvBlock(torch.nn.Module): 58 | def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='prelu', norm='batch'): 59 | super(ConvBlock, self).__init__() 60 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 61 | 62 | self.norm = norm 63 | if self.norm == 'batch': 64 | self.bn = torch.nn.BatchNorm2d(output_size) 65 | elif self.norm == 'instance': 66 | self.bn = torch.nn.InstanceNorm2d(output_size) 67 | 68 | self.activation = activation 69 | if self.activation == 'relu': 70 | self.act = torch.nn.ReLU(True) 71 | elif self.activation == 'prelu': 72 | self.act = torch.nn.PReLU() 73 | elif self.activation == 'lrelu': 74 | self.act = torch.nn.LeakyReLU(0.2, True) 75 | elif self.activation == 'tanh': 76 | self.act = torch.nn.Tanh() 77 | elif self.activation == 'sigmoid': 78 | self.act = torch.nn.Sigmoid() 79 | elif self.activation == 'mprelu': 80 | self.act = MPReLU() 81 | 82 | def forward(self, x): 83 | if self.norm is not None: 84 | out = self.bn(self.conv(x)) 85 | else: 86 | out = self.conv(x) 87 | 88 | if self.activation is not None: 89 | return self.act(out) 90 | else: 91 | return out 92 | 93 | 94 | class Upsample2xBlock(torch.nn.Module): 95 | def __init__(self, input_size, output_size, bias=True, upsample='rnc', activation='relu', norm='batch'): 96 | super(Upsample2xBlock, self).__init__() 97 | scale_factor = 2 98 | 99 | self.upsample = torch.nn.Sequential( 100 | torch.nn.Upsample(scale_factor=scale_factor, mode='nearest'), 101 | ConvBlock(input_size, output_size, 102 | kernel_size=3, stride=1, padding=1, 103 | bias=bias, activation=activation, norm=norm) 104 | ) 105 | 106 | def forward(self, x): 107 | out = self.upsample(x) 108 | return out 109 | 110 | 111 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import join 3 | 4 | from PIL import Image 5 | 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.dataset import Dataset 8 | from torchvision import transforms 9 | 10 | import random 11 | 12 | class data_loader(Dataset): 13 | def __init__(self, dataset_path, img_size=64, fliplr=True, fliptb=False, rotate=False, gray=False): 14 | super(data_loader, self).__init__() 15 | 16 | self.img_size = img_size 17 | self.dataset_path = dataset_path 18 | self.fliptb = fliptb 19 | self.fliplr = fliplr 20 | self.rotate = rotate 21 | self.gray = gray 22 | 23 | self.input_img = [join(dataset_path + '/X/', x) for x in sorted(listdir(dataset_path + '/X/')) if is_image_file(x)] 24 | self.target_img = [join(dataset_path + '/y/', y) for y in sorted(listdir(dataset_path + '/y/')) if is_image_file(y)] 25 | 26 | assert len(self.input_img) == len(self.target_img) 27 | 28 | def __getitem__(self, index): 29 | 30 | input_img = load_img(self.input_img[index]) 31 | target_img = load_img(self.target_img[index]) 32 | 33 | if self.rotate: 34 | rv = random.randint(1,3) 35 | input_img = input_img.rotate(90 * rv, expand = True) 36 | target_img = target_img.rotate(90 * rv, expand = True) 37 | 38 | if self.fliplr: 39 | if random.random() < 0.5: 40 | input_img = input_img.transpose(Image.FLIP_LEFT_RIGHT) 41 | target_img = target_img.transpose(Image.FLIP_LEFT_RIGHT) 42 | 43 | if self.fliptb: 44 | if random.random() < 0.5: 45 | input_img = input_img.transpose(Image.FLIP_TOP_BOTTOM) 46 | target_img = target_img.transpose(Image.FLIP_TOP_BOTTOM) 47 | 48 | total = transforms.Compose([transforms.Scale(4*self.img_size), 49 | transforms.ToTensor()#, 50 | #transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | # std=[0.229, 0.224, 0.225]) 52 | ]) 53 | 54 | input_tensor = total(input_img) 55 | target_tensor = total(target_img) 56 | 57 | return input_tensor, target_tensor 58 | 59 | def __len__(self): 60 | return len(self.input_img) 61 | 62 | def load_img(filepath): 63 | img = Image.open(filepath).convert('RGB') 64 | 65 | return img 66 | 67 | def is_image_file(filename): 68 | return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.PNG', '.JPG']) 69 | 70 | 71 | def get_loader(image_path, img_size=64, is_gray=False): 72 | dataset = data_loader(dataset_path=image_path, img_size=img_size) 73 | return dataset 74 | -------------------------------------------------------------------------------- /input/t10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Siyeong-Lee/Deep_Recursive_HDRI/e8dbbb124526e0301ed679709d6d9fab6e3af991/input/t10.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | #code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | # import scipy.misc 6 | import matplotlib.pyplot as plt 7 | 8 | try: 9 | from StringIO import StringIO # Python 2.7 10 | except ImportError: 11 | from io import BytesIO # Python 3.x 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, log_dir): 16 | """Create a summary writer logging to log_dir.""" 17 | self.writer = tf.summary.FileWriter(log_dir) 18 | 19 | def scalar_summary(self, tag, value, step): 20 | """Log a scalar variable.""" 21 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 22 | self.writer.add_summary(summary, step) 23 | self.writer.flush() 24 | 25 | def image_summary(self, tag, images, step): 26 | """Log a list of images.""" 27 | 28 | img_summaries = [] 29 | for i, img in enumerate(images): 30 | # Write the image to a string 31 | try: 32 | s = StringIO() 33 | except: 34 | s = BytesIO() 35 | # scipy.misc.toimage(img).save(s, format="png") 36 | plt.imsave(s, img, format='png') 37 | 38 | # Create an Image object 39 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 40 | height=img.shape[0], 41 | width=img.shape[1]) 42 | # Create a Summary value 43 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 44 | 45 | # Create and write Summary 46 | summary = tf.Summary(value=img_summaries) 47 | self.writer.add_summary(summary, step) 48 | self.writer.flush() 49 | 50 | def histo_summary(self, tag, values, step, bins=1000): 51 | """Log a histogram of the tensor of values.""" 52 | 53 | # Create a histogram using numpy 54 | counts, bin_edges = np.histogram(values, bins=bins) 55 | 56 | # Fill the fields of the histogram proto 57 | hist = tf.HistogramProto() 58 | hist.min = float(np.min(values)) 59 | hist.max = float(np.max(values)) 60 | hist.num = int(np.prod(values.shape)) 61 | hist.sum = float(np.sum(values)) 62 | hist.sum_squares = float(np.sum(values ** 2)) 63 | 64 | # Drop the start of the first bin 65 | bin_edges = bin_edges[1:] 66 | 67 | # Add bin edges and counts 68 | for edge in bin_edges: 69 | hist.bucket_limit.append(edge) 70 | for c in counts: 71 | hist.bucket.append(c) 72 | 73 | # Create and write Summary 74 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 75 | self.writer.add_summary(summary, step) 76 | self.writer.flush() 77 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from block import * 3 | import utils 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | class Generator(torch.nn.Module): 10 | def __init__(self, num_channels, base_filter, stop, n_block=4): 11 | super(Generator, self).__init__() 12 | self.stop = stop 13 | 14 | if stop == 'up': 15 | act = 'prelu' 16 | else: 17 | act = 'mprelu' 18 | 19 | self.input_conv = ConvBlock(num_channels, base_filter, 4, 2, 1, activation=act, norm=None, bias=True) 20 | self.conv1 = ConvBlock(1*base_filter, 2*base_filter, 4, 2, 1, activation=act, norm='batch', bias=False) 21 | self.conv2 = ConvBlock(2*base_filter, 4*base_filter, 4, 2, 1, activation=act, norm='batch', bias=False) 22 | self.conv3 = ConvBlock(4*base_filter, 8*base_filter, 4, 2, 1, activation=act, norm='batch', bias=False) 23 | self.conv4 = ConvBlock(8*base_filter, 8*base_filter, 4, 2, 1, activation=act, norm='batch', bias=False) 24 | 25 | self.deconv4 = Upsample2xBlock(8*base_filter, 8*base_filter, activation=act, norm='batch', bias=False, upsample='rnc') 26 | self.deconv5 = Upsample2xBlock(16*base_filter, 4*base_filter, activation=act, norm='batch', bias=False, upsample='rnc') 27 | self.deconv6 = Upsample2xBlock(8*base_filter,2*base_filter, activation=act, norm='batch', bias=False, upsample='rnc') 28 | self.deconv7 = Upsample2xBlock(4*base_filter, 1*base_filter, activation=act, norm='batch', bias=False, upsample='rnc') 29 | self.output_deconv = Upsample2xBlock(2*base_filter, num_channels, activation=act, norm='batch', bias=False, upsample='rnc') 30 | self.output = ConvBlock(2*num_channels, num_channels, 3, 1, 1, activation='tanh', norm=None, bias=False) 31 | 32 | def forward(self, x): 33 | e1 = self.input_conv(x) 34 | e2 = self.conv1(e1) 35 | e3 = self.conv2(e2) 36 | e4 = self.conv3(e3) 37 | e5 = self.conv4(e4) 38 | 39 | d4 = F.dropout(self.deconv4(e5), 0.5, training=True) 40 | d4 = torch.cat([d4, e4], 1) 41 | 42 | d5 = F.dropout(self.deconv5(d4), 0.5, training=True) 43 | d5 = torch.cat([d5, e3], 1) 44 | 45 | d6 = F.dropout(self.deconv6(d5), 0.5, training=True) 46 | d6 = torch.cat([d6, e2], 1) 47 | 48 | d7 = self.deconv7(d6) 49 | d7 = torch.cat([d7, e1], 1) 50 | 51 | d8 = self.output_deconv(d7) 52 | 53 | in_out = torch.cat([d8, x], 1) 54 | out = self.output(in_out) 55 | return out 56 | 57 | def weight_init(self, mean=0.0, std=0.02): 58 | for m in self.modules(): 59 | utils.weights_init_normal(m, mean=mean, std=std) 60 | 61 | # Defines the PatchGAN discriminator. 62 | class NLayerDiscriminator(nn.Module): 63 | def __init__(self, num_channels, base_filter, image_size, n_layers=4): 64 | super(NLayerDiscriminator, self).__init__() 65 | 66 | kw = 4 67 | padw = 1 68 | 69 | # global feature extraction 70 | sequence = [ 71 | nn.Conv2d(num_channels, base_filter, kernel_size=kw, stride=2, padding=padw), 72 | nn.LeakyReLU(0.2, True) 73 | ] 74 | 75 | nf_mult = 1 76 | nf_mult_prev = 1 77 | for n in range(1, n_layers): 78 | nf_mult_prev = nf_mult 79 | nf_mult = min(2**n, 8) 80 | sequence += [ 81 | nn.Conv2d(base_filter * nf_mult_prev, base_filter * nf_mult, kernel_size=kw, stride=2, 82 | padding=padw), nn.BatchNorm2d(base_filter * nf_mult, 83 | affine=True), nn.LeakyReLU(0.2, True) 84 | ] 85 | 86 | nf_mult_prev = nf_mult 87 | nf_mult = min(2**n_layers, 8) 88 | sequence += [ 89 | nn.Conv2d(base_filter * nf_mult_prev, base_filter * nf_mult, kernel_size=kw, stride=1, 90 | padding=padw), nn.BatchNorm2d(base_filter * nf_mult, 91 | affine=True), nn.LeakyReLU(0.2, True) 92 | ] 93 | 94 | sequence += [nn.Conv2d(base_filter * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 95 | sequence += [nn.Sigmoid()] 96 | 97 | self.model = nn.Sequential(*sequence) 98 | 99 | def forward(self, x): 100 | out = self.model(x) 101 | return out 102 | 103 | def weight_init(self, mean=0.0, std=0.02): 104 | for m in self.modules(): 105 | utils.weights_init_normal(m, mean=mean, std=std) 106 | 107 | class GANLoss(nn.Module): 108 | def __init__(self, use_lsgan=True, target_real_label=0.9, target_fake_label=0.1, 109 | tensor=torch.cuda.FloatTensor): 110 | super(GANLoss, self).__init__() 111 | self.real_label = target_real_label 112 | self.fake_label = target_fake_label 113 | self.real_label_var = None 114 | self.fake_label_var = None 115 | self.Tensor = tensor 116 | if use_lsgan: 117 | self.loss = nn.MSELoss() 118 | else: 119 | self.loss = nn.BCELoss() 120 | 121 | def get_target_tensor(self, input, target_is_real): 122 | target_tensor = None 123 | if target_is_real: 124 | create_label = ((self.real_label_var is None) or 125 | (self.real_label_var.numel() != input.numel())) 126 | if create_label: 127 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 128 | self.real_label_var = Variable(real_tensor, requires_grad=False) 129 | target_tensor = self.real_label_var 130 | else: 131 | create_label = ((self.fake_label_var is None) or 132 | (self.fake_label_var.numel() != input.numel())) 133 | if create_label: 134 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 135 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 136 | target_tensor = self.fake_label_var 137 | return target_tensor 138 | 139 | def __call__(self, input, target_is_real): 140 | target_tensor = self.get_target_tensor(input, target_is_real) 141 | return self.loss(input, target_tensor) 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | torchvision==0.1.8 3 | matplotlib 4 | imageio 5 | scipy==1.0.0 6 | 7 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.autograd import Variable 5 | from torchvision import models 6 | from torch.utils.data import DataLoader 7 | import utils 8 | #from logger import Logger 9 | from model import * 10 | 11 | from data_loader import * 12 | from torchvision import transforms 13 | import numpy as np 14 | 15 | class Solver(object): 16 | def __init__(self, args): 17 | # parameters 18 | self.model_name = args.model_name 19 | self.patch_size = args.patch_size 20 | self.num_threads = args.num_threads 21 | self.exposure_value = args.exposure_value 22 | self.num_channels = args.num_channels 23 | 24 | self.num_epochs = args.num_epochs 25 | self.save_epochs = args.save_epochs 26 | self.batch_size = args.batch_size 27 | self.test_batch_size = args.test_batch_size 28 | self.lr = args.lr 29 | 30 | self.train_dataset = args.train_dataset 31 | self.test_dataset = args.test_dataset 32 | 33 | self.save_dir = args.save_dir 34 | self.gpu_mode = args.gpu_mode 35 | 36 | self.stride = args.stride 37 | 38 | self.build_model() 39 | 40 | def build_model(self): 41 | # networks 42 | self.stopup_G = Generator(num_channels=self.num_channels, base_filter=64, stop='up') 43 | self.stopdown_G = Generator(num_channels=self.num_channels, base_filter=64, stop='down') 44 | 45 | self.stopup_D = NLayerDiscriminator(num_channels=2*self.num_channels,base_filter=64, image_size=self.patch_size) 46 | self.stopdown_D = NLayerDiscriminator(num_channels=2*self.num_channels,base_filter=64, image_size=self.patch_size) 47 | 48 | print('---------- Networks architecture -------------') 49 | utils.print_network(self.stopup_G) 50 | utils.print_network(self.stopdown_D) 51 | print('----------------------------------------------') 52 | 53 | # weigh initialization 54 | self.stopup_G.weight_init() 55 | self.stopdown_G.weight_init() 56 | 57 | self.stopup_D.weight_init() 58 | self.stopdown_D.weight_init() 59 | 60 | # optimizer 61 | self.stopup_G_optimizer = optim.Adam(self.stopup_G.parameters(), lr=self.lr, betas=(0.5, 0.999)) 62 | self.stopdown_G_optimizer = optim.Adam(self.stopdown_G.parameters(), lr=self.lr, betas=(0.5, 0.999)) 63 | 64 | self.stopup_D_optimizer = optim.Adam(self.stopup_D.parameters(), lr=self.lr, betas=(0.5, 0.999)) 65 | self.stopdown_D_optimizer = optim.Adam(self.stopdown_D.parameters(), lr=self.lr, betas=(0.5, 0.999)) 66 | 67 | # loss function 68 | if self.gpu_mode: 69 | self.stopup_G = nn.DataParallel(self.stopup_G) 70 | self.stopdown_G = nn.DataParallel(self.stopdown_G) 71 | 72 | self.stopup_D = nn.DataParallel(self.stopup_D) 73 | self.stopdown_D = nn.DataParallel(self.stopdown_D) 74 | 75 | self.stopup_G.cuda() 76 | self.stopdown_G.cuda() 77 | 78 | self.stopup_D.cuda() 79 | self.stopdown_D.cuda() 80 | 81 | self.L1_loss = nn.L1Loss().cuda() 82 | self.criterionGAN = GANLoss().cuda() 83 | 84 | else: 85 | self.L1_loss = nn.L1Loss() 86 | self.MSE_loss = nn.MSELoss() 87 | self.BCE_loss = nn.BCELoss() 88 | self.criterionGAN = GANLoss() 89 | 90 | return 91 | 92 | def load_dataset(self, dataset, is_train=True): 93 | if self.num_channels == 1: 94 | is_gray = True 95 | else: 96 | is_gray = False 97 | 98 | if is_train: 99 | print('Loading train datasets...') 100 | train_set = get_loader(self.train_dataset) 101 | return DataLoader(dataset=train_set, num_workers=self.num_threads, batch_size=self.batch_size, 102 | shuffle=True) 103 | else: 104 | print('Loading test datasets...') 105 | test_set = get_loader(self.test_dataset) 106 | return DataLoader(dataset=test_set, num_workers=self.num_threads, 107 | batch_size=self.test_batch_size, 108 | shuffle=False) 109 | 110 | def train(self): 111 | # load dataset 112 | train_data_loader = self.load_dataset(dataset=self.train_dataset, is_train=True) 113 | test_data_loader = self.load_dataset(dataset=self.test_dataset[0], is_train=False) 114 | 115 | # set the logger 116 | stopup_G_log_dir = os.path.join(self.save_dir, 'stopup_G_logs') 117 | if not os.path.exists(stopup_G_log_dir): 118 | os.mkdir(stopup_G_log_dir) 119 | #stopup_G_logger = Logger(stopup_G_log_dir) 120 | 121 | stopup_D_log_dir = os.path.join(self.save_dir, 'stopup_D_logs') 122 | if not os.path.exists(stopup_D_log_dir): 123 | os.mkdir(stopup_D_log_dir) 124 | #stopup_D_logger = Logger(stopup_D_log_dir) 125 | 126 | 127 | stopdown_G_log_dir = os.path.join(self.save_dir, 'stopdown_G_logs') 128 | if not os.path.exists(stopdown_G_log_dir): 129 | os.mkdir(stopdown_G_log_dir) 130 | #stopdown_G_logger = Logger(stopdown_G_log_dir) 131 | 132 | stopdown_D_log_dir = os.path.join(self.save_dir, 'stopdown_D_logs') 133 | if not os.path.exists(stopdown_D_log_dir): 134 | os.mkdir(stopdown_D_log_dir) 135 | #stopdown_D_logger = Logger(stopdown_D_log_dir) 136 | 137 | 138 | ################# Pre-train generator ################# 139 | self.epoch_pretrain = 10 140 | 141 | # Load pre-trained parameters of generator 142 | if not self.load_model(is_pretrain=True): 143 | # Pre-training generator for 10 epochs 144 | print('Pre-training is started.') 145 | self.stopup_G.train() 146 | self.stopdown_G.train() 147 | 148 | for epoch in range(self.epoch_pretrain): 149 | for iter, (lr, hr) in enumerate(train_data_loader): 150 | # input data (low dynamic image) 151 | if self.num_channels == 1: 152 | x_ = Variable(utils.norm(hr[:, 0].unsqueeze(1), vgg=False)) 153 | y_ = Variable(utils.norm(lr[:, 0].unsqueeze(1), vgg=False)) 154 | else: 155 | x_ = Variable(utils.norm(hr, vgg=False)) 156 | y_ = Variable(utils.norm(lr, vgg=False)) 157 | 158 | if self.gpu_mode: 159 | x_ = x_.cuda() 160 | y_ = y_.cuda() 161 | 162 | # Train generator 163 | self.stopup_G_optimizer.zero_grad() 164 | self.stopdown_G_optimizer.zero_grad() 165 | 166 | ''' 167 | stopup 168 | ''' 169 | stopup_est = self.stopup_G(y_) 170 | 171 | # Content losses 172 | stopup_content_loss = self.L1_loss(stopup_est, x_) 173 | stopup_G_loss = stopup_content_loss 174 | 175 | stopup_G_loss.backward() 176 | self.stopup_G_optimizer.step() 177 | 178 | ''' 179 | stopdown 180 | ''' 181 | stopdown_est = self.stopdown_G(x_) 182 | 183 | # Content losses 184 | stopdown_content_loss = self.L1_loss(stopdown_est, y_) 185 | stopdown_G_loss = stopdown_content_loss 186 | 187 | stopdown_G_loss.backward() 188 | self.stopdown_G_optimizer.step() 189 | 190 | # log 191 | print("Epoch: [%2d] [%4d/%4d] stopup_G: %.6f/stopdown_G: %.6f" 192 | % ((epoch + 1), (iter + 1), len(train_data_loader), stopup_G_loss.data[0], stopdown_G_loss.data[0]), end='\r') 193 | 194 | # siyeong 195 | if (iter % 100 == 0): 196 | import random 197 | index = random.randrange(0,self.batch_size) 198 | 199 | input_data = torch.cat((y_[index], x_[index]), 1) 200 | est_data = torch.cat((stopup_est[index], stopdown_est[index]),1) 201 | square = torch.cat((input_data, est_data), 2) 202 | square = utils.denorm(square.cpu().data, vgg=False) 203 | 204 | square_img = transforms.ToPILImage()(square) 205 | 206 | square_img.show() 207 | 208 | print('Pre-training is finished.') 209 | 210 | # Save pre-trained parameters of generator 211 | self.save_model(is_pretrain=True) 212 | 213 | ################# Adversarial train ################# 214 | print('Training is started.') 215 | # Avg. losses 216 | stopup_G_avg_loss = [] 217 | stopup_D_avg_loss = [] 218 | stopdown_G_avg_loss = [] 219 | stopdown_D_avg_loss = [] 220 | 221 | step = 0 222 | 223 | # test image 224 | test_lr, test_hr = test_data_loader.dataset.__getitem__(2) 225 | test_lr = test_lr.unsqueeze(0) 226 | test_hr = test_hr.unsqueeze(0) 227 | 228 | self.stopup_G.train() 229 | self.stopup_D.train() 230 | 231 | self.stopdown_G.train() 232 | self.stopdown_D.train() 233 | 234 | for epoch in range(self.num_epochs): 235 | # learning rate is decayed by a factor of 10 every 20 epoch 236 | if (epoch + 1) % 20 == 0: 237 | for param_group in self.stopup_G_optimizer.param_groups: 238 | param_group["lr"] /= 2.0 239 | print("Learning rate decay for G: lr={}".format(self.stopup_G_optimizer.param_groups[0]["lr"])) 240 | for param_group in self.stopup_D_optimizer.param_groups: 241 | param_group["lr"] /= 2.0 242 | print("Learning rate decay for D: lr={}".format(self.stopup_D_optimizer.param_groups[0]["lr"])) 243 | 244 | for param_group in self.stopdown_G_optimizer.param_groups: 245 | param_group["lr"] /= 2.0 246 | print("Learning rate decay for G: lr={}".format(self.stopdown_G_optimizer.param_groups[0]["lr"])) 247 | for param_group in self.stopdown_D_optimizer.param_groups: 248 | param_group["lr"] /= 2.0 249 | print("Learning rate decay for D: lr={}".format(self.stopdown_D_optimizer.param_groups[0]["lr"])) 250 | 251 | stopup_G_epoch_loss = 0 252 | stopup_D_epoch_loss = 0 253 | 254 | stopdown_G_epoch_loss = 0 255 | stopdown_D_epoch_loss = 0 256 | 257 | for iter, (lr, hr) in enumerate(train_data_loader): 258 | # input data (low dynamic image) 259 | mini_batch = lr.size()[0] 260 | 261 | if self.num_channels == 1: 262 | x_ = Variable(utils.norm(hr[:, 0].unsqueeze(1), vgg=False)) 263 | y_ = Variable(utils.norm(lr[:, 0].unsqueeze(1), vgg=False)) 264 | else: 265 | x_ = Variable(utils.norm(hr, vgg=False)) 266 | y_ = Variable(utils.norm(lr, vgg=False)) 267 | 268 | if self.gpu_mode: 269 | x_ = x_.cuda() 270 | y_ = y_.cuda() 271 | # labels 272 | real_label = Variable(torch.ones(mini_batch).cuda()) 273 | fake_label = Variable(torch.zeros(mini_batch).cuda()) 274 | else: 275 | # labels 276 | real_label = Variable(torch.ones(mini_batch)) 277 | fake_label = Variable(torch.zeros(mini_batch)) 278 | 279 | # Reset gradient 280 | self.stopup_D_optimizer.zero_grad() 281 | self.stopdown_D_optimizer.zero_grad() 282 | 283 | # Train discriminator with real data 284 | stopup_D_real_decision = self.stopup_D(torch.cat((x_, y_),1)) 285 | stopdown_D_real_decision = self.stopdown_D(torch.cat((y_, x_),1)) 286 | 287 | stopup_D_real_loss = self.criterionGAN(stopup_D_real_decision, True) 288 | stopdown_D_real_loss = self.criterionGAN(stopdown_D_real_decision, True) 289 | 290 | # Train discriminator with fake data 291 | stopup_est = self.stopup_G(y_) 292 | stopdown_est = self.stopdown_G(x_) 293 | 294 | stopup_D_fake_decision = self.stopup_D(torch.cat((stopup_est, y_),1)) 295 | stopdown_D_fake_decision = self.stopdown_D(torch.cat((stopdown_est, x_),1)) 296 | 297 | stopup_D_fake_loss = self.criterionGAN(stopup_D_fake_decision, False) 298 | stopdown_D_fake_loss = self.criterionGAN(stopdown_D_fake_decision, False) 299 | 300 | stopup_D_loss = 0.5*stopup_D_real_loss + 0.5*stopup_D_fake_loss 301 | stopdown_D_loss = 0.5*stopdown_D_real_loss + 0.5*stopdown_D_fake_loss 302 | 303 | # Back propagation 304 | stopup_D_loss.backward(retain_graph=True) 305 | stopdown_D_loss.backward(retain_graph=True) 306 | 307 | self.stopup_D_optimizer.step() 308 | self.stopdown_D_optimizer.step() 309 | 310 | # Reset gradient 311 | self.stopup_G_optimizer.zero_grad() 312 | self.stopdown_G_optimizer.zero_grad() 313 | 314 | # Train generator 315 | stopup_est = self.stopup_G(y_) 316 | stopdown_est = self.stopdown_G(x_) 317 | 318 | stopup_D_fake_decision = self.stopup_D(torch.cat((stopup_est, y_), 1)) 319 | stopdown_D_fake_decision = self.stopdown_D(torch.cat((stopdown_est, x_), 1)) 320 | 321 | # Adversarial loss 322 | stopup_GAN_loss = self.criterionGAN(stopup_D_fake_decision, True) 323 | stopdown_GAN_loss = self.criterionGAN(stopdown_D_fake_decision, True) 324 | 325 | # Content losses 326 | stopup_mae_loss = self.L1_loss(stopup_est, x_) 327 | stopdown_mae_loss = self.L1_loss(stopdown_est, y_) 328 | 329 | # Total loss 330 | stopup_G_loss = stopup_mae_loss + 1e-2*stopup_GAN_loss 331 | stopdown_G_loss = stopdown_mae_loss + 1e-2*stopdown_GAN_loss 332 | 333 | stopup_G_loss.backward() 334 | self.stopup_G_optimizer.step() 335 | 336 | stopdown_G_loss.backward() 337 | self.stopdown_G_optimizer.step() 338 | 339 | # siyeong 340 | if (iter % 100 == 0): 341 | import random 342 | index = random.randrange(0,self.batch_size) 343 | 344 | input_data = torch.cat((y_[index], x_[index]), 1) 345 | est_data = torch.cat((stopup_est[index], stopdown_est[index]),1) 346 | 347 | square = torch.cat((input_data, est_data), 2) 348 | 349 | square = utils.denorm(square.cpu().data, vgg=False) 350 | square_img = transforms.ToPILImage()(square) 351 | 352 | square_img.show() 353 | 354 | # log 355 | stopup_G_epoch_loss += stopup_G_loss.data[0] 356 | stopup_D_epoch_loss += stopup_D_loss.data[0] 357 | 358 | stopdown_G_epoch_loss += stopdown_G_loss.data[0] 359 | stopdown_D_epoch_loss += stopdown_D_loss.data[0] 360 | 361 | print("Epoch: [%02d] [%05d/%05d] stopup_G/D: %.6f/%.6f, stopdown_G/D: %.6f/%.6f" 362 | % ((epoch + 1), (iter + 1), len(train_data_loader), stopup_G_loss.data[0], stopup_D_loss.data[0], stopdown_G_loss.data[0], stopdown_D_loss.data[0]), end="\r") 363 | 364 | # tensorboard logging 365 | stopup_G_logger.scalar_summary('losses', stopup_G_loss.data[0], step + 1) 366 | stopup_D_logger.scalar_summary('losses', stopup_D_loss.data[0], step + 1) 367 | 368 | stopdown_G_logger.scalar_summary('losses', stopdown_G_loss.data[0], step + 1) 369 | stopdown_D_logger.scalar_summary('losses', stopdown_D_loss.data[0], step + 1) 370 | 371 | step += 1 372 | 373 | # avg. loss per epoch 374 | stopup_G_avg_loss.append(stopup_G_epoch_loss / len(train_data_loader)) 375 | stopup_D_avg_loss.append(stopup_D_epoch_loss / len(train_data_loader)) 376 | 377 | stopdown_G_avg_loss.append(stopdown_G_epoch_loss / len(train_data_loader)) 378 | stopdown_D_avg_loss.append(stopdown_D_epoch_loss / len(train_data_loader)) 379 | 380 | self.save_model(epoch + 1) 381 | 382 | # Plot avg. loss 383 | utils.plot_loss([stopup_G_avg_loss, stopup_D_avg_loss, stopdown_G_avg_loss, stopdown_D_avg_loss], self.num_epochs, save_dir=self.save_dir) 384 | print("Training is finished.") 385 | 386 | # Save final trained parameters of model 387 | self.save_model(epoch=None) 388 | 389 | # siyeong3 390 | def test(self, input_path='./', out_path='./Result/', extend = 3): 391 | # load model 392 | self.load_model(is_pretrain=False) 393 | scenes = listdir(input_path) 394 | for i, scene in enumerate(scenes): 395 | scene_path = join(input_path, scene) 396 | if not os.path.isdir(out_path): 397 | os.mkdir(out_path) 398 | 399 | out_name = os.path.splitext(os.path.split(scene_path)[1])[0] 400 | storage_path = out_path + out_name + '/' 401 | 402 | # mkdir storage folder 403 | if not os.path.isdir(storage_path): 404 | os.mkdir(storage_path) 405 | 406 | # cp middle exposure file 407 | cmd = "cp " + scene_path + " " + storage_path + out_name + '_EV0.png' 408 | os.system(cmd) 409 | 410 | target = scene_path 411 | for i in range(1, extend+1): 412 | reconst = self.image_single(target, True) 413 | output_name = storage_path + out_name + '_EV%d' %i + '.png' 414 | reconst.save(output_name) 415 | 416 | target = output_name 417 | 418 | target = scene_path 419 | for i in range(1, extend+1): 420 | reconst = self.image_single(target, False) 421 | output_name = storage_path + out_name +'_EV-%d' %i + '.png' 422 | reconst.save(output_name) 423 | 424 | target = output_name 425 | print('\tImage [', out_name, '] is finished.') 426 | print('Test is finishied.') 427 | 428 | def image_single(self, img_fn, stopup): 429 | # load data 430 | img = Image.open(img_fn).convert('RGB') 431 | 432 | img = img.resize((256, 256), 4) 433 | tensor = transforms.ToTensor()(img) 434 | tensor_norm = Variable(utils.norm(tensor, vgg=False)) 435 | tensor_expand = tensor_norm.unsqueeze(0) 436 | 437 | if stopup: 438 | self.stopup_G.train() 439 | recon_norm = self.stopup_G(tensor_expand) 440 | 441 | else: 442 | self.stopdown_G.train() 443 | recon_norm = self.stopdown_G(tensor_expand) 444 | 445 | recon = utils.denorm(recon_norm.cpu().data, vgg=False) 446 | recon = recon.squeeze(0) 447 | recon = torch.clamp(recon, min=0, max=1) 448 | recon_img = transforms.ToPILImage()(recon) 449 | return recon_img 450 | 451 | def save_model(self, epoch=None, is_pretrain=False): 452 | model_dir = os.path.join(self.save_dir, 'model') 453 | if not os.path.exists(model_dir): 454 | os.mkdir(model_dir) 455 | 456 | if is_pretrain: 457 | torch.save(self.stopup_G.state_dict(), model_dir + '/' + self.model_name + '_stopup_G_param_pretrain.pkl') 458 | torch.save(self.stopdown_G.state_dict(), model_dir + '/' + self.model_name + '_stopdown_G_param_pretrain.pkl') 459 | 460 | print('Pre-trained generator model is saved.') 461 | else: 462 | if epoch is not None: 463 | torch.save(self.stopup_G.state_dict(), model_dir + '/' + self.model_name + 464 | '_stopup_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 465 | % (self.num_channels, self.batch_size, epoch, self.lr)) 466 | torch.save(self.stopup_D.state_dict(), model_dir + '/' + self.model_name + 467 | '_stopup_D_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 468 | % (self.num_channels, self.batch_size, epoch, self.lr)) 469 | 470 | torch.save(self.stopdown_G.state_dict(), model_dir + '/' + self.model_name + 471 | '_stopdown_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 472 | % (self.num_channels, self.batch_size, epoch, self.lr)) 473 | torch.save(self.stopdown_D.state_dict(), model_dir + '/' + self.model_name + 474 | '_stopdown_D_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 475 | % (self.num_channels, self.batch_size, epoch, self.lr)) 476 | 477 | else: 478 | torch.save(self.stopup_G.state_dict(), model_dir + '/' + self.model_name + 479 | '_stopup_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 480 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 481 | torch.save(self.stopup_D.state_dict(), model_dir + '/' + self.model_name + 482 | '_stopup_D_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 483 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 484 | torch.save(self.stopdown_G.state_dict(), model_dir + '/' + self.model_name + 485 | '_stopdown_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 486 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 487 | torch.save(self.stopdown_D.state_dict(), model_dir + '/' + self.model_name + 488 | '_stopdown_D_param_ch%d_batch%d_epoch%d_lr%.g.pkl' 489 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr)) 490 | 491 | print('Trained models are saved.') 492 | 493 | def load_model(self, is_pretrain=False): 494 | model_dir = os.path.join(self.save_dir, 'model') 495 | 496 | if is_pretrain: 497 | flag_stopup = False 498 | flag_stopdown = False 499 | 500 | model_name_stopup = model_dir + '/' + self.model_name + '_stopup_G_param_pretrain.pkl' 501 | model_name_stopup_D = model_dir + '/' + self.model_name + '_stopup_D_param_pretrain.pkl' 502 | 503 | if os.path.exists(model_name_stopup): 504 | self.stopup_G.load_state_dict(torch.load(model_name_stopup)) 505 | self.stopup_D.load_state_dict(torch.load(model_name_stopup_D)) 506 | flag_stopup = True 507 | 508 | model_name_stopdown = model_dir + '/' + self.model_name + '_stopdown_G_param_pretrain.pkl' 509 | model_name_stopdown_D = model_dir + '/' + self.model_name + '_stopdown_D_param_pretrain.pkl' 510 | 511 | if os.path.exists(model_name_stopdown): 512 | self.stopdown_G.load_state_dict(torch.load(model_name_stopdown)) 513 | self.stopdown_D.load_state_dict(torch.load(model_name_stopdown_D)) 514 | 515 | flag_stopdown = True 516 | 517 | print ("[loding] (up):", flag_stopup, ', (down):',flag_stopdown) 518 | print (model_name_stopup) 519 | print (model_name_stopdown) 520 | 521 | if flag_stopdown and flag_stopup: 522 | print('Pre-trained generator model is loaded.') 523 | return True 524 | else: 525 | return False 526 | 527 | else: 528 | flag_stopup = False 529 | flag_stopdown = False 530 | 531 | model_name_stopup = model_dir + '/' + self.model_name + \ 532 | '_stopup_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' \ 533 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr) 534 | print(model_name_stopup) 535 | 536 | if os.path.exists(model_name_stopup): 537 | self.stopup_G.load_state_dict(torch.load(model_name_stopup)) 538 | flag_stopup = True 539 | 540 | model_name_stopdown = model_dir + '/' + self.model_name + \ 541 | '_stopdown_G_param_ch%d_batch%d_epoch%d_lr%.g.pkl' \ 542 | % (self.num_channels, self.batch_size, self.num_epochs, self.lr) 543 | 544 | if os.path.exists(model_name_stopdown): 545 | self.stopdown_G.load_state_dict(torch.load(model_name_stopdown)) 546 | flag_stopdown = True 547 | 548 | print ("[loding] (up):", flag_stopup, ', (down):',flag_stopdown) 549 | print (model_name_stopup) 550 | print (model_name_stopdown) 551 | 552 | if flag_stopup and flag_stopdown: 553 | print('Trained generator model is loaded.') 554 | return True 555 | 556 | else: 557 | return False 558 | 559 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, argparse 3 | 4 | from data_loader import get_loader 5 | from solver import Solver 6 | 7 | """parsing and configuration""" 8 | def parse_args(): 9 | desc = "ECCV 2018: Deep Recursive HDR" 10 | parser = argparse.ArgumentParser(description=desc) 11 | parser.add_argument('--model_name', type=str, default='HDRGAN', help='The type of model') 12 | parser.add_argument('--data_dir', type=str, default='../Data') 13 | parser.add_argument('--train_dataset', type=str, default='/database/ECCV/minus_ev/train/' , help='Train set path') 14 | parser.add_argument('--test_dataset', type=str, default='/database2/Junghee/Stack_HDR_Eye', help='Test dataset') 15 | parser.add_argument('--patch_size', type=int, default=128, help='input patch size') 16 | parser.add_argument('--num_channels', type=int, default=3, help='The number of channels to super-resolve') 17 | 18 | parser.add_argument('--num_threads', type=int, default=24, help='number of threads for data loader to use') 19 | parser.add_argument('--exposure_value', type=int, default=1, help='exposure value') 20 | 21 | parser.add_argument('--num_epochs', type=int, default=20, help='The number of epochs to run') 22 | parser.add_argument('--save_epochs', type=int, default=1, help='Save trained model every this epochs') 23 | parser.add_argument('--batch_size', type=int, default=1, help='training batch size') 24 | parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size') 25 | 26 | parser.add_argument('--save_dir', type=str, default='Result', help='Directory name to save the results') 27 | parser.add_argument('--lr', type=float, default=0.0002) 28 | parser.add_argument('--gpu_mode', type=bool, default=True) 29 | 30 | parser.add_argument('--stride', type=int, default=32) 31 | 32 | return check_args(parser.parse_args()) 33 | 34 | """checking arguments""" 35 | def check_args(args): 36 | # --save_dir 37 | args.save_dir = os.path.join(args.save_dir, args.model_name) 38 | if not os.path.exists(args.save_dir): 39 | os.makedirs(args.save_dir) 40 | 41 | # --epoch 42 | try: 43 | assert args.num_epochs >= 1 44 | except: 45 | print('number of epochs must be larger than or equal to one') 46 | 47 | # --batch_size 48 | try: 49 | assert args.batch_size >= 1 50 | except: 51 | print('batch size must be larger than or equal to one') 52 | 53 | # --stride 54 | try: 55 | assert args.stride < args.patch_size 56 | except: 57 | print('it is possible to fail image reconstruction') 58 | 59 | return args 60 | 61 | """main""" 62 | def main(): 63 | # parse arguments 64 | args = parse_args() 65 | if args is None: 66 | exit() 67 | 68 | if args.gpu_mode and not torch.cuda.is_available(): 69 | raise Exception("No GPU found, please run without --gpu_mode=False") 70 | 71 | # model 72 | net = Solver(args) 73 | 74 | # test 75 | net.test(input_path=args.test_dataset) 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, argparse 3 | 4 | from data_loader import get_loader 5 | from solver import Solver 6 | 7 | """parsing and configuration""" 8 | def parse_args(): 9 | desc = "ECCV 2018: Deep Recursive HDR" 10 | parser = argparse.ArgumentParser(description=desc) 11 | parser.add_argument('--model_name', type=str, default='HDRGAN', help='The type of model') 12 | parser.add_argument('--data_dir', type=str, default='../Data') 13 | parser.add_argument('--train_dataset', type=str, default='/database/hdr/HDRv45960/' , help='Train set path') 14 | parser.add_argument('--test_dataset', type=str, default='/database/hdr/HDRv45960/', help='Test dataset') 15 | parser.add_argument('--patch_size', type=int, default=256, help='input patch size') 16 | parser.add_argument('--num_channels', type=int, default=3, help='The number of channels to super-resolve') 17 | 18 | parser.add_argument('--num_threads', type=int, default=24, help='number of threads for data loader to use') 19 | parser.add_argument('--exposure_value', type=int, default=1, help='exposure value') 20 | 21 | parser.add_argument('--num_epochs', type=int, default=20, help='The number of epochs to run') 22 | parser.add_argument('--save_epochs', type=int, default=1, help='Save trained model every this epochs') 23 | parser.add_argument('--batch_size', type=int, default=1, help='training batch size') 24 | parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size') 25 | 26 | parser.add_argument('--save_dir', type=str, default='Result', help='Directory name to save the results') 27 | parser.add_argument('--lr', type=float, default=0.0002) 28 | parser.add_argument('--gpu_mode', type=bool, default=True) 29 | 30 | parser.add_argument('--stride', type=int, default=32) 31 | 32 | return check_args(parser.parse_args()) 33 | 34 | """checking arguments""" 35 | def check_args(args): 36 | # --save_dir 37 | args.save_dir = os.path.join(args.save_dir, args.model_name) 38 | if not os.path.exists(args.save_dir): 39 | os.makedirs(args.save_dir) 40 | 41 | # --epoch 42 | try: 43 | assert args.num_epochs >= 1 44 | except: 45 | print('number of epochs must be larger than or equal to one') 46 | 47 | # --batch_size 48 | try: 49 | assert args.batch_size >= 1 50 | except: 51 | print('batch size must be larger than or equal to one') 52 | 53 | # --stride 54 | try: 55 | assert args.stride < args.patch_size 56 | except: 57 | print('it is possible to fail image reconstruction') 58 | 59 | return args 60 | 61 | """main""" 62 | def main(): 63 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 64 | 65 | # parse arguments 66 | args = parse_args() 67 | if args is None: 68 | exit() 69 | 70 | if args.gpu_mode and not torch.cuda.is_available(): 71 | raise Exception("No GPU found, please run without --gpu_mode=False") 72 | 73 | # model 74 | net = Solver(args) 75 | 76 | # train 77 | net.train() 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | from math import log10 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import os 10 | import imageio 11 | from scipy.misc import imsave 12 | 13 | import torch.nn.modules as modules 14 | 15 | 16 | def print_network(net): 17 | num_params = 0 18 | for param in net.parameters(): 19 | num_params += param.numel() 20 | print(net) 21 | print('Total number of parameters: %d' % num_params) 22 | 23 | 24 | # For logger 25 | def to_np(x): 26 | return x.data.cpu().numpy() 27 | 28 | 29 | def to_var(x): 30 | if torch.cuda.is_available(): 31 | x = torch.from_numpy(x).cuda() 32 | return Variable(x) 33 | 34 | 35 | # Plot losses 36 | def plot_loss(avg_losses, num_epochs, save_dir='', show=False): 37 | fig, ax = plt.subplots() 38 | ax.set_xlim(0, num_epochs) 39 | temp = 0.0 40 | for i in range(len(avg_losses)): 41 | temp = max(np.max(avg_losses[i]), temp) 42 | ax.set_ylim(0, temp*1.1) 43 | plt.xlabel('# of Epochs') 44 | plt.ylabel('Loss values') 45 | 46 | if len(avg_losses) == 1: 47 | plt.plot(avg_losses[0], label='loss') 48 | else: 49 | plt.plot(avg_losses[0], label='G_loss') 50 | plt.plot(avg_losses[1], label='D_loss') 51 | plt.legend() 52 | 53 | # save figure 54 | if not os.path.exists(save_dir): 55 | os.makedirs(save_dir) 56 | 57 | save_fn = 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png' 58 | save_fn = os.path.join(save_dir, save_fn) 59 | plt.savefig(save_fn) 60 | 61 | if show: 62 | plt.show() 63 | else: 64 | plt.close() 65 | 66 | 67 | # Make gif 68 | def make_gif(dataset, num_epochs, save_dir='results/'): 69 | gen_image_plots = [] 70 | for epoch in range(num_epochs): 71 | # plot for generating gif 72 | save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png' 73 | gen_image_plots.append(imageio.imread(save_fn)) 74 | 75 | imageio.mimsave(save_dir + dataset + '_result_epochs_{:d}'.format(num_epochs) + '.gif', gen_image_plots, fps=5) 76 | 77 | 78 | def weights_init_normal(m, mean=0.0, std=0.02): 79 | classname = m.__class__.__name__ 80 | if classname.find('Linear') != -1: 81 | m.weight.data.normal_(mean, std) 82 | if m.bias is not None: 83 | m.bias.data.zero_() 84 | elif classname.find('Conv2d') != -1: 85 | m.weight.data.normal_(mean, std) 86 | if m.bias is not None: 87 | m.bias.data.zero_() 88 | elif classname.find('ConvTranspose2d') != -1: 89 | m.weight.data.normal_(mean, std) 90 | if m.bias is not None: 91 | m.bias.data.zero_() 92 | elif classname.find('Batch') != -1: 93 | m.weight.data.normal_(1.0, 0.02) 94 | if m.bias is not None: 95 | m.bias.data.zero_() 96 | 97 | 98 | def weights_init_kaming(m): 99 | classname = m.__class__.__name__ 100 | if classname.find('Linear') != -1: 101 | torch.nn.init.kaiming_normal(m.weight) 102 | if m.bias is not None: 103 | m.bias.data.zero_() 104 | elif classname.find('Conv2d') != -1: 105 | torch.nn.init.kaiming_normal(m.weight) 106 | if m.bias is not None: 107 | m.bias.data.zero_() 108 | elif classname.find('ConvTranspose2d') != -1: 109 | torch.nn.init.kaiming_normal(m.weight) 110 | if m.bias is not None: 111 | m.bias.data.zero_() 112 | elif classname.find('Norm') != -1: 113 | m.weight.data.normal_(1.0, 0.02) 114 | if m.bias is not None: 115 | m.bias.data.zero_() 116 | 117 | 118 | def save_img(img, img_num, save_dir='', is_training=False): 119 | # img.clamp(0, 1) 120 | if list(img.shape)[0] == 3: 121 | save_img = img*255.0 122 | save_img = save_img.clamp(0, 255).numpy().transpose(1, 2, 0).astype(np.uint8) 123 | # img = (((img - img.min()) * 255) / (img.max() - img.min())).numpy().transpose(1, 2, 0).astype(np.uint8) 124 | else: 125 | save_img = img.squeeze().clamp(0, 1).numpy() 126 | 127 | # save img 128 | if not os.path.exists(save_dir): 129 | os.makedirs(save_dir) 130 | if is_training: 131 | save_fn = save_dir + '/SR_result_epoch_{:d}'.format(img_num) + '.png' 132 | else: 133 | save_fn = save_dir + '/SR_result_{:d}'.format(img_num) + '.png' 134 | imsave(save_fn, save_img) 135 | 136 | 137 | def plot_test_result(imgs, psnrs, img_num, save_dir='', is_training=False, show_label=True, show=False): 138 | size = list(imgs[0].shape) 139 | if show_label: 140 | h = 3 141 | w = h * len(imgs) 142 | else: 143 | h = size[2] / 100 144 | w = size[1] * len(imgs) / 100 145 | 146 | fig, axes = plt.subplots(1, len(imgs), figsize=(w, h)) 147 | # axes.axis('off') 148 | for i, (ax, img, psnr) in enumerate(zip(axes.flatten(), imgs, psnrs)): 149 | ax.axis('off') 150 | ax.set_adjustable('box-forced') 151 | if list(img.shape)[0] == 3: 152 | # Scale to 0-255 153 | # img = (((img - img.min()) * 255) / (img.max() - img.min())).numpy().transpose(1, 2, 0).astype(np.uint8) 154 | img *= 255.0 155 | img = img.clamp(0, 255).numpy().transpose(1, 2, 0).astype(np.uint8) 156 | 157 | ax.imshow(img, cmap=None, aspect='equal') 158 | else: 159 | # img = ((img - img.min()) / (img.max() - img.min())).numpy().transpose(1, 2, 0) 160 | img = img.squeeze().clamp(0, 1).numpy() 161 | ax.imshow(img, cmap='gray', aspect='equal') 162 | 163 | if show_label: 164 | ax.axis('on') 165 | if i == 0: 166 | ax.set_xlabel('HR image') 167 | elif i == 1: 168 | ax.set_xlabel('LR image') 169 | elif i == 2: 170 | ax.set_xlabel('Bicubic (PSNR: %.2fdB)' % psnr) 171 | elif i == 3: 172 | ax.set_xlabel('SR image (PSNR: %.2fdB)' % psnr) 173 | 174 | if show_label: 175 | plt.tight_layout() 176 | else: 177 | plt.subplots_adjust(wspace=0, hspace=0) 178 | plt.subplots_adjust(bottom=0) 179 | plt.subplots_adjust(top=1) 180 | plt.subplots_adjust(right=1) 181 | plt.subplots_adjust(left=0) 182 | 183 | # save figure 184 | result_dir = os.path.join(save_dir, 'plot') 185 | if not os.path.exists(result_dir): 186 | os.makedirs(result_dir) 187 | if is_training: 188 | save_fn = result_dir + '/Train_result_epoch_{:d}'.format(img_num) + '.png' 189 | else: 190 | save_fn = result_dir + '/Test_result_{:d}'.format(img_num) + '.png' 191 | plt.savefig(save_fn) 192 | 193 | if show: 194 | plt.show() 195 | else: 196 | plt.close() 197 | 198 | 199 | def shave(imgs, border_size=0): 200 | size = list(imgs.shape) 201 | if len(size) == 4: 202 | shave_imgs = torch.FloatTensor(size[0], size[1], size[2]-border_size*2, size[3]-border_size*2) 203 | for i, img in enumerate(imgs): 204 | shave_imgs[i, :, :, :] = img[:, border_size:-border_size, border_size:-border_size] 205 | return shave_imgs 206 | else: 207 | return imgs[:, border_size:-border_size, border_size:-border_size] 208 | 209 | 210 | def PSNR(pred, gt): 211 | pred = pred.clamp(0, 1) 212 | # pred = (pred - pred.min()) / (pred.max() - pred.min()) 213 | 214 | diff = pred - gt 215 | mse = np.mean(diff.numpy() ** 2) 216 | if mse == 0: 217 | return 100 218 | return 10 * log10(1.0 / mse) 219 | 220 | 221 | def norm(img, vgg=False): 222 | if vgg: 223 | # normalize for pre-trained vgg model 224 | # https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101 225 | transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], 226 | std=[0.229, 0.224, 0.225]) 227 | return transform(img) 228 | else: 229 | # normalize [-1, 1] 230 | transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], 231 | std=[0.5, 0.5, 0.5]) 232 | return transform(img) 233 | 234 | 235 | def denorm(img, vgg=False): 236 | if vgg: 237 | transform = transforms.Normalize(mean=[-2.118, -2.036, -1.804], 238 | std=[4.367, 4.464, 4.444]) 239 | return transform(img) 240 | else: 241 | out = (img + 1) / 2 242 | #out = img 243 | return out.clamp(0, 1) 244 | 245 | 246 | def img_interp(imgs, scale_factor, interpolation='bicubic'): 247 | if interpolation == 'bicubic': 248 | interpolation = Image.BICUBIC 249 | elif interpolation == 'bilinear': 250 | interpolation = Image.BILINEAR 251 | elif interpolation == 'nearest': 252 | interpolation = Image.NEAREST 253 | 254 | size = list(imgs.shape) 255 | 256 | if len(size) == 4: 257 | target_height = int(size[2] * scale_factor) 258 | target_width = int(size[3] * scale_factor) 259 | interp_imgs = torch.FloatTensor(size[0], size[1], target_height, target_width) 260 | for i, img in enumerate(imgs): 261 | transform = transforms.Compose([transforms.ToPILImage(), 262 | transforms.Scale((target_width, target_height), interpolation=interpolation), 263 | transforms.ToTensor()]) 264 | 265 | interp_imgs[i, :, :, :] = transform(img) 266 | return interp_imgs 267 | else: 268 | target_height = int(size[1] * scale_factor) 269 | target_width = int(size[2] * scale_factor) 270 | transform = transforms.Compose([transforms.ToPILImage(), 271 | transforms.Scale((target_width, target_height), interpolation=interpolation), 272 | transforms.ToTensor()]) 273 | return transform(imgs) 274 | 275 | from collections import OrderedDict 276 | def summary(input_size, model): 277 | def register_hook(module): 278 | def hook(module, input, output): 279 | class_name = str(module.__class__).split('.')[-1].split("'")[0] 280 | module_idx = len(summary) 281 | 282 | m_key = '%s-%i' % (class_name, module_idx+1) 283 | summary[m_key] = OrderedDict() 284 | summary[m_key]['input_shape'] = list(input[0].size()) 285 | summary[m_key]['input_shape'][0] = -1 286 | summary[m_key]['output_shape'] = list(output.size()) 287 | summary[m_key]['output_shape'][0] = -1 288 | 289 | params = 0 290 | if hasattr(module, 'weight'): 291 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 292 | if module.weight.requires_grad: 293 | summary[m_key]['trainable'] = True 294 | else: 295 | summary[m_key]['trainable'] = False 296 | if hasattr(module, 'bias'): 297 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 298 | summary[m_key]['nb_params'] = params 299 | 300 | if not isinstance(module, nn.Sequential) and \ 301 | not isinstance(module, nn.ModuleList) and \ 302 | not (module == model): 303 | hooks.append(module.register_forward_hook(hook)) 304 | 305 | # check if there are multiple inputs to the network 306 | if isinstance(input_size[0], (list, tuple)): 307 | x = [Variable(torch.rand(1,*in_size)) for in_size in input_size] 308 | else: 309 | x = Variable(torch.rand(1,*input_size)) 310 | 311 | # create properties 312 | summary = OrderedDict() 313 | hooks = [] 314 | # register hook 315 | model.apply(register_hook) 316 | # make a forward pass 317 | model(x).cpu() 318 | # remove these hooks 319 | for h in hooks: 320 | h.remove() 321 | 322 | return summary 323 | --------------------------------------------------------------------------------