├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── examples.md ├── imlib ├── __init__.py ├── basic.py ├── dtype.py └── transform.py ├── module.py ├── pics ├── first_view.png ├── sample_validation.jpg ├── sample_validation_256x.jpg ├── sample_validation_384x_hd.jpg ├── sample_validation_40.jpg ├── schema.jpg ├── slide.png └── style.jpg ├── pylib ├── __init__.py ├── argument.py ├── path.py ├── processing.py ├── serialization.py └── timer.py ├── results.md ├── scripts ├── align.py ├── cropper.py └── split_CelebA-HQ.py ├── test.py ├── test_multi.py ├── test_slide.py ├── tflib ├── __init__.py ├── data │ ├── __init__.py │ └── dataset.py ├── image │ ├── __init__.py │ ├── filter.py │ └── image.py ├── layers │ ├── __init__.py │ ├── layers.py │ └── layers_slim.py ├── losses │ ├── __init__.py │ └── losses.py ├── metrics │ ├── __init__.py │ └── metrics.py ├── ops │ ├── __init__.py │ └── ops.py └── utils │ ├── __init__.py │ ├── collection.py │ ├── distribute.py │ ├── learning_rate.py │ └── utils.py ├── tfprob ├── __init__.py └── gan │ ├── __init__.py │ ├── gradient_penalty.py │ └── loss.py ├── to_pb.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | /data/ 4 | /output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zhenliang He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

     

2 | 3 |

4 | 5 |

6 | AttGAN 7 |

8 | 9 | 10 | 11 | 12 | > **[AttGAN: Facial Attribute Editing by Only Changing What You Want](https://kdocs.cn/l/cpp7J2ZsFUW7)** \ 13 | > [Zhenliang He](https://lynnho.github.io)1,2, [Wangmeng Zuo](https://scholar.google.com/citations?user=rUOpCEYAAAAJ)4, [Meina Kan](https://scholar.google.com/citations?user=4AKCKKEAAAAJ)1, [Shiguang Shan](https://scholar.google.com/citations?user=Vkzd7MIAAAAJ)1,3, [Xilin Chen](https://scholar.google.com/citations?user=vVx2v20AAAAJ)1 \ 14 | > 1Key Lab of Intelligent Information Processing, Institute of Computing Technology, CAS, China \ 15 | > 2University of Chinese Academy of Sciences, China \ 16 | > 3CAS Center for Excellence in Brain Science and Intelligence Technology, China \ 17 | > 4School of Computer Science and Technology, Harbin Institute of Technology, China 18 | 19 |

20 | 21 | ## Related 22 | 23 | - Other implementations of AttGAN 24 | 25 | - [AttGAN-PyTorch](https://github.com/elvisyjlin/AttGAN-PyTorch) by Yu-Jing Lin 26 | 27 | - [AttGAN-PaddlePaddle](https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleCV/gan) by ceci3 and zhumanyu (**AttGAN is one of the official reproduced models of [PaddlePaddle](https://github.com/PaddlePaddle?type=source)**) 28 | 29 | - Closely related works 30 | 31 | - **An excellent work built upon our code - [STGAN](https://github.com/csmliu/STGAN) (CVPR 2019) by Ming Liu** 32 | 33 | - [Changing-the-Memorability](https://github.com/acecreamu/Changing-the-Memorability) (CVPR 2019 MBCCV Workshop) by acecreamu 34 | 35 | - [Fashion-AttGAN](https://github.com/ChanningPing/Fashion_Attribute_Editing) (CVPR 2019 FSS-USAD Workshop) by Qing Ping 36 | 37 | - An unofficial [demo video](https://www.youtube.com/watch?v=gnN4ZjEWe-8) of AttGAN by 王一凡 38 | 39 | ## Exemplar Results 40 | 41 | - See [results.md](./results.md) for more results, we try **higher resolution** and **more attributes** (all **40** attributes!!!) 42 | 43 | - Inverting 13 attributes respectively 44 | 45 | from left to right: *Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young* 46 | 47 | 48 | 49 | ## Usage 50 | 51 | - Environment 52 | 53 | - Python 3.6 54 | 55 | - TensorFlow 1.15 56 | 57 | - OpenCV, scikit-image, tqdm, oyaml 58 | 59 | - *we recommend [Anaconda](https://www.anaconda.com/distribution/#download-section) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html#linux-installers), then you can create the AttGAN environment with commands below* 60 | 61 | ```console 62 | conda create -n AttGAN python=3.6 63 | 64 | source activate AttGAN 65 | 66 | conda install opencv scikit-image tqdm tensorflow-gpu=1.15 67 | 68 | conda install -c conda-forge oyaml 69 | ``` 70 | 71 | - *NOTICE: if you create a new conda environment, remember to activate it before any other command* 72 | 73 | ```console 74 | source activate AttGAN 75 | ``` 76 | 77 | - Data Preparation 78 | 79 | - Option 1: [CelebA](http://openaccess.thecvf.com/content_iccv_2015/papers/Liu_Deep_Learning_Face_ICCV_2015_paper.pdf)-unaligned (higher quality than the aligned data, 10.2GB) 80 | 81 | - download the dataset 82 | 83 | - img_celeba.7z (move to **./data/img_celeba/img_celeba.7z**): [Google Drive](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg?resourcekey=0-rJlzl934LzC-Xp28GeIBzQ) or [Baidu Netdisk](https://pan.baidu.com/s/1CRxxhoQ97A5qbsKO7iaAJg) (password rp0s) 84 | 85 | - annotations.zip (move to **./data/img_celeba/annotations.zip**): [Google Drive](https://drive.google.com/file/d/1xd-d1WRnbt3yJnwh5ORGZI3g-YS-fKM9/view?usp=sharing) 86 | 87 | - unzip and process the data 88 | 89 | ```console 90 | 7z x ./data/img_celeba/img_celeba.7z/img_celeba.7z.001 -o./data/img_celeba/ 91 | 92 | unzip ./data/img_celeba/annotations.zip -d ./data/img_celeba/ 93 | 94 | python ./scripts/align.py 95 | ``` 96 | 97 | - Option 2: CelebA-HQ (we use the data from [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ), 3.2GB) 98 | 99 | - CelebAMask-HQ.zip (move to **./data/CelebAMask-HQ.zip**): [Google Drive](https://drive.google.com/open?id=1badu11NqxGf6qM3PTTooQDJvQbejgbTv) or [Baidu Netdisk](https://pan.baidu.com/s/1wN1E-B1bJ7mE1mrn9loj5g) 100 | 101 | - unzip and process the data 102 | 103 | ```console 104 | unzip ./data/CelebAMask-HQ.zip -d ./data/ 105 | 106 | python ./scripts/split_CelebA-HQ.py 107 | ``` 108 | 109 | - Run AttGAN 110 | 111 | - training (see [examples.md](./examples.md) for more training commands) 112 | 113 | ```console 114 | \\ for CelebA 115 | CUDA_VISIBLE_DEVICES=0 \ 116 | python train.py \ 117 | --load_size 143 \ 118 | --crop_size 128 \ 119 | --model model_128 \ 120 | --experiment_name AttGAN_128 121 | 122 | \\ for CelebA-HQ 123 | CUDA_VISIBLE_DEVICES=0 \ 124 | python train.py \ 125 | --img_dir ./data/CelebAMask-HQ/CelebA-HQ-img \ 126 | --train_label_path ./data/CelebAMask-HQ/train_label.txt \ 127 | --val_label_path ./data/CelebAMask-HQ/val_label.txt \ 128 | --load_size 128 \ 129 | --crop_size 128 \ 130 | --n_epochs 200 \ 131 | --epoch_start_decay 100 \ 132 | --model model_128 \ 133 | --experiment_name AttGAN_128_CelebA-HQ 134 | ``` 135 | 136 | - testing 137 | 138 | - **single** attribute editing (inversion) 139 | 140 | ```console 141 | \\ for CelebA 142 | CUDA_VISIBLE_DEVICES=0 \ 143 | python test.py \ 144 | --experiment_name AttGAN_128 145 | 146 | \\ for CelebA-HQ 147 | CUDA_VISIBLE_DEVICES=0 \ 148 | python test.py \ 149 | --img_dir ./data/CelebAMask-HQ/CelebA-HQ-img \ 150 | --test_label_path ./data/CelebAMask-HQ/test_label.txt \ 151 | --experiment_name AttGAN_128_CelebA-HQ 152 | ``` 153 | 154 | 155 | - **multiple** attribute editing (inversion) example 156 | 157 | ```console 158 | \\ for CelebA 159 | CUDA_VISIBLE_DEVICES=0 \ 160 | python test_multi.py \ 161 | --test_att_names Bushy_Eyebrows Pale_Skin \ 162 | --experiment_name AttGAN_128 163 | ``` 164 | 165 | - attribute sliding example 166 | 167 | ```console 168 | \\ for CelebA 169 | CUDA_VISIBLE_DEVICES=0 \ 170 | python test_slide.py \ 171 | --test_att_name Pale_Skin \ 172 | --test_int_min -2 \ 173 | --test_int_max 2 \ 174 | --test_int_step 0.5 \ 175 | --experiment_name AttGAN_128 176 | ``` 177 | 178 | - loss visualization 179 | 180 | ```console 181 | CUDA_VISIBLE_DEVICES='' \ 182 | tensorboard \ 183 | --logdir ./output/AttGAN_128/summaries \ 184 | --port 6006 185 | ``` 186 | 187 | - convert trained model to .pb file 188 | 189 | ```console 190 | python to_pb.py --experiment_name AttGAN_128 191 | ``` 192 | 193 | - Using Trained Weights 194 | 195 | - alternative trained weights (move to **./output/\*.zip**) 196 | 197 | - [AttGAN_128.zip](https://drive.google.com/file/d/1Oy4F1xtYdxj4iyiLyaEd-dkGIJ0mwo41/view?usp=sharing) (987.5MB) 198 | 199 | - *including G, D, and the state of the optimizer* 200 | 201 | - [AttGAN_128_generator_only.zip](https://drive.google.com/file/d/1lcQ-ijNrGD4919eJ5Dv-7ja5rsx5p0Tp/view?usp=sharing) (161.5MB) 202 | 203 | - *G only* 204 | 205 | - [AttGAN_384_generator_only.zip](https://drive.google.com/open?id=1scaKWcWIpTfsV0yrWCI-wg_JDmDsKKm1) (91.1MB) 206 | 207 | 208 | - unzip the file (AttGAN_128.zip for example) 209 | 210 | ```console 211 | unzip ./output/AttGAN_128.zip -d ./output/ 212 | ``` 213 | 214 | - testing (see above) 215 | 216 | 217 | - Example for Custom Dataset 218 | 219 | - [AttGAN-Cartoon](https://github.com/LynnHo/AttGAN-Cartoon-Tensorflow) 220 | 221 | ## Citation 222 | 223 | If you find [AttGAN](https://kdocs.cn/l/cpp7J2ZsFUW7) useful in your research work, please consider citing: 224 | 225 | @ARTICLE{8718508, 226 | author={Z. {He} and W. {Zuo} and M. {Kan} and S. {Shan} and X. {Chen}}, 227 | journal={IEEE Transactions on Image Processing}, 228 | title={AttGAN: Facial Attribute Editing by Only Changing What You Want}, 229 | year={2019}, 230 | volume={28}, 231 | number={11}, 232 | pages={5464-5478}, 233 | keywords={Face;Facial features;Task analysis;Decoding;Image reconstruction;Hair;Gallium nitride;Facial attribute editing;attribute style manipulation;adversarial learning}, 234 | doi={10.1109/TIP.2019.2916751}, 235 | ISSN={1057-7149}, 236 | month={Nov},} 237 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pylib as py 3 | import tensorflow as tf 4 | import tflib as tl 5 | 6 | 7 | ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 8 | 'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 9 | 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10, 10 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 11 | 'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16, 12 | 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, 13 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 14 | 'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25, 15 | 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 16 | 'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31, 17 | 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, 18 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 19 | 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 20 | ID_ATT = {v: k for k, v in ATT_ID.items()} 21 | 22 | 23 | def make_celeba_dataset(img_dir, 24 | label_path, 25 | att_names, 26 | batch_size, 27 | load_size=286, 28 | crop_size=256, 29 | training=True, 30 | drop_remainder=True, 31 | shuffle=True, 32 | repeat=1): 33 | img_names = np.genfromtxt(label_path, dtype=str, usecols=0) 34 | img_paths = np.array([py.join(img_dir, img_name) for img_name in img_names]) 35 | labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 41)) 36 | labels = labels[:, np.array([ATT_ID[att_name] for att_name in att_names])] 37 | 38 | if shuffle: 39 | idx = np.random.permutation(len(img_paths)) 40 | img_paths = img_paths[idx] 41 | labels = labels[idx] 42 | 43 | if training: 44 | def map_fn_(img, label): 45 | img = tf.image.resize(img, [load_size, load_size]) 46 | # img = tl.random_rotate(img, 5) 47 | img = tf.image.random_flip_left_right(img) 48 | img = tf.image.random_crop(img, [crop_size, crop_size, 3]) 49 | # img = tl.color_jitter(img, 25, 0.2, 0.2, 0.1) 50 | # img = tl.random_grayscale(img, p=0.3) 51 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 52 | label = (label + 1) // 2 53 | return img, label 54 | else: 55 | def map_fn_(img, label): 56 | img = tf.image.resize(img, [load_size, load_size]) 57 | img = tl.center_crop(img, size=crop_size) 58 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1 59 | label = (label + 1) // 2 60 | return img, label 61 | 62 | dataset = tl.disk_image_batch_dataset(img_paths, 63 | batch_size, 64 | labels=labels, 65 | drop_remainder=drop_remainder, 66 | map_fn=map_fn_, 67 | shuffle=shuffle, 68 | repeat=repeat) 69 | 70 | if drop_remainder: 71 | len_dataset = len(img_paths) // batch_size 72 | else: 73 | len_dataset = int(np.ceil(len(img_paths) / batch_size)) 74 | 75 | return dataset, len_dataset 76 | 77 | 78 | def check_attribute_conflict(att_batch, att_name, att_names): 79 | def _set(att, value, att_name): 80 | if att_name in att_names: 81 | att[att_names.index(att_name)] = value 82 | 83 | idx = att_names.index(att_name) 84 | 85 | for att in att_batch: 86 | if att_name in ['Bald', 'Receding_Hairline'] and att[idx] == 1: 87 | _set(att, 0, 'Bangs') 88 | elif att_name == 'Bangs' and att[idx] == 1: 89 | _set(att, 0, 'Bald') 90 | _set(att, 0, 'Receding_Hairline') 91 | elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[idx] == 1: 92 | for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: 93 | if n != att_name: 94 | _set(att, 0, n) 95 | elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[idx] == 1: 96 | for n in ['Straight_Hair', 'Wavy_Hair']: 97 | if n != att_name: 98 | _set(att, 0, n) 99 | # elif att_name in ['Mustache', 'No_Beard'] and att[idx] == 1: # enable this part help to learn `Mustache` 100 | # for n in ['Mustache', 'No_Beard']: 101 | # if n != att_name: 102 | # _set(att, 0, n) 103 | 104 | return att_batch 105 | -------------------------------------------------------------------------------- /examples.md: -------------------------------------------------------------------------------- 1 | #

[AttGAN](https://ieeexplore.ieee.org/document/8718508?source=authoralert) Usage

2 | 3 | - training 4 | 5 | - for 128x128 images 6 | 7 | ```console 8 | CUDA_VISIBLE_DEVICES=0 \ 9 | python train.py \ 10 | --load_size 143 \ 11 | --crop_size 128 \ 12 | --model model_128 \ 13 | --experiment_name AttGAN_128 14 | ``` 15 | 16 | - for 128x128 images with all **40** attributes! 17 | 18 | ```console 19 | CUDA_VISIBLE_DEVICES=0 \ 20 | python train.py \ 21 | --load_size 143 \ 22 | --crop_size 128 \ 23 | --model model_128 \ 24 | --experiment_name AttGAN_128_40 \ 25 | --att_names 5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes \ 26 | Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry \ 27 | Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee \ 28 | Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open \ 29 | Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose \ 30 | Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair \ 31 | Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick \ 32 | Wearing_Necklace Wearing_Necktie Young 33 | ``` 34 | 35 | - for 256x256 images 36 | 37 | ```console 38 | CUDA_VISIBLE_DEVICES=0 \ 39 | python train.py \ 40 | --load_size 286 \ 41 | --crop_size 256 \ 42 | --model model_256 \ 43 | --experiment_name AttGAN_256 44 | ``` 45 | 46 | - for 384x384 images 47 | 48 | ```console 49 | CUDA_VISIBLE_DEVICES=0 \ 50 | python train.py \ 51 | --load_size 429 \ 52 | --crop_size 384 \ 53 | --model model_384 \ 54 | --experiment_name AttGAN_384 55 | ``` -------------------------------------------------------------------------------- /imlib/__init__.py: -------------------------------------------------------------------------------- 1 | from imlib.basic import * 2 | from imlib.dtype import * 3 | from imlib.transform import * 4 | -------------------------------------------------------------------------------- /imlib/basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.io as iio 3 | 4 | from imlib import dtype 5 | 6 | 7 | def imread(path, as_gray=False, **kwargs): 8 | """Return a float64 image in [-1.0, 1.0].""" 9 | image = iio.imread(path, as_gray, **kwargs) 10 | if image.dtype == np.uint8: 11 | image = image / 127.5 - 1 12 | elif image.dtype == np.uint16: 13 | image = image / 32767.5 - 1 14 | elif image.dtype in [np.float32, np.float64]: 15 | image = image * 2 - 1.0 16 | else: 17 | raise Exception("Inavailable image dtype: %s!" % image.dtype) 18 | return image 19 | 20 | 21 | def imwrite(image, path, quality=95, **plugin_args): 22 | """Save a [-1.0, 1.0] image.""" 23 | iio.imsave(path, dtype.im2uint(image), quality=quality, **plugin_args) 24 | 25 | 26 | def imshow(image): 27 | """Show a [-1.0, 1.0] image.""" 28 | iio.imshow(dtype.im2uint(image)) 29 | 30 | 31 | show = iio.show 32 | -------------------------------------------------------------------------------- /imlib/dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf): 5 | # check type 6 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!' 7 | 8 | # check dtype 9 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes] 10 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes 11 | 12 | # check nan and inf 13 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!' 14 | 15 | # check value 16 | if min_value not in [None, -np.inf]: 17 | l = '[' + str(min_value) 18 | else: 19 | l = '(-inf' 20 | min_value = -np.inf 21 | if max_value not in [None, np.inf]: 22 | r = str(max_value) + ']' 23 | else: 24 | r = 'inf)' 25 | max_value = np.inf 26 | assert np.min(images) >= min_value and np.max(images) <= max_value, \ 27 | '`images` should be in the range of %s!' % (l + ',' + r) 28 | 29 | 30 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None): 31 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype.""" 32 | _check(images, [np.float32, np.float64], -1.0, 1.0) 33 | dtype = dtype if dtype else images.dtype 34 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype) 35 | 36 | 37 | def float2im(images): 38 | """Transform images from [0, 1.0] to [-1.0, 1.0].""" 39 | _check(images, [np.float32, np.float64], 0.0, 1.0) 40 | return images * 2 - 1.0 41 | 42 | 43 | def float2uint(images): 44 | """Transform images from [0, 1.0] to uint8.""" 45 | _check(images, [np.float32, np.float64], -0.0, 1.0) 46 | return (images * 255).astype(np.uint8) 47 | 48 | 49 | def im2uint(images): 50 | """Transform images from [-1.0, 1.0] to uint8.""" 51 | return to_range(images, 0, 255, np.uint8) 52 | 53 | 54 | def im2float(images): 55 | """Transform images from [-1.0, 1.0] to [0.0, 1.0].""" 56 | return to_range(images, 0.0, 1.0) 57 | 58 | 59 | def uint2im(images): 60 | """Transform images from uint8 to [-1.0, 1.0] of float64.""" 61 | _check(images, np.uint8) 62 | return images / 127.5 - 1.0 63 | 64 | 65 | def uint2float(images): 66 | """Transform images from uint8 to [0.0, 1.0] of float64.""" 67 | _check(images, np.uint8) 68 | return images / 255.0 69 | 70 | 71 | def cv2im(images): 72 | """Transform opencv images to [-1.0, 1.0].""" 73 | images = uint2im(images) 74 | return images[..., ::-1] 75 | 76 | 77 | def im2cv(images): 78 | """Transform images from [-1.0, 1.0] to opencv images.""" 79 | images = im2uint(images) 80 | return images[..., ::-1] 81 | -------------------------------------------------------------------------------- /imlib/transform.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import skimage.color as color 5 | import skimage.transform as transform 6 | 7 | 8 | rgb2gray = color.rgb2gray 9 | gray2rgb = color.gray2rgb 10 | 11 | imresize = transform.resize 12 | imrescale = transform.rescale 13 | 14 | 15 | def immerge(images, n_rows=None, n_cols=None, padding=0, pad_value=0): 16 | """Merge images to an image with (n_rows * h) * (n_cols * w). 17 | 18 | Parameters 19 | ---------- 20 | images : numpy.array or object which can be converted to numpy.array 21 | Images in shape of N * H * W(* C=1 or 3). 22 | 23 | """ 24 | images = np.array(images) 25 | n = images.shape[0] 26 | if n_rows: 27 | n_rows = max(min(n_rows, n), 1) 28 | n_cols = int(n - 0.5) // n_rows + 1 29 | elif n_cols: 30 | n_cols = max(min(n_cols, n), 1) 31 | n_rows = int(n - 0.5) // n_cols + 1 32 | else: 33 | n_rows = int(n ** 0.5) 34 | n_cols = int(n - 0.5) // n_rows + 1 35 | 36 | h, w = images.shape[1], images.shape[2] 37 | shape = (h * n_rows + padding * (n_rows - 1), 38 | w * n_cols + padding * (n_cols - 1)) 39 | if images.ndim == 4: 40 | shape += (images.shape[3],) 41 | img = np.full(shape, pad_value, dtype=images.dtype) 42 | 43 | for idx, image in enumerate(images): 44 | i = idx % n_cols 45 | j = idx // n_cols 46 | img[j * (h + padding):j * (h + padding) + h, 47 | i * (w + padding):i * (w + padding) + w, ...] = image 48 | 49 | return img 50 | 51 | 52 | def grid_split(image, h, w): 53 | n_rows = math.ceil(image.shape[0] / h) 54 | n_cols = math.ceil(image.shape[1] / w) 55 | 56 | rows = [] 57 | for r in range(n_rows): 58 | cols = [] 59 | for c in range(n_cols): 60 | cols.append(image[r * h: (r + 1) * h, c * w: (c + 1) * w, ...]) 61 | rows.append(cols) 62 | 63 | return rows 64 | 65 | 66 | def grid_merge(grid, padding=(0, 0), pad_value=(0, 0)): 67 | padding = padding if isinstance(padding, (list, tuple)) else [padding, padding] 68 | pad_value = pad_value if isinstance(pad_value, (list, tuple)) else [pad_value, pad_value] 69 | 70 | new_rows = [] 71 | for r, row in enumerate(grid): 72 | new_cols = [] 73 | for c, col in enumerate(row): 74 | if c != 0: 75 | new_cols.append(np.full([col.shape[0], padding[1], col.shape[2]], pad_value[1], dtype=col.dtype)) 76 | new_cols.append(col) 77 | 78 | new_cols = np.concatenate(new_cols, axis=1) 79 | if r != 0: 80 | new_rows.append(np.full([padding[0], new_cols.shape[1], new_cols.shape[2]], pad_value[0], dtype=new_cols.dtype)) 81 | new_rows.append(new_cols) 82 | 83 | grid_merged = np.concatenate(new_rows, axis=0) 84 | 85 | return grid_merged 86 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | 6 | import utils 7 | 8 | 9 | conv = functools.partial(slim.conv2d, activation_fn=None) 10 | dconv = functools.partial(slim.conv2d_transpose, activation_fn=None) 11 | fc = functools.partial(slim.fully_connected, activation_fn=None) 12 | 13 | 14 | class UNetGenc: 15 | 16 | def __call__(self, x, dim=64, n_downsamplings=5, weight_decay=0.0, 17 | norm_name='batch_norm', training=True, scope='UNetGenc'): 18 | MAX_DIM = 1024 19 | 20 | conv_ = functools.partial(conv, weights_regularizer=slim.l2_regularizer(weight_decay)) 21 | norm = utils.get_norm_layer(norm_name, training, updates_collections=None) 22 | 23 | conv_norm_lrelu = functools.partial(conv_, normalizer_fn=norm, activation_fn=tf.nn.leaky_relu) 24 | 25 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 26 | z = x 27 | zs = [] 28 | for i in range(n_downsamplings): 29 | d = min(dim * 2**i, MAX_DIM) 30 | z = conv_norm_lrelu(z, d, 4, 2) 31 | zs.append(z) 32 | 33 | # variables and update operations 34 | self.variables = tf.global_variables(scope) 35 | self.trainable_variables = tf.trainable_variables(scope) 36 | self.reg_losses = tf.losses.get_regularization_losses(scope) 37 | 38 | return zs 39 | 40 | 41 | class UNetGdec: 42 | 43 | def __call__(self, zs, a, dim=64, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=0.0, 44 | norm_name='batch_norm', training=True, scope='UNetGdec'): 45 | MAX_DIM = 1024 46 | 47 | dconv_ = functools.partial(dconv, weights_regularizer=slim.l2_regularizer(weight_decay)) 48 | norm = utils.get_norm_layer(norm_name, training, updates_collections=None) 49 | 50 | dconv_norm_relu = functools.partial(dconv_, normalizer_fn=norm, activation_fn=tf.nn.relu) 51 | 52 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 53 | a = tf.to_float(a) 54 | 55 | z = utils.tile_concat(zs[-1], a) 56 | for i in range(n_upsamplings - 1): 57 | d = min(dim * 2**(n_upsamplings - 1 - i), MAX_DIM) 58 | z = dconv_norm_relu(z, d, 4, 2) 59 | if shortcut_layers > i: 60 | z = utils.tile_concat([z, zs[-2 - i]]) 61 | if inject_layers > i: 62 | z = utils.tile_concat(z, a) 63 | x = tf.nn.tanh(dconv_(z, 3, 4, 2)) 64 | 65 | # variables and update operations 66 | self.variables = tf.global_variables(scope) 67 | self.trainable_variables = tf.trainable_variables(scope) 68 | self.reg_losses = tf.losses.get_regularization_losses(scope) 69 | 70 | return x 71 | 72 | 73 | class ConvD: 74 | 75 | def __call__(self, x, n_atts, dim=64, fc_dim=1024, n_downsamplings=5, weight_decay=0.0, 76 | norm_name='instance_norm', training=True, scope='ConvD'): 77 | MAX_DIM = 1024 78 | 79 | conv_ = functools.partial(conv, weights_regularizer=slim.l2_regularizer(weight_decay)) 80 | fc_ = functools.partial(fc, weights_regularizer=slim.l2_regularizer(weight_decay)) 81 | norm = utils.get_norm_layer(norm_name, training, updates_collections=None) 82 | 83 | conv_norm_lrelu = functools.partial(conv_, normalizer_fn=norm, activation_fn=tf.nn.leaky_relu) 84 | 85 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 86 | z = x 87 | for i in range(n_downsamplings): 88 | d = min(dim * 2**i, MAX_DIM) 89 | z = conv_norm_lrelu(z, d, 4, 2) 90 | z = slim.flatten(z) 91 | 92 | logit_gan = tf.nn.leaky_relu(fc_(z, fc_dim)) 93 | logit_gan = fc_(logit_gan, 1) 94 | 95 | logit_att = tf.nn.leaky_relu(fc_(z, fc_dim)) 96 | logit_att = fc_(logit_att, n_atts) 97 | 98 | # variables and update operations 99 | self.variables = tf.global_variables(scope) 100 | self.trainable_variables = tf.trainable_variables(scope) 101 | self.reg_losses = tf.losses.get_regularization_losses(scope) 102 | 103 | return logit_gan, logit_att 104 | 105 | 106 | def get_model(name, n_atts, weight_decay=0.0): 107 | if name in ['model_128', 'model_256']: 108 | Genc = functools.partial(UNetGenc(), dim=64, n_downsamplings=5, weight_decay=weight_decay) 109 | Gdec = functools.partial(UNetGdec(), dim=64, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=weight_decay) 110 | D = functools.partial(ConvD(), n_atts=n_atts, dim=64, fc_dim=1024, n_downsamplings=5, weight_decay=weight_decay) 111 | elif name == 'model_384': 112 | Genc = functools.partial(UNetGenc(), dim=48, n_downsamplings=5, weight_decay=weight_decay) 113 | Gdec = functools.partial(UNetGdec(), dim=48, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=weight_decay) 114 | D = functools.partial(ConvD(), n_atts=n_atts, dim=48, fc_dim=512, n_downsamplings=5, weight_decay=weight_decay) 115 | return Genc, Gdec, D 116 | -------------------------------------------------------------------------------- /pics/first_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/first_view.png -------------------------------------------------------------------------------- /pics/sample_validation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation.jpg -------------------------------------------------------------------------------- /pics/sample_validation_256x.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation_256x.jpg -------------------------------------------------------------------------------- /pics/sample_validation_384x_hd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation_384x_hd.jpg -------------------------------------------------------------------------------- /pics/sample_validation_40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation_40.jpg -------------------------------------------------------------------------------- /pics/schema.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/schema.jpg -------------------------------------------------------------------------------- /pics/slide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/slide.png -------------------------------------------------------------------------------- /pics/style.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/style.jpg -------------------------------------------------------------------------------- /pylib/__init__.py: -------------------------------------------------------------------------------- 1 | from pylib.argument import * 2 | from pylib.processing import * 3 | from pylib.path import * 4 | from pylib.serialization import * 5 | from pylib.timer import * 6 | 7 | import pprint 8 | 9 | pp = pprint.pprint 10 | -------------------------------------------------------------------------------- /pylib/argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import json 4 | 5 | from pylib import serialization 6 | 7 | 8 | GLOBAL_COMMAND_PARSER = argparse.ArgumentParser() 9 | 10 | 11 | def _serialization_wrapper(func): 12 | @functools.wraps(func) 13 | def _wrapper(*args, **kwargs): 14 | to_json = kwargs.pop("to_json", None) 15 | to_yaml = kwargs.pop("to_yaml", None) 16 | namespace = func(*args, **kwargs) 17 | if to_json: 18 | args_to_json(to_json, namespace) 19 | if to_yaml: 20 | args_to_yaml(to_yaml, namespace) 21 | return namespace 22 | return _wrapper 23 | 24 | 25 | def str2bool(v): 26 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 27 | return True 28 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 29 | return False 30 | else: 31 | raise argparse.ArgumentTypeError('Boolean value expected!') 32 | 33 | 34 | def argument(*args, **kwargs): 35 | """Wrap argparse.add_argument.""" 36 | if 'type'in kwargs: 37 | if issubclass(kwargs['type'], bool): 38 | kwargs['type'] = str2bool 39 | elif issubclass(kwargs['type'], dict): 40 | kwargs['type'] = json.loads 41 | return GLOBAL_COMMAND_PARSER.add_argument(*args, **kwargs) 42 | 43 | 44 | arg = argument 45 | 46 | 47 | @_serialization_wrapper 48 | def args(args=None, namespace=None): 49 | """Parse args using the global parser.""" 50 | namespace = GLOBAL_COMMAND_PARSER.parse_args(args=args, namespace=namespace) 51 | return namespace 52 | 53 | 54 | @_serialization_wrapper 55 | def args_from_xxx(obj, parser, check=True): 56 | """Load args from xxx ignoring type and choices with default still valid. 57 | 58 | Parameters 59 | ---------- 60 | parser: function 61 | Should return a dict. 62 | 63 | """ 64 | dict_ = parser(obj) 65 | namespace = argparse.ArgumentParser().parse_args(args='') # '' for not to accept command line args 66 | for k, v in dict_.items(): 67 | namespace.__setattr__(k, v) 68 | return namespace 69 | 70 | 71 | args_from_dict = functools.partial(args_from_xxx, parser=lambda x: x) 72 | args_from_json = functools.partial(args_from_xxx, parser=serialization.load_json) 73 | args_from_yaml = functools.partial(args_from_xxx, parser=serialization.load_yaml) 74 | 75 | 76 | def args_to_json(path, namespace, **kwagrs): 77 | serialization.save_json(path, vars(namespace), **kwagrs) 78 | 79 | 80 | def args_to_yaml(path, namespace, **kwagrs): 81 | serialization.save_yaml(path, vars(namespace), **kwagrs) 82 | -------------------------------------------------------------------------------- /pylib/path.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import fnmatch 3 | import os 4 | import glob as _glob 5 | import sys 6 | 7 | 8 | def add_path(paths): 9 | if not isinstance(paths, (list, tuple)): 10 | paths = [paths] 11 | for path in paths: 12 | if path not in sys.path: 13 | sys.path.insert(0, path) 14 | 15 | 16 | def mkdir(paths): 17 | if not isinstance(paths, (list, tuple)): 18 | paths = [paths] 19 | for path in paths: 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def split(path): 25 | """Return dir, name, ext.""" 26 | dir, name_ext = os.path.split(path) 27 | name, ext = os.path.splitext(name_ext) 28 | return dir, name, ext 29 | 30 | 31 | def directory(path): 32 | return split(path)[0] 33 | 34 | 35 | def name(path): 36 | return split(path)[1] 37 | 38 | 39 | def ext(path): 40 | return split(path)[2] 41 | 42 | 43 | def name_ext(path): 44 | return ''.join(split(path)[1:]) 45 | 46 | 47 | def change_ext(path, ext): 48 | if ext[0] == '.': 49 | ext = ext[1:] 50 | return os.path.splitext(path)[0] + '.' + ext 51 | 52 | 53 | asbpath = os.path.abspath 54 | 55 | 56 | join = os.path.join 57 | 58 | 59 | def prefix(path, prefixes, sep='-'): 60 | prefixes = prefixes if isinstance(prefixes, (list, tuple)) else [prefixes] 61 | dir, name, ext = split(path) 62 | return join(dir, sep.join(prefixes) + sep + name + ext) 63 | 64 | 65 | def suffix(path, suffixes, sep='-'): 66 | suffixes = suffixes if isinstance(suffixes, (list, tuple)) else [suffixes] 67 | dir, name, ext = split(path) 68 | return join(dir, name + sep + sep.join(suffixes) + ext) 69 | 70 | 71 | def prefix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'): 72 | return prefix(path, prefixes=datetime.datetime.now().strftime(fmt), sep=sep) 73 | 74 | 75 | def suffix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'): 76 | return suffix(path, suffixes=datetime.datetime.now().strftime(fmt), sep=sep) 77 | 78 | 79 | def glob(dir, pats, recursive=False): # faster than match, python3 only 80 | pats = pats if isinstance(pats, (list, tuple)) else [pats] 81 | matches = [] 82 | for pat in pats: 83 | matches += _glob.glob(os.path.join(dir, pat), recursive=recursive) 84 | return matches 85 | 86 | 87 | def match(dir, pats, recursive=False): # slow 88 | pats = pats if isinstance(pats, (list, tuple)) else [pats] 89 | 90 | iterator = list(os.walk(dir)) 91 | if not recursive: 92 | iterator = iterator[0:1] 93 | 94 | matches = [] 95 | for pat in pats: 96 | for root, _, file_names in iterator: 97 | for file_name in fnmatch.filter(file_names, pat): 98 | matches.append(os.path.join(root, file_name)) 99 | 100 | return matches 101 | -------------------------------------------------------------------------------- /pylib/processing.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import functools 3 | import multiprocessing 4 | 5 | 6 | def run_parallels(work_fn, iterable, max_workers=None, chunksize=1, processing_bar=True, backend_executor=multiprocessing.Pool, debug=False): 7 | if not debug: 8 | with backend_executor(max_workers) as executor: 9 | try: 10 | works = executor.imap(work_fn, iterable, chunksize=chunksize) # for multiprocessing.Pool 11 | except: 12 | works = executor.map(work_fn, iterable, chunksize=chunksize) 13 | 14 | if processing_bar: 15 | try: 16 | import tqdm 17 | try: 18 | total = len(iterable) 19 | except: 20 | total = None 21 | works = tqdm.tqdm(works, total=total) 22 | except ImportError: 23 | print('`import tqdm` fails! Run without processing bar!') 24 | 25 | results = list(works) 26 | else: 27 | results = [work_fn(i) for i in iterable] 28 | return results 29 | 30 | run_parallels_mp = run_parallels 31 | run_parallels_cfprocess = functools.partial(run_parallels, backend_executor=concurrent.futures.ProcessPoolExecutor) 32 | run_parallels_cfthread = functools.partial(run_parallels, backend_executor=concurrent.futures.ThreadPoolExecutor) 33 | 34 | 35 | if __name__ == '__main__': 36 | import time 37 | 38 | def work(i): 39 | time.sleep(0.0001) 40 | i**i 41 | return i 42 | 43 | t = time.time() 44 | results = run_parallels_mp(work, range(10000), max_workers=2, chunksize=1, processing_bar=True, debug=False) 45 | for i in results: 46 | print(i) 47 | print(time.time() - t) 48 | -------------------------------------------------------------------------------- /pylib/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | 6 | def _check_ext(path, default_ext): 7 | name, ext = os.path.splitext(path) 8 | if ext == '': 9 | if default_ext[0] == '.': 10 | default_ext = default_ext[1:] 11 | path = name + '.' + default_ext 12 | return path 13 | 14 | 15 | def save_json(path, obj, **kwargs): 16 | # default 17 | if 'indent' not in kwargs: 18 | kwargs['indent'] = 4 19 | if 'separators' not in kwargs: 20 | kwargs['separators'] = (',', ': ') 21 | 22 | path = _check_ext(path, 'json') 23 | 24 | # wrap json.dump 25 | with open(path, 'w') as f: 26 | json.dump(obj, f, **kwargs) 27 | 28 | 29 | def load_json(path, **kwargs): 30 | # wrap json.load 31 | with open(path) as f: 32 | return json.load(f, **kwargs) 33 | 34 | 35 | def save_yaml(path, data, **kwargs): 36 | import oyaml as yaml 37 | 38 | path = _check_ext(path, 'yml') 39 | 40 | with open(path, 'w') as f: 41 | yaml.dump(data, f, **kwargs) 42 | 43 | 44 | def load_yaml(path, **kwargs): 45 | import oyaml as yaml 46 | with open(path) as f: 47 | return yaml.load(f, **kwargs) 48 | 49 | 50 | def save_pickle(path, obj, **kwargs): 51 | 52 | path = _check_ext(path, 'pkl') 53 | 54 | # wrap pickle.dump 55 | with open(path, 'wb') as f: 56 | pickle.dump(obj, f, **kwargs) 57 | 58 | 59 | def load_pickle(path, **kwargs): 60 | # wrap pickle.load 61 | with open(path, 'rb') as f: 62 | return pickle.load(f, **kwargs) 63 | -------------------------------------------------------------------------------- /pylib/timer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import timeit 3 | 4 | 5 | class Timer: # deprecated, use tqdm instead 6 | """A timer as a context manager. 7 | 8 | Wraps around a timer. A custom timer can be passed 9 | to the constructor. The default timer is timeit.default_timer. 10 | 11 | Note that the latter measures wall clock time, not CPU time! 12 | On Unix systems, it corresponds to time.time. 13 | On Windows systems, it corresponds to time.clock. 14 | 15 | Parameters 16 | ---------- 17 | print_at_exit : boolean 18 | If True, print when exiting context. 19 | format : str 20 | `ms`, `s` or `datetime`. 21 | 22 | References 23 | ---------- 24 | - https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py. 25 | 26 | 27 | """ 28 | 29 | def __init__(self, fmt='s', print_at_exit=True, timer=timeit.default_timer): 30 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!" 31 | self._fmt = fmt 32 | self._print_at_exit = print_at_exit 33 | self._timer = timer 34 | self.start() 35 | 36 | def __enter__(self): 37 | """Start the timer in the context manager scope.""" 38 | self.restart() 39 | return self 40 | 41 | def __exit__(self, exc_type, exc_value, exc_traceback): 42 | """Print the end time.""" 43 | if self._print_at_exit: 44 | print(str(self)) 45 | 46 | def __str__(self): 47 | return self.fmt(self.elapsed)[1] 48 | 49 | def start(self): 50 | self.start_time = self._timer() 51 | 52 | restart = start 53 | 54 | @property 55 | def elapsed(self): 56 | """Return the current elapsed time since last (re)start.""" 57 | return self._timer() - self.start_time 58 | 59 | def fmt(self, second): 60 | if self._fmt == 'ms': 61 | time_fmt = second * 1000 62 | time_str = '%s %s' % (time_fmt, self._fmt) 63 | elif self._fmt == 's': 64 | time_fmt = second 65 | time_str = '%s %s' % (time_fmt, self._fmt) 66 | elif self._fmt == 'datetime': 67 | time_fmt = datetime.timedelta(seconds=second) 68 | time_str = str(time_fmt) 69 | return time_fmt, time_str 70 | 71 | 72 | def timeit(run_times=1, **timer_kwargs): 73 | """Function decorator displaying the function execution time. 74 | 75 | All kwargs are the arguments taken by the Timer class constructor. 76 | 77 | """ 78 | # store Timer kwargs in local variable so the namespace isn't polluted 79 | # by different level args and kwargs 80 | 81 | def decorator(f): 82 | def wrapper(*args, **kwargs): 83 | timer_kwargs.update(print_at_exit=False) 84 | with Timer(**timer_kwargs) as t: 85 | for _ in range(run_times): 86 | out = f(*args, **kwargs) 87 | fmt = '[*] Execution time of function "%(function_name)s" for %(run_times)d runs is %(execution_time)s = %(execution_time_each)s * %(run_times)d [*]' 88 | context = {'function_name': f.__name__, 'run_times': run_times, 'execution_time': t, 'execution_time_each': t.fmt(t.elapsed / run_times)[1]} 89 | print(fmt % context) 90 | return out 91 | return wrapper 92 | 93 | return decorator 94 | 95 | 96 | if __name__ == "__main__": 97 | import time 98 | 99 | # 1 100 | print(1) 101 | with Timer() as t: 102 | time.sleep(1) 103 | print(t) 104 | time.sleep(1) 105 | 106 | with Timer(fmt='datetime') as t: 107 | time.sleep(1) 108 | 109 | # 2 110 | print(2) 111 | t = Timer(fmt='ms') 112 | time.sleep(2) 113 | print(t) 114 | 115 | t = Timer(fmt='datetime') 116 | time.sleep(1) 117 | print(t) 118 | 119 | # 3 120 | print(3) 121 | 122 | @timeit(run_times=5, fmt='s') 123 | def blah(): 124 | time.sleep(2) 125 | 126 | blah() 127 | -------------------------------------------------------------------------------- /results.md: -------------------------------------------------------------------------------- 1 | #

[AttGAN](https://ieeexplore.ieee.org/document/8718508?source=authoralert) Results

2 | 3 | - Results of **40**-attribute model!!! It's amazing that AttGAN still works although some attributes are not good enough 4 | 5 | from left to right: *Input, Reconstruction, 5_o_Clock_Shadow, Arched_Eyebrows, Attractive, Bags_Under_Eyes, Bald, Bangs, Big_Lips, Big_Nose, Black_Hair, Blond_Hair, Blurry, Brown_Hair, Bushy_Eyebrows, Chubby, Double_Chin, Eyeglasses, Goatee, Gray_Hair, Heavy_Makeup, High_Cheekbones, Male, Mouth_Slightly_Open, Mustache, Narrow_Eyes, No_Beard, Oval_Face, Pale_Skin, Pointy_Nose, Receding_Hairline, Rosy_Cheeks, Sideburns, Smiling, Straight_Hair, Wavy_Hair, Wearing_Earrings, Wearing_Hat, Wearing_Lipstick, Wearing_Necklace, Wearing_Necktie, Young* 6 | 7 | 8 | 9 | - Results of **256x256**-input model 10 | 11 | from left to right: *Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young* 12 | 13 | 14 | 15 | - Results of **384x384**-input model 16 | 17 | from left to right: *Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young* 18 | 19 | -------------------------------------------------------------------------------- /scripts/align.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | from functools import partial 7 | from multiprocessing import Pool 8 | import os 9 | import re 10 | 11 | import cropper 12 | import numpy as np 13 | import tqdm 14 | 15 | 16 | # ============================================================================== 17 | # = param = 18 | # ============================================================================== 19 | 20 | parser = argparse.ArgumentParser() 21 | # main 22 | parser.add_argument('--img_dir', dest='img_dir', default='./data/img_celeba/img_celeba') 23 | parser.add_argument('--save_dir', dest='save_dir', default='./data/img_celeba/aligned') 24 | parser.add_argument('--landmark_file', dest='landmark_file', default='./data/img_celeba/landmark.txt') 25 | parser.add_argument('--standard_landmark_file', dest='standard_landmark_file', default='./data/img_celeba/standard_landmark_68pts.txt') 26 | parser.add_argument('--crop_size_h', dest='crop_size_h', type=int, default=572) 27 | parser.add_argument('--crop_size_w', dest='crop_size_w', type=int, default=572) 28 | parser.add_argument('--move_h', dest='move_h', type=float, default=0.25) 29 | parser.add_argument('--move_w', dest='move_w', type=float, default=0.) 30 | parser.add_argument('--save_format', dest='save_format', choices=['jpg', 'png'], default='jpg') 31 | parser.add_argument('--n_worker', dest='n_worker', type=int, default=8) 32 | # others 33 | parser.add_argument('--face_factor', dest='face_factor', type=float, help='The factor of face area relative to the output image.', default=0.45) 34 | parser.add_argument('--align_type', dest='align_type', choices=['affine', 'similarity'], default='similarity') 35 | parser.add_argument('--order', dest='order', type=int, choices=[0, 1, 2, 3, 4, 5], help='The order of interpolation.', default=3) 36 | parser.add_argument('--mode', dest='mode', choices=['constant', 'edge', 'symmetric', 'reflect', 'wrap'], default='edge') 37 | args = parser.parse_args() 38 | 39 | 40 | # ============================================================================== 41 | # = opencv first = 42 | # ============================================================================== 43 | 44 | _DEAFAULT_JPG_QUALITY = 95 45 | try: 46 | import cv2 47 | imread = cv2.imread 48 | imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY]) 49 | align_crop = cropper.align_crop_opencv 50 | print('Use OpenCV') 51 | except: 52 | import skimage.io as io 53 | imread = io.imread 54 | imwrite = partial(io.imsave, quality=_DEAFAULT_JPG_QUALITY) 55 | align_crop = cropper.align_crop_skimage 56 | print('Importing OpenCv fails. Use scikit-image') 57 | 58 | 59 | # ============================================================================== 60 | # = run = 61 | # ============================================================================== 62 | 63 | # count landmarks 64 | with open(args.landmark_file) as f: 65 | line = f.readline() 66 | n_landmark = len(re.split('[ ]+', line)[1:]) // 2 67 | 68 | # read data 69 | img_names = np.genfromtxt(args.landmark_file, dtype=np.str, usecols=0) 70 | landmarks = np.genfromtxt(args.landmark_file, dtype=np.float, usecols=range(1, n_landmark * 2 + 1)).reshape(-1, n_landmark, 2) 71 | standard_landmark = np.genfromtxt(args.standard_landmark_file, dtype=np.float).reshape(n_landmark, 2) 72 | standard_landmark[:, 0] += args.move_w 73 | standard_landmark[:, 1] += args.move_h 74 | 75 | # data dir 76 | save_dir = os.path.join(args.save_dir, 'align_size(%d,%d)_move(%.3f,%.3f)_face_factor(%.3f)_%s' % (args.crop_size_h, args.crop_size_w, args.move_h, args.move_w, args.face_factor, args.save_format)) 77 | data_dir = os.path.join(save_dir, 'data') 78 | if not os.path.isdir(data_dir): 79 | os.makedirs(data_dir) 80 | 81 | 82 | def work(i): # a single work 83 | for _ in range(3): # try three times 84 | try: 85 | img = imread(os.path.join(args.img_dir, img_names[i])) 86 | img_crop, tformed_landmarks = align_crop(img, 87 | landmarks[i], 88 | standard_landmark, 89 | crop_size=(args.crop_size_h, args.crop_size_w), 90 | face_factor=args.face_factor, 91 | align_type=args.align_type, 92 | order=args.order, 93 | mode=args.mode) 94 | 95 | name = os.path.splitext(img_names[i])[0] + '.' + args.save_format 96 | path = os.path.join(data_dir, name) 97 | if not os.path.isdir(os.path.split(path)[0]): 98 | os.makedirs(os.path.split(path)[0]) 99 | imwrite(path, img_crop) 100 | 101 | tformed_landmarks.shape = -1 102 | name_landmark_str = ('%s' + ' %.1f' * n_landmark * 2) % ((name, ) + tuple(tformed_landmarks)) 103 | succeed = True 104 | break 105 | except: 106 | succeed = False 107 | if succeed: 108 | return name_landmark_str 109 | else: 110 | print('%s fails!' % img_names[i]) 111 | 112 | 113 | if __name__ == '__main__': 114 | pool = Pool(args.n_worker) 115 | name_landmark_strs = list(tqdm.tqdm(pool.imap(work, range(len(img_names))), total=len(img_names))) 116 | pool.close() 117 | pool.join() 118 | 119 | landmarks_path = os.path.join(save_dir, 'landmark.txt') 120 | with open(landmarks_path, 'w') as f: 121 | for name_landmark_str in name_landmark_strs: 122 | if name_landmark_str: 123 | f.write(name_landmark_str + '\n') 124 | -------------------------------------------------------------------------------- /scripts/cropper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def align_crop_opencv(img, 5 | src_landmarks, 6 | standard_landmarks, 7 | crop_size=512, 8 | face_factor=0.7, 9 | align_type='similarity', 10 | order=3, 11 | mode='edge'): 12 | """Align and crop a face image by landmarks. 13 | 14 | Arguments: 15 | img : Face image to be aligned and cropped. 16 | src_landmarks : [[x_1, y_1], ..., [x_n, y_n]]. 17 | standard_landmarks : Standard shape, should be normalized. 18 | crop_size : Output image size, should be 1. int for (crop_size, crop_size) 19 | or 2. (int, int) for (crop_size_h, crop_size_w). 20 | face_factor : The factor of face area relative to the output image. 21 | align_type : 'similarity' or 'affine'. 22 | order : The order of interpolation. The order has to be in the range 0-5: 23 | - 0: INTER_NEAREST 24 | - 1: INTER_LINEAR 25 | - 2: INTER_AREA 26 | - 3: INTER_CUBIC 27 | - 4: INTER_LANCZOS4 28 | - 5: INTER_LANCZOS4 29 | mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap']. 30 | Points outside the boundaries of the input are filled according 31 | to the given mode. 32 | """ 33 | # set OpenCV 34 | import cv2 35 | inter = {0: cv2.INTER_NEAREST, 1: cv2.INTER_LINEAR, 2: cv2.INTER_AREA, 36 | 3: cv2.INTER_CUBIC, 4: cv2.INTER_LANCZOS4, 5: cv2.INTER_LANCZOS4} 37 | border = {'constant': cv2.BORDER_CONSTANT, 'edge': cv2.BORDER_REPLICATE, 38 | 'symmetric': cv2.BORDER_REFLECT, 'reflect': cv2.BORDER_REFLECT101, 39 | 'wrap': cv2.BORDER_WRAP} 40 | 41 | # check 42 | assert align_type in ['affine', 'similarity'], 'Invalid `align_type`! Allowed: %s!' % ['affine', 'similarity'] 43 | assert order in [0, 1, 2, 3, 4, 5], 'Invalid `order`! Allowed: %s!' % [0, 1, 2, 3, 4, 5] 44 | assert mode in ['constant', 'edge', 'symmetric', 'reflect', 'wrap'], 'Invalid `mode`! Allowed: %s!' % ['constant', 'edge', 'symmetric', 'reflect', 'wrap'] 45 | 46 | # crop size 47 | if isinstance(crop_size, (list, tuple)) and len(crop_size) == 2: 48 | crop_size_h = crop_size[0] 49 | crop_size_w = crop_size[1] 50 | elif isinstance(crop_size, int): 51 | crop_size_h = crop_size_w = crop_size 52 | else: 53 | raise Exception('Invalid `crop_size`! `crop_size` should be 1. int for (crop_size, crop_size) or 2. (int, int) for (crop_size_h, crop_size_w)!') 54 | 55 | # estimate transform matrix 56 | trg_landmarks = standard_landmarks * max(crop_size_h, crop_size_w) * face_factor + np.array([crop_size_w // 2, crop_size_h // 2]) 57 | if align_type == 'affine': 58 | tform = cv2.estimateAffine2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0] 59 | else: 60 | tform = cv2.estimateAffinePartial2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0] 61 | 62 | # warp image by given transform 63 | output_shape = (crop_size_h, crop_size_w) 64 | img_crop = cv2.warpAffine(img, tform, output_shape[::-1], flags=cv2.WARP_INVERSE_MAP + inter[order], borderMode=border[mode]) 65 | 66 | # get transformed landmarks 67 | tformed_landmarks = cv2.transform(np.expand_dims(src_landmarks, axis=0), cv2.invertAffineTransform(tform))[0] 68 | 69 | return img_crop, tformed_landmarks 70 | 71 | 72 | def align_crop_skimage(img, 73 | src_landmarks, 74 | standard_landmarks, 75 | crop_size=512, 76 | face_factor=0.7, 77 | align_type='similarity', 78 | order=3, 79 | mode='edge'): 80 | """Align and crop a face image by landmarks. 81 | 82 | Arguments: 83 | img : Face image to be aligned and cropped. 84 | src_landmarks : [[x_1, y_1], ..., [x_n, y_n]]. 85 | standard_landmarks : Standard shape, should be normalized. 86 | crop_size : Output image size, should be 1. int for (crop_size, crop_size) 87 | or 2. (int, int) for (crop_size_h, crop_size_w). 88 | face_factor : The factor of face area relative to the output image. 89 | align_type : 'similarity' or 'affine'. 90 | order : The order of interpolation. The order has to be in the range 0-5: 91 | - 0: INTER_NEAREST 92 | - 1: INTER_LINEAR 93 | - 2: INTER_AREA 94 | - 3: INTER_CUBIC 95 | - 4: INTER_LANCZOS4 96 | - 5: INTER_LANCZOS4 97 | mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap']. 98 | Points outside the boundaries of the input are filled according 99 | to the given mode. 100 | """ 101 | raise NotImplementedError("'align_crop_skimage' is not implemented!") 102 | -------------------------------------------------------------------------------- /scripts/split_CelebA-HQ.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | # ============================================================================== 5 | # = param = 6 | # ============================================================================== 7 | 8 | label_file = './data/CelebAMask-HQ/CelebAMask-HQ-attribute-anno.txt' 9 | save_dir = './data/CelebAMask-HQ' 10 | 11 | 12 | # ============================================================================== 13 | # = run = 14 | # ============================================================================== 15 | 16 | with open(label_file, 'r') as f: 17 | lines = f.readlines()[2:] 18 | 19 | random.seed(100) 20 | random.shuffle(lines) 21 | 22 | lines_train = lines[:26500] 23 | lines_val = lines[26500:27000] 24 | lines_test = lines[27000:] 25 | 26 | with open(os.path.join(save_dir, 'train_label.txt'), 'w') as f: 27 | f.writelines(lines_train) 28 | 29 | with open(os.path.join(save_dir, 'val_label.txt'), 'w') as f: 30 | f.writelines(lines_val) 31 | 32 | with open(os.path.join(save_dir, 'test_label.txt'), 'w') as f: 33 | f.writelines(lines_test) 34 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imlib as im 4 | import numpy as np 5 | import pylib as py 6 | import tensorflow as tf 7 | import tflib as tl 8 | import tqdm 9 | 10 | import data 11 | import module 12 | 13 | 14 | # ============================================================================== 15 | # = param = 16 | # ============================================================================== 17 | 18 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data') 19 | py.arg('--test_label_path', default='./data/img_celeba/test_label.txt') 20 | py.arg('--test_int', type=float, default=2) 21 | 22 | 23 | py.arg('--experiment_name', default='default') 24 | args_ = py.args() 25 | 26 | # output_dir 27 | output_dir = py.join('output', args_.experiment_name) 28 | 29 | # save settings 30 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml')) 31 | args.__dict__.update(args_.__dict__) 32 | 33 | # others 34 | n_atts = len(args.att_names) 35 | 36 | sess = tl.session() 37 | sess.__enter__() # make default 38 | 39 | 40 | # ============================================================================== 41 | # = data and model = 42 | # ============================================================================== 43 | 44 | # data 45 | test_dataset, len_test_dataset = data.make_celeba_dataset(args.img_dir, args.test_label_path, args.att_names, args.n_samples, 46 | load_size=args.load_size, crop_size=args.crop_size, 47 | training=False, drop_remainder=False, shuffle=False, repeat=None) 48 | test_iter = test_dataset.make_one_shot_iterator() 49 | 50 | 51 | # ============================================================================== 52 | # = graph = 53 | # ============================================================================== 54 | 55 | def sample_graph(): 56 | # ====================================== 57 | # = graph = 58 | # ====================================== 59 | 60 | test_next = test_iter.get_next() 61 | 62 | if not os.path.exists(py.join(output_dir, 'generator.pb')): 63 | # model 64 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay) 65 | 66 | # placeholders & inputs 67 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) 68 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts]) 69 | 70 | # sample graph 71 | x = Gdec(Genc(xa, training=False), b_, training=False) 72 | else: 73 | # load freezed model 74 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'rb') as f: 75 | graph_def = tf.GraphDef() 76 | graph_def.ParseFromString(f.read()) 77 | tf.import_graph_def(graph_def, name='generator') 78 | 79 | # placeholders & inputs 80 | xa = sess.graph.get_tensor_by_name('generator/xa:0') 81 | b_ = sess.graph.get_tensor_by_name('generator/b_:0') 82 | 83 | # sample graph 84 | x = sess.graph.get_tensor_by_name('generator/xb:0') 85 | 86 | # ====================================== 87 | # = run function = 88 | # ====================================== 89 | 90 | save_dir = './output/%s/samples_testing_%s' % (args.experiment_name, '{:g}'.format(args.test_int)) 91 | py.mkdir(save_dir) 92 | 93 | def run(): 94 | cnt = 0 95 | for _ in tqdm.trange(len_test_dataset): 96 | # data for sampling 97 | xa_ipt, a_ipt = sess.run(test_next) 98 | b_ipt_list = [a_ipt] # the first is for reconstruction 99 | for i in range(n_atts): 100 | tmp = np.array(a_ipt, copy=True) 101 | tmp[:, i] = 1 - tmp[:, i] # inverse attribute 102 | tmp = data.check_attribute_conflict(tmp, args.att_names[i], args.att_names) 103 | b_ipt_list.append(tmp) 104 | 105 | x_opt_list = [xa_ipt] 106 | for i, b_ipt in enumerate(b_ipt_list): 107 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!! 108 | if i > 0: # i == 0 is for reconstruction 109 | b__ipt[..., i - 1] = b__ipt[..., i - 1] * args.test_int 110 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt}) 111 | x_opt_list.append(x_opt) 112 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4)) 113 | sample = np.reshape(sample, (sample.shape[0], -1, sample.shape[2] * sample.shape[3], sample.shape[4])) 114 | 115 | for s in sample: 116 | cnt += 1 117 | im.imwrite(s, '%s/%d.jpg' % (save_dir, cnt)) 118 | 119 | return run 120 | 121 | 122 | sample = sample_graph() 123 | 124 | 125 | # ============================================================================== 126 | # = test = 127 | # ============================================================================== 128 | 129 | # checkpoint 130 | if not os.path.exists(py.join(output_dir, 'generator.pb')): 131 | checkpoint = tl.Checkpoint( 132 | {v.name: v for v in tf.global_variables()}, 133 | py.join(output_dir, 'checkpoints'), 134 | max_to_keep=1 135 | ) 136 | checkpoint.restore().run_restore_ops() 137 | 138 | sample() 139 | 140 | sess.close() 141 | -------------------------------------------------------------------------------- /test_multi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imlib as im 4 | import numpy as np 5 | import pylib as py 6 | import tensorflow as tf 7 | import tflib as tl 8 | import tqdm 9 | 10 | import data 11 | import module 12 | 13 | 14 | # ============================================================================== 15 | # = param = 16 | # ============================================================================== 17 | 18 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data') 19 | py.arg('--test_label_path', default='./data/img_celeba/test_label.txt') 20 | py.arg('--test_att_names', choices=data.ATT_ID.keys(), nargs='+', default=['Bangs', 'Mustache']) 21 | py.arg('--test_ints', type=float, nargs='+', default=2) 22 | 23 | py.arg('--experiment_name', default='default') 24 | args_ = py.args() 25 | 26 | # output_dir 27 | output_dir = py.join('output', args_.experiment_name) 28 | 29 | # save settings 30 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml')) 31 | args.__dict__.update(args_.__dict__) 32 | 33 | # others 34 | n_atts = len(args.att_names) 35 | if not isinstance(args.test_ints, list): 36 | args.test_ints = [args.test_ints] * len(args.test_att_names) 37 | elif len(args.test_ints) == 1: 38 | args.test_ints = args.test_ints * len(args.test_att_names) 39 | 40 | sess = tl.session() 41 | sess.__enter__() # make default 42 | 43 | 44 | # ============================================================================== 45 | # = data and model = 46 | # ============================================================================== 47 | 48 | # data 49 | test_dataset, len_test_dataset = data.make_celeba_dataset(args.img_dir, args.test_label_path, args.att_names, args.n_samples, 50 | load_size=args.load_size, crop_size=args.crop_size, 51 | training=False, drop_remainder=False, shuffle=False, repeat=None) 52 | test_iter = test_dataset.make_one_shot_iterator() 53 | 54 | 55 | # ============================================================================== 56 | # = graph = 57 | # ============================================================================== 58 | 59 | def sample_graph(): 60 | # ====================================== 61 | # = graph = 62 | # ====================================== 63 | 64 | test_next = test_iter.get_next() 65 | 66 | if not os.path.exists(py.join(output_dir, 'generator.pb')): 67 | # model 68 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay) 69 | 70 | # placeholders & inputs 71 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) 72 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts]) 73 | 74 | # sample graph 75 | x = Gdec(Genc(xa, training=False), b_, training=False) 76 | else: 77 | # load freezed model 78 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'rb') as f: 79 | graph_def = tf.GraphDef() 80 | graph_def.ParseFromString(f.read()) 81 | tf.import_graph_def(graph_def, name='generator') 82 | 83 | # placeholders & inputs 84 | xa = sess.graph.get_tensor_by_name('generator/xa:0') 85 | b_ = sess.graph.get_tensor_by_name('generator/b_:0') 86 | 87 | # sample graph 88 | x = sess.graph.get_tensor_by_name('generator/xb:0') 89 | 90 | # ====================================== 91 | # = run function = 92 | # ====================================== 93 | 94 | save_dir = './output/%s/samples_testing_multi' % args.experiment_name 95 | tmp = '' 96 | for test_att_name, test_int in zip(args.test_att_names, args.test_ints): 97 | tmp += '_%s_%s' % (test_att_name, '{:g}'.format(test_int)) 98 | save_dir = py.join(save_dir, tmp[1:]) 99 | py.mkdir(save_dir) 100 | 101 | def run(): 102 | cnt = 0 103 | for _ in tqdm.trange(len_test_dataset): 104 | # data for sampling 105 | xa_ipt, a_ipt = sess.run(test_next) 106 | b_ipt = np.copy(a_ipt) 107 | for test_att_name in args.test_att_names: 108 | i = args.att_names.index(test_att_name) 109 | b_ipt[..., i] = 1 - b_ipt[..., i] 110 | b_ipt = data.check_attribute_conflict(b_ipt, test_att_name, args.att_names) 111 | 112 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!! 113 | for test_att_name, test_int in zip(args.test_att_names, args.test_ints): 114 | i = args.att_names.index(test_att_name) 115 | b__ipt[..., i] = b__ipt[..., i] * test_int 116 | 117 | x_opt_list = [xa_ipt] 118 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt}) 119 | x_opt_list.append(x_opt) 120 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4)) 121 | sample = np.reshape(sample, (sample.shape[0], -1, sample.shape[2] * sample.shape[3], sample.shape[4])) 122 | 123 | for s in sample: 124 | cnt += 1 125 | im.imwrite(s, '%s/%d.jpg' % (save_dir, cnt)) 126 | 127 | return run 128 | 129 | 130 | sample = sample_graph() 131 | 132 | 133 | # ============================================================================== 134 | # = test = 135 | # ============================================================================== 136 | 137 | # checkpoint 138 | if not os.path.exists(py.join(output_dir, 'generator.pb')): 139 | checkpoint = tl.Checkpoint( 140 | {v.name: v for v in tf.global_variables()}, 141 | py.join(output_dir, 'checkpoints'), 142 | max_to_keep=1 143 | ) 144 | checkpoint.restore().run_restore_ops() 145 | 146 | sample() 147 | 148 | sess.close() 149 | -------------------------------------------------------------------------------- /test_slide.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imlib as im 4 | import numpy as np 5 | import pylib as py 6 | import tensorflow as tf 7 | import tflib as tl 8 | import tqdm 9 | 10 | import data 11 | import module 12 | 13 | 14 | # ============================================================================== 15 | # = param = 16 | # ============================================================================== 17 | 18 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data') 19 | py.arg('--test_label_path', default='./data/img_celeba/test_label.txt') 20 | py.arg('--test_att_name', choices=data.ATT_ID.keys(), default='Pale_Skin') 21 | py.arg('--test_int_min', type=float, default=-2) 22 | py.arg('--test_int_max', type=float, default=2) 23 | py.arg('--test_int_step', type=float, default=0.5) 24 | 25 | py.arg('--experiment_name', default='default') 26 | args_ = py.args() 27 | 28 | # output_dir 29 | output_dir = py.join('output', args_.experiment_name) 30 | 31 | # save settings 32 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml')) 33 | args.__dict__.update(args_.__dict__) 34 | 35 | # others 36 | n_atts = len(args.att_names) 37 | 38 | sess = tl.session() 39 | sess.__enter__() # make default 40 | 41 | 42 | # ============================================================================== 43 | # = data and model = 44 | # ============================================================================== 45 | 46 | # data 47 | test_dataset, len_test_dataset = data.make_celeba_dataset(args.img_dir, args.test_label_path, args.att_names, args.n_samples, 48 | load_size=args.load_size, crop_size=args.crop_size, 49 | training=False, drop_remainder=False, shuffle=False, repeat=None) 50 | test_iter = test_dataset.make_one_shot_iterator() 51 | 52 | 53 | # ============================================================================== 54 | # = graph = 55 | # ============================================================================== 56 | 57 | def sample_graph(): 58 | # ====================================== 59 | # = graph = 60 | # ====================================== 61 | 62 | test_next = test_iter.get_next() 63 | 64 | if not os.path.exists(py.join(output_dir, 'generator.pb')): 65 | # model 66 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay) 67 | 68 | # placeholders & inputs 69 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) 70 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts]) 71 | 72 | # sample graph 73 | x = Gdec(Genc(xa, training=False), b_, training=False) 74 | else: 75 | # load freezed model 76 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'rb') as f: 77 | graph_def = tf.GraphDef() 78 | graph_def.ParseFromString(f.read()) 79 | tf.import_graph_def(graph_def, name='generator') 80 | 81 | # placeholders & inputs 82 | xa = sess.graph.get_tensor_by_name('generator/xa:0') 83 | b_ = sess.graph.get_tensor_by_name('generator/b_:0') 84 | 85 | # sample graph 86 | x = sess.graph.get_tensor_by_name('generator/xb:0') 87 | 88 | # ====================================== 89 | # = run function = 90 | # ====================================== 91 | 92 | save_dir = './output/%s/samples_testing_slide/%s_%s_%s_%s' % \ 93 | (args.experiment_name, args.test_att_name, '{:g}'.format(args.test_int_min), '{:g}'.format(args.test_int_max), '{:g}'.format(args.test_int_step)) 94 | py.mkdir(save_dir) 95 | 96 | def run(): 97 | cnt = 0 98 | for _ in tqdm.trange(len_test_dataset): 99 | # data for sampling 100 | xa_ipt, a_ipt = sess.run(test_next) 101 | b_ipt = np.copy(a_ipt) 102 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!! 103 | 104 | x_opt_list = [xa_ipt] 105 | for test_int in np.arange(args.test_int_min, args.test_int_max + 1e-5, args.test_int_step): 106 | b__ipt[:, args.att_names.index(args.test_att_name)] = test_int 107 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt}) 108 | x_opt_list.append(x_opt) 109 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4)) 110 | sample = np.reshape(sample, (sample.shape[0], -1, sample.shape[2] * sample.shape[3], sample.shape[4])) 111 | 112 | for s in sample: 113 | cnt += 1 114 | im.imwrite(s, '%s/%d.jpg' % (save_dir, cnt)) 115 | 116 | return run 117 | 118 | 119 | sample = sample_graph() 120 | 121 | 122 | # ============================================================================== 123 | # = test = 124 | # ============================================================================== 125 | 126 | # checkpoint 127 | if not os.path.exists(py.join(output_dir, 'generator.pb')): 128 | checkpoint = tl.Checkpoint( 129 | {v.name: v for v in tf.global_variables()}, 130 | py.join(output_dir, 'checkpoints'), 131 | max_to_keep=1 132 | ) 133 | checkpoint.restore().run_restore_ops() 134 | 135 | sample() 136 | 137 | sess.close() 138 | -------------------------------------------------------------------------------- /tflib/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.data import * 2 | from tflib.image import * 3 | from tflib.layers import * 4 | from tflib.losses import * 5 | from tflib.metrics import * 6 | from tflib.ops import * 7 | from tflib.utils import * 8 | -------------------------------------------------------------------------------- /tflib/data/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.data.dataset import * 2 | -------------------------------------------------------------------------------- /tflib/data/dataset.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def batch_dataset(dataset, 7 | batch_size, 8 | drop_remainder=True, 9 | n_prefetch_batch=1, 10 | filter_fn=None, 11 | map_fn=None, 12 | n_map_threads=None, 13 | filter_after_map=False, 14 | shuffle=True, 15 | shuffle_buffer_size=None, 16 | repeat=None): 17 | # set defaults 18 | if n_map_threads is None: 19 | n_map_threads = multiprocessing.cpu_count() 20 | if shuffle and shuffle_buffer_size is None: 21 | shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048 22 | 23 | # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly 24 | if shuffle: 25 | dataset = dataset.shuffle(shuffle_buffer_size) 26 | 27 | if not filter_after_map: 28 | if filter_fn: 29 | dataset = dataset.filter(filter_fn) 30 | 31 | if map_fn: 32 | dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) 33 | 34 | else: # [*] this is slower 35 | if map_fn: 36 | dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) 37 | 38 | if filter_fn: 39 | dataset = dataset.filter(filter_fn) 40 | 41 | dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) 42 | 43 | dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch) 44 | 45 | return dataset 46 | 47 | 48 | def memory_data_batch_dataset(memory_data, 49 | batch_size, 50 | drop_remainder=True, 51 | n_prefetch_batch=1, 52 | filter_fn=None, 53 | map_fn=None, 54 | n_map_threads=None, 55 | filter_after_map=False, 56 | shuffle=True, 57 | shuffle_buffer_size=None, 58 | repeat=None): 59 | """Batch dataset of memory data. 60 | 61 | Parameters 62 | ---------- 63 | memory_data : nested structure of tensors/ndarrays/lists 64 | 65 | """ 66 | dataset = tf.data.Dataset.from_tensor_slices(memory_data) 67 | dataset = batch_dataset(dataset, 68 | batch_size, 69 | drop_remainder=drop_remainder, 70 | n_prefetch_batch=n_prefetch_batch, 71 | filter_fn=filter_fn, 72 | map_fn=map_fn, 73 | n_map_threads=n_map_threads, 74 | filter_after_map=filter_after_map, 75 | shuffle=shuffle, 76 | shuffle_buffer_size=shuffle_buffer_size, 77 | repeat=repeat) 78 | return dataset 79 | 80 | 81 | def disk_image_batch_dataset(img_paths, 82 | batch_size, 83 | labels=None, 84 | drop_remainder=True, 85 | n_prefetch_batch=1, 86 | filter_fn=None, 87 | map_fn=None, 88 | n_map_threads=None, 89 | filter_after_map=False, 90 | shuffle=True, 91 | shuffle_buffer_size=None, 92 | repeat=None): 93 | """Batch dataset of disk image for PNG and JPEG. 94 | 95 | Parameters 96 | ---------- 97 | img_paths : 1d-tensor/ndarray/list of str 98 | labels : nested structure of tensors/ndarrays/lists 99 | 100 | """ 101 | if labels is None: 102 | memory_data = img_paths 103 | else: 104 | memory_data = (img_paths, labels) 105 | 106 | def parse_fn(path, *label): 107 | img = tf.io.read_file(path) 108 | img = tf.image.decode_png(img, 3) # fix channels to 3 109 | return (img,) + label 110 | 111 | if map_fn: # fuse `map_fn` and `parse_fn` 112 | def map_fn_(*args): 113 | return map_fn(*parse_fn(*args)) 114 | else: 115 | map_fn_ = parse_fn 116 | 117 | dataset = memory_data_batch_dataset(memory_data, 118 | batch_size, 119 | drop_remainder=drop_remainder, 120 | n_prefetch_batch=n_prefetch_batch, 121 | filter_fn=filter_fn, 122 | map_fn=map_fn_, 123 | n_map_threads=n_map_threads, 124 | filter_after_map=filter_after_map, 125 | shuffle=shuffle, 126 | shuffle_buffer_size=shuffle_buffer_size, 127 | repeat=repeat) 128 | 129 | return dataset 130 | -------------------------------------------------------------------------------- /tflib/image/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.image.image import * 2 | from tflib.image.filter import * 3 | -------------------------------------------------------------------------------- /tflib/image/filter.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def gaussian_kernel2d(kernel_radias, std): 5 | d = tf.distributions.Normal(0.0, float(std)) 6 | vals = d.prob(tf.range(start=-kernel_radias, limit=kernel_radias + 1, dtype=tf.float32)) 7 | kernel = vals[:, None] * vals[None, :] 8 | kernel /= tf.reduce_sum(kernel) 9 | return kernel 10 | 11 | 12 | def filter2d(image, kernel, padding, data_format=None): 13 | kernel = kernel[:, :, None, None] 14 | if data_format is None or data_format == "NHWC": 15 | kernel = tf.tile(kernel, [1, 1, image.shape[3], 1]) 16 | elif data_format == "NCHW": 17 | kernel = tf.tile(kernel, [1, 1, image.shape[1], 1]) 18 | return tf.nn.depthwise_conv2d(image, kernel, strides=[1, 1, 1, 1], padding=padding, data_format=data_format) 19 | 20 | 21 | def gaussian_filter2d(image, kernel_radias, std, padding, data_format=None): 22 | kernel = gaussian_kernel2d(kernel_radias, std) 23 | return filter2d(image, kernel, padding, data_format=None) 24 | -------------------------------------------------------------------------------- /tflib/image/image.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | import random 4 | 5 | import tensorflow as tf 6 | 7 | 8 | def center_crop(image, size): 9 | # for image of shape [batch, height, width, channels] or [height, width, channels] 10 | if not isinstance(size, (tuple, list)): 11 | size = [size, size] 12 | offset_height = (tf.shape(image)[-3] - size[0]) // 2 13 | offset_width = (tf.shape(image)[-2] - size[1]) // 2 14 | return tf.image.crop_to_bounding_box(image, offset_height, offset_width, size[0], size[1]) 15 | 16 | 17 | def color_jitter(image, brightness=0, contrast=0, saturation=0, hue=0): 18 | """Color jitter. 19 | 20 | Examples 21 | -------- 22 | >>> color_jitter(img, 25, 0.2, 0.2, 0.1) 23 | 24 | """ 25 | tforms = [] 26 | if brightness > 0: 27 | tforms.append(functools.partial(tf.image.random_brightness, max_delta=brightness)) 28 | if contrast > 0: 29 | tforms.append(functools.partial(tf.image.random_contrast, lower=max(0, 1 - contrast), upper=1 + contrast)) 30 | if saturation > 0: 31 | tforms.append(functools.partial(tf.image.random_saturation, lower=max(0, 1 - saturation), upper=1 + saturation)) 32 | if hue > 0: 33 | tforms.append(functools.partial(tf.image.random_hue, max_delta=hue)) 34 | 35 | random.shuffle(tforms) 36 | for tform in tforms: 37 | image = tform(image) 38 | 39 | return image 40 | 41 | 42 | def random_grayscale(image, p=0.1): 43 | return tf.cond(pred=tf.random.uniform(()) < p, 44 | true_fn=lambda: tf.image.adjust_saturation(image, 0), 45 | false_fn=lambda: image) 46 | 47 | 48 | def random_rotate(images, max_degrees, interpolation='BILINEAR'): 49 | # Randomly rotate image(s) counterclockwise by the angle(s) uniformly chosen from [-max_degree(s), max_degree(s)]. 50 | max_degrees = tf.convert_to_tensor(max_degrees, dtype=tf.float32) 51 | angles = tf.random.uniform(tf.shape(max_degrees), minval=-1.0, maxval=1.0) * max_degrees / 180.0 * math.pi 52 | return tf.contrib.image.rotate(images, angles, interpolation=interpolation) 53 | -------------------------------------------------------------------------------- /tflib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.layers.layers import * 2 | from tflib.layers.layers_slim import * 3 | -------------------------------------------------------------------------------- /tflib/layers/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def adaptive_instance_normalization(x, gamma, beta, epsilon=1e-5): 5 | # modified from https://github.com/taki0112/MUNIT-Tensorflow/blob/master/ops.py 6 | # x: (N, H, W, C), gamma: (N, 1, 1, C), beta: (N, 1, 1, C) 7 | 8 | c_mean, c_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 9 | c_std = tf.sqrt(c_var + epsilon) 10 | 11 | return gamma * ((x - c_mean) / c_std) + beta 12 | -------------------------------------------------------------------------------- /tflib/layers/layers_slim.py: -------------------------------------------------------------------------------- 1 | # functions compatible with tensorflow.contrib 2 | 3 | import six 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib.framework.python.ops import add_arg_scope 8 | from tensorflow.contrib.framework.python.ops import variables 9 | from tensorflow.contrib.layers.python.layers import initializers 10 | from tensorflow.contrib.layers.python.layers import utils 11 | 12 | from tensorflow.python.framework import ops 13 | from tensorflow.python.ops import array_ops 14 | from tensorflow.python.ops import init_ops 15 | from tensorflow.python.ops import nn 16 | from tensorflow.python.ops import standard_ops 17 | from tensorflow.python.ops import variable_scope 18 | 19 | 20 | @add_arg_scope 21 | def fully_connected(inputs, 22 | num_outputs, 23 | activation_fn=nn.relu, 24 | normalizer_fn=None, 25 | normalizer_params=None, 26 | weights_normalizer_fn=None, 27 | weights_normalizer_params=None, 28 | weights_initializer=initializers.xavier_initializer(), 29 | weights_regularizer=None, 30 | biases_initializer=init_ops.zeros_initializer(), 31 | biases_regularizer=None, 32 | reuse=None, 33 | variables_collections=None, 34 | outputs_collections=None, 35 | trainable=True, 36 | scope=None): 37 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.fully_connected, 38 | # add weights_nomalizer_* options. 39 | """Adds a fully connected layer. 40 | 41 | `fully_connected` creates a variable called `weights`, representing a fully 42 | connected weight matrix, which is multiplied by the `inputs` to produce a 43 | `Tensor` of hidden units. If a `normalizer_fn` is provided (such as 44 | `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is 45 | None and a `biases_initializer` is provided then a `biases` variable would be 46 | created and added the hidden units. Finally, if `activation_fn` is not `None`, 47 | it is applied to the hidden units as well. 48 | 49 | Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened 50 | prior to the initial matrix multiply by `weights`. 51 | 52 | Args: 53 | inputs: A tensor of with at least rank 2 and value for the last dimension, 54 | i.e. `[batch_size, depth]`, `[None, None, None, channels]`. 55 | num_outputs: Integer or long, the number of output units in the layer. 56 | activation_fn: activation function, set to None to skip it and maintain 57 | a linear activation. 58 | normalizer_fn: normalization function to use instead of `biases`. If 59 | `normalizer_fn` is provided then `biases_initializer` and 60 | `biases_regularizer` are ignored and `biases` are not created nor added. 61 | default set to None for no normalizer function 62 | normalizer_params: normalization function parameters. 63 | weights_normalizer_fn: weights normalization function. 64 | weights_normalizer_params: weights normalization function parameters. 65 | weights_initializer: An initializer for the weights. 66 | weights_regularizer: Optional regularizer for the weights. 67 | biases_initializer: An initializer for the biases. If None skip biases. 68 | biases_regularizer: Optional regularizer for the biases. 69 | reuse: whether or not the layer and its variables should be reused. To be 70 | able to reuse the layer scope must be given. 71 | variables_collections: Optional list of collections for all the variables or 72 | a dictionary containing a different list of collections per variable. 73 | outputs_collections: collection to add the outputs. 74 | trainable: If `True` also add variables to the graph collection 75 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 76 | scope: Optional scope for variable_scope. 77 | 78 | Returns: 79 | the tensor variable representing the result of the series of operations. 80 | 81 | Raises: 82 | ValueError: if x has rank less than 2 or if its last dimension is not set. 83 | """ 84 | if not (isinstance(num_outputs, six.integer_types)): 85 | raise ValueError('num_outputs should be int or long, got %s.', num_outputs) 86 | with variable_scope.variable_scope(scope, 'fully_connected', [inputs], 87 | reuse=reuse) as sc: 88 | inputs = ops.convert_to_tensor(inputs) 89 | dtype = inputs.dtype.base_dtype 90 | inputs_shape = inputs.get_shape() 91 | num_input_units = utils.last_dimension(inputs_shape, min_rank=2) 92 | 93 | static_shape = inputs_shape.as_list() 94 | static_shape[-1] = num_outputs 95 | 96 | out_shape = array_ops.unpack(array_ops.shape(inputs), len(static_shape)) 97 | out_shape[-1] = num_outputs 98 | 99 | weights_shape = [num_input_units, num_outputs] 100 | weights_collections = utils.get_variable_collections( 101 | variables_collections, 'weights') 102 | weights = variables.model_variable('weights', 103 | shape=weights_shape, 104 | dtype=dtype, 105 | initializer=weights_initializer, 106 | regularizer=weights_regularizer, 107 | collections=weights_collections, 108 | trainable=trainable) 109 | if weights_normalizer_fn is not None: 110 | weights_normalizer_params = weights_normalizer_params or {} 111 | weights = weights_normalizer_fn(weights, **weights_normalizer_params) 112 | if len(static_shape) > 2: 113 | # Reshape inputs 114 | inputs = array_ops.reshape(inputs, [-1, num_input_units]) 115 | outputs = standard_ops.matmul(inputs, weights) 116 | if normalizer_fn is not None: 117 | normalizer_params = normalizer_params or {} 118 | outputs = normalizer_fn(outputs, **normalizer_params) 119 | else: 120 | if biases_initializer is not None: 121 | biases_collections = utils.get_variable_collections( 122 | variables_collections, 'biases') 123 | biases = variables.model_variable('biases', 124 | shape=[num_outputs, ], 125 | dtype=dtype, 126 | initializer=biases_initializer, 127 | regularizer=biases_regularizer, 128 | collections=biases_collections, 129 | trainable=trainable) 130 | outputs = nn.bias_add(outputs, biases) 131 | if activation_fn is not None: 132 | outputs = activation_fn(outputs) 133 | if len(static_shape) > 2: 134 | # Reshape back outputs 135 | outputs = array_ops.reshape(outputs, array_ops.pack(out_shape)) 136 | outputs.set_shape(static_shape) 137 | return utils.collect_named_outputs(outputs_collections, 138 | sc.original_name_scope, outputs) 139 | 140 | 141 | @add_arg_scope 142 | def convolution(inputs, 143 | num_outputs, 144 | kernel_size, 145 | stride=1, 146 | padding='SAME', 147 | data_format=None, 148 | rate=1, 149 | activation_fn=nn.relu, 150 | normalizer_fn=None, 151 | normalizer_params=None, 152 | weights_normalizer_fn=None, 153 | weights_normalizer_params=None, 154 | weights_initializer=initializers.xavier_initializer(), 155 | weights_regularizer=None, 156 | biases_initializer=init_ops.zeros_initializer(), 157 | biases_regularizer=None, 158 | reuse=None, 159 | variables_collections=None, 160 | outputs_collections=None, 161 | trainable=True, 162 | scope=None): 163 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.convolution, 164 | # add weights_nomalizer_* options. 165 | """Adds an N-D convolution followed by an optional batch_norm layer. 166 | 167 | It is required that 1 <= N <= 3. 168 | 169 | `convolution` creates a variable called `weights`, representing the 170 | convolutional kernel, that is convolved (actually cross-correlated) with the 171 | `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is 172 | provided (such as `batch_norm`), it is then applied. Otherwise, if 173 | `normalizer_fn` is None and a `biases_initializer` is provided then a `biases` 174 | variable would be created and added the activations. Finally, if 175 | `activation_fn` is not `None`, it is applied to the activations as well. 176 | 177 | Performs a'trous convolution with input stride/dilation rate equal to `rate` 178 | if a value > 1 for any dimension of `rate` is specified. In this case 179 | `stride` values != 1 are not supported. 180 | 181 | Args: 182 | inputs: a Tensor of rank N+2 of shape 183 | `[batch_size] + input_spatial_shape + [in_channels]` if data_format does 184 | not start with "NC" (default), or 185 | `[batch_size, in_channels] + input_spatial_shape` if data_format starts 186 | with "NC". 187 | num_outputs: integer, the number of output filters. 188 | kernel_size: a sequence of N positive integers specifying the spatial 189 | dimensions of of the filters. Can be a single integer to specify the same 190 | value for all spatial dimensions. 191 | stride: a sequence of N positive integers specifying the stride at which to 192 | compute output. Can be a single integer to specify the same value for all 193 | spatial dimensions. Specifying any `stride` value != 1 is incompatible 194 | with specifying any `rate` value != 1. 195 | padding: one of `"VALID"` or `"SAME"`. 196 | data_format: A string or None. Specifies whether the channel dimension of 197 | the `input` and output is the last dimension (default, or if `data_format` 198 | does not start with "NC"), or the second dimension (if `data_format` 199 | starts with "NC"). For N=1, the valid values are "NWC" (default) and 200 | "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For 201 | N=3, currently the only valid value is "NDHWC". 202 | rate: a sequence of N positive integers specifying the dilation rate to use 203 | for a'trous convolution. Can be a single integer to specify the same 204 | value for all spatial dimensions. Specifying any `rate` value != 1 is 205 | incompatible with specifying any `stride` value != 1. 206 | activation_fn: activation function, set to None to skip it and maintain 207 | a linear activation. 208 | normalizer_fn: normalization function to use instead of `biases`. If 209 | `normalizer_fn` is provided then `biases_initializer` and 210 | `biases_regularizer` are ignored and `biases` are not created nor added. 211 | default set to None for no normalizer function 212 | normalizer_params: normalization function parameters. 213 | weights_normalizer_fn: weights normalization function. 214 | weights_normalizer_params: weights normalization function parameters. 215 | weights_initializer: An initializer for the weights. 216 | weights_regularizer: Optional regularizer for the weights. 217 | biases_initializer: An initializer for the biases. If None skip biases. 218 | biases_regularizer: Optional regularizer for the biases. 219 | reuse: whether or not the layer and its variables should be reused. To be 220 | able to reuse the layer scope must be given. 221 | variables_collections: optional list of collections for all the variables or 222 | a dictionary containing a different list of collection per variable. 223 | outputs_collections: collection to add the outputs. 224 | trainable: If `True` also add variables to the graph collection 225 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 226 | scope: Optional scope for `variable_scope`. 227 | 228 | Returns: 229 | a tensor representing the output of the operation. 230 | 231 | Raises: 232 | ValueError: if `data_format` is invalid. 233 | ValueError: both 'rate' and `stride` are not uniformly 1. 234 | """ 235 | if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC']: 236 | raise ValueError('Invalid data_format: %r' % (data_format,)) 237 | with variable_scope.variable_scope(scope, 'Conv', [inputs], 238 | reuse=reuse) as sc: 239 | inputs = ops.convert_to_tensor(inputs) 240 | dtype = inputs.dtype.base_dtype 241 | input_rank = inputs.get_shape().ndims 242 | if input_rank is None: 243 | raise ValueError('Rank of inputs must be known') 244 | if input_rank < 3 or input_rank > 5: 245 | raise ValueError('Rank of inputs is %d, which is not >= 3 and <= 5' % 246 | input_rank) 247 | conv_dims = input_rank - 2 248 | kernel_size = utils.n_positive_integers(conv_dims, kernel_size) 249 | stride = utils.n_positive_integers(conv_dims, stride) 250 | rate = utils.n_positive_integers(conv_dims, rate) 251 | 252 | if data_format is None or data_format.endswith('C'): 253 | num_input_channels = inputs.get_shape()[input_rank - 1].value 254 | elif data_format.startswith('NC'): 255 | num_input_channels = inputs.get_shape()[1].value 256 | else: 257 | raise ValueError('Invalid data_format') 258 | 259 | if num_input_channels is None: 260 | raise ValueError('Number of in_channels must be known.') 261 | 262 | weights_shape = ( 263 | list(kernel_size) + [num_input_channels, num_outputs]) 264 | weights_collections = utils.get_variable_collections(variables_collections, 265 | 'weights') 266 | weights = variables.model_variable('weights', 267 | shape=weights_shape, 268 | dtype=dtype, 269 | initializer=weights_initializer, 270 | regularizer=weights_regularizer, 271 | collections=weights_collections, 272 | trainable=trainable) 273 | if weights_normalizer_fn is not None: 274 | weights_normalizer_params = weights_normalizer_params or {} 275 | weights = weights_normalizer_fn(weights, **weights_normalizer_params) 276 | outputs = nn.convolution(input=inputs, 277 | filter=weights, 278 | dilation_rate=rate, 279 | strides=stride, 280 | padding=padding, 281 | data_format=data_format) 282 | if normalizer_fn is not None: 283 | normalizer_params = normalizer_params or {} 284 | outputs = normalizer_fn(outputs, **normalizer_params) 285 | else: 286 | if biases_initializer is not None: 287 | biases_collections = utils.get_variable_collections( 288 | variables_collections, 'biases') 289 | biases = variables.model_variable('biases', 290 | shape=[num_outputs], 291 | dtype=dtype, 292 | initializer=biases_initializer, 293 | regularizer=biases_regularizer, 294 | collections=biases_collections, 295 | trainable=trainable) 296 | outputs = nn.bias_add(outputs, biases, data_format=data_format) 297 | if activation_fn is not None: 298 | outputs = activation_fn(outputs) 299 | return utils.collect_named_outputs(outputs_collections, 300 | sc.original_name_scope, outputs) 301 | 302 | 303 | convolution2d = convolution 304 | convolution3d = convolution 305 | 306 | 307 | @add_arg_scope 308 | def spectral_normalization(weights, 309 | num_iterations=1, 310 | epsilon=1e-12, 311 | u_initializer=tf.random_normal_initializer(), 312 | updates_collections=tf.GraphKeys.UPDATE_OPS, 313 | is_training=True, 314 | reuse=None, 315 | variables_collections=None, 316 | outputs_collections=None, 317 | scope=None): 318 | with tf.variable_scope(scope, 'SpectralNorm', [weights], reuse=reuse) as sc: 319 | weights = tf.convert_to_tensor(weights) 320 | 321 | dtype = weights.dtype.base_dtype 322 | 323 | w_t = tf.reshape(weights, [-1, weights.shape.as_list()[-1]]) 324 | w = tf.transpose(w_t) 325 | m, n = w.shape.as_list() 326 | 327 | u_collections = utils.get_variable_collections(variables_collections, 'u') 328 | u = tf.get_variable("u", 329 | shape=[m, 1], 330 | dtype=dtype, 331 | initializer=u_initializer, 332 | trainable=False, 333 | collections=u_collections,) 334 | sigma_collections = utils.get_variable_collections(variables_collections, 'sigma') 335 | sigma = tf.get_variable('sigma', 336 | shape=[], 337 | dtype=dtype, 338 | initializer=tf.zeros_initializer(), 339 | trainable=False, 340 | collections=sigma_collections) 341 | 342 | def _power_iteration(i, u, v): 343 | v_ = tf.nn.l2_normalize(tf.matmul(w_t, u), epsilon=epsilon) 344 | u_ = tf.nn.l2_normalize(tf.matmul(w, v_), epsilon=epsilon) 345 | return i + 1, u_, v_ 346 | 347 | _, u_, v_ = tf.while_loop( 348 | cond=lambda i, _1, _2: i < num_iterations, 349 | body=_power_iteration, 350 | loop_vars=[tf.constant(0), u, tf.zeros(shape=[n, 1], dtype=tf.float32)] 351 | ) 352 | u_ = tf.stop_gradient(u_) 353 | v_ = tf.stop_gradient(v_) 354 | sigma_ = tf.matmul(tf.transpose(u_), tf.matmul(w, v_))[0, 0] 355 | 356 | update_u = u.assign(u_) 357 | update_sigma = sigma.assign(sigma_) 358 | if updates_collections is None: 359 | def _force_update(): 360 | with tf.control_dependencies([update_u, update_sigma]): 361 | return tf.identity(sigma_) 362 | 363 | sigma_ = utils.smart_cond(is_training, _force_update, lambda: sigma) 364 | weights_sn = weights / sigma_ 365 | else: 366 | sigma_ = utils.smart_cond(is_training, lambda: sigma_, lambda: sigma) 367 | weights_sn = weights / sigma_ 368 | tf.add_to_collections(updates_collections, update_u) 369 | tf.add_to_collections(updates_collections, update_sigma) 370 | 371 | return utils.collect_named_outputs(outputs_collections, sc.name, weights_sn) 372 | 373 | 374 | # Simple alias. 375 | conv2d = convolution2d 376 | conv3d = convolution3d 377 | -------------------------------------------------------------------------------- /tflib/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.losses.losses import * 2 | -------------------------------------------------------------------------------- /tflib/losses/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def center_loss(features, labels, num_classes, alpha=0.5, updates_collections=tf.GraphKeys.UPDATE_OPS, scope=None): 5 | # modified from https://github.com/EncodeTS/TensorFlow_Center_Loss/blob/master/center_loss.py 6 | 7 | assert features.shape.ndims == 2, 'The rank of `features` should be 2!' 8 | assert 0 <= alpha <= 1, '`alpha` should be in [0, 1]!' 9 | 10 | with tf.variable_scope(scope, 'center_loss', [features, labels]): 11 | centers = tf.get_variable('centers', shape=[num_classes, features.get_shape()[-1]], dtype=tf.float32, 12 | initializer=tf.constant_initializer(0), trainable=False) 13 | 14 | centers_batch = tf.gather(centers, labels) 15 | diff = centers_batch - features 16 | _, unique_idx, unique_count = tf.unique_with_counts(labels) 17 | appear_times = tf.gather(unique_count, unique_idx) 18 | appear_times = tf.reshape(appear_times, [-1, 1]) 19 | diff = diff / tf.cast((1 + appear_times), tf.float32) 20 | diff = alpha * diff 21 | update_centers = tf.scatter_sub(centers, labels, diff) 22 | 23 | center_loss = 0.5 * tf.reduce_mean(tf.reduce_sum((centers_batch - features)**2, axis=-1)) 24 | 25 | if updates_collections is None: 26 | with tf.control_dependencies([update_centers]): 27 | center_loss = tf.identity(center_loss) 28 | else: 29 | tf.add_to_collections(updates_collections, update_centers) 30 | 31 | return center_loss, centers 32 | 33 | 34 | def sigmoid_focal_loss(multi_class_labels, logits, gamma=2.0): 35 | epsilon = 1e-8 36 | multi_class_labels = tf.cast(multi_class_labels, logits.dtype) 37 | 38 | p = tf.sigmoid(logits) 39 | pt = p * multi_class_labels + (1 - p) * (1 - multi_class_labels) 40 | focal_loss = tf.reduce_mean(- (1 - pt)**gamma * tf.log(tf.maximum(pt, epsilon))) 41 | 42 | return focal_loss 43 | -------------------------------------------------------------------------------- /tflib/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.metrics.metrics import * 2 | -------------------------------------------------------------------------------- /tflib/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def resettable_metric(metric_fn, metric_params, scope=None): 5 | with tf.variable_scope(scope, 'resettable_metric') as sc: 6 | metric_returns = metric_fn(**metric_params) 7 | reset_op = tf.variables_initializer(tf.local_variables(sc.name)) 8 | return metric_returns + (reset_op,) 9 | 10 | 11 | def make_resettable(metric_fn, scope=None): 12 | def resettable_metric_fn(*args, **kwargs): 13 | with tf.variable_scope(scope, 'resettable_metric') as sc: 14 | metric_returns = metric_fn(*args, **kwargs) 15 | reset_op = tf.variables_initializer(tf.local_variables(sc.name)) 16 | return metric_returns + (reset_op,) 17 | return resettable_metric_fn 18 | -------------------------------------------------------------------------------- /tflib/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.ops.ops import * 2 | -------------------------------------------------------------------------------- /tflib/ops/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def minmax_norm(x, epsilon=1e-12): 5 | x = tf.to_float(x) 6 | min_val = tf.reduce_min(x) 7 | max_val = tf.reduce_max(x) 8 | norm_x = (x - min_val) / tf.maximum((max_val - min_val), epsilon) 9 | return norm_x 10 | 11 | 12 | def reshape(x, shape): 13 | x = tf.convert_to_tensor(x) 14 | shape = [x.shape[i] if shape[i] == 0 else shape[i] for i in range(len(shape))] 15 | shape = [tf.shape(x)[i] if shape[i] is None else shape[i] for i in range(len(shape))] 16 | return tf.reshape(x, shape) 17 | -------------------------------------------------------------------------------- /tflib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from tflib.utils.collection import * 2 | from tflib.utils.learning_rate import * 3 | from tflib.utils.distribute import * 4 | from tflib.utils.utils import * 5 | -------------------------------------------------------------------------------- /tflib/utils/collection.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def tensors_filter(tensors, 7 | includes='', 8 | includes_combine_type='or', 9 | excludes=None, 10 | excludes_combine_type='or'): 11 | # NOTICE: `includes` = [] means nothing to be included, and `excludes` = [] means nothing to be excluded 12 | 13 | if excludes is None: 14 | excludes = [] 15 | 16 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!' 17 | assert isinstance(includes, (str, list, tuple)), '`includes` should be a string or a list(tuple) of strings!' 18 | assert includes_combine_type in ['or', 'and'], "`includes_combine_type` should be 'or' or 'and'!" 19 | assert isinstance(excludes, (str, list, tuple)), '`excludes` should be a string or a list(tuple) of strings!' 20 | assert excludes_combine_type in ['or', 'and'], "`excludes_combine_type` should be 'or' or 'and'!" 21 | 22 | def _select(filters, combine_type): 23 | if filter in [[], ()]: 24 | return [] 25 | 26 | filters = filters if isinstance(filters, (list, tuple)) else [filters] 27 | 28 | selected = [] 29 | for t in tensors: 30 | if combine_type == 'or': 31 | for filt in filters: 32 | if filt in t.name: 33 | selected.append(t) 34 | break 35 | elif combine_type == 'and': 36 | for filt in filters: 37 | if filt not in t.name: 38 | break 39 | else: 40 | selected.append(t) 41 | 42 | return selected 43 | 44 | include_set = _select(includes, includes_combine_type) 45 | exclude_set = _select(excludes, excludes_combine_type) 46 | select_set = [t for t in include_set if t not in exclude_set] 47 | 48 | return select_set 49 | 50 | 51 | def get_collection(key, 52 | includes='', 53 | includes_combine_type='or', 54 | excludes=None, 55 | excludes_combine_type='or'): 56 | tensors = tf.get_collection(key) 57 | return tensors_filter(tensors, 58 | includes, 59 | includes_combine_type, 60 | excludes, 61 | excludes_combine_type) 62 | 63 | 64 | global_variables = partial(get_collection, key=tf.GraphKeys.GLOBAL_VARIABLES) 65 | local_variables = partial(get_collection, key=tf.GraphKeys.LOCAL_VARIABLES) 66 | trainable_variables = partial(get_collection, key=tf.GraphKeys.TRAINABLE_VARIABLES) 67 | update_ops = partial(get_collection, key=tf.GraphKeys.UPDATE_OPS) 68 | -------------------------------------------------------------------------------- /tflib/utils/distribute.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def average_gradients(tower_grads): 5 | """Calculate the average gradient for each shared variable across all towers. 6 | 7 | Note that this function provides a synchronization point across all towers. 8 | 9 | Parameters 10 | ---------- 11 | tower_grads: 12 | List of lists of (gradient, variable) tuples. The outer list 13 | is over individual gradients. The inner list is over the gradient 14 | calculation for each tower. 15 | 16 | Returns 17 | ------- 18 | List of pairs of (gradient, variable) where the gradient has been averaged 19 | across all towers. 20 | 21 | """ 22 | average_grads = [] 23 | for grad_and_vars in zip(*tower_grads): 24 | # Note that each grad_and_vars looks like the following: 25 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 26 | grads = [] 27 | for g, _ in grad_and_vars: 28 | # Add 0 dimension to the gradients to represent the tower. 29 | expanded_g = tf.expand_dims(g, 0) 30 | 31 | # Append on a 'tower' dimension which we will average over below. 32 | grads.append(expanded_g) 33 | 34 | # Average over the 'tower' dimension. 35 | grad = tf.concat(axis=0, values=grads) 36 | grad = tf.reduce_mean(grad, 0) 37 | 38 | # Keep in mind that the Variables are redundant because they are shared 39 | # across towers. So .. we will just return the first tower's pointer to 40 | # the Variable. 41 | v = grad_and_vars[0][1] 42 | grad_and_var = (grad, v) 43 | average_grads.append(grad_and_var) 44 | return average_grads 45 | -------------------------------------------------------------------------------- /tflib/utils/learning_rate.py: -------------------------------------------------------------------------------- 1 | class LinearDecayLR: 2 | # if `step` < `step_start_decay`: use fixed learning rate 3 | # else: linearly decay the learning rate to zero 4 | 5 | # `step` should start from 0(included) to `steps`(excluded) 6 | 7 | def __init__(self, initial_learning_rate, steps, step_start_decay): 8 | self._initial_learning_rate = initial_learning_rate 9 | self._steps = steps 10 | self._step_start_decay = step_start_decay 11 | self.current_learning_rate = initial_learning_rate 12 | 13 | def __call__(self, step): 14 | if step >= self._step_start_decay: 15 | self.current_learning_rate = self._initial_learning_rate * (1 - 1 / (self._steps - self._step_start_decay + 1) * (step - self._step_start_decay + 1)) 16 | else: 17 | self.current_learning_rate = self._initial_learning_rate 18 | return self.current_learning_rate 19 | 20 | 21 | class StepDecayLR: 22 | 23 | def __init__(self, initial_learning_rate, step_size, decay_rate): 24 | super(StepDecayLR, self).__init__() 25 | self._initial_learning_rate = initial_learning_rate 26 | self._step_size = step_size 27 | self._decay_rate = decay_rate 28 | self.current_learning_rate = initial_learning_rate 29 | 30 | def __call__(self, step): 31 | self.current_learning_rate = self._initial_learning_rate * self._decay_rate ** (step // self._step_size) 32 | return self.current_learning_rate 33 | -------------------------------------------------------------------------------- /tflib/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def session(graph=None, 6 | allow_soft_placement=True, 7 | log_device_placement=False, 8 | allow_growth=True): 9 | """Return a Session with simple config.""" 10 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement, 11 | log_device_placement=log_device_placement) 12 | config.gpu_options.allow_growth = allow_growth 13 | return tf.Session(graph=graph, config=config) 14 | 15 | 16 | class Checkpoint: 17 | """Enhanced "tf.train.Checkpoint".""" 18 | 19 | def __init__(self, 20 | checkpoint_kwargs, # for "tf.train.Checkpoint" 21 | directory, # for "tf.train.CheckpointManager" 22 | max_to_keep=5, 23 | keep_checkpoint_every_n_hours=None): 24 | self.checkpoint = tf.train.Checkpoint(**checkpoint_kwargs) 25 | self.manager = tf.train.CheckpointManager(self.checkpoint, directory, max_to_keep, keep_checkpoint_every_n_hours) 26 | 27 | def restore(self, save_path=None): 28 | save_path = self.manager.latest_checkpoint if save_path is None else save_path 29 | return self.checkpoint.restore(save_path) 30 | 31 | def save(self, file_prefix_or_checkpoint_number=None, session=None): 32 | if isinstance(file_prefix_or_checkpoint_number, str): 33 | return self.checkpoint.save(file_prefix_or_checkpoint_number, session=session) 34 | else: 35 | return self.manager.save(checkpoint_number=file_prefix_or_checkpoint_number) 36 | 37 | def __getattr__(self, attr): 38 | if hasattr(self.checkpoint, attr): 39 | return getattr(self.checkpoint, attr) 40 | elif hasattr(self.manager, attr): 41 | return getattr(self.manager, attr) 42 | else: 43 | self.__getattribute__(attr) # this will raise an exception 44 | 45 | 46 | def summary(name_data_dict, 47 | types=['mean', 'std', 'max', 'min', 'sparsity', 'histogram'], 48 | name='summary'): 49 | """Summary. 50 | 51 | Examples 52 | -------- 53 | >>> summary({'a': data_a, 'b': data_b}) 54 | 55 | """ 56 | def _summary(name, data): 57 | summaries = [] 58 | if data.shape == (): 59 | summaries.append(tf.summary.scalar(name, data)) 60 | else: 61 | if 'mean' in types: 62 | summaries.append(tf.summary.scalar(name + '-mean', tf.math.reduce_mean(data))) 63 | if 'std' in types: 64 | summaries.append(tf.summary.scalar(name + '-std', tf.math.reduce_std(data))) 65 | if 'max' in types: 66 | summaries.append(tf.summary.scalar(name + '-max', tf.math.reduce_max(data))) 67 | if 'min' in types: 68 | summaries.append(tf.summary.scalar(name + '-min', tf.math.reduce_min(data))) 69 | if 'sparsity' in types: 70 | summaries.append(tf.summary.scalar(name + '-sparsity', tf.math.zero_fraction(data))) 71 | if 'histogram' in types: 72 | summaries.append(tf.summary.histogram(name, data)) 73 | return tf.summary.merge(summaries) 74 | 75 | with tf.name_scope(name): 76 | summaries = [] 77 | for name, data in name_data_dict.items(): 78 | summaries.append(_summary(name, data)) 79 | return tf.summary.merge(summaries) 80 | 81 | 82 | def summary_v2(name_data_dict, 83 | step, 84 | types=['mean', 'std', 'max', 'min', 'sparsity', 'histogram'], 85 | name='summary'): 86 | """Summary. 87 | 88 | Examples 89 | -------- 90 | >>> summary({'a': data_a, 'b': data_b}, tf.train.get_global_step()) 91 | 92 | """ 93 | def _summary(name, data): 94 | summaries = [] 95 | if data.shape == (): 96 | summaries.append(tf.contrib.summary.scalar(name, data, step=step)) 97 | else: 98 | if 'mean' in types: 99 | summaries.append(tf.contrib.summary.scalar(name + '-mean', tf.math.reduce_mean(data), step=step)) 100 | if 'std' in types: 101 | summaries.append(tf.contrib.summary.scalar(name + '-std', tf.math.reduce_std(data), step=step)) 102 | if 'max' in types: 103 | summaries.append(tf.contrib.summary.scalar(name + '-max', tf.math.reduce_max(data), step=step)) 104 | if 'min' in types: 105 | summaries.append(tf.contrib.summary.scalar(name + '-min', tf.math.reduce_min(data), step=step)) 106 | if 'sparsity' in types: 107 | summaries.append(tf.contrib.summary.scalar(name + '-sparsity', tf.math.zero_fraction(data), step=step)) 108 | if 'histogram' in types: 109 | summaries.append(tf.contrib.summary.histogram(name, data, step=step)) 110 | return summaries 111 | 112 | with tf.name_scope(name): 113 | summaries = {} 114 | for name, data in name_data_dict.items(): 115 | summaries[name] = _summary(name, data) 116 | return summaries 117 | 118 | 119 | def counter(start=0, scope=None): 120 | with tf.variable_scope(scope, 'counter'): 121 | counter = tf.get_variable(name='counter', 122 | initializer=tf.constant_initializer(start), 123 | shape=(), 124 | dtype=tf.int64) 125 | update_cnt = tf.assign(counter, tf.add(counter, 1)) 126 | return counter, update_cnt 127 | 128 | 129 | def receptive_field(convnet_fn): 130 | # TODO(Lynn): too ugly ... 131 | g = tf.Graph() 132 | with g.as_default(): 133 | img = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='img') 134 | convnet_fn(img) 135 | 136 | node_names = [node.name for node in g.as_graph_def().node 137 | if 'img' != node.name 138 | # for Conv 139 | if 'weights' not in node.name 140 | if 'biases' not in node.name 141 | if 'dilation_rate' not in node.name 142 | # for BatchNorm 143 | if 'beta' not in node.name 144 | if 'gamma' not in node.name 145 | if 'moving_mean' not in node.name 146 | if 'moving_variance' not in node.name 147 | if 'AssignMovingAvg' not in node.name] 148 | 149 | results = [] 150 | for node_name in node_names: 151 | try: 152 | rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \ 153 | tf.contrib.receptive_field.compute_receptive_field_from_graph_def(g.as_graph_def(), 'img', node_name) 154 | results.append(( 155 | node_name, 156 | {'receptive_field_x': rf_x, 157 | 'receptive_field_y': rf_y, 158 | 'effective_stride_x': eff_stride_x, 159 | 'effective_stride_y': eff_stride_y, 160 | 'effective_padding_x': eff_pad_x, 161 | 'effective_padding_y': eff_pad_y} 162 | )) 163 | except ValueError as e: 164 | if str(e) != "Output node was not found": 165 | raise e 166 | 167 | return results 168 | 169 | 170 | def count_parameters(variables): 171 | variables = variables if isinstance(variables, (list, tuple)) else [variables] 172 | n_params = np.sum([np.prod(v.shape.as_list()) for v in variables]) 173 | n_bytes = np.sum([np.prod(v.shape.as_list()) * v.dtype.size for v in variables]) 174 | return n_params, n_bytes 175 | 176 | 177 | def print_tensor(tensors): 178 | if not isinstance(tensors, (list, tuple)): 179 | tensors = [tensors] 180 | 181 | for i, tensor in enumerate(tensors): 182 | ctype = str(type(tensor)) 183 | if 'Tensor' in ctype: 184 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' % 185 | (i, 'Tensor', tensor.name, tensor.shape, tensor.dtype.name, tensor.device)) 186 | elif 'Variable' in ctype: 187 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' % 188 | (i, 'Variable', tensor.name, tensor.shape, tensor.dtype.name, tensor.device)) 189 | elif 'Operation' in ctype: 190 | print('%d: %s("%s", device=%s)' % 191 | (i, 'Operation', tensor.name, tensor.device)) 192 | else: 193 | raise Exception('Not a Tensor, Variable or Operation!') 194 | 195 | 196 | prt = print_tensor 197 | -------------------------------------------------------------------------------- /tfprob/__init__.py: -------------------------------------------------------------------------------- 1 | from tfprob.gan import * 2 | -------------------------------------------------------------------------------- /tfprob/gan/__init__.py: -------------------------------------------------------------------------------- 1 | from tfprob.gan.gradient_penalty import * 2 | from tfprob.gan.loss import * 3 | -------------------------------------------------------------------------------- /tfprob/gan/gradient_penalty.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # ====================================== 5 | # = sample method = 6 | # ====================================== 7 | 8 | def _sample_line(real, fake): 9 | shape = [tf.shape(real)[0]] + [1] * (real.shape.ndims - 1) 10 | alpha = tf.random.uniform(shape=shape, minval=0, maxval=1) 11 | sample = real + alpha * (fake - real) 12 | sample.set_shape(real.shape) 13 | return sample 14 | 15 | 16 | def _sample_DRAGAN(real, fake): # fake is useless 17 | beta = tf.random.uniform(shape=tf.shape(real), minval=0, maxval=1) 18 | fake = real + 0.5 * tf.math.reduce_std(real) * beta 19 | sample = _sample_line(real, fake) 20 | return sample 21 | 22 | 23 | # ====================================== 24 | # = gradient penalty method = 25 | # ====================================== 26 | 27 | def _norm(x): 28 | norm = tf.norm(tf.reshape(x, [tf.shape(x)[0], -1]), axis=1) 29 | return norm 30 | 31 | 32 | def _one_mean_gp(grad): 33 | norm = _norm(grad) 34 | gp = tf.reduce_mean((norm - 1)**2) 35 | return gp 36 | 37 | 38 | def _zero_mean_gp(grad): 39 | norm = _norm(grad) 40 | gp = tf.reduce_mean(norm**2) 41 | return gp 42 | 43 | 44 | def _lipschitz_penalty(grad): 45 | norm = _norm(grad) 46 | gp = tf.reduce_mean(tf.maximum(norm - 1, 0)**2) 47 | return gp 48 | 49 | 50 | def gradient_penalty(f, real, fake, gp_mode, sample_mode): 51 | sample_fns = { 52 | 'line': _sample_line, 53 | 'real': lambda real, fake: real, 54 | 'fake': lambda real, fake: fake, 55 | 'dragan': _sample_DRAGAN, 56 | } 57 | 58 | gp_fns = { 59 | '1-gp': _one_mean_gp, 60 | '0-gp': _zero_mean_gp, 61 | 'lp': _lipschitz_penalty, 62 | } 63 | 64 | if gp_mode == 'none': 65 | gp = tf.constant(0, dtype=real.dtype) 66 | else: 67 | x = sample_fns[sample_mode](real, fake) 68 | grad = tf.gradients(f(x), x)[0] 69 | gp = gp_fns[gp_mode](grad) 70 | 71 | return gp 72 | -------------------------------------------------------------------------------- /tfprob/gan/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def get_gan_losses_fn(): 5 | bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) 6 | 7 | def d_loss_fn(r_logit, f_logit): 8 | r_loss = bce(tf.ones_like(r_logit), r_logit) 9 | f_loss = bce(tf.zeros_like(f_logit), f_logit) 10 | return r_loss, f_loss 11 | 12 | def g_loss_fn(f_logit): 13 | f_loss = bce(tf.ones_like(f_logit), f_logit) 14 | return f_loss 15 | 16 | return d_loss_fn, g_loss_fn 17 | 18 | 19 | def get_hinge_v1_losses_fn(): 20 | def d_loss_fn(r_logit, f_logit): 21 | r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0)) 22 | f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0)) 23 | return r_loss, f_loss 24 | 25 | def g_loss_fn(f_logit): 26 | f_loss = tf.reduce_mean(tf.maximum(1 - f_logit, 0)) 27 | return f_loss 28 | 29 | return d_loss_fn, g_loss_fn 30 | 31 | 32 | def get_hinge_v2_losses_fn(): 33 | def d_loss_fn(r_logit, f_logit): 34 | r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0)) 35 | f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0)) 36 | return r_loss, f_loss 37 | 38 | def g_loss_fn(f_logit): 39 | f_loss = tf.reduce_mean(- f_logit) 40 | return f_loss 41 | 42 | return d_loss_fn, g_loss_fn 43 | 44 | 45 | def get_lsgan_losses_fn(): 46 | mse = tf.keras.losses.MeanSquaredError() 47 | 48 | def d_loss_fn(r_logit, f_logit): 49 | r_loss = mse(tf.ones_like(r_logit), r_logit) 50 | f_loss = mse(tf.zeros_like(f_logit), f_logit) 51 | return r_loss, f_loss 52 | 53 | def g_loss_fn(f_logit): 54 | f_loss = mse(tf.ones_like(f_logit), f_logit) 55 | return f_loss 56 | 57 | return d_loss_fn, g_loss_fn 58 | 59 | 60 | def get_wgan_losses_fn(): 61 | def d_loss_fn(r_logit, f_logit): 62 | r_loss = - tf.reduce_mean(r_logit) 63 | f_loss = tf.reduce_mean(f_logit) 64 | return r_loss, f_loss 65 | 66 | def g_loss_fn(f_logit): 67 | f_loss = - tf.reduce_mean(f_logit) 68 | return f_loss 69 | 70 | return d_loss_fn, g_loss_fn 71 | 72 | 73 | def get_adversarial_losses_fn(mode): 74 | if mode == 'gan': 75 | return get_gan_losses_fn() 76 | elif mode == 'hinge_v1': 77 | return get_hinge_v1_losses_fn() 78 | elif mode == 'hinge_v2': 79 | return get_hinge_v2_losses_fn() 80 | elif mode == 'lsgan': 81 | return get_lsgan_losses_fn() 82 | elif mode == 'wgan': 83 | return get_wgan_losses_fn() 84 | -------------------------------------------------------------------------------- /to_pb.py: -------------------------------------------------------------------------------- 1 | import pylib as py 2 | import tensorflow as tf 3 | import tflib as tl 4 | 5 | import module 6 | 7 | from tensorflow.python.framework import graph_util 8 | 9 | 10 | # ============================================================================== 11 | # = param = 12 | # ============================================================================== 13 | 14 | py.arg('--experiment_name', default='default') 15 | args_ = py.args() 16 | 17 | # output_dir 18 | output_dir = py.join('output', args_.experiment_name) 19 | 20 | # save settings 21 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml')) 22 | args.__dict__.update(args_.__dict__) 23 | 24 | # others 25 | n_atts = len(args.att_names) 26 | 27 | sess = tl.session() 28 | sess.__enter__() # make default 29 | 30 | 31 | # ============================================================================== 32 | # = graph = 33 | # ============================================================================== 34 | 35 | def sample_graph(): 36 | # model 37 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay) 38 | 39 | # placeholders & inputs 40 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3], name='xa') 41 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts], name='b_') 42 | 43 | # sample graph 44 | x = Gdec(Genc(xa, training=False), b_, training=False) 45 | x = tf.identity(x, name='xb') 46 | 47 | 48 | sample = sample_graph() 49 | 50 | 51 | # ============================================================================== 52 | # = freeze = 53 | # ============================================================================== 54 | 55 | # checkpoint 56 | checkpoint = tl.Checkpoint( 57 | {v.name: v for v in tf.global_variables()}, 58 | py.join(output_dir, 'checkpoints'), 59 | max_to_keep=1 60 | ) 61 | checkpoint.restore().run_restore_ops() 62 | 63 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'wb') as f: 64 | constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['xb']) 65 | f.write(constant_graph.SerializeToString()) 66 | 67 | sess.close() 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | import imlib as im 4 | import numpy as np 5 | import pylib as py 6 | import tensorflow as tf 7 | import tflib as tl 8 | import tfprob 9 | import tqdm 10 | 11 | import data 12 | import module 13 | 14 | 15 | # ============================================================================== 16 | # = param = 17 | # ============================================================================== 18 | 19 | default_att_names = ['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses', 20 | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young'] 21 | py.arg('--att_names', choices=data.ATT_ID.keys(), nargs='+', default=default_att_names) 22 | 23 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data') 24 | py.arg('--train_label_path', default='./data/img_celeba/train_label.txt') 25 | py.arg('--val_label_path', default='./data/img_celeba/val_label.txt') 26 | py.arg('--load_size', type=int, default=143) 27 | py.arg('--crop_size', type=int, default=128) 28 | 29 | py.arg('--n_epochs', type=int, default=60) 30 | py.arg('--epoch_start_decay', type=int, default=30) 31 | py.arg('--batch_size', type=int, default=32) 32 | py.arg('--learning_rate', type=float, default=2e-4) 33 | py.arg('--beta_1', type=float, default=0.5) 34 | 35 | py.arg('--model', default='model_128', choices=['model_128', 'model_256', 'model_384']) 36 | 37 | py.arg('--n_d', type=int, default=5) # # d updates per g update 38 | py.arg('--adversarial_loss_mode', choices=['gan', 'hinge_v1', 'hinge_v2', 'lsgan', 'wgan'], default='wgan') 39 | py.arg('--gradient_penalty_mode', choices=['none', '1-gp', '0-gp', 'lp'], default='1-gp') 40 | py.arg('--gradient_penalty_sample_mode', choices=['line', 'real', 'fake', 'dragan'], default='line') 41 | py.arg('--d_gradient_penalty_weight', type=float, default=10.0) 42 | py.arg('--d_attribute_loss_weight', type=float, default=1.0) 43 | py.arg('--g_attribute_loss_weight', type=float, default=10.0) 44 | py.arg('--g_reconstruction_loss_weight', type=float, default=100.0) 45 | py.arg('--weight_decay', type=float, default=0.0) 46 | 47 | py.arg('--n_samples', type=int, default=12) 48 | py.arg('--test_int', type=float, default=2.0) 49 | 50 | py.arg('--experiment_name', default='default') 51 | args = py.args() 52 | 53 | # output_dir 54 | output_dir = py.join('output', args.experiment_name) 55 | py.mkdir(output_dir) 56 | 57 | # save settings 58 | py.args_to_yaml(py.join(output_dir, 'settings.yml'), args) 59 | 60 | # others 61 | n_atts = len(args.att_names) 62 | 63 | sess = tl.session() 64 | sess.__enter__() # make default 65 | 66 | 67 | # ============================================================================== 68 | # = data and model = 69 | # ============================================================================== 70 | 71 | # data 72 | train_dataset, len_train_dataset = data.make_celeba_dataset(args.img_dir, args.train_label_path, args.att_names, args.batch_size, 73 | load_size=args.load_size, crop_size=args.crop_size, 74 | training=True, shuffle=True, repeat=None) 75 | val_dataset, len_val_dataset = data.make_celeba_dataset(args.img_dir, args.val_label_path, args.att_names, args.n_samples, 76 | load_size=args.load_size, crop_size=args.crop_size, 77 | training=False, shuffle=True, repeat=None) 78 | train_iter = train_dataset.make_one_shot_iterator() 79 | val_iter = val_dataset.make_one_shot_iterator() 80 | 81 | # model 82 | Genc, Gdec, D = module.get_model(args.model, n_atts, weight_decay=args.weight_decay) 83 | 84 | # loss functions 85 | d_loss_fn, g_loss_fn = tfprob.get_adversarial_losses_fn(args.adversarial_loss_mode) 86 | 87 | 88 | # ============================================================================== 89 | # = graph = 90 | # ============================================================================== 91 | 92 | def D_train_graph(): 93 | # ====================================== 94 | # = graph = 95 | # ====================================== 96 | 97 | # placeholders & inputs 98 | lr = tf.placeholder(dtype=tf.float32, shape=[]) 99 | 100 | xa, a = train_iter.get_next() 101 | b = tf.random_shuffle(a) 102 | b_ = b * 2 - 1 103 | 104 | # generate 105 | z = Genc(xa) 106 | xb_ = Gdec(z, b_) 107 | 108 | # discriminate 109 | xa_logit_gan, xa_logit_att = D(xa) 110 | xb__logit_gan, xb__logit_att = D(xb_) 111 | 112 | # discriminator losses 113 | xa_loss_gan, xb__loss_gan = d_loss_fn(xa_logit_gan, xb__logit_gan) 114 | gp = tfprob.gradient_penalty(lambda x: D(x)[0], xa, xb_, args.gradient_penalty_mode, args.gradient_penalty_sample_mode) 115 | xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att) 116 | reg_loss = tf.reduce_sum(D.func.reg_losses) 117 | 118 | loss = (xa_loss_gan + xb__loss_gan + 119 | gp * args.d_gradient_penalty_weight + 120 | xa_loss_att * args.d_attribute_loss_weight + 121 | reg_loss) 122 | 123 | # optim 124 | step_cnt, _ = tl.counter() 125 | step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=D.func.trainable_variables) 126 | 127 | # summary 128 | with tf.contrib.summary.create_file_writer('./output/%s/summaries/D' % args.experiment_name).as_default(),\ 129 | tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): 130 | summary = [ 131 | tl.summary_v2({ 132 | 'loss_gan': xa_loss_gan + xb__loss_gan, 133 | 'gp': gp, 134 | 'xa_loss_att': xa_loss_att, 135 | 'reg_loss': reg_loss 136 | }, step=step_cnt, name='D'), 137 | tl.summary_v2({'lr': lr}, step=step_cnt, name='learning_rate') 138 | ] 139 | 140 | # ====================================== 141 | # = run function = 142 | # ====================================== 143 | 144 | def run(**pl_ipts): 145 | sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) 146 | 147 | return run 148 | 149 | 150 | def G_train_graph(): 151 | # ====================================== 152 | # = graph = 153 | # ====================================== 154 | 155 | # placeholders & inputs 156 | lr = tf.placeholder(dtype=tf.float32, shape=[]) 157 | 158 | xa, a = train_iter.get_next() 159 | b = tf.random_shuffle(a) 160 | a_ = a * 2 - 1 161 | b_ = b * 2 - 1 162 | 163 | # generate 164 | z = Genc(xa) 165 | xa_ = Gdec(z, a_) 166 | xb_ = Gdec(z, b_) 167 | 168 | # discriminate 169 | xb__logit_gan, xb__logit_att = D(xb_) 170 | 171 | # generator losses 172 | xb__loss_gan = g_loss_fn(xb__logit_gan) 173 | xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att) 174 | xa__loss_rec = tf.losses.absolute_difference(xa, xa_) 175 | reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses) 176 | 177 | loss = (xb__loss_gan + 178 | xb__loss_att * args.g_attribute_loss_weight + 179 | xa__loss_rec * args.g_reconstruction_loss_weight + 180 | reg_loss) 181 | 182 | # optim 183 | step_cnt, _ = tl.counter() 184 | step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables) 185 | 186 | # summary 187 | with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\ 188 | tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt): 189 | summary = tl.summary_v2({ 190 | 'xb__loss_gan': xb__loss_gan, 191 | 'xb__loss_att': xb__loss_att, 192 | 'xa__loss_rec': xa__loss_rec, 193 | 'reg_loss': reg_loss 194 | }, step=step_cnt, name='G') 195 | 196 | # ====================================== 197 | # = generator size = 198 | # ====================================== 199 | 200 | n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables) 201 | print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024)) 202 | 203 | # ====================================== 204 | # = run function = 205 | # ====================================== 206 | 207 | def run(**pl_ipts): 208 | sess.run([step, summary], feed_dict={lr: pl_ipts['lr']}) 209 | 210 | return run 211 | 212 | 213 | def sample_graph(): 214 | # ====================================== 215 | # = graph = 216 | # ====================================== 217 | 218 | # placeholders & inputs 219 | val_next = val_iter.get_next() 220 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3]) 221 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts]) 222 | 223 | # sample graph 224 | x = Gdec(Genc(xa, training=False), b_, training=False) 225 | 226 | # ====================================== 227 | # = run function = 228 | # ====================================== 229 | 230 | save_dir = './output/%s/samples_training' % args.experiment_name 231 | py.mkdir(save_dir) 232 | 233 | def run(epoch, iter): 234 | # data for sampling 235 | xa_ipt, a_ipt = sess.run(val_next) 236 | b_ipt_list = [a_ipt] # the first is for reconstruction 237 | for i in range(n_atts): 238 | tmp = np.array(a_ipt, copy=True) 239 | tmp[:, i] = 1 - tmp[:, i] # inverse attribute 240 | tmp = data.check_attribute_conflict(tmp, args.att_names[i], args.att_names) 241 | b_ipt_list.append(tmp) 242 | 243 | x_opt_list = [xa_ipt] 244 | for i, b_ipt in enumerate(b_ipt_list): 245 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!! 246 | if i > 0: # i == 0 is for reconstruction 247 | b__ipt[..., i - 1] = b__ipt[..., i - 1] * args.test_int 248 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt}) 249 | x_opt_list.append(x_opt) 250 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4)) 251 | sample = np.reshape(sample, (-1, sample.shape[2] * sample.shape[3], sample.shape[4])) 252 | im.imwrite(sample, '%s/Epoch-%d_Iter-%d.jpg' % (save_dir, epoch, iter)) 253 | 254 | return run 255 | 256 | 257 | D_train_step = D_train_graph() 258 | G_train_step = G_train_graph() 259 | sample = sample_graph() 260 | 261 | 262 | # ============================================================================== 263 | # = train = 264 | # ============================================================================== 265 | 266 | # step counter 267 | step_cnt, update_cnt = tl.counter() 268 | 269 | # checkpoint 270 | checkpoint = tl.Checkpoint( 271 | {v.name: v for v in tf.global_variables()}, 272 | py.join(output_dir, 'checkpoints'), 273 | max_to_keep=1 274 | ) 275 | checkpoint.restore().initialize_or_restore() 276 | 277 | # summary 278 | sess.run(tf.contrib.summary.summary_writer_initializer_op()) 279 | 280 | # learning rate schedule 281 | lr_fn = tl.LinearDecayLR(args.learning_rate, args.n_epochs, args.epoch_start_decay) 282 | 283 | # train 284 | try: 285 | for ep in tqdm.trange(args.n_epochs, desc='Epoch Loop'): 286 | # learning rate 287 | lr_ipt = lr_fn(ep) 288 | 289 | for it in tqdm.trange(len_train_dataset, desc='Inner Epoch Loop'): 290 | if it + ep * len_train_dataset < sess.run(step_cnt): 291 | continue 292 | step = sess.run(update_cnt) 293 | 294 | # train D 295 | if step % (args.n_d + 1) != 0: 296 | D_train_step(lr=lr_ipt) 297 | # train G 298 | else: 299 | G_train_step(lr=lr_ipt) 300 | 301 | # save 302 | if step % (1000 * (args.n_d + 1)) == 0: 303 | checkpoint.save(step) 304 | 305 | # sample 306 | if step % (100 * (args.n_d + 1)) == 0: 307 | sample(ep, it) 308 | except Exception: 309 | traceback.print_exc() 310 | finally: 311 | checkpoint.save(step) 312 | sess.close() 313 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | 6 | 7 | # ============================================================================== 8 | # = operations = 9 | # ============================================================================== 10 | 11 | def tile_concat(a_list, b_list=[]): 12 | # tile all elements of `b_list` and then concat `a_list + b_list` along the channel axis 13 | # `a` shape: (N, H, W, C_a) 14 | # `b` shape: can be (N, 1, 1, C_b) or (N, C_b) 15 | a_list = list(a_list) if isinstance(a_list, (list, tuple)) else [a_list] 16 | b_list = list(b_list) if isinstance(b_list, (list, tuple)) else [b_list] 17 | for i, b in enumerate(b_list): 18 | b = tf.reshape(b, [-1, 1, 1, b.shape[-1]]) 19 | b = tf.tile(b, [1, a_list[0].shape[1], a_list[0].shape[2], 1]) 20 | b_list[i] = b 21 | return tf.concat(a_list + b_list, axis=-1) 22 | 23 | 24 | # ============================================================================== 25 | # = others = 26 | # ============================================================================== 27 | 28 | def get_norm_layer(norm, training, updates_collections=None): 29 | if norm == 'none': 30 | return lambda x: x 31 | elif norm == 'batch_norm': 32 | return functools.partial(slim.batch_norm, scale=True, is_training=training, updates_collections=updates_collections) 33 | elif norm == 'instance_norm': 34 | return slim.instance_norm 35 | elif norm == 'layer_norm': 36 | return slim.layer_norm 37 | --------------------------------------------------------------------------------