├── .gitignore ├── LICENSE ├── README.md ├── imlib ├── __init__.py ├── basic.py ├── dtype.py ├── encode.py └── transform.py ├── models.py ├── models_64x64.py ├── pics ├── GAN_normalG.jpg ├── GAN_trickyG.jpg ├── Jensen-Shannon_normalG.jpg ├── Jensen-Shannon_trickyG.jpg ├── Kullback-Leibler_normalG.jpg ├── Kullback-Leibler_trickyG.jpg ├── Pearson-X2_normalG.jpg ├── Pearson-X2_trickyG.jpg ├── Reverse-KL_normalG.jpg └── Reverse-KL_trickyG.jpg ├── pylib ├── __init__.py ├── timer.py └── utils.py ├── tflib ├── __init__.py ├── checkpoint.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── disk_image.py │ ├── memory_data.py │ ├── tfrecord.py │ └── tfrecord_creator.py ├── ops │ ├── __init__.py │ └── layers.py ├── utils.py ├── variable.py └── vision │ ├── __init__.py │ └── dataset │ ├── __init__.py │ └── mnist.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | /data/ 4 | /output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 hezhenliang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a 4 | copy of this software and associated documentation files (the "Software"), 5 | to deal in the Software without restriction, including without limitation 6 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | and/or sell copies of the Software, and to permit persons to whom the 8 | Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

f-GAN

2 | 3 | Tensorflow implementation of f-GAN (NIPS 2016) - [f-GAN: Training Generative Neural Samplers Using Variational Divergence Minimization](https://arxiv.org/abs/1606.00709). 4 | 5 | ## TODO 6 | 7 | - [ ] make these divergences work (welcome the suggestions) 8 | - [ ] ***Kullback-Leibler*** with tricky G loss 9 | - [ ] ***Reverse-KL*** with tricky G loss 10 | - [x] ***Pearson-X2*** with tricky G loss 11 | - [ ] ***Squared-Hellinger*** with tricky G loss 12 | - [x] ***Jensen-Shannon*** with tricky G loss 13 | - [x] ***GAN*** with tricky G loss 14 | - [ ] test more divergence 15 | 16 | ## Exemplar Results 17 | 18 | - Using tricky G loss (see Section 3.2 in the paper) 19 | 20 | Kullback-Leibler | Reverse-KL | Pearson-X2 21 | :---: | :---: | :---: 22 | | | 23 | **Squared-Hellinger** | **Jensen-Shannon** | **GAN** 24 | NaN | | 25 | 26 | - Using theoretically correct G loss 27 | 28 | Kullback-Leibler | Reverse-KL | Pearson-X2 29 | :---: | :---: | :---: 30 | | | 31 | **Squared-Hellinger** | **Jensen-Shannon** | **GAN** 32 | NaN | | 33 | 34 | ## Usage 35 | 36 | - Prerequisites 37 | - tensorflow 1.7 or 1.8 38 | - python 2.7 39 | 40 | 41 | - Examples of training 42 | - training 43 | 44 | ```console 45 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset=mnist --divergence=Pearson-X2 --tricky_G 46 | ``` 47 | 48 | - tensorboard for loss visualization 49 | 50 | ```console 51 | CUDA_VISIBLE_DEVICES='' tensorboard --logdir ./output/mnist_Pearson-X2_trickyG/summaries --port 6006 52 | ``` 53 | 54 | ## Citation 55 | If you find [f-GAN](https://arxiv.org/abs/1606.00709) useful in your research work, please consider citing: 56 | 57 | @inproceedings{nowozin2016f, 58 | title={f-GAN: Training Generative Neural Samplers Using Variational Divergence Minimization}, 59 | author={Nowozin, Sebastian and Cseke, Botond and Tomioka, Ryota}, 60 | booktitle={Advances in Neural Information Processing Systems (NIPS)}, 61 | year={2016} 62 | } -------------------------------------------------------------------------------- /imlib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from imlib.basic import * 6 | from imlib.dtype import * 7 | from imlib.encode import * 8 | from imlib.transform import * 9 | -------------------------------------------------------------------------------- /imlib/basic.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from imlib.dtype import * 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import scipy.misc 9 | 10 | 11 | def imread(paths, mode='RGB'): 12 | """Read image(s). 13 | 14 | if `paths` is a list or tuple, then read a list of images into [-1.0, 1.0] 15 | of float and return the numpy array batch in shape of N * H * W (* C) 16 | if `paths` is a single str, then read an image into [-1.0, 1.0] of float 17 | 18 | Args: 19 | mode: It can be one of the following strings: 20 | * 'L' (8 - bit pixels, black and white) 21 | * 'P' (8 - bit pixels, mapped to any other mode using a color palette) 22 | * 'RGB' (3x8 - bit pixels, true color) 23 | * 'RGBA' (4x8 - bit pixels, true color with transparency mask) 24 | * 'CMYK' (4x8 - bit pixels, color separation) 25 | * 'YCbCr' (3x8 - bit pixels, color video format) 26 | * 'I' (32 - bit signed integer pixels) 27 | * 'F' (32 - bit floating point pixels) 28 | 29 | Returns: 30 | Float64 image in [-1.0, 1.0]. 31 | """ 32 | def _imread(path, mode='RGB'): 33 | return scipy.misc.imread(path, mode=mode) / 127.5 - 1 34 | 35 | if isinstance(paths, (list, tuple)): 36 | return np.array([_imread(path, mode) for path in paths]) 37 | else: 38 | return _imread(paths, mode) 39 | 40 | 41 | def imwrite(image, path): 42 | """Save an [-1.0, 1.0] image.""" 43 | if image.ndim == 3 and image.shape[2] == 1: # for gray image 44 | image = np.array(image, copy=True) 45 | image.shape = image.shape[0:2] 46 | return scipy.misc.imsave(path, to_range(image, 0, 255, np.uint8)) 47 | 48 | 49 | def imshow(image): 50 | """Show a [-1.0, 1.0] image.""" 51 | if image.ndim == 3 and image.shape[2] == 1: # for gray image 52 | image = np.array(image, copy=True) 53 | image.shape = image.shape[0:2] 54 | plt.imshow(to_range(image), cmap=plt.gray()) 55 | 56 | 57 | show = plt.show 58 | -------------------------------------------------------------------------------- /imlib/dtype.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | 8 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None): 9 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype.""" 10 | assert np.min(images) >= -1.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \ 11 | and (images.dtype == np.float32 or images.dtype == np.float64), \ 12 | 'The input images should be float64(32) and in the range of [-1.0, 1.0]!' 13 | if dtype is None: 14 | dtype = images.dtype 15 | return ((images + 1.) / 2. * (max_value - min_value) + 16 | min_value).astype(dtype) 17 | 18 | 19 | def uint2im(images): 20 | """Transform images from uint8 to [-1.0, 1.0] of float64.""" 21 | assert images.dtype == np.uint8, 'The input images type should be uint8!' 22 | return images / 127.5 - 1.0 23 | 24 | 25 | def float2im(images): 26 | """Transform images from [0, 1.0] to [-1.0, 1.0].""" 27 | assert np.min(images) >= 0.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \ 28 | and (images.dtype == np.float32 or images.dtype == np.float64), \ 29 | 'The input images should be float64(32) and in the range of [0.0, 1.0]!' 30 | return images * 2 - 1.0 31 | 32 | 33 | def im2uint(images): 34 | """Transform images from [-1.0, 1.0] to uint8.""" 35 | return to_range(images, 0, 255, np.uint8) 36 | 37 | 38 | def im2float(images): 39 | """Transform images from [-1.0, 1.0] to [0.0, 1.0].""" 40 | return to_range(images, 0.0, 1.0) 41 | 42 | 43 | def float2uint(images): 44 | """Transform images from [0, 1.0] to uint8.""" 45 | assert np.min(images) >= 0.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \ 46 | and (images.dtype == np.float32 or images.dtype == np.float64), \ 47 | 'The input images should be float64(32) and in the range of [0.0, 1.0]!' 48 | return (images * 255).astype(np.uint8) 49 | 50 | 51 | def uint2float(images): 52 | """Transform images from uint8 to [0.0, 1.0] of float64.""" 53 | assert images.dtype == np.uint8, 'The input images type should be uint8!' 54 | return images / 255.0 55 | -------------------------------------------------------------------------------- /imlib/encode.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import io 6 | 7 | from imlib.dtype import * 8 | from PIL import Image 9 | 10 | 11 | def imencode(image, format='PNG', quality=95): 12 | """Encode an [-1.0, 1.0] into byte str. 13 | 14 | Args: 15 | format: 'PNG' or 'JPEG'. 16 | quality: for 'JPEG'. 17 | 18 | Returns: 19 | Byte string. 20 | """ 21 | byte_io = io.BytesIO() 22 | image = Image.fromarray(im2uint(image)) 23 | image.save(byte_io, format=format, quality=quality) 24 | bytes = byte_io.getvalue() 25 | return bytes 26 | 27 | 28 | def imdecode(bytes): 29 | """Decode byte str to image in [-1.0, 1.0] of float64. 30 | 31 | Args: 32 | bytes: Byte string. 33 | 34 | Returns: 35 | A float64 image in [-1.0, 1.0]. 36 | """ 37 | byte_io = io.BytesIO() 38 | byte_io.write(bytes) 39 | image = np.array(Image.open(byte_io)) 40 | image = uint2im(image) 41 | return image 42 | -------------------------------------------------------------------------------- /imlib/transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from imlib.dtype import * 6 | import numpy as np 7 | import scipy.misc 8 | 9 | 10 | def rgb2gray(images): 11 | if images.ndim == 4 or images.ndim == 3: 12 | assert images.shape[-1] == 3, 'Channel size should be 3!' 13 | else: 14 | raise Exception('Wrong dimensions!') 15 | 16 | return (images[..., 0] * 0.299 + images[..., 1] * 0.587 + images[..., 2] * 0.114).astype(images.dtype) 17 | 18 | 19 | def gray2rgb(images): 20 | assert images.ndim == 2 or images.ndim == 3, 'Wrong dimensions!' 21 | rgb_imgs = np.zeros(images.shape + (3,), dtype=images.dtype) 22 | rgb_imgs[..., 0] = images 23 | rgb_imgs[..., 1] = images 24 | rgb_imgs[..., 2] = images 25 | return rgb_imgs 26 | 27 | 28 | def imresize(image, size, interp='bilinear'): 29 | """Resize an [-1.0, 1.0] image. 30 | 31 | Args: 32 | size : int, float or tuple 33 | * int - Percentage of current size. 34 | * float - Fraction of current size. 35 | * tuple - Size of the output image. 36 | 37 | interp : str, optional 38 | Interpolation to use for re-sizing ('nearest', 'lanczos', 39 | 'bilinear', 'bicubic' or 'cubic'). 40 | """ 41 | # scipy.misc.imresize should deal with uint8 image, or it would cause some 42 | # problem (scale the image to [0, 255]) 43 | return (scipy.misc.imresize(im2uint(image), size, interp=interp) / 127.5 - 1).astype(image.dtype) 44 | 45 | 46 | def resize_images(images, size, interp='bilinear'): 47 | """Resize batch [-1.0, 1.0] images of shape (N * H * W (* 3)). 48 | 49 | Args: 50 | size : int, float or tuple 51 | * int - Percentage of current size. 52 | * float - Fraction of current size. 53 | * tuple - Size of the output image. 54 | 55 | interp : str, optional 56 | Interpolation to use for re-sizing ('nearest', 'lanczos', 57 | 'bilinear', 'bicubic' or 'cubic'). 58 | """ 59 | rs_imgs = [] 60 | for img in images: 61 | rs_imgs.append(imresize(img, size, interp)) 62 | return np.array(rs_imgs) 63 | 64 | 65 | def immerge(images, n_row=None, n_col=None, padding=0, pad_value=0): 66 | """Merge images into an image with (n_row * h) * (n_col * w). 67 | 68 | `images` is in shape of N * H * W(* C=1 or 3) 69 | """ 70 | n = images.shape[0] 71 | if n_row: 72 | n_row = max(min(n_row, n), 1) 73 | n_col = int(n - 0.5) // n_row + 1 74 | elif n_col: 75 | n_col = max(min(n_col, n), 1) 76 | n_row = int(n - 0.5) // n_col + 1 77 | else: 78 | n_row = int(n ** 0.5) 79 | n_col = int(n - 0.5) // n_row + 1 80 | 81 | h, w = images.shape[1], images.shape[2] 82 | shape = (h * n_row + padding * (n_row - 1), 83 | w * n_col + padding * (n_col - 1)) 84 | if images.ndim == 4: 85 | shape += (images.shape[3],) 86 | img = np.full(shape, pad_value, dtype=images.dtype) 87 | 88 | for idx, image in enumerate(images): 89 | i = idx % n_col 90 | j = idx // n_col 91 | img[j * (h + padding):j * (h + padding) + h, 92 | i * (w + padding):i * (w + padding) + w, ...] = image 93 | 94 | return img 95 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | import tflib as tl 10 | 11 | conv = partial(slim.conv2d, activation_fn=None) 12 | dconv = partial(slim.conv2d_transpose, activation_fn=None) 13 | fc = partial(tl.flatten_fully_connected, activation_fn=None) 14 | relu = tf.nn.relu 15 | lrelu = tf.nn.leaky_relu 16 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None) 17 | 18 | 19 | def G(z, dim=64, is_training=True): 20 | bn = partial(batch_norm, is_training=is_training) 21 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu) 22 | fc_bn_relu = partial(fc, normalizer_fn=bn, activation_fn=relu) 23 | 24 | with tf.variable_scope('G', reuse=tf.AUTO_REUSE): 25 | y = fc_bn_relu(z, 1024) 26 | y = fc_bn_relu(y, 7 * 7 * dim * 2) 27 | y = tf.reshape(y, [-1, 7, 7, dim * 2]) 28 | y = dconv_bn_relu(y, dim * 2, 5, 2) 29 | img = tf.tanh(dconv(y, 1, 5, 2)) 30 | return img 31 | 32 | 33 | def D(img, dim=64, is_training=True): 34 | bn = partial(batch_norm, is_training=is_training) 35 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu) 36 | fc_bn_lrelu = partial(fc, normalizer_fn=bn, activation_fn=lrelu) 37 | 38 | with tf.variable_scope('D', reuse=tf.AUTO_REUSE): 39 | y = lrelu(conv(img, 1, 5, 2)) 40 | y = conv_bn_lrelu(y, dim, 5, 2) 41 | y = fc_bn_lrelu(y, 1024) 42 | logit = fc(y, 1) 43 | return logit 44 | -------------------------------------------------------------------------------- /models_64x64.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | import tflib as tl 10 | 11 | conv = partial(slim.conv2d, activation_fn=None) 12 | dconv = partial(slim.conv2d_transpose, activation_fn=None) 13 | fc = partial(tl.flatten_fully_connected, activation_fn=None) 14 | relu = tf.nn.relu 15 | lrelu = tf.nn.leaky_relu 16 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None) 17 | 18 | 19 | def G(z, dim=64, is_training=True): 20 | bn = partial(batch_norm, is_training=is_training) 21 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu) 22 | fc_bn_relu = partial(fc, normalizer_fn=bn, activation_fn=relu) 23 | 24 | with tf.variable_scope('G', reuse=tf.AUTO_REUSE): 25 | y = fc_bn_relu(z, 4 * 4 * dim * 8) 26 | y = tf.reshape(y, [-1, 4, 4, dim * 8]) 27 | y = dconv_bn_relu(y, dim * 4, 5, 2) 28 | y = dconv_bn_relu(y, dim * 2, 5, 2) 29 | y = dconv_bn_relu(y, dim * 1, 5, 2) 30 | img = tf.tanh(dconv(y, 3, 5, 2)) 31 | return img 32 | 33 | 34 | def D(img, dim=64, is_training=True): 35 | bn = partial(batch_norm, is_training=is_training) 36 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu) 37 | 38 | with tf.variable_scope('D', reuse=tf.AUTO_REUSE): 39 | y = lrelu(conv(img, dim, 5, 2)) 40 | y = conv_bn_lrelu(y, dim * 2, 5, 2) 41 | y = conv_bn_lrelu(y, dim * 4, 5, 2) 42 | y = conv_bn_lrelu(y, dim * 8, 5, 2) 43 | logit = fc(y, 1) 44 | return logit 45 | -------------------------------------------------------------------------------- /pics/GAN_normalG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/GAN_normalG.jpg -------------------------------------------------------------------------------- /pics/GAN_trickyG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/GAN_trickyG.jpg -------------------------------------------------------------------------------- /pics/Jensen-Shannon_normalG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Jensen-Shannon_normalG.jpg -------------------------------------------------------------------------------- /pics/Jensen-Shannon_trickyG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Jensen-Shannon_trickyG.jpg -------------------------------------------------------------------------------- /pics/Kullback-Leibler_normalG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Kullback-Leibler_normalG.jpg -------------------------------------------------------------------------------- /pics/Kullback-Leibler_trickyG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Kullback-Leibler_trickyG.jpg -------------------------------------------------------------------------------- /pics/Pearson-X2_normalG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Pearson-X2_normalG.jpg -------------------------------------------------------------------------------- /pics/Pearson-X2_trickyG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Pearson-X2_trickyG.jpg -------------------------------------------------------------------------------- /pics/Reverse-KL_normalG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Reverse-KL_normalG.jpg -------------------------------------------------------------------------------- /pics/Reverse-KL_trickyG.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Reverse-KL_trickyG.jpg -------------------------------------------------------------------------------- /pylib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pylib.timer import * 6 | from pylib.utils import * 7 | -------------------------------------------------------------------------------- /pylib/timer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import datetime 6 | import timeit 7 | 8 | 9 | class Timer(object): 10 | """A timer as a context manager. 11 | 12 | Modified from https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py. 13 | 14 | Wraps around a timer. A custom timer can be passed 15 | to the constructor. The default timer is timeit.default_timer. 16 | 17 | Note that the latter measures wall clock time, not CPU time! 18 | On Unix systems, it corresponds to time.time. 19 | On Windows systems, it corresponds to time.clock. 20 | 21 | Keyword arguments: 22 | is_output -- if True, print output after exiting context. 23 | format -- 'ms', 's' or 'datetime' 24 | """ 25 | 26 | def __init__(self, timer=timeit.default_timer, is_output=True, fmt='s'): 27 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!" 28 | self._timer = timer 29 | self._is_output = is_output 30 | self._fmt = fmt 31 | 32 | def __enter__(self): 33 | """Start the timer in the context manager scope.""" 34 | self.start() 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_value, exc_traceback): 38 | """Set the end time.""" 39 | if self._is_output: 40 | print(str(self)) 41 | 42 | def __str__(self): 43 | if self._fmt != 'datetime': 44 | return '%s %s' % (self.elapsed, self._fmt) 45 | else: 46 | return str(self.elapsed) 47 | 48 | def start(self): 49 | self.start_time = self._timer() 50 | 51 | @property 52 | def elapsed(self): 53 | """Return the current elapsed time since start.""" 54 | e = self._timer() - self.start_time 55 | 56 | if self._fmt == 'ms': 57 | return e * 1000 58 | elif self._fmt == 's': 59 | return e 60 | elif self._fmt == 'datetime': 61 | return datetime.timedelta(seconds=e) 62 | 63 | 64 | def timer(**timer_kwargs): 65 | """Function decorator displaying the function execution time. 66 | 67 | All kwargs are the arguments taken by the Timer class constructor. 68 | """ 69 | # store Timer kwargs in local variable so the namespace isn't polluted 70 | # by different level args and kwargs 71 | 72 | def wrapped_f(f): 73 | def wrapped(*args, **kwargs): 74 | fmt = '[*] function "%(function_name)s" execution time: %(execution_time)s [*]' 75 | with Timer(**timer_kwargs) as t: 76 | out = f(*args, **kwargs) 77 | context = {'function_name': f.__name__, 'execution_time': str(t)} 78 | print(fmt % context) 79 | return out 80 | return wrapped 81 | 82 | return wrapped_f 83 | 84 | if __name__ == "__main__": 85 | import time 86 | 87 | # 1 88 | print(1) 89 | with Timer() as t: 90 | time.sleep(1) 91 | print(t) 92 | time.sleep(1) 93 | 94 | with Timer(fmt='datetime') as t: 95 | time.sleep(1) 96 | 97 | # 2 98 | print(2) 99 | t = Timer(fmt='ms') 100 | t.start() 101 | time.sleep(2) 102 | print(t) 103 | 104 | t = Timer(fmt='datetime') 105 | t.start() 106 | time.sleep(1) 107 | print(t) 108 | 109 | # 3 110 | print(3) 111 | 112 | @timer(fmt='ms') 113 | def blah(): 114 | time.sleep(2) 115 | 116 | blah() 117 | -------------------------------------------------------------------------------- /pylib/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | 8 | 9 | def add_path(paths): 10 | if not isinstance(paths, (list, tuple)): 11 | paths = [paths] 12 | for path in paths: 13 | if path not in sys.path: 14 | sys.path.insert(0, path) 15 | 16 | 17 | def mkdir(paths): 18 | if not isinstance(paths, (list, tuple)): 19 | paths = [paths] 20 | for path in paths: 21 | if not os.path.isdir(path): 22 | os.makedirs(path) 23 | 24 | 25 | if __name__ == '__main__': 26 | pass 27 | -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.checkpoint import * 6 | from tflib.data import * 7 | from tflib.ops import * 8 | from tflib.utils import * 9 | from tflib.variable import * 10 | from tflib.vision import * 11 | -------------------------------------------------------------------------------- /tflib/checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def load_checkpoint(ckpt_dir_or_file, session, var_list=None): 11 | """Load checkpoint. 12 | 13 | Note: 14 | This function add some useless ops to the graph. It is better 15 | to use tf.train.init_from_checkpoint(...). 16 | """ 17 | if os.path.isdir(ckpt_dir_or_file): 18 | ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) 19 | 20 | restorer = tf.train.Saver(var_list) 21 | restorer.restore(session, ckpt_dir_or_file) 22 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file) 23 | 24 | 25 | def init_from_checkpoint(ckpt_dir_or_file, assignment_map={'/': '/'}): 26 | # Use the checkpoint values for the variables' initializers. Note that this 27 | # function just changes the initializers but does not actually run them, and 28 | # you should still run the initializers manually. 29 | tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map) 30 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file) 31 | -------------------------------------------------------------------------------- /tflib/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.data.dataset import * 6 | from tflib.data.disk_image import * 7 | from tflib.data.memory_data import * 8 | from tflib.data.tfrecord import * 9 | from tflib.data.tfrecord_creator import * 10 | -------------------------------------------------------------------------------- /tflib/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.eager as tfe 9 | from tflib.utils import session 10 | 11 | 12 | _N_CPU = multiprocessing.cpu_count() 13 | 14 | 15 | def batch_dataset(dataset, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 16 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 17 | if filter: 18 | dataset = dataset.filter(filter) 19 | 20 | if map_func: 21 | dataset = dataset.map(map_func, num_parallel_calls=num_threads) 22 | 23 | if shuffle: 24 | dataset = dataset.shuffle(buffer_size) 25 | 26 | if drop_remainder: 27 | dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) 28 | else: 29 | dataset = dataset.batch(batch_size) 30 | 31 | dataset = dataset.repeat(repeat).prefetch(prefetch_batch) 32 | 33 | return dataset 34 | 35 | 36 | class Dataset(object): 37 | 38 | def __init__(self): 39 | self._dataset = None 40 | self._iterator = None 41 | self._batch_op = None 42 | self._sess = None 43 | 44 | self._is_eager = tf.executing_eagerly() 45 | self._eager_iterator = None 46 | 47 | def __del__(self): 48 | if self._sess: 49 | self._sess.close() 50 | 51 | def __iter__(self): 52 | return self 53 | 54 | def __next__(self): 55 | try: 56 | b = self.get_next() 57 | except: 58 | raise StopIteration 59 | else: 60 | return b 61 | 62 | next = __next__ 63 | 64 | def get_next(self): 65 | if self._is_eager: 66 | return self._eager_iterator.get_next() 67 | else: 68 | return self._sess.run(self._batch_op) 69 | 70 | def reset(self, feed_dict={}): 71 | if self._is_eager: 72 | self._eager_iterator = tfe.Iterator(self._dataset) 73 | else: 74 | self._sess.run(self._iterator.initializer, feed_dict=feed_dict) 75 | 76 | def _bulid(self, dataset, sess=None): 77 | self._dataset = dataset 78 | 79 | if self._is_eager: 80 | self._eager_iterator = tfe.Iterator(dataset) 81 | else: 82 | self._iterator = dataset.make_initializable_iterator() 83 | self._batch_op = self._iterator.get_next() 84 | if sess: 85 | self._sess = sess 86 | else: 87 | self._sess = session() 88 | 89 | try: 90 | self.reset() 91 | except: 92 | pass 93 | 94 | @property 95 | def dataset(self): 96 | return self._dataset 97 | 98 | @property 99 | def iterator(self): 100 | return self._iterator 101 | 102 | @property 103 | def batch_op(self): 104 | return self._batch_op 105 | -------------------------------------------------------------------------------- /tflib/data/disk_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | from tflib.data.dataset import batch_dataset, Dataset 9 | 10 | 11 | _N_CPU = multiprocessing.cpu_count() 12 | 13 | 14 | def disk_image_batch_dataset(img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 15 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 16 | """Disk image batch dataset. 17 | 18 | This function is suitable for jpg and png files 19 | 20 | img_paths: string list or 1-D tensor, each of which is an iamge path 21 | labels: label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label 22 | """ 23 | if labels is None: 24 | dataset = tf.data.Dataset.from_tensor_slices(img_paths) 25 | elif isinstance(labels, tuple): 26 | dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels)) 27 | else: 28 | dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels)) 29 | 30 | def parse_func(path, *label): 31 | img = tf.read_file(path) 32 | img = tf.image.decode_png(img, 3) 33 | return (img,) + label 34 | 35 | if map_func: 36 | def map_func_(*args): 37 | return map_func(*parse_func(*args)) 38 | else: 39 | map_func_ = parse_func 40 | 41 | # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower 42 | 43 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, 44 | map_func_, num_threads, shuffle, buffer_size, repeat) 45 | 46 | return dataset 47 | 48 | 49 | class DiskImageData(Dataset): 50 | """DiskImageData. 51 | 52 | This function is suitable for jpg and png files 53 | 54 | img_paths: string list or 1-D tensor, each of which is an iamge path 55 | labels: label list or tensor, each of which is a corresponding label 56 | """ 57 | 58 | def __init__(self, img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 59 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 60 | super(DiskImageData, self).__init__() 61 | dataset = disk_image_batch_dataset(img_paths, batch_size, labels, prefetch_batch, drop_remainder, filter, 62 | map_func, num_threads, shuffle, buffer_size, repeat) 63 | self._bulid(dataset, sess) 64 | -------------------------------------------------------------------------------- /tflib/data/memory_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | from tflib.data.dataset import batch_dataset, Dataset 9 | 10 | 11 | _N_CPU = multiprocessing.cpu_count() 12 | 13 | 14 | def memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 15 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 16 | """Memory data batch dataset. 17 | 18 | memory_data_dict: 19 | for example 20 | {'img': img_ndarray, 'label': label_ndarray} or 21 | {'img': img_tftensor, 'label': label_tftensor} 22 | the value of each item of `memory_data_dict` is in shape of (N, ...) 23 | """ 24 | dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict) 25 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, 26 | map_func, num_threads, shuffle, buffer_size, repeat) 27 | return dataset 28 | 29 | 30 | class MemoryData(Dataset): 31 | """MemoryData. 32 | 33 | memory_data_dict: 34 | for example 35 | {'img': img_ndarray, 'label': label_ndarray} or 36 | {'img': img_tftensor, 'label': label_tftensor} 37 | the value of each item of `memory_data_dict` is in shape of (N, ...) 38 | """ 39 | 40 | def __init__(self, memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 41 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 42 | super(MemoryData, self).__init__() 43 | dataset = memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter, 44 | map_func, num_threads, shuffle, buffer_size, repeat) 45 | self._bulid(dataset, sess) 46 | 47 | 48 | if __name__ == '__main__': 49 | import numpy as np 50 | data = {'a': np.array([1.0, 2, 3, 4, 5]), 51 | 'b': np.array([[1, 2], 52 | [2, 3], 53 | [3, 4], 54 | [4, 5], 55 | [5, 6]])} 56 | 57 | def filter(x): 58 | return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False)) 59 | 60 | def map_func(x): 61 | x['a'] = x['a'] * 10 62 | return x 63 | 64 | # tf.enable_eager_execution() 65 | 66 | s = tf.Session() 67 | 68 | dataset = MemoryData(data, 2, filter=None, map_func=map_func, 69 | shuffle=True, buffer_size=4096, drop_remainder=True, repeat=4, sess=s) 70 | 71 | for i in range(5): 72 | print(map(dataset.get_next().__getitem__, ['b', 'a'])) 73 | 74 | print([n.name for n in tf.get_default_graph().as_graph_def().node]) 75 | -------------------------------------------------------------------------------- /tflib/data/tfrecord.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import glob 7 | import json 8 | import multiprocessing 9 | import os 10 | 11 | import tensorflow as tf 12 | from tflib.data.dataset import batch_dataset, Dataset 13 | 14 | 15 | _N_CPU = multiprocessing.cpu_count() 16 | 17 | _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 18 | 19 | _DECODERS = { 20 | 'png': {'decoder': tf.image.decode_png, 'decode_param': dict()}, 21 | 'jpg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()}, 22 | 'jpeg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()}, 23 | 'uint8': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.uint8)}, 24 | 'int64': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.int64)}, 25 | 'float32': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.float32)}, 26 | } 27 | 28 | 29 | def tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, 30 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 31 | """Tfrecord batch dataset. 32 | 33 | infos: 34 | for example 35 | [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]}, 36 | {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}] 37 | """ 38 | dataset = tf.data.TFRecordDataset(tfrecord_files, 39 | compression_type=compression_type, 40 | buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES) 41 | 42 | features = {} 43 | for info in infos: 44 | features[info['name']] = tf.FixedLenFeature([], tf.string) 45 | 46 | def parse_func(serialized_example): 47 | example = tf.parse_single_example(serialized_example, features=features) 48 | 49 | feature_dict = {} 50 | for info in infos: 51 | name = info['name'] 52 | decoder = info['decoder'] 53 | decode_param = info['decode_param'] 54 | shape = info['shape'] 55 | 56 | feature = decoder(example[name], **decode_param) 57 | feature = tf.reshape(feature, shape) 58 | feature_dict[name] = feature 59 | 60 | return feature_dict 61 | 62 | dataset = dataset.map(parse_func, num_parallel_calls=num_threads) 63 | 64 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, 65 | map_func, num_threads, shuffle, buffer_size, repeat) 66 | 67 | return dataset 68 | 69 | 70 | class TfrecordData(Dataset): 71 | 72 | def __init__(self, tfrecord_path, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, 73 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 74 | super(TfrecordData, self).__init__() 75 | 76 | info_file = os.path.join(tfrecord_path, 'info.json') 77 | infos, self._data_num, compression_type = self._parse_json(info_file) 78 | 79 | self._shapes = {info['name']: tuple(info['shape']) for info in infos} 80 | 81 | tfrecord_files = sorted(glob.glob(os.path.join(tfrecord_path, '*.tfrecord'))) 82 | dataset = tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch, drop_remainder, 83 | filter, map_func, num_threads, shuffle, buffer_size, repeat) 84 | 85 | self._bulid(dataset, sess) 86 | 87 | def __len__(self): 88 | return self._data_num 89 | 90 | @property 91 | def shape(self): 92 | return self._shapes 93 | 94 | @staticmethod 95 | def _parse_old(json_file): 96 | with open(json_file.replace('info.json', 'info.txt')) as f: 97 | try: # older version 1 98 | infos = json.load(f) 99 | for info in infos[0:-1]: 100 | info['decoder'] = _DECODERS[info['dtype_or_format']]['decoder'] 101 | info['decode_param'] = _DECODERS[info['dtype_or_format']]['decode_param'] 102 | except: # older version 2 103 | f.seek(0) 104 | infos = '' 105 | for line in f.readlines(): 106 | infos += line.strip('\n') 107 | infos = eval(infos) 108 | 109 | data_num = infos[-1]['data_num'] 110 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[infos[-1]['compression_type']] 111 | infos[-1:] = [] 112 | 113 | return infos, data_num, compression_type 114 | 115 | @staticmethod 116 | def _parse_json(json_file): 117 | try: 118 | with open(json_file) as f: 119 | info = json.load(f) 120 | infos = info['item'] 121 | for i in infos: 122 | i['decoder'] = _DECODERS[i['dtype_or_format']]['decoder'] 123 | i['decode_param'] = _DECODERS[i['dtype_or_format']]['decode_param'] 124 | data_num = info['info']['data_num'] 125 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[info['info']['compression_type']] 126 | except: # for older version 127 | infos, data_num, compression_type = TfrecordData._parse_old(json_file) 128 | 129 | return infos, data_num, compression_type 130 | -------------------------------------------------------------------------------- /tflib/data/tfrecord_creator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import io 6 | import json 7 | import os 8 | import shutil 9 | 10 | import numpy as np 11 | from PIL import Image 12 | import tensorflow as tf 13 | from tflib.data import tfrecord 14 | 15 | __metaclass__ = type 16 | 17 | 18 | _ALLOWED_TYPES = tfrecord._DECODERS.keys() 19 | 20 | 21 | class BytesTfrecordCreator(object): 22 | """BytesTfrecordCreator. 23 | 24 | `infos` example: 25 | infos = [ 26 | ['img', 'jpg', (64, 64, 3)], 27 | ['id', 'int64', ()], 28 | ['attr', 'int64', (40,)], 29 | ['point', 'float32', (5, 2)] 30 | ] 31 | 32 | `compression_type`: 33 | 0: NONE 34 | 1: ZLIB 35 | 2: GZIP 36 | """ 37 | 38 | def __init__(self, save_path, infos, size_each=None, compression_type=0, overwrite_existence=False): 39 | # overwrite existence 40 | if os.path.exists(save_path): 41 | if not overwrite_existence: 42 | raise Exception('%s exists!' % save_path) 43 | else: 44 | shutil.rmtree(save_path) 45 | os.makedirs(save_path) 46 | else: 47 | os.makedirs(save_path) 48 | 49 | self._save_path = save_path 50 | 51 | # add info 52 | self._infos = [] 53 | self._info_names = [] 54 | for info in infos: 55 | self._add_info(*info) 56 | 57 | self._data_num = 0 58 | self._tfrecord_num = 0 59 | self._size_each = [size_each, 2147483647][not size_each] 60 | self._writer = None 61 | 62 | self._compression_type = compression_type 63 | self._options = tf.python_io.TFRecordOptions(compression_type) 64 | 65 | def __del__(self): 66 | info = {'item': self._infos, 'info': {'data_num': self._data_num, 'compression_type': self._compression_type}} 67 | info_str = json.dumps(info, indent=4, separators=(',', ':')) 68 | 69 | with open(os.path.join(self._save_path, 'info.json'), 'w') as info_f: 70 | info_f.write(info_str) 71 | 72 | if self._writer: 73 | self._writer.close() 74 | 75 | def add(self, feature_bytes_dict): 76 | """Add example. 77 | 78 | `feature_bytes_dict` example: 79 | feature_bytes_dict = { 80 | 'img': img_bytes, 81 | 'id': id_bytes, 82 | 'attr': attr_bytes, 83 | 'point': point_bytes 84 | } 85 | """ 86 | assert sorted(self._info_names) == sorted(feature_bytes_dict.keys()), \ 87 | 'Feature names are inconsistent with the givens!' 88 | 89 | self._new_tfrecord_check() 90 | 91 | self._writer.write(self._bytes_tfexample(feature_bytes_dict).SerializeToString()) 92 | self._data_num += 1 93 | 94 | def _new_tfrecord_check(self): 95 | if self._data_num // self._size_each == self._tfrecord_num: 96 | self._tfrecord_num += 1 97 | 98 | if self._writer: 99 | self._writer.close() 100 | 101 | tfrecord_path = os.path.join(self._save_path, 'data_%06d.tfrecord' % (self._tfrecord_num - 1)) 102 | self._writer = tf.python_io.TFRecordWriter(tfrecord_path, self._options) 103 | 104 | def _add_info(self, name, dtype_or_format, shape): 105 | assert name not in self._info_names, 'Info name "%s" is duplicated!' % name 106 | assert dtype_or_format in _ALLOWED_TYPES, 'Allowed data types: %s!' % str(_ALLOWED_TYPES) 107 | self._infos.append(dict(name=name, dtype_or_format=dtype_or_format, shape=shape)) 108 | self._info_names.append(name) 109 | 110 | @staticmethod 111 | def _bytes_feature(values): 112 | """Return a TF-Feature of bytes. 113 | 114 | Args: 115 | values: A byte string or list of byte strings. 116 | 117 | Returns: 118 | a TF-Feature. 119 | """ 120 | if not isinstance(values, (tuple, list)): 121 | values = [values] 122 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 123 | 124 | @staticmethod 125 | def _bytes_tfexample(bytes_dict): 126 | """Convert bytes to tfexample. 127 | 128 | `bytes_dict` example: 129 | bytes_dict = { 130 | 'img': img_bytes, 131 | 'id': id_bytes, 132 | 'attr': attr_bytes, 133 | 'point': point_bytes 134 | } 135 | """ 136 | feature_dict = {} 137 | for key, value in bytes_dict.items(): 138 | feature_dict[key] = BytesTfrecordCreator._bytes_feature(value) 139 | return tf.train.Example(features=tf.train.Features(feature=feature_dict)) 140 | 141 | 142 | class DataLablePairTfrecordCreator(BytesTfrecordCreator): 143 | """DataLablePairTfrecordCreator. 144 | 145 | If `data_shape` is None, then the `data` to be added should be a 146 | numpy array, and the shape and dtype of `data` will be inferred. 147 | If `data_shape` is not None, `data` should be given as byte string, 148 | and `data_dtype_or_format` should also be given. 149 | 150 | `compression_type`: 151 | 0: NONE 152 | 1: ZLIB 153 | 2: GZIP 154 | """ 155 | 156 | def __init__(self, save_path, data_shape=None, data_dtype_or_format=None, data_name='data', 157 | size_each=None, compression_type=0, overwrite_existence=False): 158 | super(DataLablePairTfrecordCreator, self).__init__(save_path, [], size_each, compression_type, overwrite_existence) 159 | 160 | if data_shape: 161 | assert data_dtype_or_format, '`data_dtype_or_format` should be given when `data_shape` is given!' 162 | self._is_data_bytes = True 163 | else: 164 | self._is_data_bytes = False 165 | 166 | self._data_shape = data_shape 167 | self._data_dtype_or_format = data_dtype_or_format 168 | self._data_name = data_name 169 | self._label_shape_dict = {} 170 | self._label_dtype_dict = {} 171 | 172 | self._info_built = False 173 | 174 | def add(self, data, label_dict): 175 | """Add example. 176 | 177 | `label_dict` example: 178 | label_dict = { 179 | 'id': id_ndarray, 180 | 'attr': attr_ndarray, 181 | 'point': point_ndarray 182 | } 183 | """ 184 | self._check_and_build(data, label_dict) 185 | 186 | if not self._is_data_bytes: 187 | data = data.tobytes() 188 | 189 | feature_dict = {self._data_name: data} 190 | for name, label in label_dict.items(): 191 | feature_dict[name] = label.tobytes() 192 | 193 | super(DataLablePairTfrecordCreator, self).add(feature_dict) 194 | 195 | def _check_and_build(self, data, label_dict): 196 | # check type 197 | if self._is_data_bytes: 198 | assert isinstance(data, (str, bytes)), '`data` should be a byte string!' 199 | else: 200 | assert isinstance(data, np.ndarray), '`data` should be a numpy array!' 201 | for label in label_dict.values(): 202 | assert isinstance(label, np.ndarray), 'labels should be numpy arrays!' 203 | 204 | # check shape and dtype or bulid info at first adding 205 | if self._info_built: 206 | if not self._is_data_bytes: 207 | assert data.shape == tuple(self._data_shape), 'Shapes of `data`s are inconsistent!' 208 | assert data.dtype.name == self._data_dtype_or_format, 'Dtypes of `data`s are inconsistent!' 209 | for name, label in label_dict.items(): 210 | assert label.shape == self._label_shape_dict[name], 'Shapes of `%s`s are inconsistent!' % name 211 | assert label.dtype.name == self._label_dtype_dict[name], 'Dtypes of `%s`s are inconsistent!' % name 212 | else: 213 | if not self._is_data_bytes: 214 | self._data_shape = data.shape 215 | self._data_dtype_or_format = data.dtype.name 216 | self._add_info(self._data_name, self._data_dtype_or_format, self._data_shape) 217 | 218 | for name, label in label_dict.items(): 219 | self._label_shape_dict[name] = label.shape 220 | self._label_dtype_dict[name] = label.dtype.name 221 | self._add_info(name, label.dtype.name, label.shape) 222 | 223 | self._info_built = True 224 | 225 | 226 | class ImageLablePairTfrecordCreator(DataLablePairTfrecordCreator): 227 | """ImageLablePairTfrecordCreator. 228 | 229 | `encode_type`: in [None, 'png', 'jpg']. 230 | `quality`: for 'jpg'. 231 | `compression_type`: 232 | 0: NONE 233 | 1: ZLIB 234 | 2: GZIP 235 | """ 236 | 237 | def __init__(self, save_path, encode_type='png', quality=95, data_name='img', 238 | size_each=None, compression_type=0, overwrite_existence=False): 239 | super(ImageLablePairTfrecordCreator, self).__init__( 240 | save_path, None, None, data_name, size_each, compression_type, overwrite_existence) 241 | 242 | assert encode_type in [None, 'png', 'jpg'], "`encode_type` should be in the list of [None, 'png', 'jpg']!" 243 | 244 | self._encode_type = encode_type 245 | self._quality = quality 246 | 247 | self._data_shape = None 248 | self._data_dtype_or_format = None 249 | self._is_data_bytes = True 250 | 251 | def add(self, image, label_dict): 252 | """Add example. 253 | 254 | `image`: an H * W (* C) uint8 numpy array. 255 | 256 | `label_dict` example: 257 | label_dict = { 258 | 'id': id_ndarray, 259 | 'attr': attr_ndarray, 260 | 'point': point_ndarray 261 | } 262 | """ 263 | self._check(image) 264 | image_bytes = self._encode(image) 265 | super(ImageLablePairTfrecordCreator, self).add(image_bytes, label_dict) 266 | 267 | def _check(self, image): 268 | if not self._data_shape: 269 | assert isinstance(image, np.ndarray) and image.dtype == np.uint8 and image.ndim in [2, 3], \ 270 | '`image` should be an H * W (* C) uint8 numpy array!' 271 | if self._encode_type and image.ndim == 3 and image.shape[-1] != 3: 272 | raise Exception('Only images with 1 or 3 channels are allowed to be encoded!') 273 | 274 | if image.ndim == 2: 275 | self._data_shape = image.shape + (1,) 276 | else: 277 | self._data_shape = image.shape 278 | self._data_dtype_or_format = [self._encode_type, 'uint8'][not self._encode_type] 279 | else: 280 | sp = image.shape 281 | if image.ndim == 2: 282 | sp = sp + (1,) 283 | assert sp == self._data_shape, 'Shapes of `image`s are inconsistent!' 284 | assert image.dtype == np.uint8, 'Dtypes of `image`s are inconsistent!' 285 | 286 | def _encode(self, image): 287 | if self._encode_type: 288 | if image.shape[-1] == 1: 289 | image.shape = image.shape[:2] 290 | byte = io.BytesIO() 291 | image = Image.fromarray(image) 292 | if self._encode_type == 'jpg': 293 | image.save(byte, 'JPEG', quality=self._quality) 294 | elif self._encode_type == 'png': 295 | image.save(byte, 'PNG') 296 | image_bytes = byte.getvalue() 297 | else: 298 | image_bytes = image.tobytes() 299 | return image_bytes 300 | -------------------------------------------------------------------------------- /tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.ops.layers import * 6 | -------------------------------------------------------------------------------- /tflib/ops/layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | 9 | def flatten_fully_connected(inputs, 10 | num_outputs, 11 | activation_fn=tf.nn.relu, 12 | normalizer_fn=None, 13 | normalizer_params=None, 14 | weights_initializer=slim.xavier_initializer(), 15 | weights_regularizer=None, 16 | biases_initializer=tf.zeros_initializer(), 17 | biases_regularizer=None, 18 | reuse=None, 19 | variables_collections=None, 20 | outputs_collections=None, 21 | trainable=True, 22 | scope=None): 23 | with tf.variable_scope(scope, 'flatten_fully_connected', [inputs]): 24 | if inputs.shape.ndims > 2: 25 | inputs = slim.flatten(inputs) 26 | return slim.fully_connected(inputs, 27 | num_outputs, 28 | activation_fn, 29 | normalizer_fn, 30 | normalizer_params, 31 | weights_initializer, 32 | weights_regularizer, 33 | biases_initializer, 34 | biases_regularizer, 35 | reuse, 36 | variables_collections, 37 | outputs_collections, 38 | trainable, 39 | scope) 40 | 41 | flatten_dense = flatten_fully_connected 42 | -------------------------------------------------------------------------------- /tflib/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import re 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def session(graph=None, allow_soft_placement=True, 11 | log_device_placement=False, allow_growth=True): 12 | """Return a Session with simple config.""" 13 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement, 14 | log_device_placement=log_device_placement) 15 | config.gpu_options.allow_growth = allow_growth 16 | return tf.Session(graph=graph, config=config) 17 | 18 | 19 | def print_tensor(tensors): 20 | if not isinstance(tensors, (list, tuple)): 21 | tensors = [tensors] 22 | 23 | for i, tensor in enumerate(tensors): 24 | ctype = str(type(tensor)) 25 | if 'Tensor' in ctype: 26 | type_name = 'Tensor' 27 | elif 'Variable' in ctype: 28 | type_name = 'Variable' 29 | else: 30 | raise Exception('Not a Tensor or Variable!') 31 | 32 | print(str(i) + (': %s("%s", shape=%s, dtype=%s, device=%s)' 33 | % (type_name, tensor.name, str(tensor.get_shape()), 34 | tensor.dtype.name, tensor.device))) 35 | 36 | prt = print_tensor 37 | 38 | 39 | def shape(tensor): 40 | sp = tensor.get_shape().as_list() 41 | return [num if num is not None else -1 for num in sp] 42 | 43 | 44 | def summary(tensor_collection, 45 | summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram'], 46 | scope=None): 47 | """Summary. 48 | 49 | usage: 50 | 1. summary(tensor) 51 | 2. summary([tensor_a, tensor_b]) 52 | 3. summary({tensor_a: 'a', tensor_b: 'b}) 53 | """ 54 | def _summary(tensor, name, summary_type): 55 | """Attach a lot of summaries to a Tensor.""" 56 | if name is None: 57 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 58 | # session. This helps the clarity of presentation on tensorboard. 59 | name = re.sub('%s_[0-9]*/' % 'tower', '', tensor.name) 60 | name = re.sub(':', '-', name) 61 | 62 | summaries = [] 63 | if len(tensor.shape) == 0: 64 | summaries.append(tf.summary.scalar(name, tensor)) 65 | else: 66 | if 'mean' in summary_type: 67 | mean = tf.reduce_mean(tensor) 68 | summaries.append(tf.summary.scalar(name + '/mean', mean)) 69 | if 'stddev' in summary_type: 70 | mean = tf.reduce_mean(tensor) 71 | stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean))) 72 | summaries.append(tf.summary.scalar(name + '/stddev', stddev)) 73 | if 'max' in summary_type: 74 | summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor))) 75 | if 'min' in summary_type: 76 | summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor))) 77 | if 'sparsity' in summary_type: 78 | summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor))) 79 | if 'histogram' in summary_type: 80 | summaries.append(tf.summary.histogram(name, tensor)) 81 | return tf.summary.merge(summaries) 82 | 83 | if not isinstance(tensor_collection, (list, tuple, dict)): 84 | tensor_collection = [tensor_collection] 85 | 86 | with tf.name_scope(scope, 'summary'): 87 | summaries = [] 88 | if isinstance(tensor_collection, (list, tuple)): 89 | for tensor in tensor_collection: 90 | summaries.append(_summary(tensor, None, summary_type)) 91 | else: 92 | for tensor, name in tensor_collection.items(): 93 | summaries.append(_summary(tensor, name, summary_type)) 94 | return tf.summary.merge(summaries) 95 | 96 | 97 | def counter(start=0, scope=None): 98 | with tf.variable_scope(scope, 'counter'): 99 | counter = tf.get_variable(name='counter', 100 | initializer=tf.constant_initializer(start), 101 | shape=(), 102 | dtype=tf.int64) 103 | update_cnt = tf.assign(counter, tf.add(counter, 1)) 104 | return counter, update_cnt 105 | -------------------------------------------------------------------------------- /tflib/variable.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def tensors_filter(tensors, filters, combine_type='or'): 9 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!' 10 | assert isinstance(filters, (str, list, tuple)), '`filters` should be a string or a list(tuple) of strings!' 11 | assert combine_type == 'or' or combine_type == 'and', "`combine_type` should be 'or' or 'and'!" 12 | 13 | if isinstance(filters, str): 14 | filters = [filters] 15 | 16 | f_tens = [] 17 | for ten in tensors: 18 | if combine_type == 'or': 19 | for filt in filters: 20 | if filt in ten.name: 21 | f_tens.append(ten) 22 | break 23 | elif combine_type == 'and': 24 | all_pass = True 25 | for filt in filters: 26 | if filt not in ten.name: 27 | all_pass = False 28 | break 29 | if all_pass: 30 | f_tens.append(ten) 31 | return f_tens 32 | 33 | 34 | def global_variables(filters=None, combine_type='or'): 35 | global_vars = tf.global_variables() 36 | if filters is None: 37 | return global_vars 38 | else: 39 | return tensors_filter(global_vars, filters, combine_type) 40 | 41 | 42 | def trainable_variables(filters=None, combine_type='or'): 43 | t_var = tf.trainable_variables() 44 | if filters is None: 45 | return t_var 46 | else: 47 | return tensors_filter(t_var, filters, combine_type) 48 | -------------------------------------------------------------------------------- /tflib/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.vision.dataset import * 6 | -------------------------------------------------------------------------------- /tflib/vision/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.vision.dataset.mnist import * 6 | -------------------------------------------------------------------------------- /tflib/vision/dataset/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gzip 6 | import multiprocessing 7 | import os 8 | import struct 9 | import subprocess 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from tflib.data.memory_data import MemoryData 14 | 15 | 16 | _N_CPU = multiprocessing.cpu_count() 17 | 18 | 19 | def unzip_gz(file_name): 20 | unzip_name = file_name.replace('.gz', '') 21 | gz_file = gzip.GzipFile(file_name) 22 | open(unzip_name, 'w+').write(gz_file.read()) 23 | gz_file.close() 24 | 25 | 26 | def mnist_download(download_dir): 27 | url_base = 'http://yann.lecun.com/exdb/mnist/' 28 | file_names = ['train-images-idx3-ubyte.gz', 29 | 'train-labels-idx1-ubyte.gz', 30 | 't10k-images-idx3-ubyte.gz', 31 | 't10k-labels-idx1-ubyte.gz'] 32 | for file_name in file_names: 33 | url = url_base + file_name 34 | save_path = os.path.join(download_dir, file_name) 35 | cmd = ['curl', url, '-o', save_path] 36 | print('Downloading ', file_name) 37 | if not os.path.exists(save_path): 38 | subprocess.call(cmd) 39 | else: 40 | print('%s exists, skip!' % file_name) 41 | 42 | 43 | def mnist_load(data_dir, split='train'): 44 | """Load MNIST dataset, modified from https://gist.github.com/akesling/5358964. 45 | 46 | Returns: 47 | A tuple as (`imgs`, `lbls`, `num`). 48 | 49 | `imgs`: [-1.0, 1.0] float64 images of shape (N * H * W). 50 | `lbls`: Int labels of shape (N,). 51 | `num`: # of datas. 52 | """ 53 | mnist_download(data_dir) 54 | 55 | if split == 'train': 56 | fname_img = os.path.join(data_dir, 'train-images-idx3-ubyte') 57 | fname_lbl = os.path.join(data_dir, 'train-labels-idx1-ubyte') 58 | elif split == 'test': 59 | fname_img = os.path.join(data_dir, 't10k-images-idx3-ubyte') 60 | fname_lbl = os.path.join(data_dir, 't10k-labels-idx1-ubyte') 61 | else: 62 | raise ValueError("split must be 'test' or 'train'") 63 | 64 | if not os.path.exists(fname_img): 65 | unzip_gz(fname_img + '.gz') 66 | if not os.path.exists(fname_lbl): 67 | unzip_gz(fname_lbl + '.gz') 68 | 69 | # Load everything in some numpy arrays 70 | with open(fname_lbl, 'rb') as flbl: 71 | struct.unpack('>II', flbl.read(8)) 72 | lbls = np.fromfile(flbl, dtype=np.int8) 73 | 74 | with open(fname_img, 'rb') as fimg: 75 | _, _, rows, cols = struct.unpack('>IIII', fimg.read(16)) 76 | imgs = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbls), rows, cols) 77 | imgs = imgs / 127.5 - 1 78 | 79 | return imgs, lbls, len(lbls) 80 | 81 | 82 | class Mnist(MemoryData): 83 | 84 | def __init__(self, data_dir, batch_size, split='train', prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 85 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 86 | imgs, lbls, self.n_data = mnist_load(data_dir, split) 87 | imgs.shape = imgs.shape + (1,) 88 | 89 | imgs_pl = tf.placeholder(tf.float32, imgs.shape) 90 | lbls_pl = tf.placeholder(tf.int64, lbls.shape) 91 | 92 | memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl} 93 | 94 | self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls} 95 | super(Mnist, self).__init__(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter, 96 | map_func, num_threads, shuffle, buffer_size, repeat, sess) 97 | 98 | def __len__(self): 99 | return self.n_data 100 | 101 | def reset(self): 102 | super(Mnist, self).reset(self.feed_dict) 103 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import traceback 7 | 8 | import imlib as im 9 | import numpy as np 10 | import pylib 11 | import tensorflow as tf 12 | import tflib as tl 13 | import utils 14 | 15 | 16 | # ============================================================================== 17 | # = param = 18 | # ============================================================================== 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--epoch', dest='epoch', type=int, default=50) 22 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64) 23 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate') 24 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=100, help='dimension of latent') 25 | parser.add_argument('--divergence', dest='divergence', default='Jensen-Shannon', 26 | choices=['Kullback-Leibler', 'Reverse-KL', 'Pearson-X2', 'Squared-Hellinger', 'Jensen-Shannon', 'GAN']) 27 | parser.add_argument('--tricky_G', dest='tricky_G', action='store_true', help='use tricky G loss or not') 28 | parser.add_argument('--dataset', dest='dataset_name', default='mnist', choices=['mnist', 'celeba']) 29 | 30 | args = parser.parse_args() 31 | 32 | epoch = args.epoch 33 | batch_size = args.batch_size 34 | lr = args.lr 35 | z_dim = args.z_dim 36 | 37 | divergence = args.divergence 38 | tricky_G = args.tricky_G 39 | dataset_name = args.dataset_name 40 | print(tricky_G) 41 | experiment_name = '%s_%s_%s' % (dataset_name, divergence, 'trickyG' if tricky_G else 'normalG') 42 | 43 | # dataset and models 44 | Dataset, models = utils.get_dataset_models(dataset_name) 45 | dataset = Dataset(batch_size=batch_size) 46 | G = models['G'] 47 | D = models['D'] 48 | activation_fn, conjugate_fn = utils.get_divergence_funcs(divergence) 49 | 50 | 51 | # ============================================================================== 52 | # = graph = 53 | # ============================================================================== 54 | 55 | # inputs 56 | real = tf.placeholder(tf.float32, [None, 28, 28, 1]) 57 | z = tf.placeholder(tf.float32, [None, z_dim]) 58 | 59 | # generate 60 | fake = G(z) 61 | 62 | # dicriminate 63 | r_output = D(real) 64 | f_output = D(fake) 65 | 66 | # losses 67 | d_r_loss = -tf.reduce_mean(activation_fn(r_output)) 68 | d_f_loss = tf.reduce_mean(conjugate_fn(activation_fn(f_output))) 69 | d_loss = d_r_loss + d_f_loss 70 | if tricky_G: 71 | g_loss = -tf.reduce_mean(activation_fn(f_output)) 72 | else: 73 | g_loss = -d_f_loss 74 | 75 | # otpims 76 | d_var = tl.trainable_variables('D') 77 | g_var = tl.trainable_variables('G') 78 | d_step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(d_loss, var_list=d_var) 79 | g_step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(g_loss, var_list=g_var) 80 | 81 | # summaries 82 | d_summary = tl.summary({d_r_loss: 'd_r_loss', 83 | d_f_loss: 'd_f_loss', 84 | -d_loss: '%s_diverngence' % divergence}, scope='D') 85 | g_summary = tl.summary({g_loss: 'g_loss'}, scope='G') 86 | 87 | # sample 88 | f_sample = G(z, is_training=False) 89 | 90 | 91 | # ============================================================================== 92 | # = train = 93 | # ============================================================================== 94 | 95 | # session 96 | sess = tl.session() 97 | 98 | # saver 99 | saver = tf.train.Saver(max_to_keep=1) 100 | 101 | # summary writer 102 | summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph) 103 | 104 | # initialization 105 | ckpt_dir = './output/%s/checkpoints' % experiment_name 106 | pylib.mkdir(ckpt_dir) 107 | try: 108 | tl.load_checkpoint(ckpt_dir, sess) 109 | except: 110 | sess.run(tf.global_variables_initializer()) 111 | 112 | # train 113 | try: 114 | z_ipt_sample = np.random.normal(size=[100, z_dim]) 115 | 116 | it = -1 117 | it_per_epoch = len(dataset) // batch_size 118 | for ep in range(epoch): 119 | dataset.reset() 120 | for batch in dataset: 121 | it += 1 122 | it_in_epoch = it % it_per_epoch + 1 123 | 124 | # batch data 125 | real_ipt = batch['img'] 126 | z_ipt = np.random.normal(size=[batch_size, z_dim]) 127 | 128 | # train D 129 | d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={real: real_ipt, z: z_ipt}) 130 | summary_writer.add_summary(d_summary_opt, it) 131 | 132 | # train G 133 | g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={z: z_ipt}) 134 | summary_writer.add_summary(g_summary_opt, it) 135 | 136 | # display 137 | if (it + 1) % 1 == 0: 138 | print("Epoch: (%3d) (%5d/%5d)" % (ep, it_in_epoch, it_per_epoch)) 139 | 140 | # sample 141 | if (it + 1) % 1000 == 0: 142 | f_sample_opt = sess.run(f_sample, feed_dict={z: z_ipt_sample}) 143 | 144 | save_dir = './output/%s/sample_training' % experiment_name 145 | pylib.mkdir(save_dir) 146 | im.imwrite(im.immerge(f_sample_opt), '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, ep, it_in_epoch, it_per_epoch)) 147 | 148 | save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep)) 149 | print('Model is saved in file: %s' % save_path) 150 | except: 151 | traceback.print_exc() 152 | finally: 153 | sess.close() 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | 7 | import pylib 8 | import tensorflow as tf 9 | import tflib as tl 10 | 11 | 12 | def get_divergence_funcs(divergence): 13 | if divergence == 'Kullback-Leibler': 14 | def activation_fn(v): return v 15 | 16 | def conjugate_fn(t): return tf.exp(t - 1) 17 | 18 | elif divergence == 'Reverse-KL': 19 | def activation_fn(v): return -tf.exp(-v) 20 | 21 | def conjugate_fn(t): return -1 - tf.log(-t) 22 | 23 | elif divergence == 'Pearson-X2': 24 | def activation_fn(v): return v 25 | 26 | def conjugate_fn(t): return 0.25 * t * t + t 27 | 28 | elif divergence == 'Squared-Hellinger': 29 | def activation_fn(v): return 1 - tf.exp(-v) 30 | 31 | def conjugate_fn(t): return t / (1 - t) 32 | 33 | elif divergence == 'Jensen-Shannon': 34 | def activation_fn(v): return tf.log(2.0) - tf.log(1 + tf.exp(-v)) 35 | 36 | def conjugate_fn(t): return -tf.log(2 - tf.exp(t)) 37 | 38 | elif divergence == 'GAN': 39 | def activation_fn(v): return -tf.log(1 + tf.exp(-v)) 40 | 41 | def conjugate_fn(t): return -tf.log(1 - tf.exp(t)) 42 | 43 | return activation_fn, conjugate_fn 44 | 45 | 46 | def get_dataset_models(dataset_name): 47 | if dataset_name == 'mnist': 48 | import models 49 | pylib.mkdir('./data/mnist') 50 | Dataset = partial(tl.Mnist, data_dir='./data/mnist', repeat=1) 51 | return Dataset, {'D': models.D, 'G': models.G} 52 | 53 | elif dataset_name == 'celeba': 54 | import models_64x64 55 | raise NotImplementedError 56 | --------------------------------------------------------------------------------