├── README.md ├── cnn_toys ├── __init__.py ├── colorize │ ├── model.py │ ├── run_sample.py │ └── run_train.py ├── cyclegan │ ├── __init__.py │ ├── history.py │ ├── model.py │ ├── run_single.py │ ├── run_train.py │ └── test_history.py ├── data.py ├── graphics.py ├── real_nvp │ ├── __init__.py │ ├── interp.py │ ├── layer.py │ ├── models.py │ ├── objective.py │ ├── run_interp.py │ ├── run_sample.py │ ├── run_train.py │ └── test_layer.py ├── saving.py └── schedules.py └── setup.py /README.md: -------------------------------------------------------------------------------- 1 | # cnn-toys 2 | 3 | Convolutional neural networks are fairly simple, and it's easy to apply them. However, it's also easy to lose track of which architectures work well for different applications. 4 | 5 | I want to use this repo to play with different applications of CNNs. That way, I gain some intuition for when to use transposed convolutions, upsampling, residual connections, pixel CNNs, leaky ReLUs, etc. 6 | 7 | # Contents 8 | 9 | * [colorize](cnn_toys/colorize) - Grayscale -> color predictor. *Current status:* the model sometimes colors skies in correctly, but it's generally pretty terrible. 10 | * [cyclegan](cnn_toys/cyclegan) - A re-implementation of [CycleGAN](https://github.com/junyanz/CycleGAN). *Current status:* works fairly well. 11 | * [real_nvp](cnn_toys/real_nvp) - A re-implementation of [real NVP](https://arxiv.org/abs/1605.08803). *Current status:* works on problems that I've tested it on. Generally requires messing with the architecture (usually making it way deeper). 12 | -------------------------------------------------------------------------------- /cnn_toys/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/unixpickle/cnn-toys/429402c61c07d5562467a9d6beb62300ed369a11/cnn_toys/__init__.py -------------------------------------------------------------------------------- /cnn_toys/colorize/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model for image colorization. 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def sample_loss(input_ph): 9 | """ 10 | Generate a loss for converting a color image to 11 | grayscale and then back again. 12 | """ 13 | colorized = colorize(tf.reduce_mean(input_ph, axis=-1, keep_dims=True)) 14 | return tf.reduce_mean(tf.abs(colorized - input_ph)) 15 | 16 | 17 | def colorize(input_ph): 18 | """ 19 | Apply a neural network to produce a colorized version 20 | of the input images. 21 | """ 22 | def activation(x): return tf.nn.relu(tf.contrib.layers.layer_norm(x)) 23 | output = input_ph 24 | output = tf.pad(output, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT') 25 | output = tf.layers.conv2d(output, 32, 7, activation=tf.nn.relu) 26 | for features in [64, 128]: 27 | output = tf.pad(output, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 28 | output = tf.layers.conv2d(output, features, 3, strides=2, activation=activation) 29 | for _ in range(6): 30 | old_output = output 31 | output = tf.pad(output, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 32 | output = tf.layers.conv2d(output, 128, 3, activation=activation) 33 | output = tf.pad(output, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='REFLECT') 34 | output = tf.layers.conv2d(output, 128, 3, activation=tf.contrib.layers.layer_norm) 35 | output = tf.nn.relu(old_output + output) 36 | output = tf.layers.conv2d_transpose(output, 64, 3, strides=2, padding='same', 37 | activation=activation) 38 | output = tf.layers.conv2d_transpose(output, 32, 3, strides=2, padding='same', 39 | activation=activation) 40 | output = tf.pad(output, [[0, 0], [3, 3], [3, 3], [0, 0]], mode='REFLECT') 41 | return tf.layers.conv2d(output, 3, 7, activation=tf.sigmoid) 42 | -------------------------------------------------------------------------------- /cnn_toys/colorize/run_sample.py: -------------------------------------------------------------------------------- 1 | """Sample from a colorization model.""" 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from cnn_toys.colorize.model import colorize 9 | from cnn_toys.data import dir_train_val 10 | from cnn_toys.graphics import save_image_grid 11 | from cnn_toys.saving import restore_state 12 | 13 | 14 | def main(args): 15 | """Sample a batch of colorized images.""" 16 | _, val_set = dir_train_val(args.data_dir, args.size) 17 | images = val_set.batch(args.batch).repeat().make_one_shot_iterator().get_next() 18 | grayscale = tf.reduce_mean(images, axis=-1, keep_dims=True) 19 | with tf.variable_scope('colorize'): 20 | colorized = colorize(grayscale) 21 | with tf.Session() as sess: 22 | sess.run(tf.global_variables_initializer()) 23 | restore_state(sess, args.state_file) 24 | rows = sess.run([images, tf.tile(grayscale, [1, 1, 1, 3]), colorized]) 25 | save_image_grid(np.array(rows), 'images.png') 26 | 27 | 28 | def _parse_args(): 29 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 30 | parser.add_argument('--data-dir', help='data directory', default='data') 31 | parser.add_argument('--size', help='image size', type=int, default=64) 32 | parser.add_argument('--batch', help='number of samples', type=int, default=16) 33 | parser.add_argument('--state-file', help='state input file', default='state.pkl') 34 | return parser.parse_args() 35 | 36 | 37 | if __name__ == '__main__': 38 | main(_parse_args()) 39 | -------------------------------------------------------------------------------- /cnn_toys/colorize/run_train.py: -------------------------------------------------------------------------------- 1 | """Train a colorization model.""" 2 | 3 | import argparse 4 | import itertools 5 | 6 | import tensorflow as tf 7 | 8 | from cnn_toys.data import dir_train_val 9 | from cnn_toys.saving import save_state, restore_state 10 | from cnn_toys.colorize.model import sample_loss 11 | 12 | 13 | def main(args): 14 | """Training outer loop.""" 15 | train, val = [d.batch(args.batch).repeat().make_one_shot_iterator().get_next() 16 | for d in dir_train_val(args.data_dir, args.size)] 17 | with tf.variable_scope('colorize'): 18 | train_loss = sample_loss(train) 19 | with tf.variable_scope('colorize', reuse=True): 20 | val_loss = sample_loss(val) 21 | with tf.control_dependencies([train_loss, val_loss]): 22 | optimize = tf.train.AdamOptimizer(learning_rate=args.step_size).minimize(train_loss) 23 | 24 | with tf.Session() as sess: 25 | sess.run(tf.global_variables_initializer()) 26 | restore_state(sess, args.state_file) 27 | for i in itertools.count(): 28 | losses, _ = sess.run([(train_loss, val_loss), optimize]) 29 | print('step %d: train=%f val=%f' % ((i,) + losses)) 30 | if i % args.save_interval == 0: 31 | save_state(sess, args.state_file) 32 | 33 | 34 | def _parse_args(): 35 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 36 | parser.add_argument('--data-dir', help='data directory', default='data') 37 | parser.add_argument('--size', help='image size', type=int, default=64) 38 | parser.add_argument('--batch', help='batch size', type=int, default=16) 39 | parser.add_argument('--step-size', help='training step size', type=float, default=1e-3) 40 | parser.add_argument('--state-file', help='state output file', default='state.pkl') 41 | parser.add_argument('--save-interval', help='steps per save', type=int, default=100) 42 | return parser.parse_args() 43 | 44 | 45 | if __name__ == '__main__': 46 | main(_parse_args()) 47 | -------------------------------------------------------------------------------- /cnn_toys/cyclegan/__init__.py: -------------------------------------------------------------------------------- 1 | """The CycleGAN architecture.""" 2 | 3 | from .history import history_image 4 | from .model import (CycleGAN, standard_discriminator, standard_generator, instance_norm, 5 | reflection_pad) 6 | -------------------------------------------------------------------------------- /cnn_toys/cyclegan/history.py: -------------------------------------------------------------------------------- 1 | """ 2 | An in-graph image history buffer. 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def history_image(image, buffer_size, name='image_buffer'): 9 | """ 10 | Get an image from a history buffer and submit the 11 | image to the same buffer. 12 | 13 | Args: 14 | image: an image Tensor. The shape must be known in 15 | advance. 16 | buffer_size: the number of images to store in the 17 | history buffer. 18 | name: the default name of the scope in which the 19 | buffer is stored. Uniquified as needed. 20 | """ 21 | with tf.variable_scope(None, default_name=name): 22 | buf = tf.get_variable('images', shape=[buffer_size] + [x.value for x in image.get_shape()], 23 | dtype=image.dtype, trainable=False) 24 | size = tf.get_variable('size', dtype=tf.int32, initializer=0, trainable=False) 25 | 26 | def _insert_new(): 27 | insert_idx = tf.assign_add(size, 1) - 1 28 | return _assign_buf_entry(buf, insert_idx, image) 29 | 30 | def _sample_old(): 31 | idx = tf.random_uniform((), maxval=buffer_size, dtype=tf.int32) 32 | # `+ 0` hack to deal with buffer_size == 1. 33 | # See https://github.com/tensorflow/tensorflow/issues/4663. 34 | old = buf[idx] + 0 35 | with tf.control_dependencies([old]): 36 | assign = _assign_buf_entry(buf, idx, image) 37 | with tf.control_dependencies([assign]): 38 | return tf.identity(old) 39 | return tf.cond(size < buffer_size, 40 | _insert_new, 41 | lambda: tf.cond(tf.random_uniform(()) < 0.5, 42 | _sample_old, 43 | lambda: image)) 44 | 45 | 46 | def _assign_buf_entry(buf, idx, image): 47 | pieces = [buf[:idx], tf.expand_dims(image, 0), buf[idx+1:]] 48 | return tf.assign(buf, tf.concat(pieces, 0))[idx] 49 | -------------------------------------------------------------------------------- /cnn_toys/cyclegan/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Architectures for CycleGANs. 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | from .history import history_image 8 | 9 | # pylint: disable=R0902,R0903 10 | 11 | 12 | class CycleGAN: 13 | """ 14 | A CycleGAN model. 15 | 16 | The generator() and discriminator() methods can be 17 | overridden to change the model architecture. 18 | 19 | Unless otherwise stated, all pixels are assumed to 20 | range from [0, 1]. 21 | """ 22 | 23 | def __init__(self, real_x, real_y, buffer_size=50, cycle_weight=10): 24 | initializer = tf.truncated_normal_initializer(stddev=0.02) 25 | self.buffer_size = buffer_size 26 | with tf.variable_scope('cyclegan', initializer=initializer): 27 | self.real_x = _add_image_noise(real_x) 28 | self.real_y = _add_image_noise(real_y) 29 | self._setup_generators() 30 | self._setup_discriminators() 31 | self._setup_cycles(cycle_weight) 32 | self._setup_gradients() 33 | 34 | def optimize(self, learning_rate=0.0002, beta1=0.5, global_step=None): 35 | """Create an Op that takes a training step.""" 36 | opt = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=beta1) 37 | return opt.apply_gradients(self.gradients, global_step=global_step) 38 | 39 | def generator(self, image): 40 | """Apply the generator to an image.""" 41 | return standard_generator(image, self._num_residual()) 42 | 43 | def discriminator(self, images): # pylint: disable=R0201 44 | """Apply the descriminator to a batch of images.""" 45 | return standard_discriminator(images) 46 | 47 | def _setup_generators(self): 48 | start_vars = tf.trainable_variables() 49 | with tf.variable_scope('gen_x'): 50 | self.gen_x = self.generator(self.real_y) 51 | with tf.variable_scope('gen_y'): 52 | self.gen_y = self.generator(self.real_x) 53 | self.gen_vars = [v for v in tf.trainable_variables() if v not in start_vars] 54 | 55 | def _setup_discriminators(self): 56 | start_vars = tf.trainable_variables() 57 | with tf.variable_scope('disc_x'): 58 | disc_x_loss, gen_x_loss = self._discriminate(self.real_x, self.gen_x) 59 | with tf.variable_scope('disc_y'): 60 | disc_y_loss, gen_y_loss = self._discriminate(self.real_y, self.gen_y) 61 | self.disc_vars = [v for v in tf.trainable_variables() if v not in start_vars] 62 | self.disc_loss = (disc_x_loss + disc_y_loss) / 2 63 | self.gen_loss = gen_x_loss + gen_y_loss 64 | 65 | def _setup_cycles(self, weight): 66 | with tf.variable_scope('gen_y', reuse=True): 67 | self.cycle_y = self.generator(self.gen_x) 68 | with tf.variable_scope('gen_x', reuse=True): 69 | self.cycle_x = self.generator(self.gen_y) 70 | self.cycle_loss = weight * (tf.reduce_mean(tf.abs(self.real_x - self.cycle_x)) + 71 | tf.reduce_mean(tf.abs(self.real_y - self.cycle_y))) 72 | 73 | def _setup_gradients(self): 74 | gen_grad = _grad_dict(self.gen_loss + self.cycle_loss, self.gen_vars) 75 | disc_grad = _grad_dict(self.disc_loss, self.disc_vars) 76 | total_grad = _add_grad_dicts(gen_grad, disc_grad) 77 | self.gradients = [(g, v) for v, g in total_grad.items()] 78 | 79 | def _discriminate(self, real_image, gen_image): 80 | """ 81 | Run samples through a discriminator to get the GAN 82 | losses. 83 | 84 | Returns: 85 | A tuple (discriminator_loss, generator_loss). 86 | """ 87 | buf_image = history_image(gen_image, self.buffer_size) 88 | batch = tf.stack([real_image, gen_image, buf_image]) 89 | disc = self.discriminator(batch) 90 | real_outs, gen_outs, buf_outs = disc[0], disc[1], disc[2] 91 | disc_loss = tf.reduce_mean(tf.square(real_outs - 1)) + tf.reduce_mean(tf.square(buf_outs)) 92 | gen_loss = tf.reduce_mean(tf.square(gen_outs - 1)) 93 | return disc_loss, gen_loss 94 | 95 | def _num_residual(self): 96 | if self.real_x.get_shape()[1].value > 128: 97 | return 9 98 | return 6 99 | 100 | 101 | def standard_discriminator(images): 102 | """ 103 | Apply the standard CycleGAN discriminator to a batch 104 | of images. 105 | 106 | The output Tensor may have any rank, but the outer 107 | dimension must match the batch size. 108 | """ 109 | activation = tf.nn.leaky_relu 110 | outputs = 2 * images - 1 111 | outputs = tf.layers.conv2d(outputs, 64, 4, strides=2, activation=activation) 112 | for num_filters in [128, 256, 512]: 113 | strides = 2 114 | if num_filters == 512: 115 | strides = 1 116 | outputs = tf.layers.conv2d(outputs, num_filters, 4, strides=strides, use_bias=False) 117 | outputs = activation(instance_norm(outputs)) 118 | return tf.layers.conv2d(outputs, 1, 1) 119 | 120 | 121 | def standard_generator(image, num_residual): 122 | """ 123 | Apply the standard CycleGAN generator to an image. 124 | 125 | Args: 126 | image: an image Tensor with pixel values in the 127 | range [0, 1]. 128 | num_residual: the number of residual layers. In the 129 | original paper, this varied by image size. 130 | """ 131 | def activation(x): return tf.nn.relu(instance_norm(x)) 132 | output = 2 * image - 1 133 | output = reflection_pad(tf.expand_dims(output, 0), 7) 134 | output = tf.layers.conv2d(output, 32, 7, padding='valid', activation=activation, use_bias=False) 135 | for num_filters in [64, 128]: 136 | output = reflection_pad(output, 3) 137 | output = tf.layers.conv2d(output, num_filters, 3, strides=2, padding='valid', 138 | activation=activation, use_bias=False) 139 | for _ in range(num_residual): 140 | new_out = output 141 | for i in range(2): 142 | activation = instance_norm if i == 1 else lambda x: tf.nn.relu(instance_norm(x)) 143 | new_out = reflection_pad(new_out, 3) 144 | new_out = tf.layers.conv2d(new_out, 128, 3, padding='valid', activation=activation, 145 | use_bias=False) 146 | output = output + new_out 147 | for num_filters in [64, 32]: 148 | output = tf.layers.conv2d_transpose(output, num_filters, 3, strides=2, padding='same', 149 | activation=activation, use_bias=False) 150 | output = reflection_pad(output, 7) 151 | return tf.sigmoid(tf.layers.conv2d(output, 3, 7, padding='valid'))[0] 152 | 153 | 154 | def instance_norm(images, epsilon=1e-5, name='instance_norm'): 155 | """Apply instance normalization to the batch.""" 156 | means = tf.reduce_mean(images, axis=[1, 2], keep_dims=True) 157 | stddevs = tf.sqrt(tf.reduce_mean(tf.square(images - means), axis=[1, 2], keep_dims=True)) 158 | results = (images - means) / (stddevs + epsilon) 159 | with tf.variable_scope(None, default_name=name): 160 | biases = tf.get_variable('biases', shape=images.get_shape()[-1].value, dtype=images.dtype, 161 | initializer=tf.zeros_initializer()) 162 | scales = tf.get_variable('scales', shape=images.get_shape()[-1].value, dtype=images.dtype, 163 | initializer=tf.ones_initializer()) 164 | return results*scales + biases 165 | 166 | 167 | def reflection_pad(images, filter_size): 168 | """Perform reflection padding for a convolution.""" 169 | num = filter_size // 2 170 | return tf.pad(images, [[0, 0], [num, num], [num, num], [0, 0]], mode='REFLECT') 171 | 172 | 173 | def _grad_dict(term, variables): 174 | grads = tf.gradients(term, variables) 175 | res = {} 176 | for var, grad in zip(variables, grads): 177 | if grad is not None: 178 | res[var] = grad 179 | return res 180 | 181 | 182 | def _add_grad_dicts(dict1, dict2): 183 | res = dict1.copy() 184 | for var, grad in dict2.items(): 185 | if var in res: 186 | res[var] += grad 187 | else: 188 | res[var] = grad 189 | return res 190 | 191 | 192 | def _add_image_noise(image): 193 | return tf.clip_by_value(image + tf.random_normal(tf.shape(image), stddev=0.001), 0, 1) 194 | -------------------------------------------------------------------------------- /cnn_toys/cyclegan/run_single.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run a trained CycleGAN on a single image. 3 | 4 | Runs both generators on the image and produces a grid 5 | containing both outputs. 6 | """ 7 | 8 | import argparse 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from cnn_toys.cyclegan.model import CycleGAN 14 | from cnn_toys.data import images_dataset 15 | from cnn_toys.graphics import save_image_grid 16 | from cnn_toys.saving import restore_state 17 | 18 | 19 | def main(args): 20 | """Load and use a model.""" 21 | print('loading input image...') 22 | dataset = images_dataset([args.in_file], args.size, bigger_size=args.bigger_size) 23 | image = dataset.repeat().make_one_shot_iterator().get_next() 24 | print('setting up model...') 25 | model = CycleGAN(image, image) 26 | tf.get_variable('global_step', dtype=tf.int64, shape=(), initializer=tf.zeros_initializer()) 27 | with tf.Session() as sess: 28 | print('initializing variables...') 29 | sess.run(tf.global_variables_initializer()) 30 | print('attempting to restore model...') 31 | restore_state(sess, args.state_file) 32 | print('running model...') 33 | row = sess.run([model.gen_x, model.gen_y]) 34 | save_image_grid(np.array([row]), args.out_file) 35 | 36 | 37 | def _parse_args(): 38 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 39 | parser.add_argument('--size', help='image size', type=int, default=256) 40 | parser.add_argument('--bigger-size', help='size to crop from', type=int, default=286) 41 | parser.add_argument('--state-file', help='state input file', default='state.pkl') 42 | parser.add_argument('--iters', help='number of training steps', type=int, default=100000) 43 | parser.add_argument('in_file', help='path to input file') 44 | parser.add_argument('out_file', help='path to output file') 45 | return parser.parse_args() 46 | 47 | 48 | if __name__ == '__main__': 49 | main(_parse_args()) 50 | -------------------------------------------------------------------------------- /cnn_toys/cyclegan/run_train.py: -------------------------------------------------------------------------------- 1 | """Train a CycleGAN model.""" 2 | 3 | import argparse 4 | import os 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from cnn_toys.cyclegan.model import CycleGAN 10 | from cnn_toys.data import dir_dataset 11 | from cnn_toys.graphics import save_image_grid 12 | from cnn_toys.saving import save_state, restore_state 13 | from cnn_toys.schedules import half_annealed_lr 14 | 15 | 16 | def main(args): 17 | """The main training loop.""" 18 | print('loading datasets...') 19 | real_x = tf.image.random_flip_left_right(_load_dataset(args.data_dir_1, args.size, 20 | args.bigger_size)) 21 | real_y = tf.image.random_flip_left_right(_load_dataset(args.data_dir_2, args.size, 22 | args.bigger_size)) 23 | print('setting up model...') 24 | model = CycleGAN(real_x, real_y) 25 | global_step = tf.get_variable('global_step', dtype=tf.int64, shape=(), 26 | initializer=tf.zeros_initializer()) 27 | optimize = model.optimize( 28 | learning_rate=half_annealed_lr(args.step_size, args.iters, global_step), 29 | global_step=global_step) 30 | with tf.Session() as sess: 31 | print('initializing variables...') 32 | sess.run(tf.global_variables_initializer()) 33 | print('attempting to restore model...') 34 | restore_state(sess, args.state_file) 35 | print('training...') 36 | while sess.run(global_step) < args.iters: 37 | terms = sess.run((optimize, model.disc_loss, model.gen_loss, model.cycle_loss)) 38 | step = sess.run(global_step) 39 | print('step %d: disc=%f gen=%f cycle=%f' % ((step,) + terms[1:])) 40 | if step % args.sample_interval == 0: 41 | save_state(sess, args.state_file) 42 | print('saving samples...') 43 | _generate_samples(sess, args, model, step) 44 | _generate_cycle_samples(sess, args, model, step) 45 | 46 | 47 | def _parse_args(): 48 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 49 | parser.add_argument('--data-dir-1', help='first data directory', default='data_1') 50 | parser.add_argument('--data-dir-2', help='second data directory', default='data_2') 51 | parser.add_argument('--size', help='image size', type=int, default=256) 52 | parser.add_argument('--bigger-size', help='size to crop from', type=int, default=286) 53 | parser.add_argument('--step-size', help='training step size', type=float, default=2e-4) 54 | parser.add_argument('--state-file', help='state output file', default='state.pkl') 55 | parser.add_argument('--iters', help='number of training steps', type=int, default=100000) 56 | parser.add_argument('--sample-interval', help='iters per sample', type=int, default=1000) 57 | parser.add_argument('--sample-dir', help='directory to dump samples', default='samples') 58 | parser.add_argument('--sample-count', help='number of samples to draw', type=int, default=16) 59 | return parser.parse_args() 60 | 61 | 62 | def _load_dataset(dir_path, size, bigger_size): 63 | dataset = dir_dataset(dir_path, size, bigger_size=bigger_size) 64 | return dataset.repeat().make_one_shot_iterator().get_next() 65 | 66 | 67 | def _generate_samples(sess, args, model, step): 68 | _generate_grid(sess, args, step, 'samples', 69 | (model.real_x, model.gen_y, model.real_y, model.gen_x)) 70 | 71 | 72 | def _generate_cycle_samples(sess, args, model, step): 73 | _generate_grid(sess, args, step, 'cycles', 74 | (model.real_x, model.cycle_x, model.real_y, model.cycle_y)) 75 | 76 | 77 | def _generate_grid(sess, args, step, filename, tensors): 78 | if not os.path.exists(args.sample_dir): 79 | os.mkdir(args.sample_dir) 80 | grid = [] 81 | for _ in range(args.sample_count): 82 | grid.append(sess.run(tensors)) 83 | save_image_grid(np.array(grid), os.path.join(args.sample_dir, '%s_%d.png' % (filename, step))) 84 | 85 | 86 | if __name__ == '__main__': 87 | main(_parse_args()) 88 | -------------------------------------------------------------------------------- /cnn_toys/cyclegan/test_history.py: -------------------------------------------------------------------------------- 1 | """Tests for image histories.""" 2 | 3 | # pylint: disable=E1129 4 | 5 | import tensorflow as tf 6 | 7 | from .history import history_image 8 | 9 | 10 | def test_history_image_append(): 11 | """Test underfull image histories""" 12 | with tf.Graph().as_default(): 13 | in_image = tf.random_normal((5, 5)) 14 | hist = history_image(in_image, 5) 15 | with tf.Session() as sess: 16 | sess.run(tf.global_variables_initializer()) 17 | for _ in range(5): 18 | in_arr, out_arr = sess.run((in_image, hist)) 19 | assert (in_arr == out_arr).all() 20 | 21 | 22 | def test_history_image_sample(): 23 | """Test sampling from image histories""" 24 | with tf.Graph().as_default(): 25 | in_image = tf.random_normal((5, 5)) 26 | hist = history_image(in_image, 5) 27 | with tf.Session() as sess: 28 | sess.run(tf.global_variables_initializer()) 29 | history = [] 30 | for _ in range(5): 31 | history.append(sess.run(hist)) 32 | sample_count = 0 33 | while sample_count < 5: 34 | in_arr, out_arr = sess.run((in_image, hist)) 35 | if (in_arr == out_arr).all(): 36 | continue 37 | found = False 38 | for i, hist_entry in enumerate(history): 39 | if (hist_entry == out_arr).all(): 40 | found = True 41 | history[i] = in_arr 42 | assert found 43 | sample_count += 1 44 | 45 | 46 | def test_history_image_single(): 47 | """Test a buffer with one sample""" 48 | with tf.Graph().as_default(): 49 | in_image = tf.random_normal((5, 5)) 50 | hist = history_image(in_image, 1) 51 | with tf.Session() as sess: 52 | sess.run(tf.global_variables_initializer()) 53 | history = sess.run(hist) 54 | sample_count = 0 55 | while sample_count < 5: 56 | in_arr, out_arr = sess.run((in_image, hist)) 57 | if (in_arr == out_arr).all(): 58 | continue 59 | assert (history == out_arr).all() 60 | history = in_arr 61 | sample_count += 1 62 | -------------------------------------------------------------------------------- /cnn_toys/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for dealing with datasets of images. 3 | """ 4 | 5 | import glob 6 | from hashlib import md5 7 | import os 8 | 9 | import tensorflow as tf 10 | 11 | 12 | def dir_train_val(image_dir, size, bigger_size=None): 13 | """ 14 | Create a training and validation Dataset by reading 15 | and splitting a directory of images. 16 | 17 | Returns (train, validation). 18 | """ 19 | paths = _find_paths(image_dir) 20 | train_paths = [p for p in paths if not _use_for_val(p)] 21 | val_paths = [p for p in paths if _use_for_val(p)] 22 | if not train_paths or not val_paths: 23 | raise RuntimeError('not enough data') 24 | return (images_dataset(train_paths, size, bigger_size=bigger_size), 25 | images_dataset(val_paths, size, bigger_size=bigger_size)) 26 | 27 | 28 | def dir_dataset(image_dir, size, bigger_size=None): 29 | """Create a Dataset of images from a directory.""" 30 | return images_dataset(_find_paths(image_dir), size, bigger_size=bigger_size) 31 | 32 | 33 | def images_dataset(paths, size, bigger_size=None): 34 | """ 35 | Create a Dataset of images from image file paths. 36 | 37 | Args: 38 | paths: a sequence of image paths. 39 | size: the size of the resulting images. 40 | bigger_size: if not None, the images are scaled to 41 | this size before being randomly cropped to size. 42 | """ 43 | paths_ds = tf.data.Dataset.from_tensor_slices(paths) 44 | 45 | def _read_image(path_tensor): 46 | data_tensor = tf.read_file(path_tensor) 47 | image_tensor = tf.image.decode_image(data_tensor, channels=3) 48 | image_tensor.set_shape((None, None, 3)) 49 | if bigger_size is None: 50 | return tf.cast(tf.image.resize_images(image_tensor, [size, size]), tf.float32) / 0xff 51 | big = tf.image.resize_images(image_tensor, [bigger_size, bigger_size]) 52 | small = tf.random_crop(big, [size, size, 3]) 53 | return tf.cast(small, tf.float32) / 0xff 54 | return paths_ds.shuffle(buffer_size=len(paths)).map(_read_image) 55 | 56 | 57 | def _use_for_val(path): 58 | return md5(bytes(path, 'utf-8')).digest()[0] < 0x80 59 | 60 | 61 | def _find_paths(image_dir): 62 | if not os.path.isdir(image_dir): 63 | if '*' in image_dir: 64 | return glob.glob(image_dir) 65 | else: 66 | raise RuntimeError('image directory not found: ' + image_dir) 67 | paths = [] 68 | for name in os.listdir(image_dir): 69 | if not name.startswith('.'): 70 | paths.append(os.path.join(image_dir, name)) 71 | return paths 72 | -------------------------------------------------------------------------------- /cnn_toys/graphics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for creating graphics from CNN models. 3 | """ 4 | 5 | from PIL import Image 6 | import numpy as np 7 | 8 | 9 | def save_image_grid(grid, out_file, padding=10): 10 | """ 11 | Save a grid of RGB images to a file. 12 | 13 | Args: 14 | grid: a 5-dimensional np.array of images. The shape 15 | is [rows x cols x img_height x img_width x 3]. 16 | Pixel values should range from 0 to 1. 17 | out_file: the path where the file should be saved. 18 | padding: pixels of space to put around each image. 19 | """ 20 | grid = np.clip(grid, 0, 1) 21 | num_rows = grid.shape[0] 22 | num_cols = grid.shape[1] 23 | img_height = grid.shape[2] 24 | img_width = grid.shape[3] 25 | grid_img = np.zeros((num_rows * img_height + (num_rows + 1) * padding, 26 | num_cols * img_width + (num_cols + 1) * padding, 3), dtype='float32') 27 | # White background for the border. 28 | grid_img += 1 29 | for row, row_imgs in enumerate(grid): 30 | for col, img in enumerate(row_imgs): 31 | row_start = row * img_height + (row + 1) * padding 32 | col_start = col * img_width + (col + 1) * padding 33 | grid_img[row_start: row_start + img_height, 34 | col_start: col_start + img_width] = img 35 | img = Image.fromarray((grid_img * 0xff).astype('uint8'), 'RGB') 36 | img.save(out_file) 37 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of RealNVP: 3 | https://arxiv.org/abs/1605.08803 4 | """ 5 | 6 | from .interp import interpolate, interpolate_linear 7 | from .layer import (FactorHalf, MaskedConv, MaskedFC, NVPLayer, Network, PaddedLogit, Squeeze, 8 | checkerboard_mask, depth_mask, one_cold_mask) 9 | from .models import simple_network 10 | from .objective import (bits_per_pixel, bits_per_pixel_and_grad, log_likelihood, 11 | output_log_likelihood) 12 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/interp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Interpolation in latent space. 3 | """ 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def interpolate(latents, fracs): 9 | """ 10 | Interpolate between two images' latent variables. 11 | 12 | Args: 13 | latents: latents from running a batch of two images 14 | through an NVPLayer. 15 | fracs: a sequence of interpolation fractions, where 16 | 0 means the first image and 1 means the second. 17 | 18 | Returns: 19 | A new set of latents of batch size `len(fracs)`. 20 | """ 21 | new_latents = [] 22 | for latent in latents: 23 | img_1 = latent[0] 24 | img_2 = latent[1] 25 | spread = [img_1 * f + img_2 * (1 - f) for f in fracs] 26 | new_latents.append(tf.stack(spread)) 27 | return new_latents 28 | 29 | 30 | def interpolate_linear(latents, num_stops): 31 | """ 32 | Linearly interpolate between two images' latent 33 | variables. 34 | 35 | Args: 36 | latents: latents from running a batch of two images 37 | through an NVPLayer. 38 | num_stops: the number of samples to produce. 39 | 40 | Returns: 41 | A new set of latents of batch size num_stops. 42 | """ 43 | return interpolate(latents, [i / (num_stops - 1) for i in range(num_stops)]) 44 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Real-valued non-volume preserving transformations. 3 | """ 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | 11 | class NVPLayer(ABC): 12 | """ 13 | A layer in a real NVP model. 14 | 15 | Subclasses must override _forward() and _inverse(). 16 | Subclasses may also override test_feed_dict() and 17 | num_latents() if appropriate. 18 | """ 19 | @property 20 | def num_latents(self): 21 | """ 22 | Get the size of the latent tuple returned by 23 | forward(). 24 | """ 25 | return 0 26 | 27 | def test_feed_dict(self): 28 | """ 29 | Get a feed_dict to pass to TensorFlow when testing 30 | the model. Typically, this will tell BatchNorm to 31 | use pre-computed statistics. 32 | """ 33 | return {} 34 | 35 | @abstractmethod 36 | def _forward(self, inputs): 37 | """ 38 | Apply the layer to a batch of inputs. 39 | 40 | Args: 41 | inputs: an input batch for the layer. 42 | 43 | Returns: 44 | A tuple (outputs, latents, log_det): 45 | outputs: the values to be passed to the next 46 | layer of the network. May be None for the 47 | last layer of the network. 48 | latents: A tuple of factored out Tensors. 49 | This may be an empty tuple. 50 | log_det: a batch of log of the determinants. 51 | """ 52 | pass 53 | 54 | @abstractmethod 55 | def _inverse(self, outputs, latents): 56 | """ 57 | Apply the inverse of the layer. 58 | 59 | Args: 60 | outputs: the outputs from the layer. 61 | latents: the latent outputs from the layer. 62 | 63 | Returns: 64 | The recovered input batch for the layer. 65 | """ 66 | pass 67 | 68 | def forward(self, inputs, name='layer', reuse=False): 69 | """ 70 | Apply the layer to a batch of inputs. 71 | 72 | Args: 73 | inputs: an input batch for the layer. 74 | name: the name of the variable scope. 75 | reuse: the variable scope reuse flag. 76 | 77 | Returns: 78 | A tuple (outputs, latents, log_det): 79 | outputs: the values to be passed to the next 80 | layer of the network. May be None for the 81 | last layer of the network. 82 | latents: A tuple of factored out Tensors. 83 | This may be an empty tuple. 84 | log_det: a batch of log of the determinants. 85 | """ 86 | with tf.variable_scope(name, reuse=reuse): 87 | return self._forward(inputs) 88 | 89 | def inverse(self, outputs, latents, name='layer', reuse=False): 90 | """ 91 | Apply the inverse of the layer. 92 | 93 | Args: 94 | outputs: the outputs from the layer. 95 | latents: the latent outputs from the layer. 96 | name: the name of the variable scope. 97 | reuse: the variable scope reuse flag. 98 | 99 | Returns: 100 | The recovered input batch for the layer. 101 | """ 102 | with tf.variable_scope(name, reuse=reuse): 103 | return self._inverse(outputs, latents) 104 | 105 | def backward(self, outputs, outputs_grad, latents, latents_grad, log_det_grad, 106 | var_list=None, name='layer', reuse=False): 107 | """ 108 | Compute a gradient through the layer. 109 | 110 | This is optimized for memory consumption. 111 | Currently, it does not support 2nd derivatives. 112 | 113 | Args: 114 | outputs: the outputs of the layer. May be None. 115 | outputs_grad: the gradient of the objective with 116 | respect to the outputs. May be None. 117 | latents: the latent outputs from the layer. 118 | latents_grad: the gradient of the objective with 119 | respect to the latents. 120 | log_det_grad: the gradient of the objective with 121 | respect to the log determinant. 122 | var_list: the list of variables to differentiate 123 | with respect to. If None, use all trainable 124 | variables. 125 | name: the name of the variable scope. 126 | reuse: the variable scope reuse flag. 127 | 128 | Returns: 129 | A tuple (upstream, grads): 130 | inputs: the recovered inputs to the layer. 131 | upstream: a Tensor representing the gradient 132 | of the objective with respect to the inputs 133 | to the layer. 134 | grads: a list of (gradient, variable) pairs 135 | for the parameters of the layer. 136 | """ 137 | inputs = tf.stop_gradient(self.inverse(outputs, latents, name=name, reuse=reuse)) 138 | new_outputs, new_latents, new_log_dets = self.forward(inputs, 139 | name=name, 140 | reuse=True) 141 | objective = tf.reduce_sum(new_log_dets * tf.stop_gradient(log_det_grad)) 142 | if new_outputs is not None: 143 | objective += tf.reduce_sum(new_outputs * tf.stop_gradient(outputs_grad)) 144 | for latent, latent_grad in zip(new_latents, latents_grad): 145 | objective += tf.reduce_sum(latent * tf.stop_gradient(latent_grad)) 146 | variables = var_list if var_list is not None else tf.trainable_variables() 147 | grads = tf.gradients(objective, [inputs] + variables) 148 | input_grad = grads[0] 149 | if input_grad is None: 150 | input_grad = tf.Print(tf.zeros_like(input_grad), [], 151 | message='WARNING: gradient does not flow to inputs', 152 | first_n=1) 153 | var_grads = [pair for pair in zip(grads[1:], variables) if pair[0] is not None] 154 | return inputs, input_grad, var_grads 155 | 156 | def gradients(self, outputs, latents, log_det, loss, var_list=None, name='layer', reuse=True): 157 | """ 158 | Perform backpropagation through the layer using 159 | the backward() method. 160 | 161 | This computes gradients without needing to store 162 | intermediate Tensors from the forward pass. 163 | Currently, it does not support 2nd derivatives. 164 | 165 | Args: 166 | outputs: the layer outputs. 167 | latents: the layer's latent outputs. 168 | log_det: the output log determinants. 169 | loss: the loss value resulting from the latents 170 | and log determinants. 171 | var_list: the variables to find gradients for, 172 | or None to use all trainable variables. 173 | name: the name of the variable scope. 174 | reuse: the variable scope reuse flag. 175 | 176 | Returns: 177 | A list of (gradient, variable) pairs. 178 | """ 179 | assert len(latents) == self.num_latents 180 | if outputs is not None: 181 | outputs_grad = tf.gradients(loss, outputs)[0] 182 | if outputs_grad is None: 183 | outputs_grad = tf.zeros_like(outputs) 184 | else: 185 | outputs_grad = None 186 | latents_grad = [grad if grad is not None else tf.zeros_like(latent) 187 | for grad, latent in zip(tf.gradients(loss, latents), latents)] 188 | log_det_grad = tf.gradients(loss, log_det)[0] 189 | if log_det_grad is None: 190 | log_det_grad = tf.zeros_like(log_det) 191 | return self.backward(outputs, outputs_grad, latents, latents_grad, log_det_grad, 192 | var_list=var_list, name=name, reuse=reuse)[2] 193 | 194 | 195 | class Network(NVPLayer): 196 | """ 197 | A feed-forward composition of NVP layers. 198 | """ 199 | 200 | def __init__(self, layers): 201 | self.layers = layers 202 | 203 | @property 204 | def num_latents(self): 205 | return 1 + sum(l.num_latents for l in self.layers) 206 | 207 | def test_feed_dict(self): 208 | res = {} 209 | for layer in self.layers: 210 | res.update(layer.test_feed_dict()) 211 | 212 | def _forward(self, inputs): 213 | latents = [] 214 | outputs = inputs 215 | log_det = tf.zeros(shape=[tf.shape(inputs)[0]], dtype=inputs.dtype) 216 | for i, layer in enumerate(self.layers): 217 | outputs, sub_latents, sub_log_det = layer.forward(outputs, name='layer_%d' % i) 218 | latents.extend(sub_latents) 219 | log_det = log_det + sub_log_det 220 | latents.append(outputs) 221 | return None, tuple(latents), log_det 222 | 223 | def _inverse(self, outputs, latents): 224 | assert outputs is None 225 | assert len(latents) == self.num_latents 226 | inputs = latents[-1] 227 | latents = latents[:-1] 228 | for i, layer in list(enumerate(self.layers))[::-1]: 229 | if layer.num_latents > 0: 230 | sub_latents = latents[-layer.num_latents:] 231 | latents = latents[:-layer.num_latents] 232 | else: 233 | sub_latents = () 234 | inputs = layer.inverse(inputs, sub_latents, name='layer_%d' % i) 235 | return inputs 236 | 237 | def backward(self, outputs, outputs_grad, latents, latents_grad, log_det_grad, 238 | var_list=None, name='layer', reuse=False): 239 | with tf.variable_scope(name, reuse=reuse): 240 | assert outputs is None 241 | assert outputs_grad is None 242 | outputs = latents[-1] 243 | outputs_grad = latents_grad[-1] 244 | latents = latents[:-1] 245 | latents_grad = latents_grad[:-1] 246 | total_grads = {} 247 | prev_grads = [] 248 | for i, layer in list(enumerate(self.layers))[::-1]: 249 | if layer.num_latents > 0: 250 | sub_latents = latents[-layer.num_latents:] 251 | sub_latents_grad = latents_grad[-layer.num_latents:] 252 | latents = latents[:-layer.num_latents] 253 | latents_grad = latents_grad[:-layer.num_latents] 254 | else: 255 | sub_latents = () 256 | sub_latents_grad = () 257 | with tf.control_dependencies(prev_grads): 258 | outputs, outputs_grad, vars_grad = layer.backward(outputs, 259 | outputs_grad, 260 | sub_latents, 261 | sub_latents_grad, 262 | log_det_grad, 263 | var_list=var_list, 264 | name='layer_%d' % i) 265 | for grad, var in vars_grad: 266 | if var in total_grads: 267 | total_grads[var] += grad 268 | else: 269 | total_grads[var] = grad 270 | prev_grads = [g for g, _ in vars_grad] 271 | return outputs, outputs_grad, [(grad, var) for var, grad in total_grads.items()] 272 | 273 | 274 | class PaddedLogit(NVPLayer): 275 | """ 276 | An NVP layer that applies `logit(a + (1-2a)x)`. 277 | """ 278 | 279 | def __init__(self, alpha=0.05): 280 | self.alpha = alpha 281 | 282 | def _forward(self, inputs): 283 | padded = self.alpha + (1 - 2 * self.alpha) * inputs 284 | logits = tf.log(padded / (1 - padded)) 285 | log_dets = tf.log(1 / padded + 1 / (1 - padded)) + tf.log((1 - 2 * self.alpha)) 286 | return logits, (), sum_batch(log_dets) 287 | 288 | def _inverse(self, outputs, latents): 289 | assert latents == () 290 | sigmoids = tf.nn.sigmoid(outputs) 291 | return (sigmoids - self.alpha) / (1 - 2 * self.alpha) 292 | 293 | 294 | class FactorHalf(NVPLayer): 295 | """ 296 | A layer that factors out half of the inputs. 297 | """ 298 | @property 299 | def num_latents(self): 300 | return 1 301 | 302 | def _forward(self, inputs): 303 | return (inputs[..., ::2], (inputs[..., 1::2],), 304 | tf.constant(0, dtype=inputs.dtype)) 305 | 306 | def _inverse(self, outputs, latents): 307 | assert len(latents) == 1 308 | # Trick to undo the alternating split. 309 | expanded_1 = tf.expand_dims(outputs, axis=-1) 310 | expanded_2 = tf.expand_dims(latents[0], axis=-1) 311 | concatenated = tf.concat([expanded_1, expanded_2], axis=-1) 312 | new_shape = [tf.shape(outputs)[0]] + [x.value for x in outputs.get_shape()[1:]] 313 | new_shape[-1] *= 2 314 | return tf.reshape(concatenated, new_shape) 315 | 316 | 317 | class Squeeze(NVPLayer): 318 | """ 319 | A layer that squeezes 2x2x1 blocks into 1x1x4 blocks. 320 | """ 321 | 322 | def _forward(self, inputs): 323 | assert all([x.value % 2 == 0 for x in inputs.get_shape()[1:3]]), 'even shape required' 324 | conv_filter = self._permutation_filter(inputs.get_shape()[-1].value, inputs.dtype) 325 | return (tf.nn.conv2d(inputs, conv_filter, [1, 2, 2, 1], 'VALID'), 326 | (), tf.constant(0, dtype=inputs.dtype)) 327 | 328 | def _inverse(self, outputs, latents): 329 | assert latents == () 330 | in_depth = outputs.get_shape()[-1].value // 4 331 | conv_filter = self._permutation_filter(in_depth, outputs.dtype) 332 | out_shape = ([tf.shape(outputs)[0]] + [x.value * 2 for x in outputs.get_shape()[1:3]] + 333 | [in_depth]) 334 | return tf.nn.conv2d_transpose(outputs, conv_filter, out_shape, [1, 2, 2, 1], 'VALID') 335 | 336 | @staticmethod 337 | def _permutation_filter(depth, dtype): 338 | """ 339 | Generate a convolutional filter that performs the 340 | squeeze operation. 341 | """ 342 | res = np.zeros((2, 2, depth, depth * 4)) 343 | for i in range(depth): 344 | for row in range(2): 345 | for col in range(2): 346 | res[row, col, i, 4 * i + row * 2 + col] = 1 347 | return tf.constant(res, dtype=dtype) 348 | 349 | 350 | class MaskedLayer(NVPLayer): 351 | """ 352 | An abstract NVP transformation that uses a masked 353 | neural network. 354 | """ 355 | 356 | def _forward(self, inputs): 357 | biases, log_scales = self._apply_masked(inputs) 358 | log_det = sum_batch(log_scales) 359 | return inputs * tf.exp(log_scales) + biases, (), log_det 360 | 361 | def _inverse(self, outputs, latents): 362 | assert latents == () 363 | biases, log_scales = self._apply_masked(outputs) 364 | return (outputs - biases) * tf.exp(-log_scales) 365 | 366 | @abstractmethod 367 | def _apply_masked(self, inputs): 368 | """ 369 | Get (biases, log_scales) for the inputs. 370 | """ 371 | pass 372 | 373 | 374 | class MaskedConv(MaskedLayer): 375 | """ 376 | A masked convolution NVP transformation. 377 | """ 378 | 379 | def __init__(self, mask_fn, num_residual, num_features=32, kernel_size=3, **conv_kwargs): 380 | """ 381 | Create a masked convolution layer. 382 | 383 | Args: 384 | mask_fn: a function which takes a Tensor and 385 | produces a boolean mask Tensor. 386 | num_residual: the number of residual blocks. 387 | num_features: the number of latent features. 388 | kernel_size: the convolutional kernel size. 389 | conv_kwargs: other arguments for conv2d(). 390 | """ 391 | self.mask_fn = mask_fn 392 | self.num_residual = num_residual 393 | self.num_features = num_features 394 | self.kernel_size = kernel_size 395 | self.conv_kwargs = conv_kwargs 396 | self._training = tf.constant(True) 397 | 398 | def test_feed_dict(self): 399 | return {self._training: False} 400 | 401 | def _apply_masked(self, inputs): 402 | """ 403 | Get (biases, log_scales) for the inputs. 404 | """ 405 | mask = self.mask_fn(inputs) 406 | depth = inputs.get_shape()[3].value 407 | masked = tf.where(mask, inputs, tf.zeros_like(inputs)) 408 | latent = tf.layers.conv2d(masked, self.num_features, 1) 409 | for _ in range(self.num_residual): 410 | latent = self._residual_block(tf.nn.relu(latent)) 411 | with tf.variable_scope(None, default_name='mask_biases'): 412 | bias_params = tf.layers.conv2d(latent, depth, 1) 413 | biases = tf.where(mask, tf.zeros_like(inputs), bias_params) 414 | with tf.variable_scope(None, default_name='mask_scales'): 415 | scale_params = tf.layers.conv2d(latent, depth, 1) 416 | log_scales = tf.where(mask, 417 | tf.zeros_like(inputs), 418 | tf.tanh(scale_params) * self._get_tanh_scale(inputs)) 419 | return biases, log_scales 420 | 421 | def _residual_block(self, inputs): 422 | with tf.variable_scope(None, default_name='residual'): 423 | output = tf.layers.conv2d(inputs, self.num_features, self.kernel_size, padding='same', 424 | **self.conv_kwargs) 425 | output = tf.nn.relu(self._batch_norm(output)) 426 | output = tf.layers.conv2d(output, self.num_features, self.kernel_size, padding='same', 427 | **self.conv_kwargs) 428 | output = self._batch_norm(output) 429 | return output + inputs 430 | 431 | def _batch_norm(self, values): 432 | return tf.layers.batch_normalization(values, training=self._training) 433 | 434 | @staticmethod 435 | def _get_tanh_scale(in_out): 436 | with tf.variable_scope(None, default_name='mask_params'): 437 | return tf.get_variable('tanh_scale', 438 | shape=[x.value for x in in_out.get_shape()[1:]], 439 | dtype=in_out.dtype, 440 | initializer=tf.zeros_initializer()) 441 | 442 | 443 | class MaskedFC(MaskedLayer): 444 | """ 445 | A fully-connected layer that scales certain dimensions 446 | using information from other dimensions. 447 | """ 448 | 449 | def __init__(self, mask_fn, num_features=64, num_layers=2): 450 | """ 451 | Create a masked layer. 452 | 453 | Args: 454 | mask_fn: a function which takes a Tensor and 455 | produces a boolean mask Tensor. 456 | num_features: the number of hidden units. 457 | num_layers: the number of hidden layers. 458 | """ 459 | self.mask_fn = mask_fn 460 | self.num_features = num_features 461 | self.num_layers = num_layers 462 | 463 | def _apply_masked(self, inputs): 464 | """ 465 | Get (biases, log_scales) for the inputs. 466 | """ 467 | depth = inputs.get_shape()[-1] 468 | mask = self.mask_fn(inputs) 469 | 470 | masked_in = tf.where(mask, inputs, tf.zeros_like(inputs)) 471 | out = tf.layers.dense(masked_in, self.num_features, activation=tf.nn.relu) 472 | for _ in range(self.num_layers - 1): 473 | out = tf.layers.dense(out, self.num_features, activation=tf.nn.relu) 474 | log_scales = tf.layers.dense(out, depth, kernel_initializer=tf.zeros_initializer()) 475 | log_scales = tf.where(mask, tf.zeros_like(inputs), log_scales) 476 | biases = tf.layers.dense(out, depth, kernel_initializer=tf.zeros_initializer()) 477 | biases = tf.where(mask, tf.zeros_like(inputs), biases) 478 | 479 | return biases, log_scales 480 | 481 | 482 | def checkerboard_mask(is_even, tensor): 483 | """ 484 | Create a checkerboard mask in the shape of a Tensor. 485 | 486 | Args: 487 | is_even: determines which of two masks to use. 488 | tensor: the Tensor whose shape to match. 489 | """ 490 | result = np.zeros([x.value for x in tensor.get_shape()[1:]], dtype='bool') 491 | for row in range(result.shape[0]): 492 | for col in range(result.shape[1]): 493 | result[row, col, :] = (((row + col) % 2 == 0) == is_even) 494 | return tf.tile(tf.expand_dims(result, axis=0), [tf.shape(tensor)[0], 1, 1, 1]) 495 | 496 | 497 | def one_cold_mask(idx, tensor): 498 | """ 499 | Create a mask that only masks out one channel in a 2-D 500 | input Tensor. 501 | """ 502 | result = np.ones([tensor.get_shape()[-1].value], dtype='bool') 503 | result[idx] = False 504 | return tf.tile(tf.expand_dims(result, axis=0), [tf.shape(tensor)[0], 1]) 505 | 506 | 507 | def depth_mask(is_even, tensor): 508 | """ 509 | Create a depth mask in the shape of a Tensor. 510 | 511 | Args: 512 | is_even: determines which of two masks to use. 513 | tensor: the Tensor whose shape to match. 514 | """ 515 | assert tensor.get_shape()[-1] % 4 == 0, 'depth must be divisible by 4' 516 | if is_even: 517 | mask = [True, True, False, False] 518 | else: 519 | mask = [False, False, True, True] 520 | one_dim = tf.tile(tf.constant(mask, dtype=tf.bool), [tensor.get_shape()[-1].value // 4]) 521 | # Broadcast, since + doesn't work for booleans. 522 | return tf.logical_or(one_dim, tf.zeros(shape=tf.shape(tensor), dtype=tf.bool)) 523 | 524 | 525 | def sum_batch(tensor): 526 | """ 527 | Compute a 1-D batch of sums. 528 | """ 529 | return tf.reduce_sum(tf.reshape(tensor, [tf.shape(tensor)[0], -1]), axis=1) 530 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pre-built networks. 3 | """ 4 | 5 | from functools import partial 6 | 7 | from .layer import (FactorHalf, MaskedConv, Network, PaddedLogit, Squeeze, 8 | checkerboard_mask, depth_mask) 9 | 10 | 11 | def simple_network(): 12 | """ 13 | A Network that is good for experimenting. 14 | """ 15 | main_layers = [ 16 | MaskedConv(partial(checkerboard_mask, True), 2), 17 | MaskedConv(partial(checkerboard_mask, False), 2), 18 | MaskedConv(partial(checkerboard_mask, True), 2), 19 | Squeeze(), 20 | MaskedConv(partial(depth_mask, True), 2), 21 | MaskedConv(partial(depth_mask, False), 2), 22 | MaskedConv(partial(depth_mask, True), 2), 23 | FactorHalf() 24 | ] 25 | return Network([PaddedLogit()] + (main_layers * 3)) 26 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/objective.py: -------------------------------------------------------------------------------- 1 | """ 2 | Likelihood objectives. 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from .layer import sum_batch 9 | 10 | 11 | def bits_per_pixel(layer, inputs, noise=1.0 / 255.0): 12 | """ 13 | Compute the bits per pixel for each input image. 14 | 15 | Args: 16 | layer: the network to apply. 17 | inputs: the input images. 18 | noise: the amount of noise to add for a Monte Carlo 19 | integral. This is used to turn discrete inputs 20 | into continuous inputs. 21 | """ 22 | # Compute a Monte Carlo integral with one sample. 23 | sampled_noise = tf.random_uniform(tf.shape(inputs), maxval=noise) 24 | log_probs = log_likelihood(layer, inputs + sampled_noise) 25 | num_pixels = int(np.prod([x.value for x in inputs.get_shape()[1:]])) 26 | return -(log_probs / num_pixels + tf.log(float(noise))) / tf.log(2.0) 27 | 28 | 29 | def bits_per_pixel_and_grad(layer, inputs, noise=1.0 / 255.0, var_list=None): 30 | """ 31 | Like bits_per_pixel(), but also computes the gradients 32 | for the mean bits per pixel. 33 | 34 | Returns: 35 | A pair (bits, grads): 36 | bits: a 1-D Tensor of bits-per-pixel values. 37 | grads: a list of (gradient, variable) pairs. 38 | """ 39 | sampled_noise = tf.random_uniform(tf.shape(inputs), maxval=noise) 40 | outputs, latents, log_dets = layer.forward(inputs + sampled_noise) 41 | assert outputs is None 42 | log_probs = output_log_likelihood(latents, log_dets) 43 | num_pixels = int(np.prod([x.value for x in inputs.get_shape()[1:]])) 44 | bits = -(log_probs / num_pixels + tf.log(float(noise))) / tf.log(2.0) 45 | loss = tf.reduce_mean(bits) 46 | grads = layer.gradients(outputs, latents, log_dets, loss, var_list=var_list) 47 | return bits, grads 48 | 49 | 50 | def log_likelihood(layer, inputs): 51 | """ 52 | Compute the log likelihood for each input in a batch, 53 | assuming a Gaussian latent distribution. 54 | """ 55 | outputs, latents, log_dets = layer.forward(inputs) 56 | assert outputs is None, 'extraneous non-latent outputs' 57 | return output_log_likelihood(latents, log_dets) 58 | 59 | 60 | def output_log_likelihood(latents, log_dets): 61 | """ 62 | Like log_likelihood(), but with a pre-computed output 63 | from an NVPLayer. 64 | """ 65 | log_probs = log_dets 66 | for latent in latents: 67 | log_probs = log_probs + gaussian_log_prob(latent) 68 | return log_probs 69 | 70 | 71 | def gaussian_log_prob(tensor): 72 | """ 73 | For each sub-tensor in a batch, compute the Gaussian 74 | log-density. 75 | """ 76 | dist = tf.distributions.Normal(0.0, 1.0) 77 | return sum_batch(dist.log_prob(tensor)) 78 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/run_interp.py: -------------------------------------------------------------------------------- 1 | """Interpolate between samples with a real NVP model.""" 2 | 3 | import argparse 4 | 5 | import tensorflow as tf 6 | 7 | from cnn_toys.data import images_dataset 8 | from cnn_toys.graphics import save_image_grid 9 | from cnn_toys.real_nvp import interpolate_linear, simple_network 10 | from cnn_toys.saving import restore_state 11 | 12 | 13 | def main(args): 14 | """Interpolation entry-point.""" 15 | print('loading images...') 16 | dataset = images_dataset([args.image_1, args.image_2], args.size) 17 | images = dataset.batch(2).make_one_shot_iterator().get_next() 18 | print('setting up model...') 19 | network = simple_network() 20 | with tf.variable_scope('model'): 21 | _, latents, _ = network.forward(images) 22 | latents = interpolate_linear(latents, args.rows * args.cols) 23 | with tf.variable_scope('model', reuse=True): 24 | images = network.inverse(None, latents) 25 | with tf.Session() as sess: 26 | print('initializing variables...') 27 | sess.run(tf.global_variables_initializer()) 28 | print('attempting to restore model...') 29 | restore_state(sess, args.state_file) 30 | print('generating images...') 31 | samples = sess.run(tf.reshape(images, [args.rows, args.cols, args.size, args.size, 3]), 32 | feed_dict=network.test_feed_dict()) 33 | save_image_grid(samples, args.out_file) 34 | 35 | 36 | def _parse_args(): 37 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 38 | parser.add_argument('--size', help='image size', type=int, default=64) 39 | parser.add_argument('--rows', help='rows in output', type=int, default=4) 40 | parser.add_argument('--cols', help='columns in output', type=int, default=4) 41 | parser.add_argument('--state-file', help='state output file', default='state.pkl') 42 | parser.add_argument('--out-file', help='image output file', default='interp.png') 43 | parser.add_argument('image_1', help='first input image') 44 | parser.add_argument('image_2', help='first input image') 45 | return parser.parse_args() 46 | 47 | 48 | if __name__ == '__main__': 49 | main(_parse_args()) 50 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/run_sample.py: -------------------------------------------------------------------------------- 1 | """Sample a real NVP model.""" 2 | 3 | import argparse 4 | 5 | import tensorflow as tf 6 | 7 | from cnn_toys.graphics import save_image_grid 8 | from cnn_toys.real_nvp import simple_network 9 | from cnn_toys.saving import restore_state 10 | 11 | 12 | def main(args): 13 | """Sampling entry-point.""" 14 | if args.seed: 15 | tf.set_random_seed(args.seed) 16 | print('setting up model...') 17 | network = simple_network() 18 | with tf.variable_scope('model'): 19 | fake_batch = tf.zeros((args.rows * args.cols, args.size, args.size, 3), dtype=tf.float32) 20 | _, latents, _ = network.forward(fake_batch) 21 | with tf.variable_scope('model', reuse=True): 22 | gauss_latents = [tf.random_normal(latent.shape, seed=args.seed) for latent in latents] 23 | images = network.inverse(None, gauss_latents) 24 | with tf.Session() as sess: 25 | print('initializing variables...') 26 | sess.run(tf.global_variables_initializer()) 27 | print('attempting to restore model...') 28 | restore_state(sess, args.state_file) 29 | print('generating samples...') 30 | samples = sess.run(tf.reshape(images, [args.rows, args.cols, args.size, args.size, 3]), 31 | feed_dict=network.test_feed_dict()) 32 | save_image_grid(samples, args.out_file) 33 | 34 | 35 | def _parse_args(): 36 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 37 | parser.add_argument('--size', help='image size', type=int, default=64) 38 | parser.add_argument('--rows', help='rows in output', type=int, default=4) 39 | parser.add_argument('--cols', help='columns in output', type=int, default=4) 40 | parser.add_argument('--state-file', help='state output file', default='state.pkl') 41 | parser.add_argument('--out-file', help='image output file', default='samples.png') 42 | parser.add_argument('--seed', help='seed for outputs', type=int, default=None) 43 | return parser.parse_args() 44 | 45 | 46 | if __name__ == '__main__': 47 | main(_parse_args()) 48 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/run_train.py: -------------------------------------------------------------------------------- 1 | """Train a real NVP model.""" 2 | 3 | import argparse 4 | from itertools import count 5 | 6 | import tensorflow as tf 7 | 8 | from cnn_toys.data import dir_train_val 9 | from cnn_toys.real_nvp import bits_per_pixel, bits_per_pixel_and_grad, simple_network 10 | from cnn_toys.saving import save_state, restore_state 11 | 12 | 13 | def main(args): 14 | """The main training loop.""" 15 | print('loading dataset...') 16 | train_data, val_data = dir_train_val(args.data_dir, args.size) 17 | train_images = train_data.repeat().batch(args.batch).make_one_shot_iterator().get_next() 18 | val_images = val_data.repeat().batch(args.batch).make_one_shot_iterator().get_next() 19 | print('setting up model...') 20 | network = simple_network() 21 | with tf.variable_scope('model'): 22 | if args.low_mem: 23 | bpp, train_gradients = bits_per_pixel_and_grad(network, train_images) 24 | train_loss = tf.reduce_mean(bpp) 25 | else: 26 | train_loss = tf.reduce_mean(bits_per_pixel(network, train_images)) 27 | with tf.variable_scope('model', reuse=True): 28 | val_loss = tf.reduce_mean(bits_per_pixel(network, val_images)) 29 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 30 | with tf.control_dependencies(update_ops): 31 | optimizer = tf.train.AdamOptimizer(learning_rate=args.step_size) 32 | if args.low_mem: 33 | optimize = optimizer.apply_gradients(train_gradients) 34 | else: 35 | optimize = optimizer.minimize(train_loss) 36 | with tf.Session() as sess: 37 | print('initializing variables...') 38 | sess.run(tf.global_variables_initializer()) 39 | print('attempting to restore model...') 40 | restore_state(sess, args.state_file) 41 | print('training...') 42 | for i in count(): 43 | cur_loss, _ = sess.run((train_loss, optimize)) 44 | if i % args.val_interval == 0: 45 | cur_val_loss = sess.run(val_loss, feed_dict=network.test_feed_dict()) 46 | print('step %d: loss=%f val=%f' % (i, cur_loss, cur_val_loss)) 47 | else: 48 | print('step %d: loss=%f' % (i, cur_loss)) 49 | if i % args.save_interval == 0: 50 | save_state(sess, args.state_file) 51 | 52 | 53 | def _parse_args(): 54 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 55 | parser.add_argument('--data-dir', help='data directory', default='data') 56 | parser.add_argument('--size', help='image size', type=int, default=64) 57 | parser.add_argument('--batch', help='batch size', type=int, default=32) 58 | parser.add_argument('--step-size', help='training step size', type=float, default=2e-4) 59 | parser.add_argument('--state-file', help='state output file', default='state.pkl') 60 | parser.add_argument('--save-interval', help='iterations per save', type=int, default=100) 61 | parser.add_argument('--val-interval', help='iterations per validation', type=int, default=10) 62 | parser.add_argument('--low-mem', help='use memory-efficient backprop', action='store_true') 63 | return parser.parse_args() 64 | 65 | 66 | if __name__ == '__main__': 67 | main(_parse_args()) 68 | -------------------------------------------------------------------------------- /cnn_toys/real_nvp/test_layer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for real NVP layers. 3 | """ 4 | 5 | from functools import partial 6 | from random import random 7 | 8 | import numpy as np 9 | import pytest 10 | import tensorflow as tf 11 | 12 | from .layer import (FactorHalf, MaskedConv, MaskedFC, Network, PaddedLogit, Squeeze, 13 | checkerboard_mask, depth_mask, one_cold_mask) 14 | 15 | 16 | def test_squeeze_forward(): 17 | """ 18 | Test the forward pass of the Squeeze layer. 19 | """ 20 | inputs = np.array([ 21 | [ 22 | [[1, 2, 3], [4, 5, 6]], 23 | [[7, 8, 9], [10, 11, 12]], 24 | [[13, 14, 15], [16, 17, 18]], 25 | [[19, 20, 21], [22, 23, 24]] 26 | ] 27 | ], dtype='float32') 28 | with tf.Session() as sess: 29 | actual = sess.run(Squeeze().forward(tf.constant(inputs))[0]) 30 | expected = np.array([ 31 | [ 32 | [[1, 4, 7, 10, 2, 5, 8, 11, 3, 6, 9, 12]], 33 | [[13, 16, 19, 22, 14, 17, 20, 23, 15, 18, 21, 24]] 34 | ] 35 | ], dtype='float32') 36 | assert not np.isnan(actual).any() 37 | assert np.allclose(actual, expected) 38 | 39 | 40 | def test_gradients(): 41 | """ 42 | Test that manual gradient computation works properly. 43 | """ 44 | with tf.Graph().as_default(): # pylint: disable=E1129 45 | layers = [ 46 | PaddedLogit(), 47 | MaskedConv(partial(checkerboard_mask, True), 2), 48 | MaskedConv(partial(checkerboard_mask, False), 2), 49 | Squeeze(), 50 | MaskedConv(partial(depth_mask, True), 2), 51 | FactorHalf(), 52 | ] 53 | network = Network(layers) 54 | inputs = tf.random_uniform([3, 8, 8, 4]) 55 | outputs, latents, log_det = network.forward(inputs) 56 | loss = (tf.reduce_sum(tf.stack([(random() + 1) * tf.reduce_sum(x) for x in latents])) + 57 | (random() + 1) * tf.reduce_sum(log_det)) 58 | 59 | manual_grads = network.gradients(outputs, latents, log_det, loss) 60 | manual_grads = {var: grad for grad, var in manual_grads} 61 | manual_grads = [manual_grads[v] for v in tf.trainable_variables()] 62 | 63 | true_grads = tf.gradients(loss, tf.trainable_variables()) 64 | 65 | diffs = [tf.reduce_max(x - y) for x, y in zip(manual_grads, true_grads)] 66 | max_diff = tf.reduce_max(tf.stack(diffs)) 67 | 68 | with tf.Session() as sess: 69 | sess.run(tf.global_variables_initializer()) 70 | _randomized_init(sess) 71 | assert sess.run(max_diff) < 1e-4 72 | 73 | 74 | def test_padded_logit_inverse(): 75 | """ 76 | A specialized test for PaddedLogit inverses. 77 | """ 78 | inputs = np.random.random(size=(1, 3, 3, 7)).astype('float32') * 0.95 79 | _inverse_test(PaddedLogit(), inputs) 80 | 81 | 82 | @pytest.mark.parametrize("layer,shape", 83 | [(FactorHalf(), (3, 27, 15, 8)), 84 | (MaskedConv(partial(checkerboard_mask, True), 1), (3, 28, 14, 8)), 85 | (MaskedConv(partial(depth_mask, False), 1), (3, 28, 14, 8)), 86 | (MaskedFC(partial(one_cold_mask, 3)), (3, 6)), 87 | (Network([FactorHalf(), Squeeze()]), (3, 28, 14, 4)), 88 | (Squeeze(), (4, 8, 18, 4))]) 89 | def test_inverses(layer, shape): 90 | """ 91 | Tests for inverses on unbounded inputs. 92 | """ 93 | inputs = np.random.normal(size=shape).astype('float32') 94 | _inverse_test(layer, inputs) 95 | 96 | 97 | def _inverse_test(layer, inputs): 98 | with tf.Graph().as_default(): # pylint: disable=E1129 99 | with tf.Session() as sess: 100 | in_constant = tf.constant(inputs) 101 | with tf.variable_scope('model'): 102 | out, latent, _ = layer.forward(in_constant) 103 | with tf.variable_scope('model', reuse=True): 104 | inverse = layer.inverse(out, latent) 105 | _randomized_init(sess) 106 | actual = sess.run(inverse) 107 | assert not np.isnan(actual).any() 108 | assert np.allclose(actual, inputs, atol=1e-4, rtol=1e-4) 109 | 110 | 111 | def test_padded_logit_log_det(): 112 | """ 113 | A specialized test for PaddedLogit determinants. 114 | """ 115 | inputs = np.random.random(size=(1, 3, 3, 2)).astype('float32') * 0.95 116 | _log_det_test(PaddedLogit(), inputs) 117 | 118 | 119 | @pytest.mark.parametrize("layer,shape", 120 | [(MaskedConv(partial(checkerboard_mask, True), 1), (3, 4, 6, 2)), 121 | (MaskedConv(partial(depth_mask, False), 1), (3, 4, 4, 8)), 122 | (MaskedFC(partial(one_cold_mask, 3)), (3, 6))]) 123 | def test_log_det(layer, shape): 124 | """ 125 | Tests log determinants. 126 | """ 127 | inputs = np.random.normal(size=shape).astype('float32') 128 | _log_det_test(layer, inputs) 129 | 130 | 131 | def _log_det_test(layer, inputs): 132 | with tf.Graph().as_default(): # pylint: disable=E1129 133 | with tf.Session() as sess: 134 | in_vecs = tf.constant(np.reshape(inputs, [inputs.shape[0], -1])) 135 | in_tensor = tf.reshape(in_vecs, inputs.shape) 136 | with tf.variable_scope('model'): 137 | out, _, log_dets = layer.forward(in_tensor) 138 | out_vecs = tf.reshape(out, in_vecs.get_shape()) 139 | jacobians = _compute_jacobians(in_vecs, out_vecs) 140 | real_log_dets = tf.linalg.slogdet(jacobians)[1] 141 | _randomized_init(sess) 142 | real_log_dets, log_dets = sess.run([real_log_dets, log_dets]) 143 | assert log_dets.shape == (inputs.shape[0],) 144 | assert not np.isnan(log_dets).any() 145 | assert not np.isnan(real_log_dets).any() 146 | assert np.allclose(real_log_dets, log_dets, atol=1e-4, rtol=1e-4) 147 | 148 | 149 | def _compute_jacobians(in_vecs, out_vecs): 150 | num_dims = in_vecs.get_shape()[-1].value 151 | res = [] 152 | for i in range(in_vecs.get_shape()[0].value): 153 | rows = [] 154 | for comp in range(num_dims): 155 | rows.append(tf.gradients(out_vecs[i, comp], in_vecs)[0][i]) 156 | res.append(tf.stack(rows, axis=0)) 157 | return tf.stack(res, axis=0) 158 | 159 | 160 | def _randomized_init(sess): 161 | """ 162 | Initialize all the TF variables in a way that prevents 163 | default identity behavior. 164 | 165 | Without a random init, some layers are essentially 166 | just identity transforms. 167 | """ 168 | sess.run(tf.global_variables_initializer()) 169 | for variable in tf.trainable_variables(): 170 | shape = [x.value for x in variable.get_shape()] 171 | val = tf.glorot_uniform_initializer()(shape) 172 | sess.run(tf.assign(variable, val)) 173 | -------------------------------------------------------------------------------- /cnn_toys/saving.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tools for saving/restoring models. 3 | """ 4 | 5 | import os 6 | import pickle 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def save_state(sess, path): 12 | """Export all TensorFlow variables.""" 13 | with open(path, 'wb+') as outfile: 14 | pickle.dump(sess.run(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)), outfile) 15 | 16 | 17 | def restore_state(sess, path): 18 | """Import all TensorFlow variables.""" 19 | if not os.path.exists(path): 20 | return 21 | with open(path, 'rb') as infile: 22 | state = pickle.load(infile) 23 | all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 24 | placeholders = [tf.placeholder(dtype=v.dtype.base_dtype, shape=v.get_shape()) for v in all_vars] 25 | assigns = [tf.assign(var, ph) for var, ph in zip(all_vars, placeholders)] 26 | sess.run(tf.group(*assigns), feed_dict=dict(zip(placeholders, state))) 27 | -------------------------------------------------------------------------------- /cnn_toys/schedules.py: -------------------------------------------------------------------------------- 1 | """Learning rate schedules.""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def half_annealed_lr(initial, iters, global_step): 7 | """ 8 | Create a learning rate that stays at an initial value 9 | for the first half of training, then is linearly 10 | annealed for the second half of training. 11 | 12 | Args: 13 | initial: the initial LR. 14 | iters: the total number of iterations. 15 | global_step: the step counter Tensor. 16 | """ 17 | frac_done = 1 - tf.cast(iters - global_step, tf.float32) / float(iters) 18 | return tf.cond(frac_done < 0.5, lambda: initial, lambda: (1 - frac_done) * 2 * initial) 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Package meta-data. 3 | """ 4 | 5 | from setuptools import setup 6 | 7 | setup( 8 | name='cnn-toys', 9 | version='0.0.1', 10 | description='Playing around with CNNs.', 11 | url='https://github.com/unixpickle/cnn-toys', 12 | author='Alex Nichol', 13 | author_email='unixpickle@gmail.com', 14 | license='MIT', 15 | packages=['cnn_toys'], 16 | install_requires=['numpy'] 17 | ) 18 | --------------------------------------------------------------------------------