├── .gitignore ├── LICENSE.md ├── README.md ├── requirements.txt ├── srez_demo.py ├── srez_input.py ├── srez_main.py ├── srez_model.py ├── srez_sample_output.png └── srez_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoint*/ 3 | dataset/ 4 | train*/ 5 | 6 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 David Garcia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # srez 2 | 3 | Image super-resolution through deep learning. This project uses deep learning to upscale 16x16 images by a 4x factor. The resulting 64x64 images display sharp features that are plausible based on the dataset that was used to train the neural net. 4 | 5 | Here's an random, non cherry-picked, example of what this network can do. From left to right, the first column is the 16x16 input image, the second one is what you would get from a standard bicubic interpolation, the third is the output generated by the neural net, and on the right is the ground truth. 6 | 7 | ![Example output](srez_sample_output.png) 8 | 9 | As you can see, the network is able to produce a very plausible reconstruction of the original face. As the dataset is mainly composed of well-illuminated faces looking straight ahead, the reconstruction is poorer when the face is at an angle, poorly illuminated, or partially occluded by eyeglasses or hands. 10 | 11 | This particular example was produced after training the network for 3 hours on a GTX 1080 GPU, equivalent to 130,000 batches or about 10 epochs. 12 | 13 | # How it works 14 | 15 | In essence the architecture is a DCGAN where the input to the generator network is the 16x16 image rather than a multinomial gaussian distribution. 16 | 17 | In addition to that the loss function of the generator has a term that measures the L1 difference between the 16x16 input and downscaled version of the image produced by the generator. 18 | 19 | The adversarial term of the loss function ensures the generator produces plausible faces, while the L1 term ensures that those faces resemble the low-res input data. We have found that this L1 term greatly accelerates the convergence of the network during the first batches and also appears to prevent the generator from getting stuck in a poor local solution. 20 | 21 | Finally, the generator network relies on ResNet modules as we've found them to train substantially faster than more old-fashioned architectures. The adversarial network is much simpler as the use of ResNet modules did not provide an advantage during our experimentation. 22 | 23 | # Requirements 24 | 25 | You will need Python 3 with Tensorflow, numpy, scipy and [moviepy](http://zulko.github.io/moviepy/). See `requirements.txt` for details. 26 | 27 | ## Dataset 28 | 29 | After you have the required software above you will also need the `Large-scale CelebFaces Attributes (CelebA) Dataset`. The model expects the `Align&Cropped Images` version. Extract all images to a subfolder named `dataset`. I.e. `srez/dataset/lotsoffiles.jpg`. 30 | 31 | # Training the model 32 | 33 | Training with default settings: `python3 srez_main.py --run train`. The script will periodically output an example batch in PNG format onto the `srez/train` folder, and checkpoint data will be stored in the `srez/checkpoint` folder. 34 | 35 | After the network has trained you can also produce an animation showing the evolution of the output by running `python3 srez_main.py --run demo`. 36 | 37 | # About the author 38 | 39 | [LinkedIn profile of David Garcia](https://ca.linkedin.com/in/david-garcia-70913311). 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | moviepy==0.2.2.11 2 | numpy==1.11.1 3 | scipy==0.18.0 4 | six==1.10.0 5 | tensorflow==0.10.0rc0 6 | -------------------------------------------------------------------------------- /srez_demo.py: -------------------------------------------------------------------------------- 1 | import moviepy.editor as mpe 2 | import numpy as np 3 | import numpy.random 4 | import os.path 5 | import scipy.misc 6 | import tensorflow as tf 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | def demo1(sess): 11 | """Demo based on images dumped during training""" 12 | 13 | # Get images that were dumped during training 14 | filenames = tf.gfile.ListDirectory(FLAGS.train_dir) 15 | filenames = sorted(filenames) 16 | filenames = [os.path.join(FLAGS.train_dir, f) for f in filenames if f[-4:]=='.png'] 17 | 18 | assert len(filenames) >= 1 19 | 20 | fps = 30 21 | 22 | # Create video file from PNGs 23 | print("Producing video file...") 24 | filename = os.path.join(FLAGS.train_dir, 'demo1.mp4') 25 | clip = mpe.ImageSequenceClip(filenames, fps=fps) 26 | clip.write_videofile(filename) 27 | print("Done!") 28 | 29 | -------------------------------------------------------------------------------- /srez_input.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | FLAGS = tf.app.flags.FLAGS 4 | 5 | def setup_inputs(sess, filenames, image_size=None, capacity_factor=3): 6 | 7 | if image_size is None: 8 | image_size = FLAGS.sample_size 9 | 10 | # Read each JPEG file 11 | reader = tf.WholeFileReader() 12 | filename_queue = tf.train.string_input_producer(filenames) 13 | key, value = reader.read(filename_queue) 14 | channels = 3 15 | image = tf.image.decode_jpeg(value, channels=channels, name="dataset_image") 16 | image.set_shape([None, None, channels]) 17 | 18 | # Crop and other random augmentations 19 | image = tf.image.random_flip_left_right(image) 20 | image = tf.image.random_saturation(image, .95, 1.05) 21 | image = tf.image.random_brightness(image, .05) 22 | image = tf.image.random_contrast(image, .95, 1.05) 23 | 24 | wiggle = 8 25 | off_x, off_y = 25-wiggle, 60-wiggle 26 | crop_size = 128 27 | crop_size_plus = crop_size + 2*wiggle 28 | image = tf.image.crop_to_bounding_box(image, off_y, off_x, crop_size_plus, crop_size_plus) 29 | image = tf.random_crop(image, [crop_size, crop_size, 3]) 30 | 31 | image = tf.reshape(image, [1, crop_size, crop_size, 3]) 32 | image = tf.cast(image, tf.float32)/255.0 33 | 34 | if crop_size != image_size: 35 | image = tf.image.resize_area(image, [image_size, image_size]) 36 | 37 | # The feature is simply a Kx downscaled version 38 | K = 4 39 | downsampled = tf.image.resize_area(image, [image_size//K, image_size//K]) 40 | 41 | feature = tf.reshape(downsampled, [image_size//K, image_size//K, 3]) 42 | label = tf.reshape(image, [image_size, image_size, 3]) 43 | 44 | # Using asynchronous queues 45 | features, labels = tf.train.batch([feature, label], 46 | batch_size=FLAGS.batch_size, 47 | num_threads=4, 48 | capacity = capacity_factor*FLAGS.batch_size, 49 | name='labels_and_features') 50 | 51 | tf.train.start_queue_runners(sess=sess) 52 | 53 | return features, labels 54 | -------------------------------------------------------------------------------- /srez_main.py: -------------------------------------------------------------------------------- 1 | import srez_demo 2 | import srez_input 3 | import srez_model 4 | import srez_train 5 | 6 | import os.path 7 | import random 8 | import numpy as np 9 | import numpy.random 10 | 11 | import tensorflow as tf 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | 15 | # Configuration (alphabetically) 16 | tf.app.flags.DEFINE_integer('batch_size', 16, 17 | "Number of samples per batch.") 18 | 19 | tf.app.flags.DEFINE_string('checkpoint_dir', 'checkpoint', 20 | "Output folder where checkpoints are dumped.") 21 | 22 | tf.app.flags.DEFINE_integer('checkpoint_period', 10000, 23 | "Number of batches in between checkpoints") 24 | 25 | tf.app.flags.DEFINE_string('dataset', 'dataset', 26 | "Path to the dataset directory.") 27 | 28 | tf.app.flags.DEFINE_float('epsilon', 1e-8, 29 | "Fuzz term to avoid numerical instability") 30 | 31 | tf.app.flags.DEFINE_string('run', 'demo', 32 | "Which operation to run. [demo|train]") 33 | 34 | tf.app.flags.DEFINE_float('gene_l1_factor', .90, 35 | "Multiplier for generator L1 loss term") 36 | 37 | tf.app.flags.DEFINE_float('learning_beta1', 0.5, 38 | "Beta1 parameter used for AdamOptimizer") 39 | 40 | tf.app.flags.DEFINE_float('learning_rate_start', 0.00020, 41 | "Starting learning rate used for AdamOptimizer") 42 | 43 | tf.app.flags.DEFINE_integer('learning_rate_half_life', 5000, 44 | "Number of batches until learning rate is halved") 45 | 46 | tf.app.flags.DEFINE_bool('log_device_placement', False, 47 | "Log the device where variables are placed.") 48 | 49 | tf.app.flags.DEFINE_integer('sample_size', 64, 50 | "Image sample size in pixels. Range [64,128]") 51 | 52 | tf.app.flags.DEFINE_integer('summary_period', 200, 53 | "Number of batches between summary data dumps") 54 | 55 | tf.app.flags.DEFINE_integer('random_seed', 0, 56 | "Seed used to initialize rng.") 57 | 58 | tf.app.flags.DEFINE_integer('test_vectors', 16, 59 | """Number of features to use for testing""") 60 | 61 | tf.app.flags.DEFINE_string('train_dir', 'train', 62 | "Output folder where training logs are dumped.") 63 | 64 | tf.app.flags.DEFINE_integer('train_time', 20, 65 | "Time in minutes to train the model") 66 | 67 | def prepare_dirs(delete_train_dir=False): 68 | # Create checkpoint dir (do not delete anything) 69 | if not tf.gfile.Exists(FLAGS.checkpoint_dir): 70 | tf.gfile.MakeDirs(FLAGS.checkpoint_dir) 71 | 72 | # Cleanup train dir 73 | if delete_train_dir: 74 | if tf.gfile.Exists(FLAGS.train_dir): 75 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 76 | tf.gfile.MakeDirs(FLAGS.train_dir) 77 | 78 | # Return names of training files 79 | if not tf.gfile.Exists(FLAGS.dataset) or \ 80 | not tf.gfile.IsDirectory(FLAGS.dataset): 81 | raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.dataset,)) 82 | 83 | filenames = tf.gfile.ListDirectory(FLAGS.dataset) 84 | filenames = sorted(filenames) 85 | random.shuffle(filenames) 86 | filenames = [os.path.join(FLAGS.dataset, f) for f in filenames] 87 | 88 | return filenames 89 | 90 | 91 | def setup_tensorflow(): 92 | # Create session 93 | config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement) 94 | sess = tf.Session(config=config) 95 | 96 | # Initialize rng with a deterministic seed 97 | with sess.graph.as_default(): 98 | tf.set_random_seed(FLAGS.random_seed) 99 | 100 | random.seed(FLAGS.random_seed) 101 | np.random.seed(FLAGS.random_seed) 102 | 103 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph) 104 | 105 | return sess, summary_writer 106 | 107 | def _demo(): 108 | # Load checkpoint 109 | if not tf.gfile.IsDirectory(FLAGS.checkpoint_dir): 110 | raise FileNotFoundError("Could not find folder `%s'" % (FLAGS.checkpoint_dir,)) 111 | 112 | # Setup global tensorflow state 113 | sess, summary_writer = setup_tensorflow() 114 | 115 | # Prepare directories 116 | filenames = prepare_dirs(delete_train_dir=False) 117 | 118 | # Setup async input queues 119 | features, labels = srez_input.setup_inputs(sess, filenames) 120 | 121 | # Create and initialize model 122 | [gene_minput, gene_moutput, 123 | gene_output, gene_var_list, 124 | disc_real_output, disc_fake_output, disc_var_list] = \ 125 | srez_model.create_model(sess, features, labels) 126 | 127 | # Restore variables from checkpoint 128 | saver = tf.train.Saver() 129 | filename = 'checkpoint_new.txt' 130 | filename = os.path.join(FLAGS.checkpoint_dir, filename) 131 | saver.restore(sess, filename) 132 | 133 | # Execute demo 134 | srez_demo.demo1(sess) 135 | 136 | class TrainData(object): 137 | def __init__(self, dictionary): 138 | self.__dict__.update(dictionary) 139 | 140 | def _train(): 141 | # Setup global tensorflow state 142 | sess, summary_writer = setup_tensorflow() 143 | 144 | # Prepare directories 145 | all_filenames = prepare_dirs(delete_train_dir=True) 146 | 147 | # Separate training and test sets 148 | train_filenames = all_filenames[:-FLAGS.test_vectors] 149 | test_filenames = all_filenames[-FLAGS.test_vectors:] 150 | 151 | # TBD: Maybe download dataset here 152 | 153 | # Setup async input queues 154 | train_features, train_labels = srez_input.setup_inputs(sess, train_filenames) 155 | test_features, test_labels = srez_input.setup_inputs(sess, test_filenames) 156 | 157 | # Add some noise during training (think denoising autoencoders) 158 | noise_level = .03 159 | noisy_train_features = train_features + \ 160 | tf.random_normal(train_features.get_shape(), stddev=noise_level) 161 | 162 | # Create and initialize model 163 | [gene_minput, gene_moutput, 164 | gene_output, gene_var_list, 165 | disc_real_output, disc_fake_output, disc_var_list] = \ 166 | srez_model.create_model(sess, noisy_train_features, train_labels) 167 | 168 | gene_loss = srez_model.create_generator_loss(disc_fake_output, gene_output, train_features) 169 | disc_real_loss, disc_fake_loss = \ 170 | srez_model.create_discriminator_loss(disc_real_output, disc_fake_output) 171 | disc_loss = tf.add(disc_real_loss, disc_fake_loss, name='disc_loss') 172 | 173 | (global_step, learning_rate, gene_minimize, disc_minimize) = \ 174 | srez_model.create_optimizers(gene_loss, gene_var_list, 175 | disc_loss, disc_var_list) 176 | 177 | # Train model 178 | train_data = TrainData(locals()) 179 | srez_train.train_model(train_data) 180 | 181 | def main(argv=None): 182 | # Training or showing off? 183 | 184 | if FLAGS.run == 'demo': 185 | _demo() 186 | elif FLAGS.run == 'train': 187 | _train() 188 | 189 | if __name__ == '__main__': 190 | tf.app.run() 191 | -------------------------------------------------------------------------------- /srez_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | FLAGS = tf.app.flags.FLAGS 5 | 6 | class Model: 7 | """A neural network model. 8 | 9 | Currently only supports a feedforward architecture.""" 10 | 11 | def __init__(self, name, features): 12 | self.name = name 13 | self.outputs = [features] 14 | 15 | def _get_layer_str(self, layer=None): 16 | if layer is None: 17 | layer = self.get_num_layers() 18 | 19 | return '%s_L%03d' % (self.name, layer+1) 20 | 21 | def _get_num_inputs(self): 22 | return int(self.get_output().get_shape()[-1]) 23 | 24 | def _glorot_initializer(self, prev_units, num_units, stddev_factor=1.0): 25 | """Initialization in the style of Glorot 2010. 26 | 27 | stddev_factor should be 1.0 for linear activations, and 2.0 for ReLUs""" 28 | stddev = np.sqrt(stddev_factor / np.sqrt(prev_units*num_units)) 29 | return tf.truncated_normal([prev_units, num_units], 30 | mean=0.0, stddev=stddev) 31 | 32 | def _glorot_initializer_conv2d(self, prev_units, num_units, mapsize, stddev_factor=1.0): 33 | """Initialization in the style of Glorot 2010. 34 | 35 | stddev_factor should be 1.0 for linear activations, and 2.0 for ReLUs""" 36 | 37 | stddev = np.sqrt(stddev_factor / (np.sqrt(prev_units*num_units)*mapsize*mapsize)) 38 | return tf.truncated_normal([mapsize, mapsize, prev_units, num_units], 39 | mean=0.0, stddev=stddev) 40 | 41 | def get_num_layers(self): 42 | return len(self.outputs) 43 | 44 | def add_batch_norm(self, scale=False): 45 | """Adds a batch normalization layer to this model. 46 | 47 | See ArXiv 1502.03167v3 for details.""" 48 | 49 | # TBD: This appears to be very flaky, often raising InvalidArgumentError internally 50 | with tf.variable_scope(self._get_layer_str()): 51 | out = tf.contrib.layers.batch_norm(self.get_output(), scale=scale) 52 | 53 | self.outputs.append(out) 54 | return self 55 | 56 | def add_flatten(self): 57 | """Transforms the output of this network to a 1D tensor""" 58 | 59 | with tf.variable_scope(self._get_layer_str()): 60 | batch_size = int(self.get_output().get_shape()[0]) 61 | out = tf.reshape(self.get_output(), [batch_size, -1]) 62 | 63 | self.outputs.append(out) 64 | return self 65 | 66 | def add_dense(self, num_units, stddev_factor=1.0): 67 | """Adds a dense linear layer to this model. 68 | 69 | Uses Glorot 2010 initialization assuming linear activation.""" 70 | 71 | assert len(self.get_output().get_shape()) == 2, "Previous layer must be 2-dimensional (batch, channels)" 72 | 73 | with tf.variable_scope(self._get_layer_str()): 74 | prev_units = self._get_num_inputs() 75 | 76 | # Weight term 77 | initw = self._glorot_initializer(prev_units, num_units, 78 | stddev_factor=stddev_factor) 79 | weight = tf.get_variable('weight', initializer=initw) 80 | 81 | # Bias term 82 | initb = tf.constant(0.0, shape=[num_units]) 83 | bias = tf.get_variable('bias', initializer=initb) 84 | 85 | # Output of this layer 86 | out = tf.matmul(self.get_output(), weight) + bias 87 | 88 | self.outputs.append(out) 89 | return self 90 | 91 | def add_sigmoid(self): 92 | """Adds a sigmoid (0,1) activation function layer to this model.""" 93 | 94 | with tf.variable_scope(self._get_layer_str()): 95 | prev_units = self._get_num_inputs() 96 | out = tf.nn.sigmoid(self.get_output()) 97 | 98 | self.outputs.append(out) 99 | return self 100 | 101 | def add_softmax(self): 102 | """Adds a softmax operation to this model""" 103 | 104 | with tf.variable_scope(self._get_layer_str()): 105 | this_input = tf.square(self.get_output()) 106 | reduction_indices = list(range(1, len(this_input.get_shape()))) 107 | acc = tf.reduce_sum(this_input, reduction_indices=reduction_indices, keep_dims=True) 108 | out = this_input / (acc+FLAGS.epsilon) 109 | #out = tf.verify_tensor_all_finite(out, "add_softmax failed; is sum equal to zero?") 110 | 111 | self.outputs.append(out) 112 | return self 113 | 114 | def add_relu(self): 115 | """Adds a ReLU activation function to this model""" 116 | 117 | with tf.variable_scope(self._get_layer_str()): 118 | out = tf.nn.relu(self.get_output()) 119 | 120 | self.outputs.append(out) 121 | return self 122 | 123 | def add_elu(self): 124 | """Adds a ELU activation function to this model""" 125 | 126 | with tf.variable_scope(self._get_layer_str()): 127 | out = tf.nn.elu(self.get_output()) 128 | 129 | self.outputs.append(out) 130 | return self 131 | 132 | def add_lrelu(self, leak=.2): 133 | """Adds a leaky ReLU (LReLU) activation function to this model""" 134 | 135 | with tf.variable_scope(self._get_layer_str()): 136 | t1 = .5 * (1 + leak) 137 | t2 = .5 * (1 - leak) 138 | out = t1 * self.get_output() + \ 139 | t2 * tf.abs(self.get_output()) 140 | 141 | self.outputs.append(out) 142 | return self 143 | 144 | def add_conv2d(self, num_units, mapsize=1, stride=1, stddev_factor=1.0): 145 | """Adds a 2D convolutional layer.""" 146 | 147 | assert len(self.get_output().get_shape()) == 4 and "Previous layer must be 4-dimensional (batch, width, height, channels)" 148 | 149 | with tf.variable_scope(self._get_layer_str()): 150 | prev_units = self._get_num_inputs() 151 | 152 | # Weight term and convolution 153 | initw = self._glorot_initializer_conv2d(prev_units, num_units, 154 | mapsize, 155 | stddev_factor=stddev_factor) 156 | weight = tf.get_variable('weight', initializer=initw) 157 | out = tf.nn.conv2d(self.get_output(), weight, 158 | strides=[1, stride, stride, 1], 159 | padding='SAME') 160 | 161 | # Bias term 162 | initb = tf.constant(0.0, shape=[num_units]) 163 | bias = tf.get_variable('bias', initializer=initb) 164 | out = tf.nn.bias_add(out, bias) 165 | 166 | self.outputs.append(out) 167 | return self 168 | 169 | def add_conv2d_transpose(self, num_units, mapsize=1, stride=1, stddev_factor=1.0): 170 | """Adds a transposed 2D convolutional layer""" 171 | 172 | assert len(self.get_output().get_shape()) == 4 and "Previous layer must be 4-dimensional (batch, width, height, channels)" 173 | 174 | with tf.variable_scope(self._get_layer_str()): 175 | prev_units = self._get_num_inputs() 176 | 177 | # Weight term and convolution 178 | initw = self._glorot_initializer_conv2d(prev_units, num_units, 179 | mapsize, 180 | stddev_factor=stddev_factor) 181 | weight = tf.get_variable('weight', initializer=initw) 182 | weight = tf.transpose(weight, perm=[0, 1, 3, 2]) 183 | prev_output = self.get_output() 184 | output_shape = [FLAGS.batch_size, 185 | int(prev_output.get_shape()[1]) * stride, 186 | int(prev_output.get_shape()[2]) * stride, 187 | num_units] 188 | out = tf.nn.conv2d_transpose(self.get_output(), weight, 189 | output_shape=output_shape, 190 | strides=[1, stride, stride, 1], 191 | padding='SAME') 192 | 193 | # Bias term 194 | initb = tf.constant(0.0, shape=[num_units]) 195 | bias = tf.get_variable('bias', initializer=initb) 196 | out = tf.nn.bias_add(out, bias) 197 | 198 | self.outputs.append(out) 199 | return self 200 | 201 | def add_residual_block(self, num_units, mapsize=3, num_layers=2, stddev_factor=1e-3): 202 | """Adds a residual block as per Arxiv 1512.03385, Figure 3""" 203 | 204 | assert len(self.get_output().get_shape()) == 4 and "Previous layer must be 4-dimensional (batch, width, height, channels)" 205 | 206 | # Add projection in series if needed prior to shortcut 207 | if num_units != int(self.get_output().get_shape()[3]): 208 | self.add_conv2d(num_units, mapsize=1, stride=1, stddev_factor=1.) 209 | 210 | bypass = self.get_output() 211 | 212 | # Residual block 213 | for _ in range(num_layers): 214 | self.add_batch_norm() 215 | self.add_relu() 216 | self.add_conv2d(num_units, mapsize=mapsize, stride=1, stddev_factor=stddev_factor) 217 | 218 | self.add_sum(bypass) 219 | 220 | return self 221 | 222 | def add_bottleneck_residual_block(self, num_units, mapsize=3, stride=1, transpose=False): 223 | """Adds a bottleneck residual block as per Arxiv 1512.03385, Figure 3""" 224 | 225 | assert len(self.get_output().get_shape()) == 4 and "Previous layer must be 4-dimensional (batch, width, height, channels)" 226 | 227 | # Add projection in series if needed prior to shortcut 228 | if num_units != int(self.get_output().get_shape()[3]) or stride != 1: 229 | ms = 1 if stride == 1 else mapsize 230 | #bypass.add_batch_norm() # TBD: Needed? 231 | if transpose: 232 | self.add_conv2d_transpose(num_units, mapsize=ms, stride=stride, stddev_factor=1.) 233 | else: 234 | self.add_conv2d(num_units, mapsize=ms, stride=stride, stddev_factor=1.) 235 | 236 | bypass = self.get_output() 237 | 238 | # Bottleneck residual block 239 | self.add_batch_norm() 240 | self.add_relu() 241 | self.add_conv2d(num_units//4, mapsize=1, stride=1, stddev_factor=2.) 242 | 243 | self.add_batch_norm() 244 | self.add_relu() 245 | if transpose: 246 | self.add_conv2d_transpose(num_units//4, 247 | mapsize=mapsize, 248 | stride=1, 249 | stddev_factor=2.) 250 | else: 251 | self.add_conv2d(num_units//4, 252 | mapsize=mapsize, 253 | stride=1, 254 | stddev_factor=2.) 255 | 256 | self.add_batch_norm() 257 | self.add_relu() 258 | self.add_conv2d(num_units, mapsize=1, stride=1, stddev_factor=2.) 259 | 260 | self.add_sum(bypass) 261 | 262 | return self 263 | 264 | def add_sum(self, term): 265 | """Adds a layer that sums the top layer with the given term""" 266 | 267 | with tf.variable_scope(self._get_layer_str()): 268 | prev_shape = self.get_output().get_shape() 269 | term_shape = term.get_shape() 270 | #print("%s %s" % (prev_shape, term_shape)) 271 | assert prev_shape == term_shape and "Can't sum terms with a different size" 272 | out = tf.add(self.get_output(), term) 273 | 274 | self.outputs.append(out) 275 | return self 276 | 277 | def add_mean(self): 278 | """Adds a layer that averages the inputs from the previous layer""" 279 | 280 | with tf.variable_scope(self._get_layer_str()): 281 | prev_shape = self.get_output().get_shape() 282 | reduction_indices = list(range(len(prev_shape))) 283 | assert len(reduction_indices) > 2 and "Can't average a (batch, activation) tensor" 284 | reduction_indices = reduction_indices[1:-1] 285 | out = tf.reduce_mean(self.get_output(), reduction_indices=reduction_indices) 286 | 287 | self.outputs.append(out) 288 | return self 289 | 290 | def add_upscale(self): 291 | """Adds a layer that upscales the output by 2x through nearest neighbor interpolation""" 292 | 293 | prev_shape = self.get_output().get_shape() 294 | size = [2 * int(s) for s in prev_shape[1:3]] 295 | out = tf.image.resize_nearest_neighbor(self.get_output(), size) 296 | 297 | self.outputs.append(out) 298 | return self 299 | 300 | def get_output(self): 301 | """Returns the output from the topmost layer of the network""" 302 | return self.outputs[-1] 303 | 304 | def get_variable(self, layer, name): 305 | """Returns a variable given its layer and name. 306 | 307 | The variable must already exist.""" 308 | 309 | scope = self._get_layer_str(layer) 310 | collection = tf.get_collection(tf.GraphKeys.VARIABLES, scope=scope) 311 | 312 | # TBD: Ugly! 313 | for var in collection: 314 | if var.name[:-2] == scope+'/'+name: 315 | return var 316 | 317 | return None 318 | 319 | def get_all_layer_variables(self, layer): 320 | """Returns all variables in the given layer""" 321 | scope = self._get_layer_str(layer) 322 | return tf.get_collection(tf.GraphKeys.VARIABLES, scope=scope) 323 | 324 | def _discriminator_model(sess, features, disc_input): 325 | # Fully convolutional model 326 | mapsize = 3 327 | layers = [64, 128, 256, 512] 328 | 329 | old_vars = tf.all_variables() 330 | 331 | model = Model('DIS', 2*disc_input - 1) 332 | 333 | for layer in range(len(layers)): 334 | nunits = layers[layer] 335 | stddev_factor = 2.0 336 | 337 | model.add_conv2d(nunits, mapsize=mapsize, stride=2, stddev_factor=stddev_factor) 338 | model.add_batch_norm() 339 | model.add_relu() 340 | 341 | # Finalization a la "all convolutional net" 342 | model.add_conv2d(nunits, mapsize=mapsize, stride=1, stddev_factor=stddev_factor) 343 | model.add_batch_norm() 344 | model.add_relu() 345 | 346 | model.add_conv2d(nunits, mapsize=1, stride=1, stddev_factor=stddev_factor) 347 | model.add_batch_norm() 348 | model.add_relu() 349 | 350 | # Linearly map to real/fake and return average score 351 | # (softmax will be applied later) 352 | model.add_conv2d(1, mapsize=1, stride=1, stddev_factor=stddev_factor) 353 | model.add_mean() 354 | 355 | new_vars = tf.all_variables() 356 | disc_vars = list(set(new_vars) - set(old_vars)) 357 | 358 | return model.get_output(), disc_vars 359 | 360 | def _generator_model(sess, features, labels, channels): 361 | # Upside-down all-convolutional resnet 362 | 363 | mapsize = 3 364 | res_units = [256, 128, 96] 365 | 366 | old_vars = tf.all_variables() 367 | 368 | # See Arxiv 1603.05027 369 | model = Model('GEN', features) 370 | 371 | for ru in range(len(res_units)-1): 372 | nunits = res_units[ru] 373 | 374 | for j in range(2): 375 | model.add_residual_block(nunits, mapsize=mapsize) 376 | 377 | # Spatial upscale (see http://distill.pub/2016/deconv-checkerboard/) 378 | # and transposed convolution 379 | model.add_upscale() 380 | 381 | model.add_batch_norm() 382 | model.add_relu() 383 | model.add_conv2d_transpose(nunits, mapsize=mapsize, stride=1, stddev_factor=1.) 384 | 385 | # Finalization a la "all convolutional net" 386 | nunits = res_units[-1] 387 | model.add_conv2d(nunits, mapsize=mapsize, stride=1, stddev_factor=2.) 388 | # Worse: model.add_batch_norm() 389 | model.add_relu() 390 | 391 | model.add_conv2d(nunits, mapsize=1, stride=1, stddev_factor=2.) 392 | # Worse: model.add_batch_norm() 393 | model.add_relu() 394 | 395 | # Last layer is sigmoid with no batch normalization 396 | model.add_conv2d(channels, mapsize=1, stride=1, stddev_factor=1.) 397 | model.add_sigmoid() 398 | 399 | new_vars = tf.all_variables() 400 | gene_vars = list(set(new_vars) - set(old_vars)) 401 | 402 | return model.get_output(), gene_vars 403 | 404 | def create_model(sess, features, labels): 405 | # Generator 406 | rows = int(features.get_shape()[1]) 407 | cols = int(features.get_shape()[2]) 408 | channels = int(features.get_shape()[3]) 409 | 410 | gene_minput = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, rows, cols, channels]) 411 | 412 | # TBD: Is there a better way to instance the generator? 413 | with tf.variable_scope('gene') as scope: 414 | gene_output, gene_var_list = \ 415 | _generator_model(sess, features, labels, channels) 416 | 417 | scope.reuse_variables() 418 | 419 | gene_moutput, _ = _generator_model(sess, gene_minput, labels, channels) 420 | 421 | # Discriminator with real data 422 | disc_real_input = tf.identity(labels, name='disc_real_input') 423 | 424 | # TBD: Is there a better way to instance the discriminator? 425 | with tf.variable_scope('disc') as scope: 426 | disc_real_output, disc_var_list = \ 427 | _discriminator_model(sess, features, disc_real_input) 428 | 429 | scope.reuse_variables() 430 | 431 | disc_fake_output, _ = _discriminator_model(sess, features, gene_output) 432 | 433 | return [gene_minput, gene_moutput, 434 | gene_output, gene_var_list, 435 | disc_real_output, disc_fake_output, disc_var_list] 436 | 437 | def _downscale(images, K): 438 | """Differentiable image downscaling by a factor of K""" 439 | arr = np.zeros([K, K, 3, 3]) 440 | arr[:,:,0,0] = 1.0/(K*K) 441 | arr[:,:,1,1] = 1.0/(K*K) 442 | arr[:,:,2,2] = 1.0/(K*K) 443 | dowscale_weight = tf.constant(arr, dtype=tf.float32) 444 | 445 | downscaled = tf.nn.conv2d(images, dowscale_weight, 446 | strides=[1, K, K, 1], 447 | padding='SAME') 448 | return downscaled 449 | 450 | def create_generator_loss(disc_output, gene_output, features): 451 | # I.e. did we fool the discriminator? 452 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(disc_output, tf.ones_like(disc_output)) 453 | gene_ce_loss = tf.reduce_mean(cross_entropy, name='gene_ce_loss') 454 | 455 | # I.e. does the result look like the feature? 456 | K = int(gene_output.get_shape()[1])//int(features.get_shape()[1]) 457 | assert K == 2 or K == 4 or K == 8 458 | downscaled = _downscale(gene_output, K) 459 | 460 | gene_l1_loss = tf.reduce_mean(tf.abs(downscaled - features), name='gene_l1_loss') 461 | 462 | gene_loss = tf.add((1.0 - FLAGS.gene_l1_factor) * gene_ce_loss, 463 | FLAGS.gene_l1_factor * gene_l1_loss, name='gene_loss') 464 | 465 | return gene_loss 466 | 467 | def create_discriminator_loss(disc_real_output, disc_fake_output): 468 | # I.e. did we correctly identify the input as real or not? 469 | cross_entropy_real = tf.nn.sigmoid_cross_entropy_with_logits(disc_real_output, tf.ones_like(disc_real_output)) 470 | disc_real_loss = tf.reduce_mean(cross_entropy_real, name='disc_real_loss') 471 | 472 | cross_entropy_fake = tf.nn.sigmoid_cross_entropy_with_logits(disc_fake_output, tf.zeros_like(disc_fake_output)) 473 | disc_fake_loss = tf.reduce_mean(cross_entropy_fake, name='disc_fake_loss') 474 | 475 | return disc_real_loss, disc_fake_loss 476 | 477 | def create_optimizers(gene_loss, gene_var_list, 478 | disc_loss, disc_var_list): 479 | # TBD: Does this global step variable need to be manually incremented? I think so. 480 | global_step = tf.Variable(0, dtype=tf.int64, trainable=False, name='global_step') 481 | learning_rate = tf.placeholder(dtype=tf.float32, name='learning_rate') 482 | 483 | gene_opti = tf.train.AdamOptimizer(learning_rate=learning_rate, 484 | beta1=FLAGS.learning_beta1, 485 | name='gene_optimizer') 486 | disc_opti = tf.train.AdamOptimizer(learning_rate=learning_rate, 487 | beta1=FLAGS.learning_beta1, 488 | name='disc_optimizer') 489 | 490 | gene_minimize = gene_opti.minimize(gene_loss, var_list=gene_var_list, name='gene_loss_minimize', global_step=global_step) 491 | 492 | disc_minimize = disc_opti.minimize(disc_loss, var_list=disc_var_list, name='disc_loss_minimize', global_step=global_step) 493 | 494 | return (global_step, learning_rate, gene_minimize, disc_minimize) 495 | -------------------------------------------------------------------------------- /srez_sample_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/david-gpu/srez/d65a2aabdb04122be95e61c8aa420036cd6336d6/srez_sample_output.png -------------------------------------------------------------------------------- /srez_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path 3 | import scipy.misc 4 | import tensorflow as tf 5 | import time 6 | 7 | FLAGS = tf.app.flags.FLAGS 8 | 9 | def _summarize_progress(train_data, feature, label, gene_output, batch, suffix, max_samples=8): 10 | td = train_data 11 | 12 | size = [label.shape[1], label.shape[2]] 13 | 14 | nearest = tf.image.resize_nearest_neighbor(feature, size) 15 | nearest = tf.maximum(tf.minimum(nearest, 1.0), 0.0) 16 | 17 | bicubic = tf.image.resize_bicubic(feature, size) 18 | bicubic = tf.maximum(tf.minimum(bicubic, 1.0), 0.0) 19 | 20 | clipped = tf.maximum(tf.minimum(gene_output, 1.0), 0.0) 21 | 22 | image = tf.concat(2, [nearest, bicubic, clipped, label]) 23 | 24 | image = image[0:max_samples,:,:,:] 25 | image = tf.concat(0, [image[i,:,:,:] for i in range(max_samples)]) 26 | image = td.sess.run(image) 27 | 28 | filename = 'batch%06d_%s.png' % (batch, suffix) 29 | filename = os.path.join(FLAGS.train_dir, filename) 30 | scipy.misc.toimage(image, cmin=0., cmax=1.).save(filename) 31 | print(" Saved %s" % (filename,)) 32 | 33 | def _save_checkpoint(train_data, batch): 34 | td = train_data 35 | 36 | oldname = 'checkpoint_old.txt' 37 | newname = 'checkpoint_new.txt' 38 | 39 | oldname = os.path.join(FLAGS.checkpoint_dir, oldname) 40 | newname = os.path.join(FLAGS.checkpoint_dir, newname) 41 | 42 | # Delete oldest checkpoint 43 | try: 44 | tf.gfile.Remove(oldname) 45 | tf.gfile.Remove(oldname + '.meta') 46 | except: 47 | pass 48 | 49 | # Rename old checkpoint 50 | try: 51 | tf.gfile.Rename(newname, oldname) 52 | tf.gfile.Rename(newname + '.meta', oldname + '.meta') 53 | except: 54 | pass 55 | 56 | # Generate new checkpoint 57 | saver = tf.train.Saver() 58 | saver.save(td.sess, newname) 59 | 60 | print(" Checkpoint saved") 61 | 62 | def train_model(train_data): 63 | td = train_data 64 | 65 | summaries = tf.merge_all_summaries() 66 | td.sess.run(tf.initialize_all_variables()) 67 | 68 | lrval = FLAGS.learning_rate_start 69 | start_time = time.time() 70 | done = False 71 | batch = 0 72 | 73 | assert FLAGS.learning_rate_half_life % 10 == 0 74 | 75 | # Cache test features and labels (they are small) 76 | test_feature, test_label = td.sess.run([td.test_features, td.test_labels]) 77 | 78 | while not done: 79 | batch += 1 80 | gene_loss = disc_real_loss = disc_fake_loss = -1.234 81 | 82 | feed_dict = {td.learning_rate : lrval} 83 | 84 | ops = [td.gene_minimize, td.disc_minimize, td.gene_loss, td.disc_real_loss, td.disc_fake_loss] 85 | _, _, gene_loss, disc_real_loss, disc_fake_loss = td.sess.run(ops, feed_dict=feed_dict) 86 | 87 | if batch % 10 == 0: 88 | # Show we are alive 89 | elapsed = int(time.time() - start_time)/60 90 | print('Progress[%3d%%], ETA[%4dm], Batch [%4d], G_Loss[%3.3f], D_Real_Loss[%3.3f], D_Fake_Loss[%3.3f]' % 91 | (int(100*elapsed/FLAGS.train_time), FLAGS.train_time - elapsed, 92 | batch, gene_loss, disc_real_loss, disc_fake_loss)) 93 | 94 | # Finished? 95 | current_progress = elapsed / FLAGS.train_time 96 | if current_progress >= 1.0: 97 | done = True 98 | 99 | # Update learning rate 100 | if batch % FLAGS.learning_rate_half_life == 0: 101 | lrval *= .5 102 | 103 | if batch % FLAGS.summary_period == 0: 104 | # Show progress with test features 105 | feed_dict = {td.gene_minput: test_feature} 106 | gene_output = td.sess.run(td.gene_moutput, feed_dict=feed_dict) 107 | _summarize_progress(td, test_feature, test_label, gene_output, batch, 'out') 108 | 109 | if batch % FLAGS.checkpoint_period == 0: 110 | # Save checkpoint 111 | _save_checkpoint(td, batch) 112 | 113 | _save_checkpoint(td, batch) 114 | print('Finished training!') 115 | --------------------------------------------------------------------------------