├── .gitignore
├── LICENSE
├── README.md
├── download.py
├── examples
├── input_1.jpg
├── input_2.jpg
├── input_3.jpg
├── input_4.jpg
├── output_1.jpg
├── output_2.jpg
├── output_3.jpg
├── output_4.jpg
├── truth_1.jpg
├── truth_2.jpg
├── truth_3.jpg
└── truth_4.jpg
├── input.py
├── main.py
└── model.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 seungjooli
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ConditionalGAN - sketch2face
2 | This is tensorflow implementation based on [pix2pix](https://phillipi.github.io/pix2pix/). I specifically tried to train the model to convert a human face sketch image to a photo-like image.
3 |
4 | Need improvements:
5 | It works for edges extracted from real photos, but not very well for hand-drawn sketch.
6 |
7 | Differences from the paper:
8 | - To avoid the fast convergence of discriminator, updated generator twice for each training step.
(Borrowed the idea from https://github.com/carpedm20/DCGAN-tensorflow)
9 |
10 | ## Output Example
11 | After 2 epochs training, from randomly chosen test set:
12 |
13 | | input | output | ground truth |
14 | | :---: | :---: | :---: |
15 | |
|
|
|
16 | |
|
|
|
17 | |
|
|
|
18 | |
|
|
|
19 |
20 |
21 | ## Data
22 | [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
23 | Used first 99% of them (200,573 images) as a training set and the others (2,026 images) as a test set.
24 |
25 | 1. Crop and resize images into 256 x 256 (target image)
26 | 2. Extract edges from target images to get sketch images using opencv canny (input image)
27 |
28 | ## Usage
29 | To download dataset and start training:
30 |
31 | $ python main.py --download_data=True
32 |
33 | To train a model with downloaded data:
34 |
35 | $ python main.py
36 |
37 | To test the model with a canvas UI:
38 |
39 | $ python main.py --mode=test
40 |
41 | ## Environment
42 | - python (3.5.3)
43 | - tensorflow-gpu (1.1.0)
44 | - opencv-python (3.2.0)
45 |
46 | - google-api-python-client (1.6.2)
47 | - requests (2.13.0)
48 | - tqdm (4.11.2)
49 |
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import tarfile
4 | import urllib.request
5 | import zipfile
6 | from tqdm import tqdm
7 |
8 |
9 | def maybe_download_from_url(url, download_dir):
10 | """
11 | Download the data from url, unless it's already here.
12 |
13 | Args:
14 | download_dir: string, path to download directory
15 | url: url to download from
16 |
17 | Returns:
18 | Path to the downloaded file
19 | """
20 | filename = url.split('/')[-1]
21 | filepath = os.path.join(download_dir, filename)
22 |
23 | os.makedirs(download_dir, exist_ok=True)
24 |
25 | if not os.path.isfile(filepath):
26 | print('Downloading: "{}"'.format(filepath))
27 | urllib.request.urlretrieve(url, filepath)
28 | size = os.path.getsize(filepath)
29 | print('Download complete ({} bytes)'.format(size))
30 | else:
31 | print('File already exists: "{}"'.format(filepath))
32 |
33 | return filepath
34 |
35 |
36 | def maybe_download_from_google_drive(id, filepath):
37 | def get_confirm_token(response):
38 | for key, value in response.cookies.items():
39 | if key.startswith('download_warning'):
40 | return value
41 | return None
42 |
43 | def save_response_content(response, filepath, chunk_size=32 * 1024):
44 | total_size = int(response.headers.get('content-length', 0))
45 | with open(filepath, "wb") as f:
46 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
47 | unit='B', unit_scale=True, desc=filepath):
48 | if chunk: # filter out keep-alive new chunks
49 | f.write(chunk)
50 |
51 | if not os.path.isfile(filepath):
52 | print('Downloading: "{}"'.format(filepath))
53 | URL = "https://docs.google.com/uc?export=download"
54 | session = requests.Session()
55 |
56 | response = session.get(URL, params={'id': id}, stream=True)
57 | token = get_confirm_token(response)
58 |
59 | if token:
60 | params = {'id': id, 'confirm': token}
61 | response = session.get(URL, params=params, stream=True)
62 |
63 | save_response_content(response, filepath)
64 | size = os.path.getsize(filepath)
65 | print('Download complete ({} bytes)'.format(size))
66 | else:
67 | print('File already exists: "{}"'.format(filepath))
68 |
69 | return filepath
70 |
71 |
72 | def maybe_extract(compressed_filepath, train_dir, test_dir):
73 | def is_image(filepath):
74 | extensions = ('.jpg', '.jpeg', '.png', '.gif')
75 | return any(filepath.endswith(ext) for ext in extensions)
76 |
77 | os.makedirs(train_dir, exist_ok=True)
78 | os.makedirs(test_dir, exist_ok=True)
79 | print('Extracting: "{}"'.format(compressed_filepath))
80 |
81 | if zipfile.is_zipfile(compressed_filepath):
82 | with zipfile.ZipFile(compressed_filepath) as zf:
83 | files = [member for member in zf.infolist() if is_image(member.filename)]
84 | count = len(files)
85 | train_test_boundary = int(count * 0.99)
86 | for i in range(count):
87 | if i < train_test_boundary:
88 | extract_dir = train_dir
89 | else:
90 | extract_dir = test_dir
91 |
92 | if not os.path.exists(os.path.join(extract_dir, files[i].filename)):
93 | zf.extract(files[i], extract_dir)
94 | elif tarfile.is_tarfile(compressed_filepath):
95 | with tarfile.open(compressed_filepath) as tar:
96 | files = [member for member in tar if is_image(member.name)]
97 | count = len(files)
98 | train_test_boundary = int(count * 0.99)
99 | for i in range(count):
100 | if i < train_test_boundary:
101 | extract_dir = train_dir
102 | else:
103 | extract_dir = test_dir
104 |
105 | if not os.path.exists(os.path.join(extract_dir, files[i].name)):
106 | tar.extract(files[i], extract_dir)
107 | else:
108 | raise NotImplemented
109 |
110 | print('Extraction complete')
111 |
--------------------------------------------------------------------------------
/examples/input_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_1.jpg
--------------------------------------------------------------------------------
/examples/input_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_2.jpg
--------------------------------------------------------------------------------
/examples/input_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_3.jpg
--------------------------------------------------------------------------------
/examples/input_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_4.jpg
--------------------------------------------------------------------------------
/examples/output_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_1.jpg
--------------------------------------------------------------------------------
/examples/output_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_2.jpg
--------------------------------------------------------------------------------
/examples/output_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_3.jpg
--------------------------------------------------------------------------------
/examples/output_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_4.jpg
--------------------------------------------------------------------------------
/examples/truth_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_1.jpg
--------------------------------------------------------------------------------
/examples/truth_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_2.jpg
--------------------------------------------------------------------------------
/examples/truth_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_3.jpg
--------------------------------------------------------------------------------
/examples/truth_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_4.jpg
--------------------------------------------------------------------------------
/input.py:
--------------------------------------------------------------------------------
1 | import imghdr
2 | import os
3 | import tensorflow as tf
4 |
5 |
6 | def is_image_valid(filepath):
7 | return imghdr.what(filepath) is not None
8 |
9 |
10 | def get_image_paths(image_dir):
11 | image_paths = []
12 | for root, directories, filenames in os.walk(image_dir):
13 | image_paths += [os.path.join(root, filename) for filename in filenames]
14 | image_paths = [filepath for filepath in image_paths if is_image_valid(filepath)]
15 |
16 | return image_paths
17 |
18 |
19 | def inputs(image_dir, batch_size, min_queue_examples, input_height, input_width):
20 | def read_images(image_paths):
21 | filename_queue = tf.train.string_input_producer(image_paths)
22 | reader = tf.WholeFileReader()
23 | key, value = reader.read(filename_queue)
24 | image = tf.image.decode_image(value)
25 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
26 | image.set_shape([None, None, 3])
27 |
28 | return image
29 |
30 | image_paths = get_image_paths(image_dir)
31 | images = read_images(image_paths)
32 | images = tf.image.crop_to_bounding_box(images, 30, 0, 178, 178)
33 | # images = tf.image.random_flip_left_right(images)
34 | images = tf.image.resize_images(images, [input_height, input_width])
35 |
36 | total_image_count = len(image_paths)
37 | input_batch = tf.train.shuffle_batch([images],
38 | batch_size=batch_size,
39 | num_threads=16,
40 | capacity=min_queue_examples + 3 * batch_size,
41 | min_after_dequeue=min_queue_examples)
42 |
43 | return input_batch, total_image_count
44 |
45 |
46 | if __name__ == '__main__':
47 | pass
48 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import cv2
5 | import download
6 | import input
7 | import model
8 |
9 | flags = tf.app.flags
10 | FLAGS = flags.FLAGS
11 |
12 | flags.DEFINE_boolean('download_data', False, 'whether to download, extract image data')
13 | flags.DEFINE_string('download_dir', './downloads/', 'directory path to download data')
14 | flags.DEFINE_string('train_dir', './images/train/', 'directory path to training set')
15 | flags.DEFINE_string('test_dir', './images/test/', 'directory path to test set')
16 |
17 | flags.DEFINE_integer('input_height', 256, 'resized image height, model input')
18 | flags.DEFINE_integer('input_width', 256, 'resized image width, model input')
19 |
20 | flags.DEFINE_string('mode', 'train', 'train or test')
21 | flags.DEFINE_boolean('load_ckpt', True, 'whether to try restoring model from checkpoint')
22 | flags.DEFINE_string('ckpt_dir', './checkpoints/', 'directory path to checkpoint files')
23 |
24 | flags.DEFINE_integer('epoch', 10, 'total number of epoch to train')
25 | flags.DEFINE_integer('batch_size', 4, 'size of batch')
26 | flags.DEFINE_integer('min_queue_examples', 1000, 'minimum number of elements in batch queue')
27 | flags.DEFINE_float('learning_rate', 0.0001, 'learning rate')
28 | flags.DEFINE_float('l1_weight', 100, 'weight on L1 term for generator')
29 | flags.DEFINE_float('beta1', 0.5, 'adam optimizer beta1 parameter')
30 | flags.DEFINE_string('log_dir', './logs/', 'directory path to write summary')
31 |
32 |
33 | def main(argv):
34 | m = model.Model(FLAGS.log_dir, FLAGS.ckpt_dir, FLAGS.load_ckpt, FLAGS.input_height, FLAGS.input_width)
35 | if FLAGS.mode == 'train':
36 | train(m)
37 | elif FLAGS.mode == 'test':
38 | test(m)
39 | else:
40 | print('Unexpected mode: {} Choose \'train\' or \'test\''.format(FLAGS.mode))
41 | m.close()
42 |
43 |
44 | def train(m):
45 | if FLAGS.download_data:
46 | google_drive_file_id = '0B7EVK8r0v71pZjFTYXZWM3FlRnM'
47 | download_path = os.path.join(FLAGS.download_dir, 'img_align_celeba.zip')
48 | download.maybe_download_from_google_drive(google_drive_file_id, download_path)
49 | download.maybe_extract(download_path, FLAGS.train_dir, FLAGS.test_dir)
50 |
51 | training_inputs, count = input.inputs(FLAGS.train_dir, FLAGS.batch_size, FLAGS.min_queue_examples,
52 | FLAGS.input_height, FLAGS.input_width)
53 | steps_per_epoch = int(count / FLAGS.batch_size)
54 |
55 | test_inputs, _ = input.inputs(FLAGS.test_dir, FLAGS.batch_size, 0, FLAGS.input_height, FLAGS.input_width)
56 |
57 | m.train(training_inputs, test_inputs,
58 | FLAGS.epoch, steps_per_epoch, FLAGS.learning_rate, FLAGS.l1_weight, FLAGS.beta1, FLAGS.load_ckpt)
59 |
60 |
61 | def test(m):
62 | class DrawingState:
63 | def __init__(self):
64 | self.x_prev = 0
65 | self.y_prev = 0
66 | self.drawing = False
67 | self.update = True
68 |
69 | def interactive_drawing(event, x, y, flags, param):
70 | image = param[0]
71 | state = param[1]
72 | if event == cv2.EVENT_LBUTTONDOWN:
73 | state.drawing = True
74 | state.x_prev, state.y_prev = x, y
75 | elif event == cv2.EVENT_MOUSEMOVE:
76 | if state.drawing:
77 | cv2.line(image, (state.x_prev, state.y_prev), (x, y), (1, 1, 1), 1)
78 | state.x_prev = x
79 | state.y_prev = y
80 | state.update = True
81 | elif event == cv2.EVENT_LBUTTONUP:
82 | state.drawing = False
83 | elif event == cv2.EVENT_RBUTTONDOWN:
84 | image.fill(0)
85 | state.update = True
86 |
87 | cv2.namedWindow('Canvas')
88 | image_input = np.zeros((FLAGS.input_height, FLAGS.input_width, 3), np.float32)
89 | state = DrawingState()
90 | cv2.setMouseCallback('Canvas', interactive_drawing, [image_input, state])
91 | while cv2.getWindowProperty('Canvas', 0) >= 0:
92 | if state.update:
93 | reshaped_image_input = np.array([image_input])
94 | image_output = m.test(reshaped_image_input)
95 | concatenated = np.concatenate((image_input, image_output[0]), axis=1)
96 | color_converted = cv2.cvtColor(concatenated, cv2.COLOR_RGB2BGR)
97 | cv2.imshow('Canvas', color_converted)
98 | state.update = False
99 |
100 | k = cv2.waitKey(1) & 0xFF
101 | if k == 27: # esc
102 | break
103 | cv2.destroyAllWindows()
104 |
105 |
106 | if __name__ == '__main__':
107 | tf.app.run()
108 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import cv2
5 |
6 |
7 | class Model:
8 | def __init__(self, log_dir, ckpt_dir, load_ckpt, image_height, image_width, image_channel=3):
9 | self.log_dir = log_dir
10 | self.ckpt_dir = ckpt_dir
11 | self.image_channel = image_channel
12 | self.sess = tf.Session()
13 | self.trained_step = 0
14 |
15 | self.inputs = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel])
16 | self.targets = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel])
17 | self.is_training = tf.placeholder(tf.bool)
18 | self.g = self.create_generator(self.inputs, self.image_channel, self.is_training)
19 | self.d_fake = self.create_discriminator(self.inputs, self.g, self.is_training)
20 | self.d_real = self.create_discriminator(self.inputs, self.targets, self.is_training, reuse=True)
21 |
22 | self.sess.run(tf.global_variables_initializer())
23 | if load_ckpt:
24 | self.trained_step = self.load()
25 |
26 | def close(self):
27 | self.sess.close()
28 |
29 | @staticmethod
30 | def lrelu(input_, leak=0.2):
31 | with tf.name_scope('lrelu'):
32 | return tf.maximum(input_, leak * input_)
33 |
34 | @staticmethod
35 | def batch_norm(input_, is_training):
36 | with tf.name_scope('batchnorm'):
37 | return tf.contrib.layers.batch_norm(input_, is_training=is_training)
38 |
39 | @staticmethod
40 | def conv(input_, output_channels, filter_size=4, stride=2, stddev=3e-2):
41 | with tf.variable_scope('conv'):
42 | in_channels = input_.get_shape()[-1]
43 | filter_ = tf.get_variable(
44 | name='filter',
45 | shape=[filter_size, filter_size, in_channels, output_channels],
46 | initializer=tf.truncated_normal_initializer(stddev=stddev),
47 | )
48 | conv = tf.nn.conv2d(input_, filter_, [1, stride, stride, 1], padding='SAME')
49 | return conv
50 |
51 | @staticmethod
52 | def deconv(input_, out_height, out_width, out_channels, filter_size=4, stride=2, stddev=3e-2):
53 | with tf.variable_scope("deconv"):
54 | in_channels = input_.get_shape().as_list()[-1]
55 | filter_ = tf.get_variable(
56 | name='filter',
57 | shape=[filter_size, filter_size, out_channels, in_channels],
58 | initializer=tf.truncated_normal_initializer(stddev=stddev),
59 | )
60 |
61 | batch_dynamic = tf.shape(input_)[0]
62 | output_shape = tf.stack([batch_dynamic, out_height, out_width, out_channels])
63 | conv = tf.nn.conv2d_transpose(input_, filter_, output_shape, [1, stride, stride, 1], padding="SAME")
64 | conv = tf.reshape(conv, [-1, out_height, out_width, out_channels])
65 | return conv
66 |
67 | @staticmethod
68 | def detect_edges(images):
69 | def blur(image):
70 | return cv2.GaussianBlur(image, (5, 5), 0)
71 |
72 | def canny_otsu(image):
73 | scale_factor = 255
74 | scaled_image = np.uint8(image * scale_factor)
75 |
76 | otsu_threshold = cv2.threshold(
77 | cv2.cvtColor(scaled_image, cv2.COLOR_RGB2GRAY), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[0]
78 | lower_threshold = max(0, int(otsu_threshold * 0.5))
79 | upper_threshold = min(255, int(otsu_threshold))
80 | edges = cv2.Canny(scaled_image, lower_threshold, upper_threshold)
81 | edges = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
82 |
83 | return np.float32(edges) * (1 / scale_factor)
84 |
85 | blurred = [blur(image) for image in images]
86 | canny_applied = [canny_otsu(image) for image in blurred]
87 |
88 | return canny_applied
89 |
90 | @staticmethod
91 | def create_generator(input_, generator_output_channels, is_training):
92 | class Encoder:
93 | def __init__(self, name, out_channels, is_training, use_batch_norm=True):
94 | self.name = name
95 | self.out_channels = out_channels
96 | self.is_training = is_training
97 | self.use_batch_norm = use_batch_norm
98 | self.in_height = None
99 | self.in_width = None
100 | self.output = None
101 |
102 | def encode(self, input_):
103 | with tf.variable_scope(self.name):
104 | output = Model.conv(input_, self.out_channels)
105 | if self.use_batch_norm:
106 | output = Model.batch_norm(output, self.is_training)
107 | output = Model.lrelu(output)
108 |
109 | input_shape = input_.get_shape().as_list()
110 | self.in_height = input_shape[1]
111 | self.in_width = input_shape[2]
112 | self.output = output
113 | return output
114 |
115 | class Decoder:
116 | def __init__(self, name, out_channels, is_training, use_batch_norm=True, dropout=None):
117 | self.name = name
118 | self.out_channels = out_channels
119 | self.is_training = is_training
120 | self.use_batch_norm = use_batch_norm
121 | self.dropout = dropout
122 | self.output = None
123 |
124 | def decode(self, input_, out_height, out_width, skip_input=None):
125 | with tf.variable_scope(self.name):
126 | if skip_input is None:
127 | merged_input = input_
128 | else:
129 | merged_input = tf.concat([input_, skip_input], axis=3)
130 |
131 | output = Model.deconv(merged_input, out_height, out_width, self.out_channels)
132 | if self.use_batch_norm:
133 | output = Model.batch_norm(output, self.is_training)
134 | output = tf.nn.relu(output)
135 | if self.dropout:
136 | output = tf.nn.dropout(output, keep_prob=1 - self.dropout)
137 |
138 | self.output = output
139 | return output
140 |
141 | with tf.variable_scope('generator'):
142 | ngf = 64
143 |
144 | encoders = [
145 | Encoder('encoder_0', ngf * 1, is_training, use_batch_norm=False),
146 | Encoder('encoder_1', ngf * 2, is_training),
147 | Encoder('encoder_2', ngf * 4, is_training),
148 | Encoder('encoder_3', ngf * 8, is_training),
149 | Encoder('encoder_4', ngf * 8, is_training),
150 | Encoder('encoder_5', ngf * 8, is_training),
151 | Encoder('encoder_6', ngf * 8, is_training),
152 | Encoder('encoder_7', ngf * 8, is_training),
153 | ]
154 |
155 | for i, encoder in enumerate(encoders):
156 | if i == 0:
157 | encoder_input = input_
158 | else:
159 | encoder_input = encoders[i - 1].output
160 | encoders[i].encode(encoder_input)
161 |
162 | decoders = [
163 | Decoder('decoder_0', ngf * 8, is_training, dropout=0.5),
164 | Decoder('decoder_1', ngf * 8, is_training, dropout=0.5),
165 | Decoder('decoder_2', ngf * 8, is_training, dropout=0.5),
166 | Decoder('decoder_3', ngf * 8, is_training),
167 | Decoder('decoder_4', ngf * 4, is_training),
168 | Decoder('decoder_5', ngf * 2, is_training),
169 | Decoder('decoder_6', ngf * 1, is_training),
170 | Decoder('decoder_7', generator_output_channels, is_training),
171 | ]
172 |
173 | for i, decoder in enumerate(decoders):
174 | if i == 0:
175 | decoder_input = encoders[-1].output
176 | decoder_skip_input = None
177 | else:
178 | decoder_input = decoders[i - 1].output
179 | decoder_skip_input = encoders[-i - 1].output
180 |
181 | decoders[i].decode(decoder_input, encoders[-i - 1].in_height, encoders[-i - 1].in_width, decoder_skip_input)
182 |
183 | return tf.nn.tanh(decoders[-1].output)
184 |
185 | @staticmethod
186 | def create_discriminator(input_, target, is_training, reuse=False):
187 | class Layer:
188 | def __init__(self, name, output_channels, stride, is_training, use_batch_norm=True, use_activation=True):
189 | self.name = name
190 | self.output_channels = output_channels
191 | self.stride = stride
192 | self.is_training = is_training
193 | self.use_batch_norm = use_batch_norm
194 | self.use_activation = use_activation
195 | self.output = None
196 |
197 | def conv(self, input_):
198 | with tf.variable_scope(self.name):
199 | output = Model.conv(input_, self.output_channels, stride=self.stride)
200 | if self.use_batch_norm:
201 | output = Model.batch_norm(output, self.is_training)
202 | if self.use_activation:
203 | output = Model.lrelu(output)
204 |
205 | self.output = output
206 | return output
207 |
208 | with tf.variable_scope('discriminator') as scope:
209 | if reuse:
210 | scope.reuse_variables()
211 |
212 | ndf = 64
213 |
214 | layers = [
215 | Layer('layer_0', ndf * 1, 2, is_training, use_batch_norm=False),
216 | Layer('layer_1', ndf * 2, 2, is_training),
217 | Layer('layer_2', ndf * 4, 2, is_training),
218 | Layer('layer_3', ndf * 8, 1, is_training),
219 | Layer('layer_4', 1, 1, is_training, use_batch_norm=False, use_activation=False),
220 | ]
221 |
222 | for i, layer in enumerate(layers):
223 | if i == 0:
224 | layer_input = tf.concat([input_, target], axis=3)
225 | else:
226 | layer_input = layers[i - 1].output
227 | layers[i].conv(layer_input)
228 |
229 | return tf.nn.sigmoid(layers[-1].output)
230 |
231 | def train(self, training_image, test_image, total_epoch, steps_per_epoch, learning_rate, l1_weight, beta1, load_ckpt):
232 | epsilon = 1e-8
233 |
234 | loss_g_gan = tf.reduce_mean(-tf.log(self.d_fake + epsilon))
235 | loss_g_l1 = l1_weight * tf.reduce_mean(tf.abs(self.targets - self.g))
236 | loss_g = loss_g_gan + loss_g_l1
237 |
238 | loss_d_real = tf.reduce_mean(-tf.log(self.d_real + epsilon))
239 | loss_d_fake = tf.reduce_mean(-tf.log(tf.ones_like(self.d_fake) - self.d_fake + epsilon))
240 | loss_d = loss_d_real + loss_d_fake
241 |
242 | vars_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
243 | vars_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
244 |
245 | # update batch_norm moving_mean, moving_variance
246 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
247 | with tf.control_dependencies(update_ops):
248 | train_g = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss_g, var_list=vars_g)
249 | train_d = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss_d, var_list=vars_d)
250 |
251 | tf.summary.image('training_truth', self.targets, 4)
252 | tf.summary.image('training_input', self.inputs, 4)
253 | tf.summary.image('training_output', self.g, 4)
254 |
255 | tf.summary.histogram('D_real', self.d_real)
256 | tf.summary.histogram('D_fake', self.d_fake)
257 |
258 | tf.summary.scalar('G_loss', loss_g)
259 | tf.summary.scalar('G_loss_gan', loss_g_gan)
260 | tf.summary.scalar('G_loss_l1', loss_g_l1)
261 | tf.summary.scalar('D_loss', loss_d)
262 | tf.summary.scalar('D_loss_real', loss_d_real)
263 | tf.summary.scalar('D_loss_fake', loss_d_fake)
264 |
265 | for var in vars_g:
266 | tf.summary.histogram(var.name, var)
267 | for var in vars_d:
268 | tf.summary.histogram(var.name, var)
269 |
270 | training_summary = tf.summary.merge_all()
271 |
272 | test_summary_truth = tf.summary.image('test_truth', self.targets, 4)
273 | test_summary_input = tf.summary.image('test_input', self.inputs, 4)
274 | test_summary_output = tf.summary.image('test_output', self.g, 4)
275 | test_summary = tf.summary.merge([test_summary_input, test_summary_output, test_summary_truth])
276 |
277 | writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)
278 |
279 | # FIXME
280 | self.sess.run(tf.global_variables_initializer())
281 | if load_ckpt:
282 | self.trained_step = self.load()
283 |
284 | coord = tf.train.Coordinator()
285 | threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)
286 |
287 | print('Training start')
288 | for epoch in range(total_epoch):
289 | for step in range(steps_per_epoch):
290 | image_value = self.sess.run(training_image)
291 | edges = self.detect_edges(image_value)
292 |
293 | feed_dict = {self.inputs: edges, self.targets: image_value, self.is_training: True}
294 | self.sess.run(train_d, feed_dict=feed_dict)
295 | self.sess.run(train_g, feed_dict=feed_dict)
296 | self.sess.run(train_g, feed_dict=feed_dict)
297 |
298 | self.trained_step += 1
299 | if self.trained_step % 100 == 0:
300 | print('step: {}'.format(self.trained_step))
301 |
302 | training_summary_value = self.sess.run(training_summary, feed_dict=feed_dict)
303 | writer.add_summary(training_summary_value, self.trained_step)
304 |
305 | image_value = self.sess.run(test_image)
306 | edges = self.detect_edges(image_value)
307 |
308 | feed_dict = {self.inputs: edges, self.targets: image_value, self.is_training: False}
309 | test_summary_value = self.sess.run(test_summary, feed_dict=feed_dict)
310 | writer.add_summary(test_summary_value, self.trained_step)
311 |
312 | if self.trained_step % 1000 == 0:
313 | self.save()
314 |
315 | coord.join(threads)
316 |
317 | def test(self, inputs):
318 | output = self.sess.run(self.g, feed_dict={self.inputs: inputs, self.is_training: False})
319 | return output
320 |
321 | def save(self):
322 | os.makedirs(self.ckpt_dir, exist_ok=True)
323 | saver = tf.train.Saver()
324 | saver.save(self.sess, self.ckpt_dir, global_step=self.trained_step)
325 |
326 | def load(self):
327 | ckpt = tf.train.get_checkpoint_state(self.ckpt_dir)
328 | if ckpt and ckpt.model_checkpoint_path:
329 | saver = tf.train.Saver()
330 | saver.restore(self.sess, ckpt.model_checkpoint_path)
331 |
332 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
333 | trained_step = int(os.path.splitext(ckpt_name)[0][1:])
334 |
335 | return trained_step
336 | else:
337 | return 0
338 |
--------------------------------------------------------------------------------