├── .gitignore
├── LICENSE
├── README.md
├── imlib
├── __init__.py
├── basic.py
├── dtype.py
├── encode.py
└── transform.py
├── models.py
├── models_64x64.py
├── pics
├── GAN_normalG.jpg
├── GAN_trickyG.jpg
├── Jensen-Shannon_normalG.jpg
├── Jensen-Shannon_trickyG.jpg
├── Kullback-Leibler_normalG.jpg
├── Kullback-Leibler_trickyG.jpg
├── Pearson-X2_normalG.jpg
├── Pearson-X2_trickyG.jpg
├── Reverse-KL_normalG.jpg
└── Reverse-KL_trickyG.jpg
├── pylib
├── __init__.py
├── timer.py
└── utils.py
├── tflib
├── __init__.py
├── checkpoint.py
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── disk_image.py
│ ├── memory_data.py
│ ├── tfrecord.py
│ └── tfrecord_creator.py
├── ops
│ ├── __init__.py
│ └── layers.py
├── utils.py
├── variable.py
└── vision
│ ├── __init__.py
│ └── dataset
│ ├── __init__.py
│ └── mnist.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__/
3 | /data/
4 | /output/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018 hezhenliang
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a
4 | copy of this software and associated documentation files (the "Software"),
5 | to deal in the Software without restriction, including without limitation
6 | the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | and/or sell copies of the Software, and to permit persons to whom the
8 | Software is furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
f-GAN
2 |
3 | Tensorflow implementation of f-GAN (NIPS 2016) - [f-GAN: Training Generative Neural Samplers Using Variational Divergence Minimization](https://arxiv.org/abs/1606.00709).
4 |
5 | ## TODO
6 |
7 | - [ ] make these divergences work (welcome the suggestions)
8 | - [ ] ***Kullback-Leibler*** with tricky G loss
9 | - [ ] ***Reverse-KL*** with tricky G loss
10 | - [x] ***Pearson-X2*** with tricky G loss
11 | - [ ] ***Squared-Hellinger*** with tricky G loss
12 | - [x] ***Jensen-Shannon*** with tricky G loss
13 | - [x] ***GAN*** with tricky G loss
14 | - [ ] test more divergence
15 |
16 | ## Exemplar Results
17 |
18 | - Using tricky G loss (see Section 3.2 in the paper)
19 |
20 | Kullback-Leibler | Reverse-KL | Pearson-X2
21 | :---: | :---: | :---:
22 |
|
|
23 | **Squared-Hellinger** | **Jensen-Shannon** | **GAN**
24 | NaN |
|
25 |
26 | - Using theoretically correct G loss
27 |
28 | Kullback-Leibler | Reverse-KL | Pearson-X2
29 | :---: | :---: | :---:
30 |
|
|
31 | **Squared-Hellinger** | **Jensen-Shannon** | **GAN**
32 | NaN |
|
33 |
34 | ## Usage
35 |
36 | - Prerequisites
37 | - tensorflow 1.7 or 1.8
38 | - python 2.7
39 |
40 |
41 | - Examples of training
42 | - training
43 |
44 | ```console
45 | CUDA_VISIBLE_DEVICES=0 python train.py --dataset=mnist --divergence=Pearson-X2 --tricky_G
46 | ```
47 |
48 | - tensorboard for loss visualization
49 |
50 | ```console
51 | CUDA_VISIBLE_DEVICES='' tensorboard --logdir ./output/mnist_Pearson-X2_trickyG/summaries --port 6006
52 | ```
53 |
54 | ## Citation
55 | If you find [f-GAN](https://arxiv.org/abs/1606.00709) useful in your research work, please consider citing:
56 |
57 | @inproceedings{nowozin2016f,
58 | title={f-GAN: Training Generative Neural Samplers Using Variational Divergence Minimization},
59 | author={Nowozin, Sebastian and Cseke, Botond and Tomioka, Ryota},
60 | booktitle={Advances in Neural Information Processing Systems (NIPS)},
61 | year={2016}
62 | }
--------------------------------------------------------------------------------
/imlib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from imlib.basic import *
6 | from imlib.dtype import *
7 | from imlib.encode import *
8 | from imlib.transform import *
9 |
--------------------------------------------------------------------------------
/imlib/basic.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from imlib.dtype import *
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import scipy.misc
9 |
10 |
11 | def imread(paths, mode='RGB'):
12 | """Read image(s).
13 |
14 | if `paths` is a list or tuple, then read a list of images into [-1.0, 1.0]
15 | of float and return the numpy array batch in shape of N * H * W (* C)
16 | if `paths` is a single str, then read an image into [-1.0, 1.0] of float
17 |
18 | Args:
19 | mode: It can be one of the following strings:
20 | * 'L' (8 - bit pixels, black and white)
21 | * 'P' (8 - bit pixels, mapped to any other mode using a color palette)
22 | * 'RGB' (3x8 - bit pixels, true color)
23 | * 'RGBA' (4x8 - bit pixels, true color with transparency mask)
24 | * 'CMYK' (4x8 - bit pixels, color separation)
25 | * 'YCbCr' (3x8 - bit pixels, color video format)
26 | * 'I' (32 - bit signed integer pixels)
27 | * 'F' (32 - bit floating point pixels)
28 |
29 | Returns:
30 | Float64 image in [-1.0, 1.0].
31 | """
32 | def _imread(path, mode='RGB'):
33 | return scipy.misc.imread(path, mode=mode) / 127.5 - 1
34 |
35 | if isinstance(paths, (list, tuple)):
36 | return np.array([_imread(path, mode) for path in paths])
37 | else:
38 | return _imread(paths, mode)
39 |
40 |
41 | def imwrite(image, path):
42 | """Save an [-1.0, 1.0] image."""
43 | if image.ndim == 3 and image.shape[2] == 1: # for gray image
44 | image = np.array(image, copy=True)
45 | image.shape = image.shape[0:2]
46 | return scipy.misc.imsave(path, to_range(image, 0, 255, np.uint8))
47 |
48 |
49 | def imshow(image):
50 | """Show a [-1.0, 1.0] image."""
51 | if image.ndim == 3 and image.shape[2] == 1: # for gray image
52 | image = np.array(image, copy=True)
53 | image.shape = image.shape[0:2]
54 | plt.imshow(to_range(image), cmap=plt.gray())
55 |
56 |
57 | show = plt.show
58 |
--------------------------------------------------------------------------------
/imlib/dtype.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 |
8 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
9 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype."""
10 | assert np.min(images) >= -1.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \
11 | and (images.dtype == np.float32 or images.dtype == np.float64), \
12 | 'The input images should be float64(32) and in the range of [-1.0, 1.0]!'
13 | if dtype is None:
14 | dtype = images.dtype
15 | return ((images + 1.) / 2. * (max_value - min_value) +
16 | min_value).astype(dtype)
17 |
18 |
19 | def uint2im(images):
20 | """Transform images from uint8 to [-1.0, 1.0] of float64."""
21 | assert images.dtype == np.uint8, 'The input images type should be uint8!'
22 | return images / 127.5 - 1.0
23 |
24 |
25 | def float2im(images):
26 | """Transform images from [0, 1.0] to [-1.0, 1.0]."""
27 | assert np.min(images) >= 0.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \
28 | and (images.dtype == np.float32 or images.dtype == np.float64), \
29 | 'The input images should be float64(32) and in the range of [0.0, 1.0]!'
30 | return images * 2 - 1.0
31 |
32 |
33 | def im2uint(images):
34 | """Transform images from [-1.0, 1.0] to uint8."""
35 | return to_range(images, 0, 255, np.uint8)
36 |
37 |
38 | def im2float(images):
39 | """Transform images from [-1.0, 1.0] to [0.0, 1.0]."""
40 | return to_range(images, 0.0, 1.0)
41 |
42 |
43 | def float2uint(images):
44 | """Transform images from [0, 1.0] to uint8."""
45 | assert np.min(images) >= 0.0 - 1e-5 and np.max(images) <= 1.0 + 1e-5 \
46 | and (images.dtype == np.float32 or images.dtype == np.float64), \
47 | 'The input images should be float64(32) and in the range of [0.0, 1.0]!'
48 | return (images * 255).astype(np.uint8)
49 |
50 |
51 | def uint2float(images):
52 | """Transform images from uint8 to [0.0, 1.0] of float64."""
53 | assert images.dtype == np.uint8, 'The input images type should be uint8!'
54 | return images / 255.0
55 |
--------------------------------------------------------------------------------
/imlib/encode.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import io
6 |
7 | from imlib.dtype import *
8 | from PIL import Image
9 |
10 |
11 | def imencode(image, format='PNG', quality=95):
12 | """Encode an [-1.0, 1.0] into byte str.
13 |
14 | Args:
15 | format: 'PNG' or 'JPEG'.
16 | quality: for 'JPEG'.
17 |
18 | Returns:
19 | Byte string.
20 | """
21 | byte_io = io.BytesIO()
22 | image = Image.fromarray(im2uint(image))
23 | image.save(byte_io, format=format, quality=quality)
24 | bytes = byte_io.getvalue()
25 | return bytes
26 |
27 |
28 | def imdecode(bytes):
29 | """Decode byte str to image in [-1.0, 1.0] of float64.
30 |
31 | Args:
32 | bytes: Byte string.
33 |
34 | Returns:
35 | A float64 image in [-1.0, 1.0].
36 | """
37 | byte_io = io.BytesIO()
38 | byte_io.write(bytes)
39 | image = np.array(Image.open(byte_io))
40 | image = uint2im(image)
41 | return image
42 |
--------------------------------------------------------------------------------
/imlib/transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from imlib.dtype import *
6 | import numpy as np
7 | import scipy.misc
8 |
9 |
10 | def rgb2gray(images):
11 | if images.ndim == 4 or images.ndim == 3:
12 | assert images.shape[-1] == 3, 'Channel size should be 3!'
13 | else:
14 | raise Exception('Wrong dimensions!')
15 |
16 | return (images[..., 0] * 0.299 + images[..., 1] * 0.587 + images[..., 2] * 0.114).astype(images.dtype)
17 |
18 |
19 | def gray2rgb(images):
20 | assert images.ndim == 2 or images.ndim == 3, 'Wrong dimensions!'
21 | rgb_imgs = np.zeros(images.shape + (3,), dtype=images.dtype)
22 | rgb_imgs[..., 0] = images
23 | rgb_imgs[..., 1] = images
24 | rgb_imgs[..., 2] = images
25 | return rgb_imgs
26 |
27 |
28 | def imresize(image, size, interp='bilinear'):
29 | """Resize an [-1.0, 1.0] image.
30 |
31 | Args:
32 | size : int, float or tuple
33 | * int - Percentage of current size.
34 | * float - Fraction of current size.
35 | * tuple - Size of the output image.
36 |
37 | interp : str, optional
38 | Interpolation to use for re-sizing ('nearest', 'lanczos',
39 | 'bilinear', 'bicubic' or 'cubic').
40 | """
41 | # scipy.misc.imresize should deal with uint8 image, or it would cause some
42 | # problem (scale the image to [0, 255])
43 | return (scipy.misc.imresize(im2uint(image), size, interp=interp) / 127.5 - 1).astype(image.dtype)
44 |
45 |
46 | def resize_images(images, size, interp='bilinear'):
47 | """Resize batch [-1.0, 1.0] images of shape (N * H * W (* 3)).
48 |
49 | Args:
50 | size : int, float or tuple
51 | * int - Percentage of current size.
52 | * float - Fraction of current size.
53 | * tuple - Size of the output image.
54 |
55 | interp : str, optional
56 | Interpolation to use for re-sizing ('nearest', 'lanczos',
57 | 'bilinear', 'bicubic' or 'cubic').
58 | """
59 | rs_imgs = []
60 | for img in images:
61 | rs_imgs.append(imresize(img, size, interp))
62 | return np.array(rs_imgs)
63 |
64 |
65 | def immerge(images, n_row=None, n_col=None, padding=0, pad_value=0):
66 | """Merge images into an image with (n_row * h) * (n_col * w).
67 |
68 | `images` is in shape of N * H * W(* C=1 or 3)
69 | """
70 | n = images.shape[0]
71 | if n_row:
72 | n_row = max(min(n_row, n), 1)
73 | n_col = int(n - 0.5) // n_row + 1
74 | elif n_col:
75 | n_col = max(min(n_col, n), 1)
76 | n_row = int(n - 0.5) // n_col + 1
77 | else:
78 | n_row = int(n ** 0.5)
79 | n_col = int(n - 0.5) // n_row + 1
80 |
81 | h, w = images.shape[1], images.shape[2]
82 | shape = (h * n_row + padding * (n_row - 1),
83 | w * n_col + padding * (n_col - 1))
84 | if images.ndim == 4:
85 | shape += (images.shape[3],)
86 | img = np.full(shape, pad_value, dtype=images.dtype)
87 |
88 | for idx, image in enumerate(images):
89 | i = idx % n_col
90 | j = idx // n_col
91 | img[j * (h + padding):j * (h + padding) + h,
92 | i * (w + padding):i * (w + padding) + w, ...] = image
93 |
94 | return img
95 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 |
7 | import tensorflow as tf
8 | import tensorflow.contrib.slim as slim
9 | import tflib as tl
10 |
11 | conv = partial(slim.conv2d, activation_fn=None)
12 | dconv = partial(slim.conv2d_transpose, activation_fn=None)
13 | fc = partial(tl.flatten_fully_connected, activation_fn=None)
14 | relu = tf.nn.relu
15 | lrelu = tf.nn.leaky_relu
16 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None)
17 |
18 |
19 | def G(z, dim=64, is_training=True):
20 | bn = partial(batch_norm, is_training=is_training)
21 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu)
22 | fc_bn_relu = partial(fc, normalizer_fn=bn, activation_fn=relu)
23 |
24 | with tf.variable_scope('G', reuse=tf.AUTO_REUSE):
25 | y = fc_bn_relu(z, 1024)
26 | y = fc_bn_relu(y, 7 * 7 * dim * 2)
27 | y = tf.reshape(y, [-1, 7, 7, dim * 2])
28 | y = dconv_bn_relu(y, dim * 2, 5, 2)
29 | img = tf.tanh(dconv(y, 1, 5, 2))
30 | return img
31 |
32 |
33 | def D(img, dim=64, is_training=True):
34 | bn = partial(batch_norm, is_training=is_training)
35 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu)
36 | fc_bn_lrelu = partial(fc, normalizer_fn=bn, activation_fn=lrelu)
37 |
38 | with tf.variable_scope('D', reuse=tf.AUTO_REUSE):
39 | y = lrelu(conv(img, 1, 5, 2))
40 | y = conv_bn_lrelu(y, dim, 5, 2)
41 | y = fc_bn_lrelu(y, 1024)
42 | logit = fc(y, 1)
43 | return logit
44 |
--------------------------------------------------------------------------------
/models_64x64.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 |
7 | import tensorflow as tf
8 | import tensorflow.contrib.slim as slim
9 | import tflib as tl
10 |
11 | conv = partial(slim.conv2d, activation_fn=None)
12 | dconv = partial(slim.conv2d_transpose, activation_fn=None)
13 | fc = partial(tl.flatten_fully_connected, activation_fn=None)
14 | relu = tf.nn.relu
15 | lrelu = tf.nn.leaky_relu
16 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None)
17 |
18 |
19 | def G(z, dim=64, is_training=True):
20 | bn = partial(batch_norm, is_training=is_training)
21 | dconv_bn_relu = partial(dconv, normalizer_fn=bn, activation_fn=relu)
22 | fc_bn_relu = partial(fc, normalizer_fn=bn, activation_fn=relu)
23 |
24 | with tf.variable_scope('G', reuse=tf.AUTO_REUSE):
25 | y = fc_bn_relu(z, 4 * 4 * dim * 8)
26 | y = tf.reshape(y, [-1, 4, 4, dim * 8])
27 | y = dconv_bn_relu(y, dim * 4, 5, 2)
28 | y = dconv_bn_relu(y, dim * 2, 5, 2)
29 | y = dconv_bn_relu(y, dim * 1, 5, 2)
30 | img = tf.tanh(dconv(y, 3, 5, 2))
31 | return img
32 |
33 |
34 | def D(img, dim=64, is_training=True):
35 | bn = partial(batch_norm, is_training=is_training)
36 | conv_bn_lrelu = partial(conv, normalizer_fn=bn, activation_fn=lrelu)
37 |
38 | with tf.variable_scope('D', reuse=tf.AUTO_REUSE):
39 | y = lrelu(conv(img, dim, 5, 2))
40 | y = conv_bn_lrelu(y, dim * 2, 5, 2)
41 | y = conv_bn_lrelu(y, dim * 4, 5, 2)
42 | y = conv_bn_lrelu(y, dim * 8, 5, 2)
43 | logit = fc(y, 1)
44 | return logit
45 |
--------------------------------------------------------------------------------
/pics/GAN_normalG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/GAN_normalG.jpg
--------------------------------------------------------------------------------
/pics/GAN_trickyG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/GAN_trickyG.jpg
--------------------------------------------------------------------------------
/pics/Jensen-Shannon_normalG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Jensen-Shannon_normalG.jpg
--------------------------------------------------------------------------------
/pics/Jensen-Shannon_trickyG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Jensen-Shannon_trickyG.jpg
--------------------------------------------------------------------------------
/pics/Kullback-Leibler_normalG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Kullback-Leibler_normalG.jpg
--------------------------------------------------------------------------------
/pics/Kullback-Leibler_trickyG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Kullback-Leibler_trickyG.jpg
--------------------------------------------------------------------------------
/pics/Pearson-X2_normalG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Pearson-X2_normalG.jpg
--------------------------------------------------------------------------------
/pics/Pearson-X2_trickyG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Pearson-X2_trickyG.jpg
--------------------------------------------------------------------------------
/pics/Reverse-KL_normalG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Reverse-KL_normalG.jpg
--------------------------------------------------------------------------------
/pics/Reverse-KL_trickyG.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/f-GAN-Tensorflow/d285e84a713d44a9ea32a883d2a91119e219608e/pics/Reverse-KL_trickyG.jpg
--------------------------------------------------------------------------------
/pylib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from pylib.timer import *
6 | from pylib.utils import *
7 |
--------------------------------------------------------------------------------
/pylib/timer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import datetime
6 | import timeit
7 |
8 |
9 | class Timer(object):
10 | """A timer as a context manager.
11 |
12 | Modified from https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py.
13 |
14 | Wraps around a timer. A custom timer can be passed
15 | to the constructor. The default timer is timeit.default_timer.
16 |
17 | Note that the latter measures wall clock time, not CPU time!
18 | On Unix systems, it corresponds to time.time.
19 | On Windows systems, it corresponds to time.clock.
20 |
21 | Keyword arguments:
22 | is_output -- if True, print output after exiting context.
23 | format -- 'ms', 's' or 'datetime'
24 | """
25 |
26 | def __init__(self, timer=timeit.default_timer, is_output=True, fmt='s'):
27 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!"
28 | self._timer = timer
29 | self._is_output = is_output
30 | self._fmt = fmt
31 |
32 | def __enter__(self):
33 | """Start the timer in the context manager scope."""
34 | self.start()
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_value, exc_traceback):
38 | """Set the end time."""
39 | if self._is_output:
40 | print(str(self))
41 |
42 | def __str__(self):
43 | if self._fmt != 'datetime':
44 | return '%s %s' % (self.elapsed, self._fmt)
45 | else:
46 | return str(self.elapsed)
47 |
48 | def start(self):
49 | self.start_time = self._timer()
50 |
51 | @property
52 | def elapsed(self):
53 | """Return the current elapsed time since start."""
54 | e = self._timer() - self.start_time
55 |
56 | if self._fmt == 'ms':
57 | return e * 1000
58 | elif self._fmt == 's':
59 | return e
60 | elif self._fmt == 'datetime':
61 | return datetime.timedelta(seconds=e)
62 |
63 |
64 | def timer(**timer_kwargs):
65 | """Function decorator displaying the function execution time.
66 |
67 | All kwargs are the arguments taken by the Timer class constructor.
68 | """
69 | # store Timer kwargs in local variable so the namespace isn't polluted
70 | # by different level args and kwargs
71 |
72 | def wrapped_f(f):
73 | def wrapped(*args, **kwargs):
74 | fmt = '[*] function "%(function_name)s" execution time: %(execution_time)s [*]'
75 | with Timer(**timer_kwargs) as t:
76 | out = f(*args, **kwargs)
77 | context = {'function_name': f.__name__, 'execution_time': str(t)}
78 | print(fmt % context)
79 | return out
80 | return wrapped
81 |
82 | return wrapped_f
83 |
84 | if __name__ == "__main__":
85 | import time
86 |
87 | # 1
88 | print(1)
89 | with Timer() as t:
90 | time.sleep(1)
91 | print(t)
92 | time.sleep(1)
93 |
94 | with Timer(fmt='datetime') as t:
95 | time.sleep(1)
96 |
97 | # 2
98 | print(2)
99 | t = Timer(fmt='ms')
100 | t.start()
101 | time.sleep(2)
102 | print(t)
103 |
104 | t = Timer(fmt='datetime')
105 | t.start()
106 | time.sleep(1)
107 | print(t)
108 |
109 | # 3
110 | print(3)
111 |
112 | @timer(fmt='ms')
113 | def blah():
114 | time.sleep(2)
115 |
116 | blah()
117 |
--------------------------------------------------------------------------------
/pylib/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import sys
7 |
8 |
9 | def add_path(paths):
10 | if not isinstance(paths, (list, tuple)):
11 | paths = [paths]
12 | for path in paths:
13 | if path not in sys.path:
14 | sys.path.insert(0, path)
15 |
16 |
17 | def mkdir(paths):
18 | if not isinstance(paths, (list, tuple)):
19 | paths = [paths]
20 | for path in paths:
21 | if not os.path.isdir(path):
22 | os.makedirs(path)
23 |
24 |
25 | if __name__ == '__main__':
26 | pass
27 |
--------------------------------------------------------------------------------
/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.checkpoint import *
6 | from tflib.data import *
7 | from tflib.ops import *
8 | from tflib.utils import *
9 | from tflib.variable import *
10 | from tflib.vision import *
11 |
--------------------------------------------------------------------------------
/tflib/checkpoint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def load_checkpoint(ckpt_dir_or_file, session, var_list=None):
11 | """Load checkpoint.
12 |
13 | Note:
14 | This function add some useless ops to the graph. It is better
15 | to use tf.train.init_from_checkpoint(...).
16 | """
17 | if os.path.isdir(ckpt_dir_or_file):
18 | ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
19 |
20 | restorer = tf.train.Saver(var_list)
21 | restorer.restore(session, ckpt_dir_or_file)
22 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file)
23 |
24 |
25 | def init_from_checkpoint(ckpt_dir_or_file, assignment_map={'/': '/'}):
26 | # Use the checkpoint values for the variables' initializers. Note that this
27 | # function just changes the initializers but does not actually run them, and
28 | # you should still run the initializers manually.
29 | tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map)
30 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file)
31 |
--------------------------------------------------------------------------------
/tflib/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.data.dataset import *
6 | from tflib.data.disk_image import *
7 | from tflib.data.memory_data import *
8 | from tflib.data.tfrecord import *
9 | from tflib.data.tfrecord_creator import *
10 |
--------------------------------------------------------------------------------
/tflib/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 | import tensorflow.contrib.eager as tfe
9 | from tflib.utils import session
10 |
11 |
12 | _N_CPU = multiprocessing.cpu_count()
13 |
14 |
15 | def batch_dataset(dataset, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
16 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
17 | if filter:
18 | dataset = dataset.filter(filter)
19 |
20 | if map_func:
21 | dataset = dataset.map(map_func, num_parallel_calls=num_threads)
22 |
23 | if shuffle:
24 | dataset = dataset.shuffle(buffer_size)
25 |
26 | if drop_remainder:
27 | dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
28 | else:
29 | dataset = dataset.batch(batch_size)
30 |
31 | dataset = dataset.repeat(repeat).prefetch(prefetch_batch)
32 |
33 | return dataset
34 |
35 |
36 | class Dataset(object):
37 |
38 | def __init__(self):
39 | self._dataset = None
40 | self._iterator = None
41 | self._batch_op = None
42 | self._sess = None
43 |
44 | self._is_eager = tf.executing_eagerly()
45 | self._eager_iterator = None
46 |
47 | def __del__(self):
48 | if self._sess:
49 | self._sess.close()
50 |
51 | def __iter__(self):
52 | return self
53 |
54 | def __next__(self):
55 | try:
56 | b = self.get_next()
57 | except:
58 | raise StopIteration
59 | else:
60 | return b
61 |
62 | next = __next__
63 |
64 | def get_next(self):
65 | if self._is_eager:
66 | return self._eager_iterator.get_next()
67 | else:
68 | return self._sess.run(self._batch_op)
69 |
70 | def reset(self, feed_dict={}):
71 | if self._is_eager:
72 | self._eager_iterator = tfe.Iterator(self._dataset)
73 | else:
74 | self._sess.run(self._iterator.initializer, feed_dict=feed_dict)
75 |
76 | def _bulid(self, dataset, sess=None):
77 | self._dataset = dataset
78 |
79 | if self._is_eager:
80 | self._eager_iterator = tfe.Iterator(dataset)
81 | else:
82 | self._iterator = dataset.make_initializable_iterator()
83 | self._batch_op = self._iterator.get_next()
84 | if sess:
85 | self._sess = sess
86 | else:
87 | self._sess = session()
88 |
89 | try:
90 | self.reset()
91 | except:
92 | pass
93 |
94 | @property
95 | def dataset(self):
96 | return self._dataset
97 |
98 | @property
99 | def iterator(self):
100 | return self._iterator
101 |
102 | @property
103 | def batch_op(self):
104 | return self._batch_op
105 |
--------------------------------------------------------------------------------
/tflib/data/disk_image.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 | from tflib.data.dataset import batch_dataset, Dataset
9 |
10 |
11 | _N_CPU = multiprocessing.cpu_count()
12 |
13 |
14 | def disk_image_batch_dataset(img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
15 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
16 | """Disk image batch dataset.
17 |
18 | This function is suitable for jpg and png files
19 |
20 | img_paths: string list or 1-D tensor, each of which is an iamge path
21 | labels: label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label
22 | """
23 | if labels is None:
24 | dataset = tf.data.Dataset.from_tensor_slices(img_paths)
25 | elif isinstance(labels, tuple):
26 | dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels))
27 | else:
28 | dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))
29 |
30 | def parse_func(path, *label):
31 | img = tf.read_file(path)
32 | img = tf.image.decode_png(img, 3)
33 | return (img,) + label
34 |
35 | if map_func:
36 | def map_func_(*args):
37 | return map_func(*parse_func(*args))
38 | else:
39 | map_func_ = parse_func
40 |
41 | # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower
42 |
43 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
44 | map_func_, num_threads, shuffle, buffer_size, repeat)
45 |
46 | return dataset
47 |
48 |
49 | class DiskImageData(Dataset):
50 | """DiskImageData.
51 |
52 | This function is suitable for jpg and png files
53 |
54 | img_paths: string list or 1-D tensor, each of which is an iamge path
55 | labels: label list or tensor, each of which is a corresponding label
56 | """
57 |
58 | def __init__(self, img_paths, batch_size, labels=None, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
59 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
60 | super(DiskImageData, self).__init__()
61 | dataset = disk_image_batch_dataset(img_paths, batch_size, labels, prefetch_batch, drop_remainder, filter,
62 | map_func, num_threads, shuffle, buffer_size, repeat)
63 | self._bulid(dataset, sess)
64 |
--------------------------------------------------------------------------------
/tflib/data/memory_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 | from tflib.data.dataset import batch_dataset, Dataset
9 |
10 |
11 | _N_CPU = multiprocessing.cpu_count()
12 |
13 |
14 | def memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
15 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
16 | """Memory data batch dataset.
17 |
18 | memory_data_dict:
19 | for example
20 | {'img': img_ndarray, 'label': label_ndarray} or
21 | {'img': img_tftensor, 'label': label_tftensor}
22 | the value of each item of `memory_data_dict` is in shape of (N, ...)
23 | """
24 | dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict)
25 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
26 | map_func, num_threads, shuffle, buffer_size, repeat)
27 | return dataset
28 |
29 |
30 | class MemoryData(Dataset):
31 | """MemoryData.
32 |
33 | memory_data_dict:
34 | for example
35 | {'img': img_ndarray, 'label': label_ndarray} or
36 | {'img': img_tftensor, 'label': label_tftensor}
37 | the value of each item of `memory_data_dict` is in shape of (N, ...)
38 | """
39 |
40 | def __init__(self, memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
41 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
42 | super(MemoryData, self).__init__()
43 | dataset = memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter,
44 | map_func, num_threads, shuffle, buffer_size, repeat)
45 | self._bulid(dataset, sess)
46 |
47 |
48 | if __name__ == '__main__':
49 | import numpy as np
50 | data = {'a': np.array([1.0, 2, 3, 4, 5]),
51 | 'b': np.array([[1, 2],
52 | [2, 3],
53 | [3, 4],
54 | [4, 5],
55 | [5, 6]])}
56 |
57 | def filter(x):
58 | return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False))
59 |
60 | def map_func(x):
61 | x['a'] = x['a'] * 10
62 | return x
63 |
64 | # tf.enable_eager_execution()
65 |
66 | s = tf.Session()
67 |
68 | dataset = MemoryData(data, 2, filter=None, map_func=map_func,
69 | shuffle=True, buffer_size=4096, drop_remainder=True, repeat=4, sess=s)
70 |
71 | for i in range(5):
72 | print(map(dataset.get_next().__getitem__, ['b', 'a']))
73 |
74 | print([n.name for n in tf.get_default_graph().as_graph_def().node])
75 |
--------------------------------------------------------------------------------
/tflib/data/tfrecord.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | import glob
7 | import json
8 | import multiprocessing
9 | import os
10 |
11 | import tensorflow as tf
12 | from tflib.data.dataset import batch_dataset, Dataset
13 |
14 |
15 | _N_CPU = multiprocessing.cpu_count()
16 |
17 | _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024
18 |
19 | _DECODERS = {
20 | 'png': {'decoder': tf.image.decode_png, 'decode_param': dict()},
21 | 'jpg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()},
22 | 'jpeg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()},
23 | 'uint8': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.uint8)},
24 | 'int64': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.int64)},
25 | 'float32': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.float32)},
26 | }
27 |
28 |
29 | def tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True,
30 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1):
31 | """Tfrecord batch dataset.
32 |
33 | infos:
34 | for example
35 | [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]},
36 | {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}]
37 | """
38 | dataset = tf.data.TFRecordDataset(tfrecord_files,
39 | compression_type=compression_type,
40 | buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES)
41 |
42 | features = {}
43 | for info in infos:
44 | features[info['name']] = tf.FixedLenFeature([], tf.string)
45 |
46 | def parse_func(serialized_example):
47 | example = tf.parse_single_example(serialized_example, features=features)
48 |
49 | feature_dict = {}
50 | for info in infos:
51 | name = info['name']
52 | decoder = info['decoder']
53 | decode_param = info['decode_param']
54 | shape = info['shape']
55 |
56 | feature = decoder(example[name], **decode_param)
57 | feature = tf.reshape(feature, shape)
58 | feature_dict[name] = feature
59 |
60 | return feature_dict
61 |
62 | dataset = dataset.map(parse_func, num_parallel_calls=num_threads)
63 |
64 | dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter,
65 | map_func, num_threads, shuffle, buffer_size, repeat)
66 |
67 | return dataset
68 |
69 |
70 | class TfrecordData(Dataset):
71 |
72 | def __init__(self, tfrecord_path, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True,
73 | filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
74 | super(TfrecordData, self).__init__()
75 |
76 | info_file = os.path.join(tfrecord_path, 'info.json')
77 | infos, self._data_num, compression_type = self._parse_json(info_file)
78 |
79 | self._shapes = {info['name']: tuple(info['shape']) for info in infos}
80 |
81 | tfrecord_files = sorted(glob.glob(os.path.join(tfrecord_path, '*.tfrecord')))
82 | dataset = tfrecord_batch_dataset(tfrecord_files, infos, compression_type, batch_size, prefetch_batch, drop_remainder,
83 | filter, map_func, num_threads, shuffle, buffer_size, repeat)
84 |
85 | self._bulid(dataset, sess)
86 |
87 | def __len__(self):
88 | return self._data_num
89 |
90 | @property
91 | def shape(self):
92 | return self._shapes
93 |
94 | @staticmethod
95 | def _parse_old(json_file):
96 | with open(json_file.replace('info.json', 'info.txt')) as f:
97 | try: # older version 1
98 | infos = json.load(f)
99 | for info in infos[0:-1]:
100 | info['decoder'] = _DECODERS[info['dtype_or_format']]['decoder']
101 | info['decode_param'] = _DECODERS[info['dtype_or_format']]['decode_param']
102 | except: # older version 2
103 | f.seek(0)
104 | infos = ''
105 | for line in f.readlines():
106 | infos += line.strip('\n')
107 | infos = eval(infos)
108 |
109 | data_num = infos[-1]['data_num']
110 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[infos[-1]['compression_type']]
111 | infos[-1:] = []
112 |
113 | return infos, data_num, compression_type
114 |
115 | @staticmethod
116 | def _parse_json(json_file):
117 | try:
118 | with open(json_file) as f:
119 | info = json.load(f)
120 | infos = info['item']
121 | for i in infos:
122 | i['decoder'] = _DECODERS[i['dtype_or_format']]['decoder']
123 | i['decode_param'] = _DECODERS[i['dtype_or_format']]['decode_param']
124 | data_num = info['info']['data_num']
125 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[info['info']['compression_type']]
126 | except: # for older version
127 | infos, data_num, compression_type = TfrecordData._parse_old(json_file)
128 |
129 | return infos, data_num, compression_type
130 |
--------------------------------------------------------------------------------
/tflib/data/tfrecord_creator.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import io
6 | import json
7 | import os
8 | import shutil
9 |
10 | import numpy as np
11 | from PIL import Image
12 | import tensorflow as tf
13 | from tflib.data import tfrecord
14 |
15 | __metaclass__ = type
16 |
17 |
18 | _ALLOWED_TYPES = tfrecord._DECODERS.keys()
19 |
20 |
21 | class BytesTfrecordCreator(object):
22 | """BytesTfrecordCreator.
23 |
24 | `infos` example:
25 | infos = [
26 | ['img', 'jpg', (64, 64, 3)],
27 | ['id', 'int64', ()],
28 | ['attr', 'int64', (40,)],
29 | ['point', 'float32', (5, 2)]
30 | ]
31 |
32 | `compression_type`:
33 | 0: NONE
34 | 1: ZLIB
35 | 2: GZIP
36 | """
37 |
38 | def __init__(self, save_path, infos, size_each=None, compression_type=0, overwrite_existence=False):
39 | # overwrite existence
40 | if os.path.exists(save_path):
41 | if not overwrite_existence:
42 | raise Exception('%s exists!' % save_path)
43 | else:
44 | shutil.rmtree(save_path)
45 | os.makedirs(save_path)
46 | else:
47 | os.makedirs(save_path)
48 |
49 | self._save_path = save_path
50 |
51 | # add info
52 | self._infos = []
53 | self._info_names = []
54 | for info in infos:
55 | self._add_info(*info)
56 |
57 | self._data_num = 0
58 | self._tfrecord_num = 0
59 | self._size_each = [size_each, 2147483647][not size_each]
60 | self._writer = None
61 |
62 | self._compression_type = compression_type
63 | self._options = tf.python_io.TFRecordOptions(compression_type)
64 |
65 | def __del__(self):
66 | info = {'item': self._infos, 'info': {'data_num': self._data_num, 'compression_type': self._compression_type}}
67 | info_str = json.dumps(info, indent=4, separators=(',', ':'))
68 |
69 | with open(os.path.join(self._save_path, 'info.json'), 'w') as info_f:
70 | info_f.write(info_str)
71 |
72 | if self._writer:
73 | self._writer.close()
74 |
75 | def add(self, feature_bytes_dict):
76 | """Add example.
77 |
78 | `feature_bytes_dict` example:
79 | feature_bytes_dict = {
80 | 'img': img_bytes,
81 | 'id': id_bytes,
82 | 'attr': attr_bytes,
83 | 'point': point_bytes
84 | }
85 | """
86 | assert sorted(self._info_names) == sorted(feature_bytes_dict.keys()), \
87 | 'Feature names are inconsistent with the givens!'
88 |
89 | self._new_tfrecord_check()
90 |
91 | self._writer.write(self._bytes_tfexample(feature_bytes_dict).SerializeToString())
92 | self._data_num += 1
93 |
94 | def _new_tfrecord_check(self):
95 | if self._data_num // self._size_each == self._tfrecord_num:
96 | self._tfrecord_num += 1
97 |
98 | if self._writer:
99 | self._writer.close()
100 |
101 | tfrecord_path = os.path.join(self._save_path, 'data_%06d.tfrecord' % (self._tfrecord_num - 1))
102 | self._writer = tf.python_io.TFRecordWriter(tfrecord_path, self._options)
103 |
104 | def _add_info(self, name, dtype_or_format, shape):
105 | assert name not in self._info_names, 'Info name "%s" is duplicated!' % name
106 | assert dtype_or_format in _ALLOWED_TYPES, 'Allowed data types: %s!' % str(_ALLOWED_TYPES)
107 | self._infos.append(dict(name=name, dtype_or_format=dtype_or_format, shape=shape))
108 | self._info_names.append(name)
109 |
110 | @staticmethod
111 | def _bytes_feature(values):
112 | """Return a TF-Feature of bytes.
113 |
114 | Args:
115 | values: A byte string or list of byte strings.
116 |
117 | Returns:
118 | a TF-Feature.
119 | """
120 | if not isinstance(values, (tuple, list)):
121 | values = [values]
122 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
123 |
124 | @staticmethod
125 | def _bytes_tfexample(bytes_dict):
126 | """Convert bytes to tfexample.
127 |
128 | `bytes_dict` example:
129 | bytes_dict = {
130 | 'img': img_bytes,
131 | 'id': id_bytes,
132 | 'attr': attr_bytes,
133 | 'point': point_bytes
134 | }
135 | """
136 | feature_dict = {}
137 | for key, value in bytes_dict.items():
138 | feature_dict[key] = BytesTfrecordCreator._bytes_feature(value)
139 | return tf.train.Example(features=tf.train.Features(feature=feature_dict))
140 |
141 |
142 | class DataLablePairTfrecordCreator(BytesTfrecordCreator):
143 | """DataLablePairTfrecordCreator.
144 |
145 | If `data_shape` is None, then the `data` to be added should be a
146 | numpy array, and the shape and dtype of `data` will be inferred.
147 | If `data_shape` is not None, `data` should be given as byte string,
148 | and `data_dtype_or_format` should also be given.
149 |
150 | `compression_type`:
151 | 0: NONE
152 | 1: ZLIB
153 | 2: GZIP
154 | """
155 |
156 | def __init__(self, save_path, data_shape=None, data_dtype_or_format=None, data_name='data',
157 | size_each=None, compression_type=0, overwrite_existence=False):
158 | super(DataLablePairTfrecordCreator, self).__init__(save_path, [], size_each, compression_type, overwrite_existence)
159 |
160 | if data_shape:
161 | assert data_dtype_or_format, '`data_dtype_or_format` should be given when `data_shape` is given!'
162 | self._is_data_bytes = True
163 | else:
164 | self._is_data_bytes = False
165 |
166 | self._data_shape = data_shape
167 | self._data_dtype_or_format = data_dtype_or_format
168 | self._data_name = data_name
169 | self._label_shape_dict = {}
170 | self._label_dtype_dict = {}
171 |
172 | self._info_built = False
173 |
174 | def add(self, data, label_dict):
175 | """Add example.
176 |
177 | `label_dict` example:
178 | label_dict = {
179 | 'id': id_ndarray,
180 | 'attr': attr_ndarray,
181 | 'point': point_ndarray
182 | }
183 | """
184 | self._check_and_build(data, label_dict)
185 |
186 | if not self._is_data_bytes:
187 | data = data.tobytes()
188 |
189 | feature_dict = {self._data_name: data}
190 | for name, label in label_dict.items():
191 | feature_dict[name] = label.tobytes()
192 |
193 | super(DataLablePairTfrecordCreator, self).add(feature_dict)
194 |
195 | def _check_and_build(self, data, label_dict):
196 | # check type
197 | if self._is_data_bytes:
198 | assert isinstance(data, (str, bytes)), '`data` should be a byte string!'
199 | else:
200 | assert isinstance(data, np.ndarray), '`data` should be a numpy array!'
201 | for label in label_dict.values():
202 | assert isinstance(label, np.ndarray), 'labels should be numpy arrays!'
203 |
204 | # check shape and dtype or bulid info at first adding
205 | if self._info_built:
206 | if not self._is_data_bytes:
207 | assert data.shape == tuple(self._data_shape), 'Shapes of `data`s are inconsistent!'
208 | assert data.dtype.name == self._data_dtype_or_format, 'Dtypes of `data`s are inconsistent!'
209 | for name, label in label_dict.items():
210 | assert label.shape == self._label_shape_dict[name], 'Shapes of `%s`s are inconsistent!' % name
211 | assert label.dtype.name == self._label_dtype_dict[name], 'Dtypes of `%s`s are inconsistent!' % name
212 | else:
213 | if not self._is_data_bytes:
214 | self._data_shape = data.shape
215 | self._data_dtype_or_format = data.dtype.name
216 | self._add_info(self._data_name, self._data_dtype_or_format, self._data_shape)
217 |
218 | for name, label in label_dict.items():
219 | self._label_shape_dict[name] = label.shape
220 | self._label_dtype_dict[name] = label.dtype.name
221 | self._add_info(name, label.dtype.name, label.shape)
222 |
223 | self._info_built = True
224 |
225 |
226 | class ImageLablePairTfrecordCreator(DataLablePairTfrecordCreator):
227 | """ImageLablePairTfrecordCreator.
228 |
229 | `encode_type`: in [None, 'png', 'jpg'].
230 | `quality`: for 'jpg'.
231 | `compression_type`:
232 | 0: NONE
233 | 1: ZLIB
234 | 2: GZIP
235 | """
236 |
237 | def __init__(self, save_path, encode_type='png', quality=95, data_name='img',
238 | size_each=None, compression_type=0, overwrite_existence=False):
239 | super(ImageLablePairTfrecordCreator, self).__init__(
240 | save_path, None, None, data_name, size_each, compression_type, overwrite_existence)
241 |
242 | assert encode_type in [None, 'png', 'jpg'], "`encode_type` should be in the list of [None, 'png', 'jpg']!"
243 |
244 | self._encode_type = encode_type
245 | self._quality = quality
246 |
247 | self._data_shape = None
248 | self._data_dtype_or_format = None
249 | self._is_data_bytes = True
250 |
251 | def add(self, image, label_dict):
252 | """Add example.
253 |
254 | `image`: an H * W (* C) uint8 numpy array.
255 |
256 | `label_dict` example:
257 | label_dict = {
258 | 'id': id_ndarray,
259 | 'attr': attr_ndarray,
260 | 'point': point_ndarray
261 | }
262 | """
263 | self._check(image)
264 | image_bytes = self._encode(image)
265 | super(ImageLablePairTfrecordCreator, self).add(image_bytes, label_dict)
266 |
267 | def _check(self, image):
268 | if not self._data_shape:
269 | assert isinstance(image, np.ndarray) and image.dtype == np.uint8 and image.ndim in [2, 3], \
270 | '`image` should be an H * W (* C) uint8 numpy array!'
271 | if self._encode_type and image.ndim == 3 and image.shape[-1] != 3:
272 | raise Exception('Only images with 1 or 3 channels are allowed to be encoded!')
273 |
274 | if image.ndim == 2:
275 | self._data_shape = image.shape + (1,)
276 | else:
277 | self._data_shape = image.shape
278 | self._data_dtype_or_format = [self._encode_type, 'uint8'][not self._encode_type]
279 | else:
280 | sp = image.shape
281 | if image.ndim == 2:
282 | sp = sp + (1,)
283 | assert sp == self._data_shape, 'Shapes of `image`s are inconsistent!'
284 | assert image.dtype == np.uint8, 'Dtypes of `image`s are inconsistent!'
285 |
286 | def _encode(self, image):
287 | if self._encode_type:
288 | if image.shape[-1] == 1:
289 | image.shape = image.shape[:2]
290 | byte = io.BytesIO()
291 | image = Image.fromarray(image)
292 | if self._encode_type == 'jpg':
293 | image.save(byte, 'JPEG', quality=self._quality)
294 | elif self._encode_type == 'png':
295 | image.save(byte, 'PNG')
296 | image_bytes = byte.getvalue()
297 | else:
298 | image_bytes = image.tobytes()
299 | return image_bytes
300 |
--------------------------------------------------------------------------------
/tflib/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.ops.layers import *
6 |
--------------------------------------------------------------------------------
/tflib/ops/layers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import tensorflow.contrib.slim as slim
7 |
8 |
9 | def flatten_fully_connected(inputs,
10 | num_outputs,
11 | activation_fn=tf.nn.relu,
12 | normalizer_fn=None,
13 | normalizer_params=None,
14 | weights_initializer=slim.xavier_initializer(),
15 | weights_regularizer=None,
16 | biases_initializer=tf.zeros_initializer(),
17 | biases_regularizer=None,
18 | reuse=None,
19 | variables_collections=None,
20 | outputs_collections=None,
21 | trainable=True,
22 | scope=None):
23 | with tf.variable_scope(scope, 'flatten_fully_connected', [inputs]):
24 | if inputs.shape.ndims > 2:
25 | inputs = slim.flatten(inputs)
26 | return slim.fully_connected(inputs,
27 | num_outputs,
28 | activation_fn,
29 | normalizer_fn,
30 | normalizer_params,
31 | weights_initializer,
32 | weights_regularizer,
33 | biases_initializer,
34 | biases_regularizer,
35 | reuse,
36 | variables_collections,
37 | outputs_collections,
38 | trainable,
39 | scope)
40 |
41 | flatten_dense = flatten_fully_connected
42 |
--------------------------------------------------------------------------------
/tflib/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import re
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def session(graph=None, allow_soft_placement=True,
11 | log_device_placement=False, allow_growth=True):
12 | """Return a Session with simple config."""
13 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
14 | log_device_placement=log_device_placement)
15 | config.gpu_options.allow_growth = allow_growth
16 | return tf.Session(graph=graph, config=config)
17 |
18 |
19 | def print_tensor(tensors):
20 | if not isinstance(tensors, (list, tuple)):
21 | tensors = [tensors]
22 |
23 | for i, tensor in enumerate(tensors):
24 | ctype = str(type(tensor))
25 | if 'Tensor' in ctype:
26 | type_name = 'Tensor'
27 | elif 'Variable' in ctype:
28 | type_name = 'Variable'
29 | else:
30 | raise Exception('Not a Tensor or Variable!')
31 |
32 | print(str(i) + (': %s("%s", shape=%s, dtype=%s, device=%s)'
33 | % (type_name, tensor.name, str(tensor.get_shape()),
34 | tensor.dtype.name, tensor.device)))
35 |
36 | prt = print_tensor
37 |
38 |
39 | def shape(tensor):
40 | sp = tensor.get_shape().as_list()
41 | return [num if num is not None else -1 for num in sp]
42 |
43 |
44 | def summary(tensor_collection,
45 | summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram'],
46 | scope=None):
47 | """Summary.
48 |
49 | usage:
50 | 1. summary(tensor)
51 | 2. summary([tensor_a, tensor_b])
52 | 3. summary({tensor_a: 'a', tensor_b: 'b})
53 | """
54 | def _summary(tensor, name, summary_type):
55 | """Attach a lot of summaries to a Tensor."""
56 | if name is None:
57 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
58 | # session. This helps the clarity of presentation on tensorboard.
59 | name = re.sub('%s_[0-9]*/' % 'tower', '', tensor.name)
60 | name = re.sub(':', '-', name)
61 |
62 | summaries = []
63 | if len(tensor.shape) == 0:
64 | summaries.append(tf.summary.scalar(name, tensor))
65 | else:
66 | if 'mean' in summary_type:
67 | mean = tf.reduce_mean(tensor)
68 | summaries.append(tf.summary.scalar(name + '/mean', mean))
69 | if 'stddev' in summary_type:
70 | mean = tf.reduce_mean(tensor)
71 | stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean)))
72 | summaries.append(tf.summary.scalar(name + '/stddev', stddev))
73 | if 'max' in summary_type:
74 | summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor)))
75 | if 'min' in summary_type:
76 | summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor)))
77 | if 'sparsity' in summary_type:
78 | summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor)))
79 | if 'histogram' in summary_type:
80 | summaries.append(tf.summary.histogram(name, tensor))
81 | return tf.summary.merge(summaries)
82 |
83 | if not isinstance(tensor_collection, (list, tuple, dict)):
84 | tensor_collection = [tensor_collection]
85 |
86 | with tf.name_scope(scope, 'summary'):
87 | summaries = []
88 | if isinstance(tensor_collection, (list, tuple)):
89 | for tensor in tensor_collection:
90 | summaries.append(_summary(tensor, None, summary_type))
91 | else:
92 | for tensor, name in tensor_collection.items():
93 | summaries.append(_summary(tensor, name, summary_type))
94 | return tf.summary.merge(summaries)
95 |
96 |
97 | def counter(start=0, scope=None):
98 | with tf.variable_scope(scope, 'counter'):
99 | counter = tf.get_variable(name='counter',
100 | initializer=tf.constant_initializer(start),
101 | shape=(),
102 | dtype=tf.int64)
103 | update_cnt = tf.assign(counter, tf.add(counter, 1))
104 | return counter, update_cnt
105 |
--------------------------------------------------------------------------------
/tflib/variable.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def tensors_filter(tensors, filters, combine_type='or'):
9 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!'
10 | assert isinstance(filters, (str, list, tuple)), '`filters` should be a string or a list(tuple) of strings!'
11 | assert combine_type == 'or' or combine_type == 'and', "`combine_type` should be 'or' or 'and'!"
12 |
13 | if isinstance(filters, str):
14 | filters = [filters]
15 |
16 | f_tens = []
17 | for ten in tensors:
18 | if combine_type == 'or':
19 | for filt in filters:
20 | if filt in ten.name:
21 | f_tens.append(ten)
22 | break
23 | elif combine_type == 'and':
24 | all_pass = True
25 | for filt in filters:
26 | if filt not in ten.name:
27 | all_pass = False
28 | break
29 | if all_pass:
30 | f_tens.append(ten)
31 | return f_tens
32 |
33 |
34 | def global_variables(filters=None, combine_type='or'):
35 | global_vars = tf.global_variables()
36 | if filters is None:
37 | return global_vars
38 | else:
39 | return tensors_filter(global_vars, filters, combine_type)
40 |
41 |
42 | def trainable_variables(filters=None, combine_type='or'):
43 | t_var = tf.trainable_variables()
44 | if filters is None:
45 | return t_var
46 | else:
47 | return tensors_filter(t_var, filters, combine_type)
48 |
--------------------------------------------------------------------------------
/tflib/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.vision.dataset import *
6 |
--------------------------------------------------------------------------------
/tflib/vision/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.vision.dataset.mnist import *
6 |
--------------------------------------------------------------------------------
/tflib/vision/dataset/mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gzip
6 | import multiprocessing
7 | import os
8 | import struct
9 | import subprocess
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 | from tflib.data.memory_data import MemoryData
14 |
15 |
16 | _N_CPU = multiprocessing.cpu_count()
17 |
18 |
19 | def unzip_gz(file_name):
20 | unzip_name = file_name.replace('.gz', '')
21 | gz_file = gzip.GzipFile(file_name)
22 | open(unzip_name, 'w+').write(gz_file.read())
23 | gz_file.close()
24 |
25 |
26 | def mnist_download(download_dir):
27 | url_base = 'http://yann.lecun.com/exdb/mnist/'
28 | file_names = ['train-images-idx3-ubyte.gz',
29 | 'train-labels-idx1-ubyte.gz',
30 | 't10k-images-idx3-ubyte.gz',
31 | 't10k-labels-idx1-ubyte.gz']
32 | for file_name in file_names:
33 | url = url_base + file_name
34 | save_path = os.path.join(download_dir, file_name)
35 | cmd = ['curl', url, '-o', save_path]
36 | print('Downloading ', file_name)
37 | if not os.path.exists(save_path):
38 | subprocess.call(cmd)
39 | else:
40 | print('%s exists, skip!' % file_name)
41 |
42 |
43 | def mnist_load(data_dir, split='train'):
44 | """Load MNIST dataset, modified from https://gist.github.com/akesling/5358964.
45 |
46 | Returns:
47 | A tuple as (`imgs`, `lbls`, `num`).
48 |
49 | `imgs`: [-1.0, 1.0] float64 images of shape (N * H * W).
50 | `lbls`: Int labels of shape (N,).
51 | `num`: # of datas.
52 | """
53 | mnist_download(data_dir)
54 |
55 | if split == 'train':
56 | fname_img = os.path.join(data_dir, 'train-images-idx3-ubyte')
57 | fname_lbl = os.path.join(data_dir, 'train-labels-idx1-ubyte')
58 | elif split == 'test':
59 | fname_img = os.path.join(data_dir, 't10k-images-idx3-ubyte')
60 | fname_lbl = os.path.join(data_dir, 't10k-labels-idx1-ubyte')
61 | else:
62 | raise ValueError("split must be 'test' or 'train'")
63 |
64 | if not os.path.exists(fname_img):
65 | unzip_gz(fname_img + '.gz')
66 | if not os.path.exists(fname_lbl):
67 | unzip_gz(fname_lbl + '.gz')
68 |
69 | # Load everything in some numpy arrays
70 | with open(fname_lbl, 'rb') as flbl:
71 | struct.unpack('>II', flbl.read(8))
72 | lbls = np.fromfile(flbl, dtype=np.int8)
73 |
74 | with open(fname_img, 'rb') as fimg:
75 | _, _, rows, cols = struct.unpack('>IIII', fimg.read(16))
76 | imgs = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbls), rows, cols)
77 | imgs = imgs / 127.5 - 1
78 |
79 | return imgs, lbls, len(lbls)
80 |
81 |
82 | class Mnist(MemoryData):
83 |
84 | def __init__(self, data_dir, batch_size, split='train', prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None,
85 | map_func=None, num_threads=_N_CPU, shuffle=True, buffer_size=4096, repeat=-1, sess=None):
86 | imgs, lbls, self.n_data = mnist_load(data_dir, split)
87 | imgs.shape = imgs.shape + (1,)
88 |
89 | imgs_pl = tf.placeholder(tf.float32, imgs.shape)
90 | lbls_pl = tf.placeholder(tf.int64, lbls.shape)
91 |
92 | memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl}
93 |
94 | self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls}
95 | super(Mnist, self).__init__(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter,
96 | map_func, num_threads, shuffle, buffer_size, repeat, sess)
97 |
98 | def __len__(self):
99 | return self.n_data
100 |
101 | def reset(self):
102 | super(Mnist, self).reset(self.feed_dict)
103 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import traceback
7 |
8 | import imlib as im
9 | import numpy as np
10 | import pylib
11 | import tensorflow as tf
12 | import tflib as tl
13 | import utils
14 |
15 |
16 | # ==============================================================================
17 | # = param =
18 | # ==============================================================================
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--epoch', dest='epoch', type=int, default=50)
22 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64)
23 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate')
24 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=100, help='dimension of latent')
25 | parser.add_argument('--divergence', dest='divergence', default='Jensen-Shannon',
26 | choices=['Kullback-Leibler', 'Reverse-KL', 'Pearson-X2', 'Squared-Hellinger', 'Jensen-Shannon', 'GAN'])
27 | parser.add_argument('--tricky_G', dest='tricky_G', action='store_true', help='use tricky G loss or not')
28 | parser.add_argument('--dataset', dest='dataset_name', default='mnist', choices=['mnist', 'celeba'])
29 |
30 | args = parser.parse_args()
31 |
32 | epoch = args.epoch
33 | batch_size = args.batch_size
34 | lr = args.lr
35 | z_dim = args.z_dim
36 |
37 | divergence = args.divergence
38 | tricky_G = args.tricky_G
39 | dataset_name = args.dataset_name
40 | print(tricky_G)
41 | experiment_name = '%s_%s_%s' % (dataset_name, divergence, 'trickyG' if tricky_G else 'normalG')
42 |
43 | # dataset and models
44 | Dataset, models = utils.get_dataset_models(dataset_name)
45 | dataset = Dataset(batch_size=batch_size)
46 | G = models['G']
47 | D = models['D']
48 | activation_fn, conjugate_fn = utils.get_divergence_funcs(divergence)
49 |
50 |
51 | # ==============================================================================
52 | # = graph =
53 | # ==============================================================================
54 |
55 | # inputs
56 | real = tf.placeholder(tf.float32, [None, 28, 28, 1])
57 | z = tf.placeholder(tf.float32, [None, z_dim])
58 |
59 | # generate
60 | fake = G(z)
61 |
62 | # dicriminate
63 | r_output = D(real)
64 | f_output = D(fake)
65 |
66 | # losses
67 | d_r_loss = -tf.reduce_mean(activation_fn(r_output))
68 | d_f_loss = tf.reduce_mean(conjugate_fn(activation_fn(f_output)))
69 | d_loss = d_r_loss + d_f_loss
70 | if tricky_G:
71 | g_loss = -tf.reduce_mean(activation_fn(f_output))
72 | else:
73 | g_loss = -d_f_loss
74 |
75 | # otpims
76 | d_var = tl.trainable_variables('D')
77 | g_var = tl.trainable_variables('G')
78 | d_step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(d_loss, var_list=d_var)
79 | g_step = tf.train.AdamOptimizer(learning_rate=lr, beta1=0.5).minimize(g_loss, var_list=g_var)
80 |
81 | # summaries
82 | d_summary = tl.summary({d_r_loss: 'd_r_loss',
83 | d_f_loss: 'd_f_loss',
84 | -d_loss: '%s_diverngence' % divergence}, scope='D')
85 | g_summary = tl.summary({g_loss: 'g_loss'}, scope='G')
86 |
87 | # sample
88 | f_sample = G(z, is_training=False)
89 |
90 |
91 | # ==============================================================================
92 | # = train =
93 | # ==============================================================================
94 |
95 | # session
96 | sess = tl.session()
97 |
98 | # saver
99 | saver = tf.train.Saver(max_to_keep=1)
100 |
101 | # summary writer
102 | summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph)
103 |
104 | # initialization
105 | ckpt_dir = './output/%s/checkpoints' % experiment_name
106 | pylib.mkdir(ckpt_dir)
107 | try:
108 | tl.load_checkpoint(ckpt_dir, sess)
109 | except:
110 | sess.run(tf.global_variables_initializer())
111 |
112 | # train
113 | try:
114 | z_ipt_sample = np.random.normal(size=[100, z_dim])
115 |
116 | it = -1
117 | it_per_epoch = len(dataset) // batch_size
118 | for ep in range(epoch):
119 | dataset.reset()
120 | for batch in dataset:
121 | it += 1
122 | it_in_epoch = it % it_per_epoch + 1
123 |
124 | # batch data
125 | real_ipt = batch['img']
126 | z_ipt = np.random.normal(size=[batch_size, z_dim])
127 |
128 | # train D
129 | d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={real: real_ipt, z: z_ipt})
130 | summary_writer.add_summary(d_summary_opt, it)
131 |
132 | # train G
133 | g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={z: z_ipt})
134 | summary_writer.add_summary(g_summary_opt, it)
135 |
136 | # display
137 | if (it + 1) % 1 == 0:
138 | print("Epoch: (%3d) (%5d/%5d)" % (ep, it_in_epoch, it_per_epoch))
139 |
140 | # sample
141 | if (it + 1) % 1000 == 0:
142 | f_sample_opt = sess.run(f_sample, feed_dict={z: z_ipt_sample})
143 |
144 | save_dir = './output/%s/sample_training' % experiment_name
145 | pylib.mkdir(save_dir)
146 | im.imwrite(im.immerge(f_sample_opt), '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, ep, it_in_epoch, it_per_epoch))
147 |
148 | save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep))
149 | print('Model is saved in file: %s' % save_path)
150 | except:
151 | traceback.print_exc()
152 | finally:
153 | sess.close()
154 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 |
7 | import pylib
8 | import tensorflow as tf
9 | import tflib as tl
10 |
11 |
12 | def get_divergence_funcs(divergence):
13 | if divergence == 'Kullback-Leibler':
14 | def activation_fn(v): return v
15 |
16 | def conjugate_fn(t): return tf.exp(t - 1)
17 |
18 | elif divergence == 'Reverse-KL':
19 | def activation_fn(v): return -tf.exp(-v)
20 |
21 | def conjugate_fn(t): return -1 - tf.log(-t)
22 |
23 | elif divergence == 'Pearson-X2':
24 | def activation_fn(v): return v
25 |
26 | def conjugate_fn(t): return 0.25 * t * t + t
27 |
28 | elif divergence == 'Squared-Hellinger':
29 | def activation_fn(v): return 1 - tf.exp(-v)
30 |
31 | def conjugate_fn(t): return t / (1 - t)
32 |
33 | elif divergence == 'Jensen-Shannon':
34 | def activation_fn(v): return tf.log(2.0) - tf.log(1 + tf.exp(-v))
35 |
36 | def conjugate_fn(t): return -tf.log(2 - tf.exp(t))
37 |
38 | elif divergence == 'GAN':
39 | def activation_fn(v): return -tf.log(1 + tf.exp(-v))
40 |
41 | def conjugate_fn(t): return -tf.log(1 - tf.exp(t))
42 |
43 | return activation_fn, conjugate_fn
44 |
45 |
46 | def get_dataset_models(dataset_name):
47 | if dataset_name == 'mnist':
48 | import models
49 | pylib.mkdir('./data/mnist')
50 | Dataset = partial(tl.Mnist, data_dir='./data/mnist', repeat=1)
51 | return Dataset, {'D': models.D, 'G': models.G}
52 |
53 | elif dataset_name == 'celeba':
54 | import models_64x64
55 | raise NotImplementedError
56 |
--------------------------------------------------------------------------------