├── assests
├── men.png
├── logo.jpg
├── women.png
├── overview.PNG
├── generator.PNG
└── discriminator.PNG
├── dataset
└── celebA
│ ├── train
│ └── 000001.jpg
│ └── test
│ └── your_test_image.png
├── docker-compose.yml
├── Dockerfile.gpu
├── LICENSE
├── README.md
├── .gitignore
├── download.py
├── main.py
├── ops.py
├── utils.py
└── StarGAN.py
/assests/men.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/men.png
--------------------------------------------------------------------------------
/assests/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/logo.jpg
--------------------------------------------------------------------------------
/assests/women.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/women.png
--------------------------------------------------------------------------------
/assests/overview.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/overview.PNG
--------------------------------------------------------------------------------
/assests/generator.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/generator.PNG
--------------------------------------------------------------------------------
/assests/discriminator.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/discriminator.PNG
--------------------------------------------------------------------------------
/dataset/celebA/train/000001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/dataset/celebA/train/000001.jpg
--------------------------------------------------------------------------------
/dataset/celebA/test/your_test_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/dataset/celebA/test/your_test_image.png
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '2'
2 | services:
3 | tensorflow:
4 | image: slamhan/stargan-gpu:1.8.0-gpu-py3-jupyter
5 | volumes:
6 | - /notebooks:/notebooks
7 | ports:
8 | - 8888:8888/tcp
9 |
--------------------------------------------------------------------------------
/Dockerfile.gpu:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:1.8.0-gpu-py3
2 |
3 | RUN apt-get update -qq -y \
4 | && apt-get install -y libsm6 libxrender1 libxext-dev python3-tk\
5 | && apt-get clean \
6 | && rm -rf /var/lib/apt/lists/*
7 |
8 | COPY requirements.txt /opt/
9 | RUN pip3 install --upgrade pip
10 | RUN pip3 install cmake
11 | RUN pip3 install dlib
12 | #RUN pip3 install torchvision
13 | RUN pip3 --no-cache-dir install -r /opt/requirements.txt && rm /opt/requirements.txt
14 | RUN pip3 install jupyter matplotlib
15 | RUN pip3 install jupyter_http_over_ws
16 | RUN jupyter serverextension enable --py jupyter_http_over_ws
17 | # patch for tensorflow:1.8.0-gpu-py3 image
18 | RUN cd /usr/local/cuda/lib64 \
19 | && mv stubs/libcuda.so ./ \
20 | && ln -s libcuda.so libcuda.so.1 \
21 | && ldconfig
22 |
23 | WORKDIR "/notebooks"
24 | CMD ["jupyter-notebook", "--allow-root" ,"--port=8888" ,"--no-browser" ,"--ip=0.0.0.0"]
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |

2 |
3 | --------------------------------------------------------------------------------
4 | ## Requirements
5 | * Tensorflow 1.8
6 | * Python 3.6
7 |
8 | ## Usage
9 | ### Downloading the dataset
10 | ```python
11 | > python download.py celebA
12 | ```
13 |
14 | ```
15 | ├── dataset
16 | └── celebA
17 | ├── train
18 | ├── 000001.jpg
19 | ├── 000002.jpg
20 | └── ...
21 | ├── test (It is not celebA)
22 | ├── a.jpg (The test image that you wanted)
23 | ├── b.png
24 | └── ...
25 | ├── list_attr_celeba.txt (For attribute information)
26 | ```
27 |
28 | ### Train
29 | * python main.py --phase train
30 |
31 | ### Test
32 | * python main.py --phase test
33 | * The celebA test image and the image you wanted run simultaneously
34 |
35 | ### Pretrained model
36 | * Download [checkpoint for 128x128](https://drive.google.com/open?id=1ezwtU1O_rxgNXgJaHcAynVX8KjMt0Ua-)
37 |
38 | ## Summary
39 | 
40 |
41 | ## Results (128x128, wgan-gp)
42 | ### Women
43 | 
44 |
45 | ### Men
46 | 
47 |
48 | ## Related works
49 | * [CycleGAN-Tensorflow](https://github.com/taki0112/CycleGAN-Tensorflow)
50 | * [DiscoGAN-Tensorflow](https://github.com/taki0112/DiscoGAN-Tensorflow)
51 | * [UNIT-Tensorflow](https://github.com/taki0112/UNIT-Tensorflow)
52 | * [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow)
53 |
54 | ## Reference
55 | * [StarGAN paper](https://arxiv.org/abs/1711.09020)
56 | * [Author pytorch code](https://github.com/yunjey/StarGAN)
57 |
58 | ## Author
59 | Junho Kim
60 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import argparse
4 | import requests
5 |
6 | from tqdm import tqdm
7 |
8 | parser = argparse.ArgumentParser(description='Download dataset for StarGAN')
9 | parser.add_argument('dataset', metavar='N', type=str, nargs='+', choices=['celebA'],
10 | help='name of dataset to download [celebA]')
11 |
12 |
13 | def download_file_from_google_drive(id, destination):
14 | URL = "https://docs.google.com/uc?export=download"
15 | session = requests.Session()
16 |
17 | response = session.get(URL, params={'id': id}, stream=True)
18 | token = get_confirm_token(response)
19 |
20 | if token:
21 | params = {'id': id, 'confirm': token}
22 | response = session.get(URL, params=params, stream=True)
23 |
24 | save_response_content(response, destination)
25 |
26 |
27 | def get_confirm_token(response):
28 | for key, value in response.cookies.items():
29 | if key.startswith('download_warning'):
30 | return value
31 | return None
32 |
33 |
34 | def save_response_content(response, destination, chunk_size=32 * 1024):
35 | total_size = int(response.headers.get('content-length', 0))
36 | with open(destination, "wb") as f:
37 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
38 | unit='B', unit_scale=True, desc=destination):
39 | if chunk: # filter out keep-alive new chunks
40 | f.write(chunk)
41 |
42 |
43 | def download_celeb_a(dirpath):
44 | data_dir = 'celebA'
45 | celebA_dir = os.path.join(dirpath, data_dir)
46 | prepare_data_dir(celebA_dir)
47 |
48 | file_name, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
49 | txt_name, txt_drive_id = "list_attr_celeba.txt", "0B7EVK8r0v71pblRyaVFSWGxPY0U"
50 |
51 | save_path = os.path.join(dirpath, file_name)
52 | txt_save_path = os.path.join(celebA_dir, txt_name)
53 |
54 | if os.path.exists(txt_save_path):
55 | print('[*] {} already exists'.format(txt_save_path))
56 | else:
57 | download_file_from_google_drive(drive_id, txt_save_path)
58 |
59 | if os.path.exists(save_path):
60 | print('[*] {} already exists'.format(save_path))
61 | else:
62 | download_file_from_google_drive(drive_id, save_path)
63 |
64 | with zipfile.ZipFile(save_path) as zf:
65 | zf.extractall(celebA_dir)
66 |
67 | # os.remove(save_path)
68 | os.rename(os.path.join(celebA_dir, 'img_align_celeba'), os.path.join(celebA_dir, 'train'))
69 |
70 | custom_data_dir = os.path.join(celebA_dir, 'test')
71 | prepare_data_dir(custom_data_dir)
72 |
73 |
74 | def prepare_data_dir(path='./dataset'):
75 | if not os.path.exists(path):
76 | os.makedirs(path)
77 |
78 |
79 | if __name__ == '__main__':
80 | args = parser.parse_args()
81 | prepare_data_dir()
82 |
83 | if any(name in args.dataset for name in ['CelebA', 'celebA', 'celebA']):
84 | download_celeb_a('./dataset')
85 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from StarGAN import StarGAN
2 | import argparse
3 | from utils import *
4 |
5 | """parsing and configuration"""
6 | def parse_args():
7 | desc = "Tensorflow implementation of StarGAN"
8 | parser = argparse.ArgumentParser(description=desc)
9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?')
10 | parser.add_argument('--dataset', type=str, default='celebA', help='dataset_name')
11 |
12 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run')
13 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
14 | parser.add_argument('--batch_size', type=int, default=16, help='The size of batch size')
15 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq')
16 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq')
17 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag')
18 | parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch')
19 |
20 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate')
21 | parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
22 | parser.add_argument('--adv_weight', type=float, default=1, help='Weight about GAN')
23 | parser.add_argument('--rec_weight', type=float, default=10, help='Weight about Reconstruction')
24 | parser.add_argument('--cls_weight', type=float, default=10, help='Weight about Classification')
25 |
26 | parser.add_argument('--gan_type', type=str, default='wgan-gp', help='gan / lsgan / wgan-gp / wgan-lp / dragan / hinge')
27 | parser.add_argument('--selected_attrs', type=str, nargs='+', help='selected attributes for the CelebA dataset',
28 | default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'])
29 |
30 | parser.add_argument('--custom_label', type=int, nargs='+', help='custom label about selected attributes',
31 | default=[1, 0, 0, 0, 0])
32 | # If your image is "Young, Man, Black Hair" = [1, 0, 0, 1, 1]
33 |
34 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer')
35 | parser.add_argument('--n_res', type=int, default=6, help='The number of resblock')
36 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer')
37 | parser.add_argument('--n_critic', type=int, default=5, help='The number of critic')
38 |
39 | parser.add_argument('--img_size', type=int, default=128, help='The size of image')
40 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel')
41 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not')
42 |
43 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
44 | help='Directory name to save the checkpoints')
45 | parser.add_argument('--result_dir', type=str, default='results',
46 | help='Directory name to save the generated images')
47 | parser.add_argument('--log_dir', type=str, default='logs',
48 | help='Directory name to save training logs')
49 | parser.add_argument('--sample_dir', type=str, default='samples',
50 | help='Directory name to save the samples on training')
51 |
52 | return check_args(parser.parse_args())
53 |
54 | """checking arguments"""
55 | def check_args(args):
56 | # --checkpoint_dir
57 | check_folder(args.checkpoint_dir)
58 |
59 | # --result_dir
60 | check_folder(args.result_dir)
61 |
62 | # --result_dir
63 | check_folder(args.log_dir)
64 |
65 | # --sample_dir
66 | check_folder(args.sample_dir)
67 |
68 | # --epoch
69 | try:
70 | assert args.epoch >= 1
71 | except:
72 | print('number of epochs must be larger than or equal to one')
73 |
74 | # --batch_size
75 | try:
76 | assert args.batch_size >= 1
77 | except:
78 | print('batch size must be larger than or equal to one')
79 | return args
80 |
81 | """main"""
82 | def main():
83 | # parse arguments
84 | args = parse_args()
85 | if args is None:
86 | exit()
87 |
88 | # open session
89 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
90 | gan = StarGAN(sess, args)
91 |
92 | # build graph
93 | gan.build_model()
94 |
95 | # show network architecture
96 | show_all_variables()
97 |
98 | if args.phase == 'train' :
99 | gan.train()
100 | print(" [*] Training finished!")
101 |
102 | if args.phase == 'test' :
103 | gan.test()
104 | print(" [*] Test finished!")
105 |
106 | if __name__ == '__main__':
107 | main()
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow.contrib as tf_contrib
3 |
4 |
5 | # Xavier : tf_contrib.layers.xavier_initializer()
6 | # He : tf_contrib.layers.variance_scaling_initializer()
7 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
8 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
9 |
10 | weight_init = tf_contrib.layers.xavier_initializer()
11 | weight_regularizer = None
12 |
13 | ##################################################################################
14 | # Layer
15 | ##################################################################################
16 |
17 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, scope='conv_0'):
18 | with tf.variable_scope(scope):
19 | if pad_type == 'zero' :
20 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
21 | if pad_type == 'reflect' :
22 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
23 |
24 | x = tf.layers.conv2d(inputs=x, filters=channels,
25 | kernel_size=kernel, kernel_initializer=weight_init,
26 | kernel_regularizer=weight_regularizer,
27 | strides=stride, use_bias=use_bias)
28 |
29 | return x
30 |
31 |
32 | def deconv(x, channels, kernel=4, stride=2, use_bias=True, scope='deconv_0'):
33 | with tf.variable_scope(scope):
34 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
35 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
36 | strides=stride, padding='SAME', use_bias=use_bias)
37 |
38 | return x
39 |
40 | def flatten(x) :
41 | return tf.layers.flatten(x)
42 |
43 | ##################################################################################
44 | # Residual-block
45 | ##################################################################################
46 |
47 | def resblock(x_init, channels, use_bias=True, scope='resblock'):
48 | with tf.variable_scope(scope):
49 | with tf.variable_scope('res1'):
50 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias)
51 | x = instance_norm(x)
52 | x = relu(x)
53 |
54 | with tf.variable_scope('res2'):
55 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias)
56 | x = instance_norm(x)
57 |
58 | return x + x_init
59 |
60 |
61 | ##################################################################################
62 | # Activation function
63 | ##################################################################################
64 |
65 | def lrelu(x, alpha=0.2):
66 | return tf.nn.leaky_relu(x, alpha)
67 |
68 |
69 | def relu(x):
70 | return tf.nn.relu(x)
71 |
72 |
73 | def tanh(x):
74 | return tf.tanh(x)
75 |
76 | ##################################################################################
77 | # Normalization function
78 | ##################################################################################
79 |
80 | def instance_norm(x, scope='instance_norm'):
81 | return tf_contrib.layers.instance_norm(x,
82 | epsilon=1e-05,
83 | center=True, scale=True,
84 | scope=scope)
85 |
86 | ##################################################################################
87 | # Loss function
88 | ##################################################################################
89 |
90 | def discriminator_loss(loss_func, real, fake):
91 | real_loss = 0
92 | fake_loss = 0
93 |
94 | if loss_func.__contains__('wgan') :
95 | real_loss = -tf.reduce_mean(real)
96 | fake_loss = tf.reduce_mean(fake)
97 |
98 | if loss_func == 'lsgan' :
99 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
100 | fake_loss = tf.reduce_mean(tf.square(fake))
101 |
102 | if loss_func == 'gan' or loss_func == 'dragan' :
103 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
104 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
105 |
106 | if loss_func == 'hinge' :
107 | real_loss = tf.reduce_mean(relu(1.0 - real))
108 | fake_loss = tf.reduce_mean(relu(1.0 + fake))
109 |
110 | loss = real_loss + fake_loss
111 |
112 | return loss
113 |
114 | def generator_loss(loss_func, fake):
115 | fake_loss = 0
116 |
117 | if loss_func.__contains__('wgan') :
118 | fake_loss = -tf.reduce_mean(fake)
119 |
120 | if loss_func == 'lsgan' :
121 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
122 |
123 | if loss_func == 'gan' or loss_func == 'dragan' :
124 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
125 |
126 | if loss_func == 'hinge' :
127 | fake_loss = -tf.reduce_mean(fake)
128 |
129 | loss = fake_loss
130 |
131 | return loss
132 |
133 | def classification_loss(logit, label) :
134 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logit))
135 |
136 | return loss
137 |
138 | def L1_loss(x, y):
139 | loss = tf.reduce_mean(tf.abs(x - y))
140 |
141 | return loss
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import scipy.misc
2 | import numpy as np
3 | import os
4 | from scipy import misc
5 |
6 | import tensorflow as tf
7 | import tensorflow.contrib.slim as slim
8 | import random
9 |
10 | class ImageData:
11 |
12 | def __init__(self, load_size, channels, data_path, selected_attrs, augment_flag=False):
13 | self.load_size = load_size
14 | self.channels = channels
15 | self.augment_flag = augment_flag
16 | self.selected_attrs = selected_attrs
17 |
18 | self.data_path = os.path.join(data_path, 'train')
19 | check_folder(self.data_path)
20 | self.lines = open(os.path.join(data_path, 'list_attr_celeba.txt'), 'r').readlines()
21 |
22 | self.train_dataset = []
23 | self.train_dataset_label = []
24 | self.train_dataset_fix_label = []
25 |
26 | self.test_dataset = []
27 | self.test_dataset_label = []
28 | self.test_dataset_fix_label = []
29 |
30 | self.attr2idx = {}
31 | self.idx2attr = {}
32 |
33 | def image_processing(self, filename, label, fix_label):
34 | x = tf.read_file(filename)
35 | x_decode = tf.image.decode_jpeg(x, channels=self.channels)
36 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
37 | img = tf.cast(img, tf.float32) / 127.5 - 1
38 |
39 | if self.augment_flag :
40 | augment_size = self.load_size + (30 if self.load_size == 256 else 15)
41 | p = random.random()
42 |
43 | if p > 0.5 :
44 | img = augmentation(img, augment_size)
45 |
46 |
47 | return img, label, fix_label
48 |
49 | def preprocess(self) :
50 | all_attr_names = self.lines[1].split()
51 | for i, attr_name in enumerate(all_attr_names) :
52 | self.attr2idx[attr_name] = i
53 | self.idx2attr[i] = attr_name
54 |
55 |
56 | lines = self.lines[2:]
57 | random.seed(1234)
58 | random.shuffle(lines)
59 |
60 | for i, line in enumerate(lines) :
61 | split = line.split()
62 | filename = os.path.join(self.data_path, split[0])
63 | values = split[1:]
64 |
65 | label = []
66 |
67 | for attr_name in self.selected_attrs :
68 | idx = self.attr2idx[attr_name]
69 |
70 | if values[idx] == '1' :
71 | label.append(1.0)
72 | else :
73 | label.append(0.0)
74 |
75 | if i < 2000 :
76 | self.test_dataset.append(filename)
77 | self.test_dataset_label.append(label)
78 | else :
79 | self.train_dataset.append(filename)
80 | self.train_dataset_label.append(label)
81 | # ['./dataset/celebA/train/019932.jpg', [1, 0, 0, 0, 1]]
82 |
83 | self.test_dataset_fix_label = create_labels(self.test_dataset_label, self.selected_attrs)
84 | self.train_dataset_fix_label = create_labels(self.train_dataset_label, self.selected_attrs)
85 |
86 | print('\n Finished preprocessing the CelebA dataset...')
87 |
88 | def load_test_data(image_path, size=128):
89 | img = misc.imread(image_path, mode='RGB')
90 | img = misc.imresize(img, [size, size])
91 | img = np.expand_dims(img, axis=0)
92 | img = normalize(img)
93 |
94 | return img
95 |
96 | def augmentation(image, aug_size):
97 | seed = random.randint(0, 2 ** 31 - 1)
98 | ori_image_shape = tf.shape(image)
99 | image = tf.image.random_flip_left_right(image, seed=seed)
100 | image = tf.image.resize_images(image, [aug_size, aug_size])
101 | image = tf.random_crop(image, ori_image_shape, seed=seed)
102 | return image
103 |
104 | def normalize(x) :
105 | return x/127.5 - 1
106 |
107 | def save_images(images, size, image_path):
108 | return imsave(inverse_transform(images), size, image_path)
109 |
110 | def merge(images, size):
111 | h, w = images.shape[1], images.shape[2]
112 |
113 | if (images.shape[3] in (3,4)):
114 | c = images.shape[3]
115 | img = np.zeros((h * size[0], w * size[1], c))
116 | for idx, image in enumerate(images):
117 | i = idx % size[1]
118 | j = idx // size[1]
119 | img[j * h:j * h + h, i * w:i * w + w, :] = image
120 | return img
121 |
122 | elif images.shape[3] == 1:
123 | img = np.zeros((h * size[0], w * size[1]))
124 | for idx, image in enumerate(images):
125 | i = idx % size[1]
126 | j = idx // size[1]
127 | img[j * h:j * h + h, i * w:i * w + w] = image[:, :, 0]
128 | return img
129 |
130 | else:
131 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
132 |
133 |
134 | def imsave(images, size, path):
135 | return scipy.misc.imsave(path, merge(images, size))
136 |
137 | def inverse_transform(images):
138 | return (images+1.)/2.
139 |
140 | def check_folder(log_dir):
141 | if not os.path.exists(log_dir):
142 | os.makedirs(log_dir)
143 | return log_dir
144 |
145 | def show_all_variables():
146 | model_vars = tf.trainable_variables()
147 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
148 |
149 | def str2bool(x):
150 | return x.lower() in ('true')
151 |
152 | def create_labels(c_org, selected_attrs=None):
153 | """Generate target domain labels for debugging and testing."""
154 | # Get hair color indices.
155 | c_org = np.asarray(c_org)
156 | hair_color_indices = []
157 | for i, attr_name in enumerate(selected_attrs):
158 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
159 | hair_color_indices.append(i)
160 |
161 | c_trg_list = []
162 |
163 | for i in range(len(selected_attrs)):
164 | c_trg = c_org.copy()
165 |
166 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
167 | c_trg[:, i] = 1.0
168 | for j in hair_color_indices:
169 | if j != i:
170 | c_trg[:, j] = 0.0
171 | else:
172 | c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
173 |
174 | c_trg_list.append(c_trg)
175 |
176 | c_trg_list = np.transpose(c_trg_list, axes=[1, 0, 2]) # [c_dim, bs, ch]
177 |
178 | return c_trg_list
--------------------------------------------------------------------------------
/StarGAN.py:
--------------------------------------------------------------------------------
1 | from ops import *
2 | from utils import *
3 | import time
4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
5 | import numpy as np
6 | from glob import glob
7 |
8 | class StarGAN(object) :
9 | def __init__(self, sess, args):
10 | self.model_name = 'StarGAN'
11 | self.sess = sess
12 | self.checkpoint_dir = args.checkpoint_dir
13 | self.sample_dir = args.sample_dir
14 | self.result_dir = args.result_dir
15 | self.log_dir = args.log_dir
16 | self.dataset_name = args.dataset
17 | self.dataset_path = os.path.join('./dataset', self.dataset_name)
18 | self.augment_flag = args.augment_flag
19 |
20 | self.epoch = args.epoch
21 | self.iteration = args.iteration
22 | self.decay_flag = args.decay_flag
23 | self.decay_epoch = args.decay_epoch
24 |
25 | self.gan_type = args.gan_type
26 |
27 | self.batch_size = args.batch_size
28 | self.print_freq = args.print_freq
29 | self.save_freq = args.save_freq
30 |
31 | self.init_lr = args.lr
32 | self.ch = args.ch
33 | self.selected_attrs = args.selected_attrs
34 | self.custom_label = np.expand_dims(args.custom_label, axis=0)
35 | self.c_dim = len(self.selected_attrs)
36 |
37 | """ Weight """
38 | self.adv_weight = args.adv_weight
39 | self.rec_weight = args.rec_weight
40 | self.cls_weight = args.cls_weight
41 | self.ld = args.ld
42 |
43 | """ Generator """
44 | self.n_res = args.n_res
45 |
46 | """ Discriminator """
47 | self.n_dis = args.n_dis
48 | self.n_critic = args.n_critic
49 |
50 | self.img_size = args.img_size
51 | self.img_ch = args.img_ch
52 |
53 | print()
54 |
55 | print("##### Information #####")
56 | print("# gan type : ", self.gan_type)
57 | print("# selected_attrs : ", self.selected_attrs)
58 | print("# dataset : ", self.dataset_name)
59 | print("# batch_size : ", self.batch_size)
60 | print("# epoch : ", self.epoch)
61 | print("# iteration per epoch : ", self.iteration)
62 |
63 | print()
64 |
65 | print("##### Generator #####")
66 | print("# residual blocks : ", self.n_res)
67 |
68 | print()
69 |
70 | print("##### Discriminator #####")
71 | print("# discriminator layer : ", self.n_dis)
72 | print("# the number of critic : ", self.n_critic)
73 |
74 | ##################################################################################
75 | # Generator
76 | ##################################################################################
77 |
78 | def generator(self, x_init, c, reuse=False, scope="generator"):
79 | channel = self.ch
80 | c = tf.cast(tf.reshape(c, shape=[-1, 1, 1, c.shape[-1]]), tf.float32)
81 | c = tf.tile(c, [1, x_init.shape[1], x_init.shape[2], 1])
82 | x = tf.concat([x_init, c], axis=-1)
83 |
84 | with tf.variable_scope(scope, reuse=reuse) :
85 | x = conv(x, channel, kernel=7, stride=1, pad=3, use_bias=False, scope='conv')
86 | x = instance_norm(x, scope='ins_norm')
87 | x = relu(x)
88 |
89 | # Down-Sampling
90 | for i in range(2) :
91 | x = conv(x, channel*2, kernel=4, stride=2, pad=1, use_bias=False, scope='conv_'+str(i))
92 | x = instance_norm(x, scope='down_ins_norm_'+str(i))
93 | x = relu(x)
94 |
95 | channel = channel * 2
96 |
97 | # Bottleneck
98 | for i in range(self.n_res):
99 | x = resblock(x, channel, use_bias=False, scope='resblock_' + str(i))
100 |
101 | # Up-Sampling
102 | for i in range(2) :
103 | x = deconv(x, channel//2, kernel=4, stride=2, use_bias=False, scope='deconv_'+str(i))
104 | x = instance_norm(x, scope='up_ins_norm'+str(i))
105 | x = relu(x)
106 |
107 | channel = channel // 2
108 |
109 |
110 | x = conv(x, channels=3, kernel=7, stride=1, pad=3, use_bias=False, scope='G_logit')
111 | x = tanh(x)
112 |
113 | return x
114 |
115 | ##################################################################################
116 | # Discriminator
117 | ##################################################################################
118 |
119 | def discriminator(self, x_init, reuse=False, scope="discriminator"):
120 | with tf.variable_scope(scope, reuse=reuse) :
121 | channel = self.ch
122 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, use_bias=True, scope='conv_0')
123 | x = lrelu(x, 0.01)
124 |
125 | for i in range(1, self.n_dis):
126 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, use_bias=True, scope='conv_' + str(i))
127 | x = lrelu(x, 0.01)
128 |
129 | channel = channel * 2
130 |
131 | c_kernel = int(self.img_size / np.power(2, self.n_dis))
132 |
133 | logit = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, scope='D_logit')
134 | c = conv(x, channels=self.c_dim, kernel=c_kernel, stride=1, use_bias=False, scope='D_label')
135 | c = tf.reshape(c, shape=[-1, self.c_dim])
136 |
137 | return logit, c
138 |
139 | ##################################################################################
140 | # Model
141 | ##################################################################################
142 |
143 | def gradient_panalty(self, real, fake, scope="discriminator"):
144 | if self.gan_type == 'dragan' :
145 | shape = tf.shape(real)
146 | eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
147 | x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
148 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
149 | noise = 0.5 * x_std * eps # delta in paper
150 |
151 | # Author suggested U[0,1] in original paper, but he admitted it is bug in github
152 | # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.
153 |
154 | alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)
155 | interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X
156 |
157 | else :
158 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
159 | interpolated = alpha*real + (1. - alpha)*fake
160 |
161 | logit, _ = self.discriminator(interpolated, reuse=True, scope=scope)
162 |
163 |
164 | GP = 0
165 |
166 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
167 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
168 |
169 | # WGAN - LP
170 | if self.gan_type == 'wgan-lp' :
171 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
172 |
173 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
174 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
175 |
176 | return GP
177 |
178 | def build_model(self):
179 | self.lr = tf.placeholder(tf.float32, name='learning_rate')
180 |
181 | """ Input Image"""
182 | Image_data_class = ImageData(load_size=self.img_size, channels=self.img_ch, data_path=self.dataset_path, selected_attrs=self.selected_attrs, augment_flag=self.augment_flag)
183 | Image_data_class.preprocess()
184 |
185 | train_dataset_num = len(Image_data_class.train_dataset)
186 | test_dataset_num = len(Image_data_class.test_dataset)
187 |
188 | train_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label))
189 | test_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label))
190 |
191 | gpu_device = '/gpu:0'
192 | train_dataset = train_dataset.\
193 | apply(shuffle_and_repeat(train_dataset_num)).\
194 | apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
195 | apply(prefetch_to_device(gpu_device, self.batch_size))
196 |
197 | test_dataset = test_dataset.\
198 | apply(shuffle_and_repeat(test_dataset_num)).\
199 | apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
200 | apply(prefetch_to_device(gpu_device, self.batch_size))
201 |
202 | train_dataset_iterator = train_dataset.make_one_shot_iterator()
203 | test_dataset_iterator = test_dataset.make_one_shot_iterator()
204 |
205 |
206 | self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next() # Input image / Original domain labels
207 | label_trg = tf.random_shuffle(label_org) # Target domain labels
208 | label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2])
209 |
210 | self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next() # Input image / Original domain labels
211 | test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2])
212 |
213 | self.custom_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='custom_image') # Custom Image
214 | custom_label_fix_list = tf.transpose(create_labels(self.custom_label, self.selected_attrs), perm=[1, 0, 2])
215 |
216 | """ Define Generator, Discriminator """
217 | x_fake = self.generator(self.x_real, label_trg) # real a
218 | x_recon = self.generator(x_fake, label_org, reuse=True) # real b
219 |
220 | real_logit, real_cls = self.discriminator(self.x_real)
221 | fake_logit, fake_cls = self.discriminator(x_fake, reuse=True)
222 |
223 |
224 | """ Define Loss """
225 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
226 | GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
227 | else :
228 | GP = 0
229 |
230 | g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit)
231 | g_cls_loss = classification_loss(logit=fake_cls, label=label_trg)
232 | g_rec_loss = L1_loss(self.x_real, x_recon)
233 |
234 | d_adv_loss = discriminator_loss(loss_func=self.gan_type, real=real_logit, fake=fake_logit) + GP
235 | d_cls_loss = classification_loss(logit=real_cls, label=label_org)
236 |
237 | self.d_loss = self.adv_weight * d_adv_loss + self.cls_weight * d_cls_loss
238 | self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss
239 |
240 |
241 | """ Result Image """
242 | self.x_fake_list = tf.map_fn(lambda x : self.generator(self.x_real, x, reuse=True), label_fix_list, dtype=tf.float32)
243 |
244 |
245 | """ Test Image """
246 | self.x_test_fake_list = tf.map_fn(lambda x : self.generator(self.x_test, x, reuse=True), test_label_fix_list, dtype=tf.float32)
247 | self.custom_fake_image = tf.map_fn(lambda x : self.generator(self.custom_image, x, reuse=True), custom_label_fix_list, dtype=tf.float32)
248 |
249 |
250 | """ Training """
251 | t_vars = tf.trainable_variables()
252 | G_vars = [var for var in t_vars if 'generator' in var.name]
253 | D_vars = [var for var in t_vars if 'discriminator' in var.name]
254 |
255 | self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars)
256 | self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars)
257 |
258 |
259 | """" Summary """
260 | self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss)
261 | self.Discriminator_loss = tf.summary.scalar("Discriminator_loss", self.d_loss)
262 |
263 | self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
264 | self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
265 | self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)
266 |
267 | self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
268 | self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)
269 |
270 | self.g_summary_loss = tf.summary.merge([self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss])
271 | self.d_summary_loss = tf.summary.merge([self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
272 |
273 |
274 | def train(self):
275 | # initialize all variables
276 | tf.global_variables_initializer().run()
277 |
278 | # saver to save model
279 | self.saver = tf.train.Saver()
280 |
281 | # summary writer
282 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
283 |
284 | # restore check-point if it exits
285 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
286 | if could_load:
287 | start_epoch = (int)(checkpoint_counter / self.iteration)
288 | start_batch_id = checkpoint_counter - start_epoch * self.iteration
289 | counter = checkpoint_counter
290 | print(" [*] Load SUCCESS")
291 | else:
292 | start_epoch = 0
293 | start_batch_id = 0
294 | counter = 1
295 | print(" [!] Load failed...")
296 |
297 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
298 | check_folder(self.sample_dir)
299 |
300 | # loop for epoch
301 | start_time = time.time()
302 | past_g_loss = -1.
303 | lr = self.init_lr
304 | for epoch in range(start_epoch, self.epoch):
305 | if self.decay_flag :
306 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay
307 |
308 | for idx in range(start_batch_id, self.iteration):
309 | train_feed_dict = {
310 | self.lr : lr
311 | }
312 |
313 | # Update D
314 | _, d_loss, summary_str = self.sess.run([self.d_optimizer, self.d_loss, self.d_summary_loss], feed_dict = train_feed_dict)
315 | self.writer.add_summary(summary_str, counter)
316 |
317 | # Update G
318 | g_loss = None
319 | if (counter - 1) % self.n_critic == 0 :
320 | real_images, fake_images, _, g_loss, summary_str = self.sess.run([self.x_real, self.x_fake_list, self.g_optimizer, self.g_loss, self.g_summary_loss], feed_dict = train_feed_dict)
321 | self.writer.add_summary(summary_str, counter)
322 | past_g_loss = g_loss
323 |
324 | # display training status
325 | counter += 1
326 | if g_loss == None :
327 | g_loss = past_g_loss
328 |
329 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
330 |
331 | if np.mod(idx+1, self.print_freq) == 0 :
332 | real_image = np.expand_dims(real_images[0], axis=0)
333 | fake_image = np.transpose(fake_images, axes=[1, 0, 2, 3, 4])[0] # [bs, c_dim, h, w, ch]
334 |
335 | save_images(real_image, [1, 1],
336 | './{}/real_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
337 |
338 | save_images(fake_image, [1, self.c_dim],
339 | './{}/fake_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1))
340 |
341 | if np.mod(idx + 1, self.save_freq) == 0:
342 | self.save(self.checkpoint_dir, counter)
343 |
344 | # After an epoch, start_batch_id is set to zero
345 | # non-zero value is only for the first epoch after loading pre-trained model
346 | start_batch_id = 0
347 |
348 | # save model for final step
349 | self.save(self.checkpoint_dir, counter)
350 |
351 | @property
352 | def model_dir(self):
353 | n_res = str(self.n_res) + 'resblock'
354 | n_dis = str(self.n_dis) + 'dis'
355 |
356 | return "{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name,
357 | self.gan_type,
358 | n_res, n_dis)
359 |
360 | def save(self, checkpoint_dir, step):
361 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
362 |
363 | if not os.path.exists(checkpoint_dir):
364 | os.makedirs(checkpoint_dir)
365 |
366 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
367 |
368 | def load(self, checkpoint_dir):
369 | import re
370 | print(" [*] Reading checkpoints...")
371 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
372 |
373 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
374 | if ckpt and ckpt.model_checkpoint_path:
375 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
376 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
377 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
378 | print(" [*] Success to read {}".format(ckpt_name))
379 | return True, counter
380 | else:
381 | print(" [*] Failed to find a checkpoint")
382 | return False, 0
383 |
384 | def test(self):
385 | tf.global_variables_initializer().run()
386 | test_path = os.path.join(self.dataset_path, 'test')
387 | check_folder(test_path)
388 | test_files = glob(os.path.join(test_path, '*.*'))
389 |
390 | self.saver = tf.train.Saver()
391 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
392 | self.result_dir = os.path.join(self.result_dir, self.model_dir)
393 | check_folder(self.result_dir)
394 |
395 | image_folder = os.path.join(self.result_dir, 'images')
396 | check_folder(image_folder)
397 |
398 | if could_load :
399 | print(" [*] Load SUCCESS")
400 | else :
401 | print(" [!] Load failed...")
402 |
403 | # write html for visual comparison
404 | index_path = os.path.join(self.result_dir, 'index.html')
405 | index = open(index_path, 'w')
406 | index.write("")
407 | index.write("| name | input | output |
")
408 |
409 | # Custom Image
410 | for sample_file in test_files:
411 | print("Processing image: " + sample_file)
412 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size))
413 | image_path = os.path.join(image_folder, '{}'.format(os.path.basename(sample_file)))
414 |
415 | fake_image = self.sess.run(self.custom_fake_image, feed_dict = {self.custom_image : sample_image})
416 | fake_image = np.transpose(fake_image, axes=[1, 0, 2, 3, 4])[0]
417 | save_images(fake_image, [1, self.c_dim], image_path)
418 |
419 | index.write("%s | " % os.path.basename(image_path))
420 | index.write(" | " % (sample_file if os.path.isabs(sample_file) else (
421 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size))
422 |
423 | index.write(" | " % (image_path if os.path.isabs(image_path) else (
424 | '../..' + os.path.sep + image_path), self.img_size * self.c_dim, self.img_size))
425 | index.write("")
426 |
427 | # CelebA
428 | real_images, fake_images = self.sess.run([self.x_test, self.x_test_fake_list])
429 | fake_images = np.transpose(fake_images, axes=[1, 0, 2, 3, 4])
430 |
431 | for i in range(len(real_images)) :
432 | print("{} / {}".format(i, len(real_images)))
433 | real_path = os.path.join(image_folder, 'real_{}.png'.format(i))
434 | fake_path = os.path.join(image_folder, 'fake_{}.png'.format(i))
435 |
436 | real_image = np.expand_dims(real_images[i], axis=0)
437 | fake_image = fake_images[i]
438 | save_images(real_image, [1, 1], real_path)
439 | save_images(fake_image, [1, self.c_dim], fake_path)
440 |
441 | index.write("%s | " % os.path.basename(real_path))
442 | index.write(" | " % (real_path if os.path.isabs(real_path) else (
443 | '../..' + os.path.sep + real_path), self.img_size, self.img_size))
444 |
445 | index.write(" | " % (fake_path if os.path.isabs(fake_path) else (
446 | '../..' + os.path.sep + fake_path), self.img_size * self.c_dim, self.img_size))
447 | index.write("")
448 |
449 | index.close()
--------------------------------------------------------------------------------