├── .gitignore
├── LICENSE
├── README.md
├── data.py
├── imlib
├── __init__.py
├── basic.py
├── dtype.py
├── encode.py
└── transform.py
├── model.py
├── pics
├── bangs.png
└── eyeglasses.png
├── pylib
├── __init__.py
├── path.py
└── timer.py
├── tflib
├── __init__.py
├── checkpoint.py
├── collection.py
├── data
│ ├── __init__.py
│ ├── dataset.py
│ ├── disk_image.py
│ ├── memory_data.py
│ ├── tfrecord.py
│ └── tfrecord_creator.py
├── layers
│ ├── __init__.py
│ └── layers.py
├── ops
│ ├── __init__.py
│ └── ops.py
├── parallel.py
├── utils.py
└── vision
│ ├── __init__.py
│ └── dataset
│ ├── __init__.py
│ └── mnist.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__/
3 | /data/
4 | /output/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018 hezhenliang
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a
4 | copy of this software and associated documentation files (the "Software"),
5 | to deal in the Software without restriction, including without limitation
6 | the rights to use, copy, modify, merge, publish, distribute, sublicense,
7 | and/or sell copies of the Software, and to permit persons to whom the
8 | Software is furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19 | DEALINGS IN THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
DTLC-GAN
2 |
3 | Tensorflow implementation of [DTLC-GAN (CVPR 2018):
4 | Generative Adversarial Image Synthesis with Decision Tree Latent Controller](https://arxiv.org/abs/1805.10603).
5 |
6 | ## Usage
7 |
8 | - Prerequisites
9 | - Tensorflow 1.9
10 | - Python 3.6
11 |
12 | - Training
13 | - Important Arguments (See the others in [train.py](train.py))
14 | - `att`: attribute to learn (default: `''`)
15 | - `ks`: # of outputs of each node of each layer (default: `[2, 3, 3]`)
16 | - `lambdas`: loss weights of each layer (default: `[1.0, 1.0, 1.0]`)
17 | - `--n_d`: # of d steps in each iteration (default: `1`)
18 | - `--n_g`: # of g steps in each iteration (default: `1`)
19 | - `--loss_mode`: gan loss (choices: `[gan, lsgan, wgan, hinge]`, default: `gan`)
20 | - `--gp_mode`: type of gradient penalty (choices: `[none, dragan, wgan-gp]`, default: `none`)
21 | - `--norm`: normalization (choices: `[batch_norm, instance_norm, layer_norm, none]`, default: `batch_norm`)
22 | - `--experiment_name`: name for current experiment (default: `default`)
23 | - Example
24 | ```console
25 | CUDA_VISIBLE_DEVICES=0 \
26 | python train.py \
27 | --att Eyeglasses \
28 | --ks 2 3 3 \
29 | --lambdas 1 1 1 \
30 | --n_d 1 \
31 | --n_g 1 \
32 | --loss_mode hinge \
33 | --gp_mode dragan \
34 | --norm layer_norm \
35 | --experiment_name att{Eyeglasses}_ks{2-3-3}_lambdas{1-1-1}_continuous_last{False}_loss{hinge}_gp{dragan}_norm{layer_norm}
36 | ```
37 |
38 | ## Dataset
39 |
40 | - [Celeba](http://openaccess.thecvf.com/content_iccv_2015/papers/Liu_Deep_Learning_Face_ICCV_2015_paper.pdf) dataset
41 | - [Images](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADSNUu0bseoCKuxuI5ZeTl1a/Img?dl=0&preview=img_align_celeba.zip) should be placed in ***./data/img_align_celeba/\*.jpg***
42 | - [Attribute labels](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAA8YmAHNNU6BEfWMPMfM6r9a/Anno?dl=0&preview=list_attr_celeba.txt) should be placed in ***./data/list_attr_celeba.txt***
43 | - the above links might be inaccessible, the alternatives are
44 | - ***img_align_celeba.zip***
45 | - https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FImg or
46 | - https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg
47 | - ***list_attr_celeba.txt***
48 | - https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FAnno&parentPath=%2F or
49 | - https://drive.google.com/drive/folders/0B7EVK8r0v71pOC0wOVZlQnFfaGs
50 |
51 | ## Exemplar Results
52 |
53 | 1. Eyeglasses, 3 layers
54 |
55 |
56 | 2. Bangs, 3 layers
57 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import multiprocessing
7 |
8 | import numpy as np
9 | import tensorflow as tf
10 | import tflib as tl
11 |
12 |
13 | _N_CPU = multiprocessing.cpu_count()
14 |
15 |
16 | class Celeba(tl.Dataset):
17 |
18 | att_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,
19 | 'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6,
20 | 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10,
21 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13,
22 | 'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16,
23 | 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19,
24 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22,
25 | 'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25,
26 | 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28,
27 | 'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31,
28 | 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34,
29 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36,
30 | 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}
31 |
32 | def __init__(self,
33 | data_dir,
34 | atts,
35 | img_resize,
36 | batch_size,
37 | prefetch_batch=_N_CPU + 1,
38 | drop_remainder=True,
39 | num_threads=_N_CPU,
40 | shuffle=True,
41 | shuffle_buffer_size=None,
42 | repeat=-1,
43 | sess=None,
44 | split='train',
45 | crop=True):
46 | super(Celeba, self).__init__()
47 |
48 | list_file = os.path.join(data_dir, 'list_attr_celeba.txt')
49 | if crop:
50 | img_dir_jpg = os.path.join(data_dir, 'img_align_celeba')
51 | img_dir_png = os.path.join(data_dir, 'img_align_celeba_png')
52 | else:
53 | img_dir_jpg = os.path.join(data_dir, 'img_crop_celeba')
54 | img_dir_png = os.path.join(data_dir, 'img_crop_celeba_png')
55 |
56 | names = np.loadtxt(list_file, skiprows=2, usecols=[0], dtype=np.str)
57 | if os.path.exists(img_dir_png):
58 | img_paths = [os.path.join(img_dir_png, name.replace('jpg', 'png')) for name in names]
59 | elif os.path.exists(img_dir_jpg):
60 | img_paths = [os.path.join(img_dir_jpg, name) for name in names]
61 |
62 | att_id = [Celeba.att_dict[att] + 1 for att in atts]
63 | labels = np.loadtxt(list_file, skiprows=2, usecols=att_id, dtype=np.int64)
64 |
65 | if img_resize == 64:
66 | # crop as how VAE/GAN do
67 | offset_h = 40
68 | offset_w = 15
69 | img_size = 148
70 | else:
71 | offset_h = 26
72 | offset_w = 3
73 | img_size = 170
74 |
75 | def _map_func(img, label):
76 | if crop:
77 | img = tf.image.crop_to_bounding_box(img, offset_h, offset_w, img_size, img_size)
78 | # img = tf.image.resize_images(img, [img_resize, img_resize]) / 127.5 - 1
79 | # or
80 | img = tf.image.resize_images(img, [img_resize, img_resize], tf.image.ResizeMethod.BICUBIC)
81 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
82 | label = (label + 1) // 2
83 | return img, label
84 |
85 | if split == 'test':
86 | drop_remainder = False
87 | shuffle = False
88 | repeat = 1
89 | img_paths = img_paths[182637:]
90 | labels = labels[182637:]
91 | elif split == 'val':
92 | img_paths = img_paths[182000:182637]
93 | labels = labels[182000:182637]
94 | else:
95 | img_paths = img_paths[:182000]
96 | labels = labels[:182000]
97 |
98 | dataset = tl.disk_image_batch_dataset(img_paths=img_paths,
99 | labels=labels,
100 | batch_size=batch_size,
101 | prefetch_batch=prefetch_batch,
102 | drop_remainder=drop_remainder,
103 | map_func=_map_func,
104 | num_threads=num_threads,
105 | shuffle=shuffle,
106 | shuffle_buffer_size=shuffle_buffer_size,
107 | repeat=repeat)
108 | self._bulid(dataset, sess)
109 |
110 | self._img_num = len(img_paths)
111 |
112 | def __len__(self):
113 | return self._img_num
114 |
115 |
116 | if __name__ == '__main__':
117 | import imlib as im
118 | atts = ['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses',
119 | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young']
120 | data = Celeba('./data', atts, 128, 32, split='val')
121 | batch = data.get_next()
122 | print(len(data))
123 | print(batch[1][1], batch[1].dtype)
124 | print(batch[0].min(), batch[1].max(), batch[0].dtype)
125 | im.imshow(batch[0][1])
126 | im.show()
127 |
--------------------------------------------------------------------------------
/imlib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from imlib.basic import *
6 | from imlib.dtype import *
7 | from imlib.encode import *
8 | from imlib.transform import *
9 |
--------------------------------------------------------------------------------
/imlib/basic.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import skimage.io as iio
7 |
8 | from imlib.dtype import im2float
9 |
10 |
11 | def imread(path, as_gray=False):
12 | """Read image.
13 |
14 | Returns:
15 | Float64 image in [-1.0, 1.0].
16 | """
17 | image = iio.imread(path, as_gray)
18 | if image.dtype == np.uint8:
19 | image = image / 127.5 - 1
20 | elif image.dtype == np.uint16:
21 | image = image / 32767.5 - 1
22 | return image
23 |
24 |
25 | def imwrite(image, path):
26 | """Save an [-1.0, 1.0] image."""
27 | iio.imsave(path, im2float(image))
28 |
29 |
30 | def imshow(image):
31 | """Show a [-1.0, 1.0] image."""
32 | iio.imshow(im2float(image))
33 |
34 |
35 | show = iio.show
36 |
--------------------------------------------------------------------------------
/imlib/dtype.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 |
7 |
8 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf):
9 | # check type
10 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!'
11 |
12 | # check dtype
13 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes]
14 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes
15 |
16 | # check nan and inf
17 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!'
18 |
19 | # check value
20 | if min_value not in [None, -np.inf]:
21 | l = '[' + str(min_value)
22 | else:
23 | l = '(-inf'
24 | min_value = -np.inf
25 | if max_value not in [None, np.inf]:
26 | r = str(max_value) + ']'
27 | else:
28 | r = 'inf)'
29 | max_value = np.inf
30 | assert np.min(images) >= min_value - 1e-5 and np.max(images) <= max_value + 1e-5, \
31 | '`images` should be in the range of %s!' % (l + ',' + r)
32 |
33 |
34 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
35 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype."""
36 | _check(images, [np.float32, np.float64], -1.0, 1.0)
37 | dtype = dtype if dtype else images.dtype
38 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)
39 |
40 |
41 | def float2im(images):
42 | """Transform images from [0, 1.0] to [-1.0, 1.0]."""
43 | _check(images, [np.float32, np.float64], 0.0, 1.0)
44 | return images * 2 - 1.0
45 |
46 |
47 | def float2uint(images):
48 | """Transform images from [0, 1.0] to uint8."""
49 | _check(images, [np.float32, np.float64], -0.0, 1.0)
50 | return (images * 255).astype(np.uint8)
51 |
52 |
53 | def im2uint(images):
54 | """Transform images from [-1.0, 1.0] to uint8."""
55 | return to_range(images, 0, 255, np.uint8)
56 |
57 |
58 | def im2float(images):
59 | """Transform images from [-1.0, 1.0] to [0.0, 1.0]."""
60 | return to_range(images, 0.0, 1.0)
61 |
62 |
63 | def uint2im(images):
64 | """Transform images from uint8 to [-1.0, 1.0] of float64."""
65 | _check(images, np.uint8)
66 | return images / 127.5 - 1.0
67 |
68 |
69 | def uint2float(images):
70 | """Transform images from uint8 to [0.0, 1.0] of float64."""
71 | _check(images, np.uint8)
72 | return images / 255.0
73 |
--------------------------------------------------------------------------------
/imlib/encode.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import io
6 |
7 | import numpy as np
8 |
9 | from imlib.dtype import im2uint, uint2im
10 | from PIL import Image
11 |
12 |
13 | def imencode(image, format='PNG', quality=95):
14 | """Encode an [-1.0, 1.0] image into byte string.
15 |
16 | Args:
17 | format : 'PNG' or 'JPEG'.
18 | quality : Only for 'JPEG'.
19 |
20 | Returns:
21 | Byte string.
22 | """
23 | byte_io = io.BytesIO()
24 | image = Image.fromarray(im2uint(image))
25 | image.save(byte_io, format=format, quality=quality)
26 | bytes = byte_io.getvalue()
27 | return bytes
28 |
29 |
30 | def imdecode(bytes):
31 | """Decode byte string to float64 image in [-1.0, 1.0].
32 |
33 | Args:
34 | bytes: Byte string.
35 |
36 | Returns:
37 | A float64 image in [-1.0, 1.0].
38 | """
39 | byte_io = io.BytesIO()
40 | byte_io.write(bytes)
41 | image = np.array(Image.open(byte_io))
42 | image = uint2im(image)
43 | return image
44 |
--------------------------------------------------------------------------------
/imlib/transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import numpy as np
6 | import skimage.color as color
7 | import skimage.transform as transform
8 |
9 |
10 | rgb2gray = color.rgb2gray
11 | gray2rgb = color.gray2rgb
12 |
13 | imresize = transform.resize
14 | imrescale = transform.rescale
15 |
16 |
17 | def immerge(images, n_row=None, n_col=None, padding=0, pad_value=0):
18 | """Merge images to an image with (n_row * h) * (n_col * w).
19 |
20 | `images` is in shape of N * H * W(* C=1 or 3)
21 | """
22 | n = images.shape[0]
23 | if n_row:
24 | n_row = max(min(n_row, n), 1)
25 | n_col = int(n - 0.5) // n_row + 1
26 | elif n_col:
27 | n_col = max(min(n_col, n), 1)
28 | n_row = int(n - 0.5) // n_col + 1
29 | else:
30 | n_row = int(n ** 0.5)
31 | n_col = int(n - 0.5) // n_row + 1
32 |
33 | h, w = images.shape[1], images.shape[2]
34 | shape = (h * n_row + padding * (n_row - 1),
35 | w * n_col + padding * (n_col - 1))
36 | if images.ndim == 4:
37 | shape += (images.shape[3],)
38 | img = np.full(shape, pad_value, dtype=images.dtype)
39 |
40 | for idx, image in enumerate(images):
41 | i = idx % n_col
42 | j = idx // n_col
43 | img[j * (h + padding):j * (h + padding) + h,
44 | i * (w + padding):i * (w + padding) + w, ...] = image
45 |
46 | return img
47 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 | import tensorflow.contrib.slim as slim
10 | import tflib as tl
11 |
12 |
13 | # ==============================================================================
14 | # = alias =
15 | # ==============================================================================
16 |
17 | conv = partial(slim.conv2d, activation_fn=None)
18 | dconv = partial(slim.conv2d_transpose, activation_fn=None)
19 | fc = partial(tl.flatten_fully_connected, activation_fn=None)
20 | relu = tf.nn.relu
21 | lrelu = tf.nn.leaky_relu
22 | # batch_norm = partial(slim.batch_norm, scale=True)
23 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None)
24 | layer_norm = slim.layer_norm
25 | instance_norm = slim.instance_norm
26 |
27 |
28 | # ==============================================================================
29 | # = models =
30 | # ==============================================================================
31 |
32 | def _get_norm_fn(norm_name, is_training):
33 | if norm_name == 'none':
34 | norm = None
35 | elif norm_name == 'batch_norm':
36 | norm = partial(batch_norm, is_training=is_training)
37 | elif norm_name == 'instance_norm':
38 | norm = instance_norm
39 | elif norm_name == 'layer_norm':
40 | norm = layer_norm
41 | return norm
42 |
43 |
44 | def G(z, c, dim=64, is_training=True):
45 | norm = _get_norm_fn('batch_norm', is_training)
46 | fc_norm_relu = partial(fc, normalizer_fn=norm, activation_fn=relu)
47 | dconv_norm_relu = partial(dconv, normalizer_fn=norm, activation_fn=relu)
48 |
49 | with tf.variable_scope('G', reuse=tf.AUTO_REUSE):
50 | y = tf.concat([z, c], axis=1)
51 | y = fc_norm_relu(y, 4 * 4 * dim * 8)
52 | y = tf.reshape(y, [-1, 4, 4, dim * 8])
53 | y = dconv_norm_relu(y, dim * 4, 4, 2)
54 | y = dconv_norm_relu(y, dim * 2, 4, 2)
55 | y = dconv_norm_relu(y, dim * 1, 4, 2)
56 | x = tf.tanh(dconv(y, 3, 4, 2))
57 | return x
58 |
59 |
60 | def D(x, c_dim, dim=64, norm_name='batch_norm', is_training=True):
61 | norm = _get_norm_fn(norm_name, is_training)
62 | conv_norm_lrelu = partial(conv, normalizer_fn=norm, activation_fn=lrelu)
63 |
64 | with tf.variable_scope('D', reuse=tf.AUTO_REUSE):
65 | y = conv_norm_lrelu(x, dim, 4, 2)
66 | y = conv_norm_lrelu(y, dim * 2, 4, 2)
67 | y = conv_norm_lrelu(y, dim * 4, 4, 2)
68 | y = conv_norm_lrelu(y, dim * 8, 4, 2)
69 | logit = fc(y, 1)
70 | c_logit = fc(y, c_dim)
71 | return logit, c_logit
72 |
73 |
74 | # ==============================================================================
75 | # = loss function =
76 | # ==============================================================================
77 |
78 | def get_loss_fn(mode):
79 | if mode == 'gan':
80 | def d_loss_fn(r_logit, f_logit):
81 | r_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(r_logit), r_logit)
82 | f_loss = tf.losses.sigmoid_cross_entropy(tf.zeros_like(f_logit), f_logit)
83 | return r_loss, f_loss
84 |
85 | def g_loss_fn(f_logit):
86 | f_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(f_logit), f_logit)
87 | return f_loss
88 |
89 | elif mode == 'lsgan':
90 | def d_loss_fn(r_logit, f_logit):
91 | r_loss = tf.losses.mean_squared_error(tf.ones_like(r_logit), r_logit)
92 | f_loss = tf.losses.mean_squared_error(tf.zeros_like(f_logit), f_logit)
93 | return r_loss, f_loss
94 |
95 | def g_loss_fn(f_logit):
96 | f_loss = tf.losses.mean_squared_error(tf.ones_like(f_logit), f_logit)
97 | return f_loss
98 |
99 | elif mode == 'wgan':
100 | def d_loss_fn(r_logit, f_logit):
101 | r_loss = - tf.reduce_mean(r_logit)
102 | f_loss = tf.reduce_mean(f_logit)
103 | return r_loss, f_loss
104 |
105 | def g_loss_fn(f_logit):
106 | f_loss = - tf.reduce_mean(f_logit)
107 | return f_loss
108 |
109 | elif mode == 'hinge':
110 | def d_loss_fn(r_logit, f_logit):
111 | r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0))
112 | f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0))
113 | return r_loss, f_loss
114 |
115 | def g_loss_fn(f_logit):
116 | # f_loss = tf.reduce_mean(tf.maximum(1 - f_logit, 0))
117 | f_loss = tf.reduce_mean(- f_logit)
118 | return f_loss
119 |
120 | return d_loss_fn, g_loss_fn
121 |
122 |
123 | # ==============================================================================
124 | # = others =
125 | # ==============================================================================
126 |
127 | def gradient_penalty(f, real, fake, mode):
128 | def _gradient_penalty(f, real, fake=None):
129 | def _interpolate(a, b=None):
130 | with tf.name_scope('interpolate'):
131 | if b is None: # interpolation in DRAGAN
132 | beta = tf.random_uniform(shape=tf.shape(a), minval=0., maxval=1.)
133 | _, variance = tf.nn.moments(a, list(range(a.shape.ndims)))
134 | b = a + 0.5 * tf.sqrt(variance) * beta
135 | shape = [tf.shape(a)[0]] + [1] * (a.shape.ndims - 1)
136 | alpha = tf.random_uniform(shape=shape, minval=0., maxval=1.)
137 | inter = a + alpha * (b - a)
138 | inter.set_shape(a.get_shape().as_list())
139 | return inter
140 |
141 | with tf.name_scope('gradient_penalty'):
142 | x = _interpolate(real, fake)
143 | pred = f(x)
144 | if isinstance(pred, tuple):
145 | pred = pred[0]
146 | grad = tf.gradients(pred, x)[0]
147 | norm = tf.norm(slim.flatten(grad), axis=1)
148 | gp = tf.reduce_mean((norm - 1.)**2)
149 | return gp
150 |
151 | if mode == 'none':
152 | gp = tf.constant(0, dtype=tf.float32)
153 | elif mode == 'wgan-gp':
154 | gp = _gradient_penalty(f, real, fake)
155 | elif mode == 'dragan':
156 | gp = _gradient_penalty(f, real)
157 |
158 | return gp
159 |
160 |
161 | def sample_c(ks, c_1=None, continuous_last=False):
162 | assert c_1 is None or ks[0] == len(c_1), '`ks[0]` is inconsistent with `c_1`!'
163 |
164 | c_tree = [[np.array([1.])]]
165 | mask_tree = []
166 |
167 | for l, k in enumerate(ks):
168 | if c_1 is not None and l == 0:
169 | c_l = [c_1]
170 | mask_l = [np.ones_like(c_1)]
171 | else:
172 | c_l = []
173 | mask_l = []
174 | for i in range(len(c_tree[-1])):
175 | for j in range(len(c_tree[-1][-1])):
176 | if c_tree[-1][i][j] == 1.:
177 | if continuous_last is True and l == len(ks) - 1:
178 | c_l.append(np.random.uniform(-1, 1, size=[k]))
179 | else:
180 | c_l.append(np.eye(k)[np.random.randint(k)])
181 | mask_l.append(np.ones([k]))
182 | else:
183 | c_l.append(np.zeros([k]))
184 | mask_l.append(np.zeros([k]))
185 | c_tree.append(c_l)
186 | mask_tree.append(mask_l)
187 |
188 | c_tree[0:1] = []
189 | c = np.concatenate([k for l in c_tree for k in l])
190 | mask = np.concatenate([k for l in mask_tree for k in l])
191 |
192 | return c, mask, c_tree, mask_tree
193 |
194 |
195 | def traversal_trees(ks, continuous_last=False):
196 | trees = []
197 | if len(ks) == 1:
198 | if continuous_last:
199 | trees.append([[np.random.uniform(-1, 1, size=[ks[0]])]])
200 | else:
201 | for i in range(ks[0]):
202 | trees.append([[np.eye(ks[0])[i]]])
203 | else:
204 | def _merge_trees(trees):
205 | tree = []
206 | for l in range(len(trees[0])):
207 | tree_l = []
208 | for t in trees:
209 | tree_l += t[l]
210 | tree.append(tree_l)
211 | return tree
212 |
213 | def _zero_tree(tree):
214 | zero_tree = []
215 | for l in tree:
216 | zero_tree_l = []
217 | for i in l:
218 | zero_tree_l.append(i * 0.)
219 | zero_tree.append(zero_tree_l)
220 | return zero_tree
221 |
222 | for i in range(ks[0]):
223 | trees_i = []
224 | sub_trees, _ = traversal_trees(ks[1:], continuous_last=continuous_last)
225 | for j, s_t in enumerate(sub_trees):
226 | to_merge = [_zero_tree(s_t)] * ks[0]
227 | to_merge[i] = s_t
228 | sub_trees[j] = _merge_trees(to_merge)
229 | for s_t in sub_trees:
230 | trees_i.append([[np.eye(ks[0])[i]]] + s_t)
231 | trees += trees_i
232 |
233 | cs = []
234 | for t in trees:
235 | cs.append(np.concatenate([k for l in t for k in l]))
236 | return trees, cs
237 |
238 |
239 | def to_tree(x, ks):
240 | size_splits = []
241 | n_l = 1
242 | for k in ks:
243 | for _ in range(n_l):
244 | size_splits.append(k)
245 | n_l *= k
246 |
247 | splits = tf.split(x, size_splits, axis=1)
248 |
249 | tree = []
250 | n_l = 1
251 | i = 0
252 | for k in ks:
253 | tree_l = []
254 | for _ in range(n_l):
255 | tree_l.append(splits[i])
256 | i += 1
257 | n_l *= k
258 | tree.append(tree_l)
259 |
260 | return tree
261 |
262 |
263 | def tree_loss(logits, c, mask, ks, continuous_last=False):
264 | logits_tree = to_tree(logits, ks)
265 | c_tree = to_tree(c, ks)
266 | mask_tree = to_tree(mask, ks)
267 |
268 | losses = []
269 | for l, logits_l, c_l, mask_l in zip(range(len(logits_tree)), logits_tree, c_tree, mask_tree):
270 | loss_l = 0
271 | for lo, c, m in zip(logits_l, c_l, mask_l):
272 | weights = tf.reduce_mean(m, axis=1)
273 | if continuous_last is True and l == len(ks) - 1:
274 | loss_l += tf.losses.mean_squared_error(c, lo, weights=weights)
275 | else:
276 | loss_l += tf.losses.softmax_cross_entropy(c, lo, weights=weights)
277 | losses.append(loss_l)
278 | return losses
279 |
280 | if __name__ == '__main__':
281 | from pprint import pprint as pp
282 | s = tf.Session()
283 | ks = [2, 2, 2]
284 | c_1 = np.array([1., 0])
285 | continuous_last = False
286 | if len(ks) == 1 and c_1 is not None:
287 | continuous_last = False
288 |
289 | c, mask, c_tree, mask_tree = sample_c(ks, c_1, continuous_last)
290 | pp(c)
291 | pp(mask)
292 | pp(c_tree)
293 | pp(mask_tree)
294 |
295 | tree = to_tree(tf.constant(np.array([c]), dtype=tf.float32), ks)
296 | pp(tree)
297 | pp(s.run(tree))
298 |
299 | pp(tree_loss(tf.constant(np.array([c]), dtype=tf.float32),
300 | tf.constant(np.array([c]), dtype=tf.float32),
301 | tf.constant(np.array([mask]), dtype=tf.float32),
302 | ks,
303 | continuous_last))
304 |
305 | pp(s.run(tf.losses.softmax_cross_entropy([[0, 10]], [[0.5, 0.5]])))
306 | pp(s.run(tf.losses.mean_squared_error([[2, 1]], [[0, 0]])))
307 |
308 | for tree, c in zip(*traversal_trees(ks, continuous_last=continuous_last)):
309 | pp(tree)
310 | pp(c)
311 |
--------------------------------------------------------------------------------
/pics/bangs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/DTLC-GAN-Tensorflow/ca053af68a47e4678c172d756d64e3576a51e009/pics/bangs.png
--------------------------------------------------------------------------------
/pics/eyeglasses.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/DTLC-GAN-Tensorflow/ca053af68a47e4678c172d756d64e3576a51e009/pics/eyeglasses.png
--------------------------------------------------------------------------------
/pylib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from pylib.path import *
6 | from pylib.timer import *
7 |
--------------------------------------------------------------------------------
/pylib/path.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import fnmatch
6 | import os
7 | import sys
8 |
9 |
10 | def add_path(paths):
11 | if not isinstance(paths, (list, tuple)):
12 | paths = [paths]
13 | for path in paths:
14 | if path not in sys.path:
15 | sys.path.insert(0, path)
16 |
17 |
18 | def mkdir(paths):
19 | if not isinstance(paths, (list, tuple)):
20 | paths = [paths]
21 | for path in paths:
22 | if not os.path.isdir(path):
23 | os.makedirs(path)
24 |
25 |
26 | def split(path):
27 | dir, name_ext = os.path.split(path)
28 | name, ext = os.path.splitext(name_ext)
29 | return dir, name, ext
30 |
31 |
32 | def directory(path):
33 | return split(path)[0]
34 |
35 |
36 | def name(path):
37 | return split(path)[1]
38 |
39 |
40 | def ext(path):
41 | return split(path)[2]
42 |
43 |
44 | def name_ext(path):
45 | return ''.join(split(path)[1:])
46 |
47 |
48 | asbpath = os.path.abspath
49 |
50 |
51 | join = os.path.join
52 |
53 |
54 | def match(dir, pat, recursive=False):
55 | if recursive:
56 | iterator = os.walk(dir)
57 | else:
58 | iterator = [next(os.walk(dir))]
59 | matches = []
60 | for root, _, file_names in iterator:
61 | for file_name in fnmatch.filter(file_names, pat):
62 | matches.append(os.path.join(root, file_name))
63 | return matches
64 |
--------------------------------------------------------------------------------
/pylib/timer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import datetime
6 | import timeit
7 |
8 |
9 | class Timer(object):
10 | """A timer as a context manager.
11 |
12 | Modified from https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py.
13 |
14 | Wraps around a timer. A custom timer can be passed
15 | to the constructor. The default timer is timeit.default_timer.
16 |
17 | Note that the latter measures wall clock time, not CPU time!
18 | On Unix systems, it corresponds to time.time.
19 | On Windows systems, it corresponds to time.clock.
20 |
21 | Arguments:
22 | is_output : If True, print output after exiting context.
23 | format : 'ms', 's' or 'datetime'
24 | """
25 |
26 | def __init__(self, timer=timeit.default_timer, is_output=True, fmt='s'):
27 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!"
28 | self._timer = timer
29 | self._is_output = is_output
30 | self._fmt = fmt
31 |
32 | def __enter__(self):
33 | """Start the timer in the context manager scope."""
34 | self.start()
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_value, exc_traceback):
38 | """Set the end time."""
39 | if self._is_output:
40 | print(str(self))
41 |
42 | def __str__(self):
43 | if self._fmt != 'datetime':
44 | return '%s %s' % (self.elapsed, self._fmt)
45 | else:
46 | return str(self.elapsed)
47 |
48 | def start(self):
49 | self.start_time = self._timer()
50 |
51 | @property
52 | def elapsed(self):
53 | """Return the current elapsed time since start."""
54 | e = self._timer() - self.start_time
55 |
56 | if self._fmt == 'ms':
57 | return e * 1000
58 | elif self._fmt == 's':
59 | return e
60 | elif self._fmt == 'datetime':
61 | return datetime.timedelta(seconds=e)
62 |
63 |
64 | def timer(**timer_kwargs):
65 | """Function decorator displaying the function execution time.
66 |
67 | All kwargs are the arguments taken by the Timer class constructor.
68 | """
69 | # store Timer kwargs in local variable so the namespace isn't polluted
70 | # by different level args and kwargs
71 |
72 | def wrapped_f(f):
73 | def wrapped(*args, **kwargs):
74 | fmt = '[*] function "%(function_name)s" execution time: %(execution_time)s [*]'
75 | with Timer(**timer_kwargs) as t:
76 | out = f(*args, **kwargs)
77 | context = {'function_name': f.__name__, 'execution_time': str(t)}
78 | print(fmt % context)
79 | return out
80 | return wrapped
81 |
82 | return wrapped_f
83 |
84 | if __name__ == "__main__":
85 | import time
86 |
87 | # 1
88 | print(1)
89 | with Timer() as t:
90 | time.sleep(1)
91 | print(t)
92 | time.sleep(1)
93 |
94 | with Timer(fmt='datetime') as t:
95 | time.sleep(1)
96 |
97 | # 2
98 | print(2)
99 | t = Timer(fmt='ms')
100 | t.start()
101 | time.sleep(2)
102 | print(t)
103 |
104 | t = Timer(fmt='datetime')
105 | t.start()
106 | time.sleep(1)
107 | print(t)
108 |
109 | # 3
110 | print(3)
111 |
112 | @timer(fmt='ms')
113 | def blah():
114 | time.sleep(2)
115 |
116 | blah()
117 |
--------------------------------------------------------------------------------
/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.checkpoint import *
6 | from tflib.collection import *
7 | from tflib.data import *
8 | from tflib.layers import *
9 | from tflib.parallel import *
10 | from tflib.utils import *
11 | from tflib.vision import *
12 |
--------------------------------------------------------------------------------
/tflib/checkpoint.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def load_checkpoint(ckpt_dir_or_file, session, var_list=None):
11 | """Load checkpoint.
12 |
13 | This function add some useless ops to the graph. It is better
14 | to use tf.train.init_from_checkpoint(...).
15 | """
16 | if os.path.isdir(ckpt_dir_or_file):
17 | ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
18 |
19 | restorer = tf.train.Saver(var_list)
20 | restorer.restore(session, ckpt_dir_or_file)
21 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file)
22 |
23 |
24 | def init_from_checkpoint(ckpt_dir_or_file, assignment_map={'/': '/'}):
25 | # Use the checkpoint values for the variables' initializers. Note that this
26 | # function just changes the initializers but does not actually run them, and
27 | # you should still run the initializers manually.
28 | tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map)
29 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file)
30 |
--------------------------------------------------------------------------------
/tflib/collection.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from functools import partial
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def tensors_filter(tensors,
11 | includes='',
12 | includes_combine_type='or',
13 | excludes=[],
14 | excludes_combine_type='or'):
15 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!'
16 | assert isinstance(includes, (str, list, tuple)), '`includes` should be a string or a list(tuple) of strings!'
17 | assert includes_combine_type in ['or', 'and'], "`includes_combine_type` should be 'or' or 'and'!"
18 | assert isinstance(excludes, (str, list, tuple)), '`excludes` should be a string or a list(tuple) of strings!'
19 | assert excludes_combine_type in ['or', 'and'], "`excludes_combine_type` should be 'or' or 'and'!"
20 |
21 | def _select(filters, combine_type):
22 | if isinstance(filters, str):
23 | filters = [filters]
24 |
25 | selected = []
26 | for t in tensors:
27 | if combine_type == 'or':
28 | for filt in filters:
29 | if filt in t.name:
30 | selected.append(t)
31 | break
32 | elif combine_type == 'and':
33 | all_pass = True and filters # for fiters == []
34 | for filt in filters:
35 | if filt not in t.name:
36 | all_pass = False
37 | break
38 | if all_pass:
39 | selected.append(t)
40 |
41 | return selected
42 |
43 | include_set = _select(includes, includes_combine_type)
44 | exclude_set = _select(excludes, excludes_combine_type)
45 | select_set = [t for t in include_set if t not in exclude_set]
46 |
47 | return select_set
48 |
49 |
50 | def get_collection(key,
51 | includes='',
52 | includes_combine_type='or',
53 | excludes=[],
54 | excludes_combine_type='and'):
55 | tensors = tf.get_collection(key)
56 | return tensors_filter(tensors,
57 | includes,
58 | includes_combine_type,
59 | excludes,
60 | excludes_combine_type)
61 |
62 | global_variables = partial(get_collection, key=tf.GraphKeys.GLOBAL_VARIABLES)
63 | trainable_variables = partial(get_collection, key=tf.GraphKeys.TRAINABLE_VARIABLES)
64 | update_ops = partial(get_collection, key=tf.GraphKeys.UPDATE_OPS)
65 |
--------------------------------------------------------------------------------
/tflib/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.data.dataset import *
6 | from tflib.data.disk_image import *
7 | from tflib.data.memory_data import *
8 | from tflib.data.tfrecord import *
9 | from tflib.data.tfrecord_creator import *
10 |
--------------------------------------------------------------------------------
/tflib/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 | import tensorflow.contrib.eager as tfe
9 |
10 | from tflib.utils import session
11 |
12 |
13 | _N_CPU = multiprocessing.cpu_count()
14 |
15 |
16 | def batch_dataset(dataset,
17 | batch_size,
18 | prefetch_batch=_N_CPU + 1,
19 | drop_remainder=True,
20 | filter=None,
21 | map_func=None,
22 | num_threads=_N_CPU,
23 | shuffle=True,
24 | shuffle_buffer_size=None,
25 | repeat=-1):
26 | if filter:
27 | dataset = dataset.filter(filter)
28 |
29 | if map_func:
30 | dataset = dataset.map(map_func, num_parallel_calls=num_threads)
31 |
32 | if shuffle:
33 | if shuffle_buffer_size is None:
34 | shuffle_buffer_size = batch_size * 100
35 | dataset = dataset.shuffle(shuffle_buffer_size)
36 |
37 | if drop_remainder:
38 | dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
39 | else:
40 | dataset = dataset.batch(batch_size)
41 |
42 | dataset = dataset.repeat(repeat).prefetch(prefetch_batch)
43 |
44 | return dataset
45 |
46 |
47 | class Dataset(object):
48 |
49 | def __init__(self):
50 | self._dataset = None
51 | self._iterator = None
52 | self._batch_op = None
53 | self._sess = None
54 |
55 | self._is_eager = tf.executing_eagerly()
56 | self._eager_iterator = None
57 |
58 | def __del__(self):
59 | if self._sess:
60 | self._sess.close()
61 |
62 | def __iter__(self):
63 | return self
64 |
65 | def __next__(self):
66 | try:
67 | b = self.get_next()
68 | except:
69 | raise StopIteration
70 | else:
71 | return b
72 |
73 | next = __next__
74 |
75 | def get_next(self):
76 | if self._is_eager:
77 | return self._eager_iterator.get_next()
78 | else:
79 | return self._sess.run(self._batch_op)
80 |
81 | def reset(self, feed_dict={}):
82 | if self._is_eager:
83 | self._eager_iterator = tfe.Iterator(self._dataset)
84 | else:
85 | self._sess.run(self._iterator.initializer, feed_dict=feed_dict)
86 |
87 | def _bulid(self, dataset, sess=None):
88 | self._dataset = dataset
89 |
90 | if self._is_eager:
91 | self._eager_iterator = tfe.Iterator(dataset)
92 | else:
93 | self._iterator = dataset.make_initializable_iterator()
94 | self._batch_op = self._iterator.get_next()
95 | if sess:
96 | self._sess = sess
97 | else:
98 | self._sess = session()
99 |
100 | try:
101 | self.reset()
102 | except:
103 | pass
104 |
105 | @property
106 | def dataset(self):
107 | return self._dataset
108 |
109 | @property
110 | def iterator(self):
111 | return self._iterator
112 |
113 | @property
114 | def batch_op(self):
115 | return self._batch_op
116 |
--------------------------------------------------------------------------------
/tflib/data/disk_image.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import tensorflow as tf
8 |
9 | from tflib.data.dataset import batch_dataset, Dataset
10 |
11 |
12 | _N_CPU = multiprocessing.cpu_count()
13 |
14 |
15 | def disk_image_batch_dataset(img_paths,
16 | batch_size,
17 | labels=None,
18 | prefetch_batch=_N_CPU + 1,
19 | drop_remainder=True,
20 | filter=None,
21 | map_func=None,
22 | num_threads=_N_CPU,
23 | shuffle=True,
24 | shuffle_buffer_size=None,
25 | repeat=-1):
26 | """Disk image batch dataset.
27 |
28 | This function is suitable for jpg and png files
29 |
30 | Arguments:
31 | img_paths : String list or 1-D tensor, each of which is an iamge path
32 | labels : Label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label
33 | """
34 | if labels is None:
35 | dataset = tf.data.Dataset.from_tensor_slices(img_paths)
36 | elif isinstance(labels, tuple):
37 | dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels))
38 | else:
39 | dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels))
40 |
41 | def parse_func(path, *label):
42 | img = tf.read_file(path)
43 | img = tf.image.decode_png(img, 3)
44 | return (img,) + label
45 |
46 | if map_func:
47 | def map_func_(*args):
48 | return map_func(*parse_func(*args))
49 | else:
50 | map_func_ = parse_func
51 |
52 | # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower
53 |
54 | dataset = batch_dataset(dataset,
55 | batch_size,
56 | prefetch_batch,
57 | drop_remainder,
58 | filter,
59 | map_func_,
60 | num_threads,
61 | shuffle,
62 | shuffle_buffer_size,
63 | repeat)
64 |
65 | return dataset
66 |
67 |
68 | class DiskImageData(Dataset):
69 | """DiskImageData.
70 |
71 | This class is suitable for jpg and png files
72 |
73 | Arguments:
74 | img_paths : String list or 1-D tensor, each of which is an iamge path
75 | labels : Label list or tensor, each of which is a corresponding label
76 | """
77 |
78 | def __init__(self,
79 | img_paths,
80 | batch_size,
81 | labels=None,
82 | prefetch_batch=_N_CPU + 1,
83 | drop_remainder=True,
84 | filter=None,
85 | map_func=None,
86 | num_threads=_N_CPU,
87 | shuffle=True,
88 | shuffle_buffer_size=None,
89 | repeat=-1,
90 | sess=None):
91 | super(DiskImageData, self).__init__()
92 | dataset = disk_image_batch_dataset(img_paths,
93 | batch_size,
94 | labels,
95 | prefetch_batch,
96 | drop_remainder,
97 | filter,
98 | map_func,
99 | num_threads,
100 | shuffle,
101 | shuffle_buffer_size,
102 | repeat)
103 | self._bulid(dataset, sess)
104 | self._n_data = len(img_paths)
105 |
106 | def __len__(self):
107 | return self._n_data
108 |
109 |
110 | if __name__ == '__main__':
111 | import glob
112 |
113 | import imlib as im
114 | import numpy as np
115 | import pylib
116 |
117 | paths = glob.glob('/home/hezhenliang/Resource/face/CelebA/origin/origin/processed_by_hezhenliang/align_celeba/img_align_celeba/*.jpg')
118 | paths = sorted(paths)[182637:]
119 | labels = list(range(len(paths)))
120 |
121 | def filter(x, y, *args):
122 | return tf.cond(y > 1, lambda: tf.constant(True), lambda: tf.constant(False))
123 |
124 | def map_func(x, *args):
125 | x = tf.image.resize_images(x, [256, 256])
126 | x = tf.to_float((x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)) * 2 - 1)
127 | return (x,) + args
128 |
129 | tf.enable_eager_execution()
130 |
131 | s = tf.Session()
132 |
133 | data = DiskImageData(paths, 32, (labels, labels), filter=None, map_func=None, shuffle=True, sess=s)
134 |
135 | for _ in range(1000):
136 | with pylib.Timer():
137 | for i in range(100):
138 | b = data.get_next()
139 | # print(b[1][0])
140 | # print(b[2][0])
141 | # im.imshow(np.array(b[0][0]))
142 | # im.show()
143 | # data.reset()
144 |
--------------------------------------------------------------------------------
/tflib/data/memory_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import multiprocessing
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 | from tflib.data.dataset import batch_dataset, Dataset
11 |
12 |
13 | _N_CPU = multiprocessing.cpu_count()
14 |
15 |
16 | def memory_data_batch_dataset(memory_data_dict,
17 | batch_size,
18 | prefetch_batch=_N_CPU + 1,
19 | drop_remainder=True,
20 | filter=None,
21 | map_func=None,
22 | num_threads=_N_CPU,
23 | shuffle=True,
24 | shuffle_buffer_size=None,
25 | repeat=-1):
26 | """Memory data batch dataset.
27 |
28 | `memory_data_dict` example:
29 | {'img': img_ndarray, 'label': label_ndarray} or
30 | {'img': img_tftensor, 'label': label_tftensor}
31 | * The value of each item of `memory_data_dict` is in shape of (N, ...).
32 | """
33 | dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict)
34 | dataset = batch_dataset(dataset,
35 | batch_size,
36 | prefetch_batch,
37 | drop_remainder,
38 | filter,
39 | map_func,
40 | num_threads,
41 | shuffle,
42 | shuffle_buffer_size,
43 | repeat)
44 | return dataset
45 |
46 |
47 | class MemoryData(Dataset):
48 | """MemoryData.
49 |
50 | `memory_data_dict` example:
51 | {'img': img_ndarray, 'label': label_ndarray} or
52 | {'img': img_tftensor, 'label': label_tftensor}
53 | * The value of each item of `memory_data_dict` is in shape of (N, ...).
54 | """
55 |
56 | def __init__(self,
57 | memory_data_dict,
58 | batch_size,
59 | prefetch_batch=_N_CPU + 1,
60 | drop_remainder=True,
61 | filter=None,
62 | map_func=None,
63 | num_threads=_N_CPU,
64 | shuffle=True,
65 | shuffle_buffer_size=None,
66 | repeat=-1,
67 | sess=None):
68 | super(MemoryData, self).__init__()
69 | dataset = memory_data_batch_dataset(memory_data_dict,
70 | batch_size,
71 | prefetch_batch,
72 | drop_remainder,
73 | filter,
74 | map_func,
75 | num_threads,
76 | shuffle,
77 | shuffle_buffer_size,
78 | repeat)
79 | self._bulid(dataset, sess)
80 | if isinstance(memory_data_dict.values()[0], np.ndarray):
81 | self._n_data = len(memory_data_dict.values()[0])
82 | else:
83 | self._n_data = memory_data_dict.values()[0].get_shape().as_list()[0]
84 |
85 | def __len__(self):
86 | return self._n_data
87 |
88 | if __name__ == '__main__':
89 | data = {'a': np.array([1.0, 2, 3, 4, 5]),
90 | 'b': np.array([[1, 2],
91 | [2, 3],
92 | [3, 4],
93 | [4, 5],
94 | [5, 6]])}
95 |
96 | def filter(x):
97 | return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False))
98 |
99 | def map_func(x):
100 | x['a'] = x['a'] * 10
101 | return x
102 |
103 | # tf.enable_eager_execution()
104 |
105 | s = tf.Session()
106 |
107 | dataset = MemoryData(data,
108 | 2,
109 | filter=None,
110 | map_func=map_func,
111 | shuffle=True,
112 | shuffle_buffer_size=None,
113 | drop_remainder=True,
114 | repeat=4,
115 | sess=s)
116 |
117 | for i in range(5):
118 | print(map(dataset.get_next().__getitem__, ['b', 'a']))
119 |
120 | print([n.name for n in tf.get_default_graph().as_graph_def().node])
121 |
--------------------------------------------------------------------------------
/tflib/data/tfrecord.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | import glob
7 | import json
8 | import multiprocessing
9 | import os
10 |
11 | import tensorflow as tf
12 |
13 | from tflib.data.dataset import batch_dataset, Dataset
14 |
15 |
16 | _N_CPU = multiprocessing.cpu_count()
17 |
18 | _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024
19 |
20 | _DECODERS = {
21 | 'png': {'decoder': tf.image.decode_png, 'decode_param': dict()},
22 | 'jpg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()},
23 | 'jpeg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()},
24 | 'uint8': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.uint8)},
25 | 'int64': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.int64)},
26 | 'float32': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.float32)},
27 | }
28 |
29 |
30 | def tfrecord_batch_dataset(tfrecord_files,
31 | infos,
32 | compression_type,
33 | batch_size,
34 | prefetch_batch=_N_CPU + 1,
35 | drop_remainder=True,
36 | filter=None,
37 | map_func=None,
38 | num_threads=_N_CPU,
39 | shuffle=True,
40 | shuffle_buffer_size=None,
41 | repeat=-1):
42 | """Tfrecord batch dataset.
43 |
44 | `infos` example:
45 | [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]},
46 | {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}]
47 | """
48 | dataset = tf.data.TFRecordDataset(tfrecord_files,
49 | compression_type=compression_type,
50 | buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES)
51 |
52 | features = {}
53 | for info in infos:
54 | features[info['name']] = tf.FixedLenFeature([], tf.string)
55 |
56 | def parse_func(serialized_example):
57 | example = tf.parse_single_example(serialized_example, features=features)
58 |
59 | feature_dict = {}
60 | for info in infos:
61 | name = info['name']
62 | decoder = info['decoder']
63 | decode_param = info['decode_param']
64 | shape = info['shape']
65 |
66 | feature = decoder(example[name], **decode_param)
67 | feature = tf.reshape(feature, shape)
68 | feature_dict[name] = feature
69 |
70 | return feature_dict
71 |
72 | dataset = dataset.map(parse_func, num_parallel_calls=num_threads)
73 |
74 | dataset = batch_dataset(dataset,
75 | batch_size,
76 | prefetch_batch,
77 | drop_remainder,
78 | filter,
79 | map_func,
80 | num_threads,
81 | shuffle,
82 | shuffle_buffer_size,
83 | repeat)
84 |
85 | return dataset
86 |
87 |
88 | class TfrecordData(Dataset):
89 |
90 | def __init__(self,
91 | tfrecord_path,
92 | batch_size,
93 | prefetch_batch=_N_CPU + 1,
94 | drop_remainder=True,
95 | filter=None,
96 | map_func=None,
97 | num_threads=_N_CPU,
98 | shuffle=True,
99 | shuffle_buffer_size=None,
100 | repeat=-1,
101 | sess=None):
102 | super(TfrecordData, self).__init__()
103 |
104 | info_file = os.path.join(tfrecord_path, 'info.json')
105 | infos, self._data_num, compression_type = self._parse_json(info_file)
106 |
107 | self._shapes = {info['name']: tuple(info['shape']) for info in infos}
108 |
109 | tfrecord_files = sorted(glob.glob(os.path.join(tfrecord_path, '*.tfrecord')))
110 | dataset = tfrecord_batch_dataset(tfrecord_files,
111 | infos,
112 | compression_type,
113 | batch_size,
114 | prefetch_batch,
115 | drop_remainder,
116 | filter,
117 | map_func,
118 | num_threads,
119 | shuffle,
120 | shuffle_buffer_size,
121 | repeat)
122 |
123 | self._bulid(dataset, sess)
124 |
125 | def __len__(self):
126 | return self._data_num
127 |
128 | @property
129 | def shape(self):
130 | return self._shapes
131 |
132 | @staticmethod
133 | def _parse_old(json_file):
134 | with open(json_file.replace('info.json', 'info.txt')) as f:
135 | try: # older version 1
136 | infos = json.load(f)
137 | for info in infos[0:-1]:
138 | info['decoder'] = _DECODERS[info['dtype_or_format']]['decoder']
139 | info['decode_param'] = _DECODERS[info['dtype_or_format']]['decode_param']
140 | except: # older version 2
141 | f.seek(0)
142 | infos = ''
143 | for line in f.readlines():
144 | infos += line.strip('\n')
145 | infos = eval(infos)
146 |
147 | data_num = infos[-1]['data_num']
148 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[infos[-1]['compression_type']]
149 | infos[-1:] = []
150 |
151 | return infos, data_num, compression_type
152 |
153 | @staticmethod
154 | def _parse_json(json_file):
155 | try:
156 | with open(json_file) as f:
157 | info = json.load(f)
158 | infos = info['item']
159 | for i in infos:
160 | i['decoder'] = _DECODERS[i['dtype_or_format']]['decoder']
161 | i['decode_param'] = _DECODERS[i['dtype_or_format']]['decode_param']
162 | data_num = info['info']['data_num']
163 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[info['info']['compression_type']]
164 | except: # for older version
165 | infos, data_num, compression_type = TfrecordData._parse_old(json_file)
166 |
167 | return infos, data_num, compression_type
168 |
--------------------------------------------------------------------------------
/tflib/data/tfrecord_creator.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import io
6 | import json
7 | import os
8 | import shutil
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from PIL import Image
14 | from tflib.data import tfrecord
15 |
16 | __metaclass__ = type
17 |
18 |
19 | _ALLOWED_TYPES = tfrecord._DECODERS.keys()
20 |
21 |
22 | class BytesTfrecordCreator(object):
23 | """BytesTfrecordCreator.
24 |
25 | `infos` example:
26 | infos = [
27 | ['img', 'jpg', (64, 64, 3)],
28 | ['id', 'int64', ()],
29 | ['attr', 'int64', (40,)],
30 | ['point', 'float32', (5, 2)]
31 | ]
32 |
33 | `compression_type`:
34 | 0 : NONE
35 | 1 : ZLIB
36 | 2 : GZIP
37 | """
38 |
39 | def __init__(self,
40 | save_path,
41 | infos,
42 | size_each=None,
43 | compression_type=0,
44 | overwrite_existence=False):
45 | # overwrite existence
46 | if os.path.exists(save_path):
47 | if not overwrite_existence:
48 | raise Exception('%s exists!' % save_path)
49 | else:
50 | shutil.rmtree(save_path)
51 | os.makedirs(save_path)
52 | else:
53 | os.makedirs(save_path)
54 |
55 | self._save_path = save_path
56 |
57 | # add info
58 | self._infos = []
59 | self._info_names = []
60 | for info in infos:
61 | self._add_info(*info)
62 |
63 | self._data_num = 0
64 | self._tfrecord_num = 0
65 | self._size_each = [size_each, 2147483647][not size_each]
66 | self._writer = None
67 |
68 | self._compression_type = compression_type
69 | self._options = tf.python_io.TFRecordOptions(compression_type)
70 |
71 | def __del__(self):
72 | info = {'item': self._infos, 'info': {'data_num': self._data_num, 'compression_type': self._compression_type}}
73 | info_str = json.dumps(info, indent=4, separators=(',', ':'))
74 |
75 | with open(os.path.join(self._save_path, 'info.json'), 'w') as info_f:
76 | info_f.write(info_str)
77 |
78 | if self._writer:
79 | self._writer.close()
80 |
81 | def add(self, feature_bytes_dict):
82 | """Add example.
83 |
84 | `feature_bytes_dict` example:
85 | feature_bytes_dict = {
86 | 'img' : img_bytes,
87 | 'id' : id_bytes,
88 | 'attr' : attr_bytes,
89 | 'point' : point_bytes
90 | }
91 | """
92 | assert sorted(self._info_names) == sorted(feature_bytes_dict.keys()), \
93 | 'Feature names are inconsistent with the givens!'
94 |
95 | self._new_tfrecord_check()
96 |
97 | self._writer.write(self._bytes_tfexample(feature_bytes_dict).SerializeToString())
98 | self._data_num += 1
99 |
100 | def _new_tfrecord_check(self):
101 | if self._data_num // self._size_each == self._tfrecord_num:
102 | self._tfrecord_num += 1
103 |
104 | if self._writer:
105 | self._writer.close()
106 |
107 | tfrecord_path = os.path.join(self._save_path, 'data_%06d.tfrecord' % (self._tfrecord_num - 1))
108 | self._writer = tf.python_io.TFRecordWriter(tfrecord_path, self._options)
109 |
110 | def _add_info(self, name, dtype_or_format, shape):
111 | assert name not in self._info_names, 'Info name "%s" is duplicated!' % name
112 | assert dtype_or_format in _ALLOWED_TYPES, 'Allowed data types: %s!' % str(_ALLOWED_TYPES)
113 | self._infos.append(dict(name=name, dtype_or_format=dtype_or_format, shape=shape))
114 | self._info_names.append(name)
115 |
116 | @staticmethod
117 | def _bytes_feature(values):
118 | """Return a TF-Feature of bytes.
119 |
120 | Arguments:
121 | values : A byte string or list of byte strings.
122 |
123 | Returns:
124 | A TF-Feature.
125 | """
126 | if not isinstance(values, (tuple, list)):
127 | values = [values]
128 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
129 |
130 | @staticmethod
131 | def _bytes_tfexample(bytes_dict):
132 | """Convert bytes to tfexample.
133 |
134 | `bytes_dict` example:
135 | bytes_dict = {
136 | 'img' : img_bytes,
137 | 'id' : id_bytes,
138 | 'attr' : attr_bytes,
139 | 'point' : point_bytes
140 | }
141 | """
142 | feature_dict = {}
143 | for key, value in bytes_dict.items():
144 | feature_dict[key] = BytesTfrecordCreator._bytes_feature(value)
145 | return tf.train.Example(features=tf.train.Features(feature=feature_dict))
146 |
147 |
148 | class DataLablePairTfrecordCreator(BytesTfrecordCreator):
149 | """DataLablePairTfrecordCreator.
150 |
151 | If `data_shape` is None, then the `data` to be added should be a
152 | numpy array, and the shape and dtype of `data` will be inferred.
153 | If `data_shape` is not None, `data` should be given as byte string,
154 | and `data_dtype_or_format` should also be given.
155 |
156 | `compression_type`:
157 | 0 : NONE
158 | 1 : ZLIB
159 | 2 : GZIP
160 | """
161 |
162 | def __init__(self,
163 | save_path,
164 | data_shape=None,
165 | data_dtype_or_format=None,
166 | data_name='data',
167 | size_each=None,
168 | compression_type=0,
169 | overwrite_existence=False):
170 | super(DataLablePairTfrecordCreator, self).__init__(save_path,
171 | [],
172 | size_each,
173 | compression_type,
174 | overwrite_existence)
175 |
176 | if data_shape:
177 | assert data_dtype_or_format, '`data_dtype_or_format` should be given when `data_shape` is given!'
178 | self._is_data_bytes = True
179 | else:
180 | self._is_data_bytes = False
181 |
182 | self._data_shape = data_shape
183 | self._data_dtype_or_format = data_dtype_or_format
184 | self._data_name = data_name
185 | self._label_shape_dict = {}
186 | self._label_dtype_dict = {}
187 |
188 | self._info_built = False
189 |
190 | def add(self, data, label_dict):
191 | """Add example.
192 |
193 | `label_dict` example:
194 | label_dict = {
195 | 'id' : id_ndarray,
196 | 'attr' : attr_ndarray,
197 | 'point' : point_ndarray
198 | }
199 | """
200 | self._check_and_build(data, label_dict)
201 |
202 | if not self._is_data_bytes:
203 | data = data.tobytes()
204 |
205 | feature_dict = {self._data_name: data}
206 | for name, label in label_dict.items():
207 | feature_dict[name] = label.tobytes()
208 |
209 | super(DataLablePairTfrecordCreator, self).add(feature_dict)
210 |
211 | def _check_and_build(self, data, label_dict):
212 | # check type
213 | if self._is_data_bytes:
214 | assert isinstance(data, (str, bytes)), '`data` should be a byte string!'
215 | else:
216 | assert isinstance(data, np.ndarray), '`data` should be a numpy array!'
217 | for label in label_dict.values():
218 | assert isinstance(label, np.ndarray), 'labels should be numpy arrays!'
219 |
220 | # check shape and dtype or bulid info at first adding
221 | if self._info_built:
222 | if not self._is_data_bytes:
223 | assert data.shape == tuple(self._data_shape), 'Shapes of `data`s are inconsistent!'
224 | assert data.dtype.name == self._data_dtype_or_format, 'Dtypes of `data`s are inconsistent!'
225 | for name, label in label_dict.items():
226 | assert label.shape == self._label_shape_dict[name], 'Shapes of `%s`s are inconsistent!' % name
227 | assert label.dtype.name == self._label_dtype_dict[name], 'Dtypes of `%s`s are inconsistent!' % name
228 | else:
229 | if not self._is_data_bytes:
230 | self._data_shape = data.shape
231 | self._data_dtype_or_format = data.dtype.name
232 | self._add_info(self._data_name, self._data_dtype_or_format, self._data_shape)
233 |
234 | for name, label in label_dict.items():
235 | self._label_shape_dict[name] = label.shape
236 | self._label_dtype_dict[name] = label.dtype.name
237 | self._add_info(name, label.dtype.name, label.shape)
238 |
239 | self._info_built = True
240 |
241 |
242 | class ImageLablePairTfrecordCreator(DataLablePairTfrecordCreator):
243 | """ImageLablePairTfrecordCreator.
244 |
245 | Arguments:
246 | encode_type : One of [None, 'png', 'jpg'].
247 | quality : For 'jpg'.
248 | compression_type :
249 | 0 : NONE
250 | 1 : ZLIB
251 | 2 : GZIP
252 | """
253 |
254 | def __init__(self,
255 | save_path,
256 | encode_type='png',
257 | quality=95,
258 | data_name='img',
259 | size_each=None,
260 | compression_type=0,
261 | overwrite_existence=False):
262 | super(ImageLablePairTfrecordCreator, self).__init__(save_path,
263 | None,
264 | None,
265 | data_name,
266 | size_each,
267 | compression_type,
268 | overwrite_existence)
269 |
270 | assert encode_type in [None, 'png', 'jpg'], "`encode_type` should be in the list of [None, 'png', 'jpg']!"
271 |
272 | self._encode_type = encode_type
273 | self._quality = quality
274 |
275 | self._data_shape = None
276 | self._data_dtype_or_format = None
277 | self._is_data_bytes = True
278 |
279 | def add(self, image, label_dict):
280 | """Add example.
281 |
282 | `image`: An H * W (* C) uint8 numpy array.
283 |
284 | `label_dict` example:
285 | label_dict = {
286 | 'id' : id_ndarray,
287 | 'attr' : attr_ndarray,
288 | 'point' : point_ndarray
289 | }
290 | """
291 | self._check(image)
292 | image_bytes = self._encode(image)
293 | super(ImageLablePairTfrecordCreator, self).add(image_bytes, label_dict)
294 |
295 | def _check(self, image):
296 | if not self._data_shape:
297 | assert isinstance(image, np.ndarray) and image.dtype == np.uint8 and image.ndim in [2, 3], \
298 | '`image` should be an H * W (* C) uint8 numpy array!'
299 | if self._encode_type and image.ndim == 3 and image.shape[-1] != 3:
300 | raise Exception('Only images with 1 or 3 channels are allowed to be encoded!')
301 |
302 | if image.ndim == 2:
303 | self._data_shape = image.shape + (1,)
304 | else:
305 | self._data_shape = image.shape
306 | self._data_dtype_or_format = [self._encode_type, 'uint8'][not self._encode_type]
307 | else:
308 | sp = image.shape
309 | if image.ndim == 2:
310 | sp = sp + (1,)
311 | assert sp == self._data_shape, 'Shapes of `image`s are inconsistent!'
312 | assert image.dtype == np.uint8, 'Dtypes of `image`s are inconsistent!'
313 |
314 | def _encode(self, image):
315 | if self._encode_type:
316 | if image.shape[-1] == 1:
317 | image.shape = image.shape[:2]
318 | byte = io.BytesIO()
319 | image = Image.fromarray(image)
320 | if self._encode_type == 'jpg':
321 | image.save(byte, 'JPEG', quality=self._quality)
322 | elif self._encode_type == 'png':
323 | image.save(byte, 'PNG')
324 | image_bytes = byte.getvalue()
325 | else:
326 | image_bytes = image.tobytes()
327 | return image_bytes
328 |
--------------------------------------------------------------------------------
/tflib/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.layers.layers import *
6 |
--------------------------------------------------------------------------------
/tflib/layers/layers.py:
--------------------------------------------------------------------------------
1 | # functions compatible with tensorflow.contrib
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import six
8 |
9 | import tensorflow as tf
10 |
11 | from tensorflow.contrib.framework.python.ops import add_arg_scope
12 | from tensorflow.contrib.framework.python.ops import variables
13 | from tensorflow.contrib.layers.python import layers
14 | from tensorflow.contrib.layers.python.layers import initializers
15 | from tensorflow.contrib.layers.python.layers import utils
16 |
17 | from tensorflow.python.framework import ops
18 | from tensorflow.python.ops import array_ops
19 | from tensorflow.python.ops import init_ops
20 | from tensorflow.python.ops import nn
21 | from tensorflow.python.ops import standard_ops
22 | from tensorflow.python.ops import variable_scope
23 |
24 |
25 | @add_arg_scope
26 | def fully_connected(inputs,
27 | num_outputs,
28 | activation_fn=nn.relu,
29 | normalizer_fn=None,
30 | normalizer_params=None,
31 | weights_normalizer_fn=None,
32 | weights_normalizer_params=None,
33 | weights_initializer=initializers.xavier_initializer(),
34 | weights_regularizer=None,
35 | biases_initializer=init_ops.zeros_initializer(),
36 | biases_regularizer=None,
37 | reuse=None,
38 | variables_collections=None,
39 | outputs_collections=None,
40 | trainable=True,
41 | scope=None):
42 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.fully_connected,
43 | # add weights_nomalizer_* options.
44 | """Adds a fully connected layer.
45 |
46 | `fully_connected` creates a variable called `weights`, representing a fully
47 | connected weight matrix, which is multiplied by the `inputs` to produce a
48 | `Tensor` of hidden units. If a `normalizer_fn` is provided (such as
49 | `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
50 | None and a `biases_initializer` is provided then a `biases` variable would be
51 | created and added the hidden units. Finally, if `activation_fn` is not `None`,
52 | it is applied to the hidden units as well.
53 |
54 | Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
55 | prior to the initial matrix multiply by `weights`.
56 |
57 | Args:
58 | inputs: A tensor of with at least rank 2 and value for the last dimension,
59 | i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
60 | num_outputs: Integer or long, the number of output units in the layer.
61 | activation_fn: activation function, set to None to skip it and maintain
62 | a linear activation.
63 | normalizer_fn: normalization function to use instead of `biases`. If
64 | `normalizer_fn` is provided then `biases_initializer` and
65 | `biases_regularizer` are ignored and `biases` are not created nor added.
66 | default set to None for no normalizer function
67 | normalizer_params: normalization function parameters.
68 | weights_normalizer_fn: weights normalization function.
69 | weights_normalizer_params: weights normalization function parameters.
70 | weights_initializer: An initializer for the weights.
71 | weights_regularizer: Optional regularizer for the weights.
72 | biases_initializer: An initializer for the biases. If None skip biases.
73 | biases_regularizer: Optional regularizer for the biases.
74 | reuse: whether or not the layer and its variables should be reused. To be
75 | able to reuse the layer scope must be given.
76 | variables_collections: Optional list of collections for all the variables or
77 | a dictionary containing a different list of collections per variable.
78 | outputs_collections: collection to add the outputs.
79 | trainable: If `True` also add variables to the graph collection
80 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
81 | scope: Optional scope for variable_scope.
82 |
83 | Returns:
84 | the tensor variable representing the result of the series of operations.
85 |
86 | Raises:
87 | ValueError: if x has rank less than 2 or if its last dimension is not set.
88 | """
89 | if not (isinstance(num_outputs, six.integer_types)):
90 | raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
91 | with variable_scope.variable_scope(scope, 'fully_connected', [inputs],
92 | reuse=reuse) as sc:
93 | inputs = ops.convert_to_tensor(inputs)
94 | dtype = inputs.dtype.base_dtype
95 | inputs_shape = inputs.get_shape()
96 | num_input_units = utils.last_dimension(inputs_shape, min_rank=2)
97 |
98 | static_shape = inputs_shape.as_list()
99 | static_shape[-1] = num_outputs
100 |
101 | out_shape = array_ops.unpack(array_ops.shape(inputs), len(static_shape))
102 | out_shape[-1] = num_outputs
103 |
104 | weights_shape = [num_input_units, num_outputs]
105 | weights_collections = utils.get_variable_collections(
106 | variables_collections, 'weights')
107 | weights = variables.model_variable('weights',
108 | shape=weights_shape,
109 | dtype=dtype,
110 | initializer=weights_initializer,
111 | regularizer=weights_regularizer,
112 | collections=weights_collections,
113 | trainable=trainable)
114 | if weights_normalizer_fn is not None:
115 | weights_normalizer_params = weights_normalizer_params or {}
116 | weights = weights_normalizer_fn(weights, **weights_normalizer_params)
117 | if len(static_shape) > 2:
118 | # Reshape inputs
119 | inputs = array_ops.reshape(inputs, [-1, num_input_units])
120 | outputs = standard_ops.matmul(inputs, weights)
121 | if normalizer_fn is not None:
122 | normalizer_params = normalizer_params or {}
123 | outputs = normalizer_fn(outputs, **normalizer_params)
124 | else:
125 | if biases_initializer is not None:
126 | biases_collections = utils.get_variable_collections(
127 | variables_collections, 'biases')
128 | biases = variables.model_variable('biases',
129 | shape=[num_outputs, ],
130 | dtype=dtype,
131 | initializer=biases_initializer,
132 | regularizer=biases_regularizer,
133 | collections=biases_collections,
134 | trainable=trainable)
135 | outputs = nn.bias_add(outputs, biases)
136 | if activation_fn is not None:
137 | outputs = activation_fn(outputs)
138 | if len(static_shape) > 2:
139 | # Reshape back outputs
140 | outputs = array_ops.reshape(outputs, array_ops.pack(out_shape))
141 | outputs.set_shape(static_shape)
142 | return utils.collect_named_outputs(outputs_collections,
143 | sc.original_name_scope, outputs)
144 |
145 |
146 | @add_arg_scope
147 | def flatten_fully_connected(inputs,
148 | num_outputs,
149 | activation_fn=nn.relu,
150 | normalizer_fn=None,
151 | normalizer_params=None,
152 | weights_normalizer_fn=None,
153 | weights_normalizer_params=None,
154 | weights_initializer=initializers.xavier_initializer(),
155 | weights_regularizer=None,
156 | biases_initializer=init_ops.zeros_initializer(),
157 | biases_regularizer=None,
158 | reuse=None,
159 | variables_collections=None,
160 | outputs_collections=None,
161 | trainable=True,
162 | scope=None):
163 | with variable_scope.variable_scope(scope, 'flatten_fully_connected'):
164 | if inputs.shape.ndims > 2:
165 | inputs = layers.flatten(inputs)
166 | return fully_connected(inputs=inputs,
167 | num_outputs=num_outputs,
168 | activation_fn=activation_fn,
169 | normalizer_fn=normalizer_fn,
170 | normalizer_params=normalizer_params,
171 | weights_normalizer_fn=weights_normalizer_fn,
172 | weights_normalizer_params=weights_normalizer_params,
173 | weights_initializer=weights_initializer,
174 | weights_regularizer=weights_regularizer,
175 | biases_initializer=biases_initializer,
176 | biases_regularizer=biases_regularizer,
177 | reuse=reuse,
178 | variables_collections=variables_collections,
179 | outputs_collections=outputs_collections,
180 | trainable=trainable,
181 | scope=scope)
182 |
183 | flatten_dense = flatten_fully_connected
184 |
185 |
186 | @add_arg_scope
187 | def convolution(inputs,
188 | num_outputs,
189 | kernel_size,
190 | stride=1,
191 | padding='SAME',
192 | data_format=None,
193 | rate=1,
194 | activation_fn=nn.relu,
195 | normalizer_fn=None,
196 | normalizer_params=None,
197 | weights_normalizer_fn=None,
198 | weights_normalizer_params=None,
199 | weights_initializer=initializers.xavier_initializer(),
200 | weights_regularizer=None,
201 | biases_initializer=init_ops.zeros_initializer(),
202 | biases_regularizer=None,
203 | reuse=None,
204 | variables_collections=None,
205 | outputs_collections=None,
206 | trainable=True,
207 | scope=None):
208 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.convolution,
209 | # add weights_nomalizer_* options.
210 | """Adds an N-D convolution followed by an optional batch_norm layer.
211 |
212 | It is required that 1 <= N <= 3.
213 |
214 | `convolution` creates a variable called `weights`, representing the
215 | convolutional kernel, that is convolved (actually cross-correlated) with the
216 | `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
217 | provided (such as `batch_norm`), it is then applied. Otherwise, if
218 | `normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
219 | variable would be created and added the activations. Finally, if
220 | `activation_fn` is not `None`, it is applied to the activations as well.
221 |
222 | Performs a'trous convolution with input stride/dilation rate equal to `rate`
223 | if a value > 1 for any dimension of `rate` is specified. In this case
224 | `stride` values != 1 are not supported.
225 |
226 | Args:
227 | inputs: a Tensor of rank N+2 of shape
228 | `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
229 | not start with "NC" (default), or
230 | `[batch_size, in_channels] + input_spatial_shape` if data_format starts
231 | with "NC".
232 | num_outputs: integer, the number of output filters.
233 | kernel_size: a sequence of N positive integers specifying the spatial
234 | dimensions of of the filters. Can be a single integer to specify the same
235 | value for all spatial dimensions.
236 | stride: a sequence of N positive integers specifying the stride at which to
237 | compute output. Can be a single integer to specify the same value for all
238 | spatial dimensions. Specifying any `stride` value != 1 is incompatible
239 | with specifying any `rate` value != 1.
240 | padding: one of `"VALID"` or `"SAME"`.
241 | data_format: A string or None. Specifies whether the channel dimension of
242 | the `input` and output is the last dimension (default, or if `data_format`
243 | does not start with "NC"), or the second dimension (if `data_format`
244 | starts with "NC"). For N=1, the valid values are "NWC" (default) and
245 | "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
246 | N=3, currently the only valid value is "NDHWC".
247 | rate: a sequence of N positive integers specifying the dilation rate to use
248 | for a'trous convolution. Can be a single integer to specify the same
249 | value for all spatial dimensions. Specifying any `rate` value != 1 is
250 | incompatible with specifying any `stride` value != 1.
251 | activation_fn: activation function, set to None to skip it and maintain
252 | a linear activation.
253 | normalizer_fn: normalization function to use instead of `biases`. If
254 | `normalizer_fn` is provided then `biases_initializer` and
255 | `biases_regularizer` are ignored and `biases` are not created nor added.
256 | default set to None for no normalizer function
257 | normalizer_params: normalization function parameters.
258 | weights_normalizer_fn: weights normalization function.
259 | weights_normalizer_params: weights normalization function parameters.
260 | weights_initializer: An initializer for the weights.
261 | weights_regularizer: Optional regularizer for the weights.
262 | biases_initializer: An initializer for the biases. If None skip biases.
263 | biases_regularizer: Optional regularizer for the biases.
264 | reuse: whether or not the layer and its variables should be reused. To be
265 | able to reuse the layer scope must be given.
266 | variables_collections: optional list of collections for all the variables or
267 | a dictionary containing a different list of collection per variable.
268 | outputs_collections: collection to add the outputs.
269 | trainable: If `True` also add variables to the graph collection
270 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
271 | scope: Optional scope for `variable_scope`.
272 |
273 | Returns:
274 | a tensor representing the output of the operation.
275 |
276 | Raises:
277 | ValueError: if `data_format` is invalid.
278 | ValueError: both 'rate' and `stride` are not uniformly 1.
279 | """
280 | if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC']:
281 | raise ValueError('Invalid data_format: %r' % (data_format,))
282 | with variable_scope.variable_scope(scope, 'Conv', [inputs],
283 | reuse=reuse) as sc:
284 | inputs = ops.convert_to_tensor(inputs)
285 | dtype = inputs.dtype.base_dtype
286 | input_rank = inputs.get_shape().ndims
287 | if input_rank is None:
288 | raise ValueError('Rank of inputs must be known')
289 | if input_rank < 3 or input_rank > 5:
290 | raise ValueError('Rank of inputs is %d, which is not >= 3 and <= 5' %
291 | input_rank)
292 | conv_dims = input_rank - 2
293 | kernel_size = utils.n_positive_integers(conv_dims, kernel_size)
294 | stride = utils.n_positive_integers(conv_dims, stride)
295 | rate = utils.n_positive_integers(conv_dims, rate)
296 |
297 | if data_format is None or data_format.endswith('C'):
298 | num_input_channels = inputs.get_shape()[input_rank - 1].value
299 | elif data_format.startswith('NC'):
300 | num_input_channels = inputs.get_shape()[1].value
301 | else:
302 | raise ValueError('Invalid data_format')
303 |
304 | if num_input_channels is None:
305 | raise ValueError('Number of in_channels must be known.')
306 |
307 | weights_shape = (
308 | list(kernel_size) + [num_input_channels, num_outputs])
309 | weights_collections = utils.get_variable_collections(variables_collections,
310 | 'weights')
311 | weights = variables.model_variable('weights',
312 | shape=weights_shape,
313 | dtype=dtype,
314 | initializer=weights_initializer,
315 | regularizer=weights_regularizer,
316 | collections=weights_collections,
317 | trainable=trainable)
318 | if weights_normalizer_fn is not None:
319 | weights_normalizer_params = weights_normalizer_params or {}
320 | weights = weights_normalizer_fn(weights, **weights_normalizer_params)
321 | outputs = nn.convolution(input=inputs,
322 | filter=weights,
323 | dilation_rate=rate,
324 | strides=stride,
325 | padding=padding,
326 | data_format=data_format)
327 | if normalizer_fn is not None:
328 | normalizer_params = normalizer_params or {}
329 | outputs = normalizer_fn(outputs, **normalizer_params)
330 | else:
331 | if biases_initializer is not None:
332 | biases_collections = utils.get_variable_collections(
333 | variables_collections, 'biases')
334 | biases = variables.model_variable('biases',
335 | shape=[num_outputs],
336 | dtype=dtype,
337 | initializer=biases_initializer,
338 | regularizer=biases_regularizer,
339 | collections=biases_collections,
340 | trainable=trainable)
341 | outputs = nn.bias_add(outputs, biases, data_format=data_format)
342 | if activation_fn is not None:
343 | outputs = activation_fn(outputs)
344 | return utils.collect_named_outputs(outputs_collections,
345 | sc.original_name_scope, outputs)
346 |
347 |
348 | convolution2d = convolution
349 | convolution3d = convolution
350 |
351 |
352 | @add_arg_scope
353 | def spectral_normalization(weights,
354 | num_iterations=1,
355 | epsilon=1e-12,
356 | u_initializer=tf.random_normal_initializer(),
357 | updates_collections=tf.GraphKeys.UPDATE_OPS,
358 | is_training=True,
359 | reuse=None,
360 | variables_collections=None,
361 | outputs_collections=None,
362 | scope=None):
363 | with tf.variable_scope(scope, 'SpectralNorm', [weights], reuse=reuse) as sc:
364 | weights = tf.convert_to_tensor(weights)
365 |
366 | dtype = weights.dtype.base_dtype
367 |
368 | w_t = tf.reshape(weights, [-1, weights.shape.as_list()[-1]])
369 | w = tf.transpose(w_t)
370 | m, n = w.shape.as_list()
371 |
372 | u_collections = utils.get_variable_collections(variables_collections, 'u')
373 | u = tf.get_variable("u",
374 | shape=[m, 1],
375 | dtype=dtype,
376 | initializer=u_initializer,
377 | trainable=False,
378 | collections=u_collections,)
379 | sigma_collections = utils.get_variable_collections(variables_collections, 'sigma')
380 | sigma = tf.get_variable('sigma',
381 | shape=[],
382 | dtype=dtype,
383 | initializer=tf.zeros_initializer(),
384 | trainable=False,
385 | collections=sigma_collections)
386 |
387 | def _power_iteration(i, u, v):
388 | v_ = tf.nn.l2_normalize(tf.matmul(w_t, u), epsilon=epsilon)
389 | u_ = tf.nn.l2_normalize(tf.matmul(w, v_), epsilon=epsilon)
390 | return i + 1, u_, v_
391 |
392 | _, u_, v_ = tf.while_loop(
393 | cond=lambda i, _1, _2: i < num_iterations,
394 | body=_power_iteration,
395 | loop_vars=[tf.constant(0), u, tf.zeros(shape=[n, 1], dtype=tf.float32)]
396 | )
397 | u_ = tf.stop_gradient(u_)
398 | v_ = tf.stop_gradient(v_)
399 | sigma_ = tf.matmul(tf.transpose(u_), tf.matmul(w, v_))[0, 0]
400 |
401 | update_u = u.assign(u_)
402 | update_sigma = sigma.assign(sigma_)
403 | if updates_collections is None:
404 | def _force_update():
405 | with tf.control_dependencies([update_u, update_sigma]):
406 | return tf.identity(sigma_)
407 |
408 | sigma_ = utils.smart_cond(is_training, _force_update, lambda: sigma)
409 | weights_sn = weights / sigma_
410 | else:
411 | sigma_ = utils.smart_cond(is_training, lambda: sigma_, lambda: sigma)
412 | weights_sn = weights / sigma_
413 | tf.add_to_collections(updates_collections, update_u)
414 | tf.add_to_collections(updates_collections, update_sigma)
415 |
416 | return utils.collect_named_outputs(outputs_collections, sc.name, weights_sn)
417 |
418 |
419 | # Simple alias.
420 | conv2d = convolution2d
421 | conv3d = convolution3d
422 |
--------------------------------------------------------------------------------
/tflib/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.ops.ops import *
6 |
--------------------------------------------------------------------------------
/tflib/ops/ops.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def minmax_norm(x, epsilon=1e-12):
9 | x = tf.to_float(x)
10 | min_val = tf.reduce_min(x)
11 | max_val = tf.reduce_max(x)
12 | x_norm = (x - min_val) / tf.maximum((max_val - min_val), epsilon)
13 | return x_norm
14 |
--------------------------------------------------------------------------------
/tflib/parallel.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def average_gradients(tower_grads):
9 | """Calculate the average gradient for each shared variable across all towers.
10 | Copied from https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py.
11 | Note that this function provides a synchronization point across all towers.
12 | Args:
13 | tower_grads: List of lists of (gradient, variable) tuples. The outer list
14 | is over individual gradients. The inner list is over the gradient
15 | calculation for each tower.
16 | Returns:
17 | List of pairs of (gradient, variable) where the gradient has been averaged
18 | across all towers.
19 | """
20 | average_grads = []
21 | for grad_and_vars in zip(*tower_grads):
22 | # Note that each grad_and_vars looks like the following:
23 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
24 | grads = []
25 | for g, _ in grad_and_vars:
26 | # Add 0 dimension to the gradients to represent the tower.
27 | expanded_g = tf.expand_dims(g, 0)
28 |
29 | # Append on a 'tower' dimension which we will average over below.
30 | grads.append(expanded_g)
31 |
32 | # Average over the 'tower' dimension.
33 | grad = tf.concat(axis=0, values=grads)
34 | grad = tf.reduce_mean(grad, 0)
35 |
36 | # Keep in mind that the Variables are redundant because they are shared
37 | # across towers. So .. we will just return the first tower's pointer to
38 | # the Variable.
39 | v = grad_and_vars[0][1]
40 | grad_and_var = (grad, v)
41 | average_grads.append(grad_and_var)
42 | return average_grads
43 |
--------------------------------------------------------------------------------
/tflib/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def session(graph=None,
9 | allow_soft_placement=True,
10 | log_device_placement=False,
11 | allow_growth=True):
12 | """Return a Session with simple config."""
13 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
14 | log_device_placement=log_device_placement)
15 | config.gpu_options.allow_growth = allow_growth
16 | return tf.Session(graph=graph, config=config)
17 |
18 |
19 | def print_tensor(tensors):
20 | if not isinstance(tensors, (list, tuple)):
21 | tensors = [tensors]
22 |
23 | for i, tensor in enumerate(tensors):
24 | ctype = str(type(tensor))
25 | if 'Tensor' in ctype:
26 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' %
27 | (i, 'Tensor', tensor.name, tensor.shape, tensor.dtype.name, tensor.device))
28 | elif 'Variable' in ctype:
29 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' %
30 | (i, 'Variable', tensor.name, tensor.shape, tensor.dtype.name, tensor.device))
31 | elif 'Operation' in ctype:
32 | print('%d: %s("%s", device=%s)' %
33 | (i, 'Operation', tensor.name, tensor.device))
34 | else:
35 | raise Exception('Not a Tensor, Variable or Operation!')
36 |
37 |
38 | prt = print_tensor
39 |
40 |
41 | def shape(tensor):
42 | sp = tensor.get_shape().as_list()
43 | return [num if num is not None else -1 for num in sp]
44 |
45 |
46 | def summary(tensor_collection,
47 | summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram'],
48 | scope=None):
49 | """Summary.
50 |
51 | Usage:
52 | 1. summary(tensor)
53 | 2. summary([tensor_a, tensor_b])
54 | 3. summary({tensor_a: 'a', tensor_b: 'b})
55 | """
56 | def _summary(tensor, name, summary_type):
57 | """Attach a lot of summaries to a Tensor."""
58 | if name is None:
59 | name = tensor.name
60 |
61 | summaries = []
62 | if len(tensor.shape) == 0:
63 | summaries.append(tf.summary.scalar(name, tensor))
64 | else:
65 | if 'mean' in summary_type:
66 | mean = tf.reduce_mean(tensor)
67 | summaries.append(tf.summary.scalar(name + '/mean', mean))
68 | if 'stddev' in summary_type:
69 | mean = tf.reduce_mean(tensor)
70 | stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean)))
71 | summaries.append(tf.summary.scalar(name + '/stddev', stddev))
72 | if 'max' in summary_type:
73 | summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor)))
74 | if 'min' in summary_type:
75 | summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor)))
76 | if 'sparsity' in summary_type:
77 | summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor)))
78 | if 'histogram' in summary_type:
79 | summaries.append(tf.summary.histogram(name, tensor))
80 | return tf.summary.merge(summaries)
81 |
82 | if not isinstance(tensor_collection, (list, tuple, dict)):
83 | tensor_collection = [tensor_collection]
84 |
85 | with tf.name_scope(scope, 'summary'):
86 | summaries = []
87 | if isinstance(tensor_collection, (list, tuple)):
88 | for tensor in tensor_collection:
89 | summaries.append(_summary(tensor, None, summary_type))
90 | else:
91 | for tensor, name in tensor_collection.items():
92 | summaries.append(_summary(tensor, name, summary_type))
93 | return tf.summary.merge(summaries)
94 |
95 |
96 | def counter(start=0, scope=None):
97 | with tf.variable_scope(scope, 'counter'):
98 | counter = tf.get_variable(name='counter',
99 | initializer=tf.constant_initializer(start),
100 | shape=(),
101 | dtype=tf.int64)
102 | update_cnt = tf.assign(counter, tf.add(counter, 1))
103 | return counter, update_cnt
104 |
--------------------------------------------------------------------------------
/tflib/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.vision.dataset import *
6 |
--------------------------------------------------------------------------------
/tflib/vision/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from tflib.vision.dataset.mnist import *
6 |
--------------------------------------------------------------------------------
/tflib/vision/dataset/mnist.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import gzip
6 | import multiprocessing
7 | import os
8 | import struct
9 | import subprocess
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from tflib.data.memory_data import MemoryData
15 |
16 |
17 | _N_CPU = multiprocessing.cpu_count()
18 |
19 |
20 | def mnist_download(download_dir):
21 | url_base = 'http://yann.lecun.com/exdb/mnist/'
22 | file_names = ['train-images-idx3-ubyte.gz',
23 | 'train-labels-idx1-ubyte.gz',
24 | 't10k-images-idx3-ubyte.gz',
25 | 't10k-labels-idx1-ubyte.gz']
26 | for file_name in file_names:
27 | url = url_base + file_name
28 | save_path = os.path.join(download_dir, file_name)
29 | cmd = ['curl', url, '-o', save_path]
30 | print('Downloading ', file_name)
31 | if not os.path.exists(save_path):
32 | subprocess.call(cmd)
33 | else:
34 | print('%s exists, skip!' % file_name)
35 |
36 |
37 | def mnist_load(data_dir, split='train'):
38 | """Load MNIST dataset, modified from https://gist.github.com/akesling/5358964.
39 |
40 | Returns:
41 | `imgs`, `lbls`, `num`.
42 |
43 | `imgs` : [-1.0, 1.0] float64 images of shape (N * H * W).
44 | `lbls` : Int labels of shape (N,).
45 | `num` : # of datas.
46 | """
47 | mnist_download(data_dir)
48 |
49 | if split == 'train':
50 | fname_img = os.path.join(data_dir, 'train-images-idx3-ubyte')
51 | fname_lbl = os.path.join(data_dir, 'train-labels-idx1-ubyte')
52 | elif split == 'test':
53 | fname_img = os.path.join(data_dir, 't10k-images-idx3-ubyte')
54 | fname_lbl = os.path.join(data_dir, 't10k-labels-idx1-ubyte')
55 | else:
56 | raise ValueError("`split` must be 'test' or 'train'")
57 |
58 | def _unzip_gz(file_name):
59 | unzip_name = file_name.replace('.gz', '')
60 | gz_file = gzip.GzipFile(file_name)
61 | open(unzip_name, 'w+').write(gz_file.read())
62 | gz_file.close()
63 |
64 | if not os.path.exists(fname_img):
65 | _unzip_gz(fname_img + '.gz')
66 | if not os.path.exists(fname_lbl):
67 | _unzip_gz(fname_lbl + '.gz')
68 |
69 | with open(fname_lbl, 'rb') as flbl:
70 | struct.unpack('>II', flbl.read(8))
71 | lbls = np.fromfile(flbl, dtype=np.int8)
72 |
73 | with open(fname_img, 'rb') as fimg:
74 | _, _, rows, cols = struct.unpack('>IIII', fimg.read(16))
75 | imgs = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbls), rows, cols)
76 | imgs = imgs / 127.5 - 1
77 |
78 | return imgs, lbls, len(lbls)
79 |
80 |
81 | class Mnist(MemoryData):
82 |
83 | def __init__(self,
84 | data_dir,
85 | batch_size,
86 | split='train',
87 | prefetch_batch=_N_CPU + 1,
88 | drop_remainder=True,
89 | filter=None,
90 | map_func=None,
91 | num_threads=_N_CPU,
92 | shuffle=True,
93 | buffer_size=None,
94 | repeat=-1,
95 | sess=None):
96 | imgs, lbls, _ = mnist_load(data_dir, split)
97 | imgs.shape = imgs.shape + (1,)
98 |
99 | imgs_pl = tf.placeholder(tf.float32, imgs.shape)
100 | lbls_pl = tf.placeholder(tf.int64, lbls.shape)
101 |
102 | memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl}
103 |
104 | self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls}
105 | super(Mnist, self).__init__(memory_data_dict,
106 | batch_size,
107 | prefetch_batch,
108 | drop_remainder,
109 | filter,
110 | map_func,
111 | num_threads,
112 | shuffle,
113 | buffer_size,
114 | repeat,
115 | sess)
116 |
117 | def reset(self):
118 | super(Mnist, self).reset(self.feed_dict)
119 |
120 | if __name__ == '__main__':
121 | import imlib as im
122 | from tflib import session
123 | sess = session()
124 | mnist = Mnist('/tmp', 5000, repeat=1, sess=sess)
125 | print(len(mnist))
126 | for batch in mnist:
127 | print(batch['lbl'][-1])
128 | im.imshow(batch['img'][-1].squeeze())
129 | im.show()
130 | sess.close()
131 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | from functools import partial
7 | import json
8 | import traceback
9 |
10 | import data
11 | import imlib as im
12 | import model
13 | import numpy as np
14 | import pylib
15 | import tensorflow as tf
16 | import tflib as tl
17 |
18 |
19 | # ==============================================================================
20 | # = param =
21 | # ==============================================================================
22 |
23 | # argument
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--att', dest='att', default='', choices=list(data.Celeba.att_dict.keys()) + [''])
26 | parser.add_argument('--ks', dest='ks', type=int, default=[2, 3, 3], nargs='+', help='k each layer')
27 | parser.add_argument('--lambdas', dest='lambdas', type=float, default=[1., 1., 1.], nargs='+', help='loss weight of each layer')
28 | parser.add_argument('--continuous_last', dest='continuous_last', action='store_true')
29 | parser.add_argument('--half_acgan', dest='half_acgan', action='store_true')
30 |
31 | parser.add_argument('--epoch', dest='epoch', type=int, default=100)
32 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64)
33 | parser.add_argument('--lr_d', dest='lr_d', type=float, default=0.0002, help='learning rate of d')
34 | parser.add_argument('--lr_g', dest='lr_g', type=float, default=0.0002, help='learning rate of g')
35 | parser.add_argument('--n_d', dest='n_d', type=int, default=1)
36 | parser.add_argument('--n_g', dest='n_g', type=int, default=1)
37 | parser.add_argument('--n_d_pre', dest='n_d_pre', type=int, default=0)
38 | parser.add_argument('--optimizer', dest='optimizer', default='adam', choices=['adam', 'rmsprop'])
39 |
40 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=100, help='dimension of latent')
41 | parser.add_argument('--loss_mode', dest='loss_mode', default='gan', choices=['gan', 'lsgan', 'wgan', 'hinge'])
42 | parser.add_argument('--gp_mode', dest='gp_mode', default='none', choices=['none', 'dragan', 'wgan-gp'], help='type of gradient penalty')
43 | parser.add_argument('--norm', dest='norm', default='batch_norm', choices=['batch_norm', 'instance_norm', 'layer_norm', 'none'])
44 |
45 | parser.add_argument('--experiment_name', dest='experiment_name', default='default')
46 |
47 | args = parser.parse_args()
48 |
49 | att = args.att
50 | ks = args.ks
51 | if att != '':
52 | ks[0] = 2
53 | lambdas = args.lambdas
54 | assert len(ks) == len(lambdas), 'The lens of `ks` and `lambdas` should be the same!'
55 | continuous_last = args.continuous_last
56 | if len(ks) == 1 and att != '':
57 | continuous_last = False
58 | half_acgan = args.half_acgan
59 |
60 | epoch = args.epoch
61 | batch_size = args.batch_size
62 | lr_d = args.lr_d
63 | lr_g = args.lr_g
64 | n_d = args.n_d
65 | n_g = args.n_g
66 | n_d_pre = args.n_d_pre
67 | optimizer = args.optimizer
68 |
69 | z_dim = args.z_dim
70 | loss_mode = args.loss_mode
71 | gp_mode = args.gp_mode
72 | norm = args.norm
73 |
74 | experiment_name = args.experiment_name
75 |
76 | pylib.mkdir('./output/%s' % experiment_name)
77 | with open('./output/%s/setting.txt' % experiment_name, 'w') as f:
78 | f.write(json.dumps(vars(args), indent=4, separators=(',', ':')))
79 |
80 | img_size = 64
81 |
82 | # dataset
83 | dataset = data.Celeba('./data', ['Bangs' if att == '' else att], img_size, batch_size)
84 |
85 |
86 | # ==============================================================================
87 | # = graph =
88 | # ==============================================================================
89 |
90 | # models
91 | c_dim = len(model.sample_c(ks)[0])
92 | D = partial(model.D, c_dim=c_dim, norm_name=norm)
93 | G = model.G
94 |
95 | # otpimizer
96 | if optimizer == 'adam':
97 | optim = partial(tf.train.AdamOptimizer, beta1=0.5)
98 | elif optimizer == 'rmsprop':
99 | optim = tf.train.RMSPropOptimizer
100 |
101 | # loss func
102 | d_loss_fn, g_loss_fn = model.get_loss_fn(loss_mode)
103 | tree_loss_fn = partial(model.tree_loss, ks=ks, continuous_last=continuous_last)
104 |
105 | # inputs
106 | real = tf.placeholder(tf.float32, [None] + [img_size, img_size, 3])
107 | z = tf.placeholder(tf.float32, [None, z_dim])
108 | c = tf.placeholder(tf.float32, [None, c_dim])
109 | mask = tf.placeholder(tf.float32, [None, c_dim])
110 |
111 | counter = tf.placeholder(tf.int64, [])
112 | layer_mask = tf.constant(np.tril(np.ones(len(ks))), dtype=tf.float32)[counter // (epoch // len(ks))]
113 |
114 | # generate
115 | fake = G(z, c)
116 |
117 | # dicriminate
118 | r_logit, r_c_logit = D(real)
119 | f_logit, f_c_logit = D(fake)
120 |
121 | # d loss
122 | d_r_loss, d_f_loss = d_loss_fn(r_logit, f_logit)
123 | d_f_tree_losses = tree_loss_fn(f_c_logit, c, mask)
124 | if att != '':
125 | d_r_tree_losses = tree_loss_fn(r_c_logit, c, mask)
126 | start = 1 if half_acgan else 0
127 | d_tree_loss = sum([d_f_tree_losses[i] * lambdas[i] * layer_mask[i] for i in range(start, len(lambdas))])
128 | d_tree_loss += d_r_tree_losses[0] * lambdas[0] * layer_mask[0]
129 | else:
130 | d_tree_loss = sum([d_f_tree_losses[i] * lambdas[i] for i in range(len(lambdas))])
131 | gp = model.gradient_penalty(D, real, fake, gp_mode)
132 | d_loss = d_r_loss + d_f_loss + d_tree_loss + gp * 10.0
133 |
134 | # g loss
135 | g_f_loss = g_loss_fn(f_logit)
136 | g_f_tree_losses = tree_loss_fn(f_c_logit, c, mask)
137 | g_tree_loss = sum([g_f_tree_losses[i] * lambdas[i] * layer_mask[i] for i in range(len(lambdas))])
138 | g_loss = g_f_loss + g_tree_loss
139 |
140 | # optims
141 | d_step = optim(learning_rate=lr_d).minimize(d_loss, var_list=tl.trainable_variables(includes='D'))
142 | g_step = optim(learning_rate=lr_g).minimize(g_loss, var_list=tl.trainable_variables(includes='G'))
143 |
144 | # summaries
145 | d_summary = tl.summary({d_r_loss: 'd_r_loss',
146 | d_f_loss: 'd_f_loss',
147 | d_r_loss + d_f_loss: 'd_loss',
148 | gp: 'gp'}, scope='D')
149 | tmp = {l: 'd_f_tree_loss_%d' % i for i, l in enumerate(d_f_tree_losses)}
150 | if att != '':
151 | tmp.update({d_r_tree_losses[0]: 'd_r_tree_loss_0'})
152 | d_tree_summary = tl.summary(tmp, scope='D_Tree')
153 | d_summary = tf.summary.merge([d_summary, d_tree_summary])
154 |
155 | g_summary = tl.summary({g_f_loss: 'g_f_loss'}, scope='G')
156 | g_tree_summary = tl.summary({l: 'g_f_tree_loss_%d' % i for i, l in enumerate(g_f_tree_losses)}, scope='G_Tree')
157 | g_summary = tf.summary.merge([g_summary, g_tree_summary])
158 |
159 | # sample
160 | z_sample = tf.placeholder(tf.float32, [None, z_dim])
161 | c_sample = tf.placeholder(tf.float32, [None, c_dim])
162 | f_sample = G(z_sample, c_sample, is_training=False)
163 |
164 |
165 | # ==============================================================================
166 | # = train =
167 | # ==============================================================================
168 |
169 | # epoch counter
170 | ep_cnt, update_cnt = tl.counter(start=1)
171 |
172 | # session
173 | sess = tl.session()
174 |
175 | # saver
176 | saver = tf.train.Saver(max_to_keep=1)
177 |
178 | # summary writer
179 | summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph)
180 |
181 | # initialization
182 | ckpt_dir = './output/%s/checkpoints' % experiment_name
183 | pylib.mkdir(ckpt_dir)
184 | try:
185 | tl.load_checkpoint(ckpt_dir, sess)
186 | except:
187 | sess.run(tf.global_variables_initializer())
188 |
189 | # train
190 | try:
191 | c_ipt_sample = np.stack(model.traversal_trees(ks, continuous_last=continuous_last)[1])
192 | z_ipt_samples = [np.stack([np.random.normal(size=[z_dim])] * len(c_ipt_sample)) for i in range(15)]
193 |
194 | it = 0
195 | it_per_epoch = len(dataset) // (batch_size * n_d)
196 | for ep in range(sess.run(ep_cnt), epoch + 1):
197 | sess.run(update_cnt)
198 |
199 | dataset.reset()
200 | for i in range(it_per_epoch):
201 | it += 1
202 |
203 | # train D
204 | if n_d_pre > 0 and it <= 25:
205 | n_d_ = n_d_pre
206 | else:
207 | n_d_ = n_d
208 | for _ in range(n_d_):
209 | # batch data
210 | real_ipt, att_ipt = dataset.get_next()
211 | c_ipt = []
212 | mask_ipt = []
213 | for idx in range(batch_size):
214 | if att == '':
215 | c_1 = None
216 | else:
217 | if att_ipt[idx] == 1:
218 | c_1 = np.array([1.0, 0])
219 | else:
220 | c_1 = np.array([0, 1.0])
221 | c_tmp, mask_tmp, _, _ = model.sample_c(ks, c_1, continuous_last)
222 | c_ipt.append(c_tmp)
223 | mask_ipt.append(mask_tmp)
224 | c_ipt = np.stack(c_ipt)
225 | mask_ipt = np.stack(mask_ipt)
226 | z_ipt = np.random.normal(size=[batch_size, z_dim])
227 |
228 | d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={real: real_ipt, z: z_ipt, c: c_ipt, mask: mask_ipt, counter: ep})
229 | summary_writer.add_summary(d_summary_opt, it)
230 |
231 | # train G
232 | for _ in range(n_g):
233 | # batch data
234 | z_ipt = np.random.normal(size=[batch_size, z_dim])
235 |
236 | g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={z: z_ipt, c: c_ipt, mask: mask_ipt, counter: ep})
237 | summary_writer.add_summary(g_summary_opt, it)
238 |
239 | # display
240 | if it % 1 == 0:
241 | print("Epoch: (%3d) (%5d/%5d)" % (ep, i + 1, it_per_epoch))
242 |
243 | # sample
244 | if it % 100 == 0:
245 | merge = []
246 | for z_ipt_sample in z_ipt_samples:
247 | f_sample_opt = sess.run(f_sample, feed_dict={z_sample: z_ipt_sample, c_sample: c_ipt_sample}).squeeze()
248 |
249 | k_prod = 1
250 | for k in ks:
251 | k_prod *= k
252 | f_sample_opts_k = list(f_sample_opt)
253 | for idx in range(len(f_sample_opts_k)):
254 | if idx % (len(f_sample_opts_k) / k_prod) != 0:
255 | f_sample_opts_k[idx] = np.zeros_like(f_sample_opts_k[idx])
256 | merge.append(np.concatenate(f_sample_opts_k, axis=1))
257 | merge = np.concatenate(merge, axis=0)
258 |
259 | save_dir = './output/%s/sample_training' % experiment_name
260 | pylib.mkdir(save_dir)
261 | im.imwrite(merge, '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, ep, i + 1, it_per_epoch))
262 |
263 | save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep))
264 | print('Model is saved in file: %s' % save_path)
265 | except:
266 | traceback.print_exc()
267 | finally:
268 | sess.close()
269 |
--------------------------------------------------------------------------------