├── .gitignore
├── LICENSE
├── README.md
├── data.py
├── examples.md
├── imlib
├── __init__.py
├── basic.py
├── dtype.py
└── transform.py
├── module.py
├── pics
├── first_view.png
├── sample_validation.jpg
├── sample_validation_256x.jpg
├── sample_validation_384x_hd.jpg
├── sample_validation_40.jpg
├── schema.jpg
├── slide.png
└── style.jpg
├── pylib
├── __init__.py
├── argument.py
├── path.py
├── processing.py
├── serialization.py
└── timer.py
├── results.md
├── scripts
├── align.py
├── cropper.py
└── split_CelebA-HQ.py
├── test.py
├── test_multi.py
├── test_slide.py
├── tflib
├── __init__.py
├── data
│ ├── __init__.py
│ └── dataset.py
├── image
│ ├── __init__.py
│ ├── filter.py
│ └── image.py
├── layers
│ ├── __init__.py
│ ├── layers.py
│ └── layers_slim.py
├── losses
│ ├── __init__.py
│ └── losses.py
├── metrics
│ ├── __init__.py
│ └── metrics.py
├── ops
│ ├── __init__.py
│ └── ops.py
└── utils
│ ├── __init__.py
│ ├── collection.py
│ ├── distribute.py
│ ├── learning_rate.py
│ └── utils.py
├── tfprob
├── __init__.py
└── gan
│ ├── __init__.py
│ ├── gradient_penalty.py
│ └── loss.py
├── to_pb.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__/
3 | /data/
4 | /output/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Zhenliang He
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
8 |
9 |
10 |
11 |
12 | > **[AttGAN: Facial Attribute Editing by Only Changing What You Want](https://kdocs.cn/l/cpp7J2ZsFUW7)** \
13 | > [Zhenliang He](https://lynnho.github.io)1,2, [Wangmeng Zuo](https://scholar.google.com/citations?user=rUOpCEYAAAAJ)4, [Meina Kan](https://scholar.google.com/citations?user=4AKCKKEAAAAJ)1, [Shiguang Shan](https://scholar.google.com/citations?user=Vkzd7MIAAAAJ)1,3, [Xilin Chen](https://scholar.google.com/citations?user=vVx2v20AAAAJ)1 \
14 | > 1Key Lab of Intelligent Information Processing, Institute of Computing Technology, CAS, China \
15 | > 2University of Chinese Academy of Sciences, China \
16 | > 3CAS Center for Excellence in Brain Science and Intelligence Technology, China \
17 | > 4School of Computer Science and Technology, Harbin Institute of Technology, China
18 |
19 |
20 |
21 | ## Related
22 |
23 | - Other implementations of AttGAN
24 |
25 | - [AttGAN-PyTorch](https://github.com/elvisyjlin/AttGAN-PyTorch) by Yu-Jing Lin
26 |
27 | - [AttGAN-PaddlePaddle](https://github.com/PaddlePaddle/models/tree/release/1.7/PaddleCV/gan) by ceci3 and zhumanyu (**AttGAN is one of the official reproduced models of [PaddlePaddle](https://github.com/PaddlePaddle?type=source)**)
28 |
29 | - Closely related works
30 |
31 | - **An excellent work built upon our code - [STGAN](https://github.com/csmliu/STGAN) (CVPR 2019) by Ming Liu**
32 |
33 | - [Changing-the-Memorability](https://github.com/acecreamu/Changing-the-Memorability) (CVPR 2019 MBCCV Workshop) by acecreamu
34 |
35 | - [Fashion-AttGAN](https://github.com/ChanningPing/Fashion_Attribute_Editing) (CVPR 2019 FSS-USAD Workshop) by Qing Ping
36 |
37 | - An unofficial [demo video](https://www.youtube.com/watch?v=gnN4ZjEWe-8) of AttGAN by 王一凡
38 |
39 | ## Exemplar Results
40 |
41 | - See [results.md](./results.md) for more results, we try **higher resolution** and **more attributes** (all **40** attributes!!!)
42 |
43 | - Inverting 13 attributes respectively
44 |
45 | from left to right: *Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young*
46 |
47 |
48 |
49 | ## Usage
50 |
51 | - Environment
52 |
53 | - Python 3.6
54 |
55 | - TensorFlow 1.15
56 |
57 | - OpenCV, scikit-image, tqdm, oyaml
58 |
59 | - *we recommend [Anaconda](https://www.anaconda.com/distribution/#download-section) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html#linux-installers), then you can create the AttGAN environment with commands below*
60 |
61 | ```console
62 | conda create -n AttGAN python=3.6
63 |
64 | source activate AttGAN
65 |
66 | conda install opencv scikit-image tqdm tensorflow-gpu=1.15
67 |
68 | conda install -c conda-forge oyaml
69 | ```
70 |
71 | - *NOTICE: if you create a new conda environment, remember to activate it before any other command*
72 |
73 | ```console
74 | source activate AttGAN
75 | ```
76 |
77 | - Data Preparation
78 |
79 | - Option 1: [CelebA](http://openaccess.thecvf.com/content_iccv_2015/papers/Liu_Deep_Learning_Face_ICCV_2015_paper.pdf)-unaligned (higher quality than the aligned data, 10.2GB)
80 |
81 | - download the dataset
82 |
83 | - img_celeba.7z (move to **./data/img_celeba/img_celeba.7z**): [Google Drive](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg?resourcekey=0-rJlzl934LzC-Xp28GeIBzQ) or [Baidu Netdisk](https://pan.baidu.com/s/1CRxxhoQ97A5qbsKO7iaAJg) (password rp0s)
84 |
85 | - annotations.zip (move to **./data/img_celeba/annotations.zip**): [Google Drive](https://drive.google.com/file/d/1xd-d1WRnbt3yJnwh5ORGZI3g-YS-fKM9/view?usp=sharing)
86 |
87 | - unzip and process the data
88 |
89 | ```console
90 | 7z x ./data/img_celeba/img_celeba.7z/img_celeba.7z.001 -o./data/img_celeba/
91 |
92 | unzip ./data/img_celeba/annotations.zip -d ./data/img_celeba/
93 |
94 | python ./scripts/align.py
95 | ```
96 |
97 | - Option 2: CelebA-HQ (we use the data from [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ), 3.2GB)
98 |
99 | - CelebAMask-HQ.zip (move to **./data/CelebAMask-HQ.zip**): [Google Drive](https://drive.google.com/open?id=1badu11NqxGf6qM3PTTooQDJvQbejgbTv) or [Baidu Netdisk](https://pan.baidu.com/s/1wN1E-B1bJ7mE1mrn9loj5g)
100 |
101 | - unzip and process the data
102 |
103 | ```console
104 | unzip ./data/CelebAMask-HQ.zip -d ./data/
105 |
106 | python ./scripts/split_CelebA-HQ.py
107 | ```
108 |
109 | - Run AttGAN
110 |
111 | - training (see [examples.md](./examples.md) for more training commands)
112 |
113 | ```console
114 | \\ for CelebA
115 | CUDA_VISIBLE_DEVICES=0 \
116 | python train.py \
117 | --load_size 143 \
118 | --crop_size 128 \
119 | --model model_128 \
120 | --experiment_name AttGAN_128
121 |
122 | \\ for CelebA-HQ
123 | CUDA_VISIBLE_DEVICES=0 \
124 | python train.py \
125 | --img_dir ./data/CelebAMask-HQ/CelebA-HQ-img \
126 | --train_label_path ./data/CelebAMask-HQ/train_label.txt \
127 | --val_label_path ./data/CelebAMask-HQ/val_label.txt \
128 | --load_size 128 \
129 | --crop_size 128 \
130 | --n_epochs 200 \
131 | --epoch_start_decay 100 \
132 | --model model_128 \
133 | --experiment_name AttGAN_128_CelebA-HQ
134 | ```
135 |
136 | - testing
137 |
138 | - **single** attribute editing (inversion)
139 |
140 | ```console
141 | \\ for CelebA
142 | CUDA_VISIBLE_DEVICES=0 \
143 | python test.py \
144 | --experiment_name AttGAN_128
145 |
146 | \\ for CelebA-HQ
147 | CUDA_VISIBLE_DEVICES=0 \
148 | python test.py \
149 | --img_dir ./data/CelebAMask-HQ/CelebA-HQ-img \
150 | --test_label_path ./data/CelebAMask-HQ/test_label.txt \
151 | --experiment_name AttGAN_128_CelebA-HQ
152 | ```
153 |
154 |
155 | - **multiple** attribute editing (inversion) example
156 |
157 | ```console
158 | \\ for CelebA
159 | CUDA_VISIBLE_DEVICES=0 \
160 | python test_multi.py \
161 | --test_att_names Bushy_Eyebrows Pale_Skin \
162 | --experiment_name AttGAN_128
163 | ```
164 |
165 | - attribute sliding example
166 |
167 | ```console
168 | \\ for CelebA
169 | CUDA_VISIBLE_DEVICES=0 \
170 | python test_slide.py \
171 | --test_att_name Pale_Skin \
172 | --test_int_min -2 \
173 | --test_int_max 2 \
174 | --test_int_step 0.5 \
175 | --experiment_name AttGAN_128
176 | ```
177 |
178 | - loss visualization
179 |
180 | ```console
181 | CUDA_VISIBLE_DEVICES='' \
182 | tensorboard \
183 | --logdir ./output/AttGAN_128/summaries \
184 | --port 6006
185 | ```
186 |
187 | - convert trained model to .pb file
188 |
189 | ```console
190 | python to_pb.py --experiment_name AttGAN_128
191 | ```
192 |
193 | - Using Trained Weights
194 |
195 | - alternative trained weights (move to **./output/\*.zip**)
196 |
197 | - [AttGAN_128.zip](https://drive.google.com/file/d/1Oy4F1xtYdxj4iyiLyaEd-dkGIJ0mwo41/view?usp=sharing) (987.5MB)
198 |
199 | - *including G, D, and the state of the optimizer*
200 |
201 | - [AttGAN_128_generator_only.zip](https://drive.google.com/file/d/1lcQ-ijNrGD4919eJ5Dv-7ja5rsx5p0Tp/view?usp=sharing) (161.5MB)
202 |
203 | - *G only*
204 |
205 | - [AttGAN_384_generator_only.zip](https://drive.google.com/open?id=1scaKWcWIpTfsV0yrWCI-wg_JDmDsKKm1) (91.1MB)
206 |
207 |
208 | - unzip the file (AttGAN_128.zip for example)
209 |
210 | ```console
211 | unzip ./output/AttGAN_128.zip -d ./output/
212 | ```
213 |
214 | - testing (see above)
215 |
216 |
217 | - Example for Custom Dataset
218 |
219 | - [AttGAN-Cartoon](https://github.com/LynnHo/AttGAN-Cartoon-Tensorflow)
220 |
221 | ## Citation
222 |
223 | If you find [AttGAN](https://kdocs.cn/l/cpp7J2ZsFUW7) useful in your research work, please consider citing:
224 |
225 | @ARTICLE{8718508,
226 | author={Z. {He} and W. {Zuo} and M. {Kan} and S. {Shan} and X. {Chen}},
227 | journal={IEEE Transactions on Image Processing},
228 | title={AttGAN: Facial Attribute Editing by Only Changing What You Want},
229 | year={2019},
230 | volume={28},
231 | number={11},
232 | pages={5464-5478},
233 | keywords={Face;Facial features;Task analysis;Decoding;Image reconstruction;Hair;Gallium nitride;Facial attribute editing;attribute style manipulation;adversarial learning},
234 | doi={10.1109/TIP.2019.2916751},
235 | ISSN={1057-7149},
236 | month={Nov},}
237 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pylib as py
3 | import tensorflow as tf
4 | import tflib as tl
5 |
6 |
7 | ATT_ID = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2,
8 | 'Bags_Under_Eyes': 3, 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6,
9 | 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, 'Blurry': 10,
10 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13,
11 | 'Double_Chin': 14, 'Eyeglasses': 15, 'Goatee': 16,
12 | 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19,
13 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22,
14 | 'Narrow_Eyes': 23, 'No_Beard': 24, 'Oval_Face': 25,
15 | 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28,
16 | 'Rosy_Cheeks': 29, 'Sideburns': 30, 'Smiling': 31,
17 | 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34,
18 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36,
19 | 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}
20 | ID_ATT = {v: k for k, v in ATT_ID.items()}
21 |
22 |
23 | def make_celeba_dataset(img_dir,
24 | label_path,
25 | att_names,
26 | batch_size,
27 | load_size=286,
28 | crop_size=256,
29 | training=True,
30 | drop_remainder=True,
31 | shuffle=True,
32 | repeat=1):
33 | img_names = np.genfromtxt(label_path, dtype=str, usecols=0)
34 | img_paths = np.array([py.join(img_dir, img_name) for img_name in img_names])
35 | labels = np.genfromtxt(label_path, dtype=int, usecols=range(1, 41))
36 | labels = labels[:, np.array([ATT_ID[att_name] for att_name in att_names])]
37 |
38 | if shuffle:
39 | idx = np.random.permutation(len(img_paths))
40 | img_paths = img_paths[idx]
41 | labels = labels[idx]
42 |
43 | if training:
44 | def map_fn_(img, label):
45 | img = tf.image.resize(img, [load_size, load_size])
46 | # img = tl.random_rotate(img, 5)
47 | img = tf.image.random_flip_left_right(img)
48 | img = tf.image.random_crop(img, [crop_size, crop_size, 3])
49 | # img = tl.color_jitter(img, 25, 0.2, 0.2, 0.1)
50 | # img = tl.random_grayscale(img, p=0.3)
51 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
52 | label = (label + 1) // 2
53 | return img, label
54 | else:
55 | def map_fn_(img, label):
56 | img = tf.image.resize(img, [load_size, load_size])
57 | img = tl.center_crop(img, size=crop_size)
58 | img = tf.clip_by_value(img, 0, 255) / 127.5 - 1
59 | label = (label + 1) // 2
60 | return img, label
61 |
62 | dataset = tl.disk_image_batch_dataset(img_paths,
63 | batch_size,
64 | labels=labels,
65 | drop_remainder=drop_remainder,
66 | map_fn=map_fn_,
67 | shuffle=shuffle,
68 | repeat=repeat)
69 |
70 | if drop_remainder:
71 | len_dataset = len(img_paths) // batch_size
72 | else:
73 | len_dataset = int(np.ceil(len(img_paths) / batch_size))
74 |
75 | return dataset, len_dataset
76 |
77 |
78 | def check_attribute_conflict(att_batch, att_name, att_names):
79 | def _set(att, value, att_name):
80 | if att_name in att_names:
81 | att[att_names.index(att_name)] = value
82 |
83 | idx = att_names.index(att_name)
84 |
85 | for att in att_batch:
86 | if att_name in ['Bald', 'Receding_Hairline'] and att[idx] == 1:
87 | _set(att, 0, 'Bangs')
88 | elif att_name == 'Bangs' and att[idx] == 1:
89 | _set(att, 0, 'Bald')
90 | _set(att, 0, 'Receding_Hairline')
91 | elif att_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair'] and att[idx] == 1:
92 | for n in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
93 | if n != att_name:
94 | _set(att, 0, n)
95 | elif att_name in ['Straight_Hair', 'Wavy_Hair'] and att[idx] == 1:
96 | for n in ['Straight_Hair', 'Wavy_Hair']:
97 | if n != att_name:
98 | _set(att, 0, n)
99 | # elif att_name in ['Mustache', 'No_Beard'] and att[idx] == 1: # enable this part help to learn `Mustache`
100 | # for n in ['Mustache', 'No_Beard']:
101 | # if n != att_name:
102 | # _set(att, 0, n)
103 |
104 | return att_batch
105 |
--------------------------------------------------------------------------------
/examples.md:
--------------------------------------------------------------------------------
1 | # [AttGAN](https://ieeexplore.ieee.org/document/8718508?source=authoralert) Usage
2 |
3 | - training
4 |
5 | - for 128x128 images
6 |
7 | ```console
8 | CUDA_VISIBLE_DEVICES=0 \
9 | python train.py \
10 | --load_size 143 \
11 | --crop_size 128 \
12 | --model model_128 \
13 | --experiment_name AttGAN_128
14 | ```
15 |
16 | - for 128x128 images with all **40** attributes!
17 |
18 | ```console
19 | CUDA_VISIBLE_DEVICES=0 \
20 | python train.py \
21 | --load_size 143 \
22 | --crop_size 128 \
23 | --model model_128 \
24 | --experiment_name AttGAN_128_40 \
25 | --att_names 5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes \
26 | Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry \
27 | Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee \
28 | Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open \
29 | Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose \
30 | Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair \
31 | Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick \
32 | Wearing_Necklace Wearing_Necktie Young
33 | ```
34 |
35 | - for 256x256 images
36 |
37 | ```console
38 | CUDA_VISIBLE_DEVICES=0 \
39 | python train.py \
40 | --load_size 286 \
41 | --crop_size 256 \
42 | --model model_256 \
43 | --experiment_name AttGAN_256
44 | ```
45 |
46 | - for 384x384 images
47 |
48 | ```console
49 | CUDA_VISIBLE_DEVICES=0 \
50 | python train.py \
51 | --load_size 429 \
52 | --crop_size 384 \
53 | --model model_384 \
54 | --experiment_name AttGAN_384
55 | ```
--------------------------------------------------------------------------------
/imlib/__init__.py:
--------------------------------------------------------------------------------
1 | from imlib.basic import *
2 | from imlib.dtype import *
3 | from imlib.transform import *
4 |
--------------------------------------------------------------------------------
/imlib/basic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import skimage.io as iio
3 |
4 | from imlib import dtype
5 |
6 |
7 | def imread(path, as_gray=False, **kwargs):
8 | """Return a float64 image in [-1.0, 1.0]."""
9 | image = iio.imread(path, as_gray, **kwargs)
10 | if image.dtype == np.uint8:
11 | image = image / 127.5 - 1
12 | elif image.dtype == np.uint16:
13 | image = image / 32767.5 - 1
14 | elif image.dtype in [np.float32, np.float64]:
15 | image = image * 2 - 1.0
16 | else:
17 | raise Exception("Inavailable image dtype: %s!" % image.dtype)
18 | return image
19 |
20 |
21 | def imwrite(image, path, quality=95, **plugin_args):
22 | """Save a [-1.0, 1.0] image."""
23 | iio.imsave(path, dtype.im2uint(image), quality=quality, **plugin_args)
24 |
25 |
26 | def imshow(image):
27 | """Show a [-1.0, 1.0] image."""
28 | iio.imshow(dtype.im2uint(image))
29 |
30 |
31 | show = iio.show
32 |
--------------------------------------------------------------------------------
/imlib/dtype.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf):
5 | # check type
6 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!'
7 |
8 | # check dtype
9 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes]
10 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes
11 |
12 | # check nan and inf
13 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!'
14 |
15 | # check value
16 | if min_value not in [None, -np.inf]:
17 | l = '[' + str(min_value)
18 | else:
19 | l = '(-inf'
20 | min_value = -np.inf
21 | if max_value not in [None, np.inf]:
22 | r = str(max_value) + ']'
23 | else:
24 | r = 'inf)'
25 | max_value = np.inf
26 | assert np.min(images) >= min_value and np.max(images) <= max_value, \
27 | '`images` should be in the range of %s!' % (l + ',' + r)
28 |
29 |
30 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
31 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype."""
32 | _check(images, [np.float32, np.float64], -1.0, 1.0)
33 | dtype = dtype if dtype else images.dtype
34 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)
35 |
36 |
37 | def float2im(images):
38 | """Transform images from [0, 1.0] to [-1.0, 1.0]."""
39 | _check(images, [np.float32, np.float64], 0.0, 1.0)
40 | return images * 2 - 1.0
41 |
42 |
43 | def float2uint(images):
44 | """Transform images from [0, 1.0] to uint8."""
45 | _check(images, [np.float32, np.float64], -0.0, 1.0)
46 | return (images * 255).astype(np.uint8)
47 |
48 |
49 | def im2uint(images):
50 | """Transform images from [-1.0, 1.0] to uint8."""
51 | return to_range(images, 0, 255, np.uint8)
52 |
53 |
54 | def im2float(images):
55 | """Transform images from [-1.0, 1.0] to [0.0, 1.0]."""
56 | return to_range(images, 0.0, 1.0)
57 |
58 |
59 | def uint2im(images):
60 | """Transform images from uint8 to [-1.0, 1.0] of float64."""
61 | _check(images, np.uint8)
62 | return images / 127.5 - 1.0
63 |
64 |
65 | def uint2float(images):
66 | """Transform images from uint8 to [0.0, 1.0] of float64."""
67 | _check(images, np.uint8)
68 | return images / 255.0
69 |
70 |
71 | def cv2im(images):
72 | """Transform opencv images to [-1.0, 1.0]."""
73 | images = uint2im(images)
74 | return images[..., ::-1]
75 |
76 |
77 | def im2cv(images):
78 | """Transform images from [-1.0, 1.0] to opencv images."""
79 | images = im2uint(images)
80 | return images[..., ::-1]
81 |
--------------------------------------------------------------------------------
/imlib/transform.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import skimage.color as color
5 | import skimage.transform as transform
6 |
7 |
8 | rgb2gray = color.rgb2gray
9 | gray2rgb = color.gray2rgb
10 |
11 | imresize = transform.resize
12 | imrescale = transform.rescale
13 |
14 |
15 | def immerge(images, n_rows=None, n_cols=None, padding=0, pad_value=0):
16 | """Merge images to an image with (n_rows * h) * (n_cols * w).
17 |
18 | Parameters
19 | ----------
20 | images : numpy.array or object which can be converted to numpy.array
21 | Images in shape of N * H * W(* C=1 or 3).
22 |
23 | """
24 | images = np.array(images)
25 | n = images.shape[0]
26 | if n_rows:
27 | n_rows = max(min(n_rows, n), 1)
28 | n_cols = int(n - 0.5) // n_rows + 1
29 | elif n_cols:
30 | n_cols = max(min(n_cols, n), 1)
31 | n_rows = int(n - 0.5) // n_cols + 1
32 | else:
33 | n_rows = int(n ** 0.5)
34 | n_cols = int(n - 0.5) // n_rows + 1
35 |
36 | h, w = images.shape[1], images.shape[2]
37 | shape = (h * n_rows + padding * (n_rows - 1),
38 | w * n_cols + padding * (n_cols - 1))
39 | if images.ndim == 4:
40 | shape += (images.shape[3],)
41 | img = np.full(shape, pad_value, dtype=images.dtype)
42 |
43 | for idx, image in enumerate(images):
44 | i = idx % n_cols
45 | j = idx // n_cols
46 | img[j * (h + padding):j * (h + padding) + h,
47 | i * (w + padding):i * (w + padding) + w, ...] = image
48 |
49 | return img
50 |
51 |
52 | def grid_split(image, h, w):
53 | n_rows = math.ceil(image.shape[0] / h)
54 | n_cols = math.ceil(image.shape[1] / w)
55 |
56 | rows = []
57 | for r in range(n_rows):
58 | cols = []
59 | for c in range(n_cols):
60 | cols.append(image[r * h: (r + 1) * h, c * w: (c + 1) * w, ...])
61 | rows.append(cols)
62 |
63 | return rows
64 |
65 |
66 | def grid_merge(grid, padding=(0, 0), pad_value=(0, 0)):
67 | padding = padding if isinstance(padding, (list, tuple)) else [padding, padding]
68 | pad_value = pad_value if isinstance(pad_value, (list, tuple)) else [pad_value, pad_value]
69 |
70 | new_rows = []
71 | for r, row in enumerate(grid):
72 | new_cols = []
73 | for c, col in enumerate(row):
74 | if c != 0:
75 | new_cols.append(np.full([col.shape[0], padding[1], col.shape[2]], pad_value[1], dtype=col.dtype))
76 | new_cols.append(col)
77 |
78 | new_cols = np.concatenate(new_cols, axis=1)
79 | if r != 0:
80 | new_rows.append(np.full([padding[0], new_cols.shape[1], new_cols.shape[2]], pad_value[0], dtype=new_cols.dtype))
81 | new_rows.append(new_cols)
82 |
83 | grid_merged = np.concatenate(new_rows, axis=0)
84 |
85 | return grid_merged
86 |
--------------------------------------------------------------------------------
/module.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import tensorflow as tf
4 | import tensorflow.contrib.slim as slim
5 |
6 | import utils
7 |
8 |
9 | conv = functools.partial(slim.conv2d, activation_fn=None)
10 | dconv = functools.partial(slim.conv2d_transpose, activation_fn=None)
11 | fc = functools.partial(slim.fully_connected, activation_fn=None)
12 |
13 |
14 | class UNetGenc:
15 |
16 | def __call__(self, x, dim=64, n_downsamplings=5, weight_decay=0.0,
17 | norm_name='batch_norm', training=True, scope='UNetGenc'):
18 | MAX_DIM = 1024
19 |
20 | conv_ = functools.partial(conv, weights_regularizer=slim.l2_regularizer(weight_decay))
21 | norm = utils.get_norm_layer(norm_name, training, updates_collections=None)
22 |
23 | conv_norm_lrelu = functools.partial(conv_, normalizer_fn=norm, activation_fn=tf.nn.leaky_relu)
24 |
25 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
26 | z = x
27 | zs = []
28 | for i in range(n_downsamplings):
29 | d = min(dim * 2**i, MAX_DIM)
30 | z = conv_norm_lrelu(z, d, 4, 2)
31 | zs.append(z)
32 |
33 | # variables and update operations
34 | self.variables = tf.global_variables(scope)
35 | self.trainable_variables = tf.trainable_variables(scope)
36 | self.reg_losses = tf.losses.get_regularization_losses(scope)
37 |
38 | return zs
39 |
40 |
41 | class UNetGdec:
42 |
43 | def __call__(self, zs, a, dim=64, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=0.0,
44 | norm_name='batch_norm', training=True, scope='UNetGdec'):
45 | MAX_DIM = 1024
46 |
47 | dconv_ = functools.partial(dconv, weights_regularizer=slim.l2_regularizer(weight_decay))
48 | norm = utils.get_norm_layer(norm_name, training, updates_collections=None)
49 |
50 | dconv_norm_relu = functools.partial(dconv_, normalizer_fn=norm, activation_fn=tf.nn.relu)
51 |
52 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
53 | a = tf.to_float(a)
54 |
55 | z = utils.tile_concat(zs[-1], a)
56 | for i in range(n_upsamplings - 1):
57 | d = min(dim * 2**(n_upsamplings - 1 - i), MAX_DIM)
58 | z = dconv_norm_relu(z, d, 4, 2)
59 | if shortcut_layers > i:
60 | z = utils.tile_concat([z, zs[-2 - i]])
61 | if inject_layers > i:
62 | z = utils.tile_concat(z, a)
63 | x = tf.nn.tanh(dconv_(z, 3, 4, 2))
64 |
65 | # variables and update operations
66 | self.variables = tf.global_variables(scope)
67 | self.trainable_variables = tf.trainable_variables(scope)
68 | self.reg_losses = tf.losses.get_regularization_losses(scope)
69 |
70 | return x
71 |
72 |
73 | class ConvD:
74 |
75 | def __call__(self, x, n_atts, dim=64, fc_dim=1024, n_downsamplings=5, weight_decay=0.0,
76 | norm_name='instance_norm', training=True, scope='ConvD'):
77 | MAX_DIM = 1024
78 |
79 | conv_ = functools.partial(conv, weights_regularizer=slim.l2_regularizer(weight_decay))
80 | fc_ = functools.partial(fc, weights_regularizer=slim.l2_regularizer(weight_decay))
81 | norm = utils.get_norm_layer(norm_name, training, updates_collections=None)
82 |
83 | conv_norm_lrelu = functools.partial(conv_, normalizer_fn=norm, activation_fn=tf.nn.leaky_relu)
84 |
85 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
86 | z = x
87 | for i in range(n_downsamplings):
88 | d = min(dim * 2**i, MAX_DIM)
89 | z = conv_norm_lrelu(z, d, 4, 2)
90 | z = slim.flatten(z)
91 |
92 | logit_gan = tf.nn.leaky_relu(fc_(z, fc_dim))
93 | logit_gan = fc_(logit_gan, 1)
94 |
95 | logit_att = tf.nn.leaky_relu(fc_(z, fc_dim))
96 | logit_att = fc_(logit_att, n_atts)
97 |
98 | # variables and update operations
99 | self.variables = tf.global_variables(scope)
100 | self.trainable_variables = tf.trainable_variables(scope)
101 | self.reg_losses = tf.losses.get_regularization_losses(scope)
102 |
103 | return logit_gan, logit_att
104 |
105 |
106 | def get_model(name, n_atts, weight_decay=0.0):
107 | if name in ['model_128', 'model_256']:
108 | Genc = functools.partial(UNetGenc(), dim=64, n_downsamplings=5, weight_decay=weight_decay)
109 | Gdec = functools.partial(UNetGdec(), dim=64, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=weight_decay)
110 | D = functools.partial(ConvD(), n_atts=n_atts, dim=64, fc_dim=1024, n_downsamplings=5, weight_decay=weight_decay)
111 | elif name == 'model_384':
112 | Genc = functools.partial(UNetGenc(), dim=48, n_downsamplings=5, weight_decay=weight_decay)
113 | Gdec = functools.partial(UNetGdec(), dim=48, n_upsamplings=5, shortcut_layers=1, inject_layers=1, weight_decay=weight_decay)
114 | D = functools.partial(ConvD(), n_atts=n_atts, dim=48, fc_dim=512, n_downsamplings=5, weight_decay=weight_decay)
115 | return Genc, Gdec, D
116 |
--------------------------------------------------------------------------------
/pics/first_view.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/first_view.png
--------------------------------------------------------------------------------
/pics/sample_validation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation.jpg
--------------------------------------------------------------------------------
/pics/sample_validation_256x.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation_256x.jpg
--------------------------------------------------------------------------------
/pics/sample_validation_384x_hd.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation_384x_hd.jpg
--------------------------------------------------------------------------------
/pics/sample_validation_40.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/sample_validation_40.jpg
--------------------------------------------------------------------------------
/pics/schema.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/schema.jpg
--------------------------------------------------------------------------------
/pics/slide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/slide.png
--------------------------------------------------------------------------------
/pics/style.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LynnHo/AttGAN-Tensorflow/f1231fcc581133c9a7d8451e26a8bcd293fe03d5/pics/style.jpg
--------------------------------------------------------------------------------
/pylib/__init__.py:
--------------------------------------------------------------------------------
1 | from pylib.argument import *
2 | from pylib.processing import *
3 | from pylib.path import *
4 | from pylib.serialization import *
5 | from pylib.timer import *
6 |
7 | import pprint
8 |
9 | pp = pprint.pprint
10 |
--------------------------------------------------------------------------------
/pylib/argument.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import json
4 |
5 | from pylib import serialization
6 |
7 |
8 | GLOBAL_COMMAND_PARSER = argparse.ArgumentParser()
9 |
10 |
11 | def _serialization_wrapper(func):
12 | @functools.wraps(func)
13 | def _wrapper(*args, **kwargs):
14 | to_json = kwargs.pop("to_json", None)
15 | to_yaml = kwargs.pop("to_yaml", None)
16 | namespace = func(*args, **kwargs)
17 | if to_json:
18 | args_to_json(to_json, namespace)
19 | if to_yaml:
20 | args_to_yaml(to_yaml, namespace)
21 | return namespace
22 | return _wrapper
23 |
24 |
25 | def str2bool(v):
26 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
27 | return True
28 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
29 | return False
30 | else:
31 | raise argparse.ArgumentTypeError('Boolean value expected!')
32 |
33 |
34 | def argument(*args, **kwargs):
35 | """Wrap argparse.add_argument."""
36 | if 'type'in kwargs:
37 | if issubclass(kwargs['type'], bool):
38 | kwargs['type'] = str2bool
39 | elif issubclass(kwargs['type'], dict):
40 | kwargs['type'] = json.loads
41 | return GLOBAL_COMMAND_PARSER.add_argument(*args, **kwargs)
42 |
43 |
44 | arg = argument
45 |
46 |
47 | @_serialization_wrapper
48 | def args(args=None, namespace=None):
49 | """Parse args using the global parser."""
50 | namespace = GLOBAL_COMMAND_PARSER.parse_args(args=args, namespace=namespace)
51 | return namespace
52 |
53 |
54 | @_serialization_wrapper
55 | def args_from_xxx(obj, parser, check=True):
56 | """Load args from xxx ignoring type and choices with default still valid.
57 |
58 | Parameters
59 | ----------
60 | parser: function
61 | Should return a dict.
62 |
63 | """
64 | dict_ = parser(obj)
65 | namespace = argparse.ArgumentParser().parse_args(args='') # '' for not to accept command line args
66 | for k, v in dict_.items():
67 | namespace.__setattr__(k, v)
68 | return namespace
69 |
70 |
71 | args_from_dict = functools.partial(args_from_xxx, parser=lambda x: x)
72 | args_from_json = functools.partial(args_from_xxx, parser=serialization.load_json)
73 | args_from_yaml = functools.partial(args_from_xxx, parser=serialization.load_yaml)
74 |
75 |
76 | def args_to_json(path, namespace, **kwagrs):
77 | serialization.save_json(path, vars(namespace), **kwagrs)
78 |
79 |
80 | def args_to_yaml(path, namespace, **kwagrs):
81 | serialization.save_yaml(path, vars(namespace), **kwagrs)
82 |
--------------------------------------------------------------------------------
/pylib/path.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import fnmatch
3 | import os
4 | import glob as _glob
5 | import sys
6 |
7 |
8 | def add_path(paths):
9 | if not isinstance(paths, (list, tuple)):
10 | paths = [paths]
11 | for path in paths:
12 | if path not in sys.path:
13 | sys.path.insert(0, path)
14 |
15 |
16 | def mkdir(paths):
17 | if not isinstance(paths, (list, tuple)):
18 | paths = [paths]
19 | for path in paths:
20 | if not os.path.exists(path):
21 | os.makedirs(path)
22 |
23 |
24 | def split(path):
25 | """Return dir, name, ext."""
26 | dir, name_ext = os.path.split(path)
27 | name, ext = os.path.splitext(name_ext)
28 | return dir, name, ext
29 |
30 |
31 | def directory(path):
32 | return split(path)[0]
33 |
34 |
35 | def name(path):
36 | return split(path)[1]
37 |
38 |
39 | def ext(path):
40 | return split(path)[2]
41 |
42 |
43 | def name_ext(path):
44 | return ''.join(split(path)[1:])
45 |
46 |
47 | def change_ext(path, ext):
48 | if ext[0] == '.':
49 | ext = ext[1:]
50 | return os.path.splitext(path)[0] + '.' + ext
51 |
52 |
53 | asbpath = os.path.abspath
54 |
55 |
56 | join = os.path.join
57 |
58 |
59 | def prefix(path, prefixes, sep='-'):
60 | prefixes = prefixes if isinstance(prefixes, (list, tuple)) else [prefixes]
61 | dir, name, ext = split(path)
62 | return join(dir, sep.join(prefixes) + sep + name + ext)
63 |
64 |
65 | def suffix(path, suffixes, sep='-'):
66 | suffixes = suffixes if isinstance(suffixes, (list, tuple)) else [suffixes]
67 | dir, name, ext = split(path)
68 | return join(dir, name + sep + sep.join(suffixes) + ext)
69 |
70 |
71 | def prefix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'):
72 | return prefix(path, prefixes=datetime.datetime.now().strftime(fmt), sep=sep)
73 |
74 |
75 | def suffix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'):
76 | return suffix(path, suffixes=datetime.datetime.now().strftime(fmt), sep=sep)
77 |
78 |
79 | def glob(dir, pats, recursive=False): # faster than match, python3 only
80 | pats = pats if isinstance(pats, (list, tuple)) else [pats]
81 | matches = []
82 | for pat in pats:
83 | matches += _glob.glob(os.path.join(dir, pat), recursive=recursive)
84 | return matches
85 |
86 |
87 | def match(dir, pats, recursive=False): # slow
88 | pats = pats if isinstance(pats, (list, tuple)) else [pats]
89 |
90 | iterator = list(os.walk(dir))
91 | if not recursive:
92 | iterator = iterator[0:1]
93 |
94 | matches = []
95 | for pat in pats:
96 | for root, _, file_names in iterator:
97 | for file_name in fnmatch.filter(file_names, pat):
98 | matches.append(os.path.join(root, file_name))
99 |
100 | return matches
101 |
--------------------------------------------------------------------------------
/pylib/processing.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import functools
3 | import multiprocessing
4 |
5 |
6 | def run_parallels(work_fn, iterable, max_workers=None, chunksize=1, processing_bar=True, backend_executor=multiprocessing.Pool, debug=False):
7 | if not debug:
8 | with backend_executor(max_workers) as executor:
9 | try:
10 | works = executor.imap(work_fn, iterable, chunksize=chunksize) # for multiprocessing.Pool
11 | except:
12 | works = executor.map(work_fn, iterable, chunksize=chunksize)
13 |
14 | if processing_bar:
15 | try:
16 | import tqdm
17 | try:
18 | total = len(iterable)
19 | except:
20 | total = None
21 | works = tqdm.tqdm(works, total=total)
22 | except ImportError:
23 | print('`import tqdm` fails! Run without processing bar!')
24 |
25 | results = list(works)
26 | else:
27 | results = [work_fn(i) for i in iterable]
28 | return results
29 |
30 | run_parallels_mp = run_parallels
31 | run_parallels_cfprocess = functools.partial(run_parallels, backend_executor=concurrent.futures.ProcessPoolExecutor)
32 | run_parallels_cfthread = functools.partial(run_parallels, backend_executor=concurrent.futures.ThreadPoolExecutor)
33 |
34 |
35 | if __name__ == '__main__':
36 | import time
37 |
38 | def work(i):
39 | time.sleep(0.0001)
40 | i**i
41 | return i
42 |
43 | t = time.time()
44 | results = run_parallels_mp(work, range(10000), max_workers=2, chunksize=1, processing_bar=True, debug=False)
45 | for i in results:
46 | print(i)
47 | print(time.time() - t)
48 |
--------------------------------------------------------------------------------
/pylib/serialization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 |
5 |
6 | def _check_ext(path, default_ext):
7 | name, ext = os.path.splitext(path)
8 | if ext == '':
9 | if default_ext[0] == '.':
10 | default_ext = default_ext[1:]
11 | path = name + '.' + default_ext
12 | return path
13 |
14 |
15 | def save_json(path, obj, **kwargs):
16 | # default
17 | if 'indent' not in kwargs:
18 | kwargs['indent'] = 4
19 | if 'separators' not in kwargs:
20 | kwargs['separators'] = (',', ': ')
21 |
22 | path = _check_ext(path, 'json')
23 |
24 | # wrap json.dump
25 | with open(path, 'w') as f:
26 | json.dump(obj, f, **kwargs)
27 |
28 |
29 | def load_json(path, **kwargs):
30 | # wrap json.load
31 | with open(path) as f:
32 | return json.load(f, **kwargs)
33 |
34 |
35 | def save_yaml(path, data, **kwargs):
36 | import oyaml as yaml
37 |
38 | path = _check_ext(path, 'yml')
39 |
40 | with open(path, 'w') as f:
41 | yaml.dump(data, f, **kwargs)
42 |
43 |
44 | def load_yaml(path, **kwargs):
45 | import oyaml as yaml
46 | with open(path) as f:
47 | return yaml.load(f, **kwargs)
48 |
49 |
50 | def save_pickle(path, obj, **kwargs):
51 |
52 | path = _check_ext(path, 'pkl')
53 |
54 | # wrap pickle.dump
55 | with open(path, 'wb') as f:
56 | pickle.dump(obj, f, **kwargs)
57 |
58 |
59 | def load_pickle(path, **kwargs):
60 | # wrap pickle.load
61 | with open(path, 'rb') as f:
62 | return pickle.load(f, **kwargs)
63 |
--------------------------------------------------------------------------------
/pylib/timer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import timeit
3 |
4 |
5 | class Timer: # deprecated, use tqdm instead
6 | """A timer as a context manager.
7 |
8 | Wraps around a timer. A custom timer can be passed
9 | to the constructor. The default timer is timeit.default_timer.
10 |
11 | Note that the latter measures wall clock time, not CPU time!
12 | On Unix systems, it corresponds to time.time.
13 | On Windows systems, it corresponds to time.clock.
14 |
15 | Parameters
16 | ----------
17 | print_at_exit : boolean
18 | If True, print when exiting context.
19 | format : str
20 | `ms`, `s` or `datetime`.
21 |
22 | References
23 | ----------
24 | - https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py.
25 |
26 |
27 | """
28 |
29 | def __init__(self, fmt='s', print_at_exit=True, timer=timeit.default_timer):
30 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!"
31 | self._fmt = fmt
32 | self._print_at_exit = print_at_exit
33 | self._timer = timer
34 | self.start()
35 |
36 | def __enter__(self):
37 | """Start the timer in the context manager scope."""
38 | self.restart()
39 | return self
40 |
41 | def __exit__(self, exc_type, exc_value, exc_traceback):
42 | """Print the end time."""
43 | if self._print_at_exit:
44 | print(str(self))
45 |
46 | def __str__(self):
47 | return self.fmt(self.elapsed)[1]
48 |
49 | def start(self):
50 | self.start_time = self._timer()
51 |
52 | restart = start
53 |
54 | @property
55 | def elapsed(self):
56 | """Return the current elapsed time since last (re)start."""
57 | return self._timer() - self.start_time
58 |
59 | def fmt(self, second):
60 | if self._fmt == 'ms':
61 | time_fmt = second * 1000
62 | time_str = '%s %s' % (time_fmt, self._fmt)
63 | elif self._fmt == 's':
64 | time_fmt = second
65 | time_str = '%s %s' % (time_fmt, self._fmt)
66 | elif self._fmt == 'datetime':
67 | time_fmt = datetime.timedelta(seconds=second)
68 | time_str = str(time_fmt)
69 | return time_fmt, time_str
70 |
71 |
72 | def timeit(run_times=1, **timer_kwargs):
73 | """Function decorator displaying the function execution time.
74 |
75 | All kwargs are the arguments taken by the Timer class constructor.
76 |
77 | """
78 | # store Timer kwargs in local variable so the namespace isn't polluted
79 | # by different level args and kwargs
80 |
81 | def decorator(f):
82 | def wrapper(*args, **kwargs):
83 | timer_kwargs.update(print_at_exit=False)
84 | with Timer(**timer_kwargs) as t:
85 | for _ in range(run_times):
86 | out = f(*args, **kwargs)
87 | fmt = '[*] Execution time of function "%(function_name)s" for %(run_times)d runs is %(execution_time)s = %(execution_time_each)s * %(run_times)d [*]'
88 | context = {'function_name': f.__name__, 'run_times': run_times, 'execution_time': t, 'execution_time_each': t.fmt(t.elapsed / run_times)[1]}
89 | print(fmt % context)
90 | return out
91 | return wrapper
92 |
93 | return decorator
94 |
95 |
96 | if __name__ == "__main__":
97 | import time
98 |
99 | # 1
100 | print(1)
101 | with Timer() as t:
102 | time.sleep(1)
103 | print(t)
104 | time.sleep(1)
105 |
106 | with Timer(fmt='datetime') as t:
107 | time.sleep(1)
108 |
109 | # 2
110 | print(2)
111 | t = Timer(fmt='ms')
112 | time.sleep(2)
113 | print(t)
114 |
115 | t = Timer(fmt='datetime')
116 | time.sleep(1)
117 | print(t)
118 |
119 | # 3
120 | print(3)
121 |
122 | @timeit(run_times=5, fmt='s')
123 | def blah():
124 | time.sleep(2)
125 |
126 | blah()
127 |
--------------------------------------------------------------------------------
/results.md:
--------------------------------------------------------------------------------
1 | # [AttGAN](https://ieeexplore.ieee.org/document/8718508?source=authoralert) Results
2 |
3 | - Results of **40**-attribute model!!! It's amazing that AttGAN still works although some attributes are not good enough
4 |
5 | from left to right: *Input, Reconstruction, 5_o_Clock_Shadow, Arched_Eyebrows, Attractive, Bags_Under_Eyes, Bald, Bangs, Big_Lips, Big_Nose, Black_Hair, Blond_Hair, Blurry, Brown_Hair, Bushy_Eyebrows, Chubby, Double_Chin, Eyeglasses, Goatee, Gray_Hair, Heavy_Makeup, High_Cheekbones, Male, Mouth_Slightly_Open, Mustache, Narrow_Eyes, No_Beard, Oval_Face, Pale_Skin, Pointy_Nose, Receding_Hairline, Rosy_Cheeks, Sideburns, Smiling, Straight_Hair, Wavy_Hair, Wearing_Earrings, Wearing_Hat, Wearing_Lipstick, Wearing_Necklace, Wearing_Necktie, Young*
6 |
7 |
8 |
9 | - Results of **256x256**-input model
10 |
11 | from left to right: *Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young*
12 |
13 |
14 |
15 | - Results of **384x384**-input model
16 |
17 | from left to right: *Input, Reconstruction, Bald, Bangs, Black_Hair, Blond_Hair, Brown_Hair, Bushy_Eyebrows, Eyeglasses, Male, Mouth_Slightly_Open, Mustache, No_Beard, Pale_Skin, Young*
18 |
19 |
--------------------------------------------------------------------------------
/scripts/align.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | from functools import partial
7 | from multiprocessing import Pool
8 | import os
9 | import re
10 |
11 | import cropper
12 | import numpy as np
13 | import tqdm
14 |
15 |
16 | # ==============================================================================
17 | # = param =
18 | # ==============================================================================
19 |
20 | parser = argparse.ArgumentParser()
21 | # main
22 | parser.add_argument('--img_dir', dest='img_dir', default='./data/img_celeba/img_celeba')
23 | parser.add_argument('--save_dir', dest='save_dir', default='./data/img_celeba/aligned')
24 | parser.add_argument('--landmark_file', dest='landmark_file', default='./data/img_celeba/landmark.txt')
25 | parser.add_argument('--standard_landmark_file', dest='standard_landmark_file', default='./data/img_celeba/standard_landmark_68pts.txt')
26 | parser.add_argument('--crop_size_h', dest='crop_size_h', type=int, default=572)
27 | parser.add_argument('--crop_size_w', dest='crop_size_w', type=int, default=572)
28 | parser.add_argument('--move_h', dest='move_h', type=float, default=0.25)
29 | parser.add_argument('--move_w', dest='move_w', type=float, default=0.)
30 | parser.add_argument('--save_format', dest='save_format', choices=['jpg', 'png'], default='jpg')
31 | parser.add_argument('--n_worker', dest='n_worker', type=int, default=8)
32 | # others
33 | parser.add_argument('--face_factor', dest='face_factor', type=float, help='The factor of face area relative to the output image.', default=0.45)
34 | parser.add_argument('--align_type', dest='align_type', choices=['affine', 'similarity'], default='similarity')
35 | parser.add_argument('--order', dest='order', type=int, choices=[0, 1, 2, 3, 4, 5], help='The order of interpolation.', default=3)
36 | parser.add_argument('--mode', dest='mode', choices=['constant', 'edge', 'symmetric', 'reflect', 'wrap'], default='edge')
37 | args = parser.parse_args()
38 |
39 |
40 | # ==============================================================================
41 | # = opencv first =
42 | # ==============================================================================
43 |
44 | _DEAFAULT_JPG_QUALITY = 95
45 | try:
46 | import cv2
47 | imread = cv2.imread
48 | imwrite = partial(cv2.imwrite, params=[int(cv2.IMWRITE_JPEG_QUALITY), _DEAFAULT_JPG_QUALITY])
49 | align_crop = cropper.align_crop_opencv
50 | print('Use OpenCV')
51 | except:
52 | import skimage.io as io
53 | imread = io.imread
54 | imwrite = partial(io.imsave, quality=_DEAFAULT_JPG_QUALITY)
55 | align_crop = cropper.align_crop_skimage
56 | print('Importing OpenCv fails. Use scikit-image')
57 |
58 |
59 | # ==============================================================================
60 | # = run =
61 | # ==============================================================================
62 |
63 | # count landmarks
64 | with open(args.landmark_file) as f:
65 | line = f.readline()
66 | n_landmark = len(re.split('[ ]+', line)[1:]) // 2
67 |
68 | # read data
69 | img_names = np.genfromtxt(args.landmark_file, dtype=np.str, usecols=0)
70 | landmarks = np.genfromtxt(args.landmark_file, dtype=np.float, usecols=range(1, n_landmark * 2 + 1)).reshape(-1, n_landmark, 2)
71 | standard_landmark = np.genfromtxt(args.standard_landmark_file, dtype=np.float).reshape(n_landmark, 2)
72 | standard_landmark[:, 0] += args.move_w
73 | standard_landmark[:, 1] += args.move_h
74 |
75 | # data dir
76 | save_dir = os.path.join(args.save_dir, 'align_size(%d,%d)_move(%.3f,%.3f)_face_factor(%.3f)_%s' % (args.crop_size_h, args.crop_size_w, args.move_h, args.move_w, args.face_factor, args.save_format))
77 | data_dir = os.path.join(save_dir, 'data')
78 | if not os.path.isdir(data_dir):
79 | os.makedirs(data_dir)
80 |
81 |
82 | def work(i): # a single work
83 | for _ in range(3): # try three times
84 | try:
85 | img = imread(os.path.join(args.img_dir, img_names[i]))
86 | img_crop, tformed_landmarks = align_crop(img,
87 | landmarks[i],
88 | standard_landmark,
89 | crop_size=(args.crop_size_h, args.crop_size_w),
90 | face_factor=args.face_factor,
91 | align_type=args.align_type,
92 | order=args.order,
93 | mode=args.mode)
94 |
95 | name = os.path.splitext(img_names[i])[0] + '.' + args.save_format
96 | path = os.path.join(data_dir, name)
97 | if not os.path.isdir(os.path.split(path)[0]):
98 | os.makedirs(os.path.split(path)[0])
99 | imwrite(path, img_crop)
100 |
101 | tformed_landmarks.shape = -1
102 | name_landmark_str = ('%s' + ' %.1f' * n_landmark * 2) % ((name, ) + tuple(tformed_landmarks))
103 | succeed = True
104 | break
105 | except:
106 | succeed = False
107 | if succeed:
108 | return name_landmark_str
109 | else:
110 | print('%s fails!' % img_names[i])
111 |
112 |
113 | if __name__ == '__main__':
114 | pool = Pool(args.n_worker)
115 | name_landmark_strs = list(tqdm.tqdm(pool.imap(work, range(len(img_names))), total=len(img_names)))
116 | pool.close()
117 | pool.join()
118 |
119 | landmarks_path = os.path.join(save_dir, 'landmark.txt')
120 | with open(landmarks_path, 'w') as f:
121 | for name_landmark_str in name_landmark_strs:
122 | if name_landmark_str:
123 | f.write(name_landmark_str + '\n')
124 |
--------------------------------------------------------------------------------
/scripts/cropper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def align_crop_opencv(img,
5 | src_landmarks,
6 | standard_landmarks,
7 | crop_size=512,
8 | face_factor=0.7,
9 | align_type='similarity',
10 | order=3,
11 | mode='edge'):
12 | """Align and crop a face image by landmarks.
13 |
14 | Arguments:
15 | img : Face image to be aligned and cropped.
16 | src_landmarks : [[x_1, y_1], ..., [x_n, y_n]].
17 | standard_landmarks : Standard shape, should be normalized.
18 | crop_size : Output image size, should be 1. int for (crop_size, crop_size)
19 | or 2. (int, int) for (crop_size_h, crop_size_w).
20 | face_factor : The factor of face area relative to the output image.
21 | align_type : 'similarity' or 'affine'.
22 | order : The order of interpolation. The order has to be in the range 0-5:
23 | - 0: INTER_NEAREST
24 | - 1: INTER_LINEAR
25 | - 2: INTER_AREA
26 | - 3: INTER_CUBIC
27 | - 4: INTER_LANCZOS4
28 | - 5: INTER_LANCZOS4
29 | mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap'].
30 | Points outside the boundaries of the input are filled according
31 | to the given mode.
32 | """
33 | # set OpenCV
34 | import cv2
35 | inter = {0: cv2.INTER_NEAREST, 1: cv2.INTER_LINEAR, 2: cv2.INTER_AREA,
36 | 3: cv2.INTER_CUBIC, 4: cv2.INTER_LANCZOS4, 5: cv2.INTER_LANCZOS4}
37 | border = {'constant': cv2.BORDER_CONSTANT, 'edge': cv2.BORDER_REPLICATE,
38 | 'symmetric': cv2.BORDER_REFLECT, 'reflect': cv2.BORDER_REFLECT101,
39 | 'wrap': cv2.BORDER_WRAP}
40 |
41 | # check
42 | assert align_type in ['affine', 'similarity'], 'Invalid `align_type`! Allowed: %s!' % ['affine', 'similarity']
43 | assert order in [0, 1, 2, 3, 4, 5], 'Invalid `order`! Allowed: %s!' % [0, 1, 2, 3, 4, 5]
44 | assert mode in ['constant', 'edge', 'symmetric', 'reflect', 'wrap'], 'Invalid `mode`! Allowed: %s!' % ['constant', 'edge', 'symmetric', 'reflect', 'wrap']
45 |
46 | # crop size
47 | if isinstance(crop_size, (list, tuple)) and len(crop_size) == 2:
48 | crop_size_h = crop_size[0]
49 | crop_size_w = crop_size[1]
50 | elif isinstance(crop_size, int):
51 | crop_size_h = crop_size_w = crop_size
52 | else:
53 | raise Exception('Invalid `crop_size`! `crop_size` should be 1. int for (crop_size, crop_size) or 2. (int, int) for (crop_size_h, crop_size_w)!')
54 |
55 | # estimate transform matrix
56 | trg_landmarks = standard_landmarks * max(crop_size_h, crop_size_w) * face_factor + np.array([crop_size_w // 2, crop_size_h // 2])
57 | if align_type == 'affine':
58 | tform = cv2.estimateAffine2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0]
59 | else:
60 | tform = cv2.estimateAffinePartial2D(trg_landmarks, src_landmarks, ransacReprojThreshold=np.Inf)[0]
61 |
62 | # warp image by given transform
63 | output_shape = (crop_size_h, crop_size_w)
64 | img_crop = cv2.warpAffine(img, tform, output_shape[::-1], flags=cv2.WARP_INVERSE_MAP + inter[order], borderMode=border[mode])
65 |
66 | # get transformed landmarks
67 | tformed_landmarks = cv2.transform(np.expand_dims(src_landmarks, axis=0), cv2.invertAffineTransform(tform))[0]
68 |
69 | return img_crop, tformed_landmarks
70 |
71 |
72 | def align_crop_skimage(img,
73 | src_landmarks,
74 | standard_landmarks,
75 | crop_size=512,
76 | face_factor=0.7,
77 | align_type='similarity',
78 | order=3,
79 | mode='edge'):
80 | """Align and crop a face image by landmarks.
81 |
82 | Arguments:
83 | img : Face image to be aligned and cropped.
84 | src_landmarks : [[x_1, y_1], ..., [x_n, y_n]].
85 | standard_landmarks : Standard shape, should be normalized.
86 | crop_size : Output image size, should be 1. int for (crop_size, crop_size)
87 | or 2. (int, int) for (crop_size_h, crop_size_w).
88 | face_factor : The factor of face area relative to the output image.
89 | align_type : 'similarity' or 'affine'.
90 | order : The order of interpolation. The order has to be in the range 0-5:
91 | - 0: INTER_NEAREST
92 | - 1: INTER_LINEAR
93 | - 2: INTER_AREA
94 | - 3: INTER_CUBIC
95 | - 4: INTER_LANCZOS4
96 | - 5: INTER_LANCZOS4
97 | mode : One of ['constant', 'edge', 'symmetric', 'reflect', 'wrap'].
98 | Points outside the boundaries of the input are filled according
99 | to the given mode.
100 | """
101 | raise NotImplementedError("'align_crop_skimage' is not implemented!")
102 |
--------------------------------------------------------------------------------
/scripts/split_CelebA-HQ.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | # ==============================================================================
5 | # = param =
6 | # ==============================================================================
7 |
8 | label_file = './data/CelebAMask-HQ/CelebAMask-HQ-attribute-anno.txt'
9 | save_dir = './data/CelebAMask-HQ'
10 |
11 |
12 | # ==============================================================================
13 | # = run =
14 | # ==============================================================================
15 |
16 | with open(label_file, 'r') as f:
17 | lines = f.readlines()[2:]
18 |
19 | random.seed(100)
20 | random.shuffle(lines)
21 |
22 | lines_train = lines[:26500]
23 | lines_val = lines[26500:27000]
24 | lines_test = lines[27000:]
25 |
26 | with open(os.path.join(save_dir, 'train_label.txt'), 'w') as f:
27 | f.writelines(lines_train)
28 |
29 | with open(os.path.join(save_dir, 'val_label.txt'), 'w') as f:
30 | f.writelines(lines_val)
31 |
32 | with open(os.path.join(save_dir, 'test_label.txt'), 'w') as f:
33 | f.writelines(lines_test)
34 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import imlib as im
4 | import numpy as np
5 | import pylib as py
6 | import tensorflow as tf
7 | import tflib as tl
8 | import tqdm
9 |
10 | import data
11 | import module
12 |
13 |
14 | # ==============================================================================
15 | # = param =
16 | # ==============================================================================
17 |
18 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data')
19 | py.arg('--test_label_path', default='./data/img_celeba/test_label.txt')
20 | py.arg('--test_int', type=float, default=2)
21 |
22 |
23 | py.arg('--experiment_name', default='default')
24 | args_ = py.args()
25 |
26 | # output_dir
27 | output_dir = py.join('output', args_.experiment_name)
28 |
29 | # save settings
30 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml'))
31 | args.__dict__.update(args_.__dict__)
32 |
33 | # others
34 | n_atts = len(args.att_names)
35 |
36 | sess = tl.session()
37 | sess.__enter__() # make default
38 |
39 |
40 | # ==============================================================================
41 | # = data and model =
42 | # ==============================================================================
43 |
44 | # data
45 | test_dataset, len_test_dataset = data.make_celeba_dataset(args.img_dir, args.test_label_path, args.att_names, args.n_samples,
46 | load_size=args.load_size, crop_size=args.crop_size,
47 | training=False, drop_remainder=False, shuffle=False, repeat=None)
48 | test_iter = test_dataset.make_one_shot_iterator()
49 |
50 |
51 | # ==============================================================================
52 | # = graph =
53 | # ==============================================================================
54 |
55 | def sample_graph():
56 | # ======================================
57 | # = graph =
58 | # ======================================
59 |
60 | test_next = test_iter.get_next()
61 |
62 | if not os.path.exists(py.join(output_dir, 'generator.pb')):
63 | # model
64 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay)
65 |
66 | # placeholders & inputs
67 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
68 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts])
69 |
70 | # sample graph
71 | x = Gdec(Genc(xa, training=False), b_, training=False)
72 | else:
73 | # load freezed model
74 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'rb') as f:
75 | graph_def = tf.GraphDef()
76 | graph_def.ParseFromString(f.read())
77 | tf.import_graph_def(graph_def, name='generator')
78 |
79 | # placeholders & inputs
80 | xa = sess.graph.get_tensor_by_name('generator/xa:0')
81 | b_ = sess.graph.get_tensor_by_name('generator/b_:0')
82 |
83 | # sample graph
84 | x = sess.graph.get_tensor_by_name('generator/xb:0')
85 |
86 | # ======================================
87 | # = run function =
88 | # ======================================
89 |
90 | save_dir = './output/%s/samples_testing_%s' % (args.experiment_name, '{:g}'.format(args.test_int))
91 | py.mkdir(save_dir)
92 |
93 | def run():
94 | cnt = 0
95 | for _ in tqdm.trange(len_test_dataset):
96 | # data for sampling
97 | xa_ipt, a_ipt = sess.run(test_next)
98 | b_ipt_list = [a_ipt] # the first is for reconstruction
99 | for i in range(n_atts):
100 | tmp = np.array(a_ipt, copy=True)
101 | tmp[:, i] = 1 - tmp[:, i] # inverse attribute
102 | tmp = data.check_attribute_conflict(tmp, args.att_names[i], args.att_names)
103 | b_ipt_list.append(tmp)
104 |
105 | x_opt_list = [xa_ipt]
106 | for i, b_ipt in enumerate(b_ipt_list):
107 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!!
108 | if i > 0: # i == 0 is for reconstruction
109 | b__ipt[..., i - 1] = b__ipt[..., i - 1] * args.test_int
110 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt})
111 | x_opt_list.append(x_opt)
112 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4))
113 | sample = np.reshape(sample, (sample.shape[0], -1, sample.shape[2] * sample.shape[3], sample.shape[4]))
114 |
115 | for s in sample:
116 | cnt += 1
117 | im.imwrite(s, '%s/%d.jpg' % (save_dir, cnt))
118 |
119 | return run
120 |
121 |
122 | sample = sample_graph()
123 |
124 |
125 | # ==============================================================================
126 | # = test =
127 | # ==============================================================================
128 |
129 | # checkpoint
130 | if not os.path.exists(py.join(output_dir, 'generator.pb')):
131 | checkpoint = tl.Checkpoint(
132 | {v.name: v for v in tf.global_variables()},
133 | py.join(output_dir, 'checkpoints'),
134 | max_to_keep=1
135 | )
136 | checkpoint.restore().run_restore_ops()
137 |
138 | sample()
139 |
140 | sess.close()
141 |
--------------------------------------------------------------------------------
/test_multi.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import imlib as im
4 | import numpy as np
5 | import pylib as py
6 | import tensorflow as tf
7 | import tflib as tl
8 | import tqdm
9 |
10 | import data
11 | import module
12 |
13 |
14 | # ==============================================================================
15 | # = param =
16 | # ==============================================================================
17 |
18 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data')
19 | py.arg('--test_label_path', default='./data/img_celeba/test_label.txt')
20 | py.arg('--test_att_names', choices=data.ATT_ID.keys(), nargs='+', default=['Bangs', 'Mustache'])
21 | py.arg('--test_ints', type=float, nargs='+', default=2)
22 |
23 | py.arg('--experiment_name', default='default')
24 | args_ = py.args()
25 |
26 | # output_dir
27 | output_dir = py.join('output', args_.experiment_name)
28 |
29 | # save settings
30 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml'))
31 | args.__dict__.update(args_.__dict__)
32 |
33 | # others
34 | n_atts = len(args.att_names)
35 | if not isinstance(args.test_ints, list):
36 | args.test_ints = [args.test_ints] * len(args.test_att_names)
37 | elif len(args.test_ints) == 1:
38 | args.test_ints = args.test_ints * len(args.test_att_names)
39 |
40 | sess = tl.session()
41 | sess.__enter__() # make default
42 |
43 |
44 | # ==============================================================================
45 | # = data and model =
46 | # ==============================================================================
47 |
48 | # data
49 | test_dataset, len_test_dataset = data.make_celeba_dataset(args.img_dir, args.test_label_path, args.att_names, args.n_samples,
50 | load_size=args.load_size, crop_size=args.crop_size,
51 | training=False, drop_remainder=False, shuffle=False, repeat=None)
52 | test_iter = test_dataset.make_one_shot_iterator()
53 |
54 |
55 | # ==============================================================================
56 | # = graph =
57 | # ==============================================================================
58 |
59 | def sample_graph():
60 | # ======================================
61 | # = graph =
62 | # ======================================
63 |
64 | test_next = test_iter.get_next()
65 |
66 | if not os.path.exists(py.join(output_dir, 'generator.pb')):
67 | # model
68 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay)
69 |
70 | # placeholders & inputs
71 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
72 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts])
73 |
74 | # sample graph
75 | x = Gdec(Genc(xa, training=False), b_, training=False)
76 | else:
77 | # load freezed model
78 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'rb') as f:
79 | graph_def = tf.GraphDef()
80 | graph_def.ParseFromString(f.read())
81 | tf.import_graph_def(graph_def, name='generator')
82 |
83 | # placeholders & inputs
84 | xa = sess.graph.get_tensor_by_name('generator/xa:0')
85 | b_ = sess.graph.get_tensor_by_name('generator/b_:0')
86 |
87 | # sample graph
88 | x = sess.graph.get_tensor_by_name('generator/xb:0')
89 |
90 | # ======================================
91 | # = run function =
92 | # ======================================
93 |
94 | save_dir = './output/%s/samples_testing_multi' % args.experiment_name
95 | tmp = ''
96 | for test_att_name, test_int in zip(args.test_att_names, args.test_ints):
97 | tmp += '_%s_%s' % (test_att_name, '{:g}'.format(test_int))
98 | save_dir = py.join(save_dir, tmp[1:])
99 | py.mkdir(save_dir)
100 |
101 | def run():
102 | cnt = 0
103 | for _ in tqdm.trange(len_test_dataset):
104 | # data for sampling
105 | xa_ipt, a_ipt = sess.run(test_next)
106 | b_ipt = np.copy(a_ipt)
107 | for test_att_name in args.test_att_names:
108 | i = args.att_names.index(test_att_name)
109 | b_ipt[..., i] = 1 - b_ipt[..., i]
110 | b_ipt = data.check_attribute_conflict(b_ipt, test_att_name, args.att_names)
111 |
112 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!!
113 | for test_att_name, test_int in zip(args.test_att_names, args.test_ints):
114 | i = args.att_names.index(test_att_name)
115 | b__ipt[..., i] = b__ipt[..., i] * test_int
116 |
117 | x_opt_list = [xa_ipt]
118 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt})
119 | x_opt_list.append(x_opt)
120 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4))
121 | sample = np.reshape(sample, (sample.shape[0], -1, sample.shape[2] * sample.shape[3], sample.shape[4]))
122 |
123 | for s in sample:
124 | cnt += 1
125 | im.imwrite(s, '%s/%d.jpg' % (save_dir, cnt))
126 |
127 | return run
128 |
129 |
130 | sample = sample_graph()
131 |
132 |
133 | # ==============================================================================
134 | # = test =
135 | # ==============================================================================
136 |
137 | # checkpoint
138 | if not os.path.exists(py.join(output_dir, 'generator.pb')):
139 | checkpoint = tl.Checkpoint(
140 | {v.name: v for v in tf.global_variables()},
141 | py.join(output_dir, 'checkpoints'),
142 | max_to_keep=1
143 | )
144 | checkpoint.restore().run_restore_ops()
145 |
146 | sample()
147 |
148 | sess.close()
149 |
--------------------------------------------------------------------------------
/test_slide.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import imlib as im
4 | import numpy as np
5 | import pylib as py
6 | import tensorflow as tf
7 | import tflib as tl
8 | import tqdm
9 |
10 | import data
11 | import module
12 |
13 |
14 | # ==============================================================================
15 | # = param =
16 | # ==============================================================================
17 |
18 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data')
19 | py.arg('--test_label_path', default='./data/img_celeba/test_label.txt')
20 | py.arg('--test_att_name', choices=data.ATT_ID.keys(), default='Pale_Skin')
21 | py.arg('--test_int_min', type=float, default=-2)
22 | py.arg('--test_int_max', type=float, default=2)
23 | py.arg('--test_int_step', type=float, default=0.5)
24 |
25 | py.arg('--experiment_name', default='default')
26 | args_ = py.args()
27 |
28 | # output_dir
29 | output_dir = py.join('output', args_.experiment_name)
30 |
31 | # save settings
32 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml'))
33 | args.__dict__.update(args_.__dict__)
34 |
35 | # others
36 | n_atts = len(args.att_names)
37 |
38 | sess = tl.session()
39 | sess.__enter__() # make default
40 |
41 |
42 | # ==============================================================================
43 | # = data and model =
44 | # ==============================================================================
45 |
46 | # data
47 | test_dataset, len_test_dataset = data.make_celeba_dataset(args.img_dir, args.test_label_path, args.att_names, args.n_samples,
48 | load_size=args.load_size, crop_size=args.crop_size,
49 | training=False, drop_remainder=False, shuffle=False, repeat=None)
50 | test_iter = test_dataset.make_one_shot_iterator()
51 |
52 |
53 | # ==============================================================================
54 | # = graph =
55 | # ==============================================================================
56 |
57 | def sample_graph():
58 | # ======================================
59 | # = graph =
60 | # ======================================
61 |
62 | test_next = test_iter.get_next()
63 |
64 | if not os.path.exists(py.join(output_dir, 'generator.pb')):
65 | # model
66 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay)
67 |
68 | # placeholders & inputs
69 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
70 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts])
71 |
72 | # sample graph
73 | x = Gdec(Genc(xa, training=False), b_, training=False)
74 | else:
75 | # load freezed model
76 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'rb') as f:
77 | graph_def = tf.GraphDef()
78 | graph_def.ParseFromString(f.read())
79 | tf.import_graph_def(graph_def, name='generator')
80 |
81 | # placeholders & inputs
82 | xa = sess.graph.get_tensor_by_name('generator/xa:0')
83 | b_ = sess.graph.get_tensor_by_name('generator/b_:0')
84 |
85 | # sample graph
86 | x = sess.graph.get_tensor_by_name('generator/xb:0')
87 |
88 | # ======================================
89 | # = run function =
90 | # ======================================
91 |
92 | save_dir = './output/%s/samples_testing_slide/%s_%s_%s_%s' % \
93 | (args.experiment_name, args.test_att_name, '{:g}'.format(args.test_int_min), '{:g}'.format(args.test_int_max), '{:g}'.format(args.test_int_step))
94 | py.mkdir(save_dir)
95 |
96 | def run():
97 | cnt = 0
98 | for _ in tqdm.trange(len_test_dataset):
99 | # data for sampling
100 | xa_ipt, a_ipt = sess.run(test_next)
101 | b_ipt = np.copy(a_ipt)
102 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!!
103 |
104 | x_opt_list = [xa_ipt]
105 | for test_int in np.arange(args.test_int_min, args.test_int_max + 1e-5, args.test_int_step):
106 | b__ipt[:, args.att_names.index(args.test_att_name)] = test_int
107 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt})
108 | x_opt_list.append(x_opt)
109 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4))
110 | sample = np.reshape(sample, (sample.shape[0], -1, sample.shape[2] * sample.shape[3], sample.shape[4]))
111 |
112 | for s in sample:
113 | cnt += 1
114 | im.imwrite(s, '%s/%d.jpg' % (save_dir, cnt))
115 |
116 | return run
117 |
118 |
119 | sample = sample_graph()
120 |
121 |
122 | # ==============================================================================
123 | # = test =
124 | # ==============================================================================
125 |
126 | # checkpoint
127 | if not os.path.exists(py.join(output_dir, 'generator.pb')):
128 | checkpoint = tl.Checkpoint(
129 | {v.name: v for v in tf.global_variables()},
130 | py.join(output_dir, 'checkpoints'),
131 | max_to_keep=1
132 | )
133 | checkpoint.restore().run_restore_ops()
134 |
135 | sample()
136 |
137 | sess.close()
138 |
--------------------------------------------------------------------------------
/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.data import *
2 | from tflib.image import *
3 | from tflib.layers import *
4 | from tflib.losses import *
5 | from tflib.metrics import *
6 | from tflib.ops import *
7 | from tflib.utils import *
8 |
--------------------------------------------------------------------------------
/tflib/data/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.data.dataset import *
2 |
--------------------------------------------------------------------------------
/tflib/data/dataset.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 |
3 | import tensorflow as tf
4 |
5 |
6 | def batch_dataset(dataset,
7 | batch_size,
8 | drop_remainder=True,
9 | n_prefetch_batch=1,
10 | filter_fn=None,
11 | map_fn=None,
12 | n_map_threads=None,
13 | filter_after_map=False,
14 | shuffle=True,
15 | shuffle_buffer_size=None,
16 | repeat=None):
17 | # set defaults
18 | if n_map_threads is None:
19 | n_map_threads = multiprocessing.cpu_count()
20 | if shuffle and shuffle_buffer_size is None:
21 | shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048
22 |
23 | # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
24 | if shuffle:
25 | dataset = dataset.shuffle(shuffle_buffer_size)
26 |
27 | if not filter_after_map:
28 | if filter_fn:
29 | dataset = dataset.filter(filter_fn)
30 |
31 | if map_fn:
32 | dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
33 |
34 | else: # [*] this is slower
35 | if map_fn:
36 | dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
37 |
38 | if filter_fn:
39 | dataset = dataset.filter(filter_fn)
40 |
41 | dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
42 |
43 | dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
44 |
45 | return dataset
46 |
47 |
48 | def memory_data_batch_dataset(memory_data,
49 | batch_size,
50 | drop_remainder=True,
51 | n_prefetch_batch=1,
52 | filter_fn=None,
53 | map_fn=None,
54 | n_map_threads=None,
55 | filter_after_map=False,
56 | shuffle=True,
57 | shuffle_buffer_size=None,
58 | repeat=None):
59 | """Batch dataset of memory data.
60 |
61 | Parameters
62 | ----------
63 | memory_data : nested structure of tensors/ndarrays/lists
64 |
65 | """
66 | dataset = tf.data.Dataset.from_tensor_slices(memory_data)
67 | dataset = batch_dataset(dataset,
68 | batch_size,
69 | drop_remainder=drop_remainder,
70 | n_prefetch_batch=n_prefetch_batch,
71 | filter_fn=filter_fn,
72 | map_fn=map_fn,
73 | n_map_threads=n_map_threads,
74 | filter_after_map=filter_after_map,
75 | shuffle=shuffle,
76 | shuffle_buffer_size=shuffle_buffer_size,
77 | repeat=repeat)
78 | return dataset
79 |
80 |
81 | def disk_image_batch_dataset(img_paths,
82 | batch_size,
83 | labels=None,
84 | drop_remainder=True,
85 | n_prefetch_batch=1,
86 | filter_fn=None,
87 | map_fn=None,
88 | n_map_threads=None,
89 | filter_after_map=False,
90 | shuffle=True,
91 | shuffle_buffer_size=None,
92 | repeat=None):
93 | """Batch dataset of disk image for PNG and JPEG.
94 |
95 | Parameters
96 | ----------
97 | img_paths : 1d-tensor/ndarray/list of str
98 | labels : nested structure of tensors/ndarrays/lists
99 |
100 | """
101 | if labels is None:
102 | memory_data = img_paths
103 | else:
104 | memory_data = (img_paths, labels)
105 |
106 | def parse_fn(path, *label):
107 | img = tf.io.read_file(path)
108 | img = tf.image.decode_png(img, 3) # fix channels to 3
109 | return (img,) + label
110 |
111 | if map_fn: # fuse `map_fn` and `parse_fn`
112 | def map_fn_(*args):
113 | return map_fn(*parse_fn(*args))
114 | else:
115 | map_fn_ = parse_fn
116 |
117 | dataset = memory_data_batch_dataset(memory_data,
118 | batch_size,
119 | drop_remainder=drop_remainder,
120 | n_prefetch_batch=n_prefetch_batch,
121 | filter_fn=filter_fn,
122 | map_fn=map_fn_,
123 | n_map_threads=n_map_threads,
124 | filter_after_map=filter_after_map,
125 | shuffle=shuffle,
126 | shuffle_buffer_size=shuffle_buffer_size,
127 | repeat=repeat)
128 |
129 | return dataset
130 |
--------------------------------------------------------------------------------
/tflib/image/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.image.image import *
2 | from tflib.image.filter import *
3 |
--------------------------------------------------------------------------------
/tflib/image/filter.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def gaussian_kernel2d(kernel_radias, std):
5 | d = tf.distributions.Normal(0.0, float(std))
6 | vals = d.prob(tf.range(start=-kernel_radias, limit=kernel_radias + 1, dtype=tf.float32))
7 | kernel = vals[:, None] * vals[None, :]
8 | kernel /= tf.reduce_sum(kernel)
9 | return kernel
10 |
11 |
12 | def filter2d(image, kernel, padding, data_format=None):
13 | kernel = kernel[:, :, None, None]
14 | if data_format is None or data_format == "NHWC":
15 | kernel = tf.tile(kernel, [1, 1, image.shape[3], 1])
16 | elif data_format == "NCHW":
17 | kernel = tf.tile(kernel, [1, 1, image.shape[1], 1])
18 | return tf.nn.depthwise_conv2d(image, kernel, strides=[1, 1, 1, 1], padding=padding, data_format=data_format)
19 |
20 |
21 | def gaussian_filter2d(image, kernel_radias, std, padding, data_format=None):
22 | kernel = gaussian_kernel2d(kernel_radias, std)
23 | return filter2d(image, kernel, padding, data_format=None)
24 |
--------------------------------------------------------------------------------
/tflib/image/image.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import math
3 | import random
4 |
5 | import tensorflow as tf
6 |
7 |
8 | def center_crop(image, size):
9 | # for image of shape [batch, height, width, channels] or [height, width, channels]
10 | if not isinstance(size, (tuple, list)):
11 | size = [size, size]
12 | offset_height = (tf.shape(image)[-3] - size[0]) // 2
13 | offset_width = (tf.shape(image)[-2] - size[1]) // 2
14 | return tf.image.crop_to_bounding_box(image, offset_height, offset_width, size[0], size[1])
15 |
16 |
17 | def color_jitter(image, brightness=0, contrast=0, saturation=0, hue=0):
18 | """Color jitter.
19 |
20 | Examples
21 | --------
22 | >>> color_jitter(img, 25, 0.2, 0.2, 0.1)
23 |
24 | """
25 | tforms = []
26 | if brightness > 0:
27 | tforms.append(functools.partial(tf.image.random_brightness, max_delta=brightness))
28 | if contrast > 0:
29 | tforms.append(functools.partial(tf.image.random_contrast, lower=max(0, 1 - contrast), upper=1 + contrast))
30 | if saturation > 0:
31 | tforms.append(functools.partial(tf.image.random_saturation, lower=max(0, 1 - saturation), upper=1 + saturation))
32 | if hue > 0:
33 | tforms.append(functools.partial(tf.image.random_hue, max_delta=hue))
34 |
35 | random.shuffle(tforms)
36 | for tform in tforms:
37 | image = tform(image)
38 |
39 | return image
40 |
41 |
42 | def random_grayscale(image, p=0.1):
43 | return tf.cond(pred=tf.random.uniform(()) < p,
44 | true_fn=lambda: tf.image.adjust_saturation(image, 0),
45 | false_fn=lambda: image)
46 |
47 |
48 | def random_rotate(images, max_degrees, interpolation='BILINEAR'):
49 | # Randomly rotate image(s) counterclockwise by the angle(s) uniformly chosen from [-max_degree(s), max_degree(s)].
50 | max_degrees = tf.convert_to_tensor(max_degrees, dtype=tf.float32)
51 | angles = tf.random.uniform(tf.shape(max_degrees), minval=-1.0, maxval=1.0) * max_degrees / 180.0 * math.pi
52 | return tf.contrib.image.rotate(images, angles, interpolation=interpolation)
53 |
--------------------------------------------------------------------------------
/tflib/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.layers.layers import *
2 | from tflib.layers.layers_slim import *
3 |
--------------------------------------------------------------------------------
/tflib/layers/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def adaptive_instance_normalization(x, gamma, beta, epsilon=1e-5):
5 | # modified from https://github.com/taki0112/MUNIT-Tensorflow/blob/master/ops.py
6 | # x: (N, H, W, C), gamma: (N, 1, 1, C), beta: (N, 1, 1, C)
7 |
8 | c_mean, c_var = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
9 | c_std = tf.sqrt(c_var + epsilon)
10 |
11 | return gamma * ((x - c_mean) / c_std) + beta
12 |
--------------------------------------------------------------------------------
/tflib/layers/layers_slim.py:
--------------------------------------------------------------------------------
1 | # functions compatible with tensorflow.contrib
2 |
3 | import six
4 |
5 | import tensorflow as tf
6 |
7 | from tensorflow.contrib.framework.python.ops import add_arg_scope
8 | from tensorflow.contrib.framework.python.ops import variables
9 | from tensorflow.contrib.layers.python.layers import initializers
10 | from tensorflow.contrib.layers.python.layers import utils
11 |
12 | from tensorflow.python.framework import ops
13 | from tensorflow.python.ops import array_ops
14 | from tensorflow.python.ops import init_ops
15 | from tensorflow.python.ops import nn
16 | from tensorflow.python.ops import standard_ops
17 | from tensorflow.python.ops import variable_scope
18 |
19 |
20 | @add_arg_scope
21 | def fully_connected(inputs,
22 | num_outputs,
23 | activation_fn=nn.relu,
24 | normalizer_fn=None,
25 | normalizer_params=None,
26 | weights_normalizer_fn=None,
27 | weights_normalizer_params=None,
28 | weights_initializer=initializers.xavier_initializer(),
29 | weights_regularizer=None,
30 | biases_initializer=init_ops.zeros_initializer(),
31 | biases_regularizer=None,
32 | reuse=None,
33 | variables_collections=None,
34 | outputs_collections=None,
35 | trainable=True,
36 | scope=None):
37 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.fully_connected,
38 | # add weights_nomalizer_* options.
39 | """Adds a fully connected layer.
40 |
41 | `fully_connected` creates a variable called `weights`, representing a fully
42 | connected weight matrix, which is multiplied by the `inputs` to produce a
43 | `Tensor` of hidden units. If a `normalizer_fn` is provided (such as
44 | `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
45 | None and a `biases_initializer` is provided then a `biases` variable would be
46 | created and added the hidden units. Finally, if `activation_fn` is not `None`,
47 | it is applied to the hidden units as well.
48 |
49 | Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
50 | prior to the initial matrix multiply by `weights`.
51 |
52 | Args:
53 | inputs: A tensor of with at least rank 2 and value for the last dimension,
54 | i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
55 | num_outputs: Integer or long, the number of output units in the layer.
56 | activation_fn: activation function, set to None to skip it and maintain
57 | a linear activation.
58 | normalizer_fn: normalization function to use instead of `biases`. If
59 | `normalizer_fn` is provided then `biases_initializer` and
60 | `biases_regularizer` are ignored and `biases` are not created nor added.
61 | default set to None for no normalizer function
62 | normalizer_params: normalization function parameters.
63 | weights_normalizer_fn: weights normalization function.
64 | weights_normalizer_params: weights normalization function parameters.
65 | weights_initializer: An initializer for the weights.
66 | weights_regularizer: Optional regularizer for the weights.
67 | biases_initializer: An initializer for the biases. If None skip biases.
68 | biases_regularizer: Optional regularizer for the biases.
69 | reuse: whether or not the layer and its variables should be reused. To be
70 | able to reuse the layer scope must be given.
71 | variables_collections: Optional list of collections for all the variables or
72 | a dictionary containing a different list of collections per variable.
73 | outputs_collections: collection to add the outputs.
74 | trainable: If `True` also add variables to the graph collection
75 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
76 | scope: Optional scope for variable_scope.
77 |
78 | Returns:
79 | the tensor variable representing the result of the series of operations.
80 |
81 | Raises:
82 | ValueError: if x has rank less than 2 or if its last dimension is not set.
83 | """
84 | if not (isinstance(num_outputs, six.integer_types)):
85 | raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
86 | with variable_scope.variable_scope(scope, 'fully_connected', [inputs],
87 | reuse=reuse) as sc:
88 | inputs = ops.convert_to_tensor(inputs)
89 | dtype = inputs.dtype.base_dtype
90 | inputs_shape = inputs.get_shape()
91 | num_input_units = utils.last_dimension(inputs_shape, min_rank=2)
92 |
93 | static_shape = inputs_shape.as_list()
94 | static_shape[-1] = num_outputs
95 |
96 | out_shape = array_ops.unpack(array_ops.shape(inputs), len(static_shape))
97 | out_shape[-1] = num_outputs
98 |
99 | weights_shape = [num_input_units, num_outputs]
100 | weights_collections = utils.get_variable_collections(
101 | variables_collections, 'weights')
102 | weights = variables.model_variable('weights',
103 | shape=weights_shape,
104 | dtype=dtype,
105 | initializer=weights_initializer,
106 | regularizer=weights_regularizer,
107 | collections=weights_collections,
108 | trainable=trainable)
109 | if weights_normalizer_fn is not None:
110 | weights_normalizer_params = weights_normalizer_params or {}
111 | weights = weights_normalizer_fn(weights, **weights_normalizer_params)
112 | if len(static_shape) > 2:
113 | # Reshape inputs
114 | inputs = array_ops.reshape(inputs, [-1, num_input_units])
115 | outputs = standard_ops.matmul(inputs, weights)
116 | if normalizer_fn is not None:
117 | normalizer_params = normalizer_params or {}
118 | outputs = normalizer_fn(outputs, **normalizer_params)
119 | else:
120 | if biases_initializer is not None:
121 | biases_collections = utils.get_variable_collections(
122 | variables_collections, 'biases')
123 | biases = variables.model_variable('biases',
124 | shape=[num_outputs, ],
125 | dtype=dtype,
126 | initializer=biases_initializer,
127 | regularizer=biases_regularizer,
128 | collections=biases_collections,
129 | trainable=trainable)
130 | outputs = nn.bias_add(outputs, biases)
131 | if activation_fn is not None:
132 | outputs = activation_fn(outputs)
133 | if len(static_shape) > 2:
134 | # Reshape back outputs
135 | outputs = array_ops.reshape(outputs, array_ops.pack(out_shape))
136 | outputs.set_shape(static_shape)
137 | return utils.collect_named_outputs(outputs_collections,
138 | sc.original_name_scope, outputs)
139 |
140 |
141 | @add_arg_scope
142 | def convolution(inputs,
143 | num_outputs,
144 | kernel_size,
145 | stride=1,
146 | padding='SAME',
147 | data_format=None,
148 | rate=1,
149 | activation_fn=nn.relu,
150 | normalizer_fn=None,
151 | normalizer_params=None,
152 | weights_normalizer_fn=None,
153 | weights_normalizer_params=None,
154 | weights_initializer=initializers.xavier_initializer(),
155 | weights_regularizer=None,
156 | biases_initializer=init_ops.zeros_initializer(),
157 | biases_regularizer=None,
158 | reuse=None,
159 | variables_collections=None,
160 | outputs_collections=None,
161 | trainable=True,
162 | scope=None):
163 | # Be copied and modified from tensorflow-0.12.0.contrib.layer.convolution,
164 | # add weights_nomalizer_* options.
165 | """Adds an N-D convolution followed by an optional batch_norm layer.
166 |
167 | It is required that 1 <= N <= 3.
168 |
169 | `convolution` creates a variable called `weights`, representing the
170 | convolutional kernel, that is convolved (actually cross-correlated) with the
171 | `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
172 | provided (such as `batch_norm`), it is then applied. Otherwise, if
173 | `normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
174 | variable would be created and added the activations. Finally, if
175 | `activation_fn` is not `None`, it is applied to the activations as well.
176 |
177 | Performs a'trous convolution with input stride/dilation rate equal to `rate`
178 | if a value > 1 for any dimension of `rate` is specified. In this case
179 | `stride` values != 1 are not supported.
180 |
181 | Args:
182 | inputs: a Tensor of rank N+2 of shape
183 | `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
184 | not start with "NC" (default), or
185 | `[batch_size, in_channels] + input_spatial_shape` if data_format starts
186 | with "NC".
187 | num_outputs: integer, the number of output filters.
188 | kernel_size: a sequence of N positive integers specifying the spatial
189 | dimensions of of the filters. Can be a single integer to specify the same
190 | value for all spatial dimensions.
191 | stride: a sequence of N positive integers specifying the stride at which to
192 | compute output. Can be a single integer to specify the same value for all
193 | spatial dimensions. Specifying any `stride` value != 1 is incompatible
194 | with specifying any `rate` value != 1.
195 | padding: one of `"VALID"` or `"SAME"`.
196 | data_format: A string or None. Specifies whether the channel dimension of
197 | the `input` and output is the last dimension (default, or if `data_format`
198 | does not start with "NC"), or the second dimension (if `data_format`
199 | starts with "NC"). For N=1, the valid values are "NWC" (default) and
200 | "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". For
201 | N=3, currently the only valid value is "NDHWC".
202 | rate: a sequence of N positive integers specifying the dilation rate to use
203 | for a'trous convolution. Can be a single integer to specify the same
204 | value for all spatial dimensions. Specifying any `rate` value != 1 is
205 | incompatible with specifying any `stride` value != 1.
206 | activation_fn: activation function, set to None to skip it and maintain
207 | a linear activation.
208 | normalizer_fn: normalization function to use instead of `biases`. If
209 | `normalizer_fn` is provided then `biases_initializer` and
210 | `biases_regularizer` are ignored and `biases` are not created nor added.
211 | default set to None for no normalizer function
212 | normalizer_params: normalization function parameters.
213 | weights_normalizer_fn: weights normalization function.
214 | weights_normalizer_params: weights normalization function parameters.
215 | weights_initializer: An initializer for the weights.
216 | weights_regularizer: Optional regularizer for the weights.
217 | biases_initializer: An initializer for the biases. If None skip biases.
218 | biases_regularizer: Optional regularizer for the biases.
219 | reuse: whether or not the layer and its variables should be reused. To be
220 | able to reuse the layer scope must be given.
221 | variables_collections: optional list of collections for all the variables or
222 | a dictionary containing a different list of collection per variable.
223 | outputs_collections: collection to add the outputs.
224 | trainable: If `True` also add variables to the graph collection
225 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
226 | scope: Optional scope for `variable_scope`.
227 |
228 | Returns:
229 | a tensor representing the output of the operation.
230 |
231 | Raises:
232 | ValueError: if `data_format` is invalid.
233 | ValueError: both 'rate' and `stride` are not uniformly 1.
234 | """
235 | if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC']:
236 | raise ValueError('Invalid data_format: %r' % (data_format,))
237 | with variable_scope.variable_scope(scope, 'Conv', [inputs],
238 | reuse=reuse) as sc:
239 | inputs = ops.convert_to_tensor(inputs)
240 | dtype = inputs.dtype.base_dtype
241 | input_rank = inputs.get_shape().ndims
242 | if input_rank is None:
243 | raise ValueError('Rank of inputs must be known')
244 | if input_rank < 3 or input_rank > 5:
245 | raise ValueError('Rank of inputs is %d, which is not >= 3 and <= 5' %
246 | input_rank)
247 | conv_dims = input_rank - 2
248 | kernel_size = utils.n_positive_integers(conv_dims, kernel_size)
249 | stride = utils.n_positive_integers(conv_dims, stride)
250 | rate = utils.n_positive_integers(conv_dims, rate)
251 |
252 | if data_format is None or data_format.endswith('C'):
253 | num_input_channels = inputs.get_shape()[input_rank - 1].value
254 | elif data_format.startswith('NC'):
255 | num_input_channels = inputs.get_shape()[1].value
256 | else:
257 | raise ValueError('Invalid data_format')
258 |
259 | if num_input_channels is None:
260 | raise ValueError('Number of in_channels must be known.')
261 |
262 | weights_shape = (
263 | list(kernel_size) + [num_input_channels, num_outputs])
264 | weights_collections = utils.get_variable_collections(variables_collections,
265 | 'weights')
266 | weights = variables.model_variable('weights',
267 | shape=weights_shape,
268 | dtype=dtype,
269 | initializer=weights_initializer,
270 | regularizer=weights_regularizer,
271 | collections=weights_collections,
272 | trainable=trainable)
273 | if weights_normalizer_fn is not None:
274 | weights_normalizer_params = weights_normalizer_params or {}
275 | weights = weights_normalizer_fn(weights, **weights_normalizer_params)
276 | outputs = nn.convolution(input=inputs,
277 | filter=weights,
278 | dilation_rate=rate,
279 | strides=stride,
280 | padding=padding,
281 | data_format=data_format)
282 | if normalizer_fn is not None:
283 | normalizer_params = normalizer_params or {}
284 | outputs = normalizer_fn(outputs, **normalizer_params)
285 | else:
286 | if biases_initializer is not None:
287 | biases_collections = utils.get_variable_collections(
288 | variables_collections, 'biases')
289 | biases = variables.model_variable('biases',
290 | shape=[num_outputs],
291 | dtype=dtype,
292 | initializer=biases_initializer,
293 | regularizer=biases_regularizer,
294 | collections=biases_collections,
295 | trainable=trainable)
296 | outputs = nn.bias_add(outputs, biases, data_format=data_format)
297 | if activation_fn is not None:
298 | outputs = activation_fn(outputs)
299 | return utils.collect_named_outputs(outputs_collections,
300 | sc.original_name_scope, outputs)
301 |
302 |
303 | convolution2d = convolution
304 | convolution3d = convolution
305 |
306 |
307 | @add_arg_scope
308 | def spectral_normalization(weights,
309 | num_iterations=1,
310 | epsilon=1e-12,
311 | u_initializer=tf.random_normal_initializer(),
312 | updates_collections=tf.GraphKeys.UPDATE_OPS,
313 | is_training=True,
314 | reuse=None,
315 | variables_collections=None,
316 | outputs_collections=None,
317 | scope=None):
318 | with tf.variable_scope(scope, 'SpectralNorm', [weights], reuse=reuse) as sc:
319 | weights = tf.convert_to_tensor(weights)
320 |
321 | dtype = weights.dtype.base_dtype
322 |
323 | w_t = tf.reshape(weights, [-1, weights.shape.as_list()[-1]])
324 | w = tf.transpose(w_t)
325 | m, n = w.shape.as_list()
326 |
327 | u_collections = utils.get_variable_collections(variables_collections, 'u')
328 | u = tf.get_variable("u",
329 | shape=[m, 1],
330 | dtype=dtype,
331 | initializer=u_initializer,
332 | trainable=False,
333 | collections=u_collections,)
334 | sigma_collections = utils.get_variable_collections(variables_collections, 'sigma')
335 | sigma = tf.get_variable('sigma',
336 | shape=[],
337 | dtype=dtype,
338 | initializer=tf.zeros_initializer(),
339 | trainable=False,
340 | collections=sigma_collections)
341 |
342 | def _power_iteration(i, u, v):
343 | v_ = tf.nn.l2_normalize(tf.matmul(w_t, u), epsilon=epsilon)
344 | u_ = tf.nn.l2_normalize(tf.matmul(w, v_), epsilon=epsilon)
345 | return i + 1, u_, v_
346 |
347 | _, u_, v_ = tf.while_loop(
348 | cond=lambda i, _1, _2: i < num_iterations,
349 | body=_power_iteration,
350 | loop_vars=[tf.constant(0), u, tf.zeros(shape=[n, 1], dtype=tf.float32)]
351 | )
352 | u_ = tf.stop_gradient(u_)
353 | v_ = tf.stop_gradient(v_)
354 | sigma_ = tf.matmul(tf.transpose(u_), tf.matmul(w, v_))[0, 0]
355 |
356 | update_u = u.assign(u_)
357 | update_sigma = sigma.assign(sigma_)
358 | if updates_collections is None:
359 | def _force_update():
360 | with tf.control_dependencies([update_u, update_sigma]):
361 | return tf.identity(sigma_)
362 |
363 | sigma_ = utils.smart_cond(is_training, _force_update, lambda: sigma)
364 | weights_sn = weights / sigma_
365 | else:
366 | sigma_ = utils.smart_cond(is_training, lambda: sigma_, lambda: sigma)
367 | weights_sn = weights / sigma_
368 | tf.add_to_collections(updates_collections, update_u)
369 | tf.add_to_collections(updates_collections, update_sigma)
370 |
371 | return utils.collect_named_outputs(outputs_collections, sc.name, weights_sn)
372 |
373 |
374 | # Simple alias.
375 | conv2d = convolution2d
376 | conv3d = convolution3d
377 |
--------------------------------------------------------------------------------
/tflib/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.losses.losses import *
2 |
--------------------------------------------------------------------------------
/tflib/losses/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def center_loss(features, labels, num_classes, alpha=0.5, updates_collections=tf.GraphKeys.UPDATE_OPS, scope=None):
5 | # modified from https://github.com/EncodeTS/TensorFlow_Center_Loss/blob/master/center_loss.py
6 |
7 | assert features.shape.ndims == 2, 'The rank of `features` should be 2!'
8 | assert 0 <= alpha <= 1, '`alpha` should be in [0, 1]!'
9 |
10 | with tf.variable_scope(scope, 'center_loss', [features, labels]):
11 | centers = tf.get_variable('centers', shape=[num_classes, features.get_shape()[-1]], dtype=tf.float32,
12 | initializer=tf.constant_initializer(0), trainable=False)
13 |
14 | centers_batch = tf.gather(centers, labels)
15 | diff = centers_batch - features
16 | _, unique_idx, unique_count = tf.unique_with_counts(labels)
17 | appear_times = tf.gather(unique_count, unique_idx)
18 | appear_times = tf.reshape(appear_times, [-1, 1])
19 | diff = diff / tf.cast((1 + appear_times), tf.float32)
20 | diff = alpha * diff
21 | update_centers = tf.scatter_sub(centers, labels, diff)
22 |
23 | center_loss = 0.5 * tf.reduce_mean(tf.reduce_sum((centers_batch - features)**2, axis=-1))
24 |
25 | if updates_collections is None:
26 | with tf.control_dependencies([update_centers]):
27 | center_loss = tf.identity(center_loss)
28 | else:
29 | tf.add_to_collections(updates_collections, update_centers)
30 |
31 | return center_loss, centers
32 |
33 |
34 | def sigmoid_focal_loss(multi_class_labels, logits, gamma=2.0):
35 | epsilon = 1e-8
36 | multi_class_labels = tf.cast(multi_class_labels, logits.dtype)
37 |
38 | p = tf.sigmoid(logits)
39 | pt = p * multi_class_labels + (1 - p) * (1 - multi_class_labels)
40 | focal_loss = tf.reduce_mean(- (1 - pt)**gamma * tf.log(tf.maximum(pt, epsilon)))
41 |
42 | return focal_loss
43 |
--------------------------------------------------------------------------------
/tflib/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.metrics.metrics import *
2 |
--------------------------------------------------------------------------------
/tflib/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def resettable_metric(metric_fn, metric_params, scope=None):
5 | with tf.variable_scope(scope, 'resettable_metric') as sc:
6 | metric_returns = metric_fn(**metric_params)
7 | reset_op = tf.variables_initializer(tf.local_variables(sc.name))
8 | return metric_returns + (reset_op,)
9 |
10 |
11 | def make_resettable(metric_fn, scope=None):
12 | def resettable_metric_fn(*args, **kwargs):
13 | with tf.variable_scope(scope, 'resettable_metric') as sc:
14 | metric_returns = metric_fn(*args, **kwargs)
15 | reset_op = tf.variables_initializer(tf.local_variables(sc.name))
16 | return metric_returns + (reset_op,)
17 | return resettable_metric_fn
18 |
--------------------------------------------------------------------------------
/tflib/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.ops.ops import *
2 |
--------------------------------------------------------------------------------
/tflib/ops/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def minmax_norm(x, epsilon=1e-12):
5 | x = tf.to_float(x)
6 | min_val = tf.reduce_min(x)
7 | max_val = tf.reduce_max(x)
8 | norm_x = (x - min_val) / tf.maximum((max_val - min_val), epsilon)
9 | return norm_x
10 |
11 |
12 | def reshape(x, shape):
13 | x = tf.convert_to_tensor(x)
14 | shape = [x.shape[i] if shape[i] == 0 else shape[i] for i in range(len(shape))]
15 | shape = [tf.shape(x)[i] if shape[i] is None else shape[i] for i in range(len(shape))]
16 | return tf.reshape(x, shape)
17 |
--------------------------------------------------------------------------------
/tflib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from tflib.utils.collection import *
2 | from tflib.utils.learning_rate import *
3 | from tflib.utils.distribute import *
4 | from tflib.utils.utils import *
5 |
--------------------------------------------------------------------------------
/tflib/utils/collection.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import tensorflow as tf
4 |
5 |
6 | def tensors_filter(tensors,
7 | includes='',
8 | includes_combine_type='or',
9 | excludes=None,
10 | excludes_combine_type='or'):
11 | # NOTICE: `includes` = [] means nothing to be included, and `excludes` = [] means nothing to be excluded
12 |
13 | if excludes is None:
14 | excludes = []
15 |
16 | assert isinstance(tensors, (list, tuple)), '`tensors` shoule be a list or tuple!'
17 | assert isinstance(includes, (str, list, tuple)), '`includes` should be a string or a list(tuple) of strings!'
18 | assert includes_combine_type in ['or', 'and'], "`includes_combine_type` should be 'or' or 'and'!"
19 | assert isinstance(excludes, (str, list, tuple)), '`excludes` should be a string or a list(tuple) of strings!'
20 | assert excludes_combine_type in ['or', 'and'], "`excludes_combine_type` should be 'or' or 'and'!"
21 |
22 | def _select(filters, combine_type):
23 | if filter in [[], ()]:
24 | return []
25 |
26 | filters = filters if isinstance(filters, (list, tuple)) else [filters]
27 |
28 | selected = []
29 | for t in tensors:
30 | if combine_type == 'or':
31 | for filt in filters:
32 | if filt in t.name:
33 | selected.append(t)
34 | break
35 | elif combine_type == 'and':
36 | for filt in filters:
37 | if filt not in t.name:
38 | break
39 | else:
40 | selected.append(t)
41 |
42 | return selected
43 |
44 | include_set = _select(includes, includes_combine_type)
45 | exclude_set = _select(excludes, excludes_combine_type)
46 | select_set = [t for t in include_set if t not in exclude_set]
47 |
48 | return select_set
49 |
50 |
51 | def get_collection(key,
52 | includes='',
53 | includes_combine_type='or',
54 | excludes=None,
55 | excludes_combine_type='or'):
56 | tensors = tf.get_collection(key)
57 | return tensors_filter(tensors,
58 | includes,
59 | includes_combine_type,
60 | excludes,
61 | excludes_combine_type)
62 |
63 |
64 | global_variables = partial(get_collection, key=tf.GraphKeys.GLOBAL_VARIABLES)
65 | local_variables = partial(get_collection, key=tf.GraphKeys.LOCAL_VARIABLES)
66 | trainable_variables = partial(get_collection, key=tf.GraphKeys.TRAINABLE_VARIABLES)
67 | update_ops = partial(get_collection, key=tf.GraphKeys.UPDATE_OPS)
68 |
--------------------------------------------------------------------------------
/tflib/utils/distribute.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def average_gradients(tower_grads):
5 | """Calculate the average gradient for each shared variable across all towers.
6 |
7 | Note that this function provides a synchronization point across all towers.
8 |
9 | Parameters
10 | ----------
11 | tower_grads:
12 | List of lists of (gradient, variable) tuples. The outer list
13 | is over individual gradients. The inner list is over the gradient
14 | calculation for each tower.
15 |
16 | Returns
17 | -------
18 | List of pairs of (gradient, variable) where the gradient has been averaged
19 | across all towers.
20 |
21 | """
22 | average_grads = []
23 | for grad_and_vars in zip(*tower_grads):
24 | # Note that each grad_and_vars looks like the following:
25 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
26 | grads = []
27 | for g, _ in grad_and_vars:
28 | # Add 0 dimension to the gradients to represent the tower.
29 | expanded_g = tf.expand_dims(g, 0)
30 |
31 | # Append on a 'tower' dimension which we will average over below.
32 | grads.append(expanded_g)
33 |
34 | # Average over the 'tower' dimension.
35 | grad = tf.concat(axis=0, values=grads)
36 | grad = tf.reduce_mean(grad, 0)
37 |
38 | # Keep in mind that the Variables are redundant because they are shared
39 | # across towers. So .. we will just return the first tower's pointer to
40 | # the Variable.
41 | v = grad_and_vars[0][1]
42 | grad_and_var = (grad, v)
43 | average_grads.append(grad_and_var)
44 | return average_grads
45 |
--------------------------------------------------------------------------------
/tflib/utils/learning_rate.py:
--------------------------------------------------------------------------------
1 | class LinearDecayLR:
2 | # if `step` < `step_start_decay`: use fixed learning rate
3 | # else: linearly decay the learning rate to zero
4 |
5 | # `step` should start from 0(included) to `steps`(excluded)
6 |
7 | def __init__(self, initial_learning_rate, steps, step_start_decay):
8 | self._initial_learning_rate = initial_learning_rate
9 | self._steps = steps
10 | self._step_start_decay = step_start_decay
11 | self.current_learning_rate = initial_learning_rate
12 |
13 | def __call__(self, step):
14 | if step >= self._step_start_decay:
15 | self.current_learning_rate = self._initial_learning_rate * (1 - 1 / (self._steps - self._step_start_decay + 1) * (step - self._step_start_decay + 1))
16 | else:
17 | self.current_learning_rate = self._initial_learning_rate
18 | return self.current_learning_rate
19 |
20 |
21 | class StepDecayLR:
22 |
23 | def __init__(self, initial_learning_rate, step_size, decay_rate):
24 | super(StepDecayLR, self).__init__()
25 | self._initial_learning_rate = initial_learning_rate
26 | self._step_size = step_size
27 | self._decay_rate = decay_rate
28 | self.current_learning_rate = initial_learning_rate
29 |
30 | def __call__(self, step):
31 | self.current_learning_rate = self._initial_learning_rate * self._decay_rate ** (step // self._step_size)
32 | return self.current_learning_rate
33 |
--------------------------------------------------------------------------------
/tflib/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | def session(graph=None,
6 | allow_soft_placement=True,
7 | log_device_placement=False,
8 | allow_growth=True):
9 | """Return a Session with simple config."""
10 | config = tf.ConfigProto(allow_soft_placement=allow_soft_placement,
11 | log_device_placement=log_device_placement)
12 | config.gpu_options.allow_growth = allow_growth
13 | return tf.Session(graph=graph, config=config)
14 |
15 |
16 | class Checkpoint:
17 | """Enhanced "tf.train.Checkpoint"."""
18 |
19 | def __init__(self,
20 | checkpoint_kwargs, # for "tf.train.Checkpoint"
21 | directory, # for "tf.train.CheckpointManager"
22 | max_to_keep=5,
23 | keep_checkpoint_every_n_hours=None):
24 | self.checkpoint = tf.train.Checkpoint(**checkpoint_kwargs)
25 | self.manager = tf.train.CheckpointManager(self.checkpoint, directory, max_to_keep, keep_checkpoint_every_n_hours)
26 |
27 | def restore(self, save_path=None):
28 | save_path = self.manager.latest_checkpoint if save_path is None else save_path
29 | return self.checkpoint.restore(save_path)
30 |
31 | def save(self, file_prefix_or_checkpoint_number=None, session=None):
32 | if isinstance(file_prefix_or_checkpoint_number, str):
33 | return self.checkpoint.save(file_prefix_or_checkpoint_number, session=session)
34 | else:
35 | return self.manager.save(checkpoint_number=file_prefix_or_checkpoint_number)
36 |
37 | def __getattr__(self, attr):
38 | if hasattr(self.checkpoint, attr):
39 | return getattr(self.checkpoint, attr)
40 | elif hasattr(self.manager, attr):
41 | return getattr(self.manager, attr)
42 | else:
43 | self.__getattribute__(attr) # this will raise an exception
44 |
45 |
46 | def summary(name_data_dict,
47 | types=['mean', 'std', 'max', 'min', 'sparsity', 'histogram'],
48 | name='summary'):
49 | """Summary.
50 |
51 | Examples
52 | --------
53 | >>> summary({'a': data_a, 'b': data_b})
54 |
55 | """
56 | def _summary(name, data):
57 | summaries = []
58 | if data.shape == ():
59 | summaries.append(tf.summary.scalar(name, data))
60 | else:
61 | if 'mean' in types:
62 | summaries.append(tf.summary.scalar(name + '-mean', tf.math.reduce_mean(data)))
63 | if 'std' in types:
64 | summaries.append(tf.summary.scalar(name + '-std', tf.math.reduce_std(data)))
65 | if 'max' in types:
66 | summaries.append(tf.summary.scalar(name + '-max', tf.math.reduce_max(data)))
67 | if 'min' in types:
68 | summaries.append(tf.summary.scalar(name + '-min', tf.math.reduce_min(data)))
69 | if 'sparsity' in types:
70 | summaries.append(tf.summary.scalar(name + '-sparsity', tf.math.zero_fraction(data)))
71 | if 'histogram' in types:
72 | summaries.append(tf.summary.histogram(name, data))
73 | return tf.summary.merge(summaries)
74 |
75 | with tf.name_scope(name):
76 | summaries = []
77 | for name, data in name_data_dict.items():
78 | summaries.append(_summary(name, data))
79 | return tf.summary.merge(summaries)
80 |
81 |
82 | def summary_v2(name_data_dict,
83 | step,
84 | types=['mean', 'std', 'max', 'min', 'sparsity', 'histogram'],
85 | name='summary'):
86 | """Summary.
87 |
88 | Examples
89 | --------
90 | >>> summary({'a': data_a, 'b': data_b}, tf.train.get_global_step())
91 |
92 | """
93 | def _summary(name, data):
94 | summaries = []
95 | if data.shape == ():
96 | summaries.append(tf.contrib.summary.scalar(name, data, step=step))
97 | else:
98 | if 'mean' in types:
99 | summaries.append(tf.contrib.summary.scalar(name + '-mean', tf.math.reduce_mean(data), step=step))
100 | if 'std' in types:
101 | summaries.append(tf.contrib.summary.scalar(name + '-std', tf.math.reduce_std(data), step=step))
102 | if 'max' in types:
103 | summaries.append(tf.contrib.summary.scalar(name + '-max', tf.math.reduce_max(data), step=step))
104 | if 'min' in types:
105 | summaries.append(tf.contrib.summary.scalar(name + '-min', tf.math.reduce_min(data), step=step))
106 | if 'sparsity' in types:
107 | summaries.append(tf.contrib.summary.scalar(name + '-sparsity', tf.math.zero_fraction(data), step=step))
108 | if 'histogram' in types:
109 | summaries.append(tf.contrib.summary.histogram(name, data, step=step))
110 | return summaries
111 |
112 | with tf.name_scope(name):
113 | summaries = {}
114 | for name, data in name_data_dict.items():
115 | summaries[name] = _summary(name, data)
116 | return summaries
117 |
118 |
119 | def counter(start=0, scope=None):
120 | with tf.variable_scope(scope, 'counter'):
121 | counter = tf.get_variable(name='counter',
122 | initializer=tf.constant_initializer(start),
123 | shape=(),
124 | dtype=tf.int64)
125 | update_cnt = tf.assign(counter, tf.add(counter, 1))
126 | return counter, update_cnt
127 |
128 |
129 | def receptive_field(convnet_fn):
130 | # TODO(Lynn): too ugly ...
131 | g = tf.Graph()
132 | with g.as_default():
133 | img = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='img')
134 | convnet_fn(img)
135 |
136 | node_names = [node.name for node in g.as_graph_def().node
137 | if 'img' != node.name
138 | # for Conv
139 | if 'weights' not in node.name
140 | if 'biases' not in node.name
141 | if 'dilation_rate' not in node.name
142 | # for BatchNorm
143 | if 'beta' not in node.name
144 | if 'gamma' not in node.name
145 | if 'moving_mean' not in node.name
146 | if 'moving_variance' not in node.name
147 | if 'AssignMovingAvg' not in node.name]
148 |
149 | results = []
150 | for node_name in node_names:
151 | try:
152 | rf_x, rf_y, eff_stride_x, eff_stride_y, eff_pad_x, eff_pad_y = \
153 | tf.contrib.receptive_field.compute_receptive_field_from_graph_def(g.as_graph_def(), 'img', node_name)
154 | results.append((
155 | node_name,
156 | {'receptive_field_x': rf_x,
157 | 'receptive_field_y': rf_y,
158 | 'effective_stride_x': eff_stride_x,
159 | 'effective_stride_y': eff_stride_y,
160 | 'effective_padding_x': eff_pad_x,
161 | 'effective_padding_y': eff_pad_y}
162 | ))
163 | except ValueError as e:
164 | if str(e) != "Output node was not found":
165 | raise e
166 |
167 | return results
168 |
169 |
170 | def count_parameters(variables):
171 | variables = variables if isinstance(variables, (list, tuple)) else [variables]
172 | n_params = np.sum([np.prod(v.shape.as_list()) for v in variables])
173 | n_bytes = np.sum([np.prod(v.shape.as_list()) * v.dtype.size for v in variables])
174 | return n_params, n_bytes
175 |
176 |
177 | def print_tensor(tensors):
178 | if not isinstance(tensors, (list, tuple)):
179 | tensors = [tensors]
180 |
181 | for i, tensor in enumerate(tensors):
182 | ctype = str(type(tensor))
183 | if 'Tensor' in ctype:
184 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' %
185 | (i, 'Tensor', tensor.name, tensor.shape, tensor.dtype.name, tensor.device))
186 | elif 'Variable' in ctype:
187 | print('%d: %s("%s", shape=%s, dtype=%s, device=%s)' %
188 | (i, 'Variable', tensor.name, tensor.shape, tensor.dtype.name, tensor.device))
189 | elif 'Operation' in ctype:
190 | print('%d: %s("%s", device=%s)' %
191 | (i, 'Operation', tensor.name, tensor.device))
192 | else:
193 | raise Exception('Not a Tensor, Variable or Operation!')
194 |
195 |
196 | prt = print_tensor
197 |
--------------------------------------------------------------------------------
/tfprob/__init__.py:
--------------------------------------------------------------------------------
1 | from tfprob.gan import *
2 |
--------------------------------------------------------------------------------
/tfprob/gan/__init__.py:
--------------------------------------------------------------------------------
1 | from tfprob.gan.gradient_penalty import *
2 | from tfprob.gan.loss import *
3 |
--------------------------------------------------------------------------------
/tfprob/gan/gradient_penalty.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | # ======================================
5 | # = sample method =
6 | # ======================================
7 |
8 | def _sample_line(real, fake):
9 | shape = [tf.shape(real)[0]] + [1] * (real.shape.ndims - 1)
10 | alpha = tf.random.uniform(shape=shape, minval=0, maxval=1)
11 | sample = real + alpha * (fake - real)
12 | sample.set_shape(real.shape)
13 | return sample
14 |
15 |
16 | def _sample_DRAGAN(real, fake): # fake is useless
17 | beta = tf.random.uniform(shape=tf.shape(real), minval=0, maxval=1)
18 | fake = real + 0.5 * tf.math.reduce_std(real) * beta
19 | sample = _sample_line(real, fake)
20 | return sample
21 |
22 |
23 | # ======================================
24 | # = gradient penalty method =
25 | # ======================================
26 |
27 | def _norm(x):
28 | norm = tf.norm(tf.reshape(x, [tf.shape(x)[0], -1]), axis=1)
29 | return norm
30 |
31 |
32 | def _one_mean_gp(grad):
33 | norm = _norm(grad)
34 | gp = tf.reduce_mean((norm - 1)**2)
35 | return gp
36 |
37 |
38 | def _zero_mean_gp(grad):
39 | norm = _norm(grad)
40 | gp = tf.reduce_mean(norm**2)
41 | return gp
42 |
43 |
44 | def _lipschitz_penalty(grad):
45 | norm = _norm(grad)
46 | gp = tf.reduce_mean(tf.maximum(norm - 1, 0)**2)
47 | return gp
48 |
49 |
50 | def gradient_penalty(f, real, fake, gp_mode, sample_mode):
51 | sample_fns = {
52 | 'line': _sample_line,
53 | 'real': lambda real, fake: real,
54 | 'fake': lambda real, fake: fake,
55 | 'dragan': _sample_DRAGAN,
56 | }
57 |
58 | gp_fns = {
59 | '1-gp': _one_mean_gp,
60 | '0-gp': _zero_mean_gp,
61 | 'lp': _lipschitz_penalty,
62 | }
63 |
64 | if gp_mode == 'none':
65 | gp = tf.constant(0, dtype=real.dtype)
66 | else:
67 | x = sample_fns[sample_mode](real, fake)
68 | grad = tf.gradients(f(x), x)[0]
69 | gp = gp_fns[gp_mode](grad)
70 |
71 | return gp
72 |
--------------------------------------------------------------------------------
/tfprob/gan/loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def get_gan_losses_fn():
5 | bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
6 |
7 | def d_loss_fn(r_logit, f_logit):
8 | r_loss = bce(tf.ones_like(r_logit), r_logit)
9 | f_loss = bce(tf.zeros_like(f_logit), f_logit)
10 | return r_loss, f_loss
11 |
12 | def g_loss_fn(f_logit):
13 | f_loss = bce(tf.ones_like(f_logit), f_logit)
14 | return f_loss
15 |
16 | return d_loss_fn, g_loss_fn
17 |
18 |
19 | def get_hinge_v1_losses_fn():
20 | def d_loss_fn(r_logit, f_logit):
21 | r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0))
22 | f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0))
23 | return r_loss, f_loss
24 |
25 | def g_loss_fn(f_logit):
26 | f_loss = tf.reduce_mean(tf.maximum(1 - f_logit, 0))
27 | return f_loss
28 |
29 | return d_loss_fn, g_loss_fn
30 |
31 |
32 | def get_hinge_v2_losses_fn():
33 | def d_loss_fn(r_logit, f_logit):
34 | r_loss = tf.reduce_mean(tf.maximum(1 - r_logit, 0))
35 | f_loss = tf.reduce_mean(tf.maximum(1 + f_logit, 0))
36 | return r_loss, f_loss
37 |
38 | def g_loss_fn(f_logit):
39 | f_loss = tf.reduce_mean(- f_logit)
40 | return f_loss
41 |
42 | return d_loss_fn, g_loss_fn
43 |
44 |
45 | def get_lsgan_losses_fn():
46 | mse = tf.keras.losses.MeanSquaredError()
47 |
48 | def d_loss_fn(r_logit, f_logit):
49 | r_loss = mse(tf.ones_like(r_logit), r_logit)
50 | f_loss = mse(tf.zeros_like(f_logit), f_logit)
51 | return r_loss, f_loss
52 |
53 | def g_loss_fn(f_logit):
54 | f_loss = mse(tf.ones_like(f_logit), f_logit)
55 | return f_loss
56 |
57 | return d_loss_fn, g_loss_fn
58 |
59 |
60 | def get_wgan_losses_fn():
61 | def d_loss_fn(r_logit, f_logit):
62 | r_loss = - tf.reduce_mean(r_logit)
63 | f_loss = tf.reduce_mean(f_logit)
64 | return r_loss, f_loss
65 |
66 | def g_loss_fn(f_logit):
67 | f_loss = - tf.reduce_mean(f_logit)
68 | return f_loss
69 |
70 | return d_loss_fn, g_loss_fn
71 |
72 |
73 | def get_adversarial_losses_fn(mode):
74 | if mode == 'gan':
75 | return get_gan_losses_fn()
76 | elif mode == 'hinge_v1':
77 | return get_hinge_v1_losses_fn()
78 | elif mode == 'hinge_v2':
79 | return get_hinge_v2_losses_fn()
80 | elif mode == 'lsgan':
81 | return get_lsgan_losses_fn()
82 | elif mode == 'wgan':
83 | return get_wgan_losses_fn()
84 |
--------------------------------------------------------------------------------
/to_pb.py:
--------------------------------------------------------------------------------
1 | import pylib as py
2 | import tensorflow as tf
3 | import tflib as tl
4 |
5 | import module
6 |
7 | from tensorflow.python.framework import graph_util
8 |
9 |
10 | # ==============================================================================
11 | # = param =
12 | # ==============================================================================
13 |
14 | py.arg('--experiment_name', default='default')
15 | args_ = py.args()
16 |
17 | # output_dir
18 | output_dir = py.join('output', args_.experiment_name)
19 |
20 | # save settings
21 | args = py.args_from_yaml(py.join(output_dir, 'settings.yml'))
22 | args.__dict__.update(args_.__dict__)
23 |
24 | # others
25 | n_atts = len(args.att_names)
26 |
27 | sess = tl.session()
28 | sess.__enter__() # make default
29 |
30 |
31 | # ==============================================================================
32 | # = graph =
33 | # ==============================================================================
34 |
35 | def sample_graph():
36 | # model
37 | Genc, Gdec, _ = module.get_model(args.model, n_atts, weight_decay=args.weight_decay)
38 |
39 | # placeholders & inputs
40 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3], name='xa')
41 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts], name='b_')
42 |
43 | # sample graph
44 | x = Gdec(Genc(xa, training=False), b_, training=False)
45 | x = tf.identity(x, name='xb')
46 |
47 |
48 | sample = sample_graph()
49 |
50 |
51 | # ==============================================================================
52 | # = freeze =
53 | # ==============================================================================
54 |
55 | # checkpoint
56 | checkpoint = tl.Checkpoint(
57 | {v.name: v for v in tf.global_variables()},
58 | py.join(output_dir, 'checkpoints'),
59 | max_to_keep=1
60 | )
61 | checkpoint.restore().run_restore_ops()
62 |
63 | with tf.gfile.GFile(py.join(output_dir, 'generator.pb'), 'wb') as f:
64 | constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['xb'])
65 | f.write(constant_graph.SerializeToString())
66 |
67 | sess.close()
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import traceback
2 |
3 | import imlib as im
4 | import numpy as np
5 | import pylib as py
6 | import tensorflow as tf
7 | import tflib as tl
8 | import tfprob
9 | import tqdm
10 |
11 | import data
12 | import module
13 |
14 |
15 | # ==============================================================================
16 | # = param =
17 | # ==============================================================================
18 |
19 | default_att_names = ['Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', 'Eyeglasses',
20 | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young']
21 | py.arg('--att_names', choices=data.ATT_ID.keys(), nargs='+', default=default_att_names)
22 |
23 | py.arg('--img_dir', default='./data/img_celeba/aligned/align_size(572,572)_move(0.250,0.000)_face_factor(0.450)_jpg/data')
24 | py.arg('--train_label_path', default='./data/img_celeba/train_label.txt')
25 | py.arg('--val_label_path', default='./data/img_celeba/val_label.txt')
26 | py.arg('--load_size', type=int, default=143)
27 | py.arg('--crop_size', type=int, default=128)
28 |
29 | py.arg('--n_epochs', type=int, default=60)
30 | py.arg('--epoch_start_decay', type=int, default=30)
31 | py.arg('--batch_size', type=int, default=32)
32 | py.arg('--learning_rate', type=float, default=2e-4)
33 | py.arg('--beta_1', type=float, default=0.5)
34 |
35 | py.arg('--model', default='model_128', choices=['model_128', 'model_256', 'model_384'])
36 |
37 | py.arg('--n_d', type=int, default=5) # # d updates per g update
38 | py.arg('--adversarial_loss_mode', choices=['gan', 'hinge_v1', 'hinge_v2', 'lsgan', 'wgan'], default='wgan')
39 | py.arg('--gradient_penalty_mode', choices=['none', '1-gp', '0-gp', 'lp'], default='1-gp')
40 | py.arg('--gradient_penalty_sample_mode', choices=['line', 'real', 'fake', 'dragan'], default='line')
41 | py.arg('--d_gradient_penalty_weight', type=float, default=10.0)
42 | py.arg('--d_attribute_loss_weight', type=float, default=1.0)
43 | py.arg('--g_attribute_loss_weight', type=float, default=10.0)
44 | py.arg('--g_reconstruction_loss_weight', type=float, default=100.0)
45 | py.arg('--weight_decay', type=float, default=0.0)
46 |
47 | py.arg('--n_samples', type=int, default=12)
48 | py.arg('--test_int', type=float, default=2.0)
49 |
50 | py.arg('--experiment_name', default='default')
51 | args = py.args()
52 |
53 | # output_dir
54 | output_dir = py.join('output', args.experiment_name)
55 | py.mkdir(output_dir)
56 |
57 | # save settings
58 | py.args_to_yaml(py.join(output_dir, 'settings.yml'), args)
59 |
60 | # others
61 | n_atts = len(args.att_names)
62 |
63 | sess = tl.session()
64 | sess.__enter__() # make default
65 |
66 |
67 | # ==============================================================================
68 | # = data and model =
69 | # ==============================================================================
70 |
71 | # data
72 | train_dataset, len_train_dataset = data.make_celeba_dataset(args.img_dir, args.train_label_path, args.att_names, args.batch_size,
73 | load_size=args.load_size, crop_size=args.crop_size,
74 | training=True, shuffle=True, repeat=None)
75 | val_dataset, len_val_dataset = data.make_celeba_dataset(args.img_dir, args.val_label_path, args.att_names, args.n_samples,
76 | load_size=args.load_size, crop_size=args.crop_size,
77 | training=False, shuffle=True, repeat=None)
78 | train_iter = train_dataset.make_one_shot_iterator()
79 | val_iter = val_dataset.make_one_shot_iterator()
80 |
81 | # model
82 | Genc, Gdec, D = module.get_model(args.model, n_atts, weight_decay=args.weight_decay)
83 |
84 | # loss functions
85 | d_loss_fn, g_loss_fn = tfprob.get_adversarial_losses_fn(args.adversarial_loss_mode)
86 |
87 |
88 | # ==============================================================================
89 | # = graph =
90 | # ==============================================================================
91 |
92 | def D_train_graph():
93 | # ======================================
94 | # = graph =
95 | # ======================================
96 |
97 | # placeholders & inputs
98 | lr = tf.placeholder(dtype=tf.float32, shape=[])
99 |
100 | xa, a = train_iter.get_next()
101 | b = tf.random_shuffle(a)
102 | b_ = b * 2 - 1
103 |
104 | # generate
105 | z = Genc(xa)
106 | xb_ = Gdec(z, b_)
107 |
108 | # discriminate
109 | xa_logit_gan, xa_logit_att = D(xa)
110 | xb__logit_gan, xb__logit_att = D(xb_)
111 |
112 | # discriminator losses
113 | xa_loss_gan, xb__loss_gan = d_loss_fn(xa_logit_gan, xb__logit_gan)
114 | gp = tfprob.gradient_penalty(lambda x: D(x)[0], xa, xb_, args.gradient_penalty_mode, args.gradient_penalty_sample_mode)
115 | xa_loss_att = tf.losses.sigmoid_cross_entropy(a, xa_logit_att)
116 | reg_loss = tf.reduce_sum(D.func.reg_losses)
117 |
118 | loss = (xa_loss_gan + xb__loss_gan +
119 | gp * args.d_gradient_penalty_weight +
120 | xa_loss_att * args.d_attribute_loss_weight +
121 | reg_loss)
122 |
123 | # optim
124 | step_cnt, _ = tl.counter()
125 | step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=D.func.trainable_variables)
126 |
127 | # summary
128 | with tf.contrib.summary.create_file_writer('./output/%s/summaries/D' % args.experiment_name).as_default(),\
129 | tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
130 | summary = [
131 | tl.summary_v2({
132 | 'loss_gan': xa_loss_gan + xb__loss_gan,
133 | 'gp': gp,
134 | 'xa_loss_att': xa_loss_att,
135 | 'reg_loss': reg_loss
136 | }, step=step_cnt, name='D'),
137 | tl.summary_v2({'lr': lr}, step=step_cnt, name='learning_rate')
138 | ]
139 |
140 | # ======================================
141 | # = run function =
142 | # ======================================
143 |
144 | def run(**pl_ipts):
145 | sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})
146 |
147 | return run
148 |
149 |
150 | def G_train_graph():
151 | # ======================================
152 | # = graph =
153 | # ======================================
154 |
155 | # placeholders & inputs
156 | lr = tf.placeholder(dtype=tf.float32, shape=[])
157 |
158 | xa, a = train_iter.get_next()
159 | b = tf.random_shuffle(a)
160 | a_ = a * 2 - 1
161 | b_ = b * 2 - 1
162 |
163 | # generate
164 | z = Genc(xa)
165 | xa_ = Gdec(z, a_)
166 | xb_ = Gdec(z, b_)
167 |
168 | # discriminate
169 | xb__logit_gan, xb__logit_att = D(xb_)
170 |
171 | # generator losses
172 | xb__loss_gan = g_loss_fn(xb__logit_gan)
173 | xb__loss_att = tf.losses.sigmoid_cross_entropy(b, xb__logit_att)
174 | xa__loss_rec = tf.losses.absolute_difference(xa, xa_)
175 | reg_loss = tf.reduce_sum(Genc.func.reg_losses + Gdec.func.reg_losses)
176 |
177 | loss = (xb__loss_gan +
178 | xb__loss_att * args.g_attribute_loss_weight +
179 | xa__loss_rec * args.g_reconstruction_loss_weight +
180 | reg_loss)
181 |
182 | # optim
183 | step_cnt, _ = tl.counter()
184 | step = tf.train.AdamOptimizer(lr, beta1=args.beta_1).minimize(loss, global_step=step_cnt, var_list=Genc.func.trainable_variables + Gdec.func.trainable_variables)
185 |
186 | # summary
187 | with tf.contrib.summary.create_file_writer('./output/%s/summaries/G' % args.experiment_name).as_default(),\
188 | tf.contrib.summary.record_summaries_every_n_global_steps(10, global_step=step_cnt):
189 | summary = tl.summary_v2({
190 | 'xb__loss_gan': xb__loss_gan,
191 | 'xb__loss_att': xb__loss_att,
192 | 'xa__loss_rec': xa__loss_rec,
193 | 'reg_loss': reg_loss
194 | }, step=step_cnt, name='G')
195 |
196 | # ======================================
197 | # = generator size =
198 | # ======================================
199 |
200 | n_params, n_bytes = tl.count_parameters(Genc.func.variables + Gdec.func.variables)
201 | print('Generator Size: n_parameters = %d = %.2fMB' % (n_params, n_bytes / 1024 / 1024))
202 |
203 | # ======================================
204 | # = run function =
205 | # ======================================
206 |
207 | def run(**pl_ipts):
208 | sess.run([step, summary], feed_dict={lr: pl_ipts['lr']})
209 |
210 | return run
211 |
212 |
213 | def sample_graph():
214 | # ======================================
215 | # = graph =
216 | # ======================================
217 |
218 | # placeholders & inputs
219 | val_next = val_iter.get_next()
220 | xa = tf.placeholder(tf.float32, shape=[None, args.crop_size, args.crop_size, 3])
221 | b_ = tf.placeholder(tf.float32, shape=[None, n_atts])
222 |
223 | # sample graph
224 | x = Gdec(Genc(xa, training=False), b_, training=False)
225 |
226 | # ======================================
227 | # = run function =
228 | # ======================================
229 |
230 | save_dir = './output/%s/samples_training' % args.experiment_name
231 | py.mkdir(save_dir)
232 |
233 | def run(epoch, iter):
234 | # data for sampling
235 | xa_ipt, a_ipt = sess.run(val_next)
236 | b_ipt_list = [a_ipt] # the first is for reconstruction
237 | for i in range(n_atts):
238 | tmp = np.array(a_ipt, copy=True)
239 | tmp[:, i] = 1 - tmp[:, i] # inverse attribute
240 | tmp = data.check_attribute_conflict(tmp, args.att_names[i], args.att_names)
241 | b_ipt_list.append(tmp)
242 |
243 | x_opt_list = [xa_ipt]
244 | for i, b_ipt in enumerate(b_ipt_list):
245 | b__ipt = (b_ipt * 2 - 1).astype(np.float32) # !!!
246 | if i > 0: # i == 0 is for reconstruction
247 | b__ipt[..., i - 1] = b__ipt[..., i - 1] * args.test_int
248 | x_opt = sess.run(x, feed_dict={xa: xa_ipt, b_: b__ipt})
249 | x_opt_list.append(x_opt)
250 | sample = np.transpose(x_opt_list, (1, 2, 0, 3, 4))
251 | sample = np.reshape(sample, (-1, sample.shape[2] * sample.shape[3], sample.shape[4]))
252 | im.imwrite(sample, '%s/Epoch-%d_Iter-%d.jpg' % (save_dir, epoch, iter))
253 |
254 | return run
255 |
256 |
257 | D_train_step = D_train_graph()
258 | G_train_step = G_train_graph()
259 | sample = sample_graph()
260 |
261 |
262 | # ==============================================================================
263 | # = train =
264 | # ==============================================================================
265 |
266 | # step counter
267 | step_cnt, update_cnt = tl.counter()
268 |
269 | # checkpoint
270 | checkpoint = tl.Checkpoint(
271 | {v.name: v for v in tf.global_variables()},
272 | py.join(output_dir, 'checkpoints'),
273 | max_to_keep=1
274 | )
275 | checkpoint.restore().initialize_or_restore()
276 |
277 | # summary
278 | sess.run(tf.contrib.summary.summary_writer_initializer_op())
279 |
280 | # learning rate schedule
281 | lr_fn = tl.LinearDecayLR(args.learning_rate, args.n_epochs, args.epoch_start_decay)
282 |
283 | # train
284 | try:
285 | for ep in tqdm.trange(args.n_epochs, desc='Epoch Loop'):
286 | # learning rate
287 | lr_ipt = lr_fn(ep)
288 |
289 | for it in tqdm.trange(len_train_dataset, desc='Inner Epoch Loop'):
290 | if it + ep * len_train_dataset < sess.run(step_cnt):
291 | continue
292 | step = sess.run(update_cnt)
293 |
294 | # train D
295 | if step % (args.n_d + 1) != 0:
296 | D_train_step(lr=lr_ipt)
297 | # train G
298 | else:
299 | G_train_step(lr=lr_ipt)
300 |
301 | # save
302 | if step % (1000 * (args.n_d + 1)) == 0:
303 | checkpoint.save(step)
304 |
305 | # sample
306 | if step % (100 * (args.n_d + 1)) == 0:
307 | sample(ep, it)
308 | except Exception:
309 | traceback.print_exc()
310 | finally:
311 | checkpoint.save(step)
312 | sess.close()
313 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import tensorflow as tf
4 | import tensorflow.contrib.slim as slim
5 |
6 |
7 | # ==============================================================================
8 | # = operations =
9 | # ==============================================================================
10 |
11 | def tile_concat(a_list, b_list=[]):
12 | # tile all elements of `b_list` and then concat `a_list + b_list` along the channel axis
13 | # `a` shape: (N, H, W, C_a)
14 | # `b` shape: can be (N, 1, 1, C_b) or (N, C_b)
15 | a_list = list(a_list) if isinstance(a_list, (list, tuple)) else [a_list]
16 | b_list = list(b_list) if isinstance(b_list, (list, tuple)) else [b_list]
17 | for i, b in enumerate(b_list):
18 | b = tf.reshape(b, [-1, 1, 1, b.shape[-1]])
19 | b = tf.tile(b, [1, a_list[0].shape[1], a_list[0].shape[2], 1])
20 | b_list[i] = b
21 | return tf.concat(a_list + b_list, axis=-1)
22 |
23 |
24 | # ==============================================================================
25 | # = others =
26 | # ==============================================================================
27 |
28 | def get_norm_layer(norm, training, updates_collections=None):
29 | if norm == 'none':
30 | return lambda x: x
31 | elif norm == 'batch_norm':
32 | return functools.partial(slim.batch_norm, scale=True, is_training=training, updates_collections=updates_collections)
33 | elif norm == 'instance_norm':
34 | return slim.instance_norm
35 | elif norm == 'layer_norm':
36 | return slim.layer_norm
37 |
--------------------------------------------------------------------------------