├── .gitignore ├── LICENSE ├── README.md ├── imlib ├── __init__.py ├── basic.py ├── dtype.py ├── encode.py └── transform.py ├── models.py ├── pics ├── z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_rec.jpg ├── z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_sample.jpg ├── z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_rec.jpg ├── z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_sample.jpg ├── z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_rec.jpg └── z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_sample.jpg ├── pylib ├── __init__.py ├── path.py └── timer.py ├── tflib ├── __init__.py ├── checkpoint.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── disk_image.py │ ├── memory_data.py │ ├── tfrecord.py │ └── tfrecord_creator.py ├── ops │ ├── __init__.py │ └── layers.py ├── utils.py ├── variable.py └── vision │ ├── __init__.py │ └── dataset │ ├── __init__.py │ └── mnist.py ├── train.py ├── traversal.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | /data/ 4 | /output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 hezhenliang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a 4 | copy of this software and associated documentation files (the "Software"), 5 | to deal in the Software without restriction, including without limitation 6 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | and/or sell copies of the Software, and to permit persons to whom the 8 | Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

(beta-)VAE

2 | 3 | Tensorflow implementation of [VAE](http://arxiv.org/abs/1312.6114) and [beta-VAE](https://openreview.net/pdf?id=Sy2fzU9gl) 4 | 5 | ## Exemplar results 6 | 7 | - Celeba 8 | 9 | - ConvNet (z_dim: 100, beta: 0.05) 10 | 11 | Generation | Reconstruction 12 | :---: | :---: 13 | | 14 | 15 | - Mnist 16 | 17 | - ConvNet (z_dim: 10, beta: 0.1) 18 | 19 | Generation | Reconstruction 20 | :---: | :---: 21 | | 22 | 23 | - MLP (z_dim: 10, beta: 0.1) 24 | 25 | Generation | Reconstruction 26 | :---: | :---: 27 | | 28 | 29 | ## Usage 30 | 31 | - Prerequisites 32 | - Tensorflow 1.8 33 | - Python 2.7 or 3.6 34 | 35 | 36 | - Examples of training 37 | 38 | ```console 39 | CUDA_VISIBLE_DEVICES=0 python train.py --z_dim 10 --beta 0.1 --dataset mnist --model mlp_mnist --experiment_name z10_beta0.1_mnist_mlp 40 | CUDA_VISIBLE_DEVICES=0 python train.py --z_dim 10 --beta 0.1 --dataset mnist --model conv_mnist --experiment_name z10_beta0.1_mnist_conv 41 | CUDA_VISIBLE_DEVICES=0 python train.py --z_dim 32 --beta 0.1 --dataset celeba --model conv_64 --experiment_name z32_beta0.1_celeba_conv 42 | ``` 43 | 44 | ## Datasets 45 | 46 | 1. Celeba should be prepared by yourself in ***./data/celeba/img_align_celeba/*.jpg*** 47 | - Download the dataset: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAB06FXaQRUNtjW9ntaoPGvCa?dl=0 48 | - the above links might be inaccessible, the alternatives are 49 | - ***img_align_celeba.zip*** 50 | - https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FImg or 51 | - https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg 52 | 2. Mnist will be automatically downloaded -------------------------------------------------------------------------------- /imlib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from imlib.basic import * 6 | from imlib.dtype import * 7 | from imlib.encode import * 8 | from imlib.transform import * 9 | -------------------------------------------------------------------------------- /imlib/basic.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from imlib.dtype import im2float 6 | import numpy as np 7 | import skimage.io as iio 8 | 9 | 10 | def imread(path, as_gray=False): 11 | """Read image. 12 | 13 | Returns: 14 | Float64 image in [-1.0, 1.0]. 15 | """ 16 | image = iio.imread(path, as_gray) 17 | if image.dtype == np.uint8: 18 | image = image / 127.5 - 1 19 | return image 20 | 21 | 22 | def imwrite(image, path): 23 | """Save an [-1.0, 1.0] image.""" 24 | iio.imsave(path, im2float(image)) 25 | 26 | 27 | def imshow(image): 28 | """Show a [-1.0, 1.0] image.""" 29 | iio.imshow(im2float(image)) 30 | 31 | 32 | show = iio.show 33 | -------------------------------------------------------------------------------- /imlib/dtype.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | 8 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf): 9 | # check type 10 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!' 11 | 12 | # check dtype 13 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes] 14 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes 15 | 16 | # check nan and inf 17 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!' 18 | 19 | # check value 20 | if min_value not in [None, -np.inf]: 21 | l = '[' + str(min_value) 22 | else: 23 | l = '(-inf' 24 | min_value = -np.inf 25 | if max_value not in [None, np.inf]: 26 | r = str(max_value) + ']' 27 | else: 28 | r = 'inf)' 29 | max_value = np.inf 30 | assert np.min(images) >= min_value - 1e-5 and np.max(images) <= max_value + 1e-5, \ 31 | '`images` should be in the range of %s!' % (l + ',' + r) 32 | 33 | 34 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None): 35 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype.""" 36 | _check(images, [np.float32, np.float64], -1.0, 1.0) 37 | dtype = dtype if dtype else images.dtype 38 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype) 39 | 40 | 41 | def float2im(images): 42 | """Transform images from [0, 1.0] to [-1.0, 1.0].""" 43 | _check(images, [np.float32, np.float64], 0.0, 1.0) 44 | return images * 2 - 1.0 45 | 46 | 47 | def float2uint(images): 48 | """Transform images from [0, 1.0] to uint8.""" 49 | _check(images, [np.float32, np.float64], -0.0, 1.0) 50 | return (images * 255).astype(np.uint8) 51 | 52 | 53 | def im2uint(images): 54 | """Transform images from [-1.0, 1.0] to uint8.""" 55 | return to_range(images, 0, 255, np.uint8) 56 | 57 | 58 | def im2float(images): 59 | """Transform images from [-1.0, 1.0] to [0.0, 1.0].""" 60 | return to_range(images, 0.0, 1.0) 61 | 62 | 63 | def uint2im(images): 64 | """Transform images from uint8 to [-1.0, 1.0] of float64.""" 65 | _check(images, np.uint8) 66 | return images / 127.5 - 1.0 67 | 68 | 69 | def uint2float(images): 70 | """Transform images from uint8 to [0.0, 1.0] of float64.""" 71 | _check(images, np.uint8) 72 | return images / 255.0 73 | -------------------------------------------------------------------------------- /imlib/encode.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import io 6 | 7 | from imlib.dtype import im2uint, uint2im 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | def imencode(image, format='PNG', quality=95): 13 | """Encode an [-1.0, 1.0] image into byte string. 14 | 15 | Args: 16 | format : 'PNG' or 'JPEG'. 17 | quality : Only for 'JPEG'. 18 | 19 | Returns: 20 | Byte string. 21 | """ 22 | byte_io = io.BytesIO() 23 | image = Image.fromarray(im2uint(image)) 24 | image.save(byte_io, format=format, quality=quality) 25 | bytes = byte_io.getvalue() 26 | return bytes 27 | 28 | 29 | def imdecode(bytes): 30 | """Decode byte string to float64 image in [-1.0, 1.0]. 31 | 32 | Args: 33 | bytes: Byte string. 34 | 35 | Returns: 36 | A float64 image in [-1.0, 1.0]. 37 | """ 38 | byte_io = io.BytesIO() 39 | byte_io.write(bytes) 40 | image = np.array(Image.open(byte_io)) 41 | image = uint2im(image) 42 | return image 43 | -------------------------------------------------------------------------------- /imlib/transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import skimage.color as color 7 | import skimage.transform as transform 8 | 9 | 10 | rgb2gray = color.rgb2gray 11 | gray2rgb = color.gray2rgb 12 | 13 | imresize = transform.resize 14 | imrescale = transform.rescale 15 | 16 | 17 | def immerge(images, n_row=None, n_col=None, padding=0, pad_value=0): 18 | """Merge images to an image with (n_row * h) * (n_col * w). 19 | 20 | `images` is in shape of N * H * W(* C=1 or 3) 21 | """ 22 | n = images.shape[0] 23 | if n_row: 24 | n_row = max(min(n_row, n), 1) 25 | n_col = int(n - 0.5) // n_row + 1 26 | elif n_col: 27 | n_col = max(min(n_col, n), 1) 28 | n_row = int(n - 0.5) // n_col + 1 29 | else: 30 | n_row = int(n ** 0.5) 31 | n_col = int(n - 0.5) // n_row + 1 32 | 33 | h, w = images.shape[1], images.shape[2] 34 | shape = (h * n_row + padding * (n_row - 1), 35 | w * n_col + padding * (n_col - 1)) 36 | if images.ndim == 4: 37 | shape += (images.shape[3],) 38 | img = np.full(shape, pad_value, dtype=images.dtype) 39 | 40 | for idx, image in enumerate(images): 41 | i = idx % n_col 42 | j = idx // n_col 43 | img[j * (h + padding):j * (h + padding) + h, 44 | i * (w + padding):i * (w + padding) + w, ...] = image 45 | 46 | return img 47 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.slim as slim 9 | import tflib as tl 10 | 11 | conv = partial(slim.conv2d, activation_fn=None) 12 | dconv = partial(slim.conv2d_transpose, activation_fn=None) 13 | fc = partial(tl.flatten_fully_connected, activation_fn=None) 14 | relu = tf.nn.relu 15 | lrelu = tf.nn.leaky_relu 16 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None) 17 | 18 | 19 | def mlp_mnist(): 20 | 21 | def Enc(img, z_dim, dim=512, is_training=True): 22 | fc_relu = partial(fc, activation_fn=relu) 23 | 24 | with tf.variable_scope('Enc', reuse=tf.AUTO_REUSE): 25 | y = fc_relu(img, dim) 26 | y = fc_relu(y, dim * 2) 27 | z_mu = fc(y, z_dim) 28 | z_log_sigma_sq = fc(y, z_dim) 29 | return z_mu, z_log_sigma_sq 30 | 31 | def Dec(z, dim=512, channels=1, is_training=True): 32 | fc_relu = partial(fc, activation_fn=relu) 33 | 34 | with tf.variable_scope('Dec', reuse=tf.AUTO_REUSE): 35 | y = fc_relu(z, dim * 2) 36 | y = fc_relu(y, dim) 37 | y = tf.tanh(fc(y, 28 * 28 * channels)) 38 | img = tf.reshape(y, [-1, 28, 28, channels]) 39 | return img 40 | 41 | return Enc, Dec 42 | 43 | 44 | def conv_mnist(): 45 | 46 | def Enc(img, z_dim, dim=64, is_training=True): 47 | bn = partial(batch_norm, is_training=is_training) 48 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu) 49 | 50 | with tf.variable_scope('Enc', reuse=tf.AUTO_REUSE): 51 | y = conv_bn_lrelu(img, dim, 5, 2) 52 | y = conv_bn_lrelu(y, dim * 2, 5, 2) 53 | z_mu = fc(y, z_dim) 54 | z_log_sigma_sq = fc(y, z_dim) 55 | return z_mu, z_log_sigma_sq 56 | 57 | def Dec(z, dim=64, channels=1, is_training=True): 58 | bn = partial(batch_norm, is_training=is_training) 59 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu) 60 | 61 | with tf.variable_scope('Dec', reuse=tf.AUTO_REUSE): 62 | y = relu(fc(z, 7 * 7 * dim * 2)) 63 | y = tf.reshape(y, [-1, 7, 7, dim * 2]) 64 | y = dconv_bn_relu(y, dim * 1, 5, 2) 65 | img = tf.tanh(dconv(y, channels, 5, 2)) 66 | return img 67 | 68 | return Enc, Dec 69 | 70 | 71 | def conv_64(): 72 | 73 | def Enc(img, z_dim, dim=64, is_training=True): 74 | bn = partial(batch_norm, is_training=is_training) 75 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu) 76 | 77 | with tf.variable_scope('Enc', reuse=tf.AUTO_REUSE): 78 | y = conv_bn_lrelu(img, dim, 5, 2) 79 | y = conv_bn_lrelu(y, dim * 2, 5, 2) 80 | y = conv_bn_lrelu(y, dim * 4, 5, 2) 81 | y = conv_bn_lrelu(y, dim * 8, 5, 2) 82 | z_mu = fc(y, z_dim) 83 | z_log_sigma_sq = fc(y, z_dim) 84 | return z_mu, z_log_sigma_sq 85 | 86 | def Dec(z, dim=64, channels=3, is_training=True): 87 | bn = partial(batch_norm, is_training=is_training) 88 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu) 89 | 90 | with tf.variable_scope('Dec', reuse=tf.AUTO_REUSE): 91 | y = relu(fc(z, 4 * 4 * dim * 8)) 92 | y = tf.reshape(y, [-1, 4, 4, dim * 8]) 93 | y = dconv_bn_relu(y, dim * 4, 5, 2) 94 | y = dconv_bn_relu(y, dim * 2, 5, 2) 95 | y = dconv_bn_relu(y, dim * 1, 5, 2) 96 | img = tf.tanh(dconv(y, channels, 5, 2)) 97 | return img 98 | 99 | return Enc, Dec 100 | -------------------------------------------------------------------------------- /pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_rec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_rec.jpg -------------------------------------------------------------------------------- /pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_sample.jpg -------------------------------------------------------------------------------- /pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_rec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_rec.jpg -------------------------------------------------------------------------------- /pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_sample.jpg -------------------------------------------------------------------------------- /pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_rec.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_rec.jpg -------------------------------------------------------------------------------- /pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_sample.jpg -------------------------------------------------------------------------------- /pylib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pylib.path import * 6 | from pylib.timer import * 7 | -------------------------------------------------------------------------------- /pylib/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import fnmatch 6 | import os 7 | import sys 8 | 9 | 10 | def add_path(paths): 11 | if not isinstance(paths, (list, tuple)): 12 | paths = [paths] 13 | for path in paths: 14 | if path not in sys.path: 15 | sys.path.insert(0, path) 16 | 17 | 18 | def mkdir(paths): 19 | if not isinstance(paths, (list, tuple)): 20 | paths = [paths] 21 | for path in paths: 22 | if not os.path.isdir(path): 23 | os.makedirs(path) 24 | 25 | 26 | def split(path): 27 | dir, name_ext = os.path.split(path) 28 | name, ext = os.path.splitext(name_ext) 29 | return dir, name, ext 30 | 31 | 32 | def directory(path): 33 | return split(path)[0] 34 | 35 | 36 | def name(path): 37 | return split(path)[1] 38 | 39 | 40 | def ext(path): 41 | return split(path)[2] 42 | 43 | 44 | def name_ext(path): 45 | return ''.join(split(path)[1:]) 46 | 47 | 48 | asbpath = os.path.abspath 49 | 50 | 51 | join = os.path.join 52 | 53 | 54 | def match(dir, pat, recursive=False): 55 | if recursive: 56 | iterator = os.walk(dir) 57 | else: 58 | iterator = [next(os.walk(dir))] 59 | matches = [] 60 | for root, _, file_names in iterator: 61 | for file_name in fnmatch.filter(file_names, pat): 62 | matches.append(os.path.join(root, file_name)) 63 | return matches 64 | -------------------------------------------------------------------------------- /pylib/timer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import datetime 6 | import timeit 7 | 8 | 9 | class Timer(object): 10 | """A timer as a context manager. 11 | 12 | Modified from https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py. 13 | 14 | Wraps around a timer. A custom timer can be passed 15 | to the constructor. The default timer is timeit.default_timer. 16 | 17 | Note that the latter measures wall clock time, not CPU time! 18 | On Unix systems, it corresponds to time.time. 19 | On Windows systems, it corresponds to time.clock. 20 | 21 | Arguments: 22 | is_output : If True, print output after exiting context. 23 | format : 'ms', 's' or 'datetime' 24 | """ 25 | 26 | def __init__(self, timer=timeit.default_timer, is_output=True, fmt='s'): 27 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!" 28 | self._timer = timer 29 | self._is_output = is_output 30 | self._fmt = fmt 31 | 32 | def __enter__(self): 33 | """Start the timer in the context manager scope.""" 34 | self.start() 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_value, exc_traceback): 38 | """Set the end time.""" 39 | if self._is_output: 40 | print(str(self)) 41 | 42 | def __str__(self): 43 | if self._fmt != 'datetime': 44 | return '%s %s' % (self.elapsed, self._fmt) 45 | else: 46 | return str(self.elapsed) 47 | 48 | def start(self): 49 | self.start_time = self._timer() 50 | 51 | @property 52 | def elapsed(self): 53 | """Return the current elapsed time since start.""" 54 | e = self._timer() - self.start_time 55 | 56 | if self._fmt == 'ms': 57 | return e * 1000 58 | elif self._fmt == 's': 59 | return e 60 | elif self._fmt == 'datetime': 61 | return datetime.timedelta(seconds=e) 62 | 63 | 64 | def timer(**timer_kwargs): 65 | """Function decorator displaying the function execution time. 66 | 67 | All kwargs are the arguments taken by the Timer class constructor. 68 | """ 69 | # store Timer kwargs in local variable so the namespace isn't polluted 70 | # by different level args and kwargs 71 | 72 | def wrapped_f(f): 73 | def wrapped(*args, **kwargs): 74 | fmt = '[*] function "%(function_name)s" execution time: %(execution_time)s [*]' 75 | with Timer(**timer_kwargs) as t: 76 | out = f(*args, **kwargs) 77 | context = {'function_name': f.__name__, 'execution_time': str(t)} 78 | print(fmt % context) 79 | return out 80 | return wrapped 81 | 82 | return wrapped_f 83 | 84 | if __name__ == "__main__": 85 | import time 86 | 87 | # 1 88 | print(1) 89 | with Timer() as t: 90 | time.sleep(1) 91 | print(t) 92 | time.sleep(1) 93 | 94 | with Timer(fmt='datetime') as t: 95 | time.sleep(1) 96 | 97 | # 2 98 | print(2) 99 | t = Timer(fmt='ms') 100 | t.start() 101 | time.sleep(2) 102 | print(t) 103 | 104 | t = Timer(fmt='datetime') 105 | t.start() 106 | time.sleep(1) 107 | print(t) 108 | 109 | # 3 110 | print(3) 111 | 112 | @timer(fmt='ms') 113 | def blah(): 114 | time.sleep(2) 115 | 116 | blah() 117 | -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.checkpoint import * 6 | from tflib.data import * 7 | from tflib.ops import * 8 | from tflib.utils import * 9 | from tflib.variable import * 10 | from tflib.vision import * 11 | -------------------------------------------------------------------------------- /tflib/checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def load_checkpoint(ckpt_dir_or_file, session, var_list=None): 11 | """Load checkpoint. 12 | 13 | This function add some useless ops to the graph. It is better 14 | to use tf.train.init_from_checkpoint(...). 15 | """ 16 | if os.path.isdir(ckpt_dir_or_file): 17 | ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) 18 | 19 | restorer = tf.train.Saver(var_list) 20 | restorer.restore(session, ckpt_dir_or_file) 21 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file) 22 | 23 | 24 | def init_from_checkpoint(ckpt_dir_or_file, assignment_map={'/': '/'}): 25 | # Use the checkpoint values for the variables' initializers. Note that this 26 | # function just changes the initializers but does not actually run them, and 27 | # you should still run the initializers manually. 28 | tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map) 29 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file) 30 | -------------------------------------------------------------------------------- /tflib/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.data.dataset import * 6 | from tflib.data.disk_image import * 7 | from tflib.data.memory_data import * 8 | from tflib.data.tfrecord import * 9 | from tflib.data.tfrecord_creator import * 10 | -------------------------------------------------------------------------------- /tflib/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.eager as tfe 9 | from tflib.utils import session 10 | 11 | 12 | _N_CPU = multiprocessing.cpu_count() 13 | 14 | 15 | def batch_dataset(dataset, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 16 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 17 | if filter: 18 | dataset = dataset.filter(filter) 19 | 20 | if map_func: 21 | dataset = dataset.map(map_func, num_parallel_calls=num_threads) 22 | 23 | if shuffle: 24 | dataset = dataset.shuffle(buffer_size) 25 | 26 | if drop_remainder: 27 | dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) 28 | else: 29 | dataset = dataset.batch(batch_size) 30 | 31 | dataset = dataset.repeat(repeat).prefetch(prefetch_batch) 32 | 33 | return dataset 34 | 35 | 36 | class Dataset(object): 37 | 38 | def __init__(self): 39 | self._dataset = None 40 | self._iterator = None 41 | self._batch_op = None 42 | self._sess = None 43 | 44 | self._is_eager = tf.executing_eagerly() 45 | self._eager_iterator = None 46 | 47 | def __del__(self): 48 | if self._sess: 49 | self._sess.close() 50 | 51 | def __iter__(self): 52 | return self 53 | 54 | def __next__(self): 55 | try: 56 | b = self.get_next() 57 | except: 58 | raise StopIteration 59 | else: 60 | return b 61 | 62 | next = __next__ 63 | 64 | def get_next(self): 65 | if self._is_eager: 66 | return self._eager_iterator.get_next() 67 | else: 68 | return self._sess.run(self._batch_op) 69 | 70 | def reset(self, feed_dict={}): 71 | if self._is_eager: 72 | self._eager_iterator = tfe.Iterator(self._dataset) 73 | else: 74 | self._sess.run(self._iterator.initializer, feed_dict=feed_dict) 75 | 76 | def _bulid(self, dataset, sess=None): 77 | self._dataset = dataset 78 | 79 | if self._is_eager: 80 | self._eager_iterator = tfe.Iterator(dataset) 81 | else: 82 | self._iterator = dataset.make_initializable_iterator() 83 | self._batch_op = self._iterator.get_next() 84 | if sess: 85 | self._sess = sess 86 | else: 87 | self._sess = session() 88 | 89 | try: 90 | self.reset() 91 | except: 92 | pass 93 | 94 | @property 95 | def dataset(self): 96 | return self._dataset 97 | 98 | @property 99 | def iterator(self): 100 | return self._iterator 101 | 102 | @property 103 | def batch_op(self): 104 | return self._batch_op 105 | -------------------------------------------------------------------------------- /tflib/data/disk_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | from tflib.data.dataset import batch_dataset, Dataset 9 | 10 | 11 | _N_CPU = multiprocessing.cpu_count() 12 | 13 | 14 | def disk_image_batch_dataset(img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 15 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 16 | """Disk image batch dataset. 17 | 18 | This function is suitable for jpg and png files 19 | 20 | Arguments: 21 | img_paths : String list or 1-D tensor, each of which is an iamge path 22 | labels : Label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label 23 | """ 24 | if labels is None: 25 | dataset = tf.data.Dataset.from_tensor_slices(img_paths) 26 | elif isinstance(labels, tuple): 27 | dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels)) 28 | else: 29 | dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels)) 30 | 31 | def parse_func(path, *label): 32 | img = tf.read_file(path) 33 | img = tf.image.decode_png(img, 3) 34 | return (img,) + label 35 | 36 | if map_func: 37 | def map_func_(*args): 38 | return map_func(*parse_func(*args)) 39 | else: 40 | map_func_ = parse_func 41 | 42 | # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower 43 | 44 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, 45 | map_func_, num_threads, shuffle, buffer_size, repeat) 46 | 47 | return dataset 48 | 49 | 50 | class DiskImageData(Dataset): 51 | """DiskImageData. 52 | 53 | This class is suitable for jpg and png files 54 | 55 | Arguments: 56 | img_paths : String list or 1-D tensor, each of which is an iamge path 57 | labels : Label list or tensor, each of which is a corresponding label 58 | """ 59 | 60 | def __init__(self, img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 61 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 62 | super(DiskImageData, self).__init__() 63 | dataset = disk_image_batch_dataset(img_paths, batch_size, labels, prefetch_batch, drop_remainder, filter, 64 | map_func, num_threads, shuffle, buffer_size, repeat) 65 | self._bulid(dataset, sess) 66 | self._n_data = len(img_paths) 67 | 68 | def __len__(self): 69 | return self._n_data 70 | 71 | 72 | if __name__ == '__main__': 73 | import glob 74 | 75 | import imlib as im 76 | import numpy as np 77 | import pylib 78 | 79 | paths = glob.glob('/home/hezhenliang/Resource/face/CelebA/origin/origin/processed_by_hezhenliang/align_celeba/img_align_celeba/*.jpg') 80 | paths = sorted(paths)[182637:] 81 | labels = range(len(paths)) 82 | 83 | def filter(x, y, *args): 84 | return tf.cond(y > 1, lambda: tf.constant(True), lambda: tf.constant(False)) 85 | 86 | def map_func(x, *args): 87 | x = tf.image.resize_images(x, [256, 256]) 88 | x = tf.to_float((x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)) * 2 - 1) 89 | return (x,) + args 90 | 91 | # tf.enable_eager_execution() 92 | 93 | s = tf.Session() 94 | 95 | data = DiskImageData(paths, 128, (labels, labels), filter=None, map_func=map_func, shuffle=False, sess=s) 96 | 97 | for _ in range(1000): 98 | with pylib.Timer(): 99 | for i in range(100): 100 | b = data.get_next() 101 | print(b[1][0]) 102 | print(b[2][0]) 103 | im.imshow(np.array(b[0][0])) 104 | im.show() 105 | # data.reset() 106 | -------------------------------------------------------------------------------- /tflib/data/memory_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | from tflib.data.dataset import batch_dataset, Dataset 10 | 11 | 12 | _N_CPU = multiprocessing.cpu_count() 13 | 14 | 15 | def memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 16 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 17 | """Memory data batch dataset. 18 | 19 | `memory_data_dict` example: 20 | {'img': img_ndarray, 'label': label_ndarray} or 21 | {'img': img_tftensor, 'label': label_tftensor} 22 | * The value of each item of `memory_data_dict` is in shape of (N, ...). 23 | """ 24 | dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict) 25 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, 26 | map_func, num_threads, shuffle, buffer_size, repeat) 27 | return dataset 28 | 29 | 30 | class MemoryData(Dataset): 31 | """MemoryData. 32 | 33 | `memory_data_dict` example: 34 | {'img': img_ndarray, 'label': label_ndarray} or 35 | {'img': img_tftensor, 'label': label_tftensor} 36 | * The value of each item of `memory_data_dict` is in shape of (N, ...). 37 | """ 38 | 39 | def __init__(self, memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 40 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 41 | super(MemoryData, self).__init__() 42 | dataset = memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter, 43 | map_func, num_threads, shuffle, buffer_size, repeat) 44 | self._bulid(dataset, sess) 45 | first_value = next(iter(memory_data_dict.values())) 46 | if isinstance(first_value, np.ndarray): 47 | self._n_data = len(first_value) 48 | else: 49 | self._n_data = first_value.get_shape().as_list()[0] 50 | 51 | def __len__(self): 52 | return self._n_data 53 | 54 | if __name__ == '__main__': 55 | import numpy as np 56 | data = {'a': np.array([1.0, 2, 3, 4, 5]), 57 | 'b': np.array([[1, 2], 58 | [2, 3], 59 | [3, 4], 60 | [4, 5], 61 | [5, 6]])} 62 | 63 | def filter(x): 64 | return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False)) 65 | 66 | def map_func(x): 67 | x['a'] = x['a'] * 10 68 | return x 69 | 70 | # tf.enable_eager_execution() 71 | 72 | s = tf.Session() 73 | 74 | dataset = MemoryData(data, 2, filter=None, map_func=map_func, 75 | shuffle=True, buffer_size=4096, drop_remainder=True, repeat=4, sess=s) 76 | 77 | for i in range(5): 78 | print(map(dataset.get_next().__getitem__, ['b', 'a'])) 79 | 80 | print([n.name for n in tf.get_default_graph().as_graph_def().node]) 81 | -------------------------------------------------------------------------------- /tflib/data/tfrecord.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import glob 7 | import json 8 | import multiprocessing 9 | import os 10 | 11 | import tensorflow as tf 12 | from tflib.data.dataset import batch_dataset, Dataset 13 | 14 | 15 | _N_CPU = multiprocessing.cpu_count() 16 | 17 | _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 18 | 19 | _DECODERS = { 20 | 'png': {'decoder': tf.image.decode_png, 'decode_param': dict()}, 21 | 'jpg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()}, 22 | 'jpeg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()}, 23 | 'uint8': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.uint8)}, 24 | 'int64': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.int64)}, 25 | 'float32': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.float32)}, 26 | } 27 | 28 | 29 | def tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, 30 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1): 31 | """Tfrecord batch dataset. 32 | 33 | `infos` example: 34 | [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]}, 35 | {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}] 36 | """ 37 | dataset = tf.data.TFRecordDataset(tfrecord_files, 38 | compression_type=compression_type, 39 | buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES) 40 | 41 | features = {} 42 | for info in infos: 43 | features[info['name']] = tf.FixedLenFeature([], tf.string) 44 | 45 | def parse_func(serialized_example): 46 | example = tf.parse_single_example(serialized_example, features=features) 47 | 48 | feature_dict = {} 49 | for info in infos: 50 | name = info['name'] 51 | decoder = info['decoder'] 52 | decode_param = info['decode_param'] 53 | shape = info['shape'] 54 | 55 | feature = decoder(example[name], **decode_param) 56 | feature = tf.reshape(feature, shape) 57 | feature_dict[name] = feature 58 | 59 | return feature_dict 60 | 61 | dataset = dataset.map(parse_func, num_parallel_calls=num_threads) 62 | 63 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, 64 | map_func, num_threads, shuffle, buffer_size, repeat) 65 | 66 | return dataset 67 | 68 | 69 | class TfrecordData(Dataset): 70 | 71 | def __init__(self, tfrecord_path, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, 72 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 73 | super(TfrecordData, self).__init__() 74 | 75 | info_file = os.path.join(tfrecord_path, 'info.json') 76 | infos, self._data_num, compression_type = self._parse_json(info_file) 77 | 78 | self._shapes = {info['name']: tuple(info['shape']) for info in infos} 79 | 80 | tfrecord_files = sorted(glob.glob(os.path.join(tfrecord_path, '*.tfrecord'))) 81 | dataset = tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch, drop_remainder, 82 | filter, map_func, num_threads, shuffle, buffer_size, repeat) 83 | 84 | self._bulid(dataset, sess) 85 | 86 | def __len__(self): 87 | return self._data_num 88 | 89 | @property 90 | def shape(self): 91 | return self._shapes 92 | 93 | @staticmethod 94 | def _parse_old(json_file): 95 | with open(json_file.replace('info.json', 'info.txt')) as f: 96 | try: # older version 1 97 | infos = json.load(f) 98 | for info in infos[0:-1]: 99 | info['decoder'] = _DECODERS[info['dtype_or_format']]['decoder'] 100 | info['decode_param'] = _DECODERS[info['dtype_or_format']]['decode_param'] 101 | except: # older version 2 102 | f.seek(0) 103 | infos = '' 104 | for line in f.readlines(): 105 | infos += line.strip('\n') 106 | infos = eval(infos) 107 | 108 | data_num = infos[-1]['data_num'] 109 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[infos[-1]['compression_type']] 110 | infos[-1:] = [] 111 | 112 | return infos, data_num, compression_type 113 | 114 | @staticmethod 115 | def _parse_json(json_file): 116 | try: 117 | with open(json_file) as f: 118 | info = json.load(f) 119 | infos = info['item'] 120 | for i in infos: 121 | i['decoder'] = _DECODERS[i['dtype_or_format']]['decoder'] 122 | i['decode_param'] = _DECODERS[i['dtype_or_format']]['decode_param'] 123 | data_num = info['info']['data_num'] 124 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[info['info']['compression_type']] 125 | except: # for older version 126 | infos, data_num, compression_type = TfrecordData._parse_old(json_file) 127 | 128 | return infos, data_num, compression_type 129 | -------------------------------------------------------------------------------- /tflib/data/tfrecord_creator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import io 6 | import json 7 | import os 8 | import shutil 9 | 10 | import numpy as np 11 | from PIL import Image 12 | import tensorflow as tf 13 | from tflib.data import tfrecord 14 | 15 | __metaclass__ = type 16 | 17 | 18 | _ALLOWED_TYPES = tfrecord._DECODERS.keys() 19 | 20 | 21 | class BytesTfrecordCreator(object): 22 | """BytesTfrecordCreator. 23 | 24 | `infos` example: 25 | infos = [ 26 | ['img', 'jpg', (64, 64, 3)], 27 | ['id', 'int64', ()], 28 | ['attr', 'int64', (40,)], 29 | ['point', 'float32', (5, 2)] 30 | ] 31 | 32 | `compression_type`: 33 | 0 : NONE 34 | 1 : ZLIB 35 | 2 : GZIP 36 | """ 37 | 38 | def __init__(self, save_path, infos, size_each=None, compression_type=0, overwrite_existence=False): 39 | # overwrite existence 40 | if os.path.exists(save_path): 41 | if not overwrite_existence: 42 | raise Exception('%s exists!' % save_path) 43 | else: 44 | shutil.rmtree(save_path) 45 | os.makedirs(save_path) 46 | else: 47 | os.makedirs(save_path) 48 | 49 | self._save_path = save_path 50 | 51 | # add info 52 | self._infos = [] 53 | self._info_names = [] 54 | for info in infos: 55 | self._add_info(*info) 56 | 57 | self._data_num = 0 58 | self._tfrecord_num = 0 59 | self._size_each = [size_each, 2147483647][not size_each] 60 | self._writer = None 61 | 62 | self._compression_type = compression_type 63 | self._options = tf.python_io.TFRecordOptions(compression_type) 64 | 65 | def __del__(self): 66 | info = {'item': self._infos, 'info': {'data_num': self._data_num, 'compression_type': self._compression_type}} 67 | info_str = json.dumps(info, indent=4, separators=(',', ':')) 68 | 69 | with open(os.path.join(self._save_path, 'info.json'), 'w') as info_f: 70 | info_f.write(info_str) 71 | 72 | if self._writer: 73 | self._writer.close() 74 | 75 | def add(self, feature_bytes_dict): 76 | """Add example. 77 | 78 | `feature_bytes_dict` example: 79 | feature_bytes_dict = { 80 | 'img' : img_bytes, 81 | 'id' : id_bytes, 82 | 'attr' : attr_bytes, 83 | 'point' : point_bytes 84 | } 85 | """ 86 | assert sorted(self._info_names) == sorted(feature_bytes_dict.keys()), \ 87 | 'Feature names are inconsistent with the givens!' 88 | 89 | self._new_tfrecord_check() 90 | 91 | self._writer.write(self._bytes_tfexample(feature_bytes_dict).SerializeToString()) 92 | self._data_num += 1 93 | 94 | def _new_tfrecord_check(self): 95 | if self._data_num // self._size_each == self._tfrecord_num: 96 | self._tfrecord_num += 1 97 | 98 | if self._writer: 99 | self._writer.close() 100 | 101 | tfrecord_path = os.path.join(self._save_path, 'data_%06d.tfrecord' % (self._tfrecord_num - 1)) 102 | self._writer = tf.python_io.TFRecordWriter(tfrecord_path, self._options) 103 | 104 | def _add_info(self, name, dtype_or_format, shape): 105 | assert name not in self._info_names, 'Info name "%s" is duplicated!' % name 106 | assert dtype_or_format in _ALLOWED_TYPES, 'Allowed data types: %s!' % str(_ALLOWED_TYPES) 107 | self._infos.append(dict(name=name, dtype_or_format=dtype_or_format, shape=shape)) 108 | self._info_names.append(name) 109 | 110 | @staticmethod 111 | def _bytes_feature(values): 112 | """Return a TF-Feature of bytes. 113 | 114 | Arguments: 115 | values : A byte string or list of byte strings. 116 | 117 | Returns: 118 | A TF-Feature. 119 | """ 120 | if not isinstance(values, (tuple, list)): 121 | values = [values] 122 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 123 | 124 | @staticmethod 125 | def _bytes_tfexample(bytes_dict): 126 | """Convert bytes to tfexample. 127 | 128 | `bytes_dict` example: 129 | bytes_dict = { 130 | 'img' : img_bytes, 131 | 'id' : id_bytes, 132 | 'attr' : attr_bytes, 133 | 'point' : point_bytes 134 | } 135 | """ 136 | feature_dict = {} 137 | for key, value in bytes_dict.items(): 138 | feature_dict[key] = BytesTfrecordCreator._bytes_feature(value) 139 | return tf.train.Example(features=tf.train.Features(feature=feature_dict)) 140 | 141 | 142 | class DataLablePairTfrecordCreator(BytesTfrecordCreator): 143 | """DataLablePairTfrecordCreator. 144 | 145 | If `data_shape` is None, then the `data` to be added should be a 146 | numpy array, and the shape and dtype of `data` will be inferred. 147 | If `data_shape` is not None, `data` should be given as byte string, 148 | and `data_dtype_or_format` should also be given. 149 | 150 | `compression_type`: 151 | 0 : NONE 152 | 1 : ZLIB 153 | 2 : GZIP 154 | """ 155 | 156 | def __init__(self, save_path, data_shape=None, data_dtype_or_format=None, data_name='data', 157 | size_each=None, compression_type=0, overwrite_existence=False): 158 | super(DataLablePairTfrecordCreator, self).__init__(save_path, [], size_each, compression_type, overwrite_existence) 159 | 160 | if data_shape: 161 | assert data_dtype_or_format, '`data_dtype_or_format` should be given when `data_shape` is given!' 162 | self._is_data_bytes = True 163 | else: 164 | self._is_data_bytes = False 165 | 166 | self._data_shape = data_shape 167 | self._data_dtype_or_format = data_dtype_or_format 168 | self._data_name = data_name 169 | self._label_shape_dict = {} 170 | self._label_dtype_dict = {} 171 | 172 | self._info_built = False 173 | 174 | def add(self, data, label_dict): 175 | """Add example. 176 | 177 | `label_dict` example: 178 | label_dict = { 179 | 'id' : id_ndarray, 180 | 'attr' : attr_ndarray, 181 | 'point' : point_ndarray 182 | } 183 | """ 184 | self._check_and_build(data, label_dict) 185 | 186 | if not self._is_data_bytes: 187 | data = data.tobytes() 188 | 189 | feature_dict = {self._data_name: data} 190 | for name, label in label_dict.items(): 191 | feature_dict[name] = label.tobytes() 192 | 193 | super(DataLablePairTfrecordCreator, self).add(feature_dict) 194 | 195 | def _check_and_build(self, data, label_dict): 196 | # check type 197 | if self._is_data_bytes: 198 | assert isinstance(data, (str, bytes)), '`data` should be a byte string!' 199 | else: 200 | assert isinstance(data, np.ndarray), '`data` should be a numpy array!' 201 | for label in label_dict.values(): 202 | assert isinstance(label, np.ndarray), 'labels should be numpy arrays!' 203 | 204 | # check shape and dtype or bulid info at first adding 205 | if self._info_built: 206 | if not self._is_data_bytes: 207 | assert data.shape == tuple(self._data_shape), 'Shapes of `data`s are inconsistent!' 208 | assert data.dtype.name == self._data_dtype_or_format, 'Dtypes of `data`s are inconsistent!' 209 | for name, label in label_dict.items(): 210 | assert label.shape == self._label_shape_dict[name], 'Shapes of `%s`s are inconsistent!' % name 211 | assert label.dtype.name == self._label_dtype_dict[name], 'Dtypes of `%s`s are inconsistent!' % name 212 | else: 213 | if not self._is_data_bytes: 214 | self._data_shape = data.shape 215 | self._data_dtype_or_format = data.dtype.name 216 | self._add_info(self._data_name, self._data_dtype_or_format, self._data_shape) 217 | 218 | for name, label in label_dict.items(): 219 | self._label_shape_dict[name] = label.shape 220 | self._label_dtype_dict[name] = label.dtype.name 221 | self._add_info(name, label.dtype.name, label.shape) 222 | 223 | self._info_built = True 224 | 225 | 226 | class ImageLablePairTfrecordCreator(DataLablePairTfrecordCreator): 227 | """ImageLablePairTfrecordCreator. 228 | 229 | Arguments: 230 | encode_type : One of [None, 'png', 'jpg']. 231 | quality : For 'jpg'. 232 | compression_type : 233 | 0 : NONE 234 | 1 : ZLIB 235 | 2 : GZIP 236 | """ 237 | 238 | def __init__(self, save_path, encode_type='png', quality=95, data_name='img', 239 | size_each=None, compression_type=0, overwrite_existence=False): 240 | super(ImageLablePairTfrecordCreator, self).__init__( 241 | save_path, None, None, data_name, size_each, compression_type, overwrite_existence) 242 | 243 | assert encode_type in [None, 'png', 'jpg'], "`encode_type` should be in the list of [None, 'png', 'jpg']!" 244 | 245 | self._encode_type = encode_type 246 | self._quality = quality 247 | 248 | self._data_shape = None 249 | self._data_dtype_or_format = None 250 | self._is_data_bytes = True 251 | 252 | def add(self, image, label_dict): 253 | """Add example. 254 | 255 | `image`: An H * W (* C) uint8 numpy array. 256 | 257 | `label_dict` example: 258 | label_dict = { 259 | 'id' : id_ndarray, 260 | 'attr' : attr_ndarray, 261 | 'point' : point_ndarray 262 | } 263 | """ 264 | self._check(image) 265 | image_bytes = self._encode(image) 266 | super(ImageLablePairTfrecordCreator, self).add(image_bytes, label_dict) 267 | 268 | def _check(self, image): 269 | if not self._data_shape: 270 | assert isinstance(image, np.ndarray) and image.dtype == np.uint8 and image.ndim in [2, 3], \ 271 | '`image` should be an H * W (* C) uint8 numpy array!' 272 | if self._encode_type and image.ndim == 3 and image.shape[-1] != 3: 273 | raise Exception('Only images with 1 or 3 channels are allowed to be encoded!') 274 | 275 | if image.ndim == 2: 276 | self._data_shape = image.shape + (1,) 277 | else: 278 | self._data_shape = image.shape 279 | self._data_dtype_or_format = [self._encode_type, 'uint8'][not self._encode_type] 280 | else: 281 | sp = image.shape 282 | if image.ndim == 2: 283 | sp = sp + (1,) 284 | assert sp == self._data_shape, 'Shapes of `image`s are inconsistent!' 285 | assert image.dtype == np.uint8, 'Dtypes of `image`s are inconsistent!' 286 | 287 | def _encode(self, image): 288 | if self._encode_type: 289 | if image.shape[-1] == 1: 290 | image.shape = image.shape[:2] 291 | byte = io.BytesIO() 292 | image = Image.fromarray(image) 293 | if self._encode_type == 'jpg': 294 | image.save(byte, 'JPEG', quality=self._quality) 295 | elif self._encode_type == 'png': 296 | image.save(byte, 'PNG') 297 | image_bytes = byte.getvalue() 298 | else: 299 | image_bytes = image.tobytes() 300 | return image_bytes 301 | -------------------------------------------------------------------------------- /tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.ops.layers import * 6 | -------------------------------------------------------------------------------- /tflib/ops/layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | 9 | def flatten_fully_connected(inputs, 10 | num_outputs, 11 | activation_fn=tf.nn.relu, 12 | normalizer_fn=None, 13 | normalizer_params=None, 14 | weights_initializer=slim.xavier_initializer(), 15 | weights_regularizer=None, 16 | biases_initializer=tf.zeros_initializer(), 17 | biases_regularizer=None, 18 | reuse=None, 19 | variables_collections=None, 20 | outputs_collections=None, 21 | trainable=True, 22 | scope=None): 23 | with tf.variable_scope(scope, 'flatten_fully_connected', [inputs]): 24 | if inputs.shape.ndims > 2: 25 | inputs = slim.flatten(inputs) 26 | return slim.fully_connected(inputs, 27 | num_outputs, 28 | activation_fn, 29 | normalizer_fn, 30 | normalizer_params, 31 | weights_initializer, 32 | weights_regularizer, 33 | biases_initializer, 34 | biases_regularizer, 35 | reuse, 36 | variables_collections, 37 | outputs_collections, 38 | trainable, 39 | scope) 40 | 41 | flatten_dense = flatten_fully_connected 42 | -------------------------------------------------------------------------------- /tflib/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import re 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def session(graph=None, allow_soft_placement=True, 11 | log_device_placement=False, allow_growth=True): 12 | """Return a Session with simple config.""" 13 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement, 14 | log_device_placement=log_device_placement) 15 | config.gpu_options.allow_growth = allow_growth 16 | return tf.Session(graph=graph, config=config) 17 | 18 | 19 | def print_tensor(tensors): 20 | if not isinstance(tensors, (list, tuple)): 21 | tensors = [tensors] 22 | 23 | for i, tensor in enumerate(tensors): 24 | ctype = str(type(tensor)) 25 | if 'Tensor' in ctype: 26 | type_name = 'Tensor' 27 | elif 'Variable' in ctype: 28 | type_name = 'Variable' 29 | else: 30 | raise Exception('Not a Tensor or Variable!') 31 | 32 | print(str(i) + (': %s("%s", shape=%s, dtype=%s, device=%s)' 33 | % (type_name, tensor.name, str(tensor.get_shape()), 34 | tensor.dtype.name, tensor.device))) 35 | 36 | prt = print_tensor 37 | 38 | 39 | def shape(tensor): 40 | sp = tensor.get_shape().as_list() 41 | return [num if num is not None else -1 for num in sp] 42 | 43 | 44 | def summary(tensor_collection, 45 | summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram'], 46 | scope=None): 47 | """Summary. 48 | 49 | Usage: 50 | 1. summary(tensor) 51 | 2. summary([tensor_a, tensor_b]) 52 | 3. summary({tensor_a: 'a', tensor_b: 'b}) 53 | """ 54 | def _summary(tensor, name, summary_type): 55 | """Attach a lot of summaries to a Tensor.""" 56 | if name is None: 57 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 58 | # session. This helps the clarity of presentation on tensorboard. 59 | name = re.sub('%s_[0-9]*/' % 'tower', '', tensor.name) 60 | name = re.sub(':', '-', name) 61 | 62 | summaries = [] 63 | if len(tensor.shape) == 0: 64 | summaries.append(tf.summary.scalar(name, tensor)) 65 | else: 66 | if 'mean' in summary_type: 67 | mean = tf.reduce_mean(tensor) 68 | summaries.append(tf.summary.scalar(name + '/mean', mean)) 69 | if 'stddev' in summary_type: 70 | mean = tf.reduce_mean(tensor) 71 | stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean))) 72 | summaries.append(tf.summary.scalar(name + '/stddev', stddev)) 73 | if 'max' in summary_type: 74 | summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor))) 75 | if 'min' in summary_type: 76 | summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor))) 77 | if 'sparsity' in summary_type: 78 | summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor))) 79 | if 'histogram' in summary_type: 80 | summaries.append(tf.summary.histogram(name, tensor)) 81 | return tf.summary.merge(summaries) 82 | 83 | if not isinstance(tensor_collection, (list, tuple, dict)): 84 | tensor_collection = [tensor_collection] 85 | 86 | with tf.name_scope(scope, 'summary'): 87 | summaries = [] 88 | if isinstance(tensor_collection, (list, tuple)): 89 | for tensor in tensor_collection: 90 | summaries.append(_summary(tensor, None, summary_type)) 91 | else: 92 | for tensor, name in tensor_collection.items(): 93 | summaries.append(_summary(tensor, name, summary_type)) 94 | return tf.summary.merge(summaries) 95 | 96 | 97 | def counter(start=0, scope=None): 98 | with tf.variable_scope(scope, 'counter'): 99 | counter = tf.get_variable(name='counter', 100 | initializer=tf.constant_initializer(start), 101 | shape=(), 102 | dtype=tf.int64) 103 | update_cnt = tf.assign(counter, tf.add(counter, 1)) 104 | return counter, update_cnt 105 | -------------------------------------------------------------------------------- /tflib/variable.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def tensors_filter(tensors, filters, combine_type='or'): 9 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!' 10 | assert isinstance(filters, (str, list, tuple)), '`filters` should be a string or a list(tuple) of strings!' 11 | assert combine_type == 'or' or combine_type == 'and', "`combine_type` should be 'or' or 'and'!" 12 | 13 | if isinstance(filters, str): 14 | filters = [filters] 15 | 16 | f_tens = [] 17 | for ten in tensors: 18 | if combine_type == 'or': 19 | for filt in filters: 20 | if filt in ten.name: 21 | f_tens.append(ten) 22 | break 23 | elif combine_type == 'and': 24 | all_pass = True 25 | for filt in filters: 26 | if filt not in ten.name: 27 | all_pass = False 28 | break 29 | if all_pass: 30 | f_tens.append(ten) 31 | return f_tens 32 | 33 | 34 | def global_variables(filters=None, combine_type='or'): 35 | global_vars = tf.global_variables() 36 | if filters is None: 37 | return global_vars 38 | else: 39 | return tensors_filter(global_vars, filters, combine_type) 40 | 41 | 42 | def trainable_variables(filters=None, combine_type='or'): 43 | t_var = tf.trainable_variables() 44 | if filters is None: 45 | return t_var 46 | else: 47 | return tensors_filter(t_var, filters, combine_type) 48 | -------------------------------------------------------------------------------- /tflib/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.vision.dataset import * 6 | -------------------------------------------------------------------------------- /tflib/vision/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.vision.dataset.mnist import * 6 | -------------------------------------------------------------------------------- /tflib/vision/dataset/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gzip 6 | import multiprocessing 7 | import os 8 | import struct 9 | import subprocess 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from tflib.data.memory_data import MemoryData 14 | 15 | 16 | _N_CPU = multiprocessing.cpu_count() 17 | 18 | 19 | def unzip_gz(file_name): 20 | unzip_name = file_name.replace('.gz', '') 21 | gz_file = gzip.GzipFile(file_name) 22 | open(unzip_name, 'wb+').write(gz_file.read()) 23 | gz_file.close() 24 | 25 | 26 | def mnist_download(download_dir): 27 | url_base = 'http://yann.lecun.com/exdb/mnist/' 28 | file_names = ['train-images-idx3-ubyte.gz', 29 | 'train-labels-idx1-ubyte.gz', 30 | 't10k-images-idx3-ubyte.gz', 31 | 't10k-labels-idx1-ubyte.gz'] 32 | for file_name in file_names: 33 | url = url_base + file_name 34 | save_path = os.path.join(download_dir, file_name) 35 | cmd = ['curl', url, '-o', save_path] 36 | print('Downloading ', file_name) 37 | if not os.path.exists(save_path): 38 | subprocess.call(cmd) 39 | else: 40 | print('%s exists, skip!' % file_name) 41 | 42 | 43 | def mnist_load(data_dir, split='train'): 44 | """Load MNIST dataset, modified from https://gist.github.com/akesling/5358964. 45 | 46 | Returns: 47 | `imgs`, `lbls`, `num`. 48 | 49 | `imgs` : [-1.0, 1.0] float64 images of shape (N * H * W). 50 | `lbls` : Int labels of shape (N,). 51 | `num` : # of datas. 52 | """ 53 | mnist_download(data_dir) 54 | 55 | if split == 'train': 56 | fname_img = os.path.join(data_dir, 'train-images-idx3-ubyte') 57 | fname_lbl = os.path.join(data_dir, 'train-labels-idx1-ubyte') 58 | elif split == 'test': 59 | fname_img = os.path.join(data_dir, 't10k-images-idx3-ubyte') 60 | fname_lbl = os.path.join(data_dir, 't10k-labels-idx1-ubyte') 61 | else: 62 | raise ValueError("split must be 'test' or 'train'") 63 | 64 | if not os.path.exists(fname_img): 65 | unzip_gz(fname_img + '.gz') 66 | if not os.path.exists(fname_lbl): 67 | unzip_gz(fname_lbl + '.gz') 68 | 69 | with open(fname_lbl, 'rb') as flbl: 70 | struct.unpack('>II', flbl.read(8)) 71 | lbls = np.fromfile(flbl, dtype=np.int8) 72 | 73 | with open(fname_img, 'rb') as fimg: 74 | _, _, rows, cols = struct.unpack('>IIII', fimg.read(16)) 75 | imgs = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbls), rows, cols) 76 | imgs = imgs / 127.5 - 1 77 | 78 | return imgs, lbls, len(lbls) 79 | 80 | 81 | class Mnist(MemoryData): 82 | 83 | def __init__(self, data_dir, batch_size, split='train', prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, 84 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None): 85 | imgs, lbls, _ = mnist_load(data_dir, split) 86 | imgs.shape = imgs.shape + (1,) 87 | 88 | imgs_pl = tf.placeholder(tf.float32, imgs.shape) 89 | lbls_pl = tf.placeholder(tf.int64, lbls.shape) 90 | 91 | memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl} 92 | 93 | self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls} 94 | super(Mnist, self).__init__(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter, 95 | map_func, num_threads, shuffle, buffer_size, repeat, sess) 96 | 97 | def reset(self): 98 | super(Mnist, self).reset(self.feed_dict) 99 | 100 | if __name__ == '__main__': 101 | import imlib as im 102 | from tflib import session 103 | sess = session() 104 | mnist = Mnist('/tmp', 5000, repeat=1, sess=sess) 105 | print(len(mnist)) 106 | for batch in mnist: 107 | print(batch['lbl'][-1]) 108 | im.imshow(batch['img'][-1].squeeze()) 109 | im.show() 110 | sess.close() 111 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import datetime 7 | from functools import partial 8 | import json 9 | import traceback 10 | 11 | import imlib as im 12 | import numpy as np 13 | import pylib 14 | import tensorflow as tf 15 | import tflib as tl 16 | import utils 17 | 18 | 19 | # ============================================================================== 20 | # = param = 21 | # ============================================================================== 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--epoch', dest='epoch', type=int, default=50) 25 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64) 26 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate') 27 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=32, help='dimension of latent') 28 | parser.add_argument('--beta', dest='beta', type=float, default=0.1) 29 | parser.add_argument('--dataset', dest='dataset_name', default='mnist', choices=['mnist', 'celeba']) 30 | parser.add_argument('--model', dest='model_name', default='mlp_mnist', choices=['mlp_mnist', 'conv_mnist', 'conv_64']) 31 | parser.add_argument('--experiment_name', dest='experiment_name', default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 32 | 33 | args = parser.parse_args() 34 | 35 | epoch = args.epoch 36 | batch_size = args.batch_size 37 | lr = args.lr 38 | z_dim = args.z_dim 39 | beta = args.beta 40 | 41 | dataset_name = args.dataset_name 42 | model_name = args.model_name 43 | experiment_name = args.experiment_name 44 | 45 | pylib.mkdir('./output/%s' % experiment_name) 46 | with open('./output/%s/setting.txt' % experiment_name, 'w') as f: 47 | f.write(json.dumps(vars(args), indent=4, separators=(',', ':'))) 48 | 49 | # dataset and models 50 | Dataset, img_shape, get_imgs = utils.get_dataset(dataset_name) 51 | dataset = Dataset(batch_size=batch_size) 52 | dataset_val = Dataset(batch_size=100) 53 | Enc, Dec = utils.get_models(model_name) 54 | Enc = partial(Enc, z_dim=z_dim) 55 | Dec = partial(Dec, channels=img_shape[2]) 56 | 57 | 58 | # ============================================================================== 59 | # = graph = 60 | # ============================================================================== 61 | 62 | def enc_dec(img, is_training=True): 63 | # encode 64 | z_mu, z_log_sigma_sq = Enc(img, is_training=is_training) 65 | 66 | # sample 67 | epsilon = tf.random_normal(tf.shape(z_mu)) 68 | if is_training: 69 | z = z_mu + tf.exp(0.5 * z_log_sigma_sq) * epsilon 70 | else: 71 | z = z_mu 72 | 73 | # decode 74 | img_rec = Dec(z, is_training=is_training) 75 | 76 | return z_mu, z_log_sigma_sq, img_rec 77 | 78 | # input 79 | img = tf.placeholder(tf.float32, [None] + img_shape) 80 | z_sample = tf.placeholder(tf.float32, [None, z_dim]) 81 | 82 | # encode & decode 83 | z_mu, z_log_sigma_sq, img_rec = enc_dec(img) 84 | 85 | # loss 86 | rec_loss = tf.losses.mean_squared_error(img, img_rec) 87 | kld_loss = -tf.reduce_mean(0.5 * (1 + z_log_sigma_sq - z_mu**2 - tf.exp(z_log_sigma_sq))) 88 | loss = rec_loss + kld_loss * beta 89 | 90 | # otpim 91 | step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(loss) 92 | 93 | # summary 94 | summary = tl.summary({rec_loss: 'rec_loss', kld_loss: 'kld_loss'}) 95 | 96 | # sample 97 | _, _, img_rec_sample = enc_dec(img, is_training=False) 98 | img_sample = Dec(z_sample, is_training=False) 99 | 100 | 101 | # ============================================================================== 102 | # = train = 103 | # ============================================================================== 104 | 105 | # session 106 | sess = tl.session() 107 | 108 | # saver 109 | saver = tf.train.Saver(max_to_keep=1) 110 | 111 | # summary writer 112 | summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph) 113 | 114 | # initialization 115 | ckpt_dir = './output/%s/checkpoints' % experiment_name 116 | pylib.mkdir(ckpt_dir) 117 | try: 118 | tl.load_checkpoint(ckpt_dir, sess) 119 | except: 120 | sess.run(tf.global_variables_initializer()) 121 | 122 | # train 123 | try: 124 | img_ipt_sample = get_imgs(dataset_val.get_next()) 125 | z_ipt_sample = np.random.normal(size=[100, z_dim]) 126 | 127 | it = -1 128 | for ep in range(epoch): 129 | dataset.reset() 130 | it_per_epoch = it_in_epoch if it != -1 else -1 131 | it_in_epoch = 0 132 | for batch in dataset: 133 | it += 1 134 | it_in_epoch += 1 135 | 136 | # batch data 137 | img_ipt = get_imgs(batch) 138 | 139 | # train D 140 | summary_opt, _ = sess.run([summary, step], feed_dict={img: img_ipt}) 141 | summary_writer.add_summary(summary_opt, it) 142 | 143 | # display 144 | if (it + 1) % 1 == 0: 145 | print("Epoch: (%3d) (%5d/%5d)" % (ep, it_in_epoch, it_per_epoch)) 146 | 147 | # sample 148 | if (it + 1) % 1000 == 0: 149 | save_dir = './output/%s/sample_training' % experiment_name 150 | pylib.mkdir(save_dir) 151 | 152 | img_rec_opt_sample = sess.run(img_rec_sample, feed_dict={img: img_ipt_sample}) 153 | ipt_rec = np.concatenate((img_ipt_sample, img_rec_opt_sample), axis=2).squeeze() 154 | img_opt_sample = sess.run(img_sample, feed_dict={z_sample: z_ipt_sample}).squeeze() 155 | 156 | im.imwrite(im.immerge(ipt_rec, padding=img_shape[0] // 8), '%s/Epoch_(%d)_(%dof%d)_img_rec.jpg' % (save_dir, ep, it_in_epoch, it_per_epoch)) 157 | im.imwrite(im.immerge(img_opt_sample), '%s/Epoch_(%d)_(%dof%d)_img_sample.jpg' % (save_dir, ep, it_in_epoch, it_per_epoch)) 158 | 159 | save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep)) 160 | print('Model is saved in file: %s' % save_path) 161 | except: 162 | traceback.print_exc() 163 | finally: 164 | sess.close() 165 | -------------------------------------------------------------------------------- /traversal.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | from functools import partial 7 | import json 8 | import traceback 9 | 10 | import imlib as im 11 | import numpy as np 12 | import pylib 13 | import tensorflow as tf 14 | import tflib as tl 15 | import utils 16 | 17 | 18 | # ============================================================================== 19 | # = param = 20 | # ============================================================================== 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--experiment_name', dest='experiment_name', help='experiment_name') 24 | args_ = parser.parse_args() 25 | with open('./output/%s/setting.txt' % args_.experiment_name) as f: 26 | args = json.load(f) 27 | 28 | z_dim = args["z_dim"] 29 | 30 | dataset_name = args["dataset_name"] 31 | model_name = args["model_name"] 32 | experiment_name = args_.experiment_name 33 | 34 | # dataset and models 35 | _, img_shape, _ = utils.get_dataset(dataset_name) 36 | _, Dec = utils.get_models(model_name) 37 | Dec = partial(Dec, channels=img_shape[2]) 38 | 39 | 40 | # ============================================================================== 41 | # = graph = 42 | # ============================================================================== 43 | 44 | # input 45 | z_sample = tf.placeholder(tf.float32, [None, z_dim]) 46 | 47 | # sample 48 | img_sample = Dec(z_sample, is_training=False) 49 | 50 | 51 | # ============================================================================== 52 | # = train = 53 | # ============================================================================== 54 | 55 | # session 56 | sess = tl.session() 57 | 58 | # initialization 59 | ckpt_dir = './output/%s/checkpoints' % experiment_name 60 | try: 61 | tl.load_checkpoint(ckpt_dir, sess) 62 | except: 63 | raise Exception(' [*] No checkpoint!') 64 | 65 | # train 66 | try: 67 | z_ipt_sample_ = np.random.normal(size=[10, z_dim]) 68 | for i in range(z_dim): 69 | z_ipt_sample = np.copy(z_ipt_sample_) 70 | img_opt_samples = [] 71 | for v in np.linspace(-3, 3, 10): 72 | z_ipt_sample[:, i] = v 73 | img_opt_samples.append(sess.run(img_sample, feed_dict={z_sample: z_ipt_sample}).squeeze()) 74 | 75 | save_dir = './output/%s/sample_traversal' % experiment_name 76 | pylib.mkdir(save_dir) 77 | im.imwrite(im.immerge(np.concatenate(img_opt_samples, axis=2), 10), '%s/traversal_d%d.jpg' % (save_dir, i)) 78 | except: 79 | traceback.print_exc() 80 | finally: 81 | sess.close() 82 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | import glob as glob 7 | 8 | import models 9 | import pylib 10 | import tensorflow as tf 11 | import tflib as tl 12 | 13 | 14 | def get_dataset(dataset_name): 15 | if dataset_name == 'mnist': 16 | # dataset 17 | pylib.mkdir('./data/mnist') 18 | Dataset = partial(tl.Mnist, data_dir='./data/mnist', repeat=1) 19 | 20 | # shape 21 | img_shape = [28, 28, 1] 22 | 23 | # index func 24 | def get_imgs(batch): 25 | return batch['img'] 26 | 27 | return Dataset, img_shape, get_imgs 28 | 29 | elif dataset_name == 'celeba': 30 | # dataset 31 | def _map_func(img): 32 | crop_size = 108 33 | re_size = 64 34 | img = tf.image.crop_to_bounding_box(img, (218 - crop_size) // 2, (178 - crop_size) // 2, crop_size, crop_size) 35 | img = tf.image.resize_images(img, [re_size, re_size], method=tf.image.ResizeMethod.BICUBIC) 36 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 37 | return img 38 | 39 | paths = glob.glob('./data/celeba/img_align_celeba/*.jpg') 40 | Dataset = partial(tl.DiskImageData, img_paths=paths, repeat=1, map_func=_map_func) 41 | 42 | # shape 43 | img_shape = [64, 64, 3] 44 | 45 | # index func 46 | def get_imgs(batch): 47 | return batch 48 | 49 | return Dataset, img_shape, get_imgs 50 | 51 | 52 | def get_models(model_name): 53 | return getattr(models, model_name)() 54 | --------------------------------------------------------------------------------