├── .gitignore ├── LICENSE ├── README.md ├── download.py ├── examples ├── input_1.jpg ├── input_2.jpg ├── input_3.jpg ├── input_4.jpg ├── output_1.jpg ├── output_2.jpg ├── output_3.jpg ├── output_4.jpg ├── truth_1.jpg ├── truth_2.jpg ├── truth_3.jpg └── truth_4.jpg ├── input.py ├── main.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 seungjooli 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConditionalGAN - sketch2face 2 | This is tensorflow implementation based on [pix2pix](https://phillipi.github.io/pix2pix/). I specifically tried to train the model to convert a human face sketch image to a photo-like image. 3 | 4 | Need improvements: 5 | It works for edges extracted from real photos, but not very well for hand-drawn sketch. 6 | 7 | Differences from the paper: 8 | - To avoid the fast convergence of discriminator, updated generator twice for each training step.
(Borrowed the idea from https://github.com/carpedm20/DCGAN-tensorflow) 9 | 10 | ## Output Example 11 | After 2 epochs training, from randomly chosen test set: 12 | 13 | | input | output | ground truth | 14 | | :---: | :---: | :---: | 15 | | | | | 16 | | | | | 17 | | | | | 18 | | | | | 19 | 20 | 21 | ## Data 22 | [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 23 | Used first 99% of them (200,573 images) as a training set and the others (2,026 images) as a test set. 24 | 25 | 1. Crop and resize images into 256 x 256 (target image) 26 | 2. Extract edges from target images to get sketch images using opencv canny (input image) 27 | 28 | ## Usage 29 | To download dataset and start training: 30 | 31 | $ python main.py --download_data=True 32 | 33 | To train a model with downloaded data: 34 | 35 | $ python main.py 36 | 37 | To test the model with a canvas UI: 38 | 39 | $ python main.py --mode=test 40 | 41 | ## Environment 42 | - python (3.5.3) 43 | - tensorflow-gpu (1.1.0) 44 | - opencv-python (3.2.0) 45 | 46 | - google-api-python-client (1.6.2) 47 | - requests (2.13.0) 48 | - tqdm (4.11.2) 49 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tarfile 4 | import urllib.request 5 | import zipfile 6 | from tqdm import tqdm 7 | 8 | 9 | def maybe_download_from_url(url, download_dir): 10 | """ 11 | Download the data from url, unless it's already here. 12 | 13 | Args: 14 | download_dir: string, path to download directory 15 | url: url to download from 16 | 17 | Returns: 18 | Path to the downloaded file 19 | """ 20 | filename = url.split('/')[-1] 21 | filepath = os.path.join(download_dir, filename) 22 | 23 | os.makedirs(download_dir, exist_ok=True) 24 | 25 | if not os.path.isfile(filepath): 26 | print('Downloading: "{}"'.format(filepath)) 27 | urllib.request.urlretrieve(url, filepath) 28 | size = os.path.getsize(filepath) 29 | print('Download complete ({} bytes)'.format(size)) 30 | else: 31 | print('File already exists: "{}"'.format(filepath)) 32 | 33 | return filepath 34 | 35 | 36 | def maybe_download_from_google_drive(id, filepath): 37 | def get_confirm_token(response): 38 | for key, value in response.cookies.items(): 39 | if key.startswith('download_warning'): 40 | return value 41 | return None 42 | 43 | def save_response_content(response, filepath, chunk_size=32 * 1024): 44 | total_size = int(response.headers.get('content-length', 0)) 45 | with open(filepath, "wb") as f: 46 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 47 | unit='B', unit_scale=True, desc=filepath): 48 | if chunk: # filter out keep-alive new chunks 49 | f.write(chunk) 50 | 51 | if not os.path.isfile(filepath): 52 | print('Downloading: "{}"'.format(filepath)) 53 | URL = "https://docs.google.com/uc?export=download" 54 | session = requests.Session() 55 | 56 | response = session.get(URL, params={'id': id}, stream=True) 57 | token = get_confirm_token(response) 58 | 59 | if token: 60 | params = {'id': id, 'confirm': token} 61 | response = session.get(URL, params=params, stream=True) 62 | 63 | save_response_content(response, filepath) 64 | size = os.path.getsize(filepath) 65 | print('Download complete ({} bytes)'.format(size)) 66 | else: 67 | print('File already exists: "{}"'.format(filepath)) 68 | 69 | return filepath 70 | 71 | 72 | def maybe_extract(compressed_filepath, train_dir, test_dir): 73 | def is_image(filepath): 74 | extensions = ('.jpg', '.jpeg', '.png', '.gif') 75 | return any(filepath.endswith(ext) for ext in extensions) 76 | 77 | os.makedirs(train_dir, exist_ok=True) 78 | os.makedirs(test_dir, exist_ok=True) 79 | print('Extracting: "{}"'.format(compressed_filepath)) 80 | 81 | if zipfile.is_zipfile(compressed_filepath): 82 | with zipfile.ZipFile(compressed_filepath) as zf: 83 | files = [member for member in zf.infolist() if is_image(member.filename)] 84 | count = len(files) 85 | train_test_boundary = int(count * 0.99) 86 | for i in range(count): 87 | if i < train_test_boundary: 88 | extract_dir = train_dir 89 | else: 90 | extract_dir = test_dir 91 | 92 | if not os.path.exists(os.path.join(extract_dir, files[i].filename)): 93 | zf.extract(files[i], extract_dir) 94 | elif tarfile.is_tarfile(compressed_filepath): 95 | with tarfile.open(compressed_filepath) as tar: 96 | files = [member for member in tar if is_image(member.name)] 97 | count = len(files) 98 | train_test_boundary = int(count * 0.99) 99 | for i in range(count): 100 | if i < train_test_boundary: 101 | extract_dir = train_dir 102 | else: 103 | extract_dir = test_dir 104 | 105 | if not os.path.exists(os.path.join(extract_dir, files[i].name)): 106 | tar.extract(files[i], extract_dir) 107 | else: 108 | raise NotImplemented 109 | 110 | print('Extraction complete') 111 | -------------------------------------------------------------------------------- /examples/input_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_1.jpg -------------------------------------------------------------------------------- /examples/input_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_2.jpg -------------------------------------------------------------------------------- /examples/input_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_3.jpg -------------------------------------------------------------------------------- /examples/input_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/input_4.jpg -------------------------------------------------------------------------------- /examples/output_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_1.jpg -------------------------------------------------------------------------------- /examples/output_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_2.jpg -------------------------------------------------------------------------------- /examples/output_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_3.jpg -------------------------------------------------------------------------------- /examples/output_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/output_4.jpg -------------------------------------------------------------------------------- /examples/truth_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_1.jpg -------------------------------------------------------------------------------- /examples/truth_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_2.jpg -------------------------------------------------------------------------------- /examples/truth_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_3.jpg -------------------------------------------------------------------------------- /examples/truth_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/seungjooli/ConditionalGAN/b1e89ad2c1a4f765113bff049e6132129b68be25/examples/truth_4.jpg -------------------------------------------------------------------------------- /input.py: -------------------------------------------------------------------------------- 1 | import imghdr 2 | import os 3 | import tensorflow as tf 4 | 5 | 6 | def is_image_valid(filepath): 7 | return imghdr.what(filepath) is not None 8 | 9 | 10 | def get_image_paths(image_dir): 11 | image_paths = [] 12 | for root, directories, filenames in os.walk(image_dir): 13 | image_paths += [os.path.join(root, filename) for filename in filenames] 14 | image_paths = [filepath for filepath in image_paths if is_image_valid(filepath)] 15 | 16 | return image_paths 17 | 18 | 19 | def inputs(image_dir, batch_size, min_queue_examples, input_height, input_width): 20 | def read_images(image_paths): 21 | filename_queue = tf.train.string_input_producer(image_paths) 22 | reader = tf.WholeFileReader() 23 | key, value = reader.read(filename_queue) 24 | image = tf.image.decode_image(value) 25 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 26 | image.set_shape([None, None, 3]) 27 | 28 | return image 29 | 30 | image_paths = get_image_paths(image_dir) 31 | images = read_images(image_paths) 32 | images = tf.image.crop_to_bounding_box(images, 30, 0, 178, 178) 33 | # images = tf.image.random_flip_left_right(images) 34 | images = tf.image.resize_images(images, [input_height, input_width]) 35 | 36 | total_image_count = len(image_paths) 37 | input_batch = tf.train.shuffle_batch([images], 38 | batch_size=batch_size, 39 | num_threads=16, 40 | capacity=min_queue_examples + 3 * batch_size, 41 | min_after_dequeue=min_queue_examples) 42 | 43 | return input_batch, total_image_count 44 | 45 | 46 | if __name__ == '__main__': 47 | pass 48 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import cv2 5 | import download 6 | import input 7 | import model 8 | 9 | flags = tf.app.flags 10 | FLAGS = flags.FLAGS 11 | 12 | flags.DEFINE_boolean('download_data', False, 'whether to download, extract image data') 13 | flags.DEFINE_string('download_dir', './downloads/', 'directory path to download data') 14 | flags.DEFINE_string('train_dir', './images/train/', 'directory path to training set') 15 | flags.DEFINE_string('test_dir', './images/test/', 'directory path to test set') 16 | 17 | flags.DEFINE_integer('input_height', 256, 'resized image height, model input') 18 | flags.DEFINE_integer('input_width', 256, 'resized image width, model input') 19 | 20 | flags.DEFINE_string('mode', 'train', 'train or test') 21 | flags.DEFINE_boolean('load_ckpt', True, 'whether to try restoring model from checkpoint') 22 | flags.DEFINE_string('ckpt_dir', './checkpoints/', 'directory path to checkpoint files') 23 | 24 | flags.DEFINE_integer('epoch', 10, 'total number of epoch to train') 25 | flags.DEFINE_integer('batch_size', 4, 'size of batch') 26 | flags.DEFINE_integer('min_queue_examples', 1000, 'minimum number of elements in batch queue') 27 | flags.DEFINE_float('learning_rate', 0.0001, 'learning rate') 28 | flags.DEFINE_float('l1_weight', 100, 'weight on L1 term for generator') 29 | flags.DEFINE_float('beta1', 0.5, 'adam optimizer beta1 parameter') 30 | flags.DEFINE_string('log_dir', './logs/', 'directory path to write summary') 31 | 32 | 33 | def main(argv): 34 | m = model.Model(FLAGS.log_dir, FLAGS.ckpt_dir, FLAGS.load_ckpt, FLAGS.input_height, FLAGS.input_width) 35 | if FLAGS.mode == 'train': 36 | train(m) 37 | elif FLAGS.mode == 'test': 38 | test(m) 39 | else: 40 | print('Unexpected mode: {} Choose \'train\' or \'test\''.format(FLAGS.mode)) 41 | m.close() 42 | 43 | 44 | def train(m): 45 | if FLAGS.download_data: 46 | google_drive_file_id = '0B7EVK8r0v71pZjFTYXZWM3FlRnM' 47 | download_path = os.path.join(FLAGS.download_dir, 'img_align_celeba.zip') 48 | download.maybe_download_from_google_drive(google_drive_file_id, download_path) 49 | download.maybe_extract(download_path, FLAGS.train_dir, FLAGS.test_dir) 50 | 51 | training_inputs, count = input.inputs(FLAGS.train_dir, FLAGS.batch_size, FLAGS.min_queue_examples, 52 | FLAGS.input_height, FLAGS.input_width) 53 | steps_per_epoch = int(count / FLAGS.batch_size) 54 | 55 | test_inputs, _ = input.inputs(FLAGS.test_dir, FLAGS.batch_size, 0, FLAGS.input_height, FLAGS.input_width) 56 | 57 | m.train(training_inputs, test_inputs, 58 | FLAGS.epoch, steps_per_epoch, FLAGS.learning_rate, FLAGS.l1_weight, FLAGS.beta1, FLAGS.load_ckpt) 59 | 60 | 61 | def test(m): 62 | class DrawingState: 63 | def __init__(self): 64 | self.x_prev = 0 65 | self.y_prev = 0 66 | self.drawing = False 67 | self.update = True 68 | 69 | def interactive_drawing(event, x, y, flags, param): 70 | image = param[0] 71 | state = param[1] 72 | if event == cv2.EVENT_LBUTTONDOWN: 73 | state.drawing = True 74 | state.x_prev, state.y_prev = x, y 75 | elif event == cv2.EVENT_MOUSEMOVE: 76 | if state.drawing: 77 | cv2.line(image, (state.x_prev, state.y_prev), (x, y), (1, 1, 1), 1) 78 | state.x_prev = x 79 | state.y_prev = y 80 | state.update = True 81 | elif event == cv2.EVENT_LBUTTONUP: 82 | state.drawing = False 83 | elif event == cv2.EVENT_RBUTTONDOWN: 84 | image.fill(0) 85 | state.update = True 86 | 87 | cv2.namedWindow('Canvas') 88 | image_input = np.zeros((FLAGS.input_height, FLAGS.input_width, 3), np.float32) 89 | state = DrawingState() 90 | cv2.setMouseCallback('Canvas', interactive_drawing, [image_input, state]) 91 | while cv2.getWindowProperty('Canvas', 0) >= 0: 92 | if state.update: 93 | reshaped_image_input = np.array([image_input]) 94 | image_output = m.test(reshaped_image_input) 95 | concatenated = np.concatenate((image_input, image_output[0]), axis=1) 96 | color_converted = cv2.cvtColor(concatenated, cv2.COLOR_RGB2BGR) 97 | cv2.imshow('Canvas', color_converted) 98 | state.update = False 99 | 100 | k = cv2.waitKey(1) & 0xFF 101 | if k == 27: # esc 102 | break 103 | cv2.destroyAllWindows() 104 | 105 | 106 | if __name__ == '__main__': 107 | tf.app.run() 108 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import cv2 5 | 6 | 7 | class Model: 8 | def __init__(self, log_dir, ckpt_dir, load_ckpt, image_height, image_width, image_channel=3): 9 | self.log_dir = log_dir 10 | self.ckpt_dir = ckpt_dir 11 | self.image_channel = image_channel 12 | self.sess = tf.Session() 13 | self.trained_step = 0 14 | 15 | self.inputs = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel]) 16 | self.targets = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel]) 17 | self.is_training = tf.placeholder(tf.bool) 18 | self.g = self.create_generator(self.inputs, self.image_channel, self.is_training) 19 | self.d_fake = self.create_discriminator(self.inputs, self.g, self.is_training) 20 | self.d_real = self.create_discriminator(self.inputs, self.targets, self.is_training, reuse=True) 21 | 22 | self.sess.run(tf.global_variables_initializer()) 23 | if load_ckpt: 24 | self.trained_step = self.load() 25 | 26 | def close(self): 27 | self.sess.close() 28 | 29 | @staticmethod 30 | def lrelu(input_, leak=0.2): 31 | with tf.name_scope('lrelu'): 32 | return tf.maximum(input_, leak * input_) 33 | 34 | @staticmethod 35 | def batch_norm(input_, is_training): 36 | with tf.name_scope('batchnorm'): 37 | return tf.contrib.layers.batch_norm(input_, is_training=is_training) 38 | 39 | @staticmethod 40 | def conv(input_, output_channels, filter_size=4, stride=2, stddev=3e-2): 41 | with tf.variable_scope('conv'): 42 | in_channels = input_.get_shape()[-1] 43 | filter_ = tf.get_variable( 44 | name='filter', 45 | shape=[filter_size, filter_size, in_channels, output_channels], 46 | initializer=tf.truncated_normal_initializer(stddev=stddev), 47 | ) 48 | conv = tf.nn.conv2d(input_, filter_, [1, stride, stride, 1], padding='SAME') 49 | return conv 50 | 51 | @staticmethod 52 | def deconv(input_, out_height, out_width, out_channels, filter_size=4, stride=2, stddev=3e-2): 53 | with tf.variable_scope("deconv"): 54 | in_channels = input_.get_shape().as_list()[-1] 55 | filter_ = tf.get_variable( 56 | name='filter', 57 | shape=[filter_size, filter_size, out_channels, in_channels], 58 | initializer=tf.truncated_normal_initializer(stddev=stddev), 59 | ) 60 | 61 | batch_dynamic = tf.shape(input_)[0] 62 | output_shape = tf.stack([batch_dynamic, out_height, out_width, out_channels]) 63 | conv = tf.nn.conv2d_transpose(input_, filter_, output_shape, [1, stride, stride, 1], padding="SAME") 64 | conv = tf.reshape(conv, [-1, out_height, out_width, out_channels]) 65 | return conv 66 | 67 | @staticmethod 68 | def detect_edges(images): 69 | def blur(image): 70 | return cv2.GaussianBlur(image, (5, 5), 0) 71 | 72 | def canny_otsu(image): 73 | scale_factor = 255 74 | scaled_image = np.uint8(image * scale_factor) 75 | 76 | otsu_threshold = cv2.threshold( 77 | cv2.cvtColor(scaled_image, cv2.COLOR_RGB2GRAY), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[0] 78 | lower_threshold = max(0, int(otsu_threshold * 0.5)) 79 | upper_threshold = min(255, int(otsu_threshold)) 80 | edges = cv2.Canny(scaled_image, lower_threshold, upper_threshold) 81 | edges = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) 82 | 83 | return np.float32(edges) * (1 / scale_factor) 84 | 85 | blurred = [blur(image) for image in images] 86 | canny_applied = [canny_otsu(image) for image in blurred] 87 | 88 | return canny_applied 89 | 90 | @staticmethod 91 | def create_generator(input_, generator_output_channels, is_training): 92 | class Encoder: 93 | def __init__(self, name, out_channels, is_training, use_batch_norm=True): 94 | self.name = name 95 | self.out_channels = out_channels 96 | self.is_training = is_training 97 | self.use_batch_norm = use_batch_norm 98 | self.in_height = None 99 | self.in_width = None 100 | self.output = None 101 | 102 | def encode(self, input_): 103 | with tf.variable_scope(self.name): 104 | output = Model.conv(input_, self.out_channels) 105 | if self.use_batch_norm: 106 | output = Model.batch_norm(output, self.is_training) 107 | output = Model.lrelu(output) 108 | 109 | input_shape = input_.get_shape().as_list() 110 | self.in_height = input_shape[1] 111 | self.in_width = input_shape[2] 112 | self.output = output 113 | return output 114 | 115 | class Decoder: 116 | def __init__(self, name, out_channels, is_training, use_batch_norm=True, dropout=None): 117 | self.name = name 118 | self.out_channels = out_channels 119 | self.is_training = is_training 120 | self.use_batch_norm = use_batch_norm 121 | self.dropout = dropout 122 | self.output = None 123 | 124 | def decode(self, input_, out_height, out_width, skip_input=None): 125 | with tf.variable_scope(self.name): 126 | if skip_input is None: 127 | merged_input = input_ 128 | else: 129 | merged_input = tf.concat([input_, skip_input], axis=3) 130 | 131 | output = Model.deconv(merged_input, out_height, out_width, self.out_channels) 132 | if self.use_batch_norm: 133 | output = Model.batch_norm(output, self.is_training) 134 | output = tf.nn.relu(output) 135 | if self.dropout: 136 | output = tf.nn.dropout(output, keep_prob=1 - self.dropout) 137 | 138 | self.output = output 139 | return output 140 | 141 | with tf.variable_scope('generator'): 142 | ngf = 64 143 | 144 | encoders = [ 145 | Encoder('encoder_0', ngf * 1, is_training, use_batch_norm=False), 146 | Encoder('encoder_1', ngf * 2, is_training), 147 | Encoder('encoder_2', ngf * 4, is_training), 148 | Encoder('encoder_3', ngf * 8, is_training), 149 | Encoder('encoder_4', ngf * 8, is_training), 150 | Encoder('encoder_5', ngf * 8, is_training), 151 | Encoder('encoder_6', ngf * 8, is_training), 152 | Encoder('encoder_7', ngf * 8, is_training), 153 | ] 154 | 155 | for i, encoder in enumerate(encoders): 156 | if i == 0: 157 | encoder_input = input_ 158 | else: 159 | encoder_input = encoders[i - 1].output 160 | encoders[i].encode(encoder_input) 161 | 162 | decoders = [ 163 | Decoder('decoder_0', ngf * 8, is_training, dropout=0.5), 164 | Decoder('decoder_1', ngf * 8, is_training, dropout=0.5), 165 | Decoder('decoder_2', ngf * 8, is_training, dropout=0.5), 166 | Decoder('decoder_3', ngf * 8, is_training), 167 | Decoder('decoder_4', ngf * 4, is_training), 168 | Decoder('decoder_5', ngf * 2, is_training), 169 | Decoder('decoder_6', ngf * 1, is_training), 170 | Decoder('decoder_7', generator_output_channels, is_training), 171 | ] 172 | 173 | for i, decoder in enumerate(decoders): 174 | if i == 0: 175 | decoder_input = encoders[-1].output 176 | decoder_skip_input = None 177 | else: 178 | decoder_input = decoders[i - 1].output 179 | decoder_skip_input = encoders[-i - 1].output 180 | 181 | decoders[i].decode(decoder_input, encoders[-i - 1].in_height, encoders[-i - 1].in_width, decoder_skip_input) 182 | 183 | return tf.nn.tanh(decoders[-1].output) 184 | 185 | @staticmethod 186 | def create_discriminator(input_, target, is_training, reuse=False): 187 | class Layer: 188 | def __init__(self, name, output_channels, stride, is_training, use_batch_norm=True, use_activation=True): 189 | self.name = name 190 | self.output_channels = output_channels 191 | self.stride = stride 192 | self.is_training = is_training 193 | self.use_batch_norm = use_batch_norm 194 | self.use_activation = use_activation 195 | self.output = None 196 | 197 | def conv(self, input_): 198 | with tf.variable_scope(self.name): 199 | output = Model.conv(input_, self.output_channels, stride=self.stride) 200 | if self.use_batch_norm: 201 | output = Model.batch_norm(output, self.is_training) 202 | if self.use_activation: 203 | output = Model.lrelu(output) 204 | 205 | self.output = output 206 | return output 207 | 208 | with tf.variable_scope('discriminator') as scope: 209 | if reuse: 210 | scope.reuse_variables() 211 | 212 | ndf = 64 213 | 214 | layers = [ 215 | Layer('layer_0', ndf * 1, 2, is_training, use_batch_norm=False), 216 | Layer('layer_1', ndf * 2, 2, is_training), 217 | Layer('layer_2', ndf * 4, 2, is_training), 218 | Layer('layer_3', ndf * 8, 1, is_training), 219 | Layer('layer_4', 1, 1, is_training, use_batch_norm=False, use_activation=False), 220 | ] 221 | 222 | for i, layer in enumerate(layers): 223 | if i == 0: 224 | layer_input = tf.concat([input_, target], axis=3) 225 | else: 226 | layer_input = layers[i - 1].output 227 | layers[i].conv(layer_input) 228 | 229 | return tf.nn.sigmoid(layers[-1].output) 230 | 231 | def train(self, training_image, test_image, total_epoch, steps_per_epoch, learning_rate, l1_weight, beta1, load_ckpt): 232 | epsilon = 1e-8 233 | 234 | loss_g_gan = tf.reduce_mean(-tf.log(self.d_fake + epsilon)) 235 | loss_g_l1 = l1_weight * tf.reduce_mean(tf.abs(self.targets - self.g)) 236 | loss_g = loss_g_gan + loss_g_l1 237 | 238 | loss_d_real = tf.reduce_mean(-tf.log(self.d_real + epsilon)) 239 | loss_d_fake = tf.reduce_mean(-tf.log(tf.ones_like(self.d_fake) - self.d_fake + epsilon)) 240 | loss_d = loss_d_real + loss_d_fake 241 | 242 | vars_g = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 243 | vars_d = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 244 | 245 | # update batch_norm moving_mean, moving_variance 246 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 247 | with tf.control_dependencies(update_ops): 248 | train_g = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss_g, var_list=vars_g) 249 | train_d = tf.train.AdamOptimizer(learning_rate, beta1=beta1).minimize(loss_d, var_list=vars_d) 250 | 251 | tf.summary.image('training_truth', self.targets, 4) 252 | tf.summary.image('training_input', self.inputs, 4) 253 | tf.summary.image('training_output', self.g, 4) 254 | 255 | tf.summary.histogram('D_real', self.d_real) 256 | tf.summary.histogram('D_fake', self.d_fake) 257 | 258 | tf.summary.scalar('G_loss', loss_g) 259 | tf.summary.scalar('G_loss_gan', loss_g_gan) 260 | tf.summary.scalar('G_loss_l1', loss_g_l1) 261 | tf.summary.scalar('D_loss', loss_d) 262 | tf.summary.scalar('D_loss_real', loss_d_real) 263 | tf.summary.scalar('D_loss_fake', loss_d_fake) 264 | 265 | for var in vars_g: 266 | tf.summary.histogram(var.name, var) 267 | for var in vars_d: 268 | tf.summary.histogram(var.name, var) 269 | 270 | training_summary = tf.summary.merge_all() 271 | 272 | test_summary_truth = tf.summary.image('test_truth', self.targets, 4) 273 | test_summary_input = tf.summary.image('test_input', self.inputs, 4) 274 | test_summary_output = tf.summary.image('test_output', self.g, 4) 275 | test_summary = tf.summary.merge([test_summary_input, test_summary_output, test_summary_truth]) 276 | 277 | writer = tf.summary.FileWriter(self.log_dir, self.sess.graph) 278 | 279 | # FIXME 280 | self.sess.run(tf.global_variables_initializer()) 281 | if load_ckpt: 282 | self.trained_step = self.load() 283 | 284 | coord = tf.train.Coordinator() 285 | threads = tf.train.start_queue_runners(sess=self.sess, coord=coord) 286 | 287 | print('Training start') 288 | for epoch in range(total_epoch): 289 | for step in range(steps_per_epoch): 290 | image_value = self.sess.run(training_image) 291 | edges = self.detect_edges(image_value) 292 | 293 | feed_dict = {self.inputs: edges, self.targets: image_value, self.is_training: True} 294 | self.sess.run(train_d, feed_dict=feed_dict) 295 | self.sess.run(train_g, feed_dict=feed_dict) 296 | self.sess.run(train_g, feed_dict=feed_dict) 297 | 298 | self.trained_step += 1 299 | if self.trained_step % 100 == 0: 300 | print('step: {}'.format(self.trained_step)) 301 | 302 | training_summary_value = self.sess.run(training_summary, feed_dict=feed_dict) 303 | writer.add_summary(training_summary_value, self.trained_step) 304 | 305 | image_value = self.sess.run(test_image) 306 | edges = self.detect_edges(image_value) 307 | 308 | feed_dict = {self.inputs: edges, self.targets: image_value, self.is_training: False} 309 | test_summary_value = self.sess.run(test_summary, feed_dict=feed_dict) 310 | writer.add_summary(test_summary_value, self.trained_step) 311 | 312 | if self.trained_step % 1000 == 0: 313 | self.save() 314 | 315 | coord.join(threads) 316 | 317 | def test(self, inputs): 318 | output = self.sess.run(self.g, feed_dict={self.inputs: inputs, self.is_training: False}) 319 | return output 320 | 321 | def save(self): 322 | os.makedirs(self.ckpt_dir, exist_ok=True) 323 | saver = tf.train.Saver() 324 | saver.save(self.sess, self.ckpt_dir, global_step=self.trained_step) 325 | 326 | def load(self): 327 | ckpt = tf.train.get_checkpoint_state(self.ckpt_dir) 328 | if ckpt and ckpt.model_checkpoint_path: 329 | saver = tf.train.Saver() 330 | saver.restore(self.sess, ckpt.model_checkpoint_path) 331 | 332 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 333 | trained_step = int(os.path.splitext(ckpt_name)[0][1:]) 334 | 335 | return trained_step 336 | else: 337 | return 0 338 | --------------------------------------------------------------------------------