├── README.md
├── dped
└── .gitkeep
├── load_dataset.py
├── models.py
├── models
└── .gitkeep
├── models_orig
├── blackberry_orig.data-00000-of-00001
├── blackberry_orig.index
├── iphone_orig.data-00000-of-00001
├── iphone_orig.index
├── sony_orig.data-00000-of-00001
└── sony_orig.index
├── results
└── .gitkeep
├── ssim.py
├── test_model.py
├── train_model.py
├── utils.py
├── vgg.py
├── vgg_pretrained
└── .gitkeep
└── visual_results
└── .gitkeep
/README.md:
--------------------------------------------------------------------------------
1 | ## DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | #### 1. Overview [[Paper]](https://arxiv.org/pdf/1704.02470.pdf) [[Project webpage]](http://people.ee.ethz.ch/~ihnatova/) [[Enhancing RAW photos]](https://github.com/aiff22/PyNET) [[Rendering Bokeh Effect]](https://github.com/aiff22/PyNET-Bokeh)
10 |
11 | The provided code implements the paper that presents an end-to-end deep learning approach for translating ordinary photos from smartphones into DSLR-quality images. The learned model can be applied to photos of arbitrary resolution, while the methodology itself is generalized to
12 | any type of digital camera. More visual results can be found [here](http://people.ee.ethz.ch/~ihnatova/#demo).
13 |
14 |
15 | #### 2. Prerequisites
16 |
17 | - Python + Pillow, scipy, numpy, imageio packages
18 | - [TensorFlow 1.x / 2.x](https://www.tensorflow.org/install/) + [CUDA CuDNN](https://developer.nvidia.com/cudnn)
19 | - Nvidia GPU
20 |
21 |
22 | #### 3. First steps
23 |
24 | - Download the pre-trained [VGG-19 model](https://polybox.ethz.ch/index.php/s/7z5bHNg5r5a0g7k) [Mirror](https://drive.google.com/file/d/0BwOLOmqkYj-jMGRwaUR2UjhSNDQ/view?usp=sharing&resourcekey=0-Ff-0HUQsoKJxZ84trhsHpA) and put it into `vgg_pretrained/` folder
25 | - Download [DPED dataset](http://people.ee.ethz.ch/~ihnatova/#dataset) (patches for CNN training) and extract it into `dped/` folder.
26 | This folder should contain three subolders: `sony/`, `iphone/` and `blackberry/`
27 |
28 |
29 |
30 | #### 4. Train the model
31 |
32 | ```bash
33 | python train_model.py model=
34 | ```
35 |
36 | Obligatory parameters:
37 |
38 | >```model```: **```iphone```**, **```blackberry```** or **```sony```**
39 |
40 | Optional parameters and their default values:
41 |
42 | >```batch_size```: **```50```** - batch size [smaller values can lead to unstable training]
43 | >```train_size```: **```30000```** - the number of training patches randomly loaded each ```eval_step``` iterations
44 | >```eval_step```: **```1000```** - each ```eval_step``` iterations the model is saved and the training data is reloaded
45 | >```num_train_iters```: **```20000```** - the number of training iterations
46 | >```learning_rate```: **```5e-4```** - learning rate
47 | >```w_content```: **```10```** - the weight of the content loss
48 | >```w_color```: **```0.5```** - the weight of the color loss
49 | >```w_texture```: **```1```** - the weight of the texture [adversarial] loss
50 | >```w_tv```: **```2000```** - the weight of the total variation loss
51 | >```dped_dir```: **```dped/```** - path to the folder with DPED dataset
52 | >```vgg_dir```: **```vgg_pretrained/imagenet-vgg-verydeep-19.mat```** - path to the pre-trained VGG-19 network
53 |
54 | Example:
55 |
56 | ```bash
57 | python train_model.py model=iphone batch_size=50 dped_dir=dped/ w_color=0.7
58 | ```
59 |
60 |
61 |
62 | #### 5. Test the provided pre-trained models
63 |
64 | ```bash
65 | python test_model.py model=
66 | ```
67 |
68 | Obligatory parameters:
69 |
70 | >```model```: **```iphone_orig```**, **```blackberry_orig```** or **```sony_orig```**
71 |
72 | Optional parameters:
73 |
74 | >```test_subset```: **```full```**,**```small```** - all 29 or only 5 test images will be processed
75 | >```resolution```: **```orig```**,**```high```**,**```medium```**,**```small```**,**```tiny```** - the resolution of the test images [**```orig```** means original resolution]
76 | >```use_gpu```: **```true```**,**```false```** - run models on GPU or CPU
77 | >```dped_dir```: **```dped/```** - path to the folder with DPED dataset
78 |
79 | Example:
80 |
81 | ```bash
82 | python test_model.py model=iphone_orig test_subset=full resolution=orig use_gpu=true
83 | ```
84 |
85 |
86 |
87 | #### 6. Test the obtained models
88 |
89 | ```bash
90 | python test_model.py model=
91 | ```
92 |
93 | Obligatory parameters:
94 |
95 | >```model```: **```iphone```**, **```blackberry```** or **```sony```**
96 |
97 | Optional parameters:
98 |
99 | >```test_subset```: **```full```**,**```small```** - all 29 or only 5 test images will be processed
100 | >```iteration```: **```all```** or **``````** - get visual results for all iterations or for the specific iteration,
101 | > **``````** must be a multiple of ```eval_step```
102 | >```resolution```: **```orig```**,**```high```**,**```medium```**,**```small```**,**```tiny```** - the resolution of the test
103 | images [**```orig```** means original resolution]
104 | >```use_gpu```: **```true```**,**```false```** - run models on GPU or CPU
105 | >```dped_dir```: **```dped/```** - path to the folder with DPED dataset
106 |
107 | Example:
108 |
109 | ```bash
110 | python test_model.py model=iphone iteration=13000 test_subset=full resolution=orig use_gpu=true
111 | ```
112 |
113 |
114 | #### 7. Folder structure
115 |
116 | >```dped/``` - the folder with the DPED dataset
117 | >```models/``` - logs and models that are saved during the training process
118 | >```models_orig/``` - the provided pre-trained models for **```iphone```**, **```sony```** and **```blackberry```**
119 | >```results/``` - visual results for small image patches that are saved while training
120 | >```vgg-pretrained/``` - the folder with the pre-trained VGG-19 network
121 | >```visual_results/``` - processed [enhanced] test images
122 |
123 | >```load_dataset.py``` - python script that loads training data
124 | >```models.py``` - architecture of the image enhancement [resnet] and adversarial networks
125 | >```ssim.py``` - implementation of the ssim score
126 | >```train_model.py``` - implementation of the training procedure
127 | >```test_model.py``` - applying the pre-trained models to test images
128 | >```utils.py``` - auxiliary functions
129 | >```vgg.py``` - loading the pre-trained vgg-19 network
130 |
131 |
132 |
133 | #### 8. Problems and errors
134 |
135 | ```
136 | What if I get an error: "OOM when allocating tensor with shape [...]"?
137 | ```
138 |
139 | Your GPU does not have enough memory. If this happens during the training process:
140 |
141 | - Decrease the size of the training batch [```batch_size```]. Note however that smaller values can lead to unstable training.
142 |
143 | If this happens while testing the models:
144 |
145 | - Run the model on CPU (set the parameter ```use_gpu``` to **```false```**). Note that this can take up to 5 minutes per image.
146 | - Use cropped images, set the parameter ```resolution``` to:
147 |
148 | > **```high```** - center crop of size ```1680x1260``` pixels
149 | > **```medium```** - center crop of size ```1366x1024``` pixels
150 | > **```small```** - center crop of size ```1024x768``` pixels
151 | > **```tiny```** - center crop of size ```800x600``` pixels
152 |
153 | The less resolution is - the smaller part of the image will be processed
154 |
155 |
156 |
157 | #### 9. Citation
158 |
159 | ```
160 | @inproceedings{ignatov2017dslr,
161 | title={DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks},
162 | author={Ignatov, Andrey and Kobyshev, Nikolay and Timofte, Radu and Vanhoey, Kenneth and Van Gool, Luc},
163 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
164 | pages={3277--3285},
165 | year={2017}
166 | }
167 | ```
168 |
169 |
170 | #### 10. Any further questions?
171 |
172 | ```
173 | Please contact Andrey Ignatov (andrey.ignatoff@gmail.com) for more information
174 | ```
175 |
--------------------------------------------------------------------------------
/dped/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/dped/.gitkeep
--------------------------------------------------------------------------------
/load_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from scipy import misc
3 | import os
4 | import numpy as np
5 | import sys
6 |
7 | def load_test_data(phone, dped_dir, IMAGE_SIZE):
8 |
9 | test_directory_phone = dped_dir + str(phone) + '/test_data/patches/' + str(phone) + '/'
10 | test_directory_dslr = dped_dir + str(phone) + '/test_data/patches/canon/'
11 |
12 | NUM_TEST_IMAGES = len([name for name in os.listdir(test_directory_phone)
13 | if os.path.isfile(os.path.join(test_directory_phone, name))])
14 |
15 | test_data = np.zeros((NUM_TEST_IMAGES, IMAGE_SIZE))
16 | test_answ = np.zeros((NUM_TEST_IMAGES, IMAGE_SIZE))
17 |
18 | for i in range(0, NUM_TEST_IMAGES):
19 |
20 | I = np.asarray(misc.imread(test_directory_phone + str(i) + '.jpg'))
21 | I = np.float16(np.reshape(I, [1, IMAGE_SIZE]))/255
22 | test_data[i, :] = I
23 |
24 | I = np.asarray(misc.imread(test_directory_dslr + str(i) + '.jpg'))
25 | I = np.float16(np.reshape(I, [1, IMAGE_SIZE]))/255
26 | test_answ[i, :] = I
27 |
28 | if i % 100 == 0:
29 | print(str(round(i * 100 / NUM_TEST_IMAGES)) + "% done", end="\r")
30 |
31 | return test_data, test_answ
32 |
33 |
34 | def load_batch(phone, dped_dir, TRAIN_SIZE, IMAGE_SIZE):
35 |
36 | train_directory_phone = dped_dir + str(phone) + '/training_data/' + str(phone) + '/'
37 | train_directory_dslr = dped_dir + str(phone) + '/training_data/canon/'
38 |
39 | NUM_TRAINING_IMAGES = len([name for name in os.listdir(train_directory_phone)
40 | if os.path.isfile(os.path.join(train_directory_phone, name))])
41 |
42 | # if TRAIN_SIZE == -1 then load all images
43 |
44 | if TRAIN_SIZE == -1:
45 | TRAIN_SIZE = NUM_TRAINING_IMAGES
46 | TRAIN_IMAGES = np.arange(0, TRAIN_SIZE)
47 | else:
48 | TRAIN_IMAGES = np.random.choice(np.arange(0, NUM_TRAINING_IMAGES), TRAIN_SIZE, replace=False)
49 |
50 | train_data = np.zeros((TRAIN_SIZE, IMAGE_SIZE))
51 | train_answ = np.zeros((TRAIN_SIZE, IMAGE_SIZE))
52 |
53 | i = 0
54 | for img in TRAIN_IMAGES:
55 |
56 | I = np.asarray(misc.imread(train_directory_phone + str(img) + '.jpg'))
57 | I = np.float16(np.reshape(I, [1, IMAGE_SIZE])) / 255
58 | train_data[i, :] = I
59 |
60 | I = np.asarray(misc.imread(train_directory_dslr + str(img) + '.jpg'))
61 | I = np.float16(np.reshape(I, [1, IMAGE_SIZE])) / 255
62 | train_answ[i, :] = I
63 |
64 | i += 1
65 | if i % 100 == 0:
66 | print(str(round(i * 100 / TRAIN_SIZE)) + "% done", end="\r")
67 |
68 | return train_data, train_answ
69 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def resnet(input_image):
5 |
6 | with tf.compat.v1.variable_scope("generator"):
7 |
8 | W1 = weight_variable([9, 9, 3, 64], name="W1"); b1 = bias_variable([64], name="b1");
9 | c1 = tf.nn.relu(conv2d(input_image, W1) + b1)
10 |
11 | # residual 1
12 |
13 | W2 = weight_variable([3, 3, 64, 64], name="W2"); b2 = bias_variable([64], name="b2");
14 | c2 = tf.nn.relu(_instance_norm(conv2d(c1, W2) + b2))
15 |
16 | W3 = weight_variable([3, 3, 64, 64], name="W3"); b3 = bias_variable([64], name="b3");
17 | c3 = tf.nn.relu(_instance_norm(conv2d(c2, W3) + b3)) + c1
18 |
19 | # residual 2
20 |
21 | W4 = weight_variable([3, 3, 64, 64], name="W4"); b4 = bias_variable([64], name="b4");
22 | c4 = tf.nn.relu(_instance_norm(conv2d(c3, W4) + b4))
23 |
24 | W5 = weight_variable([3, 3, 64, 64], name="W5"); b5 = bias_variable([64], name="b5");
25 | c5 = tf.nn.relu(_instance_norm(conv2d(c4, W5) + b5)) + c3
26 |
27 | # residual 3
28 |
29 | W6 = weight_variable([3, 3, 64, 64], name="W6"); b6 = bias_variable([64], name="b6");
30 | c6 = tf.nn.relu(_instance_norm(conv2d(c5, W6) + b6))
31 |
32 | W7 = weight_variable([3, 3, 64, 64], name="W7"); b7 = bias_variable([64], name="b7");
33 | c7 = tf.nn.relu(_instance_norm(conv2d(c6, W7) + b7)) + c5
34 |
35 | # residual 4
36 |
37 | W8 = weight_variable([3, 3, 64, 64], name="W8"); b8 = bias_variable([64], name="b8");
38 | c8 = tf.nn.relu(_instance_norm(conv2d(c7, W8) + b8))
39 |
40 | W9 = weight_variable([3, 3, 64, 64], name="W9"); b9 = bias_variable([64], name="b9");
41 | c9 = tf.nn.relu(_instance_norm(conv2d(c8, W9) + b9)) + c7
42 |
43 | # Convolutional
44 |
45 | W10 = weight_variable([3, 3, 64, 64], name="W10"); b10 = bias_variable([64], name="b10");
46 | c10 = tf.nn.relu(conv2d(c9, W10) + b10)
47 |
48 | W11 = weight_variable([3, 3, 64, 64], name="W11"); b11 = bias_variable([64], name="b11");
49 | c11 = tf.nn.relu(conv2d(c10, W11) + b11)
50 |
51 | # Final
52 |
53 | W12 = weight_variable([9, 9, 64, 3], name="W12"); b12 = bias_variable([3], name="b12");
54 | enhanced = tf.nn.tanh(conv2d(c11, W12) + b12) * 0.58 + 0.5
55 |
56 | return enhanced
57 |
58 |
59 | def adversarial(image_):
60 |
61 | with tf.compat.v1.variable_scope("discriminator"):
62 |
63 | conv1 = _conv_layer(image_, 48, 11, 4, batch_nn = False)
64 | conv2 = _conv_layer(conv1, 128, 5, 2)
65 | conv3 = _conv_layer(conv2, 192, 3, 1)
66 | conv4 = _conv_layer(conv3, 192, 3, 1)
67 | conv5 = _conv_layer(conv4, 128, 3, 2)
68 |
69 | flat_size = 128 * 7 * 7
70 | conv5_flat = tf.reshape(conv5, [-1, flat_size])
71 |
72 | W_fc = tf.Variable(tf.compat.v1.truncated_normal([flat_size, 1024], stddev=0.01))
73 | bias_fc = tf.Variable(tf.constant(0.01, shape=[1024]))
74 |
75 | fc = leaky_relu(tf.matmul(conv5_flat, W_fc) + bias_fc)
76 |
77 | W_out = tf.Variable(tf.compat.v1.truncated_normal([1024, 2], stddev=0.01))
78 | bias_out = tf.Variable(tf.constant(0.01, shape=[2]))
79 |
80 | adv_out = tf.nn.softmax(tf.matmul(fc, W_out) + bias_out)
81 |
82 | return adv_out
83 |
84 |
85 | def weight_variable(shape, name):
86 |
87 | initial = tf.compat.v1.truncated_normal(shape, stddev=0.01)
88 | return tf.Variable(initial, name=name)
89 |
90 |
91 | def bias_variable(shape, name):
92 |
93 | initial = tf.constant(0.01, shape=shape)
94 | return tf.Variable(initial, name=name)
95 |
96 |
97 | def conv2d(x, W):
98 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
99 |
100 |
101 | def leaky_relu(x, alpha = 0.2):
102 | return tf.maximum(alpha * x, x)
103 |
104 |
105 | def _conv_layer(net, num_filters, filter_size, strides, batch_nn=True):
106 |
107 | weights_init = _conv_init_vars(net, num_filters, filter_size)
108 | strides_shape = [1, strides, strides, 1]
109 | bias = tf.Variable(tf.constant(0.01, shape=[num_filters]))
110 |
111 | net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME') + bias
112 | net = leaky_relu(net)
113 |
114 | if batch_nn:
115 | net = _instance_norm(net)
116 |
117 | return net
118 |
119 |
120 | def _instance_norm(net):
121 |
122 | batch, rows, cols, channels = [i.value for i in net.get_shape()]
123 | var_shape = [channels]
124 |
125 | mu, sigma_sq = tf.compat.v1.nn.moments(net, [1,2], keepdims=True)
126 | shift = tf.Variable(tf.zeros(var_shape))
127 | scale = tf.Variable(tf.ones(var_shape))
128 |
129 | epsilon = 1e-3
130 | normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
131 |
132 | return scale * normalized + shift
133 |
134 |
135 | def _conv_init_vars(net, out_channels, filter_size, transpose=False):
136 |
137 | _, rows, cols, in_channels = [i.value for i in net.get_shape()]
138 |
139 | if not transpose:
140 | weights_shape = [filter_size, filter_size, in_channels, out_channels]
141 | else:
142 | weights_shape = [filter_size, filter_size, out_channels, in_channels]
143 |
144 | weights_init = tf.Variable(tf.compat.v1.truncated_normal(weights_shape, stddev=0.01, seed=1), dtype=tf.float32)
145 | return weights_init
146 |
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models/.gitkeep
--------------------------------------------------------------------------------
/models_orig/blackberry_orig.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models_orig/blackberry_orig.data-00000-of-00001
--------------------------------------------------------------------------------
/models_orig/blackberry_orig.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models_orig/blackberry_orig.index
--------------------------------------------------------------------------------
/models_orig/iphone_orig.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models_orig/iphone_orig.data-00000-of-00001
--------------------------------------------------------------------------------
/models_orig/iphone_orig.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models_orig/iphone_orig.index
--------------------------------------------------------------------------------
/models_orig/sony_orig.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models_orig/sony_orig.data-00000-of-00001
--------------------------------------------------------------------------------
/models_orig/sony_orig.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/models_orig/sony_orig.index
--------------------------------------------------------------------------------
/results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/results/.gitkeep
--------------------------------------------------------------------------------
/ssim.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import signal
3 | from scipy.ndimage.filters import convolve
4 | import tensorflow as tf
5 |
6 |
7 | def _FSpecialGauss(size, sigma):
8 |
9 | radius = size // 2
10 | offset = 0.0
11 | start, stop = -radius, radius + 1
12 |
13 | if size % 2 == 0:
14 | offset = 0.5
15 | stop -= 1
16 |
17 | x, y = np.mgrid[offset + start:stop, offset + start:stop]
18 | g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2)))
19 |
20 | return g / g.sum()
21 |
22 |
23 | def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03):
24 |
25 | img1 = img1.astype(np.float64)
26 | img2 = img2.astype(np.float64)
27 | _, height, width, _ = img1.shape
28 |
29 | size = min(filter_size, height, width)
30 | sigma = size * filter_sigma / filter_size if filter_size else 0
31 |
32 | if filter_size:
33 |
34 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1))
35 | mu1 = signal.fftconvolve(img1, window, mode='valid')
36 | mu2 = signal.fftconvolve(img2, window, mode='valid')
37 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid')
38 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid')
39 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid')
40 |
41 | else:
42 |
43 | mu1, mu2 = img1, img2
44 | sigma11 = img1 * img1
45 | sigma22 = img2 * img2
46 | sigma12 = img1 * img2
47 |
48 | mu11 = mu1 * mu1
49 | mu22 = mu2 * mu2
50 | mu12 = mu1 * mu2
51 | sigma11 -= mu11
52 | sigma22 -= mu22
53 | sigma12 -= mu12
54 |
55 | c1 = (k1 * max_val) ** 2
56 | c2 = (k2 * max_val) ** 2
57 | v1 = 2.0 * sigma12 + c2
58 | v2 = sigma11 + sigma22 + c2
59 |
60 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2)))
61 | cs = np.mean(v1 / v2)
62 |
63 | return ssim, cs
64 |
65 |
66 | def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, weights=None):
67 |
68 | weights = np.array(weights if weights else [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
69 | levels = weights.size
70 |
71 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0
72 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]]
73 |
74 | mssim = np.array([])
75 | mcs = np.array([])
76 |
77 | for _ in range(levels):
78 |
79 | ssim, cs = _SSIMForMultiScale(im1, im2, max_val=max_val, filter_size=filter_size, filter_sigma=filter_sigma, k1=k1, k2=k2)
80 | mssim = np.append(mssim, ssim)
81 | mcs = np.append(mcs, cs)
82 |
83 | filtered = [convolve(im, downsample_filter, mode='reflect') for im in [im1, im2]]
84 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered]
85 |
86 | return np.prod(mcs[0:levels-1] ** weights[0:levels-1]) * (mssim[levels-1] ** weights[levels-1])
87 |
--------------------------------------------------------------------------------
/test_model.py:
--------------------------------------------------------------------------------
1 | # python test_model.py model=iphone_orig dped_dir=dped/ test_subset=full iteration=all resolution=orig use_gpu=true
2 |
3 | import imageio
4 | from PIL import Image
5 | import numpy as np
6 | import tensorflow as tf
7 | from models import resnet
8 | import utils
9 | import os
10 | import sys
11 |
12 | tf.compat.v1.disable_v2_behavior()
13 |
14 | # process command arguments
15 | phone, dped_dir, test_subset, iteration, resolution, use_gpu = utils.process_test_model_args(sys.argv)
16 |
17 | # get all available image resolutions
18 | res_sizes = utils.get_resolutions()
19 |
20 | # get the specified image resolution
21 | IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE = utils.get_specified_res(res_sizes, phone, resolution)
22 |
23 | # disable gpu if specified
24 | config = tf.compat.v1.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else None
25 |
26 | # create placeholders for input images
27 | x_ = tf.compat.v1.placeholder(tf.float32, [None, IMAGE_SIZE])
28 | x_image = tf.reshape(x_, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
29 |
30 | # generate enhanced image
31 | enhanced = resnet(x_image)
32 |
33 | with tf.compat.v1.Session(config=config) as sess:
34 |
35 | test_dir = dped_dir + phone.replace("_orig", "") + "/test_data/full_size_test_images/"
36 | test_photos = [f for f in os.listdir(test_dir) if os.path.isfile(test_dir + f)]
37 |
38 | if test_subset == "small":
39 | # use five first images only
40 | test_photos = test_photos[0:5]
41 |
42 | if phone.endswith("_orig"):
43 |
44 | # load pre-trained model
45 | saver = tf.compat.v1.train.Saver()
46 | saver.restore(sess, "models_orig/" + phone)
47 |
48 | for photo in test_photos:
49 |
50 | # load training image and crop it if necessary
51 |
52 | print("Testing original " + phone.replace("_orig", "") + " model, processing image " + photo)
53 | image = np.float16(np.array(Image.fromarray(imageio.imread(test_dir + photo))
54 | .resize([res_sizes[phone][1], res_sizes[phone][0]]))) / 255
55 |
56 | image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
57 | image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
58 |
59 | # get enhanced image
60 |
61 | enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
62 | enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
63 |
64 | before_after = np.hstack((image_crop, enhanced_image))
65 | photo_name = photo.rsplit(".", 1)[0]
66 |
67 | # save the results as .png images
68 |
69 | imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_enhanced.png", enhanced_image)
70 | imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_before_after.png", before_after)
71 |
72 | else:
73 |
74 | num_saved_models = int(len([f for f in os.listdir("models/") if f.startswith(phone + "_iteration")]) / 2)
75 |
76 | if iteration == "all":
77 | iteration = np.arange(1, num_saved_models) * 1000
78 | else:
79 | iteration = [int(iteration)]
80 |
81 | for i in iteration:
82 |
83 | # load pre-trained model
84 | saver = tf.compat.v1.train.Saver()
85 | saver.restore(sess, "models/" + phone + "_iteration_" + str(i) + ".ckpt")
86 |
87 | for photo in test_photos:
88 |
89 | # load training image and crop it if necessary
90 |
91 | print("iteration " + str(i) + ", processing image " + photo)
92 | image = np.float16(np.array(Image.fromarray(imageio.imread(test_dir + photo))
93 | .resize([res_sizes[phone][1], res_sizes[phone][0]]))) / 255
94 |
95 | image_crop = utils.extract_crop(image, resolution, phone, res_sizes)
96 | image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
97 |
98 | # get enhanced image
99 |
100 | enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
101 | enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
102 |
103 | before_after = np.hstack((image_crop, enhanced_image))
104 | photo_name = photo.rsplit(".", 1)[0]
105 |
106 | # save the results as .png images
107 |
108 | imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_enhanced.png", enhanced_image)
109 | imageio.imwrite("visual_results/" + phone + "_" + photo_name + "_iteration_" + str(i) + "_before_after.png", before_after)
110 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | # python train_model.py model={iphone,sony,blackberry} dped_dir=dped vgg_dir=vgg_pretrained/imagenet-vgg-verydeep-19.mat
2 |
3 | import tensorflow as tf
4 | import imageio
5 | import numpy as np
6 | import sys
7 |
8 | tf.compat.v1.disable_v2_behavior()
9 |
10 | from load_dataset import load_test_data, load_batch
11 | from ssim import MultiScaleSSIM
12 | import models
13 | import utils
14 | import vgg
15 |
16 | # defining size of the training image patches
17 |
18 | PATCH_WIDTH = 100
19 | PATCH_HEIGHT = 100
20 | PATCH_SIZE = PATCH_WIDTH * PATCH_HEIGHT * 3
21 |
22 | # processing command arguments
23 |
24 | phone, batch_size, train_size, learning_rate, num_train_iters, \
25 | w_content, w_color, w_texture, w_tv, \
26 | dped_dir, vgg_dir, eval_step = utils.process_command_args(sys.argv)
27 |
28 | np.random.seed(0)
29 |
30 | # defining system architecture
31 |
32 | with tf.Graph().as_default(), tf.compat.v1.Session() as sess:
33 |
34 | # placeholders for training data
35 |
36 | phone_ = tf.compat.v1.placeholder(tf.float32, [None, PATCH_SIZE])
37 | phone_image = tf.reshape(phone_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
38 |
39 | dslr_ = tf.compat.v1.placeholder(tf.float32, [None, PATCH_SIZE])
40 | dslr_image = tf.reshape(dslr_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])
41 |
42 | adv_ = tf.compat.v1.placeholder(tf.float32, [None, 1])
43 |
44 | # get processed enhanced image
45 |
46 | enhanced = models.resnet(phone_image)
47 |
48 | # transform both dslr and enhanced images to grayscale
49 |
50 | enhanced_gray = tf.reshape(tf.image.rgb_to_grayscale(enhanced), [-1, PATCH_WIDTH * PATCH_HEIGHT])
51 | dslr_gray = tf.reshape(tf.image.rgb_to_grayscale(dslr_image),[-1, PATCH_WIDTH * PATCH_HEIGHT])
52 |
53 | # push randomly the enhanced or dslr image to an adversarial CNN-discriminator
54 |
55 | adversarial_ = tf.multiply(enhanced_gray, 1 - adv_) + tf.multiply(dslr_gray, adv_)
56 | adversarial_image = tf.reshape(adversarial_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 1])
57 |
58 | discrim_predictions = models.adversarial(adversarial_image)
59 |
60 | # losses
61 | # 1) texture (adversarial) loss
62 |
63 | discrim_target = tf.concat([adv_, 1 - adv_], 1)
64 |
65 | loss_discrim = -tf.reduce_sum(discrim_target * tf.compat.v1.log(tf.clip_by_value(discrim_predictions, 1e-10, 1.0)))
66 | loss_texture = -loss_discrim
67 |
68 | correct_predictions = tf.equal(tf.argmax(discrim_predictions, 1), tf.argmax(discrim_target, 1))
69 | discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
70 |
71 | # 2) content loss
72 |
73 | CONTENT_LAYER = 'relu5_4'
74 |
75 | enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
76 | dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image * 255))
77 |
78 | content_size = utils._tensor_size(dslr_vgg[CONTENT_LAYER]) * batch_size
79 | loss_content = 2 * tf.nn.l2_loss(enhanced_vgg[CONTENT_LAYER] - dslr_vgg[CONTENT_LAYER]) / content_size
80 |
81 | # 3) color loss
82 |
83 | enhanced_blur = utils.blur(enhanced)
84 | dslr_blur = utils.blur(dslr_image)
85 |
86 | loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur, 2))/(2 * batch_size)
87 |
88 | # 4) total variation loss
89 |
90 | batch_shape = (batch_size, PATCH_WIDTH, PATCH_HEIGHT, 3)
91 | tv_y_size = utils._tensor_size(enhanced[:,1:,:,:])
92 | tv_x_size = utils._tensor_size(enhanced[:,:,1:,:])
93 | y_tv = tf.nn.l2_loss(enhanced[:,1:,:,:] - enhanced[:,:batch_shape[1]-1,:,:])
94 | x_tv = tf.nn.l2_loss(enhanced[:,:,1:,:] - enhanced[:,:,:batch_shape[2]-1,:])
95 | loss_tv = 2 * (x_tv/tv_x_size + y_tv/tv_y_size) / batch_size
96 |
97 | # final loss
98 |
99 | loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv
100 |
101 | # psnr loss
102 |
103 | enhanced_flat = tf.reshape(enhanced, [-1, PATCH_SIZE])
104 |
105 | loss_mse = tf.reduce_sum(tf.pow(dslr_ - enhanced_flat, 2))/(PATCH_SIZE * batch_size)
106 | loss_psnr = 20 * utils.log10(1.0 / tf.sqrt(loss_mse))
107 |
108 | # optimize parameters of image enhancement (generator) and discriminator networks
109 |
110 | generator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("generator")]
111 | discriminator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("discriminator")]
112 |
113 | train_step_gen = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_generator, var_list=generator_vars)
114 | train_step_disc = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_discrim, var_list=discriminator_vars)
115 |
116 | saver = tf.compat.v1.train.Saver(var_list=generator_vars, max_to_keep=100)
117 |
118 | print('Initializing variables')
119 | sess.run(tf.compat.v1.global_variables_initializer())
120 |
121 | # loading training and test data
122 |
123 | print("Loading test data...")
124 | test_data, test_answ = load_test_data(phone, dped_dir, PATCH_SIZE)
125 | print("Test data was loaded\n")
126 |
127 | print("Loading training data...")
128 | train_data, train_answ = load_batch(phone, dped_dir, train_size, PATCH_SIZE)
129 | print("Training data was loaded\n")
130 |
131 | TEST_SIZE = test_data.shape[0]
132 | num_test_batches = int(test_data.shape[0] / batch_size)
133 |
134 | print('Training network')
135 |
136 | train_loss_gen = 0.0
137 | train_acc_discrim = 0.0
138 |
139 | all_zeros = np.reshape(np.zeros((batch_size, 1)), [batch_size, 1])
140 | test_crops = test_data[np.random.randint(0, TEST_SIZE, 5), :]
141 |
142 | logs = open('models/' + phone + '.txt', "w+")
143 | logs.close()
144 |
145 | for i in range(num_train_iters):
146 |
147 | # train generator
148 |
149 | idx_train = np.random.randint(0, train_size, batch_size)
150 |
151 | phone_images = train_data[idx_train]
152 | dslr_images = train_answ[idx_train]
153 |
154 | [loss_temp, temp] = sess.run([loss_generator, train_step_gen],
155 | feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros})
156 | train_loss_gen += loss_temp / eval_step
157 |
158 | # train discriminator
159 |
160 | idx_train = np.random.randint(0, train_size, batch_size)
161 |
162 | # generate image swaps (dslr or enhanced) for discriminator
163 | swaps = np.reshape(np.random.randint(0, 2, batch_size), [batch_size, 1])
164 |
165 | phone_images = train_data[idx_train]
166 | dslr_images = train_answ[idx_train]
167 |
168 | [accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc],
169 | feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
170 | train_acc_discrim += accuracy_temp / eval_step
171 |
172 | if i % eval_step == 0:
173 |
174 | # test generator and discriminator CNNs
175 |
176 | test_losses_gen = np.zeros((1, 6))
177 | test_accuracy_disc = 0.0
178 | loss_ssim = 0.0
179 |
180 | for j in range(num_test_batches):
181 |
182 | be = j * batch_size
183 | en = (j+1) * batch_size
184 |
185 | swaps = np.reshape(np.random.randint(0, 2, batch_size), [batch_size, 1])
186 |
187 | phone_images = test_data[be:en]
188 | dslr_images = test_answ[be:en]
189 |
190 | [enhanced_crops, accuracy_disc, losses] = sess.run([enhanced, discim_accuracy, \
191 | [loss_generator, loss_content, loss_color, loss_texture, loss_tv, loss_psnr]], \
192 | feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps})
193 |
194 | test_losses_gen += np.asarray(losses) / num_test_batches
195 | test_accuracy_disc += accuracy_disc / num_test_batches
196 |
197 | loss_ssim += MultiScaleSSIM(np.reshape(dslr_images * 255, [batch_size, PATCH_HEIGHT, PATCH_WIDTH, 3]),
198 | enhanced_crops * 255) / num_test_batches
199 |
200 | logs_disc = "step %d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \
201 | (i, phone, train_acc_discrim, test_accuracy_disc)
202 |
203 | logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | psnr: %.4g, ms-ssim: %.4g\n" % \
204 | (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2],
205 | test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], loss_ssim)
206 |
207 | print(logs_disc)
208 | print(logs_gen)
209 |
210 | # save the results to log file
211 |
212 | logs = open('models/' + phone + '.txt', "a")
213 | logs.write(logs_disc)
214 | logs.write('\n')
215 | logs.write(logs_gen)
216 | logs.write('\n')
217 | logs.close()
218 |
219 | # save visual results for several test image crops
220 |
221 | enhanced_crops = sess.run(enhanced, feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros})
222 |
223 | idx = 0
224 | for crop in enhanced_crops:
225 | before_after = np.hstack((np.reshape(test_crops[idx], [PATCH_HEIGHT, PATCH_WIDTH, 3]), crop))
226 | imageio.imwrite('results/' + str(phone)+ "_" + str(idx) + '_iteration_' + str(i) + '.jpg', before_after)
227 | idx += 1
228 |
229 | train_loss_gen = 0.0
230 | train_acc_discrim = 0.0
231 |
232 | # save the model that corresponds to the current iteration
233 |
234 | saver.save(sess, 'models/' + str(phone) + '_iteration_' + str(i) + '.ckpt', write_meta_graph=False)
235 |
236 | # reload a different batch of training data
237 |
238 | del train_data
239 | del train_answ
240 | train_data, train_answ = load_batch(phone, dped_dir, train_size, PATCH_SIZE)
241 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import scipy.stats as st
2 | import tensorflow as tf
3 | import numpy as np
4 | import sys
5 |
6 | from functools import reduce
7 |
8 |
9 | def log10(x):
10 | numerator = tf.compat.v1.log(x)
11 | denominator = tf.compat.v1.log(tf.constant(10, dtype=numerator.dtype))
12 | return numerator / denominator
13 |
14 |
15 | def _tensor_size(tensor):
16 | from operator import mul
17 | return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
18 |
19 |
20 | def gauss_kernel(kernlen=21, nsig=3, channels=1):
21 | interval = (2*nsig+1.)/(kernlen)
22 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
23 | kern1d = np.diff(st.norm.cdf(x))
24 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
25 | kernel = kernel_raw/kernel_raw.sum()
26 | out_filter = np.array(kernel, dtype = np.float32)
27 | out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
28 | out_filter = np.repeat(out_filter, channels, axis = 2)
29 | return out_filter
30 |
31 |
32 | def blur(x):
33 | kernel_var = gauss_kernel(21, 3, 3)
34 | return tf.nn.depthwise_conv2d(x, kernel_var, [1, 1, 1, 1], padding='SAME')
35 |
36 |
37 | def process_command_args(arguments):
38 |
39 | # specifying default parameters
40 |
41 | batch_size = 50
42 | train_size = 30000
43 | learning_rate = 5e-4
44 | num_train_iters = 20000
45 |
46 | w_content = 10
47 | w_color = 0.5
48 | w_texture = 1
49 | w_tv = 2000
50 |
51 | dped_dir = 'dped/'
52 | vgg_dir = 'vgg_pretrained/imagenet-vgg-verydeep-19.mat'
53 | eval_step = 1000
54 |
55 | phone = ""
56 |
57 | for args in arguments:
58 |
59 | if args.startswith("model"):
60 | phone = args.split("=")[1]
61 |
62 | if args.startswith("batch_size"):
63 | batch_size = int(args.split("=")[1])
64 |
65 | if args.startswith("train_size"):
66 | train_size = int(args.split("=")[1])
67 |
68 | if args.startswith("learning_rate"):
69 | learning_rate = float(args.split("=")[1])
70 |
71 | if args.startswith("num_train_iters"):
72 | num_train_iters = int(args.split("=")[1])
73 |
74 | # -----------------------------------
75 |
76 | if args.startswith("w_content"):
77 | w_content = float(args.split("=")[1])
78 |
79 | if args.startswith("w_color"):
80 | w_color = float(args.split("=")[1])
81 |
82 | if args.startswith("w_texture"):
83 | w_texture = float(args.split("=")[1])
84 |
85 | if args.startswith("w_tv"):
86 | w_tv = float(args.split("=")[1])
87 |
88 | # -----------------------------------
89 |
90 | if args.startswith("dped_dir"):
91 | dped_dir = args.split("=")[1]
92 |
93 | if args.startswith("vgg_dir"):
94 | vgg_dir = args.split("=")[1]
95 |
96 | if args.startswith("eval_step"):
97 | eval_step = int(args.split("=")[1])
98 |
99 |
100 | if phone == "":
101 | print("\nPlease specify the camera model by running the script with the following parameter:\n")
102 | print("python train_model.py model={iphone,blackberry,sony}\n")
103 | sys.exit()
104 |
105 | if phone not in ["iphone", "sony", "blackberry"]:
106 | print("\nPlease specify the correct camera model:\n")
107 | print("python train_model.py model={iphone,blackberry,sony}\n")
108 | sys.exit()
109 |
110 | print("\nThe following parameters will be applied for CNN training:\n")
111 |
112 | print("Phone model:", phone)
113 | print("Batch size:", batch_size)
114 | print("Learning rate:", learning_rate)
115 | print("Training iterations:", str(num_train_iters))
116 | print()
117 | print("Content loss:", w_content)
118 | print("Color loss:", w_color)
119 | print("Texture loss:", w_texture)
120 | print("Total variation loss:", str(w_tv))
121 | print()
122 | print("Path to DPED dataset:", dped_dir)
123 | print("Path to VGG-19 network:", vgg_dir)
124 | print("Evaluation step:", str(eval_step))
125 | print()
126 | return phone, batch_size, train_size, learning_rate, num_train_iters, \
127 | w_content, w_color, w_texture, w_tv,\
128 | dped_dir, vgg_dir, eval_step
129 |
130 |
131 | def process_test_model_args(arguments):
132 |
133 | phone = ""
134 | dped_dir = 'dped/'
135 | test_subset = "small"
136 | iteration = "all"
137 | resolution = "orig"
138 | use_gpu = "true"
139 |
140 | for args in arguments:
141 |
142 | if args.startswith("model"):
143 | phone = args.split("=")[1]
144 |
145 | if args.startswith("dped_dir"):
146 | dped_dir = args.split("=")[1]
147 |
148 | if args.startswith("test_subset"):
149 | test_subset = args.split("=")[1]
150 |
151 | if args.startswith("iteration"):
152 | iteration = args.split("=")[1]
153 |
154 | if args.startswith("resolution"):
155 | resolution = args.split("=")[1]
156 |
157 | if args.startswith("use_gpu"):
158 | use_gpu = args.split("=")[1]
159 |
160 | if phone == "":
161 | print("\nPlease specify the model by running the script with the following parameter:\n")
162 | print("python test_model.py model={iphone,blackberry,sony,iphone_orig,blackberry_orig,sony_orig}\n")
163 | sys.exit()
164 |
165 | return phone, dped_dir, test_subset, iteration, resolution, use_gpu
166 |
167 |
168 | def get_resolutions():
169 |
170 | # IMAGE_HEIGHT, IMAGE_WIDTH
171 |
172 | res_sizes = {}
173 |
174 | res_sizes["iphone"] = [1536, 2048]
175 | res_sizes["iphone_orig"] = [1536, 2048]
176 | res_sizes["blackberry"] = [1560, 2080]
177 | res_sizes["blackberry_orig"] = [1560, 2080]
178 | res_sizes["sony"] = [1944, 2592]
179 | res_sizes["sony_orig"] = [1944, 2592]
180 | res_sizes["high"] = [1260, 1680]
181 | res_sizes["medium"] = [1024, 1366]
182 | res_sizes["small"] = [768, 1024]
183 | res_sizes["tiny"] = [600, 800]
184 |
185 | return res_sizes
186 |
187 |
188 | def get_specified_res(res_sizes, phone, resolution):
189 |
190 | if resolution == "orig":
191 | IMAGE_HEIGHT = res_sizes[phone][0]
192 | IMAGE_WIDTH = res_sizes[phone][1]
193 | else:
194 | IMAGE_HEIGHT = res_sizes[resolution][0]
195 | IMAGE_WIDTH = res_sizes[resolution][1]
196 |
197 | IMAGE_SIZE = IMAGE_WIDTH * IMAGE_HEIGHT * 3
198 |
199 | return IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE
200 |
201 |
202 | def extract_crop(image, resolution, phone, res_sizes):
203 |
204 | if resolution == "orig":
205 | return image
206 |
207 | else:
208 |
209 | x_up = int((res_sizes[phone][1] - res_sizes[resolution][1]) / 2)
210 | y_up = int((res_sizes[phone][0] - res_sizes[resolution][0]) / 2)
211 |
212 | x_down = x_up + res_sizes[resolution][1]
213 | y_down = y_up + res_sizes[resolution][0]
214 |
215 | return image[y_up : y_down, x_up : x_down, :]
216 |
--------------------------------------------------------------------------------
/vgg.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import scipy.io
4 |
5 | IMAGE_MEAN = np.array([123.68 , 116.779, 103.939])
6 |
7 | def net(path_to_vgg_net, input_image):
8 |
9 | layers = (
10 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
11 |
12 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
13 |
14 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
15 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
16 |
17 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
18 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
19 |
20 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
21 | 'relu5_3', 'conv5_4', 'relu5_4'
22 | )
23 |
24 | data = scipy.io.loadmat(path_to_vgg_net)
25 | weights = data['layers'][0]
26 |
27 | net = {}
28 | current = input_image
29 | for i, name in enumerate(layers):
30 | layer_type = name[:4]
31 | if layer_type == 'conv':
32 | kernels, bias = weights[i][0][0][0][0]
33 | kernels = np.transpose(kernels, (1, 0, 2, 3))
34 | bias = bias.reshape(-1)
35 | current = _conv_layer(current, kernels, bias)
36 | elif layer_type == 'relu':
37 | current = tf.nn.relu(current)
38 | elif layer_type == 'pool':
39 | current = _pool_layer(current)
40 | net[name] = current
41 |
42 | return net
43 |
44 | def _conv_layer(input, weights, bias):
45 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), padding='SAME')
46 | return tf.nn.bias_add(conv, bias)
47 |
48 | def _pool_layer(input):
49 | return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME')
50 |
51 | def preprocess(image):
52 | return image - IMAGE_MEAN
53 |
--------------------------------------------------------------------------------
/vgg_pretrained/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/vgg_pretrained/.gitkeep
--------------------------------------------------------------------------------
/visual_results/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aiff22/DPED/ebb01315238430f7c66eaaf84996fcb59877f97f/visual_results/.gitkeep
--------------------------------------------------------------------------------