├── .gitignore ├── README.md ├── affmnist.py ├── cifar10.py ├── dataset ├── fashion.py ├── gan_models ├── __init__.py ├── config.py ├── discriminator.py ├── generator.py ├── model.py └── train.py ├── inception_score ├── eval_affmnist.py ├── eval_cifar10.py ├── eval_mnist.py ├── model_affmnist.py ├── model_cifar10.py ├── model_mnist.py ├── train_affmnist_classifier.py └── train_mnist_classifier.py ├── mnist.py ├── ops.py ├── run_gan.py ├── run_vae.py ├── utils.py └── vae_models ├── __init__.py ├── config.py ├── decoder.py ├── encoder.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | #**.pyc 2 | #**/*.pyc 3 | # Data 4 | data 5 | exps 6 | samples 7 | *.zip 8 | *.ckpt* 9 | *.gz 10 | *events.* 11 | logs 12 | log_* 13 | */log/ 14 | test* 15 | 16 | web/js/gen_layers.js 17 | 18 | # checkpoint 19 | checkpoint 20 | checkpoints 21 | inception_score/checkpoints 22 | inception_score/affmnist_checkpoints 23 | 24 | # trash 25 | .dropbox 26 | 27 | # Created by https://www.gitignore.io/api/python,vim 28 | 29 | ### Python ### 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | env/ 41 | build/ 42 | develop-eggs/ 43 | dist/ 44 | downloads/ 45 | eggs/ 46 | .eggs/ 47 | lib/ 48 | lib64/ 49 | parts/ 50 | sdist/ 51 | var/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *,cover 75 | .hypothesis/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | 91 | ### Vim ### 92 | [._]*.s[a-w][a-z] 93 | [._]s[a-w][a-z] 94 | *.un~ 95 | Session.vim 96 | .netrwhist 97 | *~ 98 | 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generic GAN framework for experiments 2 | 3 | ## Key Files 4 | 5 | #### GAN 6 | 7 | * run_gan.py 8 | * gan_models/config.py 9 | * gan_models/model.py 10 | * gan_models/generator.py 11 | * gan_models/discriminator.py 12 | 13 | #### VAE 14 | 15 | * run_vae.py 16 | * vae_models/config.py 17 | * vae_models/model.py 18 | * vae_models/encoder.py 19 | * vae_models/decoder.py 20 | 21 | ## Getting Started with Base Model 22 | 23 | * Choose dataset and set hyperparamters in `run_gan.py` or `run_vae.py` 24 | ``` 25 | MNIST, Affined MNIST, Fashion-MNIST and CIFAR10 are supported by default. 26 | Our data loader automatically download dataset and offers batch sampling method `next_batch()`. 27 | See '{DATASET_NAME}.py' scrips in project root folder for detail, or see `gan_models/train.py` and `vae_models/train.py` for usage example. 28 | ``` 29 | 30 | * Run it. 31 | 32 | ## Implementing Custom Network 33 | 34 | #### GAN (Discriminator or Generator) 35 | * Define your function in `gan_models/discriminator.py` or `gan_models/generator.py` 36 | * Open `run_gan.py` and set `generator` or `discriminator` argument to the name of your new function. 37 | 38 | #### VAE (Encoder or Decoder) 39 | * Define your function in `vae_models/encoder.py` or `vae_models/decoder.py` 40 | * Open `run_vae.py` and set `encoder` or `decoder` argument to the name of your new function. 41 | 42 | ## Implementing Whole-New Model 43 | 44 | * Make a new model class that inherits `gan_models/model.py` or `vae_models/model.py` and place it in `gan_models` or `vae_models`. 45 | * If you want, make new discriminator, generator, encoder or decoder as guided above. 46 | 47 | ## Evaluation (GAN only) 48 | Supports MNIST, Affined MNIST, Fashion-MNIST and CIFAR10. 49 | train classification model using 50 | 51 | ``` 52 | inception_score/train_{DATASET}_classifier.py 53 | ``` 54 | 55 | Evaluate inception score using 56 | 57 | ``` 58 | inception_score/eval_{DATASET}.py 59 | ``` 60 | 61 | Classification models are defined in 62 | 63 | ``` 64 | inception_score/model_{DATASET}.py 65 | ``` 66 | -------------------------------------------------------------------------------- /affmnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gzip 23 | 24 | import numpy 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | 27 | from tensorflow.contrib.learn.python.learn.datasets import base 28 | from tensorflow.python.framework import dtypes 29 | from scipy.ndimage import rotate 30 | from matplotlib import pyplot as plt 31 | 32 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 33 | 34 | 35 | def _read32(bytestream): 36 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 37 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 38 | 39 | 40 | def extract_images(f): 41 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. 42 | 43 | Args: 44 | f: A file object that can be passed into a gzip reader. 45 | 46 | Returns: 47 | data: A 4D uint8 numpy array [index, y, x, depth]. 48 | 49 | Raises: 50 | ValueError: If the bytestream does not start with 2051. 51 | 52 | """ 53 | print('Extracting', f.name) 54 | with gzip.GzipFile(fileobj=f) as bytestream: 55 | magic = _read32(bytestream) 56 | if magic != 2051: 57 | raise ValueError('Invalid magic number %d in MNIST image file: %s' % 58 | (magic, f.name)) 59 | num_images = _read32(bytestream) 60 | rows = _read32(bytestream) 61 | cols = _read32(bytestream) 62 | buf = bytestream.read(rows * cols * num_images) 63 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 64 | data = data.reshape(num_images, rows, cols, 1) 65 | return data 66 | 67 | 68 | def dense_to_one_hot(labels_dense, num_classes): 69 | """Convert class labels from scalars to one-hot vectors.""" 70 | num_labels = labels_dense.shape[0] 71 | index_offset = numpy.arange(num_labels) * num_classes 72 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 73 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 74 | return labels_one_hot 75 | 76 | 77 | def extract_labels(f, one_hot=False, num_classes=10): 78 | """Extract the labels into a 1D uint8 numpy array [index]. 79 | 80 | Args: 81 | f: A file object that can be passed into a gzip reader. 82 | one_hot: Does one hot encoding for the result. 83 | num_classes: Number of classes for the one hot encoding. 84 | 85 | Returns: 86 | labels: a 1D uint8 numpy array. 87 | 88 | Raises: 89 | ValueError: If the bystream doesn't start with 2049. 90 | """ 91 | print('Extracting', f.name) 92 | with gzip.GzipFile(fileobj=f) as bytestream: 93 | magic = _read32(bytestream) 94 | if magic != 2049: 95 | raise ValueError('Invalid magic number %d in MNIST label file: %s' % 96 | (magic, f.name)) 97 | num_items = _read32(bytestream) 98 | buf = bytestream.read(num_items) 99 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 100 | if one_hot: 101 | return dense_to_one_hot(labels, num_classes) 102 | return labels 103 | 104 | 105 | class DataSet(object): 106 | 107 | def __init__(self, 108 | images, 109 | labels, 110 | fake_data=False, 111 | one_hot=False, 112 | dtype=dtypes.float32, 113 | reshape=True): 114 | """Construct a DataSet. 115 | one_hot arg is used only if fake_data is true. `dtype` can be either 116 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 117 | `[0, 1]`. 118 | """ 119 | dtype = dtypes.as_dtype(dtype).base_dtype 120 | if dtype not in (dtypes.uint8, dtypes.float32): 121 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 122 | dtype) 123 | if fake_data: 124 | self._num_examples = 10000 125 | self.one_hot = one_hot 126 | else: 127 | assert images.shape[0] == labels.shape[0], ( 128 | 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) 129 | self._num_examples = images.shape[0] 130 | 131 | # Convert shape from [num examples, rows, columns, depth] 132 | # to [num examples, rows*columns] (assuming depth == 1) 133 | if reshape: 134 | assert images.shape[3] == 1 135 | images = images.reshape(images.shape[0], 136 | images.shape[1] * images.shape[2]) 137 | if dtype == dtypes.float32: 138 | # Convert from [0, 255] -> [0.0, 1.0]. 139 | images = images.astype(numpy.float32) 140 | images = numpy.multiply(images, 1.0 / 255.0) 141 | self._images = images 142 | self._labels = labels 143 | self._epochs_completed = 0 144 | self._index_in_epoch = 0 145 | 146 | @property 147 | def images(self): 148 | return self._images 149 | 150 | @property 151 | def labels(self): 152 | return self._labels 153 | 154 | @property 155 | def num_examples(self): 156 | return self._num_examples 157 | 158 | @property 159 | def epochs_completed(self): 160 | return self._epochs_completed 161 | 162 | def next_batch(self, batch_size, fake_data=False): 163 | """Return the next `batch_size` examples from this data set.""" 164 | if fake_data: 165 | fake_image = [1] * 784 166 | if self.one_hot: 167 | fake_label = [1] + [0] * 9 168 | else: 169 | fake_label = 0 170 | return [fake_image for _ in xrange(batch_size)], [ 171 | fake_label for _ in xrange(batch_size) 172 | ] 173 | 174 | 175 | start = self._index_in_epoch 176 | self._index_in_epoch += batch_size 177 | if self._index_in_epoch > self._num_examples: 178 | # Finished epoch 179 | self._epochs_completed += 1 180 | # Shuffle the data 181 | perm = numpy.arange(self._num_examples) 182 | numpy.random.shuffle(perm) 183 | self._images = self._images[perm] 184 | self._labels = self._labels[perm] 185 | # Start next epoch 186 | start = 0 187 | self._index_in_epoch = batch_size 188 | assert batch_size <= self._num_examples 189 | 190 | # edit for moving mnist 191 | end = self._index_in_epoch 192 | move = numpy.zeros((batch_size, 40, 40, 1)) # initialize 193 | 194 | for i in range(start, end): 195 | degree = numpy.random.randint(41) - 20 196 | h_move = numpy.random.randint(12) 197 | w_move = numpy.random.randint(12) 198 | rot = rotate(self._images[i], degree, reshape = False) 199 | move[i-start][h_move:h_move+28, w_move:w_move+28] = rot 200 | 201 | return move, self._labels[start:end] 202 | 203 | def read_data_sets(train_dir, 204 | fake_data=False, 205 | one_hot=False, 206 | dtype=dtypes.float32, 207 | reshape=True, 208 | validation_size=5000): 209 | if fake_data: 210 | 211 | def fake(): 212 | return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 213 | 214 | train = fake() 215 | validation = fake() 216 | test = fake() 217 | return base.Datasets(train=train, validation=validation, test=test) 218 | 219 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 220 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 221 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 222 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 223 | 224 | local_file = base.maybe_download(TRAIN_IMAGES, train_dir, 225 | SOURCE_URL + TRAIN_IMAGES) 226 | with open(local_file, 'rb') as f: 227 | train_images = extract_images(f) 228 | 229 | local_file = base.maybe_download(TRAIN_LABELS, train_dir, 230 | SOURCE_URL + TRAIN_LABELS) 231 | with open(local_file, 'rb') as f: 232 | train_labels = extract_labels(f, one_hot=one_hot) 233 | 234 | local_file = base.maybe_download(TEST_IMAGES, train_dir, 235 | SOURCE_URL + TEST_IMAGES) 236 | with open(local_file, 'rb') as f: 237 | test_images = extract_images(f) 238 | 239 | local_file = base.maybe_download(TEST_LABELS, train_dir, 240 | SOURCE_URL + TEST_LABELS) 241 | with open(local_file, 'rb') as f: 242 | test_labels = extract_labels(f, one_hot=one_hot) 243 | 244 | if not 0 <= validation_size <= len(train_images): 245 | raise ValueError( 246 | 'Validation size should be between 0 and {}. Received: {}.' 247 | .format(len(train_images), validation_size)) 248 | 249 | validation_images = train_images[:validation_size] 250 | validation_labels = train_labels[:validation_size] 251 | train_images = train_images[validation_size:] 252 | train_labels = train_labels[validation_size:] 253 | 254 | train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape) 255 | validation = DataSet(validation_images, 256 | validation_labels, 257 | dtype=dtype, 258 | reshape=reshape) 259 | test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape) 260 | 261 | return base.Datasets(train=train, validation=validation, test=test) 262 | 263 | 264 | def load_affmnist(train_dir='MNIST-data'): 265 | return read_data_sets(train_dir) 266 | -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | """Functions for downloading and reading MNIST data.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import gzip 24 | 25 | import os 26 | import numpy 27 | from six.moves import xrange # pylint: disable=redefined-builtin 28 | 29 | from tensorflow.contrib.learn.python.learn.datasets import base 30 | from tensorflow.python.framework import dtypes 31 | 32 | SOURCE_URL = 'https://www.cs.toronto.edu/~kriz/' 33 | 34 | 35 | def _read32(bytestream): 36 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 37 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 38 | 39 | 40 | def extract_images(f): 41 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. 42 | 43 | Args: 44 | f: A file object that can be passed into a gzip reader. 45 | 46 | Returns: 47 | data: A 4D uint8 numpy array [index, y, x, depth]. 48 | 49 | Raises: 50 | ValueError: If the bytestream does not start with 2051. 51 | 52 | """ 53 | print('Extracting', f.name) 54 | with gzip.GzipFile(fileobj=f) as bytestream: 55 | magic = _read32(bytestream) 56 | if magic != 2051: 57 | raise ValueError('Invalid magic number %d in MNIST image file: %s' % 58 | (magic, f.name)) 59 | num_images = _read32(bytestream) 60 | rows = _read32(bytestream) 61 | cols = _read32(bytestream) 62 | buf = bytestream.read(rows * cols * num_images) 63 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 64 | data = data.reshape(num_images, rows, cols, 1) 65 | return data 66 | 67 | 68 | def dense_to_one_hot(labels_dense, num_classes): 69 | """Convert class labels from scalars to one-hot vectors.""" 70 | num_labels = labels_dense.shape[0] 71 | index_offset = numpy.arange(num_labels) * num_classes 72 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 73 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 74 | return labels_one_hot 75 | 76 | 77 | def extract_labels(f, one_hot=False, num_classes=10): 78 | """Extract the labels into a 1D uint8 numpy array [index]. 79 | 80 | Args: 81 | f: A file object that can be passed into a gzip reader. 82 | one_hot: Does one hot encoding for the result. 83 | num_classes: Number of classes for the one hot encoding. 84 | 85 | Returns: 86 | labels: a 1D uint8 numpy array. 87 | 88 | Raises: 89 | ValueError: If the bystream doesn't start with 2049. 90 | """ 91 | print('Extracting', f.name) 92 | with gzip.GzipFile(fileobj=f) as bytestream: 93 | magic = _read32(bytestream) 94 | if magic != 2049: 95 | raise ValueError('Invalid magic number %d in MNIST label file: %s' % 96 | (magic, f.name)) 97 | num_items = _read32(bytestream) 98 | buf = bytestream.read(num_items) 99 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 100 | if one_hot: 101 | return dense_to_one_hot(labels, num_classes) 102 | return labels 103 | 104 | 105 | class DataSet(object): 106 | 107 | def __init__(self, 108 | images, 109 | labels, 110 | fake_data=False, 111 | one_hot=False, 112 | dtype=dtypes.float32, 113 | reshape=True): 114 | """Construct a DataSet. 115 | one_hot arg is used only if fake_data is true. `dtype` can be either 116 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 117 | `[0, 1]`. 118 | """ 119 | dtype = dtypes.as_dtype(dtype).base_dtype 120 | if dtype not in (dtypes.uint8, dtypes.float32): 121 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 122 | dtype) 123 | if fake_data: 124 | self._num_examples = 10000 125 | self.one_hot = one_hot 126 | else: 127 | assert images.shape[0] == labels.shape[0], ( 128 | 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) 129 | self._num_examples = images.shape[0] 130 | 131 | # Convert shape from [num examples, rows, columns, depth] 132 | # to [num examples, rows*columns] (assuming depth == 1) 133 | if reshape: 134 | assert images.shape[3] == 1 135 | images = images.reshape(images.shape[0], 136 | images.shape[1] * images.shape[2]) 137 | if dtype == dtypes.float32: 138 | # Convert from [0, 255] -> [0.0, 1.0]. 139 | images = images.astype(numpy.float32) 140 | images = numpy.multiply(images, 1.0 / 255.0) 141 | self._images = images 142 | self._labels = labels 143 | self._epochs_completed = 0 144 | self._index_in_epoch = 0 145 | 146 | @property 147 | def images(self): 148 | return self._images 149 | 150 | @property 151 | def labels(self): 152 | return self._labels 153 | 154 | @property 155 | def num_examples(self): 156 | return self._num_examples 157 | 158 | @property 159 | def epochs_completed(self): 160 | return self._epochs_completed 161 | 162 | def next_batch(self, batch_size, fake_data=False): 163 | """Return the next `batch_size` examples from this data set.""" 164 | if fake_data: 165 | fake_image = [1] * 784 166 | if self.one_hot: 167 | fake_label = [1] + [0] * 9 168 | else: 169 | fake_label = 0 170 | return [fake_image for _ in xrange(batch_size)], [ 171 | fake_label for _ in xrange(batch_size) 172 | ] 173 | start = self._index_in_epoch 174 | self._index_in_epoch += batch_size 175 | if self._index_in_epoch > self._num_examples: 176 | # Finished epoch 177 | self._epochs_completed += 1 178 | # Shuffle the data 179 | perm = numpy.arange(self._num_examples) 180 | numpy.random.shuffle(perm) 181 | self._images = self._images[perm] 182 | self._labels = self._labels[perm] 183 | # Start next epoch 184 | start = 0 185 | self._index_in_epoch = batch_size 186 | assert batch_size <= self._num_examples 187 | end = self._index_in_epoch 188 | return self._images[start:end], self._labels[start:end] 189 | 190 | 191 | def read_data_sets(train_dir, 192 | fake_data=False, 193 | one_hot=False, 194 | dtype=dtypes.float32, 195 | reshape=True, 196 | validation_size=5000): 197 | if fake_data: 198 | 199 | def fake(): 200 | return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 201 | 202 | train = fake() 203 | validation = fake() 204 | test = fake() 205 | return base.Datasets(train=train, validation=validation, test=test) 206 | 207 | 208 | gz_file_name = 'cifar-10-python.tar.gz' 209 | 210 | local_file = base.maybe_download(gz_file_name, train_dir, 211 | SOURCE_URL + gz_file_name) 212 | 213 | train_images = [] 214 | train_labels = [] 215 | for i in range(1, 6): 216 | with open(os.path.join(train_dir, 'cifar-10-batches-py', 'data_batch_%d'%i)) as f: 217 | batch = numpy.load(f) 218 | tmp_images = batch['data'].reshape([-1, 3, 32, 32]) 219 | train_images.append(tmp_images.transpose([0, 2, 3, 1])) 220 | train_labels += batch['labels'] 221 | train_images = numpy.concatenate(train_images) 222 | train_labels = numpy.array(train_labels) 223 | 224 | if not 0 <= validation_size <= len(train_images): 225 | raise ValueError( 226 | 'Validation size should be between 0 and {}. Received: {}.' 227 | .format(len(train_images), validation_size)) 228 | 229 | validation_images = train_images[:validation_size] 230 | validation_labels = train_labels[:validation_size] 231 | train_images = train_images[validation_size:] 232 | train_labels = train_labels[validation_size:] 233 | 234 | train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape) 235 | validation = DataSet(validation_images, 236 | validation_labels, 237 | dtype=dtype, 238 | reshape=reshape) 239 | #test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape) 240 | test = None 241 | 242 | return base.Datasets(train=train, validation=validation, test=test) 243 | 244 | 245 | def load_cifar10(train_dir='CIFAR10-data'): 246 | return read_data_sets(train_dir) 247 | -------------------------------------------------------------------------------- /dataset: -------------------------------------------------------------------------------- 1 | /data2/whyjay/CVPR2018_data -------------------------------------------------------------------------------- /fashion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gzip 23 | 24 | import numpy 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | 27 | from tensorflow.contrib.learn.python.learn.datasets import base 28 | from tensorflow.python.framework import dtypes 29 | 30 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 31 | 32 | 33 | def _read32(bytestream): 34 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 35 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 36 | 37 | 38 | def extract_images(f): 39 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. 40 | 41 | Args: 42 | f: A file object that can be passed into a gzip reader. 43 | 44 | Returns: 45 | data: A 4D uint8 numpy array [index, y, x, depth]. 46 | 47 | Raises: 48 | ValueError: If the bytestream does not start with 2051. 49 | 50 | """ 51 | print('Extracting', f.name) 52 | with gzip.GzipFile(fileobj=f) as bytestream: 53 | magic = _read32(bytestream) 54 | if magic != 2051: 55 | raise ValueError('Invalid magic number %d in MNIST image file: %s' % 56 | (magic, f.name)) 57 | num_images = _read32(bytestream) 58 | rows = _read32(bytestream) 59 | cols = _read32(bytestream) 60 | buf = bytestream.read(rows * cols * num_images) 61 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 62 | data = data.reshape(num_images, rows, cols, 1) 63 | return data 64 | 65 | 66 | def dense_to_one_hot(labels_dense, num_classes): 67 | """Convert class labels from scalars to one-hot vectors.""" 68 | num_labels = labels_dense.shape[0] 69 | index_offset = numpy.arange(num_labels) * num_classes 70 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 71 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 72 | return labels_one_hot 73 | 74 | 75 | def extract_labels(f, one_hot=False, num_classes=10): 76 | """Extract the labels into a 1D uint8 numpy array [index]. 77 | 78 | Args: 79 | f: A file object that can be passed into a gzip reader. 80 | one_hot: Does one hot encoding for the result. 81 | num_classes: Number of classes for the one hot encoding. 82 | 83 | Returns: 84 | labels: a 1D uint8 numpy array. 85 | 86 | Raises: 87 | ValueError: If the bystream doesn't start with 2049. 88 | """ 89 | print('Extracting', f.name) 90 | with gzip.GzipFile(fileobj=f) as bytestream: 91 | magic = _read32(bytestream) 92 | if magic != 2049: 93 | raise ValueError('Invalid magic number %d in MNIST label file: %s' % 94 | (magic, f.name)) 95 | num_items = _read32(bytestream) 96 | buf = bytestream.read(num_items) 97 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 98 | if one_hot: 99 | return dense_to_one_hot(labels, num_classes) 100 | return labels 101 | 102 | 103 | class DataSet(object): 104 | 105 | def __init__(self, 106 | images, 107 | labels, 108 | fake_data=False, 109 | one_hot=False, 110 | dtype=dtypes.float32, 111 | reshape=True): 112 | """Construct a DataSet. 113 | one_hot arg is used only if fake_data is true. `dtype` can be either 114 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 115 | `[0, 1]`. 116 | """ 117 | dtype = dtypes.as_dtype(dtype).base_dtype 118 | if dtype not in (dtypes.uint8, dtypes.float32): 119 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 120 | dtype) 121 | if fake_data: 122 | self._num_examples = 10000 123 | self.one_hot = one_hot 124 | else: 125 | assert images.shape[0] == labels.shape[0], ( 126 | 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) 127 | self._num_examples = images.shape[0] 128 | 129 | # Convert shape from [num examples, rows, columns, depth] 130 | # to [num examples, rows*columns] (assuming depth == 1) 131 | if reshape: 132 | assert images.shape[3] == 1 133 | images = images.reshape(images.shape[0], 134 | images.shape[1] * images.shape[2]) 135 | if dtype == dtypes.float32: 136 | # Convert from [0, 255] -> [0.0, 1.0]. 137 | images = images.astype(numpy.float32) 138 | images = numpy.multiply(images, 1.0 / 255.0) 139 | self._images = images 140 | self._labels = labels 141 | self._epochs_completed = 0 142 | self._index_in_epoch = 0 143 | 144 | @property 145 | def images(self): 146 | return self._images 147 | 148 | @property 149 | def labels(self): 150 | return self._labels 151 | 152 | @property 153 | def num_examples(self): 154 | return self._num_examples 155 | 156 | @property 157 | def epochs_completed(self): 158 | return self._epochs_completed 159 | 160 | def next_batch(self, batch_size, fake_data=False): 161 | """Return the next `batch_size` examples from this data set.""" 162 | if fake_data: 163 | fake_image = [1] * 784 164 | if self.one_hot: 165 | fake_label = [1] + [0] * 9 166 | else: 167 | fake_label = 0 168 | return [fake_image for _ in xrange(batch_size)], [ 169 | fake_label for _ in xrange(batch_size) 170 | ] 171 | start = self._index_in_epoch 172 | self._index_in_epoch += batch_size 173 | if self._index_in_epoch > self._num_examples: 174 | # Finished epoch 175 | self._epochs_completed += 1 176 | # Shuffle the data 177 | perm = numpy.arange(self._num_examples) 178 | numpy.random.shuffle(perm) 179 | self._images = self._images[perm] 180 | self._labels = self._labels[perm] 181 | # Start next epoch 182 | start = 0 183 | self._index_in_epoch = batch_size 184 | assert batch_size <= self._num_examples 185 | end = self._index_in_epoch 186 | return self._images[start:end], self._labels[start:end] 187 | 188 | 189 | def read_data_sets(train_dir, 190 | fake_data=False, 191 | one_hot=False, 192 | dtype=dtypes.float32, 193 | reshape=True, 194 | validation_size=5000): 195 | if fake_data: 196 | 197 | def fake(): 198 | return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 199 | 200 | train = fake() 201 | validation = fake() 202 | test = fake() 203 | return base.Datasets(train=train, validation=validation, test=test) 204 | 205 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 206 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 207 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 208 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 209 | 210 | local_file = base.maybe_download(TRAIN_IMAGES, train_dir, 211 | SOURCE_URL + TRAIN_IMAGES) 212 | with open(local_file, 'rb') as f: 213 | train_images = extract_images(f) 214 | 215 | local_file = base.maybe_download(TRAIN_LABELS, train_dir, 216 | SOURCE_URL + TRAIN_LABELS) 217 | with open(local_file, 'rb') as f: 218 | train_labels = extract_labels(f, one_hot=one_hot) 219 | 220 | local_file = base.maybe_download(TEST_IMAGES, train_dir, 221 | SOURCE_URL + TEST_IMAGES) 222 | with open(local_file, 'rb') as f: 223 | test_images = extract_images(f) 224 | 225 | local_file = base.maybe_download(TEST_LABELS, train_dir, 226 | SOURCE_URL + TEST_LABELS) 227 | with open(local_file, 'rb') as f: 228 | test_labels = extract_labels(f, one_hot=one_hot) 229 | 230 | if not 0 <= validation_size <= len(train_images): 231 | raise ValueError( 232 | 'Validation size should be between 0 and {}. Received: {}.' 233 | .format(len(train_images), validation_size)) 234 | 235 | validation_images = train_images[:validation_size] 236 | validation_labels = train_labels[:validation_size] 237 | train_images = train_images[validation_size:] 238 | train_labels = train_labels[validation_size:] 239 | 240 | train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape) 241 | validation = DataSet(validation_images, 242 | validation_labels, 243 | dtype=dtype, 244 | reshape=reshape) 245 | test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape) 246 | 247 | return base.Datasets(train=train, validation=validation, test=test) 248 | 249 | 250 | def load_mnist(train_dir='MNIST-data'): 251 | return read_data_sets(train_dir) 252 | -------------------------------------------------------------------------------- /gan_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whyjay/GENERATIVE_MODEL_TESTBED.tensorflow/78c7938a85b6a6e94929e5277446b00cc6c0544c/gan_models/__init__.py -------------------------------------------------------------------------------- /gan_models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | from glob import glob 5 | import tensorflow as tf 6 | 7 | from ops import * 8 | from utils import * 9 | 10 | from gan_models.generator import * 11 | from gan_models.discriminator import * 12 | #from models.evaluate import evaluate 13 | from utils import pp, visualize, to_json 14 | 15 | from IPython import embed 16 | 17 | class Config(object): 18 | def __init__(self, FLAGS): 19 | self.exp_num = str(FLAGS.exp) 20 | self.load_cp_dir = FLAGS.load_cp_dir 21 | self.dataset = FLAGS.dataset 22 | self.dataset_path = os.path.join("./dataset/", self.dataset) 23 | self.devices = ["gpu:0", "gpu:1", "gpu:2", "gpu:3"] 24 | self.use_augmentation = FLAGS.use_augmentation 25 | self.batch_size = FLAGS.batch_size 26 | self.learning_rate = FLAGS.learning_rate 27 | 28 | self.add_noise = True 29 | self.noise_stddev = 0.1 30 | 31 | 32 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")+str(self.learning_rate) 33 | 34 | self.epoch = FLAGS.epoch 35 | self.log_dir = os.path.join('logs/gan', self.exp_num, timestamp) 36 | self.checkpoint_dir = os.path.join('checkpoint/gan', self.exp_num, timestamp) 37 | self.sample_dir = os.path.join('samples/gan', self.exp_num, timestamp) 38 | self.timestamp = timestamp 39 | 40 | self.generator_name = FLAGS.generator 41 | self.discriminator_name = FLAGS.discriminator 42 | 43 | self.generator_func = globals()[self.generator_name] 44 | self.discriminator_func = globals()[self.discriminator_name] 45 | 46 | self.loss = FLAGS.loss 47 | 48 | 49 | 50 | if FLAGS.dataset in ['mnist', 'fashion']: 51 | self.y_dim=10 52 | self.image_shape=[28, 28, 1] 53 | self.c_dim=1 54 | self.z_dim=100 55 | self.f_dim = 64 56 | self.fc_dim = 1024 57 | self.beta1 = 0.5 58 | self.beta2 = 0.999 59 | 60 | elif FLAGS.dataset == 'affmnist': 61 | self.y_dim=10 62 | self.image_shape=[40, 40, 1] 63 | self.c_dim=1 64 | self.z_dim=128 65 | self.f_dim = 64 66 | self.fc_dim = 1024 67 | self.beta1 = 0.5 68 | self.beta2 = 0.999 69 | 70 | elif FLAGS.dataset == 'celebA': 71 | self.y_dim=1 72 | self.image_shape=[64, 64, 3] 73 | self.c_dim=3 74 | self.z_dim=256 # 256, 10 75 | self.f_dim = 64 76 | self.fc_dim = 1024 77 | self.beta1 = 0.5 78 | self.beta2 = 0.999 79 | 80 | elif FLAGS.dataset == 'cifar10': 81 | self.y_dim=10 82 | self.image_shape=[32, 32, 3] 83 | self.c_dim=3 84 | self.z_dim=128 85 | self.f_dim = 128 86 | self.fc_dim = 1024 87 | self.beta1 = 0. 88 | self.beta2 = 0.9 89 | 90 | self.sample_size=10*self.batch_size 91 | 92 | def print_config(self): 93 | dicts = self.__dict__ 94 | for key in dicts.keys(): 95 | print key, dicts[key] 96 | 97 | def make_dirs(self): 98 | if not os.path.exists(self.checkpoint_dir): 99 | os.makedirs(self.checkpoint_dir) 100 | if not os.path.exists(self.sample_dir): 101 | os.makedirs(self.sample_dir) 102 | -------------------------------------------------------------------------------- /gan_models/discriminator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from ops import * 4 | from IPython import embed 5 | 6 | def dcgan_d(model, x, reuse=False): 7 | bs = model.batch_size 8 | f_dim = model.f_dim 9 | fc_dim = model.fc_dim 10 | c_dim = model.c_dim 11 | 12 | with slim.arg_scope(ops_with_bn, is_training=model.is_training): 13 | with tf.variable_scope('d_', reuse=reuse) as scope: 14 | 15 | if model.dataset_name == 'mnist': 16 | w = model.image_shape[0] 17 | h = conv2d(x, f_dim, 3, 1, act=lrelu, norm=None) 18 | h = conv2d(h, f_dim*2, 3, 1, act=lrelu) 19 | h = tf.reshape(h, [bs, -1]) 20 | h = fc(h, fc_dim, act=lrelu) 21 | 22 | elif model.dataset_name == 'affmnist': 23 | n_layer = 3 24 | c = 1 25 | w = model.image_shape[0]/2**(n_layer) 26 | 27 | h = conv2d(x, f_dim * c, 4, 2, act=lrelu, norm=None) 28 | for i in range(n_layer - 1): 29 | w /= 2 30 | c *= 2 31 | h = conv2d(h, f_dim * c, 4, 2, act=lrelu) 32 | h = conv2d(h, f_dim * c, 1, 1, act=lrelu) 33 | 34 | if i == n_layer - 2: 35 | feats = h 36 | 37 | elif model.dataset_name == 'cifar10': 38 | h = conv2d(x, f_dim, 3, 1, act=tf.nn.elu, norm=None) 39 | h = conv_mean_pool(h, f_dim, 3, act=None, norm=None) 40 | h += conv_mean_pool(x, f_dim, 1, act=None, norm=None) 41 | h = tf.nn.elu(ln(h)) 42 | 43 | h = residual_block(h, resample='down', act=tf.nn.elu) 44 | h = residual_block(h, resample=None, act=tf.nn.elu) 45 | h = residual_block(h, resample=None, act=tf.nn.elu) 46 | 47 | h = conv2d(h, f_dim, 4, 2, act=None, norm=ln) 48 | h = tf.reduce_mean(h, axis=[1,2]) 49 | 50 | else: 51 | n_layer = 4 52 | c = 1 53 | w = model.image_shape[0]/2**(n_layer) 54 | 55 | h = conv2d(x, f_dim * c, 4, 2, act=lrelu, norm=None) 56 | for i in range(n_layer - 1): 57 | w /= 2 58 | c *= 2 59 | h = conv2d(h, f_dim * c, 4, 2, act=lrelu) 60 | h = conv2d(h, f_dim * c, 1, 1, act=lrelu) 61 | 62 | if i == n_layer - 2: 63 | feats = h 64 | 65 | h = tf.reshape(h, [bs, -1]) 66 | logits = fc(h, 1, act=None, norm=None) 67 | return logits 68 | -------------------------------------------------------------------------------- /gan_models/generator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import * 3 | from IPython import embed 4 | 5 | def dcgan_g(model, z, reuse=False): 6 | bs = model.batch_size 7 | f_dim = model.f_dim 8 | fc_dim = model.fc_dim 9 | c_dim = model.c_dim 10 | 11 | with slim.arg_scope(ops_with_bn, is_training=model.is_training): 12 | with tf.variable_scope('g_', reuse=reuse) as scope: 13 | 14 | if model.dataset_name == 'mnist': 15 | n_layer = 2 16 | w = model.image_shape[0] 17 | 18 | h = fc(z, fc_dim) 19 | h = fc(h, f_dim*2*w/4*w/4) 20 | h = tf.reshape(h, [-1, w/4, w/4, f_dim*2]) 21 | h = deconv2d(h, f_dim, 4, 2) 22 | x = deconv2d(h, c_dim, 4, 2, act=tf.nn.sigmoid, norm=None) 23 | 24 | elif model.dataset_name == 'affmnist': 25 | n_layer = 3 26 | c = 2**(n_layer - 1) 27 | w = model.image_shape[0]/2**(n_layer) 28 | 29 | h = fc(z, f_dim * c * w * w, act=lrelu) 30 | h = tf.reshape(h, [-1, w, w, f_dim * c]) 31 | 32 | for i in range(n_layer - 1): 33 | w *= 2 34 | c /= 2 35 | h = deconv2d(h, f_dim * c, 4, 2) 36 | h = deconv2d(h, f_dim * c, 1, 1) 37 | 38 | x = deconv2d(h, c_dim, 4, 2, act=tf.nn.sigmoid, norm=None) 39 | 40 | elif model.dataset_name == 'cifar10': 41 | n_layer = 3 42 | w = model.image_shape[0]/2**(n_layer) 43 | 44 | h = fc(z, f_dim * w * w, act=tf.nn.elu, norm=ln) 45 | h = tf.reshape(h, [-1, w, w, f_dim]) 46 | 47 | c = f_dim 48 | for i in range(n_layer): 49 | c /= 2 50 | h = residual_block(h, resample='up', act=tf.nn.elu) 51 | 52 | x = conv2d(h, c_dim, 3, 1, act=tf.nn.tanh, norm=None) 53 | 54 | else: 55 | n_layer = 4 56 | c = 2**(n_layer - 1) 57 | w = model.image_shape[0]/2**(n_layer) 58 | 59 | h = fc(z, f_dim * c * w * w, act=lrelu) 60 | h = tf.reshape(h, [-1, w, w, f_dim * c]) 61 | 62 | for i in range(n_layer - 1): 63 | w *= 2 64 | c /= 2 65 | h = deconv2d(h, f_dim * c, 4, 2) 66 | h = deconv2d(h, f_dim * c, 1, 1) 67 | 68 | x = deconv2d(h, c_dim, 4, 2, act=tf.nn.tanh, norm=None) 69 | 70 | 71 | return x 72 | 73 | -------------------------------------------------------------------------------- /gan_models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from glob import glob 4 | import tensorflow as tf 5 | from tensorflow.python.ops import parsing_ops 6 | 7 | from ops import * 8 | from utils import * 9 | from IPython import embed 10 | 11 | slim = tf.contrib.slim 12 | 13 | class GAN(object): 14 | def __init__(self, config): 15 | self.devices = config.devices 16 | self.noise_stddev = config.noise_stddev 17 | self.config = config 18 | 19 | self.generator = NetworkWrapper(self, config.generator_func) 20 | self.discriminator = NetworkWrapper(self, config.discriminator_func) 21 | 22 | #self.evaluate = Evaluate(self, config.eval_func) 23 | 24 | self.batch_size = config.batch_size 25 | self.sample_size = config.sample_size 26 | self.image_shape = config.image_shape 27 | self.sample_dir = config.sample_dir 28 | 29 | self.y_dim = config.y_dim 30 | self.c_dim = config.c_dim 31 | self.f_dim = config.f_dim 32 | self.fc_dim = config.fc_dim 33 | self.z_dim = config.z_dim 34 | self.beta1 = config.beta1 35 | self.beta2 = config.beta2 36 | 37 | self.dataset_name = config.dataset 38 | self.dataset_path = config.dataset_path 39 | self.checkpoint_dir = config.checkpoint_dir 40 | 41 | self.use_augmentation = config.use_augmentation 42 | 43 | def save(self, sess, checkpoint_dir, step): 44 | model_name = "gan" 45 | 46 | if not os.path.exists(checkpoint_dir): 47 | os.makedirs(checkpoint_dir) 48 | 49 | self.saver.save(sess, 50 | path.join(checkpoint_dir, model_name), 51 | global_step=step) 52 | 53 | def load(self, sess, checkpoint_dir): 54 | print(" [*] Reading checkpoints...") 55 | 56 | model_dir = "%s_%s" % (self.batch_size, self.config.learning_rate) 57 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 58 | 59 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 60 | if ckpt and ckpt.model_checkpoint_path: 61 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 62 | self.saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) 63 | return True 64 | else: 65 | print "Bad checkpoint: ", ckpt 66 | return False 67 | 68 | def get_vars(self): 69 | t_vars = tf.trainable_variables() 70 | self.d_vars = [var for var in t_vars if var.name.startswith('d_')] 71 | self.g_vars = [var for var in t_vars if var.name.startswith('g_')] 72 | 73 | for x in self.d_vars: 74 | assert x not in self.g_vars 75 | for x in self.g_vars: 76 | assert x not in self.d_vars 77 | for x in t_vars: 78 | assert x in self.g_vars or x in self.d_vars, x.name 79 | self.all_vars = t_vars 80 | 81 | def build_model(self): 82 | config = self.config 83 | self.is_training = tf.placeholder_with_default(False, shape=[], name='is_training') 84 | 85 | # input 86 | self.image = tf.placeholder(tf.float32, shape=[self.batch_size]+self.image_shape) 87 | self.label = tf.placeholder(tf.float32, shape=[self.batch_size]) 88 | self.z = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim]) 89 | image = preprocess_image(self.image, self.dataset_name, self.use_augmentation) 90 | 91 | #self.z = make_z(shape=[self.batch_size, self.z_dim]) 92 | 93 | self.gen_image = self.generator(self.z) 94 | d_out_real = self.discriminator(image) 95 | d_out_fake = self.discriminator(self.gen_image, reuse=True) 96 | 97 | d_loss, g_loss, d_real, d_fake = self.get_loss(d_out_real, d_out_fake, config.loss) 98 | 99 | # optimizer 100 | self.get_vars() 101 | d_opt = tf.train.AdamOptimizer(config.learning_rate, beta1=self.beta1, beta2=self.beta2) 102 | g_opt = tf.train.AdamOptimizer(config.learning_rate, beta1=self.beta1, beta2=self.beta2) 103 | d_optimize = slim.learning.create_train_op(d_loss, d_opt, variables_to_train=self.d_vars) 104 | g_optimize = slim.learning.create_train_op(g_loss, g_opt, variables_to_train=self.g_vars) 105 | 106 | # logging 107 | tf.summary.scalar("d_real", d_real) 108 | tf.summary.scalar("d_fake", d_fake) 109 | tf.summary.scalar("d_loss", d_loss) 110 | tf.summary.scalar("g_loss", g_loss) 111 | tf.summary.image("fake_images", batch_to_grid(self.gen_image)) 112 | tf.summary.image("real_images", batch_to_grid(image)) 113 | self.d_real = d_real 114 | self.d_fake = d_fake 115 | self.saver = tf.train.Saver(max_to_keep=None) 116 | 117 | return d_optimize, g_optimize 118 | 119 | def get_loss(self, d_out_real, d_out_fake, loss='jsd'): 120 | sigm_ce = tf.nn.sigmoid_cross_entropy_with_logits 121 | loss_real = tf.reduce_mean(sigm_ce(logits=d_out_real, labels=tf.ones_like(d_out_real))) 122 | loss_fake = tf.reduce_mean(sigm_ce(logits=d_out_fake, labels=tf.zeros_like(d_out_fake))) 123 | loss_fake_ = tf.reduce_mean(sigm_ce(logits=d_out_fake, labels=tf.ones_like(d_out_fake))) 124 | 125 | if loss == 'jsd': 126 | d_loss = loss_real + loss_fake 127 | g_loss = - loss_fake 128 | elif loss == 'alternative': 129 | d_loss = loss_real + loss_fake 130 | g_loss = loss_fake_ 131 | elif loss == 'reverse_kl': 132 | d_loss = loss_real + loss_fake 133 | g_loss = loss_fake_ - loss_fake 134 | 135 | return d_loss, g_loss, tf.reduce_mean(tf.nn.sigmoid(d_out_real)), tf.reduce_mean(tf.nn.sigmoid(d_out_fake)) 136 | 137 | 138 | 139 | class NetworkWrapper(object): 140 | def __init__(self, model, func): 141 | self.model = model 142 | self.func = func 143 | 144 | def __call__(self, z, reuse=False): 145 | return self.func(self.model, z, reuse=reuse) 146 | 147 | 148 | -------------------------------------------------------------------------------- /gan_models/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import time 4 | import numpy as np 5 | from utils import * 6 | from ops import * 7 | from IPython import embed 8 | 9 | def train(model, sess): 10 | config = model.config 11 | d_optim, g_optim = model.build_model() 12 | 13 | if not (config.load_cp_dir == ''): 14 | model.load(sess, config.load_cp_dir) 15 | merged_sum = init_training(model, sess) 16 | start_time = time.time() 17 | print_time = time.time() 18 | 19 | dataset = load_dataset(model) 20 | N = dataset.num_examples 21 | max_iter = int(N/model.batch_size) * model.config.epoch 22 | 23 | print "[*] Traing Start : N=%d, Batch=%d, epoch=%d, max_iter=%d" \ 24 | %(N, model.batch_size, model.config.epoch, max_iter) 25 | 26 | for idx in xrange(1, max_iter): 27 | batch_start_time = time.time() 28 | 29 | # D step 30 | image, label = dataset.next_batch(model.batch_size) 31 | _, d_real, d_fake = sess.run( 32 | [d_optim, model.d_real, model.d_fake], 33 | feed_dict={model.image:image, model.label:label, model.z:get_z(model), model.is_training:True}) 34 | ''' 35 | # Wasserstein 36 | _ = sess.run([model.clip_d_op]) 37 | ''' 38 | 39 | # G step 40 | image, label = dataset.next_batch(model.batch_size) 41 | _ = sess.run([g_optim], feed_dict={model.image:image, model.label:label, model.z:get_z(model)}) 42 | 43 | # save checkpoint for every epoch 44 | if (idx*model.batch_size) % N < model.batch_size: 45 | epoch = int(idx*model.batch_size/N) 46 | print_time = time.time() 47 | total_time = print_time - start_time 48 | sec_per_epoch = (print_time - start_time) / epoch 49 | 50 | image, label = dataset.next_batch(model.batch_size) 51 | summary = sess.run(merged_sum, feed_dict={model.image:image, model.label:label, model.z:get_z(model)}) 52 | model.writer.add_summary(summary, epoch) 53 | 54 | _save_samples(model, sess, epoch) 55 | model.save(sess, model.checkpoint_dir, epoch) 56 | 57 | print '[Epoch %(epoch)d] time: %(total_time)4.4f, d_real: %(d_real).8f, d_fake: %(d_fake).8f, sec_per_epoch: %(sec_per_epoch)4.4f' % locals() 58 | 59 | sess.close() 60 | 61 | def _save_samples(model, sess, epoch): 62 | samples = [] 63 | noises = [] 64 | 65 | # generator hard codes the batch size 66 | for i in xrange(model.sample_size // model.batch_size): 67 | # gen_image, noise = sess.run([model.gen_image, model.z]) 68 | gen_image, noise = sess.run([model.gen_image, model.z], 69 | feed_dict={model.z:get_z(model)}) 70 | samples.append(gen_image) 71 | noises.append(noise) 72 | 73 | samples = np.concatenate(samples, axis=0) 74 | noises = np.concatenate(noises, axis=0) 75 | 76 | assert samples.shape[0] == model.sample_size 77 | save_images(samples, [8, 8], os.path.join(model.sample_dir, 'samples_%s.png' % (epoch))) 78 | 79 | print "Save Samples at %s/%s" % (model.sample_dir, 'samples_%s' % (epoch)) 80 | with open(os.path.join(model.sample_dir, 'samples_%d.npy'%(epoch)), 'w') as f: 81 | np.save(f, samples) 82 | with open(os.path.join(model.sample_dir, 'noises_%d.npy'%(epoch)), 'w') as f: 83 | np.save(f, noises) 84 | 85 | def init_training(model, sess): 86 | config = model.config 87 | init_op = tf.global_variables_initializer() 88 | sess.run(init_op) 89 | 90 | merged_sum = tf.summary.merge_all() 91 | model.writer = tf.summary.FileWriter(config.log_dir, sess.graph) 92 | 93 | 94 | if model.load(sess, model.checkpoint_dir): 95 | print(" [*] Load SUCCESS") 96 | else: 97 | print(" [!] Load failed...") 98 | 99 | if not os.path.exists(config.dataset_path): 100 | print(" [!] Data does not exist : %s" % config.dataset_path) 101 | return merged_sum 102 | 103 | def load_dataset(model): 104 | if model.dataset_name == 'mnist': 105 | import mnist 106 | return mnist.read_data_sets(model.dataset_path, dtype=tf.uint8, reshape=False, validation_size=0).train 107 | elif model.dataset_name == 'cifar10': 108 | import cifar10 109 | return cifar10.read_data_sets(model.dataset_path, dtype=tf.uint8, reshape=False, validation_size=0).train 110 | 111 | def get_z(model): 112 | return np.random.uniform(-1., 1., size=(model.batch_size, model.z_dim)) 113 | -------------------------------------------------------------------------------- /inception_score/eval_affmnist.py: -------------------------------------------------------------------------------- 1 | import model_affmnist as model 2 | import numpy as np 3 | import os 4 | from IPython import embed 5 | 6 | exps = [ 7 | ] 8 | maxs = [] 9 | 10 | for exp in exps: 11 | base_dir = '../samples' 12 | 13 | sample_dir = os.path.join(base_dir, exp) 14 | samples = sorted(list(set([int(s.split('_')[-1][:-4]) for s in os.listdir(sample_dir) if 'samples' in s])))[:100] 15 | print "num samples : %d" % (len(samples)) 16 | print sample_dir 17 | 18 | mean_stddev = np.zeros((len(samples),2)) 19 | preds = [] 20 | for i, s in enumerate(samples): 21 | with open(sample_dir + '/samples_%d.npy'%s) as f: 22 | images = np.load(f) 23 | images = np.split(images, images.shape[0]) 24 | images = [im.reshape(im.shape[1:]) for im in images] 25 | 26 | mean, stddev, pred = model.get_inception_score(images) 27 | if i % 10 == 0: 28 | print "---------- SCORE in %s ------------" % s 29 | print "%f, %f"%(mean, stddev) 30 | 31 | if np.isnan(mean): 32 | mean_stddev[i,0] = -1 33 | mean_stddev[i,1] = -1 34 | else: 35 | mean_stddev[i,0] = mean 36 | mean_stddev[i,1] = stddev 37 | preds.append(pred) 38 | 39 | max_idx = mean_stddev[:,0].argmax() 40 | maxs.append((exp, mean_stddev[max_idx,0], mean_stddev[max_idx,1], samples[max_idx])) 41 | print 'MAX = %f, at %d' % (mean_stddev[max_idx,0], samples[max_idx]) 42 | 43 | with open(sample_dir+'/scores.npy', 'w') as f: 44 | np.save(f, mean_stddev) 45 | with open(sample_dir+'/predictions.npy', 'w') as f: 46 | np.save(f, preds) 47 | print maxs 48 | -------------------------------------------------------------------------------- /inception_score/eval_cifar10.py: -------------------------------------------------------------------------------- 1 | import model_cifar10 as model 2 | import numpy as np 3 | import os 4 | from IPython import embed 5 | 6 | exps = [ 7 | ] 8 | maxs = [] 9 | 10 | for exp in exps: 11 | base_dir = '/data/whyjay/NIPS2017/' 12 | sample_dir = base_dir + exp 13 | samples = sorted(list(set([int(s.split('_')[-1][:-4]) for s in os.listdir(sample_dir) if 'samples' in s]))) 14 | print "num samples : %d" % (len(samples)) 15 | print sample_dir 16 | 17 | mean_stddev = np.zeros((len(samples),2)) 18 | for i, s in enumerate(samples): 19 | #with open(sample_dir + '/samples_rec_%d.npy'%s) as f: 20 | with open(sample_dir + '/samples_%d.npy'%s) as f: 21 | images = np.load(f) 22 | images = np.split(images, images.shape[0]) 23 | images = [(im.reshape(im.shape[1:]) + 1)*255./2 for im in images] 24 | 25 | mean, stddev = model.get_inception_score(images) 26 | print "---------- SCORE in %s ------------" % s 27 | print "%f, %f"%(mean, stddev) 28 | mean_stddev[i,0] = mean 29 | mean_stddev[i,1] = stddev 30 | 31 | max_idx = mean_stddev[:,0].argmax() 32 | maxs.append((exp, mean_stddev[max_idx,0], mean_stddev[max_idx,1], samples[max_idx])) 33 | print 'MAX = %f, at %d' % (mean_stddev[max_idx,0], samples[max_idx]) 34 | 35 | with open(sample_dir+'/scores.npy', 'w') as f: 36 | np.save(f, mean_stddev) 37 | print maxs 38 | -------------------------------------------------------------------------------- /inception_score/eval_mnist.py: -------------------------------------------------------------------------------- 1 | import model_mnist as model 2 | import numpy as np 3 | import os 4 | from IPython import embed 5 | 6 | exps = [ 7 | ] 8 | maxs = [] 9 | 10 | for exp in exps: 11 | base_dir = '../samples' 12 | 13 | sample_dir = os.path.join(base_dir, exp) 14 | samples = sorted(list(set([int(s.split('_')[-1][:-4]) for s in os.listdir(sample_dir) if 'samples' in s])))[:100] 15 | print "num samples : %d" % (len(samples)) 16 | print sample_dir 17 | 18 | mean_stddev = np.zeros((len(samples),2)) 19 | preds = [] 20 | for i, s in enumerate(samples): 21 | with open(sample_dir + '/samples_%d.npy'%s) as f: 22 | images = np.load(f) 23 | images = np.split(images, images.shape[0]) 24 | images = [im.reshape(im.shape[1:]) for im in images] 25 | 26 | mean, stddev, pred = model.get_inception_score(images) 27 | if i % 10 == 0: 28 | print "---------- SCORE in %s ------------" % s 29 | print "%f, %f"%(mean, stddev) 30 | mean_stddev[i,0] = mean 31 | mean_stddev[i,1] = stddev 32 | preds.append(pred) 33 | 34 | max_idx = mean_stddev[:,0].argmax() 35 | maxs.append((exp, mean_stddev[max_idx,0], mean_stddev[max_idx,1], samples[max_idx])) 36 | print 'MAX = %f, at %d' % (mean_stddev[max_idx,0], samples[max_idx]) 37 | 38 | with open(sample_dir+'/scores.npy', 'w') as f: 39 | np.save(f, mean_stddev) 40 | with open(sample_dir+'/predictions.npy', 'w') as f: 41 | np.save(f, preds) 42 | print maxs 43 | -------------------------------------------------------------------------------- /inception_score/model_affmnist.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os.path 7 | import sys 8 | import tarfile 9 | 10 | import sys 11 | sys.path.insert(0, '/data/whyjay/NIPS2017') 12 | import numpy as np 13 | from six.moves import urllib 14 | import tensorflow as tf 15 | import glob 16 | import scipy.misc 17 | import math 18 | import sys 19 | slim = tf.contrib.slim 20 | 21 | MODEL_DIR = 'affmnist_checkpoints' 22 | softmax = None 23 | x = None 24 | 25 | # Call this function with list of images. Each of elements should be a 26 | # numpy array with values ranging from 0 to 255. 27 | def get_inception_score(images, splits=10): 28 | assert(type(images) == list) 29 | assert(type(images[0]) == np.ndarray) 30 | assert(len(images[0].shape) == 3) 31 | assert(np.min(images[0]) >= 0.0) 32 | inps = [] 33 | 34 | for img in images: 35 | img = img.astype(np.float32) 36 | inps.append(np.expand_dims(img, 0)) 37 | 38 | bs = 100 39 | with tf.Session() as sess: 40 | ckpt = tf.train.get_checkpoint_state(MODEL_DIR) 41 | if ckpt and ckpt.model_checkpoint_path: 42 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 43 | saver = tf.train.Saver() 44 | saver.restore(sess, os.path.join(MODEL_DIR, ckpt_name)) 45 | 46 | preds = [] 47 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 48 | 49 | for i in range(n_batches): 50 | sys.stdout.write(".") 51 | sys.stdout.flush() 52 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 53 | inp = np.concatenate(inp, 0) 54 | pred = sess.run(softmax, {x: inp}) 55 | preds.append(pred) 56 | 57 | preds = np.concatenate(preds, 0) 58 | scores = [] 59 | 60 | for i in range(splits): 61 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 62 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 63 | kl = np.mean(np.sum(kl, 1)) 64 | scores.append(np.exp(kl)) 65 | 66 | return np.mean(scores), np.std(scores), preds 67 | 68 | # This function is called automatically. 69 | def _init_model(): 70 | global softmax 71 | global x 72 | if not os.path.exists(MODEL_DIR): 73 | os.makedirs(MODEL_DIR) 74 | 75 | # Works with an arbitrary minibatch size. 76 | x = tf.placeholder(tf.float32, shape=[None, 40, 40, 1]) 77 | 78 | h = slim.conv2d(x, 32, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 79 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 80 | h = slim.conv2d(h, 64, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 81 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 82 | h = tf.reshape(h, [-1, 10*10*64]) 83 | h = slim.fully_connected(h, 1024, activation_fn=None, normalizer_fn=None) 84 | h = tf.nn.dropout(h, 1.0) 85 | logits = slim.fully_connected(h, 10, activation_fn=None, normalizer_fn=None) 86 | softmax = tf.nn.softmax(logits) 87 | 88 | if softmax is None or x is None: 89 | _init_model() 90 | -------------------------------------------------------------------------------- /inception_score/model_cifar10.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os.path 7 | import sys 8 | import tarfile 9 | 10 | import numpy as np 11 | from six.moves import urllib 12 | import tensorflow as tf 13 | import glob 14 | import scipy.misc 15 | import math 16 | import sys 17 | 18 | MODEL_DIR = '/tmp/imagenet' 19 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 20 | softmax = None 21 | 22 | # Call this function with list of images. Each of elements should be a 23 | # numpy array with values ranging from 0 to 255. 24 | def get_inception_score(images, splits=10): 25 | assert(type(images) == list) 26 | assert(type(images[0]) == np.ndarray) 27 | assert(len(images[0].shape) == 3) 28 | assert(np.max(images[0]) > 10) 29 | assert(np.min(images[0]) >= 0.0) 30 | inps = [] 31 | 32 | for img in images: 33 | img = img.astype(np.float32) 34 | inps.append(np.expand_dims(img, 0)) 35 | 36 | bs = 100 37 | with tf.Session() as sess: 38 | preds = [] 39 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 40 | for i in range(n_batches): 41 | sys.stdout.write(".") 42 | sys.stdout.flush() 43 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 44 | inp = np.concatenate(inp, 0) 45 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 46 | preds.append(pred) 47 | 48 | preds = np.concatenate(preds, 0) 49 | scores = [] 50 | 51 | for i in range(splits): 52 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 53 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 54 | kl = np.mean(np.sum(kl, 1)) 55 | scores.append(np.exp(kl)) 56 | 57 | return np.mean(scores), np.std(scores) 58 | 59 | # This function is called automatically. 60 | def _init_inception(): 61 | global softmax 62 | if not os.path.exists(MODEL_DIR): 63 | os.makedirs(MODEL_DIR) 64 | filename = DATA_URL.split('/')[-1] 65 | filepath = os.path.join(MODEL_DIR, filename) 66 | if not os.path.exists(filepath): 67 | def _progress(count, block_size, total_size): 68 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 69 | filename, float(count * block_size) / float(total_size) * 100.0)) 70 | sys.stdout.flush() 71 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 72 | print() 73 | statinfo = os.stat(filepath) 74 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 75 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 76 | with tf.gfile.FastGFile(os.path.join( 77 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 78 | graph_def = tf.GraphDef() 79 | graph_def.ParseFromString(f.read()) 80 | _ = tf.import_graph_def(graph_def, name='') 81 | # Works with an arbitrary minibatch size. 82 | with tf.Session() as sess: 83 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 84 | ops = pool3.graph.get_operations() 85 | for op_idx, op in enumerate(ops): 86 | for o in op.outputs: 87 | shape = o.get_shape() 88 | shape = [s.value for s in shape] 89 | new_shape = [] 90 | for j, s in enumerate(shape): 91 | if s == 1 and j == 0: 92 | new_shape.append(None) 93 | else: 94 | new_shape.append(s) 95 | o._shape = tf.TensorShape(new_shape) 96 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 97 | logits = tf.matmul(tf.squeeze(pool3), w) 98 | softmax = tf.nn.softmax(logits) 99 | 100 | if softmax is None: 101 | _init_inception() 102 | -------------------------------------------------------------------------------- /inception_score/model_mnist.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os.path 7 | import sys 8 | import tarfile 9 | 10 | import sys 11 | sys.path.insert(0, '/data/whyjay/NIPS2017') 12 | import numpy as np 13 | from six.moves import urllib 14 | import tensorflow as tf 15 | import glob 16 | import scipy.misc 17 | import math 18 | import sys 19 | slim = tf.contrib.slim 20 | 21 | MODEL_DIR = 'checkpoints' 22 | softmax = None 23 | x = None 24 | 25 | # Call this function with list of images. Each of elements should be a 26 | # numpy array with values ranging from 0 to 255. 27 | def get_inception_score(images, splits=10): 28 | assert(type(images) == list) 29 | assert(type(images[0]) == np.ndarray) 30 | assert(len(images[0].shape) == 3) 31 | assert(np.min(images[0]) >= 0.0) 32 | inps = [] 33 | 34 | for img in images: 35 | img = img.astype(np.float32) 36 | inps.append(np.expand_dims(img, 0)) 37 | 38 | bs = 100 39 | with tf.Session() as sess: 40 | ckpt = tf.train.get_checkpoint_state(MODEL_DIR) 41 | if ckpt and ckpt.model_checkpoint_path: 42 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 43 | saver = tf.train.Saver() 44 | saver.restore(sess, os.path.join(MODEL_DIR, ckpt_name)) 45 | 46 | preds = [] 47 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 48 | 49 | for i in range(n_batches): 50 | sys.stdout.write(".") 51 | sys.stdout.flush() 52 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 53 | inp = np.concatenate(inp, 0) 54 | pred = sess.run(softmax, {x: inp}) 55 | preds.append(pred) 56 | 57 | preds = np.concatenate(preds, 0) 58 | scores = [] 59 | 60 | for i in range(splits): 61 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 62 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 63 | kl = np.mean(np.sum(kl, 1)) 64 | scores.append(np.exp(kl)) 65 | 66 | return np.mean(scores), np.std(scores), preds 67 | 68 | # This function is called automatically. 69 | def _init_model(): 70 | global softmax 71 | global x 72 | if not os.path.exists(MODEL_DIR): 73 | os.makedirs(MODEL_DIR) 74 | 75 | # Works with an arbitrary minibatch size. 76 | x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1]) 77 | 78 | h = slim.conv2d(x, 32, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 79 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 80 | h = slim.conv2d(h, 64, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 81 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 82 | h = tf.reshape(h, [-1, 7*7*64]) 83 | h = slim.fully_connected(h, 1024, activation_fn=None, normalizer_fn=None) 84 | h = tf.nn.dropout(h, 1.0) 85 | logits = slim.fully_connected(h, 10, activation_fn=None, normalizer_fn=None) 86 | softmax = tf.nn.softmax(logits) 87 | 88 | if softmax is None or x is None: 89 | _init_model() 90 | -------------------------------------------------------------------------------- /inception_score/train_affmnist_classifier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../') 3 | import math 4 | 5 | import numpy as np 6 | from scipy.ndimage import rotate 7 | import tensorflow as tf 8 | from tensorflow.examples.tutorials.mnist import input_data 9 | from ops import * 10 | from IPython import embed 11 | slim = tf.contrib.slim 12 | 13 | # The MNIST dataset has 10 classes, representing the digits 0 through 9. 14 | NUM_CLASSES = 10 15 | 16 | def main(): 17 | sess = tf.Session() 18 | mnist = input_data.read_data_sets('../dataset/mnist', one_hot=True) 19 | save_path = 'checkpoints' 20 | 21 | x = tf.placeholder(tf.float32, shape=[None, 40, 40, 1]) 22 | y_ = tf.placeholder(tf.float32, shape=[None, 10]) 23 | 24 | h = slim.conv2d(x, 32, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 25 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 26 | h = slim.conv2d(h, 64, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 27 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 28 | h = tf.reshape(h, [-1, 10*10*64]) 29 | h = slim.fully_connected(h, 1024, activation_fn=None, normalizer_fn=None) 30 | keep_prob = tf.placeholder(tf.float32) 31 | h = tf.nn.dropout(h, keep_prob) 32 | logits = slim.fully_connected(h, 10, activation_fn=None, normalizer_fn=None) 33 | 34 | cross_entropy = tf.reduce_mean( 35 | tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits)) 36 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 37 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(y_,1)) 38 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 39 | sess.run(tf.global_variables_initializer()) 40 | 41 | saver = tf.train.Saver(max_to_keep=None) 42 | 43 | test_images = mnist.test.images 44 | test_images = test_images.reshape((-1, 28, 28, 1)) 45 | move = np.zeros((len(test_images), 40, 40, 1)) # initialize 46 | for i in range(len(test_images)): 47 | degree = np.random.randint(41) - 20 48 | h_move = np.random.randint(12) 49 | w_move = np.random.randint(12) 50 | rot = rotate(test_images[i], degree, reshape = False) 51 | move[i][h_move:h_move+28, w_move:w_move+28] = rot 52 | test_images = move 53 | 54 | for i in range(200000): 55 | batch = mnist.train.next_batch(50) 56 | batch_img = batch[0].reshape((-1, 28, 28, 1)) 57 | move = np.zeros((50, 40, 40, 1)) # initialize 58 | for j in range(50): 59 | degree = np.random.randint(41) - 20 60 | h_move = np.random.randint(12) 61 | w_move = np.random.randint(12) 62 | rot = rotate(batch_img[j], degree, reshape = False) 63 | move[j][h_move:h_move+28, w_move:w_move+28] = rot 64 | batch_img = move 65 | 66 | sess.run(train_step, feed_dict={x: batch_img, y_: batch[1], keep_prob: 0.5}) 67 | 68 | if i % 1000 == 0: 69 | train_accuracy = sess.run(accuracy, feed_dict={x:batch_img, y_: batch[1], keep_prob: 1.0}) 70 | test_accuracy = sess.run(accuracy, feed_dict={x:test_images, y_: mnist.test.labels, keep_prob: 1.0}) 71 | save_at = os.path.join(save_path, 'affmnist_ckpt_%f' % test_accuracy) 72 | 73 | print("step %d, training accuracy %g, test accuracy %f"%(i, train_accuracy, test_accuracy)) 74 | 75 | print 'Save at %s' % save_at 76 | saver.save(sess, save_at, global_step=i) 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /inception_score/train_mnist_classifier.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '../') 3 | import math 4 | 5 | import tensorflow as tf 6 | from tensorflow.examples.tutorials.mnist import input_data 7 | from ops import * 8 | from IPython import embed 9 | slim = tf.contrib.slim 10 | 11 | # The MNIST dataset has 10 classes, representing the digits 0 through 9. 12 | NUM_CLASSES = 10 13 | 14 | def main(): 15 | sess = tf.Session() 16 | mnist = input_data.read_data_sets('../dataset/mnist', one_hot=True) 17 | save_path = 'checkpoints' 18 | 19 | x = tf.placeholder(tf.float32, shape=[None, 784]) 20 | images = tf.reshape(x, [-1, 28, 28, 1]) 21 | y_ = tf.placeholder(tf.float32, shape=[None, 10]) 22 | 23 | h = slim.conv2d(images, 32, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 24 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 25 | h = slim.conv2d(h, 64, 3, 1, activation_fn=tf.nn.relu, normalizer_fn=None) 26 | h = tf.nn.max_pool(h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 27 | h = tf.reshape(h, [-1, 7*7*64]) 28 | h = slim.fully_connected(h, 1024, activation_fn=None, normalizer_fn=None) 29 | keep_prob = tf.placeholder(tf.float32) 30 | h = tf.nn.dropout(h, keep_prob) 31 | logits = slim.fully_connected(h, 10, activation_fn=None, normalizer_fn=None) 32 | 33 | cross_entropy = tf.reduce_mean( 34 | tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=logits)) 35 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 36 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(y_,1)) 37 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 38 | sess.run(tf.global_variables_initializer()) 39 | 40 | saver = tf.train.Saver(max_to_keep=None) 41 | for i in range(20000): 42 | batch = mnist.train.next_batch(50) 43 | sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 44 | 45 | if i % 1000 == 0: 46 | train_accuracy = sess.run(accuracy, feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0}) 47 | test_accuracy = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}) 48 | save_at = os.path.join(save_path, 'ckpt_%f' % test_accuracy) 49 | 50 | print("step %d, training accuracy %g, test accuracy %f"%(i, train_accuracy, test_accuracy)) 51 | 52 | print 'Save at %s' % save_at 53 | saver.save(sess, save_at, global_step=i) 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import gzip 23 | 24 | import numpy 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | 27 | from tensorflow.contrib.learn.python.learn.datasets import base 28 | from tensorflow.python.framework import dtypes 29 | 30 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 31 | 32 | 33 | def _read32(bytestream): 34 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 35 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 36 | 37 | 38 | def extract_images(f): 39 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. 40 | 41 | Args: 42 | f: A file object that can be passed into a gzip reader. 43 | 44 | Returns: 45 | data: A 4D uint8 numpy array [index, y, x, depth]. 46 | 47 | Raises: 48 | ValueError: If the bytestream does not start with 2051. 49 | 50 | """ 51 | print('Extracting', f.name) 52 | with gzip.GzipFile(fileobj=f) as bytestream: 53 | magic = _read32(bytestream) 54 | if magic != 2051: 55 | raise ValueError('Invalid magic number %d in MNIST image file: %s' % 56 | (magic, f.name)) 57 | num_images = _read32(bytestream) 58 | rows = _read32(bytestream) 59 | cols = _read32(bytestream) 60 | buf = bytestream.read(rows * cols * num_images) 61 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 62 | data = data.reshape(num_images, rows, cols, 1) 63 | return data 64 | 65 | 66 | def dense_to_one_hot(labels_dense, num_classes): 67 | """Convert class labels from scalars to one-hot vectors.""" 68 | num_labels = labels_dense.shape[0] 69 | index_offset = numpy.arange(num_labels) * num_classes 70 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 71 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 72 | return labels_one_hot 73 | 74 | 75 | def extract_labels(f, one_hot=False, num_classes=10): 76 | """Extract the labels into a 1D uint8 numpy array [index]. 77 | 78 | Args: 79 | f: A file object that can be passed into a gzip reader. 80 | one_hot: Does one hot encoding for the result. 81 | num_classes: Number of classes for the one hot encoding. 82 | 83 | Returns: 84 | labels: a 1D uint8 numpy array. 85 | 86 | Raises: 87 | ValueError: If the bystream doesn't start with 2049. 88 | """ 89 | print('Extracting', f.name) 90 | with gzip.GzipFile(fileobj=f) as bytestream: 91 | magic = _read32(bytestream) 92 | if magic != 2049: 93 | raise ValueError('Invalid magic number %d in MNIST label file: %s' % 94 | (magic, f.name)) 95 | num_items = _read32(bytestream) 96 | buf = bytestream.read(num_items) 97 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 98 | if one_hot: 99 | return dense_to_one_hot(labels, num_classes) 100 | return labels 101 | 102 | 103 | class DataSet(object): 104 | 105 | def __init__(self, 106 | images, 107 | labels, 108 | fake_data=False, 109 | one_hot=False, 110 | dtype=dtypes.float32, 111 | reshape=True): 112 | """Construct a DataSet. 113 | one_hot arg is used only if fake_data is true. `dtype` can be either 114 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 115 | `[0, 1]`. 116 | """ 117 | dtype = dtypes.as_dtype(dtype).base_dtype 118 | if dtype not in (dtypes.uint8, dtypes.float32): 119 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 120 | dtype) 121 | if fake_data: 122 | self._num_examples = 10000 123 | self.one_hot = one_hot 124 | else: 125 | assert images.shape[0] == labels.shape[0], ( 126 | 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) 127 | self._num_examples = images.shape[0] 128 | 129 | # Convert shape from [num examples, rows, columns, depth] 130 | # to [num examples, rows*columns] (assuming depth == 1) 131 | if reshape: 132 | assert images.shape[3] == 1 133 | images = images.reshape(images.shape[0], 134 | images.shape[1] * images.shape[2]) 135 | if dtype == dtypes.float32: 136 | # Convert from [0, 255] -> [0.0, 1.0]. 137 | images = images.astype(numpy.float32) 138 | images = numpy.multiply(images, 1.0 / 255.0) 139 | self._images = images 140 | self._labels = labels 141 | self._epochs_completed = 0 142 | self._index_in_epoch = 0 143 | 144 | @property 145 | def images(self): 146 | return self._images 147 | 148 | @property 149 | def labels(self): 150 | return self._labels 151 | 152 | @property 153 | def num_examples(self): 154 | return self._num_examples 155 | 156 | @property 157 | def epochs_completed(self): 158 | return self._epochs_completed 159 | 160 | def next_batch(self, batch_size, fake_data=False): 161 | """Return the next `batch_size` examples from this data set.""" 162 | if fake_data: 163 | fake_image = [1] * 784 164 | if self.one_hot: 165 | fake_label = [1] + [0] * 9 166 | else: 167 | fake_label = 0 168 | return [fake_image for _ in xrange(batch_size)], [ 169 | fake_label for _ in xrange(batch_size) 170 | ] 171 | start = self._index_in_epoch 172 | self._index_in_epoch += batch_size 173 | if self._index_in_epoch > self._num_examples: 174 | # Finished epoch 175 | self._epochs_completed += 1 176 | # Shuffle the data 177 | perm = numpy.arange(self._num_examples) 178 | numpy.random.shuffle(perm) 179 | self._images = self._images[perm] 180 | self._labels = self._labels[perm] 181 | # Start next epoch 182 | start = 0 183 | self._index_in_epoch = batch_size 184 | assert batch_size <= self._num_examples 185 | end = self._index_in_epoch 186 | return self._images[start:end], self._labels[start:end] 187 | 188 | 189 | def read_data_sets(train_dir, 190 | fake_data=False, 191 | one_hot=False, 192 | dtype=dtypes.float32, 193 | reshape=True, 194 | validation_size=5000): 195 | if fake_data: 196 | 197 | def fake(): 198 | return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 199 | 200 | train = fake() 201 | validation = fake() 202 | test = fake() 203 | return base.Datasets(train=train, validation=validation, test=test) 204 | 205 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 206 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 207 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 208 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 209 | 210 | local_file = base.maybe_download(TRAIN_IMAGES, train_dir, 211 | SOURCE_URL + TRAIN_IMAGES) 212 | with open(local_file, 'rb') as f: 213 | train_images = extract_images(f) 214 | 215 | local_file = base.maybe_download(TRAIN_LABELS, train_dir, 216 | SOURCE_URL + TRAIN_LABELS) 217 | with open(local_file, 'rb') as f: 218 | train_labels = extract_labels(f, one_hot=one_hot) 219 | 220 | local_file = base.maybe_download(TEST_IMAGES, train_dir, 221 | SOURCE_URL + TEST_IMAGES) 222 | with open(local_file, 'rb') as f: 223 | test_images = extract_images(f) 224 | 225 | local_file = base.maybe_download(TEST_LABELS, train_dir, 226 | SOURCE_URL + TEST_LABELS) 227 | with open(local_file, 'rb') as f: 228 | test_labels = extract_labels(f, one_hot=one_hot) 229 | 230 | if not 0 <= validation_size <= len(train_images): 231 | raise ValueError( 232 | 'Validation size should be between 0 and {}. Received: {}.' 233 | .format(len(train_images), validation_size)) 234 | 235 | validation_images = train_images[:validation_size] 236 | validation_labels = train_labels[:validation_size] 237 | train_images = train_images[validation_size:] 238 | train_labels = train_labels[validation_size:] 239 | 240 | train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape) 241 | validation = DataSet(validation_images, 242 | validation_labels, 243 | dtype=dtype, 244 | reshape=reshape) 245 | test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape) 246 | 247 | return base.Datasets(train=train, validation=validation, test=test) 248 | 249 | 250 | def load_mnist(train_dir='MNIST-data'): 251 | return read_data_sets(train_dir) 252 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import contextmanager 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from tensorflow.python.framework import ops 7 | 8 | from utils import * 9 | 10 | slim = tf.contrib.slim 11 | rng = np.random.RandomState([2016, 6, 1]) 12 | ln = tf.contrib.layers.layer_norm 13 | bn = slim.batch_norm 14 | 15 | def conv_cond_concat(x, y): 16 | """Concatenate conditioning vector on feature map axis.""" 17 | x_shapes = x.get_shape() 18 | y_shapes = y.get_shape() 19 | return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 20 | 21 | 22 | def lrelu(x, leak=0.2, name="lrelu"): 23 | with tf.variable_scope(name): 24 | f1 = 0.5 * (1 + leak) 25 | f2 = 0.5 * (1 - leak) 26 | return f1 * x + f2 * abs(x) 27 | 28 | def sin_and_cos(x, name="ignored"): 29 | return tf.concat(len(x.get_shape()) - 1, [tf.sin(x), tf.cos(x)]) 30 | 31 | def maxout(x, k = 2): 32 | shape = [int(e) for e in x.get_shape()] 33 | ax = len(shape) 34 | ch = shape[-1] 35 | assert ch % k == 0 36 | shape[-1] = ch / k 37 | shape.append(k) 38 | x = tf.reshape(x, shape) 39 | return tf.reduce_max(x, ax) 40 | 41 | def offset_maxout(x, k = 2): 42 | shape = [int(e) for e in x.get_shape()] 43 | ax = len(shape) 44 | ch = shape[-1] 45 | assert ch % k == 0 46 | shape[-1] = ch / k 47 | shape.append(k) 48 | x = tf.reshape(x, shape) 49 | ofs = rng.randn(1000, k).max(axis=1).mean() 50 | return tf.reduce_max(x, ax) - ofs 51 | 52 | def lrelu_sq(x): 53 | """ 54 | Concatenates lrelu and square 55 | """ 56 | dim = len(x.get_shape()) - 1 57 | return tf.concat([lrelu(x), tf.minimum(tf.abs(x), tf.square(x))], dim) 58 | 59 | 60 | def nin(input_, output_size, name=None, mean=0., stddev=0.02, bias_start=0.0, with_w=False): 61 | s = list(map(int, input_.get_shape())) 62 | input_ = tf.reshape(input_, [np.prod(s[:-1]), s[-1]]) 63 | input_ = linear(input_, output_size, name=name, mean=mean, stddev=stddev, bias_start=bias_start, with_w=with_w) 64 | return tf.reshape(input_, s[:-1]+[output_size]) 65 | 66 | @contextmanager 67 | def variables_on_cpu(): 68 | old_fn = tf.get_variable 69 | def new_fn(*args, **kwargs): 70 | with tf.device("/cpu:0"): 71 | return old_fn(*args, **kwargs) 72 | tf.get_variable = new_fn 73 | yield 74 | tf.get_variable = old_fn 75 | 76 | @contextmanager 77 | def variables_on_gpu0(): 78 | old_fn = tf.get_variable 79 | def new_fn(*args, **kwargs): 80 | with tf.device("/gpu:0"): 81 | return old_fn(*args, **kwargs) 82 | tf.get_variable = new_fn 83 | yield 84 | tf.get_variable = old_fn 85 | 86 | def avg_grads(tower_grads): 87 | """Calculate the average gradient for each shared variable across all towers. 88 | 89 | Note that this function provides a synchronization point across all towers. 90 | 91 | Args: 92 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 93 | is over individual gradients. The inner list is over the gradient 94 | calculation for each tower. 95 | Returns: 96 | List of pairs of (gradient, variable) where the gradient has been averaged 97 | across all towers. 98 | """ 99 | average_grads = [] 100 | for grad_and_vars in zip(*tower_grads): 101 | # Note that each grad_and_vars looks like the following: 102 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 103 | grads = [] 104 | for g, _ in grad_and_vars: 105 | # Add 0 dimension to the gradients to represent the tower. 106 | expanded_g = tf.expand_dims(g, 0) 107 | 108 | # Append on a 'tower' dimension which we will average over below. 109 | grads.append(expanded_g) 110 | 111 | # Average over the 'tower' dimension. 112 | grad = tf.reduce_mean(tf.concat(grads, 0), 0) 113 | 114 | # Keep in mind that the Variables are redundant because they are shared 115 | # across towers. So .. we will just return the first tower's pointer to 116 | # the Variable. 117 | v = grad_and_vars[0][1] 118 | grad_and_var = (grad, v) 119 | average_grads.append(grad_and_var) 120 | return average_grads 121 | 122 | 123 | def decayer(x, name="decayer"): 124 | with tf.variable_scope(name): 125 | scale = tf.get_variable("scale", [1], initializer=tf.constant_initializer(1.)) 126 | decay_scale = tf.get_variable("decay_scale", [1], initializer=tf.constant_initializer(1.)) 127 | relu = tf.nn.relu(x) 128 | return scale * relu / (1. + tf.abs(decay_scale) * tf.square(decay_scale)) 129 | 130 | def decayer2(x, name="decayer"): 131 | with tf.variable_scope(name): 132 | scale = tf.get_variable("scale", [int(x.get_shape()[-1])], initializer=tf.constant_initializer(1.)) 133 | decay_scale = tf.get_variable("decay_scale", [int(x.get_shape()[-1])], initializer=tf.constant_initializer(1.)) 134 | relu = tf.nn.relu(x) 135 | return scale * relu / (1. + tf.abs(decay_scale) * tf.square(decay_scale)) 136 | 137 | def masked_relu(x, name="ignored"): 138 | shape = [int(e) for e in x.get_shape()] 139 | prefix = [0] * (len(shape) - 1) 140 | most = shape[:-1] 141 | assert shape[-1] % 2 == 0 142 | half = shape[-1] // 2 143 | first_half = tf.slice(x, prefix + [0], most + [half]) 144 | second_half = tf.slice(x, prefix + [half], most + [half]) 145 | return tf.nn.relu(first_half) * tf.nn.sigmoid(second_half) 146 | 147 | def make_z(shape, minval=-1.0, maxval=1.0, name="z"): 148 | z = tf.random_uniform(shape, 149 | minval=minval, maxval=maxval, 150 | name=name, dtype=tf.float32) 151 | #z = tf.random_normal(shape, name=name, stddev=0.5, dtype=tf.float32) 152 | return z 153 | 154 | def get_sample_zs(model): 155 | assert model.sample_size > model.batch_size 156 | assert model.sample_size % model.batch_size == 0 157 | if model.config.multigpu: 158 | batch_size = model.batch_size // len(model.devices) 159 | else: 160 | batch_size = model.batch_size 161 | 162 | steps = model.sample_size // batch_size 163 | assert steps > 0 164 | 165 | sample_zs = [] 166 | for i in xrange(steps): 167 | cur_zs = model.sess.run(model.z) 168 | sample_zs.append(cur_zs) 169 | 170 | sample_zs = np.concatenate(sample_zs, axis=0) 171 | assert sample_zs.shape[0] == model.sample_size 172 | return sample_zs 173 | 174 | def batch_to_grid(images, width=4): 175 | images = tf.squeeze(images[:width**2]) 176 | images_list = tf.unstack(images, num=width**2, axis=0) 177 | conc = tf.concat(images_list, axis=1) 178 | sp = tf.split(conc, width, axis=1) 179 | grid = tf.expand_dims(tf.concat(sp, axis=0), axis=0) 180 | if len(grid.get_shape().as_list()) < 4: 181 | grid = tf.expand_dims(grid, axis=-1) 182 | 183 | return grid 184 | 185 | @tf.contrib.framework.add_arg_scope 186 | def fc(x, out_dim, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False): 187 | if norm == bn: 188 | return slim.fully_connected( 189 | x, out_dim, activation_fn=act, normalizer_fn=norm, 190 | weights_initializer=init, normalizer_params={'is_training':is_training}) 191 | else: 192 | return slim.fully_connected( 193 | x, out_dim, activation_fn=act, normalizer_fn=norm, weights_initializer=init) 194 | 195 | @tf.contrib.framework.add_arg_scope 196 | def deconv2d(x, out_dim, k, s, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False): 197 | if norm == bn: 198 | return slim.conv2d_transpose( 199 | x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, 200 | weights_initializer=init, normalizer_params={'is_training':is_training}) 201 | else: 202 | return slim.conv2d_transpose( 203 | x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init) 204 | 205 | @tf.contrib.framework.add_arg_scope 206 | def conv2d(x, out_dim, k, s, is_training, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)):#, is_training=False): 207 | if norm == bn: 208 | return slim.conv2d( 209 | x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, 210 | weights_initializer=init, normalizer_params={'is_training':is_training}) 211 | else: 212 | return slim.conv2d( 213 | x, out_dim, k, s, activation_fn=act, normalizer_fn=norm, weights_initializer=init) 214 | 215 | 216 | def preprocess_image(image, dataset, use_augmentation=False): 217 | image = tf.divide(image, 255., name=None) 218 | if use_augmentation: 219 | image = tf.image.random_brightness(image, max_delta=16. / 255.) 220 | image = tf.image.random_contrast(image, lower=0.9, upper=1.1) 221 | image = tf.minimum(tf.maximum(image, 0.0), 1.0) 222 | 223 | if ('mnist' not in dataset) and ('fashion' not in dataset): 224 | image = tf.subtract(image * 2., 1.) 225 | 226 | return image 227 | 228 | def conv_mean_pool(x, out_dim, k=3, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)): 229 | h = conv2d(x, out_dim, k=k, s=1, act=act, norm=norm, init=init) 230 | return tf.add_n([h[:,::2,::2,:], h[:,1::2,::2,:], h[:,::2,1::2,:], h[:,1::2,1::2,:]]) / 4. 231 | 232 | def resize_conv2d(x, out_dim, k=3, scale=2, act=tf.nn.relu, norm=bn, init=tf.truncated_normal_initializer(stddev=0.02)): 233 | # h, w = x.get_shape().as_list()[1:3] 234 | # h = tf.image.resize_nearest_neighbor(x, (h*scale, w*scale)) 235 | h = tf.concat([x, x, x, x], axis=3) 236 | h = tf.depth_to_space(h, 2) 237 | return conv2d(h, out_dim, k=k, s=1, act=act, norm=norm, init=init) 238 | 239 | def residual_block(x, resample=None, labels=None, act=tf.nn.relu, norm=ln, init=tf.truncated_normal_initializer(stddev=0.02)): 240 | c_dim = x.get_shape().as_list()[-1] 241 | 242 | if resample=='down': 243 | h = conv2d(x, c_dim, 3, 1, act=act, init=init, norm=norm) 244 | #h = conv2d(x, c_dim, 3, 1, act=act, init=init) 245 | h = conv_mean_pool(h, c_dim, 3, act=None, norm=None, init=init) 246 | h += conv_mean_pool(x, c_dim, 1, act=None, norm=None, init=init) 247 | h = act(norm(h)) 248 | elif resample=='up': 249 | h = resize_conv2d(x, c_dim, 3, act=act, init=init, norm=norm) 250 | #h = resize_conv2d(x, c_dim, 3, act=act, init=init) 251 | h = conv2d(h, c_dim, 3, 1, act=None, norm=None, init=init) 252 | h += resize_conv2d(x, c_dim, 1, act=None, norm=None, init=init) 253 | h = act(norm(h)) 254 | elif resample==None: 255 | h = conv2d(x, c_dim, 3, 1, act=act, init=init, norm=norm) 256 | #h = conv2d(x, c_dim, 3, 1, act=act, init=init) 257 | h = conv2d(h, c_dim, 3, 1, act=None, norm=None, init=init) 258 | h += x 259 | h = act(norm(h)) 260 | else: 261 | raise Exception('invalid resample value') 262 | 263 | return h 264 | 265 | def reparameterize(mu, logvar, distribution='gaussian'): 266 | if distribution == 'gaussian': 267 | epsilon = tf.random_normal(tf.shape(logvar), name="epsilon") 268 | z = mu + epsilon * tf.exp(0.5*logvar) 269 | elif distribution == 'vmf': 270 | pass 271 | 272 | ops_with_bn = [fc, conv2d, deconv2d] 273 | -------------------------------------------------------------------------------- /run_gan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from gan_models.config import Config 6 | from gan_models.model import GAN 7 | from gan_models.train import train 8 | from utils import pp, visualize, to_json 9 | 10 | from IPython import embed 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | 13 | flags = tf.app.flags 14 | flags.DEFINE_integer("epoch", 1000000, "Max epoch to train") 15 | flags.DEFINE_string("exp", 0, "Experiment number") 16 | flags.DEFINE_string("batch_size", 64, "Batch size") 17 | flags.DEFINE_string("learning_rate", 0.0002, "Learning rate") 18 | flags.DEFINE_string("load_cp_dir", '', "checkpoint path") 19 | flags.DEFINE_string("dataset", "mnist", "[mnist, affmnist, cifar10]") 20 | flags.DEFINE_string("loss", "jsd", "[jsd, alternative, reverse_kl]") 21 | flags.DEFINE_boolean("is_train", True, "True for training, False for testing [False]") 22 | flags.DEFINE_boolean("use_augmentation", True, "Normalization and random brightness/contrast") 23 | flags.DEFINE_string("generator", 'dcgan_g', '') 24 | flags.DEFINE_string("discriminator", 'dcgan_d', '') 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | def main(_): 29 | pp.pprint(flags.FLAGS.__flags) 30 | 31 | config = Config(FLAGS) 32 | config.print_config() 33 | config.make_dirs() 34 | 35 | config_proto = tf.ConfigProto(allow_soft_placement=FLAGS.is_train, log_device_placement=False) 36 | config_proto.gpu_options.allow_growth = True 37 | 38 | with tf.Session(config=config_proto) as sess: 39 | model = GAN(config) 40 | train(model, sess) 41 | 42 | if __name__ == '__main__': 43 | tf.app.run() 44 | -------------------------------------------------------------------------------- /run_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from vae_models.config import Config 6 | from vae_models.model import VAE 7 | from vae_models.train import train 8 | from utils import pp, visualize, to_json 9 | 10 | from IPython import embed 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | 13 | flags = tf.app.flags 14 | flags.DEFINE_integer("epoch", 1500, "Max epoch to train") 15 | flags.DEFINE_string("exp", 1, "Experiment number") 16 | flags.DEFINE_string("batch_size", 100, "Batch size") 17 | flags.DEFINE_string("learning_rate", 1e-4, "Learning rate") 18 | flags.DEFINE_string("load_cp_dir", '', "checkpoint path") 19 | flags.DEFINE_string("dataset", "cifar10", "[mnist, fashion, affmnist, cifar10]") 20 | flags.DEFINE_string("latent_distribution", "gaussian", "[gaussian]") 21 | flags.DEFINE_boolean("is_train", True, "True for training, False for testing [False]") 22 | flags.DEFINE_boolean("use_augmentation", True, "Normalization and random brightness/contrast") 23 | flags.DEFINE_string("encoder", 'base_encoder', '') 24 | flags.DEFINE_string("decoder", 'base_decoder', '') 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | def main(_): 29 | pp.pprint(flags.FLAGS.__flags) 30 | 31 | config = Config(FLAGS) 32 | config.print_config() 33 | config.make_dirs() 34 | 35 | config_proto = tf.ConfigProto(allow_soft_placement=FLAGS.is_train, log_device_placement=False) 36 | config_proto.gpu_options.allow_growth = True 37 | 38 | with tf.Session(config=config_proto) as sess: 39 | model = VAE(config) 40 | train(model, sess) 41 | 42 | if __name__ == '__main__': 43 | tf.app.run() 44 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | import math 5 | import os 6 | import errno 7 | import json 8 | import random 9 | import pprint 10 | import scipy.misc 11 | import scipy.stats as stats 12 | import numpy as np 13 | from time import gmtime, strftime 14 | import tensorflow as tf 15 | from sklearn.cluster import mean_shift, estimate_bandwidth 16 | from sklearn.metrics.pairwise import pairwise_distances 17 | 18 | from IPython import embed 19 | np.seterr(all='warn') 20 | 21 | pp = pprint.PrettyPrinter() 22 | 23 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 24 | 25 | index = 0 26 | def get_image(image_path, image_size, is_crop=True, resize_w=64): 27 | global index 28 | out = transform(imread(image_path), image_size, is_crop, resize_w) 29 | return out 30 | 31 | 32 | def save_images(images, size, image_path): 33 | dir_path = '/'.join(image_path.split('/')[:-1]) 34 | if not os.path.exists(dir_path): 35 | os.makedirs(dir_path) 36 | return imsave(inverse_transform(images), size, image_path) 37 | 38 | def imread(path): 39 | img = scipy.misc.imread(path) 40 | if len(img.shape) == 0: 41 | raise ValueError(path + " got loaded as a dimensionless array!") 42 | return img.astype(np.float) 43 | 44 | def merge_images(images, size): 45 | return inverse_transform(images) 46 | 47 | def merge(images, size): 48 | h, w = images.shape[1], images.shape[2] 49 | img = np.zeros((h * size[0], w * size[1], 3)) 50 | images = images[:size[0]*size[1]] 51 | 52 | for idx, image in enumerate(images): 53 | i = idx % size[1] 54 | j = idx / size[0] 55 | img[j*h:j*h+h, i*w:i*w+w, :] = image 56 | 57 | return img 58 | 59 | def imsave(images, size, path): 60 | if images.shape[-1] == 1: 61 | images = np.repeat(images, 3, axis=3) 62 | return scipy.misc.imsave(path, merge(images, size)) 63 | 64 | def center_crop(x, crop_h, crop_w=None, resize_w=64): 65 | 66 | h, w = x.shape[:2] 67 | crop_h = min(h, w) # we changed this to override the original DCGAN-TensorFlow behavior 68 | # Just use as much of the image as possible while keeping it square 69 | 70 | if crop_w is None: 71 | crop_w = crop_h 72 | j = int(round((h - crop_h)/2.)) 73 | i = int(round((w - crop_w)/2.)) 74 | return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], 75 | [resize_w, resize_w]) 76 | 77 | def transform(image, npx=64, is_crop=True, resize_w=64): 78 | # npx : # of pixels width/height of image 79 | cropped_image = center_crop(image, npx, resize_w=resize_w) 80 | return np.array(cropped_image)/127.5 - 1. 81 | 82 | def inverse_transform(images): 83 | return (images+1.)/2. 84 | 85 | 86 | def to_json(output_path, *layers): 87 | with open(output_path, "w") as layer_f: 88 | lines = "" 89 | for w, b, bn in layers: 90 | layer_idx = w.name.split('/')[0].split('h')[1] 91 | 92 | B = b.eval() 93 | 94 | if "lin/" in w.name: 95 | W = w.eval() 96 | depth = W.shape[1] 97 | else: 98 | W = np.rollaxis(w.eval(), 2, 0) 99 | depth = W.shape[0] 100 | 101 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 102 | if bn != None: 103 | gamma = bn.gamma.eval() 104 | beta = bn.beta.eval() 105 | 106 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 107 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 108 | else: 109 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 110 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 111 | 112 | if "lin/" in w.name: 113 | fs = [] 114 | for w in W.T: 115 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 116 | 117 | lines += """ 118 | var layer_%s = { 119 | "layer_type": "fc", 120 | "sy": 1, "sx": 1, 121 | "out_sx": 1, "out_sy": 1, 122 | "stride": 1, "pad": 0, 123 | "out_depth": %s, "in_depth": %s, 124 | "biases": %s, 125 | "gamma": %s, 126 | "beta": %s, 127 | "filters": %s 128 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 129 | else: 130 | fs = [] 131 | for w_ in W: 132 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 133 | 134 | lines += """ 135 | var layer_%s = { 136 | "layer_type": "deconv", 137 | "sy": 5, "sx": 5, 138 | "out_sx": %s, "out_sy": %s, 139 | "stride": 2, "pad": 1, 140 | "out_depth": %s, "in_depth": %s, 141 | "biases": %s, 142 | "gamma": %s, 143 | "beta": %s, 144 | "filters": %s 145 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 146 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 147 | layer_f.write(" ".join(lines.replace("'","").split())) 148 | 149 | def make_gif(images, fname, duration=2, true_image=False): 150 | import moviepy.editor as mpy 151 | 152 | def make_frame(t): 153 | try: 154 | x = images[int(len(images)/duration*t)] 155 | except: 156 | x = images[-1] 157 | 158 | if true_image: 159 | return x.astype(np.uint8) 160 | else: 161 | return ((x+1)/2*255).astype(np.uint8) 162 | 163 | clip = mpy.VideoClip(make_frame, duration=duration) 164 | clip.write_gif(fname, fps = len(images) / duration) 165 | 166 | def visualize(sess, dcgan, config, option): 167 | option = 0 168 | if option == 0: 169 | all_samples = [] 170 | for i in range(484): 171 | print(i) 172 | samples = sess.run(dcgan.generator()) 173 | all_samples.append(samples) 174 | samples = np.concatenate(all_samples, 0) 175 | n = int(np.sqrt(samples.shape[0])) 176 | m = samples.shape[0] // n 177 | save_images(samples, [m, n], './' + config.sample_dir + '/test.png')#_%s.png' % strftime("%Y-%m-%d %H:%M:%S", gmtime())) 178 | elif option == 5: 179 | counter = 0 180 | coord = tf.train.Coordinator() 181 | threads = tf.train.start_queue_runners(coord=coord) 182 | while counter < 1005: 183 | print(counter) 184 | samples, fake = sess.run([dcgan.generator(), dcgan.d_loss_class]) 185 | fake = np.argsort(fake) 186 | print(np.sum(samples)) 187 | print(fake) 188 | for i in range(samples.shape[0]): 189 | name = "%s%d.png" % (chr(ord('a') + counter % 10), counter) 190 | img = np.expand_dims(samples[fake[i]], 0) 191 | if counter >= 1000: 192 | save_images(img, [1, 1], './' + config.sample_dir + '/turk/fake%d.png' % (counter - 1000)) 193 | else: 194 | save_images(img, [1, 1], './' + config.sample_dir + '/turk/%s' % (name)) 195 | counter += 1 196 | elif option == 1: 197 | values = np.arange(0, 1, 1./config.batch_size) 198 | for idx in xrange(100): 199 | print(" [*] %d" % idx) 200 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 201 | for kdx, z in enumerate(z_sample): 202 | z[idx] = values[kdx] 203 | 204 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 205 | save_images(samples, [8, 8], './' + options.sample_dir + '/test_arange_%s.png' % (idx)) 206 | elif option == 2: 207 | values = np.arange(0, 1, 1./config.batch_size) 208 | for idx in [random.randint(0, 99) for _ in xrange(100)]: 209 | print(" [*] %d" % idx) 210 | 211 | if hasattr(dcgan, z): 212 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 213 | z_sample = np.tile(z, (config.batch_size, 1)) 214 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 215 | for kdx, z in enumerate(z_sample): 216 | z[idx] = values[kdx] 217 | 218 | if hasattr(dcgan, "sampler"): 219 | sampler = dcgan.sampler 220 | else: 221 | sampler = dcgan.generator() 222 | samples = sess.run(sampler, feed_dict={dcgan.z: z_sample}) 223 | make_gif(samples, './' + config.sample_dir + '/test_gif_%s.gif' % (idx)) 224 | elif option == 3: 225 | values = np.arange(0, 1, 1./config.batch_size) 226 | for idx in xrange(100): 227 | print(" [*] %d" % idx) 228 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 229 | for kdx, z in enumerate(z_sample): 230 | z[idx] = values[kdx] 231 | 232 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 233 | make_gif(samples, './' + config.sample_dir + '/test_gif_%s.gif' % (idx)) 234 | elif option == 4: 235 | image_set = [] 236 | values = np.arange(0, 1, 1./config.batch_size) 237 | 238 | for idx in xrange(100): 239 | print(" [*] %d" % idx) 240 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 241 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 242 | 243 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 244 | make_gif(image_set[-1], './' + config.sample_dir + '/test_gif_%s.gif' % (idx)) 245 | 246 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 247 | for idx in range(64) + range(63, -1, -1)] 248 | make_gif(new_image_set, './' + config.sample_dir + '/test_gif_merged.gif', duration=8) 249 | 250 | 251 | def colorize(img): 252 | if img.ndim == 2: 253 | img = img.reshape(img.shape[0], img.shape[1], 1) 254 | img = np.concatenate([img, img, img], axis=2) 255 | if img.shape[2] == 4: 256 | img = img[:, :, 0:3] 257 | return img 258 | 259 | 260 | def mkdir_p(path): 261 | # Copied from http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python 262 | try: 263 | os.makedirs(path) 264 | except OSError as exc: # Python >2.5 265 | if exc.errno == errno.EEXIST and os.path.isdir(path): 266 | pass 267 | else: 268 | raise 269 | 270 | -------------------------------------------------------------------------------- /vae_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whyjay/GENERATIVE_MODEL_TESTBED.tensorflow/78c7938a85b6a6e94929e5277446b00cc6c0544c/vae_models/__init__.py -------------------------------------------------------------------------------- /vae_models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import datetime 4 | from glob import glob 5 | import tensorflow as tf 6 | 7 | from ops import * 8 | from utils import * 9 | 10 | from vae_models.encoder import * 11 | from vae_models.decoder import * 12 | #from models.evaluate import evaluate 13 | from utils import pp, visualize, to_json 14 | 15 | from IPython import embed 16 | 17 | class Config(object): 18 | def __init__(self, FLAGS): 19 | self.exp_num = str(FLAGS.exp) 20 | self.load_cp_dir = FLAGS.load_cp_dir 21 | self.dataset = FLAGS.dataset 22 | self.dataset_path = os.path.join("./dataset/", self.dataset) 23 | self.devices = ["gpu:0", "gpu:1", "gpu:2", "gpu:3"] 24 | self.use_augmentation = FLAGS.use_augmentation 25 | self.batch_size = FLAGS.batch_size 26 | self.learning_rate = FLAGS.learning_rate 27 | self.latent_distribution = FLAGS.latent_distribution 28 | 29 | self.add_noise = True 30 | self.noise_stddev = 0.1 31 | 32 | 33 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")+str(self.learning_rate) 34 | 35 | self.epoch = FLAGS.epoch 36 | self.log_dir = os.path.join('logs/vae', self.exp_num, timestamp) 37 | self.checkpoint_dir = os.path.join('checkpoint/vae', self.exp_num, timestamp) 38 | self.sample_dir = os.path.join('samples/vae', self.exp_num, timestamp) 39 | self.timestamp = timestamp 40 | 41 | self.encoder_name = FLAGS.encoder 42 | self.decoder_name = FLAGS.decoder 43 | 44 | self.encoder_func = globals()[self.encoder_name] 45 | self.decoder_func = globals()[self.decoder_name] 46 | 47 | self.kappa = 1 48 | 49 | if FLAGS.dataset in ['mnist', 'fashion']: 50 | self.y_dim=10 51 | self.image_shape=[28, 28, 1] 52 | self.c_dim=1 53 | self.z_dim=20 54 | self.f_dim = 64 55 | self.fc_dim = 512 56 | self.beta1 = 0.5 57 | self.beta2 = 0.999 58 | 59 | elif FLAGS.dataset == 'affmnist': 60 | self.y_dim=10 61 | self.image_shape=[40, 40, 1] 62 | self.c_dim=1 63 | self.z_dim=20 64 | self.f_dim = 64 65 | self.fc_dim = 512 66 | self.beta1 = 0.5 67 | self.beta2 = 0.999 68 | 69 | elif FLAGS.dataset == 'celebA': 70 | self.y_dim=1 71 | self.image_shape=[64, 64, 3] 72 | self.c_dim=3 73 | self.z_dim=64 74 | self.f_dim = 64 75 | self.fc_dim = 1024 76 | self.beta1 = 0.5 77 | self.beta2 = 0.999 78 | 79 | elif FLAGS.dataset == 'cifar10': 80 | self.y_dim=10 81 | self.image_shape=[32, 32, 3] 82 | self.c_dim=3 83 | self.z_dim=64 84 | self.f_dim = 128 85 | self.fc_dim = 1024 86 | self.beta1 = 0. 87 | self.beta2 = 0.9 88 | 89 | self.sample_size=10*self.batch_size 90 | 91 | def print_config(self): 92 | dicts = self.__dict__ 93 | for key in dicts.keys(): 94 | print key, dicts[key] 95 | 96 | def make_dirs(self): 97 | if not os.path.exists(self.checkpoint_dir): 98 | os.makedirs(self.checkpoint_dir) 99 | if not os.path.exists(self.sample_dir): 100 | os.makedirs(self.sample_dir) 101 | -------------------------------------------------------------------------------- /vae_models/decoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from ops import * 3 | from IPython import embed 4 | 5 | def base_decoder(model, z, reuse=False): 6 | bs = model.batch_size 7 | f_dim = model.f_dim 8 | fc_dim = model.fc_dim 9 | c_dim = model.c_dim 10 | 11 | with slim.arg_scope(ops_with_bn, is_training=model.is_training, init=None, norm=None): 12 | 13 | if model.dataset_name in ['mnist', 'fashion']: 14 | w = model.image_shape[0] 15 | 16 | h = fc(z, fc_dim/2, act=tf.nn.elu) 17 | h = fc(h, 384, act=tf.nn.elu) 18 | h = fc(h, fc_dim, act=tf.nn.elu) 19 | x = fc(h, c_dim*w*w, act=tf.nn.sigmoid, norm=None) 20 | x = tf.reshape(x, [-1, w, w, c_dim]) 21 | 22 | elif model.dataset_name == 'affmnist': 23 | n_layer = 3 24 | c = 2**(n_layer - 1) 25 | w = model.image_shape[0]/2**(n_layer) 26 | 27 | h = fc(z, f_dim * c * w * w, act=lrelu) 28 | h = tf.reshape(h, [-1, w, w, f_dim * c]) 29 | 30 | for i in range(n_layer - 1): 31 | w *= 2 32 | c /= 2 33 | h = deconv2d(h, f_dim * c, 4, 2) 34 | h = deconv2d(h, f_dim * c, 1, 1) 35 | 36 | x = deconv2d(h, c_dim, 4, 2, act=tf.nn.sigmoid, norm=None) 37 | 38 | elif model.dataset_name == 'cifar10': 39 | n_layer = 3 40 | w = model.image_shape[0]/2**(n_layer) 41 | 42 | h = fc(z, f_dim * w * w, act=tf.nn.elu, norm=ln) 43 | h = tf.reshape(h, [-1, w, w, f_dim]) 44 | 45 | c = f_dim 46 | for i in range(n_layer): 47 | c /= 2 48 | h = residual_block(h, resample='up', act=tf.nn.elu) 49 | 50 | x = conv2d(h, c_dim, 3, 1, act=tf.nn.tanh, norm=None) 51 | 52 | else: 53 | n_layer = 4 54 | c = 2**(n_layer - 1) 55 | w = model.image_shape[0]/2**(n_layer) 56 | 57 | h = fc(z, f_dim * c * w * w, act=lrelu) 58 | h = tf.reshape(h, [-1, w, w, f_dim * c]) 59 | 60 | for i in range(n_layer - 1): 61 | w *= 2 62 | c /= 2 63 | h = deconv2d(h, f_dim * c, 4, 2) 64 | h = deconv2d(h, f_dim * c, 1, 1) 65 | 66 | x = deconv2d(h, c_dim, 4, 2, act=tf.nn.tanh, norm=None) 67 | 68 | return x 69 | 70 | -------------------------------------------------------------------------------- /vae_models/encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from ops import * 4 | from IPython import embed 5 | 6 | def base_encoder(model, x, reuse=False): 7 | bs = model.batch_size 8 | f_dim = model.f_dim 9 | fc_dim = model.fc_dim 10 | c_dim = model.c_dim 11 | z_dim = model.z_dim 12 | 13 | with slim.arg_scope(ops_with_bn, is_training=model.is_training, init=None, norm=None): 14 | 15 | if model.dataset_name in ['mnist', 'fashion']: 16 | h = fc(tf.reshape(x, [bs, -1]), fc_dim, act=tf.nn.elu) 17 | h = fc(h, 384, act=tf.nn.elu) 18 | h = fc(h, fc_dim/2, act=tf.nn.elu) 19 | 20 | elif model.dataset_name == 'affmnist': 21 | n_layer = 3 22 | c = 1 23 | w = model.image_shape[0]/2**(n_layer) 24 | 25 | h = conv2d(x, f_dim * c, 4, 2, act=lrelu) 26 | for i in range(n_layer - 1): 27 | w /= 2 28 | c *= 2 29 | h = conv2d(h, f_dim * c, 4, 2, act=lrelu) 30 | h = conv2d(h, f_dim * c, 1, 1, act=lrelu) 31 | 32 | if i == n_layer - 2: 33 | feats = h 34 | 35 | elif model.dataset_name == 'cifar10': 36 | h = conv2d(x, f_dim, 4, 2, act=lrelu) 37 | h = conv2d(x, f_dim, 4, 2, act=lrelu) 38 | h = conv2d(x, f_dim, 4, 2, act=lrelu) 39 | h = conv2d(x, f_dim, 4, 2, act=lrelu) 40 | 41 | ''' 42 | h = conv2d(x, f_dim, 3, 1, act=tf.nn.elu, norm=None) 43 | h = conv_mean_pool(h, f_dim, 3, act=None, norm=None) 44 | h += conv_mean_pool(x, f_dim, 1, act=None, norm=None) 45 | h = tf.nn.elu(ln(h)) 46 | 47 | h = residual_block(h, resample='down', act=tf.nn.elu) 48 | h = residual_block(h, resample=None, act=tf.nn.elu) 49 | h = residual_block(h, resample=None, act=tf.nn.elu) 50 | 51 | h = conv2d(h, f_dim, 4, 2, act=None, norm=ln) 52 | h = tf.reduce_mean(h, axis=[1,2]) 53 | ''' 54 | 55 | else: 56 | n_layer = 4 57 | c = 1 58 | w = model.image_shape[0]/2**(n_layer) 59 | 60 | h = conv2d(x, f_dim * c, 4, 2, act=lrelu, norm=None) 61 | for i in range(n_layer - 1): 62 | w /= 2 63 | c *= 2 64 | h = conv2d(h, f_dim * c, 4, 2, act=lrelu) 65 | h = conv2d(h, f_dim * c, 1, 1, act=lrelu) 66 | 67 | if i == n_layer - 2: 68 | feats = h 69 | 70 | h = tf.reshape(h, [bs, -1]) 71 | 72 | z_mu = fc(h, z_dim, act=None, norm=None) 73 | z_logvar = fc(h, z_dim, act=None, norm=None) 74 | 75 | return z_mu, z_logvar 76 | -------------------------------------------------------------------------------- /vae_models/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from glob import glob 4 | import tensorflow as tf 5 | from tensorflow.python.ops import parsing_ops 6 | 7 | from ops import * 8 | from utils import * 9 | from IPython import embed 10 | 11 | slim = tf.contrib.slim 12 | 13 | class VAE(object): 14 | def __init__(self, config): 15 | self.devices = config.devices 16 | self.config = config 17 | 18 | self.encoder = NetworkWrapper(self, config.encoder_func) 19 | self.decoder = NetworkWrapper(self, config.decoder_func) 20 | 21 | #self.evaluate = Evaluate(self, config.eval_func) 22 | 23 | self.batch_size = config.batch_size 24 | self.sample_size = config.sample_size 25 | self.image_shape = config.image_shape 26 | self.sample_dir = config.sample_dir 27 | 28 | self.k = config.kappa 29 | self.latent_distribution = config.latent_distribution 30 | self.y_dim = config.y_dim 31 | self.c_dim = config.c_dim 32 | self.f_dim = config.f_dim 33 | self.fc_dim = config.fc_dim 34 | self.z_dim = config.z_dim 35 | self.beta1 = config.beta1 36 | self.beta2 = config.beta2 37 | 38 | self.dataset_name = config.dataset 39 | self.dataset_path = config.dataset_path 40 | self.checkpoint_dir = config.checkpoint_dir 41 | 42 | self.use_augmentation = config.use_augmentation 43 | 44 | def save(self, sess, checkpoint_dir, step): 45 | model_name = "vae" 46 | 47 | if not os.path.exists(checkpoint_dir): 48 | os.makedirs(checkpoint_dir) 49 | 50 | self.saver.save(sess, 51 | os.path.join(checkpoint_dir, model_name), 52 | global_step=step) 53 | 54 | def load(self, sess, checkpoint_dir): 55 | print(" [*] Reading checkpoints...") 56 | 57 | model_dir = "%s_%s" % (self.batch_size, self.config.learning_rate) 58 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 59 | 60 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 61 | if ckpt and ckpt.model_checkpoint_path: 62 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 63 | self.saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) 64 | return True 65 | else: 66 | print "Bad checkpoint: ", ckpt 67 | return False 68 | 69 | def get_vars(self): 70 | self.t_vars = tf.trainable_variables() 71 | 72 | def build_model(self): 73 | config = self.config 74 | self.is_training = tf.placeholder_with_default(False, shape=[], name='is_training') 75 | 76 | # input 77 | self.image = tf.placeholder(tf.float32, shape=[self.batch_size]+self.image_shape) 78 | self.label = tf.placeholder(tf.float32, shape=[self.batch_size]) 79 | self.noise = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim]) 80 | image = preprocess_image(self.image, self.dataset_name, self.use_augmentation) 81 | if self.dataset_name == 'mnist': 82 | image_ = np.random.uniform(size=image.get_shape()) 83 | image = tf.to_float(image > image_) 84 | 85 | z_mu, z_logvar = self.encoder(image) 86 | z = reparameterize(z_mu, z_logvar, self.latent_distribution) 87 | recon_image = self.decoder(z) 88 | 89 | loss_elbo, loss_recon, loss_kl = self.get_loss(image, recon_image, z_mu, z_logvar) 90 | 91 | # optimizer 92 | self.get_vars() 93 | opt = tf.train.AdamOptimizer(config.learning_rate) 94 | train_op = slim.learning.create_train_op(loss_elbo, opt, variables_to_train=self.t_vars) 95 | 96 | # logging 97 | tf.summary.scalar("loss_elbo", loss_elbo) 98 | tf.summary.scalar("loss_recon", loss_recon) 99 | tf.summary.scalar("loss_kl", loss_kl) 100 | tf.summary.image("input_images", batch_to_grid(image)) 101 | tf.summary.image("recon_images", batch_to_grid(recon_image)) 102 | 103 | self.recon_image = recon_image 104 | self.input_image = image 105 | self.z = z 106 | self.gen_image = self.decoder(self.noise) 107 | 108 | self.loss_elbo = loss_elbo 109 | self.loss_recon = loss_recon 110 | self.loss_kl = loss_kl + (np.prod(image.get_shape().as_list())-1.)/2.*np.log(2.) 111 | self.saver = tf.train.Saver(max_to_keep=None) 112 | 113 | return train_op 114 | 115 | def get_loss(self, image, recon_image, z_mu, z_logvar, eps = 1e-10): 116 | if self.dataset_name == 'mnist': 117 | loss_recon = -tf.reduce_sum( 118 | image * tf.log(eps+recon_image) + (1-image) * tf.log(eps+1-recon_image), axis=[1, 2, 3] 119 | ) 120 | else: 121 | loss_recon = tf.reduce_sum(2. * tf.square(image - recon_image), axis=[1,2,3]) 122 | loss_recon = tf.reduce_mean(loss_recon)# / np.prod(self.image_shape) 123 | 124 | if self.latent_distribution == 'gaussian': 125 | loss_kl = -0.5 * tf.reduce_sum( 126 | 1 + z_logvar - tf.square(z_mu) - tf.exp(z_logvar), axis=1 127 | ) 128 | else: 129 | print "invalid latent distribution : %s" % self.latent_distribution 130 | raise 131 | 132 | loss_recon = tf.reduce_mean(loss_recon) 133 | loss_kl = tf.reduce_mean(loss_kl) 134 | loss_elbo = tf.reduce_mean(loss_recon + loss_kl) 135 | return loss_elbo, loss_recon, loss_kl 136 | 137 | class NetworkWrapper(object): 138 | def __init__(self, model, func): 139 | self.model = model 140 | self.func = func 141 | 142 | def __call__(self, z, reuse=False): 143 | return self.func(self.model, z, reuse=reuse) 144 | 145 | 146 | -------------------------------------------------------------------------------- /vae_models/train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import time 4 | import numpy as np 5 | from utils import * 6 | from ops import * 7 | from IPython import embed 8 | 9 | def train(model, sess): 10 | config = model.config 11 | train_op = model.build_model() 12 | 13 | if not (config.load_cp_dir == ''): 14 | model.load(sess, config.load_cp_dir) 15 | merged_sum = init_training(model, sess) 16 | start_time = time.time() 17 | print_time = time.time() 18 | 19 | dataset = load_dataset(model) 20 | N = dataset.num_examples 21 | max_iter = int(N/model.batch_size) * model.config.epoch 22 | 23 | print "[*] Traing Start : N=%d, Batch=%d, epoch=%d, max_iter=%d" \ 24 | %(N, model.batch_size, model.config.epoch, max_iter) 25 | 26 | for idx in xrange(1, max_iter): 27 | batch_start_time = time.time() 28 | 29 | image, label = dataset.next_batch(model.batch_size) 30 | _, recon_image, input_image, z, loss_elbo, loss_kl, loss_recon = sess.run( 31 | [train_op, model.recon_image, model.input_image, model.z, model.loss_elbo, model.loss_kl, model.loss_recon], 32 | feed_dict={model.image:image, model.label:label, model.is_training:True}) 33 | 34 | # save checkpoint for every epoch 35 | if (idx*model.batch_size) % N < model.batch_size: 36 | epoch = int(idx*model.batch_size/N) 37 | print_time = time.time() 38 | total_time = print_time - start_time 39 | sec_per_epoch = (print_time - start_time) / epoch 40 | 41 | image, label = dataset.next_batch(model.batch_size) 42 | summary = sess.run(merged_sum, feed_dict={model.image:image, model.label:label, model.z:get_z(model)}) 43 | 44 | model.writer.add_summary(summary, epoch) 45 | 46 | _save_samples(model, sess, epoch) 47 | model.save(sess, model.checkpoint_dir, epoch) 48 | 49 | print '[Epoch %(epoch)d] time: %(total_time)4.4f, loss_elbo: %(loss_elbo).4f, loss_kl: %(loss_kl).4f, loss_recon: %(loss_recon).4f, sec_per_epoch: %(sec_per_epoch)4.4f' % locals() 50 | 51 | sess.close() 52 | 53 | def _save_samples(model, sess, epoch): 54 | samples = [] 55 | noises = [] 56 | 57 | # generator hard codes the batch size 58 | for i in xrange(model.sample_size // model.batch_size): 59 | noise = get_z(model) 60 | gen_image = sess.run(model.gen_image, feed_dict={model.noise:noise}) 61 | samples.append(gen_image) 62 | noises.append(noise) 63 | 64 | samples = np.concatenate(samples, axis=0) 65 | noises = np.concatenate(noises, axis=0) 66 | 67 | assert samples.shape[0] == model.sample_size 68 | save_images(samples, [8, 8], os.path.join(model.sample_dir, 'samples_%s.png' % (epoch))) 69 | 70 | print "Save Samples at %s/%s" % (model.sample_dir, 'samples_%s' % (epoch)) 71 | with open(os.path.join(model.sample_dir, 'samples_%d.npy'%(epoch)), 'w') as f: 72 | np.save(f, samples) 73 | with open(os.path.join(model.sample_dir, 'noises_%d.npy'%(epoch)), 'w') as f: 74 | np.save(f, noises) 75 | 76 | def init_training(model, sess): 77 | config = model.config 78 | init_op = tf.global_variables_initializer() 79 | sess.run(init_op) 80 | 81 | merged_sum = tf.summary.merge_all() 82 | model.writer = tf.summary.FileWriter(config.log_dir, sess.graph) 83 | 84 | if model.load(sess, model.checkpoint_dir): 85 | print(" [*] Load SUCCESS") 86 | else: 87 | print(" [!] Load failed...") 88 | 89 | if not os.path.exists(config.dataset_path): 90 | print(" [!] Data does not exist : %s" % config.dataset_path) 91 | return merged_sum 92 | 93 | def load_dataset(model): 94 | if model.dataset_name == 'mnist': 95 | import mnist as ds 96 | elif model.dataset_name == 'fashion': 97 | import fashion as ds 98 | elif model.dataset_name == 'cifar10': 99 | import cifar10 as ds 100 | return ds.read_data_sets(model.dataset_path, dtype=tf.uint8, reshape=False, validation_size=0).train 101 | 102 | def get_z(model): 103 | if model.latent_distribution == 'vmf': 104 | z = np.random.normal(0., 1., size=(model.batch_size, model.z_dim)) 105 | return z/np.linalg.norm(z) 106 | else: 107 | return np.random.normal(0., 1., size=(model.batch_size, model.z_dim)) 108 | 109 | --------------------------------------------------------------------------------