├── .gitignore
├── README.md
├── datasets
├── __init__.py
├── create_celeba.py
├── datasets.py
├── mnist.py
└── svhn.py
├── models
├── __init__.py
├── base.py
├── began.py
├── cvae.py
├── cvaegan.py
├── dcgan.py
├── improved_gan.py
├── lsgan.py
├── resnet_gan.py
├── utils.py
├── vae.py
├── wgan.py
└── wnorm.py
├── results
├── svhn_cvaegan_epoch_0050_batch_73257.png
└── svhn_dcgan_epoch_0050_batch_73257.png
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output/*
2 | datasets/files/*
3 |
4 | __pycache__
5 |
6 | # Compiled source #
7 | ###################
8 | *.com
9 | *.class
10 | *.dll
11 | *.exe
12 | *.o
13 | *.so
14 | *.pyc
15 |
16 | # Packages #
17 | ############
18 | # it's better to unpack these files and commit the raw source
19 | # git has its own built in compression methods
20 | *.7z
21 | *.dmg
22 | *.gz
23 | *.iso
24 | *.rar
25 | #*.tar
26 | *.zip
27 |
28 | # Logs and databases #
29 | ######################
30 | *.log
31 | *.sqlite
32 |
33 | # OS generated files #
34 | ######################
35 | .DS_Store
36 | ehthumbs.db
37 | Icon
38 | Thumbs.db
39 | .tmtags
40 | .idea
41 | tags
42 | vendor.tags
43 | tmtagsHistory
44 | *.sublime-project
45 | *.sublime-workspace
46 | .bundle
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | TensorFlow VAEs and GANs
2 | ===
3 |
4 | TensorFlow implementation of various deep generative networks such as VAE and GAN.
5 |
6 | ## Models
7 |
8 | ### Standard models
9 |
10 | * Variational autoencoder (VAE) [Kingma et al. 2013]
11 | * Generative adversarial network (GAN or DCGAN) [Goodfellow et al. 2014]
12 |
13 |
14 |
15 |
16 | ### Conditional models
17 |
18 | * Conditional variational autoencoder [Kingma et al. 2014]
19 | * CVAE-GAN [Bao et al. 2017]
20 |
21 | ## Usage
22 |
23 | ### Prepare datasets
24 |
25 | #### MNIST and SVHN
26 |
27 | MNIST and SVHN datasets are automatically downloaded from their websites.
28 |
29 | #### CelebA
30 |
31 | First, download ``img_align_celeba.zip`` and ``list_attr_celeba.txt`` from CelebA [webpage](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
32 | Then, place these files to ``datasets`` and run ``create_database.py`` on ``databsets`` directory.
33 |
34 | ### Training
35 |
36 | ```shell
37 | # Both standard and conditional models are available!
38 | python train.py --model=dcgan --epoch=200 --batchsize=100 --output=output
39 | ```
40 |
41 | TensorBoard is also available with the following script.
42 |
43 | ```shell
44 | tensorboard --logdir="output/dcgan/log"
45 | ```
46 |
47 | ### Results
48 |
49 | #### DCGAN (for SVHN 50 epochs)
50 |
51 |
52 |
53 | #### CVAE-GAN (for SVHN 50 epochs)
54 |
55 |
56 |
57 | ## References
58 |
59 | * Kingma et al., "Auto-Encoding Variational Bayes", arXiv preprint 2013.
60 | * Goodfellow et al., "Generative adversarial nets", NIPS 2014.
61 |
62 |
63 |
64 | * Kingma et al., "Semi-supervised learning with deep generative models", NIPS 2014.
65 | * Bao et al., "CVAE-GAN: Fine-Grained Image Generation through Asymmetric Training", arXiv preprint 2017.
66 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import load_data, Dataset, ConditionalDataset, PairwiseDataset
2 |
--------------------------------------------------------------------------------
/datasets/create_celeba.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | import zipfile
5 |
6 | import numpy as np
7 | import h5py
8 |
9 | import requests
10 | from PIL import Image
11 |
12 | google_drive_prefix = "https://docs.google.com/uc?export=download"
13 | image_url = 'https://drive.google.com/open?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM'
14 | attr_url = 'https://drive.google.com/open?id=0B7EVK8r0v71pblRyaVFSWGxPY0U'
15 |
16 | target_dir = os.path.join(os.path.dirname(__file__), 'files')
17 | outfile = os.path.join(target_dir, 'celebA.hdf5')
18 | image_file = 'img_align_celeba.zip'
19 | attr_file = 'list_attr_celeba.txt'
20 |
21 | def get_confirm_token(response):
22 | for key, value in response.cookies.items():
23 | if key.startswith('download_warning'):
24 | return value
25 |
26 | return None
27 |
28 | def save_response_content(response, destination):
29 | CHUNK_SIZE = 32768
30 | PROGBAR_WIDTH = 50
31 |
32 | with open(destination, "wb") as f:
33 | dl = 0
34 | for chunk in response.iter_content(CHUNK_SIZE):
35 | if chunk:
36 | dl += len(chunk)
37 | f.write(chunk)
38 |
39 | mb = dl / 1.0e6
40 | sys.stdout.write('\r%.2f MB downloaded...' % mb)
41 | sys.stdout.flush()
42 |
43 | sys.stdout.write('\nFinish!\n')
44 | sys.stdout.flush()
45 |
46 | def download_from_google_drive(url, dest):
47 | pat = re.compile('id=([a-zA-Z0-9]+)')
48 | mat = pat.search(url)
49 | if mat is None:
50 | raise Exception('Invalide url:', url)
51 |
52 | idx = mat.group(1)
53 |
54 | session = requests.Session()
55 |
56 | response = session.get(google_drive_prefix, params={'id': idx}, stream=True)
57 | token = get_confirm_token(response)
58 |
59 | if token:
60 | params = {'id': idx, 'confirm': token}
61 | response = session.get(google_drive_prefix, params=params, stream=True)
62 |
63 | print('Downloading:', url)
64 | save_response_content(response, dest)
65 |
66 | def main():
67 | # Download image ZIP
68 | if os.path.exists(image_file):
69 | print('Image ZIP file exists. Skip downloading.')
70 | else:
71 | download_from_google_drive(image_url, image_file)
72 |
73 | # Download attribute file
74 | if os.path.exists(attr_file):
75 | print('Attribute file exists. Skip downloading.')
76 | else:
77 | download_from_google_drive(attr_url, attr_file)
78 |
79 | # Create folder
80 | if not os.path.isdir(target_dir):
81 | os.mkdir(target_dir)
82 |
83 | # Parse labels
84 | with open(attr_file, 'r') as lines:
85 | lines = [l.strip() for l in lines]
86 | num_images = int(lines[0])
87 |
88 | label_names = re.split('\s+', lines[1])
89 | label_names = np.array(label_names, dtype=object)
90 | num_labels = len(label_names)
91 |
92 | lines = lines[2:]
93 | labels = np.ndarray((num_images, num_labels), dtype='uint8')
94 | for i in range(num_images):
95 | label = [int(l) for l in re.split('\s+', lines[i])[1:]]
96 | label = np.maximum(0, label).astype(np.uint8)
97 | labels[i] = label
98 |
99 | ## Parse images
100 | with zipfile.ZipFile(image_file, 'r', zipfile.ZIP_DEFLATED) as zf:
101 | image_files = [f for f in zf.namelist()]
102 | image_files = sorted(image_files)
103 | image_files = list(filter(lambda f: f.endswith('.jpg'), image_files))
104 |
105 | num_images = len(image_files)
106 | print('%d images' % (num_images))
107 |
108 | image_data = np.ndarray((num_images, 64, 64, 3), dtype='uint8')
109 | for i, f in enumerate(image_files):
110 | image = Image.open(zf.open(f, 'r')).resize((64, 78), Image.ANTIALIAS).crop((0, 7, 64, 64 + 7))
111 | image = np.asarray(image, dtype='uint8')
112 | image_data[i] = image
113 | print('%d / %d' % (i + 1, num_images), end='\r', flush=True)
114 |
115 | # Create HDF5 file
116 | h5 = h5py.File(outfile, 'w')
117 | string_dt = h5py.special_dtype(vlen=str)
118 | dset = h5.create_dataset('images', data=image_data, dtype='uint8')
119 | dset = h5.create_dataset('attr_names', data=label_names, dtype=string_dt)
120 | dset = h5.create_dataset('attrs', data=labels, dtype='uint8')
121 |
122 | h5.flush()
123 | h5.close()
124 |
125 | # Delete files
126 | os.remove(image_file)
127 | os.remove(attr_file)
128 |
129 | if __name__ == '__main__':
130 | main()
131 |
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 |
4 | class Dataset(object):
5 | def __init__(self):
6 | self.images = None
7 |
8 | def __len__(self):
9 | return len(self.images)
10 |
11 | def _get_shape(self):
12 | return self.images.shape
13 |
14 | shape = property(_get_shape)
15 |
16 | class ConditionalDataset(Dataset):
17 | def __init__(self):
18 | super(ConditionalDataset, self).__init__()
19 | self.attrs = None
20 | self.attr_names = None
21 |
22 | class PairwiseDataset(object):
23 | def __init__(self, x_data, y_data):
24 | assert x_data.shape[1] == y_data.shape[1]
25 | assert x_data.shape[2] == y_data.shape[2]
26 | assert x_data.shape[3] == 1 or y_data.shape[3] == 1 or \
27 | x_data.shape[3] == y_data.shape[3]
28 |
29 | if x_data.shape[3] != y_data.shape[3]:
30 | d = max(x_data.shape[3], y_data.shape[3])
31 | if x_data.shape[3] != d:
32 | x_data = np.tile(x_data, [1, 1, 1, d])
33 | if y_data.shape[3] != d:
34 | y_Data = np.tile(y_data, [1, 1, 1, d])
35 |
36 | x_len = len(x_data)
37 | y_len = len(y_data)
38 | l = min(x_len, y_len)
39 |
40 | self.x_data = x_data[:l]
41 | self.y_data = y_data[:l]
42 |
43 | def __len__(self):
44 | return len(self.x_data)
45 |
46 | def _get_shape(self):
47 | return self.x_data.shape
48 |
49 | shape = property(_get_shape)
50 |
51 | def load_data(filename, size=-1):
52 | f = h5py.File(filename)
53 |
54 | dset = ConditionalDataset()
55 | dset.images = np.asarray(f['images'], 'float32') / 255.0
56 | dset.attrs = np.asarray(f['attrs'], 'float32')
57 | dset.attr_names = np.asarray(f['attr_names'])
58 |
59 | if size > 0:
60 | dset.images = dset.images[:size]
61 | dset.attrs = dset.attrs[:size]
62 |
63 | return dset
64 |
--------------------------------------------------------------------------------
/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import gzip
4 | import struct
5 | import requests
6 |
7 | import numpy as np
8 |
9 | import tensorflow as tf
10 |
11 | from .datasets import ConditionalDataset
12 | url = 'http://yann.lecun.com/exdb/mnist/'
13 | x_train_file = 'train-images-idx3-ubyte.gz'
14 | y_train_file = 'train-labels-idx1-ubyte.gz'
15 | x_test_file = 't10k-images-idx3-ubyte.gz'
16 | y_test_file = 't10k-labels-idx1-ubyte.gz'
17 |
18 | curdir = os.path.abspath(os.path.dirname(__file__))
19 | outdir = os.path.join(curdir, 'files', 'mnist')
20 |
21 | CHUNK_SIZE = 32768
22 |
23 | def download_mnist():
24 | if not os.path.exists(outdir):
25 | os.makedirs(outdir)
26 |
27 | # Download files
28 | files = [x_train_file, y_train_file, x_test_file, y_test_file]
29 | for f in files:
30 | session = requests.Session()
31 | response = session.get(os.path.join(url, f), stream=True)
32 | print('Downloading: %s' % (os.path.join(url, f)))
33 | with open(os.path.join(outdir, f), 'wb') as fp:
34 | dl = 0
35 | for chunk in response.iter_content(CHUNK_SIZE):
36 | if chunk:
37 | dl += len(chunk)
38 | fp.write(chunk)
39 |
40 | mb = dl / 1.0e6
41 | sys.stdout.write('\r%.2f MB downloaded...' % (mb))
42 | sys.stdout.flush()
43 |
44 | sys.stdout.write('\nFinish!\n')
45 | sys.stdout.flush()
46 |
47 | def load_images(filename):
48 | with gzip.GzipFile(filename, 'rb') as fp:
49 | # Magic number
50 | magic = struct.unpack('>I', fp.read(4))[0]
51 |
52 | # item sizes
53 | n, rows, cols = struct.unpack('>III', fp.read(4 * 3))
54 |
55 | # Load items
56 | data = np.ndarray((n, rows, cols), dtype=np.uint8)
57 | for i in range(n):
58 | sub = struct.unpack('B' * rows * cols, fp.read(rows * cols))
59 | data[i] = np.asarray(sub).reshape((rows, cols))
60 |
61 | return data
62 |
63 | def load_labels(filename):
64 | with gzip.GzipFile(filename, 'rb') as fp:
65 | # Magic number
66 | magic = struct.unpack('>I', fp.read(4))
67 |
68 | # item sizes
69 | n= struct.unpack('>I', fp.read(4))[0]
70 |
71 | # Load items
72 | data = np.zeros((n, 10), dtype=np.uint8)
73 | for i in range(n):
74 | b = struct.unpack('>B', fp.read(1))[0]
75 | data[i, b] = 1
76 |
77 | return data
78 |
79 | def load_data():
80 | if not os.path.exists(outdir):
81 | download_mnist()
82 |
83 | x_train = load_images(os.path.join(outdir, x_train_file))
84 | y_train = load_labels(os.path.join(outdir, y_train_file))
85 |
86 | x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=0)
87 | x_train = (x_train[:, :, :, np.newaxis] / 255.0).astype('float32')
88 | y_train = y_train.astype('float32')
89 |
90 | datasets = ConditionalDataset()
91 | datasets.images = x_train
92 | datasets.attrs = y_train
93 | datasets.attr_names = [str(i) for i in range(10)]
94 |
95 | return datasets
96 |
--------------------------------------------------------------------------------
/datasets/svhn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import requests
4 |
5 | import numpy as np
6 | import scipy as sp
7 | import scipy.io
8 |
9 | import matplotlib.pyplot as plt
10 |
11 | import tensorflow as tf
12 |
13 | from .datasets import ConditionalDataset
14 |
15 | url = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat'
16 | curdir = os.path.abspath(os.path.dirname(__file__))
17 | outdir = os.path.join(curdir, 'files', 'svhn')
18 | outfile = os.path.join(outdir, 'svhn.mat')
19 |
20 | CHUNK_SIZE = 32768
21 |
22 | def download_svhn():
23 | if not os.path.exists(outdir):
24 | os.makedirs(outdir)
25 |
26 | session = requests.Session()
27 | response = session.get(url, stream=True)
28 | print('Downloading: %s' % (url))
29 | with open(outfile, 'wb') as fp:
30 | dl = 0
31 | for chunk in response.iter_content(CHUNK_SIZE):
32 | if chunk:
33 | dl += len(chunk)
34 | fp.write(chunk)
35 |
36 | mb = dl / 1.0e6
37 | sys.stdout.write('\r%.2f MB downloaded...' % (mb))
38 | sys.stdout.flush()
39 |
40 | sys.stdout.write('\nFinish!\n')
41 | sys.stdout.flush()
42 |
43 | def load_data():
44 | if not os.path.exists(outfile):
45 | download_svhn()
46 |
47 | mat = sp.io.loadmat(outfile)
48 | x_train = mat['X']
49 |
50 | x_train = np.transpose(x_train, axes=[3, 0, 1, 2])
51 | x_train = (x_train / 255.0).astype('float32')
52 |
53 | indices = mat['y']
54 | indices = np.squeeze(indices)
55 | indices[indices == 10] = 0
56 | y_train = np.zeros((len(indices), 10))
57 | y_train[np.arange(len(indices)), indices] = 1
58 | y_train = y_train.astype('float32')
59 |
60 | datasets = ConditionalDataset()
61 | datasets.images = x_train
62 | datasets.attrs = y_train
63 | datasets.attr_names = [str(i) for i in range(10)]
64 |
65 | return datasets
66 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseModel, CondBaseModel
2 |
3 | from .vae import VAE
4 | from .dcgan import DCGAN
5 | from .improved_gan import ImprovedGAN
6 | from .resnet_gan import ResNetGAN
7 | from .began import BEGAN
8 | from .wgan import WGAN
9 | from .lsgan import LSGAN
10 |
11 | from .cvae import CVAE
12 | from .cvaegan import CVAEGAN
--------------------------------------------------------------------------------
/models/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | import numpy as np
6 | from PIL import Image
7 |
8 | import matplotlib
9 | matplotlib.use('Agg')
10 | import matplotlib.pyplot as plt
11 | import matplotlib.gridspec as gridspec
12 |
13 | import tensorflow as tf
14 |
15 | from abc import ABCMeta, abstractmethod
16 | from .utils import *
17 |
18 | class BaseModel(metaclass=ABCMeta):
19 | """
20 | Base class for non-conditional generative networks
21 | """
22 |
23 | def __init__(self, **kwargs):
24 | """
25 | Initialization
26 | """
27 | if 'name' not in kwargs:
28 | raise Exception('Please specify model name!')
29 | self.name = kwargs['name']
30 |
31 | if 'batchsize' not in kwargs:
32 | raise Exception('Please specify batchsize!')
33 | self.batchsize = kwargs['batchsize']
34 |
35 | if 'input_shape' not in kwargs:
36 | raise Exception('Please specify input shape!')
37 |
38 | self.check_input_shape(kwargs['input_shape'])
39 | self.input_shape = kwargs['input_shape']
40 |
41 | if 'output' not in kwargs:
42 | self.output = 'output'
43 | else:
44 | self.output = kwargs['output']
45 |
46 | self.resume = kwargs['resume']
47 |
48 | self.sess = tf.Session()
49 | self.writer = None
50 | self.saver = None
51 | self.summary = None
52 |
53 | self.test_size = 10
54 | self.test_data = None
55 |
56 | self.test_mode = False
57 |
58 | def check_input_shape(self, input_shape):
59 | # Check for CelebA
60 | if input_shape == (64, 64, 3):
61 | return
62 |
63 | # Check for MNIST (size modified)
64 | if input_shape == (32, 32, 1):
65 | return
66 |
67 | # Check for Cifar10, 100 etc
68 | if input_shape == (32, 32, 3):
69 | return
70 |
71 | errmsg = 'Input size should be 32 x 32 or 64 x 64!'
72 | raise Exception(errmsg)
73 |
74 | def main_loop(self, datasets, epochs=100):
75 | """
76 | Main learning loop
77 | """
78 | # Create output directories if not exist
79 | out_dir = os.path.join(self.output, self.name)
80 | if not os.path.isdir(out_dir):
81 | os.makedirs(out_dir)
82 |
83 | res_out_dir = os.path.join(out_dir, 'results')
84 | if not os.path.isdir(res_out_dir):
85 | os.makedirs(res_out_dir)
86 |
87 | chk_out_dir = os.path.join(out_dir, 'checkpoints')
88 | if not os.path.isdir(chk_out_dir):
89 | os.makedirs(chk_out_dir)
90 |
91 | time_str = time.strftime('%Y%m%d_%H%M%S', time.localtime())
92 | log_out_dir = os.path.join(out_dir, 'log', time_str)
93 | if not os.path.isdir(log_out_dir):
94 | os.makedirs(log_out_dir)
95 |
96 | # Make test data
97 | self.make_test_data()
98 |
99 | # Start training
100 | with self.sess.as_default():
101 | current_epoch = tf.Variable(0, name='current_epoch', dtype=tf.int32)
102 | current_batch = tf.Variable(0, name='current_batch', dtype=tf.int32)
103 |
104 | # Initialize global variables
105 | self.saver = tf.train.Saver()
106 | if self.resume is not None:
107 | print('Resume training: %s' % self.resume)
108 | self.load_model(self.resume)
109 | else:
110 | self.sess.run(tf.global_variables_initializer())
111 | self.sess.run(tf.local_variables_initializer())
112 |
113 | # Update rule
114 | num_data = len(datasets)
115 | update_epoch = current_epoch.assign(current_epoch + 1)
116 | update_batch = current_batch.assign(tf.mod(tf.minimum(current_batch + self.batchsize, num_data), num_data))
117 |
118 | self.writer = tf.summary.FileWriter(log_out_dir, self.sess.graph)
119 | self.sess.graph.finalize()
120 |
121 | print('\n\n--- START TRAINING ---\n')
122 | for e in range(current_epoch.eval(), epochs):
123 | perm = np.random.permutation(num_data)
124 | start_time = time.time()
125 | for b in range(current_batch.eval(), num_data, self.batchsize):
126 | # Update batch index
127 | self.sess.run(update_batch)
128 |
129 | # Check batch size
130 | bsize = min(self.batchsize, num_data - b)
131 | indx = perm[b:b+bsize]
132 | if bsize < self.batchsize:
133 | break
134 |
135 | # Get batch and train on it
136 | x_batch = self.make_batch(datasets, indx)
137 | losses = self.train_on_batch(x_batch, e * num_data + (b + bsize))
138 |
139 | # Print current status
140 | elapsed_time = time.time() - start_time
141 | eta = elapsed_time / (b + bsize) * (num_data - (b + bsize))
142 | ratio = 100.0 * (b + bsize) / num_data
143 | print('Epoch #%d, Batch: %d / %d (%6.2f %%) ETA: %s' % \
144 | (e + 1, b + bsize, num_data, ratio, time_format(eta)))
145 |
146 | for i, (k, v) in enumerate(losses):
147 | text = '%s = %8.6f' % (k, v)
148 | print(' %25s' % (text), end='')
149 | if (i + 1) % 3 == 0:
150 | print('')
151 |
152 | print('\n')
153 | sys.stdout.flush()
154 |
155 | # Save generated images
156 | save_period = 10000
157 | if b != 0 and ((b // save_period != (b + bsize) // save_period) or ((b + bsize) == num_data)):
158 | outfile = os.path.join(res_out_dir, 'epoch_%04d_batch_%d.png' % (e + 1, b + bsize))
159 | self.save_images(outfile)
160 | outfile = os.path.join(chk_out_dir, 'epoch_%04d' % (e + 1))
161 | self.save_model(outfile)
162 |
163 | if self.test_mode:
164 | print('\nFinish testing: %s' % self.name)
165 | return
166 |
167 | print('')
168 | self.sess.run(update_epoch)
169 |
170 | def make_batch(self, datasets, indx):
171 | """
172 | Get batch from datasets
173 | """
174 | return datasets.images[indx]
175 |
176 | def save_images(self, filename):
177 | """
178 | Save images generated from random sample numbers
179 | """
180 | imgs = self.predict(self.test_data) * 0.5 + 0.5
181 | imgs = np.clip(imgs, 0.0, 1.0)
182 | if imgs.shape[3] == 1:
183 | imgs = np.squeeze(imgs, axis=(3,))
184 |
185 | _, height, width, dims = imgs.shape
186 |
187 | margin = min(width, height) // 10
188 | figure = np.ones(((margin + height) * 10 + margin, (margin + width) * 10 + margin, dims), np.float32)
189 |
190 | for i in range(100):
191 | row = i // 10
192 | col = i % 10
193 |
194 | y = margin + (margin + height) * row
195 | x = margin + (margin + width) * col
196 | figure[y:y+height, x:x+width, :] = imgs[i, :, :, :]
197 |
198 | figure = Image.fromarray((figure * 255.0).astype(np.uint8))
199 | figure.save(filename)
200 |
201 | def save_model(self, model_file):
202 | self.saver.save(self.sess, model_file)
203 |
204 | def load_model(self, model_file):
205 | self.saver.restore(self.sess, model_file)
206 |
207 | @abstractmethod
208 | def make_test_data(self):
209 | """
210 | Please override "make_test_data" method in the derived model!
211 | """
212 | pass
213 |
214 | @abstractmethod
215 | def predict(self, z_sample):
216 | """
217 | Please override "predict" method in the derived model!
218 | """
219 | pass
220 |
221 | @abstractmethod
222 | def train_on_batch(self, x_batch, index):
223 | """
224 | Please override "train_on_batch" method in the derived model!
225 | """
226 | pass
227 |
228 | def image_tiling(self, images, rows, cols):
229 | n_images = rows * cols
230 | mg = max(self.input_shape[0], self.input_shape[1]) // 20
231 | pad_img = tf.pad(images, [[0, 0], [mg, mg], [mg, mg], [0, 0]], constant_values=1.0)
232 | img_arr = tf.split(pad_img, n_images, 0)
233 |
234 | rows = []
235 | for i in range(self.test_size):
236 | rows.append(tf.concat(img_arr[i * cols: (i + 1) * cols], axis=2))
237 |
238 | tile = tf.concat(rows, axis=1)
239 | return tile
240 |
241 | class CondBaseModel(BaseModel):
242 | def __init__(self, **kwargs):
243 | super(CondBaseModel, self).__init__(**kwargs)
244 |
245 | if 'attr_names' not in kwargs:
246 | raise Exception('Please specify attribute names (attr_names')
247 | self.attr_names = kwargs['attr_names']
248 | self.num_attrs = len(self.attr_names)
249 |
250 | self.test_size = 10
251 |
252 | def make_batch(self, datasets, indx):
253 | images = datasets.images[indx]
254 | attrs = datasets.attrs[indx]
255 | return images, attrs
256 |
257 | def save_images(self, filename):
258 | assert self.attr_names is not None
259 |
260 | try:
261 | test_samples = self.test_data['z_test']
262 | except KeyError as e:
263 | print('Key "z_test" must be provided in "make_test_data" method!')
264 | raise e
265 |
266 | try:
267 | test_attrs = self.test_data['c_test']
268 | except KeyError as e:
269 | print('Key "c_test" must be provided in "make_test_data" method!')
270 | raise e
271 |
272 | imgs = self.predict([test_samples, test_attrs]) * 0.5 + 0.5
273 | imgs = np.clip(imgs, 0.0, 1.0)
274 |
275 | _, height, width, dims = imgs.shape
276 |
277 | margin = min(width, height) // 10
278 | figure = np.ones(((margin + height) * self.test_size + margin, (margin + width) * self.num_attrs + margin, dims), np.float32)
279 |
280 | for i in range(self.test_size * self.num_attrs):
281 | row = i // self.num_attrs
282 | col = i % self.num_attrs
283 |
284 | y = margin + (margin + height) * row
285 | x = margin + (margin + width) * col
286 | figure[y:y+height, x:x+width, :] = imgs[i, :, :, :]
287 |
288 | figure = Image.fromarray((figure * 255.0).astype(np.uint8))
289 | figure.save(filename)
290 |
--------------------------------------------------------------------------------
/models/began.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from .base import BaseModel
5 | from .utils import *
6 |
7 | def repelling_regularizer(x, batchsize):
8 | dims = x.get_shape()[1]
9 | S_i = tf.tile(x, [batchsize, 1])
10 | S_j = tf.tile(x, [1, batchsize])
11 | S_j = tf.reshape(S_j, [-1, dims])
12 | S_i_T_S_j = tf.reduce_sum(tf.multiply(S_i, S_j), axis=1)
13 | S_i_norm2 = tf.reduce_sum(tf.square(S_i), axis=1)
14 | S_j_norm2 = tf.reduce_sum(tf.square(S_j), axis=1)
15 | f_PT = tf.square(S_i_T_S_j) / (tf.multiply(S_i_norm2, S_j_norm2) + 1.0e-8)
16 | f_PT = tf.reduce_sum(f_PT) / tf.cast(batchsize * (batchsize - 1), 'float32')
17 | return f_PT
18 |
19 | class Generator(object):
20 | def __init__(self, input_shape, z_dims):
21 | self.variables = None
22 | self.update_ops = None
23 | self.reuse = False
24 | self.input_shape = input_shape
25 | self.z_dims = z_dims
26 | self.name = 'generator'
27 |
28 | def __call__(self, inputs, training=True):
29 | with tf.variable_scope(self.name, reuse=self.reuse):
30 | with tf.variable_scope('deconv1'):
31 | w = self.input_shape[0] // (2 ** 3)
32 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
33 | x = tf.layers.conv2d_transpose(x, 512, (w, w), (1, 1), 'valid',
34 | kernel_initializer=tf.contrib.layers.xavier_initializer())
35 | x = tf.layers.batch_normalization(x, training=training)
36 | x = tf.nn.relu(x)
37 |
38 | with tf.variable_scope('deconv2'):
39 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same',
40 | kernel_initializer=tf.contrib.layers.xavier_initializer())
41 | x = tf.layers.batch_normalization(x, training=training)
42 | x = tf.nn.relu(x)
43 |
44 | with tf.variable_scope('deconv3'):
45 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same',
46 | kernel_initializer=tf.contrib.layers.xavier_initializer())
47 | x = tf.layers.batch_normalization(x, training=training)
48 | x = tf.nn.relu(x)
49 |
50 | with tf.variable_scope('deconv4'):
51 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same',
52 | kernel_initializer=tf.contrib.layers.xavier_initializer())
53 | x = tf.layers.batch_normalization(x, training=training)
54 | x = tf.nn.relu(x)
55 |
56 | with tf.variable_scope('deconv5'):
57 | d = self.input_shape[2]
58 | x = tf.layers.conv2d(x, d, (5, 5), (1, 1), 'same',
59 | kernel_initializer=tf.contrib.layers.xavier_initializer())
60 | x = tf.tanh(x)
61 |
62 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
63 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name)
64 | self.reuse = True
65 | return x
66 |
67 | class Discriminator(object):
68 | def __init__(self, input_shape):
69 | self.input_shape = input_shape
70 | self.variables = None
71 | self.update_ops = None
72 | self.reuse = False
73 | self.name = 'discriminator'
74 |
75 | def __call__(self, inputs, training=True):
76 | with tf.variable_scope(self.name, reuse=self.reuse):
77 | with tf.variable_scope('conv1'):
78 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same',
79 | kernel_initializer=tf.contrib.layers.xavier_initializer())
80 | x = tf.layers.batch_normalization(x, training=training)
81 | x = lrelu(x)
82 |
83 | with tf.variable_scope('conv2'):
84 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same',
85 | kernel_initializer=tf.contrib.layers.xavier_initializer())
86 | x = tf.layers.batch_normalization(x, training=training)
87 | x = lrelu(x)
88 |
89 | with tf.variable_scope('conv3'):
90 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same',
91 | kernel_initializer=tf.contrib.layers.xavier_initializer())
92 | x = tf.layers.batch_normalization(x, training=training)
93 | x = lrelu(x)
94 |
95 | S = tf.contrib.layers.flatten(x)
96 |
97 | with tf.variable_scope('deconv1'):
98 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same',
99 | kernel_initializer=tf.contrib.layers.xavier_initializer())
100 | x = tf.layers.batch_normalization(x, training=training)
101 | x = lrelu(x)
102 |
103 | with tf.variable_scope('deconv2'):
104 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same',
105 | kernel_initializer=tf.contrib.layers.xavier_initializer())
106 | x = tf.layers.batch_normalization(x, training=training)
107 | x = lrelu(x)
108 |
109 | with tf.variable_scope('deconv3'):
110 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same',
111 | kernel_initializer=tf.contrib.layers.xavier_initializer())
112 | x = tf.layers.batch_normalization(x, training=training)
113 | x = lrelu(x)
114 |
115 | with tf.variable_scope('deconv4'):
116 | d = self.input_shape[2]
117 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same',
118 | kernel_initializer=tf.contrib.layers.xavier_initializer())
119 | x = tf.tanh(x)
120 |
121 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
122 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name)
123 | self.reuse = True
124 | return x, S
125 |
126 | class BEGAN(BaseModel):
127 | def __init__(self,
128 | input_shape=(64, 64, 3),
129 | z_dims = 128,
130 | name='began',
131 | **kwargs
132 | ):
133 | super(BEGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
134 |
135 | self.z_dims = z_dims
136 |
137 | self.beta = 0.01
138 | self.boundary_equil = True
139 | self.margin = 0.1
140 | self.update_k_t = None
141 | self.k_t = tf.Variable(0.5, name='k_t')
142 | self.lambda_k = 1.0e-4
143 | self.gamma = 0.7
144 |
145 | self.gen_trainer = None
146 | self.dis_trainer = None
147 | self.gen_loss_D = None
148 | self.gen_loss_G = None
149 | self.dis_loss = None
150 |
151 | self.f_gen = None
152 | self.f_dis = None
153 |
154 | self.x_train = None
155 | self.z_D = None
156 | self.z_G = None
157 |
158 | self.z_test = None
159 | self.x_test = None
160 | self.x_tile = None
161 |
162 | self.train_op = None
163 |
164 | self.build_model()
165 |
166 | def train_on_batch(self, x_batch, index):
167 | batchsize = x_batch.shape[0]
168 | z_D = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
169 | z_G = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
170 |
171 | # Training
172 | _, g_loss, _, d_loss = self.sess.run(
173 | (self.train_op, self.gen_loss_G, self.gen_loss_D, self.dis_loss),
174 | feed_dict={
175 | self.x_train: x_batch,
176 | self.z_G: z_G,
177 | self.z_D: z_D,
178 | }
179 | )
180 |
181 | # Summary update
182 | if index // 1000 != (index - batchsize) // 1000:
183 | summary = self.sess.run(
184 | self.summary,
185 | feed_dict={
186 | self.x_train: x_batch,
187 | self.z_D: z_D,
188 | self.z_G: z_G,
189 | self.z_test: self.test_data
190 | }
191 | )
192 | self.writer.add_summary(summary, index)
193 |
194 | return [
195 | ('g_loss', g_loss),
196 | ('d_loss', d_loss)
197 | ]
198 |
199 | def predict(self, z_samples):
200 | x_sample = self.sess.run(
201 | (self.x_test),
202 | feed_dict={self.z_test: z_samples}
203 | )
204 | return x_sample
205 |
206 | def make_test_data(self):
207 | self.test_data = np.random.uniform(-1, 1, size=(self.test_size * self.test_size, self.z_dims))
208 |
209 | def build_model(self):
210 | # Trainer
211 | self.f_dis = Discriminator(self.input_shape)
212 | self.f_gen = Generator(self.input_shape, self.z_dims)
213 |
214 | x_shape = (self.batchsize,) + self.input_shape
215 | self.x_train = tf.placeholder(tf.float32, shape=(self.batchsize,) + self.input_shape, name='x_train')
216 |
217 | z_shape = (self.batchsize, self.z_dims)
218 | self.z_D = tf.placeholder(tf.float32, shape=z_shape, name='z_D')
219 | self.z_G = tf.placeholder(tf.float32, shape=z_shape, name='z_G')
220 |
221 | x_f_D = self.f_gen(self.z_D)
222 | x_f_D_pred, _ = self.f_dis(x_f_D)
223 |
224 | x_f_G = self.f_gen(self.z_G)
225 | x_f_G_pred, S = self.f_dis(x_f_G)
226 |
227 | x_train_pred, _ = self.f_dis(self.x_train)
228 |
229 | f_PT = repelling_regularizer(S, self.batchsize)
230 |
231 | self.gen_loss_D = tf.losses.absolute_difference(x_f_D, x_f_D_pred)
232 | self.gen_loss_G = tf.losses.absolute_difference(x_f_G, x_f_G_pred)
233 | self.dis_loss = tf.losses.absolute_difference(self.x_train, x_train_pred)
234 |
235 | gen_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
236 | dis_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
237 |
238 | if self.boundary_equil:
239 | self.gen_trainer = gen_opt.minimize(self.gen_loss_G + self.beta * f_PT, var_list=self.f_gen.variables)
240 | self.dis_trainer = dis_opt.minimize(self.dis_loss - self.k_t * self.gen_loss_D, var_list=self.f_dis.variables)
241 | self.update_k_t = self.k_t.assign(tf.clip_by_value(self.k_t + self.lambda_k * (self.gamma * self.dis_loss - self.gen_loss_D), 0.0, 1.0))
242 |
243 | with tf.control_dependencies([self.gen_trainer, self.dis_trainer, self.update_k_t] + \
244 | self.f_dis.update_ops + self.f_gen.update_ops):
245 | self.train_op = tf.no_op(name='train')
246 |
247 | else:
248 | self.gen_trainer = gen_opt.minimize(self.gen_loss_G + self.beta * f_PT, var_list=self.f_gen.variables)
249 | self.dis_trainer = dis_opt.minimize(self.dis_loss - tf.maximum(0.0, self.margin - self.gen_loss_D), var_list=self.f_dis.variables)
250 |
251 | with tf.control_dependencies([self.gen_trainer, self.dis_trainer] + \
252 | self.f_dis.update_ops + self.f_gen.update_ops):
253 | self.train_op = tf.no_op(name='train')
254 |
255 | # Predictor
256 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
257 | self.x_test = self.f_gen(self.z_test)
258 | self.x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
259 |
260 | tf.summary.image('x_real', image_cast(self.x_train), 10)
261 | tf.summary.image('x_real_rec', image_cast(x_train_pred), 10)
262 | tf.summary.image('x_fake', image_cast(x_f_G), 10)
263 | tf.summary.image('x_fake_rec', image_cast(x_f_G_pred), 10)
264 | tf.summary.image('x_tile', image_cast(self.x_tile), 1)
265 | tf.summary.scalar('gen_loss', self.gen_loss_G)
266 | tf.summary.scalar('dis_loss', self.dis_loss)
267 | tf.summary.scalar('k_t', self.k_t)
268 |
269 | self.summary = tf.summary.merge_all()
270 |
--------------------------------------------------------------------------------
/models/cvae.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from .base import CondBaseModel
5 | from .utils import *
6 |
7 | class Encoder(object):
8 | def __init__(self, input_shape, z_dims, num_attrs):
9 | self.variables = None
10 | self.reuse = False
11 | self.input_shape = input_shape
12 | self.z_dims = z_dims
13 | self.num_attrs = num_attrs
14 |
15 | def __call__(self, inputs, attrs, training=True):
16 | with tf.variable_scope('encoder', reuse=self.reuse):
17 | with tf.variable_scope('conv1'):
18 | a = tf.reshape(attrs, [-1, 1, 1, self.num_attrs])
19 | a = tf.tile(a, [1, self.input_shape[0], self.input_shape[1], 1])
20 | x = tf.concat([inputs, a], axis=-1)
21 | x = tf.layers.conv2d(x, 64, (5, 5), (2, 2), 'same')
22 | x = tf.layers.batch_normalization(x, training=training)
23 | x = tf.nn.relu(x)
24 |
25 | with tf.variable_scope('conv2'):
26 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same')
27 | x = tf.layers.batch_normalization(x, training=training)
28 | x = tf.nn.relu(x)
29 |
30 | with tf.variable_scope('conv3'):
31 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same')
32 | x = tf.layers.batch_normalization(x, training=training)
33 | x = tf.nn.relu(x)
34 |
35 | with tf.variable_scope('global_average'):
36 | x = tf.reduce_mean(x, axis=[1, 2])
37 |
38 | with tf.variable_scope('fc1'):
39 | z_avg = tf.layers.dense(x, self.z_dims)
40 | z_log_var = tf.layers.dense(x, self.z_dims)
41 |
42 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder')
43 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='encoder')
44 | self.reuse = True
45 |
46 | return z_avg, z_log_var
47 |
48 | class Decoder(object):
49 | def __init__(self, input_shape):
50 | self.variables = None
51 | self.reuse = False
52 | self.input_shape = input_shape
53 |
54 | def __call__(self, inputs, attrs, training=True):
55 | with tf.variable_scope('decoder', reuse=self.reuse):
56 | with tf.variable_scope('fc1'):
57 | w = self.input_shape[0] // (2 ** 3)
58 | x = tf.concat([inputs, attrs], axis=-1)
59 | x = tf.layers.dense(x, w * w * 256)
60 | x = tf.layers.batch_normalization(x, training=training)
61 | x = tf.nn.relu(x)
62 | x = tf.reshape(x, [-1, w, w, 256])
63 |
64 | with tf.variable_scope('conv1'):
65 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same')
66 | x = tf.layers.batch_normalization(x, training=training)
67 | x = lrelu(x)
68 |
69 | with tf.variable_scope('conv2'):
70 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same')
71 | x = tf.layers.batch_normalization(x, training=training)
72 | x = lrelu(x)
73 |
74 | with tf.variable_scope('conv3'):
75 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same')
76 | x = tf.layers.batch_normalization(x, training=training)
77 | x = lrelu(x)
78 |
79 | with tf.variable_scope('conv4'):
80 | d = self.input_shape[2]
81 | x = tf.layers.conv2d_transpose(x, d, (3, 3), (1, 1), 'same')
82 | x = tf.tanh(x)
83 |
84 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder')
85 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='decoder')
86 | self.reuse = True
87 |
88 | return x
89 |
90 | class CVAE(CondBaseModel):
91 | def __init__(self,
92 | input_shape=(64, 64, 3),
93 | z_dims = 128,
94 | name='cvae',
95 | **kwargs
96 | ):
97 | super(CVAE, self).__init__(input_shape=input_shape, name=name, **kwargs)
98 |
99 | self.z_dims = z_dims
100 |
101 | self.total_loss = None
102 | self.optimizer = None
103 | self.train_op = None
104 |
105 | self.encoder = None
106 | self.decoder = None
107 |
108 | self.x_train = None
109 | self.c_train = None
110 |
111 | self.z_test = None
112 | self.x_test = None
113 | self.c_test = None
114 |
115 | self.build_model()
116 |
117 | def train_on_batch(self, batch, index):
118 | x_batch, c_batch = batch
119 |
120 | _, loss, summary = self.sess.run(
121 | (self.train_op, self.total_loss, self.summary),
122 | feed_dict={self.x_train: x_batch, self.c_train: c_batch, self.z_test: self.test_data['z_test'], self.c_test: self.test_data['c_test']}
123 | )
124 |
125 | self.writer.add_summary(summary, index)
126 | return [ ('loss', loss) ]
127 |
128 | def predict(self, batch):
129 | z_samples, c_samples = batch
130 | x_sample = self.sess.run(
131 | self.x_test,
132 | feed_dict={self.z_test: z_samples, self.c_test: c_samples}
133 | )
134 | return x_sample
135 |
136 | def make_test_data(self):
137 | c_t = np.identity(self.num_attrs)
138 | c_t = np.tile(c_t, (self.test_size, 1))
139 | z_t = np.random.normal(size=(self.test_size, self.z_dims))
140 | z_t = np.tile(z_t, (1, self.num_attrs))
141 | z_t = z_t.reshape((self.test_size * self.num_attrs, -1))
142 | self.test_data = {'z_test': z_t, 'c_test': c_t}
143 |
144 | def build_model(self):
145 | self.encoder = Encoder(self.input_shape, self.z_dims, self.num_attrs)
146 | self.decoder = Decoder(self.input_shape)
147 |
148 | # Trainer
149 | batch_shape = (None,) + self.input_shape
150 | self.x_train = tf.placeholder(tf.float32, shape=batch_shape)
151 | self.c_train = tf.placeholder(tf.float32, shape=(None, self.num_attrs))
152 |
153 | z_avg, z_log_var = self.encoder(self.x_train, self.c_train)
154 | epsilon = tf.random_normal(tf.shape(z_avg))
155 | z_sample = z_avg + tf.multiply(tf.exp(0.5 * z_log_var), epsilon)
156 | x_sample = self.decoder(z_sample, self.c_train)
157 |
158 | self.total_loss = tf.constant(0.0)
159 | self.total_loss += tf.reduce_mean(tf.squared_difference(self.x_train, x_sample))
160 | self.total_loss += kl_loss(z_avg, z_log_var)
161 | self.optimizer = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5).minimize(self.total_loss)
162 |
163 | with tf.control_dependencies([self.optimizer] + self.encoder.update_ops + self.decoder.update_ops):
164 | self.train_op = tf.no_op(name='train')
165 |
166 | # Predictor
167 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
168 | self.c_test = tf.placeholder(tf.float32, shape=(None, self.num_attrs))
169 |
170 | self.x_test = self.decoder(self.z_test, self.c_test)
171 | x_tile = self.image_tiling(self.x_test, self.test_size, self.num_attrs)
172 |
173 | # Summary
174 | tf.summary.image('x_real', self.x_train, 10)
175 | tf.summary.image('x_fake', x_sample, 10)
176 | tf.summary.image('x_tile', x_tile, 1)
177 | tf.summary.scalar('total_loss', self.total_loss)
178 |
179 | self.summary = tf.summary.merge_all()
--------------------------------------------------------------------------------
/models/cvaegan.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from .base import CondBaseModel
6 | from .utils import *
7 |
8 | class Encoder(object):
9 | def __init__(self, input_shape, z_dims, num_attrs):
10 | self.variables = None
11 | self.reuse = False
12 | self.input_shape = input_shape
13 | self.z_dims = z_dims
14 | self.num_attrs = num_attrs
15 | self.name = 'encoder'
16 |
17 | def __call__(self, inputs, attrs, training=True):
18 | with tf.variable_scope(self.name, reuse=self.reuse):
19 | with tf.variable_scope('conv1'):
20 | a = tf.reshape(attrs, [-1, 1, 1, self.num_attrs])
21 | a = tf.tile(a, [1, self.input_shape[0], self.input_shape[1], 1])
22 | x = tf.concat([inputs, a], axis=-1)
23 | x = tf.layers.conv2d(x, 64, (5, 5), (2, 2), 'same')
24 | x = tf.layers.batch_normalization(x, training=training)
25 | x = tf.nn.relu(x)
26 |
27 | with tf.variable_scope('conv2'):
28 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same')
29 | x = tf.layers.batch_normalization(x, training=training)
30 | x = lrelu(x)
31 |
32 | with tf.variable_scope('conv3'):
33 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same')
34 | x = tf.layers.batch_normalization(x, training=training)
35 | x = lrelu(x)
36 |
37 | with tf.variable_scope('conv4'):
38 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same')
39 | x = tf.layers.batch_normalization(x, training=training)
40 | x = lrelu(x)
41 |
42 | with tf.variable_scope('global_avg'):
43 | x = tf.reduce_mean(x, axis=[1, 2])
44 |
45 | with tf.variable_scope('fc1'):
46 | z_avg = tf.layers.dense(x, self.z_dims)
47 | z_log_var = tf.layers.dense(x, self.z_dims)
48 |
49 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
50 | self.reuse = True
51 |
52 | return z_avg, z_log_var
53 |
54 | class Decoder(object):
55 | def __init__(self, input_shape):
56 | self.variables = None
57 | self.reuse = False
58 | self.input_shape = input_shape
59 | self.name = 'decoder'
60 |
61 | def __call__(self, inputs, attrs, training=True):
62 | with tf.variable_scope(self.name, reuse=self.reuse):
63 | with tf.variable_scope('fc1'):
64 | w = self.input_shape[0] // (2 ** 3)
65 | x = tf.concat([inputs, attrs], axis=-1)
66 | x = tf.layers.dense(x, w * w * 256)
67 | x = tf.layers.batch_normalization(x, training=training)
68 | x = tf.nn.relu(x)
69 | x = tf.reshape(x, [-1, w, w, 256])
70 |
71 | with tf.variable_scope('conv1'):
72 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same')
73 | x = tf.layers.batch_normalization(x, training=training)
74 | x = lrelu(x)
75 |
76 | with tf.variable_scope('conv2'):
77 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same')
78 | x = tf.layers.batch_normalization(x, training=training)
79 | x = lrelu(x)
80 |
81 | with tf.variable_scope('conv3'):
82 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same')
83 | x = tf.layers.batch_normalization(x, training=training)
84 | x = lrelu(x)
85 |
86 | with tf.variable_scope('conv4'):
87 | d = self.input_shape[2]
88 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same')
89 | x = tf.tanh(x)
90 |
91 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
92 | self.reuse = True
93 |
94 | return x
95 |
96 | class Classifier(object):
97 | def __init__(self, input_shape, num_attrs):
98 | self.variables = None
99 | self.reuse = False
100 | self.input_shape = input_shape
101 | self.num_attrs = num_attrs
102 | self.name = 'classifier'
103 |
104 | def __call__(self, inputs, training=True):
105 | with tf.variable_scope(self.name, reuse=self.reuse):
106 | with tf.variable_scope('conv1'):
107 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same')
108 | x = tf.layers.batch_normalization(x, training=training)
109 | x = tf.nn.relu(x)
110 |
111 | with tf.variable_scope('conv2'):
112 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same')
113 | x = tf.layers.batch_normalization(x, training=training)
114 | x = tf.nn.relu(x)
115 |
116 | with tf.variable_scope('conv3'):
117 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same')
118 | x = tf.layers.batch_normalization(x, training=training)
119 | x = tf.nn.relu(x)
120 |
121 | with tf.variable_scope('conv4'):
122 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same')
123 | x = tf.layers.batch_normalization(x, training=training)
124 | x = tf.nn.relu(x)
125 |
126 | with tf.variable_scope('global_avg'):
127 | x = tf.reduce_mean(x, axis=[1, 2])
128 |
129 | with tf.variable_scope('fc1'):
130 | f = tf.contrib.layers.flatten(x)
131 | y = tf.layers.dense(f, self.num_attrs)
132 |
133 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
134 | self.reuse = True
135 |
136 | return y, f
137 |
138 | class Discriminator(object):
139 | def __init__(self, input_shape):
140 | self.variables = None
141 | self.reuse = False
142 | self.input_shape = input_shape
143 | self.name = 'discriminator'
144 |
145 | def __call__(self, inputs, training=True):
146 | with tf.variable_scope(self.name, reuse=self.reuse):
147 | with tf.variable_scope('conv1'):
148 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same')
149 | x = tf.layers.batch_normalization(x, training=training)
150 | x = tf.nn.relu(x)
151 |
152 | with tf.variable_scope('conv2'):
153 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same')
154 | x = tf.layers.batch_normalization(x, training=training)
155 | x = tf.nn.relu(x)
156 |
157 | with tf.variable_scope('conv3'):
158 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same')
159 | x = tf.layers.batch_normalization(x, training=training)
160 | x = tf.nn.relu(x)
161 |
162 | with tf.variable_scope('conv4'):
163 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same')
164 | x = tf.layers.batch_normalization(x, training=training)
165 | x = tf.nn.relu(x)
166 |
167 | with tf.variable_scope('global_avg'):
168 | x = tf.reduce_mean(x, axis=[1, 2])
169 |
170 | with tf.variable_scope('fc1'):
171 | f = tf.contrib.layers.flatten(x)
172 | y = tf.layers.dense(f, 1)
173 |
174 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
175 | self.reuse = True
176 |
177 | return y, f
178 |
179 |
180 | class CVAEGAN(CondBaseModel):
181 | def __init__(self,
182 | input_shape=(64, 64, 3),
183 | z_dims = 128,
184 | name='cvaegan',
185 | **kwargs
186 | ):
187 | super(CVAEGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
188 |
189 | self.z_dims = z_dims
190 |
191 | # Parameters for feature matching
192 | self.use_feature_match = False
193 | self.alpha = 0.7
194 |
195 | self.E_f_D_r = None
196 | self.E_f_D_p = None
197 | self.E_f_C_r = None
198 | self.E_f_C_p = None
199 |
200 | self.f_enc = None
201 | self.f_gen = None
202 | self.f_cls = None
203 | self.f_dis = None
204 |
205 | self.x_r = None
206 | self.c_r = None
207 | self.z_p = None
208 |
209 | self.z_test = None
210 | self.x_test = None
211 | self.c_test = None
212 |
213 | self.enc_trainer = None
214 | self.gen_trainer = None
215 | self.dis_trainer = None
216 | self.cls_trainer = None
217 |
218 | self.gen_loss = None
219 | self.dis_loss = None
220 | self.gen_acc = None
221 | self.dis_acc = None
222 |
223 | self.build_model()
224 |
225 | def train_on_batch(self, batch, index):
226 | x_r, c_r = batch
227 | batchsize = len(x_r)
228 | z_p = np.random.uniform(-1, 1, size=(len(x_r), self.z_dims))
229 |
230 | _, _, _, _, gen_loss, dis_loss, gen_acc, dis_acc = self.sess.run(
231 | (self.gen_trainer, self.enc_trainer, self.dis_trainer, self.cls_trainer, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc),
232 | feed_dict={
233 | self.x_r: x_r, self.z_p: z_p, self.c_r: c_r,
234 | self.z_test: self.test_data['z_test'], self.c_test: self.test_data['c_test']
235 | }
236 | )
237 |
238 | summary_priod = 1000
239 | if index // summary_priod != (index + batchsize) // summary_priod:
240 | summary = self.sess.run(
241 | self.summary,
242 | feed_dict={
243 | self.x_r: x_r, self.z_p: z_p, self.c_r: c_r,
244 | self.z_test: self.test_data['z_test'], self.c_test: self.test_data['c_test']
245 | }
246 | )
247 | self.writer.add_summary(summary, index)
248 |
249 | return [
250 | ('gen_loss', gen_loss), ('dis_loss', dis_loss),
251 | ('gen_acc', gen_acc), ('dis_acc', dis_acc)
252 | ]
253 |
254 | def predict(self, batch):
255 | z_samples, c_samples = batch
256 | x_sample = self.sess.run(
257 | self.x_test,
258 | feed_dict={self.z_test: z_samples, self.c_test: c_samples}
259 | )
260 | return x_sample
261 |
262 | def make_test_data(self):
263 | c_t = np.identity(self.num_attrs)
264 | c_t = np.tile(c_t, (self.test_size, 1))
265 | z_t = np.random.normal(size=(self.test_size, self.z_dims))
266 | z_t = np.tile(z_t, (1, self.num_attrs))
267 | z_t = z_t.reshape((self.test_size * self.num_attrs, self.z_dims))
268 | self.test_data = {'z_test': z_t, 'c_test': c_t}
269 |
270 | def build_model(self):
271 | self.f_enc = Encoder(self.input_shape, self.z_dims, self.num_attrs)
272 | self.f_gen = Decoder(self.input_shape)
273 |
274 | n_cls_out = self.num_attrs if self.use_feature_match else self.num_attrs + 1
275 | self.f_cls = Classifier(self.input_shape, n_cls_out)
276 | self.f_dis = Discriminator(self.input_shape)
277 |
278 | # Trainer
279 | self.x_r = tf.placeholder(tf.float32, shape=(None,) + self.input_shape)
280 | self.c_r = tf.placeholder(tf.float32, shape=(None, self.num_attrs))
281 |
282 | z_avg, z_log_var = self.f_enc(self.x_r, self.c_r)
283 |
284 | z_f = sample_normal(z_avg, z_log_var)
285 | x_f = self.f_gen(z_f, self.c_r)
286 |
287 | self.z_p = tf.placeholder(tf.float32, shape=(None, self.z_dims))
288 | x_p = self.f_gen(self.z_p, self.c_r)
289 |
290 | c_r_pred, f_C_r = self.f_cls(self.x_r)
291 | c_f, f_C_f = self.f_cls(x_f)
292 | c_p, f_C_p = self.f_cls(x_p)
293 |
294 | y_r, f_D_r = self.f_dis(self.x_r)
295 | y_f, f_D_f = self.f_dis(x_f)
296 | y_p, f_D_p = self.f_dis(x_p)
297 |
298 | L_KL = kl_loss(z_avg, z_log_var)
299 |
300 | enc_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
301 | gen_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
302 | cls_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
303 | dis_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
304 |
305 | if self.use_feature_match:
306 | # Use feature matching (it is usually unstable)
307 | L_GD = self.L_GD(f_D_r, f_D_p)
308 | L_GC = self.L_GC(f_C_r, f_C_p, self.c_r)
309 | L_G = self.L_G(self.x_r, x_f, f_D_r, f_D_f, f_C_r, f_C_f)
310 |
311 | with tf.name_scope('L_D'):
312 | L_D = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_r), y_r) + \
313 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_f), y_f) + \
314 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_p), y_p)
315 |
316 | with tf.name_scope('L_C'):
317 | L_C = tf.losses.softmax_cross_entropy(self.c_r, c_r_pred)
318 |
319 | self.enc_trainer = enc_opt.minimize(L_G + L_KL, var_list=self.f_enc.variables)
320 | self.gen_trainer = gen_opt.minimize(L_G + L_GD + L_GC, var_list=self.f_gen.variables)
321 | self.cls_trainer = cls_opt.minimize(L_C, var_list=self.f_cls.variables)
322 | self.dis_trainer = dis_opt.minimize(L_D, var_list=self.f_dis.variables)
323 |
324 | self.gen_loss = L_G + L_GD + L_GC
325 | self.dis_loss = L_D
326 |
327 | # Predictor
328 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
329 | self.c_test = tf.placeholder(tf.float32, shape=(None, self.num_attrs))
330 |
331 | self.x_test = self.f_gen(self.z_test, self.c_test)
332 | x_tile = self.image_tiling(self.x_test, self.test_size, self.num_attrs)
333 |
334 | # Summary
335 | tf.summary.image('x_real', self.x_r, 10)
336 | tf.summary.image('x_fake', x_f, 10)
337 | tf.summary.image('x_tile', x_tile, 1)
338 | tf.summary.scalar('L_G', L_G)
339 | tf.summary.scalar('L_GD', L_GD)
340 | tf.summary.scalar('L_GC', L_GC)
341 | tf.summary.scalar('L_C', L_C)
342 | tf.summary.scalar('L_D', L_D)
343 | tf.summary.scalar('L_KL', L_KL)
344 | tf.summary.scalar('gen_loss', self.gen_loss)
345 | tf.summary.scalar('dis_loss', self.dis_loss)
346 | else:
347 | # Not use feature matching (it is more similar to ordinary GANs)
348 | c_r_aug = tf.concat((self.c_r, tf.zeros((tf.shape(self.c_r)[0], 1))), axis=1)
349 | c_other = tf.concat((tf.zeros_like(self.c_r), tf.ones((tf.shape(self.c_r)[0], 1))), axis=1)
350 | with tf.name_scope('L_G'):
351 | L_G = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_f), y_f) + \
352 | tf.losses.sigmoid_cross_entropy(tf.ones_like(y_p), y_p) + \
353 | tf.losses.softmax_cross_entropy(c_r_aug, c_f) + \
354 | tf.losses.softmax_cross_entropy(c_r_aug, c_p)
355 |
356 | with tf.name_scope('L_rec'):
357 | # L_rec = 0.5 * tf.losses.mean_squared_error(self.x_r, x_f)
358 | L_rec = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(self.x_r, x_f), axis=[1, 2, 3]))
359 |
360 | with tf.name_scope('L_D'):
361 | L_D = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_r), y_r) + \
362 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_f), y_f) + \
363 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_p), y_p)
364 |
365 | with tf.name_scope('L_C'):
366 | L_C = tf.losses.softmax_cross_entropy(c_r_aug, c_r_pred) + \
367 | tf.losses.softmax_cross_entropy(c_other, c_f) + \
368 | tf.losses.softmax_cross_entropy(c_other, c_p)
369 |
370 | self.enc_trainer = enc_opt.minimize(L_rec + L_KL, var_list=self.f_enc.variables)
371 | self.gen_trainer = gen_opt.minimize(L_G + L_rec, var_list=self.f_gen.variables)
372 | self.cls_trainer = cls_opt.minimize(L_C, var_list=self.f_cls.variables)
373 | self.dis_trainer = dis_opt.minimize(L_D, var_list=self.f_dis.variables)
374 |
375 | self.gen_loss = L_G + L_rec
376 | self.dis_loss = L_D
377 |
378 | # Predictor
379 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
380 | self.c_test = tf.placeholder(tf.float32, shape=(None, self.num_attrs))
381 |
382 | self.x_test = self.f_gen(self.z_test, self.c_test)
383 | x_tile = self.image_tiling(self.x_test, self.test_size, self.num_attrs)
384 |
385 | # Summary
386 | tf.summary.image('x_real', self.x_r, 10)
387 | tf.summary.image('x_fake', x_f, 10)
388 | tf.summary.image('x_tile', x_tile, 1)
389 | tf.summary.scalar('L_G', L_G)
390 | tf.summary.scalar('L_rec', L_rec)
391 | tf.summary.scalar('L_C', L_C)
392 | tf.summary.scalar('L_D', L_D)
393 | tf.summary.scalar('L_KL', L_KL)
394 | tf.summary.scalar('gen_loss', self.gen_loss)
395 | tf.summary.scalar('dis_loss', self.dis_loss)
396 |
397 | # Accuracy
398 | self.gen_acc = 0.5 * binary_accuracy(tf.ones_like(y_f), y_f) + \
399 | 0.5 * binary_accuracy(tf.ones_like(y_p), y_p)
400 |
401 | self.dis_acc = binary_accuracy(tf.ones_like(y_r), y_r) / 3.0 + \
402 | binary_accuracy(tf.zeros_like(y_f), y_f) / 3.0 + \
403 | binary_accuracy(tf.zeros_like(y_p), y_p) / 3.0
404 |
405 | tf.summary.scalar('gen_acc', self.gen_acc)
406 | tf.summary.scalar('dis_acc', self.dis_acc)
407 |
408 | self.summary = tf.summary.merge_all()
409 |
410 | def L_G(self, x_r, x_f, f_D_r, f_D_f, f_C_r, f_C_f):
411 | with tf.name_scope('L_G'):
412 | loss = tf.constant(0.0, dtype=tf.float32)
413 | loss += 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(x_r, x_f), axis=[1, 2, 3]))
414 | loss += 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(f_D_r, f_D_f), axis=[1]))
415 | loss += 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(f_C_r, f_C_f), axis=[1]))
416 |
417 | return loss
418 |
419 | def L_GD(self, f_D_r, f_D_p):
420 | with tf.name_scope('L_GD'):
421 | # Compute loss
422 | E_f_D_r = tf.reduce_mean(f_D_r, axis=0)
423 | E_f_D_p = tf.reduce_mean(f_D_p, axis=0)
424 |
425 | # Update features
426 | if self.E_f_D_r is None:
427 | self.E_f_D_r = tf.zeros_like(E_f_D_r)
428 |
429 | if self.E_f_D_p is None:
430 | self.E_f_D_p = tf.zeros_like(E_f_D_p)
431 |
432 | self.E_f_D_r = self.alpha * self.E_f_D_r + (1.0 - self.alpha) * E_f_D_r
433 | self.E_f_D_p = self.alpha * self.E_f_D_p + (1.0 - self.alpha) * E_f_D_p
434 | return 0.5 * tf.reduce_sum(tf.squared_difference(self.E_f_D_r, self.E_f_D_p))
435 |
436 | def L_GC(self, f_C_r, f_C_p, c):
437 | with tf.name_scope('L_GC'):
438 | image_shape = tf.shape(f_C_r)
439 |
440 | indices = tf.eye(self.num_attrs, dtype=tf.float32)
441 | indices = tf.tile(indices, (1, image_shape[0]))
442 | indices = tf.reshape(indices, (-1, self.num_attrs))
443 |
444 | classes = tf.tile(c, (self.num_attrs, 1))
445 |
446 | mask = tf.reduce_max(tf.multiply(indices, classes), axis=1)
447 | mask = tf.reshape(mask, (-1, 1))
448 | mask = tf.tile(mask, (1, image_shape[1]))
449 |
450 | denom = tf.reshape(tf.multiply(indices, classes), (self.num_attrs, image_shape[0], self.num_attrs))
451 | denom = tf.reduce_sum(denom, axis=[1, 2])
452 | denom = tf.tile(tf.reshape(denom, (-1, 1)), (1, image_shape[1]))
453 |
454 | f_1_sum = tf.tile(f_C_r, (self.num_attrs, 1))
455 | f_1_sum = tf.multiply(f_1_sum, mask)
456 | f_1_sum = tf.reshape(f_1_sum, (self.num_attrs, image_shape[0], image_shape[1]))
457 | E_f_1 = tf.divide(tf.reduce_sum(f_1_sum, axis=1), denom + 1.0e-8)
458 |
459 | f_2_sum = tf.tile(f_C_p, (self.num_attrs, 1))
460 | f_2_sum = tf.multiply(f_2_sum, mask)
461 | f_2_sum = tf.reshape(f_2_sum, (self.num_attrs, image_shape[0], image_shape[1]))
462 | E_f_2 = tf.divide(tf.reduce_sum(f_2_sum, axis=1), denom + 1.0e-8)
463 |
464 | # Update features
465 | if self.E_f_C_r is None:
466 | self.E_f_C_r = tf.zeros_like(E_f_1)
467 |
468 | if self.E_f_C_p is None:
469 | self.E_f_C_p = tf.zeros_like(E_f_2)
470 |
471 | self.E_f_C_r = self.alpha * self.E_f_C_r + (1.0 - self.alpha) * E_f_1
472 | self.E_f_C_p = self.alpha * self.E_f_C_p + (1.0 - self.alpha) * E_f_2
473 |
474 | # return 0.5 * tf.losses.mean_squared_error(self.E_f_C_r, self.E_f_C_p)
475 | return 0.5 * tf.reduce_sum(tf.squared_difference(self.E_f_C_r, self.E_f_C_p))
476 |
--------------------------------------------------------------------------------
/models/dcgan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from .base import BaseModel
5 | from .utils import *
6 | from .wnorm import *
7 |
8 | class Generator(object):
9 | def __init__(self, input_shape, z_dims, use_wnorm=False):
10 | self.variables = None
11 | self.update_ops = None
12 | self.reuse = False
13 | self.use_wnorm = use_wnorm
14 | self.input_shape = input_shape
15 | self.z_dims = z_dims
16 |
17 | def __call__(self, inputs, training=True):
18 | with tf.variable_scope('generator', reuse=self.reuse):
19 | with tf.variable_scope('fc1'):
20 | w = self.input_shape[0] // (2 ** 3)
21 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
22 | if self.use_wnorm:
23 | x = conv2d_transpose_wnorm(x, 256, (w, w), (1, 1), use_scale=True,
24 | kernel_initializer=tf.contrib.layers.xavier_initializer())
25 | x = tf.layers.batch_normalization(x, scale=False, training=training)
26 | else:
27 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1),
28 | kernel_initializer=tf.contrib.layers.xavier_initializer())
29 | x = tf.layers.batch_normalization(x, training=training)
30 |
31 | x = tf.nn.relu(x)
32 |
33 | with tf.variable_scope('conv1'):
34 | if self.use_wnorm:
35 | x = conv2d_transpose_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True,
36 | kernel_initializer=tf.contrib.layers.xavier_initializer())
37 | x = tf.layers.batch_normalization(x, scale=False, training=training)
38 | else:
39 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
40 | x = tf.layers.batch_normalization(x, training=training)
41 | x = tf.nn.relu(x)
42 |
43 | with tf.variable_scope('conv2'):
44 | if self.use_wnorm:
45 | x = conv2d_transpose_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True,
46 | kernel_initializer=tf.contrib.layers.xavier_initializer())
47 | x = tf.layers.batch_normalization(x, scale=False, training=training)
48 | else:
49 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
50 | x = tf.layers.batch_normalization(x, training=training)
51 | x = tf.nn.relu(x)
52 |
53 | with tf.variable_scope('conv3'):
54 | if self.use_wnorm:
55 | x = conv2d_transpose_wnorm(x, 64, (5, 5), (2, 2), 'same', use_scale=True,
56 | kernel_initializer=tf.contrib.layers.xavier_initializer())
57 | x = tf.layers.batch_normalization(x, scale=False, training=training)
58 | else:
59 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
60 | x = tf.layers.batch_normalization(x, training=training)
61 | x = tf.nn.relu(x)
62 |
63 | with tf.variable_scope('conv4'):
64 | d = self.input_shape[2]
65 | if self.use_wnorm:
66 | x = conv2d_transpose_wnorm(x, d, (5, 5), (1, 1), 'same', use_scale=True,
67 | kernel_initializer=tf.contrib.layers.xavier_initializer())
68 | else:
69 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same',
70 | kernel_initializer=tf.contrib.layers.xavier_initializer())
71 | x = tf.tanh(x)
72 |
73 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
74 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
75 | self.reuse = True
76 | return x
77 |
78 | class Discriminator(object):
79 | def __init__(self, input_shape, use_wnorm=False):
80 | self.input_shape = input_shape
81 | self.variables = None
82 | self.update_ops = None
83 | self.use_wnorm = use_wnorm
84 | self.reuse = False
85 |
86 | def __call__(self, inputs, training=True):
87 | with tf.variable_scope('discriminator', reuse=self.reuse):
88 | with tf.variable_scope('conv1'):
89 | if self.use_wnorm:
90 | x = conv2d_wnorm(inputs, 64, (5, 5), (2, 2), 'same', use_scale=True,
91 | kernel_initializer=tf.contrib.layers.xavier_initializer())
92 | x = tf.layers.batch_normalization(x, scale=False, training=training)
93 | else:
94 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
95 | x = tf.layers.batch_normalization(x, training=training)
96 | x = lrelu(x)
97 |
98 | with tf.variable_scope('conv2'):
99 | if self.use_wnorm:
100 | x = conv2d_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True,
101 | kernel_initializer=tf.contrib.layers.xavier_initializer())
102 | x = tf.layers.batch_normalization(x, scale=False, training=training)
103 | else:
104 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
105 | x = tf.layers.batch_normalization(x, training=training)
106 | x = lrelu(x)
107 |
108 | with tf.variable_scope('conv3'):
109 | if self.use_wnorm:
110 | x = conv2d_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True,
111 | kernel_initializer=tf.contrib.layers.xavier_initializer())
112 | x = tf.layers.batch_normalization(x, scale=False, training=training)
113 | else:
114 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
115 | x = tf.layers.batch_normalization(x, training=training)
116 | x = lrelu(x)
117 |
118 | with tf.variable_scope('conv4'):
119 | if self.use_wnorm:
120 | x = conv2d_wnorm(x, 512, (5, 5), (2, 2), 'same', use_scale=True,
121 | kernel_initializer=tf.contrib.layers.xavier_initializer())
122 | x = tf.layers.batch_normalization(x, scale=False, training=training)
123 | else:
124 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
125 | x = tf.layers.batch_normalization(x, training=training)
126 | x = lrelu(x)
127 |
128 | with tf.variable_scope('conv5'):
129 | w = self.input_shape[0] // (2 ** 4)
130 | if self.use_wnorm:
131 | y = conv2d_wnorm(x, 1, (w, w), (1, 1), 'valid', use_scale=True,
132 | kernel_initializer=tf.contrib.layers.xavier_initializer())
133 | else:
134 | y = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid',
135 | kernel_initializer=tf.contrib.layers.xavier_initializer())
136 | y = tf.reshape(y, [-1, 1])
137 |
138 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
139 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
140 | self.reuse = True
141 | return y
142 |
143 | class DCGAN(BaseModel):
144 | def __init__(self,
145 | input_shape=(64, 64, 3),
146 | z_dims = 128,
147 | name='dcgan',
148 | **kwargs
149 | ):
150 | super(DCGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
151 |
152 | self.z_dims = z_dims
153 | self.use_wnorm = True
154 |
155 | self.f_gen = None
156 | self.f_dis = None
157 | self.gen_loss = None
158 | self.dis_loss = None
159 | self.train_op = None
160 |
161 | self.gen_acc = None
162 | self.dis_acc = None
163 |
164 | self.x_train = None
165 | self.z_train = None
166 |
167 | self.z_test = None
168 | self.x_test = None
169 |
170 | self.build_model()
171 |
172 | def train_on_batch(self, x_batch, index):
173 | batchsize = x_batch.shape[0]
174 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
175 |
176 | _, g_loss, d_loss, g_acc, d_acc, summary = self.sess.run(
177 | (self.train_op, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc, self.summary),
178 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data}
179 | )
180 |
181 | self.writer.add_summary(summary, index)
182 |
183 | return [
184 | ('g_loss', g_loss), ('d_loss', d_loss),
185 | ('g_acc', g_acc), ('d_acc', d_acc)
186 | ]
187 |
188 | def predict(self, z_samples):
189 | x_sample = self.sess.run(
190 | self.x_test,
191 | feed_dict={self.z_test: z_samples}
192 | )
193 | return x_sample
194 |
195 | def make_test_data(self):
196 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims))
197 |
198 | def build_model(self):
199 | # Trainer
200 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm)
201 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm)
202 |
203 | x_shape = (None,) + self.input_shape
204 | z_shape = (None, self.z_dims)
205 | self.x_train = tf.placeholder(tf.float32, shape=x_shape)
206 | self.z_train = tf.placeholder(tf.float32, shape=z_shape)
207 | x_fake = self.f_gen(self.z_train)
208 | y_fake = self.f_dis(x_fake)
209 | y_real = self.f_dis(self.x_train)
210 |
211 | self.gen_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_fake), y_fake)
212 | self.dis_loss = 0.5 * tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real), y_real) + \
213 | 0.5 * tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake), y_fake)
214 |
215 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
216 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
217 |
218 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables)
219 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables)
220 |
221 | self.gen_acc = binary_accuracy(tf.ones_like(y_fake), y_fake)
222 | self.dis_acc = 0.5 * binary_accuracy(tf.ones_like(y_real), y_real) + \
223 | 0.5 * binary_accuracy(tf.zeros_like(y_fake), y_fake)
224 |
225 | with tf.control_dependencies([gen_train_op, dis_train_op] + \
226 | self.f_dis.update_ops + \
227 | self.f_gen.update_ops):
228 | self.train_op = tf.no_op(name='train')
229 |
230 | # Predictor
231 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
232 | self.x_test = self.f_gen(self.z_test, training=False)
233 |
234 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
235 |
236 | tf.summary.image('x_real', image_cast(self.x_train), 10)
237 | tf.summary.image('x_fake', image_cast(x_fake), 10)
238 | tf.summary.image('x_tile', image_cast(x_tile), 1)
239 | tf.summary.scalar('gen_loss', self.gen_loss)
240 | tf.summary.scalar('dis_loss', self.dis_loss)
241 | tf.summary.scalar('gen_acc', self.gen_acc)
242 | tf.summary.scalar('dis_acc', self.dis_acc)
243 | self.summary = tf.summary.merge_all()
244 |
--------------------------------------------------------------------------------
/models/improved_gan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from .base import BaseModel
5 | from .utils import *
6 |
7 | def minibatch_discrimination(x, kernels=50, dims=5):
8 | with tf.name_scope('MinibatchDiscrimination'):
9 | size = x.get_shape()[1]
10 | W = tf.get_variable(shape=(size, kernels * dims), trainable=True, name='kernel')
11 | Ms = tf.tensordot(x, W, axes=1)
12 | x_i = tf.reshape(Ms, [-1, kernels, 1, dims])
13 | x_j = tf.reshape(Ms, [-1, 1, kernels, dims])
14 | x_i = tf.tile(x_i, [1, 1, kernels, 1])
15 | x_j = tf.tile(x_j, [1, kernels, 1, 1])
16 | norm = tf.reduce_sum(tf.abs(x_i - x_j), axis=3)
17 | Os = tf.reduce_sum(tf.exp(-norm), axis=2)
18 | return Os
19 |
20 | class Generator(object):
21 | def __init__(self, input_shape, z_dims, use_wnorm=False):
22 | self.variables = None
23 | self.update_ops = None
24 | self.reuse = False
25 | self.use_wnorm = use_wnorm
26 | self.input_shape = input_shape
27 | self.z_dims = z_dims
28 |
29 | def __call__(self, inputs, training=True):
30 | with tf.variable_scope('generator', reuse=self.reuse):
31 | with tf.variable_scope('fc1'):
32 | w = self.input_shape[0] // (2 ** 3)
33 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
34 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1),
35 | kernel_initializer=tf.contrib.layers.xavier_initializer())
36 | x = tf.layers.batch_normalization(x, training=training)
37 | x = tf.nn.relu(x)
38 |
39 | with tf.variable_scope('conv1'):
40 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same',
41 | kernel_initializer=tf.contrib.layers.xavier_initializer())
42 | x = tf.layers.batch_normalization(x, training=training)
43 | x = tf.nn.relu(x)
44 |
45 | with tf.variable_scope('conv2'):
46 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same',
47 | kernel_initializer=tf.contrib.layers.xavier_initializer())
48 | x = tf.layers.batch_normalization(x, training=training)
49 | x = tf.nn.relu(x)
50 |
51 | with tf.variable_scope('conv3'):
52 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same',
53 | kernel_initializer=tf.contrib.layers.xavier_initializer())
54 | x = tf.layers.batch_normalization(x, training=training)
55 | x = tf.nn.relu(x)
56 |
57 | with tf.variable_scope('conv4'):
58 | d = self.input_shape[2]
59 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same',
60 | kernel_initializer=tf.contrib.layers.xavier_initializer())
61 | x = tf.tanh(x)
62 |
63 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
64 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
65 | self.reuse = True
66 | return x
67 |
68 | class Discriminator(object):
69 | def __init__(self, input_shape, use_wnorm=False):
70 | self.input_shape = input_shape
71 | self.variables = None
72 | self.update_ops = None
73 | self.use_wnorm = use_wnorm
74 | self.reuse = False
75 |
76 | def __call__(self, inputs, training=True):
77 | with tf.variable_scope('discriminator', reuse=self.reuse):
78 | with tf.variable_scope('conv1'):
79 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same',
80 | kernel_initializer=tf.contrib.layers.xavier_initializer())
81 | x = tf.layers.batch_normalization(x, training=training)
82 | x = lrelu(x)
83 |
84 | with tf.variable_scope('conv2'):
85 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same',
86 | kernel_initializer=tf.contrib.layers.xavier_initializer())
87 | x = tf.layers.batch_normalization(x, training=training)
88 | x = lrelu(x)
89 |
90 | with tf.variable_scope('conv3'):
91 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same',
92 | kernel_initializer=tf.contrib.layers.xavier_initializer())
93 | x = tf.layers.batch_normalization(x, training=training)
94 | x = lrelu(x)
95 |
96 | with tf.variable_scope('conv4'):
97 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same',
98 | kernel_initializer=tf.contrib.layers.xavier_initializer())
99 | x = tf.layers.batch_normalization(x, training=training)
100 | x = lrelu(x)
101 |
102 | with tf.variable_scope('fc1'):
103 | x = tf.contrib.layers.flatten(x)
104 | x = tf.layers.dense(x, 1024)
105 | x = tf.layers.batch_normalization(x, training=training)
106 | x = lrelu(x)
107 |
108 | with tf.variable_scope('minibatch_discrimination'):
109 | x = minibatch_discrimination(x, kernels=50, dims=5)
110 | f = tf.identity(x)
111 |
112 | with tf.variable_scope('fc2'):
113 | y = tf.layers.dense(x, 1)
114 |
115 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
116 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
117 | self.reuse = True
118 | return y, f
119 |
120 | class ImprovedGAN(BaseModel):
121 | def __init__(self,
122 | input_shape=(64, 64, 3),
123 | z_dims = 128,
124 | name='improved',
125 | **kwargs
126 | ):
127 | super(ImprovedGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
128 |
129 | self.z_dims = z_dims
130 | self.use_wnorm = True
131 |
132 | self.f_gen = None
133 | self.f_dis = None
134 | self.gen_loss = None
135 | self.dis_loss = None
136 | self.train_op = None
137 |
138 | self.gen_acc = None
139 | self.dis_acc = None
140 |
141 | self.x_train = None
142 | self.z_train = None
143 |
144 | self.z_test = None
145 | self.x_test = None
146 |
147 | self.build_model()
148 |
149 | def train_on_batch(self, x_batch, index):
150 | batchsize = x_batch.shape[0]
151 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
152 |
153 | _, g_loss, d_loss, g_acc, d_acc, summary = self.sess.run(
154 | (self.train_op, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc, self.summary),
155 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data}
156 | )
157 |
158 | self.writer.add_summary(summary, index)
159 |
160 | return [
161 | ('g_loss', g_loss), ('d_loss', d_loss),
162 | ('g_acc', g_acc), ('d_acc', d_acc)
163 | ]
164 |
165 | def predict(self, z_samples):
166 | x_sample = self.sess.run(
167 | self.x_test,
168 | feed_dict={self.z_test: z_samples}
169 | )
170 | return x_sample
171 |
172 | def make_test_data(self):
173 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims))
174 |
175 | def build_model(self):
176 | # Trainer
177 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm)
178 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm)
179 |
180 | x_shape = (None,) + self.input_shape
181 | z_shape = (None, self.z_dims)
182 | self.x_train = tf.placeholder(tf.float32, shape=x_shape)
183 | self.z_train = tf.placeholder(tf.float32, shape=z_shape)
184 | x_fake = self.f_gen(self.z_train)
185 | y_fake, f_fake = self.f_dis(x_fake)
186 | y_real, f_real = self.f_dis(self.x_train)
187 |
188 | E_f_fake = tf.reduce_mean(f_fake, axis=0)
189 | E_f_real = tf.reduce_mean(f_real, axis=0)
190 | self.gen_loss = tf.reduce_sum(tf.square(E_f_real - E_f_fake))
191 | self.dis_loss = 0.5 * tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real), y_real) + \
192 | 0.5 * tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake), y_fake)
193 |
194 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
195 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
196 |
197 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables)
198 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables)
199 |
200 | self.gen_acc = binary_accuracy(tf.ones_like(y_fake), y_fake)
201 | self.dis_acc = 0.5 * binary_accuracy(tf.ones_like(y_real), y_real) + \
202 | 0.5 * binary_accuracy(tf.zeros_like(y_fake), y_fake)
203 |
204 | with tf.control_dependencies([gen_train_op, dis_train_op] + \
205 | self.f_dis.update_ops + \
206 | self.f_gen.update_ops):
207 | self.train_op = tf.no_op(name='train')
208 |
209 | # Predictor
210 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
211 | self.x_test = self.f_gen(self.z_test, training=False)
212 |
213 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
214 |
215 | tf.summary.image('x_real', image_cast(self.x_train), 10)
216 | tf.summary.image('x_fake', image_cast(x_fake), 10)
217 | tf.summary.image('x_tile', image_cast(x_tile), 1)
218 | tf.summary.scalar('gen_loss', self.gen_loss)
219 | tf.summary.scalar('dis_loss', self.dis_loss)
220 | tf.summary.scalar('gen_acc', self.gen_acc)
221 | tf.summary.scalar('dis_acc', self.dis_acc)
222 | self.summary = tf.summary.merge_all()
223 |
--------------------------------------------------------------------------------
/models/lsgan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from .base import BaseModel
5 | from .utils import *
6 |
7 | class Generator(object):
8 | def __init__(self, input_shape, z_dims, use_wnorm=False):
9 | self.variables = None
10 | self.update_ops = None
11 | self.reuse = False
12 | self.use_wnorm = use_wnorm
13 | self.input_shape = input_shape
14 | self.z_dims = z_dims
15 |
16 | def __call__(self, inputs, training=True):
17 | with tf.variable_scope('generator', reuse=self.reuse):
18 | with tf.variable_scope('fc1'):
19 | w = self.input_shape[0] // (2 ** 3)
20 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
21 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1),
22 | kernel_initializer=tf.contrib.layers.xavier_initializer())
23 | x = tf.layers.batch_normalization(x, training=training)
24 | x = tf.nn.relu(x)
25 |
26 | with tf.variable_scope('conv1'):
27 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
28 | x = tf.layers.batch_normalization(x, training=training)
29 | x = tf.nn.relu(x)
30 |
31 | with tf.variable_scope('conv2'):
32 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
33 | x = tf.layers.batch_normalization(x, training=training)
34 | x = tf.nn.relu(x)
35 |
36 | with tf.variable_scope('conv3'):
37 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
38 | x = tf.layers.batch_normalization(x, training=training)
39 | x = tf.nn.relu(x)
40 |
41 | with tf.variable_scope('conv4'):
42 | d = self.input_shape[2]
43 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same',
44 | kernel_initializer=tf.contrib.layers.xavier_initializer())
45 | x = tf.tanh(x)
46 |
47 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
48 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
49 | self.reuse = True
50 | return x
51 |
52 | class Discriminator(object):
53 | def __init__(self, input_shape, use_wnorm=False):
54 | self.input_shape = input_shape
55 | self.variables = None
56 | self.update_ops = None
57 | self.use_wnorm = use_wnorm
58 | self.reuse = False
59 |
60 | def __call__(self, inputs, training=True):
61 | with tf.variable_scope('discriminator', reuse=self.reuse):
62 | with tf.variable_scope('conv1'):
63 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
64 | x = tf.layers.batch_normalization(x, training=training)
65 | x = lrelu(x)
66 |
67 | with tf.variable_scope('conv2'):
68 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
69 | x = tf.layers.batch_normalization(x, training=training)
70 | x = lrelu(x)
71 |
72 | with tf.variable_scope('conv3'):
73 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
74 | x = tf.layers.batch_normalization(x, training=training)
75 | x = lrelu(x)
76 |
77 | with tf.variable_scope('conv4'):
78 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer())
79 | x = tf.layers.batch_normalization(x, training=training)
80 | x = lrelu(x)
81 |
82 | with tf.variable_scope('conv5'):
83 | w = self.input_shape[0] // (2 ** 4)
84 | y = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid',
85 | kernel_initializer=tf.contrib.layers.xavier_initializer())
86 | y = tf.reshape(y, [-1, 1])
87 |
88 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
89 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
90 | self.reuse = True
91 | return y
92 |
93 | class LSGAN(BaseModel):
94 | def __init__(self,
95 | input_shape=(64, 64, 3),
96 | z_dims = 128,
97 | name='lsgan',
98 | **kwargs
99 | ):
100 | super(LSGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
101 |
102 | self.z_dims = z_dims
103 | self.use_wnorm = True
104 |
105 | self.f_gen = None
106 | self.f_dis = None
107 | self.gen_loss = None
108 | self.dis_loss = None
109 | self.train_op = None
110 |
111 | self.param_a = 0.0
112 | self.param_b = 1.0
113 | self.param_c = 1.0
114 |
115 | self.x_train = None
116 | self.z_train = None
117 |
118 | self.z_test = None
119 | self.x_test = None
120 |
121 | self.build_model()
122 |
123 | def train_on_batch(self, x_batch, index):
124 | batchsize = x_batch.shape[0]
125 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
126 |
127 | _, g_loss, d_loss = self.sess.run(
128 | (self.train_op, self.gen_loss, self.dis_loss),
129 | feed_dict={self.x_train: x_batch, self.z_train: z_sample}
130 | )
131 |
132 | summary_period = 1000
133 | if index // summary_period != (index - batchsize) // summary_period:
134 | summary = self.sess.run(
135 | self.summary,
136 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data}
137 | )
138 | self.writer.add_summary(summary, index)
139 |
140 | return [
141 | ('g_loss', g_loss), ('d_loss', d_loss)
142 | ]
143 |
144 | def predict(self, z_samples):
145 | x_sample = self.sess.run(
146 | self.x_test,
147 | feed_dict={self.z_test: z_samples}
148 | )
149 | return x_sample
150 |
151 | def make_test_data(self):
152 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims))
153 |
154 | def build_model(self):
155 | # Trainer
156 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm)
157 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm)
158 |
159 | x_shape = (None,) + self.input_shape
160 | z_shape = (None, self.z_dims)
161 | self.x_train = tf.placeholder(tf.float32, shape=x_shape)
162 | self.z_train = tf.placeholder(tf.float32, shape=z_shape)
163 | x_fake = self.f_gen(self.z_train)
164 | y_fake = self.f_dis(x_fake)
165 | y_real = self.f_dis(self.x_train)
166 |
167 | self.gen_loss = tf.reduce_mean(tf.square(y_fake - self.param_c))
168 | self.dis_loss = tf.reduce_mean(tf.square(y_real - self.param_b)) + \
169 | tf.reduce_mean(tf.square(y_fake - self.param_a))
170 |
171 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
172 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
173 |
174 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables)
175 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables)
176 |
177 | with tf.control_dependencies([gen_train_op, dis_train_op] + \
178 | self.f_dis.update_ops + \
179 | self.f_gen.update_ops):
180 | self.train_op = tf.no_op(name='train')
181 |
182 | # Predictor
183 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
184 | self.x_test = self.f_gen(self.z_test, training=False)
185 |
186 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
187 |
188 | tf.summary.image('x_real', image_cast(self.x_train), 10)
189 | tf.summary.image('x_fake', image_cast(x_fake), 10)
190 | tf.summary.image('x_tile', image_cast(x_tile), 1)
191 | tf.summary.scalar('gen_loss', self.gen_loss)
192 | tf.summary.scalar('dis_loss', self.dis_loss)
193 | self.summary = tf.summary.merge_all()
194 |
--------------------------------------------------------------------------------
/models/resnet_gan.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from .base import BaseModel
5 | from .utils import *
6 | from .wnorm import *
7 |
8 | def residual_plain_unit(x, filters, training=True):
9 | y = tf.identity(x)
10 |
11 | x = tf.layers.batch_normalization(x, training=training)
12 | x = tf.nn.relu(x)
13 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same')
14 |
15 | x = tf.layers.batch_normalization(x, training=training)
16 | x = tf.nn.relu(x + y)
17 | x = tf.layers.dropout(x, rate=0.5, training=training)
18 |
19 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same')
20 |
21 | return x + y
22 |
23 | class Generator(object):
24 | def __init__(self, input_shape, z_dims, use_wnorm=False):
25 | self.variables = None
26 | self.update_ops = None
27 | self.reuse = False
28 | self.use_wnorm = use_wnorm
29 | self.input_shape = input_shape
30 | self.z_dims = z_dims
31 |
32 | def __call__(self, inputs, training=True):
33 | with tf.variable_scope('generator', reuse=self.reuse):
34 | with tf.variable_scope('deconv1'):
35 | w = self.input_shape[0] // (2 ** 3)
36 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
37 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 'valid')
38 | x = tf.layers.batch_normalization(x, training=training)
39 | x = tf.nn.relu(x)
40 |
41 | with tf.variable_scope('deconv2'):
42 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same')
43 | x = residual_plain_unit(x, 256, training=training)
44 | x = tf.layers.batch_normalization(x, training=training)
45 | x = tf.nn.relu(x)
46 |
47 | with tf.variable_scope('deconv3'):
48 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same')
49 | x = residual_plain_unit(x, 128, training=training)
50 | x = tf.layers.batch_normalization(x, training=training)
51 | x = tf.nn.relu(x)
52 |
53 | with tf.variable_scope('deconv4'):
54 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same')
55 | x = residual_plain_unit(x, 64, training=training)
56 | x = tf.layers.batch_normalization(x, training=training)
57 | x = tf.nn.relu(x)
58 |
59 | with tf.variable_scope('deconv5'):
60 | d = self.input_shape[2]
61 | x = tf.layers.conv2d(x, d, (5, 5), (1, 1), 'same',
62 | kernel_initializer=tf.contrib.layers.xavier_initializer())
63 | x = tf.tanh(x)
64 |
65 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
66 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')
67 | self.reuse = True
68 | return x
69 |
70 | class Discriminator(object):
71 | def __init__(self, input_shape, use_wnorm=False):
72 | self.input_shape = input_shape
73 | self.variables = None
74 | self.update_ops = None
75 | self.use_wnorm = use_wnorm
76 | self.reuse = False
77 |
78 | def __call__(self, inputs, training=True):
79 | with tf.variable_scope('discriminator', reuse=self.reuse):
80 | with tf.variable_scope('conv1'):
81 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same')
82 | x = residual_plain_unit(x, 64, training=training)
83 | x = tf.layers.batch_normalization(x, training=training)
84 | x = lrelu(x)
85 |
86 | with tf.variable_scope('conv2'):
87 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same')
88 | x = residual_plain_unit(x, 128, training=training)
89 | x = tf.layers.batch_normalization(x, training=training)
90 | x = lrelu(x)
91 |
92 | with tf.variable_scope('conv3'):
93 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same')
94 | x = residual_plain_unit(x, 256, training=training)
95 | x = tf.layers.batch_normalization(x, training=training)
96 | x = lrelu(x)
97 |
98 | with tf.variable_scope('conv4'):
99 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same')
100 | x = residual_plain_unit(x, 512, training=training)
101 | x = tf.layers.batch_normalization(x, training=training)
102 | x = lrelu(x)
103 |
104 | with tf.variable_scope('conv5'):
105 | w = self.input_shape[0] // (2 ** 4)
106 | y = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid')
107 | y = tf.reshape(y, [-1, 1])
108 |
109 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
110 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')
111 | self.reuse = True
112 | return y
113 |
114 | class ResNetGAN(BaseModel):
115 | def __init__(self,
116 | input_shape=(64, 64, 3),
117 | z_dims = 128,
118 | name='resnet',
119 | **kwargs
120 | ):
121 | super(ResNetGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
122 |
123 | self.z_dims = z_dims
124 | self.use_wnorm = True
125 |
126 | self.f_gen = None
127 | self.f_dis = None
128 | self.gen_loss = None
129 | self.dis_loss = None
130 | self.train_op = None
131 |
132 | self.gen_acc = None
133 | self.dis_acc = None
134 |
135 | self.x_train = None
136 | self.z_train = None
137 |
138 | self.z_test = None
139 | self.x_test = None
140 |
141 | self.build_model()
142 |
143 | def train_on_batch(self, x_batch, index):
144 | batchsize = x_batch.shape[0]
145 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
146 |
147 | _, g_loss, d_loss, g_acc, d_acc = self.sess.run(
148 | (self.train_op, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc),
149 | feed_dict={self.x_train: x_batch, self.z_train: z_sample}
150 | )
151 |
152 | summary_priod = 1000
153 | if index // summary_priod != (index - batchsize) // summary_priod:
154 | summary = self.sess.run(
155 | self.summary,
156 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data}
157 | )
158 | self.writer.add_summary(summary, index)
159 |
160 | return [
161 | ('g_loss', g_loss), ('d_loss', d_loss),
162 | ('g_acc', g_acc), ('d_acc', d_acc)
163 | ]
164 |
165 | def predict(self, z_samples):
166 | x_sample = self.sess.run(
167 | self.x_test,
168 | feed_dict={self.z_test: z_samples}
169 | )
170 | return x_sample
171 |
172 | def make_test_data(self):
173 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims))
174 |
175 | def build_model(self):
176 | # Trainer
177 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm)
178 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm)
179 |
180 | x_shape = (None,) + self.input_shape
181 | z_shape = (None, self.z_dims)
182 | self.x_train = tf.placeholder(tf.float32, shape=x_shape)
183 | self.z_train = tf.placeholder(tf.float32, shape=z_shape)
184 | x_fake = self.f_gen(self.z_train)
185 | y_fake = self.f_dis(x_fake)
186 | y_real = self.f_dis(self.x_train)
187 |
188 | self.gen_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_fake), y_fake)
189 | self.dis_loss = 0.5 * tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real), y_real) + \
190 | 0.5 * tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake), y_fake)
191 |
192 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
193 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
194 |
195 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables)
196 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables)
197 |
198 | self.gen_acc = binary_accuracy(tf.ones_like(y_fake), y_fake)
199 | self.dis_acc = 0.5 * binary_accuracy(tf.ones_like(y_real), y_real) + \
200 | 0.5 * binary_accuracy(tf.zeros_like(y_fake), y_fake)
201 |
202 | with tf.control_dependencies([gen_train_op, dis_train_op] + \
203 | self.f_dis.update_ops + \
204 | self.f_gen.update_ops):
205 | self.train_op = tf.no_op(name='train')
206 |
207 | # Predictor
208 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
209 | self.x_test = self.f_gen(self.z_test, training=False)
210 |
211 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
212 |
213 | tf.summary.image('x_real', image_cast(self.x_train), 10)
214 | tf.summary.image('x_fake', image_cast(x_fake), 10)
215 | tf.summary.image('x_tile', image_cast(x_tile), 1)
216 | tf.summary.scalar('gen_loss', self.gen_loss)
217 | tf.summary.scalar('dis_loss', self.dis_loss)
218 | tf.summary.scalar('gen_acc', self.gen_acc)
219 | tf.summary.scalar('dis_acc', self.dis_acc)
220 | self.summary = tf.summary.merge_all()
221 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def image_cast(img):
4 | return tf.cast(img * 127.5 + 127.5, tf.uint8)
5 |
6 | def kl_loss(avg, log_var):
7 | with tf.name_scope('KLLoss'):
8 | return tf.reduce_mean(-0.5 * tf.reduce_sum(1.0 + log_var - tf.square(avg) - tf.exp(log_var), axis=-1))
9 |
10 | def lrelu(x, alpha=0.02):
11 | with tf.name_scope('LeakyReLU'):
12 | return tf.maximum(x, alpha * x)
13 |
14 | def binary_accuracy(y_true, y_pred):
15 | with tf.name_scope('BinaryAccuracy'):
16 | return tf.reduce_mean(tf.cast(tf.equal(y_true, tf.round(tf.sigmoid(y_pred))), dtype=tf.float32))
17 |
18 | def sample_normal(avg, log_var):
19 | with tf.name_scope('SampleNormal'):
20 | epsilon = tf.random_normal(tf.shape(avg))
21 | return tf.add(avg, tf.multiply(tf.exp(0.5 * log_var), epsilon))
22 |
23 | def vgg_conv_unit(x, filters, layers, training=True):
24 | # Convolution
25 | for i in range(layers):
26 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same',
27 | kernel_initializer=tf.contrib.layers.xavier_initializer())
28 | x = tf.layers.batch_normalization(x, training=training)
29 | x = lrelu(x)
30 |
31 | # Downsample
32 | x = tf.layers.conv2d(x, filters, (2, 2), (2, 2), 'same',
33 | kernel_initializer=tf.contrib.layers.xavier_initializer())
34 | x = tf.layers.batch_normalization(x, training=training)
35 | x = lrelu(x)
36 |
37 | return x
38 |
39 | def vgg_deconv_unit(x, filters, layers, training=True):
40 | # Upsample
41 | x = tf.layers.conv2d_transpose(x, filters, (2, 2), (2, 2), 'same',
42 | kernel_initializer=tf.contrib.layers.xavier_initializer())
43 | x = tf.layers.batch_normalization(x, training=training)
44 | x = lrelu(x)
45 |
46 | # Convolution
47 | for i in range(layers):
48 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same',
49 | kernel_initializer=tf.contrib.layers.xavier_initializer())
50 | x = tf.layers.batch_normalization(x, training=training)
51 | x = lrelu(x)
52 |
53 | return x
54 |
55 | def time_format(t):
56 | m, s = divmod(t, 60)
57 | m = int(m)
58 | s = int(s)
59 | if m == 0:
60 | return '%d sec' % s
61 | else:
62 | return '%d min %d sec' % (m, s)
63 |
--------------------------------------------------------------------------------
/models/vae.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from .base import BaseModel
5 | from .utils import *
6 | from .wnorm import *
7 |
8 | class Encoder(object):
9 | def __init__(self, input_shape, z_dims, use_wnorm=True):
10 | self.variables = None
11 | self.update_ops = None
12 | self.reuse = False
13 | self.input_shape = input_shape
14 | self.z_dims = z_dims
15 | self.use_wnorm = use_wnorm
16 |
17 | def __call__(self, inputs, training=True):
18 | with tf.variable_scope('encoder', reuse=self.reuse):
19 | with tf.variable_scope('conv1'):
20 | if self.use_wnorm:
21 | x = conv2d_wnorm(inputs, 64, (5, 5), (2, 2), 'same', use_scale=True)
22 | x = tf.layers.batch_normalization(x, scale=False, training=training)
23 | else:
24 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same')
25 | x = tf.layers.batch_normalization(x, training=training)
26 | x = lrelu(x)
27 |
28 | with tf.variable_scope('conv2'):
29 | if self.use_wnorm:
30 | x = conv2d_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True)
31 | x = tf.layers.batch_normalization(x, scale=False, training=training)
32 | else:
33 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same')
34 | x = tf.layers.batch_normalization(x, training=training)
35 | x = lrelu(x)
36 |
37 | with tf.variable_scope('conv3'):
38 | if self.use_wnorm:
39 | x = conv2d_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True)
40 | x = tf.layers.batch_normalization(x, scale=False, training=training)
41 | else:
42 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same')
43 | x = tf.layers.batch_normalization(x, training=training)
44 | x = lrelu(x)
45 |
46 | with tf.variable_scope('conv4'):
47 | if self.use_wnorm:
48 | x = conv2d_wnorm(x, 512, (5, 5), (2, 2), 'same', use_scale=True)
49 | x = tf.layers.batch_normalization(x, scale=False, training=training)
50 | else:
51 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same')
52 | x = tf.layers.batch_normalization(x, training=training)
53 | x = lrelu(x)
54 |
55 | with tf.variable_scope('fc1'):
56 | w = self.input_shape[0] // (2 ** 4)
57 | if self.use_wnorm:
58 | z_avg = conv2d_wnorm(x, self.z_dims, (w, w), (1, 1), 'valid', use_scale=True)
59 | z_log_var = conv2d_wnorm(x, self.z_dims, (w, w), (1, 1), 'valid', use_scale=True)
60 | else:
61 | z_avg = tf.layers.conv2d(x, self.z_dims, (w, w), (1, 1), 'valid')
62 | z_log_var = tf.layers.conv2d(x, self.z_dims, (w, w), (1, 1), 'valid')
63 |
64 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder')
65 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='encoder')
66 | self.reuse = True
67 |
68 | return z_avg, z_log_var
69 |
70 | class Decoder(object):
71 | def __init__(self, input_shape, z_dims, use_wnorm=True):
72 | self.variables = None
73 | self.update_ops = None
74 | self.reuse = False
75 | self.input_shape = input_shape
76 | self.z_dims = z_dims
77 | self.use_wnorm = use_wnorm
78 |
79 | def __call__(self, inputs, training=True):
80 | with tf.variable_scope('decoder', reuse=self.reuse):
81 | with tf.variable_scope('deconv1'):
82 | w = self.input_shape[0] // (2 ** 3)
83 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
84 | if self.use_wnorm:
85 | x = conv2d_transpose_wnorm(x, 256, (w, w), (1, 1), 'valid', use_scale=True)
86 | x = tf.layers.batch_normalization(x, scale=False, training=training)
87 | else:
88 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 'valid')
89 | x = tf.layers.batch_normalization(x, training=training)
90 | x = tf.nn.relu(x)
91 |
92 | with tf.variable_scope('deconv2'):
93 | if self.use_wnorm:
94 | x = conv2d_transpose_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True)
95 | x = tf.layers.batch_normalization(x, scale=False, training=training)
96 | else:
97 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same')
98 | x = tf.layers.batch_normalization(x, training=training)
99 | x = tf.nn.relu(x)
100 |
101 | with tf.variable_scope('deconv3'):
102 | if self.use_wnorm:
103 | x = conv2d_transpose_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True)
104 | x = tf.layers.batch_normalization(x, scale=False, training=training)
105 | else:
106 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same')
107 | x = tf.layers.batch_normalization(x, training=training)
108 | x = tf.nn.relu(x)
109 |
110 | with tf.variable_scope('deconv4'):
111 | if self.use_wnorm:
112 | x = conv2d_transpose_wnorm(x, 64, (5, 5), (2, 2), 'same', use_scale=True)
113 | x = tf.layers.batch_normalization(x, scale=False, training=training)
114 | else:
115 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same')
116 | x = tf.layers.batch_normalization(x, training=training)
117 | x = tf.nn.relu(x)
118 |
119 | with tf.variable_scope('deconv5'):
120 | d = self.input_shape[2]
121 | if self.use_wnorm:
122 | x = conv2d_transpose_wnorm(x, d, (5, 5), (1, 1), 'same', use_scale=True)
123 | else:
124 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same')
125 | x = tf.tanh(x)
126 |
127 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder')
128 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='decoder')
129 | self.reuse = True
130 |
131 | return x
132 |
133 | class VAE(BaseModel):
134 | def __init__(self,
135 | input_shape=(64, 64, 3),
136 | z_dims = 128,
137 | name='vae',
138 | **kwargs
139 | ):
140 | super(VAE, self).__init__(input_shape=input_shape, name=name, **kwargs)
141 |
142 | self.z_dims = z_dims
143 | self.use_wnorm = False
144 |
145 | self.encoder = None
146 | self.decoder = None
147 | self.rec_loss = None
148 | self.kl_loss = None
149 | self.train_op = None
150 |
151 | self.x_train = None
152 |
153 | self.z_test = None
154 | self.x_test = None
155 |
156 | self.build_model()
157 |
158 | def train_on_batch(self, x_batch, index):
159 | _, rec_loss, kl_loss, summary = self.sess.run(
160 | (self.train_op, self.rec_loss, self.kl_loss, self.summary),
161 | feed_dict={self.x_train: x_batch, self.z_test: self.test_data}
162 | )
163 | self.writer.add_summary(summary, index)
164 | return [ ('rec_loss', rec_loss), ('kl_loss', kl_loss) ]
165 |
166 | def predict(self, z_samples):
167 | x_sample = self.sess.run(
168 | self.x_test,
169 | feed_dict={self.z_test: z_samples}
170 | )
171 | return x_sample
172 |
173 | def make_test_data(self):
174 | self.test_data = np.random.normal(size=(self.test_size * self.test_size, self.z_dims))
175 |
176 | def build_model(self):
177 | self.encoder = Encoder(self.input_shape, self.z_dims, self.use_wnorm)
178 | self.decoder = Decoder(self.input_shape, self.z_dims, self.use_wnorm)
179 |
180 | # Trainer
181 | batch_shape = (None,) + self.input_shape
182 | self.x_train = tf.placeholder(tf.float32, shape=batch_shape)
183 |
184 | z_avg, z_log_var = self.encoder(self.x_train)
185 | z_sample = sample_normal(z_avg, z_log_var)
186 | x_sample = self.decoder(z_sample)
187 |
188 | rec_loss_scale = tf.constant(np.prod(self.input_shape), tf.float32)
189 | self.rec_loss = tf.losses.absolute_difference(self.x_train, x_sample) * rec_loss_scale
190 | self.kl_loss = kl_loss(z_avg, z_log_var)
191 |
192 | optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5)
193 | fmin = optim.minimize(self.rec_loss + self.kl_loss)
194 |
195 | with tf.control_dependencies([fmin] + self.encoder.update_ops + self.decoder.update_ops):
196 | self.train_op = tf.no_op(name='train')
197 |
198 | # Predictor
199 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
200 | self.x_test = self.decoder(self.z_test)
201 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
202 |
203 | # Summary
204 | tf.summary.image('x_real', image_cast(self.x_train), 10)
205 | tf.summary.image('x_fake', image_cast(x_sample), 10)
206 | tf.summary.image('x_tile', image_cast(x_tile), 1)
207 | tf.summary.scalar('rec_loss', self.rec_loss)
208 | tf.summary.scalar('kl_loss', self.kl_loss)
209 |
210 | self.summary = tf.summary.merge_all()
211 |
--------------------------------------------------------------------------------
/models/wgan.py:
--------------------------------------------------------------------------------
1 | """
2 | Wasserstain GAN:
3 | This is an implementation of "improved" version of Wasserstein GAN.
4 | Gulrajani et al., "Improved Training of Wasserstein GAN", arXiv preprint, 2017.
5 | """
6 |
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 | from .base import BaseModel
11 | from .utils import *
12 |
13 | class Generator(object):
14 | def __init__(self, input_shape, z_dims):
15 | self.variables = None
16 | self.update_ops = None
17 | self.reuse = False
18 | self.name = 'generator'
19 | self.input_shape = input_shape
20 | self.z_dims = z_dims
21 |
22 | def __call__(self, inputs, training=True):
23 | with tf.variable_scope(self.name, reuse=self.reuse):
24 | with tf.variable_scope('fc1'):
25 | w = self.input_shape[0] // (2 ** 3)
26 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims])
27 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 'valid',
28 | kernel_initializer=tf.contrib.layers.xavier_initializer())
29 | x = tf.layers.batch_normalization(x, training=training)
30 | x = tf.nn.relu(x)
31 |
32 | with tf.variable_scope('conv1'):
33 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same',
34 | kernel_initializer=tf.contrib.layers.xavier_initializer())
35 | x = tf.layers.batch_normalization(x, training=training)
36 | x = tf.nn.relu(x)
37 |
38 | with tf.variable_scope('conv2'):
39 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same',
40 | kernel_initializer=tf.contrib.layers.xavier_initializer())
41 | x = tf.layers.batch_normalization(x, training=training)
42 | x = tf.nn.relu(x)
43 |
44 | with tf.variable_scope('conv3'):
45 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same',
46 | kernel_initializer=tf.contrib.layers.xavier_initializer())
47 | x = tf.layers.batch_normalization(x, training=training)
48 | x = tf.nn.relu(x)
49 |
50 | with tf.variable_scope('conv4'):
51 | d = self.input_shape[2]
52 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same',
53 | kernel_initializer=tf.contrib.layers.xavier_initializer())
54 | x = tf.tanh(x)
55 |
56 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
57 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name)
58 | self.reuse = True
59 | return x
60 |
61 | class Discriminator(object):
62 | def __init__(self, input_shape):
63 | self.input_shape = input_shape
64 | self.variables = None
65 | self.update_ops = None
66 | self.name = 'discriminator'
67 | self.reuse = False
68 |
69 | def __call__(self, inputs, training=True):
70 | with tf.variable_scope(self.name, reuse=self.reuse):
71 | with tf.variable_scope('conv1'):
72 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same',
73 | kernel_initializer=tf.contrib.layers.xavier_initializer())
74 | x = tf.layers.batch_normalization(x, training=training)
75 | x = lrelu(x)
76 |
77 | with tf.variable_scope('conv2'):
78 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same',
79 | kernel_initializer=tf.contrib.layers.xavier_initializer())
80 | x = tf.layers.batch_normalization(x, training=training)
81 | x = lrelu(x)
82 |
83 | with tf.variable_scope('conv3'):
84 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same',
85 | kernel_initializer=tf.contrib.layers.xavier_initializer())
86 | x = tf.layers.batch_normalization(x, training=training)
87 | x = lrelu(x)
88 |
89 | with tf.variable_scope('conv4'):
90 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same',
91 | kernel_initializer=tf.contrib.layers.xavier_initializer())
92 | x = tf.layers.batch_normalization(x, training=training)
93 | x = lrelu(x)
94 |
95 | with tf.variable_scope('conv5'):
96 | w = self.input_shape[0] // (2 ** 4)
97 | x = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid',
98 | kernel_initializer=tf.random_normal_initializer(stddev=0.005))
99 | y = tf.reshape(x, [-1, 1])
100 |
101 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name)
102 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name)
103 | self.reuse = True
104 | return y
105 |
106 | class WGAN(BaseModel):
107 | def __init__(self,
108 | input_shape=(64, 64, 3),
109 | z_dims = 128,
110 | name='wgan',
111 | **kwargs
112 | ):
113 | super(WGAN, self).__init__(input_shape=input_shape, name=name, **kwargs)
114 |
115 | self.z_dims = z_dims
116 | self.n_critic = 2
117 | self.lmbda = 10.0
118 |
119 | self.gen_loss = None
120 | self.dis_loss = None
121 | self.gen_train_op = None
122 | self.dis_train_op = None
123 |
124 | self.x_train = None
125 | self.e_random = None
126 | self.batch_idx = 0
127 |
128 | self.z_test = None
129 | self.x_test = None
130 |
131 | self.build_model()
132 |
133 | def train_on_batch(self, x_batch, index):
134 | batchsize = x_batch.shape[0]
135 | self.batch_idx += 1
136 |
137 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
138 | eps = float(np.random.uniform(0.0, 1.0, size=(1)))
139 | _, g_loss, d_loss = self.sess.run(
140 | (self.dis_train_op, self.gen_loss, self.dis_loss),
141 | feed_dict={
142 | self.x_train: x_batch,
143 | self.z_train: z_sample,
144 | self.e_random: eps
145 | }
146 | )
147 |
148 | if self.batch_idx % self.n_critic == 0:
149 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
150 | eps = float(np.random.uniform(0.0, 1.0, size=(1)))
151 | _, g_loss, d_loss = self.sess.run(
152 | (self.gen_train_op, self.gen_loss, self.dis_loss),
153 | feed_dict={
154 | self.x_train: x_batch,
155 | self.z_train: z_sample,
156 | self.e_random: eps,
157 | self.z_test: self.test_data
158 | }
159 | )
160 |
161 | # Summary update
162 | summary_priod = 1000
163 | if index // summary_priod != (index - batchsize) // summary_priod:
164 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims))
165 | eps = float(np.random.uniform(0.0, 1.0, size=(1)))
166 | summary = self.sess.run(
167 | self.summary,
168 | feed_dict={
169 | self.x_train: x_batch,
170 | self.z_train: z_sample,
171 | self.e_random: eps,
172 | self.z_test: self.test_data
173 | }
174 | )
175 | self.writer.add_summary(summary, index)
176 |
177 | return [
178 | ('g_loss', g_loss), ('d_loss', d_loss)
179 | ]
180 |
181 | def predict(self, z_samples):
182 | x_sample = self.sess.run(
183 | self.x_test,
184 | feed_dict={self.z_test: z_samples}
185 | )
186 | return x_sample
187 |
188 | def make_test_data(self):
189 | self.test_data = np.random.uniform(-1, 1, size=(self.test_size * self.test_size, self.z_dims))
190 |
191 | def build_model(self):
192 | # Trainer
193 | self.f_dis = Discriminator(self.input_shape)
194 | self.f_gen = Generator(self.input_shape, self.z_dims)
195 |
196 | x_shape = (None,) + self.input_shape
197 | z_shape = (None,) + (self.z_dims,)
198 | self.x_train = tf.placeholder(tf.float32, shape=x_shape)
199 | self.z_train = tf.placeholder(tf.float32, shape=z_shape)
200 | self.e_random = tf.placeholder(tf.float32, shape=())
201 |
202 | x_fake = self.f_gen(self.z_train)
203 | y_fake = self.f_dis(x_fake)
204 | y_real = self.f_dis(self.x_train)
205 |
206 | gen_optim = tf.train.AdamOptimizer(learning_rate=1.0e-4, beta1=0.0, beta2=0.9)
207 | dis_optim = tf.train.AdamOptimizer(learning_rate=1.0e-4, beta1=0.0, beta2=0.9)
208 |
209 | x_hat = self.e_random * self.x_train + (1.0 - self.e_random) * x_fake
210 | y_hat = self.f_dis(x_hat)
211 | d_grad = tf.gradients(y_hat, [x_hat])
212 | d_reg = tf.square(1.0 - tf.sqrt(tf.reduce_sum(tf.square(d_grad))))
213 |
214 | self.gen_loss = -tf.reduce_mean(y_fake)
215 | self.dis_loss = -tf.reduce_mean(y_real) + tf.reduce_mean(y_fake) + self.lmbda * d_reg
216 |
217 | gen_optim_min = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables)
218 | with tf.control_dependencies([gen_optim_min] + self.f_gen.update_ops):
219 | self.gen_train_op = tf.no_op(name='gen_train')
220 |
221 | dis_optim_min = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables)
222 |
223 | with tf.control_dependencies([dis_optim_min] + self.f_dis.update_ops):
224 | self.dis_train_op = tf.no_op(name='dis_train')
225 |
226 | # Predictor
227 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims))
228 | self.x_test = self.f_gen(self.z_test, training=False)
229 |
230 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size)
231 |
232 | tf.summary.image('x_real', image_cast(self.x_train), 10)
233 | tf.summary.image('x_fake', image_cast(x_fake), 10)
234 | tf.summary.image('x_tile', image_cast(x_tile), 1)
235 | tf.summary.scalar('gen_loss', self.gen_loss)
236 | tf.summary.scalar('dis_loss', self.dis_loss)
237 | self.summary = tf.summary.merge_all()
238 |
--------------------------------------------------------------------------------
/models/wnorm.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.eager import context
3 | from tensorflow.python.framework import ops, tensor_shape
4 | from tensorflow.python.layers import base, utils
5 | from tensorflow.python.ops import nn, standard_ops, array_ops, init_ops, nn_ops
6 |
7 | class DenseWNorm(base.Layer):
8 | def __init__(self, units,
9 | activation=None,
10 | use_scale=True,
11 | use_bias=True,
12 | kernel_initializer=None,
13 | scale_initializer=None,
14 | bias_initializer=init_ops.zeros_initializer(),
15 | kernel_regularizer=None,
16 | scale_regularizer=None,
17 | bias_regularizer=None,
18 | activity_regularizer=None,
19 | kernel_constraint=None,
20 | scale_constraint=None,
21 | bias_constraint=None,
22 | trainable=True,
23 | name=None,
24 | **kwargs):
25 | super(DenseWNorm, self).__init__(trainable=trainable, name=name,
26 | activity_regularizer=activity_regularizer,
27 | **kwargs)
28 | self.units = units
29 | self.activation = activation
30 | self.use_scale = use_scale
31 | self.use_bias = use_bias
32 | self.kernel_initializer = kernel_initializer
33 | self.scale_initializer = scale_initializer
34 | self.bias_initializer = bias_initializer
35 | self.kernel_regularizer = kernel_regularizer
36 | self.scale_regularizer = scale_regularizer
37 | self.bias_regularizer = bias_regularizer
38 | self.kernel_constraint = kernel_constraint
39 | self.scale_constraint = scale_constraint
40 | self.bias_constraint = bias_constraint
41 | self.input_spec = base.InputSpec(min_ndim=2)
42 |
43 | def build(self, input_shape):
44 | input_shape = tensor_shape.TensorShape(input_shape)
45 | if input_shape[-1].value is None:
46 | raise ValueError('The last dimension of the inputs to `Dense` '
47 | 'should be defined. Found `None`.')
48 | self.input_spec = base.InputSpec(min_ndim=2,
49 | axes={-1: input_shape[-1].value})
50 | self.kernel = self.add_variable('kernel',
51 | shape=[input_shape[-1].value, self.units],
52 | initializer=self.kernel_initializer,
53 | regularizer=self.kernel_regularizer,
54 | constraint=self.kernel_constraint,
55 | dtype=self.dtype,
56 | trainable=True)
57 |
58 | if self.use_scale:
59 | self.scale = self.add_variable('scale',
60 | shape=[self.units,],
61 | initializer=self.scale_initializer,
62 | regularizer=self.scale_regularizer,
63 | constraint=self.scale_constraint,
64 | dtype=self.dtype,
65 | trainable=True)
66 | else:
67 | self.scale = 1.0
68 |
69 | if self.use_bias:
70 | self.bias = self.add_variable('bias',
71 | shape=[self.units,],
72 | initializer=self.bias_initializer,
73 | regularizer=self.bias_regularizer,
74 | constraint=self.bias_constraint,
75 | dtype=self.dtype,
76 | trainable=True)
77 | else:
78 | self.bias = None
79 | self.built = True
80 |
81 | def call(self, inputs):
82 | inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
83 | shape = inputs.get_shape().as_list()
84 |
85 | if len(shape) > 2:
86 | # Broadcasting is required for the inputs.
87 | outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1],
88 | [0]])
89 | # Reshape the output back to the original ndim of the input.
90 | if context.in_graph_mode():
91 | output_shape = shape[:-1] + [self.units]
92 | outputs.set_shape(output_shape)
93 | else:
94 | outputs = standard_ops.matmul(inputs, self.kernel)
95 |
96 | scaler = self.scale / tf.sqrt(tf.reduce_sum(tf.square(self.kernel), [0]))
97 | outputs = scaler * outputs
98 |
99 | if self.use_bias:
100 | outputs = nn.bias_add(outputs, self.bias)
101 | if self.activation is not None:
102 | return self.activation(outputs) # pylint: disable=not-callable
103 | return outputs
104 |
105 | def _compute_output_shape(self, input_shape):
106 | input_shape = tensor_shape.TensorShape(input_shape)
107 | input_shape = input_shape.with_rank_at_least(2)
108 | if input_shape[-1].value is None:
109 | raise ValueError(
110 | 'The innermost dimension of input_shape must be defined, but saw: %s'
111 | % input_shape)
112 | return input_shape[:-1].concatenate(self.units)
113 |
114 |
115 | def dense_wnorm(
116 | inputs, units,
117 | activation=None,
118 | use_scale=True,
119 | use_bias=True,
120 | kernel_initializer=None,
121 | scale_initializer=None,
122 | bias_initializer=init_ops.zeros_initializer(),
123 | kernel_regularizer=None,
124 | scale_regularizer=None,
125 | bias_regularizer=None,
126 | activity_regularizer=None,
127 | kernel_constraint=None,
128 | scale_constraint=None,
129 | bias_constraint=None,
130 | trainable=True,
131 | name=None,
132 | reuse=None):
133 |
134 | layer = DenseWNorm(units,
135 | activation=activation,
136 | use_scale=use_scale,
137 | use_bias=use_bias,
138 | kernel_initializer=kernel_initializer,
139 | scale_initializer=scale_initializer,
140 | bias_initializer=bias_initializer,
141 | kernel_regularizer=kernel_regularizer,
142 | scale_regularizer=scale_regularizer,
143 | bias_regularizer=bias_regularizer,
144 | activity_regularizer=activity_regularizer,
145 | kernel_constraint=kernel_constraint,
146 | scale_constraint=scale_constraint,
147 | bias_constraint=bias_constraint,
148 | trainable=trainable,
149 | name=name,
150 | dtype=inputs.dtype.base_dtype,
151 | _scope=name,
152 | _reuse=reuse)
153 | return layer.apply(inputs)
154 |
155 |
156 | class _ConvWNorm(base.Layer):
157 | def __init__(self, rank,
158 | filters,
159 | kernel_size,
160 | strides=1,
161 | padding='valid',
162 | data_format='channels_last',
163 | dilation_rate=1,
164 | activation=None,
165 | use_scale=True,
166 | use_bias=True,
167 | kernel_initializer=None,
168 | scale_initializer=None,
169 | bias_initializer=init_ops.zeros_initializer(),
170 | scale_regularizer=None,
171 | kernel_regularizer=None,
172 | bias_regularizer=None,
173 | activity_regularizer=None,
174 | kernel_constraint=None,
175 | scale_constraint=None,
176 | bias_constraint=None,
177 | trainable=True,
178 | name=None,
179 | **kwargs):
180 | super(_ConvWNorm, self).__init__(trainable=trainable, name=name,
181 | activity_regularizer=activity_regularizer,
182 | **kwargs)
183 | self.rank = rank
184 | self.filters = filters
185 | self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size')
186 | self.strides = utils.normalize_tuple(strides, rank, 'strides')
187 | self.padding = utils.normalize_padding(padding)
188 | self.data_format = utils.normalize_data_format(data_format)
189 | self.dilation_rate = utils.normalize_tuple(
190 | dilation_rate, rank, 'dilation_rate')
191 | self.activation = activation
192 | self.use_scale = use_scale
193 | self.use_bias = use_bias
194 | self.kernel_initializer = kernel_initializer
195 | self.scale_initializer = scale_initializer
196 | self.bias_initializer = bias_initializer
197 | self.kernel_regularizer = kernel_regularizer
198 | self.scale_regularizer = scale_regularizer
199 | self.bias_regularizer = bias_regularizer
200 | self.kernel_constraint = kernel_constraint
201 | self.scale_constraint = scale_constraint
202 | self.bias_constraint = bias_constraint
203 | self.input_spec = base.InputSpec(ndim=self.rank + 2)
204 |
205 | def build(self, input_shape):
206 | input_shape = tensor_shape.TensorShape(input_shape)
207 | if self.data_format == 'channels_first':
208 | channel_axis = 1
209 | else:
210 | channel_axis = -1
211 | if input_shape[channel_axis].value is None:
212 | raise ValueError('The channel dimension of the inputs '
213 | 'should be defined. Found `None`.')
214 | input_dim = input_shape[channel_axis].value
215 | kernel_shape = self.kernel_size + (input_dim, self.filters)
216 |
217 | self.kernel = self.add_variable(name='kernel',
218 | shape=kernel_shape,
219 | initializer=self.kernel_initializer,
220 | regularizer=self.kernel_regularizer,
221 | constraint=self.kernel_constraint,
222 | trainable=True,
223 | dtype=self.dtype)
224 |
225 | if self.use_scale:
226 | self.scale = self.add_variable(name='scale',
227 | shape=(self.filters,),
228 | initializer=self.scale_initializer,
229 | regularizer=self.scale_regularizer,
230 | constraint=self.scale_constraint,
231 | trainable=True,
232 | dtype=self.dtype)
233 | else:
234 | self.scale = None
235 |
236 | if self.use_bias:
237 | self.bias = self.add_variable(name='bias',
238 | shape=(self.filters,),
239 | initializer=self.bias_initializer,
240 | regularizer=self.bias_regularizer,
241 | constraint=self.bias_constraint,
242 | trainable=True,
243 | dtype=self.dtype)
244 | else:
245 | self.bias = None
246 |
247 | self.input_spec = base.InputSpec(ndim=self.rank + 2,
248 | axes={channel_axis: input_dim})
249 |
250 | self._convolution_op = nn_ops.Convolution(
251 | input_shape,
252 | filter_shape=self.kernel.get_shape(),
253 | dilation_rate=self.dilation_rate,
254 | strides=self.strides,
255 | padding=self.padding.upper(),
256 | data_format=utils.convert_data_format(self.data_format,
257 | self.rank + 2))
258 | self.built = True
259 |
260 | def call(self, inputs):
261 | kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 2])
262 | if self.use_scale:
263 | kernel_norm = tf.reshape(self.scale, [1, 1, 1, self.filters]) * kernel_norm
264 | outputs = self._convolution_op(inputs, kernel_norm)
265 |
266 | if self.use_bias:
267 | if self.data_format == 'channels_first':
268 | if self.rank == 1:
269 | # nn.bias_add does not accept a 1D input tensor.
270 | bias = array_ops.reshape(self.bias, (1, self.filters, 1))
271 | outputs += bias
272 | if self.rank == 2:
273 | outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
274 | if self.rank == 3:
275 | # As of Mar 2017, direct addition is significantly slower than
276 | # bias_add when computing gradients. To use bias_add, we collapse Z
277 | # and Y into a single dimension to obtain a 4D input tensor.
278 | outputs_shape = outputs.shape.as_list()
279 | outputs_4d = array_ops.reshape(outputs,
280 | [outputs_shape[0], outputs_shape[1],
281 | outputs_shape[2] * outputs_shape[3],
282 | outputs_shape[4]])
283 | outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
284 | outputs = array_ops.reshape(outputs_4d, outputs_shape)
285 | else:
286 | outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
287 |
288 | if self.activation is not None:
289 | return self.activation(outputs)
290 | return outputs
291 |
292 | def _compute_output_shape(self, input_shape):
293 | input_shape = tensor_shape.TensorShape(input_shape).as_list()
294 | if self.data_format == 'channels_last':
295 | space = input_shape[1:-1]
296 | new_space = []
297 | for i in range(len(space)):
298 | new_dim = utils.conv_output_length(
299 | space[i],
300 | self.kernel_size[i],
301 | padding=self.padding,
302 | stride=self.strides[i],
303 | dilation=self.dilation_rate[i])
304 | new_space.append(new_dim)
305 | return tensor_shape.TensorShape([input_shape[0]] + new_space +
306 | [self.filters])
307 | else:
308 | space = input_shape[2:]
309 | new_space = []
310 | for i in range(len(space)):
311 | new_dim = utils.conv_output_length(
312 | space[i],
313 | self.kernel_size[i],
314 | padding=self.padding,
315 | stride=self.strides[i],
316 | dilation=self.dilation_rate[i])
317 | new_space.append(new_dim)
318 | return tensor_shape.TensorShape([input_shape[0], self.filters] + new_space)
319 |
320 | class Conv2DWNorm(_ConvWNorm):
321 | def __init__(self, filters,
322 | kernel_size,
323 | strides=(1, 1),
324 | padding='valid',
325 | data_format='channels_last',
326 | dilation_rate=(1, 1),
327 | activation=None,
328 | use_scale=True,
329 | use_bias=True,
330 | kernel_initializer=None,
331 | scale_initializer=None,
332 | bias_initializer=init_ops.zeros_initializer(),
333 | kernel_regularizer=None,
334 | scale_regularizer=None,
335 | bias_regularizer=None,
336 | activity_regularizer=None,
337 | kernel_constraint=None,
338 | scale_constraint=None,
339 | bias_constraint=None,
340 | trainable=True,
341 | name=None,
342 | **kwargs):
343 | super(Conv2DWNorm, self).__init__(
344 | rank=2,
345 | filters=filters,
346 | kernel_size=kernel_size,
347 | strides=strides,
348 | padding=padding,
349 | data_format=data_format,
350 | dilation_rate=dilation_rate,
351 | activation=activation,
352 | use_scale=use_scale,
353 | use_bias=use_bias,
354 | kernel_initializer=kernel_initializer,
355 | scale_initializer=scale_initializer,
356 | bias_initializer=bias_initializer,
357 | kernel_regularizer=kernel_regularizer,
358 | scale_regularizer=scale_regularizer,
359 | bias_regularizer=bias_regularizer,
360 | activity_regularizer=activity_regularizer,
361 | kernel_constraint=kernel_constraint,
362 | scale_constraint=scale_constraint,
363 | bias_constraint=bias_constraint,
364 | trainable=trainable,
365 | name=name, **kwargs)
366 |
367 |
368 | def conv2d_wnorm(inputs,
369 | filters,
370 | kernel_size,
371 | strides=(1, 1),
372 | padding='valid',
373 | data_format='channels_last',
374 | dilation_rate=(1, 1),
375 | activation=None,
376 | use_scale=True,
377 | use_bias=True,
378 | kernel_initializer=None,
379 | scale_initializer=None,
380 | bias_initializer=init_ops.zeros_initializer(),
381 | kernel_regularizer=None,
382 | scale_regularizer=None,
383 | bias_regularizer=None,
384 | activity_regularizer=None,
385 | kernel_constraint=None,
386 | scale_constraint=None,
387 | bias_constraint=None,
388 | trainable=True,
389 | name=None,
390 | reuse=None):
391 |
392 | layer = Conv2DWNorm(
393 | filters=filters,
394 | kernel_size=kernel_size,
395 | strides=strides,
396 | padding=padding,
397 | data_format=data_format,
398 | dilation_rate=dilation_rate,
399 | activation=activation,
400 | use_scale=use_scale,
401 | use_bias=use_bias,
402 | kernel_initializer=kernel_initializer,
403 | scale_initializer=scale_initializer,
404 | bias_initializer=bias_initializer,
405 | kernel_regularizer=kernel_regularizer,
406 | scale_regularizer=scale_regularizer,
407 | bias_regularizer=bias_regularizer,
408 | activity_regularizer=activity_regularizer,
409 | kernel_constraint=kernel_constraint,
410 | scale_constraint=scale_constraint,
411 | bias_constraint=bias_constraint,
412 | trainable=trainable,
413 | name=name,
414 | dtype=inputs.dtype.base_dtype,
415 | _reuse=reuse,
416 | _scope=name)
417 | return layer.apply(inputs)
418 |
419 | class Conv2DTransposeWNorm(Conv2DWNorm):
420 | def __init__(self, filters,
421 | kernel_size,
422 | strides=(1, 1),
423 | padding='valid',
424 | data_format='channels_last',
425 | activation=None,
426 | use_scale=True,
427 | use_bias=True,
428 | kernel_initializer=None,
429 | scale_initializer=None,
430 | bias_initializer=init_ops.zeros_initializer(),
431 | kernel_regularizer=None,
432 | scale_regularizer=None,
433 | bias_regularizer=None,
434 | activity_regularizer=None,
435 | kernel_constraint=None,
436 | scale_constraint=None,
437 | bias_constraint=None,
438 | trainable=True,
439 | name=None,
440 | **kwargs):
441 | super(Conv2DTransposeWNorm, self).__init__(
442 | filters,
443 | kernel_size,
444 | strides=strides,
445 | padding=padding,
446 | data_format=data_format,
447 | activation=activation,
448 | use_scale=use_scale,
449 | use_bias=use_bias,
450 | kernel_initializer=kernel_initializer,
451 | scale_initializer=scale_initializer,
452 | bias_initializer=bias_initializer,
453 | kernel_regularizer=kernel_regularizer,
454 | scale_regularizer=scale_regularizer,
455 | bias_regularizer=bias_regularizer,
456 | activity_regularizer=activity_regularizer,
457 | kernel_constraint=kernel_constraint,
458 | scale_constraint=scale_constraint,
459 | bias_constraint=bias_constraint,
460 | trainable=trainable,
461 | name=name,
462 | **kwargs)
463 | self.input_spec = base.InputSpec(ndim=4)
464 |
465 | def build(self, input_shape):
466 | if len(input_shape) != 4:
467 | raise ValueError('Inputs should have rank ' +
468 | str(4) +
469 | 'Received input shape:', str(input_shape))
470 | if self.data_format == 'channels_first':
471 | channel_axis = 1
472 | else:
473 | channel_axis = -1
474 | if input_shape[channel_axis] is None:
475 | raise ValueError('The channel dimension of the inputs '
476 | 'should be defined. Found `None`.')
477 | input_dim = input_shape[channel_axis]
478 | self.input_spec = base.InputSpec(ndim=4, axes={channel_axis: input_dim})
479 | kernel_shape = self.kernel_size + (self.filters, input_dim)
480 |
481 | self.kernel = self.add_variable(name='kernel',
482 | shape=kernel_shape,
483 | initializer=self.kernel_initializer,
484 | regularizer=self.kernel_regularizer,
485 | constraint=self.kernel_constraint,
486 | trainable=True,
487 | dtype=self.dtype)
488 |
489 | if self.use_scale:
490 | self.scale = self.add_variable(name='scale',
491 | shape=(self.filters,),
492 | initializer=self.scale_initializer,
493 | regularizer=self.scale_regularizer,
494 | constraint=self.scale_constraint,
495 | trainable=True,
496 | dtype=self.dtype)
497 | else:
498 | self.scale = None
499 |
500 |
501 | if self.use_bias:
502 | self.bias = self.add_variable(name='bias',
503 | shape=(self.filters,),
504 | initializer=self.bias_initializer,
505 | regularizer=self.bias_regularizer,
506 | constraint=self.bias_constraint,
507 | trainable=True,
508 | dtype=self.dtype)
509 | else:
510 | self.bias = None
511 | self.built = True
512 |
513 | def call(self, inputs):
514 | inputs_shape = array_ops.shape(inputs)
515 | batch_size = inputs_shape[0]
516 | if self.data_format == 'channels_first':
517 | c_axis, h_axis, w_axis = 1, 2, 3
518 | else:
519 | c_axis, h_axis, w_axis = 3, 1, 2
520 |
521 | height, width = inputs_shape[h_axis], inputs_shape[w_axis]
522 | kernel_h, kernel_w = self.kernel_size
523 | stride_h, stride_w = self.strides
524 |
525 | # Infer the dynamic output shape:
526 | out_height = utils.deconv_output_length(height,
527 | kernel_h,
528 | self.padding,
529 | stride_h)
530 | out_width = utils.deconv_output_length(width,
531 | kernel_w,
532 | self.padding,
533 | stride_w)
534 | if self.data_format == 'channels_first':
535 | output_shape = (batch_size, self.filters, out_height, out_width)
536 | strides = (1, 1, stride_h, stride_w)
537 | else:
538 | output_shape = (batch_size, out_height, out_width, self.filters)
539 | strides = (1, stride_h, stride_w, 1)
540 |
541 | output_shape_tensor = array_ops.stack(output_shape)
542 |
543 | kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 3])
544 | if self.use_scale:
545 | kernel_norm = tf.reshape(self.scale, [1, 1, self.filters, 1]) * kernel_norm
546 |
547 | outputs = nn.conv2d_transpose(
548 | inputs,
549 | kernel_norm,
550 | output_shape_tensor,
551 | strides,
552 | padding=self.padding.upper(),
553 | data_format=utils.convert_data_format(self.data_format, ndim=4))
554 |
555 | if context.in_graph_mode():
556 | # Infer the static output shape:
557 | out_shape = inputs.get_shape().as_list()
558 | out_shape[c_axis] = self.filters
559 | out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
560 | kernel_h,
561 | self.padding,
562 | stride_h)
563 | out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
564 | kernel_w,
565 | self.padding,
566 | stride_w)
567 | outputs.set_shape(out_shape)
568 |
569 | if self.use_bias:
570 | outputs = nn.bias_add(
571 | outputs,
572 | self.bias,
573 | data_format=utils.convert_data_format(self.data_format, ndim=4))
574 |
575 | if self.activation is not None:
576 | return self.activation(outputs)
577 | return outputs
578 |
579 | def _compute_output_shape(self, input_shape):
580 | input_shape = tensor_shape.TensorShape(input_shape).as_list()
581 | output_shape = list(input_shape)
582 | if self.data_format == 'channels_first':
583 | c_axis, h_axis, w_axis = 1, 2, 3
584 | else:
585 | c_axis, h_axis, w_axis = 3, 1, 2
586 |
587 | kernel_h, kernel_w = self.kernel_size
588 | stride_h, stride_w = self.strides
589 |
590 | output_shape[c_axis] = self.filters
591 | output_shape[h_axis] = utils.deconv_output_length(
592 | output_shape[h_axis], kernel_h, self.padding, stride_h)
593 | output_shape[w_axis] = utils.deconv_output_length(
594 | output_shape[w_axis], kernel_w, self.padding, stride_w)
595 | return tensor_shape.TensorShape(output_shape)
596 |
597 |
598 | def conv2d_transpose_wnorm(
599 | inputs,
600 | filters,
601 | kernel_size,
602 | strides=(1, 1),
603 | padding='valid',
604 | data_format='channels_last',
605 | activation=None,
606 | use_scale=True,
607 | use_bias=True,
608 | kernel_initializer=None,
609 | scale_initializer=None,
610 | bias_initializer=init_ops.zeros_initializer(),
611 | kernel_regularizer=None,
612 | scale_regularizer=None,
613 | bias_regularizer=None,
614 | activity_regularizer=None,
615 | kernel_constraint=None,
616 | scale_constraint=None,
617 | bias_constraint=None,
618 | trainable=True,
619 | name=None,
620 | reuse=None):
621 | layer = Conv2DTransposeWNorm(
622 | filters=filters,
623 | kernel_size=kernel_size,
624 | strides=strides,
625 | padding=padding,
626 | data_format=data_format,
627 | activation=activation,
628 | use_scale=use_scale,
629 | use_bias=use_bias,
630 | kernel_initializer=kernel_initializer,
631 | scale_initializer=scale_initializer,
632 | bias_initializer=bias_initializer,
633 | kernel_regularizer=kernel_regularizer,
634 | scale_regularizer=scale_regularizer,
635 | bias_regularizer=bias_regularizer,
636 | activity_regularizer=activity_regularizer,
637 | kernel_constraint=kernel_constraint,
638 | scale_constraint=scale_constraint,
639 | bias_constraint=bias_constraint,
640 | trainable=trainable,
641 | name=name,
642 | dtype=inputs.dtype.base_dtype,
643 | _reuse=reuse,
644 | _scope=name)
645 | return layer.apply(inputs)
646 |
--------------------------------------------------------------------------------
/results/svhn_cvaegan_epoch_0050_batch_73257.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tatsy/tf-generative/5d7fe9e8a84d0d6f82553fc1eb32c4fdadd0d1b2/results/svhn_cvaegan_epoch_0050_batch_73257.png
--------------------------------------------------------------------------------
/results/svhn_dcgan_epoch_0050_batch_73257.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tatsy/tf-generative/5d7fe9e8a84d0d6f82553fc1eb32c4fdadd0d1b2/results/svhn_dcgan_epoch_0050_batch_73257.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
5 |
6 | import numpy as np
7 | import matplotlib
8 | matplotlib.use('Agg')
9 |
10 | import tensorflow as tf
11 |
12 | from models import *
13 | from datasets import load_data, mnist, svhn
14 |
15 | models = {
16 | 'vae': VAE,
17 | 'dcgan': DCGAN,
18 | 'improved': ImprovedGAN,
19 | 'resnet': ResNetGAN,
20 | 'began': BEGAN,
21 | 'wgan': WGAN,
22 | 'lsgan': LSGAN,
23 | 'cvae': CVAE,
24 | 'cvaegan': CVAEGAN
25 | }
26 |
27 | def main(_):
28 | # Parsing arguments
29 | parser = argparse.ArgumentParser(description='Training GANs or VAEs')
30 | parser.add_argument('--model', type=str, required=True)
31 | parser.add_argument('--dataset', type=str, required=True)
32 | parser.add_argument('--datasize', type=int, default=-1)
33 | parser.add_argument('--epoch', type=int, default=200)
34 | parser.add_argument('--batchsize', type=int, default=50)
35 | parser.add_argument('--output', default='output')
36 | parser.add_argument('--zdims', type=int, default=256)
37 | parser.add_argument('--gpu', type=int, default=0)
38 | parser.add_argument('--resume', type=str, default=None)
39 | parser.add_argument('--testmode', action='store_true')
40 |
41 | args = parser.parse_args()
42 |
43 | # select gpu
44 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
45 |
46 | # Make output direcotiry if not exists
47 | if not os.path.isdir(args.output):
48 | os.mkdir(args.output)
49 |
50 | # Load datasets
51 | if args.dataset == 'mnist':
52 | datasets = mnist.load_data()
53 | elif args.dataset == 'svhn':
54 | datasets = svhn.load_data()
55 | else:
56 | datasets = load_data(args.dataset, args.datasize)
57 |
58 | # Construct model
59 | if args.model not in models:
60 | raise Exception('Unknown model:', args.model)
61 |
62 | model = models[args.model](
63 | batchsize=args.batchsize,
64 | input_shape=datasets.shape[1:],
65 | attr_names=None or datasets.attr_names,
66 | z_dims=args.zdims,
67 | output=args.output,
68 | resume=args.resume
69 | )
70 |
71 | if args.testmode:
72 | model.test_mode = True
73 |
74 | tf.set_random_seed(12345)
75 |
76 | # Training loop
77 | datasets.images = datasets.images.astype('float32') * 2.0 - 1.0
78 | model.main_loop(datasets,
79 | epochs=args.epoch)
80 |
81 | if __name__ == '__main__':
82 | tf.app.run(main)
--------------------------------------------------------------------------------