├── assests ├── men.png ├── logo.jpg ├── women.png ├── overview.PNG ├── generator.PNG └── discriminator.PNG ├── dataset └── celebA │ ├── train │ └── 000001.jpg │ └── test │ └── your_test_image.png ├── docker-compose.yml ├── Dockerfile.gpu ├── LICENSE ├── README.md ├── .gitignore ├── download.py ├── main.py ├── ops.py ├── utils.py └── StarGAN.py /assests/men.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/men.png -------------------------------------------------------------------------------- /assests/logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/logo.jpg -------------------------------------------------------------------------------- /assests/women.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/women.png -------------------------------------------------------------------------------- /assests/overview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/overview.PNG -------------------------------------------------------------------------------- /assests/generator.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/generator.PNG -------------------------------------------------------------------------------- /assests/discriminator.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/assests/discriminator.PNG -------------------------------------------------------------------------------- /dataset/celebA/train/000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/dataset/celebA/train/000001.jpg -------------------------------------------------------------------------------- /dataset/celebA/test/your_test_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/StarGAN-Tensorflow/HEAD/dataset/celebA/test/your_test_image.png -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | tensorflow: 4 | image: slamhan/stargan-gpu:1.8.0-gpu-py3-jupyter 5 | volumes: 6 | - /notebooks:/notebooks 7 | ports: 8 | - 8888:8888/tcp 9 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.8.0-gpu-py3 2 | 3 | RUN apt-get update -qq -y \ 4 | && apt-get install -y libsm6 libxrender1 libxext-dev python3-tk\ 5 | && apt-get clean \ 6 | && rm -rf /var/lib/apt/lists/* 7 | 8 | COPY requirements.txt /opt/ 9 | RUN pip3 install --upgrade pip 10 | RUN pip3 install cmake 11 | RUN pip3 install dlib 12 | #RUN pip3 install torchvision 13 | RUN pip3 --no-cache-dir install -r /opt/requirements.txt && rm /opt/requirements.txt 14 | RUN pip3 install jupyter matplotlib 15 | RUN pip3 install jupyter_http_over_ws 16 | RUN jupyter serverextension enable --py jupyter_http_over_ws 17 | # patch for tensorflow:1.8.0-gpu-py3 image 18 | RUN cd /usr/local/cuda/lib64 \ 19 | && mv stubs/libcuda.so ./ \ 20 | && ln -s libcuda.so libcuda.so.1 \ 21 | && ldconfig 22 | 23 | WORKDIR "/notebooks" 24 | CMD ["jupyter-notebook", "--allow-root" ,"--port=8888" ,"--no-browser" ,"--ip=0.0.0.0"] 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junho Kim (1993.01.12) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | -------------------------------------------------------------------------------- 4 | ## Requirements 5 | * Tensorflow 1.8 6 | * Python 3.6 7 | 8 | ## Usage 9 | ### Downloading the dataset 10 | ```python 11 | > python download.py celebA 12 | ``` 13 | 14 | ``` 15 | ├── dataset 16 |    └── celebA 17 |    ├── train 18 |           ├── 000001.jpg 19 | ├── 000002.jpg 20 | └── ... 21 | ├── test (It is not celebA) 22 | ├── a.jpg (The test image that you wanted) 23 | ├── b.png 24 | └── ... 25 | ├── list_attr_celeba.txt (For attribute information) 26 | ``` 27 | 28 | ### Train 29 | * python main.py --phase train 30 | 31 | ### Test 32 | * python main.py --phase test 33 | * The celebA test image and the image you wanted run simultaneously 34 | 35 | ### Pretrained model 36 | * Download [checkpoint for 128x128](https://drive.google.com/open?id=1ezwtU1O_rxgNXgJaHcAynVX8KjMt0Ua-) 37 | 38 | ## Summary 39 | ![overview](./assests/overview.PNG) 40 | 41 | ## Results (128x128, wgan-gp) 42 | ### Women 43 | ![women](./assests/women.png) 44 | 45 | ### Men 46 | ![men](./assests/men.png) 47 | 48 | ## Related works 49 | * [CycleGAN-Tensorflow](https://github.com/taki0112/CycleGAN-Tensorflow) 50 | * [DiscoGAN-Tensorflow](https://github.com/taki0112/DiscoGAN-Tensorflow) 51 | * [UNIT-Tensorflow](https://github.com/taki0112/UNIT-Tensorflow) 52 | * [MUNIT-Tensorflow](https://github.com/taki0112/MUNIT-Tensorflow) 53 | 54 | ## Reference 55 | * [StarGAN paper](https://arxiv.org/abs/1711.09020) 56 | * [Author pytorch code](https://github.com/yunjey/StarGAN) 57 | 58 | ## Author 59 | Junho Kim 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import argparse 4 | import requests 5 | 6 | from tqdm import tqdm 7 | 8 | parser = argparse.ArgumentParser(description='Download dataset for StarGAN') 9 | parser.add_argument('dataset', metavar='N', type=str, nargs='+', choices=['celebA'], 10 | help='name of dataset to download [celebA]') 11 | 12 | 13 | def download_file_from_google_drive(id, destination): 14 | URL = "https://docs.google.com/uc?export=download" 15 | session = requests.Session() 16 | 17 | response = session.get(URL, params={'id': id}, stream=True) 18 | token = get_confirm_token(response) 19 | 20 | if token: 21 | params = {'id': id, 'confirm': token} 22 | response = session.get(URL, params=params, stream=True) 23 | 24 | save_response_content(response, destination) 25 | 26 | 27 | def get_confirm_token(response): 28 | for key, value in response.cookies.items(): 29 | if key.startswith('download_warning'): 30 | return value 31 | return None 32 | 33 | 34 | def save_response_content(response, destination, chunk_size=32 * 1024): 35 | total_size = int(response.headers.get('content-length', 0)) 36 | with open(destination, "wb") as f: 37 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 38 | unit='B', unit_scale=True, desc=destination): 39 | if chunk: # filter out keep-alive new chunks 40 | f.write(chunk) 41 | 42 | 43 | def download_celeb_a(dirpath): 44 | data_dir = 'celebA' 45 | celebA_dir = os.path.join(dirpath, data_dir) 46 | prepare_data_dir(celebA_dir) 47 | 48 | file_name, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 49 | txt_name, txt_drive_id = "list_attr_celeba.txt", "0B7EVK8r0v71pblRyaVFSWGxPY0U" 50 | 51 | save_path = os.path.join(dirpath, file_name) 52 | txt_save_path = os.path.join(celebA_dir, txt_name) 53 | 54 | if os.path.exists(txt_save_path): 55 | print('[*] {} already exists'.format(txt_save_path)) 56 | else: 57 | download_file_from_google_drive(drive_id, txt_save_path) 58 | 59 | if os.path.exists(save_path): 60 | print('[*] {} already exists'.format(save_path)) 61 | else: 62 | download_file_from_google_drive(drive_id, save_path) 63 | 64 | with zipfile.ZipFile(save_path) as zf: 65 | zf.extractall(celebA_dir) 66 | 67 | # os.remove(save_path) 68 | os.rename(os.path.join(celebA_dir, 'img_align_celeba'), os.path.join(celebA_dir, 'train')) 69 | 70 | custom_data_dir = os.path.join(celebA_dir, 'test') 71 | prepare_data_dir(custom_data_dir) 72 | 73 | 74 | def prepare_data_dir(path='./dataset'): 75 | if not os.path.exists(path): 76 | os.makedirs(path) 77 | 78 | 79 | if __name__ == '__main__': 80 | args = parser.parse_args() 81 | prepare_data_dir() 82 | 83 | if any(name in args.dataset for name in ['CelebA', 'celebA', 'celebA']): 84 | download_celeb_a('./dataset') 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from StarGAN import StarGAN 2 | import argparse 3 | from utils import * 4 | 5 | """parsing and configuration""" 6 | def parse_args(): 7 | desc = "Tensorflow implementation of StarGAN" 8 | parser = argparse.ArgumentParser(description=desc) 9 | parser.add_argument('--phase', type=str, default='train', help='train or test ?') 10 | parser.add_argument('--dataset', type=str, default='celebA', help='dataset_name') 11 | 12 | parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') 13 | parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations') 14 | parser.add_argument('--batch_size', type=int, default=16, help='The size of batch size') 15 | parser.add_argument('--print_freq', type=int, default=1000, help='The number of image_print_freq') 16 | parser.add_argument('--save_freq', type=int, default=1000, help='The number of ckpt_save_freq') 17 | parser.add_argument('--decay_flag', type=str2bool, default=True, help='The decay_flag') 18 | parser.add_argument('--decay_epoch', type=int, default=10, help='decay epoch') 19 | 20 | parser.add_argument('--lr', type=float, default=0.0001, help='The learning rate') 21 | parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda') 22 | parser.add_argument('--adv_weight', type=float, default=1, help='Weight about GAN') 23 | parser.add_argument('--rec_weight', type=float, default=10, help='Weight about Reconstruction') 24 | parser.add_argument('--cls_weight', type=float, default=10, help='Weight about Classification') 25 | 26 | parser.add_argument('--gan_type', type=str, default='wgan-gp', help='gan / lsgan / wgan-gp / wgan-lp / dragan / hinge') 27 | parser.add_argument('--selected_attrs', type=str, nargs='+', help='selected attributes for the CelebA dataset', 28 | default=['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']) 29 | 30 | parser.add_argument('--custom_label', type=int, nargs='+', help='custom label about selected attributes', 31 | default=[1, 0, 0, 0, 0]) 32 | # If your image is "Young, Man, Black Hair" = [1, 0, 0, 1, 1] 33 | 34 | parser.add_argument('--ch', type=int, default=64, help='base channel number per layer') 35 | parser.add_argument('--n_res', type=int, default=6, help='The number of resblock') 36 | parser.add_argument('--n_dis', type=int, default=6, help='The number of discriminator layer') 37 | parser.add_argument('--n_critic', type=int, default=5, help='The number of critic') 38 | 39 | parser.add_argument('--img_size', type=int, default=128, help='The size of image') 40 | parser.add_argument('--img_ch', type=int, default=3, help='The size of image channel') 41 | parser.add_argument('--augment_flag', type=str2bool, default=True, help='Image augmentation use or not') 42 | 43 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 44 | help='Directory name to save the checkpoints') 45 | parser.add_argument('--result_dir', type=str, default='results', 46 | help='Directory name to save the generated images') 47 | parser.add_argument('--log_dir', type=str, default='logs', 48 | help='Directory name to save training logs') 49 | parser.add_argument('--sample_dir', type=str, default='samples', 50 | help='Directory name to save the samples on training') 51 | 52 | return check_args(parser.parse_args()) 53 | 54 | """checking arguments""" 55 | def check_args(args): 56 | # --checkpoint_dir 57 | check_folder(args.checkpoint_dir) 58 | 59 | # --result_dir 60 | check_folder(args.result_dir) 61 | 62 | # --result_dir 63 | check_folder(args.log_dir) 64 | 65 | # --sample_dir 66 | check_folder(args.sample_dir) 67 | 68 | # --epoch 69 | try: 70 | assert args.epoch >= 1 71 | except: 72 | print('number of epochs must be larger than or equal to one') 73 | 74 | # --batch_size 75 | try: 76 | assert args.batch_size >= 1 77 | except: 78 | print('batch size must be larger than or equal to one') 79 | return args 80 | 81 | """main""" 82 | def main(): 83 | # parse arguments 84 | args = parse_args() 85 | if args is None: 86 | exit() 87 | 88 | # open session 89 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 90 | gan = StarGAN(sess, args) 91 | 92 | # build graph 93 | gan.build_model() 94 | 95 | # show network architecture 96 | show_all_variables() 97 | 98 | if args.phase == 'train' : 99 | gan.train() 100 | print(" [*] Training finished!") 101 | 102 | if args.phase == 'test' : 103 | gan.test() 104 | print(" [*] Test finished!") 105 | 106 | if __name__ == '__main__': 107 | main() -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tf_contrib 3 | 4 | 5 | # Xavier : tf_contrib.layers.xavier_initializer() 6 | # He : tf_contrib.layers.variance_scaling_initializer() 7 | # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02) 8 | # l2_decay : tf_contrib.layers.l2_regularizer(0.0001) 9 | 10 | weight_init = tf_contrib.layers.xavier_initializer() 11 | weight_regularizer = None 12 | 13 | ################################################################################## 14 | # Layer 15 | ################################################################################## 16 | 17 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, scope='conv_0'): 18 | with tf.variable_scope(scope): 19 | if pad_type == 'zero' : 20 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]]) 21 | if pad_type == 'reflect' : 22 | x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT') 23 | 24 | x = tf.layers.conv2d(inputs=x, filters=channels, 25 | kernel_size=kernel, kernel_initializer=weight_init, 26 | kernel_regularizer=weight_regularizer, 27 | strides=stride, use_bias=use_bias) 28 | 29 | return x 30 | 31 | 32 | def deconv(x, channels, kernel=4, stride=2, use_bias=True, scope='deconv_0'): 33 | with tf.variable_scope(scope): 34 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 35 | kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, 36 | strides=stride, padding='SAME', use_bias=use_bias) 37 | 38 | return x 39 | 40 | def flatten(x) : 41 | return tf.layers.flatten(x) 42 | 43 | ################################################################################## 44 | # Residual-block 45 | ################################################################################## 46 | 47 | def resblock(x_init, channels, use_bias=True, scope='resblock'): 48 | with tf.variable_scope(scope): 49 | with tf.variable_scope('res1'): 50 | x = conv(x_init, channels, kernel=3, stride=1, pad=1, use_bias=use_bias) 51 | x = instance_norm(x) 52 | x = relu(x) 53 | 54 | with tf.variable_scope('res2'): 55 | x = conv(x, channels, kernel=3, stride=1, pad=1, use_bias=use_bias) 56 | x = instance_norm(x) 57 | 58 | return x + x_init 59 | 60 | 61 | ################################################################################## 62 | # Activation function 63 | ################################################################################## 64 | 65 | def lrelu(x, alpha=0.2): 66 | return tf.nn.leaky_relu(x, alpha) 67 | 68 | 69 | def relu(x): 70 | return tf.nn.relu(x) 71 | 72 | 73 | def tanh(x): 74 | return tf.tanh(x) 75 | 76 | ################################################################################## 77 | # Normalization function 78 | ################################################################################## 79 | 80 | def instance_norm(x, scope='instance_norm'): 81 | return tf_contrib.layers.instance_norm(x, 82 | epsilon=1e-05, 83 | center=True, scale=True, 84 | scope=scope) 85 | 86 | ################################################################################## 87 | # Loss function 88 | ################################################################################## 89 | 90 | def discriminator_loss(loss_func, real, fake): 91 | real_loss = 0 92 | fake_loss = 0 93 | 94 | if loss_func.__contains__('wgan') : 95 | real_loss = -tf.reduce_mean(real) 96 | fake_loss = tf.reduce_mean(fake) 97 | 98 | if loss_func == 'lsgan' : 99 | real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0)) 100 | fake_loss = tf.reduce_mean(tf.square(fake)) 101 | 102 | if loss_func == 'gan' or loss_func == 'dragan' : 103 | real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real)) 104 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake)) 105 | 106 | if loss_func == 'hinge' : 107 | real_loss = tf.reduce_mean(relu(1.0 - real)) 108 | fake_loss = tf.reduce_mean(relu(1.0 + fake)) 109 | 110 | loss = real_loss + fake_loss 111 | 112 | return loss 113 | 114 | def generator_loss(loss_func, fake): 115 | fake_loss = 0 116 | 117 | if loss_func.__contains__('wgan') : 118 | fake_loss = -tf.reduce_mean(fake) 119 | 120 | if loss_func == 'lsgan' : 121 | fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0)) 122 | 123 | if loss_func == 'gan' or loss_func == 'dragan' : 124 | fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake)) 125 | 126 | if loss_func == 'hinge' : 127 | fake_loss = -tf.reduce_mean(fake) 128 | 129 | loss = fake_loss 130 | 131 | return loss 132 | 133 | def classification_loss(logit, label) : 134 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=logit)) 135 | 136 | return loss 137 | 138 | def L1_loss(x, y): 139 | loss = tf.reduce_mean(tf.abs(x - y)) 140 | 141 | return loss -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import scipy.misc 2 | import numpy as np 3 | import os 4 | from scipy import misc 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | import random 9 | 10 | class ImageData: 11 | 12 | def __init__(self, load_size, channels, data_path, selected_attrs, augment_flag=False): 13 | self.load_size = load_size 14 | self.channels = channels 15 | self.augment_flag = augment_flag 16 | self.selected_attrs = selected_attrs 17 | 18 | self.data_path = os.path.join(data_path, 'train') 19 | check_folder(self.data_path) 20 | self.lines = open(os.path.join(data_path, 'list_attr_celeba.txt'), 'r').readlines() 21 | 22 | self.train_dataset = [] 23 | self.train_dataset_label = [] 24 | self.train_dataset_fix_label = [] 25 | 26 | self.test_dataset = [] 27 | self.test_dataset_label = [] 28 | self.test_dataset_fix_label = [] 29 | 30 | self.attr2idx = {} 31 | self.idx2attr = {} 32 | 33 | def image_processing(self, filename, label, fix_label): 34 | x = tf.read_file(filename) 35 | x_decode = tf.image.decode_jpeg(x, channels=self.channels) 36 | img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) 37 | img = tf.cast(img, tf.float32) / 127.5 - 1 38 | 39 | if self.augment_flag : 40 | augment_size = self.load_size + (30 if self.load_size == 256 else 15) 41 | p = random.random() 42 | 43 | if p > 0.5 : 44 | img = augmentation(img, augment_size) 45 | 46 | 47 | return img, label, fix_label 48 | 49 | def preprocess(self) : 50 | all_attr_names = self.lines[1].split() 51 | for i, attr_name in enumerate(all_attr_names) : 52 | self.attr2idx[attr_name] = i 53 | self.idx2attr[i] = attr_name 54 | 55 | 56 | lines = self.lines[2:] 57 | random.seed(1234) 58 | random.shuffle(lines) 59 | 60 | for i, line in enumerate(lines) : 61 | split = line.split() 62 | filename = os.path.join(self.data_path, split[0]) 63 | values = split[1:] 64 | 65 | label = [] 66 | 67 | for attr_name in self.selected_attrs : 68 | idx = self.attr2idx[attr_name] 69 | 70 | if values[idx] == '1' : 71 | label.append(1.0) 72 | else : 73 | label.append(0.0) 74 | 75 | if i < 2000 : 76 | self.test_dataset.append(filename) 77 | self.test_dataset_label.append(label) 78 | else : 79 | self.train_dataset.append(filename) 80 | self.train_dataset_label.append(label) 81 | # ['./dataset/celebA/train/019932.jpg', [1, 0, 0, 0, 1]] 82 | 83 | self.test_dataset_fix_label = create_labels(self.test_dataset_label, self.selected_attrs) 84 | self.train_dataset_fix_label = create_labels(self.train_dataset_label, self.selected_attrs) 85 | 86 | print('\n Finished preprocessing the CelebA dataset...') 87 | 88 | def load_test_data(image_path, size=128): 89 | img = misc.imread(image_path, mode='RGB') 90 | img = misc.imresize(img, [size, size]) 91 | img = np.expand_dims(img, axis=0) 92 | img = normalize(img) 93 | 94 | return img 95 | 96 | def augmentation(image, aug_size): 97 | seed = random.randint(0, 2 ** 31 - 1) 98 | ori_image_shape = tf.shape(image) 99 | image = tf.image.random_flip_left_right(image, seed=seed) 100 | image = tf.image.resize_images(image, [aug_size, aug_size]) 101 | image = tf.random_crop(image, ori_image_shape, seed=seed) 102 | return image 103 | 104 | def normalize(x) : 105 | return x/127.5 - 1 106 | 107 | def save_images(images, size, image_path): 108 | return imsave(inverse_transform(images), size, image_path) 109 | 110 | def merge(images, size): 111 | h, w = images.shape[1], images.shape[2] 112 | 113 | if (images.shape[3] in (3,4)): 114 | c = images.shape[3] 115 | img = np.zeros((h * size[0], w * size[1], c)) 116 | for idx, image in enumerate(images): 117 | i = idx % size[1] 118 | j = idx // size[1] 119 | img[j * h:j * h + h, i * w:i * w + w, :] = image 120 | return img 121 | 122 | elif images.shape[3] == 1: 123 | img = np.zeros((h * size[0], w * size[1])) 124 | for idx, image in enumerate(images): 125 | i = idx % size[1] 126 | j = idx // size[1] 127 | img[j * h:j * h + h, i * w:i * w + w] = image[:, :, 0] 128 | return img 129 | 130 | else: 131 | raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') 132 | 133 | 134 | def imsave(images, size, path): 135 | return scipy.misc.imsave(path, merge(images, size)) 136 | 137 | def inverse_transform(images): 138 | return (images+1.)/2. 139 | 140 | def check_folder(log_dir): 141 | if not os.path.exists(log_dir): 142 | os.makedirs(log_dir) 143 | return log_dir 144 | 145 | def show_all_variables(): 146 | model_vars = tf.trainable_variables() 147 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 148 | 149 | def str2bool(x): 150 | return x.lower() in ('true') 151 | 152 | def create_labels(c_org, selected_attrs=None): 153 | """Generate target domain labels for debugging and testing.""" 154 | # Get hair color indices. 155 | c_org = np.asarray(c_org) 156 | hair_color_indices = [] 157 | for i, attr_name in enumerate(selected_attrs): 158 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: 159 | hair_color_indices.append(i) 160 | 161 | c_trg_list = [] 162 | 163 | for i in range(len(selected_attrs)): 164 | c_trg = c_org.copy() 165 | 166 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. 167 | c_trg[:, i] = 1.0 168 | for j in hair_color_indices: 169 | if j != i: 170 | c_trg[:, j] = 0.0 171 | else: 172 | c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. 173 | 174 | c_trg_list.append(c_trg) 175 | 176 | c_trg_list = np.transpose(c_trg_list, axes=[1, 0, 2]) # [c_dim, bs, ch] 177 | 178 | return c_trg_list -------------------------------------------------------------------------------- /StarGAN.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | from utils import * 3 | import time 4 | from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch 5 | import numpy as np 6 | from glob import glob 7 | 8 | class StarGAN(object) : 9 | def __init__(self, sess, args): 10 | self.model_name = 'StarGAN' 11 | self.sess = sess 12 | self.checkpoint_dir = args.checkpoint_dir 13 | self.sample_dir = args.sample_dir 14 | self.result_dir = args.result_dir 15 | self.log_dir = args.log_dir 16 | self.dataset_name = args.dataset 17 | self.dataset_path = os.path.join('./dataset', self.dataset_name) 18 | self.augment_flag = args.augment_flag 19 | 20 | self.epoch = args.epoch 21 | self.iteration = args.iteration 22 | self.decay_flag = args.decay_flag 23 | self.decay_epoch = args.decay_epoch 24 | 25 | self.gan_type = args.gan_type 26 | 27 | self.batch_size = args.batch_size 28 | self.print_freq = args.print_freq 29 | self.save_freq = args.save_freq 30 | 31 | self.init_lr = args.lr 32 | self.ch = args.ch 33 | self.selected_attrs = args.selected_attrs 34 | self.custom_label = np.expand_dims(args.custom_label, axis=0) 35 | self.c_dim = len(self.selected_attrs) 36 | 37 | """ Weight """ 38 | self.adv_weight = args.adv_weight 39 | self.rec_weight = args.rec_weight 40 | self.cls_weight = args.cls_weight 41 | self.ld = args.ld 42 | 43 | """ Generator """ 44 | self.n_res = args.n_res 45 | 46 | """ Discriminator """ 47 | self.n_dis = args.n_dis 48 | self.n_critic = args.n_critic 49 | 50 | self.img_size = args.img_size 51 | self.img_ch = args.img_ch 52 | 53 | print() 54 | 55 | print("##### Information #####") 56 | print("# gan type : ", self.gan_type) 57 | print("# selected_attrs : ", self.selected_attrs) 58 | print("# dataset : ", self.dataset_name) 59 | print("# batch_size : ", self.batch_size) 60 | print("# epoch : ", self.epoch) 61 | print("# iteration per epoch : ", self.iteration) 62 | 63 | print() 64 | 65 | print("##### Generator #####") 66 | print("# residual blocks : ", self.n_res) 67 | 68 | print() 69 | 70 | print("##### Discriminator #####") 71 | print("# discriminator layer : ", self.n_dis) 72 | print("# the number of critic : ", self.n_critic) 73 | 74 | ################################################################################## 75 | # Generator 76 | ################################################################################## 77 | 78 | def generator(self, x_init, c, reuse=False, scope="generator"): 79 | channel = self.ch 80 | c = tf.cast(tf.reshape(c, shape=[-1, 1, 1, c.shape[-1]]), tf.float32) 81 | c = tf.tile(c, [1, x_init.shape[1], x_init.shape[2], 1]) 82 | x = tf.concat([x_init, c], axis=-1) 83 | 84 | with tf.variable_scope(scope, reuse=reuse) : 85 | x = conv(x, channel, kernel=7, stride=1, pad=3, use_bias=False, scope='conv') 86 | x = instance_norm(x, scope='ins_norm') 87 | x = relu(x) 88 | 89 | # Down-Sampling 90 | for i in range(2) : 91 | x = conv(x, channel*2, kernel=4, stride=2, pad=1, use_bias=False, scope='conv_'+str(i)) 92 | x = instance_norm(x, scope='down_ins_norm_'+str(i)) 93 | x = relu(x) 94 | 95 | channel = channel * 2 96 | 97 | # Bottleneck 98 | for i in range(self.n_res): 99 | x = resblock(x, channel, use_bias=False, scope='resblock_' + str(i)) 100 | 101 | # Up-Sampling 102 | for i in range(2) : 103 | x = deconv(x, channel//2, kernel=4, stride=2, use_bias=False, scope='deconv_'+str(i)) 104 | x = instance_norm(x, scope='up_ins_norm'+str(i)) 105 | x = relu(x) 106 | 107 | channel = channel // 2 108 | 109 | 110 | x = conv(x, channels=3, kernel=7, stride=1, pad=3, use_bias=False, scope='G_logit') 111 | x = tanh(x) 112 | 113 | return x 114 | 115 | ################################################################################## 116 | # Discriminator 117 | ################################################################################## 118 | 119 | def discriminator(self, x_init, reuse=False, scope="discriminator"): 120 | with tf.variable_scope(scope, reuse=reuse) : 121 | channel = self.ch 122 | x = conv(x_init, channel, kernel=4, stride=2, pad=1, use_bias=True, scope='conv_0') 123 | x = lrelu(x, 0.01) 124 | 125 | for i in range(1, self.n_dis): 126 | x = conv(x, channel * 2, kernel=4, stride=2, pad=1, use_bias=True, scope='conv_' + str(i)) 127 | x = lrelu(x, 0.01) 128 | 129 | channel = channel * 2 130 | 131 | c_kernel = int(self.img_size / np.power(2, self.n_dis)) 132 | 133 | logit = conv(x, channels=1, kernel=3, stride=1, pad=1, use_bias=False, scope='D_logit') 134 | c = conv(x, channels=self.c_dim, kernel=c_kernel, stride=1, use_bias=False, scope='D_label') 135 | c = tf.reshape(c, shape=[-1, self.c_dim]) 136 | 137 | return logit, c 138 | 139 | ################################################################################## 140 | # Model 141 | ################################################################################## 142 | 143 | def gradient_panalty(self, real, fake, scope="discriminator"): 144 | if self.gan_type == 'dragan' : 145 | shape = tf.shape(real) 146 | eps = tf.random_uniform(shape=shape, minval=0., maxval=1.) 147 | x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3]) 148 | x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region 149 | noise = 0.5 * x_std * eps # delta in paper 150 | 151 | # Author suggested U[0,1] in original paper, but he admitted it is bug in github 152 | # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided. 153 | 154 | alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.) 155 | interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X 156 | 157 | else : 158 | alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.) 159 | interpolated = alpha*real + (1. - alpha)*fake 160 | 161 | logit, _ = self.discriminator(interpolated, reuse=True, scope=scope) 162 | 163 | 164 | GP = 0 165 | 166 | grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated) 167 | grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm 168 | 169 | # WGAN - LP 170 | if self.gan_type == 'wgan-lp' : 171 | GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.))) 172 | 173 | elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan': 174 | GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.)) 175 | 176 | return GP 177 | 178 | def build_model(self): 179 | self.lr = tf.placeholder(tf.float32, name='learning_rate') 180 | 181 | """ Input Image""" 182 | Image_data_class = ImageData(load_size=self.img_size, channels=self.img_ch, data_path=self.dataset_path, selected_attrs=self.selected_attrs, augment_flag=self.augment_flag) 183 | Image_data_class.preprocess() 184 | 185 | train_dataset_num = len(Image_data_class.train_dataset) 186 | test_dataset_num = len(Image_data_class.test_dataset) 187 | 188 | train_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label)) 189 | test_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label)) 190 | 191 | gpu_device = '/gpu:0' 192 | train_dataset = train_dataset.\ 193 | apply(shuffle_and_repeat(train_dataset_num)).\ 194 | apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ 195 | apply(prefetch_to_device(gpu_device, self.batch_size)) 196 | 197 | test_dataset = test_dataset.\ 198 | apply(shuffle_and_repeat(test_dataset_num)).\ 199 | apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\ 200 | apply(prefetch_to_device(gpu_device, self.batch_size)) 201 | 202 | train_dataset_iterator = train_dataset.make_one_shot_iterator() 203 | test_dataset_iterator = test_dataset.make_one_shot_iterator() 204 | 205 | 206 | self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next() # Input image / Original domain labels 207 | label_trg = tf.random_shuffle(label_org) # Target domain labels 208 | label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2]) 209 | 210 | self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next() # Input image / Original domain labels 211 | test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2]) 212 | 213 | self.custom_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='custom_image') # Custom Image 214 | custom_label_fix_list = tf.transpose(create_labels(self.custom_label, self.selected_attrs), perm=[1, 0, 2]) 215 | 216 | """ Define Generator, Discriminator """ 217 | x_fake = self.generator(self.x_real, label_trg) # real a 218 | x_recon = self.generator(x_fake, label_org, reuse=True) # real b 219 | 220 | real_logit, real_cls = self.discriminator(self.x_real) 221 | fake_logit, fake_cls = self.discriminator(x_fake, reuse=True) 222 | 223 | 224 | """ Define Loss """ 225 | if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' : 226 | GP = self.gradient_panalty(real=self.x_real, fake=x_fake) 227 | else : 228 | GP = 0 229 | 230 | g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit) 231 | g_cls_loss = classification_loss(logit=fake_cls, label=label_trg) 232 | g_rec_loss = L1_loss(self.x_real, x_recon) 233 | 234 | d_adv_loss = discriminator_loss(loss_func=self.gan_type, real=real_logit, fake=fake_logit) + GP 235 | d_cls_loss = classification_loss(logit=real_cls, label=label_org) 236 | 237 | self.d_loss = self.adv_weight * d_adv_loss + self.cls_weight * d_cls_loss 238 | self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss 239 | 240 | 241 | """ Result Image """ 242 | self.x_fake_list = tf.map_fn(lambda x : self.generator(self.x_real, x, reuse=True), label_fix_list, dtype=tf.float32) 243 | 244 | 245 | """ Test Image """ 246 | self.x_test_fake_list = tf.map_fn(lambda x : self.generator(self.x_test, x, reuse=True), test_label_fix_list, dtype=tf.float32) 247 | self.custom_fake_image = tf.map_fn(lambda x : self.generator(self.custom_image, x, reuse=True), custom_label_fix_list, dtype=tf.float32) 248 | 249 | 250 | """ Training """ 251 | t_vars = tf.trainable_variables() 252 | G_vars = [var for var in t_vars if 'generator' in var.name] 253 | D_vars = [var for var in t_vars if 'discriminator' in var.name] 254 | 255 | self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars) 256 | self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars) 257 | 258 | 259 | """" Summary """ 260 | self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss) 261 | self.Discriminator_loss = tf.summary.scalar("Discriminator_loss", self.d_loss) 262 | 263 | self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss) 264 | self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss) 265 | self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss) 266 | 267 | self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss) 268 | self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss) 269 | 270 | self.g_summary_loss = tf.summary.merge([self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss]) 271 | self.d_summary_loss = tf.summary.merge([self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss]) 272 | 273 | 274 | def train(self): 275 | # initialize all variables 276 | tf.global_variables_initializer().run() 277 | 278 | # saver to save model 279 | self.saver = tf.train.Saver() 280 | 281 | # summary writer 282 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph) 283 | 284 | # restore check-point if it exits 285 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 286 | if could_load: 287 | start_epoch = (int)(checkpoint_counter / self.iteration) 288 | start_batch_id = checkpoint_counter - start_epoch * self.iteration 289 | counter = checkpoint_counter 290 | print(" [*] Load SUCCESS") 291 | else: 292 | start_epoch = 0 293 | start_batch_id = 0 294 | counter = 1 295 | print(" [!] Load failed...") 296 | 297 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 298 | check_folder(self.sample_dir) 299 | 300 | # loop for epoch 301 | start_time = time.time() 302 | past_g_loss = -1. 303 | lr = self.init_lr 304 | for epoch in range(start_epoch, self.epoch): 305 | if self.decay_flag : 306 | lr = self.init_lr if epoch < self.decay_epoch else self.init_lr * (self.epoch - epoch) / (self.epoch - self.decay_epoch) # linear decay 307 | 308 | for idx in range(start_batch_id, self.iteration): 309 | train_feed_dict = { 310 | self.lr : lr 311 | } 312 | 313 | # Update D 314 | _, d_loss, summary_str = self.sess.run([self.d_optimizer, self.d_loss, self.d_summary_loss], feed_dict = train_feed_dict) 315 | self.writer.add_summary(summary_str, counter) 316 | 317 | # Update G 318 | g_loss = None 319 | if (counter - 1) % self.n_critic == 0 : 320 | real_images, fake_images, _, g_loss, summary_str = self.sess.run([self.x_real, self.x_fake_list, self.g_optimizer, self.g_loss, self.g_summary_loss], feed_dict = train_feed_dict) 321 | self.writer.add_summary(summary_str, counter) 322 | past_g_loss = g_loss 323 | 324 | # display training status 325 | counter += 1 326 | if g_loss == None : 327 | g_loss = past_g_loss 328 | 329 | print("Epoch: [%2d] [%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss)) 330 | 331 | if np.mod(idx+1, self.print_freq) == 0 : 332 | real_image = np.expand_dims(real_images[0], axis=0) 333 | fake_image = np.transpose(fake_images, axes=[1, 0, 2, 3, 4])[0] # [bs, c_dim, h, w, ch] 334 | 335 | save_images(real_image, [1, 1], 336 | './{}/real_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 337 | 338 | save_images(fake_image, [1, self.c_dim], 339 | './{}/fake_{:03d}_{:05d}.png'.format(self.sample_dir, epoch, idx+1)) 340 | 341 | if np.mod(idx + 1, self.save_freq) == 0: 342 | self.save(self.checkpoint_dir, counter) 343 | 344 | # After an epoch, start_batch_id is set to zero 345 | # non-zero value is only for the first epoch after loading pre-trained model 346 | start_batch_id = 0 347 | 348 | # save model for final step 349 | self.save(self.checkpoint_dir, counter) 350 | 351 | @property 352 | def model_dir(self): 353 | n_res = str(self.n_res) + 'resblock' 354 | n_dis = str(self.n_dis) + 'dis' 355 | 356 | return "{}_{}_{}_{}_{}".format(self.model_name, self.dataset_name, 357 | self.gan_type, 358 | n_res, n_dis) 359 | 360 | def save(self, checkpoint_dir, step): 361 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 362 | 363 | if not os.path.exists(checkpoint_dir): 364 | os.makedirs(checkpoint_dir) 365 | 366 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step) 367 | 368 | def load(self, checkpoint_dir): 369 | import re 370 | print(" [*] Reading checkpoints...") 371 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 372 | 373 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 374 | if ckpt and ckpt.model_checkpoint_path: 375 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 376 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 377 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0)) 378 | print(" [*] Success to read {}".format(ckpt_name)) 379 | return True, counter 380 | else: 381 | print(" [*] Failed to find a checkpoint") 382 | return False, 0 383 | 384 | def test(self): 385 | tf.global_variables_initializer().run() 386 | test_path = os.path.join(self.dataset_path, 'test') 387 | check_folder(test_path) 388 | test_files = glob(os.path.join(test_path, '*.*')) 389 | 390 | self.saver = tf.train.Saver() 391 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 392 | self.result_dir = os.path.join(self.result_dir, self.model_dir) 393 | check_folder(self.result_dir) 394 | 395 | image_folder = os.path.join(self.result_dir, 'images') 396 | check_folder(image_folder) 397 | 398 | if could_load : 399 | print(" [*] Load SUCCESS") 400 | else : 401 | print(" [!] Load failed...") 402 | 403 | # write html for visual comparison 404 | index_path = os.path.join(self.result_dir, 'index.html') 405 | index = open(index_path, 'w') 406 | index.write("") 407 | index.write("") 408 | 409 | # Custom Image 410 | for sample_file in test_files: 411 | print("Processing image: " + sample_file) 412 | sample_image = np.asarray(load_test_data(sample_file, size=self.img_size)) 413 | image_path = os.path.join(image_folder, '{}'.format(os.path.basename(sample_file))) 414 | 415 | fake_image = self.sess.run(self.custom_fake_image, feed_dict = {self.custom_image : sample_image}) 416 | fake_image = np.transpose(fake_image, axes=[1, 0, 2, 3, 4])[0] 417 | save_images(fake_image, [1, self.c_dim], image_path) 418 | 419 | index.write("" % os.path.basename(image_path)) 420 | index.write("" % (sample_file if os.path.isabs(sample_file) else ( 421 | '../..' + os.path.sep + sample_file), self.img_size, self.img_size)) 422 | 423 | index.write("" % (image_path if os.path.isabs(image_path) else ( 424 | '../..' + os.path.sep + image_path), self.img_size * self.c_dim, self.img_size)) 425 | index.write("") 426 | 427 | # CelebA 428 | real_images, fake_images = self.sess.run([self.x_test, self.x_test_fake_list]) 429 | fake_images = np.transpose(fake_images, axes=[1, 0, 2, 3, 4]) 430 | 431 | for i in range(len(real_images)) : 432 | print("{} / {}".format(i, len(real_images))) 433 | real_path = os.path.join(image_folder, 'real_{}.png'.format(i)) 434 | fake_path = os.path.join(image_folder, 'fake_{}.png'.format(i)) 435 | 436 | real_image = np.expand_dims(real_images[i], axis=0) 437 | fake_image = fake_images[i] 438 | save_images(real_image, [1, 1], real_path) 439 | save_images(fake_image, [1, self.c_dim], fake_path) 440 | 441 | index.write("" % os.path.basename(real_path)) 442 | index.write("" % (real_path if os.path.isabs(real_path) else ( 443 | '../..' + os.path.sep + real_path), self.img_size, self.img_size)) 444 | 445 | index.write("" % (fake_path if os.path.isabs(fake_path) else ( 446 | '../..' + os.path.sep + fake_path), self.img_size * self.c_dim, self.img_size)) 447 | index.write("") 448 | 449 | index.close() --------------------------------------------------------------------------------
nameinputoutput
%s
%s