├── .gitignore
├── README.md
├── download_dataset.sh
├── imgs
├── AtoB_n02381460_4530.jpg
├── AtoB_n02381460_4660.jpg
├── AtoB_n02381460_510.jpg
├── AtoB_n02381460_8980.jpg
├── BtoA_n02391049_1760.jpg
├── BtoA_n02391049_3070.jpg
├── BtoA_n02391049_5100.jpg
├── BtoA_n02391049_7150.jpg
├── n02381460_4530.jpg
├── n02381460_4660.jpg
├── n02381460_510.jpg
├── n02381460_8980.jpg
├── n02391049_1760.jpg
├── n02391049_3070.jpg
├── n02391049_5100.jpg
├── n02391049_7150.jpg
└── teaser.jpg
├── main.py
├── model.py
├── module.py
├── ops.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea/*
3 | logs/*
4 | checkpoint/*
5 | datasets/*
6 | test/*
7 | sample/*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
5 | # CycleGAN
6 |
7 | Tensorflow implementation for learning an image-to-image translation **without** input-output pairs.
8 | The method is proposed by [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/) in
9 | [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkssee](https://arxiv.org/pdf/1703.10593.pdf).
10 | For example in paper:
11 |
12 |
13 |
14 |
32 |
33 | ## Update Results
34 | The results of this implementation:
35 |
36 | - Horses -> Zebras
37 |
38 |
39 | - Zebras -> Horses
40 |
41 |
42 | You can download the pretrained model from [this url](https://1drv.ms/u/s!AroAdu0uts_gj5tA93GnwyfRpvBIDA)
43 | and extract the rar file to `./checkpoint/`.
44 |
45 |
46 | ## Prerequisites
47 | - tensorflow r1.1
48 | - numpy 1.11.0
49 | - scipy 0.17.0
50 | - pillow 3.3.0
51 |
52 | ## Getting Started
53 | ### Installation
54 | - Install tensorflow from https://github.com/tensorflow/tensorflow
55 | - Clone this repo:
56 | ```bash
57 | git clone https://github.com/xhujoy/CycleGAN-tensorflow
58 | cd CycleGAN-tensorflow
59 | ```
60 |
61 | ### Train
62 | - Download a dataset (e.g. zebra and horse images from ImageNet):
63 | ```bash
64 | bash ./download_dataset.sh horse2zebra
65 | ```
66 | - Train a model:
67 | ```bash
68 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=horse2zebra
69 | ```
70 | - Use tensorboard to visualize the training details:
71 | ```bash
72 | tensorboard --logdir=./logs
73 | ```
74 |
75 | ### Test
76 | - Finally, test the model:
77 | ```bash
78 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=horse2zebra --phase=test --which_direction=AtoB
79 | ```
80 |
81 | ## Training and Test Details
82 | To train a model,
83 | ```bash
84 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=/path/to/data/
85 | ```
86 | Models are saved to `./checkpoints/` (can be changed by passing `--checkpoint_dir=your_dir`).
87 |
88 | To test the model,
89 | ```bash
90 | CUDA_VISIBLE_DEVICES=0 python main.py --dataset_dir=/path/to/data/ --phase=test --which_direction=AtoB/BtoA
91 | ```
92 |
93 | ## Datasets
94 | Download the datasets using the following script:
95 | ```bash
96 | bash ./download_dataset.sh dataset_name
97 | ```
98 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/).
99 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/).
100 | - `maps`: 1096 training images scraped from Google Maps.
101 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using keywords `wild horse` and `zebra`.
102 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using keywords `apple` and `navel orange`.
103 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
104 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using combination of tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
105 | - `iphone2dslr_flower`: both classe of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.
106 |
107 |
108 | ## Reference
109 | - The torch implementation of CycleGAN, https://github.com/junyanz/CycleGAN
110 | - The tensorflow implementation of pix2pix, https://github.com/yenchenlin/pix2pix-tensorflow
111 |
--------------------------------------------------------------------------------
/download_dataset.sh:
--------------------------------------------------------------------------------
1 | mkdir datasets
2 | FILE=$1
3 |
4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
6 | exit 1
7 | fi
8 |
9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
10 | ZIP_FILE=./datasets/$FILE.zip
11 | TARGET_DIR=./datasets/$FILE/
12 | wget -N $URL -O $ZIP_FILE
13 | mkdir $TARGET_DIR
14 | unzip $ZIP_FILE -d ./datasets/
15 | rm $ZIP_FILE
16 |
--------------------------------------------------------------------------------
/imgs/AtoB_n02381460_4530.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_4530.jpg
--------------------------------------------------------------------------------
/imgs/AtoB_n02381460_4660.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_4660.jpg
--------------------------------------------------------------------------------
/imgs/AtoB_n02381460_510.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_510.jpg
--------------------------------------------------------------------------------
/imgs/AtoB_n02381460_8980.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/AtoB_n02381460_8980.jpg
--------------------------------------------------------------------------------
/imgs/BtoA_n02391049_1760.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_1760.jpg
--------------------------------------------------------------------------------
/imgs/BtoA_n02391049_3070.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_3070.jpg
--------------------------------------------------------------------------------
/imgs/BtoA_n02391049_5100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_5100.jpg
--------------------------------------------------------------------------------
/imgs/BtoA_n02391049_7150.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/BtoA_n02391049_7150.jpg
--------------------------------------------------------------------------------
/imgs/n02381460_4530.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_4530.jpg
--------------------------------------------------------------------------------
/imgs/n02381460_4660.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_4660.jpg
--------------------------------------------------------------------------------
/imgs/n02381460_510.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_510.jpg
--------------------------------------------------------------------------------
/imgs/n02381460_8980.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02381460_8980.jpg
--------------------------------------------------------------------------------
/imgs/n02391049_1760.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_1760.jpg
--------------------------------------------------------------------------------
/imgs/n02391049_3070.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_3070.jpg
--------------------------------------------------------------------------------
/imgs/n02391049_5100.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_5100.jpg
--------------------------------------------------------------------------------
/imgs/n02391049_7150.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/n02391049_7150.jpg
--------------------------------------------------------------------------------
/imgs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaowei-hu/CycleGAN-tensorflow/bd19873a2eca15383ab70786ec0beadcd6e00c7f/imgs/teaser.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import tensorflow as tf
4 | tf.set_random_seed(19)
5 | from model import cyclegan
6 |
7 | parser = argparse.ArgumentParser(description='')
8 | parser.add_argument('--dataset_dir', dest='dataset_dir', default='horse2zebra', help='path of the dataset')
9 | parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
10 | parser.add_argument('--epoch_step', dest='epoch_step', type=int, default=100, help='# of epoch to decay lr')
11 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
12 | parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
13 | parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size')
14 | parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
15 | parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
16 | parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
17 | parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
18 | parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
19 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
20 | parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
21 | parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
22 | parser.add_argument('--phase', dest='phase', default='train', help='train, test')
23 | parser.add_argument('--save_freq', dest='save_freq', type=int, default=1000, help='save a model every save_freq iterations')
24 | parser.add_argument('--print_freq', dest='print_freq', type=int, default=100, help='print the debug information every print_freq iterations')
25 | parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
26 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
27 | parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
28 | parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
29 | parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=10.0, help='weight on L1 term in objective')
30 | parser.add_argument('--use_resnet', dest='use_resnet', type=bool, default=True, help='generation network using reidule block')
31 | parser.add_argument('--use_lsgan', dest='use_lsgan', type=bool, default=True, help='gan loss defined in lsgan')
32 | parser.add_argument('--max_size', dest='max_size', type=int, default=50, help='max size of image pool, 0 means do not use image pool')
33 |
34 | args = parser.parse_args()
35 |
36 |
37 | def main(_):
38 | if not os.path.exists(args.checkpoint_dir):
39 | os.makedirs(args.checkpoint_dir)
40 | if not os.path.exists(args.sample_dir):
41 | os.makedirs(args.sample_dir)
42 | if not os.path.exists(args.test_dir):
43 | os.makedirs(args.test_dir)
44 |
45 | tfconfig = tf.ConfigProto(allow_soft_placement=True)
46 | tfconfig.gpu_options.allow_growth = True
47 | with tf.Session(config=tfconfig) as sess:
48 | model = cyclegan(sess, args)
49 | model.train(args) if args.phase == 'train' \
50 | else model.test(args)
51 |
52 | if __name__ == '__main__':
53 | tf.app.run()
54 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import time
4 | from glob import glob
5 | import tensorflow as tf
6 | import numpy as np
7 | from collections import namedtuple
8 |
9 | from module import *
10 | from utils import *
11 |
12 |
13 | class cyclegan(object):
14 | def __init__(self, sess, args):
15 | self.sess = sess
16 | self.batch_size = args.batch_size
17 | self.image_size = args.fine_size
18 | self.input_c_dim = args.input_nc
19 | self.output_c_dim = args.output_nc
20 | self.L1_lambda = args.L1_lambda
21 | self.dataset_dir = args.dataset_dir
22 |
23 | self.discriminator = discriminator
24 | if args.use_resnet:
25 | self.generator = generator_resnet
26 | else:
27 | self.generator = generator_unet
28 | if args.use_lsgan:
29 | self.criterionGAN = mae_criterion
30 | else:
31 | self.criterionGAN = sce_criterion
32 |
33 | OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \
34 | gf_dim df_dim output_c_dim is_training')
35 | self.options = OPTIONS._make((args.batch_size, args.fine_size,
36 | args.ngf, args.ndf, args.output_nc,
37 | args.phase == 'train'))
38 |
39 | self._build_model()
40 | self.saver = tf.train.Saver()
41 | self.pool = ImagePool(args.max_size)
42 |
43 | def _build_model(self):
44 | self.real_data = tf.placeholder(tf.float32,
45 | [None, self.image_size, self.image_size,
46 | self.input_c_dim + self.output_c_dim],
47 | name='real_A_and_B_images')
48 |
49 | self.real_A = self.real_data[:, :, :, :self.input_c_dim]
50 | self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
51 |
52 | self.fake_B = self.generator(self.real_A, self.options, False, name="generatorA2B")
53 | self.fake_A_ = self.generator(self.fake_B, self.options, False, name="generatorB2A")
54 | self.fake_A = self.generator(self.real_B, self.options, True, name="generatorB2A")
55 | self.fake_B_ = self.generator(self.fake_A, self.options, True, name="generatorA2B")
56 |
57 | self.DB_fake = self.discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB")
58 | self.DA_fake = self.discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA")
59 | self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
60 | + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
61 | + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
62 | self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
63 | + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
64 | + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
65 | self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
66 | + self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
67 | + self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
68 | + self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
69 |
70 | self.fake_A_sample = tf.placeholder(tf.float32,
71 | [None, self.image_size, self.image_size,
72 | self.input_c_dim], name='fake_A_sample')
73 | self.fake_B_sample = tf.placeholder(tf.float32,
74 | [None, self.image_size, self.image_size,
75 | self.output_c_dim], name='fake_B_sample')
76 | self.DB_real = self.discriminator(self.real_B, self.options, reuse=True, name="discriminatorB")
77 | self.DA_real = self.discriminator(self.real_A, self.options, reuse=True, name="discriminatorA")
78 | self.DB_fake_sample = self.discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB")
79 | self.DA_fake_sample = self.discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorA")
80 |
81 | self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real))
82 | self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample))
83 | self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2
84 | self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real))
85 | self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample))
86 | self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2
87 | self.d_loss = self.da_loss + self.db_loss
88 |
89 | self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b)
90 | self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a)
91 | self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
92 | self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum])
93 | self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss)
94 | self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss)
95 | self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
96 | self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real)
97 | self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake)
98 | self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real)
99 | self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake)
100 | self.d_sum = tf.summary.merge(
101 | [self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum,
102 | self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum,
103 | self.d_loss_sum]
104 | )
105 |
106 | self.test_A = tf.placeholder(tf.float32,
107 | [None, self.image_size, self.image_size,
108 | self.input_c_dim], name='test_A')
109 | self.test_B = tf.placeholder(tf.float32,
110 | [None, self.image_size, self.image_size,
111 | self.output_c_dim], name='test_B')
112 | self.testB = self.generator(self.test_A, self.options, True, name="generatorA2B")
113 | self.testA = self.generator(self.test_B, self.options, True, name="generatorB2A")
114 |
115 | t_vars = tf.trainable_variables()
116 | self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
117 | self.g_vars = [var for var in t_vars if 'generator' in var.name]
118 | for var in t_vars: print(var.name)
119 |
120 | def train(self, args):
121 | """Train cyclegan"""
122 | self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
123 | self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
124 | .minimize(self.d_loss, var_list=self.d_vars)
125 | self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
126 | .minimize(self.g_loss, var_list=self.g_vars)
127 |
128 | init_op = tf.global_variables_initializer()
129 | self.sess.run(init_op)
130 | self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
131 |
132 | counter = 1
133 | start_time = time.time()
134 |
135 | if args.continue_train:
136 | if self.load(args.checkpoint_dir):
137 | print(" [*] Load SUCCESS")
138 | else:
139 | print(" [!] Load failed...")
140 |
141 | for epoch in range(args.epoch):
142 | dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainA'))
143 | dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainB'))
144 | np.random.shuffle(dataA)
145 | np.random.shuffle(dataB)
146 | batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
147 | lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step)
148 |
149 | for idx in range(0, batch_idxs):
150 | batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],
151 | dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))
152 | batch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files]
153 | batch_images = np.array(batch_images).astype(np.float32)
154 |
155 | # Update G network and record fake outputs
156 | fake_A, fake_B, _, summary_str = self.sess.run(
157 | [self.fake_A, self.fake_B, self.g_optim, self.g_sum],
158 | feed_dict={self.real_data: batch_images, self.lr: lr})
159 | self.writer.add_summary(summary_str, counter)
160 | [fake_A, fake_B] = self.pool([fake_A, fake_B])
161 |
162 | # Update D network
163 | _, summary_str = self.sess.run(
164 | [self.d_optim, self.d_sum],
165 | feed_dict={self.real_data: batch_images,
166 | self.fake_A_sample: fake_A,
167 | self.fake_B_sample: fake_B,
168 | self.lr: lr})
169 | self.writer.add_summary(summary_str, counter)
170 |
171 | counter += 1
172 | print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (
173 | epoch, idx, batch_idxs, time.time() - start_time)))
174 |
175 | if np.mod(counter, args.print_freq) == 1:
176 | self.sample_model(args.sample_dir, epoch, idx)
177 |
178 | if np.mod(counter, args.save_freq) == 2:
179 | self.save(args.checkpoint_dir, counter)
180 |
181 | def save(self, checkpoint_dir, step):
182 | model_name = "cyclegan.model"
183 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
184 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
185 |
186 | if not os.path.exists(checkpoint_dir):
187 | os.makedirs(checkpoint_dir)
188 |
189 | self.saver.save(self.sess,
190 | os.path.join(checkpoint_dir, model_name),
191 | global_step=step)
192 |
193 | def load(self, checkpoint_dir):
194 | print(" [*] Reading checkpoint...")
195 |
196 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
197 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
198 |
199 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
200 | if ckpt and ckpt.model_checkpoint_path:
201 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
202 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
203 | return True
204 | else:
205 | return False
206 |
207 | def sample_model(self, sample_dir, epoch, idx):
208 | dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
209 | dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
210 | np.random.shuffle(dataA)
211 | np.random.shuffle(dataB)
212 | batch_files = list(zip(dataA[:self.batch_size], dataB[:self.batch_size]))
213 | sample_images = [load_train_data(batch_file, is_testing=True) for batch_file in batch_files]
214 | sample_images = np.array(sample_images).astype(np.float32)
215 |
216 | fake_A, fake_B = self.sess.run(
217 | [self.fake_A, self.fake_B],
218 | feed_dict={self.real_data: sample_images}
219 | )
220 | save_images(fake_A, [self.batch_size, 1],
221 | './{}/A_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
222 | save_images(fake_B, [self.batch_size, 1],
223 | './{}/B_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
224 |
225 | def test(self, args):
226 | """Test cyclegan"""
227 | init_op = tf.global_variables_initializer()
228 | self.sess.run(init_op)
229 | if args.which_direction == 'AtoB':
230 | sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
231 | elif args.which_direction == 'BtoA':
232 | sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
233 | else:
234 | raise Exception('--which_direction must be AtoB or BtoA')
235 |
236 | if self.load(args.checkpoint_dir):
237 | print(" [*] Load SUCCESS")
238 | else:
239 | print(" [!] Load failed...")
240 |
241 | # write html for visual comparison
242 | index_path = os.path.join(args.test_dir, '{0}_index.html'.format(args.which_direction))
243 | index = open(index_path, "w")
244 | index.write("
name | input | output | %s | " % os.path.basename(image_path)) 259 | index.write("
---|