├── .gitignore ├── README.md ├── bgan ├── __init__.py ├── component.py ├── mnist │ ├── __init__.py │ ├── components.py │ ├── models.py │ └── presets │ │ ├── __init__.py │ │ ├── discriminator │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── cnn_bn.py │ │ ├── mlp.py │ │ └── mlp_bn.py │ │ └── generator │ │ ├── __init__.py │ │ ├── cnn_bernoulli.py │ │ ├── cnn_real.py │ │ ├── cnn_round.py │ │ ├── mlp_bernoulli.py │ │ ├── mlp_real.py │ │ └── mlp_round.py ├── model.py └── utils │ ├── __init__.py │ ├── image_io.py │ ├── neuralnet.py │ └── ops.py ├── config.py ├── docs ├── _config.yml ├── _includes │ ├── audio_player.html │ ├── icon_link.html │ └── video_player.html ├── _layouts │ └── default.html ├── _sass │ └── jekyll-theme-minimal.scss ├── abstract.md ├── background.md ├── data.md ├── figs │ ├── cnn_dbn.png │ ├── cnn_sbn.png │ ├── colored_dbn.png │ ├── colored_sbn.png │ ├── colormap.png │ ├── formula_dbn.png │ ├── formula_sbn.png │ ├── histogram.png │ ├── mlp_dbn.png │ ├── mlp_dbn_gan.png │ ├── mlp_dbn_wgan.png │ ├── mlp_real.png │ ├── mlp_real_bernoulli.png │ ├── mlp_real_round.png │ ├── mlp_sbn.png │ ├── mlp_sbn_gan.png │ ├── mlp_sbn_wgan.png │ ├── system.png │ └── train.png ├── index.md ├── model.md ├── paper.md ├── pdf │ ├── binarygan-arxiv-paper.pdf │ └── binarygan-arxiv-slides.pdf └── results.md ├── train.py └── training_data ├── README.md ├── download_mnist.sh └── load_mnist_to_sa.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # vscode 104 | .vscode/* 105 | 106 | # Experiments 107 | .backup/ 108 | .recycle/ 109 | analysis/ 110 | data/ 111 | exp/ 112 | logs/ 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BinaryGAN 2 | 3 | ## Prepare Training Data 4 | 5 | - Download MNIST database by running the script: 6 | 7 | ```sh 8 | ./training_data/download_mnist.sh 9 | ``` 10 | 11 | - or download it manually: 12 | 1. Download MNIST database [here](http://yann.lecun.com/exdb/mnist/) 13 | 2. Decompress all the `.gz` files 14 | 3. Move the decompressed files to `./training_data/mnist` 15 | 16 | - Store the data to shared memory (optional) 17 | 18 | > Make sure the SharedArray package has been installed. 19 | 20 | ```sh 21 | python ./training_data/load_mnist_to_sa.py ./training_data/mnist/ \ 22 | --merge --binary 23 | ``` 24 | 25 | ## Configuration 26 | 27 | Modify `config.py` for configuration. 28 | 29 | - Quick setup 30 | 31 | Change the values in the dictionary `SETUP` for a quick setup. Documentation 32 | is provided right after each key. 33 | 34 | - More configuration options 35 | 36 | Four dictionaries `EXP_CONFIG`, `DATA_CONFIG`, `MODEL_CONFIG` and 37 | `TRAIN_CONFIG` define experiment-, data-, model- and training-related 38 | configuration variables, respectively. 39 | 40 | > The automatically-determined experiment name is based only on the values 41 | defined in the dictionary `SETUP`, so remember to provide the experiment name 42 | manually when you modify any other configuration variables so that you won't 43 | overwrite a trained model. 44 | 45 | ## Train the model 46 | 47 | ```sh 48 | python train.py 49 | ``` 50 | -------------------------------------------------------------------------------- /bgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/bgan/__init__.py -------------------------------------------------------------------------------- /bgan/component.py: -------------------------------------------------------------------------------- 1 | """Base class for the components.""" 2 | from collections import OrderedDict 3 | import tensorflow as tf 4 | from bgan.utils.neuralnet import NeuralNet 5 | 6 | class Component(object): 7 | """Base class for components.""" 8 | def __init__(self, tensor_in, condition, slope_tensor=None): 9 | if not isinstance(tensor_in, (tf.Tensor, list, dict)): 10 | raise TypeError("`tensor_in` must be of tf.Tensor type or a list " 11 | "(or dict) of tf.Tensor objects") 12 | if isinstance(tensor_in, list): 13 | for tensor in tensor_in: 14 | if not isinstance(tensor, tf.Tensor): 15 | raise TypeError("`tensor_in` must be of tf.Tensor type or " 16 | "a list (or dict) of tf.Tensor objects") 17 | if isinstance(tensor_in, dict): 18 | for key in tensor_in: 19 | if not isinstance(tensor_in[key], tf.Tensor): 20 | raise TypeError("`tensor_in` must be of tf.Tensor type or " 21 | "a list (or dict) of tf.Tensor objects") 22 | 23 | self.tensor_in = tensor_in 24 | self.condition = condition 25 | self.slope_tensor = slope_tensor 26 | 27 | self.scope = None 28 | self.tensor_out = tensor_in 29 | self.nets = OrderedDict() 30 | self.vars = None 31 | 32 | def __repr__(self): 33 | if isinstance(self.tensor_in, tf.Tensor): 34 | input_shape = self.tensor_in.get_shape() 35 | else: 36 | input_shape = ', '.join([ 37 | '{}: {}'.format(key, self.tensor_in[key].get_shape()) 38 | for key in self.tensor_in]) 39 | return "Component({}, input_shape={}, output_shape={})".format( 40 | self.scope.name, input_shape, str(self.tensor_out.get_shape())) 41 | 42 | def get_summary(self): 43 | """Return the summary string.""" 44 | cleansed_nets = [] 45 | for net in self.nets.values(): 46 | if isinstance(net, NeuralNet): 47 | if net.scope is not None: 48 | cleansed_nets.append(net) 49 | if isinstance(net, list): 50 | if net[0].scope is not None: 51 | cleansed_nets.append(net[0]) 52 | 53 | if isinstance(self.tensor_in, tf.Tensor): 54 | input_strs = ["{:50}{}".format('Input', self.tensor_in.get_shape())] 55 | else: 56 | input_strs = ["{:50}{}".format('Input - ' + key, 57 | self.tensor_in[key].get_shape()) 58 | for key in self.tensor_in] 59 | 60 | return '\n'.join( 61 | ["{:-^80}".format(' ' + self.scope.name + ' ')] + input_strs 62 | + ['-' * 80 + '\n' + x.get_summary() for x in cleansed_nets] 63 | ) 64 | -------------------------------------------------------------------------------- /bgan/mnist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/bgan/mnist/__init__.py -------------------------------------------------------------------------------- /bgan/mnist/components.py: -------------------------------------------------------------------------------- 1 | """Classes that define the generator, discriminator and the end-to-end 2 | generator.""" 3 | from collections import OrderedDict 4 | import tensorflow as tf 5 | from bgan.component import Component 6 | from bgan.utils.neuralnet import NeuralNet 7 | 8 | class End2EndGenerator(Component): 9 | """Class that defines the end-to-end generator.""" 10 | def __init__(self, tensor_in, config, condition=None, slope_tensor=None, 11 | name='End2EndGenerator', reuse=None): 12 | super().__init__(tensor_in, condition, slope_tensor) 13 | with tf.variable_scope(name, reuse=reuse) as scope: 14 | self.scope = scope 15 | self.tensor_out, self.nets, self.preactivated = self.build(config) 16 | self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 17 | self.scope.name) 18 | 19 | def build(self, config): 20 | """Build the end-to-end generator.""" 21 | nets = OrderedDict() 22 | 23 | nets['main'] = NeuralNet(self.tensor_in, config['net_g']['main'], 24 | name='main') 25 | 26 | if config['net_g']['main'][-1][0] == 'reshape': 27 | preactivated = tf.reshape(nets['main'].layers[-2].preactivated, 28 | (-1, 28, 28, 1)) 29 | else: 30 | preactivated = tf.reshape(nets['main'].layers[-1].preactivated, 31 | (-1, 28, 28, 1)) 32 | 33 | return nets['main'].tensor_out, nets, preactivated 34 | 35 | class Generator(Component): 36 | """Class that defines the generator.""" 37 | def __init__(self, tensor_in, config, condition=None, name='Generator', 38 | reuse=None): 39 | super().__init__(tensor_in, condition) 40 | with tf.variable_scope(name, reuse=reuse) as scope: 41 | self.scope = scope 42 | self.tensor_out, self.nets = self.build(config) 43 | self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 44 | self.scope.name) 45 | 46 | def build(self, config): 47 | """Build the generator.""" 48 | nets = OrderedDict() 49 | 50 | nets['main'] = NeuralNet(self.tensor_in, config['net_g']['main'], 51 | name='main') 52 | 53 | return nets['main'].tensor_out, nets 54 | 55 | class Discriminator(Component): 56 | """Class that defines the discriminator.""" 57 | def __init__(self, tensor_in, config, condition=None, name='Discriminator', 58 | reuse=None): 59 | super().__init__(tensor_in, condition) 60 | with tf.variable_scope(name, reuse=reuse) as scope: 61 | self.scope = scope 62 | self.tensor_out, self.nets = self.build(config) 63 | self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 64 | self.scope.name) 65 | 66 | def build(self, config): 67 | """Build the discriminator.""" 68 | nets = OrderedDict() 69 | 70 | nets['main'] = NeuralNet(self.tensor_in, config['net_d']['main'], 71 | name='main') 72 | 73 | return nets['main'].tensor_out, nets 74 | -------------------------------------------------------------------------------- /bgan/mnist/models.py: -------------------------------------------------------------------------------- 1 | """Classes that define the proposed and the real-valued models.""" 2 | import os 3 | import time 4 | import numpy as np 5 | import tensorflow as tf 6 | from bgan.model import Model 7 | from bgan.mnist.components import End2EndGenerator, Discriminator, Generator 8 | 9 | class BinaryGAN(Model): 10 | """Class that defines the end-to-end model.""" 11 | def __init__(self, sess, config, name='BinaryGAN', reuse=None): 12 | super().__init__(sess, config, name) 13 | 14 | print('[*] Building BinaryGAN...') 15 | with tf.variable_scope(name, reuse=reuse) as scope: 16 | self.scope = scope 17 | self.build() 18 | 19 | def build(self): 20 | """Build the model.""" 21 | self.global_step = tf.Variable(0, trainable=False, name='global_step') 22 | 23 | # Create placeholders 24 | self.z = tf.placeholder( 25 | tf.float32, 26 | (self.config['batch_size'], self.config['net_g']['z_dim']), 'z' 27 | ) 28 | data_shape = (self.config['batch_size'], self.config['out_height'], 29 | self.config['out_width'], self.config['out_channel']) 30 | self.x = tf.placeholder(tf.bool, data_shape, 'x') 31 | self.x_ = tf.cast(self.x, tf.float32, 'x_') 32 | 33 | # Slope tensor for applying slope annealing trick to stochastic neurons 34 | self.slope_tensor = tf.Variable(1.0) 35 | 36 | # Components 37 | self.G = End2EndGenerator(self.z, self.config, 38 | slope_tensor=self.slope_tensor, name='G') 39 | self.D_fake = Discriminator(self.G.tensor_out, self.config, name='D') 40 | self.D_real = Discriminator(self.x_, self.config, name='D', reuse=True) 41 | self.components = (self.G, self.D_fake) 42 | 43 | # Losses 44 | self.g_loss, self.d_loss = self.get_adversarial_loss(Discriminator) 45 | 46 | # Optimizers 47 | with tf.variable_scope('Optimizer'): 48 | self.g_optimizer = self.get_optimizer() 49 | self.g_step = self.g_optimizer.minimize( 50 | self.g_loss, self.global_step, self.G.vars) 51 | 52 | self.d_optimizer = self.get_optimizer() 53 | self.d_step = self.d_optimizer.minimize( 54 | self.d_loss, self.global_step, self.D_fake.vars) 55 | 56 | # Apply weight clipping 57 | if self.config['gan']['type'] == 'wgan': 58 | with tf.control_dependencies([self.d_step]): 59 | self.d_step = tf.group( 60 | *(tf.assign(var, tf.clip_by_value( 61 | var, -self.config['gan']['clip_value'], 62 | self.config['gan']['clip_value'])) 63 | for var in self.D_fake.vars)) 64 | 65 | # Metrics 66 | # self.metrics = Metrics(self.config) 67 | 68 | # Saver 69 | self.saver = tf.train.Saver() 70 | 71 | # Print and save model information 72 | self.print_statistics() 73 | self.save_statistics() 74 | self.print_summary() 75 | self.save_summary() 76 | 77 | def train(self, x_train, train_config): 78 | """Train the model.""" 79 | # Initialize sampler 80 | self.z_sample = np.random.normal( 81 | size=(self.config['batch_size'], self.config['net_g']['z_dim'])) 82 | self.x_sample = x_train[np.random.choice( 83 | len(x_train), self.config['batch_size'], False)] 84 | feed_dict_sample = {self.x: self.x_sample, self.z: self.z_sample} 85 | 86 | # Save samples 87 | self.save_samples('x_train', x_train) 88 | self.save_samples('x_sample', self.x_sample) 89 | 90 | # Open log files and write headers 91 | log_step = open(os.path.join(self.config['log_dir'], 'step.log'), 'w') 92 | log_batch = open(os.path.join(self.config['log_dir'], 'batch.log'), 'w') 93 | log_epoch = open(os.path.join(self.config['log_dir'], 'epoch.log'), 'w') 94 | log_step.write('# epoch, step, negative_critic_loss\n') 95 | log_batch.write('# epoch, batch, time, negative_critic_loss, g_loss\n') 96 | log_epoch.write('# epoch, time, negative_critic_loss, g_loss\n') 97 | 98 | # Define slope annealing op 99 | if train_config['slope_annealing_rate'] != 1.: 100 | slope_annealing_op = tf.assign( 101 | self.slope_tensor, 102 | self.slope_tensor * train_config['slope_annealing_rate']) 103 | 104 | # Initialize counter 105 | counter = 0 106 | epoch_counter = 0 107 | num_batch = len(x_train) // self.config['batch_size'] 108 | 109 | # Start epoch iteration 110 | print('{:=^80}'.format(' Training Start ')) 111 | for epoch in range(train_config['num_epoch']): 112 | 113 | print('{:-^80}'.format(' Epoch {} Start '.format(epoch))) 114 | epoch_start_time = time.time() 115 | 116 | # Prepare batched training data 117 | z_random_batch = np.random.normal( 118 | size=(num_batch, self.config['batch_size'], 119 | self.config['net_g']['z_dim']) 120 | ) 121 | x_random_batch = np.random.choice( 122 | len(x_train), (num_batch, self.config['batch_size']), False) 123 | 124 | # Start batch iteration 125 | for batch in range(num_batch): 126 | 127 | feed_dict_batch = {self.x: x_train[x_random_batch[batch]], 128 | self.z: z_random_batch[batch]} 129 | 130 | if (counter < 25) or (counter % 500 == 0): 131 | num_critics = 100 132 | else: 133 | num_critics = 5 134 | 135 | batch_start_time = time.time() 136 | 137 | # Update networks 138 | for _ in range(num_critics): 139 | _, d_loss = self.sess.run([self.d_step, self.d_loss], 140 | feed_dict_batch) 141 | log_step.write("{}, {:14.6f}\n".format( 142 | self.get_global_step_str(), -d_loss 143 | )) 144 | 145 | _, d_loss, g_loss = self.sess.run( 146 | [self.g_step, self.d_loss, self.g_loss], feed_dict_batch 147 | ) 148 | log_step.write("{}, {:14.6f}\n".format( 149 | self.get_global_step_str(), -d_loss 150 | )) 151 | 152 | time_batch = time.time() - batch_start_time 153 | 154 | # Print iteration summary 155 | if train_config['verbose']: 156 | if batch < 1: 157 | print("epoch | batch | time | - D_loss |" 158 | " G_loss") 159 | print(" {:2d} | {:4d}/{:4d} | {:6.2f} | {:14.6f} | " 160 | "{:14.6f}".format(epoch, batch, num_batch, time_batch, 161 | -d_loss, g_loss)) 162 | 163 | log_batch.write("{:d}, {:d}, {:f}, {:f}, {:f}\n".format( 164 | epoch, batch, time_batch, -d_loss, g_loss 165 | )) 166 | 167 | # run sampler 168 | if train_config['sample_along_training']: 169 | if counter%500 == 0 or (counter < 300 and counter%100 == 0): 170 | self.run_sampler(self.G.tensor_out, feed_dict_sample) 171 | self.run_sampler(self.G.preactivated, feed_dict_sample, 172 | postfix='preactivated') 173 | 174 | # # run evaluation 175 | # if train_config['evaluate_along_training']: 176 | # if counter%10 == 0: 177 | # self.run_eval(self.G.tensor_out, feed_dict_sample) 178 | 179 | counter += 1 180 | 181 | # print epoch info 182 | time_epoch = time.time() - epoch_start_time 183 | 184 | if not train_config['verbose']: 185 | if epoch < 1: 186 | print("epoch | time | - D_loss | G_loss") 187 | print(" {:2d} | {:8.2f} | {:14.6f} | {:14.6f}".format( 188 | epoch, time_epoch, -d_loss, g_loss)) 189 | 190 | log_epoch.write("{:d}, {:f}, {:f}, {:f}\n".format( 191 | epoch, time_epoch, -d_loss, g_loss 192 | )) 193 | 194 | # save checkpoints 195 | self.save() 196 | 197 | if train_config['slope_annealing_rate'] != 1.: 198 | self.sess.run(slope_annealing_op) 199 | 200 | epoch_counter += 1 201 | 202 | print('{:=^80}'.format(' Training End ')) 203 | log_step.close() 204 | log_batch.close() 205 | log_epoch.close() 206 | 207 | class GAN(Model): 208 | """Class that defines the end-to-end model.""" 209 | def __init__(self, sess, config, name='GAN', reuse=None): 210 | super().__init__(sess, config, name) 211 | 212 | print('[*] Building GAN...') 213 | with tf.variable_scope(name, reuse=reuse) as scope: 214 | self.scope = scope 215 | self.build() 216 | 217 | def build(self): 218 | """Build the model.""" 219 | self.global_step = tf.Variable(0, trainable=False, name='global_step') 220 | 221 | # Create placeholders 222 | self.z = tf.placeholder( 223 | tf.float32, 224 | (self.config['batch_size'], self.config['net_g']['z_dim']), 'z' 225 | ) 226 | data_shape = (self.config['batch_size'], self.config['out_height'], 227 | self.config['out_width'], self.config['out_channel']) 228 | self.x = tf.placeholder(tf.bool, data_shape, 'x') 229 | self.x_ = tf.cast(self.x, tf.float32, 'x_') 230 | 231 | # Slope tensor for applying slope annealing trick to stochastic neurons 232 | self.slope_tensor = tf.Variable(1.0) 233 | 234 | # Components 235 | self.G = Generator(self.z, self.config, name='G') 236 | self.test_round = self.G.tensor_out > 0.5 237 | self.test_bernoulli = self.G.tensor_out > tf.random_uniform(data_shape) 238 | 239 | self.D_fake = Discriminator(self.G.tensor_out, self.config, name='D') 240 | self.D_real = Discriminator(self.x_, self.config, name='D', reuse=True) 241 | self.components = (self.G, self.D_fake) 242 | 243 | # Losses 244 | self.g_loss, self.d_loss = self.get_adversarial_loss(Discriminator) 245 | 246 | # Optimizers 247 | with tf.variable_scope('Optimizer'): 248 | self.g_optimizer = self.get_optimizer() 249 | self.g_step = self.g_optimizer.minimize( 250 | self.g_loss, self.global_step, self.G.vars) 251 | 252 | self.d_optimizer = self.get_optimizer() 253 | self.d_step = self.d_optimizer.minimize( 254 | self.d_loss, self.global_step, self.D_fake.vars) 255 | 256 | # Apply weight clipping 257 | if self.config['gan']['type'] == 'wgan': 258 | with tf.control_dependencies([self.d_step]): 259 | self.d_step = tf.group( 260 | *(tf.assign(var, tf.clip_by_value( 261 | var, -self.config['gan']['clip_value'], 262 | self.config['gan']['clip_value'])) 263 | for var in self.D_fake.vars)) 264 | 265 | # Metrics 266 | # self.metrics = Metrics(self.config) 267 | 268 | # Saver 269 | self.saver = tf.train.Saver() 270 | 271 | # Print and save model information 272 | self.print_statistics() 273 | self.save_statistics() 274 | self.print_summary() 275 | self.save_summary() 276 | 277 | def train(self, x_train, train_config): 278 | """Train the model.""" 279 | # Initialize sampler 280 | self.z_sample = np.random.normal( 281 | size=(self.config['batch_size'], self.config['net_g']['z_dim'])) 282 | self.x_sample = x_train[np.random.choice( 283 | len(x_train), self.config['batch_size'], False)] 284 | feed_dict_sample = {self.x: self.x_sample, self.z: self.z_sample} 285 | 286 | # Save samples 287 | self.save_samples('x_train', x_train) 288 | self.save_samples('x_sample', self.x_sample) 289 | 290 | # Open log files and write headers 291 | log_step = open(os.path.join(self.config['log_dir'], 'step.log'), 'w') 292 | log_batch = open(os.path.join(self.config['log_dir'], 'batch.log'), 'w') 293 | log_epoch = open(os.path.join(self.config['log_dir'], 'epoch.log'), 'w') 294 | log_step.write('# epoch, step, negative_critic_loss\n') 295 | log_batch.write('# epoch, batch, time, negative_critic_loss, g_loss\n') 296 | log_epoch.write('# epoch, time, negative_critic_loss, g_loss\n') 297 | 298 | # Initialize counter 299 | counter = 0 300 | epoch_counter = 0 301 | num_batch = len(x_train) // self.config['batch_size'] 302 | 303 | # Start epoch iteration 304 | print('{:=^80}'.format(' Training Start ')) 305 | for epoch in range(train_config['num_epoch']): 306 | 307 | print('{:-^80}'.format(' Epoch {} Start '.format(epoch))) 308 | epoch_start_time = time.time() 309 | 310 | # Prepare batched training data 311 | z_random_batch = np.random.normal( 312 | size=(num_batch, self.config['batch_size'], 313 | self.config['net_g']['z_dim']) 314 | ) 315 | x_random_batch = np.random.choice( 316 | len(x_train), (num_batch, self.config['batch_size']), False) 317 | 318 | # Start batch iteration 319 | for batch in range(num_batch): 320 | 321 | feed_dict_batch = {self.x: x_train[x_random_batch[batch]], 322 | self.z: z_random_batch[batch]} 323 | 324 | if (counter < 25) or (counter % 500 == 0): 325 | num_critics = 100 326 | else: 327 | num_critics = 5 328 | 329 | batch_start_time = time.time() 330 | 331 | # Update networks 332 | for _ in range(num_critics): 333 | _, d_loss = self.sess.run([self.d_step, self.d_loss], 334 | feed_dict_batch) 335 | log_step.write("{}, {:14.6f}\n".format( 336 | self.get_global_step_str(), -d_loss 337 | )) 338 | 339 | _, d_loss, g_loss = self.sess.run( 340 | [self.g_step, self.d_loss, self.g_loss], feed_dict_batch 341 | ) 342 | log_step.write("{}, {:14.6f}\n".format( 343 | self.get_global_step_str(), -d_loss 344 | )) 345 | 346 | time_batch = time.time() - batch_start_time 347 | 348 | # Print iteration summary 349 | if train_config['verbose']: 350 | if batch < 1: 351 | print("epoch | batch | time | - D_loss |" 352 | " G_loss") 353 | print(" {:2d} | {:4d}/{:4d} | {:6.2f} | {:14.6f} | " 354 | "{:14.6f}".format(epoch, batch, num_batch, time_batch, 355 | -d_loss, g_loss)) 356 | 357 | log_batch.write("{:d}, {:d}, {:f}, {:f}, {:f}\n".format( 358 | epoch, batch, time_batch, -d_loss, g_loss 359 | )) 360 | 361 | # run sampler 362 | if train_config['sample_along_training']: 363 | if counter%500 == 0 or (counter < 300 and counter%100 == 0): 364 | self.run_sampler(self.G.tensor_out, feed_dict_sample) 365 | self.run_sampler(self.test_round, feed_dict_sample, 366 | postfix='test_round') 367 | self.run_sampler(self.test_bernoulli, feed_dict_sample, 368 | postfix='test_bernoulli') 369 | 370 | # # run evaluation 371 | # if train_config['evaluate_along_training']: 372 | # if counter%10 == 0: 373 | # self.run_eval(self.G.tensor_out, feed_dict_sample) 374 | 375 | counter += 1 376 | 377 | # print epoch info 378 | time_epoch = time.time() - epoch_start_time 379 | 380 | if not train_config['verbose']: 381 | if epoch < 1: 382 | print("epoch | time | - D_loss | G_loss") 383 | print(" {:2d} | {:8.2f} | {:14.6f} | {:14.6f}".format( 384 | epoch, time_epoch, -d_loss, g_loss)) 385 | 386 | log_epoch.write("{:d}, {:f}, {:f}, {:f}\n".format( 387 | epoch, time_epoch, -d_loss, g_loss 388 | )) 389 | 390 | # save checkpoints 391 | self.save() 392 | 393 | epoch_counter += 1 394 | 395 | print('{:=^80}'.format(' Training End ')) 396 | log_step.close() 397 | log_batch.close() 398 | log_epoch.close() 399 | -------------------------------------------------------------------------------- /bgan/mnist/presets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/bgan/mnist/presets/__init__.py -------------------------------------------------------------------------------- /bgan/mnist/presets/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/bgan/mnist/presets/discriminator/__init__.py -------------------------------------------------------------------------------- /bgan/mnist/presets/discriminator/cnn.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the discriminator for the real-valued model 2 | implemented by CNNs.""" 3 | NET_D = {} 4 | 5 | NET_D['main'] = [ 6 | ('conv2d', (32, (3, 3), (1, 1)), None, 'lrelu'), # 0 (26, 26) 7 | ('maxpool2d', ((2, 2), (2, 2))), # 1 (13, 13) 8 | ('conv2d', (64, (3, 3), (1, 1)), None, 'lrelu'), # 2 (11, 11) 9 | ('maxpool2d', ((2, 2), (2, 2), 'same')), # 3 (6, 6) 10 | ('reshape', (64*6*6)), # 4 11 | ('dense', (128), None, 'lrelu'), # 5 12 | ('dense', 1), # 6 13 | ] 14 | -------------------------------------------------------------------------------- /bgan/mnist/presets/discriminator/cnn_bn.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the discriminator for the real-valued model 2 | implemented by CNNs.""" 3 | NET_D = {} 4 | 5 | NET_D['main'] = [ 6 | ('conv2d', (32, (3, 3), (1, 1)), 'bn', 'lrelu'), # 0 (26, 26) 7 | ('maxpool2d', ((2, 2), (2, 2))), # 1 (13, 13) 8 | ('conv2d', (64, (3, 3), (1, 1)), 'bn', 'lrelu'), # 2 (11, 11) 9 | ('maxpool2d', ((2, 2), (2, 2), 'same')), # 3 (6, 6) 10 | ('reshape', (64*6*6)), # 4 11 | ('dense', (128), 'bn', 'lrelu'), # 5 12 | ('dense', 1), # 6 13 | ] 14 | -------------------------------------------------------------------------------- /bgan/mnist/presets/discriminator/mlp.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the discriminator for the real-valued model 2 | implemented by MLPs.""" 3 | NET_D = {} 4 | 5 | NET_D['main'] = [ 6 | ('reshape', 784), # 0 7 | ('dense', 512, None, 'lrelu'), # 1 8 | ('dense', 256, None, 'lrelu'), # 2 9 | ('dense', 1), # 3 10 | ] 11 | -------------------------------------------------------------------------------- /bgan/mnist/presets/discriminator/mlp_bn.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the discriminator for the real-valued model 2 | implemented by MLPs.""" 3 | NET_D = {} 4 | 5 | NET_D['main'] = [ 6 | ('reshape', 784), # 0 7 | ('dense', 512, 'bn', 'lrelu'), # 1 8 | ('dense', 256, 'bn', 'lrelu'), # 2 9 | ('dense', 1), # 3 10 | ] 11 | -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/bgan/mnist/presets/generator/__init__.py -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/cnn_bernoulli.py: -------------------------------------------------------------------------------- 1 | """Network architecture of the generator for the proposed model implemented by 2 | CNNs and SBNs.""" 3 | NET_G = {} 4 | 5 | NET_G['z_dim'] = 128 6 | 7 | NET_G['main'] = [ 8 | ('reshape', (1, 1, 128)), # 0 (1, 1) 9 | ('transconv2d', (128, (2, 2), (1, 1)), 'bn', 'relu'), # 1 (2, 2) 10 | ('transconv2d', (64, (4, 4), (2, 2)), 'bn', 'relu'), # 2 (6, 6) 11 | ('transconv2d', (32, (3, 3), (2, 2)), 'bn', 'relu'), # 3 (13, 13) 12 | ('transconv2d', (1, (4, 4), (2, 2)), 'bn', 'bernoulli'), # 4 (28, 28) 13 | ] 14 | -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/cnn_real.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the generator for the real-valued model implemented 2 | by CNNs.""" 3 | NET_G = {} 4 | 5 | NET_G['z_dim'] = 128 6 | 7 | NET_G['main'] = [ 8 | ('reshape', (1, 1, 128)), # 0 (1, 1) 9 | ('transconv2d', (128, (2, 2), (1, 1)), 'bn', 'relu'), # 1 (2, 2) 10 | ('transconv2d', (64, (4, 4), (2, 2)), 'bn', 'relu'), # 2 (6, 6) 11 | ('transconv2d', (32, (3, 3), (2, 2)), 'bn', 'relu'), # 3 (13, 13) 12 | ('transconv2d', (1, (4, 4), (2, 2)), 'bn', 'sigmoid'), # 4 (28, 28) 13 | ] 14 | -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/cnn_round.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the generator for the proposed model implemented by 2 | CNNs and DBNs. 3 | """ 4 | NET_G = {} 5 | 6 | NET_G['z_dim'] = 128 7 | 8 | NET_G['main'] = [ 9 | ('reshape', (1, 1, 128)), # 0 (1, 1) 10 | ('transconv2d', (128, (2, 2), (1, 1)), 'bn', 'relu'), # 1 (2, 2) 11 | ('transconv2d', (64, (4, 4), (2, 2)), 'bn', 'relu'), # 2 (6, 6) 12 | ('transconv2d', (32, (3, 3), (2, 2)), 'bn', 'relu'), # 3 (13, 13) 13 | ('transconv2d', (1, (4, 4), (2, 2)), 'bn', 'round'), # 4 (28, 28) 14 | ] 15 | -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/mlp_bernoulli.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the generator for the proposed model implemented by 2 | MLPs and SBNs. 3 | """ 4 | NET_G = {} 5 | 6 | NET_G['z_dim'] = 128 7 | 8 | NET_G['main'] = [ 9 | ('dense', (1024), 'bn', 'relu'), # 0 10 | ('dense', (784), 'bn', 'bernoulli'), # 1 11 | ('reshape', (28, 28, 1)), # 2 12 | ] 13 | -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/mlp_real.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the generator for the real-valued model implemented 2 | by MLPs.""" 3 | NET_G = {} 4 | 5 | NET_G['z_dim'] = 128 6 | 7 | NET_G['main'] = [ 8 | ('dense', (1024), 'bn', 'relu'), # 0 9 | ('dense', (784), 'bn', 'sigmoid'), # 1 10 | ('reshape', (28, 28, 1)), # 2 11 | ] 12 | -------------------------------------------------------------------------------- /bgan/mnist/presets/generator/mlp_round.py: -------------------------------------------------------------------------------- 1 | """Network architecture for the generator for the proposed model implemented by 2 | MLPs and DBNs. 3 | """ 4 | NET_G = {} 5 | 6 | NET_G['z_dim'] = 128 7 | 8 | NET_G['main'] = [ 9 | ('dense', (1024), 'bn', 'relu'), # 0 10 | ('dense', (784), 'bn', 'round'), # 1 11 | ('reshape', (28, 28, 1)), # 2 12 | ] 13 | -------------------------------------------------------------------------------- /bgan/model.py: -------------------------------------------------------------------------------- 1 | """Base class for the models""" 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from bgan.utils import image_io 6 | 7 | class Model(object): 8 | """Base class for models.""" 9 | def __init__(self, sess, config, name='model'): 10 | self.sess = sess 11 | self.name = name 12 | self.config = config 13 | 14 | self.scope = None 15 | self.global_step = None 16 | self.x_ = None 17 | self.G = None 18 | self.D_real = None 19 | self.D_fake = None 20 | self.components = [] 21 | self.metrics = None 22 | self.saver = None 23 | 24 | def init_all(self): 25 | """Initialize all variables in the scope.""" 26 | print('[*] Initializing variables...') 27 | tf.variables_initializer(tf.global_variables(self.scope.name)).run() 28 | 29 | def get_adversarial_loss(self, discriminator, scope_to_reuse=None): 30 | """Return the adversarial losses for the generator and the 31 | discriminator.""" 32 | if self.config['gan']['type'] == 'gan': 33 | d_loss_real = tf.losses.sigmoid_cross_entropy( 34 | tf.ones_like(self.D_real.tensor_out), self.D_real.tensor_out) 35 | d_loss_fake = tf.losses.sigmoid_cross_entropy( 36 | tf.zeros_like(self.D_fake.tensor_out), self.D_fake.tensor_out) 37 | 38 | adv_loss_d = d_loss_real + d_loss_fake 39 | adv_loss_g = tf.losses.sigmoid_cross_entropy( 40 | tf.ones_like(self.D_fake.tensor_out), self.D_fake.tensor_out) 41 | 42 | if (self.config['gan']['type'] == 'wgan' 43 | or self.config['gan']['type'] == 'wgan-gp'): 44 | adv_loss_d = (tf.reduce_mean(self.D_fake.tensor_out) 45 | - tf.reduce_mean(self.D_real.tensor_out)) 46 | adv_loss_g = -tf.reduce_mean(self.D_fake.tensor_out) 47 | 48 | if self.config['gan']['type'] == 'wgan-gp': 49 | eps = tf.random_uniform( 50 | [tf.shape(self.x_)[0], 1, 1, 1], 0.0, 1.0) 51 | inter = eps * self.x_ + (1. - eps) * self.G.tensor_out 52 | if scope_to_reuse is None: 53 | D_inter = discriminator(inter, self.config, name='D', 54 | reuse=True) 55 | else: 56 | with tf.variable_scope(scope_to_reuse, reuse=True): 57 | D_inter = discriminator(inter, self.config, name='D', 58 | reuse=True) 59 | gradient = tf.gradients(D_inter.tensor_out, inter)[0] 60 | slopes = tf.sqrt(1e-8 + tf.reduce_sum( 61 | tf.square(gradient), 62 | tf.range(1, len(gradient.get_shape())))) 63 | gradient_penalty = tf.reduce_mean(tf.square(slopes - 1.0)) 64 | adv_loss_d += (self.config['gan']['gp_coefficient'] 65 | * gradient_penalty) 66 | 67 | return adv_loss_g, adv_loss_d 68 | 69 | def get_optimizer(self): 70 | """Return a Adam optimizer.""" 71 | if self.config['optimizer']['type'] == 'adam': 72 | return tf.train.AdamOptimizer( 73 | self.config['optimizer']['lr'], 74 | self.config['optimizer']['beta1'], 75 | self.config['optimizer']['beta2'], 76 | self.config['optimizer']['epsilon']) 77 | if self.config['optimizer']['type'] == 'rmsprop': 78 | return tf.train.RMSPropOptimizer( 79 | self.config['optimizer']['lr'], 80 | self.config['optimizer']['decay'], 81 | self.config['optimizer']['momentum'], 82 | self.config['optimizer']['epsilon']) 83 | 84 | def get_statistics(self): 85 | """Return model statistics (number of paramaters for each component).""" 86 | def get_num_parameter(var_list): 87 | """Given the variable list, return the total number of parameters. 88 | """ 89 | return int(np.sum([np.product([x.value for x in var.get_shape()]) 90 | for var in var_list])) 91 | num_par = get_num_parameter(tf.trainable_variables( 92 | self.scope.name)) 93 | num_par_g = get_num_parameter(self.G.vars) 94 | num_par_d = get_num_parameter(self.D_fake.vars) 95 | return ("Number of parameters: {}\nNumber of parameters in G: {}\n" 96 | "Number of parameters in D: {}".format(num_par, num_par_g, 97 | num_par_d)) 98 | 99 | def get_summary(self): 100 | """Return model summary.""" 101 | return '\n'.join( 102 | ["{:-^80}".format(' < ' + self.scope.name + ' > ')] 103 | + [(x.get_summary() + '\n' + '-' * 80) for x in self.components]) 104 | 105 | def get_global_step_str(self): 106 | """Return the global step as a string.""" 107 | return str(tf.train.global_step(self.sess, self.global_step)) 108 | 109 | def print_statistics(self): 110 | """Print model statistics (number of paramaters for each component).""" 111 | print("{:=^80}".format(' Model Statistics ')) 112 | print(self.get_statistics()) 113 | 114 | def print_summary(self): 115 | """Print model summary.""" 116 | print("{:=^80}".format(' Model Summary ')) 117 | print(self.get_summary()) 118 | 119 | def save_statistics(self, filepath=None): 120 | """Save model statistics to file. Default to save to the log directory 121 | given as a global variable.""" 122 | if filepath is None: 123 | filepath = os.path.join(self.config['log_dir'], 124 | 'model_statistics.txt') 125 | with open(filepath, 'w') as f: 126 | f.write(self.get_statistics()) 127 | 128 | def save_summary(self, filepath=None): 129 | """Save model summary to file. Default to save to the log directory 130 | given as a global variable.""" 131 | if filepath is None: 132 | filepath = os.path.join(self.config['log_dir'], 'model_summary.txt') 133 | with open(filepath, 'w') as f: 134 | f.write(self.get_summary()) 135 | 136 | def save(self, filepath=None): 137 | """Save the model to a checkpoint file. Default to save to the log 138 | directory given as a global variable.""" 139 | if filepath is None: 140 | filepath = os.path.join(self.config['checkpoint_dir'], 141 | self.name + '.model') 142 | print('[*] Saving checkpoint...') 143 | self.saver.save(self.sess, filepath, self.global_step) 144 | 145 | def load(self, filepath): 146 | """Load the model from the latest checkpoint in a directory.""" 147 | print('[*] Loading checkpoint...') 148 | self.saver.restore(self.sess, filepath) 149 | 150 | def load_latest(self, checkpoint_dir=None): 151 | """Load the model from the latest checkpoint in a directory.""" 152 | if checkpoint_dir is None: 153 | checkpoint_dir = self.config['checkpoint_dir'] 154 | print('[*] Loading checkpoint...') 155 | checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir) 156 | if checkpoint_path is None: 157 | raise ValueError("Checkpoint not found") 158 | self.saver.restore(self.sess, checkpoint_path) 159 | 160 | def save_samples(self, filename, samples, shape=None, postfix=None): 161 | """Save samples to an image file.""" 162 | if shape is None: 163 | shape = self.config['sample_grid'] 164 | if len(samples) > self.config['num_sample']: 165 | samples = samples[:self.config['num_sample']] 166 | if postfix is None: 167 | imagepath = os.path.join(self.config['sample_dir'], 168 | '{}.png'.format(filename)) 169 | else: 170 | imagepath = os.path.join(self.config['sample_dir'], 171 | '{}_{}.png'.format(filename, postfix)) 172 | image_io.save_image(imagepath, samples, shape) 173 | 174 | def run_sampler(self, targets, feed_dict, postfix=None): 175 | """Run the target operation with feed_dict and save the samples.""" 176 | if not isinstance(targets, list): 177 | targets = [targets] 178 | results = self.sess.run(targets, feed_dict) 179 | results = [result[:self.config['num_sample']] for result in results] 180 | samples = np.stack(results, 1).reshape((-1,) + results[0].shape[1:]) 181 | shape = [self.config['sample_grid'][0], 182 | self.config['sample_grid'][1] * len(results)] 183 | if postfix is None: 184 | filename = self.get_global_step_str() 185 | else: 186 | filename = self.get_global_step_str() + '_' + postfix 187 | self.save_samples(filename, samples, shape) 188 | 189 | def run_eval(self, target, feed_dict, postfix=None): 190 | """Run evaluation.""" 191 | result = self.sess.run(target, feed_dict) 192 | binarized = (result > 0) 193 | if postfix is None: 194 | filename = self.get_global_step_str() 195 | else: 196 | filename = self.get_global_step_str() + '_' + postfix 197 | reshaped = binarized.reshape((-1,) + binarized.shape[2:]) 198 | mat_path = os.path.join(self.config['eval_dir'], filename+'.npy') 199 | _ = self.metrics.eval(reshaped, mat_path=mat_path) 200 | -------------------------------------------------------------------------------- /bgan/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/bgan/utils/__init__.py -------------------------------------------------------------------------------- /bgan/utils/image_io.py: -------------------------------------------------------------------------------- 1 | """Utilities for creating image grids from a batch of images. 2 | """ 3 | import numpy as np 4 | import imageio 5 | 6 | def get_image_grid(images, shape, grid_width=0, grid_color=0, 7 | frame=False): 8 | """ 9 | Merge the input images and return a merged grid image. 10 | 11 | Arguments 12 | --------- 13 | images : np.array, ndim=3 or 4 14 | The image array. Shape is (num_image, height, width, (num_channel)). 15 | shape : list or tuple of int 16 | Shape of the image grid. (height, width) 17 | grid_width : int 18 | Width of the grid lines. Default to 0. 19 | grid_color : int 20 | Color of the grid lines. Available values are 0 (black) to 21 | 255 (white). Default to 255. 22 | frame : bool 23 | True to add frame. Default to False. 24 | 25 | Returns 26 | ------- 27 | merged : np.array, ndim=3 28 | The merged grid image. 29 | """ 30 | if images.ndim == 3: 31 | reshaped = images.reshape(shape[0], shape[1], images.shape[1], 32 | images.shape[2], 1) 33 | elif images.ndim == 4: 34 | reshaped = images.reshape(shape[0], shape[1], images.shape[1], 35 | images.shape[2], images.shape[3]) 36 | else: 37 | raise ValueError("Number of dimension for `images` must be 3 or 4") 38 | pad_width = ((0, 0), (0, 0), (grid_width, 0), (grid_width, 0), (0, 0)) 39 | padded = np.pad(reshaped, pad_width, 'constant', constant_values=grid_color) 40 | transposed = padded.transpose(0, 2, 1, 3, 4) 41 | merged = transposed.reshape(shape[0] * (images.shape[1] + grid_width), 42 | shape[1] * (images.shape[2] + grid_width), 43 | reshaped.shape[4]) 44 | if frame: 45 | return np.pad(merged, ((0, grid_width), (0, grid_width), (0, 0)), 46 | 'constant', constant_values=grid_color) 47 | return merged[grid_width:, grid_width:] 48 | 49 | def save_image(filepath, phrases, shape, inverted=False, grid_width=1, 50 | grid_color=255, frame=False): 51 | """ 52 | Save a batch of phrases to a single image grid. 53 | 54 | Arguments 55 | --------- 56 | filepath : str 57 | Path to save the image grid. 58 | phrases : np.array, ndim=5 59 | The phrase array. Shape is (num_phrase, num_bar, num_time_step, 60 | num_pitch, num_track). 61 | shape : list or tuple of int 62 | Shape of the image grid. (height, width) 63 | inverted : bool 64 | True to invert the colors. Default to False. 65 | grid_width : int 66 | Width of the grid lines. Default to 2. 67 | grid_color : int 68 | Color of the grid lines. Available values are 0 (black) to 69 | 255 (white). Default to 255. 70 | frame : bool 71 | True to add frame. Default to False. 72 | """ 73 | if phrases.dtype == np.bool_: 74 | if inverted: 75 | phrases = np.logical_not(phrases) 76 | clipped = (phrases * 255).astype(np.uint8) 77 | else: 78 | if inverted: 79 | phrases = 1. - phrases 80 | clipped = (phrases * 255.).clip(0, 255).astype(np.uint8) 81 | merged = get_image_grid(clipped, shape, grid_width, grid_color, frame) 82 | imageio.imwrite(filepath, merged) 83 | -------------------------------------------------------------------------------- /bgan/utils/neuralnet.py: -------------------------------------------------------------------------------- 1 | """Classes for neural networks and layers.""" 2 | import numpy as np 3 | import tensorflow as tf 4 | from bgan.utils.ops import binary_stochastic_ST 5 | 6 | SUPPORTED_LAYER_TYPES = ( 7 | 'reshape', 'mean', 'sum', 'dense', 'identity', 'conv1d', 'conv2d', 'conv3d', 8 | 'transconv2d', 'transconv3d', 'avgpool2d', 'avgpool3d', 'maxpool2d', 9 | 'maxpool3d' 10 | ) 11 | 12 | class Layer(object): 13 | """Base class for layers.""" 14 | def __init__(self, tensor_in, structure=None, condition=None, 15 | slope_tensor=None, name=None, reuse=None): 16 | if not isinstance(tensor_in, tf.Tensor): 17 | raise TypeError("`tensor_in` must be of tf.Tensor type") 18 | 19 | self.tensor_in = tensor_in 20 | 21 | if structure is not None: 22 | with tf.variable_scope(name, reuse=reuse) as scope: 23 | self.scope = scope 24 | if structure[0] not in SUPPORTED_LAYER_TYPES: 25 | raise ValueError("Unknown layer type at " + self.scope.name) 26 | self.layer_type = structure[0] 27 | self.tensor_out = self.build(structure, condition, slope_tensor) 28 | self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 29 | self.scope.name) 30 | else: 31 | self.scope = None 32 | self.layer_type = 'bypass' 33 | self.tensor_out = tensor_in 34 | self.vars = [] 35 | 36 | def __repr__(self): 37 | return "Layer({}, type={}, input_shape={}, output_shape={})".format( 38 | self.scope.name, self.layer_type, self.tensor_in.get_shape(), 39 | self.tensor_out.get_shape()) 40 | 41 | def get_summary(self): 42 | """Return the summary string.""" 43 | return "{:36} {:12} {:30}".format( 44 | self.scope.name, self.layer_type, str(self.tensor_out.get_shape())) 45 | 46 | def build(self, structure, condition, slope_tensor): 47 | """Build the layer.""" 48 | # Mean layers 49 | if self.layer_type == 'mean': 50 | keepdims = structure[2] if len(structure) > 2 else None 51 | return tf.reduce_mean(self.tensor_in, structure[1], keepdims, 52 | name='mean') 53 | 54 | # Summation layers 55 | if self.layer_type == 'sum': 56 | keepdims = structure[2] if len(structure) > 2 else None 57 | return tf.reduce_sum(self.tensor_in, structure[1], keepdims, 58 | name='sum') 59 | 60 | # Reshape layers 61 | if self.layer_type == 'reshape': 62 | if np.prod(structure[1]) != np.prod(self.tensor_in.get_shape()[1:]): 63 | raise ValueError("Bad reshape size: {} to {} at {}".format( 64 | self.tensor_in.get_shape()[1:], structure[1], 65 | self.scope.name)) 66 | if isinstance(structure[1], int): 67 | reshape_shape = (-1, structure[1]) 68 | else: 69 | reshape_shape = (-1,) + structure[1] 70 | return tf.reshape(self.tensor_in, reshape_shape, 'reshape') 71 | 72 | # Pooling layers 73 | if self.layer_type in ('avgpool2d', 'maxpool2d', 'avgpool3d', 74 | 'maxpool3d'): 75 | strides = structure[1][1] if len(structure[1]) > 1 else 1 76 | padding = structure[1][2] if len(structure[1]) > 2 else 'valid' 77 | 78 | if self.layer_type == 'avgpool2d': 79 | return tf.layers.average_pooling2d( 80 | self.tensor_in, structure[1][0], strides, padding, 81 | name='avgpool2d') 82 | if self.layer_type == 'maxpool2d': 83 | return tf.layers.max_pooling2d( 84 | self.tensor_in, structure[1][0], strides, padding, 85 | name='maxpool2d') 86 | if self.layer_type == 'avgpool3d': 87 | return tf.layers.average_pooling3d( 88 | self.tensor_in, structure[1][0], strides, padding, 89 | name='avgpool3d') 90 | if self.layer_type == 'maxpool3d': 91 | return tf.layers.max_pooling3d( 92 | self.tensor_in, structure[1][0], strides, padding, 93 | name='maxpool3d') 94 | 95 | # Condition 96 | if condition is None: 97 | self.conditioned = self.tensor_in 98 | elif self.layer_type == 'dense': 99 | self.conditioned = tf.concat([self.tensor_in, condition], 1) 100 | elif self.layer_type in ('conv1d', 'conv2d', 'transconv2d', 'conv3d', 101 | 'transconv3d'): 102 | if self.layer_type == 'conv1d': 103 | reshape_shape = (-1, 1, condition.get_shape()[1]) 104 | elif self.layer_type in ('conv2d', 'transconv2d'): 105 | reshape_shape = (-1, 1, 1, condition.get_shape()[1]) 106 | else: # ('conv3d', 'transconv3d') 107 | reshape_shape = (-1, 1, 1, 1, condition.get_shape()[1]) 108 | reshaped = tf.reshape(condition, reshape_shape) 109 | out_shape = ([-1] + self.tensor_in.get_shape()[1:-1] 110 | + [condition.get_shape()[1]]) 111 | to_concat = reshaped * tf.ones(out_shape) 112 | self.conditioned = tf.concat([self.tensor_in, to_concat], -1) 113 | 114 | # Core layers (dense, convolutional or identity layer) 115 | if self.layer_type == 'dense': 116 | kernel_initializer = tf.truncated_normal_initializer(stddev=0.02) 117 | self.core = tf.layers.dense(self.conditioned, structure[1], 118 | kernel_initializer=kernel_initializer, 119 | name='dense') 120 | elif self.layer_type == 'identity': 121 | self.core = self.conditioned 122 | else: 123 | filters = structure[1][0] 124 | kernel_size = structure[1][1] 125 | strides = structure[1][2] if len(structure[1]) > 2 else 1 126 | padding = structure[1][3] if len(structure[1]) > 3 else 'valid' 127 | kernel_initializer = tf.truncated_normal_initializer(stddev=0.02) 128 | 129 | if self.layer_type == 'conv1d': 130 | self.core = tf.layers.conv1d( 131 | self.conditioned, filters, kernel_size, strides, padding, 132 | kernel_initializer=kernel_initializer, name='conv1d') 133 | elif self.layer_type == 'conv2d': 134 | self.core = tf.layers.conv2d( 135 | self.conditioned, filters, kernel_size, strides, padding, 136 | kernel_initializer=kernel_initializer, name='conv2d') 137 | elif self.layer_type == 'transconv2d': 138 | self.core = tf.layers.conv2d_transpose( 139 | self.conditioned, filters, kernel_size, strides, padding, 140 | kernel_initializer=kernel_initializer, name='transconv2d') 141 | elif self.layer_type == 'conv3d': 142 | self.core = tf.layers.conv3d( 143 | self.conditioned, filters, kernel_size, strides, padding, 144 | kernel_initializer=kernel_initializer, name='conv3d') 145 | elif self.layer_type == 'transconv3d': 146 | self.core = tf.layers.conv3d_transpose( 147 | self.conditioned, filters, kernel_size, strides, padding, 148 | kernel_initializer=kernel_initializer, name='transconv3d') 149 | 150 | # normalization layer 151 | if len(structure) > 2: 152 | if structure[2] not in (None, 'bn', 'in', 'ln'): 153 | raise ValueError("Unknown normalization at " + self.scope.name) 154 | normalization = structure[2] 155 | else: 156 | normalization = None 157 | 158 | if normalization is None: 159 | self.normalized = self.core 160 | elif normalization == 'bn': 161 | self.normalized = tf.layers.batch_normalization( 162 | self.core, name='batch_norm') 163 | elif normalization == 'in': 164 | self.normalized = tf.contrib.layers.instance_norm( 165 | self.core, scope='instance_norm') 166 | elif normalization == 'ln': 167 | self.normalized = tf.contrib.layers.layer_norm( 168 | self.core, scope='layer_norm') 169 | 170 | # activation 171 | if len(structure) > 3: 172 | if structure[3] not in (None, 'tanh', 'sigmoid', 'relu', 'lrelu', 173 | 'bernoulli', 'round'): 174 | raise ValueError("Unknown activation at " + self.scope.name) 175 | activation = structure[3] 176 | else: 177 | activation = None 178 | 179 | if activation is None: 180 | self.activated = self.normalized 181 | elif activation == 'tanh': 182 | self.activated = tf.nn.tanh(self.normalized, 'tanh') 183 | elif activation == 'sigmoid': 184 | self.activated = tf.nn.sigmoid(self.normalized, 'sigmoid') 185 | elif activation == 'relu': 186 | self.activated = tf.nn.relu(self.normalized, 'relu') 187 | elif activation == 'lrelu': 188 | self.activated = tf.nn.leaky_relu(self.normalized, name='lrelu') 189 | elif activation == 'bernoulli': 190 | self.activated, self.preactivated = binary_stochastic_ST( 191 | self.normalized, slope_tensor, False, True) 192 | elif activation == 'round': 193 | self.activated, self.preactivated = binary_stochastic_ST( 194 | self.normalized, slope_tensor, False, False) 195 | 196 | return self.activated 197 | 198 | class NeuralNet(object): 199 | """Base class for neural networks.""" 200 | def __init__(self, tensor_in, architecture=None, condition=None, 201 | slope_tensor=None, name='NeuralNet', reuse=None): 202 | if not isinstance(tensor_in, tf.Tensor): 203 | raise TypeError("`tensor_in` must be of tf.Tensor type") 204 | 205 | self.tensor_in = tensor_in 206 | self.condition = condition 207 | self.slope_tensor = slope_tensor 208 | 209 | if architecture is not None: 210 | with tf.variable_scope(name, reuse=reuse) as scope: 211 | self.scope = scope 212 | self.layers = self.build(architecture) 213 | self.tensor_out = self.layers[-1].tensor_out 214 | self.vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 215 | self.scope.name) 216 | else: 217 | self.scope = None 218 | self.layers = [] 219 | self.tensor_out = tensor_in 220 | self.vars = [] 221 | 222 | def __repr__(self): 223 | return "NeuralNet({}, input_shape={}, output_shape={})".format( 224 | self.scope.name, self.tensor_in.get_shape(), 225 | self.tensor_out.get_shape()) 226 | 227 | def get_summary(self): 228 | """Return the summary string.""" 229 | return '\n'.join( 230 | ['[{}]'.format(self.scope.name), 231 | "{:49} {}".format('Input', self.tensor_in.get_shape())] 232 | + [x.get_summary() for x in self.layers]) 233 | 234 | def build(self, architecture): 235 | """Build the neural network.""" 236 | layers = [] 237 | for idx, structure in enumerate(architecture): 238 | if idx > 0: 239 | prev_layer = layers[idx-1].tensor_out 240 | else: 241 | prev_layer = self.tensor_in 242 | 243 | # Skip connections 244 | if len(structure) > 4: 245 | skip_connection = structure[4][0] 246 | else: 247 | skip_connection = None 248 | 249 | if skip_connection is None: 250 | connected = prev_layer 251 | elif skip_connection == 'add': 252 | connected = prev_layer + layers[structure[4][1]].tensor_out 253 | elif skip_connection == 'concat': 254 | connected = tf.concat( 255 | [prev_layer, layers[structure[4][1]].tensor_out], -1) 256 | 257 | # Build layer 258 | layers.append(Layer(connected, structure, 259 | slope_tensor=self.slope_tensor, 260 | name='Layer_{}'.format(idx))) 261 | return layers 262 | -------------------------------------------------------------------------------- /bgan/utils/ops.py: -------------------------------------------------------------------------------- 1 | """Operations for implementing binary neurons. Code is from the R2RT blog post: 2 | https://r2rt.com/binary-stochastic-neurons-in-tensorflow.html (slightly adapted) 3 | """ 4 | import tensorflow as tf 5 | from tensorflow.python.framework import ops 6 | 7 | def binary_round(x): 8 | """ 9 | Rounds a tensor whose values are in [0,1] to a tensor with values in 10 | {0, 1}, using the straight through estimator for the gradient. 11 | """ 12 | g = tf.get_default_graph() 13 | 14 | with ops.name_scope("BinaryRound") as name: 15 | with g.gradient_override_map({"Round": "Identity"}): 16 | return tf.round(x, name=name) 17 | 18 | def bernoulli_sample(x): 19 | """ 20 | Uses a tensor whose values are in [0,1] to sample a tensor with values 21 | in {0, 1}, using the straight through estimator for the gradient. 22 | 23 | E.g., if x is 0.6, bernoulliSample(x) will be 1 with probability 0.6, 24 | and 0 otherwise, and the gradient will be pass-through (identity). 25 | """ 26 | g = tf.get_default_graph() 27 | 28 | with ops.name_scope("BernoulliSample") as name: 29 | with g.gradient_override_map({"Ceil": "Identity", 30 | "Sub": "BernoulliSample_ST"}): 31 | return tf.ceil(x - tf.random_uniform(tf.shape(x)), name=name) 32 | 33 | @ops.RegisterGradient("BernoulliSample_ST") 34 | def bernoulli_sample_ST(op, grad): 35 | return [grad, tf.zeros(tf.shape(op.inputs[1]))] 36 | 37 | def pass_through_sigmoid(x, slope=1): 38 | """Sigmoid that uses identity function as its gradient""" 39 | g = tf.get_default_graph() 40 | with ops.name_scope("PassThroughSigmoid") as name: 41 | with g.gradient_override_map({"Sigmoid": "Identity"}): 42 | return tf.sigmoid(x, name=name) 43 | 44 | def binary_stochastic_ST(x, slope_tensor=None, pass_through=True, 45 | stochastic=True): 46 | """ 47 | Sigmoid followed by either a random sample from a bernoulli distribution 48 | according to the result (binary stochastic neuron) (default), or a 49 | sigmoid followed by a binary step function (if stochastic == False). 50 | Uses the straight through estimator. See 51 | https://arxiv.org/abs/1308.3432. 52 | 53 | Arguments: 54 | * x: the pre-activation / logit tensor 55 | * slope_tensor: if passThrough==False, slope adjusts the slope of the 56 | sigmoid function for purposes of the Slope Annealing Trick (see 57 | http://arxiv.org/abs/1609.01704) 58 | * pass_through: if True (default), gradient of the entire function is 1 59 | or 0; if False, gradient of 1 is scaled by the gradient of the 60 | sigmoid (required if Slope Annealing Trick is used) 61 | * stochastic: binary stochastic neuron if True (default), or step 62 | function if False 63 | """ 64 | if slope_tensor is None: 65 | slope_tensor = tf.constant(1.0) 66 | 67 | if pass_through: 68 | p = pass_through_sigmoid(x) 69 | else: 70 | p = tf.sigmoid(slope_tensor * x) 71 | 72 | if stochastic: 73 | return bernoulli_sample(p), p 74 | else: 75 | return binary_round(p), p 76 | 77 | def binary_stochastic_REINFORCE(x, loss_op_name="loss_by_example"): 78 | """ 79 | Sigmoid followed by a random sample from a bernoulli distribution 80 | according to the result (binary stochastic neuron). Uses the REINFORCE 81 | estimator. See https://arxiv.org/abs/1308.3432. 82 | 83 | NOTE: Requires a loss operation with name matching the argument for 84 | loss_op_name in the graph. This loss operation should be broken out by 85 | example (i.e., not a single number for the entire batch). 86 | """ 87 | g = tf.get_default_graph() 88 | 89 | with ops.name_scope("BinaryStochasticREINFORCE"): 90 | with g.gradient_override_map({"Sigmoid": "BinaryStochastic_REINFORCE", 91 | "Ceil": "Identity"}): 92 | p = tf.sigmoid(x) 93 | 94 | reinforce_collection = g.get_collection("REINFORCE") 95 | if not reinforce_collection: 96 | g.add_to_collection("REINFORCE", {}) 97 | reinforce_collection = g.get_collection("REINFORCE") 98 | reinforce_collection[0][p.op.name] = loss_op_name 99 | 100 | return tf.ceil(p - tf.random_uniform(tf.shape(x))) 101 | 102 | 103 | @ops.RegisterGradient("BinaryStochastic_REINFORCE") 104 | def _binaryStochastic_REINFORCE(op, _): 105 | """Unbiased estimator for binary stochastic function based on REINFORCE.""" 106 | loss_op_name = op.graph.get_collection("REINFORCE")[0][op.name] 107 | loss_tensor = op.graph.get_operation_by_name(loss_op_name).outputs[0] 108 | 109 | sub_tensor = op.outputs[0].consumers()[0].outputs[0] #subtraction tensor 110 | ceil_tensor = sub_tensor.consumers()[0].outputs[0] #ceiling tensor 111 | 112 | outcome_diff = (ceil_tensor - op.outputs[0]) 113 | 114 | # Provides an early out if we want to avoid variance adjustment for 115 | # whatever reason (e.g., to show that variance adjustment helps) 116 | if op.graph.get_collection("REINFORCE")[0].get("no_variance_adj"): 117 | return outcome_diff * tf.expand_dims(loss_tensor, 1) 118 | 119 | outcome_diff_sq = tf.square(outcome_diff) 120 | outcome_diff_sq_r = tf.reduce_mean(outcome_diff_sq, reduction_indices=0) 121 | outcome_diff_sq_loss_r = tf.reduce_mean( 122 | outcome_diff_sq * tf.expand_dims(loss_tensor, 1), reduction_indices=0) 123 | 124 | l_bar_num = tf.Variable(tf.zeros(outcome_diff_sq_r.get_shape()), 125 | trainable=False) 126 | l_bar_den = tf.Variable(tf.ones(outcome_diff_sq_r.get_shape()), 127 | trainable=False) 128 | 129 | # Note: we already get a decent estimate of the average from the minibatch 130 | decay = 0.95 131 | train_l_bar_num = tf.assign(l_bar_num, l_bar_num*decay +\ 132 | outcome_diff_sq_loss_r*(1-decay)) 133 | train_l_bar_den = tf.assign(l_bar_den, l_bar_den*decay +\ 134 | outcome_diff_sq_r*(1-decay)) 135 | 136 | 137 | with tf.control_dependencies([train_l_bar_num, train_l_bar_den]): 138 | l_bar = train_l_bar_num/(train_l_bar_den + 1e-4) 139 | l = tf.tile(tf.expand_dims(loss_tensor, 1), 140 | tf.constant([1, l_bar.get_shape().as_list()[0]])) 141 | return outcome_diff * (l - l_bar) 142 | 143 | def binary_wrapper(pre_activations_tensor, estimator, 144 | stochastic_tensor=tf.constant(True), pass_through=True, 145 | slope_tensor=tf.constant(1.0)): 146 | """ 147 | Turns a layer of pre-activations (logits) into a layer of binary 148 | stochastic neurons 149 | 150 | Keyword arguments: 151 | *estimator: either ST or REINFORCE 152 | *stochastic_tensor: a boolean tensor indicating whether to sample from a 153 | bernoulli distribution (True, default) or use a step_function (e.g., 154 | for inference) 155 | *pass_through: for ST only - boolean as to whether to substitute 156 | identity derivative on the backprop (True, default), or whether to 157 | use the derivative of the sigmoid 158 | *slope_tensor: for ST only - tensor specifying the slope for purposes of 159 | slope annealing trick 160 | """ 161 | if estimator == 'straight_through': 162 | if pass_through: 163 | return tf.cond( 164 | stochastic_tensor, 165 | lambda: binary_stochastic_ST(pre_activations_tensor), 166 | lambda: binary_stochastic_ST(pre_activations_tensor, 167 | stochastic=False)) 168 | else: 169 | return tf.cond( 170 | stochastic_tensor, 171 | lambda: binary_stochastic_ST(pre_activations_tensor, 172 | slope_tensor, False), 173 | lambda: binary_stochastic_ST(pre_activations_tensor, 174 | slope_tensor, False, False)) 175 | 176 | elif estimator == 'reinforce': 177 | # binaryStochastic_REINFORCE was designed to only be stochastic, so 178 | # using the ST version for the step fn for purposes of using step 179 | # fn at evaluation / not for training 180 | return tf.cond( 181 | stochastic_tensor, 182 | lambda: binary_stochastic_REINFORCE(pre_activations_tensor), 183 | lambda: binary_stochastic_ST(pre_activations_tensor, 184 | stochastic=False)) 185 | 186 | else: 187 | raise ValueError("Unrecognized estimator.") 188 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """Define configuration variables in experiment, model and training levels. 2 | 3 | Quick Setup 4 | =========== 5 | Change the values in the dictionary `SETUP` for a quick setup. 6 | Documentation is provided right after each key. 7 | 8 | Configuration 9 | ============= 10 | More configuration options are provided as a dictionary `CONFIG`. 11 | `CONFIG['exp']`, `CONFIG['data']`, `CONFIG['model']`, `CONFIG['train']` and 12 | `CONFIG['tensorflow']` define experiment-, data-, model-, training-, 13 | TensorFlow-related configuration variables, respectively. 14 | 15 | Note that the automatically-determined experiment name is based only on the 16 | values defined in the dictionary `SETUP`, so remember to provide the experiment 17 | name manually if you have changed the configuration so that you won't overwrite 18 | existing experiment directories. 19 | """ 20 | import os 21 | import shutil 22 | import distutils.dir_util 23 | import importlib 24 | import tensorflow as tf 25 | 26 | # Quick setup 27 | SETUP = { 28 | 'model': 'binarygan', 29 | # {'binarygan', 'gan'} 30 | # The model to use. Currently support BinaryGAN and GAN models. 31 | 32 | 'exp_name': None, 33 | # The experiment name. Also the name of the folder that will be created 34 | # in './exp/' and all the experiment-related files are saved in that 35 | # folder. None to determine automatically. The automatically- 36 | # determined experiment name is based only on the values defined in the 37 | # dictionary `SETUP`, so remember to provide the experiment name manually 38 | # (so that you won't overwrite a trained model). 39 | 40 | 'training_data': 'herman_binarized_mnist_x', 41 | # Filename of the training data. The training data can be loaded from a npy 42 | # file in the hard disk or from the shared memory using SharedArray package. 43 | 44 | 'training_data_location': 'sa', 45 | # Location of the training data. 'hd' to load from a npy file stored in the 46 | # hard disk. 'sa' to load from shared array using SharedArray package. 47 | 48 | 'gpu': '0', 49 | # The GPU index in os.environ['CUDA_VISIBLE_DEVICES'] to use. 50 | 51 | 'prefix': 'fix_gan_loss', 52 | # Prefix for the experiment name. Useful when training with different 53 | # training data to avoid replacing the previous experiment outputs. 54 | 55 | 'sample_along_training': True, 56 | # True to generate samples along the training process. False for nothing. 57 | 58 | 'evaluate_along_training': True, 59 | # True to run evaluation along the training process. False for nothing. 60 | 61 | 'verbose': False, 62 | # True to print each batch details to stdout. False to print once an epoch. 63 | 64 | 'pretrained_dir': None, 65 | # The directory containing the pretrained model. None to retrain the 66 | # model from scratch. 67 | 68 | 'gan_type': 'gan', 69 | # {'gan', 'wgan', 'wgan-gp'} 70 | # The type of GAN objective to use. Currently support GAN, Wasserstein GAN 71 | # (WGAN), Wasserstein GAN with gradient penalties (WGAN-GP). 72 | 73 | 'optimizer': 'adam', 74 | # {'adam', 'rmsprop'} 75 | # The optimizer to use. Currently support Adam and RMSProp optimizers. 76 | 77 | 'preset_g': 'mlp_bernoulli', 78 | # BinaryGAN: {'mlp_bernoulli', 'mlp_round', 'cnn_bernoulli', 'cnn_round'} 79 | # GAN: {'mlp_real', 'cnn_real'} 80 | # Use a preset network architecture for the generator or set to None and 81 | # setup `CONFIG['model']['net_g']` to define the network architecture. 82 | 83 | 'preset_d': 'mlp', 84 | # {'mlp', 'cnn', 'mlp_bn', 'cnn_bn'} 85 | # Use a preset network architecture for the discriminator or set to None 86 | # and setup `CONFIG['model']['net_d']` to define the network architecture. 87 | } 88 | 89 | CONFIG = {} 90 | 91 | #=============================================================================== 92 | #=========================== TensorFlow Configuration ========================== 93 | #=============================================================================== 94 | os.environ['CUDA_VISIBLE_DEVICES'] = SETUP['gpu'] 95 | CONFIG['tensorflow'] = tf.ConfigProto() 96 | CONFIG['tensorflow'].gpu_options.allow_growth = True 97 | 98 | #=============================================================================== 99 | #========================== Experiment Configuration =========================== 100 | #=============================================================================== 101 | CONFIG['exp'] = { 102 | 'model': None, 103 | 'exp_name': None, 104 | 'pretrained_dir': None, 105 | } 106 | 107 | for key in ('model', 'pretrained_dir'): 108 | if CONFIG['exp'][key] is None: 109 | CONFIG['exp'][key] = SETUP[key] 110 | 111 | # Set default experiment name 112 | if CONFIG['exp']['exp_name'] is None: 113 | if SETUP['exp_name'] is not None: 114 | CONFIG['exp']['exp_name'] = SETUP['exp_name'] 115 | else: 116 | CONFIG['exp']['exp_name'] = '_'.join( 117 | (SETUP['prefix'], SETUP['model'], SETUP['gan_type'], 118 | 'g', SETUP['preset_g'], 'd', SETUP['preset_d']) 119 | ) 120 | 121 | #=============================================================================== 122 | #============================= Data Configuration ============================== 123 | #=============================================================================== 124 | CONFIG['data'] = { 125 | 'training_data': None, 126 | 'training_data_location': None, 127 | } 128 | 129 | for key in ('training_data', 'training_data_location'): 130 | if CONFIG['data'][key] is None: 131 | CONFIG['data'][key] = SETUP[key] 132 | 133 | #=============================================================================== 134 | #=========================== Training Configuration ============================ 135 | #=============================================================================== 136 | CONFIG['train'] = { 137 | 'sample_along_training': None, 138 | 'evaluate_along_training': None, 139 | 'verbose': None, 140 | 'num_epoch': 20, 141 | 'slope_annealing_rate': 1.1, 142 | } 143 | 144 | for key in ('verbose', 'sample_along_training', 'evaluate_along_training'): 145 | if CONFIG['train'][key] is None: 146 | CONFIG['train'][key] = SETUP[key] 147 | 148 | #=============================================================================== 149 | #============================= Model Configuration ============================= 150 | #=============================================================================== 151 | CONFIG['model'] = { 152 | # Parameters 153 | 'batch_size': 64, # Note: tf.layers.conv3d_transpose requires a fixed batch 154 | # size in TensorFlow < 1.6 155 | 'gan': { 156 | 'type': None, # 'gan', 'wgan', 'wgan-gp' 157 | 'clip_value': .01, 158 | 'gp_coefficient': 10. 159 | }, 160 | 'optimizer': { 161 | 'type': None, 162 | 'lr': .0001, 163 | 'epsilon': 1e-8, 164 | # Parameters for Adam optimizers 165 | 'beta1': .5, 166 | 'beta2': .9, 167 | # Parameters for RMSProp optimizers 168 | 'momentum': 0.0, 169 | 'decay': .9, 170 | }, 171 | 172 | # Data 173 | 'out_width': 28, 174 | 'out_height': 28, 175 | 'out_channel': 1, 176 | 177 | # Network architectures (define them here if not using the presets) 178 | 'net_g': None, 179 | 'net_d': None, 180 | 'net_r': None, 181 | 182 | # Samples 183 | 'num_sample': 64, 184 | 'sample_grid': (8, 8), 185 | 186 | # Directories 187 | 'checkpoint_dir': None, 188 | 'sample_dir': None, 189 | 'eval_dir': None, 190 | 'log_dir': None, 191 | 'src_dir': None, 192 | } 193 | 194 | if CONFIG['model']['gan']['type'] is None: 195 | CONFIG['model']['gan']['type'] = SETUP['gan_type'] 196 | if CONFIG['model']['optimizer']['type'] is None: 197 | CONFIG['model']['optimizer']['type'] = SETUP['optimizer'] 198 | 199 | # Import preset network architectures 200 | if CONFIG['model']['net_g'] is None: 201 | IMPORTED = importlib.import_module('.'.join(( 202 | 'bgan.mnist.presets', 'generator', SETUP['preset_g'] 203 | ))) 204 | CONFIG['model']['net_g'] = IMPORTED.NET_G 205 | if CONFIG['model']['net_d'] is None: 206 | IMPORTED = importlib.import_module('.'.join(( 207 | 'bgan.mnist.presets', 'discriminator', SETUP['preset_d'] 208 | ))) 209 | CONFIG['model']['net_d'] = IMPORTED.NET_D 210 | 211 | # Set default directories 212 | for kv_pair in (('checkpoint_dir', 'checkpoints'), ('sample_dir', 'samples'), 213 | ('eval_dir', 'eval'), ('log_dir', 'logs'), ('src_dir', 'src')): 214 | if CONFIG['model'][kv_pair[0]] is None: 215 | CONFIG['model'][kv_pair[0]] = os.path.join( 216 | os.path.dirname(os.path.realpath(__file__)), 'exp', SETUP['model'], 217 | CONFIG['exp']['exp_name'], kv_pair[1]) 218 | 219 | #=============================================================================== 220 | #=================== Make directories & Backup source code ===================== 221 | #=============================================================================== 222 | # Make sure directories exist 223 | for path in (CONFIG['model']['checkpoint_dir'], CONFIG['model']['sample_dir'], 224 | CONFIG['model']['eval_dir'], CONFIG['model']['log_dir'], 225 | CONFIG['model']['src_dir']): 226 | if not os.path.exists(path): 227 | os.makedirs(path) 228 | 229 | # Backup source code 230 | for path in os.listdir(os.path.dirname(os.path.realpath(__file__))): 231 | if os.path.isfile(path): 232 | if path.endswith('.py'): 233 | shutil.copyfile(os.path.basename(path), 234 | os.path.join(CONFIG['model']['src_dir'], 235 | os.path.basename(path))) 236 | 237 | distutils.dir_util.copy_tree( 238 | os.path.join(os.path.dirname(os.path.realpath(__file__)), 'bgan'), 239 | os.path.join(CONFIG['model']['src_dir'], 'bgan') 240 | ) 241 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal 2 | title: BinaryGAN 3 | description: Training GANs with Binary Neurons by Backpropagation 4 | -------------------------------------------------------------------------------- /docs/_includes/audio_player.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_includes/icon_link.html: -------------------------------------------------------------------------------- 1 |  {{ include.text }} -------------------------------------------------------------------------------- /docs/_includes/video_player.html: -------------------------------------------------------------------------------- 1 |
-------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | {% seo %} 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 |
16 | 17 | 18 |
19 |

20 | Hao-Wen Dong, 21 | Yi-Hsuan Yang 22 |

23 |

24 | Music and AI Lab,
Research Center for IT Innovation,
Academia Sinica 25 |

26 |
27 | 28 | 29 | 30 | 31 | 49 |
50 |
51 | 52 | {{ content }} 53 | 54 |
55 | 59 |
60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /docs/_sass/jekyll-theme-minimal.scss: -------------------------------------------------------------------------------- 1 | @import "fonts"; 2 | @import "rouge-github"; 3 | 4 | // Main wrappers 5 | body { 6 | overflow-x: hidden; 7 | margin-left: calc(100vw - 100%); 8 | padding: 50px; 9 | background-color: #fff; 10 | font: 16px/1.5 "Noto Sans", "Helvetica Neue", Helvetica, Arial, sans-serif; 11 | font-weight: 400; 12 | color: #333; 13 | } 14 | 15 | .wrapper { 16 | margin: 0 auto; 17 | width: 1100px; 18 | } 19 | 20 | header { 21 | width: 260px; 22 | float: left; 23 | position: fixed; 24 | -webkit-font-smoothing: subpixel-antialiased; 25 | } 26 | 27 | section, footer { 28 | float: right; 29 | width: 740px; 30 | } 31 | 32 | footer { 33 | padding-bottom: 50px; 34 | -webkit-font-smoothing: subpixel-antialiased; 35 | } 36 | 37 | // Menu 38 | .main-menu-checkbox, .menu-toggle, .main-menu-close { 39 | display: none; 40 | } 41 | 42 | // Header 43 | .title a { 44 | color: #111; 45 | } 46 | 47 | .author-names { 48 | margin: 0 0 5px; 49 | } 50 | 51 | .main-menu .subpage-links { 52 | margin: 0 0 20px; 53 | } 54 | 55 | .main-menu .link-button:hover, .main-menu .link-button:focus { 56 | background: #eee; 57 | } 58 | 59 | // Video player 60 | .video-container { 61 | display: block; 62 | overflow: hidden; 63 | margin-left: auto; 64 | margin-right: auto; 65 | margin-bottom: 20px; 66 | width: 100%; 67 | max-width: 600px; 68 | } 69 | 70 | .video-inner-container { 71 | position: relative; 72 | overflow: hidden; 73 | width: 100%; 74 | height: 0; 75 | padding-bottom: 56.25%; 76 | } 77 | 78 | .video-container iframe { 79 | position: absolute; 80 | top: 0; 81 | left: 0; 82 | width: 100%; 83 | height: 100%; 84 | } 85 | 86 | // Special classes 87 | .caption { 88 | margin-top: -15px; 89 | text-align: center; 90 | } 91 | 92 | .caption-above { 93 | margin-bottom: 5px; 94 | text-align: center; 95 | } 96 | 97 | .switch-on-hover .hide-on-hover { 98 | display: inline-block; 99 | } 100 | 101 | .switch-on-hover .show-on-hover { 102 | display: none; 103 | } 104 | 105 | .switch-on-hover:hover .hide-on-hover { 106 | display: none; 107 | } 108 | 109 | .switch-on-hover:hover .show-on-hover { 110 | display: inline-block; 111 | } 112 | 113 | // Main components 114 | h1, h2, h3, h4, h5, h6 { 115 | margin: 0 0 20px; 116 | } 117 | 118 | p, ul, ol, table, pre, dl, audio { 119 | margin: 0 0 20px; 120 | } 121 | 122 | h1, h2, h3 { 123 | line-height: 1.1; 124 | color: #111; 125 | } 126 | 127 | h1 { 128 | font-size: 28px; 129 | } 130 | 131 | h4, h5, h6 { 132 | color: #222; 133 | } 134 | 135 | small { 136 | font-size: 11px; 137 | } 138 | 139 | strong { 140 | font-weight: 700; 141 | } 142 | 143 | a { 144 | text-decoration: none; 145 | color: #267CB9; 146 | } 147 | 148 | a:not(.nohover):hover, a:not(.nohover):focus { 149 | text-shadow: 0 0.015em #069,0 -0.015em #069,0.01em 0 #069,-0.01em 0 #069; 150 | color: #069; 151 | } 152 | 153 | footer small a { 154 | color: #555; 155 | } 156 | 157 | blockquote { 158 | border-left: 1px solid #e5e5e5; 159 | margin: 0; 160 | padding: 0 0 0 20px; 161 | font-style: italic; 162 | } 163 | 164 | code, pre { 165 | font-family: Monaco, Bitstream Vera Sans Mono, Lucida Console, Terminal, Consolas, Liberation Mono, DejaVu Sans Mono, Courier New, monospace; 166 | color: #333; 167 | } 168 | 169 | pre { 170 | border: 1px solid #e5e5e5; 171 | border-radius: 5px; 172 | padding: 8px 15px; 173 | background: #f8f8f8; 174 | overflow-x: auto; 175 | } 176 | 177 | table { 178 | margin-left: auto; 179 | margin-right: auto; 180 | border-collapse: collapse; 181 | } 182 | 183 | table audio, table .video-container { 184 | margin: 0; 185 | } 186 | 187 | th, td { 188 | border-bottom: 1px solid #e5e5e5; 189 | padding: 5px 10px; 190 | text-align: left; 191 | } 192 | 193 | dt { 194 | color: #444; 195 | font-weight: 700; 196 | } 197 | 198 | th { 199 | color: #444; 200 | } 201 | 202 | img, audio { 203 | display: block; 204 | margin-left: auto; 205 | margin-right: auto; 206 | } 207 | 208 | img { 209 | width: 100%; 210 | max-width: 600px; 211 | } 212 | 213 | audio { 214 | max-width: 100%; 215 | } 216 | 217 | hr { 218 | margin: 0 0 20px; 219 | border: 0; 220 | height: 1px; 221 | background: #e5e5e5; 222 | } 223 | 224 | footer hr { 225 | margin: 15px 0 0; 226 | } 227 | 228 | @media print, screen and (max-width: 1200px) { 229 | .wrapper { 230 | width: 860px; 231 | } 232 | 233 | header { 234 | width: 260px; 235 | } 236 | 237 | section, footer { 238 | width: 560px; 239 | } 240 | 241 | ul { 242 | padding-left: 20px; 243 | } 244 | } 245 | 246 | @media print, screen and (max-width: 960px) { 247 | // Main wrappers 248 | .wrapper { 249 | margin: 0; 250 | width: auto; 251 | } 252 | 253 | header, section, footer { 254 | position: static; 255 | float: none; 256 | width: auto; 257 | } 258 | 259 | section { 260 | margin: 0 0 20px; 261 | border: 1px solid #e5e5e5; 262 | border-width: 1px 0; 263 | padding: 20px 0; 264 | } 265 | 266 | // Menu 267 | .menu-toggle, .main-menu { 268 | position: absolute; 269 | right: 55px; 270 | top: 55px; 271 | } 272 | 273 | .menu-toggle { 274 | z-index: 998; 275 | display: block; 276 | padding: 5px 10px; 277 | color: #333; 278 | cursor: pointer; 279 | } 280 | 281 | .menu-toggle:hover { 282 | color:#069; 283 | } 284 | 285 | .main-menu { 286 | z-index: 999; 287 | display: none; 288 | min-width: 120px; 289 | background: #eee; 290 | } 291 | 292 | .main-menu .link-button { 293 | padding: 5px 10px; 294 | } 295 | 296 | .main-menu .link-button:hover, .main-menu .link-button:focus, .main-menu-close:hover { 297 | background: #ddd; 298 | } 299 | 300 | .main-menu-close { 301 | display: block; 302 | position: absolute; 303 | right: 0; 304 | top: 0; 305 | padding: 5px 10px; 306 | cursor: pointer; 307 | } 308 | 309 | .main-menu .subpage-links { 310 | margin: 0; 311 | } 312 | 313 | #main-menu-checkbox:checked ~ .main-menu { 314 | display: block; 315 | } 316 | 317 | #main-menu-checkbox:checked ~ .main-menu .main-menu-close { 318 | display: block; 319 | z-index: 1001; 320 | } 321 | 322 | #main-menu-checkbox:checked ~ .main-menu div { 323 | position: relative; 324 | z-index: 1000; 325 | } 326 | 327 | // Others 328 | .title { 329 | padding: 0 50px 0 0; 330 | } 331 | 332 | .author-info { 333 | font-size: small; 334 | } 335 | 336 | .author-info br { 337 | display: none; 338 | } 339 | 340 | .author-affiliations { 341 | margin: 0 0 5px; 342 | } 343 | 344 | hr { 345 | height: 0; 346 | border: 0; 347 | background: #fff; 348 | margin: 0; 349 | } 350 | } 351 | 352 | @media print, screen and (max-width: 720px) { 353 | body { 354 | word-wrap: break-word; 355 | } 356 | 357 | header { 358 | padding: 0; 359 | } 360 | 361 | header p.view { 362 | position: static; 363 | } 364 | 365 | pre, code { 366 | word-wrap: normal; 367 | } 368 | } 369 | 370 | @media print, screen and (max-width: 480px) { 371 | body { 372 | padding: 15px; 373 | } 374 | 375 | footer { 376 | padding-bottom: 15px; 377 | } 378 | 379 | .menu-toggle, .main-menu { 380 | right: 20px; 381 | top: 20px; 382 | } 383 | } 384 | 385 | @media print, screen and (max-height: 600px) { 386 | .author-info br { 387 | display: none; 388 | } 389 | 390 | .main-menu .subpage-links { 391 | margin: 0; 392 | } 393 | } 394 | 395 | @media print, screen and (max-height: 480px) { 396 | body { 397 | padding: 15px; 398 | } 399 | 400 | footer { 401 | padding-bottom: 15px; 402 | } 403 | 404 | .menu-toggle, .main-menu { 405 | right: 20px; 406 | top: 20px; 407 | } 408 | 409 | .author-info { 410 | font-size: small; 411 | } 412 | } 413 | 414 | @media print { 415 | body { 416 | padding: 0.4in; 417 | font-size: 12pt; 418 | color: #444; 419 | } 420 | 421 | ul { 422 | padding-left: 20px; 423 | } 424 | } 425 | -------------------------------------------------------------------------------- /docs/abstract.md: -------------------------------------------------------------------------------- 1 | # Abstract 2 | 3 | We propose the BinaryGAN, a novel generative adversarial network (GAN) that uses 4 | binary neurons at the output layer of the generator. We employ the 5 | sigmoid-adjusted straight-through estimators to estimate the gradients for the 6 | binary neurons and train the whole network by end-to-end backpropogation. The 7 | proposed model is able to directly generate binary-valued predictions at test 8 | time. We implement such a model to generate binarized MNIST digits and 9 | experimentally compare the performance for different types of binary neurons, 10 | GAN objectives and network architectures. Although the results are still 11 | preliminary, we show that it is possible to train a GAN that has binary neurons 12 | and that the use of gradient estimators can be a promising direction for 13 | modeling discrete distributions with GANs. 14 | -------------------------------------------------------------------------------- /docs/background.md: -------------------------------------------------------------------------------- 1 | # Background 2 | 3 | ## Stochastic and Deterministic Binary Neurons 4 | 5 | _Binary neurons_ (BNs) are neurons that output binary-valued predictions. In 6 | this work, we consider two types of BNs: 7 | 8 | - _Deterministic Binary Neurons_ (DBNs) act like neurons with hard thresholding 9 | functions as their activation functions. We define the output of a DBN for a 10 | real-valued input _x_ as: 11 | formula_dbn 12 | where __1__(·) is the indicator function. 13 | 14 | - _Stochastic Binary Neurons_ (SBNs) binarize a real-valued input _x_ according 15 | to a probability, defined as: 16 | formula_sbn 17 | where σ(·) is the logistic sigmoid function and _U_[0, 1] denotes 18 | an uniform distribution. 19 | 20 | ## Straight-through Estimator 21 | 22 | Computing the exact gradients for either DBNs or SBNs, however, is intractable, 23 | for doing so requires the computation of the average loss over all possible 24 | binary samplings of all the BNs, which is exponential in the total number of 25 | BNs. 26 | 27 | A few solutions have been proposed to address this issue [1, 2]. In this work, 28 | we resort to the sigmoid-adjusted _straight-through_ (ST) estimator when 29 | training networks with DBNs and SBNs. The ST estimator is first proposed in [3], 30 | which simply treats BNs as identify functions and ignores their gradients. The 31 | sigmoid-adjusted ST estimator is a variant which multiply the gradients in the 32 | backward pass by the derivative of the sigmoid function. 33 | 34 | By replacing the non-differentiable functions, which are used in the forward 35 | pass, by differentiable functions (usually called the _estimators_) in the 36 | backward pass, we can then train the whole network with back propagation. 37 | 38 | ## References 39 | 40 | 1. Silviu Pitis, "Binary Stochastic Neurons in Tensorflow," 2016. 41 | Blog post on R2RT blog.   42 | {% include icon_link.html text="link" icon="fas fa-globe-asia" href="https://r2rt.com/binary-stochastic-neurons-in-tensorflow.html" %} 43 | 44 | 2. Yoshua Bengio, Nicholas Léonard, and Aaron C. Courville, 45 | "Estimating or propagating gradients through stochastic neurons for 46 | conditional computation," 47 | _arXiv preprint arXiv:1308.3432_, 2013. 48 | 49 | 3. Geoffrey Hinton, 50 | "Neural networks for machine learning - using noise as a regularizer (lecture 51 | 9c)", 2012. 52 | Coursera, video lecture.   53 | {% include icon_link.html text="video" icon="fab fa-youtube" href="https://www.coursera.org/lecture/neural-networks/using-noise-as-a-regularizer-7-min-wbw7b" %} -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | We use the binarized version of the MNIST handwritten digit database. 4 | Specifically, we 5 | 6 | - convert pixels with nonzero intensities to ones 7 | - convert pixels with zero intensities to zeros. 8 | 9 | The following figure shows some sample binarized MNIST digits seen in our 10 | training data. 11 | 12 | ![training_data](figs/train.png) 13 |

Sample binarized MNIST digits

14 | -------------------------------------------------------------------------------- /docs/figs/cnn_dbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/cnn_dbn.png -------------------------------------------------------------------------------- /docs/figs/cnn_sbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/cnn_sbn.png -------------------------------------------------------------------------------- /docs/figs/colored_dbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/colored_dbn.png -------------------------------------------------------------------------------- /docs/figs/colored_sbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/colored_sbn.png -------------------------------------------------------------------------------- /docs/figs/colormap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/colormap.png -------------------------------------------------------------------------------- /docs/figs/formula_dbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/formula_dbn.png -------------------------------------------------------------------------------- /docs/figs/formula_sbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/formula_sbn.png -------------------------------------------------------------------------------- /docs/figs/histogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/histogram.png -------------------------------------------------------------------------------- /docs/figs/mlp_dbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_dbn.png -------------------------------------------------------------------------------- /docs/figs/mlp_dbn_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_dbn_gan.png -------------------------------------------------------------------------------- /docs/figs/mlp_dbn_wgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_dbn_wgan.png -------------------------------------------------------------------------------- /docs/figs/mlp_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_real.png -------------------------------------------------------------------------------- /docs/figs/mlp_real_bernoulli.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_real_bernoulli.png -------------------------------------------------------------------------------- /docs/figs/mlp_real_round.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_real_round.png -------------------------------------------------------------------------------- /docs/figs/mlp_sbn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_sbn.png -------------------------------------------------------------------------------- /docs/figs/mlp_sbn_gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_sbn_gan.png -------------------------------------------------------------------------------- /docs/figs/mlp_sbn_wgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/mlp_sbn_wgan.png -------------------------------------------------------------------------------- /docs/figs/system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/system.png -------------------------------------------------------------------------------- /docs/figs/train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/figs/train.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |

2 | colored_dbn 3 | colored_sbn 4 |

5 | 6 | [BinaryGAN](https://salu133445.github.io/binarygan/) is a novel generative 7 | adversarial network (GAN) that uses binary neurons at the output layer of the 8 | generator. We employ the sigmoid-adjusted straight-through estimators to 9 | estimate the gradients for the binary neurons and train the whole network by 10 | end-to-end backpropogation. The proposed model is able to directly generate 11 | binary-valued predictions at test time. 12 | 13 | We implement such a model to generate binarized MNIST digits and experimentally 14 | compare the performance for different types of binary neurons, GAN objectives 15 | and network architectures. Although the results are still preliminary, we show 16 | that it is possible to train a GAN that has binary neurons and that the use of 17 | gradient estimators can be a promising direction for modeling discrete 18 | distributions with GANs. 19 | -------------------------------------------------------------------------------- /docs/model.md: -------------------------------------------------------------------------------- 1 | # Model 2 | 3 | The proposed model consists of a _generator G_ and a _discriminator D_. The 4 | generator takes as input a random vector _z_ drawn from a prior distribution 5 | _pz_ and generate a fake sample _G_(_z_). The discriminator takes as 6 | input either a real sample drawn from the data distribution or a fake sample 7 | generated by the generator and outputs a scalar indicating the genuineness of 8 | that sample. The discriminator is trained to tell the fake data from the real 9 | ones. The generator is trained to fool the discriminator. 10 | 11 | In order to handle binary data, we propose to use [binary neurons](background), 12 | either deterministic or stochastic ones, at the output layer (i.e., the final 13 | layer) of the generator. We employ the 14 | [sigmoid-adjusted straight-through estimators](background) to estimate the 15 | gradients for binary neurons and train the whole network by end-to-end 16 | backpropagation. 17 | 18 | The following is the system diagram for the proposed model implemented by 19 | multilayer perceptrons (MLPs). 20 | 21 | ![system](figs/system.png) 22 | -------------------------------------------------------------------------------- /docs/paper.md: -------------------------------------------------------------------------------- 1 | # Paper 2 | 3 | __Training Generative Adversarial Networks with Binary Neurons by End-to-end Backpropagation__
4 | 5 | Hao-Wen Dong and Yi-Hsuan Yang
6 | _arXiv preprint arXiv:1810.04714_, 2018.
7 | {% include icon_link.html text="website" icon="fas fa-globe-asia" href="https://salu133445.github.io/binarygan/" %}  8 | {% include icon_link.html text="arxiv" icon="fas fa-archive" href="https://arxiv.org/abs/1810.04714" %}  9 | {% include icon_link.html text="paper" icon="fas fa-file-pdf" href="https://salu133445.github.io/binarygan/pdf/binarygan-arxiv-paper.pdf" %}  10 | {% include icon_link.html text="slides" icon="fas fa-file-pdf" href="https://salu133445.github.io/binarygan/pdf/binarygan-arxiv-slides.pdf" %}  11 | {% include icon_link.html text="code" icon="fab fa-github" href="https://github.com/salu133445/binarygan" %} 12 |
13 | -------------------------------------------------------------------------------- /docs/pdf/binarygan-arxiv-paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/pdf/binarygan-arxiv-paper.pdf -------------------------------------------------------------------------------- /docs/pdf/binarygan-arxiv-slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salu133445/binarygan/6a068a23e151af052c32b35f4e436582ca17cf4c/docs/pdf/binarygan-arxiv-slides.pdf -------------------------------------------------------------------------------- /docs/results.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | ## Experiment I — Comparison of the proposed model using deterministic binary neurons and stochastic binary neurons 4 | 5 | We compare the performance of using deterministic binary neurons (DBNs) 6 | and stochastic binary neurons (SBNs) in the proposed model. 7 | 8 | | proposed model with DBNs | proposed model with SBNs | 9 | |:------------------------:|:------------------------:| 10 | | mlp_dbn | mlp_sbn | 11 | 12 | The following figures show the preactivated outputs (i.e., the real-valued, 13 | intermediate values right before the binarization operation). 14 | 15 | | proposed model with DBNs | proposed model with SBNs | 16 | |:------------------------:|:------------------------:| 17 | | mlp_dbn_prob | mlp_sbn_prob | 18 | 19 | colormap 20 | 21 | In order to see how DBNs and SBNs work differently, we compute the histograms of 22 | their preactivated outputs. 23 | 24 | ![histogram](figs/histogram.png) 25 | 26 | We can see that 27 | 28 | - DBNs tend to output more preactivated values in the middle of zero and one and 29 | avoid producing preactivated values around the decision boundary (i.e. the 30 | threshold, 0.5 here) 31 | - SBNs tend to output more preactivated values close to zero and one 32 | 33 | ## Experiment II — Comparison of the proposed model and the real-valued model 34 | 35 | We compare the proposed model with a variant that uses normal neurons at the 36 | output layer (with sigmoid functions as the activation functions). We refer t 37 | this model as the _real-valued model_. 38 | 39 | | raw prediction | hard thresholding | Bernoulli sampling | 40 | |:--------------:|:-----------------:|:------------------:| 41 | | mlp_real | mlp_round | mlp_bernoulli | 42 | 43 | We also show the histogram of its probabilistic predictions in the figure above. 44 | We can see that 45 | 46 | - the histogram of the real-valued model is more U-shaped than that of the 47 | proposed model with SBNs 48 | - there is no notch in the middle of the curve as compared to the proposed model 49 | with DBNs. 50 | 51 | From here we can see how different binarization strategies can shape the 52 | characteristics of the preactivated outputs of binary neurons. This also 53 | emphasizes the importance of including the binarization operations in the 54 | training so that the binarization operations themselves can also be optimized. 55 | 56 | ## Experiment III — Comparison of the proposed model trained with the GAN, WGAN and WGAN-GP objectives 57 | 58 | We compare the proposed model trained by the WGAN-GP objective with that trained 59 | by the GAN objective and by the WGAN objective. 60 | 61 | | GAN model with DBNs | GAN model with SBNs | 62 | |:-------------------:|:-------------------:| 63 | | mlp_dbn_gan | mlp_sbn_gan | 64 | 65 | | WGAN model with DBNs | WGAN model with SBNs | 66 | |:--------------------:|:--------------------:| 67 | | mlp_dbn_wgan | mlp_sbn_wgan | 68 | 69 | We can see that the WGAN model is able to generate digits with similar qualities 70 | as the WGAN-GP model does, while the GAN model suffers from the so-called mode 71 | collapse issue. 72 | 73 | ## Experiment IV — Comparison of the proposed model using multilayer perceptrons and convolutional neural networks 74 | 75 | We compare the performance of using multilayer perceptrons (MLPs) and 76 | convolutional neural networks (CNNs). Note that the number of trainable 77 | parameters for the MLP and CNN models are 0.53M and 1.4M, respectively. 78 | 79 | | CNN model with DBNs | CNN model with SBNs | 80 | |:-------------------:|:-------------------:| 81 | | cnn_dbn | cnn_sbn | 82 | 83 | We can see that the CNN model can better capture the characteristics of 84 | different digits and generate less artifacts even with a smaller number of 85 | trainable parameters as compared to the MLP model. 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Script for training the model.""" 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | from bgan.mnist.models import BinaryGAN, GAN 6 | from config import CONFIG 7 | 8 | def load_data(): 9 | """Load and return the training data.""" 10 | print('[*] Loading data...') 11 | 12 | # Load data from SharedArray 13 | if CONFIG['data']['training_data_location'] == 'sa': 14 | import SharedArray as sa 15 | x_train = sa.attach(CONFIG['data']['training_data']) 16 | 17 | # Load data from hard disk 18 | elif CONFIG['data']['training_data_location'] == 'hd': 19 | if os.path.isabs(CONFIG['data']['training_data']): 20 | x_train = np.load(CONFIG['data']['training_data']) 21 | else: 22 | filepath = os.path.abspath(os.path.join( 23 | os.path.realpath(__file__), 'training_data', 24 | CONFIG['data']['training_data'])) 25 | x_train = np.load(filepath) 26 | 27 | return x_train 28 | 29 | def main(): 30 | """Main function.""" 31 | if CONFIG['exp']['model'] not in ('binarygan', 'gan'): 32 | raise ValueError("Unrecognizable model name") 33 | 34 | print("Start experiment: {}".format(CONFIG['exp']['exp_name'])) 35 | 36 | # Load training data 37 | x_train = load_data() 38 | 39 | # Open TensorFlow session 40 | with tf.Session(config=CONFIG['tensorflow']) as sess: 41 | 42 | # Create model 43 | if CONFIG['exp']['model'] == 'gan': 44 | gan = GAN(sess, CONFIG['model']) 45 | elif CONFIG['exp']['model'] == 'binarygan': 46 | gan = BinaryGAN(sess, CONFIG['model']) 47 | 48 | # Initialize all variables 49 | gan.init_all() 50 | 51 | # Load pretrained model if given 52 | if CONFIG['exp']['pretrained_dir'] is not None: 53 | gan.load_latest(CONFIG['exp']['pretrained_dir']) 54 | 55 | # Train the model 56 | gan.train(x_train, CONFIG['train']) 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /training_data/README.md: -------------------------------------------------------------------------------- 1 | # Preparing Training Data 2 | 3 | ## Download the MNIST Handwritten Digit Database 4 | 5 | ```sh 6 | ./download_mnist.sh 7 | ``` 8 | 9 | This will download the MNIST handwritten digit database to the current working 10 | directory. 11 | 12 | ## Load the Training Data to SharedArray 13 | 14 | > Make sure the SharedArray package has been installed. 15 | 16 | Run 17 | 18 | ```sh 19 | python ./load_mnist_to_sa.py ./mnist/ --merge --binary 20 | ``` 21 | 22 | This will load and binarize the MNIST digits and save them to shared memory via 23 | SharedArray package. 24 | -------------------------------------------------------------------------------- /training_data/download_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 3 | wget -P ${DIR} http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 4 | wget -P ${DIR} http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 5 | wget -P ${DIR} http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 6 | wget -P ${DIR} http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 7 | gunzip ${DIR}/*.gz 8 | -------------------------------------------------------------------------------- /training_data/load_mnist_to_sa.py: -------------------------------------------------------------------------------- 1 | """Load and save the MNIST dataset to shared memory via SharedArray package.""" 2 | import os 3 | import argparse 4 | import numpy as np 5 | import SharedArray as sa 6 | 7 | def parse_arguments(): 8 | """Parse and return the command line arguments.""" 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('dataset_root', help="Root directory of the dataset.") 11 | parser.add_argument('--prefix', default='', 12 | help="Prefix to the file name to save in SharedArray.") 13 | parser.add_argument('--merge', help='Merge train and test set', 14 | action='store_true') 15 | parser.add_argument('--binary', help='Binarize the data.', 16 | action='store_true') 17 | parser.add_argument('--labels', help='Store the labels as well.', 18 | action='store_true') 19 | parser.add_argument('--onehot', help='Use onehot encoding for the labels', 20 | action='store_true') 21 | args = parser.parse_args() 22 | return (args.dataset_root, args.prefix, args.merge, args.binary, 23 | args.labels, args.onehot) 24 | 25 | def save_to_sa(name, data): 26 | """Save data to SharedArray.""" 27 | arr = sa.create(name, data.shape, data.dtype) 28 | np.copyto(arr, data) 29 | 30 | def load(dataset_root, prefix, merge, binary, labels, onehot): 31 | """Load and save the dataset to SharedArray.""" 32 | with open(os.path.join(dataset_root, 'train-images-idx3-ubyte')) as file: 33 | loaded = np.fromfile(file=file, dtype=np.uint8) 34 | trX = loaded[16:].reshape((60000, 28, 28, 1)) 35 | if binary: 36 | trX = (trX > 0) 37 | if not merge: 38 | save_to_sa('_'.join((prefix, 'binarized_mnist_x_train')), trX) 39 | elif not merge: 40 | save_to_sa('_'.join((prefix, 'mnist_x_train')), trX) 41 | 42 | with open(os.path.join(dataset_root, 't10k-images-idx3-ubyte')) as file: 43 | loaded = np.fromfile(file=file, dtype=np.uint8) 44 | teX = loaded[16:].reshape((10000, 28, 28, 1)) 45 | if binary: 46 | teX = (teX > 0) 47 | if not merge: 48 | save_to_sa('_'.join((prefix, 'binarized_mnist_x_test')), teX) 49 | elif not merge: 50 | save_to_sa('_'.join((prefix, 'mnist_x_test')), teX) 51 | 52 | if merge: 53 | if binary: 54 | filename = '_'.join((prefix, 'binarized_mnist_x')) 55 | else: 56 | filename = '_'.join((prefix, 'mnist_x')) 57 | save_to_sa(filename, np.concatenate((trX, teX))) 58 | 59 | if not labels: 60 | return 61 | 62 | with open(os.path.join(dataset_root, 'train-labels-idx1-ubyte')) as file: 63 | loaded = np.fromfile(file=file, dtype=np.uint8) 64 | trY = loaded[8:].reshape((60000)) 65 | if onehot: 66 | onehot_encoded = np.zeros((60000, 10), np.bool_) 67 | onehot_encoded[np.arange(60000), trY] = True 68 | trY = onehot_encoded 69 | if not merge: 70 | save_to_sa('_'.join((prefix, 'mnist_y_train_onehot')), trY) 71 | elif not merge: 72 | save_to_sa('_'.join((prefix, 'mnist_y_train')), trY) 73 | 74 | with open(os.path.join(dataset_root, 't10k-labels-idx1-ubyte')) as file: 75 | loaded = np.fromfile(file=file, dtype=np.uint8) 76 | teY = loaded[8:].reshape((10000)) 77 | if onehot: 78 | onehot_encoded = np.zeros((10000, 10), np.bool_) 79 | onehot_encoded[np.arange(10000), teY] = True 80 | teY = onehot_encoded 81 | if not merge: 82 | save_to_sa('_'.join((prefix, 'mnist_y_test_onehot')), teY) 83 | elif not merge: 84 | save_to_sa('_'.join((prefix, 'mnist_y_test')), teY) 85 | 86 | if merge: 87 | if onehot: 88 | filename = '_'.join((prefix, 'mnist_y_onehot')) 89 | else: 90 | filename = '_'.join((prefix, 'mnist_y')) 91 | save_to_sa(filename, np.concatenate((trY, teY))) 92 | 93 | def main(): 94 | """Main function""" 95 | dataset_root, prefix, merge, binary, labels, onehot = parse_arguments() 96 | load(dataset_root, prefix, merge, binary, labels, onehot) 97 | 98 | if __name__ == '__main__': 99 | main() 100 | --------------------------------------------------------------------------------