├── .gitignore ├── LICENSE ├── README.md ├── data_loaders ├── __init__.py ├── generate_tfr │ ├── __init__.py │ ├── generate.py │ ├── imagenet_oord.py │ └── lsun.py ├── get_data.py └── get_mnist_cifar.py ├── demo ├── README.md ├── __init__.py ├── align_face.py ├── get_manipulators.py ├── model.py ├── results │ ├── dec.png │ ├── img.png │ └── smile.png ├── script.sh ├── server.py ├── test │ └── img.png ├── videos.py └── web │ ├── canvas2image.js │ ├── glowDemo.css │ ├── glowDemo.js │ ├── html2canvas.min.js │ ├── index.html │ ├── load-image.all.min.js │ ├── load-image.all.min.js.map │ ├── media │ ├── DownloadIcon.png │ ├── EditIcon.png │ ├── beyonce.png │ ├── cersei.png │ ├── geoff.png │ ├── john.png │ ├── lena.png │ ├── leo.png │ ├── loading.png │ ├── louis.png │ ├── neil.png │ ├── placeholder.png │ ├── placeholder2.png │ ├── placeholder4.png │ ├── rashida.png │ ├── seth.png │ └── steve.png │ └── mock.css ├── graphics.py ├── memory_saving_gradients.py ├── model.py ├── optim.py ├── requirements.txt ├── tfops.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # Glow 4 | 5 | Code for reproducing results in ["Glow: Generative Flow with Invertible 1x1 Convolutions"](https://d4mucfpksywv.cloudfront.net/research-covers/glow/paper/glow.pdf) 6 | 7 | To use pretrained CelebA-HQ model, make your own manipulation vectors and run our interactive demo, check `demo` folder. 8 | 9 | ## Requirements 10 | 11 | - Tensorflow (tested with v1.8.0) 12 | - Horovod (tested with v0.13.8) and (Open)MPI 13 | 14 | Run 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | To setup (Open)MPI, check instructions on Horovod github [page](https://github.com/uber/horovod). 20 | 21 | ## Download datasets 22 | For small scale experiments, use MNIST/CIFAR-10 (directly downloaded by `train.py` using keras) 23 | 24 | For larger scale experiments, the datasets used are in the Google Cloud locations `https://openaipublic.azureedge.net/glow-demo/data/{dataset_name}-tfr.tar`. The dataset_names are below, we mention the exact preprocessing / downsampling method for a correct comparison of likelihood. 25 | 26 | Quantitative results 27 | - `imagenet-oord` - 20GB. Unconditional ImageNet 32x32 and 64x64, as described in PixelRNN/RealNVP papers (we downloaded [this](http://image-net.org/small/download.php) processed version). 28 | - `lsun_realnvp` - 140GB. LSUN 96x96. Random 64x64 crops taken at processing time, as described in RealNVP. 29 | 30 | Qualitative results 31 | - `celeba` - 4GB. CelebA-HQ 256x256 dataset, as described in Progressive growing of GAN's. For 1024x1024 version (120GB), use `celeba-full-tfr.tar` while downloading. 32 | - `imagenet` - 20GB. ImageNet 32x32 and 64x64 with class labels. Centre cropped, area downsampled. 33 | - `lsun` - 700GB. LSUN 256x256. Centre cropped, area downsampled. 34 | 35 | To download and extract celeb for example, run 36 | ``` 37 | wget https://openaipublic.azureedge.net/glow-demo/data/celeba-tfr.tar 38 | tar -xvf celeb-tfr.tar 39 | ``` 40 | Change `hps.data_dir` in train.py file to point to the above folder (or use the `--data_dir` flag when you run train.py) 41 | 42 | For `lsun`, since download can be quite big, you can instead follow the instructions in `data_loaders/generate_tfr/lsun.py` to generate the tfr file directly from LSUN images. `church_outdoor` will be the smallest category. 43 | 44 | ## Simple Train with 1 GPU 45 | 46 | Run wtih small depth to test 47 | ``` 48 | CUDA_VISIBLE_DEVICES=0 python train.py --depth 1 49 | ``` 50 | 51 | ## Train with multiple GPUs using MPI and Horovod 52 | 53 | Run default training script with 8 GPUs: 54 | ``` 55 | mpiexec -n 8 python train.py 56 | ``` 57 | 58 | ##### Ablation experiments 59 | 60 | ``` 61 | mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation [0/1/2] --flow_coupling [0/1] --seed [0/1/2] --learntop --lr 0.001 62 | ``` 63 | 64 | Pretrained models, logs and samples 65 | ``` 66 | wget https://openaipublic.azureedge.net/glow-demo/logs/abl-[reverse/shuffle/1x1]-[add/aff].tar 67 | ``` 68 | 69 | ##### CIFAR-10 Quantitative result 70 | 71 | ``` 72 | mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8 73 | ``` 74 | 75 | ##### ImageNet 32x32 Quantitative result 76 | 77 | ``` 78 | mpiexec -n 8 python train.py --problem imagenet-oord --image_size 32 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8 79 | ``` 80 | 81 | ##### ImageNet 64x64 Quantitative result 82 | ``` 83 | mpiexec -n 8 python train.py --problem imagenet-oord --image_size 64 --n_level 4 --depth 48 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8 84 | ``` 85 | 86 | ##### LSUN 64x64 Quantitative result 87 | ``` 88 | mpiexec -n 8 python train.py --problem lsun_realnvp --category [bedroom/church_outdoor/tower] --image_size 64 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8 89 | ``` 90 | 91 | Pretrained models, logs and samples 92 | ``` 93 | wget https://openaipublic.azureedge.net/glow-demo/logs/lsun-rnvp-[bdr/crh/twr].tar 94 | ``` 95 | 96 | ##### CelebA-HQ 256x256 Qualitative result 97 | 98 | ``` 99 | mpiexec -n 40 python train.py --problem celeba --image_size 256 --n_level 6 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 100 | ``` 101 | 102 | ##### LSUN 96x96 and 128x128 Qualitative result 103 | ``` 104 | mpiexec -n 40 python train.py --problem lsun --category [bedroom/church_outdoor/tower] --image_size [96/128] --n_level 5 --depth 64 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 105 | ``` 106 | 107 | Logs and samples 108 | ``` 109 | wget https://openaipublic.azureedge.net/glow-demo/logs/lsun-bdr-[96/128].tar 110 | ``` 111 | 112 | ##### Conditional CIFAR-10 Qualitative result 113 | ``` 114 | mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01 115 | ``` 116 | 117 | ##### Conditional ImageNet 32x32 Qualitative result 118 | ``` 119 | mpiexec -n 8 python train.py --problem imagenet --image_size 32 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01 120 | ``` 121 | -------------------------------------------------------------------------------- /data_loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/data_loaders/__init__.py -------------------------------------------------------------------------------- /data_loaders/generate_tfr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/data_loaders/generate_tfr/__init__.py -------------------------------------------------------------------------------- /data_loaders/generate_tfr/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Generate CelebA-HQ and Imagenet datasets 18 | For CelebA-HQ, first create original tfrecords file using https://github.com/tkarras/progressive_growing_of_gans/blob/master/dataset_tool.py 19 | For Imagenet, first create original tfrecords file using https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py 20 | Then, use this script to get our tfr file from those records. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | 29 | import tensorflow as tf 30 | import numpy as np 31 | from tqdm import tqdm 32 | from typing import Iterable 33 | 34 | _NUM_CHANNELS = 3 35 | 36 | 37 | _NUM_PARALLEL_FILE_READERS = 32 38 | _NUM_PARALLEL_MAP_CALLS = 32 39 | _DOWNSAMPLING = tf.image.ResizeMethod.BILINEAR 40 | _SHUFFLE_BUFFER = 1024 41 | 42 | 43 | def _int64_feature(value): 44 | if not isinstance(value, Iterable): 45 | value = [value] 46 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 47 | 48 | 49 | def _bytes_feature(value): 50 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 51 | 52 | 53 | def error(msg): 54 | print('Error: ' + msg) 55 | exit(1) 56 | 57 | 58 | def x_to_uint8(x): 59 | return tf.cast(tf.clip_by_value(tf.floor(x), 0, 255), 'uint8') 60 | 61 | 62 | def centre_crop(img): 63 | h, w = tf.shape(img)[0], tf.shape(img)[1] 64 | min_side = tf.minimum(h, w) 65 | h_offset = (h - min_side) // 2 66 | w_offset = (w - min_side) // 2 67 | return tf.image.crop_to_bounding_box(img, h_offset, w_offset, min_side, min_side) 68 | 69 | 70 | def downsample(img): 71 | return (img[0::2, 0::2, :] + img[0::2, 1::2, :] + img[1::2, 0::2, :] + img[1::2, 1::2, :]) * 0.25 72 | 73 | 74 | def parse_image(max_res): 75 | def _process_image(img): 76 | img = centre_crop(img) 77 | img = tf.image.resize_images( 78 | img, [max_res, max_res], method=_DOWNSAMPLING) 79 | img = tf.cast(img, 'float32') 80 | resolution_log2 = int(np.log2(max_res)) 81 | q_imgs = [] 82 | for lod in range(resolution_log2 - 1): 83 | if lod: 84 | img = downsample(img) 85 | quant = x_to_uint8(img) 86 | q_imgs.append(quant) 87 | return q_imgs 88 | 89 | def _parse_image(example): 90 | feature_map = { 91 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, 92 | default_value=''), 93 | 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, 94 | default_value=-1) 95 | } 96 | features = tf.parse_single_example(example, feature_map) 97 | img, label = features['image/encoded'], features['image/class/label'] 98 | label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32) - 1 99 | img = tf.image.decode_jpeg(img, channels=_NUM_CHANNELS) 100 | imgs = _process_image(img) 101 | parsed = (label, *imgs) 102 | return parsed 103 | 104 | return _parse_image 105 | 106 | 107 | def parse_celeba_image(max_res, transpose=False): 108 | def _process_image(img): 109 | img = tf.cast(img, 'float32') 110 | resolution_log2 = int(np.log2(max_res)) 111 | q_imgs = [] 112 | for lod in range(resolution_log2 - 1): 113 | if lod: 114 | img = downsample(img) 115 | quant = x_to_uint8(img) 116 | q_imgs.append(quant) 117 | return q_imgs 118 | 119 | def _parse_image(example): 120 | features = tf.parse_single_example(example, features={ 121 | 'shape': tf.FixedLenFeature([3], tf.int64), 122 | 'data': tf.FixedLenFeature([], tf.string), 123 | 'attr': tf.FixedLenFeature([40], tf.int64)}) 124 | shape = features['shape'] 125 | data = features['data'] 126 | attr = features['attr'] 127 | data = tf.decode_raw(data, tf.uint8) 128 | img = tf.reshape(data, shape) 129 | if transpose: 130 | img = tf.transpose(img, (1, 2, 0)) # CHW -> HWC 131 | imgs = _process_image(img) 132 | parsed = (attr, *imgs) 133 | return parsed 134 | 135 | return _parse_image 136 | 137 | 138 | def get_tfr_files(data_dir, split, lgres): 139 | data_dir = os.path.join(data_dir, split) 140 | tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir)) 141 | tfr_files = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (lgres) 142 | return tfr_files 143 | 144 | 145 | def get_tfr_file(data_dir, split, lgres): 146 | if split: 147 | data_dir = os.path.join(data_dir, split) 148 | tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir)) 149 | tfr_file = tfr_prefix + '-r%02d.tfrecords' % (lgres) 150 | return tfr_file 151 | 152 | 153 | def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write): 154 | _NUM_IMAGES = { 155 | 'train': 27000, 156 | 'validation': 3000, 157 | } 158 | 159 | _NUM_SHARDS = { 160 | 'train': 120, 161 | 'validation': 40, 162 | } 163 | resolution_log2 = int(np.log2(max_res)) 164 | if max_res != 2 ** resolution_log2: 165 | error('Input image resolution must be a power-of-two') 166 | with tf.Session() as sess: 167 | print("Reading data from ", data_dir) 168 | if split: 169 | tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res))) 170 | files = tf.data.Dataset.list_files(tfr_files) 171 | dset = files.apply(tf.contrib.data.parallel_interleave( 172 | tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS)) 173 | transpose = False 174 | else: 175 | tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res))) 176 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='') 177 | transpose = True 178 | 179 | parse_fn = parse_celeba_image(max_res, transpose) 180 | dset = dset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS) 181 | dset = dset.prefetch(1) 182 | iterator = dset.make_one_shot_iterator() 183 | _attr, *_imgs = iterator.get_next() 184 | sess.run(tf.global_variables_initializer()) 185 | splits = [split] if split else ["validation", "train"] 186 | for split in splits: 187 | total_imgs = _NUM_IMAGES[split] 188 | shards = _NUM_SHARDS[split] 189 | with TFRecordExporter(os.path.join(tfrecord_dir, split), resolution_log2, total_imgs, shards) as tfr: 190 | for _ in tqdm(range(total_imgs)): 191 | attr, *imgs = sess.run([_attr, *_imgs]) 192 | if write: 193 | tfr.add_image(0, imgs, attr) 194 | if write: 195 | assert tfr.cur_images == total_imgs, ( 196 | tfr.cur_images, total_imgs) 197 | 198 | #attr, *imgs = sess.run([_attr, *_imgs]) 199 | 200 | 201 | def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write): 202 | _NUM_IMAGES = { 203 | 'train': 1281167, 204 | 'validation': 50000, 205 | } 206 | 207 | _NUM_FILES = _NUM_SHARDS = { 208 | 'train': 2000, 209 | 'validation': 80, 210 | } 211 | resolution_log2 = int(np.log2(max_res)) 212 | if max_res != 2 ** resolution_log2: 213 | error('Input image resolution must be a power-of-two') 214 | 215 | with tf.Session() as sess: 216 | is_training = (split == 'train') 217 | if is_training: 218 | files = tf.data.Dataset.list_files( 219 | os.path.join(data_dir, 'train-*-of-01024')) 220 | else: 221 | files = tf.data.Dataset.list_files( 222 | os.path.join(data_dir, 'validation-*-of-00128')) 223 | 224 | files = files.shuffle(buffer_size=_NUM_FILES[split]) 225 | 226 | dataset = files.apply(tf.contrib.data.parallel_interleave( 227 | tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS)) 228 | 229 | dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) 230 | parse_fn = parse_image(max_res) 231 | dataset = dataset.map( 232 | parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS) 233 | dataset = dataset.prefetch(1) 234 | iterator = dataset.make_one_shot_iterator() 235 | 236 | _label, *_imgs = iterator.get_next() 237 | 238 | sess.run(tf.global_variables_initializer()) 239 | 240 | total_imgs = _NUM_IMAGES[split] 241 | shards = _NUM_SHARDS[split] 242 | tfrecord_dir = os.path.join(tfrecord_dir, split) 243 | with TFRecordExporter(tfrecord_dir, resolution_log2, total_imgs, shards) as tfr: 244 | for _ in tqdm(range(total_imgs)): 245 | label, *imgs = sess.run([_label, *_imgs]) 246 | if write: 247 | tfr.add_image(label, imgs, []) 248 | assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs) 249 | 250 | #label, *imgs = sess.run([_label, *_imgs]) 251 | 252 | 253 | class TFRecordExporter: 254 | def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print_progress=True, progress_interval=10): 255 | self.tfrecord_dir = tfrecord_dir 256 | self.tfr_prefix = os.path.join( 257 | self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) 258 | self.resolution_log2 = resolution_log2 259 | self.expected_images = expected_images 260 | 261 | self.cur_images = 0 262 | self.shape = None 263 | self.tfr_writers = [] 264 | self.print_progress = print_progress 265 | self.progress_interval = progress_interval 266 | if self.print_progress: 267 | print('Creating dataset "%s"' % tfrecord_dir) 268 | if not os.path.isdir(self.tfrecord_dir): 269 | os.makedirs(self.tfrecord_dir) 270 | assert (os.path.isdir(self.tfrecord_dir)) 271 | tfr_opt = tf.python_io.TFRecordOptions( 272 | tf.python_io.TFRecordCompressionType.NONE) 273 | for lod in range(self.resolution_log2 - 1): 274 | p_shard = np.array_split( 275 | np.random.permutation(expected_images), shards) 276 | img_to_shard = np.zeros(expected_images, dtype=np.int) 277 | writers = [] 278 | for shard in range(shards): 279 | img_to_shard[p_shard[shard]] = shard 280 | tfr_file = self.tfr_prefix + \ 281 | '-r%02d-s-%04d-of-%04d.tfrecords' % ( 282 | self.resolution_log2 - lod, shard, shards) 283 | writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) 284 | #print(np.unique(img_to_shard, return_counts=True)) 285 | counts = np.unique(img_to_shard, return_counts=True)[1] 286 | assert len(counts) == shards 287 | print("Smallest and largest shards have size", 288 | np.min(counts), np.max(counts)) 289 | self.tfr_writers.append((writers, img_to_shard)) 290 | 291 | def close(self): 292 | if self.print_progress: 293 | print('%-40s\r' % 'Flushing data...', end='', flush=True) 294 | for (writers, _) in self.tfr_writers: 295 | for writer in writers: 296 | writer.close() 297 | self.tfr_writers = [] 298 | if self.print_progress: 299 | print('%-40s\r' % '', end='', flush=True) 300 | print('Added %d images.' % self.cur_images) 301 | 302 | def add_image(self, label, imgs, attr): 303 | assert len(imgs) == len(self.tfr_writers) 304 | # if self.print_progress and self.cur_images % self.progress_interval == 0: 305 | # print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True) 306 | for lod, (writers, img_to_shard) in enumerate(self.tfr_writers): 307 | quant = imgs[lod] 308 | size = 2 ** (self.resolution_log2 - lod) 309 | assert quant.shape == (size, size, 3), quant.shape 310 | ex = tf.train.Example( 311 | features=tf.train.Features( 312 | feature={ 313 | 'shape': _int64_feature(quant.shape), 314 | 'data': _bytes_feature(quant.tostring()), 315 | 'label': _int64_feature(label), 316 | 'attr': _int64_feature(attr) 317 | } 318 | ) 319 | ) 320 | writers[img_to_shard[self.cur_images]].write( 321 | ex.SerializeToString()) 322 | self.cur_images += 1 323 | 324 | # def add_labels(self, labels): 325 | # if self.print_progress: 326 | # print('%-40s\r' % 'Saving labels...', end='', flush=True) 327 | # assert labels.shape[0] == self.cur_images 328 | # with open(self.tfr_prefix + '-rxx.labels', 'wb') as f: 329 | # np.save(f, labels.astype(np.float32)) 330 | 331 | def __enter__(self): 332 | return self 333 | 334 | def __exit__(self, *args): 335 | self.close() 336 | 337 | 338 | if __name__ == "__main__": 339 | import argparse 340 | parser = argparse.ArgumentParser() 341 | parser.add_argument("--data_dir", type=str, required=True) 342 | parser.add_argument("--max_res", type=int, default=256, help="Image size") 343 | parser.add_argument("--tfrecord_dir", type=str, 344 | required=True, help='place to dump') 345 | parser.add_argument("--write", action='store_true', 346 | help="Whether to write") 347 | hps = parser.parse_args() # So error if typo 348 | #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write) 349 | #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write) 350 | dump_celebahq(hps.data_dir, hps.tfrecord_dir, 351 | hps.max_res, 'validation', hps.write) 352 | dump_celebahq(hps.data_dir, hps.tfrecord_dir, 353 | hps.max_res, 'train', hps.write) 354 | -------------------------------------------------------------------------------- /data_loaders/generate_tfr/imagenet_oord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Generate tfrecords for ImageNet 32x32 and 64x64. 18 | 19 | # Get images 20 | Downloaded images from http://image-net.org/small/download.php, and unzip them. 21 | (Move one file from training to test to have 50000 test images) 22 | 23 | # Get tfr file from images 24 | Use this script to generate the tfr file. 25 | python imagenet_oord.py --res [RES] --tfrecord_dir [OUTPUT_FOLDER] --write 26 | 27 | """ 28 | 29 | from __future__ import print_function 30 | 31 | import os 32 | import os.path 33 | 34 | import scipy.io 35 | import scipy.io.wavfile 36 | import scipy.ndimage 37 | import tensorflow as tf 38 | import numpy as np 39 | from tqdm import tqdm 40 | 41 | from typing import Iterable 42 | 43 | 44 | def _int64_feature(value): 45 | if not isinstance(value, Iterable): 46 | value = [value] 47 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 48 | 49 | 50 | def _bytes_feature(value): 51 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 52 | 53 | 54 | def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write): 55 | """Main converter function.""" 56 | # fn_root = FLAGS.fn_root 57 | # max_res = FLAGS.max_res 58 | resolution_log2 = int(np.log2(max_res)) 59 | tfr_prefix = os.path.join(tfrecord_dir, os.path.basename(tfrecord_dir)) 60 | 61 | print("Checking in", fn_root) 62 | img_fn_list = os.listdir(fn_root) 63 | img_fn_list = [img_fn for img_fn in img_fn_list 64 | if img_fn.endswith('.png')] 65 | num_examples = len(img_fn_list) 66 | print("Found", num_examples) 67 | assert num_examples == expected_images 68 | 69 | # Sharding 70 | tfr_opt = tf.python_io.TFRecordOptions( 71 | tf.python_io.TFRecordCompressionType.NONE) 72 | p_shard = np.array_split(np.random.permutation(expected_images), shards) 73 | img_to_shard = np.zeros(expected_images, dtype=np.int) 74 | writers = [] 75 | for shard in range(shards): 76 | img_to_shard[p_shard[shard]] = shard 77 | tfr_file = tfr_prefix + \ 78 | '-r%02d-s-%04d-of-%04d.tfrecords' % ( 79 | resolution_log2, shard, shards) 80 | writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) 81 | 82 | # print(np.unique(img_to_shard, return_counts=True)) 83 | counts = np.unique(img_to_shard, return_counts=True)[1] 84 | assert len(counts) == shards 85 | print("Smallest and largest shards have size", 86 | np.min(counts), np.max(counts)) 87 | 88 | for example_idx, img_fn in enumerate(tqdm(img_fn_list)): 89 | shard = img_to_shard[example_idx] 90 | img = scipy.ndimage.imread(os.path.join(fn_root, img_fn)) 91 | rows = img.shape[0] 92 | cols = img.shape[1] 93 | depth = img.shape[2] 94 | shape = (rows, cols, depth) 95 | img = img.astype("uint8") 96 | img = img.tostring() 97 | example = tf.train.Example( 98 | features=tf.train.Features( 99 | feature={ 100 | "shape": _int64_feature(shape), 101 | "data": _bytes_feature(img), 102 | "label": _int64_feature(0) 103 | } 104 | ) 105 | ) 106 | if write: 107 | writers[shard].write(example.SerializeToString()) 108 | 109 | print('%-40s\r' % 'Flushing data...', end='', flush=True) 110 | for writer in writers: 111 | writer.close() 112 | 113 | print('%-40s\r' % '', end='', flush=True) 114 | print('Added %d images.' % num_examples) 115 | 116 | 117 | if __name__ == "__main__": 118 | import argparse 119 | 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--res", type=int, default=32, help="Image size") 122 | parser.add_argument("--tfrecord_dir", type=str, 123 | required=True, help='place to dump') 124 | parser.add_argument("--write", action='store_true', 125 | help="Whether to write") 126 | hps = parser.parse_args() 127 | 128 | # Imagenet 129 | _NUM_IMAGES = { 130 | 'train': 1281148, 131 | 'validation': 50000, 132 | } 133 | 134 | _NUM_SHARDS = { 135 | 'train': 2000, 136 | 'validation': 80, 137 | } 138 | 139 | _FILE = { 140 | 'train': 'train_%dx%d' % (hps.res, hps.res), 141 | 'validation': 'valid_%dx%d' % (hps.res, hps.res), 142 | } 143 | 144 | for split in ['validation', 'train']: 145 | fn_root = _FILE[split] 146 | tfrecord_dir = os.path.join(hps.tfrecord_dir, split) 147 | total_imgs = _NUM_IMAGES[split] 148 | shards = _NUM_SHARDS[split] 149 | if not os.path.exists(tfrecord_dir): 150 | os.mkdir(tfrecord_dir) 151 | dump(fn_root, tfrecord_dir, hps.res, total_imgs, shards, hps.write) 152 | -------------------------------------------------------------------------------- /data_loaders/generate_tfr/lsun.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """" 17 | LSUN dataset 18 | 19 | # Get image files 20 | Download the LSUN dataset as follows: 21 | git clone https://github.com/fyu/lsun.git 22 | cd lsun 23 | python2.7 download.py -c [CATEGORY] 24 | Unzip the downloaded .zip files and execute: 25 | python2.7 data.py export [IMAGE_DB_PATH] --out_dir [LSUN_FOLDER] --flat 26 | 27 | # Get tfr file from images 28 | Use this script to generate the tfr file. 29 | python lsun.py --res [RES] --category [CATEGORY] --lsun_dir [LSUN_FOLDER] --tfrecord_dir [OUTPUT_FOLDER] --write [--realnvp] 30 | Without realnvp flag you get 256x256 centre cropped area downsampled images, with flag you get 96x96 images with realnvp preprocessing. 31 | """ 32 | 33 | from __future__ import print_function 34 | 35 | import os 36 | import os.path 37 | 38 | import numpy 39 | import skimage.transform 40 | from PIL import Image 41 | import tensorflow as tf 42 | import numpy as np 43 | from tqdm import tqdm 44 | 45 | from typing import Iterable 46 | 47 | 48 | def _int64_feature(value): 49 | if not isinstance(value, Iterable): 50 | value = [value] 51 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 52 | 53 | 54 | def _bytes_feature(value): 55 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 56 | 57 | 58 | def centre_crop(img): 59 | h, w = img.shape[:2] 60 | crop = min(h, w) 61 | return img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2] 62 | 63 | 64 | def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write, realnvp=False): 65 | """Main converter function.""" 66 | resolution_log2 = int(np.log2(max_res)) 67 | tfr_prefix = os.path.join(tfrecord_dir, os.path.basename(tfrecord_dir)) 68 | 69 | print("Checking in", fn_root) 70 | img_fn_list = os.listdir(fn_root) 71 | img_fn_list = [img_fn for img_fn in img_fn_list 72 | if img_fn.endswith('.webp')] 73 | num_examples = len(img_fn_list) 74 | print("Found", num_examples) 75 | assert num_examples == expected_images 76 | 77 | tfr_opt = tf.python_io.TFRecordOptions( 78 | tf.python_io.TFRecordCompressionType.NONE) 79 | p_shard = np.array_split(np.random.permutation(expected_images), shards) 80 | img_to_shard = np.zeros(expected_images, dtype=np.int) 81 | writers = [] 82 | for shard in tqdm(range(shards)): 83 | img_to_shard[p_shard[shard]] = shard 84 | tfr_file = tfr_prefix + \ 85 | '-r%02d-s-%04d-of-%04d.tfrecords' % ( 86 | resolution_log2, shard, shards) 87 | writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) 88 | 89 | # print(np.unique(img_to_shard, return_counts=True)) 90 | counts = np.unique(img_to_shard, return_counts=True)[1] 91 | assert len(counts) == shards 92 | print("Smallest and largest shards have size", 93 | np.min(counts), np.max(counts)) 94 | 95 | for example_idx, img_fn in enumerate(tqdm(img_fn_list)): 96 | shard = img_to_shard[example_idx] 97 | img = numpy.array(Image.open(os.path.join(fn_root, img_fn))) 98 | rows = img.shape[0] 99 | cols = img.shape[1] 100 | if realnvp: 101 | downscale = min(rows / 96., cols / 96.) 102 | img = skimage.transform.pyramid_reduce(img, downscale) 103 | img *= 255. 104 | img = img.astype("uint8") 105 | else: 106 | img = centre_crop(img) 107 | img = Image.fromarray(img, 'RGB') 108 | img = img.resize((max_res, max_res), Image.ANTIALIAS) 109 | img = np.asarray(img) 110 | rows = img.shape[0] 111 | cols = img.shape[1] 112 | depth = img.shape[2] 113 | shape = (rows, cols, depth) 114 | img = img.tostring() 115 | example = tf.train.Example( 116 | features=tf.train.Features( 117 | feature={ 118 | "shape": _int64_feature(shape), 119 | "data": _bytes_feature(img), 120 | "label": _int64_feature(0) 121 | } 122 | ) 123 | ) 124 | if write: 125 | writers[shard].write(example.SerializeToString()) 126 | 127 | print('%-40s\r' % 'Flushing data...', end='', flush=True) 128 | for writer in writers: 129 | writer.close() 130 | 131 | print('%-40s\r' % '', end='', flush=True) 132 | print('Added %d images.' % num_examples) 133 | 134 | 135 | if __name__ == "__main__": 136 | import argparse 137 | 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument("--category", type=str, help="LSUN category") 140 | parser.add_argument("--realnvp", action='store_true', 141 | help="Use this flag to do realnvp preprocessing instead of our centre-crops") 142 | parser.add_argument("--res", type=int, default=256, help="Image size") 143 | parser.add_argument("--lsun_dir", type=str, 144 | required=True, help="place of lsun dir") 145 | parser.add_argument("--tfrecord_dir", type=str, 146 | required=True, help='place to dump') 147 | parser.add_argument("--write", action='store_true', 148 | help="Whether to write") 149 | hps = parser.parse_args() 150 | 151 | # LSUN 152 | # CATEGORIES = ["bedroom", "bridge", "church_outdoor", "classroom", "conference_room", "dining_room", "kitchen", "living"] 153 | base_tfr = hps.tfrecord_dir 154 | res = hps.res 155 | for realnvp in [False, True]: 156 | for category in ["tower", "church_outdoor", "bedroom"]: 157 | hps.realnvp = realnvp 158 | hps.category = category 159 | if realnvp: 160 | hps.tfrecord_dir = "%s_%s/%s" % (base_tfr, 161 | "realnvp", hps.category) 162 | else: 163 | hps.tfrecord_dir = "%s/%s" % (base_tfr, hps.category) 164 | print(hps.realnvp, hps.category, hps.lsun_dir, hps.tfrecord_dir) 165 | imgs = { 166 | 'bedroom': 3033042, 167 | 'bridge': 818687, 168 | 'church_outdoor': 126227, 169 | 'classroom': 168103, 170 | 'conference_room': 229069, 171 | 'dining_room': 657571, 172 | 'kitchen': 2212277, 173 | 'living_room': 1315802, 174 | 'restaurant': 626331, 175 | 'tower': 708264 176 | } 177 | 178 | _NUM_IMAGES = { 179 | 'train': imgs[hps.category], 180 | 'validation': 300, 181 | } 182 | 183 | _NUM_SHARDS = { 184 | 'train': 2560, 185 | 'validation': 1, 186 | } 187 | 188 | _FILE = { 189 | 'train': os.path.join(hps.lsun_dir, '%s_train' % hps.category), 190 | 'validation': os.path.join(hps.lsun_dir, '%s_val' % hps.category) 191 | 192 | } 193 | 194 | if hps.realnvp: 195 | res = 96 196 | else: 197 | res = hps.res 198 | 199 | for split in ['validation', 'train']: 200 | fn_root = _FILE[split] 201 | tfrecord_dir = os.path.join(hps.tfrecord_dir, split) 202 | total_imgs = _NUM_IMAGES[split] 203 | shards = _NUM_SHARDS[split] 204 | if not os.path.exists(tfrecord_dir): 205 | os.mkdir(tfrecord_dir) 206 | dump(fn_root, tfrecord_dir, res, total_imgs, 207 | shards, hps.write, hps.realnvp) 208 | -------------------------------------------------------------------------------- /data_loaders/get_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | import glob 5 | 6 | _FILES_SHUFFLE = 1024 7 | _SHUFFLE_FACTOR = 4 8 | 9 | 10 | def parse_tfrecord_tf(record, res, rnd_crop): 11 | features = tf.parse_single_example(record, features={ 12 | 'shape': tf.FixedLenFeature([3], tf.int64), 13 | 'data': tf.FixedLenFeature([], tf.string), 14 | 'label': tf.FixedLenFeature([1], tf.int64)}) 15 | # label is always 0 if uncondtional 16 | # to get CelebA attr, add 'attr': tf.FixedLenFeature([40], tf.int64) 17 | data, label, shape = features['data'], features['label'], features['shape'] 18 | label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32) 19 | img = tf.decode_raw(data, tf.uint8) 20 | if rnd_crop: 21 | # For LSUN Realnvp only - random crop 22 | img = tf.reshape(img, shape) 23 | img = tf.random_crop(img, [res, res, 3]) 24 | img = tf.reshape(img, [res, res, 3]) 25 | return img, label # to get CelebA attr, also return attr 26 | 27 | 28 | def input_fn(tfr_file, shards, rank, pmap, fmap, n_batch, resolution, rnd_crop, is_training): 29 | files = tf.data.Dataset.list_files(tfr_file) 30 | if ('lsun' not in tfr_file) or is_training: 31 | # For 'lsun' validation, only one shard and each machine goes over the full dataset 32 | # each worker works on a subset of the data 33 | files = files.shard(shards, rank) 34 | if is_training: 35 | # shuffle order of files in shard 36 | files = files.shuffle(buffer_size=_FILES_SHUFFLE) 37 | dset = files.apply(tf.contrib.data.parallel_interleave( 38 | tf.data.TFRecordDataset, cycle_length=fmap)) 39 | if is_training: 40 | dset = dset.shuffle(buffer_size=n_batch * _SHUFFLE_FACTOR) 41 | dset = dset.repeat() 42 | dset = dset.map(lambda x: parse_tfrecord_tf( 43 | x, resolution, rnd_crop), num_parallel_calls=pmap) 44 | dset = dset.batch(n_batch) 45 | dset = dset.prefetch(1) 46 | itr = dset.make_one_shot_iterator() 47 | return itr 48 | 49 | 50 | def get_tfr_file(data_dir, split, res_lg2): 51 | data_dir = os.path.join(data_dir, split) 52 | tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir)) 53 | tfr_file = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (res_lg2) 54 | files = glob.glob(tfr_file) 55 | assert len(files) == int(files[0].split( 56 | "-")[-1].split(".")[0]), "Not all tfrecords files present at %s" % tfr_prefix 57 | return tfr_file 58 | 59 | 60 | def get_data(sess, data_dir, shards, rank, pmap, fmap, n_batch_train, n_batch_test, n_batch_init, resolution, rnd_crop): 61 | assert resolution == 2 ** int(np.log2(resolution)) 62 | 63 | train_file = get_tfr_file(data_dir, 'train', int(np.log2(resolution))) 64 | valid_file = get_tfr_file(data_dir, 'validation', int(np.log2(resolution))) 65 | 66 | train_itr = input_fn(train_file, shards, rank, pmap, 67 | fmap, n_batch_train, resolution, rnd_crop, True) 68 | valid_itr = input_fn(valid_file, shards, rank, pmap, 69 | fmap, n_batch_test, resolution, rnd_crop, False) 70 | 71 | data_init = make_batch(sess, train_itr, n_batch_train, n_batch_init) 72 | 73 | return train_itr, valid_itr, data_init 74 | 75 | # 76 | 77 | 78 | def make_batch(sess, itr, itr_batch_size, required_batch_size): 79 | ib, rb = itr_batch_size, required_batch_size 80 | #assert rb % ib == 0 81 | k = int(np.ceil(rb / ib)) 82 | xs, ys = [], [] 83 | data = itr.get_next() 84 | for i in range(k): 85 | x, y = sess.run(data) 86 | xs.append(x) 87 | ys.append(y) 88 | x, y = np.concatenate(xs)[:rb], np.concatenate(ys)[:rb] 89 | return {'x': x, 'y': y} 90 | -------------------------------------------------------------------------------- /data_loaders/get_mnist_cifar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def downsample(x, resolution): 5 | assert x.dtype == np.float32 6 | assert x.shape[1] % resolution == 0 7 | assert x.shape[2] % resolution == 0 8 | if x.shape[1] == x.shape[2] == resolution: 9 | return x 10 | s = x.shape 11 | x = np.reshape(x, [s[0], resolution, s[1] // resolution, 12 | resolution, s[2] // resolution, s[3]]) 13 | x = np.mean(x, (2, 4)) 14 | return x 15 | 16 | 17 | def x_to_uint8(x): 18 | x = np.clip(np.floor(x), 0, 255) 19 | return x.astype(np.uint8) 20 | 21 | 22 | def shard(data, shards, rank): 23 | # Determinisitc shards 24 | x, y = data 25 | assert x.shape[0] == y.shape[0] 26 | assert x.shape[0] % shards == 0 27 | assert 0 <= rank < shards 28 | size = x.shape[0] // shards 29 | ind = rank*size 30 | return x[ind:ind+size], y[ind:ind+size] 31 | 32 | 33 | def get_data(problem, shards, rank, data_augmentation_level, n_batch_train, n_batch_test, n_batch_init, resolution): 34 | if problem == 'mnist': 35 | from keras.datasets import mnist 36 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 37 | y_train = np.reshape(y_train, [-1]) 38 | y_test = np.reshape(y_test, [-1]) 39 | # Pad with zeros to make 32x32 40 | x_train = np.lib.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'minimum') 41 | # Pad with zeros to make 32x23 42 | x_test = np.lib.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'minimum') 43 | x_train = np.tile(np.reshape(x_train, (-1, 32, 32, 1)), (1, 1, 1, 3)) 44 | x_test = np.tile(np.reshape(x_test, (-1, 32, 32, 1)), (1, 1, 1, 3)) 45 | elif problem == 'cifar10': 46 | from keras.datasets import cifar10 47 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 48 | y_train = np.reshape(y_train, [-1]) 49 | y_test = np.reshape(y_test, [-1]) 50 | else: 51 | raise Exception() 52 | 53 | print('n_train:', x_train.shape[0], 'n_test:', x_test.shape[0]) 54 | 55 | # Shard before any shuffling 56 | x_train, y_train = shard((x_train, y_train), shards, rank) 57 | x_test, y_test = shard((x_test, y_test), shards, rank) 58 | 59 | print('n_shard_train:', x_train.shape[0], 'n_shard_test:', x_test.shape[0]) 60 | 61 | from keras.preprocessing.image import ImageDataGenerator 62 | datagen_test = ImageDataGenerator() 63 | if data_augmentation_level == 0: 64 | datagen_train = ImageDataGenerator() 65 | else: 66 | if problem == 'mnist': 67 | datagen_train = ImageDataGenerator( 68 | width_shift_range=0.1, 69 | height_shift_range=0.1 70 | ) 71 | elif problem == 'cifar10': 72 | if data_augmentation_level == 1: 73 | datagen_train = ImageDataGenerator( 74 | width_shift_range=0.1, 75 | height_shift_range=0.1 76 | ) 77 | elif data_augmentation_level == 2: 78 | datagen_train = ImageDataGenerator( 79 | width_shift_range=0.1, 80 | height_shift_range=0.1, 81 | horizontal_flip=True, 82 | rotation_range=15, # degrees rotation 83 | zoom_range=0.1, 84 | shear_range=0.02, 85 | ) 86 | else: 87 | raise Exception() 88 | else: 89 | raise Exception() 90 | 91 | datagen_train.fit(x_train) 92 | datagen_test.fit(x_test) 93 | train_flow = datagen_train.flow(x_train, y_train, n_batch_train) 94 | test_flow = datagen_test.flow(x_test, y_test, n_batch_test, shuffle=False) 95 | 96 | def make_iterator(flow, resolution): 97 | def iterator(): 98 | x_full, y = flow.next() 99 | x_full = x_full.astype(np.float32) 100 | x = downsample(x_full, resolution) 101 | x = x_to_uint8(x) 102 | return x, y 103 | 104 | return iterator 105 | 106 | #init_iterator = make_iterator(train_flow, resolution) 107 | train_iterator = make_iterator(train_flow, resolution) 108 | test_iterator = make_iterator(test_flow, resolution) 109 | 110 | # Get data for initialization 111 | data_init = make_batch(train_iterator, n_batch_train, n_batch_init) 112 | 113 | return train_iterator, test_iterator, data_init 114 | 115 | 116 | def make_batch(iterator, iterator_batch_size, required_batch_size): 117 | ib, rb = iterator_batch_size, required_batch_size 118 | #assert rb % ib == 0 119 | k = int(np.ceil(rb / ib)) 120 | xs, ys = [], [] 121 | for i in range(k): 122 | x, y = iterator() 123 | xs.append(x) 124 | ys.append(y) 125 | x, y = np.concatenate(xs)[:rb], np.concatenate(ys)[:rb] 126 | return {'x': x, 'y': y} 127 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | Code for the demo used in blog post. 2 | 3 | # Setup 4 | Run `./script.sh`. 5 | 6 | The script install pip packages, downloads pretrained model weights, 7 | manipulation vectors and a facial landmarks detector for aligning input faces. 8 | 9 | # Using pre-trained model 10 | To use pre-trained CelebA-HQ model for encoding/decoding/manipulating images, use `model.py` 11 | 12 | If your image is not aligned, use `align_face.py` to align image. 13 | 14 | To create videos, check `videos.py` 15 | 16 | # Create manipulation vectors for an attribute of your choice 17 | Scrape images from the internet for an attribute of your choice (say red-hair vs not red-hair). Then, to obtain manipulation vectors from, use `get_manipulators.py` 18 | 19 | To see how it was done for the CelebA-HQ dataset (which has 40 attributes), 20 | first download the input images (x.npy), their attributes (attr.npy) and their encoding (z.npy) 21 | ``` 22 | curl https://openaipublic.azureedge.net/glow-demo/celeba-hq/x.npy > x.npy 23 | curl https://openaipublic.azureedge.net/glow-demo/celeba-hq/attr.npy > attr.npy 24 | curl https://openaipublic.azureedge.net/glow-demo/celeba-hq/z.npy > z.npy 25 | ``` 26 | Then, run `get_manipulators.py` 27 | 28 | # Run server and client for demo 29 | To run server for demo, run `python server.py`. 30 | 31 | To run client, run `python -m http.server` (starts a local http server at port 8000) and open `0.0.0.0:8000/web` in your browser. 32 | 33 | To test client, upload `test/img.png`. You should see aligned image and be able to move sliders. 34 | 35 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/__init__.py -------------------------------------------------------------------------------- /demo/align_face.py: -------------------------------------------------------------------------------- 1 | # OLD USAGE 2 | # python align_faces.py --shape-predictor shape_predictor_68_face_landmarks.dat --image images/example_01.jpg 3 | 4 | # import the necessary packages 5 | from imutils.face_utils import FaceAligner 6 | from PIL import Image 7 | import numpy as np 8 | # import argparse 9 | import imutils 10 | import dlib 11 | import cv2 12 | 13 | # construct the argument parser and parse the arguments 14 | # ap = argparse.ArgumentParser() 15 | # ap.add_argument("--shape-predictor", help="path to facial landmark predictor", default='shape_predictor_68_face_landmarks.dat') 16 | # ap.add_argument("--input", help="path to input images", default='input_raw') 17 | # ap.add_argument("--output", help="path to input images", default='input_aligned') 18 | # args = vars(ap.parse_args()) 19 | 20 | # initialize dlib's face detector (HOG-based) and then create 21 | # the facial landmark predictor and the face aligner 22 | detector = dlib.get_frontal_face_detector() 23 | predictor = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat') 24 | fa = FaceAligner(predictor, desiredFaceWidth=256, 25 | desiredLeftEye=(0.371, 0.480)) 26 | 27 | 28 | # Input: numpy array for image with RGB channels 29 | # Output: (numpy array, face_found) 30 | def align_face(img): 31 | img = img[:, :, ::-1] # Convert from RGB to BGR format 32 | img = imutils.resize(img, width=800) 33 | 34 | # detect faces in the grayscale image 35 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 36 | rects = detector(gray, 2) 37 | 38 | if len(rects) > 0: 39 | # align the face using facial landmarks 40 | align_img = fa.align(img, gray, rects[0])[:, :, ::-1] 41 | align_img = np.array(Image.fromarray(align_img).convert('RGB')) 42 | return align_img, True 43 | else: 44 | # No face found 45 | return None, False 46 | 47 | # Input: img_path 48 | # Output: aligned_img if face_found, else None 49 | def align(img_path): 50 | img = Image.open(img_path) 51 | img = img.convert('RGB') # if image is RGBA or Grayscale etc 52 | img = np.array(img) 53 | x, face_found = align_face(img) 54 | return x -------------------------------------------------------------------------------- /demo/get_manipulators.py: -------------------------------------------------------------------------------- 1 | # To get x.npy, attr.npy, and z.npy, run in command line 2 | # curl https://openaipublic.azureedge.net/glow-demo/celeba-hq/x.npy > x.npy 3 | # curl https://openaipublic.azureedge.net/glow-demo/celeba-hq/attr.npy > attr.npy 4 | # curl https://openaipublic.azureedge.net/glow-demo/celeba-hq/z.npy > z.npy 5 | 6 | import pickle 7 | import numpy as np 8 | import model 9 | from align_face import align_face 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | # Align input images 14 | def get_aligned(img_paths): 15 | xs = [] 16 | for img_path in img_paths: 17 | img = Image.open(img_path) 18 | img = img.convert('RGB') # if image is RGBA or Grayscale etc 19 | img = np.array(img) 20 | x, face_found = align_face(img) 21 | if face_found: 22 | xs.append(x) 23 | x = np.concatenate(xs, axis=0) 24 | return x 25 | 26 | # Input data. 30000 aligned images of shape 256x256x3 27 | # x = get_aligned(img_paths) 28 | x = np.load('x.npy') 29 | print("Loaded inputs") 30 | 31 | # Encode all inputs 32 | def get_z(x): 33 | bs = 10 34 | x = x.reshape((-1, bs, 256, 256, 3)) 35 | z = [] 36 | for _x in tqdm(x): 37 | z.append(model.encode(_x)) 38 | z = np.concatenate(z, axis=0) 39 | return z 40 | 41 | # z = get_z(x) 42 | z = np.load('z.npy') 43 | print("Got encodings") 44 | 45 | # Get manipulation vector based on attribute 46 | attr = np.load('attr.npy') 47 | 48 | def get_manipulator(index): 49 | z_pos = [z[i] for i in range(len(x)) if attr[i][index] == 1] 50 | z_neg = [z[i] for i in range(len(x)) if attr[i][index] == -1] 51 | 52 | z_pos = np.mean(z_pos, axis=0) 53 | z_neg = np.mean(z_neg, axis=0) 54 | return z_pos - z_neg 55 | 56 | _TAGS = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" 57 | _TAGS = _TAGS.split() 58 | 59 | z_manipulate = [get_manipulator(i) for i in range(len(_TAGS))] 60 | z_manipulate = 1.6 * np.array(z_manipulate, dtype=np.float32) 61 | print("Got manipulators") 62 | np.save('z_manipulate.npy', z_manipulate) 63 | -------------------------------------------------------------------------------- /demo/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import time 4 | from tqdm import tqdm 5 | from PIL import Image 6 | from threading import Lock 7 | 8 | lock = Lock() 9 | 10 | 11 | def get(name): 12 | return tf.get_default_graph().get_tensor_by_name('import/' + name + ':0') 13 | 14 | 15 | def tensorflow_session(): 16 | # Init session and params 17 | config = tf.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | # Pin GPU to local rank (one GPU per process) 20 | config.gpu_options.visible_device_list = str(0) 21 | sess = tf.Session(config=config) 22 | return sess 23 | 24 | 25 | optimized = True 26 | if optimized: 27 | # Optimized model. Twice as fast as 28 | # 1. we freeze conditional network (label is always 0) 29 | # 2. we use fused kernels 30 | import blocksparse 31 | graph_path = 'graph_optimized.pb' 32 | inputs = { 33 | 'dec_eps_0': 'dec_eps_0', 34 | 'dec_eps_1': 'dec_eps_1', 35 | 'dec_eps_2': 'dec_eps_2', 36 | 'dec_eps_3': 'dec_eps_3', 37 | 'dec_eps_4': 'dec_eps_4', 38 | 'dec_eps_5': 'dec_eps_5', 39 | 'enc_x': 'input/enc_x', 40 | } 41 | outputs = { 42 | 'dec_x': 'model_3/Cast_1', 43 | 'enc_eps_0': 'model_2/pool0/truediv_1', 44 | 'enc_eps_1': 'model_2/pool1/truediv_1', 45 | 'enc_eps_2': 'model_2/pool2/truediv_1', 46 | 'enc_eps_3': 'model_2/pool3/truediv_1', 47 | 'enc_eps_4': 'model_2/pool4/truediv_1', 48 | 'enc_eps_5': 'model_2/truediv_4' 49 | } 50 | 51 | def update_feed(feed_dict, bs): 52 | return feed_dict 53 | else: 54 | graph_path = 'graph_unoptimized.pb' 55 | inputs = { 56 | 'dec_eps_0': 'Placeholder', 57 | 'dec_eps_1': 'Placeholder_1', 58 | 'dec_eps_2': 'Placeholder_2', 59 | 'dec_eps_3': 'Placeholder_3', 60 | 'dec_eps_4': 'Placeholder_4', 61 | 'dec_eps_5': 'Placeholder_5', 62 | 'enc_x': 'input/image', 63 | 'enc_x_d': 'input/downsampled_image', 64 | 'enc_y': 'input/label' 65 | } 66 | outputs = { 67 | 'dec_x': 'model_1/Cast_1', 68 | 'enc_eps_0': 'model/pool0/truediv_1', 69 | 'enc_eps_1': 'model/pool1/truediv_1', 70 | 'enc_eps_2': 'model/pool2/truediv_1', 71 | 'enc_eps_3': 'model/pool3/truediv_1', 72 | 'enc_eps_4': 'model/pool4/truediv_1', 73 | 'enc_eps_5': 'model/truediv_4' 74 | } 75 | 76 | def update_feed(feed_dict, bs): 77 | x_d = 128 * np.ones([bs, 128, 128, 3], dtype=np.uint8) 78 | y = np.zeros([bs], dtype=np.int32) 79 | feed_dict[enc_x_d] = x_d 80 | feed_dict[enc_y] = y 81 | return feed_dict 82 | 83 | with tf.gfile.GFile(graph_path, 'rb') as f: 84 | graph_def_optimized = tf.GraphDef() 85 | graph_def_optimized.ParseFromString(f.read()) 86 | 87 | sess = tensorflow_session() 88 | tf.import_graph_def(graph_def_optimized) 89 | 90 | print("Loaded model") 91 | 92 | n_eps = 6 93 | 94 | # Encoder 95 | enc_x = get(inputs['enc_x']) 96 | enc_eps = [get(outputs['enc_eps_' + str(i)]) for i in range(n_eps)] 97 | if not optimized: 98 | enc_x_d = get(inputs['enc_x_d']) 99 | enc_y = get(inputs['enc_y']) 100 | 101 | # Decoder 102 | dec_x = get(outputs['dec_x']) 103 | dec_eps = [get(inputs['dec_eps_' + str(i)]) for i in range(n_eps)] 104 | 105 | eps_shapes = [(128, 128, 6), (64, 64, 12), (32, 32, 24), 106 | (16, 16, 48), (8, 8, 96), (4, 4, 384)] 107 | eps_sizes = [np.prod(e) for e in eps_shapes] 108 | eps_size = 256 * 256 * 3 109 | z_manipulate = np.load('z_manipulate.npy') 110 | 111 | _TAGS = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" 112 | _TAGS = _TAGS.split() 113 | 114 | flip_tags = ['No_Beard', 'Young'] 115 | for tag in flip_tags: 116 | i = _TAGS.index(tag) 117 | z_manipulate[i] = -z_manipulate[i] 118 | 119 | scale_tags = ['Narrow_Eyes'] 120 | for tag in scale_tags: 121 | i = _TAGS.index(tag) 122 | z_manipulate[i] = 1.2*z_manipulate[i] 123 | 124 | z_sq_norms = np.sum(z_manipulate**2, axis=-1, keepdims=True) 125 | z_proj = (z_manipulate / z_sq_norms).T 126 | 127 | 128 | def run(sess, fetches, feed_dict): 129 | with lock: 130 | # Locked tensorflow so average server response time to user is lower 131 | result = sess.run(fetches, feed_dict) 132 | return result 133 | 134 | 135 | def flatten_eps(eps): 136 | # [BS, eps_size] 137 | return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1) 138 | 139 | 140 | def unflatten_eps(feps): 141 | index = 0 142 | eps = [] 143 | bs = feps.shape[0] # feps.size // eps_size 144 | for shape in eps_shapes: 145 | eps.append(np.reshape( 146 | feps[:, index: index+np.prod(shape)], (bs, *shape))) 147 | index += np.prod(shape) 148 | return eps 149 | 150 | 151 | def encode(img): 152 | if len(img.shape) == 3: 153 | img = np.expand_dims(img, 0) 154 | bs = img.shape[0] 155 | assert img.shape[1:] == (256, 256, 3) 156 | feed_dict = {enc_x: img} 157 | 158 | update_feed(feed_dict, bs) # For unoptimized model 159 | return flatten_eps(run(sess, enc_eps, feed_dict)) 160 | 161 | 162 | def decode(feps): 163 | if len(feps.shape) == 1: 164 | feps = np.expand_dims(feps, 0) 165 | bs = feps.shape[0] 166 | # assert len(eps) == n_eps 167 | # for i in range(n_eps): 168 | # shape = (BATCH_SIZE, 128 // (2 ** i), 128 // (2 ** i), 6 * (2 ** i) * (2 ** (i == (n_eps - 1)))) 169 | # assert eps[i].shape == shape 170 | eps = unflatten_eps(feps) 171 | 172 | feed_dict = {} 173 | for i in range(n_eps): 174 | feed_dict[dec_eps[i]] = eps[i] 175 | 176 | update_feed(feed_dict, bs) # For unoptimized model 177 | return run(sess, dec_x, feed_dict) 178 | 179 | 180 | def project(z): 181 | return np.dot(z, z_proj) 182 | 183 | 184 | def _manipulate(z, dz, alpha): 185 | z = z + alpha * dz 186 | return decode(z), z 187 | 188 | 189 | def _manipulate_range(z, dz, points, scale): 190 | z_range = np.concatenate( 191 | [z + scale*(pt/(points - 1)) * dz for pt in range(0, points)], axis=0) 192 | return decode(z_range), z_range 193 | 194 | 195 | # alpha from [0,1] 196 | def mix(z1, z2, alpha): 197 | dz = (z2 - z1) 198 | return _manipulate(z1, dz, alpha) 199 | 200 | 201 | def mix_range(z1, z2, points=5): 202 | dz = (z2 - z1) 203 | return _manipulate_range(z1, dz, points, 1.) 204 | 205 | 206 | # alpha goes from [-1,1] 207 | def manipulate(z, typ, alpha): 208 | dz = z_manipulate[typ] 209 | return _manipulate(z, dz, alpha) 210 | 211 | 212 | def manipulate_all(z, typs, alphas): 213 | dz = 0.0 214 | for i in range(len(typs)): 215 | dz += alphas[i] * z_manipulate[typs[i]] 216 | return _manipulate(z, dz, 1.0) 217 | 218 | 219 | def manipulate_range(z, typ, points=5, scale=1): 220 | dz = z_manipulate[typ] 221 | return _manipulate_range(z - dz, 2*dz, points, scale) 222 | 223 | 224 | def random(bs=1, eps_std=0.7): 225 | feps = np.random.normal(scale=eps_std, size=[bs, eps_size]) 226 | return decode(feps), feps 227 | 228 | 229 | def test(): 230 | img = Image.open('test/img.png') 231 | img = np.reshape(np.array(img), [1, 256, 256, 3]) 232 | 233 | # Encoding speed 234 | eps = encode(img) 235 | t = time.time() 236 | for _ in tqdm(range(10)): 237 | eps = encode(img) 238 | print("Encoding latency {} sec/img".format((time.time() - t) / (1 * 10))) 239 | 240 | # Decoding speed 241 | dec = decode(eps) 242 | t = time.time() 243 | for _ in tqdm(range(10)): 244 | dec = decode(eps) 245 | print("Decoding latency {} sec/img".format((time.time() - t) / (1 * 10))) 246 | img = Image.fromarray(dec[0]) 247 | img.save('test/dec.png') 248 | 249 | # Manipulation 250 | dec, _ = manipulate(eps, _TAGS.index('Smiling'), 0.66) 251 | img = Image.fromarray(dec[0]) 252 | img.save('test/smile.png') 253 | 254 | 255 | # warm start 256 | _img, _z = random(1) 257 | _z = encode(_img) 258 | print("Warm started tf model") 259 | 260 | if __name__ == '__main__': 261 | test() 262 | -------------------------------------------------------------------------------- /demo/results/dec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/results/dec.png -------------------------------------------------------------------------------- /demo/results/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/results/img.png -------------------------------------------------------------------------------- /demo/results/smile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/results/smile.png -------------------------------------------------------------------------------- /demo/script.sh: -------------------------------------------------------------------------------- 1 | apt update && \ 2 | apt install -y locales cmake libsm6 libxext6 libxrender-dev && \ 3 | locale-gen en_US.UTF-8 4 | 5 | export LC_ALL=en_US.UTF-8 6 | 7 | # Pip packages for running server and face-aligned (dlib takes a while to install) 8 | pip install flask flask_cors tqdm opencv-python imutils dlib imageio 9 | 10 | # Get model weights 11 | curl https://openaipublic.azureedge.net/glow-demo/large3/graph_optimized.pb > graph_optimized.pb 12 | 13 | # Get manipulation vectors 14 | curl https://openaipublic.azureedge.net/glow-demo/z_manipulate.npy > z_manipulate.npy 15 | 16 | # Get facial landmarks for aligning input faces 17 | curl https://openaipublic.azureedge.net/glow-demo/shape_predictor_68_face_landmarks.dat > shape_predictor_68_face_landmarks.dat 18 | 19 | # Pip package for running optimized model with fused kernels 20 | curl https://openaipublic.azureedge.net/glow-demo/blocksparse-1.0.0-py2.py3-none-any.whl > blocksparse-1.0.0-py2.py3-none-any.whl 21 | pip install blocksparse-1.0.0-py2.py3-none-any.whl 22 | 23 | # If blocksparse doesn't install, use unoptimized model (and set optimized=False in model.py) 24 | # curl https://openaipublic.azureedge.net/glow-demo/large3/graph_unoptimized.pb > graph_unoptimized.pb 25 | 26 | -------------------------------------------------------------------------------- /demo/server.py: -------------------------------------------------------------------------------- 1 | import model 2 | from align_face import align_face 3 | from flask import Flask, jsonify, request 4 | from flask_cors import CORS 5 | 6 | import base64 7 | import time 8 | import numpy as np 9 | from PIL import Image 10 | from io import BytesIO 11 | app = Flask(__name__) 12 | CORS(app) 13 | 14 | 15 | def deserialise_img(img_str): 16 | img = base64.b64decode(img_str.split(",")[-1]) 17 | img = Image.open(BytesIO(img)) 18 | img = img.convert('RGB') 19 | img = np.array(img) 20 | return img 21 | 22 | 23 | def serialise_img(arr): 24 | img = Image.fromarray(arr) 25 | buf = BytesIO() 26 | img.save(buf, format='PNG') 27 | buf = buf.getvalue() 28 | return "data:image/png;base64," + base64.b64encode(buf).decode('utf-8') 29 | 30 | 31 | def deserialise_nparr(arr_str): 32 | arr = np.loads(base64.b64decode(arr_str)) 33 | return np.array(arr, dtype=np.float32) 34 | 35 | 36 | def serialise_nparr(arr): 37 | arr = np.array(arr, dtype=np.float16) 38 | return base64.b64encode(arr.dumps()).decode('utf-8') 39 | 40 | 41 | def send(result): 42 | # img, z are batches, send as list of singles 43 | img, z = result 44 | # , z=list(map(serialise_nparr, z))) 45 | return jsonify(img=list(map(serialise_img, img))) 46 | 47 | 48 | def send_proj(result, proj): 49 | # img, z are batches, send as list of singles 50 | img, z = result 51 | return jsonify(face_found=True, img=list(map(serialise_img, img)), z=list(map(serialise_nparr, z)), proj=proj.tolist()) 52 | 53 | 54 | def get(request, key): 55 | return request.get_json().get(key) 56 | 57 | 58 | def get_z(request, key): 59 | # z is a single point, batch it for use 60 | z = get(request, key) 61 | return np.expand_dims(deserialise_nparr(z), axis=0) 62 | 63 | 64 | @app.route('/') 65 | def hello_world(): 66 | return 'Welcome to Glow!' 67 | 68 | # Align and encode image 69 | # 70 | # args 71 | # img: Image as base64 string 72 | # 73 | # returns 74 | # json: {'face_found': face_found, 'img':[base64 img], 'z': [serialised z]} 75 | @app.route('/api/align_encode', methods=['POST']) 76 | def align_encode(): 77 | r = request 78 | img = get(r, 'img') 79 | # img = parse_img(img) if in jpg etc format 80 | img = deserialise_img(img) 81 | img, face_found = align_face(img) 82 | if face_found: 83 | img = np.reshape(img, [1, 256, 256, 3]) 84 | print(img.shape) 85 | z = model.encode(img) 86 | proj = model.project(z) # get projections. Not used 87 | result = img, z 88 | # jsonify(img=serialise_img(img), z=serialise_nparr(z)) 89 | return send_proj(result, proj) 90 | else: 91 | return jsonify(face_found=False) 92 | 93 | # Maipulate single attribute 94 | # 95 | # args 96 | # z: Serialised np array for encoding of image 97 | # typ: int in [0,40), representing which attribute to manipulate 98 | # alpha: float, usually in [-1,1], representing how much to manipulate. 0 gives original image 99 | # 100 | # returns 101 | # json: {'img': [img]} 102 | @app.route('/api/manipulate', methods=['POST']) 103 | def manipulate(): 104 | r = request 105 | z = get_z(r, 'z') 106 | typ = get(r, 'typ') 107 | alpha = get(r, 'alpha') 108 | result = model.manipulate(z, typ, alpha) 109 | return send(result) 110 | 111 | # Manipulate multiple attributes 112 | # typs: list of typ 113 | # alphas: list of corresponding alphas 114 | @app.route('/api/manipulate_all', methods=['POST']) 115 | def manipulate_all(): 116 | r = request 117 | z = get_z(r, 'z') 118 | typs = get(r, 'typs') 119 | alphas = get(r, 'alphas') 120 | result = model.manipulate_all(z, typs, alphas) 121 | return send(result) 122 | 123 | # Mix two faces 124 | # 125 | # args 126 | # z1: Serialised np array for encoding of image 1 127 | # z2: Serialised np array for encoding of image 2 128 | # alpha: float in [0,1], representing how much to mix. 0.5 gives middle image 129 | # 130 | # returns 131 | # json: {'img': [img]} 132 | @app.route('/api/mix', methods=['POST']) 133 | def mix(): 134 | r = request 135 | z1 = get_z(r, 'z1') 136 | z2 = get_z(r, 'z2') 137 | alpha = get(r, 'alpha') 138 | result = model.mix(z1, z2, alpha) 139 | return send(result) 140 | 141 | # Get random image 142 | @app.route('/api/random', methods=['POST']) 143 | def random(): 144 | r = request 145 | bs = get(r, 'bs') 146 | result = model.random(bs) 147 | img, z = result 148 | proj = model.project(z) 149 | return send_proj(result, proj) 150 | 151 | # Extra functions 152 | @app.route('/api/test', methods=['POST']) 153 | def test(): 154 | r = request 155 | z = get_z(r, 'z') 156 | typs = get(r, 'typs') 157 | alphas = get(r, 'alphas') # value between [-1,1] -> 0.5 is original image 158 | return jsonify(z="") 159 | 160 | 161 | @app.route('/api/manipulate_range', methods=['POST']) 162 | def manipulate_range(): 163 | r = request 164 | z = get_z(r, 'z') 165 | typ = get(r, 'typ') 166 | points = get(r, 'points') 167 | result = model.manipulate_range(z, typ, points) 168 | return send(result) 169 | 170 | 171 | @app.route('/api/mix_range', methods=['POST']) 172 | def mix_range(): 173 | r = request 174 | z1 = get_z(r, 'z1') 175 | z2 = get_z(r, 'z2') 176 | points = get(r, 'points') 177 | result = model.mix_range(z1, z2, points) 178 | return send(result) 179 | 180 | # Legacy 181 | # @app.route('/api/encode', methods=['POST']) 182 | # def encode(): 183 | # t = time.time() 184 | # r = request 185 | # img = get(r, 'img') 186 | # print("Time to read from request", time.time() - t) 187 | # t = time.time() 188 | # img = deserialise_nparr(img) 189 | # print("Time to serialise from request", time.time() - t) 190 | # t = time.time() 191 | # z = model.encode(img) 192 | # print("TIme to encode", time.time() - t) 193 | # t = time.time() 194 | # json = jsonify(z=serialise_nparr(z)) 195 | # print("Time to jsonify", time.time() - t) 196 | # return json 197 | # 198 | # @app.route('/api/decode', methods=['POST']) 199 | # def decode(): 200 | # t = time.time() 201 | # r = request 202 | # z = get(r, 'z') 203 | # print("Time to read from request", time.time() - t) 204 | # t = time.time() 205 | # z = deserialise_nparr(z) 206 | # print("Time to serialise from request", time.time() - t) 207 | # t = time.time() 208 | # img = model.decode(z) 209 | # print("TIme to decode", time.time() - t) 210 | # t = time.time() 211 | # json = jsonify(img=serialise_img(img)) 212 | # print("Time to jsonify", time.time() - t) 213 | # return json 214 | # 215 | # @app.route('/api/align', methods=['POST']) 216 | # def align(): 217 | # r = request 218 | # img = get(r, 'img') 219 | # # img = parse_img(img) if in jpg etc format 220 | # img = deserialise_img(img) 221 | # img = align_face(img) 222 | # return jsonify(img=serialise_img(img)) 223 | 224 | 225 | # FaceOff! Use for manipulation and blending faces 226 | if __name__ == '__main__': 227 | print('Running Flask app...') 228 | app.run(host='0.0.0.0', port=5050) 229 | -------------------------------------------------------------------------------- /demo/test/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/test/img.png -------------------------------------------------------------------------------- /demo/videos.py: -------------------------------------------------------------------------------- 1 | from model import encode, manipulate_range, mix_range 2 | from align_face import align 3 | import numpy as np 4 | from imageio import mimwrite, get_writer 5 | from PIL import Image 6 | 7 | _TAGS = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" 8 | _TAGS = _TAGS.split() 9 | 10 | # Reshape multiple images to a grid 11 | # def reshape(img,h,w,mirror=False): 12 | # img = np.reshape(img, [h, w, 256, 256, 3]) 13 | # img = np.transpose(img, [0,2,1,3,4]) 14 | # if mirror: 15 | # img = img[:,:,::-1,:,:] ## reflect width wise 16 | # img = np.reshape(img, [h*256, w*256, 3]) 17 | # return img 18 | 19 | def resize(arr, res, ratio=1.): 20 | shape = (int(res*ratio),res) 21 | return np.array(Image.fromarray(arr).resize(shape, resample=Image.ANTIALIAS)) 22 | 23 | def make_loop(imgs, gap=10): 24 | return [imgs[0]]*gap + imgs + [imgs[-1]]*2*gap + imgs[::-1] + [imgs[0]]*gap 25 | 26 | def write(imgs, name, fps): 27 | writer = get_writer(name, fps=fps, quality=6) 28 | for t in range(len(imgs)): 29 | writer.append_data(imgs[t]) 30 | writer.close() 31 | 32 | def make_video(name, imgs, fps=30, res=1024): 33 | imgs = [resize(img, res) for img in imgs] 34 | write(imgs, name + '.mp4', fps) 35 | imgs = make_loop(imgs) 36 | write(imgs, name + '_loop.mp4', fps) 37 | return 38 | 39 | def get_manipulations(name, typ, points=46, scale=1.0): 40 | img = align(name) 41 | z = encode(img) 42 | imgs, _ = manipulate_range(z, typ, points, scale) 43 | return imgs 44 | 45 | def get_mixs(name1, name2, points=46): 46 | img1 = align(name1) 47 | img2 = align(name2) 48 | z1 = encode(img1) 49 | z2 = encode(img2) 50 | imgs, _ = mix_range(z1, z2, points) 51 | return imgs 52 | 53 | if __name__ == '__main__': 54 | n1 = 'web/media/geoff.png' 55 | n2 = 'web/media/leo.png' 56 | tag = 'Smiling' 57 | print('Making smiling video') 58 | imgs_manipulated = get_manipulations(n1, _TAGS.index(tag)) 59 | print('Saving smiling video') 60 | make_video('geoff_%s' % tag, imgs_manipulated) 61 | print('Making mixing video') 62 | imgs_mixed = get_mixs(n1, n2) 63 | print('Saving mixing video') 64 | make_video('geoff_leo', imgs_mixed) -------------------------------------------------------------------------------- /demo/web/canvas2image.js: -------------------------------------------------------------------------------- 1 | // https://github.com/hongru/canvas2image 2 | 3 | /** 4 | * covert canvas to image 5 | * and save the image file 6 | */ 7 | 8 | var Canvas2Image = function () { 9 | 10 | // check if support sth. 11 | var $support = function () { 12 | var canvas = document.createElement('canvas'), 13 | ctx = canvas.getContext('2d'); 14 | 15 | return { 16 | canvas: !!ctx, 17 | imageData: !!ctx.getImageData, 18 | dataURL: !!canvas.toDataURL, 19 | btoa: !!window.btoa 20 | }; 21 | }(); 22 | 23 | var downloadMime = 'image/octet-stream'; 24 | 25 | function scaleCanvas (canvas, width, height) { 26 | var w = canvas.width, 27 | h = canvas.height; 28 | if (width == undefined) { 29 | width = w; 30 | } 31 | if (height == undefined) { 32 | height = h; 33 | } 34 | 35 | if (window.GlowDemoCanvasCropRect) { 36 | let r = window.GlowDemoCanvasCropRect; 37 | 38 | let rx = getCropRectParam(r.x, 0); 39 | let ry = getCropRectParam(r.y, 0); 40 | let rw = getCropRectParam(r.width, canvas.width); 41 | let rh = getCropRectParam(r.height, canvas.height); 42 | console.log([rx, ry, rw, rh]); 43 | console.log('cropped image'); 44 | 45 | var retCanvas = document.createElement('canvas'); 46 | var retCtx = retCanvas.getContext('2d'); 47 | retCanvas.width = rw; 48 | retCanvas.height = rh; 49 | retCtx.drawImage(canvas, rx, ry, rw, rh, 0, 0, rw, rh); 50 | return retCanvas; 51 | } 52 | else { 53 | console.log('will NOT crop image'); 54 | var retCanvas = document.createElement('canvas'); 55 | var retCtx = retCanvas.getContext('2d'); 56 | retCanvas.width = width; 57 | retCanvas.height = height; 58 | retCtx.drawImage(canvas, 0, 0, w, h, 0, 0, width, height); 59 | return retCanvas; 60 | } 61 | } 62 | 63 | function getDataURL (canvas, type, width, height) { 64 | canvas = scaleCanvas(canvas, width, height); 65 | return canvas.toDataURL(type); 66 | } 67 | 68 | function saveFile (strData) { 69 | if (window.GlowDemoDownloadFileName) { 70 | var element = document.createElement('a'); 71 | element.setAttribute('href', strData); 72 | element.setAttribute('download', window.GlowDemoDownloadFileName); 73 | 74 | element.style.display = 'none'; 75 | document.body.appendChild(element); 76 | 77 | element.click(); 78 | 79 | document.body.removeChild(element); 80 | } 81 | 82 | //document.location.href = strData; 83 | } 84 | 85 | function genImage(strData) { 86 | var img = document.createElement('img'); 87 | img.src = strData; 88 | return img; 89 | } 90 | function fixType (type) { 91 | type = type.toLowerCase().replace(/jpg/i, 'jpeg'); 92 | var r = type.match(/png|jpeg|bmp|gif/)[0]; 93 | return 'image/' + r; 94 | } 95 | function encodeData (data) { 96 | if (!window.btoa) { throw 'btoa undefined' } 97 | var str = ''; 98 | if (typeof data == 'string') { 99 | str = data; 100 | } else { 101 | for (var i = 0; i < data.length; i ++) { 102 | str += String.fromCharCode(data[i]); 103 | } 104 | } 105 | 106 | return btoa(str); 107 | } 108 | function getImageData (canvas) { 109 | console.log(window.GlowDemoCanvasCropRect); 110 | if (window.GlowDemoCanvasCropRect) { 111 | let r = window.GlowDemoCanvasCropRect; 112 | 113 | let x = getCropRectParam(r.x, 0); 114 | let y = getCropRectParam(r.y, 0); 115 | let w = getCropRectParam(r.width, canvas.width); 116 | let h = getCropRectParam(r.height, canvas.height); 117 | console.log([r.x, r.y, r.width, r.height]); 118 | 119 | return canvas.getContext('2d').getImageData(r.x, r.y, r.width, r.height); 120 | } 121 | else { 122 | var w = canvas.width, 123 | h = canvas.height; 124 | return canvas.getContext('2d').getImageData(0, 0, w, h); 125 | } 126 | } 127 | function getCropRectParam(param, autoParam) { 128 | if (isFunction(param)) { 129 | return param(autoParam); 130 | } 131 | else if (param === "auto") { 132 | return autoParam; 133 | } 134 | else { 135 | return param; 136 | } 137 | } 138 | function isFunction(obj) { 139 | return !!(obj && obj.constructor && obj.call && obj.apply); 140 | } 141 | function makeURI (strData, type) { 142 | return 'data:' + type + ';base64,' + strData; 143 | } 144 | 145 | 146 | /** 147 | * create bitmap image 148 | * 按照规则生成图片响应头和响应体 149 | */ 150 | var genBitmapImage = function (oData) { 151 | 152 | // 153 | // BITMAPFILEHEADER: http://msdn.microsoft.com/en-us/library/windows/desktop/dd183374(v=vs.85).aspx 154 | // BITMAPINFOHEADER: http://msdn.microsoft.com/en-us/library/dd183376.aspx 155 | // 156 | 157 | var biWidth = oData.width; 158 | var biHeight = oData.height; 159 | var biSizeImage = biWidth * biHeight * 3; 160 | var bfSize = biSizeImage + 54; // total header size = 54 bytes 161 | 162 | // 163 | // typedef struct tagBITMAPFILEHEADER { 164 | // WORD bfType; 165 | // DWORD bfSize; 166 | // WORD bfReserved1; 167 | // WORD bfReserved2; 168 | // DWORD bfOffBits; 169 | // } BITMAPFILEHEADER; 170 | // 171 | var BITMAPFILEHEADER = [ 172 | // WORD bfType -- The file type signature; must be "BM" 173 | 0x42, 0x4D, 174 | // DWORD bfSize -- The size, in bytes, of the bitmap file 175 | bfSize & 0xff, bfSize >> 8 & 0xff, bfSize >> 16 & 0xff, bfSize >> 24 & 0xff, 176 | // WORD bfReserved1 -- Reserved; must be zero 177 | 0, 0, 178 | // WORD bfReserved2 -- Reserved; must be zero 179 | 0, 0, 180 | // DWORD bfOffBits -- The offset, in bytes, from the beginning of the BITMAPFILEHEADER structure to the bitmap bits. 181 | 54, 0, 0, 0 182 | ]; 183 | 184 | // 185 | // typedef struct tagBITMAPINFOHEADER { 186 | // DWORD biSize; 187 | // LONG biWidth; 188 | // LONG biHeight; 189 | // WORD biPlanes; 190 | // WORD biBitCount; 191 | // DWORD biCompression; 192 | // DWORD biSizeImage; 193 | // LONG biXPelsPerMeter; 194 | // LONG biYPelsPerMeter; 195 | // DWORD biClrUsed; 196 | // DWORD biClrImportant; 197 | // } BITMAPINFOHEADER, *PBITMAPINFOHEADER; 198 | // 199 | var BITMAPINFOHEADER = [ 200 | // DWORD biSize -- The number of bytes required by the structure 201 | 40, 0, 0, 0, 202 | // LONG biWidth -- The width of the bitmap, in pixels 203 | biWidth & 0xff, biWidth >> 8 & 0xff, biWidth >> 16 & 0xff, biWidth >> 24 & 0xff, 204 | // LONG biHeight -- The height of the bitmap, in pixels 205 | biHeight & 0xff, biHeight >> 8 & 0xff, biHeight >> 16 & 0xff, biHeight >> 24 & 0xff, 206 | // WORD biPlanes -- The number of planes for the target device. This value must be set to 1 207 | 1, 0, 208 | // WORD biBitCount -- The number of bits-per-pixel, 24 bits-per-pixel -- the bitmap 209 | // has a maximum of 2^24 colors (16777216, Truecolor) 210 | 24, 0, 211 | // DWORD biCompression -- The type of compression, BI_RGB (code 0) -- uncompressed 212 | 0, 0, 0, 0, 213 | // DWORD biSizeImage -- The size, in bytes, of the image. This may be set to zero for BI_RGB bitmaps 214 | biSizeImage & 0xff, biSizeImage >> 8 & 0xff, biSizeImage >> 16 & 0xff, biSizeImage >> 24 & 0xff, 215 | // LONG biXPelsPerMeter, unused 216 | 0,0,0,0, 217 | // LONG biYPelsPerMeter, unused 218 | 0,0,0,0, 219 | // DWORD biClrUsed, the number of color indexes of palette, unused 220 | 0,0,0,0, 221 | // DWORD biClrImportant, unused 222 | 0,0,0,0 223 | ]; 224 | 225 | var iPadding = (4 - ((biWidth * 3) % 4)) % 4; 226 | 227 | var aImgData = oData.data; 228 | 229 | var strPixelData = ''; 230 | var biWidth4 = biWidth<<2; 231 | var y = biHeight; 232 | var fromCharCode = String.fromCharCode; 233 | 234 | do { 235 | var iOffsetY = biWidth4*(y-1); 236 | var strPixelRow = ''; 237 | for (var x = 0; x < biWidth; x++) { 238 | var iOffsetX = x<<2; 239 | strPixelRow += fromCharCode(aImgData[iOffsetY+iOffsetX+2]) + 240 | fromCharCode(aImgData[iOffsetY+iOffsetX+1]) + 241 | fromCharCode(aImgData[iOffsetY+iOffsetX]); 242 | } 243 | 244 | for (var c = 0; c < iPadding; c++) { 245 | strPixelRow += String.fromCharCode(0); 246 | } 247 | 248 | strPixelData += strPixelRow; 249 | } while (--y); 250 | 251 | var strEncoded = encodeData(BITMAPFILEHEADER.concat(BITMAPINFOHEADER)) + encodeData(strPixelData); 252 | 253 | return strEncoded; 254 | }; 255 | 256 | /** 257 | * saveAsImage 258 | * @param canvasElement 259 | * @param {String} image type 260 | * @param {Number} [optional] png width 261 | * @param {Number} [optional] png height 262 | */ 263 | var saveAsImage = function (canvas, width, height, type) { 264 | if ($support.canvas && $support.dataURL) { 265 | if (typeof canvas == "string") { canvas = document.getElementById(canvas); } 266 | if (type == undefined) { type = 'png'; } 267 | type = fixType(type); 268 | if (/bmp/.test(type)) { 269 | var data = getImageData(scaleCanvas(canvas, width, height)); 270 | var strData = genBitmapImage(data); 271 | saveFile(makeURI(strData, downloadMime)); 272 | } else { 273 | var strData = getDataURL(canvas, type, width, height); 274 | saveFile(strData.replace(type, downloadMime)); 275 | } 276 | } 277 | }; 278 | 279 | var convertToImage = function (canvas, width, height, type) { 280 | if ($support.canvas && $support.dataURL) { 281 | if (typeof canvas == "string") { canvas = document.getElementById(canvas); } 282 | if (type == undefined) { type = 'png'; } 283 | type = fixType(type); 284 | 285 | if (/bmp/.test(type)) { 286 | var data = getImageData(scaleCanvas(canvas, width, height)); 287 | var strData = genBitmapImage(data); 288 | return genImage(makeURI(strData, 'image/bmp')); 289 | } else { 290 | var strData = getDataURL(canvas, type, width, height); 291 | return genImage(strData); 292 | } 293 | } 294 | }; 295 | 296 | 297 | 298 | return { 299 | saveAsImage: saveAsImage, 300 | saveAsPNG: function (canvas, width, height) { 301 | return saveAsImage(canvas, width, height, 'png'); 302 | }, 303 | saveAsJPEG: function (canvas, width, height) { 304 | return saveAsImage(canvas, width, height, 'jpeg'); 305 | }, 306 | saveAsGIF: function (canvas, width, height) { 307 | return saveAsImage(canvas, width, height, 'gif'); 308 | }, 309 | saveAsBMP: function (canvas, width, height) { 310 | return saveAsImage(canvas, width, height, 'bmp'); 311 | }, 312 | 313 | convertToImage: convertToImage, 314 | convertToPNG: function (canvas, width, height) { 315 | return convertToImage(canvas, width, height, 'png'); 316 | }, 317 | convertToJPEG: function (canvas, width, height) { 318 | return convertToImage(canvas, width, height, 'jpeg'); 319 | }, 320 | convertToGIF: function (canvas, width, height) { 321 | return convertToImage(canvas, width, height, 'gif'); 322 | }, 323 | convertToBMP: function (canvas, width, height) { 324 | return convertToImage(canvas, width, height, 'bmp'); 325 | } 326 | }; 327 | 328 | }(); 329 | -------------------------------------------------------------------------------- /demo/web/glowDemo.css: -------------------------------------------------------------------------------- 1 | /* glowDemo.css 2 | * 3 | * CSS driving the Glow paper face-mixing demo. 4 | */ 5 | 6 | /* Tabs */ 7 | 8 | .GlowDemo_TabLabelContainer { 9 | display: table; 10 | margin: auto; 11 | } 12 | 13 | .GlowDemo_TabLabel { 14 | cursor: pointer; 15 | display: table-cell; 16 | padding: 10px; 17 | background-color: #fff; 18 | border-radius: 5px 5px 5px 5px; 19 | border: 1px solid #4bacff; 20 | color: #0b8dff; 21 | -moz-user-select: none; 22 | -webkit-user-select: none; 23 | -ms-user-select: none; 24 | min-width: 90px; 25 | text-align: center; 26 | } 27 | 28 | .GlowDemo_TabLabel:hover { 29 | background-color: #c3e3ff; 30 | } 31 | 32 | .GlowDemo_TabLabel:active, .GlowDemo_ActiveTab { 33 | color: #fff; 34 | background-color: #4bacff; 35 | } 36 | 37 | .GlowDemo_TabLabel:hover.GlowDemo_ActiveTab { 38 | background-color: #6dbbff; 39 | } 40 | 41 | /* Demo Container */ 42 | 43 | .GlowDemo { 44 | box-sizing: initial; 45 | font-size: initial; 46 | line-height: initial; 47 | } 48 | 49 | .GlowDemo img { 50 | display: initial; 51 | padding: initial; 52 | position: absolute; 53 | left: initial; 54 | transform: initial; 55 | } 56 | 57 | .GlowDemo_Container { 58 | margin-top: 1em; 59 | margin-left: auto; 60 | margin-right: auto; 61 | width: 375px; 62 | transform: translateX(-16px); 63 | } 64 | 65 | /* Face Sliders Demo */ 66 | 67 | .GlowDemo_FaceSlidersDemo { 68 | /* width: 404px; */ 69 | overflow: hidden; 70 | display: table; 71 | margin: 1em auto 3em; 72 | } 73 | 74 | /* Face Slider Mode */ 75 | 76 | .GlowDemo_SelectorAndOutput { 77 | display: table; 78 | } 79 | 80 | /* Image Selector (Input) */ 81 | 82 | .GlowDemo_InputLabel { 83 | display: block; 84 | text-transform: uppercase; 85 | margin: auto; 86 | color: #747c9f; 87 | font-size: 0.8em; 88 | padding-left: 5px; 89 | } 90 | 91 | .GlowDemo_SelectorFrame { 92 | display: table-cell; 93 | padding-right: 5px; 94 | } 95 | 96 | .GlowDemo_ImageSelectorNoFaceFoundOverlay { 97 | position: absolute; 98 | z-index: 7; 99 | margin-left: -16px; 100 | margin-top: -154px; 101 | width: 14em; 102 | background-color: #fbb4d7; 103 | border-radius: 4px; 104 | padding: 0.5em; 105 | box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.205); 106 | } 107 | 108 | .GlowDemo_ImageChoice { 109 | cursor: pointer; 110 | width: 59px; 111 | height: 59px; 112 | position: absolute; 113 | } 114 | 115 | /* Input & Output Images */ 116 | 117 | .GlowDemo_ImageFrame { 118 | width: 178px; 119 | height: 178px; 120 | margin: auto; 121 | 122 | padding: 2px; 123 | box-shadow: 0px 3px 6px #ddd; 124 | 125 | background-color: #f9f9f9; 126 | } 127 | 128 | /* Output Images */ 129 | 130 | .GlowDemo_OutputImage { 131 | width: 178px !important; 132 | height: 178px !important; 133 | } 134 | 135 | .GlowDemo_OutputImageFrame { 136 | /* background-color: #d1d1d1 */ 137 | } 138 | 139 | .GlowDemo_OutputLabel { 140 | display: block; 141 | text-transform: uppercase; 142 | margin: auto; 143 | color: #747c9f; 144 | font-size: 0.8em; 145 | text-align: right; 146 | padding-right: 5px; 147 | } 148 | 149 | .GlowDemo_OutputHider { 150 | z-index: 2; 151 | color: #fff; 152 | background-color: white; 153 | width: 187px; 154 | height: 210px; 155 | position: absolute; 156 | margin-top: -200px; 157 | margin-left: -5px; 158 | } 159 | 160 | .GlowDemo_MixingOutputFrame .GlowDemo_OutputHider { 161 | width: 248px; 162 | } 163 | 164 | .GlowDemo_FadeButton { 165 | width: 35px !important; 166 | height: 35px !important; 167 | position: absolute; 168 | border-radius: 4px; 169 | padding: 5px !important; 170 | box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.48); 171 | z-index: 5; 172 | } 173 | 174 | .GlowDemo_DownloadButton { 175 | cursor: pointer; 176 | margin-left: 128px; 177 | margin-top: 129px; 178 | background-color: #ffffffe6; 179 | } 180 | 181 | .GlowDemo_DownloadButton:hover { 182 | margin-left: 127px; 183 | margin-top: 128px; 184 | padding: 6px !important; 185 | background-color: #ffffff; 186 | } 187 | 188 | .GlowDemo_DownloadButton:active { 189 | margin-left: 130px; 190 | margin-top: 131px; 191 | padding: 3px !important; 192 | background-color: #ffffff; 193 | } 194 | 195 | .GlowDemo_UserImageEditButton { 196 | cursor: pointer; 197 | margin-left: 128px; 198 | margin-top: 129px; 199 | background-color: #ffffffe6; 200 | } 201 | 202 | .GlowDemo_UserImageEditButton:hover { 203 | margin-left: 127px; 204 | margin-top: 128px; 205 | padding: 6px !important; 206 | background-color: #ffffff; 207 | } 208 | 209 | .GlowDemo_UserImageEditButton:active { 210 | margin-left: 130px; 211 | margin-top: 131px; 212 | padding: 3px !important; 213 | background-color: #ffffff; 214 | } 215 | 216 | /* Feature Sliders */ 217 | 218 | .GlowDemo_SliderFrame { 219 | width: 160px; 220 | margin: auto; 221 | display: table; 222 | padding-top: 2.7em; 223 | } 224 | 225 | .GlowDemo_MixingSliderFrame { 226 | padding-top: 0.8em; 227 | } 228 | 229 | .GlowDemo_FaceSliderContainer { 230 | /* display: table-row; */ 231 | } 232 | 233 | .GlowDemo_FaceSliderLabel { 234 | display: table-cell; 235 | vertical-align: middle; 236 | } 237 | 238 | .GlowDemo_FaceSliderLabel p { 239 | margin: 0.5em 0.5em; 240 | min-width: 8em; 241 | font-size: 1.1em; 242 | line-height: 1.1em; 243 | } 244 | 245 | .GlowDemo_FaceSlider { 246 | /* display: inline-block; */ 247 | /* margin: 10px 0 0px 0 !important; */ 248 | width: 190px !important; 249 | vertical-align: middle; 250 | display: table-cell; 251 | } 252 | 253 | .GlowDemo_SliderHider { 254 | position: absolute; 255 | height: 43px; 256 | width: 365px; 257 | color: #fff; 258 | background-color: #ffffff; 259 | /* margin-top: -30px; */ 260 | margin-left: -365px; 261 | } 262 | 263 | /* Face Mixing Demo */ 264 | 265 | .GlowDemo_FaceMixingDemo { 266 | /* width: 404px; */ 267 | overflow: hidden; 268 | display: table; 269 | margin: 1em auto 3em; 270 | } 271 | 272 | .GlowDemo_LeftInputLabel { 273 | display: block; 274 | text-transform: uppercase; 275 | margin: auto; 276 | color: #747c9f; 277 | font-size: 0.8em; 278 | padding-left: 5px; 279 | } 280 | 281 | .GlowDemo_RightInputLabel { 282 | display: block; 283 | text-transform: uppercase; 284 | margin: auto; 285 | color: #747c9f; 286 | font-size: 0.8em; 287 | text-align: right; 288 | padding-right: 5px; 289 | } 290 | 291 | .GlowDemo_MixingInputImagesContainer { 292 | display: table; 293 | } 294 | 295 | .GlowDemo_MixingSelectorFrame { 296 | display: table-cell; 297 | } 298 | 299 | .GlowDemo_MixingSelectorFrameLeft { 300 | padding-right: 5px; 301 | } 302 | 303 | .GlowDemo_MixingSelectorFrameRight { 304 | /* display: table-cell; */ 305 | } 306 | 307 | .GlowDemo_OutputAndMixingSliderContainer { 308 | display: table; 309 | margin: auto; 310 | } 311 | 312 | .GlowDemo_MixingOutputFrame { 313 | margin-top: 15px; 314 | } 315 | 316 | .GlowDemo_MixingOutputLabel { 317 | display: block; 318 | text-transform: uppercase; 319 | margin: auto; 320 | color: #747c9f; 321 | font-size: 0.8em; 322 | text-align: center; 323 | } 324 | 325 | .GlowDemo_MixingSliderLabel { 326 | display: table; 327 | text-transform: uppercase; 328 | margin: auto; 329 | color: #747c9f; 330 | font-size: 0.8em; 331 | text-align: center; 332 | } 333 | 334 | .GlowDemo_MixingSlider { 335 | display: table; 336 | } 337 | 338 | .GlowDemo_MixingSliderHider { 339 | margin-top: -69px; 340 | height: 60px; 341 | } 342 | 343 | /* Loading Visuals */ 344 | 345 | .GlowDemo_LoadingVisual { 346 | position: absolute; 347 | display: block; 348 | z-index: 5; 349 | width: 60px; 350 | height: 60px; 351 | margin-top: 59px; 352 | margin-left: 60px; 353 | } 354 | 355 | /* Hints */ 356 | 357 | .GlowDemo_Hint { 358 | position: absolute; 359 | z-index: 4; 360 | font-size: 1em; 361 | line-height: 1.2em; 362 | background-color: #ffffffe6; 363 | border-radius: 4px; 364 | padding: 0.1em 0.4em; 365 | color: #818181; 366 | -moz-user-select: none; 367 | -webkit-user-select: none; 368 | -ms-user-select: none; 369 | /* box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.07); */ 370 | } 371 | 372 | .GlowDemo_SelectorHint { 373 | max-width: 220px; 374 | margin-left: -173px; 375 | margin-top: 7px; 376 | } 377 | 378 | .GlowDemo_DownloadHint { 379 | max-width: 220px; 380 | margin-left: 23px; 381 | margin-top: 7px; 382 | background-color: #ffffffe6; 383 | border-radius: 4px; 384 | padding: 0.1em 0.4em; 385 | color: #818181; 386 | /* box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.07); */ 387 | -moz-user-select: none; 388 | -webkit-user-select: none; 389 | -ms-user-select: none; 390 | font-size: 1em; 391 | } 392 | 393 | .GlowDemo_MixingHint { 394 | max-width: 324px; 395 | margin-left: -36px; 396 | margin-top: 84px; 397 | } 398 | 399 | input[type=range] { 400 | /*removes default webkit styles*/ 401 | -webkit-appearance: none; 402 | 403 | /*fix for FF unable to apply focus style bug */ 404 | border: 1px solid white; 405 | 406 | /*required for proper track sizing in FF*/ 407 | width: 190px; 408 | height: 35px; 409 | 410 | /*centering*/ 411 | vertical-align: middle; 412 | display: table-cell; 413 | } 414 | input[type=range]::-webkit-slider-runnable-track { 415 | width: 190px; 416 | height: 3px; 417 | background: #ddd; 418 | border: none; 419 | border-radius: 3px; 420 | } 421 | input[type=range]::-webkit-slider-thumb { 422 | -webkit-appearance: none; 423 | border: none; 424 | height: 16px; 425 | width: 16px; 426 | border-radius: 50%; 427 | background: #4bacff; 428 | margin-top: -6px; 429 | } 430 | input[type=range]:focus { 431 | outline: none; 432 | } 433 | input[type=range]:focus::-webkit-slider-runnable-track { 434 | background: #ccc; 435 | } 436 | 437 | input[type=range]::-moz-range-track { 438 | width: 190px; 439 | height: 3px; 440 | background: #ddd; 441 | border: none; 442 | border-radius: 3px; 443 | } 444 | input[type=range]::-moz-range-thumb { 445 | border: none; 446 | height: 16px; 447 | width: 16px; 448 | border-radius: 50%; 449 | background: #4bacff; 450 | } 451 | 452 | /*hide the outline behind the border*/ 453 | input[type=range]:-moz-focusring{ 454 | outline: 1px solid white; 455 | outline-offset: -1px; 456 | } 457 | 458 | input[type=range]::-ms-track { 459 | width: 190px; 460 | height: 3px; 461 | 462 | /*remove bg colour from the track, we'll use ms-fill-lower and ms-fill-upper instead */ 463 | background: transparent; 464 | 465 | /*leave room for the larger thumb to overflow with a transparent border */ 466 | border-color: transparent; 467 | border-width: 6px 0; 468 | 469 | /*remove default tick marks*/ 470 | color: transparent; 471 | } 472 | input[type=range]::-ms-fill-lower { 473 | background: #777; 474 | border-radius: 10px; 475 | } 476 | input[type=range]::-ms-fill-upper { 477 | background: #ddd; 478 | border-radius: 10px; 479 | } 480 | input[type=range]::-ms-thumb { 481 | border: none; 482 | height: 16px; 483 | width: 16px; 484 | border-radius: 50%; 485 | background: #4bacff; 486 | } 487 | input[type=range]:focus::-ms-fill-lower { 488 | background: #888; 489 | } 490 | input[type=range]:focus::-ms-fill-upper { 491 | background: #ccc; 492 | } 493 | -------------------------------------------------------------------------------- /demo/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Glow: Better Reversible Generative Models 10 | 11 | 12 | 13 | 14 | 16 | 17 | 18 | 19 | 20 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 |
37 |
38 |

Glow: Better Reversible Generative Models

39 | 40 |
41 |
42 |
43 |

44 | This mock-up mimics the look of the in-progress article to inform a design 45 | that embeds the demo into the article. The relevant assets just need to be 46 | migrated into the final article. 47 |

48 | 49 |

We’ve developed Glow, a new type of generative model which uses 50 | invertible 1x1 convolutions to create rich, synthetic models of data, 51 | automatically discovering features we can manipulate. The model extends 52 | previous work on reversible generative models, simplifying the 53 | architecture and leading to substantially better results. We’re releasing 54 | code for the model and an online visualization tool so people can explore 55 | and build on these results.

56 | 57 | Read Paper
58 | View Code 59 | 60 | 61 | 62 |
63 |
64 | 65 | 66 | 67 |

Motivation

68 |

Generative modeling is about observing data, like a set of pictures of 69 | faces, then learning a model of how this data was generated. Learning to 70 | approximate the data-generating process requires learning all structure 71 | present in the data, and successful models should be able to synthesize 72 | outputs that look similar to the data. Accurate generative models have 73 | broad applications, including speech synthesis, text analysis and 74 | synthesis, semi-supervised learning and model-based control. The technique 75 | we propose can be applied to those problems as well.

76 |
77 | 78 | -------------------------------------------------------------------------------- /demo/web/load-image.all.min.js: -------------------------------------------------------------------------------- 1 | !function(e){"use strict";function t(e,i,a){var o,n=document.createElement("img");return n.onerror=function(o){return t.onerror(n,o,e,i,a)},n.onload=function(o){return t.onload(n,o,e,i,a)},"string"==typeof e?(t.fetchBlob(e,function(i){i?(e=i,o=t.createObjectURL(e)):(o=e,a&&a.crossOrigin&&(n.crossOrigin=a.crossOrigin)),n.src=o},a),n):t.isInstanceOf("Blob",e)||t.isInstanceOf("File",e)?(o=n._objectURL=t.createObjectURL(e))?(n.src=o,n):t.readFile(e,function(e){var t=e.target;t&&t.result?n.src=t.result:i&&i(e)}):void 0}function i(e,i){!e._objectURL||i&&i.noRevoke||(t.revokeObjectURL(e._objectURL),delete e._objectURL)}var a=e.createObjectURL&&e||e.URL&&URL.revokeObjectURL&&URL||e.webkitURL&&webkitURL;t.fetchBlob=function(e,t,i){t()},t.isInstanceOf=function(e,t){return Object.prototype.toString.call(t)==="[object "+e+"]"},t.transform=function(e,t,i,a,o){i(e,o)},t.onerror=function(e,t,a,o,n){i(e,n),o&&o.call(e,t)},t.onload=function(e,a,o,n,r){i(e,r),n&&t.transform(e,r,n,o,{})},t.createObjectURL=function(e){return!!a&&a.createObjectURL(e)},t.revokeObjectURL=function(e){return!!a&&a.revokeObjectURL(e)},t.readFile=function(t,i,a){if(e.FileReader){var o=new FileReader;if(o.onload=o.onerror=i,a=a||"readAsDataURL",o[a])return o[a](t),o}return!1},"function"==typeof define&&define.amd?define(function(){return t}):"object"==typeof module&&module.exports?module.exports=t:e.loadImage=t}("undefined"!=typeof window&&window||this),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image"],e):e("object"==typeof module&&module.exports?require("./load-image"):window.loadImage)}(function(e){"use strict";var t=e.transform;e.transform=function(i,a,o,n,r){t.call(e,e.scale(i,a,r),a,o,n,r)},e.transformCoordinates=function(){},e.getTransformedOptions=function(e,t){var i,a,o,n,r=t.aspectRatio;if(!r)return t;i={};for(a in t)t.hasOwnProperty(a)&&(i[a]=t[a]);return i.crop=!0,o=e.naturalWidth||e.width,n=e.naturalHeight||e.height,o/n>r?(i.maxWidth=n*r,i.maxHeight=n):(i.maxWidth=o,i.maxHeight=o/r),i},e.renderImageToCanvas=function(e,t,i,a,o,n,r,s,l,d){return e.getContext("2d").drawImage(t,i,a,o,n,r,s,l,d),e},e.hasCanvasOption=function(e){return e.canvas||e.crop||!!e.aspectRatio},e.scale=function(t,i,a){function o(){var e=Math.max((l||v)/v,(d||P)/P);e>1&&(v*=e,P*=e)}function n(){var e=Math.min((r||v)/v,(s||P)/P);e<1&&(v*=e,P*=e)}i=i||{};var r,s,l,d,c,u,f,g,h,m,p,S=document.createElement("canvas"),b=t.getContext||e.hasCanvasOption(i)&&S.getContext,y=t.naturalWidth||t.width,x=t.naturalHeight||t.height,v=y,P=x;if(b&&(f=(i=e.getTransformedOptions(t,i,a)).left||0,g=i.top||0,i.sourceWidth?(c=i.sourceWidth,void 0!==i.right&&void 0===i.left&&(f=y-c-i.right)):c=y-f-(i.right||0),i.sourceHeight?(u=i.sourceHeight,void 0!==i.bottom&&void 0===i.top&&(g=x-u-i.bottom)):u=x-g-(i.bottom||0),v=c,P=u),r=i.maxWidth,s=i.maxHeight,l=i.minWidth,d=i.minHeight,b&&r&&s&&i.crop?(v=r,P=s,(p=c/u-r/s)<0?(u=s*c/r,void 0===i.top&&void 0===i.bottom&&(g=(x-u)/2)):p>0&&(c=r*u/s,void 0===i.left&&void 0===i.right&&(f=(y-c)/2))):((i.contain||i.cover)&&(l=r=r||l,d=s=s||d),i.cover?(n(),o()):(o(),n())),b){if((h=i.pixelRatio)>1&&(S.style.width=v+"px",S.style.height=P+"px",v*=h,P*=h,S.getContext("2d").scale(h,h)),(m=i.downsamplingRatio)>0&&m<1&&vv;)S.width=c*m,S.height=u*m,e.renderImageToCanvas(S,t,f,g,c,u,0,0,S.width,S.height),f=0,g=0,c=S.width,u=S.height,(t=document.createElement("canvas")).width=c,t.height=u,e.renderImageToCanvas(t,S,0,0,c,u,0,0,c,u);return S.width=v,S.height=P,e.transformCoordinates(S,i),e.renderImageToCanvas(S,t,f,g,c,u,0,0,v,P)}return t.width=v,t.height=P,t}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image"],e):e("object"==typeof module&&module.exports?require("./load-image"):window.loadImage)}(function(e){"use strict";var t="undefined"!=typeof Blob&&(Blob.prototype.slice||Blob.prototype.webkitSlice||Blob.prototype.mozSlice);e.blobSlice=t&&function(){return(this.slice||this.webkitSlice||this.mozSlice).apply(this,arguments)},e.metaDataParsers={jpeg:{65505:[]}},e.parseMetaData=function(t,i,a,o){a=a||{},o=o||{};var n=this,r=a.maxMetaDataSize||262144;!!("undefined"!=typeof DataView&&t&&t.size>=12&&"image/jpeg"===t.type&&e.blobSlice)&&e.readFile(e.blobSlice.call(t,0,r),function(t){if(t.target.error)return console.log(t.target.error),void i(o);var r,s,l,d,c=t.target.result,u=new DataView(c),f=2,g=u.byteLength-4,h=f;if(65496===u.getUint16(0)){for(;f=65504&&r<=65519||65534===r);){if(s=u.getUint16(f+2)+2,f+s>u.byteLength){console.log("Invalid meta data: Invalid segment size.");break}if(l=e.metaDataParsers.jpeg[r])for(d=0;d6&&(c.slice?o.imageHead=c.slice(0,h):o.imageHead=new Uint8Array(c).subarray(0,h))}else console.log("Invalid JPEG file: Missing JPEG marker.");i(o)},"readAsArrayBuffer")||i(o)},e.hasMetaOption=function(e){return e&&e.meta};var i=e.transform;e.transform=function(t,a,o,n,r){e.hasMetaOption(a)?e.parseMetaData(n,function(r){i.call(e,t,a,o,n,r)},a,r):i.apply(e,arguments)}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";"undefined"!=typeof fetch&&"undefined"!=typeof Request&&(e.fetchBlob=function(t,i,a){if(e.hasMetaOption(a))return fetch(new Request(t,a)).then(function(e){return e.blob()}).then(i).catch(function(e){console.log(e),i()});i()})}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";e.ExifMap=function(){return this},e.ExifMap.prototype.map={Orientation:274},e.ExifMap.prototype.get=function(e){return this[e]||this[this.map[e]]},e.getExifThumbnail=function(t,i,a){if(a&&!(i+a>t.byteLength))return e.createObjectURL(new Blob([t.buffer.slice(i,i+a)]));console.log("Invalid Exif data: Invalid thumbnail data.")},e.exifTagTypes={1:{getValue:function(e,t){return e.getUint8(t)},size:1},2:{getValue:function(e,t){return String.fromCharCode(e.getUint8(t))},size:1,ascii:!0},3:{getValue:function(e,t,i){return e.getUint16(t,i)},size:2},4:{getValue:function(e,t,i){return e.getUint32(t,i)},size:4},5:{getValue:function(e,t,i){return e.getUint32(t,i)/e.getUint32(t+4,i)},size:8},9:{getValue:function(e,t,i){return e.getInt32(t,i)},size:4},10:{getValue:function(e,t,i){return e.getInt32(t,i)/e.getInt32(t+4,i)},size:8}},e.exifTagTypes[7]=e.exifTagTypes[1],e.getExifValue=function(t,i,a,o,n,r){var s,l,d,c,u,f,g=e.exifTagTypes[o];if(g){if(s=g.size*n,!((l=s>4?i+t.getUint32(a+8,r):a+8)+s>t.byteLength)){if(1===n)return g.getValue(t,l,r);for(d=[],c=0;ce.byteLength)console.log("Invalid Exif data: Invalid directory offset.");else{if(n=e.getUint16(i,a),!((r=i+2+12*n)+4>e.byteLength)){for(s=0;st.byteLength)console.log("Invalid Exif data: Invalid segment size.");else if(0===t.getUint16(i+8)){switch(t.getUint16(d)){case 18761:r=!0;break;case 19789:r=!1;break;default:return void console.log("Invalid Exif data: Invalid byte alignment marker.")}42===t.getUint16(d+2,r)?(s=t.getUint32(d+4,r),o.exif=new e.ExifMap,(s=e.parseExifTags(t,d,d+s,r,o))&&!n.disableExifThumbnail&&(l={exif:{}},s=e.parseExifTags(t,d,d+s,r,l),l.exif[513]&&(o.exif.Thumbnail=e.getExifThumbnail(t,d+l.exif[513],l.exif[514]))),o.exif[34665]&&!n.disableExifSub&&e.parseExifTags(t,d,d+o.exif[34665],r,o),o.exif[34853]&&!n.disableExifGps&&e.parseExifTags(t,d,d+o.exif[34853],r,o)):console.log("Invalid Exif data: Missing TIFF marker.")}else console.log("Invalid Exif data: Missing byte alignment offset.")}},e.metaDataParsers.jpeg[65505].push(e.parseExifData)}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-exif"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-exif")):e(window.loadImage)}(function(e){"use strict";e.ExifMap.prototype.tags={256:"ImageWidth",257:"ImageHeight",34665:"ExifIFDPointer",34853:"GPSInfoIFDPointer",40965:"InteroperabilityIFDPointer",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",274:"Orientation",277:"SamplesPerPixel",284:"PlanarConfiguration",530:"YCbCrSubSampling",531:"YCbCrPositioning",282:"XResolution",283:"YResolution",296:"ResolutionUnit",273:"StripOffsets",278:"RowsPerStrip",279:"StripByteCounts",513:"JPEGInterchangeFormat",514:"JPEGInterchangeFormatLength",301:"TransferFunction",318:"WhitePoint",319:"PrimaryChromaticities",529:"YCbCrCoefficients",532:"ReferenceBlackWhite",306:"DateTime",270:"ImageDescription",271:"Make",272:"Model",305:"Software",315:"Artist",33432:"Copyright",36864:"ExifVersion",40960:"FlashpixVersion",40961:"ColorSpace",40962:"PixelXDimension",40963:"PixelYDimension",42240:"Gamma",37121:"ComponentsConfiguration",37122:"CompressedBitsPerPixel",37500:"MakerNote",37510:"UserComment",40964:"RelatedSoundFile",36867:"DateTimeOriginal",36868:"DateTimeDigitized",37520:"SubSecTime",37521:"SubSecTimeOriginal",37522:"SubSecTimeDigitized",33434:"ExposureTime",33437:"FNumber",34850:"ExposureProgram",34852:"SpectralSensitivity",34855:"PhotographicSensitivity",34856:"OECF",34864:"SensitivityType",34865:"StandardOutputSensitivity",34866:"RecommendedExposureIndex",34867:"ISOSpeed",34868:"ISOSpeedLatitudeyyy",34869:"ISOSpeedLatitudezzz",37377:"ShutterSpeedValue",37378:"ApertureValue",37379:"BrightnessValue",37380:"ExposureBias",37381:"MaxApertureValue",37382:"SubjectDistance",37383:"MeteringMode",37384:"LightSource",37385:"Flash",37396:"SubjectArea",37386:"FocalLength",41483:"FlashEnergy",41484:"SpatialFrequencyResponse",41486:"FocalPlaneXResolution",41487:"FocalPlaneYResolution",41488:"FocalPlaneResolutionUnit",41492:"SubjectLocation",41493:"ExposureIndex",41495:"SensingMethod",41728:"FileSource",41729:"SceneType",41730:"CFAPattern",41985:"CustomRendered",41986:"ExposureMode",41987:"WhiteBalance",41988:"DigitalZoomRatio",41989:"FocalLengthIn35mmFilm",41990:"SceneCaptureType",41991:"GainControl",41992:"Contrast",41993:"Saturation",41994:"Sharpness",41995:"DeviceSettingDescription",41996:"SubjectDistanceRange",42016:"ImageUniqueID",42032:"CameraOwnerName",42033:"BodySerialNumber",42034:"LensSpecification",42035:"LensMake",42036:"LensModel",42037:"LensSerialNumber",0:"GPSVersionID",1:"GPSLatitudeRef",2:"GPSLatitude",3:"GPSLongitudeRef",4:"GPSLongitude",5:"GPSAltitudeRef",6:"GPSAltitude",7:"GPSTimeStamp",8:"GPSSatellites",9:"GPSStatus",10:"GPSMeasureMode",11:"GPSDOP",12:"GPSSpeedRef",13:"GPSSpeed",14:"GPSTrackRef",15:"GPSTrack",16:"GPSImgDirectionRef",17:"GPSImgDirection",18:"GPSMapDatum",19:"GPSDestLatitudeRef",20:"GPSDestLatitude",21:"GPSDestLongitudeRef",22:"GPSDestLongitude",23:"GPSDestBearingRef",24:"GPSDestBearing",25:"GPSDestDistanceRef",26:"GPSDestDistance",27:"GPSProcessingMethod",28:"GPSAreaInformation",29:"GPSDateStamp",30:"GPSDifferential",31:"GPSHPositioningError"},e.ExifMap.prototype.stringValues={ExposureProgram:{0:"Undefined",1:"Manual",2:"Normal program",3:"Aperture priority",4:"Shutter priority",5:"Creative program",6:"Action program",7:"Portrait mode",8:"Landscape mode"},MeteringMode:{0:"Unknown",1:"Average",2:"CenterWeightedAverage",3:"Spot",4:"MultiSpot",5:"Pattern",6:"Partial",255:"Other"},LightSource:{0:"Unknown",1:"Daylight",2:"Fluorescent",3:"Tungsten (incandescent light)",4:"Flash",9:"Fine weather",10:"Cloudy weather",11:"Shade",12:"Daylight fluorescent (D 5700 - 7100K)",13:"Day white fluorescent (N 4600 - 5400K)",14:"Cool white fluorescent (W 3900 - 4500K)",15:"White fluorescent (WW 3200 - 3700K)",17:"Standard light A",18:"Standard light B",19:"Standard light C",20:"D55",21:"D65",22:"D75",23:"D50",24:"ISO studio tungsten",255:"Other"},Flash:{0:"Flash did not fire",1:"Flash fired",5:"Strobe return light not detected",7:"Strobe return light detected",9:"Flash fired, compulsory flash mode",13:"Flash fired, compulsory flash mode, return light not detected",15:"Flash fired, compulsory flash mode, return light detected",16:"Flash did not fire, compulsory flash mode",24:"Flash did not fire, auto mode",25:"Flash fired, auto mode",29:"Flash fired, auto mode, return light not detected",31:"Flash fired, auto mode, return light detected",32:"No flash function",65:"Flash fired, red-eye reduction mode",69:"Flash fired, red-eye reduction mode, return light not detected",71:"Flash fired, red-eye reduction mode, return light detected",73:"Flash fired, compulsory flash mode, red-eye reduction mode",77:"Flash fired, compulsory flash mode, red-eye reduction mode, return light not detected",79:"Flash fired, compulsory flash mode, red-eye reduction mode, return light detected",89:"Flash fired, auto mode, red-eye reduction mode",93:"Flash fired, auto mode, return light not detected, red-eye reduction mode",95:"Flash fired, auto mode, return light detected, red-eye reduction mode"},SensingMethod:{1:"Undefined",2:"One-chip color area sensor",3:"Two-chip color area sensor",4:"Three-chip color area sensor",5:"Color sequential area sensor",7:"Trilinear sensor",8:"Color sequential linear sensor"},SceneCaptureType:{0:"Standard",1:"Landscape",2:"Portrait",3:"Night scene"},SceneType:{1:"Directly photographed"},CustomRendered:{0:"Normal process",1:"Custom process"},WhiteBalance:{0:"Auto white balance",1:"Manual white balance"},GainControl:{0:"None",1:"Low gain up",2:"High gain up",3:"Low gain down",4:"High gain down"},Contrast:{0:"Normal",1:"Soft",2:"Hard"},Saturation:{0:"Normal",1:"Low saturation",2:"High saturation"},Sharpness:{0:"Normal",1:"Soft",2:"Hard"},SubjectDistanceRange:{0:"Unknown",1:"Macro",2:"Close view",3:"Distant view"},FileSource:{3:"DSC"},ComponentsConfiguration:{0:"",1:"Y",2:"Cb",3:"Cr",4:"R",5:"G",6:"B"},Orientation:{1:"top-left",2:"top-right",3:"bottom-right",4:"bottom-left",5:"left-top",6:"right-top",7:"right-bottom",8:"left-bottom"}},e.ExifMap.prototype.getText=function(e){var t=this.get(e);switch(e){case"LightSource":case"Flash":case"MeteringMode":case"ExposureProgram":case"SensingMethod":case"SceneCaptureType":case"SceneType":case"CustomRendered":case"WhiteBalance":case"GainControl":case"Contrast":case"Saturation":case"Sharpness":case"SubjectDistanceRange":case"FileSource":case"Orientation":return this.stringValues[e][t];case"ExifVersion":case"FlashpixVersion":if(!t)return;return String.fromCharCode(t[0],t[1],t[2],t[3]);case"ComponentsConfiguration":if(!t)return;return this.stringValues[e][t[0]]+this.stringValues[e][t[1]]+this.stringValues[e][t[2]]+this.stringValues[e][t[3]];case"GPSVersionID":if(!t)return;return t[0]+"."+t[1]+"."+t[2]+"."+t[3]}return String(t)},function(e){var t,i=e.tags,a=e.map;for(t in i)i.hasOwnProperty(t)&&(a[i[t]]=t)}(e.ExifMap.prototype),e.ExifMap.prototype.getAll=function(){var e,t,i={};for(e in this)this.hasOwnProperty(e)&&(t=this.tags[e])&&(i[t]=this.getText(t));return i}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-scale","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-scale"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";var t=e.hasCanvasOption,i=e.hasMetaOption,a=e.transformCoordinates,o=e.getTransformedOptions;e.hasCanvasOption=function(i){return!!i.orientation||t.call(e,i)},e.hasMetaOption=function(t){return t&&!0===t.orientation||i.call(e,t)},e.transformCoordinates=function(t,i){a.call(e,t,i);var o=t.getContext("2d"),n=t.width,r=t.height,s=t.style.width,l=t.style.height,d=i.orientation;if(d&&!(d>8))switch(d>4&&(t.width=r,t.height=n,t.style.width=l,t.style.height=s),d){case 2:o.translate(n,0),o.scale(-1,1);break;case 3:o.translate(n,r),o.rotate(Math.PI);break;case 4:o.translate(0,r),o.scale(1,-1);break;case 5:o.rotate(.5*Math.PI),o.scale(1,-1);break;case 6:o.rotate(.5*Math.PI),o.translate(0,-r);break;case 7:o.rotate(.5*Math.PI),o.translate(n,-r),o.scale(-1,1);break;case 8:o.rotate(-.5*Math.PI),o.translate(-n,0)}},e.getTransformedOptions=function(t,i,a){var n,r,s=o.call(e,t,i),l=s.orientation;if(!0===l&&a&&a.exif&&(l=a.exif.get("Orientation")),!l||l>8||1===l)return s;n={};for(r in s)s.hasOwnProperty(r)&&(n[r]=s[r]);switch(n.orientation=l,l){case 2:n.left=s.right,n.right=s.left;break;case 3:n.left=s.right,n.top=s.bottom,n.right=s.left,n.bottom=s.top;break;case 4:n.top=s.bottom,n.bottom=s.top;break;case 5:n.left=s.top,n.top=s.left,n.right=s.bottom,n.bottom=s.right;break;case 6:n.left=s.top,n.top=s.right,n.right=s.bottom,n.bottom=s.left;break;case 7:n.left=s.bottom,n.top=s.right,n.right=s.top,n.bottom=s.left;break;case 8:n.left=s.bottom,n.top=s.left,n.right=s.top,n.bottom=s.right}return n.orientation>4&&(n.maxWidth=s.maxHeight,n.maxHeight=s.maxWidth,n.minWidth=s.minHeight,n.minHeight=s.minWidth,n.sourceWidth=s.sourceHeight,n.sourceHeight=s.sourceWidth),n}}); 2 | //# sourceMappingURL=load-image.all.min.js.map 3 | -------------------------------------------------------------------------------- /demo/web/load-image.all.min.js.map: -------------------------------------------------------------------------------- 1 | {"version":3,"sources":["load-image.js","load-image-scale.js","load-image-meta.js","load-image-fetch.js","load-image-exif.js","load-image-exif-map.js","load-image-orientation.js"],"names":["$","loadImage","file","callback","options","url","img","document","createElement","onerror","event","onload","fetchBlob","blob","createObjectURL","crossOrigin","src","isInstanceOf","_objectURL","readFile","e","target","result","revokeHelper","noRevoke","revokeObjectURL","urlAPI","URL","webkitURL","type","obj","Object","prototype","toString","call","transform","data","method","FileReader","fileReader","define","amd","module","exports","window","this","factory","require","originalTransform","scale","transformCoordinates","getTransformedOptions","newOptions","i","width","height","aspectRatio","hasOwnProperty","crop","naturalWidth","naturalHeight","maxWidth","maxHeight","renderImageToCanvas","canvas","sourceX","sourceY","sourceWidth","sourceHeight","destX","destY","destWidth","destHeight","getContext","drawImage","hasCanvasOption","scaleUp","Math","max","minWidth","minHeight","scaleDown","min","pixelRatio","downsamplingRatio","tmp","useCanvas","left","top","undefined","right","bottom","contain","cover","style","hasblobSlice","Blob","slice","webkitSlice","mozSlice","blobSlice","apply","arguments","metaDataParsers","jpeg","65505","parseMetaData","that","maxMetaDataSize","DataView","size","error","console","log","markerBytes","markerLength","parsers","buffer","dataView","offset","maxOffset","byteLength","headLength","getUint16","length","disableImageHead","imageHead","Uint8Array","subarray","hasMetaOption","meta","fetch","Request","then","response","catch","err","ExifMap","map","Orientation","get","id","getExifThumbnail","exifTagTypes","1","getValue","dataOffset","getUint8","2","String","fromCharCode","ascii","3","littleEndian","4","getUint32","5","9","getInt32","10","getExifValue","tiffOffset","tagSize","values","str","c","tagType","parseExifTag","tag","exif","parseExifTags","dirOffset","tagsNumber","dirEndOffset","parseExifData","disableExif","thumbnailData","disableExifThumbnail","Thumbnail","disableExifSub","disableExifGps","push","tags","256","257","34665","34853","40965","258","259","262","274","277","284","530","531","282","283","296","273","278","279","513","514","301","318","319","529","532","306","270","271","272","305","315","33432","36864","40960","40961","40962","40963","42240","37121","37122","37500","37510","40964","36867","36868","37520","37521","37522","33434","33437","34850","34852","34855","34856","34864","34865","34866","34867","34868","34869","37377","37378","37379","37380","37381","37382","37383","37384","37385","37396","37386","41483","41484","41486","41487","41488","41492","41493","41495","41728","41729","41730","41985","41986","41987","41988","41989","41990","41991","41992","41993","41994","41995","41996","42016","42032","42033","42034","42035","42036","42037","0","6","7","8","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","stringValues","ExposureProgram","MeteringMode","255","LightSource","Flash","32","65","69","71","73","77","79","89","93","95","SensingMethod","SceneCaptureType","SceneType","CustomRendered","WhiteBalance","GainControl","Contrast","Saturation","Sharpness","SubjectDistanceRange","FileSource","ComponentsConfiguration","getText","value","exifMapPrototype","prop","getAll","originalHasCanvasOption","originalHasMetaOption","originalTransformCoordinates","originalGetTransformedOptions","orientation","ctx","styleWidth","styleHeight","translate","rotate","PI","opts"],"mappings":"CAaC,SAAWA,GACV,aAKA,SAASC,EAAWC,EAAMC,EAAUC,GAClC,IACIC,EADAC,EAAMC,SAASC,cAAc,OAQjC,OANAF,EAAIG,QAAU,SAAUC,GACtB,OAAOT,EAAUQ,QAAQH,EAAKI,EAAOR,EAAMC,EAAUC,IAEvDE,EAAIK,OAAS,SAAUD,GACrB,OAAOT,EAAUU,OAAOL,EAAKI,EAAOR,EAAMC,EAAUC,IAElC,iBAATF,GACTD,EAAUW,UACRV,EACA,SAAUW,GACJA,GACFX,EAAOW,EACPR,EAAMJ,EAAUa,gBAAgBZ,KAEhCG,EAAMH,EACFE,GAAWA,EAAQW,cACrBT,EAAIS,YAAcX,EAAQW,cAG9BT,EAAIU,IAAMX,GAEZD,GAEKE,GAEPL,EAAUgB,aAAa,OAAQf,IAG/BD,EAAUgB,aAAa,OAAQf,IAE/BG,EAAMC,EAAIY,WAAajB,EAAUa,gBAAgBZ,KAE/CI,EAAIU,IAAMX,EACHC,GAEFL,EAAUkB,SAASjB,EAAM,SAAUkB,GACxC,IAAIC,EAASD,EAAEC,OACXA,GAAUA,EAAOC,OACnBhB,EAAIU,IAAMK,EAAOC,OACRnB,GACTA,EAASiB,UAhBR,EA4BT,SAASG,EAAcjB,EAAKF,IACtBE,EAAIY,YAAgBd,GAAWA,EAAQoB,WACzCvB,EAAUwB,gBAAgBnB,EAAIY,mBACvBZ,EAAIY,YARf,IAAIQ,EACD1B,EAAEc,iBAAmBd,GACrBA,EAAE2B,KAAOA,IAAIF,iBAAmBE,KAChC3B,EAAE4B,WAAaA,UAYlB3B,EAAUW,UAAY,SAAUP,EAAKF,EAAUC,GAC7CD,KAGFF,EAAUgB,aAAe,SAAUY,EAAMC,GAEvC,OAAOC,OAAOC,UAAUC,SAASC,KAAKJ,KAAS,WAAaD,EAAO,KAGrE5B,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GAC5DjC,EAASG,EAAK8B,IAGhBnC,EAAUQ,QAAU,SAAUH,EAAKI,EAAOR,EAAMC,EAAUC,GACxDmB,EAAajB,EAAKF,GACdD,GACFA,EAAS+B,KAAK5B,EAAKI,IAIvBT,EAAUU,OAAS,SAAUL,EAAKI,EAAOR,EAAMC,EAAUC,GACvDmB,EAAajB,EAAKF,GACdD,GACFF,EAAUkC,UAAU7B,EAAKF,EAASD,EAAUD,OAIhDD,EAAUa,gBAAkB,SAAUZ,GACpC,QAAOwB,GAASA,EAAOZ,gBAAgBZ,IAGzCD,EAAUwB,gBAAkB,SAAUpB,GACpC,QAAOqB,GAASA,EAAOD,gBAAgBpB,IAMzCJ,EAAUkB,SAAW,SAAUjB,EAAMC,EAAUkC,GAC7C,GAAIrC,EAAEsC,WAAY,CAChB,IAAIC,EAAa,IAAID,WAGrB,GAFAC,EAAW5B,OAAS4B,EAAW9B,QAAUN,EACzCkC,EAASA,GAAU,gBACfE,EAAWF,GAEb,OADAE,EAAWF,GAAQnC,GACZqC,EAGX,OAAO,GAGa,mBAAXC,QAAyBA,OAAOC,IACzCD,OAAO,WACL,OAAOvC,IAEkB,iBAAXyC,QAAuBA,OAAOC,QAC9CD,OAAOC,QAAU1C,EAEjBD,EAAEC,UAAYA,EAjIjB,CAmIqB,oBAAX2C,QAA0BA,QAAWC,MCnI/C,SAAWC,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,gBAAiBM,GAEzBA,EAD2B,iBAAXJ,QAAuBA,OAAOC,QACtCI,QAAQ,gBAGRH,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEA,IAAI+C,EAAoB/C,EAAUkC,UAElClC,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GAC5DY,EAAkBd,KAChBjC,EACAA,EAAUgD,MAAM3C,EAAKF,EAASgC,GAC9BhC,EACAD,EACAD,EACAkC,IAOJnC,EAAUiD,qBAAuB,aAKjCjD,EAAUkD,sBAAwB,SAAU7C,EAAKF,GAC/C,IACIgD,EACAC,EACAC,EACAC,EAJAC,EAAcpD,EAAQoD,YAK1B,IAAKA,EACH,OAAOpD,EAETgD,KACA,IAAKC,KAAKjD,EACJA,EAAQqD,eAAeJ,KACzBD,EAAWC,GAAKjD,EAAQiD,IAa5B,OAVAD,EAAWM,MAAO,EAClBJ,EAAQhD,EAAIqD,cAAgBrD,EAAIgD,MAChCC,EAASjD,EAAIsD,eAAiBtD,EAAIiD,OAC9BD,EAAQC,EAASC,GACnBJ,EAAWS,SAAWN,EAASC,EAC/BJ,EAAWU,UAAYP,IAEvBH,EAAWS,SAAWP,EACtBF,EAAWU,UAAYR,EAAQE,GAE1BJ,GAITnD,EAAU8D,oBAAsB,SAC9BC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,GAeA,OAbAR,EACGS,WAAW,MACXC,UACCpE,EACA2D,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,GAEGR,GAIT/D,EAAU0E,gBAAkB,SAAUvE,GACpC,OAAOA,EAAQ4D,QAAU5D,EAAQsD,QAAUtD,EAAQoD,aAQrDvD,EAAUgD,MAAQ,SAAU3C,EAAKF,EAASgC,GAqBxC,SAASwC,IACP,IAAI3B,EAAQ4B,KAAKC,KACdC,GAAYR,GAAaA,GACzBS,GAAaR,GAAcA,GAE1BvB,EAAQ,IACVsB,GAAatB,EACbuB,GAAcvB,GAGlB,SAASgC,IACP,IAAIhC,EAAQ4B,KAAKK,KACdrB,GAAYU,GAAaA,GACzBT,GAAaU,GAAcA,GAE1BvB,EAAQ,IACVsB,GAAatB,EACbuB,GAAcvB,GArClB7C,EAAUA,MACV,IAQIyD,EACAC,EACAiB,EACAC,EACAb,EACAC,EACAH,EACAC,EACAiB,EACAC,EACAC,EAlBArB,EAASzD,SAASC,cAAc,UAChC8E,EACFhF,EAAImE,YACHxE,EAAU0E,gBAAgBvE,IAAY4D,EAAOS,WAC5CnB,EAAQhD,EAAIqD,cAAgBrD,EAAIgD,MAChCC,EAASjD,EAAIsD,eAAiBtD,EAAIiD,OAClCgB,EAAYjB,EACZkB,EAAajB,EAuFjB,GAvDI+B,IAEFrB,GADA7D,EAAUH,EAAUkD,sBAAsB7C,EAAKF,EAASgC,IACtCmD,MAAQ,EAC1BrB,EAAU9D,EAAQoF,KAAO,EACrBpF,EAAQ+D,aACVA,EAAc/D,EAAQ+D,iBACAsB,IAAlBrF,EAAQsF,YAAwCD,IAAjBrF,EAAQmF,OACzCtB,EAAUX,EAAQa,EAAc/D,EAAQsF,QAG1CvB,EAAcb,EAAQW,GAAW7D,EAAQsF,OAAS,GAEhDtF,EAAQgE,cACVA,EAAehE,EAAQgE,kBACAqB,IAAnBrF,EAAQuF,aAAwCF,IAAhBrF,EAAQoF,MAC1CtB,EAAUX,EAASa,EAAehE,EAAQuF,SAG5CvB,EAAeb,EAASW,GAAW9D,EAAQuF,QAAU,GAEvDpB,EAAYJ,EACZK,EAAaJ,GAEfP,EAAWzD,EAAQyD,SACnBC,EAAY1D,EAAQ0D,UACpBiB,EAAW3E,EAAQ2E,SACnBC,EAAY5E,EAAQ4E,UAChBM,GAAazB,GAAYC,GAAa1D,EAAQsD,MAChDa,EAAYV,EACZW,EAAaV,GACbuB,EAAMlB,EAAcC,EAAeP,EAAWC,GACpC,GACRM,EAAeN,EAAYK,EAAcN,OACrB4B,IAAhBrF,EAAQoF,UAAwCC,IAAnBrF,EAAQuF,SACvCzB,GAAWX,EAASa,GAAgB,IAE7BiB,EAAM,IACflB,EAAcN,EAAWO,EAAeN,OACnB2B,IAAjBrF,EAAQmF,WAAwCE,IAAlBrF,EAAQsF,QACxCzB,GAAWX,EAAQa,GAAe,OAIlC/D,EAAQwF,SAAWxF,EAAQyF,SAC7Bd,EAAWlB,EAAWA,GAAYkB,EAClCC,EAAYlB,EAAYA,GAAakB,GAEnC5E,EAAQyF,OACVZ,IACAL,MAEAA,IACAK,MAGAK,EAAW,CAUb,IATAH,EAAa/E,EAAQ+E,YACJ,IACfnB,EAAO8B,MAAMxC,MAAQiB,EAAY,KACjCP,EAAO8B,MAAMvC,OAASiB,EAAa,KACnCD,GAAaY,EACbX,GAAcW,EACdnB,EAAOS,WAAW,MAAMxB,MAAMkC,EAAYA,KAE5CC,EAAoBhF,EAAQgF,mBAEN,GACpBA,EAAoB,GACpBb,EAAYJ,GACZK,EAAaJ,EAEb,KAAOD,EAAciB,EAAoBb,GACvCP,EAAOV,MAAQa,EAAciB,EAC7BpB,EAAOT,OAASa,EAAegB,EAC/BnF,EAAU8D,oBACRC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACA,EACA,EACAJ,EAAOV,MACPU,EAAOT,QAETU,EAAU,EACVC,EAAU,EACVC,EAAcH,EAAOV,MACrBc,EAAeJ,EAAOT,QACtBjD,EAAMC,SAASC,cAAc,WACzB8C,MAAQa,EACZ7D,EAAIiD,OAASa,EACbnE,EAAU8D,oBACRzD,EACA0D,EACA,EACA,EACAG,EACAC,EACA,EACA,EACAD,EACAC,GAON,OAHAJ,EAAOV,MAAQiB,EACfP,EAAOT,OAASiB,EAChBvE,EAAUiD,qBAAqBc,EAAQ5D,GAChCH,EAAU8D,oBACfC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACA,EACA,EACAG,EACAC,GAKJ,OAFAlE,EAAIgD,MAAQiB,EACZjE,EAAIiD,OAASiB,EACNlE,KCxQV,SAAWwC,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,gBAAiBM,GAEzBA,EAD2B,iBAAXJ,QAAuBA,OAAOC,QACtCI,QAAQ,gBAGRH,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEA,IAAI8F,EACc,oBAATC,OACNA,KAAKhE,UAAUiE,OACdD,KAAKhE,UAAUkE,aACfF,KAAKhE,UAAUmE,UAEnBlG,EAAUmG,UACRL,GACA,WAEE,OADYlD,KAAKoD,OAASpD,KAAKqD,aAAerD,KAAKsD,UACtCE,MAAMxD,KAAMyD,YAG7BrG,EAAUsG,iBACRC,MACEC,WAUJxG,EAAUyG,cAAgB,SAAUxG,EAAMC,EAAUC,EAASgC,GAC3DhC,EAAUA,MACVgC,EAAOA,MACP,IAAIuE,EAAO9D,KAEP+D,EAAkBxG,EAAQwG,iBAAmB,UAE3B,oBAAbC,UACP3G,GACAA,EAAK4G,MAAQ,IACC,eAAd5G,EAAK2B,MACL5B,EAAUmG,YAITnG,EAAUkB,SACTlB,EAAUmG,UAAUlE,KAAKhC,EAAM,EAAG0G,GAClC,SAAUxF,GACR,GAAIA,EAAEC,OAAO0F,MAIX,OAFAC,QAAQC,IAAI7F,EAAEC,OAAO0F,YACrB5G,EAASiC,GAOX,IAKI8E,EACAC,EACAC,EACA/D,EARAgE,EAASjG,EAAEC,OAAOC,OAClBgG,EAAW,IAAIT,SAASQ,GACxBE,EAAS,EACTC,EAAYF,EAASG,WAAa,EAClCC,EAAaH,EAMjB,GAA8B,QAA1BD,EAASK,UAAU,GAAe,CACpC,KAAOJ,EAASC,KACdN,EAAcI,EAASK,UAAUJ,KAKf,OAAUL,GAAe,OACzB,QAAhBA,IAPuB,CAcvB,GADAC,EAAeG,EAASK,UAAUJ,EAAS,GAAK,EAC5CA,EAASJ,EAAeG,EAASG,WAAY,CAC/CT,QAAQC,IAAI,4CACZ,MAGF,GADAG,EAAUnH,EAAUsG,gBAAgBC,KAAKU,GAEvC,IAAK7D,EAAI,EAAGA,EAAI+D,EAAQQ,OAAQvE,GAAK,EACnC+D,EAAQ/D,GAAGnB,KACTyE,EACAW,EACAC,EACAJ,EACA/E,EACAhC,GAKNsH,EADAH,GAAUJ,GAUT/G,EAAQyH,kBAAoBH,EAAa,IACxCL,EAAOpB,MACT7D,EAAK0F,UAAYT,EAAOpB,MAAM,EAAGyB,GAIjCtF,EAAK0F,UAAY,IAAIC,WAAWV,GAAQW,SAAS,EAAGN,SAIxDV,QAAQC,IAAI,2CAEd9G,EAASiC,IAEX,sBAGFjC,EAASiC,IAKbnC,EAAUgI,cAAgB,SAAU7H,GAClC,OAAOA,GAAWA,EAAQ8H,MAG5B,IAAIlF,EAAoB/C,EAAUkC,UAClClC,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GACxDnC,EAAUgI,cAAc7H,GAC1BH,EAAUyG,cACRxG,EACA,SAAUkC,GACRY,EAAkBd,KAAKjC,EAAWK,EAAKF,EAASD,EAAUD,EAAMkC,IAElEhC,EACAgC,GAGFY,EAAkBqD,MAAMpG,EAAWqG,cCjKxC,SAAWxD,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEqB,oBAAVkI,OAA4C,oBAAZC,UACzCnI,EAAUW,UAAY,SAAUP,EAAKF,EAAUC,GAC7C,GAAIH,EAAUgI,cAAc7H,GAC1B,OAAO+H,MAAM,IAAIC,QAAQ/H,EAAKD,IAC3BiI,KAAK,SAAUC,GACd,OAAOA,EAASzH,SAEjBwH,KAAKlI,GACLoI,MAAM,SAAUC,GACfxB,QAAQC,IAAIuB,GACZrI,MAGJA,QC3BP,SAAW2C,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEAA,EAAUwI,QAAU,WAClB,OAAO5F,MAGT5C,EAAUwI,QAAQzG,UAAU0G,KAC1BC,YAAa,KAGf1I,EAAUwI,QAAQzG,UAAU4G,IAAM,SAAUC,GAC1C,OAAOhG,KAAKgG,IAAOhG,KAAKA,KAAK6F,IAAIG,KAGnC5I,EAAU6I,iBAAmB,SAAUxB,EAAUC,EAAQK,GACvD,GAAKA,KAAUL,EAASK,EAASN,EAASG,YAI1C,OAAOxH,EAAUa,gBACf,IAAIkF,MAAMsB,EAASD,OAAOpB,MAAMsB,EAAQA,EAASK,MAJjDZ,QAAQC,IAAI,+CAQhBhH,EAAU8I,cAERC,GACEC,SAAU,SAAU3B,EAAU4B,GAC5B,OAAO5B,EAAS6B,SAASD,IAE3BpC,KAAM,GAGRsC,GACEH,SAAU,SAAU3B,EAAU4B,GAC5B,OAAOG,OAAOC,aAAahC,EAAS6B,SAASD,KAE/CpC,KAAM,EACNyC,OAAO,GAGTC,GACEP,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASK,UAAUuB,EAAYO,IAExC3C,KAAM,GAGR4C,GACET,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASqC,UAAUT,EAAYO,IAExC3C,KAAM,GAGR8C,GACEX,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OACEnC,EAASqC,UAAUT,EAAYO,GAC/BnC,EAASqC,UAAUT,EAAa,EAAGO,IAGvC3C,KAAM,GAGR+C,GACEZ,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASwC,SAASZ,EAAYO,IAEvC3C,KAAM,GAGRiD,IACEd,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OACEnC,EAASwC,SAASZ,EAAYO,GAC9BnC,EAASwC,SAASZ,EAAa,EAAGO,IAGtC3C,KAAM,IAIV7G,EAAU8I,aAAa,GAAK9I,EAAU8I,aAAa,GAEnD9I,EAAU+J,aAAe,SACvB1C,EACA2C,EACA1C,EACA1F,EACA+F,EACA6B,GAEA,IACIS,EACAhB,EACAiB,EACA9G,EACA+G,EACAC,EANAC,EAAUrK,EAAU8I,aAAalH,GAOrC,GAAKyI,EAAL,CAWA,GAPAJ,EAAUI,EAAQxD,KAAOc,KAGzBsB,EACEgB,EAAU,EACND,EAAa3C,EAASqC,UAAUpC,EAAS,EAAGkC,GAC5ClC,EAAS,GACE2C,EAAU5C,EAASG,YAApC,CAIA,GAAe,IAAXG,EACF,OAAO0C,EAAQrB,SAAS3B,EAAU4B,EAAYO,GAGhD,IADAU,KACK9G,EAAI,EAAGA,EAAIuE,EAAQvE,GAAK,EAC3B8G,EAAO9G,GAAKiH,EAAQrB,SAClB3B,EACA4B,EAAa7F,EAAIiH,EAAQxD,KACzB2C,GAGJ,GAAIa,EAAQf,MAAO,CAGjB,IAFAa,EAAM,GAED/G,EAAI,EAAGA,EAAI8G,EAAOvC,QAGX,QAFVyC,EAAIF,EAAO9G,IADkBA,GAAK,EAMlC+G,GAAOC,EAET,OAAOD,EAET,OAAOD,EA3BLnD,QAAQC,IAAI,gDAXZD,QAAQC,IAAI,yCAyChBhH,EAAUsK,aAAe,SACvBjD,EACA2C,EACA1C,EACAkC,EACArH,GAEA,IAAIoI,EAAMlD,EAASK,UAAUJ,EAAQkC,GACrCrH,EAAKqI,KAAKD,GAAOvK,EAAU+J,aACzB1C,EACA2C,EACA1C,EACAD,EAASK,UAAUJ,EAAS,EAAGkC,GAC/BnC,EAASqC,UAAUpC,EAAS,EAAGkC,GAC/BA,IAIJxJ,EAAUyK,cAAgB,SACxBpD,EACA2C,EACAU,EACAlB,EACArH,GAEA,IAAIwI,EAAYC,EAAcxH,EAC9B,GAAIsH,EAAY,EAAIrD,EAASG,WAC3BT,QAAQC,IAAI,oDADd,CAMA,GAFA2D,EAAatD,EAASK,UAAUgD,EAAWlB,MAC3CoB,EAAeF,EAAY,EAAI,GAAKC,GACjB,EAAItD,EAASG,YAAhC,CAIA,IAAKpE,EAAI,EAAGA,EAAIuH,EAAYvH,GAAK,EAC/BR,KAAK0H,aACHjD,EACA2C,EACAU,EAAY,EAAI,GAAKtH,EACrBoG,EACArH,GAIJ,OAAOkF,EAASqC,UAAUkB,EAAcpB,GAbtCzC,QAAQC,IAAI,gDAgBhBhH,EAAU6K,cAAgB,SAAUxD,EAAUC,EAAQK,EAAQxF,EAAMhC,GAClE,IAAIA,EAAQ2K,YAAZ,CAGA,IACItB,EACAkB,EACAK,EAHAf,EAAa1C,EAAS,GAK1B,GAAuC,aAAnCD,EAASqC,UAAUpC,EAAS,GAIhC,GAAI0C,EAAa,EAAI3C,EAASG,WAC5BT,QAAQC,IAAI,iDAId,GAAuC,IAAnCK,EAASK,UAAUJ,EAAS,GAAhC,CAKA,OAAQD,EAASK,UAAUsC,IACzB,KAAK,MACHR,GAAe,EACf,MACF,KAAK,MACHA,GAAe,EACf,MACF,QAEE,YADAzC,QAAQC,IAAI,qDAIyC,KAArDK,EAASK,UAAUsC,EAAa,EAAGR,IAKvCkB,EAAYrD,EAASqC,UAAUM,EAAa,EAAGR,GAE/CrH,EAAKqI,KAAO,IAAIxK,EAAUwI,SAG1BkC,EAAY1K,EAAUyK,cACpBpD,EACA2C,EACAA,EAAaU,EACblB,EACArH,MAEgBhC,EAAQ6K,uBACxBD,GAAkBP,SAClBE,EAAY1K,EAAUyK,cACpBpD,EACA2C,EACAA,EAAaU,EACblB,EACAuB,GAGEA,EAAcP,KAAK,OACrBrI,EAAKqI,KAAKS,UAAYjL,EAAU6I,iBAC9BxB,EACA2C,EAAae,EAAcP,KAAK,KAChCO,EAAcP,KAAK,QAKrBrI,EAAKqI,KAAK,SAAYrK,EAAQ+K,gBAChClL,EAAUyK,cACRpD,EACA2C,EACAA,EAAa7H,EAAKqI,KAAK,OACvBhB,EACArH,GAIAA,EAAKqI,KAAK,SAAYrK,EAAQgL,gBAChCnL,EAAUyK,cACRpD,EACA2C,EACAA,EAAa7H,EAAKqI,KAAK,OACvBhB,EACArH,IAnDF4E,QAAQC,IAAI,gDAjBZD,QAAQC,IAAI,uDA0EhBhH,EAAUsG,gBAAgBC,KAAK,OAAQ6E,KAAKpL,EAAU6K,iBCrSvD,SAAWhI,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEAA,EAAUwI,QAAQzG,UAAUsJ,MAI1BC,IAAQ,aACRC,IAAQ,cACRC,MAAQ,iBACRC,MAAQ,oBACRC,MAAQ,6BACRC,IAAQ,gBACRC,IAAQ,cACRC,IAAQ,4BACRC,IAAQ,cACRC,IAAQ,kBACRC,IAAQ,sBACRC,IAAQ,mBACRC,IAAQ,mBACRC,IAAQ,cACRC,IAAQ,cACRC,IAAQ,iBACRC,IAAQ,eACRC,IAAQ,eACRC,IAAQ,kBACRC,IAAQ,wBACRC,IAAQ,8BACRC,IAAQ,mBACRC,IAAQ,aACRC,IAAQ,wBACRC,IAAQ,oBACRC,IAAQ,sBACRC,IAAQ,WACRC,IAAQ,mBACRC,IAAQ,OACRC,IAAQ,QACRC,IAAQ,WACRC,IAAQ,SACRC,MAAQ,YAIRC,MAAQ,cACRC,MAAQ,kBACRC,MAAQ,aACRC,MAAQ,kBACRC,MAAQ,kBACRC,MAAQ,QACRC,MAAQ,0BACRC,MAAQ,yBACRC,MAAQ,YACRC,MAAQ,cACRC,MAAQ,mBACRC,MAAQ,mBACRC,MAAQ,oBACRC,MAAQ,aACRC,MAAQ,qBACRC,MAAQ,sBACRC,MAAQ,eACRC,MAAQ,UACRC,MAAQ,kBACRC,MAAQ,sBACRC,MAAQ,0BACRC,MAAQ,OACRC,MAAQ,kBACRC,MAAQ,4BACRC,MAAQ,2BACRC,MAAQ,WACRC,MAAQ,sBACRC,MAAQ,sBACRC,MAAQ,oBACRC,MAAQ,gBACRC,MAAQ,kBACRC,MAAQ,eACRC,MAAQ,mBACRC,MAAQ,kBACRC,MAAQ,eACRC,MAAQ,cACRC,MAAQ,QACRC,MAAQ,cACRC,MAAQ,cACRC,MAAQ,cACRC,MAAQ,2BACRC,MAAQ,wBACRC,MAAQ,wBACRC,MAAQ,2BACRC,MAAQ,kBACRC,MAAQ,gBACRC,MAAQ,gBACRC,MAAQ,aACRC,MAAQ,YACRC,MAAQ,aACRC,MAAQ,iBACRC,MAAQ,eACRC,MAAQ,eACRC,MAAQ,mBACRC,MAAQ,wBACRC,MAAQ,mBACRC,MAAQ,cACRC,MAAQ,WACRC,MAAQ,aACRC,MAAQ,YACRC,MAAQ,2BACRC,MAAQ,uBACRC,MAAQ,gBACRC,MAAQ,kBACRC,MAAQ,mBACRC,MAAQ,oBACRC,MAAQ,WACRC,MAAQ,YACRC,MAAQ,mBAIRC,EAAQ,eACR7I,EAAQ,iBACRI,EAAQ,cACRI,EAAQ,kBACRE,EAAQ,eACRE,EAAQ,iBACRkI,EAAQ,cACRC,EAAQ,eACRC,EAAQ,gBACRnI,EAAQ,YACRE,GAAQ,iBACRkI,GAAQ,SACRC,GAAQ,cACRC,GAAQ,WACRC,GAAQ,cACRC,GAAQ,WACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,cACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,sBACRC,GAAQ,mBACRC,GAAQ,oBACRC,GAAQ,iBACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,sBACRC,GAAQ,qBACRC,GAAQ,eACRC,GAAQ,kBACRC,GAAQ,wBAGVpT,EAAUwI,QAAQzG,UAAUsR,cAC1BC,iBACE1B,EAAG,YACH7I,EAAG,SACHI,EAAG,iBACHI,EAAG,oBACHE,EAAG,mBACHE,EAAG,mBACHkI,EAAG,iBACHC,EAAG,gBACHC,EAAG,kBAELwB,cACE3B,EAAG,UACH7I,EAAG,UACHI,EAAG,wBACHI,EAAG,OACHE,EAAG,YACHE,EAAG,UACHkI,EAAG,UACH2B,IAAK,SAEPC,aACE7B,EAAG,UACH7I,EAAG,WACHI,EAAG,cACHI,EAAG,gCACHE,EAAG,QACHG,EAAG,eACHE,GAAI,iBACJkI,GAAI,QACJC,GAAI,wCACJC,GAAI,yCACJC,GAAI,0CACJC,GAAI,sCACJE,GAAI,mBACJC,GAAI,mBACJC,GAAI,mBACJC,GAAI,MACJC,GAAI,MACJC,GAAI,MACJC,GAAI,MACJC,GAAI,sBACJW,IAAK,SAEPE,OACE9B,EAAQ,qBACR7I,EAAQ,cACRY,EAAQ,mCACRmI,EAAQ,+BACRlI,EAAQ,qCACRsI,GAAQ,gEACRE,GAAQ,4DACRC,GAAQ,4CACRQ,GAAQ,gCACRC,GAAQ,yBACRI,GAAQ,oDACRE,GAAQ,gDACRO,GAAQ,oBACRC,GAAQ,sCACRC,GAAQ,iEACRC,GAAQ,6DACRC,GAAQ,6DACRC,GAAQ,wFACRC,GAAQ,oFACRC,GAAQ,iDACRC,GAAQ,4EACRC,GAAQ,yEAEVC,eACEtL,EAAG,YACHI,EAAG,6BACHI,EAAG,6BACHE,EAAG,+BACHE,EAAG,+BACHmI,EAAG,mBACHC,EAAG,kCAELuC,kBACE1C,EAAG,WACH7I,EAAG,YACHI,EAAG,WACHI,EAAG,eAELgL,WACExL,EAAG,yBAELyL,gBACE5C,EAAG,iBACH7I,EAAG,kBAEL0L,cACE7C,EAAG,qBACH7I,EAAG,wBAEL2L,aACE9C,EAAG,OACH7I,EAAG,cACHI,EAAG,eACHI,EAAG,gBACHE,EAAG,kBAELkL,UACE/C,EAAG,SACH7I,EAAG,OACHI,EAAG,QAELyL,YACEhD,EAAG,SACH7I,EAAG,iBACHI,EAAG,mBAEL0L,WACEjD,EAAG,SACH7I,EAAG,OACHI,EAAG,QAEL2L,sBACElD,EAAG,UACH7I,EAAG,QACHI,EAAG,aACHI,EAAG,gBAELwL,YACExL,EAAG,OAELyL,yBACEpD,EAAG,GACH7I,EAAG,IACHI,EAAG,KACHI,EAAG,KACHE,EAAG,IACHE,EAAG,IACHkI,EAAG,KAELnJ,aACEK,EAAG,WACHI,EAAG,YACHI,EAAG,eACHE,EAAG,cACHE,EAAG,WACHkI,EAAG,YACHC,EAAG,eACHC,EAAG,gBAIP/R,EAAUwI,QAAQzG,UAAUkT,QAAU,SAAUrM,GAC9C,IAAIsM,EAAQtS,KAAK+F,IAAIC,GACrB,OAAQA,GACN,IAAK,cACL,IAAK,QACL,IAAK,eACL,IAAK,kBACL,IAAK,gBACL,IAAK,mBACL,IAAK,YACL,IAAK,iBACL,IAAK,eACL,IAAK,cACL,IAAK,WACL,IAAK,aACL,IAAK,YACL,IAAK,uBACL,IAAK,aACL,IAAK,cACH,OAAOhG,KAAKyQ,aAAazK,GAAIsM,GAC/B,IAAK,cACL,IAAK,kBACH,IAAKA,EAAO,OACZ,OAAO9L,OAAOC,aAAa6L,EAAM,GAAIA,EAAM,GAAIA,EAAM,GAAIA,EAAM,IACjE,IAAK,0BACH,IAAKA,EAAO,OACZ,OACEtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAEhC,IAAK,eACH,IAAKA,EAAO,OACZ,OAAOA,EAAM,GAAK,IAAMA,EAAM,GAAK,IAAMA,EAAM,GAAK,IAAMA,EAAM,GAEpE,OAAO9L,OAAO8L,IAEf,SAAWC,GACV,IAEIC,EAFA/J,EAAO8J,EAAiB9J,KACxB5C,EAAM0M,EAAiB1M,IAG3B,IAAK2M,KAAQ/J,EACPA,EAAK7H,eAAe4R,KACtB3M,EAAI4C,EAAK+J,IAASA,GAPvB,CAUEpV,EAAUwI,QAAQzG,WAErB/B,EAAUwI,QAAQzG,UAAUsT,OAAS,WACnC,IACID,EACAxM,EAFAH,KAGJ,IAAK2M,KAAQxS,KACPA,KAAKY,eAAe4R,KACtBxM,EAAKhG,KAAKyI,KAAK+J,MAEb3M,EAAIG,GAAMhG,KAAKqS,QAAQrM,IAI7B,OAAOH,KCpXV,SAAW5F,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsB,qBAAsBM,GACzC,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EACEC,QAAQ,gBACRA,QAAQ,sBACRA,QAAQ,sBAIVD,EAAQF,OAAO3C,WAblB,CAeE,SAAUA,GACX,aAEA,IAAIsV,EAA0BtV,EAAU0E,gBACpC6Q,EAAwBvV,EAAUgI,cAClCwN,EAA+BxV,EAAUiD,qBACzCwS,EAAgCzV,EAAUkD,sBAG9ClD,EAAU0E,gBAAkB,SAAUvE,GACpC,QACIA,EAAQuV,aAAeJ,EAAwBrT,KAAKjC,EAAWG,IAKrEH,EAAUgI,cAAgB,SAAU7H,GAClC,OACGA,IAAmC,IAAxBA,EAAQuV,aACpBH,EAAsBtT,KAAKjC,EAAWG,IAM1CH,EAAUiD,qBAAuB,SAAUc,EAAQ5D,GACjDqV,EAA6BvT,KAAKjC,EAAW+D,EAAQ5D,GACrD,IAAIwV,EAAM5R,EAAOS,WAAW,MACxBnB,EAAQU,EAAOV,MACfC,EAASS,EAAOT,OAChBsS,EAAa7R,EAAO8B,MAAMxC,MAC1BwS,EAAc9R,EAAO8B,MAAMvC,OAC3BoS,EAAcvV,EAAQuV,YAC1B,GAAKA,KAAeA,EAAc,GASlC,OANIA,EAAc,IAChB3R,EAAOV,MAAQC,EACfS,EAAOT,OAASD,EAChBU,EAAO8B,MAAMxC,MAAQwS,EACrB9R,EAAO8B,MAAMvC,OAASsS,GAEhBF,GACN,KAAK,EAEHC,EAAIG,UAAUzS,EAAO,GACrBsS,EAAI3S,OAAO,EAAG,GACd,MACF,KAAK,EAEH2S,EAAIG,UAAUzS,EAAOC,GACrBqS,EAAII,OAAOnR,KAAKoR,IAChB,MACF,KAAK,EAEHL,EAAIG,UAAU,EAAGxS,GACjBqS,EAAI3S,MAAM,GAAI,GACd,MACF,KAAK,EAEH2S,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAI3S,MAAM,GAAI,GACd,MACF,KAAK,EAEH2S,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAIG,UAAU,GAAIxS,GAClB,MACF,KAAK,EAEHqS,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAIG,UAAUzS,GAAQC,GACtBqS,EAAI3S,OAAO,EAAG,GACd,MACF,KAAK,EAEH2S,EAAII,QAAQ,GAAMnR,KAAKoR,IACvBL,EAAIG,WAAWzS,EAAO,KAO5BrD,EAAUkD,sBAAwB,SAAU7C,EAAK4V,EAAM9T,GACrD,IAEIgB,EACAC,EAHAjD,EAAUsV,EAA8BxT,KAAKjC,EAAWK,EAAK4V,GAC7DP,EAAcvV,EAAQuV,YAM1B,IAHoB,IAAhBA,GAAwBvT,GAAQA,EAAKqI,OACvCkL,EAAcvT,EAAKqI,KAAK7B,IAAI,iBAEzB+M,GAAeA,EAAc,GAAqB,IAAhBA,EACrC,OAAOvV,EAETgD,KACA,IAAKC,KAAKjD,EACJA,EAAQqD,eAAeJ,KACzBD,EAAWC,GAAKjD,EAAQiD,IAI5B,OADAD,EAAWuS,YAAcA,EACjBA,GACN,KAAK,EAEHvS,EAAWmC,KAAOnF,EAAQsF,MAC1BtC,EAAWsC,MAAQtF,EAAQmF,KAC3B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQsF,MAC1BtC,EAAWoC,IAAMpF,EAAQuF,OACzBvC,EAAWsC,MAAQtF,EAAQmF,KAC3BnC,EAAWuC,OAASvF,EAAQoF,IAC5B,MACF,KAAK,EAEHpC,EAAWoC,IAAMpF,EAAQuF,OACzBvC,EAAWuC,OAASvF,EAAQoF,IAC5B,MACF,KAAK,EAEHpC,EAAWmC,KAAOnF,EAAQoF,IAC1BpC,EAAWoC,IAAMpF,EAAQmF,KACzBnC,EAAWsC,MAAQtF,EAAQuF,OAC3BvC,EAAWuC,OAASvF,EAAQsF,MAC5B,MACF,KAAK,EAEHtC,EAAWmC,KAAOnF,EAAQoF,IAC1BpC,EAAWoC,IAAMpF,EAAQsF,MACzBtC,EAAWsC,MAAQtF,EAAQuF,OAC3BvC,EAAWuC,OAASvF,EAAQmF,KAC5B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQuF,OAC1BvC,EAAWoC,IAAMpF,EAAQsF,MACzBtC,EAAWsC,MAAQtF,EAAQoF,IAC3BpC,EAAWuC,OAASvF,EAAQmF,KAC5B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQuF,OAC1BvC,EAAWoC,IAAMpF,EAAQmF,KACzBnC,EAAWsC,MAAQtF,EAAQoF,IAC3BpC,EAAWuC,OAASvF,EAAQsF,MAWhC,OARItC,EAAWuS,YAAc,IAC3BvS,EAAWS,SAAWzD,EAAQ0D,UAC9BV,EAAWU,UAAY1D,EAAQyD,SAC/BT,EAAW2B,SAAW3E,EAAQ4E,UAC9B5B,EAAW4B,UAAY5E,EAAQ2E,SAC/B3B,EAAWe,YAAc/D,EAAQgE,aACjChB,EAAWgB,aAAehE,EAAQ+D,aAE7Bf"} 2 | -------------------------------------------------------------------------------- /demo/web/media/DownloadIcon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/DownloadIcon.png -------------------------------------------------------------------------------- /demo/web/media/EditIcon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/EditIcon.png -------------------------------------------------------------------------------- /demo/web/media/beyonce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/beyonce.png -------------------------------------------------------------------------------- /demo/web/media/cersei.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/cersei.png -------------------------------------------------------------------------------- /demo/web/media/geoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/geoff.png -------------------------------------------------------------------------------- /demo/web/media/john.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/john.png -------------------------------------------------------------------------------- /demo/web/media/lena.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/lena.png -------------------------------------------------------------------------------- /demo/web/media/leo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/leo.png -------------------------------------------------------------------------------- /demo/web/media/loading.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/loading.png -------------------------------------------------------------------------------- /demo/web/media/louis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/louis.png -------------------------------------------------------------------------------- /demo/web/media/neil.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/neil.png -------------------------------------------------------------------------------- /demo/web/media/placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/placeholder.png -------------------------------------------------------------------------------- /demo/web/media/placeholder2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/placeholder2.png -------------------------------------------------------------------------------- /demo/web/media/placeholder4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/placeholder4.png -------------------------------------------------------------------------------- /demo/web/media/rashida.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/rashida.png -------------------------------------------------------------------------------- /demo/web/media/seth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/seth.png -------------------------------------------------------------------------------- /demo/web/media/steve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/glow/91b2c577a5c110b2b38761fc56d81f7d87f077c1/demo/web/media/steve.png -------------------------------------------------------------------------------- /demo/web/mock.css: -------------------------------------------------------------------------------- 1 | /* mock.css 2 | * 3 | * CSS for the mock-up mimicking the article around the demo. 4 | */ 5 | 6 | /* Global Rules */ 7 | 8 | html { 9 | font-family: Lato, Helvetica, Arial, sans-serif; 10 | text-rendering: optimizeLegibility; 11 | } 12 | 13 | body { 14 | margin: 0em; 15 | font-size: 1.1em; 16 | line-height: 1.5em; 17 | } 18 | 19 | /* Article Title */ 20 | 21 | div.TitlePanel { 22 | background: rgb(52,51,57); 23 | background: linear-gradient(349deg, rgb(85, 83, 95) 0%, 24 | rgb(139, 139, 163) 100%); 25 | color: rgb(255, 255, 255); 26 | display: flex; 27 | height: 337px; 28 | } 29 | 30 | div.Title { 31 | margin: auto; 32 | max-width: 555px; 33 | text-align: center; 34 | } 35 | 36 | .Title h1 { 37 | font-size: 2.4em; 38 | line-height: 1.2em; 39 | margin-bottom: 0.3em; 40 | } 41 | 42 | .Title time { 43 | text-transform: uppercase; 44 | color: #cacaca; 45 | font-size: 0.7em; 46 | font-weight: 700; 47 | display: block; 48 | text-align: center; 49 | } 50 | 51 | /* Article Content */ 52 | 53 | div.Content { 54 | color: #111; 55 | max-width: 570px; 56 | padding: 1em; 57 | margin: 3em auto; 58 | word-wrap: break-word; 59 | } 60 | 61 | .MockUpNotice { 62 | color: #b32e2e; 63 | } 64 | 65 | .Content p { 66 | margin: 2em auto; 67 | } 68 | .Content h1, h2, h3, h4, h5, h6 { 69 | font-size: 1.3em; 70 | margin-top: 2.5em; 71 | margin-bottom: 0.5em; 72 | } -------------------------------------------------------------------------------- /graphics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import time 4 | import threading 5 | 6 | 7 | def save_image(x, path): 8 | im = Image.fromarray(x) 9 | im.save(path, optimize=True) 10 | return 11 | 12 | # Assumes [NCHW] format 13 | def save_raster(x, path, rescale=False, width=None): 14 | t = threading.Thread(target=_save_raster, args=(x, path, rescale, width)) 15 | t.start() 16 | 17 | 18 | def _save_raster(x, path, rescale, width): 19 | x = to_raster(x, rescale, width) 20 | save_image(x, path) 21 | 22 | # Shape: (n_patches,rows,columns,channels) 23 | def to_raster_old(x, rescale=False, width=None): 24 | x = np.transpose(x, (0, 3, 1, 2)) 25 | 26 | #x = x.swapaxes(2, 3) 27 | if len(x.shape) == 3: 28 | x = x.reshape((x.shape[0], 1, x.shape[1], x.shape[2])) 29 | if x.shape[1] == 1: 30 | x = np.repeat(x, 3, axis=1) 31 | if rescale: 32 | x = (x - x.min()) / (x.max() - x.min()) * 255. 33 | x = np.clip(x, 0, 255) 34 | assert len(x.shape) == 4 35 | assert x.shape[1] == 3 36 | n_patches = x.shape[0] 37 | if width is None: 38 | width = int(np.ceil(np.sqrt(n_patches))) # result width 39 | height = int(n_patches/width) # result height 40 | tile_height = x.shape[2] 41 | tile_width = x.shape[3] 42 | result = np.zeros((3, int(height*tile_height), 43 | int(width*tile_width)), dtype='uint8') 44 | for i in range(height): 45 | for j in range(width): 46 | result[:, i*tile_height:(i+1)*tile_height, 47 | j*tile_width:(j+1)*tile_width] = x[i] 48 | return result 49 | 50 | 51 | # Shape: (n_patches,rows,columns,channels) 52 | def to_raster(x, rescale=False, width=None): 53 | if len(x.shape) == 3: 54 | x = x.reshape((x.shape[0], x.shape[1], x.shape[2], 1)) 55 | if x.shape[3] == 1: 56 | x = np.repeat(x, 3, axis=3) 57 | if rescale: 58 | x = (x - x.min()) / (x.max() - x.min()) * 255. 59 | x = np.clip(x, 0, 255) 60 | assert len(x.shape) == 4 61 | assert x.shape[3] == 3 62 | n_batch = x.shape[0] 63 | if width is None: 64 | width = int(np.ceil(np.sqrt(n_batch))) # result width 65 | height = int(n_batch / width) # result height 66 | tile_height = x.shape[1] 67 | tile_width = x.shape[2] 68 | result = np.zeros((int(height * tile_height), 69 | int(width * tile_width), 3), dtype='uint8') 70 | for i in range(height): 71 | for j in range(width): 72 | result[i * tile_height:(i + 1) * tile_height, j * 73 | tile_width:(j + 1) * tile_width] = x[width*i+j] 74 | return result 75 | -------------------------------------------------------------------------------- /memory_saving_gradients.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort 2 | import contextlib 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow.contrib.graph_editor as ge 6 | import time 7 | import sys 8 | sys.setrecursionlimit(10000) 9 | # refers back to current module if we decide to split helpers out 10 | util = sys.modules[__name__] 11 | 12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" 13 | setattr(tf.GraphKeys, "VARIABLES", "variables") 14 | 15 | # save original gradients since tf.gradient could be monkey-patched to point 16 | # to our version 17 | from tensorflow.python.ops import gradients as tf_gradients_lib 18 | tf_gradients = tf_gradients_lib.gradients 19 | 20 | MIN_CHECKPOINT_NODE_SIZE = 1024 # use lower value during testing 21 | 22 | # specific versions we can use to do process-wide replacement of tf.gradients 23 | 24 | 25 | def gradients_speed(ys, xs, grad_ys=None, **kwargs): 26 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) 27 | 28 | 29 | def gradients_memory(ys, xs, grad_ys=None, **kwargs): 30 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs) 31 | 32 | 33 | def gradients_collection(ys, xs, grad_ys=None, **kwargs): 34 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs) 35 | 36 | 37 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): 38 | ''' 39 | Authors: Tim Salimans & Yaroslav Bulatov 40 | 41 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" 42 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) 43 | 44 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients 45 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 46 | 47 | 'checkpoints' can either be 48 | - a list consisting of tensors from the forward pass of the neural net 49 | that we should re-use when calculating the gradients in the backward pass 50 | all other tensors that do not appear in this list will be re-computed 51 | - a string specifying how this list should be determined. currently we support 52 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, 53 | so checkpointing them maximizes the running speed 54 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) 55 | - 'memory': try to minimize the memory usage 56 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) 57 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint 58 | ''' 59 | 60 | # print("Calling memsaving gradients with", checkpoints) 61 | if not isinstance(ys, list): 62 | ys = [ys] 63 | if not isinstance(xs, list): 64 | xs = [xs] 65 | 66 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], 67 | inclusive=True) 68 | 69 | debug_print("bwd_ops: %s", bwd_ops) 70 | 71 | # forward ops are all ops that are candidates for recomputation 72 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], 73 | inclusive=True, 74 | within_ops=bwd_ops) 75 | debug_print("fwd_ops: %s", fwd_ops) 76 | 77 | # exclude ops with no inputs 78 | fwd_ops = [op for op in fwd_ops if op.inputs] 79 | 80 | # don't recompute xs, remove variables 81 | xs_ops = _to_ops(xs) 82 | fwd_ops = [op for op in fwd_ops if not op in xs_ops] 83 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] 84 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] 85 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name] 86 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors 87 | ts_all = [t for t in ts_all if '/read' not in t.name] 88 | ts_all = set(ts_all) - set(xs) - set(ys) 89 | 90 | # construct list of tensors to checkpoint during forward pass, if not 91 | # given as input 92 | if type(checkpoints) is not list: 93 | if checkpoints == 'collection': 94 | checkpoints = tf.get_collection('checkpoints') 95 | 96 | elif checkpoints == 'speed': 97 | # checkpoint all expensive ops to maximize running speed 98 | checkpoints = ge.filter_ts_from_regex( 99 | fwd_ops, 'conv2d|Conv|MatMul') 100 | 101 | elif checkpoints == 'memory': 102 | 103 | # remove very small tensors and some weird ops 104 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually 105 | try: 106 | return [int(e if e.value is not None else 64) for e in t] 107 | except: 108 | return [0] # unknown shape 109 | ts_all = [t for t in ts_all if np.prod( 110 | fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] 111 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name] 112 | ts_all = [t for t in ts_all if 'entropy' not in t.name] 113 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] 114 | ts_all = [t for t in ts_all if 'Switch' not in t.name] 115 | ts_all = [t for t in ts_all if 'dropout' not in t.name] 116 | 117 | # filter out all tensors that are inputs of the backward graph 118 | with util.capture_ops() as bwd_ops: 119 | tf_gradients(ys, xs, grad_ys, **kwargs) 120 | 121 | bwd_inputs = [t for op in bwd_ops for t in op.inputs] 122 | # list of tensors in forward graph that is in input to bwd graph 123 | ts_filtered = list(set(bwd_inputs).intersection(ts_all)) 124 | debug_print("Using tensors %s", ts_filtered) 125 | 126 | # try two slightly different ways of getting bottlenecks tensors 127 | # to checkpoint 128 | for ts in [ts_filtered, ts_all]: 129 | 130 | # get all bottlenecks in the graph 131 | bottleneck_ts = [] 132 | for t in ts: 133 | b = set(ge.get_backward_walk_ops( 134 | t.op, inclusive=True, within_ops=fwd_ops)) 135 | f = set(ge.get_forward_walk_ops( 136 | t.op, inclusive=False, within_ops=fwd_ops)) 137 | # check that there are not shortcuts 138 | b_inp = set( 139 | [inp for op in b for inp in op.inputs]).intersection(ts_all) 140 | f_inp = set( 141 | [inp for op in f for inp in op.inputs]).intersection(ts_all) 142 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): 143 | bottleneck_ts.append(t) # we have a bottleneck! 144 | else: 145 | debug_print("Rejected bottleneck candidate and ops %s", [ 146 | t] + list(set(ts_all) - set(b_inp) - set(f_inp))) 147 | 148 | # success? or try again without filtering? 149 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! 150 | break 151 | 152 | if not bottleneck_ts: 153 | raise Exception( 154 | 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') 155 | 156 | # sort the bottlenecks 157 | bottlenecks_sorted_lists = tf_toposort( 158 | bottleneck_ts, within_ops=fwd_ops) 159 | sorted_bottlenecks = [ 160 | t for ts in bottlenecks_sorted_lists for t in ts] 161 | 162 | # save an approximately optimal number ~ sqrt(N) 163 | N = len(ts_filtered) 164 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): 165 | checkpoints = sorted_bottlenecks 166 | else: 167 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) 168 | checkpoints = sorted_bottlenecks[step::step] 169 | 170 | else: 171 | raise Exception( 172 | '%s is unsupported input for "checkpoints"' % (checkpoints,)) 173 | 174 | checkpoints = list(set(checkpoints).intersection(ts_all)) 175 | 176 | # at this point automatic selection happened and checkpoints is list of nodes 177 | assert isinstance(checkpoints, list) 178 | 179 | debug_print("Checkpoint nodes used: %s", checkpoints) 180 | # better error handling of special cases 181 | # xs are already handled as checkpoint nodes, so no need to include them 182 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) 183 | if xs_intersect_checkpoints: 184 | debug_print("Warning, some input nodes are also checkpoint nodes: %s", 185 | xs_intersect_checkpoints) 186 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) 187 | debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, 188 | ys_intersect_checkpoints) 189 | # saving an output node (ys) gives no benefit in memory while creating 190 | # new edge cases, exclude them 191 | if ys_intersect_checkpoints: 192 | debug_print("Warning, some output nodes are also checkpoints nodes: %s", 193 | format_ops(ys_intersect_checkpoints)) 194 | 195 | # remove initial and terminal nodes from checkpoints list if present 196 | checkpoints = list(set(checkpoints) - set(ys) - set(xs)) 197 | 198 | # check that we have some nodes to checkpoint 199 | if not checkpoints: 200 | raise Exception('no checkpoints nodes found or given as input! ') 201 | 202 | # disconnect dependencies between checkpointed tensors 203 | checkpoints_disconnected = {} 204 | for x in checkpoints: 205 | if x.op and x.op.name is not None: 206 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") 207 | else: 208 | grad_node = tf.stop_gradient(x) 209 | checkpoints_disconnected[x] = grad_node 210 | 211 | # partial derivatives to the checkpointed tensors and xs 212 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], 213 | stop_at_ts=checkpoints, within_ops=fwd_ops) 214 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", 215 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) 216 | debug_print("ops_to_copy = %s", ops_to_copy) 217 | debug_print("Processing list %s", ys) 218 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 219 | copied_ops = info._transformed_ops.values() 220 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 221 | ge.reroute_ts(checkpoints_disconnected.values(), 222 | checkpoints_disconnected.keys(), can_modify=copied_ops) 223 | debug_print("Rewired %s in place of %s restricted to %s", 224 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) 225 | 226 | # get gradients with respect to current boundary + original x's 227 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] 228 | boundary = list(checkpoints_disconnected.values()) 229 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) 230 | debug_print("Got gradients %s", dv) 231 | debug_print("for %s", copied_ys) 232 | debug_print("with respect to %s", boundary+xs) 233 | 234 | inputs_to_do_before = [y.op for y in ys] 235 | if grad_ys is not None: 236 | inputs_to_do_before += grad_ys 237 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 238 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 239 | 240 | # partial derivatives to the checkpointed nodes 241 | # dictionary of "node: backprop" for nodes in the boundary 242 | d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(), 243 | dv[:len(checkpoints_disconnected)])} 244 | # partial derivatives to xs (usually the params of the neural net) 245 | d_xs = dv[len(checkpoints_disconnected):] 246 | 247 | # incorporate derivatives flowing through the checkpointed nodes 248 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) 249 | for ts in checkpoints_sorted_lists[::-1]: 250 | debug_print("Processing list %s", ts) 251 | checkpoints_other = [r for r in checkpoints if r not in ts] 252 | checkpoints_disconnected_other = [ 253 | checkpoints_disconnected[r] for r in checkpoints_other] 254 | 255 | # copy part of the graph below current checkpoint node, stopping at 256 | # other checkpoints nodes 257 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[ 258 | r.op for r in ts], stop_at_ts=checkpoints_other) 259 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", 260 | len(ops_to_copy), fwd_ops, [r.op for r in ts], 261 | checkpoints_other) 262 | debug_print("ops_to_copy = %s", ops_to_copy) 263 | if not ops_to_copy: # we're done! 264 | break 265 | copied_sgv, info = ge.copy_with_input_replacements( 266 | ge.sgv(ops_to_copy), {}) 267 | copied_ops = info._transformed_ops.values() 268 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 269 | ge.reroute_ts(checkpoints_disconnected_other, 270 | checkpoints_other, can_modify=copied_ops) 271 | debug_print("Rewired %s in place of %s restricted to %s", 272 | checkpoints_disconnected_other, checkpoints_other, copied_ops) 273 | 274 | # gradient flowing through the checkpointed node 275 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] 276 | substitute_backprops = [d_checkpoints[r] for r in ts] 277 | dv = tf_gradients(boundary, 278 | checkpoints_disconnected_other+xs, 279 | grad_ys=substitute_backprops, **kwargs) 280 | debug_print("Got gradients %s", dv) 281 | debug_print("for %s", boundary) 282 | debug_print("with respect to %s", checkpoints_disconnected_other+xs) 283 | debug_print("with boundary backprop substitutions %s", 284 | substitute_backprops) 285 | 286 | inputs_to_do_before = [d_checkpoints[r].op for r in ts] 287 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 288 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 289 | 290 | # partial derivatives to the checkpointed nodes 291 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): 292 | if dr is not None: 293 | if d_checkpoints[r] is None: 294 | d_checkpoints[r] = dr 295 | else: 296 | d_checkpoints[r] += dr 297 | 298 | # partial derivatives to xs (usually the params of the neural net) 299 | d_xs_new = dv[len(checkpoints_other):] 300 | for j in range(len(xs)): 301 | if d_xs_new[j] is not None: 302 | if d_xs[j] is None: 303 | d_xs[j] = d_xs_new[j] 304 | else: 305 | d_xs[j] += d_xs_new[j] 306 | 307 | return d_xs 308 | 309 | 310 | def tf_toposort(ts, within_ops=None): 311 | all_ops = ge.get_forward_walk_ops( 312 | [x.op for x in ts], within_ops=within_ops) 313 | 314 | deps = {} 315 | for op in all_ops: 316 | for o in op.outputs: 317 | deps[o] = set(op.inputs) 318 | sorted_ts = toposort(deps) 319 | 320 | # only keep the tensors from our original list 321 | ts_sorted_lists = [] 322 | for l in sorted_ts: 323 | keep = list(set(l).intersection(ts)) 324 | if keep: 325 | ts_sorted_lists.append(keep) 326 | 327 | return ts_sorted_lists 328 | 329 | 330 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts): 331 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) 332 | ops = bwd_ops.intersection(within_ops).difference( 333 | [t.op for t in stop_at_ts]) 334 | return list(ops) 335 | 336 | 337 | @contextlib.contextmanager 338 | def capture_ops(): 339 | """Decorator to capture ops created in the block. 340 | with capture_ops() as ops: 341 | # create some ops 342 | print(ops) # => prints ops created. 343 | """ 344 | 345 | micros = int(time.time()*10**6) 346 | scope_name = str(micros) 347 | op_list = [] 348 | with tf.name_scope(scope_name): 349 | yield op_list 350 | 351 | g = tf.get_default_graph() 352 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g)) 353 | 354 | 355 | def _to_op(tensor_or_op): 356 | if hasattr(tensor_or_op, "op"): 357 | return tensor_or_op.op 358 | return tensor_or_op 359 | 360 | 361 | def _to_ops(iterable): 362 | if not _is_iterable(iterable): 363 | return iterable 364 | return [_to_op(i) for i in iterable] 365 | 366 | 367 | def _is_iterable(o): 368 | try: 369 | _ = iter(o) 370 | except Exception: 371 | return False 372 | return True 373 | 374 | 375 | DEBUG_LOGGING = False 376 | 377 | 378 | def debug_print(s, *args): 379 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their 380 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug 381 | 382 | Usage: 383 | debug_print("see tensors %s for %s", tensorlist, [1,2,3]) 384 | """ 385 | 386 | if DEBUG_LOGGING: 387 | formatted_args = [format_ops(arg) for arg in args] 388 | print("DEBUG "+s % tuple(formatted_args)) 389 | 390 | 391 | def format_ops(ops, sort_outputs=True): 392 | """Helper method for printing ops. Converts Tensor/Operation op to op.name, 393 | rest to str(op).""" 394 | 395 | if hasattr(ops, '__iter__') and not isinstance(ops, str): 396 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops] 397 | if sort_outputs: 398 | return sorted(l) 399 | return l 400 | else: 401 | return ops.name if hasattr(ops, "name") else str(ops) 402 | 403 | 404 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before): 405 | for op in wait_to_do_ops: 406 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs] 407 | ge.add_control_inputs(op, ci) 408 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import tfops as Z 4 | import optim 5 | import numpy as np 6 | import horovod.tensorflow as hvd 7 | from tensorflow.contrib.framework.python.ops import add_arg_scope 8 | 9 | 10 | ''' 11 | f_loss: function with as input the (x,y,reuse=False), and as output a list/tuple whose first element is the loss. 12 | ''' 13 | 14 | 15 | def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss): 16 | 17 | # == Create class with static fields and methods 18 | class m(object): 19 | pass 20 | m.sess = sess 21 | m.feeds = feeds 22 | m.lr = lr 23 | 24 | # === Loss and optimizer 25 | loss_train, stats_train = f_loss(train_iterator, True) 26 | all_params = tf.trainable_variables() 27 | if hps.gradient_checkpointing == 1: 28 | from memory_saving_gradients import gradients 29 | gs = gradients(loss_train, all_params) 30 | else: 31 | gs = tf.gradients(loss_train, all_params) 32 | 33 | optimizer = {'adam': optim.adam, 'adamax': optim.adamax, 34 | 'adam2': optim.adam2}[hps.optimizer] 35 | 36 | train_op, polyak_swap_op, ema = optimizer( 37 | all_params, gs, alpha=lr, hps=hps) 38 | if hps.direct_iterator: 39 | m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1] 40 | else: 41 | def _train(_lr): 42 | _x, _y = train_iterator() 43 | return sess.run([train_op, stats_train], {feeds['x']: _x, 44 | feeds['y']: _y, lr: _lr})[1] 45 | m.train = _train 46 | 47 | m.polyak_swap = lambda: sess.run(polyak_swap_op) 48 | 49 | # === Testing 50 | loss_test, stats_test = f_loss(test_iterator, False, reuse=True) 51 | if hps.direct_iterator: 52 | m.test = lambda: sess.run(stats_test) 53 | else: 54 | def _test(): 55 | _x, _y = test_iterator() 56 | return sess.run(stats_test, {feeds['x']: _x, 57 | feeds['y']: _y}) 58 | m.test = _test 59 | 60 | # === Saving and restoring 61 | saver = tf.train.Saver() 62 | saver_ema = tf.train.Saver(ema.variables_to_restore()) 63 | m.save_ema = lambda path: saver_ema.save( 64 | sess, path, write_meta_graph=False) 65 | m.save = lambda path: saver.save(sess, path, write_meta_graph=False) 66 | m.restore = lambda path: saver.restore(sess, path) 67 | 68 | # === Initialize the parameters 69 | if hps.restore_path != '': 70 | m.restore(hps.restore_path) 71 | else: 72 | with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True): 73 | results_init = f_loss(None, True, reuse=True) 74 | sess.run(tf.global_variables_initializer()) 75 | sess.run(results_init, {feeds['x']: data_init['x'], 76 | feeds['y']: data_init['y']}) 77 | sess.run(hvd.broadcast_global_variables(0)) 78 | 79 | return m 80 | 81 | 82 | def codec(hps): 83 | 84 | def encoder(z, objective): 85 | eps = [] 86 | for i in range(hps.n_levels): 87 | z, objective = revnet2d(str(i), z, objective, hps) 88 | if i < hps.n_levels-1: 89 | z, objective, _eps = split2d("pool"+str(i), z, objective=objective) 90 | eps.append(_eps) 91 | return z, objective, eps 92 | 93 | def decoder(z, eps=[None]*hps.n_levels, eps_std=None): 94 | for i in reversed(range(hps.n_levels)): 95 | if i < hps.n_levels-1: 96 | z = split2d_reverse("pool"+str(i), z, eps=eps[i], eps_std=eps_std) 97 | z, _ = revnet2d(str(i), z, 0, hps, reverse=True) 98 | 99 | return z 100 | 101 | return encoder, decoder 102 | 103 | 104 | def prior(name, y_onehot, hps): 105 | 106 | with tf.variable_scope(name): 107 | n_z = hps.top_shape[-1] 108 | 109 | h = tf.zeros([tf.shape(y_onehot)[0]]+hps.top_shape[:2]+[2*n_z]) 110 | if hps.learntop: 111 | h = Z.conv2d_zeros('p', h, 2*n_z) 112 | if hps.ycond: 113 | h += tf.reshape(Z.linear_zeros("y_emb", y_onehot, 114 | 2*n_z), [-1, 1, 1, 2 * n_z]) 115 | 116 | pz = Z.gaussian_diag(h[:, :, :, :n_z], h[:, :, :, n_z:]) 117 | 118 | def logp(z1): 119 | objective = pz.logp(z1) 120 | return objective 121 | 122 | def sample(eps=None, eps_std=None): 123 | if eps is not None: 124 | # Already sampled eps. Don't use eps_std 125 | z = pz.sample2(eps) 126 | elif eps_std is not None: 127 | # Sample with given eps_std 128 | z = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1])) 129 | else: 130 | # Sample normally 131 | z = pz.sample 132 | 133 | return z 134 | 135 | def eps(z1): 136 | return pz.get_eps(z1) 137 | 138 | return logp, sample, eps 139 | 140 | 141 | def model(sess, hps, train_iterator, test_iterator, data_init): 142 | 143 | # Only for decoding/init, rest use iterators directly 144 | with tf.name_scope('input'): 145 | X = tf.placeholder( 146 | tf.uint8, [None, hps.image_size, hps.image_size, 3], name='image') 147 | Y = tf.placeholder(tf.int32, [None], name='label') 148 | lr = tf.placeholder(tf.float32, None, name='learning_rate') 149 | 150 | encoder, decoder = codec(hps) 151 | hps.n_bins = 2. ** hps.n_bits_x 152 | 153 | def preprocess(x): 154 | x = tf.cast(x, 'float32') 155 | if hps.n_bits_x < 8: 156 | x = tf.floor(x / 2 ** (8 - hps.n_bits_x)) 157 | x = x / hps.n_bins - .5 158 | return x 159 | 160 | def postprocess(x): 161 | return tf.cast(tf.clip_by_value(tf.floor((x + .5)*hps.n_bins)*(256./hps.n_bins), 0, 255), 'uint8') 162 | 163 | def _f_loss(x, y, is_training, reuse=False): 164 | 165 | with tf.variable_scope('model', reuse=reuse): 166 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') 167 | 168 | # Discrete -> Continuous 169 | objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] 170 | z = preprocess(x) 171 | z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins) 172 | objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) 173 | 174 | # Encode 175 | z = Z.squeeze2d(z, 2) # > 16x16x12 176 | z, objective, _ = encoder(z, objective) 177 | 178 | # Prior 179 | hps.top_shape = Z.int_shape(z)[1:] 180 | logp, _, _ = prior("prior", y_onehot, hps) 181 | objective += logp(z) 182 | 183 | # Generative loss 184 | nobj = - objective 185 | bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int( 186 | x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel 187 | 188 | # Predictive loss 189 | if hps.weight_y > 0 and hps.ycond: 190 | 191 | # Classification loss 192 | h_y = tf.reduce_mean(z, axis=[1, 2]) 193 | y_logits = Z.linear_zeros("classifier", h_y, hps.n_y) 194 | bits_y = tf.nn.softmax_cross_entropy_with_logits_v2( 195 | labels=y_onehot, logits=y_logits) / np.log(2.) 196 | 197 | # Classification accuracy 198 | y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32) 199 | classification_error = 1 - \ 200 | tf.cast(tf.equal(y_predicted, y), tf.float32) 201 | else: 202 | bits_y = tf.zeros_like(bits_x) 203 | classification_error = tf.ones_like(bits_x) 204 | 205 | return bits_x, bits_y, classification_error 206 | 207 | def f_loss(iterator, is_training, reuse=False): 208 | if hps.direct_iterator and iterator is not None: 209 | x, y = iterator.get_next() 210 | else: 211 | x, y = X, Y 212 | 213 | bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse) 214 | local_loss = bits_x + hps.weight_y * bits_y 215 | stats = [local_loss, bits_x, bits_y, pred_loss] 216 | global_stats = Z.allreduce_mean( 217 | tf.stack([tf.reduce_mean(i) for i in stats])) 218 | 219 | return tf.reduce_mean(local_loss), global_stats 220 | 221 | feeds = {'x': X, 'y': Y} 222 | m = abstract_model_xy(sess, hps, feeds, train_iterator, 223 | test_iterator, data_init, lr, f_loss) 224 | 225 | # === Sampling function 226 | def f_sample(y, eps_std): 227 | with tf.variable_scope('model', reuse=True): 228 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') 229 | 230 | _, sample, _ = prior("prior", y_onehot, hps) 231 | z = sample(eps_std=eps_std) 232 | z = decoder(z, eps_std=eps_std) 233 | z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3 234 | x = postprocess(z) 235 | 236 | return x 237 | 238 | m.eps_std = tf.placeholder(tf.float32, [None], name='eps_std') 239 | x_sampled = f_sample(Y, m.eps_std) 240 | 241 | def sample(_y, _eps_std): 242 | return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std}) 243 | m.sample = sample 244 | 245 | if hps.inference: 246 | # === Encoder-Decoder functions 247 | def f_encode(x, y, reuse=True): 248 | with tf.variable_scope('model', reuse=reuse): 249 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') 250 | 251 | # Discrete -> Continuous 252 | objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0] 253 | z = preprocess(x) 254 | z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins) 255 | objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:]) 256 | 257 | # Encode 258 | z = Z.squeeze2d(z, 2) # > 16x16x12 259 | z, objective, eps = encoder(z, objective) 260 | 261 | # Prior 262 | hps.top_shape = Z.int_shape(z)[1:] 263 | logp, _, _eps = prior("prior", y_onehot, hps) 264 | objective += logp(z) 265 | eps.append(_eps(z)) 266 | 267 | return eps 268 | 269 | def f_decode(y, eps, reuse=True): 270 | with tf.variable_scope('model', reuse=reuse): 271 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32') 272 | 273 | _, sample, _ = prior("prior", y_onehot, hps) 274 | z = sample(eps=eps[-1]) 275 | z = decoder(z, eps=eps[:-1]) 276 | z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3 277 | x = postprocess(z) 278 | 279 | return x 280 | 281 | enc_eps = f_encode(X, Y) 282 | dec_eps = [] 283 | print(enc_eps) 284 | for i, _eps in enumerate(enc_eps): 285 | print(_eps) 286 | dec_eps.append(tf.placeholder(tf.float32, _eps.get_shape().as_list(), name="dec_eps_" + str(i))) 287 | dec_x = f_decode(Y, dec_eps) 288 | 289 | eps_shapes = [_eps.get_shape().as_list()[1:] for _eps in enc_eps] 290 | 291 | def flatten_eps(eps): 292 | # [BS, eps_size] 293 | return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1) 294 | 295 | def unflatten_eps(feps): 296 | index = 0 297 | eps = [] 298 | bs = feps.shape[0] 299 | for shape in eps_shapes: 300 | eps.append(np.reshape(feps[:, index: index+np.prod(shape)], (bs, *shape))) 301 | index += np.prod(shape) 302 | return eps 303 | 304 | # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32) 305 | def encode(x, y): 306 | return flatten_eps(sess.run(enc_eps, {X: x, Y: y})) 307 | 308 | def decode(y, feps): 309 | eps = unflatten_eps(feps) 310 | feed_dict = {Y: y} 311 | for i in range(len(dec_eps)): 312 | feed_dict[dec_eps[i]] = eps[i] 313 | return sess.run(dec_x, feed_dict) 314 | 315 | m.encode = encode 316 | m.decode = decode 317 | 318 | return m 319 | 320 | 321 | def checkpoint(z, logdet): 322 | zshape = Z.int_shape(z) 323 | z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]]) 324 | logdet = tf.reshape(logdet, [-1, 1]) 325 | combined = tf.concat([z, logdet], axis=1) 326 | tf.add_to_collection('checkpoints', combined) 327 | logdet = combined[:, -1] 328 | z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]]) 329 | return z, logdet 330 | 331 | 332 | @add_arg_scope 333 | def revnet2d(name, z, logdet, hps, reverse=False): 334 | with tf.variable_scope(name): 335 | if not reverse: 336 | for i in range(hps.depth): 337 | z, logdet = checkpoint(z, logdet) 338 | z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse) 339 | z, logdet = checkpoint(z, logdet) 340 | else: 341 | for i in reversed(range(hps.depth)): 342 | z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse) 343 | return z, logdet 344 | 345 | # Simpler, new version 346 | @add_arg_scope 347 | def revnet2d_step(name, z, logdet, hps, reverse): 348 | with tf.variable_scope(name): 349 | 350 | shape = Z.int_shape(z) 351 | n_z = shape[3] 352 | assert n_z % 2 == 0 353 | 354 | if not reverse: 355 | 356 | z, logdet = Z.actnorm("actnorm", z, logdet=logdet) 357 | 358 | if hps.flow_permutation == 0: 359 | z = Z.reverse_features("reverse", z) 360 | elif hps.flow_permutation == 1: 361 | z = Z.shuffle_features("shuffle", z) 362 | elif hps.flow_permutation == 2: 363 | z, logdet = invertible_1x1_conv("invconv", z, logdet) 364 | else: 365 | raise Exception() 366 | 367 | z1 = z[:, :, :, :n_z // 2] 368 | z2 = z[:, :, :, n_z // 2:] 369 | 370 | if hps.flow_coupling == 0: 371 | z2 += f("f1", z1, hps.width) 372 | elif hps.flow_coupling == 1: 373 | h = f("f1", z1, hps.width, n_z) 374 | shift = h[:, :, :, 0::2] 375 | # scale = tf.exp(h[:, :, :, 1::2]) 376 | scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) 377 | z2 += shift 378 | z2 *= scale 379 | logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) 380 | else: 381 | raise Exception() 382 | 383 | z = tf.concat([z1, z2], 3) 384 | 385 | else: 386 | 387 | z1 = z[:, :, :, :n_z // 2] 388 | z2 = z[:, :, :, n_z // 2:] 389 | 390 | if hps.flow_coupling == 0: 391 | z2 -= f("f1", z1, hps.width) 392 | elif hps.flow_coupling == 1: 393 | h = f("f1", z1, hps.width, n_z) 394 | shift = h[:, :, :, 0::2] 395 | # scale = tf.exp(h[:, :, :, 1::2]) 396 | scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.) 397 | z2 /= scale 398 | z2 -= shift 399 | logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3]) 400 | else: 401 | raise Exception() 402 | 403 | z = tf.concat([z1, z2], 3) 404 | 405 | if hps.flow_permutation == 0: 406 | z = Z.reverse_features("reverse", z, reverse=True) 407 | elif hps.flow_permutation == 1: 408 | z = Z.shuffle_features("shuffle", z, reverse=True) 409 | elif hps.flow_permutation == 2: 410 | z, logdet = invertible_1x1_conv( 411 | "invconv", z, logdet, reverse=True) 412 | else: 413 | raise Exception() 414 | 415 | z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True) 416 | 417 | return z, logdet 418 | 419 | 420 | def f(name, h, width, n_out=None): 421 | n_out = n_out or int(h.get_shape()[3]) 422 | with tf.variable_scope(name): 423 | h = tf.nn.relu(Z.conv2d("l_1", h, width)) 424 | h = tf.nn.relu(Z.conv2d("l_2", h, width, filter_size=[1, 1])) 425 | h = Z.conv2d_zeros("l_last", h, n_out) 426 | return h 427 | 428 | 429 | def f_resnet(name, h, width, n_out=None): 430 | n_out = n_out or int(h.get_shape()[3]) 431 | with tf.variable_scope(name): 432 | h = tf.nn.relu(Z.conv2d("l_1", h, width)) 433 | h = Z.conv2d_zeros("l_2", h, n_out) 434 | return h 435 | 436 | # Invertible 1x1 conv 437 | @add_arg_scope 438 | def invertible_1x1_conv(name, z, logdet, reverse=False): 439 | 440 | if True: # Set to "False" to use the LU-decomposed version 441 | 442 | with tf.variable_scope(name): 443 | 444 | shape = Z.int_shape(z) 445 | w_shape = [shape[3], shape[3]] 446 | 447 | # Sample a random orthogonal matrix: 448 | w_init = np.linalg.qr(np.random.randn( 449 | *w_shape))[0].astype('float32') 450 | 451 | w = tf.get_variable("W", dtype=tf.float32, initializer=w_init) 452 | 453 | # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2] 454 | dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant( 455 | tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2] 456 | 457 | if not reverse: 458 | 459 | _w = tf.reshape(w, [1, 1] + w_shape) 460 | z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 461 | 'SAME', data_format='NHWC') 462 | logdet += dlogdet 463 | 464 | return z, logdet 465 | else: 466 | 467 | _w = tf.matrix_inverse(w) 468 | _w = tf.reshape(_w, [1, 1]+w_shape) 469 | z = tf.nn.conv2d(z, _w, [1, 1, 1, 1], 470 | 'SAME', data_format='NHWC') 471 | logdet -= dlogdet 472 | 473 | return z, logdet 474 | 475 | else: 476 | 477 | # LU-decomposed version 478 | shape = Z.int_shape(z) 479 | with tf.variable_scope(name): 480 | 481 | dtype = 'float64' 482 | 483 | # Random orthogonal matrix: 484 | import scipy 485 | np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[ 486 | 0].astype('float32') 487 | 488 | np_p, np_l, np_u = scipy.linalg.lu(np_w) 489 | np_s = np.diag(np_u) 490 | np_sign_s = np.sign(np_s) 491 | np_log_s = np.log(abs(np_s)) 492 | np_u = np.triu(np_u, k=1) 493 | 494 | p = tf.get_variable("P", initializer=np_p, trainable=False) 495 | l = tf.get_variable("L", initializer=np_l) 496 | sign_s = tf.get_variable( 497 | "sign_S", initializer=np_sign_s, trainable=False) 498 | log_s = tf.get_variable("log_S", initializer=np_log_s) 499 | # S = tf.get_variable("S", initializer=np_s) 500 | u = tf.get_variable("U", initializer=np_u) 501 | 502 | p = tf.cast(p, dtype) 503 | l = tf.cast(l, dtype) 504 | sign_s = tf.cast(sign_s, dtype) 505 | log_s = tf.cast(log_s, dtype) 506 | u = tf.cast(u, dtype) 507 | 508 | w_shape = [shape[3], shape[3]] 509 | 510 | l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1) 511 | l = l * l_mask + tf.eye(*w_shape, dtype=dtype) 512 | u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s)) 513 | w = tf.matmul(p, tf.matmul(l, u)) 514 | 515 | if True: 516 | u_inv = tf.matrix_inverse(u) 517 | l_inv = tf.matrix_inverse(l) 518 | p_inv = tf.matrix_inverse(p) 519 | w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv)) 520 | else: 521 | w_inv = tf.matrix_inverse(w) 522 | 523 | w = tf.cast(w, tf.float32) 524 | w_inv = tf.cast(w_inv, tf.float32) 525 | log_s = tf.cast(log_s, tf.float32) 526 | 527 | if not reverse: 528 | 529 | w = tf.reshape(w, [1, 1] + w_shape) 530 | z = tf.nn.conv2d(z, w, [1, 1, 1, 1], 531 | 'SAME', data_format='NHWC') 532 | logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2]) 533 | 534 | return z, logdet 535 | else: 536 | 537 | w_inv = tf.reshape(w_inv, [1, 1]+w_shape) 538 | z = tf.nn.conv2d( 539 | z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC') 540 | logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2]) 541 | 542 | return z, logdet 543 | 544 | 545 | @add_arg_scope 546 | def split2d(name, z, objective=0.): 547 | with tf.variable_scope(name): 548 | n_z = Z.int_shape(z)[3] 549 | z1 = z[:, :, :, :n_z // 2] 550 | z2 = z[:, :, :, n_z // 2:] 551 | pz = split2d_prior(z1) 552 | objective += pz.logp(z2) 553 | z1 = Z.squeeze2d(z1) 554 | eps = pz.get_eps(z2) 555 | return z1, objective, eps 556 | 557 | 558 | @add_arg_scope 559 | def split2d_reverse(name, z, eps, eps_std): 560 | with tf.variable_scope(name): 561 | z1 = Z.unsqueeze2d(z) 562 | pz = split2d_prior(z1) 563 | if eps is not None: 564 | # Already sampled eps 565 | z2 = pz.sample2(eps) 566 | elif eps_std is not None: 567 | # Sample with given eps_std 568 | z2 = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1])) 569 | else: 570 | # Sample normally 571 | z2 = pz.sample 572 | z = tf.concat([z1, z2], 3) 573 | return z 574 | 575 | 576 | @add_arg_scope 577 | def split2d_prior(z): 578 | n_z2 = int(z.get_shape()[3]) 579 | n_z1 = n_z2 580 | h = Z.conv2d_zeros("conv", z, 2 * n_z1) 581 | 582 | mean = h[:, :, :, 0::2] 583 | logs = h[:, :, :, 1::2] 584 | return Z.gaussian_diag(mean, logs) 585 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tfops as Z 3 | import horovod.tensorflow as hvd 4 | 5 | # Optimizers 6 | 7 | ''' 8 | Polyak averaging op 9 | ''' 10 | 11 | 12 | def polyak(params, beta): 13 | #params = tf.trainable_variables() 14 | ema = tf.train.ExponentialMovingAverage(decay=beta, zero_debias=True) 15 | avg_op = tf.group(ema.apply(params)) 16 | # Swapping op 17 | updates = [] 18 | for i in range(len(params)): 19 | p = params[i] 20 | avg = ema.average(p) 21 | tmp = 0. + avg * 1. 22 | with tf.control_dependencies([tmp]): 23 | update1 = avg.assign(p) 24 | with tf.control_dependencies([update1]): 25 | update2 = p.assign(tmp) 26 | updates += [update1, update2] 27 | swap_op = tf.group(*updates) 28 | return avg_op, swap_op, ema 29 | 30 | 31 | def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8): 32 | updates = [] 33 | if type(cost_or_grads) is not list: 34 | gs = tf.gradients(cost_or_grads, params) 35 | else: 36 | gs = cost_or_grads 37 | 38 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs) 39 | 40 | # all-reduce 41 | grads = [Z.allreduce_mean(g) for g in gs] 42 | 43 | t = tf.Variable(1., 'adam_t') 44 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \ 45 | (1. - tf.pow(hps.beta1, t)) 46 | updates.append(t.assign_add(1)) 47 | 48 | for w, g in zip(params, grads): 49 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2') 50 | if hps.beta1 > 0: 51 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1') 52 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g 53 | updates.append(mom1.assign(mom1_new)) 54 | else: 55 | mom1_new = g 56 | m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g) 57 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon) 58 | w_new = hps.weight_decay * w - alpha_t * delta_t 59 | updates.append(mom2.assign(m2_new)) 60 | updates.append(w.assign(w_new)) 61 | 62 | # Polyak averaging 63 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2) 64 | train_op = tf.group(polyak_avg_op, *updates) 65 | return train_op, polyak_swap_op, ema 66 | 67 | 68 | ''' 69 | Adam optimizer 70 | Version whose learning rate could, in theory, be scaled linearly (like SGD+momentum). 71 | (It doesn't seem to work yet, though.) 72 | ''' 73 | 74 | 75 | def adam2(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8): 76 | updates = [] 77 | if type(cost_or_grads) is not list: 78 | gs = tf.gradients(cost_or_grads, params) 79 | else: 80 | gs = cost_or_grads 81 | 82 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs) 83 | 84 | # all-reduce 85 | grads1 = [Z.allreduce_mean(g) for g in gs] 86 | grads2 = [Z.allreduce_mean(g**2) for g in gs] 87 | 88 | t = tf.Variable(1., 'adam_t') 89 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \ 90 | (1. - tf.pow(hps.beta1, t)) 91 | updates.append(t.assign_add(1)) 92 | 93 | for w, g1, g2 in zip(params, grads1, grads2): 94 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2') 95 | if hps.beta1 > 0: 96 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1') 97 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g1 98 | updates.append(mom1.assign(mom1_new)) 99 | else: 100 | mom1_new = g1 101 | m2_new = beta2 * mom2 + (1. - beta2) * g2 102 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon) 103 | w_new = hps.weight_decay * w - alpha_t * delta_t 104 | updates.append(mom2.assign(m2_new)) 105 | updates.append(w.assign(w_new)) 106 | 107 | # Polyak averaging 108 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2) 109 | train_op = tf.group(polyak_avg_op, *updates) 110 | return train_op, polyak_swap_op, ema 111 | 112 | 113 | ''' 114 | Adam optimizer 115 | Version whose learning rate could, in theory, be scaled linearly (like SGD+momentum). 116 | It doesn't seem to work though. 117 | ''' 118 | 119 | 120 | def adam2_old(params, cost_or_grads, lr=3e-4, mom1=0.9, mom2=0.999, epsilon=1e-8): 121 | updates = [] 122 | if type(cost_or_grads) is not list: 123 | gs = tf.gradients(cost_or_grads, params) 124 | else: 125 | gs = cost_or_grads 126 | 127 | # all-reduce 128 | grads1 = [Z.allreduce_mean(g) for g in gs] 129 | grads2 = [Z.allreduce_mean(tf.square(g)) for g in gs] 130 | mom2 = tf.maximum(0., 1. - (hvd.size() * (1 - mom2))) 131 | 132 | t = tf.Variable(1., 'adam_t') 133 | lr_t = lr * tf.sqrt((1. - tf.pow(mom2, t))) / (1. - tf.pow(mom1, t)) 134 | updates.append(t.assign_add(1)) 135 | 136 | for p, g1, g2 in zip(params, grads1, grads2): 137 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') 138 | if mom1 > 0: 139 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') 140 | v_t = mom1 * v + (1. - mom1) * g1 141 | updates.append(v.assign(v_t)) 142 | else: 143 | v_t = g1 144 | mg_t = mom2 * mg + (1. - mom2) * g2 145 | delta_t = v_t / (tf.sqrt(mg_t) + epsilon) 146 | p_t = p - lr_t * delta_t 147 | updates.append(mg.assign(mg_t)) 148 | updates.append(p.assign(p_t)) 149 | return tf.group(*updates) 150 | 151 | 152 | def adamax(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8): 153 | updates = [] 154 | if type(cost_or_grads) is not list: 155 | gs = tf.gradients(cost_or_grads, params) 156 | else: 157 | gs = cost_or_grads 158 | 159 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs) 160 | 161 | # all-reduce 162 | grads = [Z.allreduce_mean(g) for g in gs] 163 | 164 | t = tf.Variable(1., 'adam_t') 165 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \ 166 | (1. - tf.pow(hps.beta1, t)) 167 | updates.append(t.assign_add(1)) 168 | 169 | for w, g in zip(params, grads): 170 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2') 171 | if hps.beta1 > 0: 172 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1') 173 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g 174 | updates.append(mom1.assign(mom1_new)) 175 | else: 176 | mom1_new = g 177 | m2_new = tf.maximum(beta2 * mom2, abs(g)) 178 | delta_t = mom1_new / (m2_new + epsilon) 179 | w_new = hps.weight_decay * w - alpha_t * delta_t 180 | updates.append(mom2.assign(m2_new)) 181 | updates.append(w.assign(w_new)) 182 | 183 | # Polyak averaging 184 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2) 185 | train_op = tf.group(polyak_avg_op, *updates) 186 | return train_op, polyak_swap_op, ema 187 | 188 | 189 | def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8): 190 | updates = [] 191 | if type(cost_or_grads) is not list: 192 | gs = tf.gradients(cost_or_grads, params) 193 | else: 194 | gs = cost_or_grads 195 | 196 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs) 197 | 198 | # all-reduce 199 | grads = [Z.allreduce_mean(g) for g in gs] 200 | 201 | t = tf.Variable(1., 'adam_t') 202 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \ 203 | (1. - tf.pow(hps.beta1, t)) 204 | updates.append(t.assign_add(1)) 205 | 206 | for w, g in zip(params, grads): 207 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2') 208 | if hps.beta1 > 0: 209 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1') 210 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g 211 | updates.append(mom1.assign(mom1_new)) 212 | else: 213 | mom1_new = g 214 | m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g) 215 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon) 216 | w_new = hps.weight_decay * w - alpha_t * delta_t 217 | updates.append(mom2.assign(m2_new)) 218 | updates.append(w.assign(w_new)) 219 | 220 | # Polyak averaging 221 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2) 222 | train_op = tf.group(polyak_avg_op, *updates) 223 | return train_op, polyak_swap_op, ema 224 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.8.0 2 | keras==2.2.0 3 | pillow==5.2.0 4 | toposort==1.5 5 | horovod==0.13.8 6 | -------------------------------------------------------------------------------- /tfops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope 3 | from tensorflow.contrib.layers import variance_scaling_initializer 4 | import numpy as np 5 | import horovod.tensorflow as hvd 6 | 7 | # Debugging function 8 | do_print_act_stats = True 9 | 10 | 11 | def print_act_stats(x, _str=""): 12 | if not do_print_act_stats: 13 | return x 14 | if hvd.rank() != 0: 15 | return x 16 | if len(x.get_shape()) == 1: 17 | x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True) 18 | if len(x.get_shape()) == 2: 19 | x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True) 20 | if len(x.get_shape()) == 4: 21 | x_mean, x_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True) 22 | stats = [tf.reduce_min(x_mean), tf.reduce_mean(x_mean), tf.reduce_max(x_mean), 23 | tf.reduce_min(tf.sqrt(x_var)), tf.reduce_mean(tf.sqrt(x_var)), tf.reduce_max(tf.sqrt(x_var))] 24 | return tf.Print(x, stats, "["+_str+"] "+x.name) 25 | 26 | # Allreduce methods 27 | 28 | 29 | def allreduce_sum(x): 30 | if hvd.size() == 1: 31 | return x 32 | return hvd.mpi_ops._allreduce(x) 33 | 34 | 35 | def allreduce_mean(x): 36 | x = allreduce_sum(x) / hvd.size() 37 | return x 38 | 39 | 40 | def default_initial_value(shape, std=0.05): 41 | return tf.random_normal(shape, 0., std) 42 | 43 | 44 | def default_initializer(std=0.05): 45 | return tf.random_normal_initializer(0., std) 46 | 47 | 48 | def int_shape(x): 49 | if str(x.get_shape()[0]) != '?': 50 | return list(map(int, x.get_shape())) 51 | return [-1]+list(map(int, x.get_shape()[1:])) 52 | 53 | # wrapper tf.get_variable, augmented with 'init' functionality 54 | # Get variable with data dependent init 55 | 56 | 57 | @add_arg_scope 58 | def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False, trainable=True): 59 | w = tf.get_variable(name, shape, dtype, None, trainable=trainable) 60 | if init: 61 | w = w.assign(initial_value) 62 | with tf.control_dependencies([w]): 63 | return w 64 | return w 65 | 66 | # Activation normalization 67 | # Convenience function that does centering+scaling 68 | 69 | 70 | @add_arg_scope 71 | def actnorm(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True): 72 | if arg_scope([get_variable_ddi], trainable=trainable): 73 | if not reverse: 74 | x = actnorm_center(name+"_center", x, reverse) 75 | x = actnorm_scale(name+"_scale", x, scale, logdet, 76 | logscale_factor, batch_variance, reverse, init) 77 | if logdet != None: 78 | x, logdet = x 79 | else: 80 | x = actnorm_scale(name + "_scale", x, scale, logdet, 81 | logscale_factor, batch_variance, reverse, init) 82 | if logdet != None: 83 | x, logdet = x 84 | x = actnorm_center(name+"_center", x, reverse) 85 | if logdet != None: 86 | return x, logdet 87 | return x 88 | 89 | # Activation normalization 90 | 91 | 92 | @add_arg_scope 93 | def actnorm_center(name, x, reverse=False): 94 | shape = x.get_shape() 95 | with tf.variable_scope(name): 96 | assert len(shape) == 2 or len(shape) == 4 97 | if len(shape) == 2: 98 | x_mean = tf.reduce_mean(x, [0], keepdims=True) 99 | b = get_variable_ddi( 100 | "b", (1, int_shape(x)[1]), initial_value=-x_mean) 101 | elif len(shape) == 4: 102 | x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True) 103 | b = get_variable_ddi( 104 | "b", (1, 1, 1, int_shape(x)[3]), initial_value=-x_mean) 105 | 106 | if not reverse: 107 | x += b 108 | else: 109 | x -= b 110 | 111 | return x 112 | 113 | # Activation normalization 114 | 115 | 116 | @add_arg_scope 117 | def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True): 118 | shape = x.get_shape() 119 | with tf.variable_scope(name), arg_scope([get_variable_ddi], trainable=trainable): 120 | assert len(shape) == 2 or len(shape) == 4 121 | if len(shape) == 2: 122 | x_var = tf.reduce_mean(x**2, [0], keepdims=True) 123 | logdet_factor = 1 124 | _shape = (1, int_shape(x)[1]) 125 | 126 | elif len(shape) == 4: 127 | x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True) 128 | logdet_factor = int(shape[1])*int(shape[2]) 129 | _shape = (1, 1, 1, int_shape(x)[3]) 130 | 131 | if batch_variance: 132 | x_var = tf.reduce_mean(x**2, keepdims=True) 133 | 134 | if init and False: 135 | # MPI all-reduce 136 | x_var = allreduce_mean(x_var) 137 | # Somehow this also slows down graph when not initializing 138 | # (it's not optimized away?) 139 | 140 | if True: 141 | logs = get_variable_ddi("logs", _shape, initial_value=tf.log( 142 | scale/(tf.sqrt(x_var)+1e-6))/logscale_factor)*logscale_factor 143 | if not reverse: 144 | x = x * tf.exp(logs) 145 | else: 146 | x = x * tf.exp(-logs) 147 | else: 148 | # Alternative, doesn't seem to do significantly worse or better than the logarithmic version above 149 | s = get_variable_ddi("s", _shape, initial_value=scale / 150 | (tf.sqrt(x_var) + 1e-6) / logscale_factor)*logscale_factor 151 | logs = tf.log(tf.abs(s)) 152 | if not reverse: 153 | x *= s 154 | else: 155 | x /= s 156 | 157 | if logdet != None: 158 | dlogdet = tf.reduce_sum(logs) * logdet_factor 159 | if reverse: 160 | dlogdet *= -1 161 | return x, logdet + dlogdet 162 | 163 | return x 164 | 165 | # Linear layer with layer norm 166 | 167 | 168 | @add_arg_scope 169 | def linear(name, x, width, do_weightnorm=True, do_actnorm=True, initializer=None, scale=1.): 170 | initializer = initializer or default_initializer() 171 | with tf.variable_scope(name): 172 | n_in = int(x.get_shape()[1]) 173 | w = tf.get_variable("W", [n_in, width], 174 | tf.float32, initializer=initializer) 175 | if do_weightnorm: 176 | w = tf.nn.l2_normalize(w, [0]) 177 | x = tf.matmul(x, w) 178 | x += tf.get_variable("b", [1, width], 179 | initializer=tf.zeros_initializer()) 180 | if do_actnorm: 181 | x = actnorm("actnorm", x, scale) 182 | return x 183 | 184 | # Linear layer with zero init 185 | 186 | 187 | @add_arg_scope 188 | def linear_zeros(name, x, width, logscale_factor=3): 189 | with tf.variable_scope(name): 190 | n_in = int(x.get_shape()[1]) 191 | w = tf.get_variable("W", [n_in, width], tf.float32, 192 | initializer=tf.zeros_initializer()) 193 | x = tf.matmul(x, w) 194 | x += tf.get_variable("b", [1, width], 195 | initializer=tf.zeros_initializer()) 196 | x *= tf.exp(tf.get_variable("logs", 197 | [1, width], initializer=tf.zeros_initializer()) * logscale_factor) 198 | return x 199 | 200 | # Slow way to add edge padding 201 | 202 | 203 | def add_edge_padding(x, filter_size): 204 | assert filter_size[0] % 2 == 1 205 | if filter_size[0] == 1 and filter_size[1] == 1: 206 | return x 207 | a = (filter_size[0] - 1) // 2 # vertical padding size 208 | b = (filter_size[1] - 1) // 2 # horizontal padding size 209 | if True: 210 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]]) 211 | name = "_".join([str(dim) for dim in [a, b, *int_shape(x)[1:3]]]) 212 | pads = tf.get_collection(name) 213 | if not pads: 214 | if hvd.rank() == 0: 215 | print("Creating pad", name) 216 | pad = np.zeros([1] + int_shape(x)[1:3] + [1], dtype='float32') 217 | pad[:, :a, :, 0] = 1. 218 | pad[:, -a:, :, 0] = 1. 219 | pad[:, :, :b, 0] = 1. 220 | pad[:, :, -b:, 0] = 1. 221 | pad = tf.convert_to_tensor(pad) 222 | tf.add_to_collection(name, pad) 223 | else: 224 | pad = pads[0] 225 | pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1]) 226 | x = tf.concat([x, pad], axis=3) 227 | else: 228 | pad = tf.pad(tf.zeros_like(x[:, :, :, :1]) - 1, 229 | [[0, 0], [a, a], [b, b], [0, 0]]) + 1 230 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]]) 231 | x = tf.concat([x, pad], axis=3) 232 | return x 233 | 234 | 235 | @add_arg_scope 236 | def conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", do_weightnorm=False, do_actnorm=True, context1d=None, skip=1, edge_bias=True): 237 | with tf.variable_scope(name): 238 | if edge_bias and pad == "SAME": 239 | x = add_edge_padding(x, filter_size) 240 | pad = 'VALID' 241 | 242 | n_in = int(x.get_shape()[3]) 243 | 244 | stride_shape = [1] + stride + [1] 245 | filter_shape = filter_size + [n_in, width] 246 | w = tf.get_variable("W", filter_shape, tf.float32, 247 | initializer=default_initializer()) 248 | if do_weightnorm: 249 | w = tf.nn.l2_normalize(w, [0, 1, 2]) 250 | if skip == 1: 251 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC') 252 | else: 253 | assert stride[0] == 1 and stride[1] == 1 254 | x = tf.nn.atrous_conv2d(x, w, skip, pad) 255 | if do_actnorm: 256 | x = actnorm("actnorm", x) 257 | else: 258 | x += tf.get_variable("b", [1, 1, 1, width], 259 | initializer=tf.zeros_initializer()) 260 | 261 | if context1d != None: 262 | x += tf.reshape(linear("context", context1d, 263 | width), [-1, 1, 1, width]) 264 | return x 265 | 266 | 267 | @add_arg_scope 268 | def separable_conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], padding="SAME", do_actnorm=True, std=0.05): 269 | n_in = int(x.get_shape()[3]) 270 | with tf.variable_scope(name): 271 | assert filter_size[0] % 2 == 1 and filter_size[1] % 2 == 1 272 | strides = [1] + stride + [1] 273 | w1_shape = filter_size + [n_in, 1] 274 | w1_init = np.zeros(w1_shape, dtype='float32') 275 | w1_init[(filter_size[0]-1)//2, (filter_size[1]-1)//2, :, 276 | :] = 1. # initialize depthwise conv as identity 277 | w1 = tf.get_variable("W1", dtype=tf.float32, initializer=w1_init) 278 | w2_shape = [1, 1, n_in, width] 279 | w2 = tf.get_variable("W2", w2_shape, tf.float32, 280 | initializer=default_initializer(std)) 281 | x = tf.nn.separable_conv2d( 282 | x, w1, w2, strides, padding, data_format='NHWC') 283 | if do_actnorm: 284 | x = actnorm("actnorm", x) 285 | else: 286 | x += tf.get_variable("b", [1, 1, 1, width], 287 | initializer=tf.zeros_initializer(std)) 288 | 289 | return x 290 | 291 | 292 | @add_arg_scope 293 | def conv2d_zeros(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", logscale_factor=3, skip=1, edge_bias=True): 294 | with tf.variable_scope(name): 295 | if edge_bias and pad == "SAME": 296 | x = add_edge_padding(x, filter_size) 297 | pad = 'VALID' 298 | 299 | n_in = int(x.get_shape()[3]) 300 | stride_shape = [1] + stride + [1] 301 | filter_shape = filter_size + [n_in, width] 302 | w = tf.get_variable("W", filter_shape, tf.float32, 303 | initializer=tf.zeros_initializer()) 304 | if skip == 1: 305 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC') 306 | else: 307 | assert stride[0] == 1 and stride[1] == 1 308 | x = tf.nn.atrous_conv2d(x, w, skip, pad) 309 | x += tf.get_variable("b", [1, 1, 1, width], 310 | initializer=tf.zeros_initializer()) 311 | x *= tf.exp(tf.get_variable("logs", 312 | [1, width], initializer=tf.zeros_initializer()) * logscale_factor) 313 | return x 314 | 315 | 316 | # 2X nearest-neighbour upsampling, also inspired by Jascha Sohl-Dickstein's code 317 | def upsample2d_nearest_neighbour(x): 318 | shape = x.get_shape() 319 | n_batch = int(shape[0]) 320 | height = int(shape[1]) 321 | width = int(shape[2]) 322 | n_channels = int(shape[3]) 323 | x = tf.reshape(x, (n_batch, height, 1, width, 1, n_channels)) 324 | x = tf.concat(2, [x, x]) 325 | x = tf.concat(4, [x, x]) 326 | x = tf.reshape(x, (n_batch, height*2, width*2, n_channels)) 327 | return x 328 | 329 | 330 | def upsample(x, factor=2): 331 | shape = x.get_shape() 332 | height = int(shape[1]) 333 | width = int(shape[2]) 334 | x = tf.image.resize_nearest_neighbor(x, [height * factor, width * factor]) 335 | return x 336 | 337 | 338 | def squeeze2d(x, factor=2): 339 | assert factor >= 1 340 | if factor == 1: 341 | return x 342 | shape = x.get_shape() 343 | height = int(shape[1]) 344 | width = int(shape[2]) 345 | n_channels = int(shape[3]) 346 | assert height % factor == 0 and width % factor == 0 347 | x = tf.reshape(x, [-1, height//factor, factor, 348 | width//factor, factor, n_channels]) 349 | x = tf.transpose(x, [0, 1, 3, 5, 2, 4]) 350 | x = tf.reshape(x, [-1, height//factor, width // 351 | factor, n_channels*factor*factor]) 352 | return x 353 | 354 | 355 | def unsqueeze2d(x, factor=2): 356 | assert factor >= 1 357 | if factor == 1: 358 | return x 359 | shape = x.get_shape() 360 | height = int(shape[1]) 361 | width = int(shape[2]) 362 | n_channels = int(shape[3]) 363 | assert n_channels >= 4 and n_channels % 4 == 0 364 | x = tf.reshape( 365 | x, (-1, height, width, int(n_channels/factor**2), factor, factor)) 366 | x = tf.transpose(x, [0, 1, 4, 2, 5, 3]) 367 | x = tf.reshape(x, (-1, int(height*factor), 368 | int(width*factor), int(n_channels/factor**2))) 369 | return x 370 | 371 | # Reverse features across channel dimension 372 | 373 | 374 | def reverse_features(name, h, reverse=False): 375 | return h[:, :, :, ::-1] 376 | 377 | # Shuffle across the channel dimension 378 | 379 | 380 | def shuffle_features(name, h, indices=None, return_indices=False, reverse=False): 381 | with tf.variable_scope(name): 382 | 383 | rng = np.random.RandomState( 384 | (abs(hash(tf.get_variable_scope().name))) % 10000000) 385 | 386 | if indices == None: 387 | # Create numpy and tensorflow variables with indices 388 | n_channels = int(h.get_shape()[-1]) 389 | indices = list(range(n_channels)) 390 | rng.shuffle(indices) 391 | # Reverse it 392 | indices_inverse = [0]*n_channels 393 | for i in range(n_channels): 394 | indices_inverse[indices[i]] = i 395 | 396 | tf_indices = tf.get_variable("indices", dtype=tf.int32, initializer=np.asarray( 397 | indices, dtype='int32'), trainable=False) 398 | tf_indices_reverse = tf.get_variable("indices_inverse", dtype=tf.int32, initializer=np.asarray( 399 | indices_inverse, dtype='int32'), trainable=False) 400 | 401 | _indices = tf_indices 402 | if reverse: 403 | _indices = tf_indices_reverse 404 | 405 | if len(h.get_shape()) == 2: 406 | # Slice 407 | h = tf.transpose(h) 408 | h = tf.gather(h, _indices) 409 | h = tf.transpose(h) 410 | elif len(h.get_shape()) == 4: 411 | # Slice 412 | h = tf.transpose(h, [3, 1, 2, 0]) 413 | h = tf.gather(h, _indices) 414 | h = tf.transpose(h, [3, 1, 2, 0]) 415 | if return_indices: 416 | return h, indices 417 | return h 418 | 419 | 420 | def embedding(name, y, n_y, width): 421 | with tf.variable_scope(name): 422 | params = tf.get_variable( 423 | "embedding", [n_y, width], initializer=default_initializer()) 424 | embeddings = tf.gather(params, y) 425 | return embeddings 426 | 427 | # Random variables 428 | 429 | 430 | def flatten_sum(logps): 431 | if len(logps.get_shape()) == 2: 432 | return tf.reduce_sum(logps, [1]) 433 | elif len(logps.get_shape()) == 4: 434 | return tf.reduce_sum(logps, [1, 2, 3]) 435 | else: 436 | raise Exception() 437 | 438 | 439 | def standard_gaussian(shape): 440 | return gaussian_diag(tf.zeros(shape), tf.zeros(shape)) 441 | 442 | 443 | def gaussian_diag(mean, logsd): 444 | class o(object): 445 | pass 446 | o.mean = mean 447 | o.logsd = logsd 448 | o.eps = tf.random_normal(tf.shape(mean)) 449 | o.sample = mean + tf.exp(logsd) * o.eps 450 | o.sample2 = lambda eps: mean + tf.exp(logsd) * eps 451 | o.logps = lambda x: -0.5 * \ 452 | (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / tf.exp(2. * logsd)) 453 | o.logp = lambda x: flatten_sum(o.logps(x)) 454 | o.get_eps = lambda x: (x - mean) / tf.exp(logsd) 455 | return o 456 | 457 | 458 | # def discretized_logistic_old(mean, logscale, binsize=1 / 256.0, sample=None): 459 | # scale = tf.exp(logscale) 460 | # sample = (tf.floor(sample / binsize) * binsize - mean) / scale 461 | # logp = tf.log(tf.sigmoid(sample + binsize / scale) - tf.sigmoid(sample) + 1e-7) 462 | # return tf.reduce_sum(logp, [1, 2, 3]) 463 | 464 | def discretized_logistic(mean, logscale, binsize=1. / 256): 465 | class o(object): 466 | pass 467 | o.mean = mean 468 | o.logscale = logscale 469 | scale = tf.exp(logscale) 470 | 471 | def logps(x): 472 | x = (x - mean) / scale 473 | return tf.log(tf.sigmoid(x + binsize / scale) - tf.sigmoid(x) + 1e-7) 474 | o.logps = logps 475 | o.logp = lambda x: flatten_sum(logps(x)) 476 | return o 477 | 478 | 479 | def _symmetric_matrix_square_root(mat, eps=1e-10): 480 | """Compute square root of a symmetric matrix. 481 | Note that this is different from an elementwise square root. We want to 482 | compute M' where M' = sqrt(mat) such that M' * M' = mat. 483 | Also note that this method **only** works for symmetric matrices. 484 | Args: 485 | mat: Matrix to take the square root of. 486 | eps: Small epsilon such that any element less than eps will not be square 487 | rooted to guard against numerical instability. 488 | Returns: 489 | Matrix square root of mat. 490 | """ 491 | # Unlike numpy, tensorflow's return order is (s, u, v) 492 | s, u, v = tf.svd(mat) 493 | # sqrt is unstable around 0, just use 0 in such case 494 | si = tf.where(tf.less(s, eps), s, tf.sqrt(s)) 495 | # Note that the v returned by Tensorflow is v = V 496 | # (when referencing the equation A = U S V^T) 497 | # This is unlike Numpy which returns v = V^T 498 | return tf.matmul( 499 | tf.matmul(u, tf.diag(si)), v, transpose_b=True) 500 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Modified Horovod MNIST example 4 | 5 | import os 6 | import sys 7 | import time 8 | 9 | import horovod.tensorflow as hvd 10 | import numpy as np 11 | import tensorflow as tf 12 | import graphics 13 | from utils import ResultLogger 14 | 15 | learn = tf.contrib.learn 16 | 17 | # Surpress verbose warnings 18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 19 | 20 | 21 | def _print(*args, **kwargs): 22 | if hvd.rank() == 0: 23 | print(*args, **kwargs) 24 | 25 | 26 | def init_visualizations(hps, model, logdir): 27 | 28 | def sample_batch(y, eps): 29 | n_batch = hps.local_batch_train 30 | xs = [] 31 | for i in range(int(np.ceil(len(eps) / n_batch))): 32 | xs.append(model.sample( 33 | y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch])) 34 | return np.concatenate(xs) 35 | 36 | def draw_samples(epoch): 37 | if hvd.rank() != 0: 38 | return 39 | 40 | rows = 10 if hps.image_size <= 64 else 4 41 | cols = rows 42 | n_batch = rows*cols 43 | y = np.asarray([_y % hps.n_y for _y in ( 44 | list(range(cols)) * rows)], dtype='int32') 45 | 46 | # temperatures = [0., .25, .5, .626, .75, .875, 1.] #previously 47 | temperatures = [0., .25, .5, .6, .7, .8, .9, 1.] 48 | 49 | x_samples = [] 50 | x_samples.append(sample_batch(y, [.0]*n_batch)) 51 | x_samples.append(sample_batch(y, [.25]*n_batch)) 52 | x_samples.append(sample_batch(y, [.5]*n_batch)) 53 | x_samples.append(sample_batch(y, [.6]*n_batch)) 54 | x_samples.append(sample_batch(y, [.7]*n_batch)) 55 | x_samples.append(sample_batch(y, [.8]*n_batch)) 56 | x_samples.append(sample_batch(y, [.9] * n_batch)) 57 | x_samples.append(sample_batch(y, [1.]*n_batch)) 58 | # previously: 0, .25, .5, .625, .75, .875, 1. 59 | 60 | for i in range(len(x_samples)): 61 | x_sample = np.reshape( 62 | x_samples[i], (n_batch, hps.image_size, hps.image_size, 3)) 63 | graphics.save_raster(x_sample, logdir + 64 | 'epoch_{}_sample_{}.png'.format(epoch, i)) 65 | 66 | return draw_samples 67 | 68 | # === 69 | # Code for getting data 70 | # === 71 | def get_data(hps, sess): 72 | if hps.image_size == -1: 73 | hps.image_size = {'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64, 74 | 'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem] 75 | if hps.n_test == -1: 76 | hps.n_test = {'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000, 77 | 'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem] 78 | hps.n_y = {'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000, 79 | 'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem] 80 | if hps.data_dir == "": 81 | hps.data_dir = {'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr', 82 | 'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem] 83 | 84 | if hps.problem == 'lsun_realnvp': 85 | hps.rnd_crop = True 86 | else: 87 | hps.rnd_crop = False 88 | 89 | if hps.category: 90 | hps.data_dir += ('/%s' % hps.category) 91 | 92 | # Use anchor_size to rescale batch size based on image_size 93 | s = hps.anchor_size 94 | hps.local_batch_train = hps.n_batch_train * \ 95 | s * s // (hps.image_size * hps.image_size) 96 | hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[ 97 | hps.local_batch_train] # round down to closest divisor of 50 98 | hps.local_batch_init = hps.n_batch_init * \ 99 | s * s // (hps.image_size * hps.image_size) 100 | 101 | print("Rank {} Batch sizes Train {} Test {} Init {}".format( 102 | hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init)) 103 | 104 | if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']: 105 | hps.direct_iterator = True 106 | import data_loaders.get_data as v 107 | train_iterator, test_iterator, data_init = \ 108 | v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train, 109 | hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop) 110 | 111 | elif hps.problem in ['mnist', 'cifar10']: 112 | hps.direct_iterator = False 113 | import data_loaders.get_mnist_cifar as v 114 | train_iterator, test_iterator, data_init = \ 115 | v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train, 116 | hps.local_batch_test, hps.local_batch_init, hps.image_size) 117 | 118 | else: 119 | raise Exception() 120 | 121 | return train_iterator, test_iterator, data_init 122 | 123 | 124 | def process_results(results): 125 | stats = ['loss', 'bits_x', 'bits_y', 'pred_loss'] 126 | assert len(stats) == results.shape[0] 127 | res_dict = {} 128 | for i in range(len(stats)): 129 | res_dict[stats[i]] = "{:.4f}".format(results[i]) 130 | return res_dict 131 | 132 | 133 | def main(hps): 134 | 135 | # Initialize Horovod. 136 | hvd.init() 137 | 138 | # Create tensorflow session 139 | sess = tensorflow_session() 140 | 141 | # Download and load dataset. 142 | tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed) 143 | np.random.seed(hvd.rank() + hvd.size() * hps.seed) 144 | 145 | # Get data and set train_its and valid_its 146 | train_iterator, test_iterator, data_init = get_data(hps, sess) 147 | hps.train_its, hps.test_its, hps.full_test_its = get_its(hps) 148 | 149 | # Create log dir 150 | logdir = os.path.abspath(hps.logdir) + "/" 151 | if not os.path.exists(logdir): 152 | os.mkdir(logdir) 153 | 154 | # Create model 155 | import model 156 | model = model.model(sess, hps, train_iterator, test_iterator, data_init) 157 | 158 | # Initialize visualization functions 159 | visualise = init_visualizations(hps, model, logdir) 160 | 161 | if not hps.inference: 162 | # Perform training 163 | train(sess, model, hps, logdir, visualise) 164 | else: 165 | infer(sess, model, hps, test_iterator) 166 | 167 | 168 | def infer(sess, model, hps, iterator): 169 | # Example of using model in inference mode. Load saved model using hps.restore_path 170 | # Can provide x, y from files instead of dataset iterator 171 | # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32) 172 | if hps.direct_iterator: 173 | iterator = iterator.get_next() 174 | 175 | xs = [] 176 | zs = [] 177 | for it in range(hps.full_test_its): 178 | if hps.direct_iterator: 179 | # replace with x, y, attr if you're getting CelebA attributes, also modify get_data 180 | x, y = sess.run(iterator) 181 | else: 182 | x, y = iterator() 183 | 184 | z = model.encode(x, y) 185 | x = model.decode(y, z) 186 | xs.append(x) 187 | zs.append(z) 188 | 189 | x = np.concatenate(xs, axis=0) 190 | z = np.concatenate(zs, axis=0) 191 | np.save('logs/x.npy', x) 192 | np.save('logs/z.npy', z) 193 | return zs 194 | 195 | 196 | def train(sess, model, hps, logdir, visualise): 197 | _print(hps) 198 | _print('Starting training. Logging to', logdir) 199 | _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg') 200 | 201 | # Train 202 | sess.graph.finalize() 203 | n_processed = 0 204 | n_images = 0 205 | train_time = 0.0 206 | test_loss_best = 999999 207 | 208 | if hvd.rank() == 0: 209 | train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__) 210 | test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__) 211 | 212 | tcurr = time.time() 213 | for epoch in range(1, hps.epochs): 214 | 215 | t = time.time() 216 | 217 | train_results = [] 218 | for it in range(hps.train_its): 219 | 220 | # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs. 221 | lr = hps.lr * min(1., n_processed / 222 | (hps.n_train * hps.epochs_warmup)) 223 | 224 | # Run a training step synchronously. 225 | _t = time.time() 226 | train_results += [model.train(lr)] 227 | if hps.verbose and hvd.rank() == 0: 228 | _print(n_processed, time.time()-_t, train_results[-1]) 229 | sys.stdout.flush() 230 | 231 | # Images seen wrt anchor resolution 232 | n_processed += hvd.size() * hps.n_batch_train 233 | # Actual images seen at current resolution 234 | n_images += hvd.size() * hps.local_batch_train 235 | 236 | train_results = np.mean(np.asarray(train_results), axis=0) 237 | 238 | dtrain = time.time() - t 239 | ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain 240 | train_time += dtrain 241 | 242 | if hvd.rank() == 0: 243 | train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int( 244 | train_time), **process_results(train_results)) 245 | 246 | if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0: 247 | test_results = [] 248 | msg = '' 249 | 250 | t = time.time() 251 | # model.polyak_swap() 252 | 253 | if epoch % hps.epochs_full_valid == 0: 254 | # Full validation run 255 | for it in range(hps.full_test_its): 256 | test_results += [model.test()] 257 | test_results = np.mean(np.asarray(test_results), axis=0) 258 | 259 | if hvd.rank() == 0: 260 | test_logger.log(epoch=epoch, n_processed=n_processed, 261 | n_images=n_images, **process_results(test_results)) 262 | 263 | # Save checkpoint 264 | if test_results[0] < test_loss_best: 265 | test_loss_best = test_results[0] 266 | model.save(logdir+"model_best_loss.ckpt") 267 | msg += ' *' 268 | 269 | dtest = time.time() - t 270 | 271 | # Sample 272 | t = time.time() 273 | if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0: 274 | visualise(epoch) 275 | dsample = time.time() - t 276 | 277 | if hvd.rank() == 0: 278 | dcurr = time.time() - tcurr 279 | tcurr = time.time() 280 | _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format( 281 | ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg) 282 | 283 | # model.polyak_swap() 284 | 285 | if hvd.rank() == 0: 286 | _print("Finished!") 287 | 288 | # Get number of training and validation iterations 289 | def get_its(hps): 290 | # These run for a fixed amount of time. As anchored batch is smaller, we've actually seen fewer examples 291 | train_its = int(np.ceil(hps.n_train / (hps.n_batch_train * hvd.size()))) 292 | test_its = int(np.ceil(hps.n_test / (hps.n_batch_train * hvd.size()))) 293 | train_epoch = train_its * hps.n_batch_train * hvd.size() 294 | 295 | # Do a full validation run 296 | if hvd.rank() == 0: 297 | print(hps.n_test, hps.local_batch_test, hvd.size()) 298 | assert hps.n_test % (hps.local_batch_test * hvd.size()) == 0 299 | full_test_its = hps.n_test // (hps.local_batch_test * hvd.size()) 300 | 301 | if hvd.rank() == 0: 302 | print("Train epoch size: " + str(train_epoch)) 303 | return train_its, test_its, full_test_its 304 | 305 | 306 | ''' 307 | Create tensorflow session with horovod 308 | ''' 309 | def tensorflow_session(): 310 | # Init session and params 311 | config = tf.ConfigProto() 312 | config.gpu_options.allow_growth = True 313 | # Pin GPU to local rank (one GPU per process) 314 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 315 | sess = tf.Session(config=config) 316 | return sess 317 | 318 | 319 | if __name__ == "__main__": 320 | 321 | # This enables a ctr-C without triggering errors 322 | import signal 323 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0)) 324 | 325 | import argparse 326 | parser = argparse.ArgumentParser() 327 | parser.add_argument("--verbose", action='store_true', help="Verbose mode") 328 | parser.add_argument("--restore_path", type=str, default='', 329 | help="Location of checkpoint to restore") 330 | parser.add_argument("--inference", action="store_true", 331 | help="Use in inference mode") 332 | parser.add_argument("--logdir", type=str, 333 | default='./logs', help="Location to save logs") 334 | 335 | # Dataset hyperparams: 336 | parser.add_argument("--problem", type=str, default='cifar10', 337 | help="Problem (mnist/cifar10/imagenet") 338 | parser.add_argument("--category", type=str, 339 | default='', help="LSUN category") 340 | parser.add_argument("--data_dir", type=str, default='', 341 | help="Location of data") 342 | parser.add_argument("--dal", type=int, default=1, 343 | help="Data augmentation level: 0=None, 1=Standard, 2=Extra") 344 | 345 | # New dataloader params 346 | parser.add_argument("--fmap", type=int, default=1, 347 | help="# Threads for parallel file reading") 348 | parser.add_argument("--pmap", type=int, default=16, 349 | help="# Threads for parallel map") 350 | 351 | # Optimization hyperparams: 352 | parser.add_argument("--n_train", type=int, 353 | default=50000, help="Train epoch size") 354 | parser.add_argument("--n_test", type=int, default=- 355 | 1, help="Valid epoch size") 356 | parser.add_argument("--n_batch_train", type=int, 357 | default=64, help="Minibatch size") 358 | parser.add_argument("--n_batch_test", type=int, 359 | default=50, help="Minibatch size") 360 | parser.add_argument("--n_batch_init", type=int, default=256, 361 | help="Minibatch size for data-dependent init") 362 | parser.add_argument("--optimizer", type=str, 363 | default="adamax", help="adam or adamax") 364 | parser.add_argument("--lr", type=float, default=0.001, 365 | help="Base learning rate") 366 | parser.add_argument("--beta1", type=float, default=.9, help="Adam beta1") 367 | parser.add_argument("--polyak_epochs", type=float, default=1, 368 | help="Nr of averaging epochs for Polyak and beta2") 369 | parser.add_argument("--weight_decay", type=float, default=1., 370 | help="Weight decay. Switched off by default.") 371 | parser.add_argument("--epochs", type=int, default=1000000, 372 | help="Total number of training epochs") 373 | parser.add_argument("--epochs_warmup", type=int, 374 | default=10, help="Warmup epochs") 375 | parser.add_argument("--epochs_full_valid", type=int, 376 | default=50, help="Epochs between valid") 377 | parser.add_argument("--gradient_checkpointing", type=int, 378 | default=1, help="Use memory saving gradients") 379 | 380 | # Model hyperparams: 381 | parser.add_argument("--image_size", type=int, 382 | default=-1, help="Image size") 383 | parser.add_argument("--anchor_size", type=int, default=32, 384 | help="Anchor size for deciding batch size") 385 | parser.add_argument("--width", type=int, default=512, 386 | help="Width of hidden layers") 387 | parser.add_argument("--depth", type=int, default=32, 388 | help="Depth of network") 389 | parser.add_argument("--weight_y", type=float, default=0.00, 390 | help="Weight of log p(y|x) in weighted loss") 391 | parser.add_argument("--n_bits_x", type=int, default=8, 392 | help="Number of bits of x") 393 | parser.add_argument("--n_levels", type=int, default=3, 394 | help="Number of levels") 395 | 396 | # Synthesis/Sampling hyperparameters: 397 | parser.add_argument("--n_sample", type=int, default=1, 398 | help="minibatch size for sample") 399 | parser.add_argument("--epochs_full_sample", type=int, 400 | default=50, help="Epochs between full scale sample") 401 | 402 | # Ablation 403 | parser.add_argument("--learntop", action="store_true", 404 | help="Learn spatial prior") 405 | parser.add_argument("--ycond", action="store_true", 406 | help="Use y conditioning") 407 | parser.add_argument("--seed", type=int, default=0, help="Random seed") 408 | parser.add_argument("--flow_permutation", type=int, default=2, 409 | help="Type of flow. 0=reverse (realnvp), 1=shuffle, 2=invconv (ours)") 410 | parser.add_argument("--flow_coupling", type=int, default=0, 411 | help="Coupling type: 0=additive, 1=affine") 412 | 413 | hps = parser.parse_args() # So error if typo 414 | main(hps) 415 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class ResultLogger(object): 5 | def __init__(self, path, *args, **kwargs): 6 | self.f_log = open(path, 'w') 7 | self.f_log.write(json.dumps(kwargs) + '\n') 8 | 9 | def log(self, **kwargs): 10 | self.f_log.write(json.dumps(kwargs) + '\n') 11 | self.f_log.flush() 12 | 13 | def close(self): 14 | self.f_log.close() 15 | --------------------------------------------------------------------------------