├── .gitignore ├── LICENSE ├── README.md ├── assets ├── generation_2016_08_01_16_40_28.jpg ├── model.png ├── pixel_cnn.png ├── pixel_rnn.png └── pixel_rnn_cnn_relative.png ├── cifar10.py ├── main.py ├── network.py ├── ops.py ├── statistic.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # misc 2 | logs 3 | data 4 | *.jpg 5 | MNIST_data 6 | 7 | # data 8 | samples 9 | *checkpoints/ 10 | *.npy 11 | *.pkl 12 | *.tgz 13 | *.zip 14 | *.tar.gz 15 | 16 | 17 | # Created by https://www.gitignore.io/api/python,vim 18 | 19 | ### IPythonNotebook ### 20 | ## Temporary data 21 | .ipynb_checkpoints/ 22 | 23 | ### Python ### 24 | # Byte-compiled / optimized / DLL files 25 | __pycache__/ 26 | *.py[cod] 27 | *$py.class 28 | 29 | # C extensions 30 | *.so 31 | 32 | # Distribution / packaging 33 | .Python 34 | env/ 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *,cover 69 | .hypothesis/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | 85 | ### Vim ### 86 | [._]*.s[a-w][a-z] 87 | [._]s[a-w][a-z] 88 | *.un~ 89 | Session.vim 90 | .netrwhist 91 | *~ 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Taehoon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PixelCNN & PixelRNN in TensorFlow 2 | 3 | TensorFlow implementation of [Pixel Recurrent Neural Networks](https://arxiv.org/abs/1601.06759). This implementation contains: 4 | 5 | ![model](./assets/model.png) 6 | 7 | 1. PixelCNN 8 | - Masked Convolution (A, B) 9 | 2. PixelRNN 10 | - Row LSTM (in progress) 11 | - Diagonal BiLSTM (skew, unskew) 12 | - Residual Connections 13 | - Multi-Scale PixelRNN (in progress) 14 | 3. Datasets 15 | - MNIST 16 | - cifar10 (in progress) 17 | - ImageNet (in progress) 18 | 19 | 20 | ## Requirements 21 | 22 | - Python 2.7 23 | - [Scipy](https://www.scipy.org/) 24 | - [TensorFlow](https://www.tensorflow.org/) 0.9+ 25 | 26 | 27 | ## Usage 28 | 29 | First, install prerequisites with: 30 | 31 | $ pip install tqdm gym[all] 32 | 33 | To train a `pixel_rnn` model with `mnist` data (slow iteration, fast convergence): 34 | 35 | $ python main.py --data=mnist --model=pixel_rnn 36 | 37 | To train a `pixel_cnn` model with `mnist` data (fast iteration, slow convergence): 38 | 39 | $ python main.py --data=mnist --model=pixel_cnn --hidden_dims=64 --recurrent_length=2 --out_hidden_dims=64 40 | 41 | To generate images with trained model: 42 | 43 | $ python main.py --data=mnist --model=pixel_rnn --is_train=False 44 | 45 | 46 | ## Samples 47 | 48 | Samples generated with `pixel_cnn` after 50 epochs. 49 | 50 | ![generation_2016_08_01_16_40_28.jpg](./assets/generation_2016_08_01_16_40_28.jpg) 51 | 52 | 53 | ## Training details 54 | 55 | Below results uses two different parameters 56 | 57 | [1] `--hidden_dims=16 --recurrent_length=7 --out_hidden_dims=32` 58 | [2] `--hidden_dims=64 --recurrent_length=2 --out_hidden_dims=64` 59 | 60 | Training results of `pixel_rnn` with \[1\] (yellow) and \[2\] (green) with `epoch` as x-axis: 61 | 62 | ![pixel_rnn](./assets/pixel_rnn.png) 63 | 64 | Training results of `pixel_cnn` with \[1\] (orange) and \[2\] (purple) with `epoch` as x-axis: 65 | 66 | ![pixel_cnn](./assets/pixel_cnn.png) 67 | 68 | Training results of `pixel_rnn` (yellow, green) and `pixel_cnn` (orange, purple) with `hour` as x-axis: 69 | 70 | ![pixel_rnn_cnn_relative](./assets/pixel_rnn_cnn_relative.png) 71 | 72 | 73 | 74 | ## References 75 | 76 | - [Pixel Recurrent Neural Networks](https://arxiv.org/abs/1601.06759) 77 | - [Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328) 78 | - [Review by Kyle Kastner](https://github.com/tensorflow/magenta/blob/master/magenta/reviews/pixelrnn.md) 79 | - [igul222/pixel_rnn](https://github.com/igul222/pixel_rnn) 80 | - [kundan2510/pixelCNN](https://github.com/kundan2510/pixelCNN) 81 | 82 | 83 | ## Author 84 | 85 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 86 | -------------------------------------------------------------------------------- /assets/generation_2016_08_01_16_40_28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/pixel-rnn-tensorflow/1d98c28c62ed9c9584a342b63e4612dd9e5ae688/assets/generation_2016_08_01_16_40_28.jpg -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/pixel-rnn-tensorflow/1d98c28c62ed9c9584a342b63e4612dd9e5ae688/assets/model.png -------------------------------------------------------------------------------- /assets/pixel_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/pixel-rnn-tensorflow/1d98c28c62ed9c9584a342b63e4612dd9e5ae688/assets/pixel_cnn.png -------------------------------------------------------------------------------- /assets/pixel_rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/pixel-rnn-tensorflow/1d98c28c62ed9c9584a342b63e4612dd9e5ae688/assets/pixel_rnn.png -------------------------------------------------------------------------------- /assets/pixel_rnn_cnn_relative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/pixel-rnn-tensorflow/1d98c28c62ed9c9584a342b63e4612dd9e5ae688/assets/pixel_rnn_cnn_relative.png -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Routine for decoding the CIFAR-10 binary file format.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | from six.moves import xrange # pylint: disable=redefined-builtin 25 | import tensorflow as tf 26 | 27 | # Process images of this size. Note that this differs from the original CIFAR 28 | # image size of 32 x 32. If one alters this number, then the entire model 29 | # architecture will change and any model would need to be retrained. 30 | IMAGE_SIZE = 24 31 | 32 | # Global constants describing the CIFAR-10 data set. 33 | NUM_CLASSES = 10 34 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 35 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 36 | 37 | 38 | def read_cifar10(filename_queue): 39 | """Reads and parses examples from CIFAR10 data files. 40 | 41 | Recommendation: if you want N-way read parallelism, call this function 42 | N times. This will give you N independent Readers reading different 43 | files & positions within those files, which will give better mixing of 44 | examples. 45 | 46 | Args: 47 | filename_queue: A queue of strings with the filenames to read from. 48 | 49 | Returns: 50 | An object representing a single example, with the following fields: 51 | height: number of rows in the result (32) 52 | width: number of columns in the result (32) 53 | depth: number of color channels in the result (3) 54 | key: a scalar string Tensor describing the filename & record number 55 | for this example. 56 | label: an int32 Tensor with the label in the range 0..9. 57 | uint8image: a [height, width, depth] uint8 Tensor with the image data 58 | """ 59 | 60 | class CIFAR10Record(object): 61 | pass 62 | result = CIFAR10Record() 63 | 64 | # Dimensions of the images in the CIFAR-10 dataset. 65 | # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the 66 | # input format. 67 | label_bytes = 1 # 2 for CIFAR-100 68 | result.height = 32 69 | result.width = 32 70 | result.depth = 3 71 | image_bytes = result.height * result.width * result.depth 72 | # Every record consists of a label followed by the image, with a 73 | # fixed number of bytes for each. 74 | record_bytes = label_bytes + image_bytes 75 | 76 | # Read a record, getting filenames from the filename_queue. No 77 | # header or footer in the CIFAR-10 format, so we leave header_bytes 78 | # and footer_bytes at their default of 0. 79 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 80 | result.key, value = reader.read(filename_queue) 81 | 82 | # Convert from a string to a vector of uint8 that is record_bytes long. 83 | record_bytes = tf.decode_raw(value, tf.uint8) 84 | 85 | # The first bytes represent the label, which we convert from uint8->int32. 86 | result.label = tf.cast( 87 | tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 88 | 89 | # The remaining bytes after the label represent the image, which we reshape 90 | # from [depth * height * width] to [depth, height, width]. 91 | depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), 92 | [result.depth, result.height, result.width]) 93 | # Convert from [depth, height, width] to [height, width, depth]. 94 | result.uint8image = tf.transpose(depth_major, [1, 2, 0]) 95 | 96 | return result 97 | 98 | 99 | def _generate_image_and_label_batch(image, label, min_queue_examples, 100 | batch_size): 101 | """Construct a queued batch of images and labels. 102 | 103 | Args: 104 | image: 3-D Tensor of [height, width, 3] of type.float32. 105 | label: 1-D Tensor of type.int32 106 | min_queue_examples: int32, minimum number of samples to retain 107 | in the queue that provides of batches of examples. 108 | batch_size: Number of images per batch. 109 | 110 | Returns: 111 | images: Images. 4D tensor of [batch_size, height, width, 3] size. 112 | labels: Labels. 1D tensor of [batch_size] size. 113 | """ 114 | # Create a queue that shuffles the examples, and then 115 | # read 'batch_size' images + labels from the example queue. 116 | num_preprocess_threads = 16 117 | images, label_batch = tf.train.shuffle_batch( 118 | [image, label], 119 | batch_size=batch_size, 120 | num_threads=num_preprocess_threads, 121 | capacity=min_queue_examples + 3 * batch_size, 122 | min_after_dequeue=min_queue_examples) 123 | 124 | # Display the training images in the visualizer. 125 | # FIXED pre-1.0 # tf.image_summary('images', images) 126 | tf.summary.image('images', images) 127 | 128 | return images, tf.reshape(label_batch, [batch_size]) 129 | 130 | 131 | def distorted_inputs(data_dir, batch_size): 132 | """Construct distorted input for CIFAR training using the Reader ops. 133 | 134 | Args: 135 | data_dir: Path to the CIFAR-10 data directory. 136 | batch_size: Number of images per batch. 137 | 138 | Returns: 139 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 140 | labels: Labels. 1D tensor of [batch_size] size. 141 | """ 142 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 143 | for i in xrange(1, 6)] 144 | for f in filenames: 145 | if not tf.gfile.Exists(f): 146 | raise ValueError('Failed to find file: ' + f) 147 | 148 | # Create a queue that produces the filenames to read. 149 | filename_queue = tf.train.string_input_producer(filenames) 150 | 151 | # Read examples from files in the filename queue. 152 | read_input = read_cifar10(filename_queue) 153 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 154 | 155 | height = IMAGE_SIZE 156 | width = IMAGE_SIZE 157 | 158 | # Image processing for training the network. Note the many random 159 | # distortions applied to the image. 160 | 161 | # Randomly crop a [height, width] section of the image. 162 | distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) 163 | 164 | # Randomly flip the image horizontally. 165 | distorted_image = tf.image.random_flip_left_right(distorted_image) 166 | 167 | # Because these operations are not commutative, consider randomizing 168 | # randomize the order their operation. 169 | distorted_image = tf.image.random_brightness(distorted_image, 170 | max_delta=63) 171 | distorted_image = tf.image.random_contrast(distorted_image, 172 | lower=0.2, upper=1.8) 173 | 174 | # Subtract off the mean and divide by the variance of the pixels. 175 | # FIXED pre-1.0 # float_image = tf.image.per_image_whitening(distorted_image) 176 | float_image = tf.image.per_image_standardization(distorted_image) 177 | 178 | # Ensure that the random shuffling has good mixing properties. 179 | min_fraction_of_examples_in_queue = 0.4 180 | min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 181 | min_fraction_of_examples_in_queue) 182 | print ('Filling queue with %d CIFAR images before starting to train. ' 183 | 'This will take a few minutes.' % min_queue_examples) 184 | 185 | # Generate a batch of images and labels by building up a queue of examples. 186 | return _generate_image_and_label_batch(float_image, read_input.label, 187 | min_queue_examples, batch_size) 188 | 189 | 190 | def inputs(eval_data, data_dir, batch_size): 191 | """Construct input for CIFAR evaluation using the Reader ops. 192 | 193 | Args: 194 | eval_data: bool, indicating if one should use the train or eval data set. 195 | data_dir: Path to the CIFAR-10 data directory. 196 | batch_size: Number of images per batch. 197 | 198 | Returns: 199 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 200 | labels: Labels. 1D tensor of [batch_size] size. 201 | """ 202 | if not eval_data: 203 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 204 | for i in xrange(1, 6)] 205 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 206 | else: 207 | filenames = [os.path.join(data_dir, 'test_batch.bin')] 208 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL 209 | 210 | for f in filenames: 211 | if not tf.gfile.Exists(f): 212 | raise ValueError('Failed to find file: ' + f) 213 | 214 | # Create a queue that produces the filenames to read. 215 | filename_queue = tf.train.string_input_producer(filenames) 216 | 217 | # Read examples from files in the filename queue. 218 | read_input = read_cifar10(filename_queue) 219 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 220 | 221 | height = IMAGE_SIZE 222 | width = IMAGE_SIZE 223 | 224 | # Image processing for evaluation. 225 | # Crop the central [height, width] of the image. 226 | resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 227 | width, height) 228 | 229 | # Subtract off the mean and divide by the variance of the pixels. 230 | # FIXED pre-1.0 # float_image = tf.image.per_image_whitening(resized_image) 231 | float_image = tf.image.per_image_standardization(resized_image) 232 | 233 | # Ensure that the random shuffling has good mixing properties. 234 | min_fraction_of_examples_in_queue = 0.4 235 | min_queue_examples = int(num_examples_per_epoch * 236 | min_fraction_of_examples_in_queue) 237 | 238 | # Generate a batch of images and labels by building up a queue of examples. 239 | return _generate_image_and_label_batch(float_image, read_input.label, 240 | min_queue_examples, batch_size) 241 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | logging.basicConfig(format="[%(asctime)s] %(message)s", datefmt="%m-%d %H:%M:%S") 4 | 5 | import numpy as np 6 | from tqdm import trange 7 | import tensorflow as tf 8 | 9 | from utils import * 10 | from network import Network 11 | from statistic import Statistic 12 | 13 | flags = tf.app.flags 14 | 15 | # network 16 | flags.DEFINE_string("model", "pixel_cnn", "name of model [pixel_rnn, pixel_cnn]") 17 | flags.DEFINE_integer("batch_size", 100, "size of a batch") 18 | flags.DEFINE_integer("hidden_dims", 16, "dimesion of hidden states of LSTM or Conv layers") 19 | flags.DEFINE_integer("recurrent_length", 7, "the length of LSTM or Conv layers") 20 | flags.DEFINE_integer("out_hidden_dims", 32, "dimesion of hidden states of output Conv layers") 21 | flags.DEFINE_integer("out_recurrent_length", 2, "the length of output Conv layers") 22 | flags.DEFINE_boolean("use_residual", False, "whether to use residual connections or not") 23 | # flags.DEFINE_boolean("use_dynamic_rnn", False, "whether to use dynamic_rnn or not") 24 | 25 | # training 26 | flags.DEFINE_integer("max_epoch", 100000, "# of step in an epoch") 27 | flags.DEFINE_integer("test_step", 100, "# of step to test a model") 28 | flags.DEFINE_integer("save_step", 1000, "# of step to save a model") 29 | flags.DEFINE_float("learning_rate", 1e-3, "learning rate") 30 | flags.DEFINE_float("grad_clip", 1, "value of gradient to be used for clipping") 31 | flags.DEFINE_boolean("use_gpu", True, "whether to use gpu for training") 32 | 33 | # data 34 | flags.DEFINE_string("data", "mnist", "name of dataset [mnist, cifar]") 35 | flags.DEFINE_string("data_dir", "data", "name of data directory") 36 | flags.DEFINE_string("sample_dir", "samples", "name of sample directory") 37 | 38 | # Debug 39 | flags.DEFINE_boolean("is_train", True, "training or testing") 40 | flags.DEFINE_boolean("display", False, "whether to display the training results or not") 41 | flags.DEFINE_string("log_level", "INFO", "log level [DEBUG, INFO, WARNING, ERROR, CRITICAL]") 42 | flags.DEFINE_integer("random_seed", 123, "random seed for python") 43 | 44 | conf = flags.FLAGS 45 | 46 | # logging 47 | logger = logging.getLogger() 48 | logger.setLevel(conf.log_level) 49 | 50 | # random seed 51 | tf.set_random_seed(conf.random_seed) 52 | np.random.seed(conf.random_seed) 53 | 54 | def main(_): 55 | model_dir = get_model_dir(conf, 56 | ['data_dir', 'sample_dir', 'max_epoch', 'test_step', 'save_step', 57 | 'is_train', 'random_seed', 'log_level', 'display']) 58 | preprocess_conf(conf) 59 | 60 | DATA_DIR = os.path.join(conf.data_dir, conf.data) 61 | SAMPLE_DIR = os.path.join(conf.sample_dir, conf.data, model_dir) 62 | 63 | check_and_create_dir(DATA_DIR) 64 | check_and_create_dir(SAMPLE_DIR) 65 | 66 | # 0. prepare datasets 67 | if conf.data == "mnist": 68 | from tensorflow.examples.tutorials.mnist import input_data 69 | mnist = input_data.read_data_sets(DATA_DIR, one_hot=True) 70 | 71 | next_train_batch = lambda x: mnist.train.next_batch(x)[0] 72 | next_test_batch = lambda x: mnist.test.next_batch(x)[0] 73 | 74 | height, width, channel = 28, 28, 1 75 | 76 | train_step_per_epoch = mnist.train.num_examples / conf.batch_size 77 | test_step_per_epoch = mnist.test.num_examples / conf.batch_size 78 | elif conf.data == "cifar": 79 | from cifar10 import IMAGE_SIZE, inputs 80 | 81 | maybe_download_and_extract(DATA_DIR) 82 | images, labels = inputs(eval_data=False, 83 | data_dir=os.path.join(DATA_DIR, 'cifar-10-batches-bin'), batch_size=conf.batch_size) 84 | 85 | height, width, channel = IMAGE_SIZE, IMAGE_SIZE, 3 86 | 87 | with tf.Session() as sess: 88 | network = Network(sess, conf, height, width, channel) 89 | 90 | stat = Statistic(sess, conf.data, model_dir, tf.trainable_variables(), conf.test_step) 91 | stat.load_model() 92 | 93 | if conf.is_train: 94 | logger.info("Training starts!") 95 | 96 | initial_step = stat.get_t() if stat else 0 97 | iterator = trange(conf.max_epoch, ncols=70, initial=initial_step) 98 | 99 | for epoch in iterator: 100 | # 1. train 101 | total_train_costs = [] 102 | for idx in xrange(train_step_per_epoch): 103 | images = binarize(next_train_batch(conf.batch_size)) \ 104 | .reshape([conf.batch_size, height, width, channel]) 105 | 106 | cost = network.test(images, with_update=True) 107 | total_train_costs.append(cost) 108 | 109 | # 2. test 110 | total_test_costs = [] 111 | for idx in xrange(test_step_per_epoch): 112 | images = binarize(next_test_batch(conf.batch_size)) \ 113 | .reshape([conf.batch_size, height, width, channel]) 114 | 115 | cost = network.test(images, with_update=False) 116 | total_test_costs.append(cost) 117 | 118 | avg_train_cost, avg_test_cost = np.mean(total_train_costs), np.mean(total_test_costs) 119 | 120 | stat.on_step(avg_train_cost, avg_test_cost) 121 | 122 | # 3. generate samples 123 | samples = network.generate() 124 | save_images(samples, height, width, 10, 10, 125 | directory=SAMPLE_DIR, prefix="epoch_%s" % epoch) 126 | 127 | iterator.set_description("train l: %.3f, test l: %.3f" % (avg_train_cost, avg_test_cost)) 128 | print 129 | else: 130 | logger.info("Image generation starts!") 131 | 132 | samples = network.generate() 133 | save_images(samples, height, width, 10, 10, directory=SAMPLE_DIR) 134 | 135 | 136 | if __name__ == "__main__": 137 | tf.app.run() 138 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from logging import getLogger 3 | 4 | from ops import * 5 | from utils import * 6 | 7 | logger = getLogger(__name__) 8 | 9 | class Network: 10 | def __init__(self, sess, conf, height, width, channel): 11 | logger.info("Building %s starts!" % conf.model) 12 | 13 | self.sess = sess 14 | self.data = conf.data 15 | self.height, self.width, self.channel = height, width, channel 16 | 17 | if conf.use_gpu: 18 | data_format = "NHWC" 19 | else: 20 | data_format = "NCHW" 21 | 22 | if data_format == "NHWC": 23 | input_shape = [None, height, width, channel] 24 | elif data_format == "NCHW": 25 | input_shape = [None, channel, height, width] 26 | else: 27 | raise ValueError("Unknown data_format: %s" % data_format) 28 | 29 | self.l = {} 30 | 31 | self.l['inputs'] = tf.placeholder(tf.float32, [None, height, width, channel],) 32 | 33 | if conf.data =='mnist': 34 | self.l['normalized_inputs'] = self.l['inputs'] 35 | else: 36 | self.l['normalized_inputs'] = tf.div(self.l['inputs'], 255., name="normalized_inputs") 37 | 38 | # input of main reccurent layers 39 | scope = "conv_inputs" 40 | logger.info("Building %s" % scope) 41 | 42 | if conf.use_residual and conf.model == "pixel_rnn": 43 | self.l[scope] = conv2d(self.l['normalized_inputs'], conf.hidden_dims * 2, [7, 7], "A", scope=scope) 44 | else: 45 | self.l[scope] = conv2d(self.l['normalized_inputs'], conf.hidden_dims, [7, 7], "A", scope=scope) 46 | 47 | # main reccurent layers 48 | l_hid = self.l[scope] 49 | for idx in xrange(conf.recurrent_length): 50 | if conf.model == "pixel_rnn": 51 | scope = 'LSTM%d' % idx 52 | self.l[scope] = l_hid = diagonal_bilstm(l_hid, conf, scope=scope) 53 | elif conf.model == "pixel_cnn": 54 | scope = 'CONV%d' % idx 55 | self.l[scope] = l_hid = conv2d(l_hid, 3, [1, 1], "B", scope=scope) 56 | else: 57 | raise ValueError("wrong type of model: %s" % (conf.model)) 58 | logger.info("Building %s" % scope) 59 | 60 | # output reccurent layers 61 | for idx in xrange(conf.out_recurrent_length): 62 | scope = 'CONV_OUT%d' % idx 63 | self.l[scope] = l_hid = tf.nn.relu(conv2d(l_hid, conf.out_hidden_dims, [1, 1], "B", scope=scope)) 64 | logger.info("Building %s" % scope) 65 | 66 | if channel == 1: 67 | self.l['conv2d_out_logits'] = conv2d(l_hid, 1, [1, 1], "B", scope='conv2d_out_logits') 68 | self.l['output'] = tf.nn.sigmoid(self.l['conv2d_out_logits']) 69 | 70 | logger.info("Building loss and optims") 71 | # FIXED pre-1.0 72 | # self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 73 | # self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss')) 74 | self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( 75 | logits=self.l['conv2d_out_logits'], labels=self.l['normalized_inputs'], name='loss')) 76 | else: 77 | raise ValueError("Implementation in progress for RGB colors") 78 | 79 | COLOR_DIM = 256 80 | 81 | self.l['conv2d_out_logits'] = conv2d(l_hid, COLOR_DIM, [1, 1], "B", scope='conv2d_out_logits') 82 | 83 | self.l['conv2d_out_logits_flat'] = tf.reshape( 84 | self.l['conv2d_out_logits'], [-1, self.height * self.width, COLOR_DIM]) 85 | self.l['normalized_inputs_flat'] = tf.reshape( 86 | self.l['normalized_inputs'], [-1, self.height * self.width, COLOR_DIM]) 87 | 88 | # FIXED pre-1.0 # pred_pixels = [tf.squeeze(pixel, squeeze_dims=[1]) 89 | pred_pixels = [tf.squeeze(pixel, axis=[1]) 90 | # FIXED pre-1.0 # for pixel in tf.split(1, self.height * self.width, self.l['conv2d_out_logits_flat'])] 91 | for pixel in tf.split(self.l['conv2d_out_logits_flat'], self.height * self.width, 1)] 92 | # FIXED pre-1.0 # target_pixels = [tf.squeeze(pixel, squeeze_dims=[1]) 93 | target_pixels = [tf.squeeze(pixel, axis=[1]) 94 | # FIXED pre-1.0 # for pixel in tf.split(1, self.height * self.width, self.l['normalized_inputs_flat'])] 95 | for pixel in tf.split(self.l['normalized_inputs_flat'], self.height * self.width, 1)] 96 | 97 | softmaxed_pixels = [tf.nn.softmax(pixel) for pixel in pred_pixels] 98 | 99 | losses = [tf.nn.sampled_softmax_loss( 100 | pred_pixel, tf.zeros_like(pred_pixel), pred_pixel, target_pixel, 1, COLOR_DIM) \ 101 | for pred_pixel, target_pixel in zip(pred_pixels, target_pixels)] 102 | 103 | self.l['output'] = tf.nn.softmax(self.l['conv2d_out_logits']) 104 | 105 | logger.info("Building loss and optims") 106 | # FIXED pre-1.0 107 | # self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 108 | # self.l['conv2d_out_logits'], self.l['normalized_inputs'], name='loss')) 109 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 110 | logits=self.l['conv2d_out_logits'], labels=self.l['normalized_inputs'], name='loss')) 111 | 112 | optimizer = tf.train.RMSPropOptimizer(conf.learning_rate) 113 | grads_and_vars = optimizer.compute_gradients(self.loss) 114 | 115 | new_grads_and_vars = \ 116 | [(tf.clip_by_value(gv[0], -conf.grad_clip, conf.grad_clip), gv[1]) for gv in grads_and_vars] 117 | self.optim = optimizer.apply_gradients(new_grads_and_vars) 118 | 119 | show_all_variables() 120 | 121 | logger.info("Building %s finished!" % conf.model) 122 | 123 | def predict(self, images): 124 | return self.sess.run(self.l['output'], {self.l['inputs']: images}) 125 | 126 | def test(self, images, with_update=False): 127 | if with_update: 128 | _, cost = self.sess.run([ 129 | self.optim, self.loss, 130 | ], feed_dict={ self.l['inputs']: images }) 131 | else: 132 | cost = self.sess.run(self.loss, feed_dict={ self.l['inputs']: images }) 133 | return cost 134 | 135 | def generate(self): 136 | samples = np.zeros((100, self.height, self.width, 1), dtype='float32') 137 | 138 | for i in xrange(self.height): 139 | for j in xrange(self.width): 140 | for k in xrange(self.channel): 141 | next_sample = binarize(self.predict(samples)) 142 | samples[:, i, j, k] = next_sample[:, i, j, k] 143 | 144 | if self.data == 'mnist': 145 | print "=" * (self.width/2), "(%2d, %2d)" % (i, j), "=" * (self.width/2) 146 | mprint(next_sample[0,:,:,:]) 147 | 148 | return samples 149 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(format="[%(asctime)s] %(message)s", datefmt="%m-%d %H:%M:%S") 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.python.ops import rnn_cell 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | from tensorflow.contrib.layers import variance_scaling_initializer 9 | 10 | WEIGHT_INITIALIZER = tf.contrib.layers.xavier_initializer() 11 | #WEIGHT_INITIALIZER = tf.uniform_unit_scaling_initializer() 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | he_uniform = variance_scaling_initializer(factor=2.0, mode="FAN_IN", uniform=False) 16 | data_format = "NCHW" 17 | 18 | def get_shape(layer): 19 | return layer.get_shape().as_list() 20 | 21 | #def get_shape(layer): 22 | # if data_format == "NHWC": 23 | # batch, height, width, channel = layer.get_shape().as_list() 24 | # elif data_format == "NCHW": 25 | # batch, channel, height, width = layer.get_shape().as_list() 26 | # else: 27 | # raise ValueError("Unknown data_format: %s" % data_format) 28 | # return batch, height, width, channel 29 | 30 | def skew(inputs, scope="skew"): 31 | with tf.name_scope(scope): 32 | batch, height, width, channel = get_shape(inputs) # [batch, height, width, channel] 33 | # FIXED pre-1.0 # rows = tf.split(1, height, inputs) # [batch, 1, width, channel] 34 | rows = tf.split(inputs, height, 1) # [batch, 1, width, channel] 35 | 36 | new_width = width + height - 1 37 | new_rows = [] 38 | 39 | for idx, row in enumerate(rows): 40 | transposed_row = tf.transpose(tf.squeeze(row, [1]), [0, 2, 1]) # [batch, channel, width] 41 | squeezed_row = tf.reshape(transposed_row, [-1, width]) # [batch*channel, width] 42 | padded_row = tf.pad(squeezed_row, ((0, 0), (idx, height - 1 - idx))) # [batch*channel, width*2-1] 43 | 44 | unsqueezed_row = tf.reshape(padded_row, [-1, channel, new_width]) # [batch, channel, width*2-1] 45 | untransposed_row = tf.transpose(unsqueezed_row, [0, 2, 1]) # [batch, width*2-1, channel] 46 | 47 | assert get_shape(untransposed_row) == [batch, new_width, channel], "wrong shape of skewed row" 48 | new_rows.append(untransposed_row) 49 | 50 | # FIXED pre-1.0 # outputs = tf.pack(new_rows, axis=1, name="output") 51 | outputs = tf.stack(new_rows, axis=1, name="output") 52 | assert get_shape(outputs) == [None, height, new_width, channel], "wrong shape of skewed output" 53 | 54 | logger.debug('[skew] %s : %s %s -> %s %s' \ 55 | % (scope, inputs.name, inputs.get_shape(), outputs.name, outputs.get_shape())) 56 | return outputs 57 | 58 | def unskew(inputs, width=None, scope="unskew"): 59 | with tf.name_scope(scope): 60 | batch, height, skewed_width, channel = get_shape(inputs) 61 | width = width if width else height 62 | 63 | new_rows = [] 64 | # FIXED pre-1.0 # rows = tf.split(1, height, inputs) 65 | rows = tf.split(inputs, height, 1) 66 | 67 | for idx, row in enumerate(rows): 68 | new_rows.append(tf.slice(row, [0, 0, idx, 0], [-1, -1, width, -1])) 69 | # FIXED pre-1.0 # outputs = tf.concat(1, new_rows, name="output") 70 | outputs = tf.concat(new_rows, 1, name="output") 71 | 72 | logger.debug('[unskew] %s : %s %s -> %s %s' \ 73 | % (scope, inputs.name, inputs.get_shape(), outputs.name, outputs.get_shape())) 74 | return outputs 75 | 76 | def conv2d( 77 | inputs, 78 | num_outputs, 79 | kernel_shape, # [kernel_height, kernel_width] 80 | mask_type, # None, "A" or "B", 81 | strides=[1, 1], # [column_wise_stride, row_wise_stride] 82 | padding="SAME", 83 | activation_fn=None, 84 | weights_initializer=WEIGHT_INITIALIZER, 85 | weights_regularizer=None, 86 | # FIXED pre-1.0 # biases_initializer=tf.zeros_initializer, 87 | biases_initializer=tf.zeros_initializer(), 88 | biases_regularizer=None, 89 | scope="conv2d"): 90 | with tf.variable_scope(scope): 91 | mask_type = mask_type.lower() 92 | batch_size, height, width, channel = inputs.get_shape().as_list() 93 | 94 | kernel_h, kernel_w = kernel_shape 95 | stride_h, stride_w = strides 96 | 97 | assert kernel_h % 2 == 1 and kernel_w % 2 == 1, \ 98 | "kernel height and width should be odd number" 99 | 100 | center_h = kernel_h // 2 101 | center_w = kernel_w // 2 102 | 103 | weights_shape = [kernel_h, kernel_w, channel, num_outputs] 104 | weights = tf.get_variable("weights", weights_shape, 105 | tf.float32, weights_initializer, weights_regularizer) 106 | 107 | if mask_type is not None: 108 | mask = np.ones( 109 | (kernel_h, kernel_w, channel, num_outputs), dtype=np.float32) 110 | 111 | mask[center_h, center_w+1: ,: ,:] = 0. 112 | mask[center_h+1:, :, :, :] = 0. 113 | 114 | if mask_type == 'a': 115 | mask[center_h,center_w,:,:] = 0. 116 | 117 | weights *= tf.constant(mask, dtype=tf.float32) 118 | tf.add_to_collection('conv2d_weights_%s' % mask_type, weights) 119 | 120 | outputs = tf.nn.conv2d(inputs, 121 | weights, [1, stride_h, stride_w, 1], padding=padding, name='outputs') 122 | tf.add_to_collection('conv2d_outputs', outputs) 123 | 124 | if biases_initializer != None: 125 | biases = tf.get_variable("biases", [num_outputs,], 126 | tf.float32, biases_initializer, biases_regularizer) 127 | outputs = tf.nn.bias_add(outputs, biases, name='outputs_plus_b') 128 | 129 | if activation_fn: 130 | outputs = activation_fn(outputs, name='outputs_with_fn') 131 | 132 | logger.debug('[conv2d_%s] %s : %s %s -> %s %s' \ 133 | % (mask_type, scope, inputs.name, inputs.get_shape(), outputs.name, outputs.get_shape())) 134 | 135 | return outputs 136 | 137 | def conv1d( 138 | inputs, 139 | num_outputs, 140 | kernel_size, 141 | strides=[1, 1], # [column_wise_stride, row_wise_stride] 142 | padding="SAME", 143 | activation_fn=None, 144 | weights_initializer=WEIGHT_INITIALIZER, 145 | weights_regularizer=None, 146 | # FIXED pre-1.0 # biases_initializer=tf.zeros_initializer, 147 | biases_initializer=tf.zeros_initializer(), 148 | biases_regularizer=None, 149 | scope="conv1d"): 150 | with tf.variable_scope(scope): 151 | batch_size, height, _, channel = inputs.get_shape().as_list() # [batch, height, 1, channel] 152 | 153 | kernel_h, kernel_w = kernel_size, 1 154 | stride_h, stride_w = strides 155 | 156 | weights_shape = [kernel_h, kernel_w, channel, num_outputs] 157 | weights = tf.get_variable("weights", weights_shape, 158 | tf.float32, weights_initializer, weights_regularizer) 159 | tf.add_to_collection('conv1d_weights', weights) 160 | 161 | outputs = tf.nn.conv2d(inputs, 162 | weights, [1, stride_h, stride_w, 1], padding=padding, name='outputs') 163 | tf.add_to_collection('conv1d_outputs', weights) 164 | 165 | if biases_initializer != None: 166 | biases = tf.get_variable("biases", [num_outputs,], 167 | tf.float32, biases_initializer, biases_regularizer) 168 | outputs = tf.nn.bias_add(outputs, biases, name='outputs_plus_b') 169 | 170 | if activation_fn: 171 | outputs = activation_fn(outputs, name='outputs_with_fn') 172 | 173 | logger.debug('[conv1d] %s : %s %s -> %s %s' \ 174 | % (scope, inputs.name, inputs.get_shape(), outputs.name, outputs.get_shape())) 175 | 176 | return outputs 177 | 178 | def diagonal_bilstm(inputs, conf, scope='diagonal_bilstm'): 179 | with tf.variable_scope(scope): 180 | def reverse(inputs): 181 | # FIXED pre-1.0 # return tf.reverse(inputs, [False, False, True, False]) 182 | return tf.reverse(inputs, [2]) # [False, False, True, False]) 183 | 184 | output_state_fw = diagonal_lstm(inputs, conf, scope='output_state_fw') 185 | output_state_bw = reverse(diagonal_lstm(reverse(inputs), conf, scope='output_state_bw')) 186 | 187 | tf.add_to_collection('output_state_fw', output_state_fw) 188 | tf.add_to_collection('output_state_bw', output_state_bw) 189 | 190 | if conf.use_residual: 191 | residual_state_fw = conv2d(output_state_fw, conf.hidden_dims * 2, [1, 1], "B", scope="residual_fw") 192 | output_state_fw = residual_state_fw + inputs 193 | 194 | residual_state_bw = conv2d(output_state_bw, conf.hidden_dims * 2, [1, 1], "B", scope="residual_bw") 195 | output_state_bw = residual_state_bw + inputs 196 | 197 | tf.add_to_collection('residual_state_fw', residual_state_fw) 198 | tf.add_to_collection('residual_state_bw', residual_state_bw) 199 | tf.add_to_collection('residual_output_state_fw', output_state_fw) 200 | tf.add_to_collection('residual_output_state_bw', output_state_bw) 201 | 202 | batch, height, width, channel = get_shape(output_state_bw) 203 | 204 | output_state_bw_except_last = tf.slice(output_state_bw, [0, 0, 0, 0], [-1, height-1, -1, -1]) 205 | output_state_bw_only_last = tf.slice(output_state_bw, [0, height-1, 0, 0], [-1, 1, -1, -1]) 206 | dummy_zeros = tf.zeros_like(output_state_bw_only_last) 207 | 208 | # FIXED pre-1.0 # output_state_bw_with_last_zeros = tf.concat(1, [output_state_bw_except_last, dummy_zeros]) 209 | output_state_bw_with_last_zeros = tf.concat([output_state_bw_except_last, dummy_zeros], 1) 210 | 211 | tf.add_to_collection('output_state_bw_with_last_zeros', output_state_bw_with_last_zeros) 212 | 213 | return output_state_fw + output_state_bw_with_last_zeros 214 | 215 | def diagonal_lstm(inputs, conf, scope='diagonal_lstm'): 216 | with tf.variable_scope(scope): 217 | tf.add_to_collection('lstm_inputs', inputs) 218 | 219 | skewed_inputs = skew(inputs, scope="skewed_i") 220 | tf.add_to_collection('skewed_lstm_inputs', skewed_inputs) 221 | 222 | # input-to-state (K_is * x_i) : 1x1 convolution. generate 4h x n x n tensor. 223 | input_to_state = conv2d(skewed_inputs, conf.hidden_dims * 4, [1, 1], "B", scope="i_to_s") 224 | column_wise_inputs = tf.transpose( 225 | input_to_state, [0, 2, 1, 3]) # [batch, width, height, hidden_dims * 4] 226 | 227 | tf.add_to_collection('skewed_conv_inputs', input_to_state) 228 | tf.add_to_collection('column_wise_inputs', column_wise_inputs) 229 | 230 | batch, width, height, channel = get_shape(column_wise_inputs) 231 | rnn_inputs = tf.reshape(column_wise_inputs, 232 | [-1, width, height * channel]) # [batch, max_time, height * hidden_dims * 4] 233 | 234 | tf.add_to_collection('rnn_inputs', rnn_inputs) 235 | 236 | # FIXED pre-1.0 # rnn_input_list = [tf.squeeze(rnn_input, squeeze_dims=[1]) 237 | rnn_input_list = [tf.squeeze(rnn_input, axis=[1]) 238 | # FIXED pre-1.0 # for rnn_input in tf.split(split_dim=1, num_split=width, value=rnn_inputs)] 239 | for rnn_input in tf.split(rnn_inputs, width, 1)] 240 | 241 | cell = DiagonalLSTMCell(conf.hidden_dims, height, channel) 242 | 243 | # if conf.use_dynamic_rnn: 244 | if True: 245 | # XXX FIXME: sequence_length ? 246 | outputs, states = tf.nn.dynamic_rnn(cell, 247 | inputs=rnn_inputs, dtype=tf.float32) # [batch, width, height * hidden_dims] 248 | packed_outputs = outputs # dynaic_rnn(), [batch, width, height * hidden_dims] 249 | 250 | # else: 251 | # output_list, state_list = tf.nn.rnn(cell, 252 | # inputs=rnn_input_list, dtype=tf.float32) # width * [batch, height * hidden_dims] 253 | 254 | # # FIXED pre-1.0 # packed_outputs = tf.pack(output_list, 1) # [batch, width, height * hidden_dims] 255 | # packed_outputs = tf.stack(output_list, 1) # [batch, width, height * hidden_dims] 256 | 257 | width_first_outputs = tf.reshape(packed_outputs, 258 | [-1, width, height, conf.hidden_dims]) # [batch, width, height, hidden_dims] 259 | 260 | skewed_outputs = tf.transpose(width_first_outputs, [0, 2, 1, 3]) 261 | tf.add_to_collection('skewed_outputs', skewed_outputs) 262 | 263 | outputs = unskew(skewed_outputs) 264 | tf.add_to_collection('unskewed_outputs', outputs) 265 | 266 | return outputs 267 | 268 | class DiagonalLSTMCell(rnn_cell.RNNCell): 269 | def __init__(self, hidden_dims, height, channel): 270 | self._num_unit_shards = 1 271 | self._forget_bias = 1. 272 | 273 | self._height = height 274 | self._channel = channel 275 | 276 | self._hidden_dims = hidden_dims 277 | self._num_units = self._hidden_dims * self._height 278 | self._state_size = self._num_units * 2 279 | self._output_size = self._num_units 280 | 281 | @property 282 | def state_size(self): 283 | return self._state_size 284 | 285 | @property 286 | def output_size(self): 287 | return self._output_size 288 | 289 | def __call__(self, i_to_s, state, scope="DiagonalBiLSTMCell"): 290 | c_prev = tf.slice(state, [0, 0], [-1, self._num_units]) 291 | h_prev = tf.slice(state, [0, self._num_units], [-1, self._num_units]) # [batch, height * hidden_dims] 292 | 293 | # i_to_s : [batch, 4 * height * hidden_dims] 294 | input_size = i_to_s.get_shape().with_rank(2)[1] 295 | 296 | if input_size.value is None: 297 | raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 298 | 299 | with tf.variable_scope(scope): 300 | # input-to-state (K_ss * h_{i-1}) : 2x1 convolution. generate 4h x n x n tensor. 301 | conv1d_inputs = tf.reshape(h_prev, 302 | [-1, self._height, 1, self._hidden_dims], name='conv1d_inputs') # [batch, height, 1, hidden_dims] 303 | 304 | tf.add_to_collection('i_to_s', i_to_s) 305 | tf.add_to_collection('conv1d_inputs', conv1d_inputs) 306 | 307 | conv_s_to_s = conv1d(conv1d_inputs, 308 | 4 * self._hidden_dims, 2, scope='s_to_s') # [batch, height, 1, hidden_dims * 4] 309 | s_to_s = tf.reshape(conv_s_to_s, 310 | [-1, self._height * self._hidden_dims * 4]) # [batch, height * hidden_dims * 4] 311 | 312 | tf.add_to_collection('conv_s_to_s', conv_s_to_s) 313 | tf.add_to_collection('s_to_s', s_to_s) 314 | 315 | lstm_matrix = tf.sigmoid(s_to_s + i_to_s) 316 | 317 | # i = input_gate, g = new_input, f = forget_gate, o = output_gate 318 | # FIXED pre-1.0 # i, g, f, o = tf.split(1, 4, lstm_matrix) 319 | i, g, f, o = tf.split(lstm_matrix, 4, 1) 320 | 321 | c = f * c_prev + i * g 322 | # FIXED pre-1.0 # h = tf.mul(o, tf.tanh(c), name='hid') 323 | h = tf.multiply(o, tf.tanh(c), name='hid') 324 | 325 | logger.debug('[DiagonalLSTMCell] %s : %s %s -> %s %s' \ 326 | % (scope, i_to_s.name, i_to_s.get_shape(), h.name, h.get_shape())) 327 | 328 | # FIXED pre-1.0 # new_state = tf.concat(1, [c, h]) 329 | new_state = tf.concat([c, h], 1) 330 | return h, new_state 331 | 332 | class RowLSTMCell(rnn_cell.RNNCell): 333 | def __init__(self, num_units, kernel_shape=[3, 1]): 334 | self._num_units = num_units 335 | self._state_size = num_units * 2 336 | self._output_size = num_units 337 | self._kernel_shape = kernel_shape 338 | 339 | @property 340 | def state_size(self): 341 | return self._state_size 342 | 343 | @property 344 | def output_size(self): 345 | return self._output_size 346 | 347 | def __call__(self, inputs, state, scope="RowLSTMCell"): 348 | raise Exception("Not implemented") 349 | -------------------------------------------------------------------------------- /statistic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from logging import getLogger 5 | 6 | logger = getLogger(__name__) 7 | 8 | class Statistic(object): 9 | def __init__(self, sess, data, model_dir, variables, test_step, max_to_keep=20): 10 | self.sess = sess 11 | self.test_step = test_step 12 | self.reset() 13 | 14 | with tf.variable_scope('t'): 15 | self.t_op = tf.Variable(0, trainable=False, name='t') 16 | self.t_add_op = self.t_op.assign_add(1) 17 | 18 | self.model_dir = model_dir 19 | self.saver = tf.train.Saver(variables + [self.t_op], max_to_keep=max_to_keep) 20 | # FIXED pre-1.0 # self.writer = tf.train.SummaryWriter('./logs/%s' % self.model_dir, self.sess.graph) 21 | self.writer = tf.summary.FileWriter('./logs/%s' % self.model_dir, self.sess.graph) 22 | 23 | with tf.variable_scope('summary'): 24 | scalar_summary_tags = ['train_l', 'test_l'] 25 | 26 | self.summary_placeholders = {} 27 | self.summary_ops = {} 28 | 29 | for tag in scalar_summary_tags: 30 | self.summary_placeholders[tag] = tf.placeholder('float32', None, name=tag.replace(' ', '_')) 31 | # FIXED pre-1.0 # self.summary_ops[tag] = tf.scalar_summary('%s/%s' % (data, tag), self.summary_placeholders[tag]) 32 | self.summary_ops[tag] = tf.summary.scalar('%s/%s' % (data, tag), self.summary_placeholders[tag]) 33 | 34 | def reset(self): 35 | pass 36 | 37 | def on_step(self, train_l, test_l): 38 | self.t = self.t_add_op.eval(session=self.sess) 39 | 40 | self.inject_summary({'train_l': train_l, 'test_l': test_l}, self.t) 41 | 42 | self.save_model(self.t) 43 | self.reset() 44 | 45 | def get_t(self): 46 | return self.t_op.eval(session=self.sess) 47 | 48 | def inject_summary(self, tag_dict, t): 49 | summary_str_lists = self.sess.run([self.summary_ops[tag] for tag in tag_dict.keys()], { 50 | self.summary_placeholders[tag]: value for tag, value in tag_dict.items() 51 | }) 52 | for summary_str in summary_str_lists: 53 | self.writer.add_summary(summary_str, t) 54 | 55 | def save_model(self, t): 56 | logger.info("Saving checkpoints...") 57 | model_name = type(self).__name__ 58 | 59 | if not os.path.exists(self.model_dir): 60 | os.makedirs(self.model_dir) 61 | self.saver.save(self.sess, self.model_dir, global_step=t) 62 | 63 | def load_model(self): 64 | logger.info("Initializing all variables") 65 | # FIXED pre-1.0 # tf.initialize_all_variables().run() 66 | tf.global_variables_initializer().run() 67 | 68 | logger.info("Loading checkpoints...") 69 | ckpt = tf.train.get_checkpoint_state(self.model_dir) 70 | if ckpt and ckpt.model_checkpoint_path: 71 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 72 | fname = os.path.join(self.model_dir, ckpt_name) 73 | self.saver.restore(self.sess, fname) 74 | logger.info("Load SUCCESS: %s" % fname) 75 | else: 76 | logger.info("Load FAILED: %s" % self.model_dir) 77 | 78 | self.t = self.t_add_op.eval(session=self.sess) 79 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(format="[%(asctime)s] %(message)s", datefmt="%m-%d %H:%M:%S") 3 | 4 | import os 5 | import sys 6 | import urllib 7 | import pprint 8 | import tarfile 9 | import tensorflow as tf 10 | 11 | import datetime 12 | import dateutil.tz 13 | import numpy as np 14 | 15 | import scipy.misc 16 | 17 | pp = pprint.PrettyPrinter().pprint 18 | logger = logging.getLogger(__name__) 19 | 20 | def mprint(matrix, pivot=0.5): 21 | for array in matrix: 22 | print "".join("#" if i > pivot else " " for i in array) 23 | 24 | def show_all_variables(): 25 | total_count = 0 26 | for idx, op in enumerate(tf.trainable_variables()): 27 | shape = op.get_shape() 28 | count = np.prod(shape) 29 | print "[%2d] %s %s = %s" % (idx, op.name, shape, count) 30 | total_count += int(count) 31 | print "[Total] variable size: %s" % "{:,}".format(total_count) 32 | 33 | def get_timestamp(): 34 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 35 | return now.strftime('%Y_%m_%d_%H_%M_%S') 36 | 37 | def binarize(images): 38 | return (np.random.uniform(size=images.shape) < images).astype('float32') 39 | 40 | def save_images(images, height, width, n_row, n_col, 41 | cmin=0.0, cmax=1.0, directory="./", prefix="sample"): 42 | images = images.reshape((n_row, n_col, height, width)) 43 | images = images.transpose(1, 2, 0, 3) 44 | images = images.reshape((height * n_row, width * n_col)) 45 | 46 | filename = '%s_%s.jpg' % (prefix, get_timestamp()) 47 | scipy.misc.toimage(images, cmin=cmin, cmax=cmax) \ 48 | .save(os.path.join(directory, filename)) 49 | 50 | def get_model_dir(config, exceptions=None): 51 | attrs = config.__dict__['__flags'] 52 | pp(attrs) 53 | 54 | keys = attrs.keys() 55 | keys.sort() 56 | keys.remove('data') 57 | keys = ['data'] + keys 58 | 59 | names =[] 60 | for key in keys: 61 | # Only use useful flags 62 | if key not in exceptions: 63 | names.append("%s=%s" % (key, ",".join([str(i) for i in attrs[key]]) 64 | if type(attrs[key]) == list else attrs[key])) 65 | return os.path.join('checkpoints', *names) + '/' 66 | 67 | def preprocess_conf(conf): 68 | options = conf.__flags 69 | 70 | for option, value in options.items(): 71 | option = option.lower() 72 | 73 | def check_and_create_dir(directory): 74 | if not os.path.exists(directory): 75 | logger.info('Creating directory: %s' % directory) 76 | os.makedirs(directory) 77 | else: 78 | logger.info('Skip creating directory: %s' % directory) 79 | 80 | def maybe_download_and_extract(dest_directory): 81 | """ 82 | Download and extract the tarball from Alex's website. 83 | From https://github.com/tensorflow/tensorflow/blob/r0.9/tensorflow/models/image/cifar10/cifar10.py 84 | """ 85 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 86 | 87 | if not os.path.exists(dest_directory): 88 | os.makedirs(dest_directory) 89 | 90 | filename = DATA_URL.split('/')[-1] 91 | filepath = os.path.join(dest_directory, filename) 92 | 93 | if not os.path.exists(filepath): 94 | def _progress(count, block_size, total_size): 95 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 96 | float(count * block_size) / float(total_size) * 100.0)) 97 | sys.stdout.flush() 98 | filepath, _ = urllib.urlretrieve(DATA_URL, filepath, _progress) 99 | print() 100 | statinfo = os.stat(filepath) 101 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 102 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 103 | --------------------------------------------------------------------------------