├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── create_celeba.py ├── datasets.py ├── mnist.py └── svhn.py ├── models ├── __init__.py ├── base.py ├── began.py ├── cvae.py ├── cvaegan.py ├── dcgan.py ├── improved_gan.py ├── lsgan.py ├── resnet_gan.py ├── utils.py ├── vae.py ├── wgan.py └── wnorm.py ├── results ├── svhn_cvaegan_epoch_0050_batch_73257.png └── svhn_dcgan_epoch_0050_batch_73257.png └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/* 2 | datasets/files/* 3 | 4 | __pycache__ 5 | 6 | # Compiled source # 7 | ################### 8 | *.com 9 | *.class 10 | *.dll 11 | *.exe 12 | *.o 13 | *.so 14 | *.pyc 15 | 16 | # Packages # 17 | ############ 18 | # it's better to unpack these files and commit the raw source 19 | # git has its own built in compression methods 20 | *.7z 21 | *.dmg 22 | *.gz 23 | *.iso 24 | *.rar 25 | #*.tar 26 | *.zip 27 | 28 | # Logs and databases # 29 | ###################### 30 | *.log 31 | *.sqlite 32 | 33 | # OS generated files # 34 | ###################### 35 | .DS_Store 36 | ehthumbs.db 37 | Icon 38 | Thumbs.db 39 | .tmtags 40 | .idea 41 | tags 42 | vendor.tags 43 | tmtagsHistory 44 | *.sublime-project 45 | *.sublime-workspace 46 | .bundle 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TensorFlow VAEs and GANs 2 | === 3 | 4 | TensorFlow implementation of various deep generative networks such as VAE and GAN. 5 | 6 | ## Models 7 | 8 | ### Standard models 9 | 10 | * Variational autoencoder (VAE) [Kingma et al. 2013] 11 | * Generative adversarial network (GAN or DCGAN) [Goodfellow et al. 2014] 12 | 13 | 14 | 15 | 16 | ### Conditional models 17 | 18 | * Conditional variational autoencoder [Kingma et al. 2014] 19 | * CVAE-GAN [Bao et al. 2017] 20 | 21 | ## Usage 22 | 23 | ### Prepare datasets 24 | 25 | #### MNIST and SVHN 26 | 27 | MNIST and SVHN datasets are automatically downloaded from their websites. 28 | 29 | #### CelebA 30 | 31 | First, download ``img_align_celeba.zip`` and ``list_attr_celeba.txt`` from CelebA [webpage](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 32 | Then, place these files to ``datasets`` and run ``create_database.py`` on ``databsets`` directory. 33 | 34 | ### Training 35 | 36 | ```shell 37 | # Both standard and conditional models are available! 38 | python train.py --model=dcgan --epoch=200 --batchsize=100 --output=output 39 | ``` 40 | 41 | TensorBoard is also available with the following script. 42 | 43 | ```shell 44 | tensorboard --logdir="output/dcgan/log" 45 | ``` 46 | 47 | ### Results 48 | 49 | #### DCGAN (for SVHN 50 epochs) 50 | 51 | 52 | 53 | #### CVAE-GAN (for SVHN 50 epochs) 54 | 55 | 56 | 57 | ## References 58 | 59 | * Kingma et al., "Auto-Encoding Variational Bayes", arXiv preprint 2013. 60 | * Goodfellow et al., "Generative adversarial nets", NIPS 2014. 61 | 62 | 63 | 64 | * Kingma et al., "Semi-supervised learning with deep generative models", NIPS 2014. 65 | * Bao et al., "CVAE-GAN: Fine-Grained Image Generation through Asymmetric Training", arXiv preprint 2017. 66 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import load_data, Dataset, ConditionalDataset, PairwiseDataset 2 | -------------------------------------------------------------------------------- /datasets/create_celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import zipfile 5 | 6 | import numpy as np 7 | import h5py 8 | 9 | import requests 10 | from PIL import Image 11 | 12 | google_drive_prefix = "https://docs.google.com/uc?export=download" 13 | image_url = 'https://drive.google.com/open?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM' 14 | attr_url = 'https://drive.google.com/open?id=0B7EVK8r0v71pblRyaVFSWGxPY0U' 15 | 16 | target_dir = os.path.join(os.path.dirname(__file__), 'files') 17 | outfile = os.path.join(target_dir, 'celebA.hdf5') 18 | image_file = 'img_align_celeba.zip' 19 | attr_file = 'list_attr_celeba.txt' 20 | 21 | def get_confirm_token(response): 22 | for key, value in response.cookies.items(): 23 | if key.startswith('download_warning'): 24 | return value 25 | 26 | return None 27 | 28 | def save_response_content(response, destination): 29 | CHUNK_SIZE = 32768 30 | PROGBAR_WIDTH = 50 31 | 32 | with open(destination, "wb") as f: 33 | dl = 0 34 | for chunk in response.iter_content(CHUNK_SIZE): 35 | if chunk: 36 | dl += len(chunk) 37 | f.write(chunk) 38 | 39 | mb = dl / 1.0e6 40 | sys.stdout.write('\r%.2f MB downloaded...' % mb) 41 | sys.stdout.flush() 42 | 43 | sys.stdout.write('\nFinish!\n') 44 | sys.stdout.flush() 45 | 46 | def download_from_google_drive(url, dest): 47 | pat = re.compile('id=([a-zA-Z0-9]+)') 48 | mat = pat.search(url) 49 | if mat is None: 50 | raise Exception('Invalide url:', url) 51 | 52 | idx = mat.group(1) 53 | 54 | session = requests.Session() 55 | 56 | response = session.get(google_drive_prefix, params={'id': idx}, stream=True) 57 | token = get_confirm_token(response) 58 | 59 | if token: 60 | params = {'id': idx, 'confirm': token} 61 | response = session.get(google_drive_prefix, params=params, stream=True) 62 | 63 | print('Downloading:', url) 64 | save_response_content(response, dest) 65 | 66 | def main(): 67 | # Download image ZIP 68 | if os.path.exists(image_file): 69 | print('Image ZIP file exists. Skip downloading.') 70 | else: 71 | download_from_google_drive(image_url, image_file) 72 | 73 | # Download attribute file 74 | if os.path.exists(attr_file): 75 | print('Attribute file exists. Skip downloading.') 76 | else: 77 | download_from_google_drive(attr_url, attr_file) 78 | 79 | # Create folder 80 | if not os.path.isdir(target_dir): 81 | os.mkdir(target_dir) 82 | 83 | # Parse labels 84 | with open(attr_file, 'r') as lines: 85 | lines = [l.strip() for l in lines] 86 | num_images = int(lines[0]) 87 | 88 | label_names = re.split('\s+', lines[1]) 89 | label_names = np.array(label_names, dtype=object) 90 | num_labels = len(label_names) 91 | 92 | lines = lines[2:] 93 | labels = np.ndarray((num_images, num_labels), dtype='uint8') 94 | for i in range(num_images): 95 | label = [int(l) for l in re.split('\s+', lines[i])[1:]] 96 | label = np.maximum(0, label).astype(np.uint8) 97 | labels[i] = label 98 | 99 | ## Parse images 100 | with zipfile.ZipFile(image_file, 'r', zipfile.ZIP_DEFLATED) as zf: 101 | image_files = [f for f in zf.namelist()] 102 | image_files = sorted(image_files) 103 | image_files = list(filter(lambda f: f.endswith('.jpg'), image_files)) 104 | 105 | num_images = len(image_files) 106 | print('%d images' % (num_images)) 107 | 108 | image_data = np.ndarray((num_images, 64, 64, 3), dtype='uint8') 109 | for i, f in enumerate(image_files): 110 | image = Image.open(zf.open(f, 'r')).resize((64, 78), Image.ANTIALIAS).crop((0, 7, 64, 64 + 7)) 111 | image = np.asarray(image, dtype='uint8') 112 | image_data[i] = image 113 | print('%d / %d' % (i + 1, num_images), end='\r', flush=True) 114 | 115 | # Create HDF5 file 116 | h5 = h5py.File(outfile, 'w') 117 | string_dt = h5py.special_dtype(vlen=str) 118 | dset = h5.create_dataset('images', data=image_data, dtype='uint8') 119 | dset = h5.create_dataset('attr_names', data=label_names, dtype=string_dt) 120 | dset = h5.create_dataset('attrs', data=labels, dtype='uint8') 121 | 122 | h5.flush() 123 | h5.close() 124 | 125 | # Delete files 126 | os.remove(image_file) 127 | os.remove(attr_file) 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | 4 | class Dataset(object): 5 | def __init__(self): 6 | self.images = None 7 | 8 | def __len__(self): 9 | return len(self.images) 10 | 11 | def _get_shape(self): 12 | return self.images.shape 13 | 14 | shape = property(_get_shape) 15 | 16 | class ConditionalDataset(Dataset): 17 | def __init__(self): 18 | super(ConditionalDataset, self).__init__() 19 | self.attrs = None 20 | self.attr_names = None 21 | 22 | class PairwiseDataset(object): 23 | def __init__(self, x_data, y_data): 24 | assert x_data.shape[1] == y_data.shape[1] 25 | assert x_data.shape[2] == y_data.shape[2] 26 | assert x_data.shape[3] == 1 or y_data.shape[3] == 1 or \ 27 | x_data.shape[3] == y_data.shape[3] 28 | 29 | if x_data.shape[3] != y_data.shape[3]: 30 | d = max(x_data.shape[3], y_data.shape[3]) 31 | if x_data.shape[3] != d: 32 | x_data = np.tile(x_data, [1, 1, 1, d]) 33 | if y_data.shape[3] != d: 34 | y_Data = np.tile(y_data, [1, 1, 1, d]) 35 | 36 | x_len = len(x_data) 37 | y_len = len(y_data) 38 | l = min(x_len, y_len) 39 | 40 | self.x_data = x_data[:l] 41 | self.y_data = y_data[:l] 42 | 43 | def __len__(self): 44 | return len(self.x_data) 45 | 46 | def _get_shape(self): 47 | return self.x_data.shape 48 | 49 | shape = property(_get_shape) 50 | 51 | def load_data(filename, size=-1): 52 | f = h5py.File(filename) 53 | 54 | dset = ConditionalDataset() 55 | dset.images = np.asarray(f['images'], 'float32') / 255.0 56 | dset.attrs = np.asarray(f['attrs'], 'float32') 57 | dset.attr_names = np.asarray(f['attr_names']) 58 | 59 | if size > 0: 60 | dset.images = dset.images[:size] 61 | dset.attrs = dset.attrs[:size] 62 | 63 | return dset 64 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import gzip 4 | import struct 5 | import requests 6 | 7 | import numpy as np 8 | 9 | import tensorflow as tf 10 | 11 | from .datasets import ConditionalDataset 12 | url = 'http://yann.lecun.com/exdb/mnist/' 13 | x_train_file = 'train-images-idx3-ubyte.gz' 14 | y_train_file = 'train-labels-idx1-ubyte.gz' 15 | x_test_file = 't10k-images-idx3-ubyte.gz' 16 | y_test_file = 't10k-labels-idx1-ubyte.gz' 17 | 18 | curdir = os.path.abspath(os.path.dirname(__file__)) 19 | outdir = os.path.join(curdir, 'files', 'mnist') 20 | 21 | CHUNK_SIZE = 32768 22 | 23 | def download_mnist(): 24 | if not os.path.exists(outdir): 25 | os.makedirs(outdir) 26 | 27 | # Download files 28 | files = [x_train_file, y_train_file, x_test_file, y_test_file] 29 | for f in files: 30 | session = requests.Session() 31 | response = session.get(os.path.join(url, f), stream=True) 32 | print('Downloading: %s' % (os.path.join(url, f))) 33 | with open(os.path.join(outdir, f), 'wb') as fp: 34 | dl = 0 35 | for chunk in response.iter_content(CHUNK_SIZE): 36 | if chunk: 37 | dl += len(chunk) 38 | fp.write(chunk) 39 | 40 | mb = dl / 1.0e6 41 | sys.stdout.write('\r%.2f MB downloaded...' % (mb)) 42 | sys.stdout.flush() 43 | 44 | sys.stdout.write('\nFinish!\n') 45 | sys.stdout.flush() 46 | 47 | def load_images(filename): 48 | with gzip.GzipFile(filename, 'rb') as fp: 49 | # Magic number 50 | magic = struct.unpack('>I', fp.read(4))[0] 51 | 52 | # item sizes 53 | n, rows, cols = struct.unpack('>III', fp.read(4 * 3)) 54 | 55 | # Load items 56 | data = np.ndarray((n, rows, cols), dtype=np.uint8) 57 | for i in range(n): 58 | sub = struct.unpack('B' * rows * cols, fp.read(rows * cols)) 59 | data[i] = np.asarray(sub).reshape((rows, cols)) 60 | 61 | return data 62 | 63 | def load_labels(filename): 64 | with gzip.GzipFile(filename, 'rb') as fp: 65 | # Magic number 66 | magic = struct.unpack('>I', fp.read(4)) 67 | 68 | # item sizes 69 | n= struct.unpack('>I', fp.read(4))[0] 70 | 71 | # Load items 72 | data = np.zeros((n, 10), dtype=np.uint8) 73 | for i in range(n): 74 | b = struct.unpack('>B', fp.read(1))[0] 75 | data[i, b] = 1 76 | 77 | return data 78 | 79 | def load_data(): 80 | if not os.path.exists(outdir): 81 | download_mnist() 82 | 83 | x_train = load_images(os.path.join(outdir, x_train_file)) 84 | y_train = load_labels(os.path.join(outdir, y_train_file)) 85 | 86 | x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=0) 87 | x_train = (x_train[:, :, :, np.newaxis] / 255.0).astype('float32') 88 | y_train = y_train.astype('float32') 89 | 90 | datasets = ConditionalDataset() 91 | datasets.images = x_train 92 | datasets.attrs = y_train 93 | datasets.attr_names = [str(i) for i in range(10)] 94 | 95 | return datasets 96 | -------------------------------------------------------------------------------- /datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | 5 | import numpy as np 6 | import scipy as sp 7 | import scipy.io 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | import tensorflow as tf 12 | 13 | from .datasets import ConditionalDataset 14 | 15 | url = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat' 16 | curdir = os.path.abspath(os.path.dirname(__file__)) 17 | outdir = os.path.join(curdir, 'files', 'svhn') 18 | outfile = os.path.join(outdir, 'svhn.mat') 19 | 20 | CHUNK_SIZE = 32768 21 | 22 | def download_svhn(): 23 | if not os.path.exists(outdir): 24 | os.makedirs(outdir) 25 | 26 | session = requests.Session() 27 | response = session.get(url, stream=True) 28 | print('Downloading: %s' % (url)) 29 | with open(outfile, 'wb') as fp: 30 | dl = 0 31 | for chunk in response.iter_content(CHUNK_SIZE): 32 | if chunk: 33 | dl += len(chunk) 34 | fp.write(chunk) 35 | 36 | mb = dl / 1.0e6 37 | sys.stdout.write('\r%.2f MB downloaded...' % (mb)) 38 | sys.stdout.flush() 39 | 40 | sys.stdout.write('\nFinish!\n') 41 | sys.stdout.flush() 42 | 43 | def load_data(): 44 | if not os.path.exists(outfile): 45 | download_svhn() 46 | 47 | mat = sp.io.loadmat(outfile) 48 | x_train = mat['X'] 49 | 50 | x_train = np.transpose(x_train, axes=[3, 0, 1, 2]) 51 | x_train = (x_train / 255.0).astype('float32') 52 | 53 | indices = mat['y'] 54 | indices = np.squeeze(indices) 55 | indices[indices == 10] = 0 56 | y_train = np.zeros((len(indices), 10)) 57 | y_train[np.arange(len(indices)), indices] = 1 58 | y_train = y_train.astype('float32') 59 | 60 | datasets = ConditionalDataset() 61 | datasets.images = x_train 62 | datasets.attrs = y_train 63 | datasets.attr_names = [str(i) for i in range(10)] 64 | 65 | return datasets 66 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseModel, CondBaseModel 2 | 3 | from .vae import VAE 4 | from .dcgan import DCGAN 5 | from .improved_gan import ImprovedGAN 6 | from .resnet_gan import ResNetGAN 7 | from .began import BEGAN 8 | from .wgan import WGAN 9 | from .lsgan import LSGAN 10 | 11 | from .cvae import CVAE 12 | from .cvaegan import CVAEGAN -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | import matplotlib.gridspec as gridspec 12 | 13 | import tensorflow as tf 14 | 15 | from abc import ABCMeta, abstractmethod 16 | from .utils import * 17 | 18 | class BaseModel(metaclass=ABCMeta): 19 | """ 20 | Base class for non-conditional generative networks 21 | """ 22 | 23 | def __init__(self, **kwargs): 24 | """ 25 | Initialization 26 | """ 27 | if 'name' not in kwargs: 28 | raise Exception('Please specify model name!') 29 | self.name = kwargs['name'] 30 | 31 | if 'batchsize' not in kwargs: 32 | raise Exception('Please specify batchsize!') 33 | self.batchsize = kwargs['batchsize'] 34 | 35 | if 'input_shape' not in kwargs: 36 | raise Exception('Please specify input shape!') 37 | 38 | self.check_input_shape(kwargs['input_shape']) 39 | self.input_shape = kwargs['input_shape'] 40 | 41 | if 'output' not in kwargs: 42 | self.output = 'output' 43 | else: 44 | self.output = kwargs['output'] 45 | 46 | self.resume = kwargs['resume'] 47 | 48 | self.sess = tf.Session() 49 | self.writer = None 50 | self.saver = None 51 | self.summary = None 52 | 53 | self.test_size = 10 54 | self.test_data = None 55 | 56 | self.test_mode = False 57 | 58 | def check_input_shape(self, input_shape): 59 | # Check for CelebA 60 | if input_shape == (64, 64, 3): 61 | return 62 | 63 | # Check for MNIST (size modified) 64 | if input_shape == (32, 32, 1): 65 | return 66 | 67 | # Check for Cifar10, 100 etc 68 | if input_shape == (32, 32, 3): 69 | return 70 | 71 | errmsg = 'Input size should be 32 x 32 or 64 x 64!' 72 | raise Exception(errmsg) 73 | 74 | def main_loop(self, datasets, epochs=100): 75 | """ 76 | Main learning loop 77 | """ 78 | # Create output directories if not exist 79 | out_dir = os.path.join(self.output, self.name) 80 | if not os.path.isdir(out_dir): 81 | os.makedirs(out_dir) 82 | 83 | res_out_dir = os.path.join(out_dir, 'results') 84 | if not os.path.isdir(res_out_dir): 85 | os.makedirs(res_out_dir) 86 | 87 | chk_out_dir = os.path.join(out_dir, 'checkpoints') 88 | if not os.path.isdir(chk_out_dir): 89 | os.makedirs(chk_out_dir) 90 | 91 | time_str = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 92 | log_out_dir = os.path.join(out_dir, 'log', time_str) 93 | if not os.path.isdir(log_out_dir): 94 | os.makedirs(log_out_dir) 95 | 96 | # Make test data 97 | self.make_test_data() 98 | 99 | # Start training 100 | with self.sess.as_default(): 101 | current_epoch = tf.Variable(0, name='current_epoch', dtype=tf.int32) 102 | current_batch = tf.Variable(0, name='current_batch', dtype=tf.int32) 103 | 104 | # Initialize global variables 105 | self.saver = tf.train.Saver() 106 | if self.resume is not None: 107 | print('Resume training: %s' % self.resume) 108 | self.load_model(self.resume) 109 | else: 110 | self.sess.run(tf.global_variables_initializer()) 111 | self.sess.run(tf.local_variables_initializer()) 112 | 113 | # Update rule 114 | num_data = len(datasets) 115 | update_epoch = current_epoch.assign(current_epoch + 1) 116 | update_batch = current_batch.assign(tf.mod(tf.minimum(current_batch + self.batchsize, num_data), num_data)) 117 | 118 | self.writer = tf.summary.FileWriter(log_out_dir, self.sess.graph) 119 | self.sess.graph.finalize() 120 | 121 | print('\n\n--- START TRAINING ---\n') 122 | for e in range(current_epoch.eval(), epochs): 123 | perm = np.random.permutation(num_data) 124 | start_time = time.time() 125 | for b in range(current_batch.eval(), num_data, self.batchsize): 126 | # Update batch index 127 | self.sess.run(update_batch) 128 | 129 | # Check batch size 130 | bsize = min(self.batchsize, num_data - b) 131 | indx = perm[b:b+bsize] 132 | if bsize < self.batchsize: 133 | break 134 | 135 | # Get batch and train on it 136 | x_batch = self.make_batch(datasets, indx) 137 | losses = self.train_on_batch(x_batch, e * num_data + (b + bsize)) 138 | 139 | # Print current status 140 | elapsed_time = time.time() - start_time 141 | eta = elapsed_time / (b + bsize) * (num_data - (b + bsize)) 142 | ratio = 100.0 * (b + bsize) / num_data 143 | print('Epoch #%d, Batch: %d / %d (%6.2f %%) ETA: %s' % \ 144 | (e + 1, b + bsize, num_data, ratio, time_format(eta))) 145 | 146 | for i, (k, v) in enumerate(losses): 147 | text = '%s = %8.6f' % (k, v) 148 | print(' %25s' % (text), end='') 149 | if (i + 1) % 3 == 0: 150 | print('') 151 | 152 | print('\n') 153 | sys.stdout.flush() 154 | 155 | # Save generated images 156 | save_period = 10000 157 | if b != 0 and ((b // save_period != (b + bsize) // save_period) or ((b + bsize) == num_data)): 158 | outfile = os.path.join(res_out_dir, 'epoch_%04d_batch_%d.png' % (e + 1, b + bsize)) 159 | self.save_images(outfile) 160 | outfile = os.path.join(chk_out_dir, 'epoch_%04d' % (e + 1)) 161 | self.save_model(outfile) 162 | 163 | if self.test_mode: 164 | print('\nFinish testing: %s' % self.name) 165 | return 166 | 167 | print('') 168 | self.sess.run(update_epoch) 169 | 170 | def make_batch(self, datasets, indx): 171 | """ 172 | Get batch from datasets 173 | """ 174 | return datasets.images[indx] 175 | 176 | def save_images(self, filename): 177 | """ 178 | Save images generated from random sample numbers 179 | """ 180 | imgs = self.predict(self.test_data) * 0.5 + 0.5 181 | imgs = np.clip(imgs, 0.0, 1.0) 182 | if imgs.shape[3] == 1: 183 | imgs = np.squeeze(imgs, axis=(3,)) 184 | 185 | _, height, width, dims = imgs.shape 186 | 187 | margin = min(width, height) // 10 188 | figure = np.ones(((margin + height) * 10 + margin, (margin + width) * 10 + margin, dims), np.float32) 189 | 190 | for i in range(100): 191 | row = i // 10 192 | col = i % 10 193 | 194 | y = margin + (margin + height) * row 195 | x = margin + (margin + width) * col 196 | figure[y:y+height, x:x+width, :] = imgs[i, :, :, :] 197 | 198 | figure = Image.fromarray((figure * 255.0).astype(np.uint8)) 199 | figure.save(filename) 200 | 201 | def save_model(self, model_file): 202 | self.saver.save(self.sess, model_file) 203 | 204 | def load_model(self, model_file): 205 | self.saver.restore(self.sess, model_file) 206 | 207 | @abstractmethod 208 | def make_test_data(self): 209 | """ 210 | Please override "make_test_data" method in the derived model! 211 | """ 212 | pass 213 | 214 | @abstractmethod 215 | def predict(self, z_sample): 216 | """ 217 | Please override "predict" method in the derived model! 218 | """ 219 | pass 220 | 221 | @abstractmethod 222 | def train_on_batch(self, x_batch, index): 223 | """ 224 | Please override "train_on_batch" method in the derived model! 225 | """ 226 | pass 227 | 228 | def image_tiling(self, images, rows, cols): 229 | n_images = rows * cols 230 | mg = max(self.input_shape[0], self.input_shape[1]) // 20 231 | pad_img = tf.pad(images, [[0, 0], [mg, mg], [mg, mg], [0, 0]], constant_values=1.0) 232 | img_arr = tf.split(pad_img, n_images, 0) 233 | 234 | rows = [] 235 | for i in range(self.test_size): 236 | rows.append(tf.concat(img_arr[i * cols: (i + 1) * cols], axis=2)) 237 | 238 | tile = tf.concat(rows, axis=1) 239 | return tile 240 | 241 | class CondBaseModel(BaseModel): 242 | def __init__(self, **kwargs): 243 | super(CondBaseModel, self).__init__(**kwargs) 244 | 245 | if 'attr_names' not in kwargs: 246 | raise Exception('Please specify attribute names (attr_names') 247 | self.attr_names = kwargs['attr_names'] 248 | self.num_attrs = len(self.attr_names) 249 | 250 | self.test_size = 10 251 | 252 | def make_batch(self, datasets, indx): 253 | images = datasets.images[indx] 254 | attrs = datasets.attrs[indx] 255 | return images, attrs 256 | 257 | def save_images(self, filename): 258 | assert self.attr_names is not None 259 | 260 | try: 261 | test_samples = self.test_data['z_test'] 262 | except KeyError as e: 263 | print('Key "z_test" must be provided in "make_test_data" method!') 264 | raise e 265 | 266 | try: 267 | test_attrs = self.test_data['c_test'] 268 | except KeyError as e: 269 | print('Key "c_test" must be provided in "make_test_data" method!') 270 | raise e 271 | 272 | imgs = self.predict([test_samples, test_attrs]) * 0.5 + 0.5 273 | imgs = np.clip(imgs, 0.0, 1.0) 274 | 275 | _, height, width, dims = imgs.shape 276 | 277 | margin = min(width, height) // 10 278 | figure = np.ones(((margin + height) * self.test_size + margin, (margin + width) * self.num_attrs + margin, dims), np.float32) 279 | 280 | for i in range(self.test_size * self.num_attrs): 281 | row = i // self.num_attrs 282 | col = i % self.num_attrs 283 | 284 | y = margin + (margin + height) * row 285 | x = margin + (margin + width) * col 286 | figure[y:y+height, x:x+width, :] = imgs[i, :, :, :] 287 | 288 | figure = Image.fromarray((figure * 255.0).astype(np.uint8)) 289 | figure.save(filename) 290 | -------------------------------------------------------------------------------- /models/began.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import BaseModel 5 | from .utils import * 6 | 7 | def repelling_regularizer(x, batchsize): 8 | dims = x.get_shape()[1] 9 | S_i = tf.tile(x, [batchsize, 1]) 10 | S_j = tf.tile(x, [1, batchsize]) 11 | S_j = tf.reshape(S_j, [-1, dims]) 12 | S_i_T_S_j = tf.reduce_sum(tf.multiply(S_i, S_j), axis=1) 13 | S_i_norm2 = tf.reduce_sum(tf.square(S_i), axis=1) 14 | S_j_norm2 = tf.reduce_sum(tf.square(S_j), axis=1) 15 | f_PT = tf.square(S_i_T_S_j) / (tf.multiply(S_i_norm2, S_j_norm2) + 1.0e-8) 16 | f_PT = tf.reduce_sum(f_PT) / tf.cast(batchsize * (batchsize - 1), 'float32') 17 | return f_PT 18 | 19 | class Generator(object): 20 | def __init__(self, input_shape, z_dims): 21 | self.variables = None 22 | self.update_ops = None 23 | self.reuse = False 24 | self.input_shape = input_shape 25 | self.z_dims = z_dims 26 | self.name = 'generator' 27 | 28 | def __call__(self, inputs, training=True): 29 | with tf.variable_scope(self.name, reuse=self.reuse): 30 | with tf.variable_scope('deconv1'): 31 | w = self.input_shape[0] // (2 ** 3) 32 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 33 | x = tf.layers.conv2d_transpose(x, 512, (w, w), (1, 1), 'valid', 34 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 35 | x = tf.layers.batch_normalization(x, training=training) 36 | x = tf.nn.relu(x) 37 | 38 | with tf.variable_scope('deconv2'): 39 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', 40 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 41 | x = tf.layers.batch_normalization(x, training=training) 42 | x = tf.nn.relu(x) 43 | 44 | with tf.variable_scope('deconv3'): 45 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', 46 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 47 | x = tf.layers.batch_normalization(x, training=training) 48 | x = tf.nn.relu(x) 49 | 50 | with tf.variable_scope('deconv4'): 51 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', 52 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 53 | x = tf.layers.batch_normalization(x, training=training) 54 | x = tf.nn.relu(x) 55 | 56 | with tf.variable_scope('deconv5'): 57 | d = self.input_shape[2] 58 | x = tf.layers.conv2d(x, d, (5, 5), (1, 1), 'same', 59 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 60 | x = tf.tanh(x) 61 | 62 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 63 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name) 64 | self.reuse = True 65 | return x 66 | 67 | class Discriminator(object): 68 | def __init__(self, input_shape): 69 | self.input_shape = input_shape 70 | self.variables = None 71 | self.update_ops = None 72 | self.reuse = False 73 | self.name = 'discriminator' 74 | 75 | def __call__(self, inputs, training=True): 76 | with tf.variable_scope(self.name, reuse=self.reuse): 77 | with tf.variable_scope('conv1'): 78 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', 79 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 80 | x = tf.layers.batch_normalization(x, training=training) 81 | x = lrelu(x) 82 | 83 | with tf.variable_scope('conv2'): 84 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', 85 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 86 | x = tf.layers.batch_normalization(x, training=training) 87 | x = lrelu(x) 88 | 89 | with tf.variable_scope('conv3'): 90 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', 91 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 92 | x = tf.layers.batch_normalization(x, training=training) 93 | x = lrelu(x) 94 | 95 | S = tf.contrib.layers.flatten(x) 96 | 97 | with tf.variable_scope('deconv1'): 98 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', 99 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 100 | x = tf.layers.batch_normalization(x, training=training) 101 | x = lrelu(x) 102 | 103 | with tf.variable_scope('deconv2'): 104 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', 105 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 106 | x = tf.layers.batch_normalization(x, training=training) 107 | x = lrelu(x) 108 | 109 | with tf.variable_scope('deconv3'): 110 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', 111 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 112 | x = tf.layers.batch_normalization(x, training=training) 113 | x = lrelu(x) 114 | 115 | with tf.variable_scope('deconv4'): 116 | d = self.input_shape[2] 117 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same', 118 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 119 | x = tf.tanh(x) 120 | 121 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 122 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name) 123 | self.reuse = True 124 | return x, S 125 | 126 | class BEGAN(BaseModel): 127 | def __init__(self, 128 | input_shape=(64, 64, 3), 129 | z_dims = 128, 130 | name='began', 131 | **kwargs 132 | ): 133 | super(BEGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 134 | 135 | self.z_dims = z_dims 136 | 137 | self.beta = 0.01 138 | self.boundary_equil = True 139 | self.margin = 0.1 140 | self.update_k_t = None 141 | self.k_t = tf.Variable(0.5, name='k_t') 142 | self.lambda_k = 1.0e-4 143 | self.gamma = 0.7 144 | 145 | self.gen_trainer = None 146 | self.dis_trainer = None 147 | self.gen_loss_D = None 148 | self.gen_loss_G = None 149 | self.dis_loss = None 150 | 151 | self.f_gen = None 152 | self.f_dis = None 153 | 154 | self.x_train = None 155 | self.z_D = None 156 | self.z_G = None 157 | 158 | self.z_test = None 159 | self.x_test = None 160 | self.x_tile = None 161 | 162 | self.train_op = None 163 | 164 | self.build_model() 165 | 166 | def train_on_batch(self, x_batch, index): 167 | batchsize = x_batch.shape[0] 168 | z_D = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 169 | z_G = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 170 | 171 | # Training 172 | _, g_loss, _, d_loss = self.sess.run( 173 | (self.train_op, self.gen_loss_G, self.gen_loss_D, self.dis_loss), 174 | feed_dict={ 175 | self.x_train: x_batch, 176 | self.z_G: z_G, 177 | self.z_D: z_D, 178 | } 179 | ) 180 | 181 | # Summary update 182 | if index // 1000 != (index - batchsize) // 1000: 183 | summary = self.sess.run( 184 | self.summary, 185 | feed_dict={ 186 | self.x_train: x_batch, 187 | self.z_D: z_D, 188 | self.z_G: z_G, 189 | self.z_test: self.test_data 190 | } 191 | ) 192 | self.writer.add_summary(summary, index) 193 | 194 | return [ 195 | ('g_loss', g_loss), 196 | ('d_loss', d_loss) 197 | ] 198 | 199 | def predict(self, z_samples): 200 | x_sample = self.sess.run( 201 | (self.x_test), 202 | feed_dict={self.z_test: z_samples} 203 | ) 204 | return x_sample 205 | 206 | def make_test_data(self): 207 | self.test_data = np.random.uniform(-1, 1, size=(self.test_size * self.test_size, self.z_dims)) 208 | 209 | def build_model(self): 210 | # Trainer 211 | self.f_dis = Discriminator(self.input_shape) 212 | self.f_gen = Generator(self.input_shape, self.z_dims) 213 | 214 | x_shape = (self.batchsize,) + self.input_shape 215 | self.x_train = tf.placeholder(tf.float32, shape=(self.batchsize,) + self.input_shape, name='x_train') 216 | 217 | z_shape = (self.batchsize, self.z_dims) 218 | self.z_D = tf.placeholder(tf.float32, shape=z_shape, name='z_D') 219 | self.z_G = tf.placeholder(tf.float32, shape=z_shape, name='z_G') 220 | 221 | x_f_D = self.f_gen(self.z_D) 222 | x_f_D_pred, _ = self.f_dis(x_f_D) 223 | 224 | x_f_G = self.f_gen(self.z_G) 225 | x_f_G_pred, S = self.f_dis(x_f_G) 226 | 227 | x_train_pred, _ = self.f_dis(self.x_train) 228 | 229 | f_PT = repelling_regularizer(S, self.batchsize) 230 | 231 | self.gen_loss_D = tf.losses.absolute_difference(x_f_D, x_f_D_pred) 232 | self.gen_loss_G = tf.losses.absolute_difference(x_f_G, x_f_G_pred) 233 | self.dis_loss = tf.losses.absolute_difference(self.x_train, x_train_pred) 234 | 235 | gen_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 236 | dis_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 237 | 238 | if self.boundary_equil: 239 | self.gen_trainer = gen_opt.minimize(self.gen_loss_G + self.beta * f_PT, var_list=self.f_gen.variables) 240 | self.dis_trainer = dis_opt.minimize(self.dis_loss - self.k_t * self.gen_loss_D, var_list=self.f_dis.variables) 241 | self.update_k_t = self.k_t.assign(tf.clip_by_value(self.k_t + self.lambda_k * (self.gamma * self.dis_loss - self.gen_loss_D), 0.0, 1.0)) 242 | 243 | with tf.control_dependencies([self.gen_trainer, self.dis_trainer, self.update_k_t] + \ 244 | self.f_dis.update_ops + self.f_gen.update_ops): 245 | self.train_op = tf.no_op(name='train') 246 | 247 | else: 248 | self.gen_trainer = gen_opt.minimize(self.gen_loss_G + self.beta * f_PT, var_list=self.f_gen.variables) 249 | self.dis_trainer = dis_opt.minimize(self.dis_loss - tf.maximum(0.0, self.margin - self.gen_loss_D), var_list=self.f_dis.variables) 250 | 251 | with tf.control_dependencies([self.gen_trainer, self.dis_trainer] + \ 252 | self.f_dis.update_ops + self.f_gen.update_ops): 253 | self.train_op = tf.no_op(name='train') 254 | 255 | # Predictor 256 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 257 | self.x_test = self.f_gen(self.z_test) 258 | self.x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 259 | 260 | tf.summary.image('x_real', image_cast(self.x_train), 10) 261 | tf.summary.image('x_real_rec', image_cast(x_train_pred), 10) 262 | tf.summary.image('x_fake', image_cast(x_f_G), 10) 263 | tf.summary.image('x_fake_rec', image_cast(x_f_G_pred), 10) 264 | tf.summary.image('x_tile', image_cast(self.x_tile), 1) 265 | tf.summary.scalar('gen_loss', self.gen_loss_G) 266 | tf.summary.scalar('dis_loss', self.dis_loss) 267 | tf.summary.scalar('k_t', self.k_t) 268 | 269 | self.summary = tf.summary.merge_all() 270 | -------------------------------------------------------------------------------- /models/cvae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import CondBaseModel 5 | from .utils import * 6 | 7 | class Encoder(object): 8 | def __init__(self, input_shape, z_dims, num_attrs): 9 | self.variables = None 10 | self.reuse = False 11 | self.input_shape = input_shape 12 | self.z_dims = z_dims 13 | self.num_attrs = num_attrs 14 | 15 | def __call__(self, inputs, attrs, training=True): 16 | with tf.variable_scope('encoder', reuse=self.reuse): 17 | with tf.variable_scope('conv1'): 18 | a = tf.reshape(attrs, [-1, 1, 1, self.num_attrs]) 19 | a = tf.tile(a, [1, self.input_shape[0], self.input_shape[1], 1]) 20 | x = tf.concat([inputs, a], axis=-1) 21 | x = tf.layers.conv2d(x, 64, (5, 5), (2, 2), 'same') 22 | x = tf.layers.batch_normalization(x, training=training) 23 | x = tf.nn.relu(x) 24 | 25 | with tf.variable_scope('conv2'): 26 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same') 27 | x = tf.layers.batch_normalization(x, training=training) 28 | x = tf.nn.relu(x) 29 | 30 | with tf.variable_scope('conv3'): 31 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same') 32 | x = tf.layers.batch_normalization(x, training=training) 33 | x = tf.nn.relu(x) 34 | 35 | with tf.variable_scope('global_average'): 36 | x = tf.reduce_mean(x, axis=[1, 2]) 37 | 38 | with tf.variable_scope('fc1'): 39 | z_avg = tf.layers.dense(x, self.z_dims) 40 | z_log_var = tf.layers.dense(x, self.z_dims) 41 | 42 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder') 43 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='encoder') 44 | self.reuse = True 45 | 46 | return z_avg, z_log_var 47 | 48 | class Decoder(object): 49 | def __init__(self, input_shape): 50 | self.variables = None 51 | self.reuse = False 52 | self.input_shape = input_shape 53 | 54 | def __call__(self, inputs, attrs, training=True): 55 | with tf.variable_scope('decoder', reuse=self.reuse): 56 | with tf.variable_scope('fc1'): 57 | w = self.input_shape[0] // (2 ** 3) 58 | x = tf.concat([inputs, attrs], axis=-1) 59 | x = tf.layers.dense(x, w * w * 256) 60 | x = tf.layers.batch_normalization(x, training=training) 61 | x = tf.nn.relu(x) 62 | x = tf.reshape(x, [-1, w, w, 256]) 63 | 64 | with tf.variable_scope('conv1'): 65 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same') 66 | x = tf.layers.batch_normalization(x, training=training) 67 | x = lrelu(x) 68 | 69 | with tf.variable_scope('conv2'): 70 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same') 71 | x = tf.layers.batch_normalization(x, training=training) 72 | x = lrelu(x) 73 | 74 | with tf.variable_scope('conv3'): 75 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same') 76 | x = tf.layers.batch_normalization(x, training=training) 77 | x = lrelu(x) 78 | 79 | with tf.variable_scope('conv4'): 80 | d = self.input_shape[2] 81 | x = tf.layers.conv2d_transpose(x, d, (3, 3), (1, 1), 'same') 82 | x = tf.tanh(x) 83 | 84 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder') 85 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='decoder') 86 | self.reuse = True 87 | 88 | return x 89 | 90 | class CVAE(CondBaseModel): 91 | def __init__(self, 92 | input_shape=(64, 64, 3), 93 | z_dims = 128, 94 | name='cvae', 95 | **kwargs 96 | ): 97 | super(CVAE, self).__init__(input_shape=input_shape, name=name, **kwargs) 98 | 99 | self.z_dims = z_dims 100 | 101 | self.total_loss = None 102 | self.optimizer = None 103 | self.train_op = None 104 | 105 | self.encoder = None 106 | self.decoder = None 107 | 108 | self.x_train = None 109 | self.c_train = None 110 | 111 | self.z_test = None 112 | self.x_test = None 113 | self.c_test = None 114 | 115 | self.build_model() 116 | 117 | def train_on_batch(self, batch, index): 118 | x_batch, c_batch = batch 119 | 120 | _, loss, summary = self.sess.run( 121 | (self.train_op, self.total_loss, self.summary), 122 | feed_dict={self.x_train: x_batch, self.c_train: c_batch, self.z_test: self.test_data['z_test'], self.c_test: self.test_data['c_test']} 123 | ) 124 | 125 | self.writer.add_summary(summary, index) 126 | return [ ('loss', loss) ] 127 | 128 | def predict(self, batch): 129 | z_samples, c_samples = batch 130 | x_sample = self.sess.run( 131 | self.x_test, 132 | feed_dict={self.z_test: z_samples, self.c_test: c_samples} 133 | ) 134 | return x_sample 135 | 136 | def make_test_data(self): 137 | c_t = np.identity(self.num_attrs) 138 | c_t = np.tile(c_t, (self.test_size, 1)) 139 | z_t = np.random.normal(size=(self.test_size, self.z_dims)) 140 | z_t = np.tile(z_t, (1, self.num_attrs)) 141 | z_t = z_t.reshape((self.test_size * self.num_attrs, -1)) 142 | self.test_data = {'z_test': z_t, 'c_test': c_t} 143 | 144 | def build_model(self): 145 | self.encoder = Encoder(self.input_shape, self.z_dims, self.num_attrs) 146 | self.decoder = Decoder(self.input_shape) 147 | 148 | # Trainer 149 | batch_shape = (None,) + self.input_shape 150 | self.x_train = tf.placeholder(tf.float32, shape=batch_shape) 151 | self.c_train = tf.placeholder(tf.float32, shape=(None, self.num_attrs)) 152 | 153 | z_avg, z_log_var = self.encoder(self.x_train, self.c_train) 154 | epsilon = tf.random_normal(tf.shape(z_avg)) 155 | z_sample = z_avg + tf.multiply(tf.exp(0.5 * z_log_var), epsilon) 156 | x_sample = self.decoder(z_sample, self.c_train) 157 | 158 | self.total_loss = tf.constant(0.0) 159 | self.total_loss += tf.reduce_mean(tf.squared_difference(self.x_train, x_sample)) 160 | self.total_loss += kl_loss(z_avg, z_log_var) 161 | self.optimizer = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5).minimize(self.total_loss) 162 | 163 | with tf.control_dependencies([self.optimizer] + self.encoder.update_ops + self.decoder.update_ops): 164 | self.train_op = tf.no_op(name='train') 165 | 166 | # Predictor 167 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 168 | self.c_test = tf.placeholder(tf.float32, shape=(None, self.num_attrs)) 169 | 170 | self.x_test = self.decoder(self.z_test, self.c_test) 171 | x_tile = self.image_tiling(self.x_test, self.test_size, self.num_attrs) 172 | 173 | # Summary 174 | tf.summary.image('x_real', self.x_train, 10) 175 | tf.summary.image('x_fake', x_sample, 10) 176 | tf.summary.image('x_tile', x_tile, 1) 177 | tf.summary.scalar('total_loss', self.total_loss) 178 | 179 | self.summary = tf.summary.merge_all() -------------------------------------------------------------------------------- /models/cvaegan.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .base import CondBaseModel 6 | from .utils import * 7 | 8 | class Encoder(object): 9 | def __init__(self, input_shape, z_dims, num_attrs): 10 | self.variables = None 11 | self.reuse = False 12 | self.input_shape = input_shape 13 | self.z_dims = z_dims 14 | self.num_attrs = num_attrs 15 | self.name = 'encoder' 16 | 17 | def __call__(self, inputs, attrs, training=True): 18 | with tf.variable_scope(self.name, reuse=self.reuse): 19 | with tf.variable_scope('conv1'): 20 | a = tf.reshape(attrs, [-1, 1, 1, self.num_attrs]) 21 | a = tf.tile(a, [1, self.input_shape[0], self.input_shape[1], 1]) 22 | x = tf.concat([inputs, a], axis=-1) 23 | x = tf.layers.conv2d(x, 64, (5, 5), (2, 2), 'same') 24 | x = tf.layers.batch_normalization(x, training=training) 25 | x = tf.nn.relu(x) 26 | 27 | with tf.variable_scope('conv2'): 28 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same') 29 | x = tf.layers.batch_normalization(x, training=training) 30 | x = lrelu(x) 31 | 32 | with tf.variable_scope('conv3'): 33 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same') 34 | x = tf.layers.batch_normalization(x, training=training) 35 | x = lrelu(x) 36 | 37 | with tf.variable_scope('conv4'): 38 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same') 39 | x = tf.layers.batch_normalization(x, training=training) 40 | x = lrelu(x) 41 | 42 | with tf.variable_scope('global_avg'): 43 | x = tf.reduce_mean(x, axis=[1, 2]) 44 | 45 | with tf.variable_scope('fc1'): 46 | z_avg = tf.layers.dense(x, self.z_dims) 47 | z_log_var = tf.layers.dense(x, self.z_dims) 48 | 49 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 50 | self.reuse = True 51 | 52 | return z_avg, z_log_var 53 | 54 | class Decoder(object): 55 | def __init__(self, input_shape): 56 | self.variables = None 57 | self.reuse = False 58 | self.input_shape = input_shape 59 | self.name = 'decoder' 60 | 61 | def __call__(self, inputs, attrs, training=True): 62 | with tf.variable_scope(self.name, reuse=self.reuse): 63 | with tf.variable_scope('fc1'): 64 | w = self.input_shape[0] // (2 ** 3) 65 | x = tf.concat([inputs, attrs], axis=-1) 66 | x = tf.layers.dense(x, w * w * 256) 67 | x = tf.layers.batch_normalization(x, training=training) 68 | x = tf.nn.relu(x) 69 | x = tf.reshape(x, [-1, w, w, 256]) 70 | 71 | with tf.variable_scope('conv1'): 72 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same') 73 | x = tf.layers.batch_normalization(x, training=training) 74 | x = lrelu(x) 75 | 76 | with tf.variable_scope('conv2'): 77 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same') 78 | x = tf.layers.batch_normalization(x, training=training) 79 | x = lrelu(x) 80 | 81 | with tf.variable_scope('conv3'): 82 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same') 83 | x = tf.layers.batch_normalization(x, training=training) 84 | x = lrelu(x) 85 | 86 | with tf.variable_scope('conv4'): 87 | d = self.input_shape[2] 88 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same') 89 | x = tf.tanh(x) 90 | 91 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 92 | self.reuse = True 93 | 94 | return x 95 | 96 | class Classifier(object): 97 | def __init__(self, input_shape, num_attrs): 98 | self.variables = None 99 | self.reuse = False 100 | self.input_shape = input_shape 101 | self.num_attrs = num_attrs 102 | self.name = 'classifier' 103 | 104 | def __call__(self, inputs, training=True): 105 | with tf.variable_scope(self.name, reuse=self.reuse): 106 | with tf.variable_scope('conv1'): 107 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same') 108 | x = tf.layers.batch_normalization(x, training=training) 109 | x = tf.nn.relu(x) 110 | 111 | with tf.variable_scope('conv2'): 112 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same') 113 | x = tf.layers.batch_normalization(x, training=training) 114 | x = tf.nn.relu(x) 115 | 116 | with tf.variable_scope('conv3'): 117 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same') 118 | x = tf.layers.batch_normalization(x, training=training) 119 | x = tf.nn.relu(x) 120 | 121 | with tf.variable_scope('conv4'): 122 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same') 123 | x = tf.layers.batch_normalization(x, training=training) 124 | x = tf.nn.relu(x) 125 | 126 | with tf.variable_scope('global_avg'): 127 | x = tf.reduce_mean(x, axis=[1, 2]) 128 | 129 | with tf.variable_scope('fc1'): 130 | f = tf.contrib.layers.flatten(x) 131 | y = tf.layers.dense(f, self.num_attrs) 132 | 133 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 134 | self.reuse = True 135 | 136 | return y, f 137 | 138 | class Discriminator(object): 139 | def __init__(self, input_shape): 140 | self.variables = None 141 | self.reuse = False 142 | self.input_shape = input_shape 143 | self.name = 'discriminator' 144 | 145 | def __call__(self, inputs, training=True): 146 | with tf.variable_scope(self.name, reuse=self.reuse): 147 | with tf.variable_scope('conv1'): 148 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same') 149 | x = tf.layers.batch_normalization(x, training=training) 150 | x = tf.nn.relu(x) 151 | 152 | with tf.variable_scope('conv2'): 153 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same') 154 | x = tf.layers.batch_normalization(x, training=training) 155 | x = tf.nn.relu(x) 156 | 157 | with tf.variable_scope('conv3'): 158 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same') 159 | x = tf.layers.batch_normalization(x, training=training) 160 | x = tf.nn.relu(x) 161 | 162 | with tf.variable_scope('conv4'): 163 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same') 164 | x = tf.layers.batch_normalization(x, training=training) 165 | x = tf.nn.relu(x) 166 | 167 | with tf.variable_scope('global_avg'): 168 | x = tf.reduce_mean(x, axis=[1, 2]) 169 | 170 | with tf.variable_scope('fc1'): 171 | f = tf.contrib.layers.flatten(x) 172 | y = tf.layers.dense(f, 1) 173 | 174 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 175 | self.reuse = True 176 | 177 | return y, f 178 | 179 | 180 | class CVAEGAN(CondBaseModel): 181 | def __init__(self, 182 | input_shape=(64, 64, 3), 183 | z_dims = 128, 184 | name='cvaegan', 185 | **kwargs 186 | ): 187 | super(CVAEGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 188 | 189 | self.z_dims = z_dims 190 | 191 | # Parameters for feature matching 192 | self.use_feature_match = False 193 | self.alpha = 0.7 194 | 195 | self.E_f_D_r = None 196 | self.E_f_D_p = None 197 | self.E_f_C_r = None 198 | self.E_f_C_p = None 199 | 200 | self.f_enc = None 201 | self.f_gen = None 202 | self.f_cls = None 203 | self.f_dis = None 204 | 205 | self.x_r = None 206 | self.c_r = None 207 | self.z_p = None 208 | 209 | self.z_test = None 210 | self.x_test = None 211 | self.c_test = None 212 | 213 | self.enc_trainer = None 214 | self.gen_trainer = None 215 | self.dis_trainer = None 216 | self.cls_trainer = None 217 | 218 | self.gen_loss = None 219 | self.dis_loss = None 220 | self.gen_acc = None 221 | self.dis_acc = None 222 | 223 | self.build_model() 224 | 225 | def train_on_batch(self, batch, index): 226 | x_r, c_r = batch 227 | batchsize = len(x_r) 228 | z_p = np.random.uniform(-1, 1, size=(len(x_r), self.z_dims)) 229 | 230 | _, _, _, _, gen_loss, dis_loss, gen_acc, dis_acc = self.sess.run( 231 | (self.gen_trainer, self.enc_trainer, self.dis_trainer, self.cls_trainer, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc), 232 | feed_dict={ 233 | self.x_r: x_r, self.z_p: z_p, self.c_r: c_r, 234 | self.z_test: self.test_data['z_test'], self.c_test: self.test_data['c_test'] 235 | } 236 | ) 237 | 238 | summary_priod = 1000 239 | if index // summary_priod != (index + batchsize) // summary_priod: 240 | summary = self.sess.run( 241 | self.summary, 242 | feed_dict={ 243 | self.x_r: x_r, self.z_p: z_p, self.c_r: c_r, 244 | self.z_test: self.test_data['z_test'], self.c_test: self.test_data['c_test'] 245 | } 246 | ) 247 | self.writer.add_summary(summary, index) 248 | 249 | return [ 250 | ('gen_loss', gen_loss), ('dis_loss', dis_loss), 251 | ('gen_acc', gen_acc), ('dis_acc', dis_acc) 252 | ] 253 | 254 | def predict(self, batch): 255 | z_samples, c_samples = batch 256 | x_sample = self.sess.run( 257 | self.x_test, 258 | feed_dict={self.z_test: z_samples, self.c_test: c_samples} 259 | ) 260 | return x_sample 261 | 262 | def make_test_data(self): 263 | c_t = np.identity(self.num_attrs) 264 | c_t = np.tile(c_t, (self.test_size, 1)) 265 | z_t = np.random.normal(size=(self.test_size, self.z_dims)) 266 | z_t = np.tile(z_t, (1, self.num_attrs)) 267 | z_t = z_t.reshape((self.test_size * self.num_attrs, self.z_dims)) 268 | self.test_data = {'z_test': z_t, 'c_test': c_t} 269 | 270 | def build_model(self): 271 | self.f_enc = Encoder(self.input_shape, self.z_dims, self.num_attrs) 272 | self.f_gen = Decoder(self.input_shape) 273 | 274 | n_cls_out = self.num_attrs if self.use_feature_match else self.num_attrs + 1 275 | self.f_cls = Classifier(self.input_shape, n_cls_out) 276 | self.f_dis = Discriminator(self.input_shape) 277 | 278 | # Trainer 279 | self.x_r = tf.placeholder(tf.float32, shape=(None,) + self.input_shape) 280 | self.c_r = tf.placeholder(tf.float32, shape=(None, self.num_attrs)) 281 | 282 | z_avg, z_log_var = self.f_enc(self.x_r, self.c_r) 283 | 284 | z_f = sample_normal(z_avg, z_log_var) 285 | x_f = self.f_gen(z_f, self.c_r) 286 | 287 | self.z_p = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 288 | x_p = self.f_gen(self.z_p, self.c_r) 289 | 290 | c_r_pred, f_C_r = self.f_cls(self.x_r) 291 | c_f, f_C_f = self.f_cls(x_f) 292 | c_p, f_C_p = self.f_cls(x_p) 293 | 294 | y_r, f_D_r = self.f_dis(self.x_r) 295 | y_f, f_D_f = self.f_dis(x_f) 296 | y_p, f_D_p = self.f_dis(x_p) 297 | 298 | L_KL = kl_loss(z_avg, z_log_var) 299 | 300 | enc_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 301 | gen_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 302 | cls_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 303 | dis_opt = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 304 | 305 | if self.use_feature_match: 306 | # Use feature matching (it is usually unstable) 307 | L_GD = self.L_GD(f_D_r, f_D_p) 308 | L_GC = self.L_GC(f_C_r, f_C_p, self.c_r) 309 | L_G = self.L_G(self.x_r, x_f, f_D_r, f_D_f, f_C_r, f_C_f) 310 | 311 | with tf.name_scope('L_D'): 312 | L_D = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_r), y_r) + \ 313 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_f), y_f) + \ 314 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_p), y_p) 315 | 316 | with tf.name_scope('L_C'): 317 | L_C = tf.losses.softmax_cross_entropy(self.c_r, c_r_pred) 318 | 319 | self.enc_trainer = enc_opt.minimize(L_G + L_KL, var_list=self.f_enc.variables) 320 | self.gen_trainer = gen_opt.minimize(L_G + L_GD + L_GC, var_list=self.f_gen.variables) 321 | self.cls_trainer = cls_opt.minimize(L_C, var_list=self.f_cls.variables) 322 | self.dis_trainer = dis_opt.minimize(L_D, var_list=self.f_dis.variables) 323 | 324 | self.gen_loss = L_G + L_GD + L_GC 325 | self.dis_loss = L_D 326 | 327 | # Predictor 328 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 329 | self.c_test = tf.placeholder(tf.float32, shape=(None, self.num_attrs)) 330 | 331 | self.x_test = self.f_gen(self.z_test, self.c_test) 332 | x_tile = self.image_tiling(self.x_test, self.test_size, self.num_attrs) 333 | 334 | # Summary 335 | tf.summary.image('x_real', self.x_r, 10) 336 | tf.summary.image('x_fake', x_f, 10) 337 | tf.summary.image('x_tile', x_tile, 1) 338 | tf.summary.scalar('L_G', L_G) 339 | tf.summary.scalar('L_GD', L_GD) 340 | tf.summary.scalar('L_GC', L_GC) 341 | tf.summary.scalar('L_C', L_C) 342 | tf.summary.scalar('L_D', L_D) 343 | tf.summary.scalar('L_KL', L_KL) 344 | tf.summary.scalar('gen_loss', self.gen_loss) 345 | tf.summary.scalar('dis_loss', self.dis_loss) 346 | else: 347 | # Not use feature matching (it is more similar to ordinary GANs) 348 | c_r_aug = tf.concat((self.c_r, tf.zeros((tf.shape(self.c_r)[0], 1))), axis=1) 349 | c_other = tf.concat((tf.zeros_like(self.c_r), tf.ones((tf.shape(self.c_r)[0], 1))), axis=1) 350 | with tf.name_scope('L_G'): 351 | L_G = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_f), y_f) + \ 352 | tf.losses.sigmoid_cross_entropy(tf.ones_like(y_p), y_p) + \ 353 | tf.losses.softmax_cross_entropy(c_r_aug, c_f) + \ 354 | tf.losses.softmax_cross_entropy(c_r_aug, c_p) 355 | 356 | with tf.name_scope('L_rec'): 357 | # L_rec = 0.5 * tf.losses.mean_squared_error(self.x_r, x_f) 358 | L_rec = 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(self.x_r, x_f), axis=[1, 2, 3])) 359 | 360 | with tf.name_scope('L_D'): 361 | L_D = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_r), y_r) + \ 362 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_f), y_f) + \ 363 | tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_p), y_p) 364 | 365 | with tf.name_scope('L_C'): 366 | L_C = tf.losses.softmax_cross_entropy(c_r_aug, c_r_pred) + \ 367 | tf.losses.softmax_cross_entropy(c_other, c_f) + \ 368 | tf.losses.softmax_cross_entropy(c_other, c_p) 369 | 370 | self.enc_trainer = enc_opt.minimize(L_rec + L_KL, var_list=self.f_enc.variables) 371 | self.gen_trainer = gen_opt.minimize(L_G + L_rec, var_list=self.f_gen.variables) 372 | self.cls_trainer = cls_opt.minimize(L_C, var_list=self.f_cls.variables) 373 | self.dis_trainer = dis_opt.minimize(L_D, var_list=self.f_dis.variables) 374 | 375 | self.gen_loss = L_G + L_rec 376 | self.dis_loss = L_D 377 | 378 | # Predictor 379 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 380 | self.c_test = tf.placeholder(tf.float32, shape=(None, self.num_attrs)) 381 | 382 | self.x_test = self.f_gen(self.z_test, self.c_test) 383 | x_tile = self.image_tiling(self.x_test, self.test_size, self.num_attrs) 384 | 385 | # Summary 386 | tf.summary.image('x_real', self.x_r, 10) 387 | tf.summary.image('x_fake', x_f, 10) 388 | tf.summary.image('x_tile', x_tile, 1) 389 | tf.summary.scalar('L_G', L_G) 390 | tf.summary.scalar('L_rec', L_rec) 391 | tf.summary.scalar('L_C', L_C) 392 | tf.summary.scalar('L_D', L_D) 393 | tf.summary.scalar('L_KL', L_KL) 394 | tf.summary.scalar('gen_loss', self.gen_loss) 395 | tf.summary.scalar('dis_loss', self.dis_loss) 396 | 397 | # Accuracy 398 | self.gen_acc = 0.5 * binary_accuracy(tf.ones_like(y_f), y_f) + \ 399 | 0.5 * binary_accuracy(tf.ones_like(y_p), y_p) 400 | 401 | self.dis_acc = binary_accuracy(tf.ones_like(y_r), y_r) / 3.0 + \ 402 | binary_accuracy(tf.zeros_like(y_f), y_f) / 3.0 + \ 403 | binary_accuracy(tf.zeros_like(y_p), y_p) / 3.0 404 | 405 | tf.summary.scalar('gen_acc', self.gen_acc) 406 | tf.summary.scalar('dis_acc', self.dis_acc) 407 | 408 | self.summary = tf.summary.merge_all() 409 | 410 | def L_G(self, x_r, x_f, f_D_r, f_D_f, f_C_r, f_C_f): 411 | with tf.name_scope('L_G'): 412 | loss = tf.constant(0.0, dtype=tf.float32) 413 | loss += 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(x_r, x_f), axis=[1, 2, 3])) 414 | loss += 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(f_D_r, f_D_f), axis=[1])) 415 | loss += 0.5 * tf.reduce_mean(tf.reduce_sum(tf.squared_difference(f_C_r, f_C_f), axis=[1])) 416 | 417 | return loss 418 | 419 | def L_GD(self, f_D_r, f_D_p): 420 | with tf.name_scope('L_GD'): 421 | # Compute loss 422 | E_f_D_r = tf.reduce_mean(f_D_r, axis=0) 423 | E_f_D_p = tf.reduce_mean(f_D_p, axis=0) 424 | 425 | # Update features 426 | if self.E_f_D_r is None: 427 | self.E_f_D_r = tf.zeros_like(E_f_D_r) 428 | 429 | if self.E_f_D_p is None: 430 | self.E_f_D_p = tf.zeros_like(E_f_D_p) 431 | 432 | self.E_f_D_r = self.alpha * self.E_f_D_r + (1.0 - self.alpha) * E_f_D_r 433 | self.E_f_D_p = self.alpha * self.E_f_D_p + (1.0 - self.alpha) * E_f_D_p 434 | return 0.5 * tf.reduce_sum(tf.squared_difference(self.E_f_D_r, self.E_f_D_p)) 435 | 436 | def L_GC(self, f_C_r, f_C_p, c): 437 | with tf.name_scope('L_GC'): 438 | image_shape = tf.shape(f_C_r) 439 | 440 | indices = tf.eye(self.num_attrs, dtype=tf.float32) 441 | indices = tf.tile(indices, (1, image_shape[0])) 442 | indices = tf.reshape(indices, (-1, self.num_attrs)) 443 | 444 | classes = tf.tile(c, (self.num_attrs, 1)) 445 | 446 | mask = tf.reduce_max(tf.multiply(indices, classes), axis=1) 447 | mask = tf.reshape(mask, (-1, 1)) 448 | mask = tf.tile(mask, (1, image_shape[1])) 449 | 450 | denom = tf.reshape(tf.multiply(indices, classes), (self.num_attrs, image_shape[0], self.num_attrs)) 451 | denom = tf.reduce_sum(denom, axis=[1, 2]) 452 | denom = tf.tile(tf.reshape(denom, (-1, 1)), (1, image_shape[1])) 453 | 454 | f_1_sum = tf.tile(f_C_r, (self.num_attrs, 1)) 455 | f_1_sum = tf.multiply(f_1_sum, mask) 456 | f_1_sum = tf.reshape(f_1_sum, (self.num_attrs, image_shape[0], image_shape[1])) 457 | E_f_1 = tf.divide(tf.reduce_sum(f_1_sum, axis=1), denom + 1.0e-8) 458 | 459 | f_2_sum = tf.tile(f_C_p, (self.num_attrs, 1)) 460 | f_2_sum = tf.multiply(f_2_sum, mask) 461 | f_2_sum = tf.reshape(f_2_sum, (self.num_attrs, image_shape[0], image_shape[1])) 462 | E_f_2 = tf.divide(tf.reduce_sum(f_2_sum, axis=1), denom + 1.0e-8) 463 | 464 | # Update features 465 | if self.E_f_C_r is None: 466 | self.E_f_C_r = tf.zeros_like(E_f_1) 467 | 468 | if self.E_f_C_p is None: 469 | self.E_f_C_p = tf.zeros_like(E_f_2) 470 | 471 | self.E_f_C_r = self.alpha * self.E_f_C_r + (1.0 - self.alpha) * E_f_1 472 | self.E_f_C_p = self.alpha * self.E_f_C_p + (1.0 - self.alpha) * E_f_2 473 | 474 | # return 0.5 * tf.losses.mean_squared_error(self.E_f_C_r, self.E_f_C_p) 475 | return 0.5 * tf.reduce_sum(tf.squared_difference(self.E_f_C_r, self.E_f_C_p)) 476 | -------------------------------------------------------------------------------- /models/dcgan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import BaseModel 5 | from .utils import * 6 | from .wnorm import * 7 | 8 | class Generator(object): 9 | def __init__(self, input_shape, z_dims, use_wnorm=False): 10 | self.variables = None 11 | self.update_ops = None 12 | self.reuse = False 13 | self.use_wnorm = use_wnorm 14 | self.input_shape = input_shape 15 | self.z_dims = z_dims 16 | 17 | def __call__(self, inputs, training=True): 18 | with tf.variable_scope('generator', reuse=self.reuse): 19 | with tf.variable_scope('fc1'): 20 | w = self.input_shape[0] // (2 ** 3) 21 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 22 | if self.use_wnorm: 23 | x = conv2d_transpose_wnorm(x, 256, (w, w), (1, 1), use_scale=True, 24 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 25 | x = tf.layers.batch_normalization(x, scale=False, training=training) 26 | else: 27 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 28 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 29 | x = tf.layers.batch_normalization(x, training=training) 30 | 31 | x = tf.nn.relu(x) 32 | 33 | with tf.variable_scope('conv1'): 34 | if self.use_wnorm: 35 | x = conv2d_transpose_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True, 36 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 37 | x = tf.layers.batch_normalization(x, scale=False, training=training) 38 | else: 39 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 40 | x = tf.layers.batch_normalization(x, training=training) 41 | x = tf.nn.relu(x) 42 | 43 | with tf.variable_scope('conv2'): 44 | if self.use_wnorm: 45 | x = conv2d_transpose_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True, 46 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 47 | x = tf.layers.batch_normalization(x, scale=False, training=training) 48 | else: 49 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 50 | x = tf.layers.batch_normalization(x, training=training) 51 | x = tf.nn.relu(x) 52 | 53 | with tf.variable_scope('conv3'): 54 | if self.use_wnorm: 55 | x = conv2d_transpose_wnorm(x, 64, (5, 5), (2, 2), 'same', use_scale=True, 56 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 57 | x = tf.layers.batch_normalization(x, scale=False, training=training) 58 | else: 59 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 60 | x = tf.layers.batch_normalization(x, training=training) 61 | x = tf.nn.relu(x) 62 | 63 | with tf.variable_scope('conv4'): 64 | d = self.input_shape[2] 65 | if self.use_wnorm: 66 | x = conv2d_transpose_wnorm(x, d, (5, 5), (1, 1), 'same', use_scale=True, 67 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 68 | else: 69 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same', 70 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 71 | x = tf.tanh(x) 72 | 73 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 74 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') 75 | self.reuse = True 76 | return x 77 | 78 | class Discriminator(object): 79 | def __init__(self, input_shape, use_wnorm=False): 80 | self.input_shape = input_shape 81 | self.variables = None 82 | self.update_ops = None 83 | self.use_wnorm = use_wnorm 84 | self.reuse = False 85 | 86 | def __call__(self, inputs, training=True): 87 | with tf.variable_scope('discriminator', reuse=self.reuse): 88 | with tf.variable_scope('conv1'): 89 | if self.use_wnorm: 90 | x = conv2d_wnorm(inputs, 64, (5, 5), (2, 2), 'same', use_scale=True, 91 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 92 | x = tf.layers.batch_normalization(x, scale=False, training=training) 93 | else: 94 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 95 | x = tf.layers.batch_normalization(x, training=training) 96 | x = lrelu(x) 97 | 98 | with tf.variable_scope('conv2'): 99 | if self.use_wnorm: 100 | x = conv2d_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True, 101 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 102 | x = tf.layers.batch_normalization(x, scale=False, training=training) 103 | else: 104 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 105 | x = tf.layers.batch_normalization(x, training=training) 106 | x = lrelu(x) 107 | 108 | with tf.variable_scope('conv3'): 109 | if self.use_wnorm: 110 | x = conv2d_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True, 111 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 112 | x = tf.layers.batch_normalization(x, scale=False, training=training) 113 | else: 114 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 115 | x = tf.layers.batch_normalization(x, training=training) 116 | x = lrelu(x) 117 | 118 | with tf.variable_scope('conv4'): 119 | if self.use_wnorm: 120 | x = conv2d_wnorm(x, 512, (5, 5), (2, 2), 'same', use_scale=True, 121 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 122 | x = tf.layers.batch_normalization(x, scale=False, training=training) 123 | else: 124 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 125 | x = tf.layers.batch_normalization(x, training=training) 126 | x = lrelu(x) 127 | 128 | with tf.variable_scope('conv5'): 129 | w = self.input_shape[0] // (2 ** 4) 130 | if self.use_wnorm: 131 | y = conv2d_wnorm(x, 1, (w, w), (1, 1), 'valid', use_scale=True, 132 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 133 | else: 134 | y = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid', 135 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 136 | y = tf.reshape(y, [-1, 1]) 137 | 138 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 139 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') 140 | self.reuse = True 141 | return y 142 | 143 | class DCGAN(BaseModel): 144 | def __init__(self, 145 | input_shape=(64, 64, 3), 146 | z_dims = 128, 147 | name='dcgan', 148 | **kwargs 149 | ): 150 | super(DCGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 151 | 152 | self.z_dims = z_dims 153 | self.use_wnorm = True 154 | 155 | self.f_gen = None 156 | self.f_dis = None 157 | self.gen_loss = None 158 | self.dis_loss = None 159 | self.train_op = None 160 | 161 | self.gen_acc = None 162 | self.dis_acc = None 163 | 164 | self.x_train = None 165 | self.z_train = None 166 | 167 | self.z_test = None 168 | self.x_test = None 169 | 170 | self.build_model() 171 | 172 | def train_on_batch(self, x_batch, index): 173 | batchsize = x_batch.shape[0] 174 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 175 | 176 | _, g_loss, d_loss, g_acc, d_acc, summary = self.sess.run( 177 | (self.train_op, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc, self.summary), 178 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data} 179 | ) 180 | 181 | self.writer.add_summary(summary, index) 182 | 183 | return [ 184 | ('g_loss', g_loss), ('d_loss', d_loss), 185 | ('g_acc', g_acc), ('d_acc', d_acc) 186 | ] 187 | 188 | def predict(self, z_samples): 189 | x_sample = self.sess.run( 190 | self.x_test, 191 | feed_dict={self.z_test: z_samples} 192 | ) 193 | return x_sample 194 | 195 | def make_test_data(self): 196 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims)) 197 | 198 | def build_model(self): 199 | # Trainer 200 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm) 201 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm) 202 | 203 | x_shape = (None,) + self.input_shape 204 | z_shape = (None, self.z_dims) 205 | self.x_train = tf.placeholder(tf.float32, shape=x_shape) 206 | self.z_train = tf.placeholder(tf.float32, shape=z_shape) 207 | x_fake = self.f_gen(self.z_train) 208 | y_fake = self.f_dis(x_fake) 209 | y_real = self.f_dis(self.x_train) 210 | 211 | self.gen_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_fake), y_fake) 212 | self.dis_loss = 0.5 * tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real), y_real) + \ 213 | 0.5 * tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake), y_fake) 214 | 215 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 216 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 217 | 218 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables) 219 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables) 220 | 221 | self.gen_acc = binary_accuracy(tf.ones_like(y_fake), y_fake) 222 | self.dis_acc = 0.5 * binary_accuracy(tf.ones_like(y_real), y_real) + \ 223 | 0.5 * binary_accuracy(tf.zeros_like(y_fake), y_fake) 224 | 225 | with tf.control_dependencies([gen_train_op, dis_train_op] + \ 226 | self.f_dis.update_ops + \ 227 | self.f_gen.update_ops): 228 | self.train_op = tf.no_op(name='train') 229 | 230 | # Predictor 231 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 232 | self.x_test = self.f_gen(self.z_test, training=False) 233 | 234 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 235 | 236 | tf.summary.image('x_real', image_cast(self.x_train), 10) 237 | tf.summary.image('x_fake', image_cast(x_fake), 10) 238 | tf.summary.image('x_tile', image_cast(x_tile), 1) 239 | tf.summary.scalar('gen_loss', self.gen_loss) 240 | tf.summary.scalar('dis_loss', self.dis_loss) 241 | tf.summary.scalar('gen_acc', self.gen_acc) 242 | tf.summary.scalar('dis_acc', self.dis_acc) 243 | self.summary = tf.summary.merge_all() 244 | -------------------------------------------------------------------------------- /models/improved_gan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import BaseModel 5 | from .utils import * 6 | 7 | def minibatch_discrimination(x, kernels=50, dims=5): 8 | with tf.name_scope('MinibatchDiscrimination'): 9 | size = x.get_shape()[1] 10 | W = tf.get_variable(shape=(size, kernels * dims), trainable=True, name='kernel') 11 | Ms = tf.tensordot(x, W, axes=1) 12 | x_i = tf.reshape(Ms, [-1, kernels, 1, dims]) 13 | x_j = tf.reshape(Ms, [-1, 1, kernels, dims]) 14 | x_i = tf.tile(x_i, [1, 1, kernels, 1]) 15 | x_j = tf.tile(x_j, [1, kernels, 1, 1]) 16 | norm = tf.reduce_sum(tf.abs(x_i - x_j), axis=3) 17 | Os = tf.reduce_sum(tf.exp(-norm), axis=2) 18 | return Os 19 | 20 | class Generator(object): 21 | def __init__(self, input_shape, z_dims, use_wnorm=False): 22 | self.variables = None 23 | self.update_ops = None 24 | self.reuse = False 25 | self.use_wnorm = use_wnorm 26 | self.input_shape = input_shape 27 | self.z_dims = z_dims 28 | 29 | def __call__(self, inputs, training=True): 30 | with tf.variable_scope('generator', reuse=self.reuse): 31 | with tf.variable_scope('fc1'): 32 | w = self.input_shape[0] // (2 ** 3) 33 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 34 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 35 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 36 | x = tf.layers.batch_normalization(x, training=training) 37 | x = tf.nn.relu(x) 38 | 39 | with tf.variable_scope('conv1'): 40 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', 41 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 42 | x = tf.layers.batch_normalization(x, training=training) 43 | x = tf.nn.relu(x) 44 | 45 | with tf.variable_scope('conv2'): 46 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', 47 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 48 | x = tf.layers.batch_normalization(x, training=training) 49 | x = tf.nn.relu(x) 50 | 51 | with tf.variable_scope('conv3'): 52 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', 53 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 54 | x = tf.layers.batch_normalization(x, training=training) 55 | x = tf.nn.relu(x) 56 | 57 | with tf.variable_scope('conv4'): 58 | d = self.input_shape[2] 59 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same', 60 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 61 | x = tf.tanh(x) 62 | 63 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 64 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') 65 | self.reuse = True 66 | return x 67 | 68 | class Discriminator(object): 69 | def __init__(self, input_shape, use_wnorm=False): 70 | self.input_shape = input_shape 71 | self.variables = None 72 | self.update_ops = None 73 | self.use_wnorm = use_wnorm 74 | self.reuse = False 75 | 76 | def __call__(self, inputs, training=True): 77 | with tf.variable_scope('discriminator', reuse=self.reuse): 78 | with tf.variable_scope('conv1'): 79 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', 80 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 81 | x = tf.layers.batch_normalization(x, training=training) 82 | x = lrelu(x) 83 | 84 | with tf.variable_scope('conv2'): 85 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', 86 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 87 | x = tf.layers.batch_normalization(x, training=training) 88 | x = lrelu(x) 89 | 90 | with tf.variable_scope('conv3'): 91 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', 92 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 93 | x = tf.layers.batch_normalization(x, training=training) 94 | x = lrelu(x) 95 | 96 | with tf.variable_scope('conv4'): 97 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same', 98 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 99 | x = tf.layers.batch_normalization(x, training=training) 100 | x = lrelu(x) 101 | 102 | with tf.variable_scope('fc1'): 103 | x = tf.contrib.layers.flatten(x) 104 | x = tf.layers.dense(x, 1024) 105 | x = tf.layers.batch_normalization(x, training=training) 106 | x = lrelu(x) 107 | 108 | with tf.variable_scope('minibatch_discrimination'): 109 | x = minibatch_discrimination(x, kernels=50, dims=5) 110 | f = tf.identity(x) 111 | 112 | with tf.variable_scope('fc2'): 113 | y = tf.layers.dense(x, 1) 114 | 115 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 116 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') 117 | self.reuse = True 118 | return y, f 119 | 120 | class ImprovedGAN(BaseModel): 121 | def __init__(self, 122 | input_shape=(64, 64, 3), 123 | z_dims = 128, 124 | name='improved', 125 | **kwargs 126 | ): 127 | super(ImprovedGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 128 | 129 | self.z_dims = z_dims 130 | self.use_wnorm = True 131 | 132 | self.f_gen = None 133 | self.f_dis = None 134 | self.gen_loss = None 135 | self.dis_loss = None 136 | self.train_op = None 137 | 138 | self.gen_acc = None 139 | self.dis_acc = None 140 | 141 | self.x_train = None 142 | self.z_train = None 143 | 144 | self.z_test = None 145 | self.x_test = None 146 | 147 | self.build_model() 148 | 149 | def train_on_batch(self, x_batch, index): 150 | batchsize = x_batch.shape[0] 151 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 152 | 153 | _, g_loss, d_loss, g_acc, d_acc, summary = self.sess.run( 154 | (self.train_op, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc, self.summary), 155 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data} 156 | ) 157 | 158 | self.writer.add_summary(summary, index) 159 | 160 | return [ 161 | ('g_loss', g_loss), ('d_loss', d_loss), 162 | ('g_acc', g_acc), ('d_acc', d_acc) 163 | ] 164 | 165 | def predict(self, z_samples): 166 | x_sample = self.sess.run( 167 | self.x_test, 168 | feed_dict={self.z_test: z_samples} 169 | ) 170 | return x_sample 171 | 172 | def make_test_data(self): 173 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims)) 174 | 175 | def build_model(self): 176 | # Trainer 177 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm) 178 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm) 179 | 180 | x_shape = (None,) + self.input_shape 181 | z_shape = (None, self.z_dims) 182 | self.x_train = tf.placeholder(tf.float32, shape=x_shape) 183 | self.z_train = tf.placeholder(tf.float32, shape=z_shape) 184 | x_fake = self.f_gen(self.z_train) 185 | y_fake, f_fake = self.f_dis(x_fake) 186 | y_real, f_real = self.f_dis(self.x_train) 187 | 188 | E_f_fake = tf.reduce_mean(f_fake, axis=0) 189 | E_f_real = tf.reduce_mean(f_real, axis=0) 190 | self.gen_loss = tf.reduce_sum(tf.square(E_f_real - E_f_fake)) 191 | self.dis_loss = 0.5 * tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real), y_real) + \ 192 | 0.5 * tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake), y_fake) 193 | 194 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 195 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 196 | 197 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables) 198 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables) 199 | 200 | self.gen_acc = binary_accuracy(tf.ones_like(y_fake), y_fake) 201 | self.dis_acc = 0.5 * binary_accuracy(tf.ones_like(y_real), y_real) + \ 202 | 0.5 * binary_accuracy(tf.zeros_like(y_fake), y_fake) 203 | 204 | with tf.control_dependencies([gen_train_op, dis_train_op] + \ 205 | self.f_dis.update_ops + \ 206 | self.f_gen.update_ops): 207 | self.train_op = tf.no_op(name='train') 208 | 209 | # Predictor 210 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 211 | self.x_test = self.f_gen(self.z_test, training=False) 212 | 213 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 214 | 215 | tf.summary.image('x_real', image_cast(self.x_train), 10) 216 | tf.summary.image('x_fake', image_cast(x_fake), 10) 217 | tf.summary.image('x_tile', image_cast(x_tile), 1) 218 | tf.summary.scalar('gen_loss', self.gen_loss) 219 | tf.summary.scalar('dis_loss', self.dis_loss) 220 | tf.summary.scalar('gen_acc', self.gen_acc) 221 | tf.summary.scalar('dis_acc', self.dis_acc) 222 | self.summary = tf.summary.merge_all() 223 | -------------------------------------------------------------------------------- /models/lsgan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import BaseModel 5 | from .utils import * 6 | 7 | class Generator(object): 8 | def __init__(self, input_shape, z_dims, use_wnorm=False): 9 | self.variables = None 10 | self.update_ops = None 11 | self.reuse = False 12 | self.use_wnorm = use_wnorm 13 | self.input_shape = input_shape 14 | self.z_dims = z_dims 15 | 16 | def __call__(self, inputs, training=True): 17 | with tf.variable_scope('generator', reuse=self.reuse): 18 | with tf.variable_scope('fc1'): 19 | w = self.input_shape[0] // (2 ** 3) 20 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 21 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 22 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 23 | x = tf.layers.batch_normalization(x, training=training) 24 | x = tf.nn.relu(x) 25 | 26 | with tf.variable_scope('conv1'): 27 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 28 | x = tf.layers.batch_normalization(x, training=training) 29 | x = tf.nn.relu(x) 30 | 31 | with tf.variable_scope('conv2'): 32 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 33 | x = tf.layers.batch_normalization(x, training=training) 34 | x = tf.nn.relu(x) 35 | 36 | with tf.variable_scope('conv3'): 37 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 38 | x = tf.layers.batch_normalization(x, training=training) 39 | x = tf.nn.relu(x) 40 | 41 | with tf.variable_scope('conv4'): 42 | d = self.input_shape[2] 43 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same', 44 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 45 | x = tf.tanh(x) 46 | 47 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 48 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') 49 | self.reuse = True 50 | return x 51 | 52 | class Discriminator(object): 53 | def __init__(self, input_shape, use_wnorm=False): 54 | self.input_shape = input_shape 55 | self.variables = None 56 | self.update_ops = None 57 | self.use_wnorm = use_wnorm 58 | self.reuse = False 59 | 60 | def __call__(self, inputs, training=True): 61 | with tf.variable_scope('discriminator', reuse=self.reuse): 62 | with tf.variable_scope('conv1'): 63 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 64 | x = tf.layers.batch_normalization(x, training=training) 65 | x = lrelu(x) 66 | 67 | with tf.variable_scope('conv2'): 68 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 69 | x = tf.layers.batch_normalization(x, training=training) 70 | x = lrelu(x) 71 | 72 | with tf.variable_scope('conv3'): 73 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 74 | x = tf.layers.batch_normalization(x, training=training) 75 | x = lrelu(x) 76 | 77 | with tf.variable_scope('conv4'): 78 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same', kernel_initializer=tf.contrib.layers.xavier_initializer()) 79 | x = tf.layers.batch_normalization(x, training=training) 80 | x = lrelu(x) 81 | 82 | with tf.variable_scope('conv5'): 83 | w = self.input_shape[0] // (2 ** 4) 84 | y = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid', 85 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 86 | y = tf.reshape(y, [-1, 1]) 87 | 88 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 89 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') 90 | self.reuse = True 91 | return y 92 | 93 | class LSGAN(BaseModel): 94 | def __init__(self, 95 | input_shape=(64, 64, 3), 96 | z_dims = 128, 97 | name='lsgan', 98 | **kwargs 99 | ): 100 | super(LSGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 101 | 102 | self.z_dims = z_dims 103 | self.use_wnorm = True 104 | 105 | self.f_gen = None 106 | self.f_dis = None 107 | self.gen_loss = None 108 | self.dis_loss = None 109 | self.train_op = None 110 | 111 | self.param_a = 0.0 112 | self.param_b = 1.0 113 | self.param_c = 1.0 114 | 115 | self.x_train = None 116 | self.z_train = None 117 | 118 | self.z_test = None 119 | self.x_test = None 120 | 121 | self.build_model() 122 | 123 | def train_on_batch(self, x_batch, index): 124 | batchsize = x_batch.shape[0] 125 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 126 | 127 | _, g_loss, d_loss = self.sess.run( 128 | (self.train_op, self.gen_loss, self.dis_loss), 129 | feed_dict={self.x_train: x_batch, self.z_train: z_sample} 130 | ) 131 | 132 | summary_period = 1000 133 | if index // summary_period != (index - batchsize) // summary_period: 134 | summary = self.sess.run( 135 | self.summary, 136 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data} 137 | ) 138 | self.writer.add_summary(summary, index) 139 | 140 | return [ 141 | ('g_loss', g_loss), ('d_loss', d_loss) 142 | ] 143 | 144 | def predict(self, z_samples): 145 | x_sample = self.sess.run( 146 | self.x_test, 147 | feed_dict={self.z_test: z_samples} 148 | ) 149 | return x_sample 150 | 151 | def make_test_data(self): 152 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims)) 153 | 154 | def build_model(self): 155 | # Trainer 156 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm) 157 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm) 158 | 159 | x_shape = (None,) + self.input_shape 160 | z_shape = (None, self.z_dims) 161 | self.x_train = tf.placeholder(tf.float32, shape=x_shape) 162 | self.z_train = tf.placeholder(tf.float32, shape=z_shape) 163 | x_fake = self.f_gen(self.z_train) 164 | y_fake = self.f_dis(x_fake) 165 | y_real = self.f_dis(self.x_train) 166 | 167 | self.gen_loss = tf.reduce_mean(tf.square(y_fake - self.param_c)) 168 | self.dis_loss = tf.reduce_mean(tf.square(y_real - self.param_b)) + \ 169 | tf.reduce_mean(tf.square(y_fake - self.param_a)) 170 | 171 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 172 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 173 | 174 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables) 175 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables) 176 | 177 | with tf.control_dependencies([gen_train_op, dis_train_op] + \ 178 | self.f_dis.update_ops + \ 179 | self.f_gen.update_ops): 180 | self.train_op = tf.no_op(name='train') 181 | 182 | # Predictor 183 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 184 | self.x_test = self.f_gen(self.z_test, training=False) 185 | 186 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 187 | 188 | tf.summary.image('x_real', image_cast(self.x_train), 10) 189 | tf.summary.image('x_fake', image_cast(x_fake), 10) 190 | tf.summary.image('x_tile', image_cast(x_tile), 1) 191 | tf.summary.scalar('gen_loss', self.gen_loss) 192 | tf.summary.scalar('dis_loss', self.dis_loss) 193 | self.summary = tf.summary.merge_all() 194 | -------------------------------------------------------------------------------- /models/resnet_gan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from .base import BaseModel 5 | from .utils import * 6 | from .wnorm import * 7 | 8 | def residual_plain_unit(x, filters, training=True): 9 | y = tf.identity(x) 10 | 11 | x = tf.layers.batch_normalization(x, training=training) 12 | x = tf.nn.relu(x) 13 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same') 14 | 15 | x = tf.layers.batch_normalization(x, training=training) 16 | x = tf.nn.relu(x + y) 17 | x = tf.layers.dropout(x, rate=0.5, training=training) 18 | 19 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same') 20 | 21 | return x + y 22 | 23 | class Generator(object): 24 | def __init__(self, input_shape, z_dims, use_wnorm=False): 25 | self.variables = None 26 | self.update_ops = None 27 | self.reuse = False 28 | self.use_wnorm = use_wnorm 29 | self.input_shape = input_shape 30 | self.z_dims = z_dims 31 | 32 | def __call__(self, inputs, training=True): 33 | with tf.variable_scope('generator', reuse=self.reuse): 34 | with tf.variable_scope('deconv1'): 35 | w = self.input_shape[0] // (2 ** 3) 36 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 37 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 'valid') 38 | x = tf.layers.batch_normalization(x, training=training) 39 | x = tf.nn.relu(x) 40 | 41 | with tf.variable_scope('deconv2'): 42 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same') 43 | x = residual_plain_unit(x, 256, training=training) 44 | x = tf.layers.batch_normalization(x, training=training) 45 | x = tf.nn.relu(x) 46 | 47 | with tf.variable_scope('deconv3'): 48 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same') 49 | x = residual_plain_unit(x, 128, training=training) 50 | x = tf.layers.batch_normalization(x, training=training) 51 | x = tf.nn.relu(x) 52 | 53 | with tf.variable_scope('deconv4'): 54 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same') 55 | x = residual_plain_unit(x, 64, training=training) 56 | x = tf.layers.batch_normalization(x, training=training) 57 | x = tf.nn.relu(x) 58 | 59 | with tf.variable_scope('deconv5'): 60 | d = self.input_shape[2] 61 | x = tf.layers.conv2d(x, d, (5, 5), (1, 1), 'same', 62 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 63 | x = tf.tanh(x) 64 | 65 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') 66 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator') 67 | self.reuse = True 68 | return x 69 | 70 | class Discriminator(object): 71 | def __init__(self, input_shape, use_wnorm=False): 72 | self.input_shape = input_shape 73 | self.variables = None 74 | self.update_ops = None 75 | self.use_wnorm = use_wnorm 76 | self.reuse = False 77 | 78 | def __call__(self, inputs, training=True): 79 | with tf.variable_scope('discriminator', reuse=self.reuse): 80 | with tf.variable_scope('conv1'): 81 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same') 82 | x = residual_plain_unit(x, 64, training=training) 83 | x = tf.layers.batch_normalization(x, training=training) 84 | x = lrelu(x) 85 | 86 | with tf.variable_scope('conv2'): 87 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same') 88 | x = residual_plain_unit(x, 128, training=training) 89 | x = tf.layers.batch_normalization(x, training=training) 90 | x = lrelu(x) 91 | 92 | with tf.variable_scope('conv3'): 93 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same') 94 | x = residual_plain_unit(x, 256, training=training) 95 | x = tf.layers.batch_normalization(x, training=training) 96 | x = lrelu(x) 97 | 98 | with tf.variable_scope('conv4'): 99 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same') 100 | x = residual_plain_unit(x, 512, training=training) 101 | x = tf.layers.batch_normalization(x, training=training) 102 | x = lrelu(x) 103 | 104 | with tf.variable_scope('conv5'): 105 | w = self.input_shape[0] // (2 ** 4) 106 | y = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid') 107 | y = tf.reshape(y, [-1, 1]) 108 | 109 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 110 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator') 111 | self.reuse = True 112 | return y 113 | 114 | class ResNetGAN(BaseModel): 115 | def __init__(self, 116 | input_shape=(64, 64, 3), 117 | z_dims = 128, 118 | name='resnet', 119 | **kwargs 120 | ): 121 | super(ResNetGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 122 | 123 | self.z_dims = z_dims 124 | self.use_wnorm = True 125 | 126 | self.f_gen = None 127 | self.f_dis = None 128 | self.gen_loss = None 129 | self.dis_loss = None 130 | self.train_op = None 131 | 132 | self.gen_acc = None 133 | self.dis_acc = None 134 | 135 | self.x_train = None 136 | self.z_train = None 137 | 138 | self.z_test = None 139 | self.x_test = None 140 | 141 | self.build_model() 142 | 143 | def train_on_batch(self, x_batch, index): 144 | batchsize = x_batch.shape[0] 145 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 146 | 147 | _, g_loss, d_loss, g_acc, d_acc = self.sess.run( 148 | (self.train_op, self.gen_loss, self.dis_loss, self.gen_acc, self.dis_acc), 149 | feed_dict={self.x_train: x_batch, self.z_train: z_sample} 150 | ) 151 | 152 | summary_priod = 1000 153 | if index // summary_priod != (index - batchsize) // summary_priod: 154 | summary = self.sess.run( 155 | self.summary, 156 | feed_dict={self.x_train: x_batch, self.z_train: z_sample, self.z_test: self.test_data} 157 | ) 158 | self.writer.add_summary(summary, index) 159 | 160 | return [ 161 | ('g_loss', g_loss), ('d_loss', d_loss), 162 | ('g_acc', g_acc), ('d_acc', d_acc) 163 | ] 164 | 165 | def predict(self, z_samples): 166 | x_sample = self.sess.run( 167 | self.x_test, 168 | feed_dict={self.z_test: z_samples} 169 | ) 170 | return x_sample 171 | 172 | def make_test_data(self): 173 | self.test_data = np.random.uniform(-1.0, 1.0, size=(self.test_size * self.test_size, self.z_dims)) 174 | 175 | def build_model(self): 176 | # Trainer 177 | self.f_dis = Discriminator(self.input_shape, use_wnorm=self.use_wnorm) 178 | self.f_gen = Generator(self.input_shape, self.z_dims, use_wnorm=self.use_wnorm) 179 | 180 | x_shape = (None,) + self.input_shape 181 | z_shape = (None, self.z_dims) 182 | self.x_train = tf.placeholder(tf.float32, shape=x_shape) 183 | self.z_train = tf.placeholder(tf.float32, shape=z_shape) 184 | x_fake = self.f_gen(self.z_train) 185 | y_fake = self.f_dis(x_fake) 186 | y_real = self.f_dis(self.x_train) 187 | 188 | self.gen_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(y_fake), y_fake) 189 | self.dis_loss = 0.5 * tf.losses.sigmoid_cross_entropy(tf.ones_like(y_real), y_real) + \ 190 | 0.5 * tf.losses.sigmoid_cross_entropy(tf.zeros_like(y_fake), y_fake) 191 | 192 | gen_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 193 | dis_optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 194 | 195 | gen_train_op = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables) 196 | dis_train_op = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables) 197 | 198 | self.gen_acc = binary_accuracy(tf.ones_like(y_fake), y_fake) 199 | self.dis_acc = 0.5 * binary_accuracy(tf.ones_like(y_real), y_real) + \ 200 | 0.5 * binary_accuracy(tf.zeros_like(y_fake), y_fake) 201 | 202 | with tf.control_dependencies([gen_train_op, dis_train_op] + \ 203 | self.f_dis.update_ops + \ 204 | self.f_gen.update_ops): 205 | self.train_op = tf.no_op(name='train') 206 | 207 | # Predictor 208 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 209 | self.x_test = self.f_gen(self.z_test, training=False) 210 | 211 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 212 | 213 | tf.summary.image('x_real', image_cast(self.x_train), 10) 214 | tf.summary.image('x_fake', image_cast(x_fake), 10) 215 | tf.summary.image('x_tile', image_cast(x_tile), 1) 216 | tf.summary.scalar('gen_loss', self.gen_loss) 217 | tf.summary.scalar('dis_loss', self.dis_loss) 218 | tf.summary.scalar('gen_acc', self.gen_acc) 219 | tf.summary.scalar('dis_acc', self.dis_acc) 220 | self.summary = tf.summary.merge_all() 221 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def image_cast(img): 4 | return tf.cast(img * 127.5 + 127.5, tf.uint8) 5 | 6 | def kl_loss(avg, log_var): 7 | with tf.name_scope('KLLoss'): 8 | return tf.reduce_mean(-0.5 * tf.reduce_sum(1.0 + log_var - tf.square(avg) - tf.exp(log_var), axis=-1)) 9 | 10 | def lrelu(x, alpha=0.02): 11 | with tf.name_scope('LeakyReLU'): 12 | return tf.maximum(x, alpha * x) 13 | 14 | def binary_accuracy(y_true, y_pred): 15 | with tf.name_scope('BinaryAccuracy'): 16 | return tf.reduce_mean(tf.cast(tf.equal(y_true, tf.round(tf.sigmoid(y_pred))), dtype=tf.float32)) 17 | 18 | def sample_normal(avg, log_var): 19 | with tf.name_scope('SampleNormal'): 20 | epsilon = tf.random_normal(tf.shape(avg)) 21 | return tf.add(avg, tf.multiply(tf.exp(0.5 * log_var), epsilon)) 22 | 23 | def vgg_conv_unit(x, filters, layers, training=True): 24 | # Convolution 25 | for i in range(layers): 26 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same', 27 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 28 | x = tf.layers.batch_normalization(x, training=training) 29 | x = lrelu(x) 30 | 31 | # Downsample 32 | x = tf.layers.conv2d(x, filters, (2, 2), (2, 2), 'same', 33 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 34 | x = tf.layers.batch_normalization(x, training=training) 35 | x = lrelu(x) 36 | 37 | return x 38 | 39 | def vgg_deconv_unit(x, filters, layers, training=True): 40 | # Upsample 41 | x = tf.layers.conv2d_transpose(x, filters, (2, 2), (2, 2), 'same', 42 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 43 | x = tf.layers.batch_normalization(x, training=training) 44 | x = lrelu(x) 45 | 46 | # Convolution 47 | for i in range(layers): 48 | x = tf.layers.conv2d(x, filters, (3, 3), (1, 1), 'same', 49 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 50 | x = tf.layers.batch_normalization(x, training=training) 51 | x = lrelu(x) 52 | 53 | return x 54 | 55 | def time_format(t): 56 | m, s = divmod(t, 60) 57 | m = int(m) 58 | s = int(s) 59 | if m == 0: 60 | return '%d sec' % s 61 | else: 62 | return '%d min %d sec' % (m, s) 63 | -------------------------------------------------------------------------------- /models/vae.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from .base import BaseModel 5 | from .utils import * 6 | from .wnorm import * 7 | 8 | class Encoder(object): 9 | def __init__(self, input_shape, z_dims, use_wnorm=True): 10 | self.variables = None 11 | self.update_ops = None 12 | self.reuse = False 13 | self.input_shape = input_shape 14 | self.z_dims = z_dims 15 | self.use_wnorm = use_wnorm 16 | 17 | def __call__(self, inputs, training=True): 18 | with tf.variable_scope('encoder', reuse=self.reuse): 19 | with tf.variable_scope('conv1'): 20 | if self.use_wnorm: 21 | x = conv2d_wnorm(inputs, 64, (5, 5), (2, 2), 'same', use_scale=True) 22 | x = tf.layers.batch_normalization(x, scale=False, training=training) 23 | else: 24 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same') 25 | x = tf.layers.batch_normalization(x, training=training) 26 | x = lrelu(x) 27 | 28 | with tf.variable_scope('conv2'): 29 | if self.use_wnorm: 30 | x = conv2d_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True) 31 | x = tf.layers.batch_normalization(x, scale=False, training=training) 32 | else: 33 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same') 34 | x = tf.layers.batch_normalization(x, training=training) 35 | x = lrelu(x) 36 | 37 | with tf.variable_scope('conv3'): 38 | if self.use_wnorm: 39 | x = conv2d_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True) 40 | x = tf.layers.batch_normalization(x, scale=False, training=training) 41 | else: 42 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same') 43 | x = tf.layers.batch_normalization(x, training=training) 44 | x = lrelu(x) 45 | 46 | with tf.variable_scope('conv4'): 47 | if self.use_wnorm: 48 | x = conv2d_wnorm(x, 512, (5, 5), (2, 2), 'same', use_scale=True) 49 | x = tf.layers.batch_normalization(x, scale=False, training=training) 50 | else: 51 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same') 52 | x = tf.layers.batch_normalization(x, training=training) 53 | x = lrelu(x) 54 | 55 | with tf.variable_scope('fc1'): 56 | w = self.input_shape[0] // (2 ** 4) 57 | if self.use_wnorm: 58 | z_avg = conv2d_wnorm(x, self.z_dims, (w, w), (1, 1), 'valid', use_scale=True) 59 | z_log_var = conv2d_wnorm(x, self.z_dims, (w, w), (1, 1), 'valid', use_scale=True) 60 | else: 61 | z_avg = tf.layers.conv2d(x, self.z_dims, (w, w), (1, 1), 'valid') 62 | z_log_var = tf.layers.conv2d(x, self.z_dims, (w, w), (1, 1), 'valid') 63 | 64 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder') 65 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='encoder') 66 | self.reuse = True 67 | 68 | return z_avg, z_log_var 69 | 70 | class Decoder(object): 71 | def __init__(self, input_shape, z_dims, use_wnorm=True): 72 | self.variables = None 73 | self.update_ops = None 74 | self.reuse = False 75 | self.input_shape = input_shape 76 | self.z_dims = z_dims 77 | self.use_wnorm = use_wnorm 78 | 79 | def __call__(self, inputs, training=True): 80 | with tf.variable_scope('decoder', reuse=self.reuse): 81 | with tf.variable_scope('deconv1'): 82 | w = self.input_shape[0] // (2 ** 3) 83 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 84 | if self.use_wnorm: 85 | x = conv2d_transpose_wnorm(x, 256, (w, w), (1, 1), 'valid', use_scale=True) 86 | x = tf.layers.batch_normalization(x, scale=False, training=training) 87 | else: 88 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 'valid') 89 | x = tf.layers.batch_normalization(x, training=training) 90 | x = tf.nn.relu(x) 91 | 92 | with tf.variable_scope('deconv2'): 93 | if self.use_wnorm: 94 | x = conv2d_transpose_wnorm(x, 256, (5, 5), (2, 2), 'same', use_scale=True) 95 | x = tf.layers.batch_normalization(x, scale=False, training=training) 96 | else: 97 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same') 98 | x = tf.layers.batch_normalization(x, training=training) 99 | x = tf.nn.relu(x) 100 | 101 | with tf.variable_scope('deconv3'): 102 | if self.use_wnorm: 103 | x = conv2d_transpose_wnorm(x, 128, (5, 5), (2, 2), 'same', use_scale=True) 104 | x = tf.layers.batch_normalization(x, scale=False, training=training) 105 | else: 106 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same') 107 | x = tf.layers.batch_normalization(x, training=training) 108 | x = tf.nn.relu(x) 109 | 110 | with tf.variable_scope('deconv4'): 111 | if self.use_wnorm: 112 | x = conv2d_transpose_wnorm(x, 64, (5, 5), (2, 2), 'same', use_scale=True) 113 | x = tf.layers.batch_normalization(x, scale=False, training=training) 114 | else: 115 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same') 116 | x = tf.layers.batch_normalization(x, training=training) 117 | x = tf.nn.relu(x) 118 | 119 | with tf.variable_scope('deconv5'): 120 | d = self.input_shape[2] 121 | if self.use_wnorm: 122 | x = conv2d_transpose_wnorm(x, d, (5, 5), (1, 1), 'same', use_scale=True) 123 | else: 124 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same') 125 | x = tf.tanh(x) 126 | 127 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder') 128 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='decoder') 129 | self.reuse = True 130 | 131 | return x 132 | 133 | class VAE(BaseModel): 134 | def __init__(self, 135 | input_shape=(64, 64, 3), 136 | z_dims = 128, 137 | name='vae', 138 | **kwargs 139 | ): 140 | super(VAE, self).__init__(input_shape=input_shape, name=name, **kwargs) 141 | 142 | self.z_dims = z_dims 143 | self.use_wnorm = False 144 | 145 | self.encoder = None 146 | self.decoder = None 147 | self.rec_loss = None 148 | self.kl_loss = None 149 | self.train_op = None 150 | 151 | self.x_train = None 152 | 153 | self.z_test = None 154 | self.x_test = None 155 | 156 | self.build_model() 157 | 158 | def train_on_batch(self, x_batch, index): 159 | _, rec_loss, kl_loss, summary = self.sess.run( 160 | (self.train_op, self.rec_loss, self.kl_loss, self.summary), 161 | feed_dict={self.x_train: x_batch, self.z_test: self.test_data} 162 | ) 163 | self.writer.add_summary(summary, index) 164 | return [ ('rec_loss', rec_loss), ('kl_loss', kl_loss) ] 165 | 166 | def predict(self, z_samples): 167 | x_sample = self.sess.run( 168 | self.x_test, 169 | feed_dict={self.z_test: z_samples} 170 | ) 171 | return x_sample 172 | 173 | def make_test_data(self): 174 | self.test_data = np.random.normal(size=(self.test_size * self.test_size, self.z_dims)) 175 | 176 | def build_model(self): 177 | self.encoder = Encoder(self.input_shape, self.z_dims, self.use_wnorm) 178 | self.decoder = Decoder(self.input_shape, self.z_dims, self.use_wnorm) 179 | 180 | # Trainer 181 | batch_shape = (None,) + self.input_shape 182 | self.x_train = tf.placeholder(tf.float32, shape=batch_shape) 183 | 184 | z_avg, z_log_var = self.encoder(self.x_train) 185 | z_sample = sample_normal(z_avg, z_log_var) 186 | x_sample = self.decoder(z_sample) 187 | 188 | rec_loss_scale = tf.constant(np.prod(self.input_shape), tf.float32) 189 | self.rec_loss = tf.losses.absolute_difference(self.x_train, x_sample) * rec_loss_scale 190 | self.kl_loss = kl_loss(z_avg, z_log_var) 191 | 192 | optim = tf.train.AdamOptimizer(learning_rate=2.0e-4, beta1=0.5) 193 | fmin = optim.minimize(self.rec_loss + self.kl_loss) 194 | 195 | with tf.control_dependencies([fmin] + self.encoder.update_ops + self.decoder.update_ops): 196 | self.train_op = tf.no_op(name='train') 197 | 198 | # Predictor 199 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 200 | self.x_test = self.decoder(self.z_test) 201 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 202 | 203 | # Summary 204 | tf.summary.image('x_real', image_cast(self.x_train), 10) 205 | tf.summary.image('x_fake', image_cast(x_sample), 10) 206 | tf.summary.image('x_tile', image_cast(x_tile), 1) 207 | tf.summary.scalar('rec_loss', self.rec_loss) 208 | tf.summary.scalar('kl_loss', self.kl_loss) 209 | 210 | self.summary = tf.summary.merge_all() 211 | -------------------------------------------------------------------------------- /models/wgan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wasserstain GAN: 3 | This is an implementation of "improved" version of Wasserstein GAN. 4 | Gulrajani et al., "Improved Training of Wasserstein GAN", arXiv preprint, 2017. 5 | """ 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from .base import BaseModel 11 | from .utils import * 12 | 13 | class Generator(object): 14 | def __init__(self, input_shape, z_dims): 15 | self.variables = None 16 | self.update_ops = None 17 | self.reuse = False 18 | self.name = 'generator' 19 | self.input_shape = input_shape 20 | self.z_dims = z_dims 21 | 22 | def __call__(self, inputs, training=True): 23 | with tf.variable_scope(self.name, reuse=self.reuse): 24 | with tf.variable_scope('fc1'): 25 | w = self.input_shape[0] // (2 ** 3) 26 | x = tf.reshape(inputs, [-1, 1, 1, self.z_dims]) 27 | x = tf.layers.conv2d_transpose(x, 256, (w, w), (1, 1), 'valid', 28 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 29 | x = tf.layers.batch_normalization(x, training=training) 30 | x = tf.nn.relu(x) 31 | 32 | with tf.variable_scope('conv1'): 33 | x = tf.layers.conv2d_transpose(x, 256, (5, 5), (2, 2), 'same', 34 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 35 | x = tf.layers.batch_normalization(x, training=training) 36 | x = tf.nn.relu(x) 37 | 38 | with tf.variable_scope('conv2'): 39 | x = tf.layers.conv2d_transpose(x, 128, (5, 5), (2, 2), 'same', 40 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 41 | x = tf.layers.batch_normalization(x, training=training) 42 | x = tf.nn.relu(x) 43 | 44 | with tf.variable_scope('conv3'): 45 | x = tf.layers.conv2d_transpose(x, 64, (5, 5), (2, 2), 'same', 46 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 47 | x = tf.layers.batch_normalization(x, training=training) 48 | x = tf.nn.relu(x) 49 | 50 | with tf.variable_scope('conv4'): 51 | d = self.input_shape[2] 52 | x = tf.layers.conv2d_transpose(x, d, (5, 5), (1, 1), 'same', 53 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 54 | x = tf.tanh(x) 55 | 56 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 57 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name) 58 | self.reuse = True 59 | return x 60 | 61 | class Discriminator(object): 62 | def __init__(self, input_shape): 63 | self.input_shape = input_shape 64 | self.variables = None 65 | self.update_ops = None 66 | self.name = 'discriminator' 67 | self.reuse = False 68 | 69 | def __call__(self, inputs, training=True): 70 | with tf.variable_scope(self.name, reuse=self.reuse): 71 | with tf.variable_scope('conv1'): 72 | x = tf.layers.conv2d(inputs, 64, (5, 5), (2, 2), 'same', 73 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 74 | x = tf.layers.batch_normalization(x, training=training) 75 | x = lrelu(x) 76 | 77 | with tf.variable_scope('conv2'): 78 | x = tf.layers.conv2d(x, 128, (5, 5), (2, 2), 'same', 79 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 80 | x = tf.layers.batch_normalization(x, training=training) 81 | x = lrelu(x) 82 | 83 | with tf.variable_scope('conv3'): 84 | x = tf.layers.conv2d(x, 256, (5, 5), (2, 2), 'same', 85 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 86 | x = tf.layers.batch_normalization(x, training=training) 87 | x = lrelu(x) 88 | 89 | with tf.variable_scope('conv4'): 90 | x = tf.layers.conv2d(x, 512, (5, 5), (2, 2), 'same', 91 | kernel_initializer=tf.contrib.layers.xavier_initializer()) 92 | x = tf.layers.batch_normalization(x, training=training) 93 | x = lrelu(x) 94 | 95 | with tf.variable_scope('conv5'): 96 | w = self.input_shape[0] // (2 ** 4) 97 | x = tf.layers.conv2d(x, 1, (w, w), (1, 1), 'valid', 98 | kernel_initializer=tf.random_normal_initializer(stddev=0.005)) 99 | y = tf.reshape(x, [-1, 1]) 100 | 101 | self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 102 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.name) 103 | self.reuse = True 104 | return y 105 | 106 | class WGAN(BaseModel): 107 | def __init__(self, 108 | input_shape=(64, 64, 3), 109 | z_dims = 128, 110 | name='wgan', 111 | **kwargs 112 | ): 113 | super(WGAN, self).__init__(input_shape=input_shape, name=name, **kwargs) 114 | 115 | self.z_dims = z_dims 116 | self.n_critic = 2 117 | self.lmbda = 10.0 118 | 119 | self.gen_loss = None 120 | self.dis_loss = None 121 | self.gen_train_op = None 122 | self.dis_train_op = None 123 | 124 | self.x_train = None 125 | self.e_random = None 126 | self.batch_idx = 0 127 | 128 | self.z_test = None 129 | self.x_test = None 130 | 131 | self.build_model() 132 | 133 | def train_on_batch(self, x_batch, index): 134 | batchsize = x_batch.shape[0] 135 | self.batch_idx += 1 136 | 137 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 138 | eps = float(np.random.uniform(0.0, 1.0, size=(1))) 139 | _, g_loss, d_loss = self.sess.run( 140 | (self.dis_train_op, self.gen_loss, self.dis_loss), 141 | feed_dict={ 142 | self.x_train: x_batch, 143 | self.z_train: z_sample, 144 | self.e_random: eps 145 | } 146 | ) 147 | 148 | if self.batch_idx % self.n_critic == 0: 149 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 150 | eps = float(np.random.uniform(0.0, 1.0, size=(1))) 151 | _, g_loss, d_loss = self.sess.run( 152 | (self.gen_train_op, self.gen_loss, self.dis_loss), 153 | feed_dict={ 154 | self.x_train: x_batch, 155 | self.z_train: z_sample, 156 | self.e_random: eps, 157 | self.z_test: self.test_data 158 | } 159 | ) 160 | 161 | # Summary update 162 | summary_priod = 1000 163 | if index // summary_priod != (index - batchsize) // summary_priod: 164 | z_sample = np.random.uniform(-1.0, 1.0, size=(batchsize, self.z_dims)) 165 | eps = float(np.random.uniform(0.0, 1.0, size=(1))) 166 | summary = self.sess.run( 167 | self.summary, 168 | feed_dict={ 169 | self.x_train: x_batch, 170 | self.z_train: z_sample, 171 | self.e_random: eps, 172 | self.z_test: self.test_data 173 | } 174 | ) 175 | self.writer.add_summary(summary, index) 176 | 177 | return [ 178 | ('g_loss', g_loss), ('d_loss', d_loss) 179 | ] 180 | 181 | def predict(self, z_samples): 182 | x_sample = self.sess.run( 183 | self.x_test, 184 | feed_dict={self.z_test: z_samples} 185 | ) 186 | return x_sample 187 | 188 | def make_test_data(self): 189 | self.test_data = np.random.uniform(-1, 1, size=(self.test_size * self.test_size, self.z_dims)) 190 | 191 | def build_model(self): 192 | # Trainer 193 | self.f_dis = Discriminator(self.input_shape) 194 | self.f_gen = Generator(self.input_shape, self.z_dims) 195 | 196 | x_shape = (None,) + self.input_shape 197 | z_shape = (None,) + (self.z_dims,) 198 | self.x_train = tf.placeholder(tf.float32, shape=x_shape) 199 | self.z_train = tf.placeholder(tf.float32, shape=z_shape) 200 | self.e_random = tf.placeholder(tf.float32, shape=()) 201 | 202 | x_fake = self.f_gen(self.z_train) 203 | y_fake = self.f_dis(x_fake) 204 | y_real = self.f_dis(self.x_train) 205 | 206 | gen_optim = tf.train.AdamOptimizer(learning_rate=1.0e-4, beta1=0.0, beta2=0.9) 207 | dis_optim = tf.train.AdamOptimizer(learning_rate=1.0e-4, beta1=0.0, beta2=0.9) 208 | 209 | x_hat = self.e_random * self.x_train + (1.0 - self.e_random) * x_fake 210 | y_hat = self.f_dis(x_hat) 211 | d_grad = tf.gradients(y_hat, [x_hat]) 212 | d_reg = tf.square(1.0 - tf.sqrt(tf.reduce_sum(tf.square(d_grad)))) 213 | 214 | self.gen_loss = -tf.reduce_mean(y_fake) 215 | self.dis_loss = -tf.reduce_mean(y_real) + tf.reduce_mean(y_fake) + self.lmbda * d_reg 216 | 217 | gen_optim_min = gen_optim.minimize(self.gen_loss, var_list=self.f_gen.variables) 218 | with tf.control_dependencies([gen_optim_min] + self.f_gen.update_ops): 219 | self.gen_train_op = tf.no_op(name='gen_train') 220 | 221 | dis_optim_min = dis_optim.minimize(self.dis_loss, var_list=self.f_dis.variables) 222 | 223 | with tf.control_dependencies([dis_optim_min] + self.f_dis.update_ops): 224 | self.dis_train_op = tf.no_op(name='dis_train') 225 | 226 | # Predictor 227 | self.z_test = tf.placeholder(tf.float32, shape=(None, self.z_dims)) 228 | self.x_test = self.f_gen(self.z_test, training=False) 229 | 230 | x_tile = self.image_tiling(self.x_test, self.test_size, self.test_size) 231 | 232 | tf.summary.image('x_real', image_cast(self.x_train), 10) 233 | tf.summary.image('x_fake', image_cast(x_fake), 10) 234 | tf.summary.image('x_tile', image_cast(x_tile), 1) 235 | tf.summary.scalar('gen_loss', self.gen_loss) 236 | tf.summary.scalar('dis_loss', self.dis_loss) 237 | self.summary = tf.summary.merge_all() 238 | -------------------------------------------------------------------------------- /models/wnorm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.eager import context 3 | from tensorflow.python.framework import ops, tensor_shape 4 | from tensorflow.python.layers import base, utils 5 | from tensorflow.python.ops import nn, standard_ops, array_ops, init_ops, nn_ops 6 | 7 | class DenseWNorm(base.Layer): 8 | def __init__(self, units, 9 | activation=None, 10 | use_scale=True, 11 | use_bias=True, 12 | kernel_initializer=None, 13 | scale_initializer=None, 14 | bias_initializer=init_ops.zeros_initializer(), 15 | kernel_regularizer=None, 16 | scale_regularizer=None, 17 | bias_regularizer=None, 18 | activity_regularizer=None, 19 | kernel_constraint=None, 20 | scale_constraint=None, 21 | bias_constraint=None, 22 | trainable=True, 23 | name=None, 24 | **kwargs): 25 | super(DenseWNorm, self).__init__(trainable=trainable, name=name, 26 | activity_regularizer=activity_regularizer, 27 | **kwargs) 28 | self.units = units 29 | self.activation = activation 30 | self.use_scale = use_scale 31 | self.use_bias = use_bias 32 | self.kernel_initializer = kernel_initializer 33 | self.scale_initializer = scale_initializer 34 | self.bias_initializer = bias_initializer 35 | self.kernel_regularizer = kernel_regularizer 36 | self.scale_regularizer = scale_regularizer 37 | self.bias_regularizer = bias_regularizer 38 | self.kernel_constraint = kernel_constraint 39 | self.scale_constraint = scale_constraint 40 | self.bias_constraint = bias_constraint 41 | self.input_spec = base.InputSpec(min_ndim=2) 42 | 43 | def build(self, input_shape): 44 | input_shape = tensor_shape.TensorShape(input_shape) 45 | if input_shape[-1].value is None: 46 | raise ValueError('The last dimension of the inputs to `Dense` ' 47 | 'should be defined. Found `None`.') 48 | self.input_spec = base.InputSpec(min_ndim=2, 49 | axes={-1: input_shape[-1].value}) 50 | self.kernel = self.add_variable('kernel', 51 | shape=[input_shape[-1].value, self.units], 52 | initializer=self.kernel_initializer, 53 | regularizer=self.kernel_regularizer, 54 | constraint=self.kernel_constraint, 55 | dtype=self.dtype, 56 | trainable=True) 57 | 58 | if self.use_scale: 59 | self.scale = self.add_variable('scale', 60 | shape=[self.units,], 61 | initializer=self.scale_initializer, 62 | regularizer=self.scale_regularizer, 63 | constraint=self.scale_constraint, 64 | dtype=self.dtype, 65 | trainable=True) 66 | else: 67 | self.scale = 1.0 68 | 69 | if self.use_bias: 70 | self.bias = self.add_variable('bias', 71 | shape=[self.units,], 72 | initializer=self.bias_initializer, 73 | regularizer=self.bias_regularizer, 74 | constraint=self.bias_constraint, 75 | dtype=self.dtype, 76 | trainable=True) 77 | else: 78 | self.bias = None 79 | self.built = True 80 | 81 | def call(self, inputs): 82 | inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) 83 | shape = inputs.get_shape().as_list() 84 | 85 | if len(shape) > 2: 86 | # Broadcasting is required for the inputs. 87 | outputs = standard_ops.tensordot(inputs, self.kernel, [[len(shape) - 1], 88 | [0]]) 89 | # Reshape the output back to the original ndim of the input. 90 | if context.in_graph_mode(): 91 | output_shape = shape[:-1] + [self.units] 92 | outputs.set_shape(output_shape) 93 | else: 94 | outputs = standard_ops.matmul(inputs, self.kernel) 95 | 96 | scaler = self.scale / tf.sqrt(tf.reduce_sum(tf.square(self.kernel), [0])) 97 | outputs = scaler * outputs 98 | 99 | if self.use_bias: 100 | outputs = nn.bias_add(outputs, self.bias) 101 | if self.activation is not None: 102 | return self.activation(outputs) # pylint: disable=not-callable 103 | return outputs 104 | 105 | def _compute_output_shape(self, input_shape): 106 | input_shape = tensor_shape.TensorShape(input_shape) 107 | input_shape = input_shape.with_rank_at_least(2) 108 | if input_shape[-1].value is None: 109 | raise ValueError( 110 | 'The innermost dimension of input_shape must be defined, but saw: %s' 111 | % input_shape) 112 | return input_shape[:-1].concatenate(self.units) 113 | 114 | 115 | def dense_wnorm( 116 | inputs, units, 117 | activation=None, 118 | use_scale=True, 119 | use_bias=True, 120 | kernel_initializer=None, 121 | scale_initializer=None, 122 | bias_initializer=init_ops.zeros_initializer(), 123 | kernel_regularizer=None, 124 | scale_regularizer=None, 125 | bias_regularizer=None, 126 | activity_regularizer=None, 127 | kernel_constraint=None, 128 | scale_constraint=None, 129 | bias_constraint=None, 130 | trainable=True, 131 | name=None, 132 | reuse=None): 133 | 134 | layer = DenseWNorm(units, 135 | activation=activation, 136 | use_scale=use_scale, 137 | use_bias=use_bias, 138 | kernel_initializer=kernel_initializer, 139 | scale_initializer=scale_initializer, 140 | bias_initializer=bias_initializer, 141 | kernel_regularizer=kernel_regularizer, 142 | scale_regularizer=scale_regularizer, 143 | bias_regularizer=bias_regularizer, 144 | activity_regularizer=activity_regularizer, 145 | kernel_constraint=kernel_constraint, 146 | scale_constraint=scale_constraint, 147 | bias_constraint=bias_constraint, 148 | trainable=trainable, 149 | name=name, 150 | dtype=inputs.dtype.base_dtype, 151 | _scope=name, 152 | _reuse=reuse) 153 | return layer.apply(inputs) 154 | 155 | 156 | class _ConvWNorm(base.Layer): 157 | def __init__(self, rank, 158 | filters, 159 | kernel_size, 160 | strides=1, 161 | padding='valid', 162 | data_format='channels_last', 163 | dilation_rate=1, 164 | activation=None, 165 | use_scale=True, 166 | use_bias=True, 167 | kernel_initializer=None, 168 | scale_initializer=None, 169 | bias_initializer=init_ops.zeros_initializer(), 170 | scale_regularizer=None, 171 | kernel_regularizer=None, 172 | bias_regularizer=None, 173 | activity_regularizer=None, 174 | kernel_constraint=None, 175 | scale_constraint=None, 176 | bias_constraint=None, 177 | trainable=True, 178 | name=None, 179 | **kwargs): 180 | super(_ConvWNorm, self).__init__(trainable=trainable, name=name, 181 | activity_regularizer=activity_regularizer, 182 | **kwargs) 183 | self.rank = rank 184 | self.filters = filters 185 | self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size') 186 | self.strides = utils.normalize_tuple(strides, rank, 'strides') 187 | self.padding = utils.normalize_padding(padding) 188 | self.data_format = utils.normalize_data_format(data_format) 189 | self.dilation_rate = utils.normalize_tuple( 190 | dilation_rate, rank, 'dilation_rate') 191 | self.activation = activation 192 | self.use_scale = use_scale 193 | self.use_bias = use_bias 194 | self.kernel_initializer = kernel_initializer 195 | self.scale_initializer = scale_initializer 196 | self.bias_initializer = bias_initializer 197 | self.kernel_regularizer = kernel_regularizer 198 | self.scale_regularizer = scale_regularizer 199 | self.bias_regularizer = bias_regularizer 200 | self.kernel_constraint = kernel_constraint 201 | self.scale_constraint = scale_constraint 202 | self.bias_constraint = bias_constraint 203 | self.input_spec = base.InputSpec(ndim=self.rank + 2) 204 | 205 | def build(self, input_shape): 206 | input_shape = tensor_shape.TensorShape(input_shape) 207 | if self.data_format == 'channels_first': 208 | channel_axis = 1 209 | else: 210 | channel_axis = -1 211 | if input_shape[channel_axis].value is None: 212 | raise ValueError('The channel dimension of the inputs ' 213 | 'should be defined. Found `None`.') 214 | input_dim = input_shape[channel_axis].value 215 | kernel_shape = self.kernel_size + (input_dim, self.filters) 216 | 217 | self.kernel = self.add_variable(name='kernel', 218 | shape=kernel_shape, 219 | initializer=self.kernel_initializer, 220 | regularizer=self.kernel_regularizer, 221 | constraint=self.kernel_constraint, 222 | trainable=True, 223 | dtype=self.dtype) 224 | 225 | if self.use_scale: 226 | self.scale = self.add_variable(name='scale', 227 | shape=(self.filters,), 228 | initializer=self.scale_initializer, 229 | regularizer=self.scale_regularizer, 230 | constraint=self.scale_constraint, 231 | trainable=True, 232 | dtype=self.dtype) 233 | else: 234 | self.scale = None 235 | 236 | if self.use_bias: 237 | self.bias = self.add_variable(name='bias', 238 | shape=(self.filters,), 239 | initializer=self.bias_initializer, 240 | regularizer=self.bias_regularizer, 241 | constraint=self.bias_constraint, 242 | trainable=True, 243 | dtype=self.dtype) 244 | else: 245 | self.bias = None 246 | 247 | self.input_spec = base.InputSpec(ndim=self.rank + 2, 248 | axes={channel_axis: input_dim}) 249 | 250 | self._convolution_op = nn_ops.Convolution( 251 | input_shape, 252 | filter_shape=self.kernel.get_shape(), 253 | dilation_rate=self.dilation_rate, 254 | strides=self.strides, 255 | padding=self.padding.upper(), 256 | data_format=utils.convert_data_format(self.data_format, 257 | self.rank + 2)) 258 | self.built = True 259 | 260 | def call(self, inputs): 261 | kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 2]) 262 | if self.use_scale: 263 | kernel_norm = tf.reshape(self.scale, [1, 1, 1, self.filters]) * kernel_norm 264 | outputs = self._convolution_op(inputs, kernel_norm) 265 | 266 | if self.use_bias: 267 | if self.data_format == 'channels_first': 268 | if self.rank == 1: 269 | # nn.bias_add does not accept a 1D input tensor. 270 | bias = array_ops.reshape(self.bias, (1, self.filters, 1)) 271 | outputs += bias 272 | if self.rank == 2: 273 | outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') 274 | if self.rank == 3: 275 | # As of Mar 2017, direct addition is significantly slower than 276 | # bias_add when computing gradients. To use bias_add, we collapse Z 277 | # and Y into a single dimension to obtain a 4D input tensor. 278 | outputs_shape = outputs.shape.as_list() 279 | outputs_4d = array_ops.reshape(outputs, 280 | [outputs_shape[0], outputs_shape[1], 281 | outputs_shape[2] * outputs_shape[3], 282 | outputs_shape[4]]) 283 | outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW') 284 | outputs = array_ops.reshape(outputs_4d, outputs_shape) 285 | else: 286 | outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') 287 | 288 | if self.activation is not None: 289 | return self.activation(outputs) 290 | return outputs 291 | 292 | def _compute_output_shape(self, input_shape): 293 | input_shape = tensor_shape.TensorShape(input_shape).as_list() 294 | if self.data_format == 'channels_last': 295 | space = input_shape[1:-1] 296 | new_space = [] 297 | for i in range(len(space)): 298 | new_dim = utils.conv_output_length( 299 | space[i], 300 | self.kernel_size[i], 301 | padding=self.padding, 302 | stride=self.strides[i], 303 | dilation=self.dilation_rate[i]) 304 | new_space.append(new_dim) 305 | return tensor_shape.TensorShape([input_shape[0]] + new_space + 306 | [self.filters]) 307 | else: 308 | space = input_shape[2:] 309 | new_space = [] 310 | for i in range(len(space)): 311 | new_dim = utils.conv_output_length( 312 | space[i], 313 | self.kernel_size[i], 314 | padding=self.padding, 315 | stride=self.strides[i], 316 | dilation=self.dilation_rate[i]) 317 | new_space.append(new_dim) 318 | return tensor_shape.TensorShape([input_shape[0], self.filters] + new_space) 319 | 320 | class Conv2DWNorm(_ConvWNorm): 321 | def __init__(self, filters, 322 | kernel_size, 323 | strides=(1, 1), 324 | padding='valid', 325 | data_format='channels_last', 326 | dilation_rate=(1, 1), 327 | activation=None, 328 | use_scale=True, 329 | use_bias=True, 330 | kernel_initializer=None, 331 | scale_initializer=None, 332 | bias_initializer=init_ops.zeros_initializer(), 333 | kernel_regularizer=None, 334 | scale_regularizer=None, 335 | bias_regularizer=None, 336 | activity_regularizer=None, 337 | kernel_constraint=None, 338 | scale_constraint=None, 339 | bias_constraint=None, 340 | trainable=True, 341 | name=None, 342 | **kwargs): 343 | super(Conv2DWNorm, self).__init__( 344 | rank=2, 345 | filters=filters, 346 | kernel_size=kernel_size, 347 | strides=strides, 348 | padding=padding, 349 | data_format=data_format, 350 | dilation_rate=dilation_rate, 351 | activation=activation, 352 | use_scale=use_scale, 353 | use_bias=use_bias, 354 | kernel_initializer=kernel_initializer, 355 | scale_initializer=scale_initializer, 356 | bias_initializer=bias_initializer, 357 | kernel_regularizer=kernel_regularizer, 358 | scale_regularizer=scale_regularizer, 359 | bias_regularizer=bias_regularizer, 360 | activity_regularizer=activity_regularizer, 361 | kernel_constraint=kernel_constraint, 362 | scale_constraint=scale_constraint, 363 | bias_constraint=bias_constraint, 364 | trainable=trainable, 365 | name=name, **kwargs) 366 | 367 | 368 | def conv2d_wnorm(inputs, 369 | filters, 370 | kernel_size, 371 | strides=(1, 1), 372 | padding='valid', 373 | data_format='channels_last', 374 | dilation_rate=(1, 1), 375 | activation=None, 376 | use_scale=True, 377 | use_bias=True, 378 | kernel_initializer=None, 379 | scale_initializer=None, 380 | bias_initializer=init_ops.zeros_initializer(), 381 | kernel_regularizer=None, 382 | scale_regularizer=None, 383 | bias_regularizer=None, 384 | activity_regularizer=None, 385 | kernel_constraint=None, 386 | scale_constraint=None, 387 | bias_constraint=None, 388 | trainable=True, 389 | name=None, 390 | reuse=None): 391 | 392 | layer = Conv2DWNorm( 393 | filters=filters, 394 | kernel_size=kernel_size, 395 | strides=strides, 396 | padding=padding, 397 | data_format=data_format, 398 | dilation_rate=dilation_rate, 399 | activation=activation, 400 | use_scale=use_scale, 401 | use_bias=use_bias, 402 | kernel_initializer=kernel_initializer, 403 | scale_initializer=scale_initializer, 404 | bias_initializer=bias_initializer, 405 | kernel_regularizer=kernel_regularizer, 406 | scale_regularizer=scale_regularizer, 407 | bias_regularizer=bias_regularizer, 408 | activity_regularizer=activity_regularizer, 409 | kernel_constraint=kernel_constraint, 410 | scale_constraint=scale_constraint, 411 | bias_constraint=bias_constraint, 412 | trainable=trainable, 413 | name=name, 414 | dtype=inputs.dtype.base_dtype, 415 | _reuse=reuse, 416 | _scope=name) 417 | return layer.apply(inputs) 418 | 419 | class Conv2DTransposeWNorm(Conv2DWNorm): 420 | def __init__(self, filters, 421 | kernel_size, 422 | strides=(1, 1), 423 | padding='valid', 424 | data_format='channels_last', 425 | activation=None, 426 | use_scale=True, 427 | use_bias=True, 428 | kernel_initializer=None, 429 | scale_initializer=None, 430 | bias_initializer=init_ops.zeros_initializer(), 431 | kernel_regularizer=None, 432 | scale_regularizer=None, 433 | bias_regularizer=None, 434 | activity_regularizer=None, 435 | kernel_constraint=None, 436 | scale_constraint=None, 437 | bias_constraint=None, 438 | trainable=True, 439 | name=None, 440 | **kwargs): 441 | super(Conv2DTransposeWNorm, self).__init__( 442 | filters, 443 | kernel_size, 444 | strides=strides, 445 | padding=padding, 446 | data_format=data_format, 447 | activation=activation, 448 | use_scale=use_scale, 449 | use_bias=use_bias, 450 | kernel_initializer=kernel_initializer, 451 | scale_initializer=scale_initializer, 452 | bias_initializer=bias_initializer, 453 | kernel_regularizer=kernel_regularizer, 454 | scale_regularizer=scale_regularizer, 455 | bias_regularizer=bias_regularizer, 456 | activity_regularizer=activity_regularizer, 457 | kernel_constraint=kernel_constraint, 458 | scale_constraint=scale_constraint, 459 | bias_constraint=bias_constraint, 460 | trainable=trainable, 461 | name=name, 462 | **kwargs) 463 | self.input_spec = base.InputSpec(ndim=4) 464 | 465 | def build(self, input_shape): 466 | if len(input_shape) != 4: 467 | raise ValueError('Inputs should have rank ' + 468 | str(4) + 469 | 'Received input shape:', str(input_shape)) 470 | if self.data_format == 'channels_first': 471 | channel_axis = 1 472 | else: 473 | channel_axis = -1 474 | if input_shape[channel_axis] is None: 475 | raise ValueError('The channel dimension of the inputs ' 476 | 'should be defined. Found `None`.') 477 | input_dim = input_shape[channel_axis] 478 | self.input_spec = base.InputSpec(ndim=4, axes={channel_axis: input_dim}) 479 | kernel_shape = self.kernel_size + (self.filters, input_dim) 480 | 481 | self.kernel = self.add_variable(name='kernel', 482 | shape=kernel_shape, 483 | initializer=self.kernel_initializer, 484 | regularizer=self.kernel_regularizer, 485 | constraint=self.kernel_constraint, 486 | trainable=True, 487 | dtype=self.dtype) 488 | 489 | if self.use_scale: 490 | self.scale = self.add_variable(name='scale', 491 | shape=(self.filters,), 492 | initializer=self.scale_initializer, 493 | regularizer=self.scale_regularizer, 494 | constraint=self.scale_constraint, 495 | trainable=True, 496 | dtype=self.dtype) 497 | else: 498 | self.scale = None 499 | 500 | 501 | if self.use_bias: 502 | self.bias = self.add_variable(name='bias', 503 | shape=(self.filters,), 504 | initializer=self.bias_initializer, 505 | regularizer=self.bias_regularizer, 506 | constraint=self.bias_constraint, 507 | trainable=True, 508 | dtype=self.dtype) 509 | else: 510 | self.bias = None 511 | self.built = True 512 | 513 | def call(self, inputs): 514 | inputs_shape = array_ops.shape(inputs) 515 | batch_size = inputs_shape[0] 516 | if self.data_format == 'channels_first': 517 | c_axis, h_axis, w_axis = 1, 2, 3 518 | else: 519 | c_axis, h_axis, w_axis = 3, 1, 2 520 | 521 | height, width = inputs_shape[h_axis], inputs_shape[w_axis] 522 | kernel_h, kernel_w = self.kernel_size 523 | stride_h, stride_w = self.strides 524 | 525 | # Infer the dynamic output shape: 526 | out_height = utils.deconv_output_length(height, 527 | kernel_h, 528 | self.padding, 529 | stride_h) 530 | out_width = utils.deconv_output_length(width, 531 | kernel_w, 532 | self.padding, 533 | stride_w) 534 | if self.data_format == 'channels_first': 535 | output_shape = (batch_size, self.filters, out_height, out_width) 536 | strides = (1, 1, stride_h, stride_w) 537 | else: 538 | output_shape = (batch_size, out_height, out_width, self.filters) 539 | strides = (1, stride_h, stride_w, 1) 540 | 541 | output_shape_tensor = array_ops.stack(output_shape) 542 | 543 | kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 3]) 544 | if self.use_scale: 545 | kernel_norm = tf.reshape(self.scale, [1, 1, self.filters, 1]) * kernel_norm 546 | 547 | outputs = nn.conv2d_transpose( 548 | inputs, 549 | kernel_norm, 550 | output_shape_tensor, 551 | strides, 552 | padding=self.padding.upper(), 553 | data_format=utils.convert_data_format(self.data_format, ndim=4)) 554 | 555 | if context.in_graph_mode(): 556 | # Infer the static output shape: 557 | out_shape = inputs.get_shape().as_list() 558 | out_shape[c_axis] = self.filters 559 | out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis], 560 | kernel_h, 561 | self.padding, 562 | stride_h) 563 | out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis], 564 | kernel_w, 565 | self.padding, 566 | stride_w) 567 | outputs.set_shape(out_shape) 568 | 569 | if self.use_bias: 570 | outputs = nn.bias_add( 571 | outputs, 572 | self.bias, 573 | data_format=utils.convert_data_format(self.data_format, ndim=4)) 574 | 575 | if self.activation is not None: 576 | return self.activation(outputs) 577 | return outputs 578 | 579 | def _compute_output_shape(self, input_shape): 580 | input_shape = tensor_shape.TensorShape(input_shape).as_list() 581 | output_shape = list(input_shape) 582 | if self.data_format == 'channels_first': 583 | c_axis, h_axis, w_axis = 1, 2, 3 584 | else: 585 | c_axis, h_axis, w_axis = 3, 1, 2 586 | 587 | kernel_h, kernel_w = self.kernel_size 588 | stride_h, stride_w = self.strides 589 | 590 | output_shape[c_axis] = self.filters 591 | output_shape[h_axis] = utils.deconv_output_length( 592 | output_shape[h_axis], kernel_h, self.padding, stride_h) 593 | output_shape[w_axis] = utils.deconv_output_length( 594 | output_shape[w_axis], kernel_w, self.padding, stride_w) 595 | return tensor_shape.TensorShape(output_shape) 596 | 597 | 598 | def conv2d_transpose_wnorm( 599 | inputs, 600 | filters, 601 | kernel_size, 602 | strides=(1, 1), 603 | padding='valid', 604 | data_format='channels_last', 605 | activation=None, 606 | use_scale=True, 607 | use_bias=True, 608 | kernel_initializer=None, 609 | scale_initializer=None, 610 | bias_initializer=init_ops.zeros_initializer(), 611 | kernel_regularizer=None, 612 | scale_regularizer=None, 613 | bias_regularizer=None, 614 | activity_regularizer=None, 615 | kernel_constraint=None, 616 | scale_constraint=None, 617 | bias_constraint=None, 618 | trainable=True, 619 | name=None, 620 | reuse=None): 621 | layer = Conv2DTransposeWNorm( 622 | filters=filters, 623 | kernel_size=kernel_size, 624 | strides=strides, 625 | padding=padding, 626 | data_format=data_format, 627 | activation=activation, 628 | use_scale=use_scale, 629 | use_bias=use_bias, 630 | kernel_initializer=kernel_initializer, 631 | scale_initializer=scale_initializer, 632 | bias_initializer=bias_initializer, 633 | kernel_regularizer=kernel_regularizer, 634 | scale_regularizer=scale_regularizer, 635 | bias_regularizer=bias_regularizer, 636 | activity_regularizer=activity_regularizer, 637 | kernel_constraint=kernel_constraint, 638 | scale_constraint=scale_constraint, 639 | bias_constraint=bias_constraint, 640 | trainable=trainable, 641 | name=name, 642 | dtype=inputs.dtype.base_dtype, 643 | _reuse=reuse, 644 | _scope=name) 645 | return layer.apply(inputs) 646 | -------------------------------------------------------------------------------- /results/svhn_cvaegan_epoch_0050_batch_73257.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tatsy/tf-generative/5d7fe9e8a84d0d6f82553fc1eb32c4fdadd0d1b2/results/svhn_cvaegan_epoch_0050_batch_73257.png -------------------------------------------------------------------------------- /results/svhn_dcgan_epoch_0050_batch_73257.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tatsy/tf-generative/5d7fe9e8a84d0d6f82553fc1eb32c4fdadd0d1b2/results/svhn_dcgan_epoch_0050_batch_73257.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 5 | 6 | import numpy as np 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | 10 | import tensorflow as tf 11 | 12 | from models import * 13 | from datasets import load_data, mnist, svhn 14 | 15 | models = { 16 | 'vae': VAE, 17 | 'dcgan': DCGAN, 18 | 'improved': ImprovedGAN, 19 | 'resnet': ResNetGAN, 20 | 'began': BEGAN, 21 | 'wgan': WGAN, 22 | 'lsgan': LSGAN, 23 | 'cvae': CVAE, 24 | 'cvaegan': CVAEGAN 25 | } 26 | 27 | def main(_): 28 | # Parsing arguments 29 | parser = argparse.ArgumentParser(description='Training GANs or VAEs') 30 | parser.add_argument('--model', type=str, required=True) 31 | parser.add_argument('--dataset', type=str, required=True) 32 | parser.add_argument('--datasize', type=int, default=-1) 33 | parser.add_argument('--epoch', type=int, default=200) 34 | parser.add_argument('--batchsize', type=int, default=50) 35 | parser.add_argument('--output', default='output') 36 | parser.add_argument('--zdims', type=int, default=256) 37 | parser.add_argument('--gpu', type=int, default=0) 38 | parser.add_argument('--resume', type=str, default=None) 39 | parser.add_argument('--testmode', action='store_true') 40 | 41 | args = parser.parse_args() 42 | 43 | # select gpu 44 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 45 | 46 | # Make output direcotiry if not exists 47 | if not os.path.isdir(args.output): 48 | os.mkdir(args.output) 49 | 50 | # Load datasets 51 | if args.dataset == 'mnist': 52 | datasets = mnist.load_data() 53 | elif args.dataset == 'svhn': 54 | datasets = svhn.load_data() 55 | else: 56 | datasets = load_data(args.dataset, args.datasize) 57 | 58 | # Construct model 59 | if args.model not in models: 60 | raise Exception('Unknown model:', args.model) 61 | 62 | model = models[args.model]( 63 | batchsize=args.batchsize, 64 | input_shape=datasets.shape[1:], 65 | attr_names=None or datasets.attr_names, 66 | z_dims=args.zdims, 67 | output=args.output, 68 | resume=args.resume 69 | ) 70 | 71 | if args.testmode: 72 | model.test_mode = True 73 | 74 | tf.set_random_seed(12345) 75 | 76 | # Training loop 77 | datasets.images = datasets.images.astype('float32') * 2.0 - 1.0 78 | model.main_loop(datasets, 79 | epochs=args.epoch) 80 | 81 | if __name__ == '__main__': 82 | tf.app.run(main) --------------------------------------------------------------------------------