├── .gitignore ├── Figures ├── Autoencoder.pdf ├── Autoencoder.png ├── BAGAN-init.pdf ├── BAGAN-init.png ├── BAGAN-train.pdf ├── BAGAN-train.png ├── bagan_x5_minority.png └── plot_class_0.png ├── LICENSE ├── README.md ├── bagan_train.py ├── balancing_gan.py ├── run.sh ├── rw └── batch_generator.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | res_* 2 | dataset 3 | -------------------------------------------------------------------------------- /Figures/Autoencoder.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/Autoencoder.pdf -------------------------------------------------------------------------------- /Figures/Autoencoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/Autoencoder.png -------------------------------------------------------------------------------- /Figures/BAGAN-init.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/BAGAN-init.pdf -------------------------------------------------------------------------------- /Figures/BAGAN-init.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/BAGAN-init.png -------------------------------------------------------------------------------- /Figures/BAGAN-train.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/BAGAN-train.pdf -------------------------------------------------------------------------------- /Figures/BAGAN-train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/BAGAN-train.png -------------------------------------------------------------------------------- /Figures/bagan_x5_minority.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/bagan_x5_minority.png -------------------------------------------------------------------------------- /Figures/plot_class_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IBM/BAGAN/b57fc1513a1ab5f9050cf301a370fc7ff8c21822/Figures/plot_class_0.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Eclipse Public License - v 1.0 2 | 3 | THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC 4 | LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM 5 | CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. 6 | 7 | 1. DEFINITIONS 8 | 9 | "Contribution" means: 10 | 11 | a) in the case of the initial Contributor, the initial code and documentation 12 | distributed under this Agreement, and 13 | b) in the case of each subsequent Contributor: 14 | i) changes to the Program, and 15 | ii) additions to the Program; 16 | 17 | where such changes and/or additions to the Program originate from and are 18 | distributed by that particular Contributor. A Contribution 'originates' 19 | from a Contributor if it was added to the Program by such Contributor 20 | itself or anyone acting on such Contributor's behalf. Contributions do not 21 | include additions to the Program which: (i) are separate modules of 22 | software distributed in conjunction with the Program under their own 23 | license agreement, and (ii) are not derivative works of the Program. 24 | 25 | "Contributor" means any person or entity that distributes the Program. 26 | 27 | "Licensed Patents" mean patent claims licensable by a Contributor which are 28 | necessarily infringed by the use or sale of its Contribution alone or when 29 | combined with the Program. 30 | 31 | "Program" means the Contributions distributed in accordance with this 32 | Agreement. 33 | 34 | "Recipient" means anyone who receives the Program under this Agreement, 35 | including all Contributors. 36 | 37 | 2. GRANT OF RIGHTS 38 | a) Subject to the terms of this Agreement, each Contributor hereby grants 39 | Recipient a non-exclusive, worldwide, royalty-free copyright license to 40 | reproduce, prepare derivative works of, publicly display, publicly 41 | perform, distribute and sublicense the Contribution of such Contributor, 42 | if any, and such derivative works, in source code and object code form. 43 | b) Subject to the terms of this Agreement, each Contributor hereby grants 44 | Recipient a non-exclusive, worldwide, royalty-free patent license under 45 | Licensed Patents to make, use, sell, offer to sell, import and otherwise 46 | transfer the Contribution of such Contributor, if any, in source code and 47 | object code form. This patent license shall apply to the combination of 48 | the Contribution and the Program if, at the time the Contribution is 49 | added by the Contributor, such addition of the Contribution causes such 50 | combination to be covered by the Licensed Patents. The patent license 51 | shall not apply to any other combinations which include the Contribution. 52 | No hardware per se is licensed hereunder. 53 | c) Recipient understands that although each Contributor grants the licenses 54 | to its Contributions set forth herein, no assurances are provided by any 55 | Contributor that the Program does not infringe the patent or other 56 | intellectual property rights of any other entity. Each Contributor 57 | disclaims any liability to Recipient for claims brought by any other 58 | entity based on infringement of intellectual property rights or 59 | otherwise. As a condition to exercising the rights and licenses granted 60 | hereunder, each Recipient hereby assumes sole responsibility to secure 61 | any other intellectual property rights needed, if any. For example, if a 62 | third party patent license is required to allow Recipient to distribute 63 | the Program, it is Recipient's responsibility to acquire that license 64 | before distributing the Program. 65 | d) Each Contributor represents that to its knowledge it has sufficient 66 | copyright rights in its Contribution, if any, to grant the copyright 67 | license set forth in this Agreement. 68 | 69 | 3. REQUIREMENTS 70 | 71 | A Contributor may choose to distribute the Program in object code form under 72 | its own license agreement, provided that: 73 | 74 | a) it complies with the terms and conditions of this Agreement; and 75 | b) its license agreement: 76 | i) effectively disclaims on behalf of all Contributors all warranties 77 | and conditions, express and implied, including warranties or 78 | conditions of title and non-infringement, and implied warranties or 79 | conditions of merchantability and fitness for a particular purpose; 80 | ii) effectively excludes on behalf of all Contributors all liability for 81 | damages, including direct, indirect, special, incidental and 82 | consequential damages, such as lost profits; 83 | iii) states that any provisions which differ from this Agreement are 84 | offered by that Contributor alone and not by any other party; and 85 | iv) states that source code for the Program is available from such 86 | Contributor, and informs licensees how to obtain it in a reasonable 87 | manner on or through a medium customarily used for software exchange. 88 | 89 | When the Program is made available in source code form: 90 | 91 | a) it must be made available under this Agreement; and 92 | b) a copy of this Agreement must be included with each copy of the Program. 93 | Contributors may not remove or alter any copyright notices contained 94 | within the Program. 95 | 96 | Each Contributor must identify itself as the originator of its Contribution, 97 | if 98 | any, in a manner that reasonably allows subsequent Recipients to identify the 99 | originator of the Contribution. 100 | 101 | 4. COMMERCIAL DISTRIBUTION 102 | 103 | Commercial distributors of software may accept certain responsibilities with 104 | respect to end users, business partners and the like. While this license is 105 | intended to facilitate the commercial use of the Program, the Contributor who 106 | includes the Program in a commercial product offering should do so in a manner 107 | which does not create potential liability for other Contributors. Therefore, 108 | if a Contributor includes the Program in a commercial product offering, such 109 | Contributor ("Commercial Contributor") hereby agrees to defend and indemnify 110 | every other Contributor ("Indemnified Contributor") against any losses, 111 | damages and costs (collectively "Losses") arising from claims, lawsuits and 112 | other legal actions brought by a third party against the Indemnified 113 | Contributor to the extent caused by the acts or omissions of such Commercial 114 | Contributor in connection with its distribution of the Program in a commercial 115 | product offering. The obligations in this section do not apply to any claims 116 | or Losses relating to any actual or alleged intellectual property 117 | infringement. In order to qualify, an Indemnified Contributor must: 118 | a) promptly notify the Commercial Contributor in writing of such claim, and 119 | b) allow the Commercial Contributor to control, and cooperate with the 120 | Commercial Contributor in, the defense and any related settlement 121 | negotiations. The Indemnified Contributor may participate in any such claim at 122 | its own expense. 123 | 124 | For example, a Contributor might include the Program in a commercial product 125 | offering, Product X. That Contributor is then a Commercial Contributor. If 126 | that Commercial Contributor then makes performance claims, or offers 127 | warranties related to Product X, those performance claims and warranties are 128 | such Commercial Contributor's responsibility alone. Under this section, the 129 | Commercial Contributor would have to defend claims against the other 130 | Contributors related to those performance claims and warranties, and if a 131 | court requires any other Contributor to pay any damages as a result, the 132 | Commercial Contributor must pay those damages. 133 | 134 | 5. NO WARRANTY 135 | 136 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON AN 137 | "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR 138 | IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, 139 | NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each 140 | Recipient is solely responsible for determining the appropriateness of using 141 | and distributing the Program and assumes all risks associated with its 142 | exercise of rights under this Agreement , including but not limited to the 143 | risks and costs of program errors, compliance with applicable laws, damage to 144 | or loss of data, programs or equipment, and unavailability or interruption of 145 | operations. 146 | 147 | 6. DISCLAIMER OF LIABILITY 148 | 149 | EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY 150 | CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, 151 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION 152 | LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 153 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 154 | ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE 155 | EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY 156 | OF SUCH DAMAGES. 157 | 158 | 7. GENERAL 159 | 160 | If any provision of this Agreement is invalid or unenforceable under 161 | applicable law, it shall not affect the validity or enforceability of the 162 | remainder of the terms of this Agreement, and without further action by the 163 | parties hereto, such provision shall be reformed to the minimum extent 164 | necessary to make such provision valid and enforceable. 165 | 166 | If Recipient institutes patent litigation against any entity (including a 167 | cross-claim or counterclaim in a lawsuit) alleging that the Program itself 168 | (excluding combinations of the Program with other software or hardware) 169 | infringes such Recipient's patent(s), then such Recipient's rights granted 170 | under Section 2(b) shall terminate as of the date such litigation is filed. 171 | 172 | All Recipient's rights under this Agreement shall terminate if it fails to 173 | comply with any of the material terms or conditions of this Agreement and does 174 | not cure such failure in a reasonable period of time after becoming aware of 175 | such noncompliance. If all Recipient's rights under this Agreement terminate, 176 | Recipient agrees to cease use and distribution of the Program as soon as 177 | reasonably practicable. However, Recipient's obligations under this Agreement 178 | and any licenses granted by Recipient relating to the Program shall continue 179 | and survive. 180 | 181 | Everyone is permitted to copy and distribute copies of this Agreement, but in 182 | order to avoid inconsistency the Agreement is copyrighted and may only be 183 | modified in the following manner. The Agreement Steward reserves the right to 184 | publish new versions (including revisions) of this Agreement from time to 185 | time. No one other than the Agreement Steward has the right to modify this 186 | Agreement. The Eclipse Foundation is the initial Agreement Steward. The 187 | Eclipse Foundation may assign the responsibility to serve as the Agreement 188 | Steward to a suitable separate entity. Each new version of the Agreement will 189 | be given a distinguishing version number. The Program (including 190 | Contributions) may always be distributed subject to the version of the 191 | Agreement under which it was received. In addition, after a new version of the 192 | Agreement is published, Contributor may elect to distribute the Program 193 | (including its Contributions) under the new version. Except as expressly 194 | stated in Sections 2(a) and 2(b) above, Recipient receives no rights or 195 | licenses to the intellectual property of any Contributor under this Agreement, 196 | whether expressly, by implication, estoppel or otherwise. All rights in the 197 | Program not expressly granted under this Agreement are reserved. 198 | 199 | This Agreement is governed by the laws of the State of New York and the 200 | intellectual property laws of the United States of America. No party to this 201 | Agreement will bring a legal action under this Agreement more than one year 202 | after the cause of action arose. Each party waives its rights to a jury trial in 203 | any resulting litigation. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BAGAN 2 | Keras implementation of [Balancing GAN (BAGAN)](https://arxiv.org/abs/1803.09655) applied to the MNIST example. 3 | 4 | The framework is meant as a tool for data augmentation for imbalanced image-classification datasets where some classes are under represented. 5 | The generative model applied to sample new images for the minority class is trained in three steps: a) training a preliminary autoencoder, b) initialization of the generative adversarial framework by means of the pre-trained autoencoder modules, and c) fine tuning of the generative model in adversarial mode. 6 | 7 | 8 | 9 | Along these steps, the generative model learns from all available data including minority and majority classes. This enables the model to automatically figuring out if and which features from over-represented classes can be used to draw new images for under-represented classes. 10 | For example, when considering a traffic sign recognition problem, all warning signs share the same external triangular shape. 11 | BAGAN can easily learn the triangular shape from any warning sign in the majority classes and reuse this pattern to draw other under-represented warning signs. 12 | 13 | The application of this approach toward fairness enhancement and bias mitigation in deep-learning AI systems is currently an active research topic. 14 | 15 | ## Example results 16 | 17 | The [German Traffic Sign Recognition benchmark](http://benchmark.ini.rub.de/) is an example of imbalanced dataset composed of 43 classes, where the minority class appears 210 times, whereas the majority class 2250 times. 18 | 19 | Here we show representative sample images generated with BAGAN for the three least represented classes. 20 | 21 | ![alt text](Figures/bagan_x5_minority.png) 22 | 23 | Refer to the original work (https://arxiv.org/abs/1803.09655) for a comparison to other state of the art approaches. 24 | 25 | 26 | The code in this repository executes on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). The dataset is originally balanced and, before to train BAGAN, we force class imbalance by selecting a target class and removing from the training dataset a significant portion of its instances. 27 | The following figure shows 0-image samples generated when dropping 97.5\% of 0-images from the training set before training. 28 | 29 | ![alt text](Figures/plot_class_0.png) 30 | 31 | 32 | 33 | ## Running the MNIST example 34 | 35 | This software has been tested on `tensorflow-1.5.0`. 36 | 37 | To execute BAGAN for the MNIST example, run the command: 38 | `./run.sh` 39 | 40 | A directory named: `res_MNIST_dmode_uniform_gmode_uniform_unbalance_0.05_epochs_150_lr_0.000050_seed_0` will be generated and results will be stored there. 41 | 42 | After the training, you will find in that directory a set of `h5` files stroging the model weights, a set of `csv` files storing the loss functions measured for each epoch, a set of `npy` files storing means and covariances distributions for the class-conditional latent-vector generator, a set of `cmp_class__epoch_.png` showing example images obtained during training. 43 | 44 | The file `cmp_class__epoch_.png` shows images obtained when training the GAN for Y epochs and considering class X as minority class. There are three row per class: 1) real samples, 2) autoencoded reconstructed samples, 3) randomly-generated samples. 45 | Note that in BAGAN, after the initial autoencored training, the generative model is trained in adversarial mode and the autoencoder loss is no longer taken into account. Thus, during the adversarial training the autoencoded images may start to deteriorate (row 2 may no longer match row 1), whereas the generated images (row 3) will improve quality. 46 | 47 | For more information about available execution options: 48 | `python ./bagan_train.py --help` 49 | 50 | -------------------------------------------------------------------------------- /bagan_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright IBM Corporation 2018 3 | 4 | All rights reserved. This program and the accompanying materials 5 | are made available under the terms of the Eclipse Public License v1.0 6 | which accompanies this distribution, and is available at 7 | http://www.eclipse.org/legal/epl-v10.html 8 | """ 9 | 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | 14 | from optparse import OptionParser 15 | 16 | import balancing_gan as bagan 17 | from rw.batch_generator import BatchGenerator as BatchGenerator 18 | from utils import save_image_array 19 | 20 | import os 21 | 22 | 23 | 24 | if __name__ == '__main__': 25 | # Collect arguments 26 | argParser = OptionParser() 27 | 28 | argParser.add_option("-u", "--unbalance", default=0.2, 29 | action="store", type="float", dest="unbalance", 30 | help="Unbalance factor u. The minority class has at most u * otherClassSamples instances.") 31 | 32 | argParser.add_option("-s", "--random_seed", default=0, 33 | action="store", type="int", dest="seed", 34 | help="Random seed for repeatable subsampling.") 35 | 36 | argParser.add_option("-d", "--sampling_mode_for_discriminator", default="uniform", 37 | action="store", type="string", dest="dratio_mode", 38 | help="Dratio sampling mode (\"uniform\",\"rebalance\").") 39 | 40 | argParser.add_option("-g", "--sampling_mode_for_generator", default="uniform", 41 | action="store", type="string", dest="gratio_mode", 42 | help="Gratio sampling mode (\"uniform\",\"rebalance\").") 43 | 44 | argParser.add_option("-e", "--epochs", default=3, 45 | action="store", type="int", dest="epochs", 46 | help="Training epochs.") 47 | 48 | argParser.add_option("-l", "--learning_rate", default=0.00005, 49 | action="store", type="float", dest="adam_lr", 50 | help="Training learning rate.") 51 | 52 | argParser.add_option("-c", "--target_class", default=-1, 53 | action="store", type="int", dest="target_class", 54 | help="If greater or equal to 0, model trained only for the specified class.") 55 | 56 | argParser.add_option("-D", "--dataset", default='MNIST', 57 | action="store", type="string", dest="dataset", 58 | help="Either 'MNIST', or 'CIFAR10'.") 59 | 60 | (options, args) = argParser.parse_args() 61 | 62 | assert (options.unbalance <= 1.0 and options.unbalance > 0.0), "Data unbalance factor must be > 0 and <= 1" 63 | 64 | print("Executing BAGAN.") 65 | 66 | # Read command line parameters 67 | np.random.seed(options.seed) 68 | unbalance = options.unbalance 69 | gratio_mode = options.gratio_mode 70 | dratio_mode = options.dratio_mode 71 | gan_epochs = options.epochs 72 | adam_lr = options.adam_lr 73 | opt_class = options.target_class 74 | batch_size = 128 75 | dataset_name = options.dataset 76 | 77 | # Set channels for mnist. 78 | channels = 1 if dataset_name == 'MNIST' else 3 79 | print('Using dataset: ', dataset_name) 80 | 81 | # Result directory 82 | res_dir = "./res_{}_dmode_{}_gmode_{}_unbalance_{}_epochs_{}_lr_{:f}_seed_{}".format( 83 | dataset_name, dratio_mode, gratio_mode, unbalance, options.epochs, adam_lr, options.seed 84 | ) 85 | if not os.path.exists(res_dir): 86 | os.makedirs(res_dir) 87 | 88 | # Read initial data. 89 | print("read input data...") 90 | bg_train_full = BatchGenerator(BatchGenerator.TRAIN, batch_size, 91 | class_to_prune=None, unbalance=None, dataset=dataset_name) 92 | bg_test = BatchGenerator(BatchGenerator.TEST, batch_size, 93 | class_to_prune=None, unbalance=None, dataset=dataset_name) 94 | 95 | print("input data loaded...") 96 | 97 | shape = bg_train_full.get_image_shape() 98 | 99 | min_latent_res = shape[-1] 100 | while min_latent_res > 8: 101 | min_latent_res = min_latent_res / 2 102 | min_latent_res = int(min_latent_res) 103 | 104 | classes = bg_train_full.get_label_table() 105 | 106 | # Initialize statistics information 107 | gan_train_losses = defaultdict(list) 108 | gan_test_losses = defaultdict(list) 109 | 110 | img_samples = defaultdict(list) 111 | 112 | # For all possible minority classes. 113 | target_classes = np.array(range(len(classes))) 114 | if opt_class >= 0: 115 | min_classes = np.array([opt_class]) 116 | else: 117 | min_classes = target_classes 118 | 119 | for c in min_classes: 120 | # If unbalance is 1.0, then the same BAGAN model can be applied to every class because 121 | # we do not drop any instance at training time. 122 | if unbalance == 1.0 and c > 0 and ( 123 | os.path.exists("{}/class_0_score.csv".format(res_dir, c)) and 124 | os.path.exists("{}/class_0_discriminator.h5".format(res_dir, c)) and 125 | os.path.exists("{}/class_0_generator.h5".format(res_dir, c)) and 126 | os.path.exists("{}/class_0_reconstructor.h5".format(res_dir, c)) 127 | ): 128 | # Without additional imbalance, BAGAN does not need to be retrained, we simlink the pregenerated model 129 | os.symlink("{}/class_0_score.csv".format(res_dir), "{}/class_{}_score.csv".format(res_dir, c)) 130 | os.symlink("{}/class_0_discriminator.h5".format(res_dir), "{}/class_{}_discriminator.h5".format(res_dir, c)) 131 | os.symlink("{}/class_0_generator.h5".format(res_dir), "{}/class_{}_generator.h5".format(res_dir, c)) 132 | os.symlink("{}/class_0_reconstructor.h5".format(res_dir), "{}/class_{}_reconstructor.h5".format(res_dir, c)) 133 | 134 | # Unbalance the training set. 135 | bg_train_partial = BatchGenerator(BatchGenerator.TRAIN, batch_size, 136 | class_to_prune=c, unbalance=unbalance, dataset=dataset_name) 137 | 138 | # Train the model (or reload it if already available 139 | if not ( 140 | os.path.exists("{}/class_{}_score.csv".format(res_dir, c)) and 141 | os.path.exists("{}/class_{}_discriminator.h5".format(res_dir, c)) and 142 | os.path.exists("{}/class_{}_generator.h5".format(res_dir, c)) and 143 | os.path.exists("{}/class_{}_reconstructor.h5".format(res_dir, c)) 144 | ): 145 | # Training required 146 | print("Required GAN for class {}".format(c)) 147 | 148 | print('Class counters: ', bg_train_partial.per_class_count) 149 | 150 | # Train GAN to balance the data 151 | gan = bagan.BalancingGAN( 152 | target_classes, c, dratio_mode=dratio_mode, gratio_mode=gratio_mode, 153 | adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res 154 | ) 155 | gan.train(bg_train_partial, bg_test, epochs=gan_epochs) 156 | gan.save_history( 157 | res_dir, c 158 | ) 159 | 160 | else: # GAN pre-trained 161 | # Unbalance the training. 162 | print("Loading GAN for class {}".format(c)) 163 | 164 | gan = bagan.BalancingGAN(target_classes, c, dratio_mode=dratio_mode, gratio_mode=gratio_mode, 165 | adam_lr=adam_lr, res_dir=res_dir, image_shape=shape, min_latent_res=min_latent_res) 166 | gan.load_models( 167 | "{}/class_{}_generator.h5".format(res_dir, c), 168 | "{}/class_{}_discriminator.h5".format(res_dir, c), 169 | "{}/class_{}_reconstructor.h5".format(res_dir, c), 170 | bg_train=bg_train_partial # This is required to initialize the per-class mean and covariance matrix 171 | ) 172 | 173 | # Sample and save images 174 | img_samples['class_{}'.format(c)] = gan.generate_samples(c=c, samples=10) 175 | 176 | save_image_array(np.array([img_samples['class_{}'.format(c)]]), '{}/plot_class_{}.png'.format(res_dir, c)) 177 | 178 | 179 | -------------------------------------------------------------------------------- /balancing_gan.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright IBM Corporation 2018 3 | 4 | All rights reserved. This program and the accompanying materials 5 | are made available under the terms of the Eclipse Public License v1.0 6 | which accompanies this distribution, and is available at 7 | http://www.eclipse.org/legal/epl-v10.html 8 | """ 9 | 10 | import pickle 11 | from collections import defaultdict 12 | 13 | import keras.backend as K 14 | K.set_image_dim_ordering('th') 15 | 16 | import keras 17 | from keras.layers.advanced_activations import LeakyReLU 18 | from keras.layers.convolutional import UpSampling2D, Convolution2D, Conv2D 19 | from keras.models import Sequential, Model 20 | from keras.optimizers import Adam 21 | 22 | import os 23 | import sys 24 | import re 25 | import numpy as np 26 | 27 | from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout 28 | 29 | from keras.layers import multiply as kmultiply 30 | from keras.layers import add as kadd 31 | 32 | import csv 33 | 34 | from PIL import Image 35 | 36 | from utils import save_image_array 37 | 38 | 39 | class BalancingGAN: 40 | def build_generator(self, latent_size, init_resolution=8): 41 | resolution = self.resolution 42 | channels = self.channels 43 | 44 | # we will map a pair of (z, L), where z is a latent vector and L is a 45 | # label drawn from P_c, to image space (..., 3, resolution, resolution) 46 | cnn = Sequential() 47 | 48 | cnn.add(Dense(1024, input_dim=latent_size, activation='relu', use_bias=False)) 49 | cnn.add(Dense(128 * init_resolution * init_resolution, activation='relu', use_bias=False)) 50 | cnn.add(Reshape((128, init_resolution, init_resolution))) 51 | crt_res = init_resolution 52 | 53 | # upsample 54 | while crt_res != resolution: 55 | cnn.add(UpSampling2D(size=(2, 2))) 56 | if crt_res < resolution/2: 57 | cnn.add(Conv2D( 58 | 256, (5, 5), padding='same', 59 | activation='relu', kernel_initializer='glorot_normal', use_bias=False) 60 | ) 61 | 62 | else: 63 | cnn.add(Conv2D(128, (5, 5), padding='same', 64 | activation='relu', kernel_initializer='glorot_normal', use_bias=False)) 65 | 66 | crt_res = crt_res * 2 67 | assert crt_res <= resolution,\ 68 | "Error: final resolution [{}] must equal i*2^n. Initial resolution i is [{}]. n must be a natural number.".format(resolution, init_resolution) 69 | 70 | cnn.add(Conv2D(channels, (2, 2), padding='same', 71 | activation='tanh', kernel_initializer='glorot_normal', use_bias=False)) 72 | 73 | # This is the latent z space 74 | latent = Input(shape=(latent_size, )) 75 | 76 | fake_image_from_latent = cnn(latent) 77 | 78 | # The input-output interface 79 | self.generator = Model(inputs=latent, outputs=fake_image_from_latent) 80 | 81 | def _build_common_encoder(self, image, min_latent_res=8): 82 | resolution = self.resolution 83 | channels = self.channels 84 | 85 | # build a relatively standard conv net, with LeakyReLUs as suggested in ACGAN 86 | cnn = Sequential() 87 | 88 | cnn.add(Conv2D(32, (3, 3), padding='same', strides=(2, 2), 89 | input_shape=(channels, resolution, resolution), use_bias=True)) 90 | cnn.add(LeakyReLU()) 91 | cnn.add(Dropout(0.3)) 92 | 93 | cnn.add(Conv2D(64, (3, 3), padding='same', strides=(1, 1), use_bias=True)) 94 | cnn.add(LeakyReLU()) 95 | cnn.add(Dropout(0.3)) 96 | 97 | cnn.add(Conv2D(128, (3, 3), padding='same', strides=(2, 2), use_bias=True)) 98 | cnn.add(LeakyReLU()) 99 | cnn.add(Dropout(0.3)) 100 | 101 | cnn.add(Conv2D(256, (3, 3), padding='same', strides=(1, 1), use_bias=True)) 102 | cnn.add(LeakyReLU()) 103 | cnn.add(Dropout(0.3)) 104 | 105 | while cnn.output_shape[-1] > min_latent_res: 106 | cnn.add(Conv2D(256, (3, 3), padding='same', strides=(2, 2), use_bias=True)) 107 | cnn.add(LeakyReLU()) 108 | cnn.add(Dropout(0.3)) 109 | 110 | cnn.add(Conv2D(256, (3, 3), padding='same', strides=(1, 1), use_bias=True)) 111 | cnn.add(LeakyReLU()) 112 | cnn.add(Dropout(0.3)) 113 | 114 | cnn.add(Flatten()) 115 | 116 | features = cnn(image) 117 | return features 118 | 119 | # latent_size is the innermost latent vector size; min_latent_res is latent resolution (before the dense layer). 120 | def build_reconstructor(self, latent_size, min_latent_res=8): 121 | resolution = self.resolution 122 | channels = self.channels 123 | image = Input(shape=(channels, resolution, resolution)) 124 | features = self._build_common_encoder(image, min_latent_res) 125 | 126 | # Reconstructor specific 127 | latent = Dense(latent_size, activation='linear')(features) 128 | self.reconstructor = Model(inputs=image, outputs=latent) 129 | 130 | def build_discriminator(self, min_latent_res=8): 131 | resolution = self.resolution 132 | channels = self.channels 133 | image = Input(shape=(channels, resolution, resolution)) 134 | features = self._build_common_encoder(image, min_latent_res) 135 | 136 | # Discriminator specific 137 | aux = Dense( 138 | self.nclasses+1, activation='softmax', name='auxiliary' # nclasses+1. The last class is: FAKE 139 | )(features) 140 | self.discriminator = Model(inputs=image, outputs=aux) 141 | 142 | def generate_from_latent(self, latent): 143 | res = self.generator(latent) 144 | return res 145 | 146 | def generate(self, c, bg=None): # c is a vector of classes 147 | latent = self.generate_latent(c, bg) 148 | res = self.generator.predict(latent) 149 | return res 150 | 151 | def generate_latent(self, c, bg=None, n_mix=10): # c is a vector of classes 152 | res = np.array([ 153 | np.random.multivariate_normal(self.means[e], self.covariances[e]) 154 | for e in c 155 | ]) 156 | 157 | return res 158 | 159 | def discriminate(self, image): 160 | return self.discriminator(image) 161 | 162 | def __init__(self, classes, target_class_id, 163 | # Set dratio_mode, and gratio_mode to 'rebalance' to bias the sampling toward the minority class 164 | # No relevant difference noted 165 | dratio_mode="uniform", gratio_mode="uniform", 166 | adam_lr=0.00005, latent_size=100, 167 | res_dir = "./res-tmp", image_shape=[3,32,32], min_latent_res=8): 168 | self.gratio_mode = gratio_mode 169 | self.dratio_mode = dratio_mode 170 | self.classes = classes 171 | self.target_class_id = target_class_id # target_class_id is used only during saving, not to overwrite other class results. 172 | self.nclasses = len(classes) 173 | self.latent_size = latent_size 174 | self.res_dir = res_dir 175 | self.channels = image_shape[0] 176 | self.resolution = image_shape[1] 177 | if self.resolution != image_shape[2]: 178 | print("Error: only squared images currently supported by balancingGAN") 179 | exit(1) 180 | 181 | self.min_latent_res = min_latent_res 182 | 183 | # Initialize learning variables 184 | self.adam_lr = adam_lr 185 | self.adam_beta_1 = 0.5 186 | 187 | # Initialize stats 188 | self.train_history = defaultdict(list) 189 | self.test_history = defaultdict(list) 190 | self.trained = False 191 | 192 | # Build generator 193 | self.build_generator(latent_size, init_resolution=min_latent_res) 194 | self.generator.compile( 195 | optimizer=Adam(lr=self.adam_lr, beta_1=self.adam_beta_1), 196 | loss='sparse_categorical_crossentropy' 197 | ) 198 | 199 | latent_gen = Input(shape=(latent_size, )) 200 | 201 | # Build discriminator 202 | self.build_discriminator(min_latent_res=min_latent_res) 203 | self.discriminator.compile( 204 | optimizer=Adam(lr=self.adam_lr, beta_1=self.adam_beta_1), 205 | loss='sparse_categorical_crossentropy' 206 | ) 207 | 208 | # Build reconstructor 209 | self.build_reconstructor(latent_size, min_latent_res=min_latent_res) 210 | self.reconstructor.compile( 211 | optimizer=Adam(lr=self.adam_lr, beta_1=self.adam_beta_1), 212 | loss='mean_squared_error' 213 | ) 214 | 215 | # Define combined for training generator. 216 | fake = self.generator(latent_gen) 217 | 218 | self.discriminator.trainable = False 219 | self.reconstructor.trainable = False 220 | self.generator.trainable = True 221 | aux = self.discriminate(fake) 222 | 223 | self.combined = Model(inputs=latent_gen, outputs=aux) 224 | 225 | self.combined.compile( 226 | optimizer=Adam(lr=self.adam_lr, beta_1=self.adam_beta_1), 227 | loss='sparse_categorical_crossentropy' 228 | ) 229 | 230 | # Define initializer for autoencoder 231 | self.discriminator.trainable = False 232 | self.generator.trainable = True 233 | self.reconstructor.trainable = True 234 | 235 | img_for_reconstructor = Input(shape=(self.channels, self.resolution, self.resolution,)) 236 | img_reconstruct = self.generator(self.reconstructor(img_for_reconstructor)) 237 | 238 | self.autoenc_0 = Model(inputs=img_for_reconstructor, outputs=img_reconstruct) 239 | self.autoenc_0.compile( 240 | optimizer=Adam(lr=self.adam_lr, beta_1=self.adam_beta_1), 241 | loss='mean_squared_error' 242 | ) 243 | 244 | def _biased_sample_labels(self, samples, target_distribution="uniform"): 245 | distribution = self.class_uratio 246 | if target_distribution == "d": 247 | distribution = self.class_dratio 248 | elif target_distribution == "g": 249 | distribution = self.class_gratio 250 | 251 | sampled_labels = np.full(samples,0) 252 | sampled_labels_p = np.random.uniform(0, 1, samples) 253 | for c in list(range(self.nclasses)): 254 | mask = np.logical_and((sampled_labels_p > 0), (sampled_labels_p <= distribution[c])) 255 | sampled_labels[mask] = self.classes[c] 256 | sampled_labels_p = sampled_labels_p - distribution[c] 257 | 258 | return sampled_labels 259 | 260 | def _train_one_epoch(self, bg_train): 261 | epoch_disc_loss = [] 262 | epoch_gen_loss = [] 263 | 264 | for image_batch, label_batch in bg_train.next_batch(): 265 | 266 | crt_batch_size = label_batch.shape[0] 267 | 268 | ################## Train Discriminator ################## 269 | fake_size = int(np.ceil(crt_batch_size * 1.0/self.nclasses)) 270 | 271 | # sample some labels from p_c, then latent and images 272 | sampled_labels = self._biased_sample_labels(fake_size, "d") 273 | latent_gen = self.generate_latent(sampled_labels, bg_train) 274 | 275 | generated_images = self.generator.predict(latent_gen, verbose=0) 276 | 277 | X = np.concatenate((image_batch, generated_images)) 278 | aux_y = np.concatenate((label_batch, np.full(len(sampled_labels) , self.nclasses )), axis=0) 279 | 280 | epoch_disc_loss.append(self.discriminator.train_on_batch(X, aux_y)) 281 | 282 | ################## Train Generator ################## 283 | sampled_labels = self._biased_sample_labels(fake_size + crt_batch_size, "g") 284 | latent_gen = self.generate_latent(sampled_labels, bg_train) 285 | 286 | epoch_gen_loss.append(self.combined.train_on_batch( 287 | latent_gen, sampled_labels)) 288 | 289 | # return statistics: generator loss, 290 | return ( 291 | np.mean(np.array(epoch_disc_loss), axis=0), 292 | np.mean(np.array(epoch_gen_loss), axis=0) 293 | ) 294 | 295 | def _set_class_ratios(self): 296 | self.class_dratio = np.full(self.nclasses, 0.0) 297 | # Set uniform 298 | target = 1/self.nclasses 299 | self.class_uratio = np.full(self.nclasses, target) 300 | 301 | # Set gratio 302 | self.class_gratio = np.full(self.nclasses, 0.0) 303 | for c in range(self.nclasses): 304 | if self.gratio_mode == "uniform": 305 | self.class_gratio[c] = target 306 | elif self.gratio_mode == "rebalance": 307 | self.class_gratio[c] = 2 * target - self.class_aratio[c] 308 | else: 309 | print("Error while training bgan, unknown gmode " + self.gratio_mode) 310 | exit() 311 | 312 | # Set dratio 313 | self.class_dratio = np.full(self.nclasses, 0.0) 314 | for c in range(self.nclasses): 315 | if self.dratio_mode == "uniform": 316 | self.class_dratio[c] = target 317 | elif self.dratio_mode == "rebalance": 318 | self.class_dratio[c] = 2 * target - self.class_aratio[c] 319 | else: 320 | print("Error while training bgan, unknown dmode " + self.dratio_mode) 321 | exit() 322 | 323 | # if very unbalanced, the gratio might be negative for some classes. 324 | # In this case, we adjust.. 325 | if self.gratio_mode == "rebalance": 326 | self.class_gratio[self.class_gratio < 0] = 0 327 | self.class_gratio = self.class_gratio / sum(self.class_gratio) 328 | 329 | # if very unbalanced, the dratio might be negative for some classes. 330 | # In this case, we adjust.. 331 | if self.dratio_mode == "rebalance": 332 | self.class_dratio[self.class_dratio < 0] = 0 333 | self.class_dratio = self.class_dratio / sum(self.class_dratio) 334 | 335 | def init_autoenc(self, bg_train, gen_fname=None, rec_fname=None): 336 | if gen_fname is None: 337 | generator_fname = "{}/{}_decoder.h5".format(self.res_dir, self.target_class_id) 338 | else: 339 | generator_fname = gen_fname 340 | if rec_fname is None: 341 | reconstructor_fname = "{}/{}_encoder.h5".format(self.res_dir, self.target_class_id) 342 | else: 343 | reconstructor_fname = rec_fname 344 | 345 | multivariate_prelearnt = False 346 | 347 | # Preload the autoencoders 348 | if os.path.exists(generator_fname) and os.path.exists(reconstructor_fname): 349 | print("BAGAN: loading autoencoder: ", generator_fname, reconstructor_fname) 350 | self.generator.load_weights(generator_fname) 351 | self.reconstructor.load_weights(reconstructor_fname) 352 | 353 | # load the learned distribution 354 | if os.path.exists("{}/{}_means.npy".format(self.res_dir, self.target_class_id)) \ 355 | and os.path.exists("{}/{}_covariances.npy".format(self.res_dir, self.target_class_id)): 356 | multivariate_prelearnt = True 357 | 358 | cfname = "{}/{}_covariances.npy".format(self.res_dir, self.target_class_id) 359 | mfname = "{}/{}_means.npy".format(self.res_dir, self.target_class_id) 360 | print("BAGAN: loading multivariate: ", cfname, mfname) 361 | self.covariances = np.load(cfname) 362 | self.means = np.load(mfname) 363 | 364 | else: 365 | print("BAGAN: training autoencoder") 366 | autoenc_train_loss = [] 367 | for e in range(self.autoenc_epochs): 368 | print('Autoencoder train epoch: {}/{}'.format(e+1, self.autoenc_epochs)) 369 | autoenc_train_loss_crt = [] 370 | for image_batch, label_batch in bg_train.next_batch(): 371 | 372 | autoenc_train_loss_crt.append(self.autoenc_0.train_on_batch(image_batch, image_batch)) 373 | autoenc_train_loss.append(np.mean(np.array(autoenc_train_loss_crt), axis=0)) 374 | 375 | autoenc_loss_fname = "{}/{}_autoencoder.csv".format(self.res_dir, self.target_class_id) 376 | with open(autoenc_loss_fname, 'w') as csvfile: 377 | for item in autoenc_train_loss: 378 | csvfile.write("%s\n" % item) 379 | 380 | self.generator.save(generator_fname) 381 | self.reconstructor.save(reconstructor_fname) 382 | 383 | layers_r = self.reconstructor.layers 384 | layers_d = self.discriminator.layers 385 | 386 | for l in range(1, len(layers_r)-1): 387 | layers_d[l].set_weights( layers_r[l].get_weights() ) 388 | 389 | # Organize multivariate distribution 390 | if not multivariate_prelearnt: 391 | print("BAGAN: computing multivariate") 392 | self.covariances = [] 393 | self.means = [] 394 | 395 | for c in range(self.nclasses): 396 | imgs = bg_train.dataset_x[bg_train.per_class_ids[c]] 397 | 398 | latent = self.reconstructor.predict(imgs) 399 | 400 | self.covariances.append(np.cov(np.transpose(latent))) 401 | self.means.append(np.mean(latent, axis=0)) 402 | 403 | self.covariances = np.array(self.covariances) 404 | self.means = np.array(self.means) 405 | 406 | # save the learned distribution 407 | cfname = "{}/{}_covariances.npy".format(self.res_dir, self.target_class_id) 408 | mfname = "{}/{}_means.npy".format(self.res_dir, self.target_class_id) 409 | print("BAGAN: saving multivariate: ", cfname, mfname) 410 | np.save(cfname, self.covariances) 411 | np.save(mfname, self.means) 412 | print("BAGAN: saved multivariate") 413 | 414 | def _get_lst_bck_name(self, element): 415 | # Find last bck name 416 | files = [ 417 | f for f in os.listdir(self.res_dir) 418 | if re.match(r'bck_c_{}'.format(self.target_class_id) + "_" + element, f) 419 | ] 420 | if len(files) > 0: 421 | fname = files[0] 422 | e_str = os.path.splitext(fname)[0].split("_")[-1] 423 | 424 | epoch = int(e_str) 425 | 426 | return epoch, fname 427 | 428 | else: 429 | return 0, None 430 | 431 | def init_gan(self): 432 | # Find last bck name 433 | epoch, generator_fname = self._get_lst_bck_name("generator") 434 | 435 | new_e, discriminator_fname = self._get_lst_bck_name("discriminator") 436 | 437 | if new_e != epoch: # Reload error, restart from scratch 438 | return 0 439 | 440 | # Load last bck 441 | try: 442 | self.generator.load_weights(os.path.join(self.res_dir, generator_fname)) 443 | self.discriminator.load_weights(os.path.join(self.res_dir, discriminator_fname)) 444 | return epoch 445 | 446 | # Return epoch 447 | except: # Reload error, restart from scratch (the first time we train we pass from here) 448 | return 0 449 | 450 | def backup_point(self, epoch): 451 | # Remove last bck 452 | _, old_bck_g = self._get_lst_bck_name("generator") 453 | _, old_bck_d = self._get_lst_bck_name("discriminator") 454 | try: 455 | os.remove(os.path.join(self.res_dir, old_bck_g)) 456 | os.remove(os.path.join(self.res_dir, old_bck_d)) 457 | except: 458 | pass 459 | 460 | # Bck 461 | generator_fname = "{}/bck_c_{}_generator_e_{}.h5".format(self.res_dir, self.target_class_id, epoch) 462 | discriminator_fname = "{}/bck_c_{}_discriminator_e_{}.h5".format(self.res_dir, self.target_class_id, epoch) 463 | 464 | self.generator.save(generator_fname) 465 | self.discriminator.save(discriminator_fname) 466 | 467 | def train(self, bg_train, bg_test, epochs=50): 468 | if not self.trained: 469 | self.autoenc_epochs = epochs 470 | 471 | # Class actual ratio 472 | self.class_aratio = bg_train.get_class_probability() 473 | 474 | # Class balancing ratio 475 | self._set_class_ratios() 476 | print("uratio set to: {}".format(self.class_uratio)) 477 | print("dratio set to: {}".format(self.class_dratio)) 478 | print("gratio set to: {}".format(self.class_gratio)) 479 | 480 | # Initialization 481 | print("BAGAN init_autoenc") 482 | self.init_autoenc(bg_train) 483 | print("BAGAN autoenc initialized, init gan") 484 | start_e = self.init_gan() 485 | print("BAGAN gan initialized, start_e: ", start_e) 486 | 487 | crt_c = 0 488 | act_img_samples = bg_train.get_samples_for_class(crt_c, 10) 489 | img_samples = np.array([ 490 | [ 491 | act_img_samples, 492 | self.generator.predict( 493 | self.reconstructor.predict( 494 | act_img_samples 495 | ) 496 | ), 497 | self.generate_samples(crt_c, 10, bg_train) 498 | ] 499 | ]) 500 | for crt_c in range(1, self.nclasses): 501 | act_img_samples = bg_train.get_samples_for_class(crt_c, 10) 502 | new_samples = np.array([ 503 | [ 504 | act_img_samples, 505 | self.generator.predict( 506 | self.reconstructor.predict( 507 | act_img_samples 508 | ) 509 | ), 510 | self.generate_samples(crt_c, 10, bg_train) 511 | ] 512 | ]) 513 | img_samples = np.concatenate((img_samples, new_samples), axis=0) 514 | 515 | shape = img_samples.shape 516 | img_samples = img_samples.reshape((-1, shape[-4], shape[-3], shape[-2], shape[-1])) 517 | 518 | save_image_array( 519 | img_samples, 520 | '{}/cmp_class_{}_init.png'.format(self.res_dir, self.target_class_id) 521 | ) 522 | 523 | # Train 524 | for e in range(start_e, epochs): 525 | print('GAN train epoch: {}/{}'.format(e+1, epochs)) 526 | # train_disc_loss, train_gen_loss = self._train_one_epoch(copy.deepcopy(bg_train)) 527 | train_disc_loss, train_gen_loss = self._train_one_epoch(bg_train) 528 | 529 | # Test: # generate a new batch of noise 530 | nb_test = bg_test.get_num_samples() 531 | fake_size = int(np.ceil(nb_test * 1.0/self.nclasses)) 532 | sampled_labels = self._biased_sample_labels(nb_test, "d") 533 | latent_gen = self.generate_latent(sampled_labels, bg_test) 534 | 535 | # sample some labels from p_c and generate images from them 536 | generated_images = self.generator.predict( 537 | latent_gen, verbose=False) 538 | 539 | X = np.concatenate( (bg_test.dataset_x, generated_images) ) 540 | aux_y = np.concatenate((bg_test.dataset_y, np.full(len(sampled_labels), self.nclasses )), axis=0) 541 | 542 | # see if the discriminator can figure itself out... 543 | test_disc_loss = self.discriminator.evaluate( 544 | X, aux_y, verbose=False) 545 | 546 | # make new latent 547 | sampled_labels = self._biased_sample_labels(fake_size + nb_test, "g") 548 | latent_gen = self.generate_latent(sampled_labels, bg_test) 549 | 550 | test_gen_loss = self.combined.evaluate( 551 | latent_gen, 552 | sampled_labels, verbose=False) 553 | 554 | # generate an epoch report on performance 555 | self.train_history['disc_loss'].append(train_disc_loss) 556 | self.train_history['gen_loss'].append(train_gen_loss) 557 | self.test_history['disc_loss'].append(test_disc_loss) 558 | self.test_history['gen_loss'].append(test_gen_loss) 559 | print("train_disc_loss {},\ttrain_gen_loss {},\ttest_disc_loss {},\ttest_gen_loss {}".format( 560 | train_disc_loss, train_gen_loss, test_disc_loss, test_gen_loss 561 | )) 562 | 563 | # Save sample images 564 | if e % 10 == 9: 565 | img_samples = np.array([ 566 | self.generate_samples(c, 10, bg_train) 567 | for c in range(0,self.nclasses) 568 | ]) 569 | 570 | save_image_array( 571 | img_samples, 572 | '{}/plot_class_{}_epoch_{}.png'.format(self.res_dir, self.target_class_id, e) 573 | ) 574 | 575 | # Generate whole evaluation plot (real img, autoencoded img, fake img) 576 | if e % 10 == 5: 577 | self.backup_point(e) 578 | crt_c = 0 579 | act_img_samples = bg_train.get_samples_for_class(crt_c, 10) 580 | img_samples = np.array([ 581 | [ 582 | act_img_samples, 583 | self.generator.predict( 584 | self.reconstructor.predict( 585 | act_img_samples 586 | ) 587 | ), 588 | self.generate_samples(crt_c, 10, bg_train) 589 | ] 590 | ]) 591 | for crt_c in range(1, self.nclasses): 592 | act_img_samples = bg_train.get_samples_for_class(crt_c, 10) 593 | new_samples = np.array([ 594 | [ 595 | act_img_samples, 596 | self.generator.predict( 597 | self.reconstructor.predict( 598 | act_img_samples 599 | ) 600 | ), 601 | self.generate_samples(crt_c, 10, bg_train) 602 | ] 603 | ]) 604 | img_samples = np.concatenate((img_samples, new_samples), axis=0) 605 | 606 | shape = img_samples.shape 607 | img_samples = img_samples.reshape((-1, shape[-4], shape[-3], shape[-2], shape[-1])) 608 | 609 | save_image_array( 610 | img_samples, 611 | '{}/cmp_class_{}_epoch_{}.png'.format(self.res_dir, self.target_class_id, e) 612 | ) 613 | 614 | self.trained = True 615 | 616 | def generate_samples(self, c, samples, bg = None): 617 | return self.generate(np.full(samples, c), bg) 618 | 619 | def save_history(self, res_dir, class_id): 620 | if self.trained: 621 | filename = "{}/class_{}_score.csv".format(res_dir, class_id) 622 | generator_fname = "{}/class_{}_generator.h5".format(res_dir, class_id) 623 | discriminator_fname = "{}/class_{}_discriminator.h5".format(res_dir, class_id) 624 | reconstructor_fname = "{}/class_{}_reconstructor.h5".format(res_dir, class_id) 625 | with open(filename, 'w') as csvfile: 626 | fieldnames = [ 627 | 'train_gen_loss', 'train_disc_loss', 628 | 'test_gen_loss', 'test_disc_loss' 629 | ] 630 | 631 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 632 | 633 | writer.writeheader() 634 | for e in range(len(self.train_history['gen_loss'])): 635 | row = [ 636 | self.train_history['gen_loss'][e], 637 | self.train_history['disc_loss'][e], 638 | self.test_history['gen_loss'][e], 639 | self.test_history['disc_loss'][e] 640 | ] 641 | 642 | writer.writerow(dict(zip(fieldnames,row))) 643 | 644 | self.generator.save(generator_fname) 645 | self.discriminator.save(discriminator_fname) 646 | self.reconstructor.save(reconstructor_fname) 647 | 648 | def load_models(self, fname_generator, fname_discriminator, fname_reconstructor, bg_train=None): 649 | self.init_autoenc(bg_train, gen_fname=fname_generator, rec_fname=fname_reconstructor) 650 | self.discriminator.load_weights(fname_discriminator) 651 | 652 | 653 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 ./bagan_train.py -u 0.05 -c 0 -e 150 4 | 5 | -------------------------------------------------------------------------------- /rw/batch_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright IBM Corporation 2018 3 | 4 | All rights reserved. This program and the accompanying materials 5 | are made available under the terms of the Eclipse Public License v1.0 6 | which accompanies this distribution, and is available at 7 | http://www.eclipse.org/legal/epl-v10.html 8 | """ 9 | 10 | from tensorflow.examples.tutorials.mnist import input_data 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | class BatchGenerator: 15 | 16 | TRAIN = 1 17 | TEST = 0 18 | 19 | def __init__(self, data_src, batch_size=32, class_to_prune=None, unbalance=0, dataset='MNIST'): 20 | assert dataset in ('MNIST', 'CIFAR10'), 'Unknown dataset: ' + dataset 21 | self.batch_size = batch_size 22 | self.data_src = data_src 23 | 24 | # Load data 25 | if dataset == 'MNIST': 26 | mnist = input_data.read_data_sets("dataset/mnist", one_hot=False) 27 | 28 | assert self.batch_size > 0, 'Batch size has to be a positive integer!' 29 | 30 | if self.data_src == self.TEST: 31 | self.dataset_x = mnist.test.images 32 | self.dataset_y = mnist.test.labels 33 | else: 34 | self.dataset_x = mnist.train.images 35 | self.dataset_y = mnist.train.labels 36 | 37 | # Normalize between -1 and 1 38 | self.dataset_x = (np.reshape(self.dataset_x, (self.dataset_x.shape[0], 28, 28)) - 0.5) * 2 39 | 40 | # Include 1 single color channel 41 | self.dataset_x = np.expand_dims(self.dataset_x, axis=1) 42 | 43 | elif dataset == 'CIFAR10': 44 | ((x, y), (x_test, y_test)) = tf.keras.datasets.cifar10.load_data() 45 | 46 | if self.data_src == self.TEST: 47 | self.dataset_x = x_test 48 | self.dataset_y = y_test 49 | else: 50 | self.dataset_x = x 51 | self.dataset_y = y 52 | 53 | 54 | # Arrange x: channel first 55 | self.dataset_x = np.transpose(self.dataset_x, axes=(0, 3, 1, 2)) 56 | 57 | # Normalize between -1 and 1 58 | self.dataset_x = (self.dataset_x - 127.5) / 127.5 59 | 60 | # Y 1D format 61 | self.dataset_y = self.dataset_y[:, 0] 62 | 63 | assert (self.dataset_x.shape[0] == self.dataset_y.shape[0]) 64 | 65 | # Compute per class instance count. 66 | classes = np.unique(self.dataset_y) 67 | self.classes = classes 68 | per_class_count = list() 69 | for c in classes: 70 | per_class_count.append(np.sum(np.array(self.dataset_y == c))) 71 | 72 | # Prune if needed! 73 | if class_to_prune is not None: 74 | all_ids = list(np.arange(len(self.dataset_x))) 75 | 76 | mask = [class_to_prune == lc for lc in self.dataset_y] 77 | all_ids_c = np.array(all_ids)[mask] 78 | np.random.shuffle(all_ids_c) 79 | 80 | other_class_count = np.array(per_class_count) 81 | other_class_count = np.delete(other_class_count, class_to_prune) 82 | to_keep = int(np.ceil(unbalance * max( 83 | other_class_count))) 84 | 85 | to_delete = all_ids_c[to_keep: len(all_ids_c)] 86 | 87 | self.dataset_x = np.delete(self.dataset_x, to_delete, axis=0) 88 | self.dataset_y = np.delete(self.dataset_y, to_delete, axis=0) 89 | 90 | # Recount after pruning 91 | per_class_count = list() 92 | for c in classes: 93 | per_class_count.append(np.sum(np.array(self.dataset_y == c))) 94 | self.per_class_count = per_class_count 95 | 96 | # List of labels 97 | self.label_table = [str(c) for c in range(10)] 98 | 99 | # Preload all the labels. 100 | self.labels = self.dataset_y[:] 101 | 102 | # per class ids 103 | self.per_class_ids = dict() 104 | ids = np.array(range(len(self.dataset_x))) 105 | for c in classes: 106 | self.per_class_ids[c] = ids[self.labels == c] 107 | 108 | def get_samples_for_class(self, c, samples=None): 109 | if samples is None: 110 | samples = self.batch_size 111 | 112 | np.random.shuffle(self.per_class_ids[c]) 113 | to_return = self.per_class_ids[c][0:samples] 114 | return self.dataset_x[to_return] 115 | 116 | def get_label_table(self): 117 | return self.label_table 118 | 119 | def get_num_classes(self): 120 | return len( self.label_table ) 121 | 122 | def get_class_probability(self): 123 | return self.per_class_count/sum(self.per_class_count) 124 | 125 | ### ACCESS DATA AND SHAPES ### 126 | def get_num_samples(self): 127 | return self.dataset_x.shape[0] 128 | 129 | def get_image_shape(self): 130 | return [self.dataset_x.shape[1], self.dataset_x.shape[2], self.dataset_x.shape[3]] 131 | 132 | def next_batch(self): 133 | dataset_x = self.dataset_x 134 | labels = self.labels 135 | 136 | indices = np.arange(dataset_x.shape[0]) 137 | 138 | np.random.shuffle(indices) 139 | 140 | for start_idx in range(0, dataset_x.shape[0] - self.batch_size + 1, self.batch_size): 141 | access_pattern = indices[start_idx:start_idx + self.batch_size] 142 | access_pattern = sorted(access_pattern) 143 | 144 | yield dataset_x[access_pattern, :, :, :], labels[access_pattern] 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright IBM Corporation 2018 3 | 4 | All rights reserved. This program and the accompanying materials 5 | are made available under the terms of the Eclipse Public License v1.0 6 | which accompanies this distribution, and is available at 7 | http://www.eclipse.org/legal/epl-v10.html 8 | """ 9 | 10 | import numpy as np 11 | from PIL import Image 12 | 13 | 14 | def save_image_array(img_array, fname): 15 | channels = img_array.shape[2] 16 | resolution = img_array.shape[-1] 17 | img_rows = img_array.shape[0] 18 | img_cols = img_array.shape[1] 19 | 20 | img = np.full([channels, resolution * img_rows, resolution * img_cols], 0.0) 21 | for r in range(img_rows): 22 | for c in range(img_cols): 23 | img[:, 24 | (resolution * r): (resolution * (r + 1)), 25 | (resolution * (c % 10)): (resolution * ((c % 10) + 1)) 26 | ] = img_array[r, c] 27 | 28 | img = (img * 127.5 + 127.5).astype(np.uint8) 29 | if (img.shape[0] == 1): 30 | img = img[0] 31 | else: 32 | img = np.rollaxis(img, 0, 3) 33 | 34 | Image.fromarray(img).save(fname) 35 | 36 | --------------------------------------------------------------------------------