├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── imlib ├── __init__.py ├── basic.py ├── dtype.py ├── encode.py └── transform.py ├── model.py ├── pics ├── bangs.png └── eyeglasses.png ├── pylib ├── __init__.py ├── path.py └── timer.py ├── tflib ├── __init__.py ├── checkpoint.py ├── collection.py ├── data │ ├── __init__.py │ ├── dataset.py │ ├── disk_image.py │ ├── memory_data.py │ ├── tfrecord.py │ └── tfrecord_creator.py ├── layers │ ├── __init__.py │ └── layers.py ├── ops │ ├── __init__.py │ └── ops.py ├── parallel.py ├── utils.py └── vision │ ├── __init__.py │ └── dataset │ ├── __init__.py │ └── mnist.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | /data/ 4 | /output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 hezhenliang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a 4 | copy of this software and associated documentation files (the "Software"), 5 | to deal in the Software without restriction, including without limitation 6 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | and/or sell copies of the Software, and to permit persons to whom the 8 | Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 19 | DEALINGS IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

DTLC-GAN

2 | 3 | Tensorflow implementation of [DTLC-GAN (CVPR 2018): 4 | Generative Adversarial Image Synthesis with Decision Tree Latent Controller](https://arxiv.org/abs/1805.10603). 5 | 6 | ## Usage 7 | 8 | - Prerequisites 9 | - Tensorflow 1.9 10 | - Python 3.6 11 | 12 | - Training 13 | - Important Arguments (See the others in [train.py](train.py)) 14 | - `att`: attribute to learn (default: `''`) 15 | - `ks`: # of outputs of each node of each layer (default: `[2, 3, 3]`) 16 | - `lambdas`: loss weights of each layer (default: `[1.0, 1.0, 1.0]`) 17 | - `--n_d`: # of d steps in each iteration (default: `1`) 18 | - `--n_g`: # of g steps in each iteration (default: `1`) 19 | - `--loss_mode`: gan loss (choices: `[gan, lsgan, wgan, hinge]`, default: `gan`) 20 | - `--gp_mode`: type of gradient penalty (choices: `[none, dragan, wgan-gp]`, default: `none`) 21 | - `--norm`: normalization (choices: `[batch_norm, instance_norm, layer_norm, none]`, default: `batch_norm`) 22 | - `--experiment_name`: name for current experiment (default: `default`) 23 | - Example 24 | ```console 25 | CUDA_VISIBLE_DEVICES=0 \ 26 | python train.py \ 27 | --att Eyeglasses \ 28 | --ks 2 3 3 \ 29 | --lambdas 1 1 1 \ 30 | --n_d 1 \ 31 | --n_g 1 \ 32 | --loss_mode hinge \ 33 | --gp_mode dragan \ 34 | --norm layer_norm \ 35 | --experiment_name att{Eyeglasses}_ks{2-3-3}_lambdas{1-1-1}_continuous_last{False}_loss{hinge}_gp{dragan}_norm{layer_norm} 36 | ``` 37 | 38 | ## Dataset 39 | 40 | - [Celeba](http://openaccess.thecvf.com/content_iccv_2015/papers/Liu_Deep_Learning_Face_ICCV_2015_paper.pdf) dataset 41 | - [Images](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADSNUu0bseoCKuxuI5ZeTl1a/Img?dl=0&preview=img_align_celeba.zip) should be placed in ***./data/img_align_celeba/\*.jpg*** 42 | - [Attribute labels](https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AAA8YmAHNNU6BEfWMPMfM6r9a/Anno?dl=0&preview=list_attr_celeba.txt) should be placed in ***./data/list_attr_celeba.txt*** 43 | - the above links might be inaccessible, the alternatives are 44 | - ***img_align_celeba.zip*** 45 | - https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FImg or 46 | - https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg 47 | - ***list_attr_celeba.txt*** 48 | - https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FAnno&parentPath=%2F or 49 | - https://drive.google.com/drive/folders/0B7EVK8r0v71pOC0wOVZlQnFfaGs 50 | 51 | ## Exemplar Results 52 | 53 | 1. Eyeglasses, 3 layers 54 |

55 | 56 | 2. Bangs, 3 layers 57 |

-------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import multiprocessing 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import tflib as tl 11 | 12 | 13 | _N_CPU = multiprocessing.cpu_count() 14 | 15 | 16 | class Celeba(tl.Dataset): 17 | 18 | att_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 19 | 'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 20 | 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10, 21 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 22 | 'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16, 23 | 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, 24 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 25 | 'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25, 26 | 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 27 | 'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31, 28 | 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, 29 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 30 | 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 31 | 32 | def __init__(self, 33 | data_dir, 34 | atts, 35 | img_resize, 36 | batch_size, 37 | prefetch_batch=_N_CPU + 1, 38 | drop_remainder=True, 39 | num_threads=_N_CPU, 40 | shuffle=True, 41 | shuffle_buffer_size=None, 42 | repeat=-1, 43 | sess=None, 44 | split='train', 45 | crop=True): 46 | super(Celeba, self).__init__() 47 | 48 | list_file = os.path.join(data_dir, 'list_attr_celeba.txt') 49 | if crop: 50 | img_dir_jpg = os.path.join(data_dir, 'img_align_celeba') 51 | img_dir_png = os.path.join(data_dir, 'img_align_celeba_png') 52 | else: 53 | img_dir_jpg = os.path.join(data_dir, 'img_crop_celeba') 54 | img_dir_png = os.path.join(data_dir, 'img_crop_celeba_png') 55 | 56 | names = np.loadtxt(list_file, skiprows=2, usecols=[0], dtype=np.str) 57 | if os.path.exists(img_dir_png): 58 | img_paths = [os.path.join(img_dir_png, name.replace('jpg', 'png')) for name in names] 59 | elif os.path.exists(img_dir_jpg): 60 | img_paths = [os.path.join(img_dir_jpg, name) for name in names] 61 | 62 | att_id = [Celeba.att_dict[att] + 1 for att in atts] 63 | labels = np.loadtxt(list_file, skiprows=2, usecols=att_id, dtype=np.int64) 64 | 65 | if img_resize == 64: 66 | # crop as how VAE/GAN do 67 | offset_h = 40 68 | offset_w = 15 69 | img_size = 148 70 | else: 71 | offset_h = 26 72 | offset_w = 3 73 | img_size = 170 74 | 75 | def _map_func(img, label): 76 | if crop: 77 | img = tf.image.crop_to_bounding_box(img, offset_h, offset_w, img_size, img_size) 78 | # img = tf.image.resize_images(img, [img_resize, img_resize]) / 127.5 - 1 79 | # or 80 | img = tf.image.resize_images(img, [img_resize, img_resize], tf.image.ResizeMethod.BICUBIC) 81 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 82 | label = (label + 1) // 2 83 | return img, label 84 | 85 | if split == 'test': 86 | drop_remainder = False 87 | shuffle = False 88 | repeat = 1 89 | img_paths = img_paths[182637:] 90 | labels = labels[182637:] 91 | elif split == 'val': 92 | img_paths = img_paths[182000:182637] 93 | labels = labels[182000:182637] 94 | else: 95 | img_paths = img_paths[:182000] 96 | labels = labels[:182000] 97 | 98 | dataset = tl.disk_image_batch_dataset(img_paths=img_paths, 99 | labels=labels, 100 | batch_size=batch_size, 101 | prefetch_batch=prefetch_batch, 102 | drop_remainder=drop_remainder, 103 | map_func=_map_func, 104 | num_threads=num_threads, 105 | shuffle=shuffle, 106 | shuffle_buffer_size=shuffle_buffer_size, 107 | repeat=repeat) 108 | self._bulid(dataset, sess) 109 | 110 | self._img_num = len(img_paths) 111 | 112 | def __len__(self): 113 | return self._img_num 114 | 115 | 116 | if __name__ == '__main__': 117 | import imlib as im 118 | atts = ['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses', 119 | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'] 120 | data = Celeba('./data', atts, 128, 32, split='val') 121 | batch = data.get_next() 122 | print(len(data)) 123 | print(batch[1][1], batch[1].dtype) 124 | print(batch[0].min(), batch[1].max(), batch[0].dtype) 125 | im.imshow(batch[0][1]) 126 | im.show() 127 | -------------------------------------------------------------------------------- /imlib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from imlib.basic import * 6 | from imlib.dtype import * 7 | from imlib.encode import * 8 | from imlib.transform import * 9 | -------------------------------------------------------------------------------- /imlib/basic.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import skimage.io as iio 7 | 8 | from imlib.dtype import im2float 9 | 10 | 11 | def imread(path, as_gray=False): 12 | """Read image. 13 | 14 | Returns: 15 | Float64 image in [-1.0, 1.0]. 16 | """ 17 | image = iio.imread(path, as_gray) 18 | if image.dtype == np.uint8: 19 | image = image / 127.5 - 1 20 | elif image.dtype == np.uint16: 21 | image = image / 32767.5 - 1 22 | return image 23 | 24 | 25 | def imwrite(image, path): 26 | """Save an [-1.0, 1.0] image.""" 27 | iio.imsave(path, im2float(image)) 28 | 29 | 30 | def imshow(image): 31 | """Show a [-1.0, 1.0] image.""" 32 | iio.imshow(im2float(image)) 33 | 34 | 35 | show = iio.show 36 | -------------------------------------------------------------------------------- /imlib/dtype.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | 7 | 8 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf): 9 | # check type 10 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!' 11 | 12 | # check dtype 13 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes] 14 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes 15 | 16 | # check nan and inf 17 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!' 18 | 19 | # check value 20 | if min_value not in [None, -np.inf]: 21 | l = '[' + str(min_value) 22 | else: 23 | l = '(-inf' 24 | min_value = -np.inf 25 | if max_value not in [None, np.inf]: 26 | r = str(max_value) + ']' 27 | else: 28 | r = 'inf)' 29 | max_value = np.inf 30 | assert np.min(images) >= min_value - 1e-5 and np.max(images) <= max_value + 1e-5, \ 31 | '`images` should be in the range of %s!' % (l + ',' + r) 32 | 33 | 34 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None): 35 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype.""" 36 | _check(images, [np.float32, np.float64], -1.0, 1.0) 37 | dtype = dtype if dtype else images.dtype 38 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype) 39 | 40 | 41 | def float2im(images): 42 | """Transform images from [0, 1.0] to [-1.0, 1.0].""" 43 | _check(images, [np.float32, np.float64], 0.0, 1.0) 44 | return images * 2 - 1.0 45 | 46 | 47 | def float2uint(images): 48 | """Transform images from [0, 1.0] to uint8.""" 49 | _check(images, [np.float32, np.float64], -0.0, 1.0) 50 | return (images * 255).astype(np.uint8) 51 | 52 | 53 | def im2uint(images): 54 | """Transform images from [-1.0, 1.0] to uint8.""" 55 | return to_range(images, 0, 255, np.uint8) 56 | 57 | 58 | def im2float(images): 59 | """Transform images from [-1.0, 1.0] to [0.0, 1.0].""" 60 | return to_range(images, 0.0, 1.0) 61 | 62 | 63 | def uint2im(images): 64 | """Transform images from uint8 to [-1.0, 1.0] of float64.""" 65 | _check(images, np.uint8) 66 | return images / 127.5 - 1.0 67 | 68 | 69 | def uint2float(images): 70 | """Transform images from uint8 to [0.0, 1.0] of float64.""" 71 | _check(images, np.uint8) 72 | return images / 255.0 73 | -------------------------------------------------------------------------------- /imlib/encode.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import io 6 | 7 | import numpy as np 8 | 9 | from imlib.dtype import im2uint, uint2im 10 | from PIL import Image 11 | 12 | 13 | def imencode(image, format='PNG', quality=95): 14 | """Encode an [-1.0, 1.0] image into byte string. 15 | 16 | Args: 17 | format : 'PNG' or 'JPEG'. 18 | quality : Only for 'JPEG'. 19 | 20 | Returns: 21 | Byte string. 22 | """ 23 | byte_io = io.BytesIO() 24 | image = Image.fromarray(im2uint(image)) 25 | image.save(byte_io, format=format, quality=quality) 26 | bytes = byte_io.getvalue() 27 | return bytes 28 | 29 | 30 | def imdecode(bytes): 31 | """Decode byte string to float64 image in [-1.0, 1.0]. 32 | 33 | Args: 34 | bytes: Byte string. 35 | 36 | Returns: 37 | A float64 image in [-1.0, 1.0]. 38 | """ 39 | byte_io = io.BytesIO() 40 | byte_io.write(bytes) 41 | image = np.array(Image.open(byte_io)) 42 | image = uint2im(image) 43 | return image 44 | -------------------------------------------------------------------------------- /imlib/transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import skimage.color as color 7 | import skimage.transform as transform 8 | 9 | 10 | rgb2gray = color.rgb2gray 11 | gray2rgb = color.gray2rgb 12 | 13 | imresize = transform.resize 14 | imrescale = transform.rescale 15 | 16 | 17 | def immerge(images, n_row=None, n_col=None, padding=0, pad_value=0): 18 | """Merge images to an image with (n_row * h) * (n_col * w). 19 | 20 | `images` is in shape of N * H * W(* C=1 or 3) 21 | """ 22 | n = images.shape[0] 23 | if n_row: 24 | n_row = max(min(n_row, n), 1) 25 | n_col = int(n - 0.5) // n_row + 1 26 | elif n_col: 27 | n_col = max(min(n_col, n), 1) 28 | n_row = int(n - 0.5) // n_col + 1 29 | else: 30 | n_row = int(n ** 0.5) 31 | n_col = int(n - 0.5) // n_row + 1 32 | 33 | h, w = images.shape[1], images.shape[2] 34 | shape = (h * n_row + padding * (n_row - 1), 35 | w * n_col + padding * (n_col - 1)) 36 | if images.ndim == 4: 37 | shape += (images.shape[3],) 38 | img = np.full(shape, pad_value, dtype=images.dtype) 39 | 40 | for idx, image in enumerate(images): 41 | i = idx % n_col 42 | j = idx // n_col 43 | img[j * (h + padding):j * (h + padding) + h, 44 | i * (w + padding):i * (w + padding) + w, ...] = image 45 | 46 | return img 47 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import tensorflow.contrib.slim as slim 10 | import tflib as tl 11 | 12 | 13 | # ============================================================================== 14 | # = alias = 15 | # ============================================================================== 16 | 17 | conv = partial(slim.conv2d, activation_fn=None) 18 | dconv = partial(slim.conv2d_transpose, activation_fn=None) 19 | fc = partial(tl.flatten_fully_connected, activation_fn=None) 20 | relu = tf.nn.relu 21 | lrelu = tf.nn.leaky_relu 22 | # batch_norm = partial(slim.batch_norm, scale=True) 23 | batch_norm = partial(slim.batch_norm, scale=True, updates_collections=None) 24 | layer_norm = slim.layer_norm 25 | instance_norm = slim.instance_norm 26 | 27 | 28 | # ============================================================================== 29 | # = models = 30 | # ============================================================================== 31 | 32 | def _get_norm_fn(norm_name, is_training): 33 | if norm_name == 'none': 34 | norm = None 35 | elif norm_name == 'batch_norm': 36 | norm = partial(batch_norm, is_training=is_training) 37 | elif norm_name == 'instance_norm': 38 | norm = instance_norm 39 | elif norm_name == 'layer_norm': 40 | norm = layer_norm 41 | return norm 42 | 43 | 44 | def G(z, c, dim=64, is_training=True): 45 | norm = _get_norm_fn('batch_norm', is_training) 46 | fc_norm_relu = partial(fc, normalizer_fn=norm, activation_fn=relu) 47 | dconv_norm_relu = partial(dconv, normalizer_fn=norm, activation_fn=relu) 48 | 49 | with tf.variable_scope('G', reuse=tf.AUTO_REUSE): 50 | y = tf.concat([z, c], axis=1) 51 | y = fc_norm_relu(y, 4 * 4 * dim * 8) 52 | y = tf.reshape(y, [-1, 4, 4, dim * 8]) 53 | y = dconv_norm_relu(y, dim * 4, 4, 2) 54 | y = dconv_norm_relu(y, dim * 2, 4, 2) 55 | y = dconv_norm_relu(y, dim * 1, 4, 2) 56 | x = tf.tanh(dconv(y, 3, 4, 2)) 57 | return x 58 | 59 | 60 | def D(x, c_dim, dim=64, norm_name='batch_norm', is_training=True): 61 | norm = _get_norm_fn(norm_name, is_training) 62 | conv_norm_lrelu = partial(conv, normalizer_fn=norm, activation_fn=lrelu) 63 | 64 | with tf.variable_scope('D', reuse=tf.AUTO_REUSE): 65 | y = conv_norm_lrelu(x, dim, 4, 2) 66 | y = conv_norm_lrelu(y, dim * 2, 4, 2) 67 | y = conv_norm_lrelu(y, dim * 4, 4, 2) 68 | y = conv_norm_lrelu(y, dim * 8, 4, 2) 69 | logit = fc(y, 1) 70 | c_logit = fc(y, c_dim) 71 | return logit, c_logit 72 | 73 | 74 | # ============================================================================== 75 | # = loss function = 76 | # ============================================================================== 77 | 78 | def get_loss_fn(mode): 79 | if mode == 'gan': 80 | def d_loss_fn(r_logit, f_logit): 81 | r_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(r_logit), r_logit) 82 | f_loss = tf.losses.sigmoid_cross_entropy(tf.zeros_like(f_logit), f_logit) 83 | return r_loss, f_loss 84 | 85 | def g_loss_fn(f_logit): 86 | f_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(f_logit), f_logit) 87 | return f_loss 88 | 89 | elif mode == 'lsgan': 90 | def d_loss_fn(r_logit, f_logit): 91 | r_loss = tf.losses.mean_squared_error(tf.ones_like(r_logit), r_logit) 92 | f_loss = tf.losses.mean_squared_error(tf.zeros_like(f_logit), f_logit) 93 | return r_loss, f_loss 94 | 95 | def g_loss_fn(f_logit): 96 | f_loss = tf.losses.mean_squared_error(tf.ones_like(f_logit), f_logit) 97 | return f_loss 98 | 99 | elif mode == 'wgan': 100 | def d_loss_fn(r_logit, f_logit): 101 | r_loss = - tf.reduce_mean(r_logit) 102 | f_loss = tf.reduce_mean(f_logit) 103 | return r_loss, f_loss 104 | 105 | def g_loss_fn(f_logit): 106 | f_loss = - tf.reduce_mean(f_logit) 107 | return f_loss 108 | 109 | elif mode == 'hinge': 110 | def d_loss_fn(r_logit, f_logit): 111 | r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0)) 112 | f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0)) 113 | return r_loss, f_loss 114 | 115 | def g_loss_fn(f_logit): 116 | # f_loss = tf.reduce_mean(tf.maximum(1 - f_logit, 0)) 117 | f_loss = tf.reduce_mean(- f_logit) 118 | return f_loss 119 | 120 | return d_loss_fn, g_loss_fn 121 | 122 | 123 | # ============================================================================== 124 | # = others = 125 | # ============================================================================== 126 | 127 | def gradient_penalty(f, real, fake, mode): 128 | def _gradient_penalty(f, real, fake=None): 129 | def _interpolate(a, b=None): 130 | with tf.name_scope('interpolate'): 131 | if b is None: # interpolation in DRAGAN 132 | beta = tf.random_uniform(shape=tf.shape(a), minval=0., maxval=1.) 133 | _, variance = tf.nn.moments(a, list(range(a.shape.ndims))) 134 | b = a + 0.5 * tf.sqrt(variance) * beta 135 | shape = [tf.shape(a)[0]] + [1] * (a.shape.ndims - 1) 136 | alpha = tf.random_uniform(shape=shape, minval=0., maxval=1.) 137 | inter = a + alpha * (b - a) 138 | inter.set_shape(a.get_shape().as_list()) 139 | return inter 140 | 141 | with tf.name_scope('gradient_penalty'): 142 | x = _interpolate(real, fake) 143 | pred = f(x) 144 | if isinstance(pred, tuple): 145 | pred = pred[0] 146 | grad = tf.gradients(pred, x)[0] 147 | norm = tf.norm(slim.flatten(grad), axis=1) 148 | gp = tf.reduce_mean((norm - 1.)**2) 149 | return gp 150 | 151 | if mode == 'none': 152 | gp = tf.constant(0, dtype=tf.float32) 153 | elif mode == 'wgan-gp': 154 | gp = _gradient_penalty(f, real, fake) 155 | elif mode == 'dragan': 156 | gp = _gradient_penalty(f, real) 157 | 158 | return gp 159 | 160 | 161 | def sample_c(ks, c_1=None, continuous_last=False): 162 | assert c_1 is None or ks[0] == len(c_1), '`ks[0]` is inconsistent with `c_1`!' 163 | 164 | c_tree = [[np.array([1.])]] 165 | mask_tree = [] 166 | 167 | for l, k in enumerate(ks): 168 | if c_1 is not None and l == 0: 169 | c_l = [c_1] 170 | mask_l = [np.ones_like(c_1)] 171 | else: 172 | c_l = [] 173 | mask_l = [] 174 | for i in range(len(c_tree[-1])): 175 | for j in range(len(c_tree[-1][-1])): 176 | if c_tree[-1][i][j] == 1.: 177 | if continuous_last is True and l == len(ks) - 1: 178 | c_l.append(np.random.uniform(-1, 1, size=[k])) 179 | else: 180 | c_l.append(np.eye(k)[np.random.randint(k)]) 181 | mask_l.append(np.ones([k])) 182 | else: 183 | c_l.append(np.zeros([k])) 184 | mask_l.append(np.zeros([k])) 185 | c_tree.append(c_l) 186 | mask_tree.append(mask_l) 187 | 188 | c_tree[0:1] = [] 189 | c = np.concatenate([k for l in c_tree for k in l]) 190 | mask = np.concatenate([k for l in mask_tree for k in l]) 191 | 192 | return c, mask, c_tree, mask_tree 193 | 194 | 195 | def traversal_trees(ks, continuous_last=False): 196 | trees = [] 197 | if len(ks) == 1: 198 | if continuous_last: 199 | trees.append([[np.random.uniform(-1, 1, size=[ks[0]])]]) 200 | else: 201 | for i in range(ks[0]): 202 | trees.append([[np.eye(ks[0])[i]]]) 203 | else: 204 | def _merge_trees(trees): 205 | tree = [] 206 | for l in range(len(trees[0])): 207 | tree_l = [] 208 | for t in trees: 209 | tree_l += t[l] 210 | tree.append(tree_l) 211 | return tree 212 | 213 | def _zero_tree(tree): 214 | zero_tree = [] 215 | for l in tree: 216 | zero_tree_l = [] 217 | for i in l: 218 | zero_tree_l.append(i * 0.) 219 | zero_tree.append(zero_tree_l) 220 | return zero_tree 221 | 222 | for i in range(ks[0]): 223 | trees_i = [] 224 | sub_trees, _ = traversal_trees(ks[1:], continuous_last=continuous_last) 225 | for j, s_t in enumerate(sub_trees): 226 | to_merge = [_zero_tree(s_t)] * ks[0] 227 | to_merge[i] = s_t 228 | sub_trees[j] = _merge_trees(to_merge) 229 | for s_t in sub_trees: 230 | trees_i.append([[np.eye(ks[0])[i]]] + s_t) 231 | trees += trees_i 232 | 233 | cs = [] 234 | for t in trees: 235 | cs.append(np.concatenate([k for l in t for k in l])) 236 | return trees, cs 237 | 238 | 239 | def to_tree(x, ks): 240 | size_splits = [] 241 | n_l = 1 242 | for k in ks: 243 | for _ in range(n_l): 244 | size_splits.append(k) 245 | n_l *= k 246 | 247 | splits = tf.split(x, size_splits, axis=1) 248 | 249 | tree = [] 250 | n_l = 1 251 | i = 0 252 | for k in ks: 253 | tree_l = [] 254 | for _ in range(n_l): 255 | tree_l.append(splits[i]) 256 | i += 1 257 | n_l *= k 258 | tree.append(tree_l) 259 | 260 | return tree 261 | 262 | 263 | def tree_loss(logits, c, mask, ks, continuous_last=False): 264 | logits_tree = to_tree(logits, ks) 265 | c_tree = to_tree(c, ks) 266 | mask_tree = to_tree(mask, ks) 267 | 268 | losses = [] 269 | for l, logits_l, c_l, mask_l in zip(range(len(logits_tree)), logits_tree, c_tree, mask_tree): 270 | loss_l = 0 271 | for lo, c, m in zip(logits_l, c_l, mask_l): 272 | weights = tf.reduce_mean(m, axis=1) 273 | if continuous_last is True and l == len(ks) - 1: 274 | loss_l += tf.losses.mean_squared_error(c, lo, weights=weights) 275 | else: 276 | loss_l += tf.losses.softmax_cross_entropy(c, lo, weights=weights) 277 | losses.append(loss_l) 278 | return losses 279 | 280 | if __name__ == '__main__': 281 | from pprint import pprint as pp 282 | s = tf.Session() 283 | ks = [2, 2, 2] 284 | c_1 = np.array([1., 0]) 285 | continuous_last = False 286 | if len(ks) == 1 and c_1 is not None: 287 | continuous_last = False 288 | 289 | c, mask, c_tree, mask_tree = sample_c(ks, c_1, continuous_last) 290 | pp(c) 291 | pp(mask) 292 | pp(c_tree) 293 | pp(mask_tree) 294 | 295 | tree = to_tree(tf.constant(np.array([c]), dtype=tf.float32), ks) 296 | pp(tree) 297 | pp(s.run(tree)) 298 | 299 | pp(tree_loss(tf.constant(np.array([c]), dtype=tf.float32), 300 | tf.constant(np.array([c]), dtype=tf.float32), 301 | tf.constant(np.array([mask]), dtype=tf.float32), 302 | ks, 303 | continuous_last)) 304 | 305 | pp(s.run(tf.losses.softmax_cross_entropy([[0, 10]], [[0.5, 0.5]]))) 306 | pp(s.run(tf.losses.mean_squared_error([[2, 1]], [[0, 0]]))) 307 | 308 | for tree, c in zip(*traversal_trees(ks, continuous_last=continuous_last)): 309 | pp(tree) 310 | pp(c) 311 | -------------------------------------------------------------------------------- /pics/bangs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/DTLC-GAN-Tensorflow/ca053af68a47e4678c172d756d64e3576a51e009/pics/bangs.png -------------------------------------------------------------------------------- /pics/eyeglasses.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/DTLC-GAN-Tensorflow/ca053af68a47e4678c172d756d64e3576a51e009/pics/eyeglasses.png -------------------------------------------------------------------------------- /pylib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pylib.path import * 6 | from pylib.timer import * 7 | -------------------------------------------------------------------------------- /pylib/path.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import fnmatch 6 | import os 7 | import sys 8 | 9 | 10 | def add_path(paths): 11 | if not isinstance(paths, (list, tuple)): 12 | paths = [paths] 13 | for path in paths: 14 | if path not in sys.path: 15 | sys.path.insert(0, path) 16 | 17 | 18 | def mkdir(paths): 19 | if not isinstance(paths, (list, tuple)): 20 | paths = [paths] 21 | for path in paths: 22 | if not os.path.isdir(path): 23 | os.makedirs(path) 24 | 25 | 26 | def split(path): 27 | dir, name_ext = os.path.split(path) 28 | name, ext = os.path.splitext(name_ext) 29 | return dir, name, ext 30 | 31 | 32 | def directory(path): 33 | return split(path)[0] 34 | 35 | 36 | def name(path): 37 | return split(path)[1] 38 | 39 | 40 | def ext(path): 41 | return split(path)[2] 42 | 43 | 44 | def name_ext(path): 45 | return ''.join(split(path)[1:]) 46 | 47 | 48 | asbpath = os.path.abspath 49 | 50 | 51 | join = os.path.join 52 | 53 | 54 | def match(dir, pat, recursive=False): 55 | if recursive: 56 | iterator = os.walk(dir) 57 | else: 58 | iterator = [next(os.walk(dir))] 59 | matches = [] 60 | for root, _, file_names in iterator: 61 | for file_name in fnmatch.filter(file_names, pat): 62 | matches.append(os.path.join(root, file_name)) 63 | return matches 64 | -------------------------------------------------------------------------------- /pylib/timer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import datetime 6 | import timeit 7 | 8 | 9 | class Timer(object): 10 | """A timer as a context manager. 11 | 12 | Modified from https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py. 13 | 14 | Wraps around a timer. A custom timer can be passed 15 | to the constructor. The default timer is timeit.default_timer. 16 | 17 | Note that the latter measures wall clock time, not CPU time! 18 | On Unix systems, it corresponds to time.time. 19 | On Windows systems, it corresponds to time.clock. 20 | 21 | Arguments: 22 | is_output : If True, print output after exiting context. 23 | format : 'ms', 's' or 'datetime' 24 | """ 25 | 26 | def __init__(self, timer=timeit.default_timer, is_output=True, fmt='s'): 27 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!" 28 | self._timer = timer 29 | self._is_output = is_output 30 | self._fmt = fmt 31 | 32 | def __enter__(self): 33 | """Start the timer in the context manager scope.""" 34 | self.start() 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_value, exc_traceback): 38 | """Set the end time.""" 39 | if self._is_output: 40 | print(str(self)) 41 | 42 | def __str__(self): 43 | if self._fmt != 'datetime': 44 | return '%s %s' % (self.elapsed, self._fmt) 45 | else: 46 | return str(self.elapsed) 47 | 48 | def start(self): 49 | self.start_time = self._timer() 50 | 51 | @property 52 | def elapsed(self): 53 | """Return the current elapsed time since start.""" 54 | e = self._timer() - self.start_time 55 | 56 | if self._fmt == 'ms': 57 | return e * 1000 58 | elif self._fmt == 's': 59 | return e 60 | elif self._fmt == 'datetime': 61 | return datetime.timedelta(seconds=e) 62 | 63 | 64 | def timer(**timer_kwargs): 65 | """Function decorator displaying the function execution time. 66 | 67 | All kwargs are the arguments taken by the Timer class constructor. 68 | """ 69 | # store Timer kwargs in local variable so the namespace isn't polluted 70 | # by different level args and kwargs 71 | 72 | def wrapped_f(f): 73 | def wrapped(*args, **kwargs): 74 | fmt = '[*] function "%(function_name)s" execution time: %(execution_time)s [*]' 75 | with Timer(**timer_kwargs) as t: 76 | out = f(*args, **kwargs) 77 | context = {'function_name': f.__name__, 'execution_time': str(t)} 78 | print(fmt % context) 79 | return out 80 | return wrapped 81 | 82 | return wrapped_f 83 | 84 | if __name__ == "__main__": 85 | import time 86 | 87 | # 1 88 | print(1) 89 | with Timer() as t: 90 | time.sleep(1) 91 | print(t) 92 | time.sleep(1) 93 | 94 | with Timer(fmt='datetime') as t: 95 | time.sleep(1) 96 | 97 | # 2 98 | print(2) 99 | t = Timer(fmt='ms') 100 | t.start() 101 | time.sleep(2) 102 | print(t) 103 | 104 | t = Timer(fmt='datetime') 105 | t.start() 106 | time.sleep(1) 107 | print(t) 108 | 109 | # 3 110 | print(3) 111 | 112 | @timer(fmt='ms') 113 | def blah(): 114 | time.sleep(2) 115 | 116 | blah() 117 | -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.checkpoint import * 6 | from tflib.collection import * 7 | from tflib.data import * 8 | from tflib.layers import * 9 | from tflib.parallel import * 10 | from tflib.utils import * 11 | from tflib.vision import * 12 | -------------------------------------------------------------------------------- /tflib/checkpoint.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def load_checkpoint(ckpt_dir_or_file, session, var_list=None): 11 | """Load checkpoint. 12 | 13 | This function add some useless ops to the graph. It is better 14 | to use tf.train.init_from_checkpoint(...). 15 | """ 16 | if os.path.isdir(ckpt_dir_or_file): 17 | ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) 18 | 19 | restorer = tf.train.Saver(var_list) 20 | restorer.restore(session, ckpt_dir_or_file) 21 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file) 22 | 23 | 24 | def init_from_checkpoint(ckpt_dir_or_file, assignment_map={'/': '/'}): 25 | # Use the checkpoint values for the variables' initializers. Note that this 26 | # function just changes the initializers but does not actually run them, and 27 | # you should still run the initializers manually. 28 | tf.train.init_from_checkpoint(ckpt_dir_or_file, assignment_map) 29 | print(' [*] Loading checkpoint succeeds! Copy variables from % s!' % ckpt_dir_or_file) 30 | -------------------------------------------------------------------------------- /tflib/collection.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from functools import partial 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def tensors_filter(tensors, 11 | includes='', 12 | includes_combine_type='or', 13 | excludes=[], 14 | excludes_combine_type='or'): 15 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!' 16 | assert isinstance(includes, (str, list, tuple)), '`includes` should be a string or a list(tuple) of strings!' 17 | assert includes_combine_type in ['or', 'and'], "`includes_combine_type` should be 'or' or 'and'!" 18 | assert isinstance(excludes, (str, list, tuple)), '`excludes` should be a string or a list(tuple) of strings!' 19 | assert excludes_combine_type in ['or', 'and'], "`excludes_combine_type` should be 'or' or 'and'!" 20 | 21 | def _select(filters, combine_type): 22 | if isinstance(filters, str): 23 | filters = [filters] 24 | 25 | selected = [] 26 | for t in tensors: 27 | if combine_type == 'or': 28 | for filt in filters: 29 | if filt in t.name: 30 | selected.append(t) 31 | break 32 | elif combine_type == 'and': 33 | all_pass = True and filters # for fiters == [] 34 | for filt in filters: 35 | if filt not in t.name: 36 | all_pass = False 37 | break 38 | if all_pass: 39 | selected.append(t) 40 | 41 | return selected 42 | 43 | include_set = _select(includes, includes_combine_type) 44 | exclude_set = _select(excludes, excludes_combine_type) 45 | select_set = [t for t in include_set if t not in exclude_set] 46 | 47 | return select_set 48 | 49 | 50 | def get_collection(key, 51 | includes='', 52 | includes_combine_type='or', 53 | excludes=[], 54 | excludes_combine_type='and'): 55 | tensors = tf.get_collection(key) 56 | return tensors_filter(tensors, 57 | includes, 58 | includes_combine_type, 59 | excludes, 60 | excludes_combine_type) 61 | 62 | global_variables = partial(get_collection, key=tf.GraphKeys.GLOBAL_VARIABLES) 63 | trainable_variables = partial(get_collection, key=tf.GraphKeys.TRAINABLE_VARIABLES) 64 | update_ops = partial(get_collection, key=tf.GraphKeys.UPDATE_OPS) 65 | -------------------------------------------------------------------------------- /tflib/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.data.dataset import * 6 | from tflib.data.disk_image import * 7 | from tflib.data.memory_data import * 8 | from tflib.data.tfrecord import * 9 | from tflib.data.tfrecord_creator import * 10 | -------------------------------------------------------------------------------- /tflib/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | import tensorflow.contrib.eager as tfe 9 | 10 | from tflib.utils import session 11 | 12 | 13 | _N_CPU = multiprocessing.cpu_count() 14 | 15 | 16 | def batch_dataset(dataset, 17 | batch_size, 18 | prefetch_batch=_N_CPU + 1, 19 | drop_remainder=True, 20 | filter=None, 21 | map_func=None, 22 | num_threads=_N_CPU, 23 | shuffle=True, 24 | shuffle_buffer_size=None, 25 | repeat=-1): 26 | if filter: 27 | dataset = dataset.filter(filter) 28 | 29 | if map_func: 30 | dataset = dataset.map(map_func, num_parallel_calls=num_threads) 31 | 32 | if shuffle: 33 | if shuffle_buffer_size is None: 34 | shuffle_buffer_size = batch_size * 100 35 | dataset = dataset.shuffle(shuffle_buffer_size) 36 | 37 | if drop_remainder: 38 | dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) 39 | else: 40 | dataset = dataset.batch(batch_size) 41 | 42 | dataset = dataset.repeat(repeat).prefetch(prefetch_batch) 43 | 44 | return dataset 45 | 46 | 47 | class Dataset(object): 48 | 49 | def __init__(self): 50 | self._dataset = None 51 | self._iterator = None 52 | self._batch_op = None 53 | self._sess = None 54 | 55 | self._is_eager = tf.executing_eagerly() 56 | self._eager_iterator = None 57 | 58 | def __del__(self): 59 | if self._sess: 60 | self._sess.close() 61 | 62 | def __iter__(self): 63 | return self 64 | 65 | def __next__(self): 66 | try: 67 | b = self.get_next() 68 | except: 69 | raise StopIteration 70 | else: 71 | return b 72 | 73 | next = __next__ 74 | 75 | def get_next(self): 76 | if self._is_eager: 77 | return self._eager_iterator.get_next() 78 | else: 79 | return self._sess.run(self._batch_op) 80 | 81 | def reset(self, feed_dict={}): 82 | if self._is_eager: 83 | self._eager_iterator = tfe.Iterator(self._dataset) 84 | else: 85 | self._sess.run(self._iterator.initializer, feed_dict=feed_dict) 86 | 87 | def _bulid(self, dataset, sess=None): 88 | self._dataset = dataset 89 | 90 | if self._is_eager: 91 | self._eager_iterator = tfe.Iterator(dataset) 92 | else: 93 | self._iterator = dataset.make_initializable_iterator() 94 | self._batch_op = self._iterator.get_next() 95 | if sess: 96 | self._sess = sess 97 | else: 98 | self._sess = session() 99 | 100 | try: 101 | self.reset() 102 | except: 103 | pass 104 | 105 | @property 106 | def dataset(self): 107 | return self._dataset 108 | 109 | @property 110 | def iterator(self): 111 | return self._iterator 112 | 113 | @property 114 | def batch_op(self): 115 | return self._batch_op 116 | -------------------------------------------------------------------------------- /tflib/data/disk_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import tensorflow as tf 8 | 9 | from tflib.data.dataset import batch_dataset, Dataset 10 | 11 | 12 | _N_CPU = multiprocessing.cpu_count() 13 | 14 | 15 | def disk_image_batch_dataset(img_paths, 16 | batch_size, 17 | labels=None, 18 | prefetch_batch=_N_CPU + 1, 19 | drop_remainder=True, 20 | filter=None, 21 | map_func=None, 22 | num_threads=_N_CPU, 23 | shuffle=True, 24 | shuffle_buffer_size=None, 25 | repeat=-1): 26 | """Disk image batch dataset. 27 | 28 | This function is suitable for jpg and png files 29 | 30 | Arguments: 31 | img_paths : String list or 1-D tensor, each of which is an iamge path 32 | labels : Label list/tuple_of_list or tensor/tuple_of_tensor, each of which is a corresponding label 33 | """ 34 | if labels is None: 35 | dataset = tf.data.Dataset.from_tensor_slices(img_paths) 36 | elif isinstance(labels, tuple): 37 | dataset = tf.data.Dataset.from_tensor_slices((img_paths,) + tuple(labels)) 38 | else: 39 | dataset = tf.data.Dataset.from_tensor_slices((img_paths, labels)) 40 | 41 | def parse_func(path, *label): 42 | img = tf.read_file(path) 43 | img = tf.image.decode_png(img, 3) 44 | return (img,) + label 45 | 46 | if map_func: 47 | def map_func_(*args): 48 | return map_func(*parse_func(*args)) 49 | else: 50 | map_func_ = parse_func 51 | 52 | # dataset = dataset.map(parse_func, num_parallel_calls=num_threads) is slower 53 | 54 | dataset = batch_dataset(dataset, 55 | batch_size, 56 | prefetch_batch, 57 | drop_remainder, 58 | filter, 59 | map_func_, 60 | num_threads, 61 | shuffle, 62 | shuffle_buffer_size, 63 | repeat) 64 | 65 | return dataset 66 | 67 | 68 | class DiskImageData(Dataset): 69 | """DiskImageData. 70 | 71 | This class is suitable for jpg and png files 72 | 73 | Arguments: 74 | img_paths : String list or 1-D tensor, each of which is an iamge path 75 | labels : Label list or tensor, each of which is a corresponding label 76 | """ 77 | 78 | def __init__(self, 79 | img_paths, 80 | batch_size, 81 | labels=None, 82 | prefetch_batch=_N_CPU + 1, 83 | drop_remainder=True, 84 | filter=None, 85 | map_func=None, 86 | num_threads=_N_CPU, 87 | shuffle=True, 88 | shuffle_buffer_size=None, 89 | repeat=-1, 90 | sess=None): 91 | super(DiskImageData, self).__init__() 92 | dataset = disk_image_batch_dataset(img_paths, 93 | batch_size, 94 | labels, 95 | prefetch_batch, 96 | drop_remainder, 97 | filter, 98 | map_func, 99 | num_threads, 100 | shuffle, 101 | shuffle_buffer_size, 102 | repeat) 103 | self._bulid(dataset, sess) 104 | self._n_data = len(img_paths) 105 | 106 | def __len__(self): 107 | return self._n_data 108 | 109 | 110 | if __name__ == '__main__': 111 | import glob 112 | 113 | import imlib as im 114 | import numpy as np 115 | import pylib 116 | 117 | paths = glob.glob('/home/hezhenliang/Resource/face/CelebA/origin/origin/processed_by_hezhenliang/align_celeba/img_align_celeba/*.jpg') 118 | paths = sorted(paths)[182637:] 119 | labels = list(range(len(paths))) 120 | 121 | def filter(x, y, *args): 122 | return tf.cond(y > 1, lambda: tf.constant(True), lambda: tf.constant(False)) 123 | 124 | def map_func(x, *args): 125 | x = tf.image.resize_images(x, [256, 256]) 126 | x = tf.to_float((x - tf.reduce_min(x)) / (tf.reduce_max(x) - tf.reduce_min(x)) * 2 - 1) 127 | return (x,) + args 128 | 129 | tf.enable_eager_execution() 130 | 131 | s = tf.Session() 132 | 133 | data = DiskImageData(paths, 32, (labels, labels), filter=None, map_func=None, shuffle=True, sess=s) 134 | 135 | for _ in range(1000): 136 | with pylib.Timer(): 137 | for i in range(100): 138 | b = data.get_next() 139 | # print(b[1][0]) 140 | # print(b[2][0]) 141 | # im.imshow(np.array(b[0][0])) 142 | # im.show() 143 | # data.reset() 144 | -------------------------------------------------------------------------------- /tflib/data/memory_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import multiprocessing 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from tflib.data.dataset import batch_dataset, Dataset 11 | 12 | 13 | _N_CPU = multiprocessing.cpu_count() 14 | 15 | 16 | def memory_data_batch_dataset(memory_data_dict, 17 | batch_size, 18 | prefetch_batch=_N_CPU + 1, 19 | drop_remainder=True, 20 | filter=None, 21 | map_func=None, 22 | num_threads=_N_CPU, 23 | shuffle=True, 24 | shuffle_buffer_size=None, 25 | repeat=-1): 26 | """Memory data batch dataset. 27 | 28 | `memory_data_dict` example: 29 | {'img': img_ndarray, 'label': label_ndarray} or 30 | {'img': img_tftensor, 'label': label_tftensor} 31 | * The value of each item of `memory_data_dict` is in shape of (N, ...). 32 | """ 33 | dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict) 34 | dataset = batch_dataset(dataset, 35 | batch_size, 36 | prefetch_batch, 37 | drop_remainder, 38 | filter, 39 | map_func, 40 | num_threads, 41 | shuffle, 42 | shuffle_buffer_size, 43 | repeat) 44 | return dataset 45 | 46 | 47 | class MemoryData(Dataset): 48 | """MemoryData. 49 | 50 | `memory_data_dict` example: 51 | {'img': img_ndarray, 'label': label_ndarray} or 52 | {'img': img_tftensor, 'label': label_tftensor} 53 | * The value of each item of `memory_data_dict` is in shape of (N, ...). 54 | """ 55 | 56 | def __init__(self, 57 | memory_data_dict, 58 | batch_size, 59 | prefetch_batch=_N_CPU + 1, 60 | drop_remainder=True, 61 | filter=None, 62 | map_func=None, 63 | num_threads=_N_CPU, 64 | shuffle=True, 65 | shuffle_buffer_size=None, 66 | repeat=-1, 67 | sess=None): 68 | super(MemoryData, self).__init__() 69 | dataset = memory_data_batch_dataset(memory_data_dict, 70 | batch_size, 71 | prefetch_batch, 72 | drop_remainder, 73 | filter, 74 | map_func, 75 | num_threads, 76 | shuffle, 77 | shuffle_buffer_size, 78 | repeat) 79 | self._bulid(dataset, sess) 80 | if isinstance(memory_data_dict.values()[0], np.ndarray): 81 | self._n_data = len(memory_data_dict.values()[0]) 82 | else: 83 | self._n_data = memory_data_dict.values()[0].get_shape().as_list()[0] 84 | 85 | def __len__(self): 86 | return self._n_data 87 | 88 | if __name__ == '__main__': 89 | data = {'a': np.array([1.0, 2, 3, 4, 5]), 90 | 'b': np.array([[1, 2], 91 | [2, 3], 92 | [3, 4], 93 | [4, 5], 94 | [5, 6]])} 95 | 96 | def filter(x): 97 | return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False)) 98 | 99 | def map_func(x): 100 | x['a'] = x['a'] * 10 101 | return x 102 | 103 | # tf.enable_eager_execution() 104 | 105 | s = tf.Session() 106 | 107 | dataset = MemoryData(data, 108 | 2, 109 | filter=None, 110 | map_func=map_func, 111 | shuffle=True, 112 | shuffle_buffer_size=None, 113 | drop_remainder=True, 114 | repeat=4, 115 | sess=s) 116 | 117 | for i in range(5): 118 | print(map(dataset.get_next().__getitem__, ['b', 'a'])) 119 | 120 | print([n.name for n in tf.get_default_graph().as_graph_def().node]) 121 | -------------------------------------------------------------------------------- /tflib/data/tfrecord.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import glob 7 | import json 8 | import multiprocessing 9 | import os 10 | 11 | import tensorflow as tf 12 | 13 | from tflib.data.dataset import batch_dataset, Dataset 14 | 15 | 16 | _N_CPU = multiprocessing.cpu_count() 17 | 18 | _DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024 19 | 20 | _DECODERS = { 21 | 'png': {'decoder': tf.image.decode_png, 'decode_param': dict()}, 22 | 'jpg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()}, 23 | 'jpeg': {'decoder': tf.image.decode_jpeg, 'decode_param': dict()}, 24 | 'uint8': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.uint8)}, 25 | 'int64': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.int64)}, 26 | 'float32': {'decoder': tf.decode_raw, 'decode_param': dict(out_type=tf.float32)}, 27 | } 28 | 29 | 30 | def tfrecord_batch_dataset(tfrecord_files, 31 | infos, 32 | compression_type, 33 | batch_size, 34 | prefetch_batch=_N_CPU + 1, 35 | drop_remainder=True, 36 | filter=None, 37 | map_func=None, 38 | num_threads=_N_CPU, 39 | shuffle=True, 40 | shuffle_buffer_size=None, 41 | repeat=-1): 42 | """Tfrecord batch dataset. 43 | 44 | `infos` example: 45 | [{'name': 'img', 'decoder': tf.image.decode_png, 'decode_param': {}, 'shape': [112, 112, 1]}, 46 | {'name': 'point', 'decoder': tf.decode_raw, 'decode_param': dict(out_type = tf.float32), 'shape':[136]}] 47 | """ 48 | dataset = tf.data.TFRecordDataset(tfrecord_files, 49 | compression_type=compression_type, 50 | buffer_size=_DEFAULT_READER_BUFFER_SIZE_BYTES) 51 | 52 | features = {} 53 | for info in infos: 54 | features[info['name']] = tf.FixedLenFeature([], tf.string) 55 | 56 | def parse_func(serialized_example): 57 | example = tf.parse_single_example(serialized_example, features=features) 58 | 59 | feature_dict = {} 60 | for info in infos: 61 | name = info['name'] 62 | decoder = info['decoder'] 63 | decode_param = info['decode_param'] 64 | shape = info['shape'] 65 | 66 | feature = decoder(example[name], **decode_param) 67 | feature = tf.reshape(feature, shape) 68 | feature_dict[name] = feature 69 | 70 | return feature_dict 71 | 72 | dataset = dataset.map(parse_func, num_parallel_calls=num_threads) 73 | 74 | dataset = batch_dataset(dataset, 75 | batch_size, 76 | prefetch_batch, 77 | drop_remainder, 78 | filter, 79 | map_func, 80 | num_threads, 81 | shuffle, 82 | shuffle_buffer_size, 83 | repeat) 84 | 85 | return dataset 86 | 87 | 88 | class TfrecordData(Dataset): 89 | 90 | def __init__(self, 91 | tfrecord_path, 92 | batch_size, 93 | prefetch_batch=_N_CPU + 1, 94 | drop_remainder=True, 95 | filter=None, 96 | map_func=None, 97 | num_threads=_N_CPU, 98 | shuffle=True, 99 | shuffle_buffer_size=None, 100 | repeat=-1, 101 | sess=None): 102 | super(TfrecordData, self).__init__() 103 | 104 | info_file = os.path.join(tfrecord_path, 'info.json') 105 | infos, self._data_num, compression_type = self._parse_json(info_file) 106 | 107 | self._shapes = {info['name']: tuple(info['shape']) for info in infos} 108 | 109 | tfrecord_files = sorted(glob.glob(os.path.join(tfrecord_path, '*.tfrecord'))) 110 | dataset = tfrecord_batch_dataset(tfrecord_files, 111 | infos, 112 | compression_type, 113 | batch_size, 114 | prefetch_batch, 115 | drop_remainder, 116 | filter, 117 | map_func, 118 | num_threads, 119 | shuffle, 120 | shuffle_buffer_size, 121 | repeat) 122 | 123 | self._bulid(dataset, sess) 124 | 125 | def __len__(self): 126 | return self._data_num 127 | 128 | @property 129 | def shape(self): 130 | return self._shapes 131 | 132 | @staticmethod 133 | def _parse_old(json_file): 134 | with open(json_file.replace('info.json', 'info.txt')) as f: 135 | try: # older version 1 136 | infos = json.load(f) 137 | for info in infos[0:-1]: 138 | info['decoder'] = _DECODERS[info['dtype_or_format']]['decoder'] 139 | info['decode_param'] = _DECODERS[info['dtype_or_format']]['decode_param'] 140 | except: # older version 2 141 | f.seek(0) 142 | infos = '' 143 | for line in f.readlines(): 144 | infos += line.strip('\n') 145 | infos = eval(infos) 146 | 147 | data_num = infos[-1]['data_num'] 148 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[infos[-1]['compression_type']] 149 | infos[-1:] = [] 150 | 151 | return infos, data_num, compression_type 152 | 153 | @staticmethod 154 | def _parse_json(json_file): 155 | try: 156 | with open(json_file) as f: 157 | info = json.load(f) 158 | infos = info['item'] 159 | for i in infos: 160 | i['decoder'] = _DECODERS[i['dtype_or_format']]['decoder'] 161 | i['decode_param'] = _DECODERS[i['dtype_or_format']]['decode_param'] 162 | data_num = info['info']['data_num'] 163 | compression_type = tf.python_io.TFRecordOptions.compression_type_map[info['info']['compression_type']] 164 | except: # for older version 165 | infos, data_num, compression_type = TfrecordData._parse_old(json_file) 166 | 167 | return infos, data_num, compression_type 168 | -------------------------------------------------------------------------------- /tflib/data/tfrecord_creator.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import io 6 | import json 7 | import os 8 | import shutil 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from PIL import Image 14 | from tflib.data import tfrecord 15 | 16 | __metaclass__ = type 17 | 18 | 19 | _ALLOWED_TYPES = tfrecord._DECODERS.keys() 20 | 21 | 22 | class BytesTfrecordCreator(object): 23 | """BytesTfrecordCreator. 24 | 25 | `infos` example: 26 | infos = [ 27 | ['img', 'jpg', (64, 64, 3)], 28 | ['id', 'int64', ()], 29 | ['attr', 'int64', (40,)], 30 | ['point', 'float32', (5, 2)] 31 | ] 32 | 33 | `compression_type`: 34 | 0 : NONE 35 | 1 : ZLIB 36 | 2 : GZIP 37 | """ 38 | 39 | def __init__(self, 40 | save_path, 41 | infos, 42 | size_each=None, 43 | compression_type=0, 44 | overwrite_existence=False): 45 | # overwrite existence 46 | if os.path.exists(save_path): 47 | if not overwrite_existence: 48 | raise Exception('%s exists!' % save_path) 49 | else: 50 | shutil.rmtree(save_path) 51 | os.makedirs(save_path) 52 | else: 53 | os.makedirs(save_path) 54 | 55 | self._save_path = save_path 56 | 57 | # add info 58 | self._infos = [] 59 | self._info_names = [] 60 | for info in infos: 61 | self._add_info(*info) 62 | 63 | self._data_num = 0 64 | self._tfrecord_num = 0 65 | self._size_each = [size_each, 2147483647][not size_each] 66 | self._writer = None 67 | 68 | self._compression_type = compression_type 69 | self._options = tf.python_io.TFRecordOptions(compression_type) 70 | 71 | def __del__(self): 72 | info = {'item': self._infos, 'info': {'data_num': self._data_num, 'compression_type': self._compression_type}} 73 | info_str = json.dumps(info, indent=4, separators=(',', ':')) 74 | 75 | with open(os.path.join(self._save_path, 'info.json'), 'w') as info_f: 76 | info_f.write(info_str) 77 | 78 | if self._writer: 79 | self._writer.close() 80 | 81 | def add(self, feature_bytes_dict): 82 | """Add example. 83 | 84 | `feature_bytes_dict` example: 85 | feature_bytes_dict = { 86 | 'img' : img_bytes, 87 | 'id' : id_bytes, 88 | 'attr' : attr_bytes, 89 | 'point' : point_bytes 90 | } 91 | """ 92 | assert sorted(self._info_names) == sorted(feature_bytes_dict.keys()), \ 93 | 'Feature names are inconsistent with the givens!' 94 | 95 | self._new_tfrecord_check() 96 | 97 | self._writer.write(self._bytes_tfexample(feature_bytes_dict).SerializeToString()) 98 | self._data_num += 1 99 | 100 | def _new_tfrecord_check(self): 101 | if self._data_num // self._size_each == self._tfrecord_num: 102 | self._tfrecord_num += 1 103 | 104 | if self._writer: 105 | self._writer.close() 106 | 107 | tfrecord_path = os.path.join(self._save_path, 'data_%06d.tfrecord' % (self._tfrecord_num - 1)) 108 | self._writer = tf.python_io.TFRecordWriter(tfrecord_path, self._options) 109 | 110 | def _add_info(self, name, dtype_or_format, shape): 111 | assert name not in self._info_names, 'Info name "%s" is duplicated!' % name 112 | assert dtype_or_format in _ALLOWED_TYPES, 'Allowed data types: %s!' % str(_ALLOWED_TYPES) 113 | self._infos.append(dict(name=name, dtype_or_format=dtype_or_format, shape=shape)) 114 | self._info_names.append(name) 115 | 116 | @staticmethod 117 | def _bytes_feature(values): 118 | """Return a TF-Feature of bytes. 119 | 120 | Arguments: 121 | values : A byte string or list of byte strings. 122 | 123 | Returns: 124 | A TF-Feature. 125 | """ 126 | if not isinstance(values, (tuple, list)): 127 | values = [values] 128 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 129 | 130 | @staticmethod 131 | def _bytes_tfexample(bytes_dict): 132 | """Convert bytes to tfexample. 133 | 134 | `bytes_dict` example: 135 | bytes_dict = { 136 | 'img' : img_bytes, 137 | 'id' : id_bytes, 138 | 'attr' : attr_bytes, 139 | 'point' : point_bytes 140 | } 141 | """ 142 | feature_dict = {} 143 | for key, value in bytes_dict.items(): 144 | feature_dict[key] = BytesTfrecordCreator._bytes_feature(value) 145 | return tf.train.Example(features=tf.train.Features(feature=feature_dict)) 146 | 147 | 148 | class DataLablePairTfrecordCreator(BytesTfrecordCreator): 149 | """DataLablePairTfrecordCreator. 150 | 151 | If `data_shape` is None, then the `data` to be added should be a 152 | numpy array, and the shape and dtype of `data` will be inferred. 153 | If `data_shape` is not None, `data` should be given as byte string, 154 | and `data_dtype_or_format` should also be given. 155 | 156 | `compression_type`: 157 | 0 : NONE 158 | 1 : ZLIB 159 | 2 : GZIP 160 | """ 161 | 162 | def __init__(self, 163 | save_path, 164 | data_shape=None, 165 | data_dtype_or_format=None, 166 | data_name='data', 167 | size_each=None, 168 | compression_type=0, 169 | overwrite_existence=False): 170 | super(DataLablePairTfrecordCreator, self).__init__(save_path, 171 | [], 172 | size_each, 173 | compression_type, 174 | overwrite_existence) 175 | 176 | if data_shape: 177 | assert data_dtype_or_format, '`data_dtype_or_format` should be given when `data_shape` is given!' 178 | self._is_data_bytes = True 179 | else: 180 | self._is_data_bytes = False 181 | 182 | self._data_shape = data_shape 183 | self._data_dtype_or_format = data_dtype_or_format 184 | self._data_name = data_name 185 | self._label_shape_dict = {} 186 | self._label_dtype_dict = {} 187 | 188 | self._info_built = False 189 | 190 | def add(self, data, label_dict): 191 | """Add example. 192 | 193 | `label_dict` example: 194 | label_dict = { 195 | 'id' : id_ndarray, 196 | 'attr' : attr_ndarray, 197 | 'point' : point_ndarray 198 | } 199 | """ 200 | self._check_and_build(data, label_dict) 201 | 202 | if not self._is_data_bytes: 203 | data = data.tobytes() 204 | 205 | feature_dict = {self._data_name: data} 206 | for name, label in label_dict.items(): 207 | feature_dict[name] = label.tobytes() 208 | 209 | super(DataLablePairTfrecordCreator, self).add(feature_dict) 210 | 211 | def _check_and_build(self, data, label_dict): 212 | # check type 213 | if self._is_data_bytes: 214 | assert isinstance(data, (str, bytes)), '`data` should be a byte string!' 215 | else: 216 | assert isinstance(data, np.ndarray), '`data` should be a numpy array!' 217 | for label in label_dict.values(): 218 | assert isinstance(label, np.ndarray), 'labels should be numpy arrays!' 219 | 220 | # check shape and dtype or bulid info at first adding 221 | if self._info_built: 222 | if not self._is_data_bytes: 223 | assert data.shape == tuple(self._data_shape), 'Shapes of `data`s are inconsistent!' 224 | assert data.dtype.name == self._data_dtype_or_format, 'Dtypes of `data`s are inconsistent!' 225 | for name, label in label_dict.items(): 226 | assert label.shape == self._label_shape_dict[name], 'Shapes of `%s`s are inconsistent!' % name 227 | assert label.dtype.name == self._label_dtype_dict[name], 'Dtypes of `%s`s are inconsistent!' % name 228 | else: 229 | if not self._is_data_bytes: 230 | self._data_shape = data.shape 231 | self._data_dtype_or_format = data.dtype.name 232 | self._add_info(self._data_name, self._data_dtype_or_format, self._data_shape) 233 | 234 | for name, label in label_dict.items(): 235 | self._label_shape_dict[name] = label.shape 236 | self._label_dtype_dict[name] = label.dtype.name 237 | self._add_info(name, label.dtype.name, label.shape) 238 | 239 | self._info_built = True 240 | 241 | 242 | class ImageLablePairTfrecordCreator(DataLablePairTfrecordCreator): 243 | """ImageLablePairTfrecordCreator. 244 | 245 | Arguments: 246 | encode_type : One of [None, 'png', 'jpg']. 247 | quality : For 'jpg'. 248 | compression_type : 249 | 0 : NONE 250 | 1 : ZLIB 251 | 2 : GZIP 252 | """ 253 | 254 | def __init__(self, 255 | save_path, 256 | encode_type='png', 257 | quality=95, 258 | data_name='img', 259 | size_each=None, 260 | compression_type=0, 261 | overwrite_existence=False): 262 | super(ImageLablePairTfrecordCreator, self).__init__(save_path, 263 | None, 264 | None, 265 | data_name, 266 | size_each, 267 | compression_type, 268 | overwrite_existence) 269 | 270 | assert encode_type in [None, 'png', 'jpg'], "`encode_type` should be in the list of [None, 'png', 'jpg']!" 271 | 272 | self._encode_type = encode_type 273 | self._quality = quality 274 | 275 | self._data_shape = None 276 | self._data_dtype_or_format = None 277 | self._is_data_bytes = True 278 | 279 | def add(self, image, label_dict): 280 | """Add example. 281 | 282 | `image`: An H * W (* C) uint8 numpy array. 283 | 284 | `label_dict` example: 285 | label_dict = { 286 | 'id' : id_ndarray, 287 | 'attr' : attr_ndarray, 288 | 'point' : point_ndarray 289 | } 290 | """ 291 | self._check(image) 292 | image_bytes = self._encode(image) 293 | super(ImageLablePairTfrecordCreator, self).add(image_bytes, label_dict) 294 | 295 | def _check(self, image): 296 | if not self._data_shape: 297 | assert isinstance(image, np.ndarray) and image.dtype == np.uint8 and image.ndim in [2, 3], \ 298 | '`image` should be an H * W (* C) uint8 numpy array!' 299 | if self._encode_type and image.ndim == 3 and image.shape[-1] != 3: 300 | raise Exception('Only images with 1 or 3 channels are allowed to be encoded!') 301 | 302 | if image.ndim == 2: 303 | self._data_shape = image.shape + (1,) 304 | else: 305 | self._data_shape = image.shape 306 | self._data_dtype_or_format = [self._encode_type, 'uint8'][not self._encode_type] 307 | else: 308 | sp = image.shape 309 | if image.ndim == 2: 310 | sp = sp + (1,) 311 | assert sp == self._data_shape, 'Shapes of `image`s are inconsistent!' 312 | assert image.dtype == np.uint8, 'Dtypes of `image`s are inconsistent!' 313 | 314 | def _encode(self, image): 315 | if self._encode_type: 316 | if image.shape[-1] == 1: 317 | image.shape = image.shape[:2] 318 | byte = io.BytesIO() 319 | image = Image.fromarray(image) 320 | if self._encode_type == 'jpg': 321 | image.save(byte, 'JPEG', quality=self._quality) 322 | elif self._encode_type == 'png': 323 | image.save(byte, 'PNG') 324 | image_bytes = byte.getvalue() 325 | else: 326 | image_bytes = image.tobytes() 327 | return image_bytes 328 | -------------------------------------------------------------------------------- /tflib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.layers.layers import * 6 | -------------------------------------------------------------------------------- /tflib/layers/layers.py: -------------------------------------------------------------------------------- 1 | # functions compatible with tensorflow.contrib 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import six 8 | 9 | import tensorflow as tf 10 | 11 | from tensorflow.contrib.framework.python.ops import add_arg_scope 12 | from tensorflow.contrib.framework.python.ops import variables 13 | from tensorflow.contrib.layers.python import layers 14 | from tensorflow.contrib.layers.python.layers import initializers 15 | from tensorflow.contrib.layers.python.layers import utils 16 | 17 | from tensorflow.python.framework import ops 18 | from tensorflow.python.ops import array_ops 19 | from tensorflow.python.ops import init_ops 20 | from tensorflow.python.ops import nn 21 | from tensorflow.python.ops import standard_ops 22 | from tensorflow.python.ops import variable_scope 23 | 24 | 25 | @add_arg_scope 26 | def fully_connected(inputs, 27 | num_outputs, 28 | activation_fn=nn.relu, 29 | normalizer_fn=None, 30 | normalizer_params=None, 31 | weights_normalizer_fn=None, 32 | weights_normalizer_params=None, 33 | weights_initializer=initializers.xavier_initializer(), 34 | weights_regularizer=None, 35 | biases_initializer=init_ops.zeros_initializer(), 36 | biases_regularizer=None, 37 | reuse=None, 38 | variables_collections=None, 39 | outputs_collections=None, 40 | trainable=True, 41 | scope=None): 42 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.fully_connected, 43 | # add weights_nomalizer_* options. 44 | """Adds a fully connected layer. 45 | 46 | `fully_connected` creates a variable called `weights`, representing a fully 47 | connected weight matrix, which is multiplied by the `inputs` to produce a 48 | `Tensor` of hidden units. If a `normalizer_fn` is provided (such as 49 | `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is 50 | None and a `biases_initializer` is provided then a `biases` variable would be 51 | created and added the hidden units. Finally, if `activation_fn` is not `None`, 52 | it is applied to the hidden units as well. 53 | 54 | Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened 55 | prior to the initial matrix multiply by `weights`. 56 | 57 | Args: 58 | inputs: A tensor of with at least rank 2 and value for the last dimension, 59 | i.e. `[batch_size, depth]`, `[None, None, None, channels]`. 60 | num_outputs: Integer or long, the number of output units in the layer. 61 | activation_fn: activation function, set to None to skip it and maintain 62 | a linear activation. 63 | normalizer_fn: normalization function to use instead of `biases`. If 64 | `normalizer_fn` is provided then `biases_initializer` and 65 | `biases_regularizer` are ignored and `biases` are not created nor added. 66 | default set to None for no normalizer function 67 | normalizer_params: normalization function parameters. 68 | weights_normalizer_fn: weights normalization function. 69 | weights_normalizer_params: weights normalization function parameters. 70 | weights_initializer: An initializer for the weights. 71 | weights_regularizer: Optional regularizer for the weights. 72 | biases_initializer: An initializer for the biases. If None skip biases. 73 | biases_regularizer: Optional regularizer for the biases. 74 | reuse: whether or not the layer and its variables should be reused. To be 75 | able to reuse the layer scope must be given. 76 | variables_collections: Optional list of collections for all the variables or 77 | a dictionary containing a different list of collections per variable. 78 | outputs_collections: collection to add the outputs. 79 | trainable: If `True` also add variables to the graph collection 80 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 81 | scope: Optional scope for variable_scope. 82 | 83 | Returns: 84 | the tensor variable representing the result of the series of operations. 85 | 86 | Raises: 87 | ValueError: if x has rank less than 2 or if its last dimension is not set. 88 | """ 89 | if not (isinstance(num_outputs, six.integer_types)): 90 | raise ValueError('num_outputs should be int or long, got %s.', num_outputs) 91 | with variable_scope.variable_scope(scope, 'fully_connected', [inputs], 92 | reuse=reuse) as sc: 93 | inputs = ops.convert_to_tensor(inputs) 94 | dtype = inputs.dtype.base_dtype 95 | inputs_shape = inputs.get_shape() 96 | num_input_units = utils.last_dimension(inputs_shape, min_rank=2) 97 | 98 | static_shape = inputs_shape.as_list() 99 | static_shape[-1] = num_outputs 100 | 101 | out_shape = array_ops.unpack(array_ops.shape(inputs), len(static_shape)) 102 | out_shape[-1] = num_outputs 103 | 104 | weights_shape = [num_input_units, num_outputs] 105 | weights_collections = utils.get_variable_collections( 106 | variables_collections, 'weights') 107 | weights = variables.model_variable('weights', 108 | shape=weights_shape, 109 | dtype=dtype, 110 | initializer=weights_initializer, 111 | regularizer=weights_regularizer, 112 | collections=weights_collections, 113 | trainable=trainable) 114 | if weights_normalizer_fn is not None: 115 | weights_normalizer_params = weights_normalizer_params or {} 116 | weights = weights_normalizer_fn(weights, **weights_normalizer_params) 117 | if len(static_shape) > 2: 118 | # Reshape inputs 119 | inputs = array_ops.reshape(inputs, [-1, num_input_units]) 120 | outputs = standard_ops.matmul(inputs, weights) 121 | if normalizer_fn is not None: 122 | normalizer_params = normalizer_params or {} 123 | outputs = normalizer_fn(outputs, **normalizer_params) 124 | else: 125 | if biases_initializer is not None: 126 | biases_collections = utils.get_variable_collections( 127 | variables_collections, 'biases') 128 | biases = variables.model_variable('biases', 129 | shape=[num_outputs, ], 130 | dtype=dtype, 131 | initializer=biases_initializer, 132 | regularizer=biases_regularizer, 133 | collections=biases_collections, 134 | trainable=trainable) 135 | outputs = nn.bias_add(outputs, biases) 136 | if activation_fn is not None: 137 | outputs = activation_fn(outputs) 138 | if len(static_shape) > 2: 139 | # Reshape back outputs 140 | outputs = array_ops.reshape(outputs, array_ops.pack(out_shape)) 141 | outputs.set_shape(static_shape) 142 | return utils.collect_named_outputs(outputs_collections, 143 | sc.original_name_scope, outputs) 144 | 145 | 146 | @add_arg_scope 147 | def flatten_fully_connected(inputs, 148 | num_outputs, 149 | activation_fn=nn.relu, 150 | normalizer_fn=None, 151 | normalizer_params=None, 152 | weights_normalizer_fn=None, 153 | weights_normalizer_params=None, 154 | weights_initializer=initializers.xavier_initializer(), 155 | weights_regularizer=None, 156 | biases_initializer=init_ops.zeros_initializer(), 157 | biases_regularizer=None, 158 | reuse=None, 159 | variables_collections=None, 160 | outputs_collections=None, 161 | trainable=True, 162 | scope=None): 163 | with variable_scope.variable_scope(scope, 'flatten_fully_connected'): 164 | if inputs.shape.ndims > 2: 165 | inputs = layers.flatten(inputs) 166 | return fully_connected(inputs=inputs, 167 | num_outputs=num_outputs, 168 | activation_fn=activation_fn, 169 | normalizer_fn=normalizer_fn, 170 | normalizer_params=normalizer_params, 171 | weights_normalizer_fn=weights_normalizer_fn, 172 | weights_normalizer_params=weights_normalizer_params, 173 | weights_initializer=weights_initializer, 174 | weights_regularizer=weights_regularizer, 175 | biases_initializer=biases_initializer, 176 | biases_regularizer=biases_regularizer, 177 | reuse=reuse, 178 | variables_collections=variables_collections, 179 | outputs_collections=outputs_collections, 180 | trainable=trainable, 181 | scope=scope) 182 | 183 | flatten_dense = flatten_fully_connected 184 | 185 | 186 | @add_arg_scope 187 | def convolution(inputs, 188 | num_outputs, 189 | kernel_size, 190 | stride=1, 191 | padding='SAME', 192 | data_format=None, 193 | rate=1, 194 | activation_fn=nn.relu, 195 | normalizer_fn=None, 196 | normalizer_params=None, 197 | weights_normalizer_fn=None, 198 | weights_normalizer_params=None, 199 | weights_initializer=initializers.xavier_initializer(), 200 | weights_regularizer=None, 201 | biases_initializer=init_ops.zeros_initializer(), 202 | biases_regularizer=None, 203 | reuse=None, 204 | variables_collections=None, 205 | outputs_collections=None, 206 | trainable=True, 207 | scope=None): 208 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.convolution, 209 | # add weights_nomalizer_* options. 210 | """Adds an N-D convolution followed by an optional batch_norm layer. 211 | 212 | It is required that 1 <= N <= 3. 213 | 214 | `convolution` creates a variable called `weights`, representing the 215 | convolutional kernel, that is convolved (actually cross-correlated) with the 216 | `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is 217 | provided (such as `batch_norm`), it is then applied. Otherwise, if 218 | `normalizer_fn` is None and a `biases_initializer` is provided then a `biases` 219 | variable would be created and added the activations. Finally, if 220 | `activation_fn` is not `None`, it is applied to the activations as well. 221 | 222 | Performs a'trous convolution with input stride/dilation rate equal to `rate` 223 | if a value > 1 for any dimension of `rate` is specified. In this case 224 | `stride` values != 1 are not supported. 225 | 226 | Args: 227 | inputs: a Tensor of rank N+2 of shape 228 | `[batch_size] + input_spatial_shape + [in_channels]` if data_format does 229 | not start with "NC" (default), or 230 | `[batch_size, in_channels] + input_spatial_shape` if data_format starts 231 | with "NC". 232 | num_outputs: integer, the number of output filters. 233 | kernel_size: a sequence of N positive integers specifying the spatial 234 | dimensions of of the filters. Can be a single integer to specify the same 235 | value for all spatial dimensions. 236 | stride: a sequence of N positive integers specifying the stride at which to 237 | compute output. Can be a single integer to specify the same value for all 238 | spatial dimensions. Specifying any `stride` value != 1 is incompatible 239 | with specifying any `rate` value != 1. 240 | padding: one of `"VALID"` or `"SAME"`. 241 | data_format: A string or None. Specifies whether the channel dimension of 242 | the `input` and output is the last dimension (default, or if `data_format` 243 | does not start with "NC"), or the second dimension (if `data_format` 244 | starts with "NC"). For N=1, the valid values are "NWC" (default) and 245 | "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For 246 | N=3, currently the only valid value is "NDHWC". 247 | rate: a sequence of N positive integers specifying the dilation rate to use 248 | for a'trous convolution. Can be a single integer to specify the same 249 | value for all spatial dimensions. Specifying any `rate` value != 1 is 250 | incompatible with specifying any `stride` value != 1. 251 | activation_fn: activation function, set to None to skip it and maintain 252 | a linear activation. 253 | normalizer_fn: normalization function to use instead of `biases`. If 254 | `normalizer_fn` is provided then `biases_initializer` and 255 | `biases_regularizer` are ignored and `biases` are not created nor added. 256 | default set to None for no normalizer function 257 | normalizer_params: normalization function parameters. 258 | weights_normalizer_fn: weights normalization function. 259 | weights_normalizer_params: weights normalization function parameters. 260 | weights_initializer: An initializer for the weights. 261 | weights_regularizer: Optional regularizer for the weights. 262 | biases_initializer: An initializer for the biases. If None skip biases. 263 | biases_regularizer: Optional regularizer for the biases. 264 | reuse: whether or not the layer and its variables should be reused. To be 265 | able to reuse the layer scope must be given. 266 | variables_collections: optional list of collections for all the variables or 267 | a dictionary containing a different list of collection per variable. 268 | outputs_collections: collection to add the outputs. 269 | trainable: If `True` also add variables to the graph collection 270 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 271 | scope: Optional scope for `variable_scope`. 272 | 273 | Returns: 274 | a tensor representing the output of the operation. 275 | 276 | Raises: 277 | ValueError: if `data_format` is invalid. 278 | ValueError: both 'rate' and `stride` are not uniformly 1. 279 | """ 280 | if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC']: 281 | raise ValueError('Invalid data_format: %r' % (data_format,)) 282 | with variable_scope.variable_scope(scope, 'Conv', [inputs], 283 | reuse=reuse) as sc: 284 | inputs = ops.convert_to_tensor(inputs) 285 | dtype = inputs.dtype.base_dtype 286 | input_rank = inputs.get_shape().ndims 287 | if input_rank is None: 288 | raise ValueError('Rank of inputs must be known') 289 | if input_rank < 3 or input_rank > 5: 290 | raise ValueError('Rank of inputs is %d, which is not >= 3 and <= 5' % 291 | input_rank) 292 | conv_dims = input_rank - 2 293 | kernel_size = utils.n_positive_integers(conv_dims, kernel_size) 294 | stride = utils.n_positive_integers(conv_dims, stride) 295 | rate = utils.n_positive_integers(conv_dims, rate) 296 | 297 | if data_format is None or data_format.endswith('C'): 298 | num_input_channels = inputs.get_shape()[input_rank - 1].value 299 | elif data_format.startswith('NC'): 300 | num_input_channels = inputs.get_shape()[1].value 301 | else: 302 | raise ValueError('Invalid data_format') 303 | 304 | if num_input_channels is None: 305 | raise ValueError('Number of in_channels must be known.') 306 | 307 | weights_shape = ( 308 | list(kernel_size) + [num_input_channels, num_outputs]) 309 | weights_collections = utils.get_variable_collections(variables_collections, 310 | 'weights') 311 | weights = variables.model_variable('weights', 312 | shape=weights_shape, 313 | dtype=dtype, 314 | initializer=weights_initializer, 315 | regularizer=weights_regularizer, 316 | collections=weights_collections, 317 | trainable=trainable) 318 | if weights_normalizer_fn is not None: 319 | weights_normalizer_params = weights_normalizer_params or {} 320 | weights = weights_normalizer_fn(weights, **weights_normalizer_params) 321 | outputs = nn.convolution(input=inputs, 322 | filter=weights, 323 | dilation_rate=rate, 324 | strides=stride, 325 | padding=padding, 326 | data_format=data_format) 327 | if normalizer_fn is not None: 328 | normalizer_params = normalizer_params or {} 329 | outputs = normalizer_fn(outputs, **normalizer_params) 330 | else: 331 | if biases_initializer is not None: 332 | biases_collections = utils.get_variable_collections( 333 | variables_collections, 'biases') 334 | biases = variables.model_variable('biases', 335 | shape=[num_outputs], 336 | dtype=dtype, 337 | initializer=biases_initializer, 338 | regularizer=biases_regularizer, 339 | collections=biases_collections, 340 | trainable=trainable) 341 | outputs = nn.bias_add(outputs, biases, data_format=data_format) 342 | if activation_fn is not None: 343 | outputs = activation_fn(outputs) 344 | return utils.collect_named_outputs(outputs_collections, 345 | sc.original_name_scope, outputs) 346 | 347 | 348 | convolution2d = convolution 349 | convolution3d = convolution 350 | 351 | 352 | @add_arg_scope 353 | def spectral_normalization(weights, 354 | num_iterations=1, 355 | epsilon=1e-12, 356 | u_initializer=tf.random_normal_initializer(), 357 | updates_collections=tf.GraphKeys.UPDATE_OPS, 358 | is_training=True, 359 | reuse=None, 360 | variables_collections=None, 361 | outputs_collections=None, 362 | scope=None): 363 | with tf.variable_scope(scope, 'SpectralNorm', [weights], reuse=reuse) as sc: 364 | weights = tf.convert_to_tensor(weights) 365 | 366 | dtype = weights.dtype.base_dtype 367 | 368 | w_t = tf.reshape(weights, [-1, weights.shape.as_list()[-1]]) 369 | w = tf.transpose(w_t) 370 | m, n = w.shape.as_list() 371 | 372 | u_collections = utils.get_variable_collections(variables_collections, 'u') 373 | u = tf.get_variable("u", 374 | shape=[m, 1], 375 | dtype=dtype, 376 | initializer=u_initializer, 377 | trainable=False, 378 | collections=u_collections,) 379 | sigma_collections = utils.get_variable_collections(variables_collections, 'sigma') 380 | sigma = tf.get_variable('sigma', 381 | shape=[], 382 | dtype=dtype, 383 | initializer=tf.zeros_initializer(), 384 | trainable=False, 385 | collections=sigma_collections) 386 | 387 | def _power_iteration(i, u, v): 388 | v_ = tf.nn.l2_normalize(tf.matmul(w_t, u), epsilon=epsilon) 389 | u_ = tf.nn.l2_normalize(tf.matmul(w, v_), epsilon=epsilon) 390 | return i + 1, u_, v_ 391 | 392 | _, u_, v_ = tf.while_loop( 393 | cond=lambda i, _1, _2: i < num_iterations, 394 | body=_power_iteration, 395 | loop_vars=[tf.constant(0), u, tf.zeros(shape=[n, 1], dtype=tf.float32)] 396 | ) 397 | u_ = tf.stop_gradient(u_) 398 | v_ = tf.stop_gradient(v_) 399 | sigma_ = tf.matmul(tf.transpose(u_), tf.matmul(w, v_))[0, 0] 400 | 401 | update_u = u.assign(u_) 402 | update_sigma = sigma.assign(sigma_) 403 | if updates_collections is None: 404 | def _force_update(): 405 | with tf.control_dependencies([update_u, update_sigma]): 406 | return tf.identity(sigma_) 407 | 408 | sigma_ = utils.smart_cond(is_training, _force_update, lambda: sigma) 409 | weights_sn = weights / sigma_ 410 | else: 411 | sigma_ = utils.smart_cond(is_training, lambda: sigma_, lambda: sigma) 412 | weights_sn = weights / sigma_ 413 | tf.add_to_collections(updates_collections, update_u) 414 | tf.add_to_collections(updates_collections, update_sigma) 415 | 416 | return utils.collect_named_outputs(outputs_collections, sc.name, weights_sn) 417 | 418 | 419 | # Simple alias. 420 | conv2d = convolution2d 421 | conv3d = convolution3d 422 | -------------------------------------------------------------------------------- /tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.ops.ops import * 6 | -------------------------------------------------------------------------------- /tflib/ops/ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def minmax_norm(x, epsilon=1e-12): 9 | x = tf.to_float(x) 10 | min_val = tf.reduce_min(x) 11 | max_val = tf.reduce_max(x) 12 | x_norm = (x - min_val) / tf.maximum((max_val - min_val), epsilon) 13 | return x_norm 14 | -------------------------------------------------------------------------------- /tflib/parallel.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def average_gradients(tower_grads): 9 | """Calculate the average gradient for each shared variable across all towers. 10 | Copied from https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_multi_gpu_train.py. 11 | Note that this function provides a synchronization point across all towers. 12 | Args: 13 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 14 | is over individual gradients. The inner list is over the gradient 15 | calculation for each tower. 16 | Returns: 17 | List of pairs of (gradient, variable) where the gradient has been averaged 18 | across all towers. 19 | """ 20 | average_grads = [] 21 | for grad_and_vars in zip(*tower_grads): 22 | # Note that each grad_and_vars looks like the following: 23 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 24 | grads = [] 25 | for g, _ in grad_and_vars: 26 | # Add 0 dimension to the gradients to represent the tower. 27 | expanded_g = tf.expand_dims(g, 0) 28 | 29 | # Append on a 'tower' dimension which we will average over below. 30 | grads.append(expanded_g) 31 | 32 | # Average over the 'tower' dimension. 33 | grad = tf.concat(axis=0, values=grads) 34 | grad = tf.reduce_mean(grad, 0) 35 | 36 | # Keep in mind that the Variables are redundant because they are shared 37 | # across towers. So .. we will just return the first tower's pointer to 38 | # the Variable. 39 | v = grad_and_vars[0][1] 40 | grad_and_var = (grad, v) 41 | average_grads.append(grad_and_var) 42 | return average_grads 43 | -------------------------------------------------------------------------------- /tflib/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def session(graph=None, 9 | allow_soft_placement=True, 10 | log_device_placement=False, 11 | allow_growth=True): 12 | """Return a Session with simple config.""" 13 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement, 14 | log_device_placement=log_device_placement) 15 | config.gpu_options.allow_growth = allow_growth 16 | return tf.Session(graph=graph, config=config) 17 | 18 | 19 | def print_tensor(tensors): 20 | if not isinstance(tensors, (list, tuple)): 21 | tensors = [tensors] 22 | 23 | for i, tensor in enumerate(tensors): 24 | ctype = str(type(tensor)) 25 | if 'Tensor' in ctype: 26 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' % 27 | (i, 'Tensor', tensor.name, tensor.shape, tensor.dtype.name, tensor.device)) 28 | elif 'Variable' in ctype: 29 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' % 30 | (i, 'Variable', tensor.name, tensor.shape, tensor.dtype.name, tensor.device)) 31 | elif 'Operation' in ctype: 32 | print('%d: %s("%s", device=%s)' % 33 | (i, 'Operation', tensor.name, tensor.device)) 34 | else: 35 | raise Exception('Not a Tensor, Variable or Operation!') 36 | 37 | 38 | prt = print_tensor 39 | 40 | 41 | def shape(tensor): 42 | sp = tensor.get_shape().as_list() 43 | return [num if num is not None else -1 for num in sp] 44 | 45 | 46 | def summary(tensor_collection, 47 | summary_type=['mean', 'stddev', 'max', 'min', 'sparsity', 'histogram'], 48 | scope=None): 49 | """Summary. 50 | 51 | Usage: 52 | 1. summary(tensor) 53 | 2. summary([tensor_a, tensor_b]) 54 | 3. summary({tensor_a: 'a', tensor_b: 'b}) 55 | """ 56 | def _summary(tensor, name, summary_type): 57 | """Attach a lot of summaries to a Tensor.""" 58 | if name is None: 59 | name = tensor.name 60 | 61 | summaries = [] 62 | if len(tensor.shape) == 0: 63 | summaries.append(tf.summary.scalar(name, tensor)) 64 | else: 65 | if 'mean' in summary_type: 66 | mean = tf.reduce_mean(tensor) 67 | summaries.append(tf.summary.scalar(name + '/mean', mean)) 68 | if 'stddev' in summary_type: 69 | mean = tf.reduce_mean(tensor) 70 | stddev = tf.sqrt(tf.reduce_mean(tf.square(tensor - mean))) 71 | summaries.append(tf.summary.scalar(name + '/stddev', stddev)) 72 | if 'max' in summary_type: 73 | summaries.append(tf.summary.scalar(name + '/max', tf.reduce_max(tensor))) 74 | if 'min' in summary_type: 75 | summaries.append(tf.summary.scalar(name + '/min', tf.reduce_min(tensor))) 76 | if 'sparsity' in summary_type: 77 | summaries.append(tf.summary.scalar(name + '/sparsity', tf.nn.zero_fraction(tensor))) 78 | if 'histogram' in summary_type: 79 | summaries.append(tf.summary.histogram(name, tensor)) 80 | return tf.summary.merge(summaries) 81 | 82 | if not isinstance(tensor_collection, (list, tuple, dict)): 83 | tensor_collection = [tensor_collection] 84 | 85 | with tf.name_scope(scope, 'summary'): 86 | summaries = [] 87 | if isinstance(tensor_collection, (list, tuple)): 88 | for tensor in tensor_collection: 89 | summaries.append(_summary(tensor, None, summary_type)) 90 | else: 91 | for tensor, name in tensor_collection.items(): 92 | summaries.append(_summary(tensor, name, summary_type)) 93 | return tf.summary.merge(summaries) 94 | 95 | 96 | def counter(start=0, scope=None): 97 | with tf.variable_scope(scope, 'counter'): 98 | counter = tf.get_variable(name='counter', 99 | initializer=tf.constant_initializer(start), 100 | shape=(), 101 | dtype=tf.int64) 102 | update_cnt = tf.assign(counter, tf.add(counter, 1)) 103 | return counter, update_cnt 104 | -------------------------------------------------------------------------------- /tflib/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.vision.dataset import * 6 | -------------------------------------------------------------------------------- /tflib/vision/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tflib.vision.dataset.mnist import * 6 | -------------------------------------------------------------------------------- /tflib/vision/dataset/mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gzip 6 | import multiprocessing 7 | import os 8 | import struct 9 | import subprocess 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from tflib.data.memory_data import MemoryData 15 | 16 | 17 | _N_CPU = multiprocessing.cpu_count() 18 | 19 | 20 | def mnist_download(download_dir): 21 | url_base = 'http://yann.lecun.com/exdb/mnist/' 22 | file_names = ['train-images-idx3-ubyte.gz', 23 | 'train-labels-idx1-ubyte.gz', 24 | 't10k-images-idx3-ubyte.gz', 25 | 't10k-labels-idx1-ubyte.gz'] 26 | for file_name in file_names: 27 | url = url_base + file_name 28 | save_path = os.path.join(download_dir, file_name) 29 | cmd = ['curl', url, '-o', save_path] 30 | print('Downloading ', file_name) 31 | if not os.path.exists(save_path): 32 | subprocess.call(cmd) 33 | else: 34 | print('%s exists, skip!' % file_name) 35 | 36 | 37 | def mnist_load(data_dir, split='train'): 38 | """Load MNIST dataset, modified from https://gist.github.com/akesling/5358964. 39 | 40 | Returns: 41 | `imgs`, `lbls`, `num`. 42 | 43 | `imgs` : [-1.0, 1.0] float64 images of shape (N * H * W). 44 | `lbls` : Int labels of shape (N,). 45 | `num` : # of datas. 46 | """ 47 | mnist_download(data_dir) 48 | 49 | if split == 'train': 50 | fname_img = os.path.join(data_dir, 'train-images-idx3-ubyte') 51 | fname_lbl = os.path.join(data_dir, 'train-labels-idx1-ubyte') 52 | elif split == 'test': 53 | fname_img = os.path.join(data_dir, 't10k-images-idx3-ubyte') 54 | fname_lbl = os.path.join(data_dir, 't10k-labels-idx1-ubyte') 55 | else: 56 | raise ValueError("`split` must be 'test' or 'train'") 57 | 58 | def _unzip_gz(file_name): 59 | unzip_name = file_name.replace('.gz', '') 60 | gz_file = gzip.GzipFile(file_name) 61 | open(unzip_name, 'w+').write(gz_file.read()) 62 | gz_file.close() 63 | 64 | if not os.path.exists(fname_img): 65 | _unzip_gz(fname_img + '.gz') 66 | if not os.path.exists(fname_lbl): 67 | _unzip_gz(fname_lbl + '.gz') 68 | 69 | with open(fname_lbl, 'rb') as flbl: 70 | struct.unpack('>II', flbl.read(8)) 71 | lbls = np.fromfile(flbl, dtype=np.int8) 72 | 73 | with open(fname_img, 'rb') as fimg: 74 | _, _, rows, cols = struct.unpack('>IIII', fimg.read(16)) 75 | imgs = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbls), rows, cols) 76 | imgs = imgs / 127.5 - 1 77 | 78 | return imgs, lbls, len(lbls) 79 | 80 | 81 | class Mnist(MemoryData): 82 | 83 | def __init__(self, 84 | data_dir, 85 | batch_size, 86 | split='train', 87 | prefetch_batch=_N_CPU + 1, 88 | drop_remainder=True, 89 | filter=None, 90 | map_func=None, 91 | num_threads=_N_CPU, 92 | shuffle=True, 93 | buffer_size=None, 94 | repeat=-1, 95 | sess=None): 96 | imgs, lbls, _ = mnist_load(data_dir, split) 97 | imgs.shape = imgs.shape + (1,) 98 | 99 | imgs_pl = tf.placeholder(tf.float32, imgs.shape) 100 | lbls_pl = tf.placeholder(tf.int64, lbls.shape) 101 | 102 | memory_data_dict = {'img': imgs_pl, 'lbl': lbls_pl} 103 | 104 | self.feed_dict = {imgs_pl: imgs, lbls_pl: lbls} 105 | super(Mnist, self).__init__(memory_data_dict, 106 | batch_size, 107 | prefetch_batch, 108 | drop_remainder, 109 | filter, 110 | map_func, 111 | num_threads, 112 | shuffle, 113 | buffer_size, 114 | repeat, 115 | sess) 116 | 117 | def reset(self): 118 | super(Mnist, self).reset(self.feed_dict) 119 | 120 | if __name__ == '__main__': 121 | import imlib as im 122 | from tflib import session 123 | sess = session() 124 | mnist = Mnist('/tmp', 5000, repeat=1, sess=sess) 125 | print(len(mnist)) 126 | for batch in mnist: 127 | print(batch['lbl'][-1]) 128 | im.imshow(batch['img'][-1].squeeze()) 129 | im.show() 130 | sess.close() 131 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | from functools import partial 7 | import json 8 | import traceback 9 | 10 | import data 11 | import imlib as im 12 | import model 13 | import numpy as np 14 | import pylib 15 | import tensorflow as tf 16 | import tflib as tl 17 | 18 | 19 | # ============================================================================== 20 | # = param = 21 | # ============================================================================== 22 | 23 | # argument 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--att', dest='att', default='', choices=list(data.Celeba.att_dict.keys()) + ['']) 26 | parser.add_argument('--ks', dest='ks', type=int, default=[2, 3, 3], nargs='+', help='k each layer') 27 | parser.add_argument('--lambdas', dest='lambdas', type=float, default=[1., 1., 1.], nargs='+', help='loss weight of each layer') 28 | parser.add_argument('--continuous_last', dest='continuous_last', action='store_true') 29 | parser.add_argument('--half_acgan', dest='half_acgan', action='store_true') 30 | 31 | parser.add_argument('--epoch', dest='epoch', type=int, default=100) 32 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=64) 33 | parser.add_argument('--lr_d', dest='lr_d', type=float, default=0.0002, help='learning rate of d') 34 | parser.add_argument('--lr_g', dest='lr_g', type=float, default=0.0002, help='learning rate of g') 35 | parser.add_argument('--n_d', dest='n_d', type=int, default=1) 36 | parser.add_argument('--n_g', dest='n_g', type=int, default=1) 37 | parser.add_argument('--n_d_pre', dest='n_d_pre', type=int, default=0) 38 | parser.add_argument('--optimizer', dest='optimizer', default='adam', choices=['adam', 'rmsprop']) 39 | 40 | parser.add_argument('--z_dim', dest='z_dim', type=int, default=100, help='dimension of latent') 41 | parser.add_argument('--loss_mode', dest='loss_mode', default='gan', choices=['gan', 'lsgan', 'wgan', 'hinge']) 42 | parser.add_argument('--gp_mode', dest='gp_mode', default='none', choices=['none', 'dragan', 'wgan-gp'], help='type of gradient penalty') 43 | parser.add_argument('--norm', dest='norm', default='batch_norm', choices=['batch_norm', 'instance_norm', 'layer_norm', 'none']) 44 | 45 | parser.add_argument('--experiment_name', dest='experiment_name', default='default') 46 | 47 | args = parser.parse_args() 48 | 49 | att = args.att 50 | ks = args.ks 51 | if att != '': 52 | ks[0] = 2 53 | lambdas = args.lambdas 54 | assert len(ks) == len(lambdas), 'The lens of `ks` and `lambdas` should be the same!' 55 | continuous_last = args.continuous_last 56 | if len(ks) == 1 and att != '': 57 | continuous_last = False 58 | half_acgan = args.half_acgan 59 | 60 | epoch = args.epoch 61 | batch_size = args.batch_size 62 | lr_d = args.lr_d 63 | lr_g = args.lr_g 64 | n_d = args.n_d 65 | n_g = args.n_g 66 | n_d_pre = args.n_d_pre 67 | optimizer = args.optimizer 68 | 69 | z_dim = args.z_dim 70 | loss_mode = args.loss_mode 71 | gp_mode = args.gp_mode 72 | norm = args.norm 73 | 74 | experiment_name = args.experiment_name 75 | 76 | pylib.mkdir('./output/%s' % experiment_name) 77 | with open('./output/%s/setting.txt' % experiment_name, 'w') as f: 78 | f.write(json.dumps(vars(args), indent=4, separators=(',', ':'))) 79 | 80 | img_size = 64 81 | 82 | # dataset 83 | dataset = data.Celeba('./data', ['Bangs' if att == '' else att], img_size, batch_size) 84 | 85 | 86 | # ============================================================================== 87 | # = graph = 88 | # ============================================================================== 89 | 90 | # models 91 | c_dim = len(model.sample_c(ks)[0]) 92 | D = partial(model.D, c_dim=c_dim, norm_name=norm) 93 | G = model.G 94 | 95 | # otpimizer 96 | if optimizer == 'adam': 97 | optim = partial(tf.train.AdamOptimizer, beta1=0.5) 98 | elif optimizer == 'rmsprop': 99 | optim = tf.train.RMSPropOptimizer 100 | 101 | # loss func 102 | d_loss_fn, g_loss_fn = model.get_loss_fn(loss_mode) 103 | tree_loss_fn = partial(model.tree_loss, ks=ks, continuous_last=continuous_last) 104 | 105 | # inputs 106 | real = tf.placeholder(tf.float32, [None] + [img_size, img_size, 3]) 107 | z = tf.placeholder(tf.float32, [None, z_dim]) 108 | c = tf.placeholder(tf.float32, [None, c_dim]) 109 | mask = tf.placeholder(tf.float32, [None, c_dim]) 110 | 111 | counter = tf.placeholder(tf.int64, []) 112 | layer_mask = tf.constant(np.tril(np.ones(len(ks))), dtype=tf.float32)[counter // (epoch // len(ks))] 113 | 114 | # generate 115 | fake = G(z, c) 116 | 117 | # dicriminate 118 | r_logit, r_c_logit = D(real) 119 | f_logit, f_c_logit = D(fake) 120 | 121 | # d loss 122 | d_r_loss, d_f_loss = d_loss_fn(r_logit, f_logit) 123 | d_f_tree_losses = tree_loss_fn(f_c_logit, c, mask) 124 | if att != '': 125 | d_r_tree_losses = tree_loss_fn(r_c_logit, c, mask) 126 | start = 1 if half_acgan else 0 127 | d_tree_loss = sum([d_f_tree_losses[i] * lambdas[i] * layer_mask[i] for i in range(start, len(lambdas))]) 128 | d_tree_loss += d_r_tree_losses[0] * lambdas[0] * layer_mask[0] 129 | else: 130 | d_tree_loss = sum([d_f_tree_losses[i] * lambdas[i] for i in range(len(lambdas))]) 131 | gp = model.gradient_penalty(D, real, fake, gp_mode) 132 | d_loss = d_r_loss + d_f_loss + d_tree_loss + gp * 10.0 133 | 134 | # g loss 135 | g_f_loss = g_loss_fn(f_logit) 136 | g_f_tree_losses = tree_loss_fn(f_c_logit, c, mask) 137 | g_tree_loss = sum([g_f_tree_losses[i] * lambdas[i] * layer_mask[i] for i in range(len(lambdas))]) 138 | g_loss = g_f_loss + g_tree_loss 139 | 140 | # optims 141 | d_step = optim(learning_rate=lr_d).minimize(d_loss, var_list=tl.trainable_variables(includes='D')) 142 | g_step = optim(learning_rate=lr_g).minimize(g_loss, var_list=tl.trainable_variables(includes='G')) 143 | 144 | # summaries 145 | d_summary = tl.summary({d_r_loss: 'd_r_loss', 146 | d_f_loss: 'd_f_loss', 147 | d_r_loss + d_f_loss: 'd_loss', 148 | gp: 'gp'}, scope='D') 149 | tmp = {l: 'd_f_tree_loss_%d' % i for i, l in enumerate(d_f_tree_losses)} 150 | if att != '': 151 | tmp.update({d_r_tree_losses[0]: 'd_r_tree_loss_0'}) 152 | d_tree_summary = tl.summary(tmp, scope='D_Tree') 153 | d_summary = tf.summary.merge([d_summary, d_tree_summary]) 154 | 155 | g_summary = tl.summary({g_f_loss: 'g_f_loss'}, scope='G') 156 | g_tree_summary = tl.summary({l: 'g_f_tree_loss_%d' % i for i, l in enumerate(g_f_tree_losses)}, scope='G_Tree') 157 | g_summary = tf.summary.merge([g_summary, g_tree_summary]) 158 | 159 | # sample 160 | z_sample = tf.placeholder(tf.float32, [None, z_dim]) 161 | c_sample = tf.placeholder(tf.float32, [None, c_dim]) 162 | f_sample = G(z_sample, c_sample, is_training=False) 163 | 164 | 165 | # ============================================================================== 166 | # = train = 167 | # ============================================================================== 168 | 169 | # epoch counter 170 | ep_cnt, update_cnt = tl.counter(start=1) 171 | 172 | # session 173 | sess = tl.session() 174 | 175 | # saver 176 | saver = tf.train.Saver(max_to_keep=1) 177 | 178 | # summary writer 179 | summary_writer = tf.summary.FileWriter('./output/%s/summaries' % experiment_name, sess.graph) 180 | 181 | # initialization 182 | ckpt_dir = './output/%s/checkpoints' % experiment_name 183 | pylib.mkdir(ckpt_dir) 184 | try: 185 | tl.load_checkpoint(ckpt_dir, sess) 186 | except: 187 | sess.run(tf.global_variables_initializer()) 188 | 189 | # train 190 | try: 191 | c_ipt_sample = np.stack(model.traversal_trees(ks, continuous_last=continuous_last)[1]) 192 | z_ipt_samples = [np.stack([np.random.normal(size=[z_dim])] * len(c_ipt_sample)) for i in range(15)] 193 | 194 | it = 0 195 | it_per_epoch = len(dataset) // (batch_size * n_d) 196 | for ep in range(sess.run(ep_cnt), epoch + 1): 197 | sess.run(update_cnt) 198 | 199 | dataset.reset() 200 | for i in range(it_per_epoch): 201 | it += 1 202 | 203 | # train D 204 | if n_d_pre > 0 and it <= 25: 205 | n_d_ = n_d_pre 206 | else: 207 | n_d_ = n_d 208 | for _ in range(n_d_): 209 | # batch data 210 | real_ipt, att_ipt = dataset.get_next() 211 | c_ipt = [] 212 | mask_ipt = [] 213 | for idx in range(batch_size): 214 | if att == '': 215 | c_1 = None 216 | else: 217 | if att_ipt[idx] == 1: 218 | c_1 = np.array([1.0, 0]) 219 | else: 220 | c_1 = np.array([0, 1.0]) 221 | c_tmp, mask_tmp, _, _ = model.sample_c(ks, c_1, continuous_last) 222 | c_ipt.append(c_tmp) 223 | mask_ipt.append(mask_tmp) 224 | c_ipt = np.stack(c_ipt) 225 | mask_ipt = np.stack(mask_ipt) 226 | z_ipt = np.random.normal(size=[batch_size, z_dim]) 227 | 228 | d_summary_opt, _ = sess.run([d_summary, d_step], feed_dict={real: real_ipt, z: z_ipt, c: c_ipt, mask: mask_ipt, counter: ep}) 229 | summary_writer.add_summary(d_summary_opt, it) 230 | 231 | # train G 232 | for _ in range(n_g): 233 | # batch data 234 | z_ipt = np.random.normal(size=[batch_size, z_dim]) 235 | 236 | g_summary_opt, _ = sess.run([g_summary, g_step], feed_dict={z: z_ipt, c: c_ipt, mask: mask_ipt, counter: ep}) 237 | summary_writer.add_summary(g_summary_opt, it) 238 | 239 | # display 240 | if it % 1 == 0: 241 | print("Epoch: (%3d) (%5d/%5d)" % (ep, i + 1, it_per_epoch)) 242 | 243 | # sample 244 | if it % 100 == 0: 245 | merge = [] 246 | for z_ipt_sample in z_ipt_samples: 247 | f_sample_opt = sess.run(f_sample, feed_dict={z_sample: z_ipt_sample, c_sample: c_ipt_sample}).squeeze() 248 | 249 | k_prod = 1 250 | for k in ks: 251 | k_prod *= k 252 | f_sample_opts_k = list(f_sample_opt) 253 | for idx in range(len(f_sample_opts_k)): 254 | if idx % (len(f_sample_opts_k) / k_prod) != 0: 255 | f_sample_opts_k[idx] = np.zeros_like(f_sample_opts_k[idx]) 256 | merge.append(np.concatenate(f_sample_opts_k, axis=1)) 257 | merge = np.concatenate(merge, axis=0) 258 | 259 | save_dir = './output/%s/sample_training' % experiment_name 260 | pylib.mkdir(save_dir) 261 | im.imwrite(merge, '%s/Epoch_(%d)_(%dof%d).jpg' % (save_dir, ep, i + 1, it_per_epoch)) 262 | 263 | save_path = saver.save(sess, '%s/Epoch_%d.ckpt' % (ckpt_dir, ep)) 264 | print('Model is saved in file: %s' % save_path) 265 | except: 266 | traceback.print_exc() 267 | finally: 268 | sess.close() 269 | --------------------------------------------------------------------------------