├── .gitignore
├── LICENSE
├── README.md
├── imlib
├── __init__.py
├── basic.py
├── dtype.py
├── encode.py
└── transform.py
├── models.py
├── pics
├── z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_rec.jpg
├── z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_sample.jpg
├── z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_rec.jpg
├── z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_sample.jpg
├── z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_rec.jpg
└── z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_sample.jpg
├── pylib
├── __init__.py
├── path.py
└── timer.py
├── tflib
├── __init__.py
├── checkpoint.py
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── disk_image.py
│ ├── memory_data.py
│ ├── tfrecord.py
│ └── tfrecord_creator.py
├── ops
│ ├── __init__.py
│ └── layers.py
├── utils.py
├── variable.py
└── vision
│ ├── __init__.py
│ └── dataset
│ ├── __init__.py
│ └── mnist.py
├── train.py
├── traversal.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__/
3 | /data/
4 | /output/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018 hezhenliang
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a
4 | copy of this software and associated documentation files (the "Software"),
5 | to deal in the Software without restriction, including without limitation
6 | the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | and/or sell copies of the Software, and to permit persons to whom the
8 | Software is furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
(beta-)VAE
2 |
3 | Tensorflow implementation of [VAE](http://arxiv.org/abs/1312.6114) and [beta-VAE](https://openreview.net/pdf?id=Sy2fzU9gl)
4 |
5 | ## Exemplar results
6 |
7 | - Celeba
8 |
9 | - ConvNet (z_dim: 100, beta: 0.05)
10 |
11 | Generation | Reconstruction
12 | :---: | :---:
13 |
|
14 |
15 | - Mnist
16 |
17 | - ConvNet (z_dim: 10, beta: 0.1)
18 |
19 | Generation | Reconstruction
20 | :---: | :---:
21 |
|
22 |
23 | - MLP (z_dim: 10, beta: 0.1)
24 |
25 | Generation | Reconstruction
26 | :---: | :---:
27 |
|
28 |
29 | ## Usage
30 |
31 | - Prerequisites
32 | - Tensorflow 1.8
33 | - Python 2.7 or 3.6
34 |
35 |
36 | - Examples of training
37 |
38 | ```console
39 | CUDA_VISIBLE_DEVICES=0 python train.py --z_dim 10 --beta 0.1 --dataset mnist --model mlp_mnist --experiment_name z10_beta0.1_mnist_mlp
40 | CUDA_VISIBLE_DEVICES=0 python train.py --z_dim 10 --beta 0.1 --dataset mnist --model conv_mnist --experiment_name z10_beta0.1_mnist_conv
41 | CUDA_VISIBLE_DEVICES=0 python train.py --z_dim 32 --beta 0.1 --dataset celeba --model conv_64 --experiment_name z32_beta0.1_celeba_conv
42 | ```
43 |
44 | ## Datasets
45 |
46 | 1. Celeba should be prepared by yourself in ***./data/celeba/img_align_celeba/*.jpg***
47 | - Download the dataset: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAB06FXaQRUNtjW9ntaoPGvCa?dl=0
48 | - the above links might be inaccessible, the alternatives are
49 | - ***img_align_celeba.zip***
50 | - https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FImg or
51 | - https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
52 | 2. Mnist will be automatically downloaded
--------------------------------------------------------------------------------
/imlib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from imlib.basic import *
6 | from imlib.dtype import *
7 | from imlib.encode import *
8 | from imlib.transform import *
9 |
--------------------------------------------------------------------------------
/imlib/basic.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from imlib.dtype import im2float
6 | import numpy as np
7 | import skimage.io as iio
8 |
9 |
10 | def imread(path, as_gray=False):
11 | """Read image.
12 |
13 | Returns:
14 | Float64 image in [-1.0, 1.0].
15 | """
16 | image = iio.imread(path, as_gray)
17 | if image.dtype == np.uint8:
18 | image = image / 127.5 - 1
19 | return image
20 |
21 |
22 | def imwrite(image, path):
23 | """Save an [-1.0, 1.0] image."""
24 | iio.imsave(path, im2float(image))
25 |
26 |
27 | def imshow(image):
28 | """Show a [-1.0, 1.0] image."""
29 | iio.imshow(im2float(image))
30 |
31 |
32 | show = iio.show
33 |
--------------------------------------------------------------------------------
/imlib/dtype.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 |
8 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf):
9 | # check type
10 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!'
11 |
12 | # check dtype
13 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes]
14 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes
15 |
16 | # check nan and inf
17 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!'
18 |
19 | # check value
20 | if min_value not in [None, -np.inf]:
21 | l = '[' + str(min_value)
22 | else:
23 | l = '(-inf'
24 | min_value = -np.inf
25 | if max_value not in [None, np.inf]:
26 | r = str(max_value) + ']'
27 | else:
28 | r = 'inf)'
29 | max_value = np.inf
30 | assert np.min(images) >= min_value - 1e-5 and np.max(images) <= max_value + 1e-5, \
31 | '`images` should be in the range of %s!' % (l + ',' + r)
32 |
33 |
34 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
35 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype."""
36 | _check(images, [np.float32, np.float64], -1.0, 1.0)
37 | dtype = dtype if dtype else images.dtype
38 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)
39 |
40 |
41 | def float2im(images):
42 | """Transform images from [0, 1.0] to [-1.0, 1.0]."""
43 | _check(images, [np.float32, np.float64], 0.0, 1.0)
44 | return images * 2 - 1.0
45 |
46 |
47 | def float2uint(images):
48 | """Transform images from [0, 1.0] to uint8."""
49 | _check(images, [np.float32, np.float64], -0.0, 1.0)
50 | return (images * 255).astype(np.uint8)
51 |
52 |
53 | def im2uint(images):
54 | """Transform images from [-1.0, 1.0] to uint8."""
55 | return to_range(images, 0, 255, np.uint8)
56 |
57 |
58 | def im2float(images):
59 | """Transform images from [-1.0, 1.0] to [0.0, 1.0]."""
60 | return to_range(images, 0.0, 1.0)
61 |
62 |
63 | def uint2im(images):
64 | """Transform images from uint8 to [-1.0, 1.0] of float64."""
65 | _check(images, np.uint8)
66 | return images / 127.5 - 1.0
67 |
68 |
69 | def uint2float(images):
70 | """Transform images from uint8 to [0.0, 1.0] of float64."""
71 | _check(images, np.uint8)
72 | return images / 255.0
73 |
--------------------------------------------------------------------------------
/imlib/encode.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import io
6 |
7 | from imlib.dtype import im2uint, uint2im
8 | import numpy as np
9 | from PIL import Image
10 |
11 |
12 | def imencode(image, format='PNG', quality=95):
13 | """Encode an [-1.0, 1.0] image into byte string.
14 |
15 | Args:
16 | format : 'PNG' or 'JPEG'.
17 | quality : Only for 'JPEG'.
18 |
19 | Returns:
20 | Byte string.
21 | """
22 | byte_io = io.BytesIO()
23 | image = Image.fromarray(im2uint(image))
24 | image.save(byte_io, format=format, quality=quality)
25 | bytes = byte_io.getvalue()
26 | return bytes
27 |
28 |
29 | def imdecode(bytes):
30 | """Decode byte string to float64 image in [-1.0, 1.0].
31 |
32 | Args:
33 | bytes: Byte string.
34 |
35 | Returns:
36 | A float64 image in [-1.0, 1.0].
37 | """
38 | byte_io = io.BytesIO()
39 | byte_io.write(bytes)
40 | image = np.array(Image.open(byte_io))
41 | image = uint2im(image)
42 | return image
43 |
--------------------------------------------------------------------------------
/imlib/transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import skimage.color as color
7 | import skimage.transform as transform
8 |
9 |
10 | rgb2gray = color.rgb2gray
11 | gray2rgb = color.gray2rgb
12 |
13 | imresize = transform.resize
14 | imrescale = transform.rescale
15 |
16 |
17 | def immerge(images, n_row=None, n_col=None, padding=0, pad_value=0):
18 | """Merge images to an image with (n_row * h) * (n_col * w).
19 |
20 | `images` is in shape of N * H * W(* C=1 or 3)
21 | """
22 | n = images.shape[0]
23 | if n_row:
24 | n_row = max(min(n_row, n), 1)
25 | n_col = int(n - 0.5) // n_row + 1
26 | elif n_col:
27 | n_col = max(min(n_col, n), 1)
28 | n_row = int(n - 0.5) // n_col + 1
29 | else:
30 | n_row = int(n ** 0.5)
31 | n_col = int(n - 0.5) // n_row + 1
32 |
33 | h, w = images.shape[1], images.shape[2]
34 | shape = (h * n_row + padding * (n_row - 1),
35 | w * n_col + padding * (n_col - 1))
36 | if images.ndim == 4:
37 | shape += (images.shape[3],)
38 | img = np.full(shape, pad_value, dtype=images.dtype)
39 |
40 | for idx, image in enumerate(images):
41 | i = idx % n_col
42 | j = idx // n_col
43 | img[j * (h + padding):j * (h + padding) + h,
44 | i * (w + padding):i * (w + padding) + w, ...] = image
45 |
46 | return img
47 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 |
7 | import tensorflow as tf
8 | import tensorflow.contrib.slim as slim
9 | import tflib as tl
10 |
11 | conv = partial(slim.conv2d, activation_fn=None)
12 | dconv = partial(slim.conv2d_transpose, activation_fn=None)
13 | fc = partial(tl.flatten_fully_connected, activation_fn=None)
14 | relu = tf.nn.relu
15 | lrelu = tf.nn.leaky_relu
16 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None)
17 |
18 |
19 | def mlp_mnist():
20 |
21 | def Enc(img, z_dim, dim=512, is_training=True):
22 | fc_relu = partial(fc, activation_fn=relu)
23 |
24 | with tf.variable_scope('Enc', reuse=tf.AUTO_REUSE):
25 | y = fc_relu(img, dim)
26 | y = fc_relu(y, dim * 2)
27 | z_mu = fc(y, z_dim)
28 | z_log_sigma_sq = fc(y, z_dim)
29 | return z_mu, z_log_sigma_sq
30 |
31 | def Dec(z, dim=512, channels=1, is_training=True):
32 | fc_relu = partial(fc, activation_fn=relu)
33 |
34 | with tf.variable_scope('Dec', reuse=tf.AUTO_REUSE):
35 | y = fc_relu(z, dim * 2)
36 | y = fc_relu(y, dim)
37 | y = tf.tanh(fc(y, 28 * 28 * channels))
38 | img = tf.reshape(y, [-1, 28, 28, channels])
39 | return img
40 |
41 | return Enc, Dec
42 |
43 |
44 | def conv_mnist():
45 |
46 | def Enc(img, z_dim, dim=64, is_training=True):
47 | bn = partial(batch_norm, is_training=is_training)
48 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu)
49 |
50 | with tf.variable_scope('Enc', reuse=tf.AUTO_REUSE):
51 | y = conv_bn_lrelu(img, dim, 5, 2)
52 | y = conv_bn_lrelu(y, dim * 2, 5, 2)
53 | z_mu = fc(y, z_dim)
54 | z_log_sigma_sq = fc(y, z_dim)
55 | return z_mu, z_log_sigma_sq
56 |
57 | def Dec(z, dim=64, channels=1, is_training=True):
58 | bn = partial(batch_norm, is_training=is_training)
59 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu)
60 |
61 | with tf.variable_scope('Dec', reuse=tf.AUTO_REUSE):
62 | y = relu(fc(z, 7 * 7 * dim * 2))
63 | y = tf.reshape(y, [-1, 7, 7, dim * 2])
64 | y = dconv_bn_relu(y, dim * 1, 5, 2)
65 | img = tf.tanh(dconv(y, channels, 5, 2))
66 | return img
67 |
68 | return Enc, Dec
69 |
70 |
71 | def conv_64():
72 |
73 | def Enc(img, z_dim, dim=64, is_training=True):
74 | bn = partial(batch_norm, is_training=is_training)
75 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu)
76 |
77 | with tf.variable_scope('Enc', reuse=tf.AUTO_REUSE):
78 | y = conv_bn_lrelu(img, dim, 5, 2)
79 | y = conv_bn_lrelu(y, dim * 2, 5, 2)
80 | y = conv_bn_lrelu(y, dim * 4, 5, 2)
81 | y = conv_bn_lrelu(y, dim * 8, 5, 2)
82 | z_mu = fc(y, z_dim)
83 | z_log_sigma_sq = fc(y, z_dim)
84 | return z_mu, z_log_sigma_sq
85 |
86 | def Dec(z, dim=64, channels=3, is_training=True):
87 | bn = partial(batch_norm, is_training=is_training)
88 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu)
89 |
90 | with tf.variable_scope('Dec', reuse=tf.AUTO_REUSE):
91 | y = relu(fc(z, 4 * 4 * dim * 8))
92 | y = tf.reshape(y, [-1, 4, 4, dim * 8])
93 | y = dconv_bn_relu(y, dim * 4, 5, 2)
94 | y = dconv_bn_relu(y, dim * 2, 5, 2)
95 | y = dconv_bn_relu(y, dim * 1, 5, 2)
96 | img = tf.tanh(dconv(y, channels, 5, 2))
97 | return img
98 |
99 | return Enc, Dec
100 |
--------------------------------------------------------------------------------
/pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_rec.jpg
--------------------------------------------------------------------------------
/pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z100_beta0.05_celeba_conv_Epoch_(49)_(2915of3165)_img_sample.jpg
--------------------------------------------------------------------------------
/pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_rec.jpg
--------------------------------------------------------------------------------
/pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_conv_Epoch_(49)_(87of937)_img_sample.jpg
--------------------------------------------------------------------------------
/pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_rec.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_rec.jpg
--------------------------------------------------------------------------------
/pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/VAE-Tensorflow/c204a9438a0e7e046ad6cc7257942528accbbe60/pics/z10_beta0.1_mnist_mlp_Epoch_(49)_(87of937)_img_sample.jpg
--------------------------------------------------------------------------------
/pylib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from pylib.path import *
6 | from pylib.timer import *
7 |
--------------------------------------------------------------------------------
/pylib/path.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import fnmatch
6 | import os
7 | import sys
8 |
9 |
10 | def add_path(paths):
11 | if not isinstance(paths, (list, tuple)):
12 | paths = [paths]
13 | for path in paths:
14 | if path not in sys.path:
15 | sys.path.insert(0, path)
16 |
17 |
18 | def mkdir(paths):
19 | if not isinstance(paths, (list, tuple)):
20 | paths = [paths]
21 | for path in paths:
22 | if not os.path.isdir(path):
23 | os.makedirs(path)
24 |
25 |
26 | def split(path):
27 | dir, name_ext = os.path.split(path)
28 | name, ext = os.path.splitext(name_ext)
29 | return dir, name, ext
30 |
31 |
32 | def directory(path):
33 | return split(path)[0]
34 |
35 |
36 | def name(path):
37 | return split(path)[1]
38 |
39 |
40 | def ext(path):
41 | return split(path)[2]
42 |
43 |
44 | def name_ext(path):
45 | return ''.join(split(path)[1:])
46 |
47 |
48 | asbpath = os.path.abspath
49 |
50 |
51 | join = os.path.join
52 |
53 |
54 | def match(dir, pat, recursive=False):
55 | if recursive:
56 | iterator = os.walk(dir)
57 | else:
58 | iterator = [next(os.walk(dir))]
59 | matches = []
60 | for root, _, file_names in iterator:
61 | for file_name in fnmatch.filter(file_names, pat):
62 | matches.append(os.path.join(root, file_name))
63 | return matches
64 |
--------------------------------------------------------------------------------
/pylib/timer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import datetime
6 | import timeit
7 |
8 |
9 | class Timer(object):
10 | """A timer as a context manager.
11 |
12 | Modified from https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py.
13 |
14 | Wraps around a timer. A custom timer can be passed
15 | to the constructor. The default timer is timeit.default_timer.
16 |
17 | Note that the latter measures wall clock time, not CPU time!
18 | On Unix systems, it corresponds to time.time.
19 | On Windows systems, it corresponds to time.clock.
20 |
21 | Arguments:
22 | is_output : If True, print output after exiting context.
23 | format : 'ms', 's' or 'datetime'
24 | """
25 |
26 | def __init__(self, timer=timeit.default_timer, is_output=True, fmt='s'):
27 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!"
28 | self._timer = timer
29 | self._is_output = is_output
30 | self._fmt = fmt
31 |
32 | def __enter__(self):
33 | """Start the timer in the context manager scope."""
34 | self.start()
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_value, exc_traceback):
38 | """Set the end time."""
39 | if self._is_output:
40 | print(str(self))
41 |
42 | def __str__(self):
43 | if self._fmt != 'datetime':
44 | return '%s %s' % (self.elapsed, self._fmt)
45 | else:
46 | return str(self.elapsed)
47 |
48 | def start(self):
49 | self.start_time = self._timer()
50 |
51 | @property
52 | def elapsed(self):
53 | """Return the current elapsed time since start."""
54 | e = self._timer() - self.start_time
55 |
56 | if self._fmt == 'ms':
57 | return e * 1000
58 | elif self._fmt == 's':
59 | return e
60 | elif self._fmt == 'datetime':
61 | return datetime.timedelta(seconds=e)
62 |
63 |
64 | def timer(**timer_kwargs):
65 | """Function decorator displaying the function execution time.
66 |
67 | All kwargs are the arguments taken by the Timer class constructor.
68 | """
69 | # store Timer kwargs in local variable so the namespace isn't polluted
70 | # by different level args and kwargs
71 |
72 | def wrapped_f(f):
73 | def wrapped(*args, **kwargs):
74 | fmt = '[*] function "%(function_name)s" execution time: %(execution_time)s [*]'
75 | with Timer(**timer_kwargs) as t:
76 | out = f(*args, **kwargs)
77 | context = {'function_name': f.__name__, 'execution_time': str(t)}
78 | print(fmt % context)
79 | return out
80 | return wrapped
81 |
82 | return wrapped_f
83 |
84 | if __name__ == "__main__":
85 | import time
86 |
87 | # 1
88 | print(1)
89 | with Timer() as t:
90 | time.sleep(1)
91 | print(t)
92 | time.sleep(1)
93 |
94 | with Timer(fmt='datetime') as t:
95 | time.sleep(1)
96 |
97 | # 2
98 | print(2)
99 | t = Timer(fmt='ms')
100 | t.start()
101 | time.sleep(2)
102 | print(t)
103 |
104 | t = Timer(fmt='datetime')
105 | t.start()
106 | time.sleep(1)
107 | print(t)
108 |
109 | # 3
110 | print(3)
111 |
112 | @timer(fmt='ms')
113 | def blah():
114 | time.sleep(2)
115 |
116 | blah()
117 |
--------------------------------------------------------------------------------
/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.checkpoint import *
6 | from tflib.data import *
7 | from tflib.ops import *
8 | from tflib.utils import *
9 | from tflib.variable import *
10 | from tflib.vision import *
11 |
--------------------------------------------------------------------------------
/tflib/checkpoint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def load_checkpoint(ckpt_dir_or_file, session, var_list=None):
11 | """Load checkpoint.
12 |
13 | This function add some useless ops to the graph. It is better
14 | to use tf.train.init_from_checkpoint(...).
15 | """
16 | if os.path.isdir(ckpt_dir_or_file):
17 | ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
18 |
19 | restorer = tf.train.Saver(var_list)
20 | restorer.restore(session, ckpt_dir_or_file)
21 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file)
22 |
23 |
24 | def init_from_checkpoint(ckpt_dir_or_file, assignment_map={'/': '/'}):
25 | # Use the checkpoint values for the variables' initializers. Note that this
26 | # function just changes the initializers but does not actually run them, and
27 | # you should still run the initializers manually.
28 | tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map)
29 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file)
30 |
--------------------------------------------------------------------------------
/tflib/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.data.dataset import *
6 | from tflib.data.disk_image import *
7 | from tflib.data.memory_data import *
8 | from tflib.data.tfrecord import *
9 | from tflib.data.tfrecord_creator import *
10 |
--------------------------------------------------------------------------------
/tflib/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 | import tensorflow.contrib.eager as tfe
9 | from tflib.utils import session
10 |
11 |
12 | _N_CPU = multiprocessing.cpu_count()
13 |
14 |
15 | def batch_dataset(dataset, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
16 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
17 | if filter:
18 | dataset = dataset.filter(filter)
19 |
20 | if map_func:
21 | dataset = dataset.map(map_func, num_parallel_calls=num_threads)
22 |
23 | if shuffle:
24 | dataset = dataset.shuffle(buffer_size)
25 |
26 | if drop_remainder:
27 | dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
28 | else:
29 | dataset = dataset.batch(batch_size)
30 |
31 | dataset = dataset.repeat(repeat).prefetch(prefetch_batch)
32 |
33 | return dataset
34 |
35 |
36 | class Dataset(object):
37 |
38 | def __init__(self):
39 | self._dataset = None
40 | self._iterator = None
41 | self._batch_op = None
42 | self._sess = None
43 |
44 | self._is_eager = tf.executing_eagerly()
45 | self._eager_iterator = None
46 |
47 | def __del__(self):
48 | if self._sess:
49 | self._sess.close()
50 |
51 | def __iter__(self):
52 | return self
53 |
54 | def __next__(self):
55 | try:
56 | b = self.get_next()
57 | except:
58 | raise StopIteration
59 | else:
60 | return b
61 |
62 | next = __next__
63 |
64 | def get_next(self):
65 | if self._is_eager:
66 | return self._eager_iterator.get_next()
67 | else:
68 | return self._sess.run(self._batch_op)
69 |
70 | def reset(self, feed_dict={}):
71 | if self._is_eager:
72 | self._eager_iterator = tfe.Iterator(self._dataset)
73 | else:
74 | self._sess.run(self._iterator.initializer, feed_dict=feed_dict)
75 |
76 | def _bulid(self, dataset, sess=None):
77 | self._dataset = dataset
78 |
79 | if self._is_eager:
80 | self._eager_iterator = tfe.Iterator(dataset)
81 | else:
82 | self._iterator = dataset.make_initializable_iterator()
83 | self._batch_op = self._iterator.get_next()
84 | if sess:
85 | self._sess = sess
86 | else:
87 | self._sess = session()
88 |
89 | try:
90 | self.reset()
91 | except:
92 | pass
93 |
94 | @property
95 | def dataset(self):
96 | return self._dataset
97 |
98 | @property
99 | def iterator(self):
100 | return self._iterator
101 |
102 | @property
103 | def batch_op(self):
104 | return self._batch_op
105 |
--------------------------------------------------------------------------------
/tflib/data/disk_image.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 | from tflib.data.dataset import batch_dataset, Dataset
9 |
10 |
11 | _N_CPU = multiprocessing.cpu_count()
12 |
13 |
14 | def disk_image_batch_dataset(img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
15 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
16 | """Disk image batch dataset.
17 |
18 | This function is suitable for jpg and png files
19 |
20 | Arguments:
21 | img_paths : String list or 1-D tensor, each of which is an iamge path
22 | labels : Label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label
23 | """
24 | if labels is None:
25 | dataset = tf.data.Dataset.from_tensor_slices(img_paths)
26 | elif isinstance(labels, tuple):
27 | dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels))
28 | else:
29 | dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))
30 |
31 | def parse_func(path, *label):
32 | img = tf.read_file(path)
33 | img = tf.image.decode_png(img, 3)
34 | return (img,) + label
35 |
36 | if map_func:
37 | def map_func_(*args):
38 | return map_func(*parse_func(*args))
39 | else:
40 | map_func_ = parse_func
41 |
42 | # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower
43 |
44 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
45 | map_func_, num_threads, shuffle, buffer_size, repeat)
46 |
47 | return dataset
48 |
49 |
50 | class DiskImageData(Dataset):
51 | """DiskImageData.
52 |
53 | This class is suitable for jpg and png files
54 |
55 | Arguments:
56 | img_paths : String list or 1-D tensor, each of which is an iamge path
57 | labels : Label list or tensor, each of which is a corresponding label
58 | """
59 |
60 | def __init__(self, img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
61 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
62 | super(DiskImageData, self).__init__()
63 | dataset = disk_image_batch_dataset(img_paths, batch_size, labels, prefetch_batch, drop_remainder, filter,
64 | map_func, num_threads, shuffle, buffer_size, repeat)
65 | self._bulid(dataset, sess)
66 | self._n_data = len(img_paths)
67 |
68 | def __len__(self):
69 | return self._n_data
70 |
71 |
72 | if __name__ == '__main__':
73 | import glob
74 |
75 | import imlib as im
76 | import numpy as np
77 | import pylib
78 |
79 | paths = glob.glob('/home/hezhenliang/Resource/face/CelebA/origin/origin/processed_by_hezhenliang/align_celeba/img_align_celeba/*.jpg')
80 | paths = sorted(paths)[182637:]
81 | labels = range(len(paths))
82 |
83 | def filter(x, y, *args):
84 | return tf.cond(y > 1, lambda: tf.constant(True), lambda: tf.constant(False))
85 |
86 | def map_func(x, *args):
87 | x = tf.image.resize_images(x, [256, 256])
88 | x = tf.to_float((x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)) * 2 - 1)
89 | return (x,) + args
90 |
91 | # tf.enable_eager_execution()
92 |
93 | s = tf.Session()
94 |
95 | data = DiskImageData(paths, 128, (labels, labels), filter=None, map_func=map_func, shuffle=False, sess=s)
96 |
97 | for _ in range(1000):
98 | with pylib.Timer():
99 | for i in range(100):
100 | b = data.get_next()
101 | print(b[1][0])
102 | print(b[2][0])
103 | im.imshow(np.array(b[0][0]))
104 | im.show()
105 | # data.reset()
106 |
--------------------------------------------------------------------------------
/tflib/data/memory_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 | from tflib.data.dataset import batch_dataset, Dataset
10 |
11 |
12 | _N_CPU = multiprocessing.cpu_count()
13 |
14 |
15 | def memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
16 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
17 | """Memory data batch dataset.
18 |
19 | `memory_data_dict` example:
20 | {'img': img_ndarray, 'label': label_ndarray} or
21 | {'img': img_tftensor, 'label': label_tftensor}
22 | * The value of each item of `memory_data_dict` is in shape of (N, ...).
23 | """
24 | dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict)
25 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
26 | map_func, num_threads, shuffle, buffer_size, repeat)
27 | return dataset
28 |
29 |
30 | class MemoryData(Dataset):
31 | """MemoryData.
32 |
33 | `memory_data_dict` example:
34 | {'img': img_ndarray, 'label': label_ndarray} or
35 | {'img': img_tftensor, 'label': label_tftensor}
36 | * The value of each item of `memory_data_dict` is in shape of (N, ...).
37 | """
38 |
39 | def __init__(self, memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
40 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
41 | super(MemoryData, self).__init__()
42 | dataset = memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter,
43 | map_func, num_threads, shuffle, buffer_size, repeat)
44 | self._bulid(dataset, sess)
45 | first_value = next(iter(memory_data_dict.values()))
46 | if isinstance(first_value, np.ndarray):
47 | self._n_data = len(first_value)
48 | else:
49 | self._n_data = first_value.get_shape().as_list()[0]
50 |
51 | def __len__(self):
52 | return self._n_data
53 |
54 | if __name__ == '__main__':
55 | import numpy as np
56 | data = {'a': np.array([1.0, 2, 3, 4, 5]),
57 | 'b': np.array([[1, 2],
58 | [2, 3],
59 | [3, 4],
60 | [4, 5],
61 | [5, 6]])}
62 |
63 | def filter(x):
64 | return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False))
65 |
66 | def map_func(x):
67 | x['a'] = x['a'] * 10
68 | return x
69 |
70 | # tf.enable_eager_execution()
71 |
72 | s = tf.Session()
73 |
74 | dataset = MemoryData(data, 2, filter=None, map_func=map_func,
75 | shuffle=True, buffer_size=4096, drop_remainder=True, repeat=4, sess=s)
76 |
77 | for i in range(5):
78 | print(map(dataset.get_next().__getitem__, ['b', 'a']))
79 |
80 | print([n.name for n in tf.get_default_graph().as_graph_def().node])
81 |
--------------------------------------------------------------------------------
/tflib/data/tfrecord.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | import glob
7 | import json
8 | import multiprocessing
9 | import os
10 |
11 | import tensorflow as tf
12 | from tflib.data.dataset import batch_dataset, Dataset
13 |
14 |
15 | _N_CPU = multiprocessing.cpu_count()
16 |
17 | _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024
18 |
19 | _DECODERS = {
20 | 'png': {'decoder': tf.image.decode_png, 'decode_param': dict()},
21 | 'jpg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()},
22 | 'jpeg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()},
23 | 'uint8': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.uint8)},
24 | 'int64': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.int64)},
25 | 'float32': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.float32)},
26 | }
27 |
28 |
29 | def tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True,
30 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
31 | """Tfrecord batch dataset.
32 |
33 | `infos` example:
34 | [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]},
35 | {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}]
36 | """
37 | dataset = tf.data.TFRecordDataset(tfrecord_files,
38 | compression_type=compression_type,
39 | buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES)
40 |
41 | features = {}
42 | for info in infos:
43 | features[info['name']] = tf.FixedLenFeature([], tf.string)
44 |
45 | def parse_func(serialized_example):
46 | example = tf.parse_single_example(serialized_example, features=features)
47 |
48 | feature_dict = {}
49 | for info in infos:
50 | name = info['name']
51 | decoder = info['decoder']
52 | decode_param = info['decode_param']
53 | shape = info['shape']
54 |
55 | feature = decoder(example[name], **decode_param)
56 | feature = tf.reshape(feature, shape)
57 | feature_dict[name] = feature
58 |
59 | return feature_dict
60 |
61 | dataset = dataset.map(parse_func, num_parallel_calls=num_threads)
62 |
63 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
64 | map_func, num_threads, shuffle, buffer_size, repeat)
65 |
66 | return dataset
67 |
68 |
69 | class TfrecordData(Dataset):
70 |
71 | def __init__(self, tfrecord_path, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True,
72 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
73 | super(TfrecordData, self).__init__()
74 |
75 | info_file = os.path.join(tfrecord_path, 'info.json')
76 | infos, self._data_num, compression_type = self._parse_json(info_file)
77 |
78 | self._shapes = {info['name']: tuple(info['shape']) for info in infos}
79 |
80 | tfrecord_files = sorted(glob.glob(os.path.join(tfrecord_path, '*.tfrecord')))
81 | dataset = tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch, drop_remainder,
82 | filter, map_func, num_threads, shuffle, buffer_size, repeat)
83 |
84 | self._bulid(dataset, sess)
85 |
86 | def __len__(self):
87 | return self._data_num
88 |
89 | @property
90 | def shape(self):
91 | return self._shapes
92 |
93 | @staticmethod
94 | def _parse_old(json_file):
95 | with open(json_file.replace('info.json', 'info.txt')) as f:
96 | try: # older version 1
97 | infos = json.load(f)
98 | for info in infos[0:-1]:
99 | info['decoder'] = _DECODERS[info['dtype_or_format']]['decoder']
100 | info['decode_param'] = _DECODERS[info['dtype_or_format']]['decode_param']
101 | except: # older version 2
102 | f.seek(0)
103 | infos = ''
104 | for line in f.readlines():
105 | infos += line.strip('\n')
106 | infos = eval(infos)
107 |
108 | data_num = infos[-1]['data_num']
109 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[infos[-1]['compression_type']]
110 | infos[-1:] = []
111 |
112 | return infos, data_num, compression_type
113 |
114 | @staticmethod
115 | def _parse_json(json_file):
116 | try:
117 | with open(json_file) as f:
118 | info = json.load(f)
119 | infos = info['item']
120 | for i in infos:
121 | i['decoder'] = _DECODERS[i['dtype_or_format']]['decoder']
122 | i['decode_param'] = _DECODERS[i['dtype_or_format']]['decode_param']
123 | data_num = info['info']['data_num']
124 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[info['info']['compression_type']]
125 | except: # for older version
126 | infos, data_num, compression_type = TfrecordData._parse_old(json_file)
127 |
128 | return infos, data_num, compression_type
129 |
--------------------------------------------------------------------------------
/tflib/data/tfrecord_creator.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import io
6 | import json
7 | import os
8 | import shutil
9 |
10 | import numpy as np
11 | from PIL import Image
12 | import tensorflow as tf
13 | from tflib.data import tfrecord
14 |
15 | __metaclass__ = type
16 |
17 |
18 | _ALLOWED_TYPES = tfrecord._DECODERS.keys()
19 |
20 |
21 | class BytesTfrecordCreator(object):
22 | """BytesTfrecordCreator.
23 |
24 | `infos` example:
25 | infos = [
26 | ['img', 'jpg', (64, 64, 3)],
27 | ['id', 'int64', ()],
28 | ['attr', 'int64', (40,)],
29 | ['point', 'float32', (5, 2)]
30 | ]
31 |
32 | `compression_type`:
33 | 0 : NONE
34 | 1 : ZLIB
35 | 2 : GZIP
36 | """
37 |
38 | def __init__(self, save_path, infos, size_each=None, compression_type=0, overwrite_existence=False):
39 | # overwrite existence
40 | if os.path.exists(save_path):
41 | if not overwrite_existence:
42 | raise Exception('%s exists!' % save_path)
43 | else:
44 | shutil.rmtree(save_path)
45 | os.makedirs(save_path)
46 | else:
47 | os.makedirs(save_path)
48 |
49 | self._save_path = save_path
50 |
51 | # add info
52 | self._infos = []
53 | self._info_names = []
54 | for info in infos:
55 | self._add_info(*info)
56 |
57 | self._data_num = 0
58 | self._tfrecord_num = 0
59 | self._size_each = [size_each, 2147483647][not size_each]
60 | self._writer = None
61 |
62 | self._compression_type = compression_type
63 | self._options = tf.python_io.TFRecordOptions(compression_type)
64 |
65 | def __del__(self):
66 | info = {'item': self._infos, 'info': {'data_num': self._data_num, 'compression_type': self._compression_type}}
67 | info_str = json.dumps(info, indent=4, separators=(',', ':'))
68 |
69 | with open(os.path.join(self._save_path, 'info.json'), 'w') as info_f:
70 | info_f.write(info_str)
71 |
72 | if self._writer:
73 | self._writer.close()
74 |
75 | def add(self, feature_bytes_dict):
76 | """Add example.
77 |
78 | `feature_bytes_dict` example:
79 | feature_bytes_dict = {
80 | 'img' : img_bytes,
81 | 'id' : id_bytes,
82 | 'attr' : attr_bytes,
83 | 'point' : point_bytes
84 | }
85 | """
86 | assert sorted(self._info_names) == sorted(feature_bytes_dict.keys()), \
87 | 'Feature names are inconsistent with the givens!'
88 |
89 | self._new_tfrecord_check()
90 |
91 | self._writer.write(self._bytes_tfexample(feature_bytes_dict).SerializeToString())
92 | self._data_num += 1
93 |
94 | def _new_tfrecord_check(self):
95 | if self._data_num // self._size_each == self._tfrecord_num:
96 | self._tfrecord_num += 1
97 |
98 | if self._writer:
99 | self._writer.close()
100 |
101 | tfrecord_path = os.path.join(self._save_path, 'data_%06d.tfrecord' % (self._tfrecord_num - 1))
102 | self._writer = tf.python_io.TFRecordWriter(tfrecord_path, self._options)
103 |
104 | def _add_info(self, name, dtype_or_format, shape):
105 | assert name not in self._info_names, 'Info name "%s" is duplicated!' % name
106 | assert dtype_or_format in _ALLOWED_TYPES, 'Allowed data types: %s!' % str(_ALLOWED_TYPES)
107 | self._infos.append(dict(name=name, dtype_or_format=dtype_or_format, shape=shape))
108 | self._info_names.append(name)
109 |
110 | @staticmethod
111 | def _bytes_feature(values):
112 | """Return a TF-Feature of bytes.
113 |
114 | Arguments:
115 | values : A byte string or list of byte strings.
116 |
117 | Returns:
118 | A TF-Feature.
119 | """
120 | if not isinstance(values, (tuple, list)):
121 | values = [values]
122 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
123 |
124 | @staticmethod
125 | def _bytes_tfexample(bytes_dict):
126 | """Convert bytes to tfexample.
127 |
128 | `bytes_dict` example:
129 | bytes_dict = {
130 | 'img' : img_bytes,
131 | 'id' : id_bytes,
132 | 'attr' : attr_bytes,
133 | 'point' : point_bytes
134 | }
135 | """
136 | feature_dict = {}
137 | for key, value in bytes_dict.items():
138 | feature_dict[key] = BytesTfrecordCreator._bytes_feature(value)
139 | return tf.train.Example(features=tf.train.Features(feature=feature_dict))
140 |
141 |
142 | class DataLablePairTfrecordCreator(BytesTfrecordCreator):
143 | """DataLablePairTfrecordCreator.
144 |
145 | If `data_shape` is None, then the `data` to be added should be a
146 | numpy array, and the shape and dtype of `data` will be inferred.
147 | If `data_shape` is not None, `data` should be given as byte string,
148 | and `data_dtype_or_format` should also be given.
149 |
150 | `compression_type`:
151 | 0 : NONE
152 | 1 : ZLIB
153 | 2 : GZIP
154 | """
155 |
156 | def __init__(self, save_path, data_shape=None, data_dtype_or_format=None, data_name='data',
157 | size_each=None, compression_type=0, overwrite_existence=False):
158 | super(DataLablePairTfrecordCreator, self).__init__(save_path, [], size_each, compression_type, overwrite_existence)
159 |
160 | if data_shape:
161 | assert data_dtype_or_format, '`data_dtype_or_format` should be given when `data_shape` is given!'
162 | self._is_data_bytes = True
163 | else:
164 | self._is_data_bytes = False
165 |
166 | self._data_shape = data_shape
167 | self._data_dtype_or_format = data_dtype_or_format
168 | self._data_name = data_name
169 | self._label_shape_dict = {}
170 | self._label_dtype_dict = {}
171 |
172 | self._info_built = False
173 |
174 | def add(self, data, label_dict):
175 | """Add example.
176 |
177 | `label_dict` example:
178 | label_dict = {
179 | 'id' : id_ndarray,
180 | 'attr' : attr_ndarray,
181 | 'point' : point_ndarray
182 | }
183 | """
184 | self._check_and_build(data, label_dict)
185 |
186 | if not self._is_data_bytes:
187 | data = data.tobytes()
188 |
189 | feature_dict = {self._data_name: data}
190 | for name, label in label_dict.items():
191 | feature_dict[name] = label.tobytes()
192 |
193 | super(DataLablePairTfrecordCreator, self).add(feature_dict)
194 |
195 | def _check_and_build(self, data, label_dict):
196 | # check type
197 | if self._is_data_bytes:
198 | assert isinstance(data, (str, bytes)), '`data` should be a byte string!'
199 | else:
200 | assert isinstance(data, np.ndarray), '`data` should be a numpy array!'
201 | for label in label_dict.values():
202 | assert isinstance(label, np.ndarray), 'labels should be numpy arrays!'
203 |
204 | # check shape and dtype or bulid info at first adding
205 | if self._info_built:
206 | if not self._is_data_bytes:
207 | assert data.shape == tuple(self._data_shape), 'Shapes of `data`s are inconsistent!'
208 | assert data.dtype.name == self._data_dtype_or_format, 'Dtypes of `data`s are inconsistent!'
209 | for name, label in label_dict.items():
210 | assert label.shape == self._label_shape_dict[name], 'Shapes of `%s`s are inconsistent!' % name
211 | assert label.dtype.name == self._label_dtype_dict[name], 'Dtypes of `%s`s are inconsistent!' % name
212 | else:
213 | if not self._is_data_bytes:
214 | self._data_shape = data.shape
215 | self._data_dtype_or_format = data.dtype.name
216 | self._add_info(self._data_name, self._data_dtype_or_format, self._data_shape)
217 |
218 | for name, label in label_dict.items():
219 | self._label_shape_dict[name] = label.shape
220 | self._label_dtype_dict[name] = label.dtype.name
221 | self._add_info(name, label.dtype.name, label.shape)
222 |
223 | self._info_built = True
224 |
225 |
226 | class ImageLablePairTfrecordCreator(DataLablePairTfrecordCreator):
227 | """ImageLablePairTfrecordCreator.
228 |
229 | Arguments:
230 | encode_type : One of [None, 'png', 'jpg'].
231 | quality : For 'jpg'.
232 | compression_type :
233 | 0 : NONE
234 | 1 : ZLIB
235 | 2 : GZIP
236 | """
237 |
238 | def __init__(self, save_path, encode_type='png', quality=95, data_name='img',
239 | size_each=None, compression_type=0, overwrite_existence=False):
240 | super(ImageLablePairTfrecordCreator, self).__init__(
241 | save_path, None, None, data_name, size_each, compression_type, overwrite_existence)
242 |
243 | assert encode_type in [None, 'png', 'jpg'], "`encode_type` should be in the list of [None, 'png', 'jpg']!"
244 |
245 | self._encode_type = encode_type
246 | self._quality = quality
247 |
248 | self._data_shape = None
249 | self._data_dtype_or_format = None
250 | self._is_data_bytes = True
251 |
252 | def add(self, image, label_dict):
253 | """Add example.
254 |
255 | `image`: An H * W (* C) uint8 numpy array.
256 |
257 | `label_dict` example:
258 | label_dict = {
259 | 'id' : id_ndarray,
260 | 'attr' : attr_ndarray,
261 | 'point' : point_ndarray
262 | }
263 | """
264 | self._check(image)
265 | image_bytes = self._encode(image)
266 | super(ImageLablePairTfrecordCreator, self).add(image_bytes, label_dict)
267 |
268 | def _check(self, image):
269 | if not self._data_shape:
270 | assert isinstance(image, np.ndarray) and image.dtype == np.uint8 and image.ndim in [2, 3], \
271 | '`image` should be an H * W (* C) uint8 numpy array!'
272 | if self._encode_type and image.ndim == 3 and image.shape[-1] != 3:
273 | raise Exception('Only images with 1 or 3 channels are allowed to be encoded!')
274 |
275 | if image.ndim == 2:
276 | self._data_shape = image.shape + (1,)
277 | else:
278 | self._data_shape = image.shape
279 | self._data_dtype_or_format = [self._encode_type, 'uint8'][not self._encode_type]
280 | else:
281 | sp = image.shape
282 | if image.ndim == 2:
283 | sp = sp + (1,)
284 | assert sp == self._data_shape, 'Shapes of `image`s are inconsistent!'
285 | assert image.dtype == np.uint8, 'Dtypes of `image`s are inconsistent!'
286 |
287 | def _encode(self, image):
288 | if self._encode_type:
289 | if image.shape[-1] == 1:
290 | image.shape = image.shape[:2]
291 | byte = io.BytesIO()
292 | image = Image.fromarray(image)
293 | if self._encode_type == 'jpg':
294 | image.save(byte, 'JPEG', quality=self._quality)
295 | elif self._encode_type == 'png':
296 | image.save(byte, 'PNG')
297 | image_bytes = byte.getvalue()
298 | else:
299 | image_bytes = image.tobytes()
300 | return image_bytes
301 |
--------------------------------------------------------------------------------
/tflib/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.ops.layers import *
6 |
--------------------------------------------------------------------------------
/tflib/ops/layers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import tensorflow.contrib.slim as slim
7 |
8 |
9 | def flatten_fully_connected(inputs,
10 | num_outputs,
11 | activation_fn=tf.nn.relu,
12 | normalizer_fn=None,
13 | normalizer_params=None,
14 | weights_initializer=slim.xavier_initializer(),
15 | weights_regularizer=None,
16 | biases_initializer=tf.zeros_initializer(),
17 | biases_regularizer=None,
18 | reuse=None,
19 | variables_collections=None,
20 | outputs_collections=None,
21 | trainable=True,
22 | scope=None):
23 | with tf.variable_scope(scope, 'flatten_fully_connected', [inputs]):
24 | if inputs.shape.ndims > 2:
25 | inputs = slim.flatten(inputs)
26 | return slim.fully_connected(inputs,
27 | num_outputs,
28 | activation_fn,
29 | normalizer_fn,
30 | normalizer_params,
31 | weights_initializer,
32 | weights_regularizer,
33 | biases_initializer,
34 | biases_regularizer,
35 | reuse,
36 | variables_collections,
37 | outputs_collections,
38 | trainable,
39 | scope)
40 |
41 | flatten_dense = flatten_fully_connected
42 |
--------------------------------------------------------------------------------
/tflib/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import re
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def session(graph=None, allow_soft_placement=True,
11 | log_device_placement=False, allow_growth=True):
12 | """Return a Session with simple config."""
13 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
14 | log_device_placement=log_device_placement)
15 | config.gpu_options.allow_growth = allow_growth
16 | return tf.Session(graph=graph, config=config)
17 |
18 |
19 | def print_tensor(tensors):
20 | if not isinstance(tensors, (list, tuple)):
21 | tensors = [tensors]
22 |
23 | for i, tensor in enumerate(tensors):
24 | ctype = str(type(tensor))
25 | if 'Tensor' in ctype:
26 | type_name = 'Tensor'
27 | elif 'Variable' in ctype:
28 | type_name = 'Variable'
29 | else:
30 | raise Exception('Not a Tensor or Variable!')
31 |
32 | print(str(i) + (': %s("%s", shape=%s, dtype=%s, device=%s)'
33 | % (type_name, tensor.name, str(tensor.get_shape()),
34 | tensor.dtype.name, tensor.device)))
35 |
36 | prt = print_tensor
37 |
38 |
39 | def shape(tensor):
40 | sp = tensor.get_shape().as_list()
41 | return [num if num is not None else -1 for num in sp]
42 |
43 |
44 | def summary(tensor_collection,
45 | summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram'],
46 | scope=None):
47 | """Summary.
48 |
49 | Usage:
50 | 1. summary(tensor)
51 | 2. summary([tensor_a, tensor_b])
52 | 3. summary({tensor_a: 'a', tensor_b: 'b})
53 | """
54 | def _summary(tensor, name, summary_type):
55 | """Attach a lot of summaries to a Tensor."""
56 | if name is None:
57 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
58 | # session. This helps the clarity of presentation on tensorboard.
59 | name = re.sub('%s_[0-9]*/' % 'tower', '', tensor.name)
60 | name = re.sub(':', '-', name)
61 |
62 | summaries = []
63 | if len(tensor.shape) == 0:
64 | summaries.append(tf.summary.scalar(name, tensor))
65 | else:
66 | if 'mean' in summary_type:
67 | mean = tf.reduce_mean(tensor)
68 | summaries.append(tf.summary.scalar(name + '/mean', mean))
69 | if 'stddev' in summary_type:
70 | mean = tf.reduce_mean(tensor)
71 | stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean)))
72 | summaries.append(tf.summary.scalar(name + '/stddev', stddev))
73 | if 'max' in summary_type:
74 | summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor)))
75 | if 'min' in summary_type:
76 | summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor)))
77 | if 'sparsity' in summary_type:
78 | summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor)))
79 | if 'histogram' in summary_type:
80 | summaries.append(tf.summary.histogram(name, tensor))
81 | return tf.summary.merge(summaries)
82 |
83 | if not isinstance(tensor_collection, (list, tuple, dict)):
84 | tensor_collection = [tensor_collection]
85 |
86 | with tf.name_scope(scope, 'summary'):
87 | summaries = []
88 | if isinstance(tensor_collection, (list, tuple)):
89 | for tensor in tensor_collection:
90 | summaries.append(_summary(tensor, None, summary_type))
91 | else:
92 | for tensor, name in tensor_collection.items():
93 | summaries.append(_summary(tensor, name, summary_type))
94 | return tf.summary.merge(summaries)
95 |
96 |
97 | def counter(start=0, scope=None):
98 | with tf.variable_scope(scope, 'counter'):
99 | counter = tf.get_variable(name='counter',
100 | initializer=tf.constant_initializer(start),
101 | shape=(),
102 | dtype=tf.int64)
103 | update_cnt = tf.assign(counter, tf.add(counter, 1))
104 | return counter, update_cnt
105 |
--------------------------------------------------------------------------------
/tflib/variable.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def tensors_filter(tensors, filters, combine_type='or'):
9 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!'
10 | assert isinstance(filters, (str, list, tuple)), '`filters` should be a string or a list(tuple) of strings!'
11 | assert combine_type == 'or' or combine_type == 'and', "`combine_type` should be 'or' or 'and'!"
12 |
13 | if isinstance(filters, str):
14 | filters = [filters]
15 |
16 | f_tens = []
17 | for ten in tensors:
18 | if combine_type == 'or':
19 | for filt in filters:
20 | if filt in ten.name:
21 | f_tens.append(ten)
22 | break
23 | elif combine_type == 'and':
24 | all_pass = True
25 | for filt in filters:
26 | if filt not in ten.name:
27 | all_pass = False
28 | break
29 | if all_pass:
30 | f_tens.append(ten)
31 | return f_tens
32 |
33 |
34 | def global_variables(filters=None, combine_type='or'):
35 | global_vars = tf.global_variables()
36 | if filters is None:
37 | return global_vars
38 | else:
39 | return tensors_filter(global_vars, filters, combine_type)
40 |
41 |
42 | def trainable_variables(filters=None, combine_type='or'):
43 | t_var = tf.trainable_variables()
44 | if filters is None:
45 | return t_var
46 | else:
47 | return tensors_filter(t_var, filters, combine_type)
48 |
--------------------------------------------------------------------------------
/tflib/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.vision.dataset import *
6 |
--------------------------------------------------------------------------------
/tflib/vision/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.vision.dataset.mnist import *
6 |
--------------------------------------------------------------------------------
/tflib/vision/dataset/mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gzip
6 | import multiprocessing
7 | import os
8 | import struct
9 | import subprocess
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 | from tflib.data.memory_data import MemoryData
14 |
15 |
16 | _N_CPU = multiprocessing.cpu_count()
17 |
18 |
19 | def unzip_gz(file_name):
20 | unzip_name = file_name.replace('.gz', '')
21 | gz_file = gzip.GzipFile(file_name)
22 | open(unzip_name, 'wb+').write(gz_file.read())
23 | gz_file.close()
24 |
25 |
26 | def mnist_download(download_dir):
27 | url_base = 'http://yann.lecun.com/exdb/mnist/'
28 | file_names = ['train-images-idx3-ubyte.gz',
29 | 'train-labels-idx1-ubyte.gz',
30 | 't10k-images-idx3-ubyte.gz',
31 | 't10k-labels-idx1-ubyte.gz']
32 | for file_name in file_names:
33 | url = url_base + file_name
34 | save_path = os.path.join(download_dir, file_name)
35 | cmd = ['curl', url, '-o', save_path]
36 | print('Downloading ', file_name)
37 | if not os.path.exists(save_path):
38 | subprocess.call(cmd)
39 | else:
40 | print('%s exists, skip!' % file_name)
41 |
42 |
43 | def mnist_load(data_dir, split='train'):
44 | """Load MNIST dataset, modified from https://gist.github.com/akesling/5358964.
45 |
46 | Returns:
47 | `imgs`, `lbls`, `num`.
48 |
49 | `imgs` : [-1.0, 1.0] float64 images of shape (N * H * W).
50 | `lbls` : Int labels of shape (N,).
51 | `num` : # of datas.
52 | """
53 | mnist_download(data_dir)
54 |
55 | if split == 'train':
56 | fname_img = os.path.join(data_dir, 'train-images-idx3-ubyte')
57 | fname_lbl = os.path.join(data_dir, 'train-labels-idx1-ubyte')
58 | elif split == 'test':
59 | fname_img = os.path.join(data_dir, 't10k-images-idx3-ubyte')
60 | fname_lbl = os.path.join(data_dir, 't10k-labels-idx1-ubyte')
61 | else:
62 | raise ValueError("split must be 'test' or 'train'")
63 |
64 | if not os.path.exists(fname_img):
65 | unzip_gz(fname_img + '.gz')
66 | if not os.path.exists(fname_lbl):
67 | unzip_gz(fname_lbl + '.gz')
68 |
69 | with open(fname_lbl, 'rb') as flbl:
70 | struct.unpack('>II', flbl.read(8))
71 | lbls = np.fromfile(flbl, dtype=np.int8)
72 |
73 | with open(fname_img, 'rb') as fimg:
74 | _, _, rows, cols = struct.unpack('>IIII', fimg.read(16))
75 | imgs = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbls), rows, cols)
76 | imgs = imgs / 127.5 - 1
77 |
78 | return imgs, lbls, len(lbls)
79 |
80 |
81 | class Mnist(MemoryData):
82 |
83 | def __init__(self, data_dir, batch_size, split='train', prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
84 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
85 | imgs, lbls, _ = mnist_load(data_dir, split)
86 | imgs.shape = imgs.shape + (1,)
87 |
88 | imgs_pl = tf.placeholder(tf.float32, imgs.shape)
89 | lbls_pl = tf.placeholder(tf.int64, lbls.shape)
90 |
91 | memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl}
92 |
93 | self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls}
94 | super(Mnist, self).__init__(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter,
95 | map_func, num_threads, shuffle, buffer_size, repeat, sess)
96 |
97 | def reset(self):
98 | super(Mnist, self).reset(self.feed_dict)
99 |
100 | if __name__ == '__main__':
101 | import imlib as im
102 | from tflib import session
103 | sess = session()
104 | mnist = Mnist('/tmp', 5000, repeat=1, sess=sess)
105 | print(len(mnist))
106 | for batch in mnist:
107 | print(batch['lbl'][-1])
108 | im.imshow(batch['img'][-1].squeeze())
109 | im.show()
110 | sess.close()
111 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import datetime
7 | from functools import partial
8 | import json
9 | import traceback
10 |
11 | import imlib as im
12 | import numpy as np
13 | import pylib
14 | import tensorflow as tf
15 | import tflib as tl
16 | import utils
17 |
18 |
19 | # ==============================================================================
20 | # = param =
21 | # ==============================================================================
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--epoch', dest='epoch', type=int, default=50)
25 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64)
26 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
27 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=32, help='dimension of latent')
28 | parser.add_argument('--beta', dest='beta', type=float, default=0.1)
29 | parser.add_argument('--dataset', dest='dataset_name', default='mnist', choices=['mnist', 'celeba'])
30 | parser.add_argument('--model', dest='model_name', default='mlp_mnist', choices=['mlp_mnist', 'conv_mnist', 'conv_64'])
31 | parser.add_argument('--experiment_name', dest='experiment_name', default=datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
32 |
33 | args = parser.parse_args()
34 |
35 | epoch = args.epoch
36 | batch_size = args.batch_size
37 | lr = args.lr
38 | z_dim = args.z_dim
39 | beta = args.beta
40 |
41 | dataset_name = args.dataset_name
42 | model_name = args.model_name
43 | experiment_name = args.experiment_name
44 |
45 | pylib.mkdir('./output/%s' % experiment_name)
46 | with open('./output/%s/setting.txt' % experiment_name, 'w') as f:
47 | f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
48 |
49 | # dataset and models
50 | Dataset, img_shape, get_imgs = utils.get_dataset(dataset_name)
51 | dataset = Dataset(batch_size=batch_size)
52 | dataset_val = Dataset(batch_size=100)
53 | Enc, Dec = utils.get_models(model_name)
54 | Enc = partial(Enc, z_dim=z_dim)
55 | Dec = partial(Dec, channels=img_shape[2])
56 |
57 |
58 | # ==============================================================================
59 | # = graph =
60 | # ==============================================================================
61 |
62 | def enc_dec(img, is_training=True):
63 | # encode
64 | z_mu, z_log_sigma_sq = Enc(img, is_training=is_training)
65 |
66 | # sample
67 | epsilon = tf.random_normal(tf.shape(z_mu))
68 | if is_training:
69 | z = z_mu + tf.exp(0.5 * z_log_sigma_sq) * epsilon
70 | else:
71 | z = z_mu
72 |
73 | # decode
74 | img_rec = Dec(z, is_training=is_training)
75 |
76 | return z_mu, z_log_sigma_sq, img_rec
77 |
78 | # input
79 | img = tf.placeholder(tf.float32, [None] + img_shape)
80 | z_sample = tf.placeholder(tf.float32, [None, z_dim])
81 |
82 | # encode & decode
83 | z_mu, z_log_sigma_sq, img_rec = enc_dec(img)
84 |
85 | # loss
86 | rec_loss = tf.losses.mean_squared_error(img, img_rec)
87 | kld_loss = -tf.reduce_mean(0.5 * (1 + z_log_sigma_sq - z_mu**2 - tf.exp(z_log_sigma_sq)))
88 | loss = rec_loss + kld_loss * beta
89 |
90 | # otpim
91 | step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(loss)
92 |
93 | # summary
94 | summary = tl.summary({rec_loss: 'rec_loss', kld_loss: 'kld_loss'})
95 |
96 | # sample
97 | _, _, img_rec_sample = enc_dec(img, is_training=False)
98 | img_sample = Dec(z_sample, is_training=False)
99 |
100 |
101 | # ==============================================================================
102 | # = train =
103 | # ==============================================================================
104 |
105 | # session
106 | sess = tl.session()
107 |
108 | # saver
109 | saver = tf.train.Saver(max_to_keep=1)
110 |
111 | # summary writer
112 | summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph)
113 |
114 | # initialization
115 | ckpt_dir = './output/%s/checkpoints' % experiment_name
116 | pylib.mkdir(ckpt_dir)
117 | try:
118 | tl.load_checkpoint(ckpt_dir, sess)
119 | except:
120 | sess.run(tf.global_variables_initializer())
121 |
122 | # train
123 | try:
124 | img_ipt_sample = get_imgs(dataset_val.get_next())
125 | z_ipt_sample = np.random.normal(size=[100, z_dim])
126 |
127 | it = -1
128 | for ep in range(epoch):
129 | dataset.reset()
130 | it_per_epoch = it_in_epoch if it != -1 else -1
131 | it_in_epoch = 0
132 | for batch in dataset:
133 | it += 1
134 | it_in_epoch += 1
135 |
136 | # batch data
137 | img_ipt = get_imgs(batch)
138 |
139 | # train D
140 | summary_opt, _ = sess.run([summary, step], feed_dict={img: img_ipt})
141 | summary_writer.add_summary(summary_opt, it)
142 |
143 | # display
144 | if (it + 1) % 1 == 0:
145 | print("Epoch: (%3d) (%5d/%5d)" % (ep, it_in_epoch, it_per_epoch))
146 |
147 | # sample
148 | if (it + 1) % 1000 == 0:
149 | save_dir = './output/%s/sample_training' % experiment_name
150 | pylib.mkdir(save_dir)
151 |
152 | img_rec_opt_sample = sess.run(img_rec_sample, feed_dict={img: img_ipt_sample})
153 | ipt_rec = np.concatenate((img_ipt_sample, img_rec_opt_sample), axis=2).squeeze()
154 | img_opt_sample = sess.run(img_sample, feed_dict={z_sample: z_ipt_sample}).squeeze()
155 |
156 | im.imwrite(im.immerge(ipt_rec, padding=img_shape[0] // 8), '%s/Epoch_(%d)_(%dof%d)_img_rec.jpg' % (save_dir, ep, it_in_epoch, it_per_epoch))
157 | im.imwrite(im.immerge(img_opt_sample), '%s/Epoch_(%d)_(%dof%d)_img_sample.jpg' % (save_dir, ep, it_in_epoch, it_per_epoch))
158 |
159 | save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep))
160 | print('Model is saved in file: %s' % save_path)
161 | except:
162 | traceback.print_exc()
163 | finally:
164 | sess.close()
165 |
--------------------------------------------------------------------------------
/traversal.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | from functools import partial
7 | import json
8 | import traceback
9 |
10 | import imlib as im
11 | import numpy as np
12 | import pylib
13 | import tensorflow as tf
14 | import tflib as tl
15 | import utils
16 |
17 |
18 | # ==============================================================================
19 | # = param =
20 | # ==============================================================================
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--experiment_name', dest='experiment_name', help='experiment_name')
24 | args_ = parser.parse_args()
25 | with open('./output/%s/setting.txt' % args_.experiment_name) as f:
26 | args = json.load(f)
27 |
28 | z_dim = args["z_dim"]
29 |
30 | dataset_name = args["dataset_name"]
31 | model_name = args["model_name"]
32 | experiment_name = args_.experiment_name
33 |
34 | # dataset and models
35 | _, img_shape, _ = utils.get_dataset(dataset_name)
36 | _, Dec = utils.get_models(model_name)
37 | Dec = partial(Dec, channels=img_shape[2])
38 |
39 |
40 | # ==============================================================================
41 | # = graph =
42 | # ==============================================================================
43 |
44 | # input
45 | z_sample = tf.placeholder(tf.float32, [None, z_dim])
46 |
47 | # sample
48 | img_sample = Dec(z_sample, is_training=False)
49 |
50 |
51 | # ==============================================================================
52 | # = train =
53 | # ==============================================================================
54 |
55 | # session
56 | sess = tl.session()
57 |
58 | # initialization
59 | ckpt_dir = './output/%s/checkpoints' % experiment_name
60 | try:
61 | tl.load_checkpoint(ckpt_dir, sess)
62 | except:
63 | raise Exception(' [*] No checkpoint!')
64 |
65 | # train
66 | try:
67 | z_ipt_sample_ = np.random.normal(size=[10, z_dim])
68 | for i in range(z_dim):
69 | z_ipt_sample = np.copy(z_ipt_sample_)
70 | img_opt_samples = []
71 | for v in np.linspace(-3, 3, 10):
72 | z_ipt_sample[:, i] = v
73 | img_opt_samples.append(sess.run(img_sample, feed_dict={z_sample: z_ipt_sample}).squeeze())
74 |
75 | save_dir = './output/%s/sample_traversal' % experiment_name
76 | pylib.mkdir(save_dir)
77 | im.imwrite(im.immerge(np.concatenate(img_opt_samples, axis=2), 10), '%s/traversal_d%d.jpg' % (save_dir, i))
78 | except:
79 | traceback.print_exc()
80 | finally:
81 | sess.close()
82 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 | import glob as glob
7 |
8 | import models
9 | import pylib
10 | import tensorflow as tf
11 | import tflib as tl
12 |
13 |
14 | def get_dataset(dataset_name):
15 | if dataset_name == 'mnist':
16 | # dataset
17 | pylib.mkdir('./data/mnist')
18 | Dataset = partial(tl.Mnist, data_dir='./data/mnist', repeat=1)
19 |
20 | # shape
21 | img_shape = [28, 28, 1]
22 |
23 | # index func
24 | def get_imgs(batch):
25 | return batch['img']
26 |
27 | return Dataset, img_shape, get_imgs
28 |
29 | elif dataset_name == 'celeba':
30 | # dataset
31 | def _map_func(img):
32 | crop_size = 108
33 | re_size = 64
34 | img = tf.image.crop_to_bounding_box(img, (218 - crop_size) // 2, (178 - crop_size) // 2, crop_size, crop_size)
35 | img = tf.image.resize_images(img, [re_size, re_size], method=tf.image.ResizeMethod.BICUBIC)
36 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
37 | return img
38 |
39 | paths = glob.glob('./data/celeba/img_align_celeba/*.jpg')
40 | Dataset = partial(tl.DiskImageData, img_paths=paths, repeat=1, map_func=_map_func)
41 |
42 | # shape
43 | img_shape = [64, 64, 3]
44 |
45 | # index func
46 | def get_imgs(batch):
47 | return batch
48 |
49 | return Dataset, img_shape, get_imgs
50 |
51 |
52 | def get_models(model_name):
53 | return getattr(models, model_name)()
54 |
--------------------------------------------------------------------------------