44 | This mock-up mimics the look of the in-progress article to inform a design
45 | that embeds the demo into the article. The relevant assets just need to be
46 | migrated into the final article.
47 |
48 |
49 |
We’ve developed Glow, a new type of generative model which uses
50 | invertible 1x1 convolutions to create rich, synthetic models of data,
51 | automatically discovering features we can manipulate. The model extends
52 | previous work on reversible generative models, simplifying the
53 | architecture and leading to substantially better results. We’re releasing
54 | code for the model and an online visualization tool so people can explore
55 | and build on these results.
Generative modeling is about observing data, like a set of pictures of
69 | faces, then learning a model of how this data was generated. Learning to
70 | approximate the data-generating process requires learning all structure
71 | present in the data, and successful models should be able to synthesize
72 | outputs that look similar to the data. Accurate generative models have
73 | broad applications, including speech synthesis, text analysis and
74 | synthesis, semi-supervised learning and model-based control. The technique
75 | we propose can be applied to those problems as well.
76 |
77 |
78 |
--------------------------------------------------------------------------------
/data_loaders/get_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | import numpy as np
4 | import glob
5 |
6 | _FILES_SHUFFLE = 1024
7 | _SHUFFLE_FACTOR = 4
8 |
9 |
10 | def parse_tfrecord_tf(record, res, rnd_crop):
11 | features = tf.parse_single_example(record, features={
12 | 'shape': tf.FixedLenFeature([3], tf.int64),
13 | 'data': tf.FixedLenFeature([], tf.string),
14 | 'label': tf.FixedLenFeature([1], tf.int64)})
15 | # label is always 0 if uncondtional
16 | # to get CelebA attr, add 'attr': tf.FixedLenFeature([40], tf.int64)
17 | data, label, shape = features['data'], features['label'], features['shape']
18 | label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32)
19 | img = tf.decode_raw(data, tf.uint8)
20 | if rnd_crop:
21 | # For LSUN Realnvp only - random crop
22 | img = tf.reshape(img, shape)
23 | img = tf.random_crop(img, [res, res, 3])
24 | img = tf.reshape(img, [res, res, 3])
25 | return img, label # to get CelebA attr, also return attr
26 |
27 |
28 | def input_fn(tfr_file, shards, rank, pmap, fmap, n_batch, resolution, rnd_crop, is_training):
29 | files = tf.data.Dataset.list_files(tfr_file)
30 | if ('lsun' not in tfr_file) or is_training:
31 | # For 'lsun' validation, only one shard and each machine goes over the full dataset
32 | # each worker works on a subset of the data
33 | files = files.shard(shards, rank)
34 | if is_training:
35 | # shuffle order of files in shard
36 | files = files.shuffle(buffer_size=_FILES_SHUFFLE)
37 | dset = files.apply(tf.contrib.data.parallel_interleave(
38 | tf.data.TFRecordDataset, cycle_length=fmap))
39 | if is_training:
40 | dset = dset.shuffle(buffer_size=n_batch * _SHUFFLE_FACTOR)
41 | dset = dset.repeat()
42 | dset = dset.map(lambda x: parse_tfrecord_tf(
43 | x, resolution, rnd_crop), num_parallel_calls=pmap)
44 | dset = dset.batch(n_batch)
45 | dset = dset.prefetch(1)
46 | itr = dset.make_one_shot_iterator()
47 | return itr
48 |
49 |
50 | def get_tfr_file(data_dir, split, res_lg2):
51 | data_dir = os.path.join(data_dir, split)
52 | tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
53 | tfr_file = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (res_lg2)
54 | files = glob.glob(tfr_file)
55 | assert len(files) == int(files[0].split(
56 | "-")[-1].split(".")[0]), "Not all tfrecords files present at %s" % tfr_prefix
57 | return tfr_file
58 |
59 |
60 | def get_data(sess, data_dir, shards, rank, pmap, fmap, n_batch_train, n_batch_test, n_batch_init, resolution, rnd_crop):
61 | assert resolution == 2 ** int(np.log2(resolution))
62 |
63 | train_file = get_tfr_file(data_dir, 'train', int(np.log2(resolution)))
64 | valid_file = get_tfr_file(data_dir, 'validation', int(np.log2(resolution)))
65 |
66 | train_itr = input_fn(train_file, shards, rank, pmap,
67 | fmap, n_batch_train, resolution, rnd_crop, True)
68 | valid_itr = input_fn(valid_file, shards, rank, pmap,
69 | fmap, n_batch_test, resolution, rnd_crop, False)
70 |
71 | data_init = make_batch(sess, train_itr, n_batch_train, n_batch_init)
72 |
73 | return train_itr, valid_itr, data_init
74 |
75 | #
76 |
77 |
78 | def make_batch(sess, itr, itr_batch_size, required_batch_size):
79 | ib, rb = itr_batch_size, required_batch_size
80 | #assert rb % ib == 0
81 | k = int(np.ceil(rb / ib))
82 | xs, ys = [], []
83 | data = itr.get_next()
84 | for i in range(k):
85 | x, y = sess.run(data)
86 | xs.append(x)
87 | ys.append(y)
88 | x, y = np.concatenate(xs)[:rb], np.concatenate(ys)[:rb]
89 | return {'x': x, 'y': y}
90 |
--------------------------------------------------------------------------------
/data_loaders/get_mnist_cifar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def downsample(x, resolution):
5 | assert x.dtype == np.float32
6 | assert x.shape[1] % resolution == 0
7 | assert x.shape[2] % resolution == 0
8 | if x.shape[1] == x.shape[2] == resolution:
9 | return x
10 | s = x.shape
11 | x = np.reshape(x, [s[0], resolution, s[1] // resolution,
12 | resolution, s[2] // resolution, s[3]])
13 | x = np.mean(x, (2, 4))
14 | return x
15 |
16 |
17 | def x_to_uint8(x):
18 | x = np.clip(np.floor(x), 0, 255)
19 | return x.astype(np.uint8)
20 |
21 |
22 | def shard(data, shards, rank):
23 | # Determinisitc shards
24 | x, y = data
25 | assert x.shape[0] == y.shape[0]
26 | assert x.shape[0] % shards == 0
27 | assert 0 <= rank < shards
28 | size = x.shape[0] // shards
29 | ind = rank*size
30 | return x[ind:ind+size], y[ind:ind+size]
31 |
32 |
33 | def get_data(problem, shards, rank, data_augmentation_level, n_batch_train, n_batch_test, n_batch_init, resolution):
34 | if problem == 'mnist':
35 | from keras.datasets import mnist
36 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
37 | y_train = np.reshape(y_train, [-1])
38 | y_test = np.reshape(y_test, [-1])
39 | # Pad with zeros to make 32x32
40 | x_train = np.lib.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'minimum')
41 | # Pad with zeros to make 32x23
42 | x_test = np.lib.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'minimum')
43 | x_train = np.tile(np.reshape(x_train, (-1, 32, 32, 1)), (1, 1, 1, 3))
44 | x_test = np.tile(np.reshape(x_test, (-1, 32, 32, 1)), (1, 1, 1, 3))
45 | elif problem == 'cifar10':
46 | from keras.datasets import cifar10
47 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
48 | y_train = np.reshape(y_train, [-1])
49 | y_test = np.reshape(y_test, [-1])
50 | else:
51 | raise Exception()
52 |
53 | print('n_train:', x_train.shape[0], 'n_test:', x_test.shape[0])
54 |
55 | # Shard before any shuffling
56 | x_train, y_train = shard((x_train, y_train), shards, rank)
57 | x_test, y_test = shard((x_test, y_test), shards, rank)
58 |
59 | print('n_shard_train:', x_train.shape[0], 'n_shard_test:', x_test.shape[0])
60 |
61 | from keras.preprocessing.image import ImageDataGenerator
62 | datagen_test = ImageDataGenerator()
63 | if data_augmentation_level == 0:
64 | datagen_train = ImageDataGenerator()
65 | else:
66 | if problem == 'mnist':
67 | datagen_train = ImageDataGenerator(
68 | width_shift_range=0.1,
69 | height_shift_range=0.1
70 | )
71 | elif problem == 'cifar10':
72 | if data_augmentation_level == 1:
73 | datagen_train = ImageDataGenerator(
74 | width_shift_range=0.1,
75 | height_shift_range=0.1
76 | )
77 | elif data_augmentation_level == 2:
78 | datagen_train = ImageDataGenerator(
79 | width_shift_range=0.1,
80 | height_shift_range=0.1,
81 | horizontal_flip=True,
82 | rotation_range=15, # degrees rotation
83 | zoom_range=0.1,
84 | shear_range=0.02,
85 | )
86 | else:
87 | raise Exception()
88 | else:
89 | raise Exception()
90 |
91 | datagen_train.fit(x_train)
92 | datagen_test.fit(x_test)
93 | train_flow = datagen_train.flow(x_train, y_train, n_batch_train)
94 | test_flow = datagen_test.flow(x_test, y_test, n_batch_test, shuffle=False)
95 |
96 | def make_iterator(flow, resolution):
97 | def iterator():
98 | x_full, y = flow.next()
99 | x_full = x_full.astype(np.float32)
100 | x = downsample(x_full, resolution)
101 | x = x_to_uint8(x)
102 | return x, y
103 |
104 | return iterator
105 |
106 | #init_iterator = make_iterator(train_flow, resolution)
107 | train_iterator = make_iterator(train_flow, resolution)
108 | test_iterator = make_iterator(test_flow, resolution)
109 |
110 | # Get data for initialization
111 | data_init = make_batch(train_iterator, n_batch_train, n_batch_init)
112 |
113 | return train_iterator, test_iterator, data_init
114 |
115 |
116 | def make_batch(iterator, iterator_batch_size, required_batch_size):
117 | ib, rb = iterator_batch_size, required_batch_size
118 | #assert rb % ib == 0
119 | k = int(np.ceil(rb / ib))
120 | xs, ys = [], []
121 | for i in range(k):
122 | x, y = iterator()
123 | xs.append(x)
124 | ys.append(y)
125 | x, y = np.concatenate(xs)[:rb], np.concatenate(ys)[:rb]
126 | return {'x': x, 'y': y}
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Status:** Archive (code is provided as-is, no updates expected)
2 |
3 | # Glow
4 |
5 | Code for reproducing results in ["Glow: Generative Flow with Invertible 1x1 Convolutions"](https://d4mucfpksywv.cloudfront.net/research-covers/glow/paper/glow.pdf)
6 |
7 | To use pretrained CelebA-HQ model, make your own manipulation vectors and run our interactive demo, check `demo` folder.
8 |
9 | ## Requirements
10 |
11 | - Tensorflow (tested with v1.8.0)
12 | - Horovod (tested with v0.13.8) and (Open)MPI
13 |
14 | Run
15 | ```
16 | pip install -r requirements.txt
17 | ```
18 |
19 | To setup (Open)MPI, check instructions on Horovod github [page](https://github.com/uber/horovod).
20 |
21 | ## Download datasets
22 | For small scale experiments, use MNIST/CIFAR-10 (directly downloaded by `train.py` using keras)
23 |
24 | For larger scale experiments, the datasets used are in the Google Cloud locations `https://openaipublic.azureedge.net/glow-demo/data/{dataset_name}-tfr.tar`. The dataset_names are below, we mention the exact preprocessing / downsampling method for a correct comparison of likelihood.
25 |
26 | Quantitative results
27 | - `imagenet-oord` - 20GB. Unconditional ImageNet 32x32 and 64x64, as described in PixelRNN/RealNVP papers (we downloaded [this](http://image-net.org/small/download.php) processed version).
28 | - `lsun_realnvp` - 140GB. LSUN 96x96. Random 64x64 crops taken at processing time, as described in RealNVP.
29 |
30 | Qualitative results
31 | - `celeba` - 4GB. CelebA-HQ 256x256 dataset, as described in Progressive growing of GAN's. For 1024x1024 version (120GB), use `celeba-full-tfr.tar` while downloading.
32 | - `imagenet` - 20GB. ImageNet 32x32 and 64x64 with class labels. Centre cropped, area downsampled.
33 | - `lsun` - 700GB. LSUN 256x256. Centre cropped, area downsampled.
34 |
35 | To download and extract celeb for example, run
36 | ```
37 | wget https://openaipublic.azureedge.net/glow-demo/data/celeba-tfr.tar
38 | tar -xvf celeb-tfr.tar
39 | ```
40 | Change `hps.data_dir` in train.py file to point to the above folder (or use the `--data_dir` flag when you run train.py)
41 |
42 | For `lsun`, since download can be quite big, you can instead follow the instructions in `data_loaders/generate_tfr/lsun.py` to generate the tfr file directly from LSUN images. `church_outdoor` will be the smallest category.
43 |
44 | ## Simple Train with 1 GPU
45 |
46 | Run wtih small depth to test
47 | ```
48 | CUDA_VISIBLE_DEVICES=0 python train.py --depth 1
49 | ```
50 |
51 | ## Train with multiple GPUs using MPI and Horovod
52 |
53 | Run default training script with 8 GPUs:
54 | ```
55 | mpiexec -n 8 python train.py
56 | ```
57 |
58 | ##### Ablation experiments
59 |
60 | ```
61 | mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation [0/1/2] --flow_coupling [0/1] --seed [0/1/2] --learntop --lr 0.001
62 | ```
63 |
64 | Pretrained models, logs and samples
65 | ```
66 | wget https://openaipublic.azureedge.net/glow-demo/logs/abl-[reverse/shuffle/1x1]-[add/aff].tar
67 | ```
68 |
69 | ##### CIFAR-10 Quantitative result
70 |
71 | ```
72 | mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8
73 | ```
74 |
75 | ##### ImageNet 32x32 Quantitative result
76 |
77 | ```
78 | mpiexec -n 8 python train.py --problem imagenet-oord --image_size 32 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8
79 | ```
80 |
81 | ##### ImageNet 64x64 Quantitative result
82 | ```
83 | mpiexec -n 8 python train.py --problem imagenet-oord --image_size 64 --n_level 4 --depth 48 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8
84 | ```
85 |
86 | ##### LSUN 64x64 Quantitative result
87 | ```
88 | mpiexec -n 8 python train.py --problem lsun_realnvp --category [bedroom/church_outdoor/tower] --image_size 64 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 1 --seed 0 --learntop --lr 0.001 --n_bits_x 8
89 | ```
90 |
91 | Pretrained models, logs and samples
92 | ```
93 | wget https://openaipublic.azureedge.net/glow-demo/logs/lsun-rnvp-[bdr/crh/twr].tar
94 | ```
95 |
96 | ##### CelebA-HQ 256x256 Qualitative result
97 |
98 | ```
99 | mpiexec -n 40 python train.py --problem celeba --image_size 256 --n_level 6 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5
100 | ```
101 |
102 | ##### LSUN 96x96 and 128x128 Qualitative result
103 | ```
104 | mpiexec -n 40 python train.py --problem lsun --category [bedroom/church_outdoor/tower] --image_size [96/128] --n_level 5 --depth 64 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5
105 | ```
106 |
107 | Logs and samples
108 | ```
109 | wget https://openaipublic.azureedge.net/glow-demo/logs/lsun-bdr-[96/128].tar
110 | ```
111 |
112 | ##### Conditional CIFAR-10 Qualitative result
113 | ```
114 | mpiexec -n 8 python train.py --problem cifar10 --image_size 32 --n_level 3 --depth 32 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01
115 | ```
116 |
117 | ##### Conditional ImageNet 32x32 Qualitative result
118 | ```
119 | mpiexec -n 8 python train.py --problem imagenet --image_size 32 --n_level 3 --depth 48 --flow_permutation 2 --flow_coupling 0 --seed 0 --learntop --lr 0.001 --n_bits_x 5 --ycond --weight_y=0.01
120 | ```
121 |
--------------------------------------------------------------------------------
/data_loaders/generate_tfr/imagenet_oord.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Generate tfrecords for ImageNet 32x32 and 64x64.
18 |
19 | # Get images
20 | Downloaded images from http://image-net.org/small/download.php, and unzip them.
21 | (Move one file from training to test to have 50000 test images)
22 |
23 | # Get tfr file from images
24 | Use this script to generate the tfr file.
25 | python imagenet_oord.py --res [RES] --tfrecord_dir [OUTPUT_FOLDER] --write
26 |
27 | """
28 |
29 | from __future__ import print_function
30 |
31 | import os
32 | import os.path
33 |
34 | import scipy.io
35 | import scipy.io.wavfile
36 | import scipy.ndimage
37 | import tensorflow as tf
38 | import numpy as np
39 | from tqdm import tqdm
40 |
41 | from typing import Iterable
42 |
43 |
44 | def _int64_feature(value):
45 | if not isinstance(value, Iterable):
46 | value = [value]
47 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
48 |
49 |
50 | def _bytes_feature(value):
51 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
52 |
53 |
54 | def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write):
55 | """Main converter function."""
56 | # fn_root = FLAGS.fn_root
57 | # max_res = FLAGS.max_res
58 | resolution_log2 = int(np.log2(max_res))
59 | tfr_prefix = os.path.join(tfrecord_dir, os.path.basename(tfrecord_dir))
60 |
61 | print("Checking in", fn_root)
62 | img_fn_list = os.listdir(fn_root)
63 | img_fn_list = [img_fn for img_fn in img_fn_list
64 | if img_fn.endswith('.png')]
65 | num_examples = len(img_fn_list)
66 | print("Found", num_examples)
67 | assert num_examples == expected_images
68 |
69 | # Sharding
70 | tfr_opt = tf.python_io.TFRecordOptions(
71 | tf.python_io.TFRecordCompressionType.NONE)
72 | p_shard = np.array_split(np.random.permutation(expected_images), shards)
73 | img_to_shard = np.zeros(expected_images, dtype=np.int)
74 | writers = []
75 | for shard in range(shards):
76 | img_to_shard[p_shard[shard]] = shard
77 | tfr_file = tfr_prefix + \
78 | '-r%02d-s-%04d-of-%04d.tfrecords' % (
79 | resolution_log2, shard, shards)
80 | writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
81 |
82 | # print(np.unique(img_to_shard, return_counts=True))
83 | counts = np.unique(img_to_shard, return_counts=True)[1]
84 | assert len(counts) == shards
85 | print("Smallest and largest shards have size",
86 | np.min(counts), np.max(counts))
87 |
88 | for example_idx, img_fn in enumerate(tqdm(img_fn_list)):
89 | shard = img_to_shard[example_idx]
90 | img = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
91 | rows = img.shape[0]
92 | cols = img.shape[1]
93 | depth = img.shape[2]
94 | shape = (rows, cols, depth)
95 | img = img.astype("uint8")
96 | img = img.tostring()
97 | example = tf.train.Example(
98 | features=tf.train.Features(
99 | feature={
100 | "shape": _int64_feature(shape),
101 | "data": _bytes_feature(img),
102 | "label": _int64_feature(0)
103 | }
104 | )
105 | )
106 | if write:
107 | writers[shard].write(example.SerializeToString())
108 |
109 | print('%-40s\r' % 'Flushing data...', end='', flush=True)
110 | for writer in writers:
111 | writer.close()
112 |
113 | print('%-40s\r' % '', end='', flush=True)
114 | print('Added %d images.' % num_examples)
115 |
116 |
117 | if __name__ == "__main__":
118 | import argparse
119 |
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--res", type=int, default=32, help="Image size")
122 | parser.add_argument("--tfrecord_dir", type=str,
123 | required=True, help='place to dump')
124 | parser.add_argument("--write", action='store_true',
125 | help="Whether to write")
126 | hps = parser.parse_args()
127 |
128 | # Imagenet
129 | _NUM_IMAGES = {
130 | 'train': 1281148,
131 | 'validation': 50000,
132 | }
133 |
134 | _NUM_SHARDS = {
135 | 'train': 2000,
136 | 'validation': 80,
137 | }
138 |
139 | _FILE = {
140 | 'train': 'train_%dx%d' % (hps.res, hps.res),
141 | 'validation': 'valid_%dx%d' % (hps.res, hps.res),
142 | }
143 |
144 | for split in ['validation', 'train']:
145 | fn_root = _FILE[split]
146 | tfrecord_dir = os.path.join(hps.tfrecord_dir, split)
147 | total_imgs = _NUM_IMAGES[split]
148 | shards = _NUM_SHARDS[split]
149 | if not os.path.exists(tfrecord_dir):
150 | os.mkdir(tfrecord_dir)
151 | dump(fn_root, tfrecord_dir, hps.res, total_imgs, shards, hps.write)
152 |
--------------------------------------------------------------------------------
/demo/server.py:
--------------------------------------------------------------------------------
1 | import model
2 | from align_face import align_face
3 | from flask import Flask, jsonify, request
4 | from flask_cors import CORS
5 |
6 | import base64
7 | import time
8 | import numpy as np
9 | from PIL import Image
10 | from io import BytesIO
11 | app = Flask(__name__)
12 | CORS(app)
13 |
14 |
15 | def deserialise_img(img_str):
16 | img = base64.b64decode(img_str.split(",")[-1])
17 | img = Image.open(BytesIO(img))
18 | img = img.convert('RGB')
19 | img = np.array(img)
20 | return img
21 |
22 |
23 | def serialise_img(arr):
24 | img = Image.fromarray(arr)
25 | buf = BytesIO()
26 | img.save(buf, format='PNG')
27 | buf = buf.getvalue()
28 | return "data:image/png;base64," + base64.b64encode(buf).decode('utf-8')
29 |
30 |
31 | def deserialise_nparr(arr_str):
32 | arr = np.loads(base64.b64decode(arr_str))
33 | return np.array(arr, dtype=np.float32)
34 |
35 |
36 | def serialise_nparr(arr):
37 | arr = np.array(arr, dtype=np.float16)
38 | return base64.b64encode(arr.dumps()).decode('utf-8')
39 |
40 |
41 | def send(result):
42 | # img, z are batches, send as list of singles
43 | img, z = result
44 | # , z=list(map(serialise_nparr, z)))
45 | return jsonify(img=list(map(serialise_img, img)))
46 |
47 |
48 | def send_proj(result, proj):
49 | # img, z are batches, send as list of singles
50 | img, z = result
51 | return jsonify(face_found=True, img=list(map(serialise_img, img)), z=list(map(serialise_nparr, z)), proj=proj.tolist())
52 |
53 |
54 | def get(request, key):
55 | return request.get_json().get(key)
56 |
57 |
58 | def get_z(request, key):
59 | # z is a single point, batch it for use
60 | z = get(request, key)
61 | return np.expand_dims(deserialise_nparr(z), axis=0)
62 |
63 |
64 | @app.route('/')
65 | def hello_world():
66 | return 'Welcome to Glow!'
67 |
68 | # Align and encode image
69 | #
70 | # args
71 | # img: Image as base64 string
72 | #
73 | # returns
74 | # json: {'face_found': face_found, 'img':[base64 img], 'z': [serialised z]}
75 | @app.route('/api/align_encode', methods=['POST'])
76 | def align_encode():
77 | r = request
78 | img = get(r, 'img')
79 | # img = parse_img(img) if in jpg etc format
80 | img = deserialise_img(img)
81 | img, face_found = align_face(img)
82 | if face_found:
83 | img = np.reshape(img, [1, 256, 256, 3])
84 | print(img.shape)
85 | z = model.encode(img)
86 | proj = model.project(z) # get projections. Not used
87 | result = img, z
88 | # jsonify(img=serialise_img(img), z=serialise_nparr(z))
89 | return send_proj(result, proj)
90 | else:
91 | return jsonify(face_found=False)
92 |
93 | # Maipulate single attribute
94 | #
95 | # args
96 | # z: Serialised np array for encoding of image
97 | # typ: int in [0,40), representing which attribute to manipulate
98 | # alpha: float, usually in [-1,1], representing how much to manipulate. 0 gives original image
99 | #
100 | # returns
101 | # json: {'img': [img]}
102 | @app.route('/api/manipulate', methods=['POST'])
103 | def manipulate():
104 | r = request
105 | z = get_z(r, 'z')
106 | typ = get(r, 'typ')
107 | alpha = get(r, 'alpha')
108 | result = model.manipulate(z, typ, alpha)
109 | return send(result)
110 |
111 | # Manipulate multiple attributes
112 | # typs: list of typ
113 | # alphas: list of corresponding alphas
114 | @app.route('/api/manipulate_all', methods=['POST'])
115 | def manipulate_all():
116 | r = request
117 | z = get_z(r, 'z')
118 | typs = get(r, 'typs')
119 | alphas = get(r, 'alphas')
120 | result = model.manipulate_all(z, typs, alphas)
121 | return send(result)
122 |
123 | # Mix two faces
124 | #
125 | # args
126 | # z1: Serialised np array for encoding of image 1
127 | # z2: Serialised np array for encoding of image 2
128 | # alpha: float in [0,1], representing how much to mix. 0.5 gives middle image
129 | #
130 | # returns
131 | # json: {'img': [img]}
132 | @app.route('/api/mix', methods=['POST'])
133 | def mix():
134 | r = request
135 | z1 = get_z(r, 'z1')
136 | z2 = get_z(r, 'z2')
137 | alpha = get(r, 'alpha')
138 | result = model.mix(z1, z2, alpha)
139 | return send(result)
140 |
141 | # Get random image
142 | @app.route('/api/random', methods=['POST'])
143 | def random():
144 | r = request
145 | bs = get(r, 'bs')
146 | result = model.random(bs)
147 | img, z = result
148 | proj = model.project(z)
149 | return send_proj(result, proj)
150 |
151 | # Extra functions
152 | @app.route('/api/test', methods=['POST'])
153 | def test():
154 | r = request
155 | z = get_z(r, 'z')
156 | typs = get(r, 'typs')
157 | alphas = get(r, 'alphas') # value between [-1,1] -> 0.5 is original image
158 | return jsonify(z="")
159 |
160 |
161 | @app.route('/api/manipulate_range', methods=['POST'])
162 | def manipulate_range():
163 | r = request
164 | z = get_z(r, 'z')
165 | typ = get(r, 'typ')
166 | points = get(r, 'points')
167 | result = model.manipulate_range(z, typ, points)
168 | return send(result)
169 |
170 |
171 | @app.route('/api/mix_range', methods=['POST'])
172 | def mix_range():
173 | r = request
174 | z1 = get_z(r, 'z1')
175 | z2 = get_z(r, 'z2')
176 | points = get(r, 'points')
177 | result = model.mix_range(z1, z2, points)
178 | return send(result)
179 |
180 | # Legacy
181 | # @app.route('/api/encode', methods=['POST'])
182 | # def encode():
183 | # t = time.time()
184 | # r = request
185 | # img = get(r, 'img')
186 | # print("Time to read from request", time.time() - t)
187 | # t = time.time()
188 | # img = deserialise_nparr(img)
189 | # print("Time to serialise from request", time.time() - t)
190 | # t = time.time()
191 | # z = model.encode(img)
192 | # print("TIme to encode", time.time() - t)
193 | # t = time.time()
194 | # json = jsonify(z=serialise_nparr(z))
195 | # print("Time to jsonify", time.time() - t)
196 | # return json
197 | #
198 | # @app.route('/api/decode', methods=['POST'])
199 | # def decode():
200 | # t = time.time()
201 | # r = request
202 | # z = get(r, 'z')
203 | # print("Time to read from request", time.time() - t)
204 | # t = time.time()
205 | # z = deserialise_nparr(z)
206 | # print("Time to serialise from request", time.time() - t)
207 | # t = time.time()
208 | # img = model.decode(z)
209 | # print("TIme to decode", time.time() - t)
210 | # t = time.time()
211 | # json = jsonify(img=serialise_img(img))
212 | # print("Time to jsonify", time.time() - t)
213 | # return json
214 | #
215 | # @app.route('/api/align', methods=['POST'])
216 | # def align():
217 | # r = request
218 | # img = get(r, 'img')
219 | # # img = parse_img(img) if in jpg etc format
220 | # img = deserialise_img(img)
221 | # img = align_face(img)
222 | # return jsonify(img=serialise_img(img))
223 |
224 |
225 | # FaceOff! Use for manipulation and blending faces
226 | if __name__ == '__main__':
227 | print('Running Flask app...')
228 | app.run(host='0.0.0.0', port=5050)
229 |
--------------------------------------------------------------------------------
/data_loaders/generate_tfr/lsun.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """"
17 | LSUN dataset
18 |
19 | # Get image files
20 | Download the LSUN dataset as follows:
21 | git clone https://github.com/fyu/lsun.git
22 | cd lsun
23 | python2.7 download.py -c [CATEGORY]
24 | Unzip the downloaded .zip files and execute:
25 | python2.7 data.py export [IMAGE_DB_PATH] --out_dir [LSUN_FOLDER] --flat
26 |
27 | # Get tfr file from images
28 | Use this script to generate the tfr file.
29 | python lsun.py --res [RES] --category [CATEGORY] --lsun_dir [LSUN_FOLDER] --tfrecord_dir [OUTPUT_FOLDER] --write [--realnvp]
30 | Without realnvp flag you get 256x256 centre cropped area downsampled images, with flag you get 96x96 images with realnvp preprocessing.
31 | """
32 |
33 | from __future__ import print_function
34 |
35 | import os
36 | import os.path
37 |
38 | import numpy
39 | import skimage.transform
40 | from PIL import Image
41 | import tensorflow as tf
42 | import numpy as np
43 | from tqdm import tqdm
44 |
45 | from typing import Iterable
46 |
47 |
48 | def _int64_feature(value):
49 | if not isinstance(value, Iterable):
50 | value = [value]
51 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
52 |
53 |
54 | def _bytes_feature(value):
55 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
56 |
57 |
58 | def centre_crop(img):
59 | h, w = img.shape[:2]
60 | crop = min(h, w)
61 | return img[(h - crop) // 2: (h + crop) // 2, (w - crop) // 2: (w + crop) // 2]
62 |
63 |
64 | def dump(fn_root, tfrecord_dir, max_res, expected_images, shards, write, realnvp=False):
65 | """Main converter function."""
66 | resolution_log2 = int(np.log2(max_res))
67 | tfr_prefix = os.path.join(tfrecord_dir, os.path.basename(tfrecord_dir))
68 |
69 | print("Checking in", fn_root)
70 | img_fn_list = os.listdir(fn_root)
71 | img_fn_list = [img_fn for img_fn in img_fn_list
72 | if img_fn.endswith('.webp')]
73 | num_examples = len(img_fn_list)
74 | print("Found", num_examples)
75 | assert num_examples == expected_images
76 |
77 | tfr_opt = tf.python_io.TFRecordOptions(
78 | tf.python_io.TFRecordCompressionType.NONE)
79 | p_shard = np.array_split(np.random.permutation(expected_images), shards)
80 | img_to_shard = np.zeros(expected_images, dtype=np.int)
81 | writers = []
82 | for shard in tqdm(range(shards)):
83 | img_to_shard[p_shard[shard]] = shard
84 | tfr_file = tfr_prefix + \
85 | '-r%02d-s-%04d-of-%04d.tfrecords' % (
86 | resolution_log2, shard, shards)
87 | writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
88 |
89 | # print(np.unique(img_to_shard, return_counts=True))
90 | counts = np.unique(img_to_shard, return_counts=True)[1]
91 | assert len(counts) == shards
92 | print("Smallest and largest shards have size",
93 | np.min(counts), np.max(counts))
94 |
95 | for example_idx, img_fn in enumerate(tqdm(img_fn_list)):
96 | shard = img_to_shard[example_idx]
97 | img = numpy.array(Image.open(os.path.join(fn_root, img_fn)))
98 | rows = img.shape[0]
99 | cols = img.shape[1]
100 | if realnvp:
101 | downscale = min(rows / 96., cols / 96.)
102 | img = skimage.transform.pyramid_reduce(img, downscale)
103 | img *= 255.
104 | img = img.astype("uint8")
105 | else:
106 | img = centre_crop(img)
107 | img = Image.fromarray(img, 'RGB')
108 | img = img.resize((max_res, max_res), Image.ANTIALIAS)
109 | img = np.asarray(img)
110 | rows = img.shape[0]
111 | cols = img.shape[1]
112 | depth = img.shape[2]
113 | shape = (rows, cols, depth)
114 | img = img.tostring()
115 | example = tf.train.Example(
116 | features=tf.train.Features(
117 | feature={
118 | "shape": _int64_feature(shape),
119 | "data": _bytes_feature(img),
120 | "label": _int64_feature(0)
121 | }
122 | )
123 | )
124 | if write:
125 | writers[shard].write(example.SerializeToString())
126 |
127 | print('%-40s\r' % 'Flushing data...', end='', flush=True)
128 | for writer in writers:
129 | writer.close()
130 |
131 | print('%-40s\r' % '', end='', flush=True)
132 | print('Added %d images.' % num_examples)
133 |
134 |
135 | if __name__ == "__main__":
136 | import argparse
137 |
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument("--category", type=str, help="LSUN category")
140 | parser.add_argument("--realnvp", action='store_true',
141 | help="Use this flag to do realnvp preprocessing instead of our centre-crops")
142 | parser.add_argument("--res", type=int, default=256, help="Image size")
143 | parser.add_argument("--lsun_dir", type=str,
144 | required=True, help="place of lsun dir")
145 | parser.add_argument("--tfrecord_dir", type=str,
146 | required=True, help='place to dump')
147 | parser.add_argument("--write", action='store_true',
148 | help="Whether to write")
149 | hps = parser.parse_args()
150 |
151 | # LSUN
152 | # CATEGORIES = ["bedroom", "bridge", "church_outdoor", "classroom", "conference_room", "dining_room", "kitchen", "living"]
153 | base_tfr = hps.tfrecord_dir
154 | res = hps.res
155 | for realnvp in [False, True]:
156 | for category in ["tower", "church_outdoor", "bedroom"]:
157 | hps.realnvp = realnvp
158 | hps.category = category
159 | if realnvp:
160 | hps.tfrecord_dir = "%s_%s/%s" % (base_tfr,
161 | "realnvp", hps.category)
162 | else:
163 | hps.tfrecord_dir = "%s/%s" % (base_tfr, hps.category)
164 | print(hps.realnvp, hps.category, hps.lsun_dir, hps.tfrecord_dir)
165 | imgs = {
166 | 'bedroom': 3033042,
167 | 'bridge': 818687,
168 | 'church_outdoor': 126227,
169 | 'classroom': 168103,
170 | 'conference_room': 229069,
171 | 'dining_room': 657571,
172 | 'kitchen': 2212277,
173 | 'living_room': 1315802,
174 | 'restaurant': 626331,
175 | 'tower': 708264
176 | }
177 |
178 | _NUM_IMAGES = {
179 | 'train': imgs[hps.category],
180 | 'validation': 300,
181 | }
182 |
183 | _NUM_SHARDS = {
184 | 'train': 2560,
185 | 'validation': 1,
186 | }
187 |
188 | _FILE = {
189 | 'train': os.path.join(hps.lsun_dir, '%s_train' % hps.category),
190 | 'validation': os.path.join(hps.lsun_dir, '%s_val' % hps.category)
191 |
192 | }
193 |
194 | if hps.realnvp:
195 | res = 96
196 | else:
197 | res = hps.res
198 |
199 | for split in ['validation', 'train']:
200 | fn_root = _FILE[split]
201 | tfrecord_dir = os.path.join(hps.tfrecord_dir, split)
202 | total_imgs = _NUM_IMAGES[split]
203 | shards = _NUM_SHARDS[split]
204 | if not os.path.exists(tfrecord_dir):
205 | os.mkdir(tfrecord_dir)
206 | dump(fn_root, tfrecord_dir, res, total_imgs,
207 | shards, hps.write, hps.realnvp)
208 |
--------------------------------------------------------------------------------
/optim.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tfops as Z
3 | import horovod.tensorflow as hvd
4 |
5 | # Optimizers
6 |
7 | '''
8 | Polyak averaging op
9 | '''
10 |
11 |
12 | def polyak(params, beta):
13 | #params = tf.trainable_variables()
14 | ema = tf.train.ExponentialMovingAverage(decay=beta, zero_debias=True)
15 | avg_op = tf.group(ema.apply(params))
16 | # Swapping op
17 | updates = []
18 | for i in range(len(params)):
19 | p = params[i]
20 | avg = ema.average(p)
21 | tmp = 0. + avg * 1.
22 | with tf.control_dependencies([tmp]):
23 | update1 = avg.assign(p)
24 | with tf.control_dependencies([update1]):
25 | update2 = p.assign(tmp)
26 | updates += [update1, update2]
27 | swap_op = tf.group(*updates)
28 | return avg_op, swap_op, ema
29 |
30 |
31 | def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
32 | updates = []
33 | if type(cost_or_grads) is not list:
34 | gs = tf.gradients(cost_or_grads, params)
35 | else:
36 | gs = cost_or_grads
37 |
38 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
39 |
40 | # all-reduce
41 | grads = [Z.allreduce_mean(g) for g in gs]
42 |
43 | t = tf.Variable(1., 'adam_t')
44 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
45 | (1. - tf.pow(hps.beta1, t))
46 | updates.append(t.assign_add(1))
47 |
48 | for w, g in zip(params, grads):
49 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
50 | if hps.beta1 > 0:
51 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
52 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g
53 | updates.append(mom1.assign(mom1_new))
54 | else:
55 | mom1_new = g
56 | m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g)
57 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon)
58 | w_new = hps.weight_decay * w - alpha_t * delta_t
59 | updates.append(mom2.assign(m2_new))
60 | updates.append(w.assign(w_new))
61 |
62 | # Polyak averaging
63 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
64 | train_op = tf.group(polyak_avg_op, *updates)
65 | return train_op, polyak_swap_op, ema
66 |
67 |
68 | '''
69 | Adam optimizer
70 | Version whose learning rate could, in theory, be scaled linearly (like SGD+momentum).
71 | (It doesn't seem to work yet, though.)
72 | '''
73 |
74 |
75 | def adam2(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
76 | updates = []
77 | if type(cost_or_grads) is not list:
78 | gs = tf.gradients(cost_or_grads, params)
79 | else:
80 | gs = cost_or_grads
81 |
82 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
83 |
84 | # all-reduce
85 | grads1 = [Z.allreduce_mean(g) for g in gs]
86 | grads2 = [Z.allreduce_mean(g**2) for g in gs]
87 |
88 | t = tf.Variable(1., 'adam_t')
89 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
90 | (1. - tf.pow(hps.beta1, t))
91 | updates.append(t.assign_add(1))
92 |
93 | for w, g1, g2 in zip(params, grads1, grads2):
94 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
95 | if hps.beta1 > 0:
96 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
97 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g1
98 | updates.append(mom1.assign(mom1_new))
99 | else:
100 | mom1_new = g1
101 | m2_new = beta2 * mom2 + (1. - beta2) * g2
102 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon)
103 | w_new = hps.weight_decay * w - alpha_t * delta_t
104 | updates.append(mom2.assign(m2_new))
105 | updates.append(w.assign(w_new))
106 |
107 | # Polyak averaging
108 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
109 | train_op = tf.group(polyak_avg_op, *updates)
110 | return train_op, polyak_swap_op, ema
111 |
112 |
113 | '''
114 | Adam optimizer
115 | Version whose learning rate could, in theory, be scaled linearly (like SGD+momentum).
116 | It doesn't seem to work though.
117 | '''
118 |
119 |
120 | def adam2_old(params, cost_or_grads, lr=3e-4, mom1=0.9, mom2=0.999, epsilon=1e-8):
121 | updates = []
122 | if type(cost_or_grads) is not list:
123 | gs = tf.gradients(cost_or_grads, params)
124 | else:
125 | gs = cost_or_grads
126 |
127 | # all-reduce
128 | grads1 = [Z.allreduce_mean(g) for g in gs]
129 | grads2 = [Z.allreduce_mean(tf.square(g)) for g in gs]
130 | mom2 = tf.maximum(0., 1. - (hvd.size() * (1 - mom2)))
131 |
132 | t = tf.Variable(1., 'adam_t')
133 | lr_t = lr * tf.sqrt((1. - tf.pow(mom2, t))) / (1. - tf.pow(mom1, t))
134 | updates.append(t.assign_add(1))
135 |
136 | for p, g1, g2 in zip(params, grads1, grads2):
137 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')
138 | if mom1 > 0:
139 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
140 | v_t = mom1 * v + (1. - mom1) * g1
141 | updates.append(v.assign(v_t))
142 | else:
143 | v_t = g1
144 | mg_t = mom2 * mg + (1. - mom2) * g2
145 | delta_t = v_t / (tf.sqrt(mg_t) + epsilon)
146 | p_t = p - lr_t * delta_t
147 | updates.append(mg.assign(mg_t))
148 | updates.append(p.assign(p_t))
149 | return tf.group(*updates)
150 |
151 |
152 | def adamax(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
153 | updates = []
154 | if type(cost_or_grads) is not list:
155 | gs = tf.gradients(cost_or_grads, params)
156 | else:
157 | gs = cost_or_grads
158 |
159 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
160 |
161 | # all-reduce
162 | grads = [Z.allreduce_mean(g) for g in gs]
163 |
164 | t = tf.Variable(1., 'adam_t')
165 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
166 | (1. - tf.pow(hps.beta1, t))
167 | updates.append(t.assign_add(1))
168 |
169 | for w, g in zip(params, grads):
170 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
171 | if hps.beta1 > 0:
172 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
173 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g
174 | updates.append(mom1.assign(mom1_new))
175 | else:
176 | mom1_new = g
177 | m2_new = tf.maximum(beta2 * mom2, abs(g))
178 | delta_t = mom1_new / (m2_new + epsilon)
179 | w_new = hps.weight_decay * w - alpha_t * delta_t
180 | updates.append(mom2.assign(m2_new))
181 | updates.append(w.assign(w_new))
182 |
183 | # Polyak averaging
184 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
185 | train_op = tf.group(polyak_avg_op, *updates)
186 | return train_op, polyak_swap_op, ema
187 |
188 |
189 | def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8):
190 | updates = []
191 | if type(cost_or_grads) is not list:
192 | gs = tf.gradients(cost_or_grads, params)
193 | else:
194 | gs = cost_or_grads
195 |
196 | beta2 = 1-1./(hps.train_its*hps.polyak_epochs)
197 |
198 | # all-reduce
199 | grads = [Z.allreduce_mean(g) for g in gs]
200 |
201 | t = tf.Variable(1., 'adam_t')
202 | alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \
203 | (1. - tf.pow(hps.beta1, t))
204 | updates.append(t.assign_add(1))
205 |
206 | for w, g in zip(params, grads):
207 | mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2')
208 | if hps.beta1 > 0:
209 | mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1')
210 | mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g
211 | updates.append(mom1.assign(mom1_new))
212 | else:
213 | mom1_new = g
214 | m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g)
215 | delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon)
216 | w_new = hps.weight_decay * w - alpha_t * delta_t
217 | updates.append(mom2.assign(m2_new))
218 | updates.append(w.assign(w_new))
219 |
220 | # Polyak averaging
221 | polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2)
222 | train_op = tf.group(polyak_avg_op, *updates)
223 | return train_op, polyak_swap_op, ema
224 |
--------------------------------------------------------------------------------
/demo/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import time
4 | from tqdm import tqdm
5 | from PIL import Image
6 | from threading import Lock
7 |
8 | lock = Lock()
9 |
10 |
11 | def get(name):
12 | return tf.get_default_graph().get_tensor_by_name('import/' + name + ':0')
13 |
14 |
15 | def tensorflow_session():
16 | # Init session and params
17 | config = tf.ConfigProto()
18 | config.gpu_options.allow_growth = True
19 | # Pin GPU to local rank (one GPU per process)
20 | config.gpu_options.visible_device_list = str(0)
21 | sess = tf.Session(config=config)
22 | return sess
23 |
24 |
25 | optimized = True
26 | if optimized:
27 | # Optimized model. Twice as fast as
28 | # 1. we freeze conditional network (label is always 0)
29 | # 2. we use fused kernels
30 | import blocksparse
31 | graph_path = 'graph_optimized.pb'
32 | inputs = {
33 | 'dec_eps_0': 'dec_eps_0',
34 | 'dec_eps_1': 'dec_eps_1',
35 | 'dec_eps_2': 'dec_eps_2',
36 | 'dec_eps_3': 'dec_eps_3',
37 | 'dec_eps_4': 'dec_eps_4',
38 | 'dec_eps_5': 'dec_eps_5',
39 | 'enc_x': 'input/enc_x',
40 | }
41 | outputs = {
42 | 'dec_x': 'model_3/Cast_1',
43 | 'enc_eps_0': 'model_2/pool0/truediv_1',
44 | 'enc_eps_1': 'model_2/pool1/truediv_1',
45 | 'enc_eps_2': 'model_2/pool2/truediv_1',
46 | 'enc_eps_3': 'model_2/pool3/truediv_1',
47 | 'enc_eps_4': 'model_2/pool4/truediv_1',
48 | 'enc_eps_5': 'model_2/truediv_4'
49 | }
50 |
51 | def update_feed(feed_dict, bs):
52 | return feed_dict
53 | else:
54 | graph_path = 'graph_unoptimized.pb'
55 | inputs = {
56 | 'dec_eps_0': 'Placeholder',
57 | 'dec_eps_1': 'Placeholder_1',
58 | 'dec_eps_2': 'Placeholder_2',
59 | 'dec_eps_3': 'Placeholder_3',
60 | 'dec_eps_4': 'Placeholder_4',
61 | 'dec_eps_5': 'Placeholder_5',
62 | 'enc_x': 'input/image',
63 | 'enc_x_d': 'input/downsampled_image',
64 | 'enc_y': 'input/label'
65 | }
66 | outputs = {
67 | 'dec_x': 'model_1/Cast_1',
68 | 'enc_eps_0': 'model/pool0/truediv_1',
69 | 'enc_eps_1': 'model/pool1/truediv_1',
70 | 'enc_eps_2': 'model/pool2/truediv_1',
71 | 'enc_eps_3': 'model/pool3/truediv_1',
72 | 'enc_eps_4': 'model/pool4/truediv_1',
73 | 'enc_eps_5': 'model/truediv_4'
74 | }
75 |
76 | def update_feed(feed_dict, bs):
77 | x_d = 128 * np.ones([bs, 128, 128, 3], dtype=np.uint8)
78 | y = np.zeros([bs], dtype=np.int32)
79 | feed_dict[enc_x_d] = x_d
80 | feed_dict[enc_y] = y
81 | return feed_dict
82 |
83 | with tf.gfile.GFile(graph_path, 'rb') as f:
84 | graph_def_optimized = tf.GraphDef()
85 | graph_def_optimized.ParseFromString(f.read())
86 |
87 | sess = tensorflow_session()
88 | tf.import_graph_def(graph_def_optimized)
89 |
90 | print("Loaded model")
91 |
92 | n_eps = 6
93 |
94 | # Encoder
95 | enc_x = get(inputs['enc_x'])
96 | enc_eps = [get(outputs['enc_eps_' + str(i)]) for i in range(n_eps)]
97 | if not optimized:
98 | enc_x_d = get(inputs['enc_x_d'])
99 | enc_y = get(inputs['enc_y'])
100 |
101 | # Decoder
102 | dec_x = get(outputs['dec_x'])
103 | dec_eps = [get(inputs['dec_eps_' + str(i)]) for i in range(n_eps)]
104 |
105 | eps_shapes = [(128, 128, 6), (64, 64, 12), (32, 32, 24),
106 | (16, 16, 48), (8, 8, 96), (4, 4, 384)]
107 | eps_sizes = [np.prod(e) for e in eps_shapes]
108 | eps_size = 256 * 256 * 3
109 | z_manipulate = np.load('z_manipulate.npy')
110 |
111 | _TAGS = "5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young"
112 | _TAGS = _TAGS.split()
113 |
114 | flip_tags = ['No_Beard', 'Young']
115 | for tag in flip_tags:
116 | i = _TAGS.index(tag)
117 | z_manipulate[i] = -z_manipulate[i]
118 |
119 | scale_tags = ['Narrow_Eyes']
120 | for tag in scale_tags:
121 | i = _TAGS.index(tag)
122 | z_manipulate[i] = 1.2*z_manipulate[i]
123 |
124 | z_sq_norms = np.sum(z_manipulate**2, axis=-1, keepdims=True)
125 | z_proj = (z_manipulate / z_sq_norms).T
126 |
127 |
128 | def run(sess, fetches, feed_dict):
129 | with lock:
130 | # Locked tensorflow so average server response time to user is lower
131 | result = sess.run(fetches, feed_dict)
132 | return result
133 |
134 |
135 | def flatten_eps(eps):
136 | # [BS, eps_size]
137 | return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1)
138 |
139 |
140 | def unflatten_eps(feps):
141 | index = 0
142 | eps = []
143 | bs = feps.shape[0] # feps.size // eps_size
144 | for shape in eps_shapes:
145 | eps.append(np.reshape(
146 | feps[:, index: index+np.prod(shape)], (bs, *shape)))
147 | index += np.prod(shape)
148 | return eps
149 |
150 |
151 | def encode(img):
152 | if len(img.shape) == 3:
153 | img = np.expand_dims(img, 0)
154 | bs = img.shape[0]
155 | assert img.shape[1:] == (256, 256, 3)
156 | feed_dict = {enc_x: img}
157 |
158 | update_feed(feed_dict, bs) # For unoptimized model
159 | return flatten_eps(run(sess, enc_eps, feed_dict))
160 |
161 |
162 | def decode(feps):
163 | if len(feps.shape) == 1:
164 | feps = np.expand_dims(feps, 0)
165 | bs = feps.shape[0]
166 | # assert len(eps) == n_eps
167 | # for i in range(n_eps):
168 | # shape = (BATCH_SIZE, 128 // (2 ** i), 128 // (2 ** i), 6 * (2 ** i) * (2 ** (i == (n_eps - 1))))
169 | # assert eps[i].shape == shape
170 | eps = unflatten_eps(feps)
171 |
172 | feed_dict = {}
173 | for i in range(n_eps):
174 | feed_dict[dec_eps[i]] = eps[i]
175 |
176 | update_feed(feed_dict, bs) # For unoptimized model
177 | return run(sess, dec_x, feed_dict)
178 |
179 |
180 | def project(z):
181 | return np.dot(z, z_proj)
182 |
183 |
184 | def _manipulate(z, dz, alpha):
185 | z = z + alpha * dz
186 | return decode(z), z
187 |
188 |
189 | def _manipulate_range(z, dz, points, scale):
190 | z_range = np.concatenate(
191 | [z + scale*(pt/(points - 1)) * dz for pt in range(0, points)], axis=0)
192 | return decode(z_range), z_range
193 |
194 |
195 | # alpha from [0,1]
196 | def mix(z1, z2, alpha):
197 | dz = (z2 - z1)
198 | return _manipulate(z1, dz, alpha)
199 |
200 |
201 | def mix_range(z1, z2, points=5):
202 | dz = (z2 - z1)
203 | return _manipulate_range(z1, dz, points, 1.)
204 |
205 |
206 | # alpha goes from [-1,1]
207 | def manipulate(z, typ, alpha):
208 | dz = z_manipulate[typ]
209 | return _manipulate(z, dz, alpha)
210 |
211 |
212 | def manipulate_all(z, typs, alphas):
213 | dz = 0.0
214 | for i in range(len(typs)):
215 | dz += alphas[i] * z_manipulate[typs[i]]
216 | return _manipulate(z, dz, 1.0)
217 |
218 |
219 | def manipulate_range(z, typ, points=5, scale=1):
220 | dz = z_manipulate[typ]
221 | return _manipulate_range(z - dz, 2*dz, points, scale)
222 |
223 |
224 | def random(bs=1, eps_std=0.7):
225 | feps = np.random.normal(scale=eps_std, size=[bs, eps_size])
226 | return decode(feps), feps
227 |
228 |
229 | def test():
230 | img = Image.open('test/img.png')
231 | img = np.reshape(np.array(img), [1, 256, 256, 3])
232 |
233 | # Encoding speed
234 | eps = encode(img)
235 | t = time.time()
236 | for _ in tqdm(range(10)):
237 | eps = encode(img)
238 | print("Encoding latency {} sec/img".format((time.time() - t) / (1 * 10)))
239 |
240 | # Decoding speed
241 | dec = decode(eps)
242 | t = time.time()
243 | for _ in tqdm(range(10)):
244 | dec = decode(eps)
245 | print("Decoding latency {} sec/img".format((time.time() - t) / (1 * 10)))
246 | img = Image.fromarray(dec[0])
247 | img.save('test/dec.png')
248 |
249 | # Manipulation
250 | dec, _ = manipulate(eps, _TAGS.index('Smiling'), 0.66)
251 | img = Image.fromarray(dec[0])
252 | img.save('test/smile.png')
253 |
254 |
255 | # warm start
256 | _img, _z = random(1)
257 | _z = encode(_img)
258 | print("Warm started tf model")
259 |
260 | if __name__ == '__main__':
261 | test()
262 |
--------------------------------------------------------------------------------
/demo/web/glowDemo.css:
--------------------------------------------------------------------------------
1 | /* glowDemo.css
2 | *
3 | * CSS driving the Glow paper face-mixing demo.
4 | */
5 |
6 | /* Tabs */
7 |
8 | .GlowDemo_TabLabelContainer {
9 | display: table;
10 | margin: auto;
11 | }
12 |
13 | .GlowDemo_TabLabel {
14 | cursor: pointer;
15 | display: table-cell;
16 | padding: 10px;
17 | background-color: #fff;
18 | border-radius: 5px 5px 5px 5px;
19 | border: 1px solid #4bacff;
20 | color: #0b8dff;
21 | -moz-user-select: none;
22 | -webkit-user-select: none;
23 | -ms-user-select: none;
24 | min-width: 90px;
25 | text-align: center;
26 | }
27 |
28 | .GlowDemo_TabLabel:hover {
29 | background-color: #c3e3ff;
30 | }
31 |
32 | .GlowDemo_TabLabel:active, .GlowDemo_ActiveTab {
33 | color: #fff;
34 | background-color: #4bacff;
35 | }
36 |
37 | .GlowDemo_TabLabel:hover.GlowDemo_ActiveTab {
38 | background-color: #6dbbff;
39 | }
40 |
41 | /* Demo Container */
42 |
43 | .GlowDemo {
44 | box-sizing: initial;
45 | font-size: initial;
46 | line-height: initial;
47 | }
48 |
49 | .GlowDemo img {
50 | display: initial;
51 | padding: initial;
52 | position: absolute;
53 | left: initial;
54 | transform: initial;
55 | }
56 |
57 | .GlowDemo_Container {
58 | margin-top: 1em;
59 | margin-left: auto;
60 | margin-right: auto;
61 | width: 375px;
62 | transform: translateX(-16px);
63 | }
64 |
65 | /* Face Sliders Demo */
66 |
67 | .GlowDemo_FaceSlidersDemo {
68 | /* width: 404px; */
69 | overflow: hidden;
70 | display: table;
71 | margin: 1em auto 3em;
72 | }
73 |
74 | /* Face Slider Mode */
75 |
76 | .GlowDemo_SelectorAndOutput {
77 | display: table;
78 | }
79 |
80 | /* Image Selector (Input) */
81 |
82 | .GlowDemo_InputLabel {
83 | display: block;
84 | text-transform: uppercase;
85 | margin: auto;
86 | color: #747c9f;
87 | font-size: 0.8em;
88 | padding-left: 5px;
89 | }
90 |
91 | .GlowDemo_SelectorFrame {
92 | display: table-cell;
93 | padding-right: 5px;
94 | }
95 |
96 | .GlowDemo_ImageSelectorNoFaceFoundOverlay {
97 | position: absolute;
98 | z-index: 7;
99 | margin-left: -16px;
100 | margin-top: -154px;
101 | width: 14em;
102 | background-color: #fbb4d7;
103 | border-radius: 4px;
104 | padding: 0.5em;
105 | box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.205);
106 | }
107 |
108 | .GlowDemo_ImageChoice {
109 | cursor: pointer;
110 | width: 59px;
111 | height: 59px;
112 | position: absolute;
113 | }
114 |
115 | /* Input & Output Images */
116 |
117 | .GlowDemo_ImageFrame {
118 | width: 178px;
119 | height: 178px;
120 | margin: auto;
121 |
122 | padding: 2px;
123 | box-shadow: 0px 3px 6px #ddd;
124 |
125 | background-color: #f9f9f9;
126 | }
127 |
128 | /* Output Images */
129 |
130 | .GlowDemo_OutputImage {
131 | width: 178px !important;
132 | height: 178px !important;
133 | }
134 |
135 | .GlowDemo_OutputImageFrame {
136 | /* background-color: #d1d1d1 */
137 | }
138 |
139 | .GlowDemo_OutputLabel {
140 | display: block;
141 | text-transform: uppercase;
142 | margin: auto;
143 | color: #747c9f;
144 | font-size: 0.8em;
145 | text-align: right;
146 | padding-right: 5px;
147 | }
148 |
149 | .GlowDemo_OutputHider {
150 | z-index: 2;
151 | color: #fff;
152 | background-color: white;
153 | width: 187px;
154 | height: 210px;
155 | position: absolute;
156 | margin-top: -200px;
157 | margin-left: -5px;
158 | }
159 |
160 | .GlowDemo_MixingOutputFrame .GlowDemo_OutputHider {
161 | width: 248px;
162 | }
163 |
164 | .GlowDemo_FadeButton {
165 | width: 35px !important;
166 | height: 35px !important;
167 | position: absolute;
168 | border-radius: 4px;
169 | padding: 5px !important;
170 | box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.48);
171 | z-index: 5;
172 | }
173 |
174 | .GlowDemo_DownloadButton {
175 | cursor: pointer;
176 | margin-left: 128px;
177 | margin-top: 129px;
178 | background-color: #ffffffe6;
179 | }
180 |
181 | .GlowDemo_DownloadButton:hover {
182 | margin-left: 127px;
183 | margin-top: 128px;
184 | padding: 6px !important;
185 | background-color: #ffffff;
186 | }
187 |
188 | .GlowDemo_DownloadButton:active {
189 | margin-left: 130px;
190 | margin-top: 131px;
191 | padding: 3px !important;
192 | background-color: #ffffff;
193 | }
194 |
195 | .GlowDemo_UserImageEditButton {
196 | cursor: pointer;
197 | margin-left: 128px;
198 | margin-top: 129px;
199 | background-color: #ffffffe6;
200 | }
201 |
202 | .GlowDemo_UserImageEditButton:hover {
203 | margin-left: 127px;
204 | margin-top: 128px;
205 | padding: 6px !important;
206 | background-color: #ffffff;
207 | }
208 |
209 | .GlowDemo_UserImageEditButton:active {
210 | margin-left: 130px;
211 | margin-top: 131px;
212 | padding: 3px !important;
213 | background-color: #ffffff;
214 | }
215 |
216 | /* Feature Sliders */
217 |
218 | .GlowDemo_SliderFrame {
219 | width: 160px;
220 | margin: auto;
221 | display: table;
222 | padding-top: 2.7em;
223 | }
224 |
225 | .GlowDemo_MixingSliderFrame {
226 | padding-top: 0.8em;
227 | }
228 |
229 | .GlowDemo_FaceSliderContainer {
230 | /* display: table-row; */
231 | }
232 |
233 | .GlowDemo_FaceSliderLabel {
234 | display: table-cell;
235 | vertical-align: middle;
236 | }
237 |
238 | .GlowDemo_FaceSliderLabel p {
239 | margin: 0.5em 0.5em;
240 | min-width: 8em;
241 | font-size: 1.1em;
242 | line-height: 1.1em;
243 | }
244 |
245 | .GlowDemo_FaceSlider {
246 | /* display: inline-block; */
247 | /* margin: 10px 0 0px 0 !important; */
248 | width: 190px !important;
249 | vertical-align: middle;
250 | display: table-cell;
251 | }
252 |
253 | .GlowDemo_SliderHider {
254 | position: absolute;
255 | height: 43px;
256 | width: 365px;
257 | color: #fff;
258 | background-color: #ffffff;
259 | /* margin-top: -30px; */
260 | margin-left: -365px;
261 | }
262 |
263 | /* Face Mixing Demo */
264 |
265 | .GlowDemo_FaceMixingDemo {
266 | /* width: 404px; */
267 | overflow: hidden;
268 | display: table;
269 | margin: 1em auto 3em;
270 | }
271 |
272 | .GlowDemo_LeftInputLabel {
273 | display: block;
274 | text-transform: uppercase;
275 | margin: auto;
276 | color: #747c9f;
277 | font-size: 0.8em;
278 | padding-left: 5px;
279 | }
280 |
281 | .GlowDemo_RightInputLabel {
282 | display: block;
283 | text-transform: uppercase;
284 | margin: auto;
285 | color: #747c9f;
286 | font-size: 0.8em;
287 | text-align: right;
288 | padding-right: 5px;
289 | }
290 |
291 | .GlowDemo_MixingInputImagesContainer {
292 | display: table;
293 | }
294 |
295 | .GlowDemo_MixingSelectorFrame {
296 | display: table-cell;
297 | }
298 |
299 | .GlowDemo_MixingSelectorFrameLeft {
300 | padding-right: 5px;
301 | }
302 |
303 | .GlowDemo_MixingSelectorFrameRight {
304 | /* display: table-cell; */
305 | }
306 |
307 | .GlowDemo_OutputAndMixingSliderContainer {
308 | display: table;
309 | margin: auto;
310 | }
311 |
312 | .GlowDemo_MixingOutputFrame {
313 | margin-top: 15px;
314 | }
315 |
316 | .GlowDemo_MixingOutputLabel {
317 | display: block;
318 | text-transform: uppercase;
319 | margin: auto;
320 | color: #747c9f;
321 | font-size: 0.8em;
322 | text-align: center;
323 | }
324 |
325 | .GlowDemo_MixingSliderLabel {
326 | display: table;
327 | text-transform: uppercase;
328 | margin: auto;
329 | color: #747c9f;
330 | font-size: 0.8em;
331 | text-align: center;
332 | }
333 |
334 | .GlowDemo_MixingSlider {
335 | display: table;
336 | }
337 |
338 | .GlowDemo_MixingSliderHider {
339 | margin-top: -69px;
340 | height: 60px;
341 | }
342 |
343 | /* Loading Visuals */
344 |
345 | .GlowDemo_LoadingVisual {
346 | position: absolute;
347 | display: block;
348 | z-index: 5;
349 | width: 60px;
350 | height: 60px;
351 | margin-top: 59px;
352 | margin-left: 60px;
353 | }
354 |
355 | /* Hints */
356 |
357 | .GlowDemo_Hint {
358 | position: absolute;
359 | z-index: 4;
360 | font-size: 1em;
361 | line-height: 1.2em;
362 | background-color: #ffffffe6;
363 | border-radius: 4px;
364 | padding: 0.1em 0.4em;
365 | color: #818181;
366 | -moz-user-select: none;
367 | -webkit-user-select: none;
368 | -ms-user-select: none;
369 | /* box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.07); */
370 | }
371 |
372 | .GlowDemo_SelectorHint {
373 | max-width: 220px;
374 | margin-left: -173px;
375 | margin-top: 7px;
376 | }
377 |
378 | .GlowDemo_DownloadHint {
379 | max-width: 220px;
380 | margin-left: 23px;
381 | margin-top: 7px;
382 | background-color: #ffffffe6;
383 | border-radius: 4px;
384 | padding: 0.1em 0.4em;
385 | color: #818181;
386 | /* box-shadow: 0px 3px 6px rgba(0, 0, 0, 0.07); */
387 | -moz-user-select: none;
388 | -webkit-user-select: none;
389 | -ms-user-select: none;
390 | font-size: 1em;
391 | }
392 |
393 | .GlowDemo_MixingHint {
394 | max-width: 324px;
395 | margin-left: -36px;
396 | margin-top: 84px;
397 | }
398 |
399 | input[type=range] {
400 | /*removes default webkit styles*/
401 | -webkit-appearance: none;
402 |
403 | /*fix for FF unable to apply focus style bug */
404 | border: 1px solid white;
405 |
406 | /*required for proper track sizing in FF*/
407 | width: 190px;
408 | height: 35px;
409 |
410 | /*centering*/
411 | vertical-align: middle;
412 | display: table-cell;
413 | }
414 | input[type=range]::-webkit-slider-runnable-track {
415 | width: 190px;
416 | height: 3px;
417 | background: #ddd;
418 | border: none;
419 | border-radius: 3px;
420 | }
421 | input[type=range]::-webkit-slider-thumb {
422 | -webkit-appearance: none;
423 | border: none;
424 | height: 16px;
425 | width: 16px;
426 | border-radius: 50%;
427 | background: #4bacff;
428 | margin-top: -6px;
429 | }
430 | input[type=range]:focus {
431 | outline: none;
432 | }
433 | input[type=range]:focus::-webkit-slider-runnable-track {
434 | background: #ccc;
435 | }
436 |
437 | input[type=range]::-moz-range-track {
438 | width: 190px;
439 | height: 3px;
440 | background: #ddd;
441 | border: none;
442 | border-radius: 3px;
443 | }
444 | input[type=range]::-moz-range-thumb {
445 | border: none;
446 | height: 16px;
447 | width: 16px;
448 | border-radius: 50%;
449 | background: #4bacff;
450 | }
451 |
452 | /*hide the outline behind the border*/
453 | input[type=range]:-moz-focusring{
454 | outline: 1px solid white;
455 | outline-offset: -1px;
456 | }
457 |
458 | input[type=range]::-ms-track {
459 | width: 190px;
460 | height: 3px;
461 |
462 | /*remove bg colour from the track, we'll use ms-fill-lower and ms-fill-upper instead */
463 | background: transparent;
464 |
465 | /*leave room for the larger thumb to overflow with a transparent border */
466 | border-color: transparent;
467 | border-width: 6px 0;
468 |
469 | /*remove default tick marks*/
470 | color: transparent;
471 | }
472 | input[type=range]::-ms-fill-lower {
473 | background: #777;
474 | border-radius: 10px;
475 | }
476 | input[type=range]::-ms-fill-upper {
477 | background: #ddd;
478 | border-radius: 10px;
479 | }
480 | input[type=range]::-ms-thumb {
481 | border: none;
482 | height: 16px;
483 | width: 16px;
484 | border-radius: 50%;
485 | background: #4bacff;
486 | }
487 | input[type=range]:focus::-ms-fill-lower {
488 | background: #888;
489 | }
490 | input[type=range]:focus::-ms-fill-upper {
491 | background: #ccc;
492 | }
493 |
--------------------------------------------------------------------------------
/demo/web/canvas2image.js:
--------------------------------------------------------------------------------
1 | // https://github.com/hongru/canvas2image
2 |
3 | /**
4 | * covert canvas to image
5 | * and save the image file
6 | */
7 |
8 | var Canvas2Image = function () {
9 |
10 | // check if support sth.
11 | var $support = function () {
12 | var canvas = document.createElement('canvas'),
13 | ctx = canvas.getContext('2d');
14 |
15 | return {
16 | canvas: !!ctx,
17 | imageData: !!ctx.getImageData,
18 | dataURL: !!canvas.toDataURL,
19 | btoa: !!window.btoa
20 | };
21 | }();
22 |
23 | var downloadMime = 'image/octet-stream';
24 |
25 | function scaleCanvas (canvas, width, height) {
26 | var w = canvas.width,
27 | h = canvas.height;
28 | if (width == undefined) {
29 | width = w;
30 | }
31 | if (height == undefined) {
32 | height = h;
33 | }
34 |
35 | if (window.GlowDemoCanvasCropRect) {
36 | let r = window.GlowDemoCanvasCropRect;
37 |
38 | let rx = getCropRectParam(r.x, 0);
39 | let ry = getCropRectParam(r.y, 0);
40 | let rw = getCropRectParam(r.width, canvas.width);
41 | let rh = getCropRectParam(r.height, canvas.height);
42 | console.log([rx, ry, rw, rh]);
43 | console.log('cropped image');
44 |
45 | var retCanvas = document.createElement('canvas');
46 | var retCtx = retCanvas.getContext('2d');
47 | retCanvas.width = rw;
48 | retCanvas.height = rh;
49 | retCtx.drawImage(canvas, rx, ry, rw, rh, 0, 0, rw, rh);
50 | return retCanvas;
51 | }
52 | else {
53 | console.log('will NOT crop image');
54 | var retCanvas = document.createElement('canvas');
55 | var retCtx = retCanvas.getContext('2d');
56 | retCanvas.width = width;
57 | retCanvas.height = height;
58 | retCtx.drawImage(canvas, 0, 0, w, h, 0, 0, width, height);
59 | return retCanvas;
60 | }
61 | }
62 |
63 | function getDataURL (canvas, type, width, height) {
64 | canvas = scaleCanvas(canvas, width, height);
65 | return canvas.toDataURL(type);
66 | }
67 |
68 | function saveFile (strData) {
69 | if (window.GlowDemoDownloadFileName) {
70 | var element = document.createElement('a');
71 | element.setAttribute('href', strData);
72 | element.setAttribute('download', window.GlowDemoDownloadFileName);
73 |
74 | element.style.display = 'none';
75 | document.body.appendChild(element);
76 |
77 | element.click();
78 |
79 | document.body.removeChild(element);
80 | }
81 |
82 | //document.location.href = strData;
83 | }
84 |
85 | function genImage(strData) {
86 | var img = document.createElement('img');
87 | img.src = strData;
88 | return img;
89 | }
90 | function fixType (type) {
91 | type = type.toLowerCase().replace(/jpg/i, 'jpeg');
92 | var r = type.match(/png|jpeg|bmp|gif/)[0];
93 | return 'image/' + r;
94 | }
95 | function encodeData (data) {
96 | if (!window.btoa) { throw 'btoa undefined' }
97 | var str = '';
98 | if (typeof data == 'string') {
99 | str = data;
100 | } else {
101 | for (var i = 0; i < data.length; i ++) {
102 | str += String.fromCharCode(data[i]);
103 | }
104 | }
105 |
106 | return btoa(str);
107 | }
108 | function getImageData (canvas) {
109 | console.log(window.GlowDemoCanvasCropRect);
110 | if (window.GlowDemoCanvasCropRect) {
111 | let r = window.GlowDemoCanvasCropRect;
112 |
113 | let x = getCropRectParam(r.x, 0);
114 | let y = getCropRectParam(r.y, 0);
115 | let w = getCropRectParam(r.width, canvas.width);
116 | let h = getCropRectParam(r.height, canvas.height);
117 | console.log([r.x, r.y, r.width, r.height]);
118 |
119 | return canvas.getContext('2d').getImageData(r.x, r.y, r.width, r.height);
120 | }
121 | else {
122 | var w = canvas.width,
123 | h = canvas.height;
124 | return canvas.getContext('2d').getImageData(0, 0, w, h);
125 | }
126 | }
127 | function getCropRectParam(param, autoParam) {
128 | if (isFunction(param)) {
129 | return param(autoParam);
130 | }
131 | else if (param === "auto") {
132 | return autoParam;
133 | }
134 | else {
135 | return param;
136 | }
137 | }
138 | function isFunction(obj) {
139 | return !!(obj && obj.constructor && obj.call && obj.apply);
140 | }
141 | function makeURI (strData, type) {
142 | return 'data:' + type + ';base64,' + strData;
143 | }
144 |
145 |
146 | /**
147 | * create bitmap image
148 | * 按照规则生成图片响应头和响应体
149 | */
150 | var genBitmapImage = function (oData) {
151 |
152 | //
153 | // BITMAPFILEHEADER: http://msdn.microsoft.com/en-us/library/windows/desktop/dd183374(v=vs.85).aspx
154 | // BITMAPINFOHEADER: http://msdn.microsoft.com/en-us/library/dd183376.aspx
155 | //
156 |
157 | var biWidth = oData.width;
158 | var biHeight = oData.height;
159 | var biSizeImage = biWidth * biHeight * 3;
160 | var bfSize = biSizeImage + 54; // total header size = 54 bytes
161 |
162 | //
163 | // typedef struct tagBITMAPFILEHEADER {
164 | // WORD bfType;
165 | // DWORD bfSize;
166 | // WORD bfReserved1;
167 | // WORD bfReserved2;
168 | // DWORD bfOffBits;
169 | // } BITMAPFILEHEADER;
170 | //
171 | var BITMAPFILEHEADER = [
172 | // WORD bfType -- The file type signature; must be "BM"
173 | 0x42, 0x4D,
174 | // DWORD bfSize -- The size, in bytes, of the bitmap file
175 | bfSize & 0xff, bfSize >> 8 & 0xff, bfSize >> 16 & 0xff, bfSize >> 24 & 0xff,
176 | // WORD bfReserved1 -- Reserved; must be zero
177 | 0, 0,
178 | // WORD bfReserved2 -- Reserved; must be zero
179 | 0, 0,
180 | // DWORD bfOffBits -- The offset, in bytes, from the beginning of the BITMAPFILEHEADER structure to the bitmap bits.
181 | 54, 0, 0, 0
182 | ];
183 |
184 | //
185 | // typedef struct tagBITMAPINFOHEADER {
186 | // DWORD biSize;
187 | // LONG biWidth;
188 | // LONG biHeight;
189 | // WORD biPlanes;
190 | // WORD biBitCount;
191 | // DWORD biCompression;
192 | // DWORD biSizeImage;
193 | // LONG biXPelsPerMeter;
194 | // LONG biYPelsPerMeter;
195 | // DWORD biClrUsed;
196 | // DWORD biClrImportant;
197 | // } BITMAPINFOHEADER, *PBITMAPINFOHEADER;
198 | //
199 | var BITMAPINFOHEADER = [
200 | // DWORD biSize -- The number of bytes required by the structure
201 | 40, 0, 0, 0,
202 | // LONG biWidth -- The width of the bitmap, in pixels
203 | biWidth & 0xff, biWidth >> 8 & 0xff, biWidth >> 16 & 0xff, biWidth >> 24 & 0xff,
204 | // LONG biHeight -- The height of the bitmap, in pixels
205 | biHeight & 0xff, biHeight >> 8 & 0xff, biHeight >> 16 & 0xff, biHeight >> 24 & 0xff,
206 | // WORD biPlanes -- The number of planes for the target device. This value must be set to 1
207 | 1, 0,
208 | // WORD biBitCount -- The number of bits-per-pixel, 24 bits-per-pixel -- the bitmap
209 | // has a maximum of 2^24 colors (16777216, Truecolor)
210 | 24, 0,
211 | // DWORD biCompression -- The type of compression, BI_RGB (code 0) -- uncompressed
212 | 0, 0, 0, 0,
213 | // DWORD biSizeImage -- The size, in bytes, of the image. This may be set to zero for BI_RGB bitmaps
214 | biSizeImage & 0xff, biSizeImage >> 8 & 0xff, biSizeImage >> 16 & 0xff, biSizeImage >> 24 & 0xff,
215 | // LONG biXPelsPerMeter, unused
216 | 0,0,0,0,
217 | // LONG biYPelsPerMeter, unused
218 | 0,0,0,0,
219 | // DWORD biClrUsed, the number of color indexes of palette, unused
220 | 0,0,0,0,
221 | // DWORD biClrImportant, unused
222 | 0,0,0,0
223 | ];
224 |
225 | var iPadding = (4 - ((biWidth * 3) % 4)) % 4;
226 |
227 | var aImgData = oData.data;
228 |
229 | var strPixelData = '';
230 | var biWidth4 = biWidth<<2;
231 | var y = biHeight;
232 | var fromCharCode = String.fromCharCode;
233 |
234 | do {
235 | var iOffsetY = biWidth4*(y-1);
236 | var strPixelRow = '';
237 | for (var x = 0; x < biWidth; x++) {
238 | var iOffsetX = x<<2;
239 | strPixelRow += fromCharCode(aImgData[iOffsetY+iOffsetX+2]) +
240 | fromCharCode(aImgData[iOffsetY+iOffsetX+1]) +
241 | fromCharCode(aImgData[iOffsetY+iOffsetX]);
242 | }
243 |
244 | for (var c = 0; c < iPadding; c++) {
245 | strPixelRow += String.fromCharCode(0);
246 | }
247 |
248 | strPixelData += strPixelRow;
249 | } while (--y);
250 |
251 | var strEncoded = encodeData(BITMAPFILEHEADER.concat(BITMAPINFOHEADER)) + encodeData(strPixelData);
252 |
253 | return strEncoded;
254 | };
255 |
256 | /**
257 | * saveAsImage
258 | * @param canvasElement
259 | * @param {String} image type
260 | * @param {Number} [optional] png width
261 | * @param {Number} [optional] png height
262 | */
263 | var saveAsImage = function (canvas, width, height, type) {
264 | if ($support.canvas && $support.dataURL) {
265 | if (typeof canvas == "string") { canvas = document.getElementById(canvas); }
266 | if (type == undefined) { type = 'png'; }
267 | type = fixType(type);
268 | if (/bmp/.test(type)) {
269 | var data = getImageData(scaleCanvas(canvas, width, height));
270 | var strData = genBitmapImage(data);
271 | saveFile(makeURI(strData, downloadMime));
272 | } else {
273 | var strData = getDataURL(canvas, type, width, height);
274 | saveFile(strData.replace(type, downloadMime));
275 | }
276 | }
277 | };
278 |
279 | var convertToImage = function (canvas, width, height, type) {
280 | if ($support.canvas && $support.dataURL) {
281 | if (typeof canvas == "string") { canvas = document.getElementById(canvas); }
282 | if (type == undefined) { type = 'png'; }
283 | type = fixType(type);
284 |
285 | if (/bmp/.test(type)) {
286 | var data = getImageData(scaleCanvas(canvas, width, height));
287 | var strData = genBitmapImage(data);
288 | return genImage(makeURI(strData, 'image/bmp'));
289 | } else {
290 | var strData = getDataURL(canvas, type, width, height);
291 | return genImage(strData);
292 | }
293 | }
294 | };
295 |
296 |
297 |
298 | return {
299 | saveAsImage: saveAsImage,
300 | saveAsPNG: function (canvas, width, height) {
301 | return saveAsImage(canvas, width, height, 'png');
302 | },
303 | saveAsJPEG: function (canvas, width, height) {
304 | return saveAsImage(canvas, width, height, 'jpeg');
305 | },
306 | saveAsGIF: function (canvas, width, height) {
307 | return saveAsImage(canvas, width, height, 'gif');
308 | },
309 | saveAsBMP: function (canvas, width, height) {
310 | return saveAsImage(canvas, width, height, 'bmp');
311 | },
312 |
313 | convertToImage: convertToImage,
314 | convertToPNG: function (canvas, width, height) {
315 | return convertToImage(canvas, width, height, 'png');
316 | },
317 | convertToJPEG: function (canvas, width, height) {
318 | return convertToImage(canvas, width, height, 'jpeg');
319 | },
320 | convertToGIF: function (canvas, width, height) {
321 | return convertToImage(canvas, width, height, 'gif');
322 | },
323 | convertToBMP: function (canvas, width, height) {
324 | return convertToImage(canvas, width, height, 'bmp');
325 | }
326 | };
327 |
328 | }();
329 |
--------------------------------------------------------------------------------
/data_loaders/generate_tfr/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Generate CelebA-HQ and Imagenet datasets
18 | For CelebA-HQ, first create original tfrecords file using https://github.com/tkarras/progressive_growing_of_gans/blob/master/dataset_tool.py
19 | For Imagenet, first create original tfrecords file using https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py
20 | Then, use this script to get our tfr file from those records.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import os
28 |
29 | import tensorflow as tf
30 | import numpy as np
31 | from tqdm import tqdm
32 | from typing import Iterable
33 |
34 | _NUM_CHANNELS = 3
35 |
36 |
37 | _NUM_PARALLEL_FILE_READERS = 32
38 | _NUM_PARALLEL_MAP_CALLS = 32
39 | _DOWNSAMPLING = tf.image.ResizeMethod.BILINEAR
40 | _SHUFFLE_BUFFER = 1024
41 |
42 |
43 | def _int64_feature(value):
44 | if not isinstance(value, Iterable):
45 | value = [value]
46 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
47 |
48 |
49 | def _bytes_feature(value):
50 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
51 |
52 |
53 | def error(msg):
54 | print('Error: ' + msg)
55 | exit(1)
56 |
57 |
58 | def x_to_uint8(x):
59 | return tf.cast(tf.clip_by_value(tf.floor(x), 0, 255), 'uint8')
60 |
61 |
62 | def centre_crop(img):
63 | h, w = tf.shape(img)[0], tf.shape(img)[1]
64 | min_side = tf.minimum(h, w)
65 | h_offset = (h - min_side) // 2
66 | w_offset = (w - min_side) // 2
67 | return tf.image.crop_to_bounding_box(img, h_offset, w_offset, min_side, min_side)
68 |
69 |
70 | def downsample(img):
71 | return (img[0::2, 0::2, :] + img[0::2, 1::2, :] + img[1::2, 0::2, :] + img[1::2, 1::2, :]) * 0.25
72 |
73 |
74 | def parse_image(max_res):
75 | def _process_image(img):
76 | img = centre_crop(img)
77 | img = tf.image.resize_images(
78 | img, [max_res, max_res], method=_DOWNSAMPLING)
79 | img = tf.cast(img, 'float32')
80 | resolution_log2 = int(np.log2(max_res))
81 | q_imgs = []
82 | for lod in range(resolution_log2 - 1):
83 | if lod:
84 | img = downsample(img)
85 | quant = x_to_uint8(img)
86 | q_imgs.append(quant)
87 | return q_imgs
88 |
89 | def _parse_image(example):
90 | feature_map = {
91 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
92 | default_value=''),
93 | 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
94 | default_value=-1)
95 | }
96 | features = tf.parse_single_example(example, feature_map)
97 | img, label = features['image/encoded'], features['image/class/label']
98 | label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32) - 1
99 | img = tf.image.decode_jpeg(img, channels=_NUM_CHANNELS)
100 | imgs = _process_image(img)
101 | parsed = (label, *imgs)
102 | return parsed
103 |
104 | return _parse_image
105 |
106 |
107 | def parse_celeba_image(max_res, transpose=False):
108 | def _process_image(img):
109 | img = tf.cast(img, 'float32')
110 | resolution_log2 = int(np.log2(max_res))
111 | q_imgs = []
112 | for lod in range(resolution_log2 - 1):
113 | if lod:
114 | img = downsample(img)
115 | quant = x_to_uint8(img)
116 | q_imgs.append(quant)
117 | return q_imgs
118 |
119 | def _parse_image(example):
120 | features = tf.parse_single_example(example, features={
121 | 'shape': tf.FixedLenFeature([3], tf.int64),
122 | 'data': tf.FixedLenFeature([], tf.string),
123 | 'attr': tf.FixedLenFeature([40], tf.int64)})
124 | shape = features['shape']
125 | data = features['data']
126 | attr = features['attr']
127 | data = tf.decode_raw(data, tf.uint8)
128 | img = tf.reshape(data, shape)
129 | if transpose:
130 | img = tf.transpose(img, (1, 2, 0)) # CHW -> HWC
131 | imgs = _process_image(img)
132 | parsed = (attr, *imgs)
133 | return parsed
134 |
135 | return _parse_image
136 |
137 |
138 | def get_tfr_files(data_dir, split, lgres):
139 | data_dir = os.path.join(data_dir, split)
140 | tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
141 | tfr_files = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (lgres)
142 | return tfr_files
143 |
144 |
145 | def get_tfr_file(data_dir, split, lgres):
146 | if split:
147 | data_dir = os.path.join(data_dir, split)
148 | tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir))
149 | tfr_file = tfr_prefix + '-r%02d.tfrecords' % (lgres)
150 | return tfr_file
151 |
152 |
153 | def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
154 | _NUM_IMAGES = {
155 | 'train': 27000,
156 | 'validation': 3000,
157 | }
158 |
159 | _NUM_SHARDS = {
160 | 'train': 120,
161 | 'validation': 40,
162 | }
163 | resolution_log2 = int(np.log2(max_res))
164 | if max_res != 2 ** resolution_log2:
165 | error('Input image resolution must be a power-of-two')
166 | with tf.Session() as sess:
167 | print("Reading data from ", data_dir)
168 | if split:
169 | tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res)))
170 | files = tf.data.Dataset.list_files(tfr_files)
171 | dset = files.apply(tf.contrib.data.parallel_interleave(
172 | tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
173 | transpose = False
174 | else:
175 | tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res)))
176 | dset = tf.data.TFRecordDataset(tfr_file, compression_type='')
177 | transpose = True
178 |
179 | parse_fn = parse_celeba_image(max_res, transpose)
180 | dset = dset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
181 | dset = dset.prefetch(1)
182 | iterator = dset.make_one_shot_iterator()
183 | _attr, *_imgs = iterator.get_next()
184 | sess.run(tf.global_variables_initializer())
185 | splits = [split] if split else ["validation", "train"]
186 | for split in splits:
187 | total_imgs = _NUM_IMAGES[split]
188 | shards = _NUM_SHARDS[split]
189 | with TFRecordExporter(os.path.join(tfrecord_dir, split), resolution_log2, total_imgs, shards) as tfr:
190 | for _ in tqdm(range(total_imgs)):
191 | attr, *imgs = sess.run([_attr, *_imgs])
192 | if write:
193 | tfr.add_image(0, imgs, attr)
194 | if write:
195 | assert tfr.cur_images == total_imgs, (
196 | tfr.cur_images, total_imgs)
197 |
198 | #attr, *imgs = sess.run([_attr, *_imgs])
199 |
200 |
201 | def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
202 | _NUM_IMAGES = {
203 | 'train': 1281167,
204 | 'validation': 50000,
205 | }
206 |
207 | _NUM_FILES = _NUM_SHARDS = {
208 | 'train': 2000,
209 | 'validation': 80,
210 | }
211 | resolution_log2 = int(np.log2(max_res))
212 | if max_res != 2 ** resolution_log2:
213 | error('Input image resolution must be a power-of-two')
214 |
215 | with tf.Session() as sess:
216 | is_training = (split == 'train')
217 | if is_training:
218 | files = tf.data.Dataset.list_files(
219 | os.path.join(data_dir, 'train-*-of-01024'))
220 | else:
221 | files = tf.data.Dataset.list_files(
222 | os.path.join(data_dir, 'validation-*-of-00128'))
223 |
224 | files = files.shuffle(buffer_size=_NUM_FILES[split])
225 |
226 | dataset = files.apply(tf.contrib.data.parallel_interleave(
227 | tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
228 |
229 | dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
230 | parse_fn = parse_image(max_res)
231 | dataset = dataset.map(
232 | parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
233 | dataset = dataset.prefetch(1)
234 | iterator = dataset.make_one_shot_iterator()
235 |
236 | _label, *_imgs = iterator.get_next()
237 |
238 | sess.run(tf.global_variables_initializer())
239 |
240 | total_imgs = _NUM_IMAGES[split]
241 | shards = _NUM_SHARDS[split]
242 | tfrecord_dir = os.path.join(tfrecord_dir, split)
243 | with TFRecordExporter(tfrecord_dir, resolution_log2, total_imgs, shards) as tfr:
244 | for _ in tqdm(range(total_imgs)):
245 | label, *imgs = sess.run([_label, *_imgs])
246 | if write:
247 | tfr.add_image(label, imgs, [])
248 | assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs)
249 |
250 | #label, *imgs = sess.run([_label, *_imgs])
251 |
252 |
253 | class TFRecordExporter:
254 | def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print_progress=True, progress_interval=10):
255 | self.tfrecord_dir = tfrecord_dir
256 | self.tfr_prefix = os.path.join(
257 | self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
258 | self.resolution_log2 = resolution_log2
259 | self.expected_images = expected_images
260 |
261 | self.cur_images = 0
262 | self.shape = None
263 | self.tfr_writers = []
264 | self.print_progress = print_progress
265 | self.progress_interval = progress_interval
266 | if self.print_progress:
267 | print('Creating dataset "%s"' % tfrecord_dir)
268 | if not os.path.isdir(self.tfrecord_dir):
269 | os.makedirs(self.tfrecord_dir)
270 | assert (os.path.isdir(self.tfrecord_dir))
271 | tfr_opt = tf.python_io.TFRecordOptions(
272 | tf.python_io.TFRecordCompressionType.NONE)
273 | for lod in range(self.resolution_log2 - 1):
274 | p_shard = np.array_split(
275 | np.random.permutation(expected_images), shards)
276 | img_to_shard = np.zeros(expected_images, dtype=np.int)
277 | writers = []
278 | for shard in range(shards):
279 | img_to_shard[p_shard[shard]] = shard
280 | tfr_file = self.tfr_prefix + \
281 | '-r%02d-s-%04d-of-%04d.tfrecords' % (
282 | self.resolution_log2 - lod, shard, shards)
283 | writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
284 | #print(np.unique(img_to_shard, return_counts=True))
285 | counts = np.unique(img_to_shard, return_counts=True)[1]
286 | assert len(counts) == shards
287 | print("Smallest and largest shards have size",
288 | np.min(counts), np.max(counts))
289 | self.tfr_writers.append((writers, img_to_shard))
290 |
291 | def close(self):
292 | if self.print_progress:
293 | print('%-40s\r' % 'Flushing data...', end='', flush=True)
294 | for (writers, _) in self.tfr_writers:
295 | for writer in writers:
296 | writer.close()
297 | self.tfr_writers = []
298 | if self.print_progress:
299 | print('%-40s\r' % '', end='', flush=True)
300 | print('Added %d images.' % self.cur_images)
301 |
302 | def add_image(self, label, imgs, attr):
303 | assert len(imgs) == len(self.tfr_writers)
304 | # if self.print_progress and self.cur_images % self.progress_interval == 0:
305 | # print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
306 | for lod, (writers, img_to_shard) in enumerate(self.tfr_writers):
307 | quant = imgs[lod]
308 | size = 2 ** (self.resolution_log2 - lod)
309 | assert quant.shape == (size, size, 3), quant.shape
310 | ex = tf.train.Example(
311 | features=tf.train.Features(
312 | feature={
313 | 'shape': _int64_feature(quant.shape),
314 | 'data': _bytes_feature(quant.tostring()),
315 | 'label': _int64_feature(label),
316 | 'attr': _int64_feature(attr)
317 | }
318 | )
319 | )
320 | writers[img_to_shard[self.cur_images]].write(
321 | ex.SerializeToString())
322 | self.cur_images += 1
323 |
324 | # def add_labels(self, labels):
325 | # if self.print_progress:
326 | # print('%-40s\r' % 'Saving labels...', end='', flush=True)
327 | # assert labels.shape[0] == self.cur_images
328 | # with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:
329 | # np.save(f, labels.astype(np.float32))
330 |
331 | def __enter__(self):
332 | return self
333 |
334 | def __exit__(self, *args):
335 | self.close()
336 |
337 |
338 | if __name__ == "__main__":
339 | import argparse
340 | parser = argparse.ArgumentParser()
341 | parser.add_argument("--data_dir", type=str, required=True)
342 | parser.add_argument("--max_res", type=int, default=256, help="Image size")
343 | parser.add_argument("--tfrecord_dir", type=str,
344 | required=True, help='place to dump')
345 | parser.add_argument("--write", action='store_true',
346 | help="Whether to write")
347 | hps = parser.parse_args() # So error if typo
348 | #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write)
349 | #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)
350 | dump_celebahq(hps.data_dir, hps.tfrecord_dir,
351 | hps.max_res, 'validation', hps.write)
352 | dump_celebahq(hps.data_dir, hps.tfrecord_dir,
353 | hps.max_res, 'train', hps.write)
354 |
--------------------------------------------------------------------------------
/demo/web/load-image.all.min.js:
--------------------------------------------------------------------------------
1 | !function(e){"use strict";function t(e,i,a){var o,n=document.createElement("img");return n.onerror=function(o){return t.onerror(n,o,e,i,a)},n.onload=function(o){return t.onload(n,o,e,i,a)},"string"==typeof e?(t.fetchBlob(e,function(i){i?(e=i,o=t.createObjectURL(e)):(o=e,a&&a.crossOrigin&&(n.crossOrigin=a.crossOrigin)),n.src=o},a),n):t.isInstanceOf("Blob",e)||t.isInstanceOf("File",e)?(o=n._objectURL=t.createObjectURL(e))?(n.src=o,n):t.readFile(e,function(e){var t=e.target;t&&t.result?n.src=t.result:i&&i(e)}):void 0}function i(e,i){!e._objectURL||i&&i.noRevoke||(t.revokeObjectURL(e._objectURL),delete e._objectURL)}var a=e.createObjectURL&&e||e.URL&&URL.revokeObjectURL&&URL||e.webkitURL&&webkitURL;t.fetchBlob=function(e,t,i){t()},t.isInstanceOf=function(e,t){return Object.prototype.toString.call(t)==="[object "+e+"]"},t.transform=function(e,t,i,a,o){i(e,o)},t.onerror=function(e,t,a,o,n){i(e,n),o&&o.call(e,t)},t.onload=function(e,a,o,n,r){i(e,r),n&&t.transform(e,r,n,o,{})},t.createObjectURL=function(e){return!!a&&a.createObjectURL(e)},t.revokeObjectURL=function(e){return!!a&&a.revokeObjectURL(e)},t.readFile=function(t,i,a){if(e.FileReader){var o=new FileReader;if(o.onload=o.onerror=i,a=a||"readAsDataURL",o[a])return o[a](t),o}return!1},"function"==typeof define&&define.amd?define(function(){return t}):"object"==typeof module&&module.exports?module.exports=t:e.loadImage=t}("undefined"!=typeof window&&window||this),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image"],e):e("object"==typeof module&&module.exports?require("./load-image"):window.loadImage)}(function(e){"use strict";var t=e.transform;e.transform=function(i,a,o,n,r){t.call(e,e.scale(i,a,r),a,o,n,r)},e.transformCoordinates=function(){},e.getTransformedOptions=function(e,t){var i,a,o,n,r=t.aspectRatio;if(!r)return t;i={};for(a in t)t.hasOwnProperty(a)&&(i[a]=t[a]);return i.crop=!0,o=e.naturalWidth||e.width,n=e.naturalHeight||e.height,o/n>r?(i.maxWidth=n*r,i.maxHeight=n):(i.maxWidth=o,i.maxHeight=o/r),i},e.renderImageToCanvas=function(e,t,i,a,o,n,r,s,l,d){return e.getContext("2d").drawImage(t,i,a,o,n,r,s,l,d),e},e.hasCanvasOption=function(e){return e.canvas||e.crop||!!e.aspectRatio},e.scale=function(t,i,a){function o(){var e=Math.max((l||v)/v,(d||P)/P);e>1&&(v*=e,P*=e)}function n(){var e=Math.min((r||v)/v,(s||P)/P);e<1&&(v*=e,P*=e)}i=i||{};var r,s,l,d,c,u,f,g,h,m,p,S=document.createElement("canvas"),b=t.getContext||e.hasCanvasOption(i)&&S.getContext,y=t.naturalWidth||t.width,x=t.naturalHeight||t.height,v=y,P=x;if(b&&(f=(i=e.getTransformedOptions(t,i,a)).left||0,g=i.top||0,i.sourceWidth?(c=i.sourceWidth,void 0!==i.right&&void 0===i.left&&(f=y-c-i.right)):c=y-f-(i.right||0),i.sourceHeight?(u=i.sourceHeight,void 0!==i.bottom&&void 0===i.top&&(g=x-u-i.bottom)):u=x-g-(i.bottom||0),v=c,P=u),r=i.maxWidth,s=i.maxHeight,l=i.minWidth,d=i.minHeight,b&&r&&s&&i.crop?(v=r,P=s,(p=c/u-r/s)<0?(u=s*c/r,void 0===i.top&&void 0===i.bottom&&(g=(x-u)/2)):p>0&&(c=r*u/s,void 0===i.left&&void 0===i.right&&(f=(y-c)/2))):((i.contain||i.cover)&&(l=r=r||l,d=s=s||d),i.cover?(n(),o()):(o(),n())),b){if((h=i.pixelRatio)>1&&(S.style.width=v+"px",S.style.height=P+"px",v*=h,P*=h,S.getContext("2d").scale(h,h)),(m=i.downsamplingRatio)>0&&m<1&&vv;)S.width=c*m,S.height=u*m,e.renderImageToCanvas(S,t,f,g,c,u,0,0,S.width,S.height),f=0,g=0,c=S.width,u=S.height,(t=document.createElement("canvas")).width=c,t.height=u,e.renderImageToCanvas(t,S,0,0,c,u,0,0,c,u);return S.width=v,S.height=P,e.transformCoordinates(S,i),e.renderImageToCanvas(S,t,f,g,c,u,0,0,v,P)}return t.width=v,t.height=P,t}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image"],e):e("object"==typeof module&&module.exports?require("./load-image"):window.loadImage)}(function(e){"use strict";var t="undefined"!=typeof Blob&&(Blob.prototype.slice||Blob.prototype.webkitSlice||Blob.prototype.mozSlice);e.blobSlice=t&&function(){return(this.slice||this.webkitSlice||this.mozSlice).apply(this,arguments)},e.metaDataParsers={jpeg:{65505:[]}},e.parseMetaData=function(t,i,a,o){a=a||{},o=o||{};var n=this,r=a.maxMetaDataSize||262144;!!("undefined"!=typeof DataView&&t&&t.size>=12&&"image/jpeg"===t.type&&e.blobSlice)&&e.readFile(e.blobSlice.call(t,0,r),function(t){if(t.target.error)return console.log(t.target.error),void i(o);var r,s,l,d,c=t.target.result,u=new DataView(c),f=2,g=u.byteLength-4,h=f;if(65496===u.getUint16(0)){for(;f=65504&&r<=65519||65534===r);){if(s=u.getUint16(f+2)+2,f+s>u.byteLength){console.log("Invalid meta data: Invalid segment size.");break}if(l=e.metaDataParsers.jpeg[r])for(d=0;d6&&(c.slice?o.imageHead=c.slice(0,h):o.imageHead=new Uint8Array(c).subarray(0,h))}else console.log("Invalid JPEG file: Missing JPEG marker.");i(o)},"readAsArrayBuffer")||i(o)},e.hasMetaOption=function(e){return e&&e.meta};var i=e.transform;e.transform=function(t,a,o,n,r){e.hasMetaOption(a)?e.parseMetaData(n,function(r){i.call(e,t,a,o,n,r)},a,r):i.apply(e,arguments)}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";"undefined"!=typeof fetch&&"undefined"!=typeof Request&&(e.fetchBlob=function(t,i,a){if(e.hasMetaOption(a))return fetch(new Request(t,a)).then(function(e){return e.blob()}).then(i).catch(function(e){console.log(e),i()});i()})}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";e.ExifMap=function(){return this},e.ExifMap.prototype.map={Orientation:274},e.ExifMap.prototype.get=function(e){return this[e]||this[this.map[e]]},e.getExifThumbnail=function(t,i,a){if(a&&!(i+a>t.byteLength))return e.createObjectURL(new Blob([t.buffer.slice(i,i+a)]));console.log("Invalid Exif data: Invalid thumbnail data.")},e.exifTagTypes={1:{getValue:function(e,t){return e.getUint8(t)},size:1},2:{getValue:function(e,t){return String.fromCharCode(e.getUint8(t))},size:1,ascii:!0},3:{getValue:function(e,t,i){return e.getUint16(t,i)},size:2},4:{getValue:function(e,t,i){return e.getUint32(t,i)},size:4},5:{getValue:function(e,t,i){return e.getUint32(t,i)/e.getUint32(t+4,i)},size:8},9:{getValue:function(e,t,i){return e.getInt32(t,i)},size:4},10:{getValue:function(e,t,i){return e.getInt32(t,i)/e.getInt32(t+4,i)},size:8}},e.exifTagTypes[7]=e.exifTagTypes[1],e.getExifValue=function(t,i,a,o,n,r){var s,l,d,c,u,f,g=e.exifTagTypes[o];if(g){if(s=g.size*n,!((l=s>4?i+t.getUint32(a+8,r):a+8)+s>t.byteLength)){if(1===n)return g.getValue(t,l,r);for(d=[],c=0;ce.byteLength)console.log("Invalid Exif data: Invalid directory offset.");else{if(n=e.getUint16(i,a),!((r=i+2+12*n)+4>e.byteLength)){for(s=0;st.byteLength)console.log("Invalid Exif data: Invalid segment size.");else if(0===t.getUint16(i+8)){switch(t.getUint16(d)){case 18761:r=!0;break;case 19789:r=!1;break;default:return void console.log("Invalid Exif data: Invalid byte alignment marker.")}42===t.getUint16(d+2,r)?(s=t.getUint32(d+4,r),o.exif=new e.ExifMap,(s=e.parseExifTags(t,d,d+s,r,o))&&!n.disableExifThumbnail&&(l={exif:{}},s=e.parseExifTags(t,d,d+s,r,l),l.exif[513]&&(o.exif.Thumbnail=e.getExifThumbnail(t,d+l.exif[513],l.exif[514]))),o.exif[34665]&&!n.disableExifSub&&e.parseExifTags(t,d,d+o.exif[34665],r,o),o.exif[34853]&&!n.disableExifGps&&e.parseExifTags(t,d,d+o.exif[34853],r,o)):console.log("Invalid Exif data: Missing TIFF marker.")}else console.log("Invalid Exif data: Missing byte alignment offset.")}},e.metaDataParsers.jpeg[65505].push(e.parseExifData)}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-exif"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-exif")):e(window.loadImage)}(function(e){"use strict";e.ExifMap.prototype.tags={256:"ImageWidth",257:"ImageHeight",34665:"ExifIFDPointer",34853:"GPSInfoIFDPointer",40965:"InteroperabilityIFDPointer",258:"BitsPerSample",259:"Compression",262:"PhotometricInterpretation",274:"Orientation",277:"SamplesPerPixel",284:"PlanarConfiguration",530:"YCbCrSubSampling",531:"YCbCrPositioning",282:"XResolution",283:"YResolution",296:"ResolutionUnit",273:"StripOffsets",278:"RowsPerStrip",279:"StripByteCounts",513:"JPEGInterchangeFormat",514:"JPEGInterchangeFormatLength",301:"TransferFunction",318:"WhitePoint",319:"PrimaryChromaticities",529:"YCbCrCoefficients",532:"ReferenceBlackWhite",306:"DateTime",270:"ImageDescription",271:"Make",272:"Model",305:"Software",315:"Artist",33432:"Copyright",36864:"ExifVersion",40960:"FlashpixVersion",40961:"ColorSpace",40962:"PixelXDimension",40963:"PixelYDimension",42240:"Gamma",37121:"ComponentsConfiguration",37122:"CompressedBitsPerPixel",37500:"MakerNote",37510:"UserComment",40964:"RelatedSoundFile",36867:"DateTimeOriginal",36868:"DateTimeDigitized",37520:"SubSecTime",37521:"SubSecTimeOriginal",37522:"SubSecTimeDigitized",33434:"ExposureTime",33437:"FNumber",34850:"ExposureProgram",34852:"SpectralSensitivity",34855:"PhotographicSensitivity",34856:"OECF",34864:"SensitivityType",34865:"StandardOutputSensitivity",34866:"RecommendedExposureIndex",34867:"ISOSpeed",34868:"ISOSpeedLatitudeyyy",34869:"ISOSpeedLatitudezzz",37377:"ShutterSpeedValue",37378:"ApertureValue",37379:"BrightnessValue",37380:"ExposureBias",37381:"MaxApertureValue",37382:"SubjectDistance",37383:"MeteringMode",37384:"LightSource",37385:"Flash",37396:"SubjectArea",37386:"FocalLength",41483:"FlashEnergy",41484:"SpatialFrequencyResponse",41486:"FocalPlaneXResolution",41487:"FocalPlaneYResolution",41488:"FocalPlaneResolutionUnit",41492:"SubjectLocation",41493:"ExposureIndex",41495:"SensingMethod",41728:"FileSource",41729:"SceneType",41730:"CFAPattern",41985:"CustomRendered",41986:"ExposureMode",41987:"WhiteBalance",41988:"DigitalZoomRatio",41989:"FocalLengthIn35mmFilm",41990:"SceneCaptureType",41991:"GainControl",41992:"Contrast",41993:"Saturation",41994:"Sharpness",41995:"DeviceSettingDescription",41996:"SubjectDistanceRange",42016:"ImageUniqueID",42032:"CameraOwnerName",42033:"BodySerialNumber",42034:"LensSpecification",42035:"LensMake",42036:"LensModel",42037:"LensSerialNumber",0:"GPSVersionID",1:"GPSLatitudeRef",2:"GPSLatitude",3:"GPSLongitudeRef",4:"GPSLongitude",5:"GPSAltitudeRef",6:"GPSAltitude",7:"GPSTimeStamp",8:"GPSSatellites",9:"GPSStatus",10:"GPSMeasureMode",11:"GPSDOP",12:"GPSSpeedRef",13:"GPSSpeed",14:"GPSTrackRef",15:"GPSTrack",16:"GPSImgDirectionRef",17:"GPSImgDirection",18:"GPSMapDatum",19:"GPSDestLatitudeRef",20:"GPSDestLatitude",21:"GPSDestLongitudeRef",22:"GPSDestLongitude",23:"GPSDestBearingRef",24:"GPSDestBearing",25:"GPSDestDistanceRef",26:"GPSDestDistance",27:"GPSProcessingMethod",28:"GPSAreaInformation",29:"GPSDateStamp",30:"GPSDifferential",31:"GPSHPositioningError"},e.ExifMap.prototype.stringValues={ExposureProgram:{0:"Undefined",1:"Manual",2:"Normal program",3:"Aperture priority",4:"Shutter priority",5:"Creative program",6:"Action program",7:"Portrait mode",8:"Landscape mode"},MeteringMode:{0:"Unknown",1:"Average",2:"CenterWeightedAverage",3:"Spot",4:"MultiSpot",5:"Pattern",6:"Partial",255:"Other"},LightSource:{0:"Unknown",1:"Daylight",2:"Fluorescent",3:"Tungsten (incandescent light)",4:"Flash",9:"Fine weather",10:"Cloudy weather",11:"Shade",12:"Daylight fluorescent (D 5700 - 7100K)",13:"Day white fluorescent (N 4600 - 5400K)",14:"Cool white fluorescent (W 3900 - 4500K)",15:"White fluorescent (WW 3200 - 3700K)",17:"Standard light A",18:"Standard light B",19:"Standard light C",20:"D55",21:"D65",22:"D75",23:"D50",24:"ISO studio tungsten",255:"Other"},Flash:{0:"Flash did not fire",1:"Flash fired",5:"Strobe return light not detected",7:"Strobe return light detected",9:"Flash fired, compulsory flash mode",13:"Flash fired, compulsory flash mode, return light not detected",15:"Flash fired, compulsory flash mode, return light detected",16:"Flash did not fire, compulsory flash mode",24:"Flash did not fire, auto mode",25:"Flash fired, auto mode",29:"Flash fired, auto mode, return light not detected",31:"Flash fired, auto mode, return light detected",32:"No flash function",65:"Flash fired, red-eye reduction mode",69:"Flash fired, red-eye reduction mode, return light not detected",71:"Flash fired, red-eye reduction mode, return light detected",73:"Flash fired, compulsory flash mode, red-eye reduction mode",77:"Flash fired, compulsory flash mode, red-eye reduction mode, return light not detected",79:"Flash fired, compulsory flash mode, red-eye reduction mode, return light detected",89:"Flash fired, auto mode, red-eye reduction mode",93:"Flash fired, auto mode, return light not detected, red-eye reduction mode",95:"Flash fired, auto mode, return light detected, red-eye reduction mode"},SensingMethod:{1:"Undefined",2:"One-chip color area sensor",3:"Two-chip color area sensor",4:"Three-chip color area sensor",5:"Color sequential area sensor",7:"Trilinear sensor",8:"Color sequential linear sensor"},SceneCaptureType:{0:"Standard",1:"Landscape",2:"Portrait",3:"Night scene"},SceneType:{1:"Directly photographed"},CustomRendered:{0:"Normal process",1:"Custom process"},WhiteBalance:{0:"Auto white balance",1:"Manual white balance"},GainControl:{0:"None",1:"Low gain up",2:"High gain up",3:"Low gain down",4:"High gain down"},Contrast:{0:"Normal",1:"Soft",2:"Hard"},Saturation:{0:"Normal",1:"Low saturation",2:"High saturation"},Sharpness:{0:"Normal",1:"Soft",2:"Hard"},SubjectDistanceRange:{0:"Unknown",1:"Macro",2:"Close view",3:"Distant view"},FileSource:{3:"DSC"},ComponentsConfiguration:{0:"",1:"Y",2:"Cb",3:"Cr",4:"R",5:"G",6:"B"},Orientation:{1:"top-left",2:"top-right",3:"bottom-right",4:"bottom-left",5:"left-top",6:"right-top",7:"right-bottom",8:"left-bottom"}},e.ExifMap.prototype.getText=function(e){var t=this.get(e);switch(e){case"LightSource":case"Flash":case"MeteringMode":case"ExposureProgram":case"SensingMethod":case"SceneCaptureType":case"SceneType":case"CustomRendered":case"WhiteBalance":case"GainControl":case"Contrast":case"Saturation":case"Sharpness":case"SubjectDistanceRange":case"FileSource":case"Orientation":return this.stringValues[e][t];case"ExifVersion":case"FlashpixVersion":if(!t)return;return String.fromCharCode(t[0],t[1],t[2],t[3]);case"ComponentsConfiguration":if(!t)return;return this.stringValues[e][t[0]]+this.stringValues[e][t[1]]+this.stringValues[e][t[2]]+this.stringValues[e][t[3]];case"GPSVersionID":if(!t)return;return t[0]+"."+t[1]+"."+t[2]+"."+t[3]}return String(t)},function(e){var t,i=e.tags,a=e.map;for(t in i)i.hasOwnProperty(t)&&(a[i[t]]=t)}(e.ExifMap.prototype),e.ExifMap.prototype.getAll=function(){var e,t,i={};for(e in this)this.hasOwnProperty(e)&&(t=this.tags[e])&&(i[t]=this.getText(t));return i}}),function(e){"use strict";"function"==typeof define&&define.amd?define(["./load-image","./load-image-scale","./load-image-meta"],e):"object"==typeof module&&module.exports?e(require("./load-image"),require("./load-image-scale"),require("./load-image-meta")):e(window.loadImage)}(function(e){"use strict";var t=e.hasCanvasOption,i=e.hasMetaOption,a=e.transformCoordinates,o=e.getTransformedOptions;e.hasCanvasOption=function(i){return!!i.orientation||t.call(e,i)},e.hasMetaOption=function(t){return t&&!0===t.orientation||i.call(e,t)},e.transformCoordinates=function(t,i){a.call(e,t,i);var o=t.getContext("2d"),n=t.width,r=t.height,s=t.style.width,l=t.style.height,d=i.orientation;if(d&&!(d>8))switch(d>4&&(t.width=r,t.height=n,t.style.width=l,t.style.height=s),d){case 2:o.translate(n,0),o.scale(-1,1);break;case 3:o.translate(n,r),o.rotate(Math.PI);break;case 4:o.translate(0,r),o.scale(1,-1);break;case 5:o.rotate(.5*Math.PI),o.scale(1,-1);break;case 6:o.rotate(.5*Math.PI),o.translate(0,-r);break;case 7:o.rotate(.5*Math.PI),o.translate(n,-r),o.scale(-1,1);break;case 8:o.rotate(-.5*Math.PI),o.translate(-n,0)}},e.getTransformedOptions=function(t,i,a){var n,r,s=o.call(e,t,i),l=s.orientation;if(!0===l&&a&&a.exif&&(l=a.exif.get("Orientation")),!l||l>8||1===l)return s;n={};for(r in s)s.hasOwnProperty(r)&&(n[r]=s[r]);switch(n.orientation=l,l){case 2:n.left=s.right,n.right=s.left;break;case 3:n.left=s.right,n.top=s.bottom,n.right=s.left,n.bottom=s.top;break;case 4:n.top=s.bottom,n.bottom=s.top;break;case 5:n.left=s.top,n.top=s.left,n.right=s.bottom,n.bottom=s.right;break;case 6:n.left=s.top,n.top=s.right,n.right=s.bottom,n.bottom=s.left;break;case 7:n.left=s.bottom,n.top=s.right,n.right=s.top,n.bottom=s.left;break;case 8:n.left=s.bottom,n.top=s.left,n.right=s.top,n.bottom=s.right}return n.orientation>4&&(n.maxWidth=s.maxHeight,n.maxHeight=s.maxWidth,n.minWidth=s.minHeight,n.minHeight=s.minWidth,n.sourceWidth=s.sourceHeight,n.sourceHeight=s.sourceWidth),n}});
2 | //# sourceMappingURL=load-image.all.min.js.map
3 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Modified Horovod MNIST example
4 |
5 | import os
6 | import sys
7 | import time
8 |
9 | import horovod.tensorflow as hvd
10 | import numpy as np
11 | import tensorflow as tf
12 | import graphics
13 | from utils import ResultLogger
14 |
15 | learn = tf.contrib.learn
16 |
17 | # Surpress verbose warnings
18 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
19 |
20 |
21 | def _print(*args, **kwargs):
22 | if hvd.rank() == 0:
23 | print(*args, **kwargs)
24 |
25 |
26 | def init_visualizations(hps, model, logdir):
27 |
28 | def sample_batch(y, eps):
29 | n_batch = hps.local_batch_train
30 | xs = []
31 | for i in range(int(np.ceil(len(eps) / n_batch))):
32 | xs.append(model.sample(
33 | y[i*n_batch:i*n_batch + n_batch], eps[i*n_batch:i*n_batch + n_batch]))
34 | return np.concatenate(xs)
35 |
36 | def draw_samples(epoch):
37 | if hvd.rank() != 0:
38 | return
39 |
40 | rows = 10 if hps.image_size <= 64 else 4
41 | cols = rows
42 | n_batch = rows*cols
43 | y = np.asarray([_y % hps.n_y for _y in (
44 | list(range(cols)) * rows)], dtype='int32')
45 |
46 | # temperatures = [0., .25, .5, .626, .75, .875, 1.] #previously
47 | temperatures = [0., .25, .5, .6, .7, .8, .9, 1.]
48 |
49 | x_samples = []
50 | x_samples.append(sample_batch(y, [.0]*n_batch))
51 | x_samples.append(sample_batch(y, [.25]*n_batch))
52 | x_samples.append(sample_batch(y, [.5]*n_batch))
53 | x_samples.append(sample_batch(y, [.6]*n_batch))
54 | x_samples.append(sample_batch(y, [.7]*n_batch))
55 | x_samples.append(sample_batch(y, [.8]*n_batch))
56 | x_samples.append(sample_batch(y, [.9] * n_batch))
57 | x_samples.append(sample_batch(y, [1.]*n_batch))
58 | # previously: 0, .25, .5, .625, .75, .875, 1.
59 |
60 | for i in range(len(x_samples)):
61 | x_sample = np.reshape(
62 | x_samples[i], (n_batch, hps.image_size, hps.image_size, 3))
63 | graphics.save_raster(x_sample, logdir +
64 | 'epoch_{}_sample_{}.png'.format(epoch, i))
65 |
66 | return draw_samples
67 |
68 | # ===
69 | # Code for getting data
70 | # ===
71 | def get_data(hps, sess):
72 | if hps.image_size == -1:
73 | hps.image_size = {'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64,
74 | 'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem]
75 | if hps.n_test == -1:
76 | hps.n_test = {'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000,
77 | 'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem]
78 | hps.n_y = {'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000,
79 | 'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem]
80 | if hps.data_dir == "":
81 | hps.data_dir = {'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr',
82 | 'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem]
83 |
84 | if hps.problem == 'lsun_realnvp':
85 | hps.rnd_crop = True
86 | else:
87 | hps.rnd_crop = False
88 |
89 | if hps.category:
90 | hps.data_dir += ('/%s' % hps.category)
91 |
92 | # Use anchor_size to rescale batch size based on image_size
93 | s = hps.anchor_size
94 | hps.local_batch_train = hps.n_batch_train * \
95 | s * s // (hps.image_size * hps.image_size)
96 | hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[
97 | hps.local_batch_train] # round down to closest divisor of 50
98 | hps.local_batch_init = hps.n_batch_init * \
99 | s * s // (hps.image_size * hps.image_size)
100 |
101 | print("Rank {} Batch sizes Train {} Test {} Init {}".format(
102 | hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init))
103 |
104 | if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']:
105 | hps.direct_iterator = True
106 | import data_loaders.get_data as v
107 | train_iterator, test_iterator, data_init = \
108 | v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train,
109 | hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop)
110 |
111 | elif hps.problem in ['mnist', 'cifar10']:
112 | hps.direct_iterator = False
113 | import data_loaders.get_mnist_cifar as v
114 | train_iterator, test_iterator, data_init = \
115 | v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train,
116 | hps.local_batch_test, hps.local_batch_init, hps.image_size)
117 |
118 | else:
119 | raise Exception()
120 |
121 | return train_iterator, test_iterator, data_init
122 |
123 |
124 | def process_results(results):
125 | stats = ['loss', 'bits_x', 'bits_y', 'pred_loss']
126 | assert len(stats) == results.shape[0]
127 | res_dict = {}
128 | for i in range(len(stats)):
129 | res_dict[stats[i]] = "{:.4f}".format(results[i])
130 | return res_dict
131 |
132 |
133 | def main(hps):
134 |
135 | # Initialize Horovod.
136 | hvd.init()
137 |
138 | # Create tensorflow session
139 | sess = tensorflow_session()
140 |
141 | # Download and load dataset.
142 | tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed)
143 | np.random.seed(hvd.rank() + hvd.size() * hps.seed)
144 |
145 | # Get data and set train_its and valid_its
146 | train_iterator, test_iterator, data_init = get_data(hps, sess)
147 | hps.train_its, hps.test_its, hps.full_test_its = get_its(hps)
148 |
149 | # Create log dir
150 | logdir = os.path.abspath(hps.logdir) + "/"
151 | if not os.path.exists(logdir):
152 | os.mkdir(logdir)
153 |
154 | # Create model
155 | import model
156 | model = model.model(sess, hps, train_iterator, test_iterator, data_init)
157 |
158 | # Initialize visualization functions
159 | visualise = init_visualizations(hps, model, logdir)
160 |
161 | if not hps.inference:
162 | # Perform training
163 | train(sess, model, hps, logdir, visualise)
164 | else:
165 | infer(sess, model, hps, test_iterator)
166 |
167 |
168 | def infer(sess, model, hps, iterator):
169 | # Example of using model in inference mode. Load saved model using hps.restore_path
170 | # Can provide x, y from files instead of dataset iterator
171 | # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
172 | if hps.direct_iterator:
173 | iterator = iterator.get_next()
174 |
175 | xs = []
176 | zs = []
177 | for it in range(hps.full_test_its):
178 | if hps.direct_iterator:
179 | # replace with x, y, attr if you're getting CelebA attributes, also modify get_data
180 | x, y = sess.run(iterator)
181 | else:
182 | x, y = iterator()
183 |
184 | z = model.encode(x, y)
185 | x = model.decode(y, z)
186 | xs.append(x)
187 | zs.append(z)
188 |
189 | x = np.concatenate(xs, axis=0)
190 | z = np.concatenate(zs, axis=0)
191 | np.save('logs/x.npy', x)
192 | np.save('logs/z.npy', z)
193 | return zs
194 |
195 |
196 | def train(sess, model, hps, logdir, visualise):
197 | _print(hps)
198 | _print('Starting training. Logging to', logdir)
199 | _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')
200 |
201 | # Train
202 | sess.graph.finalize()
203 | n_processed = 0
204 | n_images = 0
205 | train_time = 0.0
206 | test_loss_best = 999999
207 |
208 | if hvd.rank() == 0:
209 | train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
210 | test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)
211 |
212 | tcurr = time.time()
213 | for epoch in range(1, hps.epochs):
214 |
215 | t = time.time()
216 |
217 | train_results = []
218 | for it in range(hps.train_its):
219 |
220 | # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
221 | lr = hps.lr * min(1., n_processed /
222 | (hps.n_train * hps.epochs_warmup))
223 |
224 | # Run a training step synchronously.
225 | _t = time.time()
226 | train_results += [model.train(lr)]
227 | if hps.verbose and hvd.rank() == 0:
228 | _print(n_processed, time.time()-_t, train_results[-1])
229 | sys.stdout.flush()
230 |
231 | # Images seen wrt anchor resolution
232 | n_processed += hvd.size() * hps.n_batch_train
233 | # Actual images seen at current resolution
234 | n_images += hvd.size() * hps.local_batch_train
235 |
236 | train_results = np.mean(np.asarray(train_results), axis=0)
237 |
238 | dtrain = time.time() - t
239 | ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
240 | train_time += dtrain
241 |
242 | if hvd.rank() == 0:
243 | train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
244 | train_time), **process_results(train_results))
245 |
246 | if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
247 | test_results = []
248 | msg = ''
249 |
250 | t = time.time()
251 | # model.polyak_swap()
252 |
253 | if epoch % hps.epochs_full_valid == 0:
254 | # Full validation run
255 | for it in range(hps.full_test_its):
256 | test_results += [model.test()]
257 | test_results = np.mean(np.asarray(test_results), axis=0)
258 |
259 | if hvd.rank() == 0:
260 | test_logger.log(epoch=epoch, n_processed=n_processed,
261 | n_images=n_images, **process_results(test_results))
262 |
263 | # Save checkpoint
264 | if test_results[0] < test_loss_best:
265 | test_loss_best = test_results[0]
266 | model.save(logdir+"model_best_loss.ckpt")
267 | msg += ' *'
268 |
269 | dtest = time.time() - t
270 |
271 | # Sample
272 | t = time.time()
273 | if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
274 | visualise(epoch)
275 | dsample = time.time() - t
276 |
277 | if hvd.rank() == 0:
278 | dcurr = time.time() - tcurr
279 | tcurr = time.time()
280 | _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
281 | ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg)
282 |
283 | # model.polyak_swap()
284 |
285 | if hvd.rank() == 0:
286 | _print("Finished!")
287 |
288 | # Get number of training and validation iterations
289 | def get_its(hps):
290 | # These run for a fixed amount of time. As anchored batch is smaller, we've actually seen fewer examples
291 | train_its = int(np.ceil(hps.n_train / (hps.n_batch_train * hvd.size())))
292 | test_its = int(np.ceil(hps.n_test / (hps.n_batch_train * hvd.size())))
293 | train_epoch = train_its * hps.n_batch_train * hvd.size()
294 |
295 | # Do a full validation run
296 | if hvd.rank() == 0:
297 | print(hps.n_test, hps.local_batch_test, hvd.size())
298 | assert hps.n_test % (hps.local_batch_test * hvd.size()) == 0
299 | full_test_its = hps.n_test // (hps.local_batch_test * hvd.size())
300 |
301 | if hvd.rank() == 0:
302 | print("Train epoch size: " + str(train_epoch))
303 | return train_its, test_its, full_test_its
304 |
305 |
306 | '''
307 | Create tensorflow session with horovod
308 | '''
309 | def tensorflow_session():
310 | # Init session and params
311 | config = tf.ConfigProto()
312 | config.gpu_options.allow_growth = True
313 | # Pin GPU to local rank (one GPU per process)
314 | config.gpu_options.visible_device_list = str(hvd.local_rank())
315 | sess = tf.Session(config=config)
316 | return sess
317 |
318 |
319 | if __name__ == "__main__":
320 |
321 | # This enables a ctr-C without triggering errors
322 | import signal
323 | signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
324 |
325 | import argparse
326 | parser = argparse.ArgumentParser()
327 | parser.add_argument("--verbose", action='store_true', help="Verbose mode")
328 | parser.add_argument("--restore_path", type=str, default='',
329 | help="Location of checkpoint to restore")
330 | parser.add_argument("--inference", action="store_true",
331 | help="Use in inference mode")
332 | parser.add_argument("--logdir", type=str,
333 | default='./logs', help="Location to save logs")
334 |
335 | # Dataset hyperparams:
336 | parser.add_argument("--problem", type=str, default='cifar10',
337 | help="Problem (mnist/cifar10/imagenet")
338 | parser.add_argument("--category", type=str,
339 | default='', help="LSUN category")
340 | parser.add_argument("--data_dir", type=str, default='',
341 | help="Location of data")
342 | parser.add_argument("--dal", type=int, default=1,
343 | help="Data augmentation level: 0=None, 1=Standard, 2=Extra")
344 |
345 | # New dataloader params
346 | parser.add_argument("--fmap", type=int, default=1,
347 | help="# Threads for parallel file reading")
348 | parser.add_argument("--pmap", type=int, default=16,
349 | help="# Threads for parallel map")
350 |
351 | # Optimization hyperparams:
352 | parser.add_argument("--n_train", type=int,
353 | default=50000, help="Train epoch size")
354 | parser.add_argument("--n_test", type=int, default=-
355 | 1, help="Valid epoch size")
356 | parser.add_argument("--n_batch_train", type=int,
357 | default=64, help="Minibatch size")
358 | parser.add_argument("--n_batch_test", type=int,
359 | default=50, help="Minibatch size")
360 | parser.add_argument("--n_batch_init", type=int, default=256,
361 | help="Minibatch size for data-dependent init")
362 | parser.add_argument("--optimizer", type=str,
363 | default="adamax", help="adam or adamax")
364 | parser.add_argument("--lr", type=float, default=0.001,
365 | help="Base learning rate")
366 | parser.add_argument("--beta1", type=float, default=.9, help="Adam beta1")
367 | parser.add_argument("--polyak_epochs", type=float, default=1,
368 | help="Nr of averaging epochs for Polyak and beta2")
369 | parser.add_argument("--weight_decay", type=float, default=1.,
370 | help="Weight decay. Switched off by default.")
371 | parser.add_argument("--epochs", type=int, default=1000000,
372 | help="Total number of training epochs")
373 | parser.add_argument("--epochs_warmup", type=int,
374 | default=10, help="Warmup epochs")
375 | parser.add_argument("--epochs_full_valid", type=int,
376 | default=50, help="Epochs between valid")
377 | parser.add_argument("--gradient_checkpointing", type=int,
378 | default=1, help="Use memory saving gradients")
379 |
380 | # Model hyperparams:
381 | parser.add_argument("--image_size", type=int,
382 | default=-1, help="Image size")
383 | parser.add_argument("--anchor_size", type=int, default=32,
384 | help="Anchor size for deciding batch size")
385 | parser.add_argument("--width", type=int, default=512,
386 | help="Width of hidden layers")
387 | parser.add_argument("--depth", type=int, default=32,
388 | help="Depth of network")
389 | parser.add_argument("--weight_y", type=float, default=0.00,
390 | help="Weight of log p(y|x) in weighted loss")
391 | parser.add_argument("--n_bits_x", type=int, default=8,
392 | help="Number of bits of x")
393 | parser.add_argument("--n_levels", type=int, default=3,
394 | help="Number of levels")
395 |
396 | # Synthesis/Sampling hyperparameters:
397 | parser.add_argument("--n_sample", type=int, default=1,
398 | help="minibatch size for sample")
399 | parser.add_argument("--epochs_full_sample", type=int,
400 | default=50, help="Epochs between full scale sample")
401 |
402 | # Ablation
403 | parser.add_argument("--learntop", action="store_true",
404 | help="Learn spatial prior")
405 | parser.add_argument("--ycond", action="store_true",
406 | help="Use y conditioning")
407 | parser.add_argument("--seed", type=int, default=0, help="Random seed")
408 | parser.add_argument("--flow_permutation", type=int, default=2,
409 | help="Type of flow. 0=reverse (realnvp), 1=shuffle, 2=invconv (ours)")
410 | parser.add_argument("--flow_coupling", type=int, default=0,
411 | help="Coupling type: 0=additive, 1=affine")
412 |
413 | hps = parser.parse_args() # So error if typo
414 | main(hps)
415 |
--------------------------------------------------------------------------------
/memory_saving_gradients.py:
--------------------------------------------------------------------------------
1 | from toposort import toposort
2 | import contextlib
3 | import numpy as np
4 | import tensorflow as tf
5 | import tensorflow.contrib.graph_editor as ge
6 | import time
7 | import sys
8 | sys.setrecursionlimit(10000)
9 | # refers back to current module if we decide to split helpers out
10 | util = sys.modules[__name__]
11 |
12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated"
13 | setattr(tf.GraphKeys, "VARIABLES", "variables")
14 |
15 | # save original gradients since tf.gradient could be monkey-patched to point
16 | # to our version
17 | from tensorflow.python.ops import gradients as tf_gradients_lib
18 | tf_gradients = tf_gradients_lib.gradients
19 |
20 | MIN_CHECKPOINT_NODE_SIZE = 1024 # use lower value during testing
21 |
22 | # specific versions we can use to do process-wide replacement of tf.gradients
23 |
24 |
25 | def gradients_speed(ys, xs, grad_ys=None, **kwargs):
26 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs)
27 |
28 |
29 | def gradients_memory(ys, xs, grad_ys=None, **kwargs):
30 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs)
31 |
32 |
33 | def gradients_collection(ys, xs, grad_ys=None, **kwargs):
34 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs)
35 |
36 |
37 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
38 | '''
39 | Authors: Tim Salimans & Yaroslav Bulatov
40 |
41 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
42 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)
43 |
44 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
45 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)
46 |
47 | 'checkpoints' can either be
48 | - a list consisting of tensors from the forward pass of the neural net
49 | that we should re-use when calculating the gradients in the backward pass
50 | all other tensors that do not appear in this list will be re-computed
51 | - a string specifying how this list should be determined. currently we support
52 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
53 | so checkpointing them maximizes the running speed
54 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
55 | - 'memory': try to minimize the memory usage
56 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
57 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
58 | '''
59 |
60 | # print("Calling memsaving gradients with", checkpoints)
61 | if not isinstance(ys, list):
62 | ys = [ys]
63 | if not isinstance(xs, list):
64 | xs = [xs]
65 |
66 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
67 | inclusive=True)
68 |
69 | debug_print("bwd_ops: %s", bwd_ops)
70 |
71 | # forward ops are all ops that are candidates for recomputation
72 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
73 | inclusive=True,
74 | within_ops=bwd_ops)
75 | debug_print("fwd_ops: %s", fwd_ops)
76 |
77 | # exclude ops with no inputs
78 | fwd_ops = [op for op in fwd_ops if op.inputs]
79 |
80 | # don't recompute xs, remove variables
81 | xs_ops = _to_ops(xs)
82 | fwd_ops = [op for op in fwd_ops if not op in xs_ops]
83 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
84 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
85 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
86 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors
87 | ts_all = [t for t in ts_all if '/read' not in t.name]
88 | ts_all = set(ts_all) - set(xs) - set(ys)
89 |
90 | # construct list of tensors to checkpoint during forward pass, if not
91 | # given as input
92 | if type(checkpoints) is not list:
93 | if checkpoints == 'collection':
94 | checkpoints = tf.get_collection('checkpoints')
95 |
96 | elif checkpoints == 'speed':
97 | # checkpoint all expensive ops to maximize running speed
98 | checkpoints = ge.filter_ts_from_regex(
99 | fwd_ops, 'conv2d|Conv|MatMul')
100 |
101 | elif checkpoints == 'memory':
102 |
103 | # remove very small tensors and some weird ops
104 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually
105 | try:
106 | return [int(e if e.value is not None else 64) for e in t]
107 | except:
108 | return [0] # unknown shape
109 | ts_all = [t for t in ts_all if np.prod(
110 | fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
111 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
112 | ts_all = [t for t in ts_all if 'entropy' not in t.name]
113 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
114 | ts_all = [t for t in ts_all if 'Switch' not in t.name]
115 | ts_all = [t for t in ts_all if 'dropout' not in t.name]
116 |
117 | # filter out all tensors that are inputs of the backward graph
118 | with util.capture_ops() as bwd_ops:
119 | tf_gradients(ys, xs, grad_ys, **kwargs)
120 |
121 | bwd_inputs = [t for op in bwd_ops for t in op.inputs]
122 | # list of tensors in forward graph that is in input to bwd graph
123 | ts_filtered = list(set(bwd_inputs).intersection(ts_all))
124 | debug_print("Using tensors %s", ts_filtered)
125 |
126 | # try two slightly different ways of getting bottlenecks tensors
127 | # to checkpoint
128 | for ts in [ts_filtered, ts_all]:
129 |
130 | # get all bottlenecks in the graph
131 | bottleneck_ts = []
132 | for t in ts:
133 | b = set(ge.get_backward_walk_ops(
134 | t.op, inclusive=True, within_ops=fwd_ops))
135 | f = set(ge.get_forward_walk_ops(
136 | t.op, inclusive=False, within_ops=fwd_ops))
137 | # check that there are not shortcuts
138 | b_inp = set(
139 | [inp for op in b for inp in op.inputs]).intersection(ts_all)
140 | f_inp = set(
141 | [inp for op in f for inp in op.inputs]).intersection(ts_all)
142 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
143 | bottleneck_ts.append(t) # we have a bottleneck!
144 | else:
145 | debug_print("Rejected bottleneck candidate and ops %s", [
146 | t] + list(set(ts_all) - set(b_inp) - set(f_inp)))
147 |
148 | # success? or try again without filtering?
149 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found!
150 | break
151 |
152 | if not bottleneck_ts:
153 | raise Exception(
154 | 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".')
155 |
156 | # sort the bottlenecks
157 | bottlenecks_sorted_lists = tf_toposort(
158 | bottleneck_ts, within_ops=fwd_ops)
159 | sorted_bottlenecks = [
160 | t for ts in bottlenecks_sorted_lists for t in ts]
161 |
162 | # save an approximately optimal number ~ sqrt(N)
163 | N = len(ts_filtered)
164 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
165 | checkpoints = sorted_bottlenecks
166 | else:
167 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
168 | checkpoints = sorted_bottlenecks[step::step]
169 |
170 | else:
171 | raise Exception(
172 | '%s is unsupported input for "checkpoints"' % (checkpoints,))
173 |
174 | checkpoints = list(set(checkpoints).intersection(ts_all))
175 |
176 | # at this point automatic selection happened and checkpoints is list of nodes
177 | assert isinstance(checkpoints, list)
178 |
179 | debug_print("Checkpoint nodes used: %s", checkpoints)
180 | # better error handling of special cases
181 | # xs are already handled as checkpoint nodes, so no need to include them
182 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
183 | if xs_intersect_checkpoints:
184 | debug_print("Warning, some input nodes are also checkpoint nodes: %s",
185 | xs_intersect_checkpoints)
186 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
187 | debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
188 | ys_intersect_checkpoints)
189 | # saving an output node (ys) gives no benefit in memory while creating
190 | # new edge cases, exclude them
191 | if ys_intersect_checkpoints:
192 | debug_print("Warning, some output nodes are also checkpoints nodes: %s",
193 | format_ops(ys_intersect_checkpoints))
194 |
195 | # remove initial and terminal nodes from checkpoints list if present
196 | checkpoints = list(set(checkpoints) - set(ys) - set(xs))
197 |
198 | # check that we have some nodes to checkpoint
199 | if not checkpoints:
200 | raise Exception('no checkpoints nodes found or given as input! ')
201 |
202 | # disconnect dependencies between checkpointed tensors
203 | checkpoints_disconnected = {}
204 | for x in checkpoints:
205 | if x.op and x.op.name is not None:
206 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
207 | else:
208 | grad_node = tf.stop_gradient(x)
209 | checkpoints_disconnected[x] = grad_node
210 |
211 | # partial derivatives to the checkpointed tensors and xs
212 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
213 | stop_at_ts=checkpoints, within_ops=fwd_ops)
214 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
215 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
216 | debug_print("ops_to_copy = %s", ops_to_copy)
217 | debug_print("Processing list %s", ys)
218 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
219 | copied_ops = info._transformed_ops.values()
220 | debug_print("Copied %s to %s", ops_to_copy, copied_ops)
221 | ge.reroute_ts(checkpoints_disconnected.values(),
222 | checkpoints_disconnected.keys(), can_modify=copied_ops)
223 | debug_print("Rewired %s in place of %s restricted to %s",
224 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops)
225 |
226 | # get gradients with respect to current boundary + original x's
227 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
228 | boundary = list(checkpoints_disconnected.values())
229 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
230 | debug_print("Got gradients %s", dv)
231 | debug_print("for %s", copied_ys)
232 | debug_print("with respect to %s", boundary+xs)
233 |
234 | inputs_to_do_before = [y.op for y in ys]
235 | if grad_ys is not None:
236 | inputs_to_do_before += grad_ys
237 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
238 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
239 |
240 | # partial derivatives to the checkpointed nodes
241 | # dictionary of "node: backprop" for nodes in the boundary
242 | d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
243 | dv[:len(checkpoints_disconnected)])}
244 | # partial derivatives to xs (usually the params of the neural net)
245 | d_xs = dv[len(checkpoints_disconnected):]
246 |
247 | # incorporate derivatives flowing through the checkpointed nodes
248 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
249 | for ts in checkpoints_sorted_lists[::-1]:
250 | debug_print("Processing list %s", ts)
251 | checkpoints_other = [r for r in checkpoints if r not in ts]
252 | checkpoints_disconnected_other = [
253 | checkpoints_disconnected[r] for r in checkpoints_other]
254 |
255 | # copy part of the graph below current checkpoint node, stopping at
256 | # other checkpoints nodes
257 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[
258 | r.op for r in ts], stop_at_ts=checkpoints_other)
259 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
260 | len(ops_to_copy), fwd_ops, [r.op for r in ts],
261 | checkpoints_other)
262 | debug_print("ops_to_copy = %s", ops_to_copy)
263 | if not ops_to_copy: # we're done!
264 | break
265 | copied_sgv, info = ge.copy_with_input_replacements(
266 | ge.sgv(ops_to_copy), {})
267 | copied_ops = info._transformed_ops.values()
268 | debug_print("Copied %s to %s", ops_to_copy, copied_ops)
269 | ge.reroute_ts(checkpoints_disconnected_other,
270 | checkpoints_other, can_modify=copied_ops)
271 | debug_print("Rewired %s in place of %s restricted to %s",
272 | checkpoints_disconnected_other, checkpoints_other, copied_ops)
273 |
274 | # gradient flowing through the checkpointed node
275 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
276 | substitute_backprops = [d_checkpoints[r] for r in ts]
277 | dv = tf_gradients(boundary,
278 | checkpoints_disconnected_other+xs,
279 | grad_ys=substitute_backprops, **kwargs)
280 | debug_print("Got gradients %s", dv)
281 | debug_print("for %s", boundary)
282 | debug_print("with respect to %s", checkpoints_disconnected_other+xs)
283 | debug_print("with boundary backprop substitutions %s",
284 | substitute_backprops)
285 |
286 | inputs_to_do_before = [d_checkpoints[r].op for r in ts]
287 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
288 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)
289 |
290 | # partial derivatives to the checkpointed nodes
291 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
292 | if dr is not None:
293 | if d_checkpoints[r] is None:
294 | d_checkpoints[r] = dr
295 | else:
296 | d_checkpoints[r] += dr
297 |
298 | # partial derivatives to xs (usually the params of the neural net)
299 | d_xs_new = dv[len(checkpoints_other):]
300 | for j in range(len(xs)):
301 | if d_xs_new[j] is not None:
302 | if d_xs[j] is None:
303 | d_xs[j] = d_xs_new[j]
304 | else:
305 | d_xs[j] += d_xs_new[j]
306 |
307 | return d_xs
308 |
309 |
310 | def tf_toposort(ts, within_ops=None):
311 | all_ops = ge.get_forward_walk_ops(
312 | [x.op for x in ts], within_ops=within_ops)
313 |
314 | deps = {}
315 | for op in all_ops:
316 | for o in op.outputs:
317 | deps[o] = set(op.inputs)
318 | sorted_ts = toposort(deps)
319 |
320 | # only keep the tensors from our original list
321 | ts_sorted_lists = []
322 | for l in sorted_ts:
323 | keep = list(set(l).intersection(ts))
324 | if keep:
325 | ts_sorted_lists.append(keep)
326 |
327 | return ts_sorted_lists
328 |
329 |
330 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts):
331 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts))
332 | ops = bwd_ops.intersection(within_ops).difference(
333 | [t.op for t in stop_at_ts])
334 | return list(ops)
335 |
336 |
337 | @contextlib.contextmanager
338 | def capture_ops():
339 | """Decorator to capture ops created in the block.
340 | with capture_ops() as ops:
341 | # create some ops
342 | print(ops) # => prints ops created.
343 | """
344 |
345 | micros = int(time.time()*10**6)
346 | scope_name = str(micros)
347 | op_list = []
348 | with tf.name_scope(scope_name):
349 | yield op_list
350 |
351 | g = tf.get_default_graph()
352 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g))
353 |
354 |
355 | def _to_op(tensor_or_op):
356 | if hasattr(tensor_or_op, "op"):
357 | return tensor_or_op.op
358 | return tensor_or_op
359 |
360 |
361 | def _to_ops(iterable):
362 | if not _is_iterable(iterable):
363 | return iterable
364 | return [_to_op(i) for i in iterable]
365 |
366 |
367 | def _is_iterable(o):
368 | try:
369 | _ = iter(o)
370 | except Exception:
371 | return False
372 | return True
373 |
374 |
375 | DEBUG_LOGGING = False
376 |
377 |
378 | def debug_print(s, *args):
379 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their
380 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug
381 |
382 | Usage:
383 | debug_print("see tensors %s for %s", tensorlist, [1,2,3])
384 | """
385 |
386 | if DEBUG_LOGGING:
387 | formatted_args = [format_ops(arg) for arg in args]
388 | print("DEBUG "+s % tuple(formatted_args))
389 |
390 |
391 | def format_ops(ops, sort_outputs=True):
392 | """Helper method for printing ops. Converts Tensor/Operation op to op.name,
393 | rest to str(op)."""
394 |
395 | if hasattr(ops, '__iter__') and not isinstance(ops, str):
396 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops]
397 | if sort_outputs:
398 | return sorted(l)
399 | return l
400 | else:
401 | return ops.name if hasattr(ops, "name") else str(ops)
402 |
403 |
404 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before):
405 | for op in wait_to_do_ops:
406 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs]
407 | ge.add_control_inputs(op, ci)
408 |
--------------------------------------------------------------------------------
/tfops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.framework.python.ops import add_arg_scope, arg_scope
3 | from tensorflow.contrib.layers import variance_scaling_initializer
4 | import numpy as np
5 | import horovod.tensorflow as hvd
6 |
7 | # Debugging function
8 | do_print_act_stats = True
9 |
10 |
11 | def print_act_stats(x, _str=""):
12 | if not do_print_act_stats:
13 | return x
14 | if hvd.rank() != 0:
15 | return x
16 | if len(x.get_shape()) == 1:
17 | x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True)
18 | if len(x.get_shape()) == 2:
19 | x_mean, x_var = tf.nn.moments(x, [0], keep_dims=True)
20 | if len(x.get_shape()) == 4:
21 | x_mean, x_var = tf.nn.moments(x, [0, 1, 2], keep_dims=True)
22 | stats = [tf.reduce_min(x_mean), tf.reduce_mean(x_mean), tf.reduce_max(x_mean),
23 | tf.reduce_min(tf.sqrt(x_var)), tf.reduce_mean(tf.sqrt(x_var)), tf.reduce_max(tf.sqrt(x_var))]
24 | return tf.Print(x, stats, "["+_str+"] "+x.name)
25 |
26 | # Allreduce methods
27 |
28 |
29 | def allreduce_sum(x):
30 | if hvd.size() == 1:
31 | return x
32 | return hvd.mpi_ops._allreduce(x)
33 |
34 |
35 | def allreduce_mean(x):
36 | x = allreduce_sum(x) / hvd.size()
37 | return x
38 |
39 |
40 | def default_initial_value(shape, std=0.05):
41 | return tf.random_normal(shape, 0., std)
42 |
43 |
44 | def default_initializer(std=0.05):
45 | return tf.random_normal_initializer(0., std)
46 |
47 |
48 | def int_shape(x):
49 | if str(x.get_shape()[0]) != '?':
50 | return list(map(int, x.get_shape()))
51 | return [-1]+list(map(int, x.get_shape()[1:]))
52 |
53 | # wrapper tf.get_variable, augmented with 'init' functionality
54 | # Get variable with data dependent init
55 |
56 |
57 | @add_arg_scope
58 | def get_variable_ddi(name, shape, initial_value, dtype=tf.float32, init=False, trainable=True):
59 | w = tf.get_variable(name, shape, dtype, None, trainable=trainable)
60 | if init:
61 | w = w.assign(initial_value)
62 | with tf.control_dependencies([w]):
63 | return w
64 | return w
65 |
66 | # Activation normalization
67 | # Convenience function that does centering+scaling
68 |
69 |
70 | @add_arg_scope
71 | def actnorm(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
72 | if arg_scope([get_variable_ddi], trainable=trainable):
73 | if not reverse:
74 | x = actnorm_center(name+"_center", x, reverse)
75 | x = actnorm_scale(name+"_scale", x, scale, logdet,
76 | logscale_factor, batch_variance, reverse, init)
77 | if logdet != None:
78 | x, logdet = x
79 | else:
80 | x = actnorm_scale(name + "_scale", x, scale, logdet,
81 | logscale_factor, batch_variance, reverse, init)
82 | if logdet != None:
83 | x, logdet = x
84 | x = actnorm_center(name+"_center", x, reverse)
85 | if logdet != None:
86 | return x, logdet
87 | return x
88 |
89 | # Activation normalization
90 |
91 |
92 | @add_arg_scope
93 | def actnorm_center(name, x, reverse=False):
94 | shape = x.get_shape()
95 | with tf.variable_scope(name):
96 | assert len(shape) == 2 or len(shape) == 4
97 | if len(shape) == 2:
98 | x_mean = tf.reduce_mean(x, [0], keepdims=True)
99 | b = get_variable_ddi(
100 | "b", (1, int_shape(x)[1]), initial_value=-x_mean)
101 | elif len(shape) == 4:
102 | x_mean = tf.reduce_mean(x, [0, 1, 2], keepdims=True)
103 | b = get_variable_ddi(
104 | "b", (1, 1, 1, int_shape(x)[3]), initial_value=-x_mean)
105 |
106 | if not reverse:
107 | x += b
108 | else:
109 | x -= b
110 |
111 | return x
112 |
113 | # Activation normalization
114 |
115 |
116 | @add_arg_scope
117 | def actnorm_scale(name, x, scale=1., logdet=None, logscale_factor=3., batch_variance=False, reverse=False, init=False, trainable=True):
118 | shape = x.get_shape()
119 | with tf.variable_scope(name), arg_scope([get_variable_ddi], trainable=trainable):
120 | assert len(shape) == 2 or len(shape) == 4
121 | if len(shape) == 2:
122 | x_var = tf.reduce_mean(x**2, [0], keepdims=True)
123 | logdet_factor = 1
124 | _shape = (1, int_shape(x)[1])
125 |
126 | elif len(shape) == 4:
127 | x_var = tf.reduce_mean(x**2, [0, 1, 2], keepdims=True)
128 | logdet_factor = int(shape[1])*int(shape[2])
129 | _shape = (1, 1, 1, int_shape(x)[3])
130 |
131 | if batch_variance:
132 | x_var = tf.reduce_mean(x**2, keepdims=True)
133 |
134 | if init and False:
135 | # MPI all-reduce
136 | x_var = allreduce_mean(x_var)
137 | # Somehow this also slows down graph when not initializing
138 | # (it's not optimized away?)
139 |
140 | if True:
141 | logs = get_variable_ddi("logs", _shape, initial_value=tf.log(
142 | scale/(tf.sqrt(x_var)+1e-6))/logscale_factor)*logscale_factor
143 | if not reverse:
144 | x = x * tf.exp(logs)
145 | else:
146 | x = x * tf.exp(-logs)
147 | else:
148 | # Alternative, doesn't seem to do significantly worse or better than the logarithmic version above
149 | s = get_variable_ddi("s", _shape, initial_value=scale /
150 | (tf.sqrt(x_var) + 1e-6) / logscale_factor)*logscale_factor
151 | logs = tf.log(tf.abs(s))
152 | if not reverse:
153 | x *= s
154 | else:
155 | x /= s
156 |
157 | if logdet != None:
158 | dlogdet = tf.reduce_sum(logs) * logdet_factor
159 | if reverse:
160 | dlogdet *= -1
161 | return x, logdet + dlogdet
162 |
163 | return x
164 |
165 | # Linear layer with layer norm
166 |
167 |
168 | @add_arg_scope
169 | def linear(name, x, width, do_weightnorm=True, do_actnorm=True, initializer=None, scale=1.):
170 | initializer = initializer or default_initializer()
171 | with tf.variable_scope(name):
172 | n_in = int(x.get_shape()[1])
173 | w = tf.get_variable("W", [n_in, width],
174 | tf.float32, initializer=initializer)
175 | if do_weightnorm:
176 | w = tf.nn.l2_normalize(w, [0])
177 | x = tf.matmul(x, w)
178 | x += tf.get_variable("b", [1, width],
179 | initializer=tf.zeros_initializer())
180 | if do_actnorm:
181 | x = actnorm("actnorm", x, scale)
182 | return x
183 |
184 | # Linear layer with zero init
185 |
186 |
187 | @add_arg_scope
188 | def linear_zeros(name, x, width, logscale_factor=3):
189 | with tf.variable_scope(name):
190 | n_in = int(x.get_shape()[1])
191 | w = tf.get_variable("W", [n_in, width], tf.float32,
192 | initializer=tf.zeros_initializer())
193 | x = tf.matmul(x, w)
194 | x += tf.get_variable("b", [1, width],
195 | initializer=tf.zeros_initializer())
196 | x *= tf.exp(tf.get_variable("logs",
197 | [1, width], initializer=tf.zeros_initializer()) * logscale_factor)
198 | return x
199 |
200 | # Slow way to add edge padding
201 |
202 |
203 | def add_edge_padding(x, filter_size):
204 | assert filter_size[0] % 2 == 1
205 | if filter_size[0] == 1 and filter_size[1] == 1:
206 | return x
207 | a = (filter_size[0] - 1) // 2 # vertical padding size
208 | b = (filter_size[1] - 1) // 2 # horizontal padding size
209 | if True:
210 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
211 | name = "_".join([str(dim) for dim in [a, b, *int_shape(x)[1:3]]])
212 | pads = tf.get_collection(name)
213 | if not pads:
214 | if hvd.rank() == 0:
215 | print("Creating pad", name)
216 | pad = np.zeros([1] + int_shape(x)[1:3] + [1], dtype='float32')
217 | pad[:, :a, :, 0] = 1.
218 | pad[:, -a:, :, 0] = 1.
219 | pad[:, :, :b, 0] = 1.
220 | pad[:, :, -b:, 0] = 1.
221 | pad = tf.convert_to_tensor(pad)
222 | tf.add_to_collection(name, pad)
223 | else:
224 | pad = pads[0]
225 | pad = tf.tile(pad, [tf.shape(x)[0], 1, 1, 1])
226 | x = tf.concat([x, pad], axis=3)
227 | else:
228 | pad = tf.pad(tf.zeros_like(x[:, :, :, :1]) - 1,
229 | [[0, 0], [a, a], [b, b], [0, 0]]) + 1
230 | x = tf.pad(x, [[0, 0], [a, a], [b, b], [0, 0]])
231 | x = tf.concat([x, pad], axis=3)
232 | return x
233 |
234 |
235 | @add_arg_scope
236 | def conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", do_weightnorm=False, do_actnorm=True, context1d=None, skip=1, edge_bias=True):
237 | with tf.variable_scope(name):
238 | if edge_bias and pad == "SAME":
239 | x = add_edge_padding(x, filter_size)
240 | pad = 'VALID'
241 |
242 | n_in = int(x.get_shape()[3])
243 |
244 | stride_shape = [1] + stride + [1]
245 | filter_shape = filter_size + [n_in, width]
246 | w = tf.get_variable("W", filter_shape, tf.float32,
247 | initializer=default_initializer())
248 | if do_weightnorm:
249 | w = tf.nn.l2_normalize(w, [0, 1, 2])
250 | if skip == 1:
251 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
252 | else:
253 | assert stride[0] == 1 and stride[1] == 1
254 | x = tf.nn.atrous_conv2d(x, w, skip, pad)
255 | if do_actnorm:
256 | x = actnorm("actnorm", x)
257 | else:
258 | x += tf.get_variable("b", [1, 1, 1, width],
259 | initializer=tf.zeros_initializer())
260 |
261 | if context1d != None:
262 | x += tf.reshape(linear("context", context1d,
263 | width), [-1, 1, 1, width])
264 | return x
265 |
266 |
267 | @add_arg_scope
268 | def separable_conv2d(name, x, width, filter_size=[3, 3], stride=[1, 1], padding="SAME", do_actnorm=True, std=0.05):
269 | n_in = int(x.get_shape()[3])
270 | with tf.variable_scope(name):
271 | assert filter_size[0] % 2 == 1 and filter_size[1] % 2 == 1
272 | strides = [1] + stride + [1]
273 | w1_shape = filter_size + [n_in, 1]
274 | w1_init = np.zeros(w1_shape, dtype='float32')
275 | w1_init[(filter_size[0]-1)//2, (filter_size[1]-1)//2, :,
276 | :] = 1. # initialize depthwise conv as identity
277 | w1 = tf.get_variable("W1", dtype=tf.float32, initializer=w1_init)
278 | w2_shape = [1, 1, n_in, width]
279 | w2 = tf.get_variable("W2", w2_shape, tf.float32,
280 | initializer=default_initializer(std))
281 | x = tf.nn.separable_conv2d(
282 | x, w1, w2, strides, padding, data_format='NHWC')
283 | if do_actnorm:
284 | x = actnorm("actnorm", x)
285 | else:
286 | x += tf.get_variable("b", [1, 1, 1, width],
287 | initializer=tf.zeros_initializer(std))
288 |
289 | return x
290 |
291 |
292 | @add_arg_scope
293 | def conv2d_zeros(name, x, width, filter_size=[3, 3], stride=[1, 1], pad="SAME", logscale_factor=3, skip=1, edge_bias=True):
294 | with tf.variable_scope(name):
295 | if edge_bias and pad == "SAME":
296 | x = add_edge_padding(x, filter_size)
297 | pad = 'VALID'
298 |
299 | n_in = int(x.get_shape()[3])
300 | stride_shape = [1] + stride + [1]
301 | filter_shape = filter_size + [n_in, width]
302 | w = tf.get_variable("W", filter_shape, tf.float32,
303 | initializer=tf.zeros_initializer())
304 | if skip == 1:
305 | x = tf.nn.conv2d(x, w, stride_shape, pad, data_format='NHWC')
306 | else:
307 | assert stride[0] == 1 and stride[1] == 1
308 | x = tf.nn.atrous_conv2d(x, w, skip, pad)
309 | x += tf.get_variable("b", [1, 1, 1, width],
310 | initializer=tf.zeros_initializer())
311 | x *= tf.exp(tf.get_variable("logs",
312 | [1, width], initializer=tf.zeros_initializer()) * logscale_factor)
313 | return x
314 |
315 |
316 | # 2X nearest-neighbour upsampling, also inspired by Jascha Sohl-Dickstein's code
317 | def upsample2d_nearest_neighbour(x):
318 | shape = x.get_shape()
319 | n_batch = int(shape[0])
320 | height = int(shape[1])
321 | width = int(shape[2])
322 | n_channels = int(shape[3])
323 | x = tf.reshape(x, (n_batch, height, 1, width, 1, n_channels))
324 | x = tf.concat(2, [x, x])
325 | x = tf.concat(4, [x, x])
326 | x = tf.reshape(x, (n_batch, height*2, width*2, n_channels))
327 | return x
328 |
329 |
330 | def upsample(x, factor=2):
331 | shape = x.get_shape()
332 | height = int(shape[1])
333 | width = int(shape[2])
334 | x = tf.image.resize_nearest_neighbor(x, [height * factor, width * factor])
335 | return x
336 |
337 |
338 | def squeeze2d(x, factor=2):
339 | assert factor >= 1
340 | if factor == 1:
341 | return x
342 | shape = x.get_shape()
343 | height = int(shape[1])
344 | width = int(shape[2])
345 | n_channels = int(shape[3])
346 | assert height % factor == 0 and width % factor == 0
347 | x = tf.reshape(x, [-1, height//factor, factor,
348 | width//factor, factor, n_channels])
349 | x = tf.transpose(x, [0, 1, 3, 5, 2, 4])
350 | x = tf.reshape(x, [-1, height//factor, width //
351 | factor, n_channels*factor*factor])
352 | return x
353 |
354 |
355 | def unsqueeze2d(x, factor=2):
356 | assert factor >= 1
357 | if factor == 1:
358 | return x
359 | shape = x.get_shape()
360 | height = int(shape[1])
361 | width = int(shape[2])
362 | n_channels = int(shape[3])
363 | assert n_channels >= 4 and n_channels % 4 == 0
364 | x = tf.reshape(
365 | x, (-1, height, width, int(n_channels/factor**2), factor, factor))
366 | x = tf.transpose(x, [0, 1, 4, 2, 5, 3])
367 | x = tf.reshape(x, (-1, int(height*factor),
368 | int(width*factor), int(n_channels/factor**2)))
369 | return x
370 |
371 | # Reverse features across channel dimension
372 |
373 |
374 | def reverse_features(name, h, reverse=False):
375 | return h[:, :, :, ::-1]
376 |
377 | # Shuffle across the channel dimension
378 |
379 |
380 | def shuffle_features(name, h, indices=None, return_indices=False, reverse=False):
381 | with tf.variable_scope(name):
382 |
383 | rng = np.random.RandomState(
384 | (abs(hash(tf.get_variable_scope().name))) % 10000000)
385 |
386 | if indices == None:
387 | # Create numpy and tensorflow variables with indices
388 | n_channels = int(h.get_shape()[-1])
389 | indices = list(range(n_channels))
390 | rng.shuffle(indices)
391 | # Reverse it
392 | indices_inverse = [0]*n_channels
393 | for i in range(n_channels):
394 | indices_inverse[indices[i]] = i
395 |
396 | tf_indices = tf.get_variable("indices", dtype=tf.int32, initializer=np.asarray(
397 | indices, dtype='int32'), trainable=False)
398 | tf_indices_reverse = tf.get_variable("indices_inverse", dtype=tf.int32, initializer=np.asarray(
399 | indices_inverse, dtype='int32'), trainable=False)
400 |
401 | _indices = tf_indices
402 | if reverse:
403 | _indices = tf_indices_reverse
404 |
405 | if len(h.get_shape()) == 2:
406 | # Slice
407 | h = tf.transpose(h)
408 | h = tf.gather(h, _indices)
409 | h = tf.transpose(h)
410 | elif len(h.get_shape()) == 4:
411 | # Slice
412 | h = tf.transpose(h, [3, 1, 2, 0])
413 | h = tf.gather(h, _indices)
414 | h = tf.transpose(h, [3, 1, 2, 0])
415 | if return_indices:
416 | return h, indices
417 | return h
418 |
419 |
420 | def embedding(name, y, n_y, width):
421 | with tf.variable_scope(name):
422 | params = tf.get_variable(
423 | "embedding", [n_y, width], initializer=default_initializer())
424 | embeddings = tf.gather(params, y)
425 | return embeddings
426 |
427 | # Random variables
428 |
429 |
430 | def flatten_sum(logps):
431 | if len(logps.get_shape()) == 2:
432 | return tf.reduce_sum(logps, [1])
433 | elif len(logps.get_shape()) == 4:
434 | return tf.reduce_sum(logps, [1, 2, 3])
435 | else:
436 | raise Exception()
437 |
438 |
439 | def standard_gaussian(shape):
440 | return gaussian_diag(tf.zeros(shape), tf.zeros(shape))
441 |
442 |
443 | def gaussian_diag(mean, logsd):
444 | class o(object):
445 | pass
446 | o.mean = mean
447 | o.logsd = logsd
448 | o.eps = tf.random_normal(tf.shape(mean))
449 | o.sample = mean + tf.exp(logsd) * o.eps
450 | o.sample2 = lambda eps: mean + tf.exp(logsd) * eps
451 | o.logps = lambda x: -0.5 * \
452 | (np.log(2 * np.pi) + 2. * logsd + (x - mean) ** 2 / tf.exp(2. * logsd))
453 | o.logp = lambda x: flatten_sum(o.logps(x))
454 | o.get_eps = lambda x: (x - mean) / tf.exp(logsd)
455 | return o
456 |
457 |
458 | # def discretized_logistic_old(mean, logscale, binsize=1 / 256.0, sample=None):
459 | # scale = tf.exp(logscale)
460 | # sample = (tf.floor(sample / binsize) * binsize - mean) / scale
461 | # logp = tf.log(tf.sigmoid(sample + binsize / scale) - tf.sigmoid(sample) + 1e-7)
462 | # return tf.reduce_sum(logp, [1, 2, 3])
463 |
464 | def discretized_logistic(mean, logscale, binsize=1. / 256):
465 | class o(object):
466 | pass
467 | o.mean = mean
468 | o.logscale = logscale
469 | scale = tf.exp(logscale)
470 |
471 | def logps(x):
472 | x = (x - mean) / scale
473 | return tf.log(tf.sigmoid(x + binsize / scale) - tf.sigmoid(x) + 1e-7)
474 | o.logps = logps
475 | o.logp = lambda x: flatten_sum(logps(x))
476 | return o
477 |
478 |
479 | def _symmetric_matrix_square_root(mat, eps=1e-10):
480 | """Compute square root of a symmetric matrix.
481 | Note that this is different from an elementwise square root. We want to
482 | compute M' where M' = sqrt(mat) such that M' * M' = mat.
483 | Also note that this method **only** works for symmetric matrices.
484 | Args:
485 | mat: Matrix to take the square root of.
486 | eps: Small epsilon such that any element less than eps will not be square
487 | rooted to guard against numerical instability.
488 | Returns:
489 | Matrix square root of mat.
490 | """
491 | # Unlike numpy, tensorflow's return order is (s, u, v)
492 | s, u, v = tf.svd(mat)
493 | # sqrt is unstable around 0, just use 0 in such case
494 | si = tf.where(tf.less(s, eps), s, tf.sqrt(s))
495 | # Note that the v returned by Tensorflow is v = V
496 | # (when referencing the equation A = U S V^T)
497 | # This is unlike Numpy which returns v = V^T
498 | return tf.matmul(
499 | tf.matmul(u, tf.diag(si)), v, transpose_b=True)
500 |
--------------------------------------------------------------------------------
/demo/web/load-image.all.min.js.map:
--------------------------------------------------------------------------------
1 | {"version":3,"sources":["load-image.js","load-image-scale.js","load-image-meta.js","load-image-fetch.js","load-image-exif.js","load-image-exif-map.js","load-image-orientation.js"],"names":["$","loadImage","file","callback","options","url","img","document","createElement","onerror","event","onload","fetchBlob","blob","createObjectURL","crossOrigin","src","isInstanceOf","_objectURL","readFile","e","target","result","revokeHelper","noRevoke","revokeObjectURL","urlAPI","URL","webkitURL","type","obj","Object","prototype","toString","call","transform","data","method","FileReader","fileReader","define","amd","module","exports","window","this","factory","require","originalTransform","scale","transformCoordinates","getTransformedOptions","newOptions","i","width","height","aspectRatio","hasOwnProperty","crop","naturalWidth","naturalHeight","maxWidth","maxHeight","renderImageToCanvas","canvas","sourceX","sourceY","sourceWidth","sourceHeight","destX","destY","destWidth","destHeight","getContext","drawImage","hasCanvasOption","scaleUp","Math","max","minWidth","minHeight","scaleDown","min","pixelRatio","downsamplingRatio","tmp","useCanvas","left","top","undefined","right","bottom","contain","cover","style","hasblobSlice","Blob","slice","webkitSlice","mozSlice","blobSlice","apply","arguments","metaDataParsers","jpeg","65505","parseMetaData","that","maxMetaDataSize","DataView","size","error","console","log","markerBytes","markerLength","parsers","buffer","dataView","offset","maxOffset","byteLength","headLength","getUint16","length","disableImageHead","imageHead","Uint8Array","subarray","hasMetaOption","meta","fetch","Request","then","response","catch","err","ExifMap","map","Orientation","get","id","getExifThumbnail","exifTagTypes","1","getValue","dataOffset","getUint8","2","String","fromCharCode","ascii","3","littleEndian","4","getUint32","5","9","getInt32","10","getExifValue","tiffOffset","tagSize","values","str","c","tagType","parseExifTag","tag","exif","parseExifTags","dirOffset","tagsNumber","dirEndOffset","parseExifData","disableExif","thumbnailData","disableExifThumbnail","Thumbnail","disableExifSub","disableExifGps","push","tags","256","257","34665","34853","40965","258","259","262","274","277","284","530","531","282","283","296","273","278","279","513","514","301","318","319","529","532","306","270","271","272","305","315","33432","36864","40960","40961","40962","40963","42240","37121","37122","37500","37510","40964","36867","36868","37520","37521","37522","33434","33437","34850","34852","34855","34856","34864","34865","34866","34867","34868","34869","37377","37378","37379","37380","37381","37382","37383","37384","37385","37396","37386","41483","41484","41486","41487","41488","41492","41493","41495","41728","41729","41730","41985","41986","41987","41988","41989","41990","41991","41992","41993","41994","41995","41996","42016","42032","42033","42034","42035","42036","42037","0","6","7","8","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","stringValues","ExposureProgram","MeteringMode","255","LightSource","Flash","32","65","69","71","73","77","79","89","93","95","SensingMethod","SceneCaptureType","SceneType","CustomRendered","WhiteBalance","GainControl","Contrast","Saturation","Sharpness","SubjectDistanceRange","FileSource","ComponentsConfiguration","getText","value","exifMapPrototype","prop","getAll","originalHasCanvasOption","originalHasMetaOption","originalTransformCoordinates","originalGetTransformedOptions","orientation","ctx","styleWidth","styleHeight","translate","rotate","PI","opts"],"mappings":"CAaC,SAAWA,GACV,aAKA,SAASC,EAAWC,EAAMC,EAAUC,GAClC,IACIC,EADAC,EAAMC,SAASC,cAAc,OAQjC,OANAF,EAAIG,QAAU,SAAUC,GACtB,OAAOT,EAAUQ,QAAQH,EAAKI,EAAOR,EAAMC,EAAUC,IAEvDE,EAAIK,OAAS,SAAUD,GACrB,OAAOT,EAAUU,OAAOL,EAAKI,EAAOR,EAAMC,EAAUC,IAElC,iBAATF,GACTD,EAAUW,UACRV,EACA,SAAUW,GACJA,GACFX,EAAOW,EACPR,EAAMJ,EAAUa,gBAAgBZ,KAEhCG,EAAMH,EACFE,GAAWA,EAAQW,cACrBT,EAAIS,YAAcX,EAAQW,cAG9BT,EAAIU,IAAMX,GAEZD,GAEKE,GAEPL,EAAUgB,aAAa,OAAQf,IAG/BD,EAAUgB,aAAa,OAAQf,IAE/BG,EAAMC,EAAIY,WAAajB,EAAUa,gBAAgBZ,KAE/CI,EAAIU,IAAMX,EACHC,GAEFL,EAAUkB,SAASjB,EAAM,SAAUkB,GACxC,IAAIC,EAASD,EAAEC,OACXA,GAAUA,EAAOC,OACnBhB,EAAIU,IAAMK,EAAOC,OACRnB,GACTA,EAASiB,UAhBR,EA4BT,SAASG,EAAcjB,EAAKF,IACtBE,EAAIY,YAAgBd,GAAWA,EAAQoB,WACzCvB,EAAUwB,gBAAgBnB,EAAIY,mBACvBZ,EAAIY,YARf,IAAIQ,EACD1B,EAAEc,iBAAmBd,GACrBA,EAAE2B,KAAOA,IAAIF,iBAAmBE,KAChC3B,EAAE4B,WAAaA,UAYlB3B,EAAUW,UAAY,SAAUP,EAAKF,EAAUC,GAC7CD,KAGFF,EAAUgB,aAAe,SAAUY,EAAMC,GAEvC,OAAOC,OAAOC,UAAUC,SAASC,KAAKJ,KAAS,WAAaD,EAAO,KAGrE5B,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GAC5DjC,EAASG,EAAK8B,IAGhBnC,EAAUQ,QAAU,SAAUH,EAAKI,EAAOR,EAAMC,EAAUC,GACxDmB,EAAajB,EAAKF,GACdD,GACFA,EAAS+B,KAAK5B,EAAKI,IAIvBT,EAAUU,OAAS,SAAUL,EAAKI,EAAOR,EAAMC,EAAUC,GACvDmB,EAAajB,EAAKF,GACdD,GACFF,EAAUkC,UAAU7B,EAAKF,EAASD,EAAUD,OAIhDD,EAAUa,gBAAkB,SAAUZ,GACpC,QAAOwB,GAASA,EAAOZ,gBAAgBZ,IAGzCD,EAAUwB,gBAAkB,SAAUpB,GACpC,QAAOqB,GAASA,EAAOD,gBAAgBpB,IAMzCJ,EAAUkB,SAAW,SAAUjB,EAAMC,EAAUkC,GAC7C,GAAIrC,EAAEsC,WAAY,CAChB,IAAIC,EAAa,IAAID,WAGrB,GAFAC,EAAW5B,OAAS4B,EAAW9B,QAAUN,EACzCkC,EAASA,GAAU,gBACfE,EAAWF,GAEb,OADAE,EAAWF,GAAQnC,GACZqC,EAGX,OAAO,GAGa,mBAAXC,QAAyBA,OAAOC,IACzCD,OAAO,WACL,OAAOvC,IAEkB,iBAAXyC,QAAuBA,OAAOC,QAC9CD,OAAOC,QAAU1C,EAEjBD,EAAEC,UAAYA,EAjIjB,CAmIqB,oBAAX2C,QAA0BA,QAAWC,MCnI/C,SAAWC,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,gBAAiBM,GAEzBA,EAD2B,iBAAXJ,QAAuBA,OAAOC,QACtCI,QAAQ,gBAGRH,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEA,IAAI+C,EAAoB/C,EAAUkC,UAElClC,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GAC5DY,EAAkBd,KAChBjC,EACAA,EAAUgD,MAAM3C,EAAKF,EAASgC,GAC9BhC,EACAD,EACAD,EACAkC,IAOJnC,EAAUiD,qBAAuB,aAKjCjD,EAAUkD,sBAAwB,SAAU7C,EAAKF,GAC/C,IACIgD,EACAC,EACAC,EACAC,EAJAC,EAAcpD,EAAQoD,YAK1B,IAAKA,EACH,OAAOpD,EAETgD,KACA,IAAKC,KAAKjD,EACJA,EAAQqD,eAAeJ,KACzBD,EAAWC,GAAKjD,EAAQiD,IAa5B,OAVAD,EAAWM,MAAO,EAClBJ,EAAQhD,EAAIqD,cAAgBrD,EAAIgD,MAChCC,EAASjD,EAAIsD,eAAiBtD,EAAIiD,OAC9BD,EAAQC,EAASC,GACnBJ,EAAWS,SAAWN,EAASC,EAC/BJ,EAAWU,UAAYP,IAEvBH,EAAWS,SAAWP,EACtBF,EAAWU,UAAYR,EAAQE,GAE1BJ,GAITnD,EAAU8D,oBAAsB,SAC9BC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,GAeA,OAbAR,EACGS,WAAW,MACXC,UACCpE,EACA2D,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,EACAC,GAEGR,GAIT/D,EAAU0E,gBAAkB,SAAUvE,GACpC,OAAOA,EAAQ4D,QAAU5D,EAAQsD,QAAUtD,EAAQoD,aAQrDvD,EAAUgD,MAAQ,SAAU3C,EAAKF,EAASgC,GAqBxC,SAASwC,IACP,IAAI3B,EAAQ4B,KAAKC,KACdC,GAAYR,GAAaA,GACzBS,GAAaR,GAAcA,GAE1BvB,EAAQ,IACVsB,GAAatB,EACbuB,GAAcvB,GAGlB,SAASgC,IACP,IAAIhC,EAAQ4B,KAAKK,KACdrB,GAAYU,GAAaA,GACzBT,GAAaU,GAAcA,GAE1BvB,EAAQ,IACVsB,GAAatB,EACbuB,GAAcvB,GArClB7C,EAAUA,MACV,IAQIyD,EACAC,EACAiB,EACAC,EACAb,EACAC,EACAH,EACAC,EACAiB,EACAC,EACAC,EAlBArB,EAASzD,SAASC,cAAc,UAChC8E,EACFhF,EAAImE,YACHxE,EAAU0E,gBAAgBvE,IAAY4D,EAAOS,WAC5CnB,EAAQhD,EAAIqD,cAAgBrD,EAAIgD,MAChCC,EAASjD,EAAIsD,eAAiBtD,EAAIiD,OAClCgB,EAAYjB,EACZkB,EAAajB,EAuFjB,GAvDI+B,IAEFrB,GADA7D,EAAUH,EAAUkD,sBAAsB7C,EAAKF,EAASgC,IACtCmD,MAAQ,EAC1BrB,EAAU9D,EAAQoF,KAAO,EACrBpF,EAAQ+D,aACVA,EAAc/D,EAAQ+D,iBACAsB,IAAlBrF,EAAQsF,YAAwCD,IAAjBrF,EAAQmF,OACzCtB,EAAUX,EAAQa,EAAc/D,EAAQsF,QAG1CvB,EAAcb,EAAQW,GAAW7D,EAAQsF,OAAS,GAEhDtF,EAAQgE,cACVA,EAAehE,EAAQgE,kBACAqB,IAAnBrF,EAAQuF,aAAwCF,IAAhBrF,EAAQoF,MAC1CtB,EAAUX,EAASa,EAAehE,EAAQuF,SAG5CvB,EAAeb,EAASW,GAAW9D,EAAQuF,QAAU,GAEvDpB,EAAYJ,EACZK,EAAaJ,GAEfP,EAAWzD,EAAQyD,SACnBC,EAAY1D,EAAQ0D,UACpBiB,EAAW3E,EAAQ2E,SACnBC,EAAY5E,EAAQ4E,UAChBM,GAAazB,GAAYC,GAAa1D,EAAQsD,MAChDa,EAAYV,EACZW,EAAaV,GACbuB,EAAMlB,EAAcC,EAAeP,EAAWC,GACpC,GACRM,EAAeN,EAAYK,EAAcN,OACrB4B,IAAhBrF,EAAQoF,UAAwCC,IAAnBrF,EAAQuF,SACvCzB,GAAWX,EAASa,GAAgB,IAE7BiB,EAAM,IACflB,EAAcN,EAAWO,EAAeN,OACnB2B,IAAjBrF,EAAQmF,WAAwCE,IAAlBrF,EAAQsF,QACxCzB,GAAWX,EAAQa,GAAe,OAIlC/D,EAAQwF,SAAWxF,EAAQyF,SAC7Bd,EAAWlB,EAAWA,GAAYkB,EAClCC,EAAYlB,EAAYA,GAAakB,GAEnC5E,EAAQyF,OACVZ,IACAL,MAEAA,IACAK,MAGAK,EAAW,CAUb,IATAH,EAAa/E,EAAQ+E,YACJ,IACfnB,EAAO8B,MAAMxC,MAAQiB,EAAY,KACjCP,EAAO8B,MAAMvC,OAASiB,EAAa,KACnCD,GAAaY,EACbX,GAAcW,EACdnB,EAAOS,WAAW,MAAMxB,MAAMkC,EAAYA,KAE5CC,EAAoBhF,EAAQgF,mBAEN,GACpBA,EAAoB,GACpBb,EAAYJ,GACZK,EAAaJ,EAEb,KAAOD,EAAciB,EAAoBb,GACvCP,EAAOV,MAAQa,EAAciB,EAC7BpB,EAAOT,OAASa,EAAegB,EAC/BnF,EAAU8D,oBACRC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACA,EACA,EACAJ,EAAOV,MACPU,EAAOT,QAETU,EAAU,EACVC,EAAU,EACVC,EAAcH,EAAOV,MACrBc,EAAeJ,EAAOT,QACtBjD,EAAMC,SAASC,cAAc,WACzB8C,MAAQa,EACZ7D,EAAIiD,OAASa,EACbnE,EAAU8D,oBACRzD,EACA0D,EACA,EACA,EACAG,EACAC,EACA,EACA,EACAD,EACAC,GAON,OAHAJ,EAAOV,MAAQiB,EACfP,EAAOT,OAASiB,EAChBvE,EAAUiD,qBAAqBc,EAAQ5D,GAChCH,EAAU8D,oBACfC,EACA1D,EACA2D,EACAC,EACAC,EACAC,EACA,EACA,EACAG,EACAC,GAKJ,OAFAlE,EAAIgD,MAAQiB,EACZjE,EAAIiD,OAASiB,EACNlE,KCxQV,SAAWwC,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,gBAAiBM,GAEzBA,EAD2B,iBAAXJ,QAAuBA,OAAOC,QACtCI,QAAQ,gBAGRH,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEA,IAAI8F,EACc,oBAATC,OACNA,KAAKhE,UAAUiE,OACdD,KAAKhE,UAAUkE,aACfF,KAAKhE,UAAUmE,UAEnBlG,EAAUmG,UACRL,GACA,WAEE,OADYlD,KAAKoD,OAASpD,KAAKqD,aAAerD,KAAKsD,UACtCE,MAAMxD,KAAMyD,YAG7BrG,EAAUsG,iBACRC,MACEC,WAUJxG,EAAUyG,cAAgB,SAAUxG,EAAMC,EAAUC,EAASgC,GAC3DhC,EAAUA,MACVgC,EAAOA,MACP,IAAIuE,EAAO9D,KAEP+D,EAAkBxG,EAAQwG,iBAAmB,UAE3B,oBAAbC,UACP3G,GACAA,EAAK4G,MAAQ,IACC,eAAd5G,EAAK2B,MACL5B,EAAUmG,YAITnG,EAAUkB,SACTlB,EAAUmG,UAAUlE,KAAKhC,EAAM,EAAG0G,GAClC,SAAUxF,GACR,GAAIA,EAAEC,OAAO0F,MAIX,OAFAC,QAAQC,IAAI7F,EAAEC,OAAO0F,YACrB5G,EAASiC,GAOX,IAKI8E,EACAC,EACAC,EACA/D,EARAgE,EAASjG,EAAEC,OAAOC,OAClBgG,EAAW,IAAIT,SAASQ,GACxBE,EAAS,EACTC,EAAYF,EAASG,WAAa,EAClCC,EAAaH,EAMjB,GAA8B,QAA1BD,EAASK,UAAU,GAAe,CACpC,KAAOJ,EAASC,KACdN,EAAcI,EAASK,UAAUJ,KAKf,OAAUL,GAAe,OACzB,QAAhBA,IAPuB,CAcvB,GADAC,EAAeG,EAASK,UAAUJ,EAAS,GAAK,EAC5CA,EAASJ,EAAeG,EAASG,WAAY,CAC/CT,QAAQC,IAAI,4CACZ,MAGF,GADAG,EAAUnH,EAAUsG,gBAAgBC,KAAKU,GAEvC,IAAK7D,EAAI,EAAGA,EAAI+D,EAAQQ,OAAQvE,GAAK,EACnC+D,EAAQ/D,GAAGnB,KACTyE,EACAW,EACAC,EACAJ,EACA/E,EACAhC,GAKNsH,EADAH,GAAUJ,GAUT/G,EAAQyH,kBAAoBH,EAAa,IACxCL,EAAOpB,MACT7D,EAAK0F,UAAYT,EAAOpB,MAAM,EAAGyB,GAIjCtF,EAAK0F,UAAY,IAAIC,WAAWV,GAAQW,SAAS,EAAGN,SAIxDV,QAAQC,IAAI,2CAEd9G,EAASiC,IAEX,sBAGFjC,EAASiC,IAKbnC,EAAUgI,cAAgB,SAAU7H,GAClC,OAAOA,GAAWA,EAAQ8H,MAG5B,IAAIlF,EAAoB/C,EAAUkC,UAClClC,EAAUkC,UAAY,SAAU7B,EAAKF,EAASD,EAAUD,EAAMkC,GACxDnC,EAAUgI,cAAc7H,GAC1BH,EAAUyG,cACRxG,EACA,SAAUkC,GACRY,EAAkBd,KAAKjC,EAAWK,EAAKF,EAASD,EAAUD,EAAMkC,IAElEhC,EACAgC,GAGFY,EAAkBqD,MAAMpG,EAAWqG,cCjKxC,SAAWxD,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEqB,oBAAVkI,OAA4C,oBAAZC,UACzCnI,EAAUW,UAAY,SAAUP,EAAKF,EAAUC,GAC7C,GAAIH,EAAUgI,cAAc7H,GAC1B,OAAO+H,MAAM,IAAIC,QAAQ/H,EAAKD,IAC3BiI,KAAK,SAAUC,GACd,OAAOA,EAASzH,SAEjBwH,KAAKlI,GACLoI,MAAM,SAAUC,GACfxB,QAAQC,IAAIuB,GACZrI,MAGJA,QC3BP,SAAW2C,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEAA,EAAUwI,QAAU,WAClB,OAAO5F,MAGT5C,EAAUwI,QAAQzG,UAAU0G,KAC1BC,YAAa,KAGf1I,EAAUwI,QAAQzG,UAAU4G,IAAM,SAAUC,GAC1C,OAAOhG,KAAKgG,IAAOhG,KAAKA,KAAK6F,IAAIG,KAGnC5I,EAAU6I,iBAAmB,SAAUxB,EAAUC,EAAQK,GACvD,GAAKA,KAAUL,EAASK,EAASN,EAASG,YAI1C,OAAOxH,EAAUa,gBACf,IAAIkF,MAAMsB,EAASD,OAAOpB,MAAMsB,EAAQA,EAASK,MAJjDZ,QAAQC,IAAI,+CAQhBhH,EAAU8I,cAERC,GACEC,SAAU,SAAU3B,EAAU4B,GAC5B,OAAO5B,EAAS6B,SAASD,IAE3BpC,KAAM,GAGRsC,GACEH,SAAU,SAAU3B,EAAU4B,GAC5B,OAAOG,OAAOC,aAAahC,EAAS6B,SAASD,KAE/CpC,KAAM,EACNyC,OAAO,GAGTC,GACEP,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASK,UAAUuB,EAAYO,IAExC3C,KAAM,GAGR4C,GACET,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASqC,UAAUT,EAAYO,IAExC3C,KAAM,GAGR8C,GACEX,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OACEnC,EAASqC,UAAUT,EAAYO,GAC/BnC,EAASqC,UAAUT,EAAa,EAAGO,IAGvC3C,KAAM,GAGR+C,GACEZ,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OAAOnC,EAASwC,SAASZ,EAAYO,IAEvC3C,KAAM,GAGRiD,IACEd,SAAU,SAAU3B,EAAU4B,EAAYO,GACxC,OACEnC,EAASwC,SAASZ,EAAYO,GAC9BnC,EAASwC,SAASZ,EAAa,EAAGO,IAGtC3C,KAAM,IAIV7G,EAAU8I,aAAa,GAAK9I,EAAU8I,aAAa,GAEnD9I,EAAU+J,aAAe,SACvB1C,EACA2C,EACA1C,EACA1F,EACA+F,EACA6B,GAEA,IACIS,EACAhB,EACAiB,EACA9G,EACA+G,EACAC,EANAC,EAAUrK,EAAU8I,aAAalH,GAOrC,GAAKyI,EAAL,CAWA,GAPAJ,EAAUI,EAAQxD,KAAOc,KAGzBsB,EACEgB,EAAU,EACND,EAAa3C,EAASqC,UAAUpC,EAAS,EAAGkC,GAC5ClC,EAAS,GACE2C,EAAU5C,EAASG,YAApC,CAIA,GAAe,IAAXG,EACF,OAAO0C,EAAQrB,SAAS3B,EAAU4B,EAAYO,GAGhD,IADAU,KACK9G,EAAI,EAAGA,EAAIuE,EAAQvE,GAAK,EAC3B8G,EAAO9G,GAAKiH,EAAQrB,SAClB3B,EACA4B,EAAa7F,EAAIiH,EAAQxD,KACzB2C,GAGJ,GAAIa,EAAQf,MAAO,CAGjB,IAFAa,EAAM,GAED/G,EAAI,EAAGA,EAAI8G,EAAOvC,QAGX,QAFVyC,EAAIF,EAAO9G,IADkBA,GAAK,EAMlC+G,GAAOC,EAET,OAAOD,EAET,OAAOD,EA3BLnD,QAAQC,IAAI,gDAXZD,QAAQC,IAAI,yCAyChBhH,EAAUsK,aAAe,SACvBjD,EACA2C,EACA1C,EACAkC,EACArH,GAEA,IAAIoI,EAAMlD,EAASK,UAAUJ,EAAQkC,GACrCrH,EAAKqI,KAAKD,GAAOvK,EAAU+J,aACzB1C,EACA2C,EACA1C,EACAD,EAASK,UAAUJ,EAAS,EAAGkC,GAC/BnC,EAASqC,UAAUpC,EAAS,EAAGkC,GAC/BA,IAIJxJ,EAAUyK,cAAgB,SACxBpD,EACA2C,EACAU,EACAlB,EACArH,GAEA,IAAIwI,EAAYC,EAAcxH,EAC9B,GAAIsH,EAAY,EAAIrD,EAASG,WAC3BT,QAAQC,IAAI,oDADd,CAMA,GAFA2D,EAAatD,EAASK,UAAUgD,EAAWlB,MAC3CoB,EAAeF,EAAY,EAAI,GAAKC,GACjB,EAAItD,EAASG,YAAhC,CAIA,IAAKpE,EAAI,EAAGA,EAAIuH,EAAYvH,GAAK,EAC/BR,KAAK0H,aACHjD,EACA2C,EACAU,EAAY,EAAI,GAAKtH,EACrBoG,EACArH,GAIJ,OAAOkF,EAASqC,UAAUkB,EAAcpB,GAbtCzC,QAAQC,IAAI,gDAgBhBhH,EAAU6K,cAAgB,SAAUxD,EAAUC,EAAQK,EAAQxF,EAAMhC,GAClE,IAAIA,EAAQ2K,YAAZ,CAGA,IACItB,EACAkB,EACAK,EAHAf,EAAa1C,EAAS,GAK1B,GAAuC,aAAnCD,EAASqC,UAAUpC,EAAS,GAIhC,GAAI0C,EAAa,EAAI3C,EAASG,WAC5BT,QAAQC,IAAI,iDAId,GAAuC,IAAnCK,EAASK,UAAUJ,EAAS,GAAhC,CAKA,OAAQD,EAASK,UAAUsC,IACzB,KAAK,MACHR,GAAe,EACf,MACF,KAAK,MACHA,GAAe,EACf,MACF,QAEE,YADAzC,QAAQC,IAAI,qDAIyC,KAArDK,EAASK,UAAUsC,EAAa,EAAGR,IAKvCkB,EAAYrD,EAASqC,UAAUM,EAAa,EAAGR,GAE/CrH,EAAKqI,KAAO,IAAIxK,EAAUwI,SAG1BkC,EAAY1K,EAAUyK,cACpBpD,EACA2C,EACAA,EAAaU,EACblB,EACArH,MAEgBhC,EAAQ6K,uBACxBD,GAAkBP,SAClBE,EAAY1K,EAAUyK,cACpBpD,EACA2C,EACAA,EAAaU,EACblB,EACAuB,GAGEA,EAAcP,KAAK,OACrBrI,EAAKqI,KAAKS,UAAYjL,EAAU6I,iBAC9BxB,EACA2C,EAAae,EAAcP,KAAK,KAChCO,EAAcP,KAAK,QAKrBrI,EAAKqI,KAAK,SAAYrK,EAAQ+K,gBAChClL,EAAUyK,cACRpD,EACA2C,EACAA,EAAa7H,EAAKqI,KAAK,OACvBhB,EACArH,GAIAA,EAAKqI,KAAK,SAAYrK,EAAQgL,gBAChCnL,EAAUyK,cACRpD,EACA2C,EACAA,EAAa7H,EAAKqI,KAAK,OACvBhB,EACArH,IAnDF4E,QAAQC,IAAI,gDAjBZD,QAAQC,IAAI,uDA0EhBhH,EAAUsG,gBAAgBC,KAAK,OAAQ6E,KAAKpL,EAAU6K,iBCrSvD,SAAWhI,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsBM,GACnB,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EAAQC,QAAQ,gBAAiBA,QAAQ,sBAGzCD,EAAQF,OAAO3C,WATlB,CAWE,SAAUA,GACX,aAEAA,EAAUwI,QAAQzG,UAAUsJ,MAI1BC,IAAQ,aACRC,IAAQ,cACRC,MAAQ,iBACRC,MAAQ,oBACRC,MAAQ,6BACRC,IAAQ,gBACRC,IAAQ,cACRC,IAAQ,4BACRC,IAAQ,cACRC,IAAQ,kBACRC,IAAQ,sBACRC,IAAQ,mBACRC,IAAQ,mBACRC,IAAQ,cACRC,IAAQ,cACRC,IAAQ,iBACRC,IAAQ,eACRC,IAAQ,eACRC,IAAQ,kBACRC,IAAQ,wBACRC,IAAQ,8BACRC,IAAQ,mBACRC,IAAQ,aACRC,IAAQ,wBACRC,IAAQ,oBACRC,IAAQ,sBACRC,IAAQ,WACRC,IAAQ,mBACRC,IAAQ,OACRC,IAAQ,QACRC,IAAQ,WACRC,IAAQ,SACRC,MAAQ,YAIRC,MAAQ,cACRC,MAAQ,kBACRC,MAAQ,aACRC,MAAQ,kBACRC,MAAQ,kBACRC,MAAQ,QACRC,MAAQ,0BACRC,MAAQ,yBACRC,MAAQ,YACRC,MAAQ,cACRC,MAAQ,mBACRC,MAAQ,mBACRC,MAAQ,oBACRC,MAAQ,aACRC,MAAQ,qBACRC,MAAQ,sBACRC,MAAQ,eACRC,MAAQ,UACRC,MAAQ,kBACRC,MAAQ,sBACRC,MAAQ,0BACRC,MAAQ,OACRC,MAAQ,kBACRC,MAAQ,4BACRC,MAAQ,2BACRC,MAAQ,WACRC,MAAQ,sBACRC,MAAQ,sBACRC,MAAQ,oBACRC,MAAQ,gBACRC,MAAQ,kBACRC,MAAQ,eACRC,MAAQ,mBACRC,MAAQ,kBACRC,MAAQ,eACRC,MAAQ,cACRC,MAAQ,QACRC,MAAQ,cACRC,MAAQ,cACRC,MAAQ,cACRC,MAAQ,2BACRC,MAAQ,wBACRC,MAAQ,wBACRC,MAAQ,2BACRC,MAAQ,kBACRC,MAAQ,gBACRC,MAAQ,gBACRC,MAAQ,aACRC,MAAQ,YACRC,MAAQ,aACRC,MAAQ,iBACRC,MAAQ,eACRC,MAAQ,eACRC,MAAQ,mBACRC,MAAQ,wBACRC,MAAQ,mBACRC,MAAQ,cACRC,MAAQ,WACRC,MAAQ,aACRC,MAAQ,YACRC,MAAQ,2BACRC,MAAQ,uBACRC,MAAQ,gBACRC,MAAQ,kBACRC,MAAQ,mBACRC,MAAQ,oBACRC,MAAQ,WACRC,MAAQ,YACRC,MAAQ,mBAIRC,EAAQ,eACR7I,EAAQ,iBACRI,EAAQ,cACRI,EAAQ,kBACRE,EAAQ,eACRE,EAAQ,iBACRkI,EAAQ,cACRC,EAAQ,eACRC,EAAQ,gBACRnI,EAAQ,YACRE,GAAQ,iBACRkI,GAAQ,SACRC,GAAQ,cACRC,GAAQ,WACRC,GAAQ,cACRC,GAAQ,WACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,cACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,sBACRC,GAAQ,mBACRC,GAAQ,oBACRC,GAAQ,iBACRC,GAAQ,qBACRC,GAAQ,kBACRC,GAAQ,sBACRC,GAAQ,qBACRC,GAAQ,eACRC,GAAQ,kBACRC,GAAQ,wBAGVpT,EAAUwI,QAAQzG,UAAUsR,cAC1BC,iBACE1B,EAAG,YACH7I,EAAG,SACHI,EAAG,iBACHI,EAAG,oBACHE,EAAG,mBACHE,EAAG,mBACHkI,EAAG,iBACHC,EAAG,gBACHC,EAAG,kBAELwB,cACE3B,EAAG,UACH7I,EAAG,UACHI,EAAG,wBACHI,EAAG,OACHE,EAAG,YACHE,EAAG,UACHkI,EAAG,UACH2B,IAAK,SAEPC,aACE7B,EAAG,UACH7I,EAAG,WACHI,EAAG,cACHI,EAAG,gCACHE,EAAG,QACHG,EAAG,eACHE,GAAI,iBACJkI,GAAI,QACJC,GAAI,wCACJC,GAAI,yCACJC,GAAI,0CACJC,GAAI,sCACJE,GAAI,mBACJC,GAAI,mBACJC,GAAI,mBACJC,GAAI,MACJC,GAAI,MACJC,GAAI,MACJC,GAAI,MACJC,GAAI,sBACJW,IAAK,SAEPE,OACE9B,EAAQ,qBACR7I,EAAQ,cACRY,EAAQ,mCACRmI,EAAQ,+BACRlI,EAAQ,qCACRsI,GAAQ,gEACRE,GAAQ,4DACRC,GAAQ,4CACRQ,GAAQ,gCACRC,GAAQ,yBACRI,GAAQ,oDACRE,GAAQ,gDACRO,GAAQ,oBACRC,GAAQ,sCACRC,GAAQ,iEACRC,GAAQ,6DACRC,GAAQ,6DACRC,GAAQ,wFACRC,GAAQ,oFACRC,GAAQ,iDACRC,GAAQ,4EACRC,GAAQ,yEAEVC,eACEtL,EAAG,YACHI,EAAG,6BACHI,EAAG,6BACHE,EAAG,+BACHE,EAAG,+BACHmI,EAAG,mBACHC,EAAG,kCAELuC,kBACE1C,EAAG,WACH7I,EAAG,YACHI,EAAG,WACHI,EAAG,eAELgL,WACExL,EAAG,yBAELyL,gBACE5C,EAAG,iBACH7I,EAAG,kBAEL0L,cACE7C,EAAG,qBACH7I,EAAG,wBAEL2L,aACE9C,EAAG,OACH7I,EAAG,cACHI,EAAG,eACHI,EAAG,gBACHE,EAAG,kBAELkL,UACE/C,EAAG,SACH7I,EAAG,OACHI,EAAG,QAELyL,YACEhD,EAAG,SACH7I,EAAG,iBACHI,EAAG,mBAEL0L,WACEjD,EAAG,SACH7I,EAAG,OACHI,EAAG,QAEL2L,sBACElD,EAAG,UACH7I,EAAG,QACHI,EAAG,aACHI,EAAG,gBAELwL,YACExL,EAAG,OAELyL,yBACEpD,EAAG,GACH7I,EAAG,IACHI,EAAG,KACHI,EAAG,KACHE,EAAG,IACHE,EAAG,IACHkI,EAAG,KAELnJ,aACEK,EAAG,WACHI,EAAG,YACHI,EAAG,eACHE,EAAG,cACHE,EAAG,WACHkI,EAAG,YACHC,EAAG,eACHC,EAAG,gBAIP/R,EAAUwI,QAAQzG,UAAUkT,QAAU,SAAUrM,GAC9C,IAAIsM,EAAQtS,KAAK+F,IAAIC,GACrB,OAAQA,GACN,IAAK,cACL,IAAK,QACL,IAAK,eACL,IAAK,kBACL,IAAK,gBACL,IAAK,mBACL,IAAK,YACL,IAAK,iBACL,IAAK,eACL,IAAK,cACL,IAAK,WACL,IAAK,aACL,IAAK,YACL,IAAK,uBACL,IAAK,aACL,IAAK,cACH,OAAOhG,KAAKyQ,aAAazK,GAAIsM,GAC/B,IAAK,cACL,IAAK,kBACH,IAAKA,EAAO,OACZ,OAAO9L,OAAOC,aAAa6L,EAAM,GAAIA,EAAM,GAAIA,EAAM,GAAIA,EAAM,IACjE,IAAK,0BACH,IAAKA,EAAO,OACZ,OACEtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAC5BtS,KAAKyQ,aAAazK,GAAIsM,EAAM,IAEhC,IAAK,eACH,IAAKA,EAAO,OACZ,OAAOA,EAAM,GAAK,IAAMA,EAAM,GAAK,IAAMA,EAAM,GAAK,IAAMA,EAAM,GAEpE,OAAO9L,OAAO8L,IAEf,SAAWC,GACV,IAEIC,EAFA/J,EAAO8J,EAAiB9J,KACxB5C,EAAM0M,EAAiB1M,IAG3B,IAAK2M,KAAQ/J,EACPA,EAAK7H,eAAe4R,KACtB3M,EAAI4C,EAAK+J,IAASA,GAPvB,CAUEpV,EAAUwI,QAAQzG,WAErB/B,EAAUwI,QAAQzG,UAAUsT,OAAS,WACnC,IACID,EACAxM,EAFAH,KAGJ,IAAK2M,KAAQxS,KACPA,KAAKY,eAAe4R,KACtBxM,EAAKhG,KAAKyI,KAAK+J,MAEb3M,EAAIG,GAAMhG,KAAKqS,QAAQrM,IAI7B,OAAOH,KCpXV,SAAW5F,GACV,aACsB,mBAAXN,QAAyBA,OAAOC,IAEzCD,QAAQ,eAAgB,qBAAsB,qBAAsBM,GACzC,iBAAXJ,QAAuBA,OAAOC,QAC9CG,EACEC,QAAQ,gBACRA,QAAQ,sBACRA,QAAQ,sBAIVD,EAAQF,OAAO3C,WAblB,CAeE,SAAUA,GACX,aAEA,IAAIsV,EAA0BtV,EAAU0E,gBACpC6Q,EAAwBvV,EAAUgI,cAClCwN,EAA+BxV,EAAUiD,qBACzCwS,EAAgCzV,EAAUkD,sBAG9ClD,EAAU0E,gBAAkB,SAAUvE,GACpC,QACIA,EAAQuV,aAAeJ,EAAwBrT,KAAKjC,EAAWG,IAKrEH,EAAUgI,cAAgB,SAAU7H,GAClC,OACGA,IAAmC,IAAxBA,EAAQuV,aACpBH,EAAsBtT,KAAKjC,EAAWG,IAM1CH,EAAUiD,qBAAuB,SAAUc,EAAQ5D,GACjDqV,EAA6BvT,KAAKjC,EAAW+D,EAAQ5D,GACrD,IAAIwV,EAAM5R,EAAOS,WAAW,MACxBnB,EAAQU,EAAOV,MACfC,EAASS,EAAOT,OAChBsS,EAAa7R,EAAO8B,MAAMxC,MAC1BwS,EAAc9R,EAAO8B,MAAMvC,OAC3BoS,EAAcvV,EAAQuV,YAC1B,GAAKA,KAAeA,EAAc,GASlC,OANIA,EAAc,IAChB3R,EAAOV,MAAQC,EACfS,EAAOT,OAASD,EAChBU,EAAO8B,MAAMxC,MAAQwS,EACrB9R,EAAO8B,MAAMvC,OAASsS,GAEhBF,GACN,KAAK,EAEHC,EAAIG,UAAUzS,EAAO,GACrBsS,EAAI3S,OAAO,EAAG,GACd,MACF,KAAK,EAEH2S,EAAIG,UAAUzS,EAAOC,GACrBqS,EAAII,OAAOnR,KAAKoR,IAChB,MACF,KAAK,EAEHL,EAAIG,UAAU,EAAGxS,GACjBqS,EAAI3S,MAAM,GAAI,GACd,MACF,KAAK,EAEH2S,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAI3S,MAAM,GAAI,GACd,MACF,KAAK,EAEH2S,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAIG,UAAU,GAAIxS,GAClB,MACF,KAAK,EAEHqS,EAAII,OAAO,GAAMnR,KAAKoR,IACtBL,EAAIG,UAAUzS,GAAQC,GACtBqS,EAAI3S,OAAO,EAAG,GACd,MACF,KAAK,EAEH2S,EAAII,QAAQ,GAAMnR,KAAKoR,IACvBL,EAAIG,WAAWzS,EAAO,KAO5BrD,EAAUkD,sBAAwB,SAAU7C,EAAK4V,EAAM9T,GACrD,IAEIgB,EACAC,EAHAjD,EAAUsV,EAA8BxT,KAAKjC,EAAWK,EAAK4V,GAC7DP,EAAcvV,EAAQuV,YAM1B,IAHoB,IAAhBA,GAAwBvT,GAAQA,EAAKqI,OACvCkL,EAAcvT,EAAKqI,KAAK7B,IAAI,iBAEzB+M,GAAeA,EAAc,GAAqB,IAAhBA,EACrC,OAAOvV,EAETgD,KACA,IAAKC,KAAKjD,EACJA,EAAQqD,eAAeJ,KACzBD,EAAWC,GAAKjD,EAAQiD,IAI5B,OADAD,EAAWuS,YAAcA,EACjBA,GACN,KAAK,EAEHvS,EAAWmC,KAAOnF,EAAQsF,MAC1BtC,EAAWsC,MAAQtF,EAAQmF,KAC3B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQsF,MAC1BtC,EAAWoC,IAAMpF,EAAQuF,OACzBvC,EAAWsC,MAAQtF,EAAQmF,KAC3BnC,EAAWuC,OAASvF,EAAQoF,IAC5B,MACF,KAAK,EAEHpC,EAAWoC,IAAMpF,EAAQuF,OACzBvC,EAAWuC,OAASvF,EAAQoF,IAC5B,MACF,KAAK,EAEHpC,EAAWmC,KAAOnF,EAAQoF,IAC1BpC,EAAWoC,IAAMpF,EAAQmF,KACzBnC,EAAWsC,MAAQtF,EAAQuF,OAC3BvC,EAAWuC,OAASvF,EAAQsF,MAC5B,MACF,KAAK,EAEHtC,EAAWmC,KAAOnF,EAAQoF,IAC1BpC,EAAWoC,IAAMpF,EAAQsF,MACzBtC,EAAWsC,MAAQtF,EAAQuF,OAC3BvC,EAAWuC,OAASvF,EAAQmF,KAC5B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQuF,OAC1BvC,EAAWoC,IAAMpF,EAAQsF,MACzBtC,EAAWsC,MAAQtF,EAAQoF,IAC3BpC,EAAWuC,OAASvF,EAAQmF,KAC5B,MACF,KAAK,EAEHnC,EAAWmC,KAAOnF,EAAQuF,OAC1BvC,EAAWoC,IAAMpF,EAAQmF,KACzBnC,EAAWsC,MAAQtF,EAAQoF,IAC3BpC,EAAWuC,OAASvF,EAAQsF,MAWhC,OARItC,EAAWuS,YAAc,IAC3BvS,EAAWS,SAAWzD,EAAQ0D,UAC9BV,EAAWU,UAAY1D,EAAQyD,SAC/BT,EAAW2B,SAAW3E,EAAQ4E,UAC9B5B,EAAW4B,UAAY5E,EAAQ2E,SAC/B3B,EAAWe,YAAc/D,EAAQgE,aACjChB,EAAWgB,aAAehE,EAAQ+D,aAE7Bf"}
2 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import tfops as Z
4 | import optim
5 | import numpy as np
6 | import horovod.tensorflow as hvd
7 | from tensorflow.contrib.framework.python.ops import add_arg_scope
8 |
9 |
10 | '''
11 | f_loss: function with as input the (x,y,reuse=False), and as output a list/tuple whose first element is the loss.
12 | '''
13 |
14 |
15 | def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):
16 |
17 | # == Create class with static fields and methods
18 | class m(object):
19 | pass
20 | m.sess = sess
21 | m.feeds = feeds
22 | m.lr = lr
23 |
24 | # === Loss and optimizer
25 | loss_train, stats_train = f_loss(train_iterator, True)
26 | all_params = tf.trainable_variables()
27 | if hps.gradient_checkpointing == 1:
28 | from memory_saving_gradients import gradients
29 | gs = gradients(loss_train, all_params)
30 | else:
31 | gs = tf.gradients(loss_train, all_params)
32 |
33 | optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
34 | 'adam2': optim.adam2}[hps.optimizer]
35 |
36 | train_op, polyak_swap_op, ema = optimizer(
37 | all_params, gs, alpha=lr, hps=hps)
38 | if hps.direct_iterator:
39 | m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
40 | else:
41 | def _train(_lr):
42 | _x, _y = train_iterator()
43 | return sess.run([train_op, stats_train], {feeds['x']: _x,
44 | feeds['y']: _y, lr: _lr})[1]
45 | m.train = _train
46 |
47 | m.polyak_swap = lambda: sess.run(polyak_swap_op)
48 |
49 | # === Testing
50 | loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
51 | if hps.direct_iterator:
52 | m.test = lambda: sess.run(stats_test)
53 | else:
54 | def _test():
55 | _x, _y = test_iterator()
56 | return sess.run(stats_test, {feeds['x']: _x,
57 | feeds['y']: _y})
58 | m.test = _test
59 |
60 | # === Saving and restoring
61 | saver = tf.train.Saver()
62 | saver_ema = tf.train.Saver(ema.variables_to_restore())
63 | m.save_ema = lambda path: saver_ema.save(
64 | sess, path, write_meta_graph=False)
65 | m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
66 | m.restore = lambda path: saver.restore(sess, path)
67 |
68 | # === Initialize the parameters
69 | if hps.restore_path != '':
70 | m.restore(hps.restore_path)
71 | else:
72 | with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
73 | results_init = f_loss(None, True, reuse=True)
74 | sess.run(tf.global_variables_initializer())
75 | sess.run(results_init, {feeds['x']: data_init['x'],
76 | feeds['y']: data_init['y']})
77 | sess.run(hvd.broadcast_global_variables(0))
78 |
79 | return m
80 |
81 |
82 | def codec(hps):
83 |
84 | def encoder(z, objective):
85 | eps = []
86 | for i in range(hps.n_levels):
87 | z, objective = revnet2d(str(i), z, objective, hps)
88 | if i < hps.n_levels-1:
89 | z, objective, _eps = split2d("pool"+str(i), z, objective=objective)
90 | eps.append(_eps)
91 | return z, objective, eps
92 |
93 | def decoder(z, eps=[None]*hps.n_levels, eps_std=None):
94 | for i in reversed(range(hps.n_levels)):
95 | if i < hps.n_levels-1:
96 | z = split2d_reverse("pool"+str(i), z, eps=eps[i], eps_std=eps_std)
97 | z, _ = revnet2d(str(i), z, 0, hps, reverse=True)
98 |
99 | return z
100 |
101 | return encoder, decoder
102 |
103 |
104 | def prior(name, y_onehot, hps):
105 |
106 | with tf.variable_scope(name):
107 | n_z = hps.top_shape[-1]
108 |
109 | h = tf.zeros([tf.shape(y_onehot)[0]]+hps.top_shape[:2]+[2*n_z])
110 | if hps.learntop:
111 | h = Z.conv2d_zeros('p', h, 2*n_z)
112 | if hps.ycond:
113 | h += tf.reshape(Z.linear_zeros("y_emb", y_onehot,
114 | 2*n_z), [-1, 1, 1, 2 * n_z])
115 |
116 | pz = Z.gaussian_diag(h[:, :, :, :n_z], h[:, :, :, n_z:])
117 |
118 | def logp(z1):
119 | objective = pz.logp(z1)
120 | return objective
121 |
122 | def sample(eps=None, eps_std=None):
123 | if eps is not None:
124 | # Already sampled eps. Don't use eps_std
125 | z = pz.sample2(eps)
126 | elif eps_std is not None:
127 | # Sample with given eps_std
128 | z = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1]))
129 | else:
130 | # Sample normally
131 | z = pz.sample
132 |
133 | return z
134 |
135 | def eps(z1):
136 | return pz.get_eps(z1)
137 |
138 | return logp, sample, eps
139 |
140 |
141 | def model(sess, hps, train_iterator, test_iterator, data_init):
142 |
143 | # Only for decoding/init, rest use iterators directly
144 | with tf.name_scope('input'):
145 | X = tf.placeholder(
146 | tf.uint8, [None, hps.image_size, hps.image_size, 3], name='image')
147 | Y = tf.placeholder(tf.int32, [None], name='label')
148 | lr = tf.placeholder(tf.float32, None, name='learning_rate')
149 |
150 | encoder, decoder = codec(hps)
151 | hps.n_bins = 2. ** hps.n_bits_x
152 |
153 | def preprocess(x):
154 | x = tf.cast(x, 'float32')
155 | if hps.n_bits_x < 8:
156 | x = tf.floor(x / 2 ** (8 - hps.n_bits_x))
157 | x = x / hps.n_bins - .5
158 | return x
159 |
160 | def postprocess(x):
161 | return tf.cast(tf.clip_by_value(tf.floor((x + .5)*hps.n_bins)*(256./hps.n_bins), 0, 255), 'uint8')
162 |
163 | def _f_loss(x, y, is_training, reuse=False):
164 |
165 | with tf.variable_scope('model', reuse=reuse):
166 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
167 |
168 | # Discrete -> Continuous
169 | objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
170 | z = preprocess(x)
171 | z = z + tf.random_uniform(tf.shape(z), 0, 1./hps.n_bins)
172 | objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])
173 |
174 | # Encode
175 | z = Z.squeeze2d(z, 2) # > 16x16x12
176 | z, objective, _ = encoder(z, objective)
177 |
178 | # Prior
179 | hps.top_shape = Z.int_shape(z)[1:]
180 | logp, _, _ = prior("prior", y_onehot, hps)
181 | objective += logp(z)
182 |
183 | # Generative loss
184 | nobj = - objective
185 | bits_x = nobj / (np.log(2.) * int(x.get_shape()[1]) * int(
186 | x.get_shape()[2]) * int(x.get_shape()[3])) # bits per subpixel
187 |
188 | # Predictive loss
189 | if hps.weight_y > 0 and hps.ycond:
190 |
191 | # Classification loss
192 | h_y = tf.reduce_mean(z, axis=[1, 2])
193 | y_logits = Z.linear_zeros("classifier", h_y, hps.n_y)
194 | bits_y = tf.nn.softmax_cross_entropy_with_logits_v2(
195 | labels=y_onehot, logits=y_logits) / np.log(2.)
196 |
197 | # Classification accuracy
198 | y_predicted = tf.argmax(y_logits, 1, output_type=tf.int32)
199 | classification_error = 1 - \
200 | tf.cast(tf.equal(y_predicted, y), tf.float32)
201 | else:
202 | bits_y = tf.zeros_like(bits_x)
203 | classification_error = tf.ones_like(bits_x)
204 |
205 | return bits_x, bits_y, classification_error
206 |
207 | def f_loss(iterator, is_training, reuse=False):
208 | if hps.direct_iterator and iterator is not None:
209 | x, y = iterator.get_next()
210 | else:
211 | x, y = X, Y
212 |
213 | bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse)
214 | local_loss = bits_x + hps.weight_y * bits_y
215 | stats = [local_loss, bits_x, bits_y, pred_loss]
216 | global_stats = Z.allreduce_mean(
217 | tf.stack([tf.reduce_mean(i) for i in stats]))
218 |
219 | return tf.reduce_mean(local_loss), global_stats
220 |
221 | feeds = {'x': X, 'y': Y}
222 | m = abstract_model_xy(sess, hps, feeds, train_iterator,
223 | test_iterator, data_init, lr, f_loss)
224 |
225 | # === Sampling function
226 | def f_sample(y, eps_std):
227 | with tf.variable_scope('model', reuse=True):
228 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
229 |
230 | _, sample, _ = prior("prior", y_onehot, hps)
231 | z = sample(eps_std=eps_std)
232 | z = decoder(z, eps_std=eps_std)
233 | z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3
234 | x = postprocess(z)
235 |
236 | return x
237 |
238 | m.eps_std = tf.placeholder(tf.float32, [None], name='eps_std')
239 | x_sampled = f_sample(Y, m.eps_std)
240 |
241 | def sample(_y, _eps_std):
242 | return m.sess.run(x_sampled, {Y: _y, m.eps_std: _eps_std})
243 | m.sample = sample
244 |
245 | if hps.inference:
246 | # === Encoder-Decoder functions
247 | def f_encode(x, y, reuse=True):
248 | with tf.variable_scope('model', reuse=reuse):
249 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
250 |
251 | # Discrete -> Continuous
252 | objective = tf.zeros_like(x, dtype='float32')[:, 0, 0, 0]
253 | z = preprocess(x)
254 | z = z + tf.random_uniform(tf.shape(z), 0, 1. / hps.n_bins)
255 | objective += - np.log(hps.n_bins) * np.prod(Z.int_shape(z)[1:])
256 |
257 | # Encode
258 | z = Z.squeeze2d(z, 2) # > 16x16x12
259 | z, objective, eps = encoder(z, objective)
260 |
261 | # Prior
262 | hps.top_shape = Z.int_shape(z)[1:]
263 | logp, _, _eps = prior("prior", y_onehot, hps)
264 | objective += logp(z)
265 | eps.append(_eps(z))
266 |
267 | return eps
268 |
269 | def f_decode(y, eps, reuse=True):
270 | with tf.variable_scope('model', reuse=reuse):
271 | y_onehot = tf.cast(tf.one_hot(y, hps.n_y, 1, 0), 'float32')
272 |
273 | _, sample, _ = prior("prior", y_onehot, hps)
274 | z = sample(eps=eps[-1])
275 | z = decoder(z, eps=eps[:-1])
276 | z = Z.unsqueeze2d(z, 2) # 8x8x12 -> 16x16x3
277 | x = postprocess(z)
278 |
279 | return x
280 |
281 | enc_eps = f_encode(X, Y)
282 | dec_eps = []
283 | print(enc_eps)
284 | for i, _eps in enumerate(enc_eps):
285 | print(_eps)
286 | dec_eps.append(tf.placeholder(tf.float32, _eps.get_shape().as_list(), name="dec_eps_" + str(i)))
287 | dec_x = f_decode(Y, dec_eps)
288 |
289 | eps_shapes = [_eps.get_shape().as_list()[1:] for _eps in enc_eps]
290 |
291 | def flatten_eps(eps):
292 | # [BS, eps_size]
293 | return np.concatenate([np.reshape(e, (e.shape[0], -1)) for e in eps], axis=-1)
294 |
295 | def unflatten_eps(feps):
296 | index = 0
297 | eps = []
298 | bs = feps.shape[0]
299 | for shape in eps_shapes:
300 | eps.append(np.reshape(feps[:, index: index+np.prod(shape)], (bs, *shape)))
301 | index += np.prod(shape)
302 | return eps
303 |
304 | # If model is uncondtional, always pass y = np.zeros([bs], dtype=np.int32)
305 | def encode(x, y):
306 | return flatten_eps(sess.run(enc_eps, {X: x, Y: y}))
307 |
308 | def decode(y, feps):
309 | eps = unflatten_eps(feps)
310 | feed_dict = {Y: y}
311 | for i in range(len(dec_eps)):
312 | feed_dict[dec_eps[i]] = eps[i]
313 | return sess.run(dec_x, feed_dict)
314 |
315 | m.encode = encode
316 | m.decode = decode
317 |
318 | return m
319 |
320 |
321 | def checkpoint(z, logdet):
322 | zshape = Z.int_shape(z)
323 | z = tf.reshape(z, [-1, zshape[1]*zshape[2]*zshape[3]])
324 | logdet = tf.reshape(logdet, [-1, 1])
325 | combined = tf.concat([z, logdet], axis=1)
326 | tf.add_to_collection('checkpoints', combined)
327 | logdet = combined[:, -1]
328 | z = tf.reshape(combined[:, :-1], [-1, zshape[1], zshape[2], zshape[3]])
329 | return z, logdet
330 |
331 |
332 | @add_arg_scope
333 | def revnet2d(name, z, logdet, hps, reverse=False):
334 | with tf.variable_scope(name):
335 | if not reverse:
336 | for i in range(hps.depth):
337 | z, logdet = checkpoint(z, logdet)
338 | z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse)
339 | z, logdet = checkpoint(z, logdet)
340 | else:
341 | for i in reversed(range(hps.depth)):
342 | z, logdet = revnet2d_step(str(i), z, logdet, hps, reverse)
343 | return z, logdet
344 |
345 | # Simpler, new version
346 | @add_arg_scope
347 | def revnet2d_step(name, z, logdet, hps, reverse):
348 | with tf.variable_scope(name):
349 |
350 | shape = Z.int_shape(z)
351 | n_z = shape[3]
352 | assert n_z % 2 == 0
353 |
354 | if not reverse:
355 |
356 | z, logdet = Z.actnorm("actnorm", z, logdet=logdet)
357 |
358 | if hps.flow_permutation == 0:
359 | z = Z.reverse_features("reverse", z)
360 | elif hps.flow_permutation == 1:
361 | z = Z.shuffle_features("shuffle", z)
362 | elif hps.flow_permutation == 2:
363 | z, logdet = invertible_1x1_conv("invconv", z, logdet)
364 | else:
365 | raise Exception()
366 |
367 | z1 = z[:, :, :, :n_z // 2]
368 | z2 = z[:, :, :, n_z // 2:]
369 |
370 | if hps.flow_coupling == 0:
371 | z2 += f("f1", z1, hps.width)
372 | elif hps.flow_coupling == 1:
373 | h = f("f1", z1, hps.width, n_z)
374 | shift = h[:, :, :, 0::2]
375 | # scale = tf.exp(h[:, :, :, 1::2])
376 | scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
377 | z2 += shift
378 | z2 *= scale
379 | logdet += tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
380 | else:
381 | raise Exception()
382 |
383 | z = tf.concat([z1, z2], 3)
384 |
385 | else:
386 |
387 | z1 = z[:, :, :, :n_z // 2]
388 | z2 = z[:, :, :, n_z // 2:]
389 |
390 | if hps.flow_coupling == 0:
391 | z2 -= f("f1", z1, hps.width)
392 | elif hps.flow_coupling == 1:
393 | h = f("f1", z1, hps.width, n_z)
394 | shift = h[:, :, :, 0::2]
395 | # scale = tf.exp(h[:, :, :, 1::2])
396 | scale = tf.nn.sigmoid(h[:, :, :, 1::2] + 2.)
397 | z2 /= scale
398 | z2 -= shift
399 | logdet -= tf.reduce_sum(tf.log(scale), axis=[1, 2, 3])
400 | else:
401 | raise Exception()
402 |
403 | z = tf.concat([z1, z2], 3)
404 |
405 | if hps.flow_permutation == 0:
406 | z = Z.reverse_features("reverse", z, reverse=True)
407 | elif hps.flow_permutation == 1:
408 | z = Z.shuffle_features("shuffle", z, reverse=True)
409 | elif hps.flow_permutation == 2:
410 | z, logdet = invertible_1x1_conv(
411 | "invconv", z, logdet, reverse=True)
412 | else:
413 | raise Exception()
414 |
415 | z, logdet = Z.actnorm("actnorm", z, logdet=logdet, reverse=True)
416 |
417 | return z, logdet
418 |
419 |
420 | def f(name, h, width, n_out=None):
421 | n_out = n_out or int(h.get_shape()[3])
422 | with tf.variable_scope(name):
423 | h = tf.nn.relu(Z.conv2d("l_1", h, width))
424 | h = tf.nn.relu(Z.conv2d("l_2", h, width, filter_size=[1, 1]))
425 | h = Z.conv2d_zeros("l_last", h, n_out)
426 | return h
427 |
428 |
429 | def f_resnet(name, h, width, n_out=None):
430 | n_out = n_out or int(h.get_shape()[3])
431 | with tf.variable_scope(name):
432 | h = tf.nn.relu(Z.conv2d("l_1", h, width))
433 | h = Z.conv2d_zeros("l_2", h, n_out)
434 | return h
435 |
436 | # Invertible 1x1 conv
437 | @add_arg_scope
438 | def invertible_1x1_conv(name, z, logdet, reverse=False):
439 |
440 | if True: # Set to "False" to use the LU-decomposed version
441 |
442 | with tf.variable_scope(name):
443 |
444 | shape = Z.int_shape(z)
445 | w_shape = [shape[3], shape[3]]
446 |
447 | # Sample a random orthogonal matrix:
448 | w_init = np.linalg.qr(np.random.randn(
449 | *w_shape))[0].astype('float32')
450 |
451 | w = tf.get_variable("W", dtype=tf.float32, initializer=w_init)
452 |
453 | # dlogdet = tf.linalg.LinearOperator(w).log_abs_determinant() * shape[1]*shape[2]
454 | dlogdet = tf.cast(tf.log(abs(tf.matrix_determinant(
455 | tf.cast(w, 'float64')))), 'float32') * shape[1]*shape[2]
456 |
457 | if not reverse:
458 |
459 | _w = tf.reshape(w, [1, 1] + w_shape)
460 | z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
461 | 'SAME', data_format='NHWC')
462 | logdet += dlogdet
463 |
464 | return z, logdet
465 | else:
466 |
467 | _w = tf.matrix_inverse(w)
468 | _w = tf.reshape(_w, [1, 1]+w_shape)
469 | z = tf.nn.conv2d(z, _w, [1, 1, 1, 1],
470 | 'SAME', data_format='NHWC')
471 | logdet -= dlogdet
472 |
473 | return z, logdet
474 |
475 | else:
476 |
477 | # LU-decomposed version
478 | shape = Z.int_shape(z)
479 | with tf.variable_scope(name):
480 |
481 | dtype = 'float64'
482 |
483 | # Random orthogonal matrix:
484 | import scipy
485 | np_w = scipy.linalg.qr(np.random.randn(shape[3], shape[3]))[
486 | 0].astype('float32')
487 |
488 | np_p, np_l, np_u = scipy.linalg.lu(np_w)
489 | np_s = np.diag(np_u)
490 | np_sign_s = np.sign(np_s)
491 | np_log_s = np.log(abs(np_s))
492 | np_u = np.triu(np_u, k=1)
493 |
494 | p = tf.get_variable("P", initializer=np_p, trainable=False)
495 | l = tf.get_variable("L", initializer=np_l)
496 | sign_s = tf.get_variable(
497 | "sign_S", initializer=np_sign_s, trainable=False)
498 | log_s = tf.get_variable("log_S", initializer=np_log_s)
499 | # S = tf.get_variable("S", initializer=np_s)
500 | u = tf.get_variable("U", initializer=np_u)
501 |
502 | p = tf.cast(p, dtype)
503 | l = tf.cast(l, dtype)
504 | sign_s = tf.cast(sign_s, dtype)
505 | log_s = tf.cast(log_s, dtype)
506 | u = tf.cast(u, dtype)
507 |
508 | w_shape = [shape[3], shape[3]]
509 |
510 | l_mask = np.tril(np.ones(w_shape, dtype=dtype), -1)
511 | l = l * l_mask + tf.eye(*w_shape, dtype=dtype)
512 | u = u * np.transpose(l_mask) + tf.diag(sign_s * tf.exp(log_s))
513 | w = tf.matmul(p, tf.matmul(l, u))
514 |
515 | if True:
516 | u_inv = tf.matrix_inverse(u)
517 | l_inv = tf.matrix_inverse(l)
518 | p_inv = tf.matrix_inverse(p)
519 | w_inv = tf.matmul(u_inv, tf.matmul(l_inv, p_inv))
520 | else:
521 | w_inv = tf.matrix_inverse(w)
522 |
523 | w = tf.cast(w, tf.float32)
524 | w_inv = tf.cast(w_inv, tf.float32)
525 | log_s = tf.cast(log_s, tf.float32)
526 |
527 | if not reverse:
528 |
529 | w = tf.reshape(w, [1, 1] + w_shape)
530 | z = tf.nn.conv2d(z, w, [1, 1, 1, 1],
531 | 'SAME', data_format='NHWC')
532 | logdet += tf.reduce_sum(log_s) * (shape[1]*shape[2])
533 |
534 | return z, logdet
535 | else:
536 |
537 | w_inv = tf.reshape(w_inv, [1, 1]+w_shape)
538 | z = tf.nn.conv2d(
539 | z, w_inv, [1, 1, 1, 1], 'SAME', data_format='NHWC')
540 | logdet -= tf.reduce_sum(log_s) * (shape[1]*shape[2])
541 |
542 | return z, logdet
543 |
544 |
545 | @add_arg_scope
546 | def split2d(name, z, objective=0.):
547 | with tf.variable_scope(name):
548 | n_z = Z.int_shape(z)[3]
549 | z1 = z[:, :, :, :n_z // 2]
550 | z2 = z[:, :, :, n_z // 2:]
551 | pz = split2d_prior(z1)
552 | objective += pz.logp(z2)
553 | z1 = Z.squeeze2d(z1)
554 | eps = pz.get_eps(z2)
555 | return z1, objective, eps
556 |
557 |
558 | @add_arg_scope
559 | def split2d_reverse(name, z, eps, eps_std):
560 | with tf.variable_scope(name):
561 | z1 = Z.unsqueeze2d(z)
562 | pz = split2d_prior(z1)
563 | if eps is not None:
564 | # Already sampled eps
565 | z2 = pz.sample2(eps)
566 | elif eps_std is not None:
567 | # Sample with given eps_std
568 | z2 = pz.sample2(pz.eps * tf.reshape(eps_std, [-1, 1, 1, 1]))
569 | else:
570 | # Sample normally
571 | z2 = pz.sample
572 | z = tf.concat([z1, z2], 3)
573 | return z
574 |
575 |
576 | @add_arg_scope
577 | def split2d_prior(z):
578 | n_z2 = int(z.get_shape()[3])
579 | n_z1 = n_z2
580 | h = Z.conv2d_zeros("conv", z, 2 * n_z1)
581 |
582 | mean = h[:, :, :, 0::2]
583 | logs = h[:, :, :, 1::2]
584 | return Z.gaussian_diag(mean, logs)
585 |
--------------------------------------------------------------------------------