├── .gitignore ├── .gitmodules ├── Cifar Plot.ipynb ├── ImageNet Plot.ipynb ├── Mnist Plot.ipynb ├── README.md ├── assets ├── admiral.png ├── alp.png ├── brown_bear.png ├── cifar10_test_original.png ├── cifar10_test_recon.png ├── coral_reef.png ├── gray_whale.png ├── imagenet_val_original.png ├── imagenet_val_recon.png ├── mnist_diff_codes.png ├── mnist_randomwalk.gif ├── mnist_randomwalk.mp4 ├── mnist_test_original.png ├── mnist_test_recon.png ├── pickup.png ├── sampled_cifar10.png └── sampled_mnist.png ├── cifar10.py ├── commons ├── .gitignore ├── __init__.py └── ops.py ├── imagenet.py ├── mnist.py ├── model.py └── models ├── cifar10 ├── last-pixelcnn.ckpt.data-00000-of-00001 ├── last-pixelcnn.ckpt.index ├── last-pixelcnn.ckpt.meta ├── last.ckpt.data-00000-of-00001 ├── last.ckpt.index └── last.ckpt.meta ├── imagenet ├── last.ckpt.data-00000-of-00001 ├── last.ckpt.index └── last.ckpt.meta └── mnist ├── last-pixelcnn.ckpt.data-00000-of-00001 ├── last-pixelcnn.ckpt.index ├── last-pixelcnn.ckpt.meta ├── last.ckpt.data-00000-of-00001 ├── last.ckpt.index └── last.ckpt.meta /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | log/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | .static_storage/ 59 | .media/ 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "slim_models"] 2 | path = slim_models 3 | url = git@github.com:tensorflow/models.git 4 | [submodule "Conditional-PixelCNN-decoder"] 5 | path = pixelcnn 6 | url = git@github.com:anantzoid/Conditional-PixelCNN-decoder.git 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQ-VAE (Neural Discrete Representation Learning) Tensorflow 2 | 3 | ## Intro 4 | 5 | This repository implements the paper, [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) (VQ-VAE) in Tensorflow. 6 | 7 | :warning: This is not an official implementation, and might have some glitch (,or a major defect). 8 | 9 | ## Requirements 10 | 11 | - Python 3.5 12 | - Tensorflow (v1.3 or higher) 13 | - numpy, better_exceptions, tqdm, etc. 14 | - ffmpeg 15 | 16 | ## Updated Result: ImageNet 17 | 18 | - [x] ImageNet 19 | 20 | | Validation Set Images | Reconstructed Images | 21 | | ------------- |:-------------:| 22 | |![Imagenet original images](/assets/imagenet_val_original.png) | ![Imagenet Reconstructed Images](/assets/imagenet_val_recon.png) | 23 | 24 | - Class Conditioned Sampled Image (Not cherry-picked, just random sample) 25 | 26 | ![alp](/assets/alp.png) 27 | 28 | ![admiral](/assets/admiral.png) 29 | 30 | ![coral reef](/assets/coral_reef.png) 31 | 32 | ![gray_whale](/assets/gray_whale.png) 33 | 34 | ![brown bear](/assets/brown_bear.png) 35 | 36 | ![pickup truck](/assets/pickup.png) 37 | 38 | - I could not reproduce as sharp images as the author produced. 39 | - But, some of results seems understandable. 40 | - Usually, natural scene images having consistent pixel orders shows better result, such as Alp or coral reef. 41 | - More tuning might provide better result. 42 | 43 | ## Updated Result: Sampling with PixelCNN 44 | 45 | - [x] Pixel CNN 46 | 47 | - MNIST Sampled Image (Conditioned on class labels) 48 | 49 | ![MNIST Sampled Images](/assets/sampled_mnist.png) 50 | 51 | - Cifar10 Sampled Image (Conditioned on class labels) 52 | 53 | ![Cifar10 Sampled Imagesl](/assets/sampled_cifar10.png) 54 | 55 | From top row to bottom, the sampled images for classes (airplane,auto,bird,cat,deer,dog,frog,horse,ship,truck) 56 | 57 | Not that satisfying so far; I guess hyperparameters for VQ-VAE should be tuned first to generate more sharper result. 58 | 59 | ## Results 60 | 61 | All training is done with Quadro M4000 GPU. Training MNIST only takes less than 10 minutes. 62 | 63 | - [x] MNIST 64 | 65 | | Original Images | Reconstructed Images | 66 | | ------------- |:-------------:| 67 | |![MNIST original images](/assets/mnist_test_original.png) | ![MNIST Reconstructed Images](/assets/mnist_test_recon.png) | 68 | 69 | The result on MNIST test dataset. (K=20, D=64, latent space=3 by 3) 70 | 71 | I also observed its latent space by changing single value for each latent space from one of the observed latent code. The result is shown below. 72 | ![MNIST Latent Observation](/assets/mnist_diff_codes.png) 73 | 74 | It seems that spatial location of latent code is improtant. By changing latent code on a specific location, the pixel matches with the location is disturbed. 75 | 76 | ![MNIST Latent Observation - Random Walk](/assets/mnist_randomwalk.gif) 77 | 78 | This results shows the 1000 generated images starting from knwon latent codes and changing aa single latent code at radnom location by +1 or -1. 79 | Most of the images are redundant (unrealistic), so it indicates that there are much room for compression. 80 | 81 | If you want to further explore the latent space, then try to play with notebook files I provided. 82 | 83 | - [x] CIFAR 10 84 | 85 | | Original Images | Reconstructed Images | 86 | | ------------- |:-------------:| 87 | |![MNIST original images](/assets/cifar10_test_original.png) | ![MNIST Reconstructed Images](/assets/cifar10_test_recon.png) | 88 | 89 | I was able to get 4.65 bits/dims. (K=10, D=256, latent space=8 by 8) 90 | 91 | 92 | ## Training 93 | 94 | It will download required datasets on the directory `./datasets/{mnist,cifar10}` by itself. 95 | Hence, just run the code will do the trick. 96 | 97 | ### Run train 98 | 99 | - Run mnist: `python mnist.py` 100 | - Run cifar10: `python cifar10.py` 101 | 102 | Change the hyperparameters accordingly as you want. Please check at the bottom of each script. 103 | 104 | ## Evaluation 105 | 106 | I provide the model and the code for generating (,or reconstructing) images in the form of Jupyter notebook. 107 | Run jupyter notebook server, then run it to see more results with provided models. 108 | 109 | If you want to test NLL, then run `test()` function on `cifar.py` by uncomment the line. You can find it at the bottom of the file. 110 | 111 | ## TODO 112 | 113 | - [ ] WaveNet? 114 | 115 | Contributions are welcome! 116 | 117 | ## Thoughts and Help request 118 | 119 | - The results seems correct, but there is a chance that the implmentation is not perfectly correct (especially, gradient copying...). If you find any glitches (or, a major defect) then, please let me know! 120 | - I am currently not sure how exactly NLL should be computed. Anyone who wants me a proper explantion on this? 121 | 122 | ## Acknowledgement 123 | 124 | - The code for Pixel CNN is borrowed from [anantzoid's repo.](https://github.com/anantzoid/Conditional-PixelCNN-decoder) 125 | -------------------------------------------------------------------------------- /assets/admiral.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/admiral.png -------------------------------------------------------------------------------- /assets/alp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/alp.png -------------------------------------------------------------------------------- /assets/brown_bear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/brown_bear.png -------------------------------------------------------------------------------- /assets/cifar10_test_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/cifar10_test_original.png -------------------------------------------------------------------------------- /assets/cifar10_test_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/cifar10_test_recon.png -------------------------------------------------------------------------------- /assets/coral_reef.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/coral_reef.png -------------------------------------------------------------------------------- /assets/gray_whale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/gray_whale.png -------------------------------------------------------------------------------- /assets/imagenet_val_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/imagenet_val_original.png -------------------------------------------------------------------------------- /assets/imagenet_val_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/imagenet_val_recon.png -------------------------------------------------------------------------------- /assets/mnist_diff_codes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/mnist_diff_codes.png -------------------------------------------------------------------------------- /assets/mnist_randomwalk.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/mnist_randomwalk.gif -------------------------------------------------------------------------------- /assets/mnist_randomwalk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/mnist_randomwalk.mp4 -------------------------------------------------------------------------------- /assets/mnist_test_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/mnist_test_original.png -------------------------------------------------------------------------------- /assets/mnist_test_recon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/mnist_test_recon.png -------------------------------------------------------------------------------- /assets/pickup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/pickup.png -------------------------------------------------------------------------------- /assets/sampled_cifar10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/sampled_cifar10.png -------------------------------------------------------------------------------- /assets/sampled_mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/assets/sampled_mnist.png -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import xrange 3 | import os 4 | import better_exceptions 5 | import tensorflow as tf 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from model import VQVAE, _cifar10_arch, PixelCNN 10 | 11 | # The codes are borrowed from 12 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py 13 | # https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_input.py 14 | DATA_DIR = 'datasets/cifar10' 15 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 16 | def maybe_download_and_extract(): 17 | import sys, tarfile 18 | from six.moves import urllib 19 | """Download and extract the tarball from Alex's website.""" 20 | if not os.path.exists(DATA_DIR): 21 | os.makedirs(DATA_DIR) 22 | filename = DATA_URL.split('/')[-1] 23 | filepath = os.path.join(DATA_DIR, filename) 24 | if not os.path.exists(filepath): 25 | def _progress(count, block_size, total_size): 26 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 27 | float(count * block_size) / float(total_size) * 100.0)) 28 | sys.stdout.flush() 29 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 30 | print() 31 | statinfo = os.stat(filepath) 32 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 33 | extracted_dir_path = os.path.join(DATA_DIR, 'cifar-10-batches-bin') 34 | if not os.path.exists(extracted_dir_path): 35 | tarfile.open(filepath, 'r:gz').extractall(DATA_DIR) 36 | 37 | def read_cifar10(filename_queue): 38 | class CIFAR10Record(object): 39 | pass 40 | result = CIFAR10Record() 41 | record_bytes = 1 + 32*32*3 42 | 43 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 44 | result.key, value = reader.read(filename_queue) 45 | record_bytes = tf.decode_raw(value, tf.uint8) 46 | 47 | result.label = tf.cast( 48 | tf.strided_slice(record_bytes, [0], [1]), tf.int32) 49 | depth_major = tf.reshape( 50 | tf.strided_slice(record_bytes, [1], 51 | [1 + 32*32*3]), 52 | [3, 32, 32]) 53 | # Convert from [depth, height, width] to [height, width, depth]. 54 | result.uint8image = tf.transpose(depth_major, [1, 2, 0]) 55 | return result 56 | 57 | def get_image(train=True,num_epochs=None): 58 | maybe_download_and_extract() 59 | if train: 60 | filenames = [os.path.join(DATA_DIR, 'cifar-10-batches-bin', 'data_batch_%d.bin' % i) for i in xrange(1, 6)] 61 | else: 62 | filenames = [os.path.join(DATA_DIR, 'cifar-10-batches-bin', 'test_batch.bin')] 63 | filename_queue = tf.train.string_input_producer(filenames,num_epochs=num_epochs) 64 | read_input = read_cifar10(filename_queue) 65 | return tf.cast(read_input.uint8image, tf.float32) / 255.0, tf.reshape(read_input.label,[]) 66 | 67 | 68 | def main(config, 69 | RANDOM_SEED, 70 | LOG_DIR, 71 | TRAIN_NUM, 72 | BATCH_SIZE, 73 | LEARNING_RATE, 74 | DECAY_VAL, 75 | DECAY_STEPS, 76 | DECAY_STAIRCASE, 77 | BETA, 78 | K, 79 | D, 80 | SAVE_PERIOD, 81 | SUMMARY_PERIOD, 82 | **kwargs): 83 | np.random.seed(RANDOM_SEED) 84 | tf.set_random_seed(RANDOM_SEED) 85 | 86 | # >>>>>>> DATASET 87 | image,_ = get_image() 88 | images = tf.train.shuffle_batch( 89 | [image], 90 | batch_size=BATCH_SIZE, 91 | num_threads=4, 92 | capacity=BATCH_SIZE*10, 93 | min_after_dequeue=BATCH_SIZE*2) 94 | valid_image,_ = get_image(False) 95 | valid_images = tf.train.shuffle_batch( 96 | [valid_image], 97 | batch_size=BATCH_SIZE, 98 | num_threads=1, 99 | capacity=BATCH_SIZE*10, 100 | min_after_dequeue=BATCH_SIZE*2) 101 | # <<<<<<< 102 | 103 | # >>>>>>> MODEL 104 | with tf.variable_scope('train'): 105 | global_step = tf.Variable(0, trainable=False) 106 | learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) 107 | tf.summary.scalar('lr',learning_rate) 108 | 109 | with tf.variable_scope('params') as params: 110 | pass 111 | net = VQVAE(learning_rate,global_step,BETA,images,K,D,_cifar10_arch,params,True) 112 | 113 | with tf.variable_scope('valid'): 114 | params.reuse_variables() 115 | valid_net = VQVAE(None,None,BETA,valid_images,K,D,_cifar10_arch,params,False) 116 | 117 | with tf.variable_scope('misc'): 118 | # Summary Operations 119 | tf.summary.scalar('loss',net.loss) 120 | tf.summary.scalar('recon',net.recon) 121 | tf.summary.scalar('vq',net.vq) 122 | tf.summary.scalar('commit',BETA*net.commit) 123 | tf.summary.scalar('nll',tf.reduce_mean(net.nll)) 124 | tf.summary.image('origin',images,max_outputs=4) 125 | tf.summary.image('recon',net.p_x_z,max_outputs=4) 126 | # TODO: logliklihood 127 | 128 | summary_op = tf.summary.merge_all() 129 | 130 | # Initialize op 131 | init_op = tf.group(tf.global_variables_initializer(), 132 | tf.local_variables_initializer()) 133 | config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) 134 | 135 | extended_summary_op = tf.summary.merge([ 136 | tf.summary.scalar('valid_loss',valid_net.loss), 137 | tf.summary.scalar('valid_recon',valid_net.recon), 138 | tf.summary.scalar('valid_vq',valid_net.vq), 139 | tf.summary.scalar('valid_commit',BETA*valid_net.commit), 140 | tf.summary.scalar('valid_nll',tf.reduce_mean(valid_net.nll)), 141 | tf.summary.image('valid_origin',valid_images,max_outputs=4), 142 | tf.summary.image('valid_recon',valid_net.p_x_z,max_outputs=4), 143 | ]) 144 | 145 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 146 | config = tf.ConfigProto() 147 | config.gpu_options.allow_growth = True 148 | sess = tf.Session(config=config) 149 | sess.graph.finalize() 150 | sess.run(init_op) 151 | 152 | summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) 153 | summary_writer.add_summary(config_summary.eval(session=sess)) 154 | 155 | try: 156 | # Start Queueing 157 | coord = tf.train.Coordinator() 158 | threads = tf.train.start_queue_runners(coord=coord,sess=sess) 159 | for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): 160 | it,loss,_ = sess.run([global_step,net.loss,net.train_op]) 161 | 162 | if( it % SAVE_PERIOD == 0 ): 163 | net.save(sess,LOG_DIR,step=it) 164 | 165 | if( it % SUMMARY_PERIOD == 0 ): 166 | tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) 167 | summary = sess.run(summary_op) 168 | summary_writer.add_summary(summary,it) 169 | 170 | if( it % (SUMMARY_PERIOD*2) == 0 ): #Extended Summary 171 | summary = sess.run(extended_summary_op) 172 | summary_writer.add_summary(summary,it) 173 | 174 | except Exception as e: 175 | coord.request_stop(e) 176 | finally : 177 | net.save(sess,LOG_DIR) 178 | 179 | coord.request_stop() 180 | coord.join(threads) 181 | 182 | def test(MODEL, 183 | BETA, 184 | K, 185 | D, 186 | **kwargs): 187 | # >>>>>>> DATASET 188 | image,_ = get_image(num_epochs=1) 189 | images = tf.train.batch( 190 | [image], 191 | batch_size=100, 192 | num_threads=1, 193 | capacity=100, 194 | allow_smaller_final_batch=True) 195 | valid_image,_ = get_image(False,num_epochs=1) 196 | valid_images = tf.train.batch( 197 | [valid_image], 198 | batch_size=100, 199 | num_threads=1, 200 | capacity=100, 201 | allow_smaller_final_batch=True) 202 | # <<<<<<< 203 | 204 | # >>>>>>> MODEL 205 | with tf.variable_scope('net'): 206 | with tf.variable_scope('params') as params: 207 | pass 208 | x = tf.placeholder(tf.float32,[None,32,32,3]) 209 | net= VQVAE(None,None,BETA,x,K,D,_cifar10_arch,params,False) 210 | 211 | init_op = tf.group(tf.global_variables_initializer(), 212 | tf.local_variables_initializer()) 213 | 214 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 215 | config = tf.ConfigProto() 216 | config.gpu_options.allow_growth = True 217 | sess = tf.Session(config=config) 218 | sess.graph.finalize() 219 | sess.run(init_op) 220 | net.load(sess,MODEL) 221 | 222 | 223 | coord = tf.train.Coordinator() 224 | threads = tf.train.start_queue_runners(coord=coord,sess=sess) 225 | try: 226 | nlls = [] 227 | while not coord.should_stop(): 228 | nlls.append( 229 | sess.run(net.nll,feed_dict={x:sess.run(valid_images)})) 230 | print('.', end='', flush=True) 231 | except tf.errors.OutOfRangeError: 232 | nlls = np.concatenate(nlls,axis=0) 233 | print(nlls.shape) 234 | print('NLL for test set: %f bits/dims'%(np.mean(nlls))) 235 | 236 | try: 237 | nlls = [] 238 | while not coord.should_stop(): 239 | nlls.append( 240 | sess.run(net.nll,feed_dict={x:sess.run(images)})) 241 | print('.', end='', flush=True) 242 | except tf.errors.OutOfRangeError: 243 | nlls = np.concatenate(nlls,axis=0) 244 | print(nlls.shape) 245 | print('NLL for training set: %f bits/dims'%(np.mean(nlls))) 246 | 247 | coord.request_stop() 248 | coord.join(threads) 249 | 250 | def extract_z(MODEL, 251 | BATCH_SIZE, 252 | BETA, 253 | K, 254 | D, 255 | **kwargs): 256 | # >>>>>>> DATASET 257 | image,label = get_image(num_epochs=1) 258 | images,labels = tf.train.batch( 259 | [image,label], 260 | batch_size=BATCH_SIZE, 261 | num_threads=1, 262 | capacity=BATCH_SIZE, 263 | allow_smaller_final_batch=True) 264 | # <<<<<<< 265 | 266 | # >>>>>>> MODEL 267 | with tf.variable_scope('net'): 268 | with tf.variable_scope('params') as params: 269 | pass 270 | x_ph = tf.placeholder(tf.float32,[None,32,32,3]) 271 | net= VQVAE(None,None,BETA,x_ph,K,D,_cifar10_arch,params,False) 272 | 273 | init_op = tf.group(tf.global_variables_initializer(), 274 | tf.local_variables_initializer()) 275 | 276 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 277 | config = tf.ConfigProto() 278 | config.gpu_options.allow_growth = True 279 | sess = tf.Session(config=config) 280 | sess.graph.finalize() 281 | sess.run(init_op) 282 | net.load(sess,MODEL) 283 | 284 | 285 | coord = tf.train.Coordinator() 286 | threads = tf.train.start_queue_runners(coord=coord,sess=sess) 287 | try: 288 | ks = [] 289 | ys = [] 290 | while not coord.should_stop(): 291 | x,y = sess.run([images,labels]) 292 | k = sess.run(net.k,feed_dict={x_ph:x}) 293 | ks.append(k) 294 | ys.append(y) 295 | print('.', end='', flush=True) 296 | except tf.errors.OutOfRangeError: 297 | print('Extracting Finished') 298 | 299 | ks = np.concatenate(ks,axis=0) 300 | ys = np.concatenate(ys,axis=0) 301 | np.savez(os.path.join(os.path.dirname(MODEL),'ks_ys.npz'),ks=ks,ys=ys) 302 | 303 | coord.request_stop() 304 | coord.join(threads) 305 | 306 | def train_prior(config, 307 | RANDOM_SEED, 308 | MODEL, 309 | TRAIN_NUM, 310 | BATCH_SIZE, 311 | LEARNING_RATE, 312 | DECAY_VAL, 313 | DECAY_STEPS, 314 | DECAY_STAIRCASE, 315 | GRAD_CLIP, 316 | K, 317 | D, 318 | BETA, 319 | NUM_LAYERS, 320 | NUM_FEATURE_MAPS, 321 | SUMMARY_PERIOD, 322 | SAVE_PERIOD, 323 | **kwargs): 324 | np.random.seed(RANDOM_SEED) 325 | tf.set_random_seed(RANDOM_SEED) 326 | LOG_DIR = os.path.join(os.path.dirname(MODEL),'pixelcnn_6') 327 | # >>>>>>> DATASET 328 | class Latents(): 329 | def __init__(self,path,validation_size=1): 330 | from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet 331 | from tensorflow.contrib.learn.python.learn.datasets import base 332 | 333 | data = np.load(path) 334 | train = DataSet(data['ks'][validation_size:], data['ys'][validation_size:],reshape=False,dtype=np.uint8,one_hot=False) #dtype won't bother even in the case when latent is int32 type. 335 | validation = DataSet(data['ks'][:validation_size], data['ys'][:validation_size],reshape=False,dtype=np.uint8,one_hot=False) 336 | #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False) 337 | self.size = data['ks'].shape[1] 338 | self.data = base.Datasets(train=train, validation=validation, test=None) 339 | latent = Latents(os.path.join(os.path.dirname(MODEL),'ks_ys.npz')) 340 | # <<<<<<< 341 | 342 | # >>>>>>> MODEL for Generate Images 343 | with tf.variable_scope('net'): 344 | with tf.variable_scope('params') as params: 345 | pass 346 | _not_used = tf.placeholder(tf.float32,[None,32,32,3]) 347 | vq_net = VQVAE(None,None,BETA,_not_used,K,D,_cifar10_arch,params,False) 348 | # <<<<<<< 349 | 350 | # >>>>>> MODEL for Training Prior 351 | with tf.variable_scope('pixelcnn'): 352 | global_step = tf.Variable(0, trainable=False) 353 | learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) 354 | tf.summary.scalar('lr',learning_rate) 355 | 356 | net = PixelCNN(learning_rate,global_step,GRAD_CLIP, 357 | latent.size,vq_net.embeds,K,D, 358 | 10,NUM_LAYERS,NUM_FEATURE_MAPS) 359 | # <<<<<< 360 | with tf.variable_scope('misc'): 361 | # Summary Operations 362 | tf.summary.scalar('loss',net.loss) 363 | summary_op = tf.summary.merge_all() 364 | 365 | # Initialize op 366 | init_op = tf.group(tf.global_variables_initializer(), 367 | tf.local_variables_initializer()) 368 | config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) 369 | 370 | sample_images = tf.placeholder(tf.float32,[None,32,32,3]) 371 | sample_summary_op = tf.summary.image('samples',sample_images,max_outputs=20) 372 | 373 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 374 | config = tf.ConfigProto() 375 | config.gpu_options.allow_growth = True 376 | sess = tf.Session(config=config) 377 | sess.graph.finalize() 378 | sess.run(init_op) 379 | vq_net.load(sess,MODEL) 380 | 381 | summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) 382 | summary_writer.add_summary(config_summary.eval(session=sess)) 383 | 384 | for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): 385 | batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE) 386 | it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={net.X:batch_xs,net.h:batch_ys}) 387 | 388 | if( it % SAVE_PERIOD == 0 ): 389 | net.save(sess,LOG_DIR,step=it) 390 | 391 | if( it % SUMMARY_PERIOD == 0 ): 392 | tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) 393 | summary = sess.run(summary_op,feed_dict={net.X:batch_xs,net.h:batch_ys}) 394 | summary_writer.add_summary(summary,it) 395 | 396 | if( it % (SUMMARY_PERIOD * 2) == 0 ): 397 | sampled_zs,log_probs = net.sample_from_prior(sess,np.arange(10),2) 398 | sampled_ims = sess.run(vq_net.gen,feed_dict={vq_net.latent:sampled_zs}) 399 | summary_writer.add_summary( 400 | sess.run(sample_summary_op,feed_dict={sample_images:sampled_ims}),it) 401 | 402 | net.save(sess,LOG_DIR) 403 | 404 | def get_default_param(): 405 | from datetime import datetime 406 | now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 407 | return { 408 | 'LOG_DIR':'./log/cifar10/%s'%(now), 409 | 'MODEL' : './log/cifar10/%s/last.ckpt'%(now), 410 | 411 | 'TRAIN_NUM' : 250000, #Size corresponds to one epoch 412 | 'BATCH_SIZE': 128, 413 | 414 | 'LEARNING_RATE' : 0.0002, 415 | 'DECAY_VAL' : 1.0, 416 | 'DECAY_STEPS' : 20000, # Half of the training procedure. 417 | 'DECAY_STAIRCASE' : False, 418 | 419 | 'BETA':0.25, 420 | 'K':10, 421 | 'D':256, 422 | 423 | # PixelCNN Params 424 | 'GRAD_CLIP' : 5.0, 425 | 'NUM_LAYERS' : 12, 426 | 'NUM_FEATURE_MAPS' : 64, 427 | 428 | 'SUMMARY_PERIOD' : 100, 429 | 'SAVE_PERIOD' : 10000, 430 | 'RANDOM_SEED': 0, 431 | } 432 | 433 | if __name__ == "__main__": 434 | class MyConfig(dict): 435 | pass 436 | params = get_default_param() 437 | config = MyConfig(params) 438 | def as_matrix() : 439 | return [[k, str(w)] for k, w in config.items()] 440 | config.as_matrix = as_matrix 441 | 442 | main(config=config,**config) 443 | extract_z(**config) 444 | config['TRAIN_NUM'] = 300000 445 | config['LEARNING_RATE'] = 0.001 446 | config['DECAY_VAL'] = 0.5 447 | config['DECAY_STEPS'] = 100000 448 | train_prior(config=config,**config) 449 | 450 | #test(MODEL='models/cifar10/last.ckpt',**config) 451 | -------------------------------------------------------------------------------- /commons/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /commons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/commons/__init__.py -------------------------------------------------------------------------------- /commons/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class Conv2d(object) : 4 | def __init__(self,name,input_dim,output_dim,k_h=4,k_w=4,d_h=2,d_w=2, 5 | stddev=0.02, data_format='NCHW') : 6 | with tf.variable_scope(name) : 7 | assert(data_format == 'NCHW' or data_format == 'NHWC') 8 | self.w = tf.get_variable('w', [k_h, k_w, input_dim, output_dim], 9 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 10 | self.b = tf.get_variable('b',[output_dim], initializer=tf.constant_initializer(0.0)) 11 | if( data_format == 'NCHW' ) : 12 | self.strides = [1, 1, d_h, d_w] 13 | else : 14 | self.strides = [1, d_h, d_w, 1] 15 | self.data_format = data_format 16 | def __call__(self,input_var,name=None,w=None,b=None,**kwargs) : 17 | w = w if w is not None else self.w 18 | b = b if b is not None else self.b 19 | 20 | if( self.data_format =='NCHW' ) : 21 | return tf.nn.bias_add( 22 | tf.nn.conv2d(input_var, w, 23 | use_cudnn_on_gpu=True,data_format='NCHW', 24 | strides=self.strides, padding='SAME'), 25 | b,data_format='NCHW',name=name) 26 | else : 27 | return tf.nn.bias_add( 28 | tf.nn.conv2d(input_var, w,data_format='NHWC', 29 | strides=self.strides, padding='SAME'), 30 | b,data_format='NHWC',name=name) 31 | def get_variables(self): 32 | return {'w':self.w,'b':self.b} 33 | 34 | class Linear(object) : 35 | def __init__(self,name,input_dim,output_dim,stddev=0.02) : 36 | with tf.variable_scope(name) : 37 | self.w = tf.get_variable('w',[input_dim, output_dim], 38 | initializer=tf.random_normal_initializer(stddev=stddev)) 39 | self.b = tf.get_variable('b',[output_dim], 40 | initializer=tf.constant_initializer(0.0)) 41 | 42 | def __call__(self,input_var,name=None,w=None,b=None,**kwargs) : 43 | w = w if w is not None else self.w 44 | b = b if b is not None else self.b 45 | 46 | if( len(input_var.get_shape().dims) > 2 ) : 47 | dims = tf.reduce_prod(tf.shape(input_var)[1:]) 48 | return tf.matmul(tf.reshape(input_var,[-1,dims]),w) + b 49 | else : 50 | return tf.matmul(input_var,w)+b 51 | def get_variables(self): 52 | return {'w':self.w,'b':self.b} 53 | 54 | class TransposedConv2d(object): 55 | def __init__(self,name,input_dim,out_dim, 56 | k_h=4,k_w=4,d_h=2,d_w=2,stddev=0.02,data_format='NCHW') : 57 | with tf.variable_scope(name) : 58 | self.w = tf.get_variable('w', [k_h, k_w, out_dim, input_dim], 59 | initializer=tf.random_normal_initializer(stddev=stddev)) 60 | self.b = tf.get_variable('b',[out_dim], initializer=tf.constant_initializer(0.0)) 61 | 62 | self.data_format = data_format 63 | if( data_format =='NCHW' ): 64 | self.strides = [1, 1, d_h, d_w] 65 | else: 66 | self.strides = [1, d_h, d_w, 1] 67 | 68 | def __call__(self,input_var,name=None,**xargs): 69 | shapes = tf.shape(input_var) 70 | if( self.data_format == 'NCHW' ): 71 | shapes = tf.stack([shapes[0],tf.shape(self.b)[0],shapes[2]*2,shapes[3]*2]) 72 | else: 73 | shapes = tf.stack([shapes[0],shapes[1]*2,shapes[2]*2,tf.shape(self.b)[0]]) 74 | 75 | return tf.nn.bias_add( 76 | tf.nn.conv2d_transpose(input_var,self.w,output_shape=shapes, 77 | data_format=self.data_format, 78 | strides=self.strides,padding='SAME'), 79 | self.b,data_format=self.data_format,name=name) 80 | 81 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | from six.moves import xrange 2 | import os 3 | import better_exceptions 4 | import tensorflow as tf 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from model import VQVAE, _imagenet_arch, PixelCNN 9 | 10 | import sys 11 | sys.path.append('slim_models/research/slim') 12 | from datasets import imagenet 13 | slim = tf.contrib.slim 14 | def _build_batch(dataset,batch_size,num_threads): 15 | with tf.device('/cpu'): 16 | provider = slim.dataset_data_provider.DatasetDataProvider( 17 | dataset, 18 | num_readers=num_threads, 19 | common_queue_capacity=20*batch_size, 20 | common_queue_min=10*batch_size, 21 | shuffle=True) 22 | image,label = provider.get(['image','label']) 23 | # Slim module has a background label as 0. By changing this, you need to use (label_num-1) 24 | # on Jupyter notebook to generate class conditioned samples. 25 | #label -= 1 26 | pp_image = tf.image.resize_images(image,[128,128]) / 255.0 27 | 28 | images,labels = tf.train.batch( 29 | [pp_image,label], 30 | batch_size=batch_size, 31 | num_threads=num_threads, 32 | capacity=5*batch_size, 33 | allow_smaller_final_batch=True) 34 | return images, labels 35 | 36 | def main(config, 37 | RANDOM_SEED, 38 | LOG_DIR, 39 | TRAIN_NUM, 40 | BATCH_SIZE, 41 | LEARNING_RATE, 42 | DECAY_VAL, 43 | DECAY_STEPS, 44 | DECAY_STAIRCASE, 45 | BETA, 46 | K, 47 | D, 48 | SAVE_PERIOD, 49 | SUMMARY_PERIOD, 50 | **kwargs): 51 | np.random.seed(RANDOM_SEED) 52 | tf.set_random_seed(RANDOM_SEED) 53 | 54 | # >>>>>>> DATASET 55 | train_dataset = imagenet.get_split('train','datasets/ILSVRC2012') 56 | valid_dataset = imagenet.get_split('validation','datasets/ILSVRC2012') 57 | train_ims,_ = _build_batch(train_dataset,BATCH_SIZE,4) 58 | valid_ims,_ = _build_batch(valid_dataset,4,1) 59 | 60 | # >>>>>>> MODEL 61 | with tf.variable_scope('train'): 62 | global_step = tf.Variable(0, trainable=False) 63 | learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) 64 | tf.summary.scalar('lr',learning_rate) 65 | 66 | with tf.variable_scope('params') as params: 67 | pass 68 | net = VQVAE(learning_rate,global_step,BETA,train_ims,K,D,_imagenet_arch,params,True) 69 | 70 | with tf.variable_scope('valid'): 71 | params.reuse_variables() 72 | valid_net = VQVAE(None,None,BETA,valid_ims,K,D,_imagenet_arch,params,False) 73 | 74 | with tf.variable_scope('misc'): 75 | # Summary Operations 76 | tf.summary.scalar('loss',net.loss) 77 | tf.summary.scalar('recon',net.recon) 78 | tf.summary.scalar('vq',net.vq) 79 | tf.summary.scalar('commit',BETA*net.commit) 80 | tf.summary.scalar('nll',tf.reduce_mean(net.nll)) 81 | tf.summary.image('origin',train_ims,max_outputs=4) 82 | tf.summary.image('recon',net.p_x_z,max_outputs=4) 83 | summary_op = tf.summary.merge_all() 84 | 85 | # Initialize op 86 | init_op = tf.group(tf.global_variables_initializer(), 87 | tf.local_variables_initializer()) 88 | config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) 89 | 90 | extended_summary_op = tf.summary.merge([ 91 | tf.summary.scalar('valid_loss',valid_net.loss), 92 | tf.summary.scalar('valid_recon',valid_net.recon), 93 | tf.summary.scalar('valid_vq',valid_net.vq), 94 | tf.summary.scalar('valid_commit',BETA*valid_net.commit), 95 | tf.summary.scalar('valid_nll',tf.reduce_mean(valid_net.nll)), 96 | tf.summary.image('valid_origin',valid_ims,max_outputs=4), 97 | tf.summary.image('valid_recon',valid_net.p_x_z,max_outputs=4), 98 | ]) 99 | # <<<<<<<<<< 100 | 101 | 102 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 103 | config = tf.ConfigProto() 104 | config.gpu_options.allow_growth = True 105 | sess = tf.Session(config=config) 106 | sess.graph.finalize() 107 | sess.run(init_op) 108 | 109 | summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) 110 | summary_writer.add_summary(config_summary.eval(session=sess)) 111 | 112 | try: 113 | # Start Queueing 114 | coord = tf.train.Coordinator() 115 | threads = tf.train.start_queue_runners(coord=coord,sess=sess) 116 | for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): 117 | it,loss,_ = sess.run([global_step,net.loss,net.train_op]) 118 | 119 | if( it % SAVE_PERIOD == 0 ): 120 | net.save(sess,LOG_DIR,step=it) 121 | 122 | if( it % SUMMARY_PERIOD == 0 ): 123 | tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) 124 | summary = sess.run(summary_op) 125 | summary_writer.add_summary(summary,it) 126 | 127 | if( it % (SUMMARY_PERIOD*2) == 0 ): #Extended Summary 128 | summary = sess.run(extended_summary_op) 129 | summary_writer.add_summary(summary,it) 130 | 131 | except Exception as e: 132 | coord.request_stop(e) 133 | finally : 134 | net.save(sess,LOG_DIR) 135 | 136 | coord.request_stop() 137 | coord.join(threads) 138 | 139 | def train_prior(config, 140 | RANDOM_SEED, 141 | MODEL, 142 | TRAIN_NUM, 143 | BATCH_SIZE, 144 | LEARNING_RATE, 145 | DECAY_VAL, 146 | DECAY_STEPS, 147 | DECAY_STAIRCASE, 148 | GRAD_CLIP, 149 | K, 150 | D, 151 | BETA, 152 | NUM_LAYERS, 153 | NUM_FEATURE_MAPS, 154 | SUMMARY_PERIOD, 155 | SAVE_PERIOD, 156 | **kwargs): 157 | np.random.seed(RANDOM_SEED) 158 | tf.set_random_seed(RANDOM_SEED) 159 | LOG_DIR = os.path.join(os.path.dirname(MODEL),'pixelcnn') 160 | # >>>>>>> DATASET 161 | train_dataset = imagenet.get_split('train','datasets/ILSVRC2012') 162 | ims,labels = _build_batch(train_dataset,BATCH_SIZE,4) 163 | # <<<<<<< 164 | 165 | # >>>>>>> MODEL for Generate Images 166 | with tf.variable_scope('net'): 167 | with tf.variable_scope('params') as params: 168 | pass 169 | vq_net = VQVAE(None,None,BETA,ims,K,D,_imagenet_arch,params,False) 170 | # <<<<<<< 171 | 172 | # >>>>>> MODEL for Training Prior 173 | with tf.variable_scope('pixelcnn'): 174 | global_step = tf.Variable(0, trainable=False) 175 | learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) 176 | tf.summary.scalar('lr',learning_rate) 177 | 178 | net = PixelCNN(learning_rate,global_step,GRAD_CLIP, 179 | vq_net.k.get_shape()[1],vq_net.embeds,K,D, 180 | 1000,NUM_LAYERS,NUM_FEATURE_MAPS) 181 | # <<<<<< 182 | with tf.variable_scope('misc'): 183 | # Summary Operations 184 | tf.summary.scalar('loss',net.loss) 185 | summary_op = tf.summary.merge_all() 186 | 187 | # Initialize op 188 | init_op = tf.group(tf.global_variables_initializer(), 189 | tf.local_variables_initializer()) 190 | config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) 191 | 192 | sample_images = tf.placeholder(tf.float32,[None,128,128,3]) 193 | sample_summary_op = tf.summary.image('samples',sample_images,max_outputs=20) 194 | 195 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 196 | config = tf.ConfigProto() 197 | config.gpu_options.allow_growth = True 198 | sess = tf.Session(config=config) 199 | sess.graph.finalize() 200 | sess.run(init_op) 201 | vq_net.load(sess,MODEL) 202 | 203 | summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) 204 | summary_writer.add_summary(config_summary.eval(session=sess)) 205 | 206 | coord = tf.train.Coordinator() 207 | threads = tf.train.start_queue_runners(coord=coord,sess=sess) 208 | try: 209 | for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): 210 | batch_xs,batch_ys = sess.run([vq_net.k,labels]) 211 | it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={net.X:batch_xs,net.h:batch_ys}) 212 | 213 | if( it % SAVE_PERIOD == 0 ): 214 | net.save(sess,LOG_DIR,step=it) 215 | sampled_zs,log_probs = net.sample_from_prior(sess,np.random.randint(0,1000,size=(10,)),2) 216 | sampled_ims = sess.run(vq_net.gen,feed_dict={vq_net.latent:sampled_zs}) 217 | summary_writer.add_summary( 218 | sess.run(sample_summary_op,feed_dict={sample_images:sampled_ims}),it) 219 | 220 | if( it % SUMMARY_PERIOD == 0 ): 221 | tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) 222 | summary = sess.run(summary_op,feed_dict={net.X:batch_xs,net.h:batch_ys}) 223 | summary_writer.add_summary(summary,it) 224 | 225 | except Exception as e: 226 | coord.request_stop(e) 227 | finally : 228 | net.save(sess,LOG_DIR) 229 | 230 | coord.request_stop() 231 | coord.join(threads) 232 | 233 | def get_default_param(): 234 | from datetime import datetime 235 | now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 236 | return { 237 | #'LOG_DIR':'./log/imagenet/%s'%('test'), 238 | 'LOG_DIR':'./log/imagenet/%s'%(now), 239 | 240 | 'TRAIN_NUM' : 50000, #Size corresponds to one epoch 241 | 'BATCH_SIZE': 16, 242 | 243 | 'LEARNING_RATE' : 0.0002, 244 | 'DECAY_VAL' : 0.5, 245 | 'DECAY_STEPS' : 25000, # Half of the training procedure. 246 | 'DECAY_STAIRCASE' : False, 247 | 248 | 'BETA':0.25, 249 | 'K':512, 250 | 'D':128, 251 | 252 | # PixelCNN Params 253 | 'GRAD_CLIP' : 5.0, 254 | 'NUM_LAYERS' : 18, 255 | 'NUM_FEATURE_MAPS' : 256, 256 | 257 | 'SUMMARY_PERIOD' : 50, 258 | 'SAVE_PERIOD' : 10000, 259 | 'RANDOM_SEED': 0, 260 | } 261 | 262 | if __name__ == "__main__": 263 | class MyConfig(dict): 264 | pass 265 | params = get_default_param() 266 | config = MyConfig(params) 267 | def as_matrix() : 268 | return [[k, str(w)] for k, w in config.items()] 269 | config.as_matrix = as_matrix 270 | 271 | main(config=config,**config) 272 | config['LEARNING_RATE'] = 0.0004 273 | config['TRAIN_NUM'] = 300000 274 | config['BATCH_SIZE'] = 16 275 | config['DECAY_STEPS'] = 100000 276 | train_prior(config=config,**config) 277 | 278 | #TODO: 279 | # Reduce memory usage by batch learn batch_xs gathering process with batchsize 1 280 | # Only training for specific class labels. (1000 is too large classes) 281 | # Find correct ys...(Coral Reef, or something) 282 | 283 | #Warning: 284 | # Uncomment line 20 for training from scratch... The slim module assigns 0 for background. 285 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | from six.moves import xrange 2 | import os 3 | import better_exceptions 4 | import tensorflow as tf 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from model import VQVAE, _mnist_arch, PixelCNN 9 | 10 | def main(config, 11 | RANDOM_SEED, 12 | LOG_DIR, 13 | TRAIN_NUM, 14 | BATCH_SIZE, 15 | LEARNING_RATE, 16 | DECAY_VAL, 17 | DECAY_STEPS, 18 | DECAY_STAIRCASE, 19 | BETA, 20 | K, 21 | D, 22 | SAVE_PERIOD, 23 | SUMMARY_PERIOD, 24 | **kwargs): 25 | np.random.seed(RANDOM_SEED) 26 | tf.set_random_seed(RANDOM_SEED) 27 | 28 | # >>>>>>> DATASET 29 | from tensorflow.examples.tutorials.mnist import input_data 30 | mnist = input_data.read_data_sets("datasets/mnist", one_hot=False) 31 | # <<<<<<< 32 | 33 | # >>>>>>> MODEL 34 | x = tf.placeholder(tf.float32,[None,784]) 35 | resized = tf.image.resize_images( 36 | tf.reshape(x,[-1,28,28,1]), 37 | (24,24), 38 | method=tf.image.ResizeMethod.BILINEAR) 39 | 40 | with tf.variable_scope('train'): 41 | global_step = tf.Variable(0, trainable=False) 42 | learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) 43 | tf.summary.scalar('lr',learning_rate) 44 | 45 | with tf.variable_scope('params') as params: 46 | pass 47 | net = VQVAE(learning_rate,global_step,BETA,resized,K,D,_mnist_arch,params,True) 48 | 49 | with tf.variable_scope('valid'): 50 | params.reuse_variables() 51 | valid_net = VQVAE(None,None,BETA,resized,K,D,_mnist_arch,params,False) 52 | 53 | with tf.variable_scope('misc'): 54 | # Summary Operations 55 | tf.summary.scalar('loss',net.loss) 56 | tf.summary.scalar('recon',net.recon) 57 | tf.summary.scalar('vq',net.vq) 58 | tf.summary.scalar('commit',BETA*net.commit) 59 | tf.summary.image('origin',resized,max_outputs=4) 60 | tf.summary.image('recon',net.p_x_z,max_outputs=4) 61 | # TODO: logliklihood 62 | 63 | summary_op = tf.summary.merge_all() 64 | 65 | # Initialize op 66 | init_op = tf.group(tf.global_variables_initializer(), 67 | tf.local_variables_initializer()) 68 | config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) 69 | 70 | extended_summary_op = tf.summary.merge([ 71 | tf.summary.scalar('valid_loss',valid_net.loss), 72 | tf.summary.scalar('valid_recon',valid_net.recon), 73 | tf.summary.scalar('valid_vq',valid_net.vq), 74 | tf.summary.scalar('valid_commit',BETA*valid_net.commit), 75 | tf.summary.image('valid_recon',valid_net.p_x_z,max_outputs=10), 76 | ]) 77 | 78 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 79 | config = tf.ConfigProto() 80 | config.gpu_options.allow_growth = True 81 | sess = tf.Session(config=config) 82 | sess.graph.finalize() 83 | sess.run(init_op) 84 | 85 | summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) 86 | summary_writer.add_summary(config_summary.eval(session=sess)) 87 | 88 | for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): 89 | batch_xs, _= mnist.train.next_batch(BATCH_SIZE) 90 | it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={x:batch_xs}) 91 | 92 | if( it % SAVE_PERIOD == 0 ): 93 | net.save(sess,LOG_DIR,step=it) 94 | 95 | if( it % SUMMARY_PERIOD == 0 ): 96 | tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) 97 | summary = sess.run(summary_op,feed_dict={x:batch_xs}) 98 | summary_writer.add_summary(summary,it) 99 | 100 | if( it % (SUMMARY_PERIOD*2) == 0 ): #Extended Summary 101 | batch_xs, _= mnist.test.next_batch(BATCH_SIZE) 102 | summary = sess.run(extended_summary_op,feed_dict={x:batch_xs}) 103 | summary_writer.add_summary(summary,it) 104 | 105 | net.save(sess,LOG_DIR) 106 | 107 | def extract_z(MODEL, 108 | BATCH_SIZE, 109 | BETA, 110 | K, 111 | D, 112 | **kwargs): 113 | # >>>>>>> DATASET 114 | from tensorflow.examples.tutorials.mnist import input_data 115 | mnist = input_data.read_data_sets("datasets/mnist", one_hot=False) 116 | # <<<<<<< 117 | 118 | # >>>>>>> MODEL 119 | x = tf.placeholder(tf.float32,[None,784]) 120 | resized = tf.image.resize_images( 121 | tf.reshape(x,[-1,28,28,1]), 122 | (24,24), 123 | method=tf.image.ResizeMethod.BILINEAR) 124 | 125 | with tf.variable_scope('net'): 126 | with tf.variable_scope('params') as params: 127 | pass 128 | net = VQVAE(None,None,BETA,resized,K,D,_mnist_arch,params,False) 129 | 130 | # Initialize op 131 | init_op = tf.group(tf.global_variables_initializer(), 132 | tf.local_variables_initializer()) 133 | 134 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 135 | config = tf.ConfigProto() 136 | config.gpu_options.allow_growth = True 137 | sess = tf.Session(config=config) 138 | sess.graph.finalize() 139 | sess.run(init_op) 140 | net.load(sess,MODEL) 141 | 142 | xs,ys = mnist.train.images, mnist.train.labels 143 | ks = [] 144 | for i in tqdm(range(0,len(xs),BATCH_SIZE)): 145 | batch = xs[i:i+BATCH_SIZE] 146 | 147 | k = sess.run(net.k,feed_dict={x:batch}) 148 | ks.append(k) 149 | ks = np.concatenate(ks,axis=0) 150 | 151 | np.savez(os.path.join(os.path.dirname(MODEL),'ks_ys.npz'),ks=ks,ys=ys) 152 | 153 | def train_prior(config, 154 | RANDOM_SEED, 155 | MODEL, 156 | TRAIN_NUM, 157 | BATCH_SIZE, 158 | LEARNING_RATE, 159 | DECAY_VAL, 160 | DECAY_STEPS, 161 | DECAY_STAIRCASE, 162 | GRAD_CLIP, 163 | K, 164 | D, 165 | BETA, 166 | NUM_LAYERS, 167 | NUM_FEATURE_MAPS, 168 | SUMMARY_PERIOD, 169 | SAVE_PERIOD, 170 | **kwargs): 171 | np.random.seed(RANDOM_SEED) 172 | tf.set_random_seed(RANDOM_SEED) 173 | LOG_DIR = os.path.join(os.path.dirname(MODEL),'pixelcnn') 174 | # >>>>>>> DATASET 175 | class Latents(): 176 | def __init__(self,path,validation_size=5000): 177 | from tensorflow.contrib.learn.python.learn.datasets.mnist import DataSet 178 | from tensorflow.contrib.learn.python.learn.datasets import base 179 | 180 | data = np.load(path) 181 | train = DataSet(data['ks'][validation_size:], data['ys'][validation_size:],reshape=False,dtype=np.uint8,one_hot=False) #dtype won't bother even in the case when latent is int32 type. 182 | validation = DataSet(data['ks'][:validation_size], data['ys'][:validation_size],reshape=False,dtype=np.uint8,one_hot=False) 183 | #test = DataSet(data['test_x'],np.argmax(data['test_y'],axis=1),reshape=False,dtype=np.float32,one_hot=False) 184 | self.size = data['ks'].shape[1] 185 | self.data = base.Datasets(train=train, validation=validation, test=None) 186 | latent = Latents(os.path.join(os.path.dirname(MODEL),'ks_ys.npz')) 187 | # <<<<<<< 188 | 189 | # >>>>>>> MODEL for Generate Images 190 | with tf.variable_scope('net'): 191 | with tf.variable_scope('params') as params: 192 | pass 193 | _not_used = tf.placeholder(tf.float32,[None,24,24,1]) 194 | vq_net = VQVAE(None,None,BETA,_not_used,K,D,_mnist_arch,params,False) 195 | # <<<<<<< 196 | 197 | # >>>>>> MODEL for Training Prior 198 | with tf.variable_scope('pixelcnn'): 199 | global_step = tf.Variable(0, trainable=False) 200 | learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) 201 | tf.summary.scalar('lr',learning_rate) 202 | 203 | net = PixelCNN(learning_rate,global_step,GRAD_CLIP, 204 | latent.size,vq_net.embeds,K,D, 205 | 10,NUM_LAYERS,NUM_FEATURE_MAPS) 206 | # <<<<<< 207 | with tf.variable_scope('misc'): 208 | # Summary Operations 209 | tf.summary.scalar('loss',net.loss) 210 | summary_op = tf.summary.merge_all() 211 | 212 | # Initialize op 213 | init_op = tf.group(tf.global_variables_initializer(), 214 | tf.local_variables_initializer()) 215 | config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor(config.as_matrix()), collections=[]) 216 | 217 | sample_images = tf.placeholder(tf.float32,[None,24,24,1]) 218 | sample_summary_op = tf.summary.image('samples',sample_images,max_outputs=20) 219 | 220 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! 221 | config = tf.ConfigProto() 222 | config.gpu_options.allow_growth = True 223 | sess = tf.Session(config=config) 224 | sess.graph.finalize() 225 | sess.run(init_op) 226 | vq_net.load(sess,MODEL) 227 | 228 | summary_writer = tf.summary.FileWriter(LOG_DIR,sess.graph) 229 | summary_writer.add_summary(config_summary.eval(session=sess)) 230 | 231 | for step in tqdm(xrange(TRAIN_NUM),dynamic_ncols=True): 232 | batch_xs, batch_ys = latent.data.train.next_batch(BATCH_SIZE) 233 | it,loss,_ = sess.run([global_step,net.loss,net.train_op],feed_dict={net.X:batch_xs,net.h:batch_ys}) 234 | 235 | if( it % SAVE_PERIOD == 0 ): 236 | net.save(sess,LOG_DIR,step=it) 237 | 238 | if( it % SUMMARY_PERIOD == 0 ): 239 | tqdm.write('[%5d] Loss: %1.3f'%(it,loss)) 240 | summary = sess.run(summary_op,feed_dict={net.X:batch_xs,net.h:batch_ys}) 241 | summary_writer.add_summary(summary,it) 242 | 243 | if( it % (SUMMARY_PERIOD * 2) == 0 ): 244 | sampled_zs,log_probs = net.sample_from_prior(sess,np.arange(10),2) 245 | sampled_ims = sess.run(vq_net.gen,feed_dict={vq_net.latent:sampled_zs}) 246 | summary_writer.add_summary( 247 | sess.run(sample_summary_op,feed_dict={sample_images:sampled_ims}),it) 248 | 249 | net.save(sess,LOG_DIR) 250 | 251 | 252 | def get_default_param(): 253 | from datetime import datetime 254 | now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 255 | return { 256 | 'LOG_DIR':'./log/mnist/%s'%(now), 257 | 'MODEL' : './log/mnist/%s/last.ckpt'%(now), 258 | 259 | 'TRAIN_NUM' : 60000, #Size corresponds to one epoch 260 | 'BATCH_SIZE': 32, 261 | 262 | 'LEARNING_RATE' : 0.0002, 263 | 'DECAY_VAL' : 1.0, 264 | 'DECAY_STEPS' : 20000, # Half of the training procedure. 265 | 'DECAY_STAIRCASE' : False, 266 | 267 | 'BETA':0.25, 268 | 'K':5, 269 | 'D':64, 270 | 271 | # PixelCNN Params 272 | 'GRAD_CLIP' : 1.0, 273 | 'NUM_LAYERS' : 12, 274 | 'NUM_FEATURE_MAPS' : 32, 275 | 276 | 'SUMMARY_PERIOD' : 100, 277 | 'SAVE_PERIOD' : 10000, 278 | 'RANDOM_SEED': 0, 279 | } 280 | 281 | if __name__ == "__main__": 282 | class MyConfig(dict): 283 | pass 284 | params = get_default_param() 285 | config = MyConfig(params) 286 | def as_matrix() : 287 | return [[k, str(w)] for k, w in config.items()] 288 | config.as_matrix = as_matrix 289 | 290 | main(config=config,**config) 291 | extract_z(**config) 292 | train_prior(config=config,**config) 293 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from six.moves import xrange 2 | import better_exceptions 3 | import tensorflow as tf 4 | import numpy as np 5 | from commons.ops import * 6 | 7 | def _mnist_arch(d): 8 | with tf.variable_scope('enc') as enc_param_scope : 9 | enc_spec = [ 10 | Conv2d('conv2d_1',1,d//4,data_format='NHWC'), 11 | lambda t,**kwargs : tf.nn.relu(t), 12 | Conv2d('conv2d_2',d//4,d//2,data_format='NHWC'), 13 | lambda t,**kwargs : tf.nn.relu(t), 14 | Conv2d('conv2d_3',d//2,d,data_format='NHWC'), 15 | lambda t,**kwargs : tf.nn.relu(t), 16 | ] 17 | with tf.variable_scope('dec') as dec_param_scope : 18 | dec_spec = [ 19 | TransposedConv2d('tconv2d_1',d,d//2,data_format='NHWC'), 20 | lambda t,**kwargs : tf.nn.relu(t), 21 | TransposedConv2d('tconv2d_2',d//2,d//4,data_format='NHWC'), 22 | lambda t,**kwargs : tf.nn.relu(t), 23 | TransposedConv2d('tconv2d_3',d//4,1,data_format='NHWC'), 24 | lambda t,**kwargs : tf.nn.sigmoid(t), 25 | ] 26 | return enc_spec,enc_param_scope,dec_spec,dec_param_scope 27 | 28 | def _cifar10_arch(d): 29 | def _residual(t,conv3,conv1): 30 | return conv1(tf.nn.relu(conv3(tf.nn.relu(t))))+t 31 | from functools import partial 32 | 33 | with tf.variable_scope('enc') as enc_param_scope : 34 | enc_spec = [ 35 | Conv2d('conv2d_1',3,d,data_format='NHWC'), 36 | lambda t,**kwargs : tf.nn.relu(t), 37 | Conv2d('conv2d_2',d,d,data_format='NHWC'), 38 | lambda t,**kwargs : tf.nn.relu(t), 39 | partial(_residual, 40 | conv3=Conv2d('res_1_3',d,d,3,3,1,1,data_format='NHWC'), 41 | conv1=Conv2d('res_1_1',d,d,1,1,1,1,data_format='NHWC')), 42 | partial(_residual, 43 | conv3=Conv2d('res_2_3',d,d,3,3,1,1,data_format='NHWC'), 44 | conv1=Conv2d('res_2_1',d,d,1,1,1,1,data_format='NHWC')), 45 | ] 46 | with tf.variable_scope('dec') as dec_param_scope : 47 | dec_spec = [ 48 | partial(_residual, 49 | conv3=Conv2d('res_1_3',d,d,3,3,1,1,data_format='NHWC'), 50 | conv1=Conv2d('res_1_1',d,d,1,1,1,1,data_format='NHWC')), 51 | partial(_residual, 52 | conv3=Conv2d('res_2_3',d,d,3,3,1,1,data_format='NHWC'), 53 | conv1=Conv2d('res_2_1',d,d,1,1,1,1,data_format='NHWC')), 54 | TransposedConv2d('tconv2d_1',d,d,data_format='NHWC'), 55 | lambda t,**kwargs : tf.nn.relu(t), 56 | TransposedConv2d('tconv2d_2',d,3,data_format='NHWC'), 57 | lambda t,**kwargs : tf.nn.sigmoid(t), 58 | ] 59 | return enc_spec,enc_param_scope,dec_spec,dec_param_scope 60 | 61 | def _imagenet_arch(d,num_residual=4): 62 | def _residual(t,conv3,conv1): 63 | return conv1(tf.nn.relu(conv3(tf.nn.relu(t))))+t 64 | from functools import partial 65 | 66 | with tf.variable_scope('enc') as enc_param_scope : 67 | enc_spec = [ 68 | Conv2d('conv2d_1',3,d//2,data_format='NHWC'), 69 | lambda t,**kwargs : tf.nn.relu(t), 70 | Conv2d('conv2d_2',d//2,d,data_format='NHWC'), 71 | lambda t,**kwargs : tf.nn.relu(t), 72 | ] 73 | enc_spec += [ 74 | partial(_residual, 75 | conv3=Conv2d('res_%d_3'%i,d,d,3,3,1,1,data_format='NHWC'), 76 | conv1=Conv2d('res_%d_1'%i,d,d,1,1,1,1,data_format='NHWC')) 77 | for i in range(num_residual) 78 | ] 79 | with tf.variable_scope('dec') as dec_param_scope : 80 | dec_spec = [ 81 | partial(_residual, 82 | conv3=Conv2d('res_%d_3'%i,d,d,3,3,1,1,data_format='NHWC'), 83 | conv1=Conv2d('res_%d_1'%i,d,d,1,1,1,1,data_format='NHWC')) 84 | for i in range(num_residual) 85 | ] 86 | dec_spec += [ 87 | lambda t,**kwargs : tf.nn.relu(t), 88 | TransposedConv2d('tconv2d_1',d,d//2,data_format='NHWC'), 89 | lambda t,**kwargs : tf.nn.relu(t), 90 | TransposedConv2d('tconv2d_2',d//2,3,data_format='NHWC'), 91 | lambda t,**kwargs : tf.nn.sigmoid(t), 92 | ] 93 | return enc_spec,enc_param_scope,dec_spec,dec_param_scope 94 | 95 | class VQVAE(): 96 | def __init__(self,lr,global_step,beta, 97 | x,K,D, 98 | arch_fn, 99 | param_scope,is_training=False): 100 | with tf.variable_scope(param_scope): 101 | enc_spec,enc_param_scope,dec_spec,dec_param_scope = arch_fn(D) 102 | with tf.variable_scope('embed') : 103 | embeds = tf.get_variable('embed', [K,D], 104 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 105 | self.embeds = embeds 106 | 107 | with tf.variable_scope('forward') as forward_scope: 108 | # Encoder Pass 109 | _t = x 110 | for block in enc_spec : 111 | _t = block(_t) 112 | z_e = _t 113 | 114 | # Middle Area (Compression or Discretize) 115 | _t = tf.expand_dims(z_e, axis=-2) 116 | _e = embeds 117 | _t = tf.norm(_t-_e,axis=-1) 118 | k = tf.argmin(_t,axis=-1) # -> [latent_h,latent_w] 119 | z_q = tf.gather(embeds,k) 120 | 121 | self.z_e = z_e # -> [batch,latent_h,latent_w,D] 122 | self.k = k 123 | self.z_q = z_q # -> [batch,latent_h,latent_w,D] 124 | 125 | # Decoder Pass 126 | _t = z_q 127 | for block in dec_spec: 128 | _t = block(_t) 129 | self.p_x_z = _t 130 | 131 | # Losses 132 | self.recon = tf.reduce_mean((self.p_x_z - x)**2,axis=[0,1,2,3]) 133 | self.vq = tf.reduce_mean( 134 | tf.norm(tf.stop_gradient(self.z_e) - z_q,axis=-1)**2, 135 | axis=[0,1,2]) 136 | self.commit = tf.reduce_mean( 137 | tf.norm(self.z_e - tf.stop_gradient(z_q),axis=-1)**2, 138 | axis=[0,1,2]) 139 | self.loss = self.recon + self.vq + beta * self.commit 140 | 141 | # NLL 142 | # TODO: is it correct impl? 143 | # it seems tf.reduce_prod(tf.shape(self.z_q)[1:2]) should be multipled 144 | # in front of log(1/K) if we assume uniform prior on z. 145 | self.nll = -1.*(tf.reduce_mean(tf.log(self.p_x_z),axis=[1,2,3]) + tf.log(1/tf.cast(K,tf.float32)))/tf.log(2.) 146 | 147 | if( is_training ): 148 | with tf.variable_scope('backward'): 149 | # Decoder Grads 150 | decoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,dec_param_scope.name) 151 | decoder_grads = list(zip(tf.gradients(self.loss,decoder_vars),decoder_vars)) 152 | # Encoder Grads 153 | encoder_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,enc_param_scope.name) 154 | grad_z = tf.gradients(self.recon,z_q) 155 | encoder_grads = [(tf.gradients(z_e,var,grad_z)[0]+beta*tf.gradients(self.commit,var)[0],var) 156 | for var in encoder_vars] 157 | # Embedding Grads 158 | embed_grads = list(zip(tf.gradients(self.vq,embeds),[embeds])) 159 | 160 | optimizer = tf.train.AdamOptimizer(lr) 161 | self.train_op= optimizer.apply_gradients(decoder_grads+encoder_grads+embed_grads,global_step=global_step) 162 | else : 163 | # Another decoder pass that we can play with! 164 | size = self.z_e.get_shape()[1] 165 | self.latent = tf.placeholder(tf.int64,[None,size,size]) 166 | _t = tf.gather(embeds,self.latent) 167 | for block in dec_spec: 168 | _t = block(_t) 169 | self.gen = _t 170 | 171 | save_vars = {('train/'+'/'.join(var.name.split('/')[1:])).split(':')[0] : var for var in 172 | tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,param_scope.name) } 173 | #for name,var in save_vars.items(): 174 | # print(name,var) 175 | 176 | self.saver = tf.train.Saver(var_list=save_vars,max_to_keep = 3) 177 | 178 | def save(self,sess,dir,step=None): 179 | if(step is not None): 180 | self.saver.save(sess,dir+'/model.ckpt',global_step=step) 181 | else : 182 | self.saver.save(sess,dir+'/last.ckpt') 183 | 184 | def load(self,sess,model): 185 | self.saver.restore(sess,model) 186 | 187 | class PixelCNN(object): 188 | def __init__(self,lr,global_step,grad_clip, 189 | size, embeds, K, D, 190 | num_classes, num_layers, num_maps, 191 | is_training=True): 192 | import sys 193 | sys.path.append('pixelcnn') 194 | from layers import GatedCNN 195 | self.X = tf.placeholder(tf.int32,[None,size,size]) 196 | 197 | if( num_classes is not None ): 198 | self.h = tf.placeholder(tf.int32,[None,]) 199 | onehot_h = tf.one_hot(self.h,num_classes,axis=-1) 200 | else: 201 | onehot_h = None 202 | 203 | if( embeds is not None ): 204 | X_processed = tf.gather(tf.stop_gradient(embeds),self.X) 205 | else: 206 | embeds = tf.get_variable('embed', [K,D], 207 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 208 | X_processed = tf.gather(embeds,self.X) 209 | 210 | v_stack_in, h_stack_in = X_processed, X_processed 211 | for i in range(num_layers): 212 | filter_size = 3 if i > 0 else 7 213 | mask = 'b' if i > 0 else 'a' 214 | residual = True if i > 0 else False 215 | i = str(i) 216 | with tf.variable_scope("v_stack"+i): 217 | v_stack = GatedCNN([filter_size, filter_size, num_maps], v_stack_in, mask=mask, conditional=onehot_h).output() 218 | v_stack_in = v_stack 219 | 220 | with tf.variable_scope("v_stack_1"+i): 221 | v_stack_1 = GatedCNN([1, 1, num_maps], v_stack_in, gated=False, mask=mask).output() 222 | 223 | with tf.variable_scope("h_stack"+i): 224 | h_stack = GatedCNN([1, filter_size, num_maps], h_stack_in, payload=v_stack_1, mask=mask, conditional=onehot_h).output() 225 | 226 | with tf.variable_scope("h_stack_1"+i): 227 | h_stack_1 = GatedCNN([1, 1, num_maps], h_stack, gated=False, mask=mask).output() 228 | if residual: 229 | h_stack_1 += h_stack_in # Residual connection 230 | h_stack_in = h_stack_1 231 | 232 | with tf.variable_scope("fc_1"): 233 | fc1 = GatedCNN([1, 1, num_maps], h_stack_in, gated=False, mask='b').output() 234 | 235 | with tf.variable_scope("fc_2"): 236 | self.fc2 = GatedCNN([1, 1, K], fc1, gated=False, mask='b', activation=False).output() 237 | self.dist = tf.distributions.Categorical(logits=self.fc2) 238 | self.sampled = self.dist.sample() 239 | self.log_prob = self.dist.log_prob(self.sampled) 240 | 241 | loss_per_batch = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.fc2, 242 | labels=self.X),axis=[1,2]) 243 | self.loss = tf.reduce_mean(loss_per_batch,axis=0) 244 | 245 | save_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,tf.contrib.framework.get_name_scope()) 246 | self.saver = tf.train.Saver(var_list=save_vars,max_to_keep = 3) 247 | 248 | if( is_training ): 249 | with tf.variable_scope('backward'): 250 | optimizer = tf.train.AdamOptimizer(lr) 251 | 252 | gradients = optimizer.compute_gradients(self.loss,var_list=save_vars) 253 | if( grad_clip is None ): 254 | clipped_gradients = gradients 255 | else : 256 | clipped_gradients = [(tf.clip_by_value(_[0], -grad_clip, grad_clip), _[1]) for _ in gradients] 257 | #clipped_gradients = [(tf.clip_by_average_norm(_[0], grad_clip), _[1]) for _ in gradients] 258 | self.train_op = optimizer.apply_gradients(clipped_gradients,global_step) 259 | #for var in save_vars: 260 | # print(var,var.name) 261 | 262 | def sample_from_prior(self,sess,classes,batch_size): 263 | # Generates len(classes)*batch_size Z samples. 264 | size = self.X.get_shape()[1] 265 | feed_dict={ 266 | self.X: np.zeros([len(classes)*batch_size,size,size],np.int32) 267 | } 268 | if( classes is not None ): 269 | feed_dict[self.h] = np.repeat(classes,batch_size).astype(np.int32) 270 | 271 | log_probs = np.zeros((len(classes)*batch_size,)) 272 | for i in xrange(size): 273 | for j in xrange(size): 274 | sampled,log_prob = sess.run([self.sampled,self.log_prob],feed_dict=feed_dict) 275 | feed_dict[self.X][:,i,j]= sampled[:,i,j] 276 | log_probs += log_prob[:,i,j] 277 | return feed_dict[self.X], log_probs 278 | 279 | def save(self,sess,dir,step=None): 280 | if(step is not None): 281 | self.saver.save(sess,dir+'/model-pixelcnn.ckpt',global_step=step) 282 | else : 283 | self.saver.save(sess,dir+'/last-pixelcnn.ckpt') 284 | 285 | def load(self,sess,model): 286 | self.saver.restore(sess,model) 287 | 288 | if __name__ == "__main__": 289 | with tf.variable_scope('params') as params: 290 | pass 291 | 292 | x = tf.placeholder(tf.float32,[None,32,32,3]) 293 | global_step = tf.Variable(0, trainable=False) 294 | 295 | net = VQVAE(0.1,global_step,0.1,x,20,256,_cifar10_arch,params,True) 296 | 297 | with tf.variable_scope('pixelcnn'): 298 | latent = tf.placeholder(tf.int32,[None,3,3]) 299 | embeds = net.embeds 300 | 301 | pixelcnn = PixelCNN(0.1,global_step,1.0, 302 | 3,embeds,20,32, 303 | True,10,20) 304 | 305 | init_op = tf.group(tf.global_variables_initializer(), 306 | tf.local_variables_initializer()) 307 | 308 | config = tf.ConfigProto() 309 | config.gpu_options.allow_growth = True 310 | sess = tf.Session(config=config) 311 | sess.graph.finalize() 312 | sess.run(init_op) 313 | 314 | #print(sess.run(net.train_op,feed_dict={x:np.random.random((10,32,32,3))})) 315 | sampled,log_prob = pixelcnn.sample_from_prior(sess,np.arange(10),1) 316 | print(sampled[0], np.exp(log_prob[0])) 317 | 318 | -------------------------------------------------------------------------------- /models/cifar10/last-pixelcnn.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/cifar10/last-pixelcnn.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /models/cifar10/last-pixelcnn.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/cifar10/last-pixelcnn.ckpt.index -------------------------------------------------------------------------------- /models/cifar10/last-pixelcnn.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/cifar10/last-pixelcnn.ckpt.meta -------------------------------------------------------------------------------- /models/cifar10/last.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/cifar10/last.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /models/cifar10/last.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/cifar10/last.ckpt.index -------------------------------------------------------------------------------- /models/cifar10/last.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/cifar10/last.ckpt.meta -------------------------------------------------------------------------------- /models/imagenet/last.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/imagenet/last.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /models/imagenet/last.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/imagenet/last.ckpt.index -------------------------------------------------------------------------------- /models/imagenet/last.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/imagenet/last.ckpt.meta -------------------------------------------------------------------------------- /models/mnist/last-pixelcnn.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/mnist/last-pixelcnn.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /models/mnist/last-pixelcnn.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/mnist/last-pixelcnn.ckpt.index -------------------------------------------------------------------------------- /models/mnist/last-pixelcnn.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/mnist/last-pixelcnn.ckpt.meta -------------------------------------------------------------------------------- /models/mnist/last.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/mnist/last.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /models/mnist/last.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/mnist/last.ckpt.index -------------------------------------------------------------------------------- /models/mnist/last.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiwonjoon/tf-vqvae/6d69ff97dd2ca62208697cc54da06f3bd1e845bb/models/mnist/last.ckpt.meta --------------------------------------------------------------------------------