├── output └── .gitkeep ├── corrupt.png ├── pupper.png ├── 2000_pupper.png ├── 4000_pupper.png ├── LICENSE ├── README.md ├── .gitignore └── deepimg.py /output/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /corrupt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beala/deep-image-prior-tensorflow/HEAD/corrupt.png -------------------------------------------------------------------------------- /pupper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beala/deep-image-prior-tensorflow/HEAD/pupper.png -------------------------------------------------------------------------------- /2000_pupper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beala/deep-image-prior-tensorflow/HEAD/2000_pupper.png -------------------------------------------------------------------------------- /4000_pupper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beala/deep-image-prior-tensorflow/HEAD/4000_pupper.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | All code is licensed under CC-SA 3.0: 2 | 3 | This work is licensed under the Creative Commons Attribution-ShareAlike 3.0 Unported License. To view a copy of this license, visit http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 4 | 5 | `pupper.png` is Copyright Ben Augarten 2017 and used here with permission. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a tensorflow implemenation of [Deep Image Prior](https://dmitryulyanov.github.io/deep_image_prior). 2 | 3 | # Setup 4 | - Install python 3.6, [tensorflow](https://www.tensorflow.org/install/) (tested on 1.5.0+), numpy. 5 | - Run it: `python deepimg.py` 6 | - Every 100 iterations, the current image is written to the `output` directory. 7 | - The input image will be blurred and written to `output/corrupted.png`. This is the starting image that the model attempts to sharpen. 8 | 9 | # Known issues and discrepancies 10 | - This only implements super resolution. 11 | - This uses a Gaussian blur rather than a Lanczos2 kernel for the downsampling operator in E. 12 | 13 | # Example 14 | ![Blurred pupper](corrupt.png?raw=true "Original") 15 | ![2000 iterations pupper](2000_pupper.png?raw=true "After 2000 iterations") 16 | ![4000 iterations pupper](4000_pupper.png?raw=true "After 4000 iterations") 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | .idea/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | output/* -------------------------------------------------------------------------------- /deepimg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import scipy.stats as st 4 | 5 | # Name of image to upscale. Must be a 256x256 PNG. 6 | image_name = "pupper.png" 7 | dim = 256 8 | 9 | def load_image(filename, dim): 10 | with open(image_name, 'rb') as f: 11 | raw_image = tf.image.decode_png(f.read()) 12 | 13 | converted = tf.image.convert_image_dtype( 14 | raw_image, 15 | tf.float32, 16 | saturate=True 17 | ) 18 | 19 | resized = tf.image.resize_images( 20 | images = converted, 21 | size = [dim, dim] 22 | ) 23 | 24 | resized.set_shape((dim,dim,3)) 25 | 26 | blur = gblur(tf.expand_dims(resized, 0)) 27 | 28 | return blur 29 | 30 | def save_image(filename, image): 31 | converted_img = tf.image.convert_image_dtype( 32 | image, 33 | tf.uint8, 34 | saturate=True) 35 | 36 | encoded_img = tf.image.encode_png(converted_img) 37 | 38 | with open(filename, 'wb') as f: 39 | f.write(encoded_img.eval()) 40 | 41 | def down_layer(layer): 42 | layer = tf.contrib.layers.conv2d( 43 | inputs=layer, 44 | num_outputs=128, 45 | kernel_size=3, 46 | stride=2, 47 | padding='SAME', 48 | activation_fn=None) 49 | 50 | layer = tf.contrib.layers.batch_norm( 51 | inputs=layer, 52 | activation_fn=tf.nn.leaky_relu) 53 | 54 | layer = tf.contrib.layers.conv2d( 55 | inputs=layer, 56 | num_outputs=128, 57 | kernel_size=3, 58 | stride=1, 59 | padding='SAME', 60 | activation_fn=None) 61 | 62 | layer = tf.contrib.layers.batch_norm( 63 | inputs=layer, 64 | activation_fn=tf.nn.leaky_relu) 65 | 66 | return layer 67 | 68 | def up_layer(layer): 69 | layer = tf.contrib.layers.batch_norm( 70 | inputs=layer) 71 | 72 | layer = tf.contrib.layers.conv2d( 73 | inputs=layer, 74 | num_outputs=128, 75 | kernel_size=3, 76 | padding='SAME', 77 | activation_fn=None) 78 | 79 | layer = tf.contrib.layers.batch_norm( 80 | inputs=layer, 81 | activation_fn=tf.nn.leaky_relu 82 | ) 83 | 84 | layer = tf.contrib.layers.conv2d( 85 | inputs=layer, 86 | num_outputs=3, 87 | kernel_size=1, 88 | padding='SAME', 89 | activation_fn=None) 90 | 91 | layer = tf.contrib.layers.batch_norm( 92 | inputs=layer, 93 | activation_fn=tf.nn.leaky_relu) 94 | 95 | height, width = layer.get_shape()[1:3] 96 | layer = tf.image.resize_images( 97 | images = layer, 98 | size = [height*2, width*2] 99 | ) 100 | 101 | return layer 102 | 103 | def skip(layer): 104 | conv_out = tf.contrib.layers.conv2d( 105 | inputs=layer, 106 | num_outputs=4, 107 | kernel_size=1, 108 | stride=1, 109 | padding='SAME', 110 | normalizer_fn = tf.contrib.layers.batch_norm, 111 | activation_fn=tf.nn.leaky_relu) 112 | 113 | return conv_out 114 | 115 | # Code from https://stackoverflow.com/a/29731818 116 | def gkern(kernlen=5, nsig=3): 117 | """Returns a 2D Gaussian kernel array.""" 118 | 119 | interval = (2*nsig+1.)/(kernlen) 120 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) 121 | kern1d = np.diff(st.norm.cdf(x)) 122 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 123 | kernel = kernel_raw/kernel_raw.sum() 124 | return (tf.convert_to_tensor(kernel, dtype=tf.float32)) 125 | 126 | # Apply the gaussian kernel to each channel to give the image a 127 | # gaussian blur. 128 | def gblur(layer): 129 | gaus_filter = tf.expand_dims(tf.stack([gkern(),gkern(),gkern()], axis=2), axis=3) 130 | return tf.nn.depthwise_conv2d(layer, gaus_filter, strides=[1,1,1,1], padding='SAME') 131 | 132 | # The number of down sampling and up sampling layers. 133 | # These should be equal if the ouput and input images 134 | # are to be equal. 135 | down_layer_count = 5 136 | up_layer_count = 5 137 | 138 | image = load_image(image_name, dim) 139 | 140 | rand = tf.placeholder(shape=(1,dim,dim,32), dtype=tf.float32) 141 | 142 | # TODO: test if 32 channels improves performance 143 | out = tf.constant(np.random.uniform(0, 0.1, size=(1,dim,dim,32)), dtype=tf.float32) + rand 144 | 145 | # Connect up all the downsampling layers. 146 | skips = [] 147 | for i in range(down_layer_count): 148 | out = down_layer(out) 149 | # Keep a list of the skip layers, so they can be connected 150 | # to the upsampling layers. 151 | skips.append(skip(out)) 152 | 153 | print("Shape after downsample: " + str(out.get_shape())) 154 | 155 | # Connect up the upsampling layers, from smallest to largest. 156 | skips.reverse() 157 | for i in range(up_layer_count): 158 | if i == 0: 159 | # As specified in the paper, the first upsampling layers is connected to 160 | # the last downsampling layer through a skip layer. 161 | out = up_layer(skip(out)) 162 | else: 163 | # The output of the rest of the skip layers is concatenated onto 164 | # the input of each upsampling layer. 165 | # Note: It's not clear from the paper if concat is the right operator 166 | # but nothing else makes sense for the shape of the tensors. 167 | out = up_layer(tf.concat([out, skips[i]], axis=3)) 168 | 169 | print("Shape after upsample: " + str(out.get_shape())) 170 | 171 | # Restore original image dimensions and channels 172 | out = tf.contrib.layers.conv2d( 173 | inputs=out, 174 | num_outputs=3, 175 | kernel_size=1, 176 | stride=1, 177 | padding='SAME', 178 | activation_fn=tf.nn.sigmoid) 179 | print("Output shape: " + str(out.get_shape())) 180 | 181 | E = tf.losses.mean_squared_error(image, gblur(out)) 182 | 183 | optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 184 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 185 | with tf.control_dependencies(update_ops): 186 | train_op = optimizer.minimize(E) 187 | 188 | sess = tf.InteractiveSession() 189 | sess.run(tf.global_variables_initializer()) 190 | 191 | save_image("output/corrupt.png", tf.reshape(image, (dim,dim,3))) 192 | 193 | for i in range(5001): 194 | new_rand = np.random.uniform(0, 1.0/30.0, size=(1,dim,dim,32)) 195 | _, lossval = sess.run( 196 | [train_op, E], 197 | feed_dict = {rand: new_rand} 198 | ) 199 | if i % 100 == 0: 200 | image_out = sess.run(out, feed_dict={rand: new_rand}).reshape(dim,dim,3) 201 | save_image("output/%d_%s" % (i, image_name), image_out) 202 | print(i, lossval) 203 | --------------------------------------------------------------------------------