├── CONTRIBUTING.md ├── LICENSE ├── README.md └── vae-gan ├── config.yaml ├── data ├── build_image_data.py ├── create_model.sh ├── create_random_embedding.py ├── display_image.py ├── generate_image.sh └── run_training.sh └── trainer ├── __init__.py ├── model.py ├── task.py └── util.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 Google Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Machine Learning on the Cloud 2 | 3 | This tool uses the [Google Cloud Machine Learning 4 | API](https://cloud.google.com/ml) and [Tensorflow](https://tensorflow.org). 5 | 6 | Generative Machine Learning on the Cloud is a cloud based tool to aid in 7 | generative art and synthetic image generation. The end to end system design 8 | allows a user to have a custom dataset of images to train a Variational 9 | Autoencoder Generative Adversarial Network (VAE-GAN) model on Cloud ML. From 10 | here, their model is deployed to the cloud, where they can input an embedding to 11 | have synthetic images generated from their dataset or input an image to get an 12 | embedding vector. 13 | 14 | ## How To Use the Tool 15 | 16 | ### Pre-steps: 17 | 18 | 1. [Install Tensorflow](https://www.tensorflow.org/install/) 19 | * Really recommend doing the virtualenv install 20 | * Verify numpy is installed 21 | 2. [Set Up the Cloud 22 | Environment](https://cloud.google.com/ml-engine/docs/quickstarts/command-line) 23 | * Create a Cloud Platform Project 24 | * Enable Billing 25 | * Enable Cloud ML Engine and Compute Engine APIs 26 | 3. Clone this repo 27 | * If using a TensorFlow virtualenv, make sure to clone into a subdirectory 28 | of the virtualenv directory 29 | 30 | ### How To: Run a Training Job 31 | 32 | A training job will train the VAE-GAN on your own training data! 33 | 34 | Important: You will be using billable components of the Cloud Platform and will 35 | incur charges when running a training job. 36 | 37 | 1. cd into the data directory of the source code you just cloned. Make sure to 38 | activate the tensorflow virtualenv (if that is the method you chose to 39 | install TensorFlow). 40 | 41 | 2. Run the training script \ 42 | \ 43 | Dataset Tips: 44 | 45 | * Read [how Cloud ML interacts with the 46 | data](https://cloud.google.com/ml-engine/docs/how-tos/working-with-data). 47 | * Accepted image formats: .jpg or .png 48 | * The larger your image set, the less chance of overfitting! 49 | * One rule of thumb is at least ~1000 images per class. 50 | - If you are trying to synthesize faces, try to have at least 1000 51 | face images. 52 | - If you are trying to generate both cat and dog images, try to have 53 | at least 1000 cats and 1000 dogs. 54 | * The model will crop / resize your images to 64x64 squares. 55 | - Use the -c flag to specify centered cropping (or else it will random 56 | crop). 57 | - The image is cropped to a bounding box of side lengths of 58 | minimum(original_height, original_width). 59 | - The image is resized to 64x64 (using 60 | [tf.image.resize_images](https://www.tensorflow.org/api_docs/python/tf/image/resize_images) 61 | to either downsample or upsample using bilinear interpolation). 62 | * This script will turn your image files into [TFRecords file 63 | format](https://www.tensorflow.org/versions/r1.0/api_guides/python/python_io) 64 | with Example protos and saves them to your GCS bucket. It partitions 65 | your data into a training dataset and a validation dataset. 66 | * For efficient throughput, image files should not exceed 4 MB. Reducing 67 | image size can increase throughput. 68 | 69 | Example: 70 | 71 | ```shell 72 | sh run_training.sh -d $PATH_TO_TRAINING_DIR -c 73 | ``` 74 | 75 | **Flags:** \ 76 | \[-d PATH_TO_TRAINING_DIR\] : required, supplies image directory of .jpg or 77 | .png images \ 78 | \[-c\] : optional, if present images will be center-cropped, if absent 79 | images will be randomly cropped. \ 80 | \[-p\] : optional, port on which to start TensorBoard instance. 81 | 82 | 3. Monitor your training job using the TensorBoard you started or the Cloud 83 | dashboard 84 | 85 | * TensorBoard: Starts at http://0.0.0.0:6006 by default, unless port 86 | specified. 87 | * Job Logs: http://console.cloud.google.com -> Big Data -> ML Engine -> 88 | Jobs 89 | 90 | ### How To: Create and Deploy Model 91 | 92 | Now that we have a trained model saved on GCS, lets deploy it on Cloud ML! 93 | 94 | 1. cd into the data directory of the source code. 95 | 2. Run create model script (if you don't know your job name, use the -l flag) \ 96 | Example: 97 | 98 | ```shell 99 | sh create_model.sh -j $JOB_NAME 100 | ``` 101 | 102 | **Flags:** \ 103 | \[-j JOB_NAME\] : required unless -l flag present, supplies job name \ 104 | \[-l\]: optional, if present lists 10 most recent jobs created by user 105 | 106 | 3. Look at your deployed model on the cloud dashboard under Cloud ML Engine! 107 | 108 | * Model: http://console.cloud.google.com -> Big Data -> ML Engine -> 109 | Models 110 | 111 | ### How To: Run an Inference Job 112 | 113 | Now that we have a deployed model trained with your own data, we can use it to 114 | generate new samples. 115 | 116 | 1. Generate an Image! 117 | 118 | * I've provided a script to randomly generate an image from your model and 119 | display it: 120 | 121 | ```shell 122 | sh generate_image.sh -m $MODEL_NAME 123 | ``` 124 | 125 | **Flags:** \ 126 | \[-m MODEL_NAME\] : required unless -l flag present, specifies model to 127 | generate image. \ 128 | \[-l\] : optional, if present lists all models associated with user. \ 129 | \[-d TEMP_DIR\] : optional, directory to which to write json file. 130 | 131 | * Assumes [PIL is installed](https://pypi.python.org/pypi/Pillow/2.2.1) 132 | 133 | ```shell 134 | $pip install Pillow 135 | ``` 136 | 137 | 2. Embedding to Image generation 138 | 139 | * Use the command line & a json file! 140 | 141 | * Example format: 142 | 143 | ```json 144 | json format -- list truncated to length 4 instead of 100: 145 | {"embeddings": [5,10,-1.6,7.8], "key": "0"} 146 | ``` 147 | 148 | * Embedding array must have dimension of 100 (if using current 149 | vae-gan) or whatever was specified in the code: 150 | 151 | ```python 152 | model.py:32 EMBEDDING_SIZE = 100 153 | ``` 154 | 155 | * Example command: 156 | 157 | ```shell 158 | gcloud ml-engine predict --model $MODEL_NAME --json-instances $JSON_FILE 159 | ``` 160 | 161 | * Batch Prediction Job 162 | 163 | * Example format: 164 | 165 | ```json 166 | json format example -- embedding lists truncated to length 9: 167 | 168 | {"embeddings": [0.1,2.3,-4.6,6.5,0,4.4,-0.9,-0.9,2.2], "key": "0"} 169 | {"embeddings": [0.1,2.3,-4.6,6.5,1,4.4,-0.9,-0.9,2.2], "key": "1"} 170 | {"embeddings": [0.1,2.3,-4.6,6.5,2,4.4,-0.9,-0.9,2.2], "key": "2"} 171 | {"embeddings": [0.1,2.3,-4.6,6.5,3,4.4,-0.9,-0.9,2.2], "key": "3"} 172 | {"embeddings": [0.1,2.3,-4.6,6.5,4,4.4,-0.9,-0.9,2.2], "key": "4"} 173 | {"embeddings": [0.1,2.3,-4.6,6.5,5,4.4,-0.9,-0.9,2.2], "key": "5"} 174 | ``` 175 | 176 | * Json file must be on GCS 177 | 178 | * Example command: 179 | 180 | ```shell 181 | gcloud ml-engine jobs submit prediction $JOB_NAME --model 182 | $MODEL_NAME --input-paths "gs://BUCKET/request.json" --output-path 183 | "gs://BUCKET/output" --region us-east1 --data-format "TEXT" 184 | ``` 185 | 186 | * Use python API 187 | 188 | * Documentation 189 | [here](https://cloud.google.com/ml-engine/docs/tutorials/python-guide) 190 | 191 | * Setup project and execute request 192 | 193 | ```python 194 | credentials = GoogleCredentials.get_application_default() 195 | ml = discovery.build('ml', 'v1', credentials=credentials) 196 | request_dict = {'instances': [{'embeddings': embeds.tolist(), 'key': '0'}]} 197 | request = ml.projects().predict(name=model_name, body=request_dict) 198 | response_image = request.execute() 199 | ``` 200 | 201 | 3. Image to Embedding generation 202 | 203 | * Use the command line & a json file! 204 | 205 | * Image has to be base64 encoded jpeg 206 | * Example format: 207 | 208 | ```json 209 | json format: 210 | {"image_bytes": {"b64":"/9j/4AAQSkZJAAQABX...zW0=="}, "key": "0"} 211 | ``` 212 | 213 | * Batch Prediction 214 | 215 | * Same as for embedding to image, but with image format json 216 | 217 | * Python API 218 | 219 | * Same as for embedding to image, but request_dict: 220 | 221 | ```python 222 | request_dict = {'instances': [{'image_bytes': {'b64': img}, 'key': '0'}]} 223 | ``` 224 | 225 | Where img is a base64 encoded jpeg 226 | 227 | ## Acknowledgements 228 | 229 | Huge shoutout to this awesome 230 | [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow). After much trial error, 231 | the architecture from this network was the one that produced the greatest 232 | generative results and ended up as the network architecture in the final version 233 | of this tool. 234 | 235 | ## Disclaimer 236 | 237 | This is not an official Google product. 238 | -------------------------------------------------------------------------------- /vae-gan/config.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | masterType: complex_model_m_gpu 4 | workerType: standard 5 | parameterServerType: standard 6 | workerCount: 1 7 | parameterServerCount: 1 8 | -------------------------------------------------------------------------------- /vae-gan/data/build_image_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converts image data to TFRecords file format with Example protos. 16 | 17 | The image data set is expected to reside in JPEG files located in the 18 | following directory structure. 19 | 20 | data_directory/image0.jpeg 21 | data_directory/image1.jpg 22 | ... 23 | data_directory/weird-image.jpeg 24 | data_directory/my-image.jpeg 25 | ... 26 | 27 | where the data_directory contains all images in the dataset. 28 | 29 | This TensorFlow script converts the data into a sharded training dataset and a 30 | sharded validation dataset consisting of TFRecord files 31 | 32 | output_directory/train-00000-of-01024 33 | output_directory/train-00001-of-01024 34 | ... 35 | output_directory/train-001023-of-01024 36 | 37 | and 38 | 39 | output_directory/validation-00000-of-00128 40 | output_directory/validation-00001-of-00128 41 | ... 42 | output_directory/validation-00127-of-00128 43 | 44 | where we have selected 1024 and 128 shards for each data set. Each record 45 | within the TFRecord file is a serialized Example proto. The Example proto 46 | contains the following fields: 47 | 48 | image/encoded: string containing JPEG encoded image in RGB colorspace 49 | image/height: integer, image height in pixels 50 | image/width: integer, image width in pixels 51 | image/colorspace: string, specifying the colorspace, always 'RGB' 52 | image/channels: integer, specifying the number of channels, always 3 53 | image/format: string, specifying the format, always'JPEG' 54 | 55 | image/filename: string containing the basename of the image file 56 | e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG' 57 | """ 58 | from __future__ import absolute_import 59 | from __future__ import division 60 | from __future__ import print_function 61 | 62 | from datetime import datetime 63 | import logging 64 | import os 65 | import random 66 | import sys 67 | import threading 68 | 69 | import numpy as np 70 | import tensorflow as tf 71 | 72 | tf.app.flags.DEFINE_string('data_directory', None, 'Data directory') 73 | tf.app.flags.DEFINE_string('output_directory', '', 'Output data directory') 74 | tf.app.flags.DEFINE_integer('num_shards', 2, 75 | 'Number of shards in TFRecord files.') 76 | tf.app.flags.DEFINE_integer('num_threads', 2, 77 | 'Number of threads to preprocess the images.') 78 | 79 | FLAGS = tf.app.flags.FLAGS 80 | 81 | VALIDATION_SIZE = 0.1 82 | 83 | 84 | def _int64_feature(value): 85 | """Wrapper for inserting int64 features into Example proto.""" 86 | if not isinstance(value, list): 87 | value = [value] 88 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 89 | 90 | 91 | def _bytes_feature(value): 92 | """Wrapper for inserting bytes features into Example proto.""" 93 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 94 | 95 | 96 | def _convert_to_example(filename, image_buffer, height, width): 97 | """Build an Example proto for an example. 98 | 99 | Args: 100 | filename: string, path to an image file, e.g., '/path/to/example.JPG' 101 | image_buffer: string, JPEG encoding of RGB image 102 | height: integer, image height in pixels 103 | width: integer, image width in pixels 104 | Returns: 105 | Example proto 106 | """ 107 | 108 | colorspace = 'RGB' 109 | channels = 3 110 | image_format = 'JPEG' 111 | 112 | example = tf.train.Example(features=tf.train.Features(feature={ 113 | 'image/height': 114 | _int64_feature(height), 115 | 'image/width': 116 | _int64_feature(width), 117 | 'image/colorspace': 118 | _bytes_feature(tf.compat.as_bytes(colorspace)), 119 | 'image/channels': 120 | _int64_feature(channels), 121 | 'image/format': 122 | _bytes_feature(tf.compat.as_bytes(image_format)), 123 | 'image/filename': 124 | _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))), 125 | 'image/encoded': 126 | _bytes_feature(tf.compat.as_bytes(image_buffer)) 127 | })) 128 | return example 129 | 130 | 131 | class ImageCoder(object): 132 | """Helper class that provides TensorFlow image coding utilities.""" 133 | 134 | def __init__(self): 135 | # Create a single Session to run all image coding calls. 136 | self._sess = tf.Session() 137 | 138 | # Initializes function that converts PNG to JPEG data. 139 | self._png_data = tf.placeholder(dtype=tf.string) 140 | image = tf.image.decode_png(self._png_data, channels=3) 141 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) 142 | 143 | # Initializes function that decodes RGB JPEG data. 144 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 145 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 146 | 147 | def png_to_jpeg(self, image_data): 148 | return self._sess.run( 149 | self._png_to_jpeg, feed_dict={self._png_data: image_data}) 150 | 151 | def decode_jpeg(self, image_data): 152 | image = self._sess.run( 153 | self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) 154 | assert len(image.shape) == 3 155 | assert image.shape[2] == 3 156 | return image 157 | 158 | 159 | def _is_png(filename): 160 | """Determine if a file contains a PNG format image. 161 | 162 | Args: 163 | filename: string, path of the image file. 164 | 165 | Returns: 166 | boolean indicating if the image is a PNG. 167 | """ 168 | return '.png' in filename 169 | 170 | 171 | def _process_image(filename, coder): 172 | """Process a single image file. 173 | 174 | Args: 175 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 176 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 177 | Returns: 178 | image_buffer: string, JPEG encoding of RGB image. 179 | height: integer, image height in pixels. 180 | width: integer, image width in pixels. 181 | """ 182 | # Read the image file. 183 | image_data = tf.gfile.FastGFile(filename, 'r').read() 184 | 185 | # Convert any PNG to JPEG's for consistency. 186 | if _is_png(filename): 187 | print('Converting PNG to JPEG for %s' % filename) 188 | image_data = coder.png_to_jpeg(image_data) 189 | 190 | # Decode the RGB JPEG. 191 | image = coder.decode_jpeg(image_data) 192 | 193 | # Check that image converted to RGB 194 | assert len(image.shape) == 3 195 | height = image.shape[0] 196 | width = image.shape[1] 197 | assert image.shape[2] == 3 198 | 199 | return image_data, height, width 200 | 201 | 202 | def _process_image_files_batch(coder, thread_index, ranges, name, filenames, 203 | num_shards): 204 | """Processes and saves list of images as TFRecord in 1 thread. 205 | 206 | Args: 207 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 208 | thread_index: integer, unique batch to run index is within [0, len(ranges)). 209 | ranges: list of pairs of integers specifying ranges of each batches to 210 | analyze in parallel. 211 | name: string, unique identifier specifying the data set 212 | filenames: list of strings; each string is a path to an image file 213 | num_shards: integer number of shards for this data set. 214 | """ 215 | # Each thread produces N shards where N = int(num_shards / num_threads). 216 | # For instance, if num_shards = 128, and the num_threads = 2, then the first 217 | # thread would produce shards [0, 64). 218 | num_threads = len(ranges) 219 | assert not num_shards % num_threads 220 | num_shards_per_batch = int(num_shards / num_threads) 221 | 222 | shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1], 223 | num_shards_per_batch + 1).astype(int) 224 | num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0] 225 | 226 | counter = 0 227 | for s in xrange(num_shards_per_batch): 228 | # Generate a sharded version of the file name, e.g. 'train-00002-of-00010' 229 | shard = thread_index * num_shards_per_batch + s 230 | output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards) 231 | output_file = os.path.join(FLAGS.output_directory, output_filename) 232 | writer = tf.python_io.TFRecordWriter(output_file) 233 | 234 | shard_counter = 0 235 | files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int) 236 | for i in files_in_shard: 237 | filename = filenames[i] 238 | 239 | image_buffer, height, width = _process_image(filename, coder) 240 | 241 | example = _convert_to_example(filename, image_buffer, height, width) 242 | writer.write(example.SerializeToString()) 243 | shard_counter += 1 244 | counter += 1 245 | 246 | if not counter % 1000: 247 | logging.info( 248 | '%s [thread %d]: Processed %d of %d images in thread batch.', 249 | datetime.now(), thread_index, counter, num_files_in_thread) 250 | sys.stdout.flush() 251 | 252 | writer.close() 253 | logging.info('%s [thread %d]: Wrote %d images to %s', 254 | datetime.now(), thread_index, shard_counter, output_file) 255 | sys.stdout.flush() 256 | shard_counter = 0 257 | logging.info('%s [thread %d]: Wrote %d images to %d shards.', 258 | datetime.now(), thread_index, counter, num_files_in_thread) 259 | sys.stdout.flush() 260 | 261 | 262 | def _process_image_files(name, filenames, num_shards): 263 | """Process and save list of images as TFRecord of Example protos. 264 | 265 | Args: 266 | name: string, unique identifier specifying the data set 267 | filenames: list of strings; each string is a path to an image file 268 | num_shards: integer number of shards for this data set. 269 | """ 270 | 271 | # Break all images into batches with a [ranges[i][0], ranges[i][1]]. 272 | spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int) 273 | ranges = [] 274 | threads = [] 275 | for i in xrange(len(spacing) - 1): 276 | ranges.append([spacing[i], spacing[i + 1]]) 277 | 278 | # Launch a thread for each batch. 279 | logging.info('Launching %d threads for spacings: %s', FLAGS.num_threads, 280 | ranges) 281 | sys.stdout.flush() 282 | 283 | # Create a mechanism for monitoring when all threads are finished. 284 | coord = tf.train.Coordinator() 285 | 286 | # Create a generic TensorFlow-based utility for converting all image codings. 287 | coder = ImageCoder() 288 | 289 | threads = [] 290 | for thread_index in xrange(len(ranges)): 291 | args = (coder, thread_index, ranges, name, filenames, num_shards) 292 | t = threading.Thread(target=_process_image_files_batch, args=args) 293 | t.start() 294 | threads.append(t) 295 | 296 | # Wait for all the threads to terminate. 297 | coord.join(threads) 298 | logging.info('%s: Finished writing all %d images in data set.', 299 | datetime.now(), len(filenames)) 300 | sys.stdout.flush() 301 | 302 | 303 | def _find_image_files(data_dir): 304 | """Build a list of all images files and labels in the data set. 305 | 306 | Args: 307 | data_dir: string, path to the root directory of images. 308 | 309 | Assumes that the image data set resides in JPEG files located in 310 | the following directory structure. 311 | 312 | data_dir/image.jpg 313 | 314 | 315 | Returns: 316 | filenames: list of strings; each string is a path to an image file. 317 | """ 318 | filenames = [] 319 | 320 | file_extensions = ['.jpeg', '.jpg', '.png'] 321 | file_extensions += [ext.upper() for ext in file_extensions] 322 | 323 | for ext in file_extensions: 324 | file_path = '%s/*%s' % (data_dir, ext) 325 | matching_files = tf.gfile.Glob(file_path) 326 | filenames.extend(matching_files) 327 | 328 | # Shuffle the ordering of all image files in order to guarantee 329 | # random ordering of the images saved in the TFRecord files. 330 | # Make the randomization repeatable. 331 | shuffled_index = range(len(filenames)) 332 | random.seed(12345) 333 | random.shuffle(shuffled_index) 334 | 335 | filenames = [filenames[i] for i in shuffled_index] 336 | 337 | logging.info('Found %d JPEG files inside %s.', len(filenames), data_dir) 338 | return filenames 339 | 340 | 341 | def _process_datasets(directory, num_shards): 342 | """Process a complete data set and save it as a TFRecord. 343 | 344 | Args: 345 | directory: string, root path to the data set. 346 | num_shards: integer number of shards for this data set. 347 | """ 348 | filenames = _find_image_files(directory) 349 | 350 | num_train_files = int(len(filenames) * (1 - VALIDATION_SIZE)) 351 | num_validation_files = len(filenames) - num_train_files 352 | 353 | train_filenames = filenames[:num_train_files] 354 | validation_filenames = filenames[-num_validation_files:] 355 | 356 | _process_image_files('train', train_filenames, num_shards) 357 | _process_image_files('validation', validation_filenames, num_shards) 358 | 359 | 360 | def main(unused_argv): 361 | assert not FLAGS.num_shards % FLAGS.num_threads, ( 362 | 'Please make the FLAGS.num_threads commensurate with FLAGS.num_shards') 363 | logging.info('Saving results to %s', FLAGS.output_directory) 364 | _process_datasets(FLAGS.data_directory, FLAGS.num_shards) 365 | 366 | 367 | if __name__ == '__main__': 368 | logging.basicConfig(level=logging.INFO) 369 | tf.app.run() 370 | -------------------------------------------------------------------------------- /vae-gan/data/create_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Creates and deploys model to Cloud ML engine from training job. 17 | # 18 | # Assumes user has Cloud Platform project setup and cloud sdk installed. 19 | # Flags: 20 | # 21 | # Required if -l not set: 22 | # [-j JOB_NAME] : specifies training job from which to create / deploy model. 23 | # Optional: 24 | # [-l ] : if set, lists 10 most recent jobs run by user. 25 | 26 | JOB_NAME='' 27 | LIST_JOBS=false 28 | 29 | while getopts 'j:l' flag; do 30 | case "${flag}" in 31 | j) JOB_NAME="${OPTARG}" 32 | ;; 33 | l) LIST_JOBS=true 34 | esac 35 | done 36 | 37 | if [[ -z "${JOB_NAME}" && "${LIST_JOBS}" == false ]]; then 38 | echo "Error: -j flag required" 39 | echo "Usage: [-j JOB_NAME]" 40 | echo "Specifies job name to generate model from" 41 | exit 1 42 | fi 43 | 44 | if [[ "${LIST_JOBS}" == true && -z "${JOB_NAME}" ]]; then 45 | echo "Most recent jobs" 46 | gcloud ml-engine jobs list --limit=10 47 | echo "Exiting...." 48 | exit 1 49 | fi 50 | 51 | readonly PROJECT=$(gcloud config list project --format "value(core.project)") 52 | readonly JOB_ID="${JOB_NAME}" 53 | readonly BUCKET="gs://${PROJECT}" 54 | readonly GCS_PATH="${BUCKET}/${USER}/${JOB_ID}" 55 | 56 | readonly EMBED_MODEL_NAME="${JOB_ID}_embed_to_image" 57 | readonly IMAGE_MODEL_NAME="${JOB_ID}_image_to_embed" 58 | readonly VERSION_NAME=v1 59 | 60 | OUTPUT=$(gcloud ml-engine jobs describe "${JOB_ID}") 61 | 62 | if [[ $(echo "${OUTPUT}" | grep -i 'state: running') ]]; then 63 | echo 64 | echo "Training task is running." 65 | echo "Please wait for task to succeeed before creating model." 66 | echo "Exiting..." 67 | exit 1 68 | elif [[ $(echo "${OUTPUT}" | grep -i 'state: failed') ]]; then 69 | echo 70 | echo "Training task failed. Please rerun training job." 71 | echo "Exiting..." 72 | exit 1 73 | elif [[ $(echo "${OUTPUT}" | grep -i 'state: cancelled') ]]; then 74 | echo 75 | echo "Training task cancelled. Please rerun training job." 76 | echo "Exiting..." 77 | exit 1 78 | elif [[ $(echo "${OUTPUT}" | grep -i 'state: succeeded') ]]; then 79 | echo 80 | echo "Training task succeeded." 81 | echo "Creating embedding to image model...." 82 | gcloud ml-engine models create "${EMBED_MODEL_NAME}" \ 83 | --regions us-central1 84 | 85 | echo "Deploying embedding to image model...." 86 | 87 | gcloud ml-engine versions create "${VERSION_NAME}" \ 88 | --model "${EMBED_MODEL_NAME}" \ 89 | --origin "${BUCKET}/${JOB_ID}/output/model/saved_model_embed_in" \ 90 | --runtime-version=1.0 91 | 92 | echo 93 | echo "Training task succeeded." 94 | echo "Creating image to embedding model...." 95 | gcloud ml-engine models create "${IMAGE_MODEL_NAME}" \ 96 | --regions us-central1 97 | 98 | echo "Deploying image to embedding model...." 99 | 100 | gcloud ml-engine versions create "${VERSION_NAME}" \ 101 | --model "${IMAGE_MODEL_NAME}" \ 102 | --origin "${BUCKET}/${JOB_ID}/output/model/saved_model_image_in" \ 103 | --runtime-version=1.0 104 | else 105 | echo 106 | echo "Task in unknown state. Please check cloud console." 107 | echo "Use -l to list 10 most recent jobs" 108 | echo "Exiting..." 109 | exit 1 110 | fi 111 | 112 | -------------------------------------------------------------------------------- /vae-gan/data/create_random_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Creates random embedding.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import json 22 | import numpy as np 23 | 24 | embeds = (5 * np.squeeze(np.random.randn(1, 100))).tolist() 25 | json_object = {'key': '0', 'embeddings': embeds} 26 | print(json.dumps(json_object)) 27 | -------------------------------------------------------------------------------- /vae-gan/data/display_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Displays randomly generated image.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import cStringIO 23 | from PIL import Image 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--base64_image', type=str, default='') 27 | args, _ = parser.parse_known_args() 28 | 29 | im = args.base64_image.replace('_', '/').replace('-', '+') 30 | 31 | missing_base64_padding = len(im) % 4 32 | if missing_base64_padding != 0: 33 | im += ('=' * (4 - missing_base64_padding)) 34 | 35 | img = Image.open(cStringIO.StringIO(im.decode('base64'))) 36 | img.show() 37 | -------------------------------------------------------------------------------- /vae-gan/data/generate_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Generates and Displays Random Image. 17 | # 18 | # Assumes user has Cloud Platform project setup and cloud sdk installed. 19 | # Flags: 20 | # 21 | # Required if -l not set: 22 | # [-n MODEL_NAME] : specifies model to generate image. 23 | # Optional: 24 | # [-l ] : if set, lists all models associated with user. 25 | # [-d TEMP_DIR] : directory to which to write json file. 26 | 27 | MODEL_NAME='' 28 | LIST_MODELS=false 29 | TMP_DIR='/tmp' 30 | 31 | while getopts 'm:ld:' flag; do 32 | case "${flag}" in 33 | m) MODEL_NAME="${OPTARG}" 34 | ;; 35 | l) LIST_JOBS=true 36 | ;; 37 | d) TMP_DIR="${OPTARG%/}" 38 | esac 39 | done 40 | 41 | TMP_FILE="${TMP_DIR}/temp.json" 42 | 43 | if [[ -z "${MODEL_NAME}" && "${LIST_MODELS}" == false ]]; then 44 | echo "Error: -m flag required" 45 | echo "Usage: [-m MODEL_NAME]" 46 | echo "Specifies job name to generate model from" 47 | echo "If model name unknown, use -l flag to list model names" 48 | exit 1 49 | fi 50 | 51 | if [[ "${LIST_MODELS}" == true && -z "${MODEL_NAME}" ]]; then 52 | echo "Your models" 53 | gcloud ml-engine models list 54 | echo "Exiting...." 55 | exit 1 56 | fi 57 | 58 | JSON_OBJ=$(python create_random_embedding.py) 59 | 60 | echo -e "${JSON_OBJ}" > "${TMP_FILE}" 61 | 62 | OUTPUT=$(gcloud ml-engine predict --model "${MODEL_NAME}" --json-instances "${TMP_FILE}") 63 | IMAGE=$(echo "${OUTPUT}"| awk 'NR==2 {print $2}') 64 | python display_image.py --base64_image "${IMAGE}" 65 | -------------------------------------------------------------------------------- /vae-gan/data/run_training.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Runs preprocessing of user images and starts a training job for the 17 | # generative model. 18 | # 19 | # Assumes user has Cloud Platform project setup and cloud sdk installed. 20 | # Flags: 21 | # 22 | # Required: 23 | # [-d DATA_DIRECTORY] : specifies directory containing images (jpg / png) 24 | # Optional: 25 | # [-c ] : if set, center-crops images. If not set, randomly crops images. 26 | # [-p PORT] : port to start Tensorboard monitoring. 27 | 28 | ORIGINAL_DATA_DIRECTORY='' 29 | DATA_DIR_PRESENT=false 30 | CENTER_CROP='False' 31 | PORT=6006 32 | 33 | while getopts 'd::cp:' flag; do 34 | case "${flag}" in 35 | d) ORIGINAL_DATA_DIRECTORY="${OPTARG%/}" 36 | ;; 37 | c) CENTER_CROP='True' 38 | ;; 39 | p) PORT="${OPTARG}" 40 | esac 41 | done 42 | 43 | readonly ORIGINAL_DATA_DIRECTORY 44 | readonly CENTER_CROP 45 | readonly PORT 46 | 47 | if [[ -z "${ORIGINAL_DATA_DIRECTORY}" ]]; then 48 | echo "Error: -d flag required" 49 | echo "Usage: [-d DATA_DIRECTORY]" 50 | echo "Specifies directory containing image files" 51 | exit 1 52 | fi 53 | 54 | readonly PROJECT=$(gcloud config list project --format "value(core.project)") 55 | readonly JOB_ID="generative_${USER}_$(date +%Y%m%d_%H%M%S)" 56 | readonly BUCKET="gs://${PROJECT}" 57 | readonly GCS_PATH="${BUCKET}/${USER}/${JOB_ID}" 58 | 59 | echo 60 | echo "Using job id: ${JOB_ID}" 61 | set -e 62 | 63 | python build_image_data.py \ 64 | --data_directory "${ORIGINAL_DATA_DIRECTORY}" \ 65 | --output_directory "${BUCKET}/${JOB_ID}/test_output" 66 | 67 | gcloud ml-engine jobs submit training "${JOB_ID}" \ 68 | --stream-logs \ 69 | --module-name trainer.task \ 70 | --package-path "../trainer" \ 71 | --staging-bucket "${BUCKET}" \ 72 | --region us-east1 \ 73 | --config "../config.yaml" \ 74 | -- \ 75 | --batch_size 64 \ 76 | --data_dir "${BUCKET}/${JOB_ID}/test_output" \ 77 | --output_path "${BUCKET}/${JOB_ID}/output" \ 78 | --center_crop "${CENTER_CROP}" & 79 | 80 | tensorboard \ 81 | --logdir "${BUCKET}/${JOB_ID}/output" \ 82 | --port="${PORT}" 83 | 84 | wait && echo "Finished training" 85 | -------------------------------------------------------------------------------- /vae-gan/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /vae-gan/trainer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Rgb VAE - GAN implementation for CloudML.""" 15 | 16 | import argparse 17 | import os 18 | 19 | import tensorflow as tf 20 | 21 | from tensorflow.python.saved_model import builder 22 | from tensorflow.python.saved_model import signature_constants 23 | from tensorflow.python.saved_model import signature_def_utils 24 | from tensorflow.python.saved_model import tag_constants 25 | from tensorflow.python.saved_model import utils as saved_model_utils 26 | 27 | import util 28 | from util import override_if_not_in_args 29 | 30 | # Global constants for Rgb dataset 31 | EMBEDDING_DIMENSION = 100 32 | LAYER_DIM = 64 33 | TRAIN, EVAL = 'TRAIN', 'EVAL' 34 | PREDICT_EMBED_IN, PREDICT_IMAGE_IN = 'PREDICT_EMBED_IN', 'PREDICT_IMAGE_IN' 35 | 36 | 37 | def build_signature(inputs, outputs): 38 | """Build the signature for use when exporting the graph. 39 | 40 | Args: 41 | inputs: a dictionary from tensor name to tensor 42 | outputs: a dictionary from tensor name to tensor 43 | Returns: 44 | The signature, a SignatureDef proto, specifies the input/output tensors 45 | to bind when running prediction. 46 | """ 47 | signature_inputs = { 48 | key: saved_model_utils.build_tensor_info(tensor) 49 | for key, tensor in inputs.items() 50 | } 51 | signature_outputs = { 52 | key: saved_model_utils.build_tensor_info(tensor) 53 | for key, tensor in outputs.items() 54 | } 55 | 56 | signature_def = signature_def_utils.build_signature_def( 57 | signature_inputs, signature_outputs, 58 | signature_constants.PREDICT_METHOD_NAME) 59 | 60 | return signature_def 61 | 62 | 63 | def create_model(): 64 | """Factory method that creates model to be used by generic task.py.""" 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--learning_rate', type=float, default=0.0002) 67 | parser.add_argument('--dropout', type=float, default=0.5) 68 | parser.add_argument('--beta1', type=float, default=0.5) 69 | parser.add_argument('--resized_image_size', type=int, default=64) 70 | parser.add_argument('--crop_image_dimension', type=int, default=None) 71 | parser.add_argument( 72 | '--center_crop', dest='center_crop', default=False, action='store_true') 73 | args, task_args = parser.parse_known_args() 74 | override_if_not_in_args('--max_steps', '80000', task_args) 75 | override_if_not_in_args('--batch_size', '64', task_args) 76 | override_if_not_in_args('--eval_set_size', '370', task_args) 77 | override_if_not_in_args('--eval_interval_secs', '2', task_args) 78 | override_if_not_in_args('--log_interval_secs', '2', task_args) 79 | override_if_not_in_args('--min_train_eval_rate', '2', task_args) 80 | 81 | return Model(args.learning_rate, args.dropout, args.beta1, 82 | args.resized_image_size, args.crop_image_dimension, 83 | args.center_crop), task_args 84 | 85 | 86 | class GraphReferences(object): 87 | """Holder of base tensors used for training model using common task.""" 88 | 89 | def __init__(self): 90 | self.image = None 91 | self.label = None 92 | self.global_step = None 93 | self.keys = None 94 | self.predictions = None 95 | self.embeddings = [] 96 | self.cost_encoder = None 97 | self.cost_generator = None 98 | self.cost_discriminator = None 99 | self.cost_balance = None 100 | self.prediction_image = None 101 | self.dis_real = None 102 | self.dis_fake = None 103 | self.encoder_optimizer = None 104 | self.generator_optimizer = None 105 | self.discriminator_optimizer = None 106 | 107 | 108 | class Model(object): 109 | """Tensorflow model for Rgb VAE-GAN.""" 110 | 111 | def __init__(self, learning_rate, dropout, beta1, resized_image_size, 112 | crop_image_dimension, center_crop): 113 | """Initializes VAE-GAN. DCGAN architecture: https://arxiv.org/abs/1511.06434 114 | 115 | Args: 116 | learning_rate: The learning rate for the three networks. 117 | dropout: The dropout rate for training the network. 118 | beta1: Exponential decay rate for the 1st moment estimates. 119 | resized_image_size: Desired size of resized image. 120 | crop_image_dimension: Square size of the bounding box. 121 | center_crop: True iff images should be center cropped. 122 | """ 123 | self.learning_rate = learning_rate 124 | self.dropout = dropout 125 | self.beta1 = beta1 126 | self.resized_image_size = resized_image_size 127 | self.crop_image_dimension = crop_image_dimension 128 | self.center_crop = center_crop 129 | self.has_exported_embed_in = False 130 | self.has_exported_image_in = False 131 | self.batch_size = 0 132 | 133 | def leaky_relu(self, x, name, leak=0.2): 134 | """Leaky relu activation function. 135 | 136 | Args: 137 | x: input into layer. 138 | name: name scope of layer. 139 | leak: slope that provides non-zero y when x < 0. 140 | 141 | Returns: 142 | The leaky relu activation. 143 | """ 144 | return tf.maximum(x, leak * x, name=name) 145 | 146 | def build_graph(self, data_dir, batch_size, mode): 147 | """Builds the VAE-GAN network. 148 | 149 | Args: 150 | data_dir: Locations of input data. 151 | batch_size: Batch size of input data. 152 | mode: Mode of the graph (TRAINING, EVAL, or PREDICT) 153 | 154 | Returns: 155 | The tensors used in training the model. 156 | """ 157 | tensors = GraphReferences() 158 | assert batch_size > 0 159 | self.batch_size = batch_size 160 | if mode is PREDICT_EMBED_IN: 161 | # Input embeddings to send through decoder/generator network. 162 | tensors.embeddings = tf.placeholder( 163 | tf.float32, shape=(None, EMBEDDING_DIMENSION), name='input') 164 | elif mode is PREDICT_IMAGE_IN: 165 | tensors.prediction_image = tf.placeholder( 166 | tf.string, shape=(None,), name='input') 167 | tensors.image = tf.map_fn( 168 | self.process_image, tensors.prediction_image, dtype=tf.float32) 169 | 170 | if mode in (TRAIN, EVAL): 171 | mode_string = 'train' 172 | if mode is EVAL: 173 | mode_string = 'validation' 174 | 175 | tensors.image = util.read_and_decode( 176 | data_dir, batch_size, mode_string, self.resized_image_size, 177 | self.crop_image_dimension, self.center_crop) 178 | 179 | tensors.image = tf.reshape(tensors.image, [ 180 | -1, self.resized_image_size, self.resized_image_size, 3 181 | ]) 182 | 183 | tf.summary.image('original_images', tensors.image, 1) 184 | 185 | tensors.embeddings, y_mean, y_stddev = self.encode(tensors.image) 186 | 187 | if mode is PREDICT_IMAGE_IN: 188 | tensors.image = tf.reshape(tensors.image, [ 189 | -1, self.resized_image_size, self.resized_image_size, 3 190 | ]) 191 | tensors.embeddings, y_mean, _ = self.encode(tensors.image, False) 192 | tensors.predictions = tensors.embeddings 193 | return tensors 194 | 195 | decoded_images = self.decode(tensors.embeddings) 196 | 197 | if mode is TRAIN: 198 | tf.summary.image('decoded_images', decoded_images, 1) 199 | 200 | if mode is PREDICT_EMBED_IN: 201 | decoded_images = self.decode(tensors.embeddings, False, True) 202 | output_images = (decoded_images + 1.0) / 2.0 203 | output_img = tf.image.convert_image_dtype( 204 | output_images, dtype=tf.uint8, saturate=True)[0] 205 | output_data = tf.image.encode_png(output_img) 206 | output = tf.encode_base64(output_data) 207 | 208 | tensors.predictions = output 209 | 210 | return tensors 211 | 212 | tensors.dis_fake = self.discriminate(decoded_images, self.dropout) 213 | tensors.dis_real = self.discriminate( 214 | tensors.image, self.dropout, reuse=True) 215 | 216 | tensors.cost_encoder = self.loss_encoder(tensors.image, decoded_images, 217 | y_mean, y_stddev) 218 | tensors.cost_generator = self.loss_generator(tensors.dis_fake) 219 | tensors.cost_discriminator = self.loss_discriminator( 220 | tensors.dis_real, tensors.dis_fake) 221 | 222 | if mode in (TRAIN, EVAL): 223 | tf.summary.scalar('cost_encoder', tensors.cost_encoder) 224 | tf.summary.scalar('cost_generator', tensors.cost_generator) 225 | tf.summary.scalar('cost_discriminator', tensors.cost_discriminator) 226 | tf.summary.tensor_summary('disc_fake', tensors.dis_fake) 227 | tf.summary.tensor_summary('disc_real', tensors.dis_real) 228 | tf.summary.scalar('mean_disc_fake', tf.reduce_mean(tensors.dis_fake)) 229 | tf.summary.scalar('mean_disc_real', tf.reduce_mean(tensors.dis_real)) 230 | 231 | # Cost of Decoder/Generator is VAE network cost and cost of generator 232 | # being detected by the discriminator. 233 | enc_weight = 1 234 | gen_weight = 1 235 | tensors.cost_balance = ( 236 | enc_weight * tensors.cost_encoder + gen_weight * tensors.cost_generator) 237 | 238 | tensors.global_step = tf.Variable(0, name='global_step', trainable=False) 239 | t_vars = tf.trainable_variables() 240 | 241 | with tf.variable_scope(tf.get_variable_scope(), reuse=None): 242 | encoder_vars = [var for var in t_vars if var.name.startswith('enc_')] 243 | generator_vars = [var for var in t_vars if var.name.startswith('gen_')] 244 | discriminator_vars = [ 245 | var for var in t_vars if var.name.startswith('disc_') 246 | ] 247 | vae_vars = encoder_vars + generator_vars 248 | 249 | # Create optimizers for each network. 250 | tensors.encoder_optimizer = tf.train.AdamOptimizer( 251 | learning_rate=self.learning_rate, beta1=self.beta1).minimize( 252 | tensors.cost_encoder, 253 | var_list=vae_vars, 254 | global_step=tensors.global_step) 255 | tensors.generator_optimizer = tf.train.AdamOptimizer( 256 | learning_rate=self.learning_rate, beta1=self.beta1).minimize( 257 | tensors.cost_balance, 258 | var_list=vae_vars, 259 | global_step=tensors.global_step) 260 | tensors.discriminator_optimizer = tf.train.AdamOptimizer( 261 | learning_rate=self.learning_rate, beta1=self.beta1).minimize( 262 | tensors.cost_discriminator, 263 | var_list=discriminator_vars, 264 | global_step=tensors.global_step) 265 | 266 | return tensors 267 | 268 | def build_train_graph(self, data_paths, batch_size): 269 | """Builds the training VAE-GAN graph. 270 | 271 | Args: 272 | data_paths: Locations of input data. 273 | batch_size: Batch size of input data. 274 | 275 | Returns: 276 | The tensors used in training the model. 277 | """ 278 | return self.build_graph(data_paths, batch_size, mode=TRAIN) 279 | 280 | def build_eval_graph(self, data_paths, batch_size): 281 | """Builds the evaluation VAE-GAN graph. 282 | 283 | Args: 284 | data_paths: Locations of input data. 285 | batch_size: Batch size of input data. 286 | 287 | Returns: 288 | The tensors used in training the model. 289 | """ 290 | return self.build_graph(data_paths, batch_size, mode=EVAL) 291 | 292 | def build_prediction_embedding_graph(self): 293 | """Builds the prediction VAE-GAN graph for embedding input. 294 | 295 | Returns: 296 | The inputs and outputs of the prediction. 297 | """ 298 | tensors = self.build_graph(None, 1, PREDICT_EMBED_IN) 299 | 300 | keys_p = tf.placeholder(tf.string, shape=[None]) 301 | inputs = {'key': keys_p, 'embeddings': tensors.embeddings} 302 | keys = tf.identity(keys_p) 303 | outputs = {'key': keys, 'prediction': tensors.predictions} 304 | 305 | return inputs, outputs 306 | 307 | def build_prediction_image_graph(self): 308 | """Builds the prediction VAE-GAN graph for image input. 309 | 310 | Returns: 311 | The inputs and outputs of the prediction. 312 | """ 313 | tensors = self.build_graph(None, 1, PREDICT_IMAGE_IN) 314 | 315 | keys_p = tf.placeholder(tf.string, shape=[None]) 316 | inputs = {'key': keys_p, 'image_bytes': tensors.prediction_image} 317 | keys = tf.identity(keys_p) 318 | outputs = {'key': keys, 'prediction': tensors.predictions} 319 | 320 | return inputs, outputs 321 | 322 | def encode(self, images, is_training=True, reuse=None): 323 | """Encoder network for VAE. 324 | 325 | Args: 326 | images: Images to encode to latent space vector. 327 | is_training: True iff in training mode. 328 | reuse: True iff variables should be reused. 329 | 330 | Returns: 331 | The embedding vector, mean and standard deviation vectors. 332 | """ 333 | with tf.variable_scope(tf.get_variable_scope(), reuse=None): 334 | # Convolution Layer 1 335 | conv = self.leaky_relu( 336 | tf.layers.conv2d( 337 | inputs=images, 338 | filters=LAYER_DIM, 339 | kernel_size=4, 340 | strides=(2, 2), 341 | padding='same', 342 | name='enc_conv0'), 'enc_r0') 343 | 344 | layers = [conv] 345 | for i, filters in enumerate([LAYER_DIM * 2, LAYER_DIM * 4, 346 | LAYER_DIM * 8]): 347 | # Convolutional Layer 348 | conv = tf.layers.conv2d( 349 | inputs=layers[-1], 350 | filters=filters, 351 | kernel_size=4, 352 | strides=(2, 2), 353 | padding='same', 354 | name='enc_conv' + str(i + 1)) 355 | 356 | # Batch Norm Layer 357 | bn = tf.contrib.layers.batch_norm( 358 | conv, 359 | decay=0.9, 360 | updates_collections=None, 361 | epsilon=1e-5, 362 | scale=True, 363 | reuse=reuse, 364 | scope='enc_bn' + str(i + 1), 365 | is_training=is_training) 366 | 367 | # ReLU activation 368 | relu = self.leaky_relu(bn, name='enc_rl' + str(i + 1)) 369 | layers.append(relu) 370 | 371 | # Fully Connected Layer 372 | conv4_flat = tf.reshape(layers[-1], [-1, 4 * 4 * LAYER_DIM * 8]) 373 | 374 | # Get Mean and Standard Deviation Vectors 375 | y_mean = tf.layers.dense( 376 | inputs=conv4_flat, 377 | units=EMBEDDING_DIMENSION, 378 | activation=None, 379 | name='enc_y_mean') 380 | y_stddev = tf.layers.dense( 381 | inputs=conv4_flat, 382 | units=EMBEDDING_DIMENSION, 383 | activation=None, 384 | name='enc_y_stddev') 385 | samples = tf.random_normal( 386 | [self.batch_size, EMBEDDING_DIMENSION], 0, 1, dtype=tf.float32) 387 | 388 | y_vector = y_mean + (y_stddev * samples) 389 | return y_vector, y_mean, y_stddev 390 | 391 | def decode(self, embeddings, is_training=True, reuse=False): 392 | """Decoder network for VAE / Generator network for GAN. 393 | 394 | Args: 395 | embeddings: Vector to decode into images. 396 | is_training: True iff in training mode. 397 | reuse: True iff vars should be reused. 398 | 399 | Returns: 400 | The decoded images. 401 | """ 402 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 403 | # Fully Connected Layers 404 | fc3 = tf.layers.dense( 405 | inputs=embeddings, 406 | units=4 * 4 * LAYER_DIM * 8, 407 | activation=None, 408 | name='gen_fc3') 409 | fc3_reshaped = tf.reshape(fc3, [-1, 4, 4, LAYER_DIM * 8]) 410 | 411 | layers = [fc3_reshaped] 412 | for i, filters in enumerate([LAYER_DIM * 4, LAYER_DIM * 2, LAYER_DIM]): 413 | # Batch Norm Layer 414 | bn = tf.contrib.layers.batch_norm( 415 | layers[-1], 416 | decay=0.9, 417 | updates_collections=None, 418 | epsilon=1e-5, 419 | scale=True, 420 | reuse=reuse, 421 | scope='gen_bn' + str(i), 422 | is_training=is_training) 423 | 424 | # ReLU activation 425 | relu = tf.nn.relu(bn, name='gen_rl' + str(i)) 426 | 427 | # "Deconvolution" Layer 428 | deconv = tf.layers.conv2d_transpose( 429 | inputs=relu, 430 | filters=filters, 431 | kernel_size=4, 432 | strides=(2, 2), 433 | padding='same', 434 | name='gen_deconv' + str(i)) 435 | layers.append(deconv) 436 | 437 | # Batch norm 438 | bn = tf.nn.relu( 439 | tf.contrib.layers.batch_norm( 440 | layers[-1], 441 | decay=0.9, 442 | updates_collections=None, 443 | epsilon=1e-5, 444 | scale=True, 445 | reuse=None, 446 | scope='gen_bn3', 447 | is_training=is_training), 448 | name='gen_rl3') 449 | 450 | # "Deconvolution" Layer 3 451 | deconv = tf.layers.conv2d_transpose( 452 | inputs=bn, 453 | filters=3, 454 | kernel_size=4, 455 | strides=(2, 2), 456 | padding='same', 457 | activation=tf.nn.tanh, 458 | name='gen_deconv3') 459 | 460 | return deconv 461 | 462 | def discriminate(self, input_images, dropout=0.5, reuse=False): 463 | """Decoder network for VAE / Generator network for GAN. 464 | 465 | Args: 466 | input_images: Input images to discriminate. 467 | dropout: Dropout used for training. 468 | reuse: True iff variables should be in reuse mode. 469 | 470 | Returns: 471 | Whether the images are real or fake. 472 | """ 473 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 474 | # Convolution Layer 1 475 | conv = self.leaky_relu( 476 | tf.layers.conv2d( 477 | inputs=input_images, 478 | filters=LAYER_DIM, 479 | kernel_size=4, 480 | strides=(2, 2), 481 | padding='same', 482 | name='disc_conv0'), 'disc_r0') 483 | 484 | layers = [conv] 485 | for i, filters in enumerate([LAYER_DIM * 2, LAYER_DIM * 4, 486 | LAYER_DIM * 8]): 487 | # Convolutional Layer 488 | conv = tf.layers.conv2d( 489 | inputs=layers[-1], 490 | filters=filters, 491 | kernel_size=4, 492 | strides=(2, 2), 493 | padding='same', 494 | name='disc_conv' + str(i + 1)) 495 | 496 | # Batch Norm Layer 497 | bn = tf.contrib.layers.batch_norm( 498 | conv, 499 | decay=0.9, 500 | updates_collections=None, 501 | epsilon=1e-5, 502 | scale=True, 503 | reuse=reuse, 504 | scope='disc_bn' + str(i + 1), 505 | is_training=True) 506 | 507 | # ReLU activation 508 | relu = self.leaky_relu(bn, name='disc_rl' + str(i + 1)) 509 | layers.append(relu) 510 | 511 | # Fully Connected Layer 512 | conv_flat = tf.reshape(layers[-1], [-1, 4 * 4 * LAYER_DIM * 8]) 513 | dropout_output = tf.nn.dropout(conv_flat, dropout) 514 | fc = tf.layers.dense( 515 | inputs=dropout_output, 516 | units=1, 517 | activation=tf.nn.sigmoid, 518 | name='disc_fc0') 519 | 520 | return fc 521 | 522 | def loss_encoder(self, images, d_images, mean, stddev): 523 | """Computes the loss of the VAE. 524 | 525 | Args: 526 | images: The input images to the VAE. 527 | d_images: The decoded images produced by the VAE. 528 | mean: The mean vector output by the encoder. 529 | stddev: The sttdev vector output by the encoder. 530 | 531 | Returns: 532 | The cost of the VAE. 533 | """ 534 | cost_reconstruct = tf.reduce_sum(tf.square(images - d_images)) 535 | 536 | cost_latent = 0.5 * tf.reduce_sum( 537 | tf.square(mean) + tf.square(stddev) - 538 | tf.log(tf.maximum(tf.square(stddev), 1e-10)) - 1, 1) 539 | 540 | cost_encoder = tf.reduce_mean(cost_latent + cost_reconstruct) 541 | return cost_encoder / (self.resized_image_size * self.resized_image_size) 542 | 543 | def loss_generator(self, dis_fake): 544 | """Computes the loss of the generator network. 545 | 546 | Args: 547 | dis_fake: The output of the discriminator for the fake images. 548 | 549 | Returns: 550 | The cost of the generator. 551 | """ 552 | return tf.reduce_mean(-1 * tf.log(tf.clip_by_value(dis_fake, 1e-10, 1.0))) 553 | 554 | def loss_discriminator(self, dis_real, dis_fake): 555 | """Computes the loss of the discriminator network. 556 | 557 | Args: 558 | dis_real: The output of the discriminator for the real images. 559 | dis_fake: The output of the discriminator for the fake images. 560 | 561 | Returns: 562 | The cost of the discriminator. 563 | """ 564 | return tf.reduce_mean(-1 * 565 | (tf.log(tf.clip_by_value(dis_real, 1e-10, 1.0)) + 566 | tf.log(tf.clip_by_value(1 - dis_fake, 1e-10, 1.0)))) 567 | 568 | def export(self, last_checkpoint, output_dir): 569 | """Exports the prediction graph. 570 | 571 | Args: 572 | last_checkpoint: The last checkpoint saved. 573 | output_dir: Directory to save graph. 574 | """ 575 | if not self.has_exported_embed_in: 576 | with tf.Session(graph=tf.Graph()) as sess: 577 | inputs, outputs = self.build_prediction_embedding_graph() 578 | init_op = tf.global_variables_initializer() 579 | sess.run(init_op) 580 | trained_saver = tf.train.Saver() 581 | trained_saver.restore(sess, last_checkpoint) 582 | 583 | predict_signature_def = build_signature(inputs, outputs) 584 | # Create a saver for writing SavedModel training checkpoints. 585 | build = builder.SavedModelBuilder( 586 | os.path.join(output_dir, 'saved_model_embed_in')) 587 | build.add_meta_graph_and_variables( 588 | sess, [tag_constants.SERVING], 589 | signature_def_map={ 590 | signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 591 | predict_signature_def 592 | }, 593 | assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)) 594 | self.has_exported_embed_in = True 595 | build.save() 596 | 597 | if not self.has_exported_image_in: 598 | with tf.Session(graph=tf.Graph()) as sess: 599 | inputs, outputs = self.build_prediction_image_graph() 600 | init_op = tf.global_variables_initializer() 601 | sess.run(init_op) 602 | trained_saver = tf.train.Saver() 603 | trained_saver.restore(sess, last_checkpoint) 604 | 605 | predict_signature_def = build_signature(inputs, outputs) 606 | # Create a saver for writing SavedModel training checkpoints. 607 | build = builder.SavedModelBuilder( 608 | os.path.join(output_dir, 'saved_model_image_in')) 609 | build.add_meta_graph_and_variables( 610 | sess, [tag_constants.SERVING], 611 | signature_def_map={ 612 | signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 613 | predict_signature_def 614 | }, 615 | assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)) 616 | self.has_exported_image_in = True 617 | build.save() 618 | 619 | def process_image(self, input_img): 620 | image = tf.image.decode_jpeg(input_img, channels=3) 621 | image = tf.image.central_crop(image, 0.75) 622 | image = tf.image.resize_images( 623 | image, [self.resized_image_size, self.resized_image_size]) 624 | image.set_shape((self.resized_image_size, self.resized_image_size, 3)) 625 | 626 | image = tf.cast(image, tf.float32) / 127.5 - 1 627 | return image 628 | 629 | -------------------------------------------------------------------------------- /vae-gan/trainer/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Example implementation of code to run on the Cloud ML service. 15 | 16 | This file is generic and can be reused by other models without modification. 17 | The only assumption this module has is that there exists model module that 18 | implements create_model() function. The function creates class implementing 19 | problem specific implementations of build_train_graph(), build_eval_graph(), 20 | build_prediction_graph() and format_metric_values(). 21 | """ 22 | 23 | import argparse 24 | import json 25 | import logging 26 | import os 27 | import shutil 28 | import subprocess 29 | import time 30 | import uuid 31 | 32 | import model as model_lib 33 | import tensorflow as tf 34 | from tensorflow.python.lib.io import file_io 35 | 36 | # Global constants 37 | LOWER_THRESHOLD = 0.6 38 | UPPER_THRESHOLD = 0.75 39 | 40 | 41 | class Evaluator(object): 42 | """Loads variables from latest checkpoint and performs model evaluation.""" 43 | 44 | def __init__(self, args, model, data_dir, dataset='eval'): 45 | self.eval_batch_size = args.eval_batch_size 46 | self.num_eval_batches = args.eval_set_size // self.eval_batch_size 47 | self.batch_of_examples = [] 48 | self.checkpoint_path = eval_dir(args.output_path) 49 | self.output_path = os.path.join(args.output_path, dataset) 50 | self.data_dir = data_dir 51 | self.batch_size = args.batch_size 52 | self.stream = args.streaming_eval 53 | self.model = model 54 | self.summary_op = None 55 | self.saver = None 56 | 57 | def evaluate(self, num_eval_batches=None): 58 | """Run one round of evaluation, return loss and accuracy.""" 59 | 60 | num_eval_batches = num_eval_batches or self.num_eval_batches 61 | with tf.Graph().as_default() as graph: 62 | self.tensors = self.model.build_eval_graph(self.data_dir, 63 | self.eval_batch_size) 64 | 65 | self.summary_op = tf.summary.merge_all() 66 | 67 | self.saver = tf.train.Saver() 68 | 69 | # Remove this if once Tensorflow 0.12 is standard. 70 | try: 71 | self.summary_writer = tf.summary.FileWriter(self.output_path) 72 | except AttributeError: 73 | self.summary_writer = tf.train.SummaryWriter(self.output_path) 74 | self.sv = tf.train.Supervisor( 75 | graph=graph, 76 | logdir=self.output_path, 77 | summary_writer=self.summary_writer, 78 | summary_op=None, 79 | global_step=None, 80 | saver=self.saver) 81 | 82 | last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) 83 | with self.sv.managed_session( 84 | master='', start_standard_services=False) as session: 85 | self.sv.saver.restore(session, last_checkpoint) 86 | 87 | if self.stream: 88 | self.sv.start_queue_runners(session) 89 | else: 90 | if not self.batch_of_examples: 91 | self.sv.start_queue_runners(session) 92 | for _ in range(num_eval_batches): 93 | self.batch_of_examples.append(session.run(self.tensors.examples)) 94 | 95 | global_step = tf.train.global_step(session, self.tensors.global_step) 96 | summary = session.run(self.summary) 97 | self.summary_writer.add_summary(summary, global_step) 98 | self.summary_writer.flush() 99 | 100 | def write_predictions(self): 101 | """Run one round of predictions and write predictions to csv file.""" 102 | num_eval_batches = self.num_eval_batches + 1 103 | with tf.Graph().as_default() as graph: 104 | self.tensors = self.model.build_eval_graph(self.data_dir, self.batch_size) 105 | self.saver = tf.train.Saver() 106 | self.sv = tf.train.Supervisor( 107 | graph=graph, 108 | logdir=self.output_path, 109 | summary_op=None, 110 | global_step=None, 111 | saver=self.saver) 112 | 113 | last_checkpoint = tf.train.latest_checkpoint(self.checkpoint_path) 114 | with self.sv.managed_session( 115 | master='', start_standard_services=False) as session: 116 | self.sv.saver.restore(session, last_checkpoint) 117 | 118 | with open(os.path.join(self.output_path, 'predictions.csv'), 'wb') as f: 119 | to_run = [self.tensors.keys] + self.tensors.predictions 120 | self.sv.start_queue_runners(session) 121 | last_log_progress = 0 122 | for i in range(num_eval_batches): 123 | progress = i * 100 // num_eval_batches 124 | if progress > last_log_progress: 125 | logging.info('%3d%% predictions processed', progress) 126 | last_log_progress = progress 127 | 128 | res = session.run(to_run) 129 | for element in range(len(res[0])): 130 | f.write('%s' % res[0][element]) 131 | for i in range(len(self.tensors.predictions)): 132 | f.write(',') 133 | f.write(self.model.format_prediction_values(res[i + 1][element])) 134 | f.write('\n') 135 | 136 | 137 | class Trainer(object): 138 | """Performs model training and optionally evaluation.""" 139 | 140 | def __init__(self, args, model, cluster, task): 141 | self.args = args 142 | self.model = model 143 | self.cluster = cluster 144 | self.task = task 145 | self.evaluator = Evaluator(self.args, self.model, self.args.data_dir, 146 | 'eval_set') 147 | self.train_evaluator = Evaluator(self.args, self.model, self.args.data_dir, 148 | 'train_set') 149 | self.min_train_eval_rate = args.min_train_eval_rate 150 | 151 | def run_training(self): 152 | """Runs a Master.""" 153 | ensure_output_path(self.args.output_path) 154 | self.train_path = train_dir(self.args.output_path) 155 | self.model_path = model_dir(self.args.output_path) 156 | self.is_master = self.task.type != 'worker' 157 | log_interval = self.args.log_interval_secs 158 | self.eval_interval = self.args.eval_interval_secs 159 | if self.is_master and self.task.index > 0: 160 | raise StandardError('Only one replica of master expected') 161 | 162 | if self.cluster: 163 | logging.info('Starting %s/%d', self.task.type, self.task.index) 164 | server = start_server(self.cluster, self.task) 165 | target = server.target 166 | device_fn = tf.train.replica_device_setter( 167 | ps_device='/job:ps', 168 | worker_device='/job:%s/task:%d' % (self.task.type, self.task.index), 169 | cluster=self.cluster) 170 | # We use a device_filter to limit the communication between this job 171 | # and the parameter servers, i.e., there is no need to directly 172 | # communicate with the other workers; attempting to do so can result 173 | # in reliability problems. 174 | device_filters = [ 175 | '/job:ps', 176 | '/job:%s/task:%d' % (self.task.type, self.task.index) 177 | ] 178 | config = tf.ConfigProto(device_filters=device_filters) 179 | else: 180 | target = '' 181 | device_fn = '' 182 | config = None 183 | 184 | with tf.Graph().as_default() as graph: 185 | with tf.device(device_fn): 186 | # Build the training graph. 187 | self.tensors = self.model.build_train_graph(self.args.data_dir, 188 | self.args.batch_size) 189 | 190 | # Add the variable initializer Op. 191 | # Remove this if once Tensorflow 0.12 is standard. 192 | try: 193 | init_op = tf.global_variables_initializer() 194 | except AttributeError: 195 | init_op = tf.initialize_all_variables() 196 | 197 | # Create a saver for writing training checkpoints. 198 | self.saver = tf.train.Saver() 199 | self.summary_writer = tf.summary.FileWriter( 200 | self.args.output_path, max_queue=500, flush_secs=5) 201 | 202 | self.summary_op = tf.summary.merge_all() 203 | # Create a "supervisor", which oversees the training process. 204 | self.sv = tf.train.Supervisor( 205 | graph, 206 | is_chief=self.is_master, 207 | logdir=self.train_path, 208 | init_op=init_op, 209 | saver=self.saver, 210 | summary_writer=self.summary_writer, 211 | # Write summary_ops by hand. 212 | summary_op=None, 213 | global_step=self.tensors.global_step, 214 | # No saving; we do it manually in order to easily evaluate immediately 215 | # afterwards. 216 | save_model_secs=0) 217 | 218 | should_retry = True 219 | 220 | while should_retry: 221 | try: 222 | should_retry = False 223 | with self.sv.managed_session(target, config=config) as session: 224 | self.start_time = start_time = time.time() 225 | self.last_save = self.last_log = 0 226 | self.global_step = self.last_global_step = 0 227 | self.local_step = self.last_local_step = 0 228 | self.last_global_time = self.last_local_time = start_time 229 | self.last_step_summary = 0 230 | self.summary_interval = 50 231 | 232 | # Loop until the supervisor shuts down or args.max_steps have 233 | # completed. 234 | max_steps = self.args.max_steps 235 | 236 | while not self.sv.should_stop() and self.global_step < max_steps: 237 | try: 238 | # Run one step of the model. 239 | # Partial Train Correct discriminator training taken from: 240 | # https://github.com/hardmaru/cppn-gan-vae-tensorflow/blob/master/model.py 241 | for _ in range(4): 242 | _, self.global_step = session.run( 243 | [self.tensors.encoder_optimizer, self.tensors.global_step]) 244 | 245 | for _ in range(2): 246 | _, gen_result, self.global_step = session.run([ 247 | self.tensors.generator_optimizer, 248 | self.tensors.cost_generator, self.tensors.global_step 249 | ]) 250 | if gen_result < LOWER_THRESHOLD: 251 | break 252 | 253 | disc_result, self.global_step = session.run([ 254 | self.tensors.cost_discriminator, 255 | self.tensors.global_step, 256 | ]) 257 | 258 | if disc_result > LOWER_THRESHOLD and gen_result < UPPER_THRESHOLD: 259 | _, self.global_step = session.run([ 260 | self.tensors.discriminator_optimizer, 261 | self.tensors.global_step 262 | ]) 263 | 264 | self.local_step += 1 265 | 266 | self.now = time.time() 267 | is_time_to_eval = (self.now - self.last_save) > self.eval_interval 268 | is_time_to_log = (self.now - self.last_log) > log_interval 269 | should_eval = self.is_master and is_time_to_eval 270 | should_log = is_time_to_log or should_eval 271 | should_write_summaries = self.is_master and ( 272 | self.global_step - 273 | self.last_step_summary) > self.summary_interval 274 | 275 | if should_log: 276 | self.log() 277 | 278 | if should_write_summaries: 279 | self.save_summaries(session) 280 | 281 | if should_eval: 282 | self.eval(session) 283 | 284 | except tf.errors.AbortedError: 285 | should_retry = True 286 | 287 | if self.is_master: 288 | # Take the final checkpoint and compute the final accuracy. 289 | self.eval(session) 290 | 291 | # Export the model for inference. 292 | self.model.export( 293 | tf.train.latest_checkpoint(self.train_path), self.model_path) 294 | except tf.errors.AbortedError: 295 | should_retry = True 296 | 297 | # Ask for all the services to stop. 298 | self.sv.stop() 299 | 300 | def log(self): 301 | """Logs training progress.""" 302 | logging.info('Train [%s/%d], step %d (%.3f sec) %.1f ' 303 | 'global steps/s, %.1f local steps/s', self.task.type, 304 | self.task.index, self.global_step, 305 | (self.now - self.start_time), 306 | (self.global_step - self.last_global_step) / 307 | (self.now - self.last_global_time), 308 | (self.local_step - self.last_local_step) / 309 | (self.now - self.last_local_time)) 310 | 311 | self.last_log = self.now 312 | self.last_global_step, self.last_global_time = self.global_step, self.now 313 | self.last_local_step, self.last_local_time = self.local_step, self.now 314 | 315 | def eval(self, session): 316 | """Runs evaluation loop.""" 317 | eval_start = time.time() 318 | self.saver.save(session, self.sv.save_path, self.tensors.global_step) 319 | 320 | now = time.time() 321 | 322 | # Make sure eval doesn't consume too much of total time. 323 | eval_time = now - eval_start 324 | train_eval_rate = self.eval_interval / eval_time 325 | if train_eval_rate < self.min_train_eval_rate and self.last_save > 0: 326 | logging.info('Adjusting eval interval from %.2fs to %.2fs', 327 | self.eval_interval, self.min_train_eval_rate * eval_time) 328 | self.eval_interval = self.min_train_eval_rate * eval_time 329 | 330 | self.last_save = now 331 | self.last_log = now 332 | 333 | def save_summaries(self, session): 334 | self.last_step_summary = self.global_step 335 | self.sv.summary_computed(session, 336 | session.run(self.summary_op), self.global_step) 337 | self.sv.summary_writer.flush() 338 | 339 | 340 | def main(_): 341 | model, argv = model_lib.create_model() 342 | run(model, argv) 343 | 344 | 345 | def run(model, argv): 346 | """Runs the training loop.""" 347 | parser = argparse.ArgumentParser() 348 | parser.add_argument( 349 | '--data_dir', 350 | type=str, 351 | action='append', 352 | help='The directory containing the data files.') 353 | parser.add_argument( 354 | '--output_path', 355 | type=str, 356 | help='The path to which checkpoints and other outputs ' 357 | 'should be saved. This can be either a local or GCS ' 358 | 'path.') 359 | parser.add_argument( 360 | '--max_steps', 361 | type=int,) 362 | parser.add_argument( 363 | '--batch_size', 364 | type=int, 365 | help='Number of examples to be processed per mini-batch.') 366 | parser.add_argument( 367 | '--eval_set_size', type=int, help='Number of examples in the eval set.') 368 | parser.add_argument( 369 | '--eval_batch_size', type=int, help='Number of examples per eval batch.') 370 | parser.add_argument( 371 | '--eval_interval_secs', 372 | type=float, 373 | default=5, 374 | help='Minimal interval between calculating evaluation metrics and saving' 375 | ' evaluation summaries.') 376 | parser.add_argument( 377 | '--log_interval_secs', 378 | type=float, 379 | default=5, 380 | help='Minimal interval between logging training metrics and saving ' 381 | 'training summaries.') 382 | parser.add_argument( 383 | '--write_predictions', 384 | action='store_true', 385 | default=False, 386 | help='If set, model is restored from latest checkpoint ' 387 | 'and predictions are written to a csv file and no training is performed.') 388 | parser.add_argument( 389 | '--min_train_eval_rate', 390 | type=int, 391 | default=20, 392 | help='Minimal train / eval time ratio on master. ' 393 | 'Default value 20 means that 20x more time is used for training than ' 394 | 'for evaluation. If evaluation takes more time the eval_interval_secs ' 395 | 'is increased.') 396 | parser.add_argument( 397 | '--write_to_tmp', 398 | action='store_true', 399 | default=False, 400 | help='If set, all checkpoints and summaries are written to ' 401 | 'local filesystem (/tmp/) and copied to gcs once training is done. ' 402 | 'This can speed up training but if training job fails all the summaries ' 403 | 'and checkpoints are lost.') 404 | parser.add_argument( 405 | '--copy_train_data_to_tmp', 406 | action='store_true', 407 | default=False, 408 | help='If set, training data is copied to local filesystem ' 409 | '(/tmp/). This can speed up training but requires extra space on the ' 410 | 'local filesystem.') 411 | parser.add_argument( 412 | '--copy_eval_data_to_tmp', 413 | action='store_true', 414 | default=False, 415 | help='If set, evaluation data is copied to local filesystem ' 416 | '(/tmp/). This can speed up training but requires extra space on the ' 417 | 'local filesystem.') 418 | parser.add_argument( 419 | '--streaming_eval', 420 | action='store_true', 421 | default=False, 422 | help='If set to True the evaluation is performed in streaming mode. ' 423 | 'During each eval cycle the evaluation data is read and parsed from ' 424 | 'files. This allows for having very large evaluation set. ' 425 | 'If set to False (default) evaluation data is read once and cached in ' 426 | 'memory. This results in faster evaluation cycle but can potentially ' 427 | 'use more memory (in streaming mode large per-file read-ahead buffer is ' 428 | 'used - which may exceed eval data size).') 429 | 430 | args, _ = parser.parse_known_args(argv) 431 | 432 | env = json.loads(os.environ.get('TF_CONFIG', '{}')) 433 | 434 | # Print the job data as provided by the service. 435 | logging.info('Original job data: %s', env.get('job', {})) 436 | 437 | # First find out if there's a task value on the environment variable. 438 | # If there is none or it is empty define a default one. 439 | task_data = env.get('task', None) or {'type': 'master', 'index': 0} 440 | task = type('TaskSpec', (object,), task_data) 441 | trial = task_data.get('trial') 442 | if trial is not None: 443 | args.output_path = os.path.join(args.output_path, trial) 444 | if args.write_to_tmp and args.output_path.startswith('gs://'): 445 | output_path = args.output_path 446 | args.output_path = os.path.join('/tmp/', str(uuid.uuid4())) 447 | os.makedirs(args.output_path) 448 | else: 449 | output_path = None 450 | 451 | if not args.eval_batch_size: 452 | # If eval_batch_size not set, use min of batch_size and eval_set_size 453 | args.eval_batch_size = min(args.batch_size, args.eval_set_size) 454 | logging.info('setting eval batch size to %s', args.eval_batch_size) 455 | 456 | cluster_data = env.get('cluster', None) 457 | cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None 458 | if args.write_predictions: 459 | write_predictions(args, model, cluster, task) 460 | else: 461 | dispatch(args, model, cluster, task) 462 | 463 | if output_path and (not cluster or not task or task.type == 'master'): 464 | subprocess.check_call([ 465 | 'gsutil', '-m', '-q', 'cp', '-r', args.output_path + '/*', output_path 466 | ]) 467 | shutil.rmtree(args.output_path, ignore_errors=True) 468 | 469 | 470 | def copy_data_to_tmp(input_files): 471 | """Copies data to /tmp/ and returns glob matching the files.""" 472 | files = [] 473 | for e in input_files: 474 | for path in e.split(','): 475 | files.extend(file_io.get_matching_files(path)) 476 | 477 | for path in files: 478 | if not path.startswith('gs://'): 479 | return input_files 480 | 481 | tmp_path = os.path.join('/tmp/', str(uuid.uuid4())) 482 | os.makedirs(tmp_path) 483 | subprocess.check_call(['gsutil', '-m', '-q', 'cp', '-r'] + files + [tmp_path]) 484 | return [os.path.join(tmp_path, '*')] 485 | 486 | 487 | def write_predictions(args, model, cluster, task): 488 | if not cluster or not task or task.type == 'master': 489 | pass # Run locally. 490 | else: 491 | raise ValueError('invalid task_type %s' % (task.type,)) 492 | 493 | logging.info('Starting to write predictions on %s/%d', task.type, task.index) 494 | evaluator = Evaluator(args, model, None) 495 | evaluator.write_predictions() 496 | logging.info('Done writing predictions on %s/%d', task.type, task.index) 497 | 498 | 499 | def dispatch(args, model, cluster, task): 500 | if not cluster or not task or task.type == 'master': 501 | # Run locally. 502 | Trainer(args, model, cluster, task).run_training() 503 | elif task.type == 'ps': 504 | run_parameter_server(cluster, task) 505 | elif task.type == 'worker': 506 | Trainer(args, model, cluster, task).run_training() 507 | else: 508 | raise ValueError('invalid task_type %s' % (task.type,)) 509 | 510 | 511 | def run_parameter_server(cluster, task): 512 | logging.info('Starting parameter server %d', task.index) 513 | server = start_server(cluster, task) 514 | server.join() 515 | 516 | 517 | def start_server(cluster, task): 518 | if not task.type: 519 | raise ValueError('--task_type must be specified.') 520 | if task.index is None: 521 | raise ValueError('--task_index must be specified.') 522 | 523 | # Create and start a server. 524 | return tf.train.Server( 525 | tf.train.ClusterSpec(cluster), 526 | protocol='grpc', 527 | job_name=task.type, 528 | task_index=task.index) 529 | 530 | 531 | def ensure_output_path(output_path): 532 | if not output_path: 533 | raise ValueError('output_path must be specified') 534 | 535 | # GCS doesn't have real directories. 536 | if output_path.startswith('gs://'): 537 | return 538 | 539 | ensure_dir(output_path) 540 | 541 | 542 | def ensure_dir(path): 543 | try: 544 | os.makedirs(path) 545 | except OSError as e: 546 | # If the directory already existed, ignore the error. 547 | if e.args[0] == 17: 548 | pass 549 | else: 550 | raise 551 | 552 | 553 | def train_dir(output_path): 554 | return os.path.join(output_path, 'train') 555 | 556 | 557 | def eval_dir(output_path): 558 | return os.path.join(output_path, 'eval') 559 | 560 | 561 | def model_dir(output_path): 562 | return os.path.join(output_path, 'model') 563 | 564 | 565 | if __name__ == '__main__': 566 | logging.basicConfig(level=logging.INFO) 567 | tf.app.run() 568 | -------------------------------------------------------------------------------- /vae-gan/trainer/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utility file for VAE-GAN.""" 15 | 16 | import os 17 | import tensorflow as tf 18 | 19 | 20 | def read_and_decode(data_directory, batch_size, mode, resized_image_size, 21 | crop_image_size, center_crop): 22 | """Reads and decodes TF Record files. 23 | 24 | Based on example for reading/decoding MNIST digits: 25 | https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py 26 | 27 | Args: 28 | data_directory: directory containing tfrecord files 29 | batch_size: Batch size for input data. 30 | mode: defines Train or Validation modes. 31 | resized_image_size: Desired size of image. 32 | crop_image_size: Original size to crop image. 33 | center_crop: True iff image should be center cropped. 34 | 35 | Returns: 36 | Batch of input images to train/validate model. 37 | """ 38 | tf_record_pattern = os.path.join(data_directory[0], '%s-*' % mode) 39 | data_files = tf.gfile.Glob(tf_record_pattern) 40 | 41 | queue = tf.train.string_input_producer(data_files) 42 | reader = tf.TFRecordReader() 43 | _, serialized_example = reader.read(queue) 44 | 45 | features = tf.parse_single_example( 46 | serialized_example, 47 | features={ 48 | 'image/encoded': tf.FixedLenFeature([], tf.string), 49 | 'image/height': tf.FixedLenFeature([], tf.int64), 50 | 'image/width': tf.FixedLenFeature([], tf.int64), 51 | }) 52 | 53 | image = tf.image.decode_jpeg(features['image/encoded'], channels=3) 54 | original_image_height = tf.cast(features['image/height'], tf.int32) 55 | original_image_width = tf.cast(features['image/width'], tf.int32) 56 | 57 | if crop_image_size is None: 58 | crop_image_size = tf.cast( 59 | tf.minimum(original_image_width, original_image_height), tf.int32) 60 | 61 | # Crop rectangular image to centered bounding box. 62 | tf.assert_greater_equal(original_image_height, crop_image_size) 63 | tf.assert_greater_equal(original_image_width, crop_image_size) 64 | if center_crop: 65 | image = tf.image.crop_to_bounding_box( 66 | image, (original_image_height - crop_image_size) / 2, 67 | (original_image_width - crop_image_size) / 2, crop_image_size, 68 | crop_image_size) 69 | else: 70 | image = tf.image.crop_to_bounding_box( 71 | image, 72 | tf.random_uniform( 73 | [], 74 | dtype=tf.int32, 75 | maxval=(original_image_height - crop_image_size)), 76 | tf.random_uniform( 77 | [], dtype=tf.int32, 78 | maxval=(original_image_width - crop_image_size)), crop_image_size, 79 | crop_image_size) 80 | 81 | # Resize image to desired pixel dimensions. 82 | image = tf.image.resize_images(image, 83 | [resized_image_size, resized_image_size]) 84 | image.set_shape((resized_image_size, resized_image_size, 3)) 85 | 86 | image = tf.cast(image, tf.float32) * (1. / 127.5) - 1 87 | images = tf.train.shuffle_batch( 88 | [image], 89 | batch_size=batch_size, 90 | num_threads=1, 91 | capacity=1000 + 3 * batch_size, 92 | min_after_dequeue=1000) 93 | 94 | return images 95 | 96 | 97 | def override_if_not_in_args(flag, argument, args): 98 | """Checks if flags is in args, and if not it adds the flag to args.""" 99 | if flag not in args: 100 | args.extend([flag, argument]) 101 | --------------------------------------------------------------------------------