├── LICENSE.md ├── README.md ├── ebb_dataset └── .gitkeep ├── load_dataset.py ├── model.py ├── models └── original │ └── .gitkeep ├── results └── full-resolution │ └── .gitkeep ├── test_model.py ├── train_model.py ├── utils.py ├── vgg.py ├── vgg_pretrained └── .gitkeep └── visual_samples ├── depth_maps ├── 1.png ├── 10.png ├── 11.png ├── 12.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png └── 9.png └── images ├── 1.jpg ├── 10.jpg ├── 11.jpg ├── 12.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg └── 9.jpg /LICENSE.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/LICENSE.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Rendering Natural Camera Bokeh Effect with Deep Learning 2 | 3 |
4 | 5 | 6 | 7 |
8 | 9 | #### 1. Overview  [[Paper]](https://arxiv.org/pdf/2006.05698.pdf) [[Project Webpage]](http://people.ee.ethz.ch/~ihnatova/pynet-bokeh.html) [[PyNET PyTorch]](https://github.com/aiff22/PyNET-PyTorch) 10 | 11 | This repository provides the implementation of the deep learning-based bokeh effect rendering approach presented in [this paper](https://arxiv.org/pdf/2006.05698.pdf). The model is trained to map the standard **narrow-aperture images** into shallow depth-of-field photos captured with a professional Canon 7D DSLR camera. The presented approach is camera independent, **does not require any special hardware**, and can also be applied to the existing images. More visual results of this method on the presented EBB! dataset and its comparison to the **Portrait Mode** of the *Google Pixel Camera* app can be found [here](http://people.ee.ethz.ch/~ihnatova/pynet-bokeh.html#demo). 12 | 13 |
14 | 15 | #### 2. Prerequisites 16 | 17 | - Python: scipy, numpy, imageio and pillow packages 18 | - [TensorFlow 1.x / 2.x](https://www.tensorflow.org/install/) + [CUDA cuDNN](https://developer.nvidia.com/cudnn) 19 | - Nvidia GPU 20 | 21 |
22 | 23 | #### 3. First steps 24 | 25 | - Download the pre-trained [VGG-19 model](https://polybox.ethz.ch/index.php/s/7z5bHNg5r5a0g7k) and put it into `vgg_pretrained/` folder. 26 | - Download the pre-trained [PyNET model](https://data.vision.ee.ethz.ch/ihnatova/public/ebb/PyNET_Bokeh_pretrained.zip) and put it into `models/original/` folder. 27 | - Download the [EBB! dataset](http://people.ee.ethz.ch/~ihnatova/pynet-bokeh.html#dataset) and extract it into `ebb_dataset/` folder. 28 | This folder should contain two subfolders: `train/` and `test/` 29 | 30 | *Please note that Google Drive has a quota limiting the number of downloads per day. To avoid it, you can login to your Google account and press "Add to My Drive" button instead of a direct download. Please check [this issue](https://github.com/aiff22/PyNET/issues/4) for more information.* 31 | 32 |
33 | 34 | 35 | #### 4. PyNET CNN 36 | 37 |
38 | 39 | drawing 40 | 41 |
42 | 43 | The proposed PyNET-based architecture has an inverted pyramidal shape and is processing the images at **seven different scales** (levels). The model is trained sequentially, starting from the lowest 7th layer, which allows to achieve good semantically-driven reconstruction results at smaller scales that are working with images of very low resolution and thus performing mostly global image manipulations. After the bottom layer is pre-trained, the same procedure is applied to the next level till the training is done on the original resolution. Since each higher level is getting **upscaled high-quality features** from the lower part of the model, it mainly learns to reconstruct the missing low-level details and refines the results. In this work, we additionally use two transposed convolutional layers on top of the main model (Levels 1, 2) that upsample the images to their target size. 44 | 45 |
46 | 47 | #### 5. Training the model 48 | 49 | The model is trained level by level, starting from the lowest (7th) one: 50 | 51 | ```bash 52 | python train_model.py level= 53 | ``` 54 | 55 | Obligatory parameters: 56 | 57 | >```level```: **```7, 6, 5, 4, 3, 2, 1```** 58 | 59 | Optional parameters and their default values: 60 | 61 | >```batch_size```: **```50```**   -   batch size [small values can lead to unstable training]
62 | >```train_size```: **```4894```**   -   the number of training images randomly loaded each 1000 iterations
63 | >```eval_step```: **```1000```**   -   each ```eval_step``` iterations the accuracy is computed and the model is saved
64 | >```learning_rate```: **```5e-5```**   -   learning rate
65 | >```restore_iter```: **```None```**   -   iteration to restore (when not specified, the last saved model for PyNET's ```level+1``` is loaded)
66 | >```num_train_iters```: **```5K, 5K, 20K, 20K, 30K, 80K, 100K (for levels 5 - 0)```**   -   the number of training iterations
67 | >```vgg_dir```: **```vgg_pretrained/imagenet-vgg-verydeep-19.mat```**   -   path to the pre-trained VGG-19 network
68 | >```dataset_dir```: **```ebb_dataset/```**   -   path to the folder with the **EBB! dataset**
69 | 70 |
71 | 72 | Below we provide the commands used for training the model on the Nvidia Tesla V100 GPU with 16GB of RAM. When using GPUs with smaller amount of memory, the batch size and the number of training iterations should be adjusted accordingly: 73 | 74 | ```bash 75 | python train_model.py level=7 batch_size=50 num_train_iters=5000 76 | python train_model.py level=6 batch_size=50 num_train_iters=5000 77 | python train_model.py level=5 batch_size=40 num_train_iters=20000 78 | python train_model.py level=4 batch_size=14 num_train_iters=20000 79 | python train_model.py level=3 batch_size=9 num_train_iters=30000 80 | python train_model.py level=2 batch_size=9 num_train_iters=80000 81 | python train_model.py level=1 batch_size=5 num_train_iters=100000 82 | ``` 83 | 84 |
85 | 86 | #### 6. Test the provided pre-trained models on full-resolution test EBB! images 87 | 88 | ```bash 89 | python test_model.py orig=true 90 | ``` 91 | 92 | Optional parameters: 93 | 94 | >```use_gpu```: **```true```**,**```false```**   -   run the model on GPU or CPU
95 | >```dataset_dir```: **```ebb_dataset/```**   -   path to the folder with the **EBB! dataset**
96 | 97 |
98 | 99 | #### 7. Validate the obtained model on full-resolution test EBB! images 100 | 101 | ```bash 102 | python test_model.py 103 | ``` 104 | Optional parameters: 105 | 106 | >```restore_iter```: **```None```**   -   iteration to restore (when not specified, the last saved model for level=`````` is loaded)
107 | >```use_gpu```: **```true```**,**```false```**   -   run the model on GPU or CPU
108 | >```dataset_dir```: **```ebb_dataset/```**   -   path to the folder with the **EBB! dataset**
109 | 110 |
111 | 112 | #### 8. Folder structure 113 | 114 | >```models/```   -   logs and models that are saved during the training process
115 | >```models/original/```   -   the folder with the provided pre-trained PyNET model
116 | >```ebb_dataset/```   -   the folder with the EBB! dataset
117 | >```results/```   -   visual results for image crops that are saved while training
118 | >```results/full-resolution/```   -   full-resolution image results saved during the testing
119 | >```vgg-pretrained/```   -   the folder with the pre-trained VGG-19 network
120 | 121 | >```load_dataset.py```   -   python script that loads training data
122 | >```model.py```   -   PyNET implementation (TensorFlow)
123 | >```train_model.py```   -   implementation of the training procedure
124 | >```test_model.py```   -   applying the pre-trained model to full-resolution test images and computing the numerical results
125 | >```utils.py```   -   auxiliary functions
126 | >```vgg.py```   -   loading the pre-trained vgg-19 network
127 | 128 |
129 | 130 | #### 9. License 131 | 132 | Copyright (C) 2020 Andrey Ignatov. All rights reserved. 133 | 134 | Licensed under the [CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 135 | 136 | The code is released for academic research use only. 137 | 138 |
139 | 140 | #### 10. Citation 141 | 142 | ``` 143 | @article{ignatov2020rendering, 144 | title={Rendering Natural Camera Bokeh Effect with Deep Learning}, 145 | author={Ignatov, Andrey and Patel, Jagruti and Timofte, Radu}, 146 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 147 | pages={0--0}, 148 | year={2020} 149 | } 150 | ``` 151 |
152 | 153 | #### 11. Any further questions? 154 | 155 | ``` 156 | Please contact Andrey Ignatov (andrey@vision.ee.ethz.ch) for more information 157 | ``` 158 | -------------------------------------------------------------------------------- /ebb_dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/ebb_dataset/.gitkeep -------------------------------------------------------------------------------- /load_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 by Andrey Ignatov. All Rights Reserved. 2 | 3 | from __future__ import print_function 4 | from scipy import misc 5 | from PIL import Image 6 | import imageio 7 | import os 8 | import numpy as np 9 | 10 | 11 | def load_test_data(dataset_dir, PATCH_WIDTH, PATCH_HEIGHT, DSLR_SCALE): 12 | 13 | test_directory_orig = dataset_dir + 'test/original/' 14 | test_directory_orig_depth = dataset_dir + 'test/original_depth/' 15 | test_directory_blur = dataset_dir + 'test/bokeh/' 16 | 17 | #NUM_TEST_IMAGES = 200 18 | NUM_TEST_IMAGES = len([name for name in os.listdir(test_directory_orig) 19 | if os.path.isfile(os.path.join(test_directory_orig, name))]) 20 | 21 | test_data = np.zeros((NUM_TEST_IMAGES, PATCH_HEIGHT, PATCH_WIDTH, 4)) 22 | test_answ = np.zeros((NUM_TEST_IMAGES, int(PATCH_HEIGHT * DSLR_SCALE), int(PATCH_WIDTH * DSLR_SCALE), 3)) 23 | 24 | for i in range(0, NUM_TEST_IMAGES): 25 | 26 | I = misc.imread(test_directory_orig + str(i) + '.jpg') 27 | I_depth = misc.imread(test_directory_orig_depth + str(i) + '.png') 28 | 29 | # Downscaling the image by a factor of 2 30 | I = misc.imresize(I, 0.5, interp='bicubic') 31 | 32 | # Making sure that its width is multiple of 32 33 | new_width = int(I.shape[1]/32) * 32 34 | I = I[:, 0:new_width, :] 35 | 36 | # Stacking the image together with its depth map 37 | I_temp = np.zeros((I.shape[0], I.shape[1], 4)) 38 | I_temp[:, :, 0:3] = I 39 | I_temp[:, :, 3] = I_depth 40 | I = I_temp 41 | 42 | h, w, d = I.shape 43 | y = np.random.randint(0, w - 512) 44 | 45 | # Extracting random patch of width PATCH_WIDTH 46 | I = np.float32(I[:, y:y + PATCH_WIDTH, :]) / 255.0 47 | test_data[i, :] = I 48 | 49 | I = misc.imread(test_directory_blur + str(i) + '.jpg') 50 | I = np.float32(misc.imresize(I[:, y*2:y*2 + 1024, :], DSLR_SCALE / 2, interp='bicubic')) / 255.0 51 | test_answ[i, :] = I 52 | 53 | return test_data, test_answ 54 | 55 | 56 | def load_training_batch(dataset_dir, PATCH_WIDTH, PATCH_HEIGHT, DSLR_SCALE, train_size): 57 | 58 | test_directory_orig = dataset_dir + 'train/original/' 59 | test_directory_orig_depth = dataset_dir + 'train/original_depth/' 60 | test_directory_blur = dataset_dir + 'train/bokeh/' 61 | 62 | # NUM_TRAINING_IMAGES = 4894 63 | NUM_TRAINING_IMAGES = len([name for name in os.listdir(test_directory_orig) 64 | if os.path.isfile(os.path.join(test_directory_orig, name))]) 65 | 66 | TRAIN_IMAGES = np.random.choice(np.arange(0, NUM_TRAINING_IMAGES), train_size, replace=False) 67 | 68 | test_data = np.zeros((train_size, PATCH_HEIGHT, PATCH_WIDTH, 4)) 69 | test_answ = np.zeros((train_size, int(PATCH_HEIGHT * DSLR_SCALE), int(PATCH_WIDTH * DSLR_SCALE), 3)) 70 | 71 | i = 0 72 | for img in TRAIN_IMAGES: 73 | 74 | I = misc.imread(test_directory_orig + str(img) + '.jpg') 75 | I_depth = misc.imread(test_directory_orig_depth + str(img) + '.png') 76 | 77 | # Downscaling the image by a factor of 2 78 | I = misc.imresize(I, 0.5, interp='bicubic') 79 | 80 | # Making sure that its width is multiple of 32 81 | new_width = int(I.shape[1] / 32) * 32 82 | I = I[:, 0:new_width, :] 83 | 84 | # Stacking the image together with its depth map 85 | I_temp = np.zeros((I.shape[0], I.shape[1], 4)) 86 | I_temp[:, :, 0:3] = I 87 | I_temp[:, :, 3] = I_depth 88 | I = I_temp 89 | 90 | h, w, d = I.shape 91 | y = np.random.randint(0, w - 512) 92 | 93 | # Extracting random patch of width PATCH_WIDTH 94 | I = np.float32(I[:, y:y + PATCH_WIDTH, :]) / 255.0 95 | test_data[i, :] = I 96 | 97 | I = misc.imread(test_directory_blur + str(img) + '.jpg') 98 | I = np.float32(misc.imresize(I[:, y * 2:y * 2 + 1024, :], DSLR_SCALE / 2, interp='bicubic')) / 255.0 99 | test_answ[i, :] = I 100 | 101 | i += 1 102 | 103 | return test_data, test_answ 104 | 105 | 106 | def load_input_image(image_dir, depth_maps_dir, photo): 107 | 108 | I = misc.imread(image_dir + photo) 109 | I_depth = misc.imread(depth_maps_dir + str(photo.split(".")[0]) + '.png') 110 | 111 | # Downscaling the image by a factor of 2 112 | I = misc.imresize(I, 0.5, interp='bicubic') 113 | 114 | # Making sure that its width is multiple of 32 115 | new_width = int(I.shape[1] / 32) * 32 116 | I = I[:, 0:new_width, :] 117 | I_depth = I_depth[:, 0:new_width] 118 | 119 | # Stacking the image together with its depth map 120 | I_temp = np.zeros((I.shape[0], I.shape[1], 4)) 121 | I_temp[:, :, 0:3] = I 122 | I_temp[:, :, 3] = I_depth 123 | 124 | I = np.float32(I_temp) / 255.0 125 | I = np.reshape(I, [1, I.shape[0], I.shape[1], 4]) 126 | 127 | return I 128 | 129 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 by Andrey Ignatov. All Rights Reserved. 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | def PyNET(input, instance_norm=True, instance_norm_level_1=False): 8 | 9 | # Note: the paper uses a different layer naming scheme. 10 | # In this code, layer N corresponds to layer N+2 from the article. 11 | 12 | with tf.compat.v1.variable_scope("generator"): 13 | 14 | # ----------------------------------------- 15 | # Space-to-depth layer 16 | 17 | space2depth_l0 = tf.nn.space_to_depth(input, 2) # 512 -> 256 18 | 19 | # ----------------------------------------- 20 | # Downsampling layers 21 | 22 | conv_l1_d1 = _conv_multi_block(space2depth_l0, 3, num_maps=32, instance_norm=False) # 256 -> 256 23 | pool1 = max_pool(conv_l1_d1, 2) # 256 -> 128 24 | 25 | conv_l2_d1 = _conv_multi_block(pool1, 3, num_maps=64, instance_norm=instance_norm) # 128 -> 128 26 | pool2 = max_pool(conv_l2_d1, 2) # 128 -> 64 27 | 28 | conv_l3_d1 = _conv_multi_block(pool2, 3, num_maps=128, instance_norm=instance_norm) # 64 -> 64 29 | pool3 = max_pool(conv_l3_d1, 2) # 64 -> 32 30 | 31 | conv_l4_d1 = _conv_multi_block(pool3, 3, num_maps=256, instance_norm=instance_norm) # 32 -> 32 32 | pool4 = max_pool(conv_l4_d1, 2) # 32 -> 16 33 | 34 | # ----------------------------------------- 35 | # Processing: Level 5, Input size: 16 x 16 36 | 37 | conv_l5_d1 = _conv_multi_block(pool4, 3, num_maps=512, instance_norm=instance_norm) 38 | conv_l5_d2 = _conv_multi_block(conv_l5_d1, 3, num_maps=512, instance_norm=instance_norm) + conv_l5_d1 39 | conv_l5_d3 = _conv_multi_block(conv_l5_d2, 3, num_maps=512, instance_norm=instance_norm) + conv_l5_d2 40 | conv_l5_d4 = _conv_multi_block(conv_l5_d3, 3, num_maps=512, instance_norm=instance_norm) 41 | 42 | conv_t4a = _conv_tranpose_layer(conv_l5_d4, 256, 3, 2) # 16 -> 32 43 | conv_t4b = _conv_tranpose_layer(conv_l5_d4, 256, 3, 2) # 16 -> 32 44 | 45 | # -> Output: Level 5 46 | 47 | conv_l5_out = _conv_layer(conv_l5_d4, 3, 3, 1, relu=False, instance_norm=False) 48 | output_l5 = tf.nn.tanh(conv_l5_out) * 0.58 + 0.5 49 | 50 | # ----------------------------------------- 51 | # Processing: Level 4, Input size: 32 x 32 52 | 53 | conv_l4_d2 = stack(conv_l4_d1, conv_t4a) 54 | conv_l4_d3 = _conv_multi_block(conv_l4_d2, 3, num_maps=256, instance_norm=instance_norm) 55 | conv_l4_d4 = _conv_multi_block(conv_l4_d3, 3, num_maps=256, instance_norm=instance_norm) + conv_l4_d3 56 | conv_l4_d5 = _conv_multi_block(conv_l4_d4, 3, num_maps=256, instance_norm=instance_norm) + conv_l4_d4 57 | conv_l4_d6 = stack(_conv_multi_block(conv_l4_d5, 3, num_maps=256, instance_norm=instance_norm), conv_t4b) 58 | 59 | conv_l4_d7 = _conv_multi_block(conv_l4_d6, 3, num_maps=256, instance_norm=instance_norm) 60 | 61 | conv_t3a = _conv_tranpose_layer(conv_l4_d7, 128, 3, 2) # 32 -> 64 62 | conv_t3b = _conv_tranpose_layer(conv_l4_d7, 128, 3, 2) # 32 -> 64 63 | 64 | # -> Output: Level 4 65 | 66 | conv_l4_out = _conv_layer(conv_l4_d7, 3, 3, 1, relu=False, instance_norm=False) 67 | output_l4 = tf.nn.tanh(conv_l4_out) * 0.58 + 0.5 68 | 69 | # ----------------------------------------- 70 | # Processing: Level 3, Input size: 64 x 64 71 | 72 | conv_l3_d2 = stack(conv_l3_d1, conv_t3a) 73 | conv_l3_d3 = _conv_multi_block(conv_l3_d2, 5, num_maps=128, instance_norm=instance_norm) + conv_l3_d2 74 | conv_l3_d4 = _conv_multi_block(conv_l3_d3, 5, num_maps=128, instance_norm=instance_norm) + conv_l3_d3 75 | conv_l3_d5 = _conv_multi_block(conv_l3_d4, 5, num_maps=128, instance_norm=instance_norm) + conv_l3_d4 76 | conv_l3_d6 = stack(_conv_multi_block(conv_l3_d5, 5, num_maps=128, instance_norm=instance_norm), conv_l3_d1) 77 | conv_l3_d7 = stack(conv_l3_d6, conv_t3b) 78 | 79 | conv_l3_d8 = _conv_multi_block(conv_l3_d7, 3, num_maps=128, instance_norm=instance_norm) 80 | 81 | conv_t2a = _conv_tranpose_layer(conv_l3_d8, 64, 3, 2) # 64 -> 128 82 | conv_t2b = _conv_tranpose_layer(conv_l3_d8, 64, 3, 2) # 64 -> 128 83 | 84 | # -> Output: Level 3 85 | 86 | conv_l3_out = _conv_layer(conv_l3_d8, 3, 3, 1, relu=False, instance_norm=False) 87 | output_l3 = tf.nn.tanh(conv_l3_out) * 0.58 + 0.5 88 | 89 | # ------------------------------------------- 90 | # Processing: Level 2, Input size: 128 x 128 91 | 92 | conv_l2_d2 = stack(conv_l2_d1, conv_t2a) 93 | conv_l2_d3 = stack(_conv_multi_block(conv_l2_d2, 5, num_maps=64, instance_norm=instance_norm), conv_l2_d1) 94 | 95 | conv_l2_d4 = _conv_multi_block(conv_l2_d3, 7, num_maps=64, instance_norm=instance_norm) + conv_l2_d3 96 | conv_l2_d5 = _conv_multi_block(conv_l2_d4, 7, num_maps=64, instance_norm=instance_norm) + conv_l2_d4 97 | conv_l2_d6 = _conv_multi_block(conv_l2_d5, 7, num_maps=64, instance_norm=instance_norm) + conv_l2_d5 98 | conv_l2_d7 = stack(_conv_multi_block(conv_l2_d6, 7, num_maps=64, instance_norm=instance_norm), conv_l2_d1) 99 | 100 | conv_l2_d8 = stack(_conv_multi_block(conv_l2_d7, 5, num_maps=64, instance_norm=instance_norm), conv_t2b) 101 | conv_l2_d9 = _conv_multi_block(conv_l2_d8, 3, num_maps=64, instance_norm=instance_norm) 102 | 103 | conv_t1a = _conv_tranpose_layer(conv_l2_d9, 32, 3, 2) # 128 -> 256 104 | conv_t1b = _conv_tranpose_layer(conv_l2_d9, 32, 3, 2) # 128 -> 256 105 | 106 | # -> Output: Level 2 107 | 108 | conv_l2_out = _conv_layer(conv_l2_d9, 3, 3, 1, relu=False, instance_norm=False) 109 | output_l2 = tf.nn.tanh(conv_l2_out) * 0.58 + 0.5 110 | 111 | # ------------------------------------------- 112 | # Processing: Level 1, Input size: 256 x 256 113 | 114 | conv_l1_d2 = stack(conv_l1_d1, conv_t1a) 115 | conv_l1_d3 = stack(_conv_multi_block(conv_l1_d2, 5, num_maps=32, instance_norm=False), conv_l1_d1) 116 | 117 | conv_l1_d4 = _conv_multi_block(conv_l1_d3, 7, num_maps=32, instance_norm=False) 118 | 119 | conv_l1_d5 = _conv_multi_block(conv_l1_d4, 9, num_maps=32, instance_norm=instance_norm_level_1) 120 | conv_l1_d6 = _conv_multi_block(conv_l1_d5, 9, num_maps=32, instance_norm=instance_norm_level_1) + conv_l1_d5 121 | conv_l1_d7 = _conv_multi_block(conv_l1_d6, 9, num_maps=32, instance_norm=instance_norm_level_1) + conv_l1_d6 122 | conv_l1_d8 = _conv_multi_block(conv_l1_d7, 9, num_maps=32, instance_norm=instance_norm_level_1) + conv_l1_d7 123 | 124 | conv_l1_d9 = stack(_conv_multi_block(conv_l1_d8, 7, num_maps=32, instance_norm=False), conv_l1_d1) 125 | 126 | conv_l1_d10 = stack(_conv_multi_block(conv_l1_d9, 5, num_maps=32, instance_norm=False), conv_t1b) 127 | conv_l1_d11 = stack(conv_l1_d10, conv_l1_d1) 128 | 129 | conv_l1_d12 = _conv_multi_block(conv_l1_d11, 3, num_maps=32, instance_norm=False) 130 | 131 | # -> Output: Level 1 132 | 133 | conv_l1_out = _conv_layer(conv_l1_d12, 3, 3, 1, relu=False, instance_norm=False) 134 | output_l1 = tf.nn.tanh(conv_l1_out) * 0.58 + 0.5 135 | 136 | # ---------------------------------------------------------- 137 | # Processing: Level 0 (x2 upscaling), Input size: 256 x 256 138 | 139 | conv_l0 = _conv_tranpose_layer(conv_l1_d12, 8, 3, 2) # 256 -> 512 140 | conv_l0_out = _conv_layer(conv_l0, 3, 3, 1, relu=False, instance_norm=False) 141 | output_l0 = tf.nn.tanh(conv_l0_out) * 0.58 + 0.5 142 | 143 | # ---------------------------------------------------------- 144 | # Processing: Level Up (x4 upscaling), Input size: 512 x 512 145 | 146 | conv_l_up = _conv_tranpose_layer(conv_l0_out, 3, 3, 2) # 512 -> 1024 147 | conv_l_up_out = _conv_layer(conv_l_up, 3, 3, 1, relu=False, instance_norm=False) 148 | 149 | output_l_up = tf.nn.tanh(conv_l_up_out) * 0.58 + 0.5 150 | 151 | return output_l_up, output_l0, output_l1, output_l2, output_l3, output_l4, output_l5 152 | 153 | 154 | def _conv_multi_block(input, max_size, num_maps, instance_norm): 155 | 156 | conv_3a = _conv_layer(input, num_maps, 3, 1, relu=True, instance_norm=instance_norm) 157 | conv_3b = _conv_layer(conv_3a, num_maps, 3, 1, relu=True, instance_norm=instance_norm) 158 | 159 | output_tensor = conv_3b 160 | 161 | if max_size >= 5: 162 | 163 | conv_5a = _conv_layer(input, num_maps, 5, 1, relu=True, instance_norm=instance_norm) 164 | conv_5b = _conv_layer(conv_5a, num_maps, 5, 1, relu=True, instance_norm=instance_norm) 165 | 166 | output_tensor = stack(output_tensor, conv_5b) 167 | 168 | if max_size >= 7: 169 | 170 | conv_7a = _conv_layer(input, num_maps, 7, 1, relu=True, instance_norm=instance_norm) 171 | conv_7b = _conv_layer(conv_7a, num_maps, 7, 1, relu=True, instance_norm=instance_norm) 172 | 173 | output_tensor = stack(output_tensor, conv_7b) 174 | 175 | if max_size >= 9: 176 | 177 | conv_9a = _conv_layer(input, num_maps, 9, 1, relu=True, instance_norm=instance_norm) 178 | conv_9b = _conv_layer(conv_9a, num_maps, 9, 1, relu=True, instance_norm=instance_norm) 179 | 180 | output_tensor = stack(output_tensor, conv_9b) 181 | 182 | return output_tensor 183 | 184 | 185 | def stack(x, y): 186 | return tf.concat([x, y], 3) 187 | 188 | 189 | def conv2d(x, W): 190 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 191 | 192 | 193 | def leaky_relu(x, alpha=0.2): 194 | return tf.maximum(alpha * x, x) 195 | 196 | 197 | def _conv_layer(net, num_filters, filter_size, strides, relu=True, instance_norm=False, padding='SAME'): 198 | 199 | weights_init = _conv_init_vars(net, num_filters, filter_size) 200 | strides_shape = [1, strides, strides, 1] 201 | bias = tf.Variable(tf.constant(0.01, shape=[num_filters])) 202 | 203 | net = tf.nn.conv2d(net, weights_init, strides_shape, padding=padding) + bias 204 | 205 | if instance_norm: 206 | net = _instance_norm(net) 207 | 208 | if relu: 209 | net = leaky_relu(net) 210 | 211 | return net 212 | 213 | 214 | def _instance_norm(net): 215 | 216 | batch, rows, cols, channels = [i.value for i in net.get_shape()] 217 | var_shape = [channels] 218 | 219 | mu, sigma_sq = tf.compat.v1.nn.moments(net, [1,2], keep_dims=True) 220 | shift = tf.Variable(tf.zeros(var_shape)) 221 | scale = tf.Variable(tf.ones(var_shape)) 222 | 223 | epsilon = 1e-3 224 | normalized = (net-mu)/(sigma_sq + epsilon)**(.5) 225 | 226 | return scale * normalized + shift 227 | 228 | 229 | def _conv_init_vars(net, out_channels, filter_size, transpose=False): 230 | 231 | _, rows, cols, in_channels = [i.value for i in net.get_shape()] 232 | 233 | if not transpose: 234 | weights_shape = [filter_size, filter_size, in_channels, out_channels] 235 | else: 236 | weights_shape = [filter_size, filter_size, out_channels, in_channels] 237 | 238 | weights_init = tf.Variable(tf.compat.v1.truncated_normal(weights_shape, stddev=0.01, seed=1), dtype=tf.float32) 239 | return weights_init 240 | 241 | 242 | def _conv_tranpose_layer(net, num_filters, filter_size, strides): 243 | weights_init = _conv_init_vars(net, num_filters, filter_size, transpose=True) 244 | 245 | net_shape = tf.shape(net) 246 | tf_shape = tf.stack([net_shape[0], net_shape[1] * strides, net_shape[2] * strides, num_filters]) 247 | 248 | strides_shape = [1, strides, strides, 1] 249 | net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding='SAME') 250 | 251 | return leaky_relu(net) 252 | 253 | 254 | def max_pool(x, n): 255 | return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='VALID') 256 | -------------------------------------------------------------------------------- /models/original/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/models/original/.gitkeep -------------------------------------------------------------------------------- /results/full-resolution/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/results/full-resolution/.gitkeep -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 by Andrey Ignatov. All Rights Reserved. 2 | 3 | from scipy import misc 4 | import numpy as np 5 | import tensorflow as tf 6 | import sys 7 | import os 8 | 9 | tf.compat.v1.disable_v2_behavior() 10 | 11 | from load_dataset import load_input_image 12 | from model import PyNET 13 | import utils 14 | import sys 15 | 16 | LEVEL, restore_iter, dataset_dir, use_gpu, orig_model = utils.process_test_model_args(sys.argv) 17 | DSLR_SCALE = float(1) / (2 ** (LEVEL - 2)) 18 | 19 | # Disable gpu if specified 20 | config = tf.compat.v1.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else None 21 | 22 | 23 | with tf.compat.v1.Session(config=config) as sess: 24 | 25 | # Placeholders for test data 26 | x_ = tf.compat.v1.placeholder(tf.float32, [1, None, None, 4]) 27 | y_ = tf.compat.v1.placeholder(tf.float32, [1, None, None, 3]) 28 | 29 | # generate bokeh image 30 | 31 | output_l1, output_l2, output_l3, output_l4, output_l5, output_l6, output_l7 = \ 32 | PyNET(x_, instance_norm=True, instance_norm_level_1=False) 33 | 34 | if LEVEL < 1: 35 | print("Lvel number cannot be less than 1. Aborting.") 36 | sys.exit() 37 | if LEVEL > 3: 38 | print("Larger images are needed for computing PSNR / SSIM scores. Aborting.") 39 | sys.exit() 40 | if LEVEL == 3: 41 | bokeh_img = output_l3 42 | if LEVEL == 2: 43 | bokeh_img = output_l2 44 | if LEVEL == 1: 45 | bokeh_img = output_l1 46 | 47 | bokeh_img = tf.clip_by_value(bokeh_img, 0.0, 1.0) 48 | 49 | # Removing the boundary (32 px) from the resulting / target images 50 | 51 | crop_height_ = tf.compat.v1.placeholder(tf.int32) 52 | crop_width_ = tf.compat.v1.placeholder(tf.int32) 53 | 54 | bokeh_img_cropped = tf.image.crop_to_bounding_box(bokeh_img, 32, 32, crop_height_, crop_width_) 55 | y_cropped = tf.image.crop_to_bounding_box(y_, 32, 32, crop_height_, crop_width_) 56 | 57 | # Losses 58 | 59 | loss_psnr = tf.reduce_mean(tf.image.psnr(bokeh_img_cropped, y_cropped, 1.0)) 60 | loss_ssim = tf.reduce_mean(tf.image.ssim(bokeh_img_cropped, y_cropped, 1.0)) 61 | loss_ms_ssim = tf.reduce_mean(tf.image.ssim_multiscale(bokeh_img_cropped, y_cropped, 1.0)) 62 | 63 | # Loading pre-trained model 64 | 65 | saver = tf.compat.v1.train.Saver() 66 | 67 | if orig_model == "true": 68 | saver.restore(sess, "models/original/pynet_bokeh_level_0") 69 | else: 70 | saver.restore(sess, "models/pynet_level_" + str(LEVEL) + "_iteration_" + str(restore_iter) + ".ckpt") 71 | 72 | # ------------------------------------------------- 73 | # Part 1: Processing sample full-resolution images 74 | 75 | print("Generating sample visual results") 76 | 77 | sample_images_dir = "visual_samples/images/" 78 | sample_depth_maps_dir = "visual_samples/depth_maps/" 79 | 80 | sample_images = [f for f in os.listdir(sample_images_dir) if os.path.isfile(sample_images_dir + f)] 81 | 82 | for photo in sample_images: 83 | 84 | # Load image 85 | 86 | I = load_input_image(sample_images_dir, sample_depth_maps_dir, photo) 87 | 88 | # Run inference 89 | 90 | bokeh_tensor = sess.run(bokeh_img, feed_dict={x_: I}) 91 | bokeh_image = np.reshape(bokeh_tensor, [int(I.shape[1] * DSLR_SCALE), int(I.shape[2] * DSLR_SCALE), 3]) 92 | 93 | # Save the results as .png images 94 | photo_name = photo.rsplit(".", 1)[0] 95 | misc.imsave("results/full-resolution/" + photo_name + "_level_" + str(LEVEL) + 96 | "_iteration_" + str(restore_iter) + ".png", bokeh_image) 97 | 98 | # ------------------------------------------------------------------------ 99 | # Part 1: Compute PSNR / SSIM scores on the test part of the EBB! dataset 100 | 101 | print("Performing quantitative evaluation") 102 | 103 | test_directory_orig = dataset_dir + 'test/original/' 104 | test_directory_orig_depth = dataset_dir + 'test/original_depth/' 105 | test_directory_blur = dataset_dir + 'test/bokeh/' 106 | 107 | test_images = [f for f in os.listdir(test_directory_orig) if os.path.isfile(os.path.join(test_directory_orig, f))] 108 | 109 | loss_psnr_ = 0.0 110 | loss_ssim_ = 0.0 111 | loss_msssim_ = 0.0 112 | 113 | test_size = len(test_images) 114 | iter_ = 0 115 | 116 | for photo in test_images: 117 | 118 | # Load image 119 | 120 | I = load_input_image(test_directory_orig, test_directory_orig_depth, photo) 121 | 122 | Y = misc.imread(test_directory_blur + photo) / 255.0 123 | Y = np.float32(misc.imresize(Y, DSLR_SCALE / 2, interp='bicubic')) / 255.0 124 | Y = np.reshape(Y, [1, Y.shape[0], Y.shape[1], 3]) 125 | 126 | loss_psnr_temp, loss_ssim_temp, loss_msssim_temp = sess.run([loss_psnr, loss_ssim, loss_ms_ssim], 127 | feed_dict={x_: I, y_: Y, crop_height_: Y.shape[1] - 64, crop_width_: Y.shape[2] - 64}) 128 | 129 | print(photo, iter_, loss_psnr_temp, loss_ssim_temp, loss_msssim_temp) 130 | 131 | loss_psnr_ += loss_psnr_temp / test_size 132 | loss_ssim_ += loss_ssim_temp / test_size 133 | loss_msssim_ += loss_msssim_temp / test_size 134 | 135 | iter_ += 1 136 | 137 | output_logs = "PSNR: %.4g, SSIM: %.4g, MS-SSIM: %.4g\n" % (loss_psnr_, loss_ssim_, loss_msssim_) 138 | print(output_logs) 139 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 by Andrey Ignatov. All Rights Reserved. 2 | 3 | import tensorflow as tf 4 | from scipy import misc 5 | import numpy as np 6 | import sys 7 | 8 | tf.compat.v1.disable_v2_behavior() 9 | 10 | from load_dataset import load_training_batch, load_test_data 11 | from model import PyNET 12 | import utils 13 | import vgg 14 | 15 | # Processing command arguments 16 | 17 | LEVEL, batch_size, train_size, learning_rate, restore_iter, num_train_iters, dataset_dir, vgg_dir, eval_step = \ 18 | utils.process_command_args(sys.argv) 19 | 20 | # Defining the size of the input and target image patches 21 | 22 | PATCH_WIDTH, PATCH_HEIGHT = 512, 512 23 | DSLR_SCALE = float(1) / (2 ** (LEVEL - 2)) 24 | 25 | TARGET_WIDTH = int(PATCH_WIDTH * DSLR_SCALE) 26 | TARGET_HEIGHT = int(PATCH_HEIGHT * DSLR_SCALE) 27 | TARGET_DEPTH = 3 28 | TARGET_SIZE = TARGET_WIDTH * TARGET_HEIGHT * TARGET_DEPTH 29 | 30 | np.random.seed(0) 31 | 32 | # Defining the model architecture 33 | 34 | with tf.Graph().as_default(), tf.compat.v1.Session() as sess: 35 | 36 | # Placeholders for training data 37 | 38 | input_ = tf.compat.v1.placeholder(tf.float32, [batch_size, PATCH_HEIGHT, PATCH_WIDTH, 4]) 39 | target_ = tf.compat.v1.placeholder(tf.float32, [batch_size, TARGET_HEIGHT, TARGET_WIDTH, TARGET_DEPTH]) 40 | 41 | # Get the rendered bokeh image 42 | 43 | output_l1, output_l2, output_l3, output_l4, output_l5, output_l6, output_l7 = \ 44 | PyNET(input_, instance_norm=True, instance_norm_level_1=False) 45 | 46 | if LEVEL == 7: 47 | bokeh_img = output_l7 48 | if LEVEL == 6: 49 | bokeh_img = output_l6 50 | if LEVEL == 5: 51 | bokeh_img = output_l5 52 | if LEVEL == 4: 53 | bokeh_img = output_l4 54 | if LEVEL == 3: 55 | bokeh_img = output_l3 56 | if LEVEL == 2: 57 | bokeh_img = output_l2 58 | if LEVEL == 1: 59 | bokeh_img = output_l1 60 | 61 | # Losses 62 | 63 | bokeh_img_flat = tf.reshape(bokeh_img, [-1, TARGET_SIZE]) 64 | target_flat = tf.reshape(target_, [-1, TARGET_SIZE]) 65 | 66 | # MSE loss 67 | loss_mse = tf.reduce_sum(tf.pow(target_flat - bokeh_img_flat, 2)) / (TARGET_SIZE * batch_size) 68 | 69 | # PSNR loss 70 | loss_psnr = 20 * utils.log10(1.0 / tf.sqrt(loss_mse)) 71 | 72 | # SSIM loss 73 | loss_ssim = tf.reduce_mean(tf.image.ssim(bokeh_img, target_, 1.0)) 74 | 75 | # MS-SSIM loss 76 | loss_ms_ssim = tf.reduce_mean(tf.image.ssim_multiscale(bokeh_img, target_, 1.0)) 77 | 78 | # L1 loss 79 | loss_l1 = tf.compat.v1.losses.absolute_difference(bokeh_img, target_) 80 | 81 | # Content loss 82 | CONTENT_LAYER = 'relu5_4' 83 | 84 | bokeh_img_vgg = vgg.net(vgg_dir, vgg.preprocess(bokeh_img * 255)) 85 | target_vgg = vgg.net(vgg_dir, vgg.preprocess(target_ * 255)) 86 | 87 | content_size = utils._tensor_size(target_vgg[CONTENT_LAYER]) * batch_size 88 | loss_content = 2 * tf.nn.l2_loss(bokeh_img_vgg[CONTENT_LAYER] - target_vgg[CONTENT_LAYER]) / content_size 89 | 90 | # Final loss function 91 | 92 | if LEVEL > 1: 93 | loss_generator = loss_l1 * 100 94 | else: 95 | loss_generator = loss_l1 * 10 + loss_content * 0.1 + (1 - loss_ssim) * 10 96 | 97 | # Optimize network parameters 98 | 99 | generator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("generator")] 100 | train_step_gen = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_generator) 101 | 102 | # Initialize and restore the variables 103 | 104 | print("Initializing variables") 105 | sess.run(tf.compat.v1.global_variables_initializer()) 106 | 107 | saver = tf.compat.v1.train.Saver(var_list=generator_vars, max_to_keep=100) 108 | 109 | if LEVEL < 7: 110 | print("Restoring Variables") 111 | saver.restore(sess, "models/pynet_level_" + str(LEVEL + 1) + "_iteration_" + str(restore_iter) + ".ckpt") 112 | 113 | saver = tf.compat.v1.train.Saver(var_list=generator_vars, max_to_keep=100) 114 | 115 | # Loading training and test data 116 | 117 | print("Loading test data...") 118 | test_data, test_answ = load_test_data(dataset_dir, PATCH_WIDTH, PATCH_HEIGHT, DSLR_SCALE) 119 | print("Test data was loaded\n") 120 | 121 | print("Loading training data...") 122 | train_data, train_answ = load_training_batch(dataset_dir, PATCH_WIDTH, PATCH_HEIGHT, DSLR_SCALE, train_size) 123 | print("Training data was loaded\n") 124 | 125 | TEST_SIZE = test_data.shape[0] 126 | num_test_batches = int(test_data.shape[0] / batch_size) 127 | 128 | visual_crops_ids = np.random.randint(0, TEST_SIZE, batch_size) 129 | visual_test_crops = test_data[visual_crops_ids, :] 130 | visual_target_crops = test_answ[visual_crops_ids, :] 131 | 132 | print("Training network") 133 | 134 | logs = open("models/logs.txt", "w+") 135 | logs.close() 136 | 137 | training_loss = 0.0 138 | 139 | for i in range(num_train_iters + 1): 140 | 141 | # Train PyNET model 142 | 143 | idx_train = np.random.randint(0, train_size, batch_size) 144 | 145 | phone_images = train_data[idx_train] 146 | dslr_images = train_answ[idx_train] 147 | 148 | # Random flips and rotations 149 | 150 | for k in range(batch_size): 151 | 152 | random_rotate = np.random.randint(1, 100) % 4 153 | phone_images[k] = np.rot90(phone_images[k], random_rotate) 154 | dslr_images[k] = np.rot90(dslr_images[k], random_rotate) 155 | random_flip = np.random.randint(1, 100) % 2 156 | 157 | if random_flip == 1: 158 | phone_images[k] = np.flipud(phone_images[k]) 159 | dslr_images[k] = np.flipud(dslr_images[k]) 160 | 161 | # Training step 162 | 163 | [loss_temp, temp] = sess.run([loss_generator, train_step_gen], feed_dict={input_: phone_images, target_: dslr_images}) 164 | training_loss += loss_temp / eval_step 165 | 166 | if i % eval_step == 0: 167 | 168 | # Evaluate PyNET model 169 | 170 | test_losses = np.zeros((1, 6 if LEVEL < 4 else 5)) 171 | 172 | for j in range(num_test_batches): 173 | 174 | be = j * batch_size 175 | en = (j+1) * batch_size 176 | 177 | phone_images = test_data[be:en] 178 | dslr_images = test_answ[be:en] 179 | 180 | if LEVEL < 4: 181 | losses = sess.run([loss_generator, loss_content, loss_mse, loss_psnr, loss_l1, loss_ms_ssim], \ 182 | feed_dict={input_: phone_images, target_: dslr_images}) 183 | else: 184 | losses = sess.run([loss_generator, loss_content, loss_mse, loss_psnr, loss_l1], \ 185 | feed_dict={input_: phone_images, target_: dslr_images}) 186 | 187 | test_losses += np.asarray(losses) / num_test_batches 188 | 189 | if LEVEL < 4: 190 | logs_gen = "step %d | training: %.4g, test: %.4g | content: %.4g, mse: %.4g, psnr: %.4g, l1: %.4g, " \ 191 | "ms-ssim: %.4g\n" % (i, training_loss, test_losses[0][0], test_losses[0][1], 192 | test_losses[0][2], test_losses[0][3], test_losses[0][4], test_losses[0][5]) 193 | else: 194 | logs_gen = "step %d | training: %.4g, test: %.4g | content: %.4g, mse: %.4g, psnr: %.4g, l1: %.4g\n" % \ 195 | (i, training_loss, test_losses[0][0], test_losses[0][1], test_losses[0][2], test_losses[0][3], test_losses[0][4]) 196 | print(logs_gen) 197 | 198 | # Save the results to log file 199 | 200 | logs = open("models/logs.txt", "a") 201 | logs.write(logs_gen) 202 | logs.write('\n') 203 | logs.close() 204 | 205 | # Save visual results for several test images 206 | 207 | bokeh_crops = sess.run(bokeh_img, feed_dict={input_: visual_test_crops, target_: dslr_images}) 208 | 209 | idx = 0 210 | for crop in bokeh_crops: 211 | if idx < 7: 212 | before_after = np.hstack(( 213 | np.float32(misc.imresize( 214 | np.reshape(visual_test_crops[idx, :, :, 0:3] * 255, [PATCH_HEIGHT, PATCH_WIDTH, 3]), 215 | [TARGET_HEIGHT, TARGET_WIDTH])) / 255.0, 216 | crop, 217 | np.reshape(visual_target_crops[idx], [TARGET_HEIGHT, TARGET_WIDTH, TARGET_DEPTH]))) 218 | misc.imsave("results/pynet_img_" + str(idx) + "_level_" + str(LEVEL) + "_iter_" + str(i) + ".jpg", 219 | before_after) 220 | idx += 1 221 | 222 | training_loss = 0.0 223 | 224 | # Saving the model that corresponds to the current iteration 225 | saver.save(sess, "models/pynet_level_" + str(LEVEL) + "_iteration_" + str(i) + ".ckpt", write_meta_graph=False) 226 | 227 | # Loading new training data 228 | if i % 1000 == 0: 229 | 230 | del train_data 231 | del train_answ 232 | train_data, train_answ = load_training_batch(dataset_dir, PATCH_WIDTH, PATCH_HEIGHT, DSLR_SCALE, train_size) 233 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 by Andrey Ignatov. All Rights Reserved. 2 | 3 | from functools import reduce 4 | import tensorflow as tf 5 | import numpy as np 6 | import sys 7 | import os 8 | 9 | NUM_DEFAULT_TRAIN_ITERS = [-1, 100000, 80000, 30000, 20000, 20000, 5000, 5000] 10 | 11 | 12 | def process_command_args(arguments): 13 | 14 | # Specifying the default parameters 15 | 16 | level = 1 17 | batch_size = 50 18 | 19 | train_size = 4894 20 | learning_rate = 5e-5 21 | 22 | eval_step = 1000 23 | restore_iter = None 24 | num_train_iters = None 25 | 26 | dataset_dir = 'ebb_dataset/' 27 | vgg_dir = 'vgg_pretrained/imagenet-vgg-verydeep-19.mat' 28 | 29 | for args in arguments: 30 | 31 | if args.startswith("level"): 32 | level = int(args.split("=")[1]) 33 | 34 | if args.startswith("batch_size"): 35 | batch_size = int(args.split("=")[1]) 36 | 37 | if args.startswith("train_size"): 38 | train_size = int(args.split("=")[1]) 39 | 40 | if args.startswith("learning_rate"): 41 | learning_rate = float(args.split("=")[1]) 42 | 43 | if args.startswith("restore_iter"): 44 | restore_iter = int(args.split("=")[1]) 45 | 46 | if args.startswith("num_train_iters"): 47 | num_train_iters = int(args.split("=")[1]) 48 | 49 | # ----------------------------------- 50 | 51 | if args.startswith("dataset_dir"): 52 | dataset_dir = args.split("=")[1] 53 | 54 | if args.startswith("vgg_dir"): 55 | vgg_dir = args.split("=")[1] 56 | 57 | if args.startswith("eval_step"): 58 | eval_step = int(args.split("=")[1]) 59 | 60 | if restore_iter is None and level < 7: 61 | restore_iter = get_last_iter(level + 1) 62 | if restore_iter == -1: 63 | print("Error: Cannot find any pre-trained models for PyNET's level " + str(level + 1) + ".") 64 | print("Aborting the training.") 65 | sys.exit() 66 | 67 | if num_train_iters is None: 68 | num_train_iters = NUM_DEFAULT_TRAIN_ITERS[level] 69 | 70 | print("The following parameters will be applied for CNN training:") 71 | 72 | print("Training level: " + str(level)) 73 | print("Batch size: " + str(batch_size)) 74 | print("Learning rate: " + str(learning_rate)) 75 | print("Training iterations: " + str(num_train_iters)) 76 | print("Evaluation step: " + str(eval_step)) 77 | print("Restore Iteration: " + str(restore_iter)) 78 | print("Path to the dataset: " + dataset_dir) 79 | print("Path to VGG-19 network: " + vgg_dir) 80 | 81 | return level, batch_size, train_size, learning_rate, restore_iter, num_train_iters,\ 82 | dataset_dir, vgg_dir, eval_step 83 | 84 | 85 | def process_test_model_args(arguments): 86 | 87 | level = 1 88 | restore_iter = None 89 | 90 | dataset_dir = 'ebb_dataset/' 91 | use_gpu = "true" 92 | 93 | orig_model = "false" 94 | 95 | for args in arguments: 96 | 97 | if args.startswith("level"): 98 | level = int(args.split("=")[1]) 99 | 100 | if args.startswith("dataset_dir"): 101 | dataset_dir = args.split("=")[1] 102 | 103 | if args.startswith("restore_iter"): 104 | restore_iter = int(args.split("=")[1]) 105 | 106 | if args.startswith("use_gpu"): 107 | use_gpu = args.split("=")[1] 108 | 109 | if args.startswith("orig"): 110 | orig_model = args.split("=")[1] 111 | 112 | if restore_iter is None and orig_model == "false": 113 | restore_iter = get_last_iter(level) 114 | if restore_iter == -1: 115 | print("Error: Cannot find any pre-trained models for PyNET's level " + str(level) + ".") 116 | sys.exit() 117 | 118 | return level, restore_iter, dataset_dir, use_gpu, orig_model 119 | 120 | 121 | def get_last_iter(level): 122 | 123 | saved_models = [int((model_file.split("_")[-1]).split(".")[0]) 124 | for model_file in os.listdir("models/") 125 | if model_file.startswith("pynet_level_" + str(level))] 126 | 127 | if len(saved_models) > 0: 128 | return np.max(saved_models) 129 | else: 130 | return -1 131 | 132 | 133 | def log10(x): 134 | numerator = tf.compat.v1.log(x) 135 | denominator = tf.compat.v1.log(tf.constant(10, dtype=numerator.dtype)) 136 | return numerator / denominator 137 | 138 | 139 | def _tensor_size(tensor): 140 | from operator import mul 141 | return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1) 142 | 143 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-2020 by Andrey Ignatov. All Rights Reserved. 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import scipy.io 6 | 7 | IMAGE_MEAN = np.array([123.68, 116.779, 103.939]) 8 | 9 | 10 | def net(path_to_vgg_net, input_image): 11 | 12 | layers = ( 13 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 14 | 15 | 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 16 | 17 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 18 | 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 19 | 20 | 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 21 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 22 | 23 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 24 | 'relu5_3', 'conv5_4', 'relu5_4' 25 | ) 26 | 27 | data = scipy.io.loadmat(path_to_vgg_net) 28 | weights = data['layers'][0] 29 | 30 | net = {} 31 | current = input_image 32 | for i, name in enumerate(layers): 33 | layer_type = name[:4] 34 | if layer_type == 'conv': 35 | kernels, bias = weights[i][0][0][0][0] 36 | kernels = np.transpose(kernels, (1, 0, 2, 3)) 37 | bias = bias.reshape(-1) 38 | current = _conv_layer(current, kernels, bias) 39 | elif layer_type == 'relu': 40 | current = tf.nn.relu(current) 41 | elif layer_type == 'pool': 42 | current = _pool_layer(current) 43 | net[name] = current 44 | 45 | return net 46 | 47 | 48 | def _conv_layer(input, weights, bias): 49 | conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1), padding='SAME') 50 | return tf.nn.bias_add(conv, bias) 51 | 52 | 53 | def _pool_layer(input): 54 | return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1), padding='SAME') 55 | 56 | 57 | def preprocess(image): 58 | return image - IMAGE_MEAN 59 | -------------------------------------------------------------------------------- /vgg_pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/vgg_pretrained/.gitkeep -------------------------------------------------------------------------------- /visual_samples/depth_maps/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/1.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/10.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/11.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/12.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/2.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/3.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/4.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/5.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/6.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/7.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/8.png -------------------------------------------------------------------------------- /visual_samples/depth_maps/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/depth_maps/9.png -------------------------------------------------------------------------------- /visual_samples/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/1.jpg -------------------------------------------------------------------------------- /visual_samples/images/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/10.jpg -------------------------------------------------------------------------------- /visual_samples/images/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/11.jpg -------------------------------------------------------------------------------- /visual_samples/images/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/12.jpg -------------------------------------------------------------------------------- /visual_samples/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/2.jpg -------------------------------------------------------------------------------- /visual_samples/images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/3.jpg -------------------------------------------------------------------------------- /visual_samples/images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/4.jpg -------------------------------------------------------------------------------- /visual_samples/images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/5.jpg -------------------------------------------------------------------------------- /visual_samples/images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/6.jpg -------------------------------------------------------------------------------- /visual_samples/images/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/7.jpg -------------------------------------------------------------------------------- /visual_samples/images/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/8.jpg -------------------------------------------------------------------------------- /visual_samples/images/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiff22/PyNET-Bokeh/98942b9b9c4206b2b8712d317bc2f79b685d50ea/visual_samples/images/9.jpg --------------------------------------------------------------------------------