├── README.md ├── datasets ├── fashion-mnist │ ├── order_1.txt │ ├── order_2.txt │ ├── order_3.txt │ ├── order_4.txt │ └── order_5.txt ├── mnist │ ├── order_1.txt │ ├── order_2.txt │ ├── order_3.txt │ ├── order_4.txt │ └── order_5.txt └── svhn │ ├── order_1.txt │ ├── order_2.txt │ ├── order_3.txt │ ├── order_4.txt │ └── order_5.txt ├── gan ├── __init__.py ├── model_mnist_cnn.py ├── model_mnist_dgr.py ├── model_mnist_introgan.py ├── model_mnist_mergan.py ├── model_svhn_cnn.py ├── model_svhn_dgr.py ├── model_svhn_introgan.py ├── model_svhn_mergan.py └── tflib │ ├── __init__.py │ ├── cifar10.py │ ├── cifar100.py │ ├── inception_score.py │ ├── mnist.py │ ├── ops │ ├── __init__.py │ ├── batchnorm.py │ ├── cond_batchnorm.py │ ├── conv1d.py │ ├── conv2d.py │ ├── deconv2d.py │ ├── layernorm.py │ ├── linear.py │ └── norm.py │ ├── plot.py │ ├── save_images.py │ └── small_imagenet.py ├── mnist_train_dgr.py ├── mnist_train_introgan.py ├── mnist_train_lowerbound.py ├── mnist_train_mergan.py ├── mnist_train_upperbound.py ├── requirements.txt ├── svhn_train_dgr.py ├── svhn_train_introgan.py ├── svhn_train_lowerbound.py ├── svhn_train_mergan.py ├── svhn_train_upperbound.py └── utils ├── __init__.py ├── fashion_mnist.py ├── fid.py ├── mnist.py ├── resnet_v1_mod.py ├── svhn.py ├── vgg_preprocessing.py ├── visualize_embedding_protos_and_samples.py └── visualize_result_single.py /README.md: -------------------------------------------------------------------------------- 1 | # IntroGAN 2 | 3 | This is the official implementation of the paper ***Introspective GAN: Learning to Grow a GAN for Incremental Generation and Classification*** which is accepted by Pattern Recognition (PR) [[Paper](https://www.sciencedirect.com/science/article/abs/pii/S0031320324001341)] 4 | 5 | ## Introduction 6 | 7 | Lifelong learning, the ability to continually learn new concepts throughout our life, is a hallmark of human intelligence. Generally, humans learn a new concept by knowing what it looks like and what makes it different from the others, which are correlated. Those two ways can be characterized by generation and classification in machine learning respectively. In this paper, we carefully design a dynamically growing GAN called **Introspective GAN (IntroGAN)** that can perform incremental generation and classification simultaneously with the guidance of prototypes, inspired by their roles of efficient information organization in human visual learning and excellent performance in other fields like zero-shot/few-shot/incremental learning. Specifically, we incorporate prototype-based classification which is robust to feature change in incremental learning and GAN as a generative memory to alleviate forgetting into a unified end-to-end framework. A comprehensive benchmark on the joint incremental generation and classification task is proposed and our method demonstrates promising results. Additionally, we conduct comprehensive analyses over the properties of IntroGAN and verify that generation and classification can be mutually beneficial in incremental scenarios, which is an inspiring area to be further exploited. 8 | 9 | ## Usage 10 | 11 | ### 1. Requirements 12 | 13 | The code is implemented in Python 2.7. 14 | 15 | The CUDA version we use is 8.0 and the cuDNN version is 6.0. 16 | 17 | The Tensorflow version is 1.4. 18 | 19 | For requirements for the Python modules, simply run: 20 | 21 | ``pip install -r requirements.txt`` 22 | 23 | ### 2. Dataset Preparation 24 | 25 | #### 2.1 MNIST 26 | 27 | Download ``train-images-idx3-ubyte.gz``, ``train-labels-idx1-ubyte.gz``, ``t10k-images-idx3-ubyte.gz``, ``t10k-labels-idx1-ubyte.gz`` from the official website of [MNIST](http://yann.lecun.com/exdb/mnist/) and move them to ``datasets/mnist``. 28 | 29 | #### 2.2 Fashion-MNIST 30 | 31 | Download ``train-images-idx3-ubyte.gz``, ``train-labels-idx1-ubyte.gz``, ``t10k-images-idx3-ubyte.gz``, ``t10k-labels-idx1-ubyte.gz`` from the official website of [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) and move them to ``datasets/fashion-mnist``. 32 | 33 | #### 2.3 SVHN 34 | 35 | Download ``train_32x32.mat``, ``test_32x32.mat`` from the official website of [SVHN](http://ufldl.stanford.edu/housenumbers/) and move them to ``datasets/svhn``. 36 | 37 | ### 3. Precomputed statistics for calculating FIDs 38 | 39 | Download the files for different datasets below and extracted in the ``precalc_fids/`` folder 40 | 41 | **MNIST**: [[Google Drive]](https://drive.google.com/file/d/12UU537Y7sZkltiTMCVU1WApNQ_7PdKVS/view?usp=sharing) 42 | 43 | **Fashion-MNIST**: [[Google Drive]](https://drive.google.com/file/d/1ide3Be6ypqt0ymYamQQ_DpfFTJvq70dG/view?usp=sharing) 44 | 45 | **SVHN**: [[Google Drive]](https://drive.google.com/file/d/13naH1scHToqwcyvb9InaTeK705ejElID/view?usp=sharing) 46 | 47 | ### 4. Training 48 | 49 | **MNIST**: 50 | 51 | _IntroGAN_: `python mnist_train_introgan.py --dataset mnist` 52 | 53 | _DGR_: `python mnist_train_dgr.py --dataset mnist` 54 | 55 | _MeRGAN_: `python mnist_train_mergan.py --dataset mnist` 56 | 57 | **Fashion-MNIST**: 58 | 59 | _IntroGAN_: `python mnist_train_introgan.py` 60 | 61 | _DGR_: `python mnist_train_dgr.py` 62 | 63 | _MeRGAN_: `python mnist_train_mergan.py` 64 | 65 | **SVHN**: 66 | 67 | _IntroGAN_: `python svhn_train_introgan.py` 68 | 69 | _DGR_: `python svhn_train_dgr.py` 70 | 71 | _MeRGAN_: `python svhn_train_mergan.py` 72 | 73 | **After running the code above, the TA-ACC and TA-FID of this particular run can be found in the result folder, e.g. `result/introgan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/500/proto_static_random_20_weight_0.000000_0.000000_squared_l2_0.010000_min_select/finetune_improved_v2_noise_0.5_exemplars_dual_use_1`** 74 | 75 | ## Further 76 | 77 | If you have any question, feel free to contact me. My email is chen.he@vipl.ict.ac.cn 78 | -------------------------------------------------------------------------------- /datasets/fashion-mnist/order_1.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 -------------------------------------------------------------------------------- /datasets/fashion-mnist/order_2.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 5 3 | 9 4 | 4 5 | 2 6 | 6 7 | 8 8 | 7 9 | 0 10 | 3 -------------------------------------------------------------------------------- /datasets/fashion-mnist/order_3.txt: -------------------------------------------------------------------------------- 1 | 7 2 | 2 3 | 4 4 | 3 5 | 9 6 | 5 7 | 6 8 | 8 9 | 1 10 | 0 -------------------------------------------------------------------------------- /datasets/fashion-mnist/order_4.txt: -------------------------------------------------------------------------------- 1 | 3 2 | 1 3 | 5 4 | 9 5 | 7 6 | 2 7 | 4 8 | 6 9 | 0 10 | 8 -------------------------------------------------------------------------------- /datasets/fashion-mnist/order_5.txt: -------------------------------------------------------------------------------- 1 | 5 2 | 3 3 | 1 4 | 0 5 | 6 6 | 9 7 | 7 8 | 4 9 | 8 10 | 2 -------------------------------------------------------------------------------- /datasets/mnist/order_1.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 -------------------------------------------------------------------------------- /datasets/mnist/order_2.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 5 3 | 9 4 | 4 5 | 2 6 | 6 7 | 8 8 | 7 9 | 0 10 | 3 -------------------------------------------------------------------------------- /datasets/mnist/order_3.txt: -------------------------------------------------------------------------------- 1 | 7 2 | 2 3 | 4 4 | 3 5 | 9 6 | 5 7 | 6 8 | 8 9 | 1 10 | 0 -------------------------------------------------------------------------------- /datasets/mnist/order_4.txt: -------------------------------------------------------------------------------- 1 | 3 2 | 1 3 | 5 4 | 9 5 | 7 6 | 2 7 | 4 8 | 6 9 | 0 10 | 8 -------------------------------------------------------------------------------- /datasets/mnist/order_5.txt: -------------------------------------------------------------------------------- 1 | 5 2 | 3 3 | 1 4 | 0 5 | 6 6 | 9 7 | 7 8 | 4 9 | 8 10 | 2 -------------------------------------------------------------------------------- /datasets/svhn/order_1.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | 6 8 | 7 9 | 8 10 | 9 -------------------------------------------------------------------------------- /datasets/svhn/order_2.txt: -------------------------------------------------------------------------------- 1 | 1 2 | 5 3 | 9 4 | 4 5 | 2 6 | 6 7 | 8 8 | 7 9 | 0 10 | 3 -------------------------------------------------------------------------------- /datasets/svhn/order_3.txt: -------------------------------------------------------------------------------- 1 | 7 2 | 2 3 | 4 4 | 3 5 | 9 6 | 5 7 | 6 8 | 8 9 | 1 10 | 0 -------------------------------------------------------------------------------- /datasets/svhn/order_4.txt: -------------------------------------------------------------------------------- 1 | 3 2 | 1 3 | 5 4 | 9 5 | 7 6 | 2 7 | 4 8 | 6 9 | 0 10 | 8 -------------------------------------------------------------------------------- /datasets/svhn/order_5.txt: -------------------------------------------------------------------------------- 1 | 5 2 | 3 3 | 1 4 | 0 5 | 6 6 | 9 7 | 7 8 | 4 9 | 8 10 | 2 -------------------------------------------------------------------------------- /gan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyPod/IntroGAN/80b1f4fdc3f3b034c7ae2005dc328286e81cddd8/gan/__init__.py -------------------------------------------------------------------------------- /gan/tflib/__init__.py: -------------------------------------------------------------------------------- 1 | import locale 2 | 3 | import tensorflow as tf 4 | 5 | locale.setlocale(locale.LC_ALL, '') 6 | 7 | _params = {} 8 | _param_aliases = {} 9 | 10 | 11 | def param(name, *args, **kwargs): 12 | """ 13 | A wrapper for `tf.Variable` which enables parameter sharing in models. 14 | 15 | Creates and returns theano shared variables similarly to `tf.Variable`, 16 | except if you try to create a param with the same name as a 17 | previously-created one, `param(...)` will just return the old one instead of 18 | making a new one. 19 | 20 | This constructor also adds a `param` attribute to the shared variables it 21 | creates, so that you can easily search a graph for all params. 22 | """ 23 | 24 | if name not in _params: 25 | kwargs['name'] = name 26 | param = tf.Variable(*args, **kwargs) 27 | param.param = True 28 | _params[name] = param 29 | result = _params[name] 30 | i = 0 31 | while result in _param_aliases: 32 | # print 'following alias {}: {} to {}'.format(i, result, _param_aliases[result]) 33 | i += 1 34 | result = _param_aliases[result] 35 | return result 36 | 37 | 38 | def params_with_name(name): 39 | return [p for n, p in _params.items() if name in n] 40 | 41 | 42 | def delete_all_params(): 43 | _params.clear() 44 | 45 | 46 | def alias_params(replace_dict): 47 | for old, new in replace_dict.items(): 48 | # print "aliasing {} to {}".format(old,new) 49 | _param_aliases[old] = new 50 | 51 | 52 | def delete_param_aliases(): 53 | _param_aliases.clear() 54 | 55 | 56 | # def search(node, critereon): 57 | # """ 58 | # Traverse the Theano graph starting at `node` and return a list of all nodes 59 | # which match the `critereon` function. When optimizing a cost function, you 60 | # can use this to get a list of all of the trainable params in the graph, like 61 | # so: 62 | 63 | # `lib.search(cost, lambda x: hasattr(x, "param"))` 64 | # """ 65 | 66 | # def _search(node, critereon, visited): 67 | # if node in visited: 68 | # return [] 69 | # visited.add(node) 70 | 71 | # results = [] 72 | # if isinstance(node, T.Apply): 73 | # for inp in node.inputs: 74 | # results += _search(inp, critereon, visited) 75 | # else: # Variable node 76 | # if critereon(node): 77 | # results.append(node) 78 | # if node.owner is not None: 79 | # results += _search(node.owner, critereon, visited) 80 | # return results 81 | 82 | # return _search(node, critereon, set()) 83 | 84 | # def print_params_info(params): 85 | # """Print information about the parameters in the given param set.""" 86 | 87 | # params = sorted(params, key=lambda p: p.name) 88 | # values = [p.get_value(borrow=True) for p in params] 89 | # shapes = [p.shape for p in values] 90 | # print "Params for cost:" 91 | # for param, value, shape in zip(params, values, shapes): 92 | # print "\t{0} ({1})".format( 93 | # param.name, 94 | # ",".join([str(x) for x in shape]) 95 | # ) 96 | 97 | # total_param_count = 0 98 | # for shape in shapes: 99 | # param_count = 1 100 | # for dim in shape: 101 | # param_count *= dim 102 | # total_param_count += param_count 103 | # print "Total parameter count: {0}".format( 104 | # locale.format("%d", total_param_count, grouping=True) 105 | # ) 106 | 107 | def print_model_settings(locals_): 108 | print("Uppercase local vars:") 109 | all_vars = [(k, v) for (k, v) in locals_.items() if 110 | (k.isupper() and k != 'T' and k != 'SETTINGS' and k != 'ALL_SETTINGS')] 111 | all_vars = sorted(all_vars, key=lambda x: x[0]) 112 | for var_name, var_value in all_vars: 113 | print("\t{}: {}".format(var_name, var_value)) 114 | 115 | 116 | def print_model_settings_dict(settings): 117 | print("Settings dict:") 118 | all_vars = [(k, v) for (k, v) in settings.items()] 119 | all_vars = sorted(all_vars, key=lambda x: x[0]) 120 | for var_name, var_value in all_vars: 121 | print("\t{}: {}".format(var_name, var_value)) 122 | -------------------------------------------------------------------------------- /gan/tflib/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def unpickle(file): 9 | fo = open(file, 'rb') 10 | dict = pickle.load(fo) 11 | fo.close() 12 | return dict['data'], dict['labels'] 13 | 14 | def cifar_generator(filenames, batch_size, data_dir): 15 | all_data = [] 16 | all_labels = [] 17 | for filename in filenames: 18 | data, labels = unpickle(data_dir + '/' + filename) 19 | all_data.append(data) 20 | all_labels.append(labels) 21 | 22 | images = np.concatenate(all_data, axis=0) 23 | labels = np.concatenate(all_labels, axis=0) 24 | 25 | def get_epoch(): 26 | rng_state = np.random.get_state() 27 | np.random.shuffle(images) 28 | np.random.set_state(rng_state) 29 | np.random.shuffle(labels) 30 | 31 | for i in xrange(len(images) / batch_size): 32 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 33 | 34 | return get_epoch 35 | 36 | 37 | def load(batch_size, data_dir): 38 | return ( 39 | cifar_generator(['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5'], batch_size, data_dir), 40 | cifar_generator(['test_batch'], batch_size, data_dir) 41 | ) -------------------------------------------------------------------------------- /gan/tflib/cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def unpickle(file): 9 | fo = open(file, 'rb') 10 | dict = pickle.load(fo) 11 | fo.close() 12 | return dict['data'], dict['fine_labels'] 13 | 14 | def cifar_generator(filenames, batch_size, data_dir): 15 | all_data = [] 16 | all_labels = [] 17 | for filename in filenames: 18 | data, labels = unpickle(data_dir + '/' + filename) 19 | all_data.append(data) 20 | all_labels.append(labels) 21 | 22 | images = np.concatenate(all_data, axis=0) 23 | labels = np.concatenate(all_labels, axis=0) 24 | 25 | def get_epoch(): 26 | rng_state = np.random.get_state() 27 | np.random.shuffle(images) 28 | np.random.set_state(rng_state) 29 | np.random.shuffle(labels) 30 | 31 | for i in xrange(len(images) / batch_size): 32 | yield (images[i*batch_size:(i+1)*batch_size], labels[i*batch_size:(i+1)*batch_size]) 33 | 34 | return get_epoch 35 | 36 | 37 | def load(batch_size, data_dir): 38 | return ( 39 | cifar_generator(['train'], batch_size, data_dir), 40 | cifar_generator(['test'], batch_size, data_dir) 41 | ) -------------------------------------------------------------------------------- /gan/tflib/inception_score.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py 2 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os.path 8 | import sys 9 | import tarfile 10 | 11 | import numpy as np 12 | from six.moves import urllib 13 | import tensorflow as tf 14 | import glob 15 | import scipy.misc 16 | import math 17 | import sys 18 | 19 | MODEL_DIR = './tmp/imagenet' 20 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 21 | softmax = None 22 | 23 | # Call this function with list of images. Each of elements should be a 24 | # numpy array with values ranging from 0 to 255. 25 | def get_inception_score(images, splits=10): 26 | assert(type(images) == list) 27 | assert(type(images[0]) == np.ndarray) 28 | assert(len(images[0].shape) == 3) 29 | assert(np.max(images[0]) > 10) 30 | assert(np.min(images[0]) >= 0.0) 31 | inps = [] 32 | for img in images: 33 | img = img.astype(np.float32) 34 | inps.append(np.expand_dims(img, 0)) 35 | batch_size = 100 36 | with tf.Session() as sess: 37 | preds = [] 38 | n_batches = int(math.ceil(float(len(inps)) / float(batch_size))) 39 | for i in range(n_batches): 40 | # sys.stdout.write(".") 41 | # sys.stdout.flush() 42 | inp = inps[(i * batch_size):min((i + 1) * batch_size, len(inps))] 43 | inp = np.concatenate(inp, 0) 44 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 45 | preds.append(pred) 46 | preds = np.concatenate(preds, 0) 47 | scores = [] 48 | for i in range(splits): 49 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] # 100x1008 50 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) # (100x1008) * ((100x1008) - (1x1008)) 51 | kl = np.mean(np.sum(kl, 1)) # (100x1008) -> (100,) -> (1) 52 | scores.append(np.exp(kl)) 53 | return np.mean(scores), np.std(scores) 54 | 55 | # This function is called automatically. 56 | def _init_inception(): 57 | global softmax 58 | if not os.path.exists(MODEL_DIR): 59 | os.makedirs(MODEL_DIR) 60 | filename = DATA_URL.split('/')[-1] 61 | filepath = os.path.join(MODEL_DIR, filename) 62 | if not os.path.exists(filepath): 63 | def _progress(count, block_size, total_size): 64 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 65 | filename, float(count * block_size) / float(total_size) * 100.0)) 66 | sys.stdout.flush() 67 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 68 | print() 69 | statinfo = os.stat(filepath) 70 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 71 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 72 | with tf.gfile.FastGFile(os.path.join( 73 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 74 | graph_def = tf.GraphDef() 75 | graph_def.ParseFromString(f.read()) 76 | _ = tf.import_graph_def(graph_def, name='') 77 | # Works with an arbitrary minibatch size. 78 | with tf.Session() as sess: 79 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 80 | ops = pool3.graph.get_operations() 81 | for op_idx, op in enumerate(ops): 82 | for o in op.outputs: 83 | shape = o.get_shape() 84 | shape = [s.value for s in shape] 85 | new_shape = [] 86 | for j, s in enumerate(shape): 87 | if s == 1 and j == 0: 88 | new_shape.append(None) 89 | else: 90 | new_shape.append(s) 91 | o._shape = tf.TensorShape(new_shape) 92 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 93 | logits = tf.matmul(tf.squeeze(pool3), w) 94 | softmax = tf.nn.softmax(logits) 95 | 96 | if softmax is None: 97 | _init_inception() 98 | -------------------------------------------------------------------------------- /gan/tflib/mnist.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import os 4 | import urllib 5 | import gzip 6 | import cPickle as pickle 7 | 8 | def mnist_generator(data, batch_size, n_labelled, limit=None): 9 | images, targets = data 10 | 11 | rng_state = numpy.random.get_state() 12 | numpy.random.shuffle(images) 13 | numpy.random.set_state(rng_state) 14 | numpy.random.shuffle(targets) 15 | if limit is not None: 16 | print "WARNING ONLY FIRST {} MNIST DIGITS".format(limit) 17 | images = images.astype('float32')[:limit] 18 | targets = targets.astype('int32')[:limit] 19 | if n_labelled is not None: 20 | labelled = numpy.zeros(len(images), dtype='int32') 21 | labelled[:n_labelled] = 1 22 | 23 | def get_epoch(): 24 | rng_state = numpy.random.get_state() 25 | numpy.random.shuffle(images) 26 | numpy.random.set_state(rng_state) 27 | numpy.random.shuffle(targets) 28 | 29 | if n_labelled is not None: 30 | numpy.random.set_state(rng_state) 31 | numpy.random.shuffle(labelled) 32 | 33 | image_batches = images.reshape(-1, batch_size, 784) 34 | target_batches = targets.reshape(-1, batch_size) 35 | 36 | if n_labelled is not None: 37 | labelled_batches = labelled.reshape(-1, batch_size) 38 | 39 | for i in xrange(len(image_batches)): 40 | yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i]), numpy.copy(labelled)) 41 | 42 | else: 43 | 44 | for i in xrange(len(image_batches)): 45 | yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i])) 46 | 47 | return get_epoch 48 | 49 | def load(batch_size, test_batch_size, n_labelled=None): 50 | filepath = '/tmp/mnist.pkl.gz' 51 | url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' 52 | 53 | if not os.path.isfile(filepath): 54 | print "Couldn't find MNIST dataset in /tmp, downloading..." 55 | urllib.urlretrieve(url, filepath) 56 | 57 | with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f: 58 | train_data, dev_data, test_data = pickle.load(f) 59 | 60 | return ( 61 | mnist_generator(train_data, batch_size, n_labelled), 62 | mnist_generator(dev_data, test_batch_size, n_labelled), 63 | mnist_generator(test_data, test_batch_size, n_labelled) 64 | ) -------------------------------------------------------------------------------- /gan/tflib/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyPod/IntroGAN/80b1f4fdc3f3b034c7ae2005dc328286e81cddd8/gan/tflib/ops/__init__.py -------------------------------------------------------------------------------- /gan/tflib/ops/batchnorm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from gan import tflib as lib 5 | 6 | 7 | def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True): 8 | """ 9 | 10 | :param name: 11 | :param axes: the remaining axis represents CHANNEL, we want to normalize CHANNEL 12 | :param inputs: 13 | :param is_training: 14 | :param stats_iter: 15 | :param update_moving_stats: 16 | :param fused: 17 | :return: 18 | """ 19 | 20 | if ((axes == [0, 2, 3]) or (axes == [0, 2])) and fused == True: 21 | if axes == [0, 2]: 22 | inputs = tf.expand_dims(inputs, 3) 23 | 24 | # Variables declaration 25 | offset = lib.param(name + '.offset', np.zeros(inputs.get_shape()[1], dtype='float32')) 26 | scale = lib.param(name + '.scale', np.ones(inputs.get_shape()[1], dtype='float32')) 27 | moving_mean = lib.param(name + '.moving_mean', np.zeros(inputs.get_shape()[1], dtype='float32'), 28 | trainable=False) 29 | moving_variance = lib.param(name + '.moving_variance', np.ones(inputs.get_shape()[1], dtype='float32'), 30 | trainable=False) 31 | 32 | # train 33 | def _fused_batch_norm_training(): 34 | return tf.nn.fused_batch_norm(inputs, scale, offset, epsilon=1e-5, data_format='NCHW') 35 | 36 | # test 37 | def _fused_batch_norm_inference(): 38 | # Version which blends in the current item's statistics 39 | batch_size = tf.cast(tf.shape(inputs)[0], 'float32') 40 | mean, var = tf.nn.moments(inputs, [2, 3], keep_dims=True) 41 | mean = ((1. / batch_size) * mean) + (((batch_size - 1.) / batch_size) * moving_mean)[None, :, None, None] 42 | var = ((1. / batch_size) * var) + (((batch_size - 1.) / batch_size) * moving_variance)[None, :, None, None] 43 | return tf.nn.batch_normalization(inputs, mean, var, offset[None, :, None, None], scale[None, :, None, None], 44 | 1e-5), mean, var 45 | 46 | # Standard version 47 | # return tf.nn.fused_batch_norm( 48 | # inputs, 49 | # scale, 50 | # offset, 51 | # epsilon=1e-2, 52 | # mean=moving_mean, 53 | # variance=moving_variance, 54 | # is_training=False, 55 | # data_format='NCHW' 56 | # ) 57 | 58 | if is_training is None: 59 | outputs, batch_mean, batch_var = _fused_batch_norm_training() 60 | else: 61 | outputs, batch_mean, batch_var = tf.cond(is_training, 62 | _fused_batch_norm_training, 63 | _fused_batch_norm_inference) 64 | if update_moving_stats: 65 | no_updates = lambda: outputs 66 | 67 | def _force_updates(): 68 | """Internal function forces updates moving_vars if is_training.""" 69 | float_stats_iter = tf.cast(stats_iter, tf.float32) 70 | 71 | update_moving_mean = tf.assign(moving_mean, 72 | ((float_stats_iter / (float_stats_iter + 1)) * moving_mean) + ( 73 | (1 / (float_stats_iter + 1)) * batch_mean)) 74 | update_moving_variance = tf.assign(moving_variance, ( 75 | (float_stats_iter / (float_stats_iter + 1)) * moving_variance) + ( 76 | (1 / (float_stats_iter + 1)) * batch_var)) 77 | 78 | with tf.control_dependencies([update_moving_mean, update_moving_variance]): 79 | return tf.identity(outputs) 80 | 81 | outputs = tf.cond(is_training, _force_updates, no_updates) 82 | 83 | if axes == [0, 2]: 84 | return outputs[:, :, :, 0] # collapse last dim 85 | else: 86 | return outputs 87 | else: 88 | # raise Exception('old BN') 89 | # TODO we can probably use nn.fused_batch_norm here too for speedup 90 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 91 | shape = mean.get_shape().as_list() 92 | if 0 not in axes: 93 | print("WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format( 94 | name)) 95 | shape[0] = 1 96 | offset = lib.param(name + '.offset', np.zeros(shape, dtype='float32')) 97 | scale = lib.param(name + '.scale', np.ones(shape, dtype='float32')) 98 | # result = tf.cond(tf.equal(tf.shape(inputs)[0], 1), lambda: inputs, lambda: tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5)) 99 | 100 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 101 | 102 | return result 103 | -------------------------------------------------------------------------------- /gan/tflib/ops/cond_batchnorm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from gan import tflib as lib 5 | 6 | 7 | def Cond_Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True, labels=None, n_labels=None): 8 | """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps""" 9 | if axes != [0,2,3]: 10 | raise Exception('unsupported') 11 | mean, var = tf.nn.moments(inputs, axes, keep_dims=True) 12 | shape = mean.get_shape().as_list() # shape is [1,n,1,1] 13 | offset_m = lib.param(name+'.offset', np.zeros([n_labels, shape[1]], dtype='float32')) 14 | scale_m = lib.param(name+'.scale', np.ones([n_labels, shape[1]], dtype='float32')) 15 | 16 | # check if labels is one hot 17 | if len(labels.shape) == 2: 18 | labels = tf.argmax(labels, axis=1) 19 | 20 | offset = tf.nn.embedding_lookup(offset_m, labels) # embedding_lookup is a indexing operation 21 | scale = tf.nn.embedding_lookup(scale_m, labels) 22 | result = tf.nn.batch_normalization(inputs, mean, var, offset[:,:,None,None], scale[:,:,None,None], 1e-5) 23 | return result -------------------------------------------------------------------------------- /gan/tflib/ops/conv1d.py: -------------------------------------------------------------------------------- 1 | from gan import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | def Conv1D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, biases=True, gain=1.): 12 | """ 13 | inputs: tensor of shape (batch size, num channels, width) 14 | mask_type: one of None, 'a', 'b' 15 | 16 | returns: tensor of shape (batch size, num channels, width) 17 | """ 18 | with tf.name_scope(name) as scope: 19 | 20 | if mask_type is not None: 21 | mask_type, mask_n_channels = mask_type 22 | 23 | mask = np.ones( 24 | (filter_size, input_dim, output_dim), 25 | dtype='float32' 26 | ) 27 | center = filter_size // 2 28 | 29 | # Mask out future locations 30 | # filter shape is (width, input channels, output channels) 31 | mask[center+1:, :, :] = 0. 32 | 33 | # Mask out future channels 34 | for i in xrange(mask_n_channels): 35 | for j in xrange(mask_n_channels): 36 | if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): 37 | mask[ 38 | center, 39 | i::mask_n_channels, 40 | j::mask_n_channels 41 | ] = 0. 42 | 43 | 44 | def uniform(stdev, size): 45 | return np.random.uniform( 46 | low=-stdev * np.sqrt(3), 47 | high=stdev * np.sqrt(3), 48 | size=size 49 | ).astype('float32') 50 | 51 | fan_in = input_dim * filter_size 52 | fan_out = output_dim * filter_size / stride 53 | 54 | if mask_type is not None: # only approximately correct 55 | fan_in /= 2. 56 | fan_out /= 2. 57 | 58 | if he_init: 59 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 60 | else: # Normalized init (Glorot & Bengio) 61 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 62 | 63 | filter_values = uniform( 64 | filters_stdev, 65 | (filter_size, input_dim, output_dim) 66 | ) 67 | # print "WARNING IGNORING GAIN" 68 | filter_values *= gain 69 | 70 | filters = lib.param(name+'.Filters', filter_values) 71 | 72 | if weightnorm==None: 73 | weightnorm = _default_weightnorm 74 | if weightnorm: 75 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1))) 76 | target_norms = lib.param( 77 | name + '.g', 78 | norm_values 79 | ) 80 | with tf.name_scope('weightnorm') as scope: 81 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1])) 82 | filters = filters * (target_norms / norms) 83 | 84 | if mask_type is not None: 85 | with tf.name_scope('filter_mask'): 86 | filters = filters * mask 87 | 88 | result = tf.nn.conv1d( 89 | value=inputs, 90 | filters=filters, 91 | stride=stride, 92 | padding='SAME', 93 | data_format='NCHW' 94 | ) 95 | 96 | if biases: 97 | _biases = lib.param( 98 | name+'.Biases', 99 | np.zeros([output_dim], dtype='float32') 100 | ) 101 | 102 | # result = result + _biases 103 | 104 | result = tf.expand_dims(result, 3) 105 | result = tf.nn.bias_add(result, _biases, data_format='NCHW') 106 | result = tf.squeeze(result) 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /gan/tflib/ops/conv2d.py: -------------------------------------------------------------------------------- 1 | from gan import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from gan.tflib.ops.norm import weights_spectral_norm 7 | 8 | _default_weightnorm = False 9 | def enable_default_weightnorm(): 10 | global _default_weightnorm 11 | _default_weightnorm = True 12 | 13 | _weights_stdev = None 14 | def set_weights_stdev(weights_stdev): 15 | global _weights_stdev 16 | _weights_stdev = weights_stdev 17 | 18 | def unset_weights_stdev(): 19 | global _weights_stdev 20 | _weights_stdev = None 21 | 22 | def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, spectral_norm=False, biases=True, gain=1.): 23 | """ 24 | inputs: tensor of shape (batch size, num channels, height, width) 25 | mask_type: one of None, 'a', 'b' 26 | 27 | returns: tensor of shape (batch size, num channels, height, width) 28 | """ 29 | with tf.name_scope(name) as scope: 30 | 31 | if mask_type is not None: 32 | mask_type, mask_n_channels = mask_type 33 | 34 | mask = np.ones( 35 | (filter_size, filter_size, input_dim, output_dim), 36 | dtype='float32' 37 | ) 38 | center = filter_size // 2 39 | 40 | # Mask out future locations 41 | # filter shape is (height, width, input channels, output channels) 42 | mask[center+1:, :, :, :] = 0. 43 | mask[center, center+1:, :, :] = 0. 44 | 45 | # Mask out future channels 46 | for i in xrange(mask_n_channels): 47 | for j in xrange(mask_n_channels): 48 | if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): 49 | mask[ 50 | center, 51 | center, 52 | i::mask_n_channels, 53 | j::mask_n_channels 54 | ] = 0. 55 | 56 | 57 | def uniform(stdev, size): 58 | return np.random.uniform( 59 | low=-stdev * np.sqrt(3), 60 | high=stdev * np.sqrt(3), 61 | size=size 62 | ).astype('float32') 63 | 64 | fan_in = input_dim * filter_size**2 65 | fan_out = output_dim * filter_size**2 / (stride**2) 66 | 67 | if mask_type is not None: # only approximately correct 68 | fan_in /= 2. 69 | fan_out /= 2. 70 | 71 | if he_init: 72 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 73 | else: # Normalized init (Glorot & Bengio) 74 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 75 | 76 | if _weights_stdev is not None: 77 | filter_values = uniform( 78 | _weights_stdev, 79 | (filter_size, filter_size, input_dim, output_dim) 80 | ) 81 | else: 82 | filter_values = uniform( 83 | filters_stdev, 84 | (filter_size, filter_size, input_dim, output_dim) 85 | ) 86 | 87 | # print "WARNING IGNORING GAIN" 88 | filter_values *= gain 89 | 90 | filters = lib.param(name+'.Filters', filter_values) 91 | 92 | if weightnorm==None: 93 | weightnorm = _default_weightnorm 94 | if weightnorm: 95 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1,2))) 96 | target_norms = lib.param( 97 | name + '.g', 98 | norm_values 99 | ) 100 | with tf.name_scope('weightnorm') as scope: 101 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1,2])) 102 | filters = filters * (target_norms / norms) 103 | 104 | if spectral_norm: 105 | filters = weights_spectral_norm(name, filters) 106 | 107 | if mask_type is not None: 108 | with tf.name_scope('filter_mask'): 109 | filters = filters * mask 110 | 111 | result = tf.nn.conv2d( 112 | input=inputs, 113 | filter=filters, 114 | strides=[1, 1, stride, stride], 115 | padding='SAME', 116 | data_format='NCHW' 117 | ) 118 | 119 | if biases: 120 | _biases = lib.param( 121 | name+'.Biases', 122 | np.zeros(output_dim, dtype='float32') 123 | ) 124 | 125 | result = tf.nn.bias_add(result, _biases, data_format='NCHW') 126 | 127 | 128 | return result 129 | -------------------------------------------------------------------------------- /gan/tflib/ops/deconv2d.py: -------------------------------------------------------------------------------- 1 | from gan import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | _default_weightnorm = False 7 | def enable_default_weightnorm(): 8 | global _default_weightnorm 9 | _default_weightnorm = True 10 | 11 | _weights_stdev = None 12 | def set_weights_stdev(weights_stdev): 13 | global _weights_stdev 14 | _weights_stdev = weights_stdev 15 | 16 | def unset_weights_stdev(): 17 | global _weights_stdev 18 | _weights_stdev = None 19 | 20 | def Deconv2D( 21 | name, 22 | input_dim, 23 | output_dim, 24 | filter_size, 25 | inputs, 26 | stride=2, 27 | he_init=True, 28 | weightnorm=None, 29 | biases=True, 30 | gain=1., 31 | mask_type=None, 32 | special=False, 33 | padding='SAME' 34 | ): 35 | """ 36 | inputs: tensor of shape (batch size, height, width, input_dim) 37 | returns: tensor of shape (batch size, 2*height, 2*width, output_dim) 38 | """ 39 | with tf.name_scope(name) as scope: 40 | 41 | if mask_type != None: 42 | raise Exception('Unsupported configuration') 43 | 44 | def uniform(stdev, size): 45 | return np.random.uniform( 46 | low=-stdev * np.sqrt(3), 47 | high=stdev * np.sqrt(3), 48 | size=size 49 | ).astype('float32') 50 | 51 | fan_in = input_dim * filter_size**2 / (stride**2) 52 | fan_out = output_dim * filter_size**2 53 | 54 | if he_init: 55 | filters_stdev = np.sqrt(4./(fan_in+fan_out)) 56 | else: # Normalized init (Glorot & Bengio) 57 | filters_stdev = np.sqrt(2./(fan_in+fan_out)) 58 | 59 | 60 | if _weights_stdev is not None: 61 | filter_values = uniform( 62 | _weights_stdev, 63 | (filter_size, filter_size, output_dim, input_dim) 64 | ) 65 | else: 66 | filter_values = uniform( 67 | filters_stdev, 68 | (filter_size, filter_size, output_dim, input_dim) 69 | ) 70 | 71 | filter_values *= gain 72 | 73 | filters = lib.param( 74 | name+'.Filters', 75 | filter_values 76 | ) 77 | 78 | if weightnorm==None: 79 | weightnorm = _default_weightnorm 80 | if weightnorm: 81 | norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1,3))) 82 | target_norms = lib.param( 83 | name + '.g', 84 | norm_values 85 | ) 86 | with tf.name_scope('weightnorm') as scope: 87 | norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1,3])) 88 | filters = filters * tf.expand_dims(target_norms / norms, 1) 89 | 90 | 91 | inputs = tf.transpose(inputs, [0,2,3,1], name='NCHW_to_NHWC') 92 | 93 | input_shape = tf.shape(inputs) 94 | try: # tf pre-1.0 (top) vs 1.0 (bottom) 95 | if special: 96 | output_shape = tf.pack([input_shape[0], 4, 4, output_dim]) 97 | else: 98 | output_shape = tf.pack([input_shape[0], stride*input_shape[1], stride*input_shape[2], output_dim]) 99 | except Exception as e: 100 | if special: 101 | output_shape = tf.stack([input_shape[0], 4, 4, output_dim]) 102 | else: 103 | output_shape = tf.stack([input_shape[0], stride*input_shape[1], stride*input_shape[2], output_dim]) 104 | 105 | result = tf.nn.conv2d_transpose( 106 | value=inputs, 107 | filter=filters, 108 | output_shape=output_shape, 109 | strides=[1, stride, stride, 1], 110 | padding=padding 111 | ) 112 | 113 | if biases: 114 | _biases = lib.param( 115 | name+'.Biases', 116 | np.zeros(output_dim, dtype='float32') 117 | ) 118 | result = tf.nn.bias_add(result, _biases) 119 | 120 | result = tf.transpose(result, [0,3,1,2], name='NHWC_to_NCHW') 121 | 122 | 123 | return result 124 | -------------------------------------------------------------------------------- /gan/tflib/ops/layernorm.py: -------------------------------------------------------------------------------- 1 | from gan import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def Layernorm(name, norm_axes, inputs): 8 | mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True) 9 | 10 | # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers. 11 | n_neurons = inputs.get_shape().as_list()[norm_axes[0]] 12 | 13 | offset = lib.param(name+'.offset', np.zeros(n_neurons, dtype='float32')) 14 | scale = lib.param(name+'.scale', np.ones(n_neurons, dtype='float32')) 15 | 16 | # Add broadcasting dims to offset and scale (e.g. BCHW conv data) 17 | offset = tf.reshape(offset, [-1] + [1 for i in xrange(len(norm_axes)-1)]) 18 | scale = tf.reshape(scale, [-1] + [1 for i in xrange(len(norm_axes)-1)]) 19 | 20 | result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5) 21 | 22 | return result -------------------------------------------------------------------------------- /gan/tflib/ops/linear.py: -------------------------------------------------------------------------------- 1 | from gan import tflib as lib 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from gan.tflib.ops.norm import weights_spectral_norm 7 | 8 | _default_weightnorm = False 9 | def enable_default_weightnorm(): 10 | global _default_weightnorm 11 | _default_weightnorm = True 12 | 13 | def disable_default_weightnorm(): 14 | global _default_weightnorm 15 | _default_weightnorm = False 16 | 17 | _weights_stdev = None 18 | def set_weights_stdev(weights_stdev): 19 | global _weights_stdev 20 | _weights_stdev = weights_stdev 21 | 22 | def unset_weights_stdev(): 23 | global _weights_stdev 24 | _weights_stdev = None 25 | 26 | def Linear( 27 | name, 28 | input_dim, 29 | output_dim, 30 | inputs, 31 | biases=True, 32 | initialization=None, 33 | weightnorm=None, 34 | spectral_norm=False, 35 | gain=1. 36 | ): 37 | """ 38 | initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)` 39 | """ 40 | with tf.name_scope(name) as scope: 41 | 42 | def uniform(stdev, size): 43 | if _weights_stdev is not None: 44 | stdev = _weights_stdev 45 | return np.random.uniform( 46 | low=-stdev * np.sqrt(3), 47 | high=stdev * np.sqrt(3), 48 | size=size 49 | ).astype('float32') 50 | 51 | if initialization == 'lecun':# and input_dim != output_dim): 52 | # disabling orth. init for now because it's too slow 53 | weight_values = uniform( 54 | np.sqrt(1./input_dim), 55 | (input_dim, output_dim) 56 | ) 57 | 58 | elif initialization == 'glorot' or (initialization == None): 59 | 60 | weight_values = uniform( 61 | np.sqrt(2./(input_dim+output_dim)), 62 | (input_dim, output_dim) 63 | ) 64 | 65 | elif initialization == 'he': 66 | 67 | weight_values = uniform( 68 | np.sqrt(2./input_dim), 69 | (input_dim, output_dim) 70 | ) 71 | 72 | elif initialization == 'glorot_he': 73 | 74 | weight_values = uniform( 75 | np.sqrt(4./(input_dim+output_dim)), 76 | (input_dim, output_dim) 77 | ) 78 | 79 | elif initialization == 'orthogonal' or \ 80 | (initialization == None and input_dim == output_dim): 81 | 82 | # From lasagne 83 | def sample(shape): 84 | if len(shape) < 2: 85 | raise RuntimeError("Only shapes of length 2 or more are " 86 | "supported.") 87 | flat_shape = (shape[0], np.prod(shape[1:])) 88 | # TODO: why normal and not uniform? 89 | a = np.random.normal(0.0, 1.0, flat_shape) 90 | u, _, v = np.linalg.svd(a, full_matrices=False) 91 | # pick the one with the correct shape 92 | q = u if u.shape == flat_shape else v 93 | q = q.reshape(shape) 94 | return q.astype('float32') 95 | weight_values = sample((input_dim, output_dim)) 96 | 97 | elif initialization[0] == 'uniform': 98 | 99 | weight_values = np.random.uniform( 100 | low=-initialization[1], 101 | high=initialization[1], 102 | size=(input_dim, output_dim) 103 | ).astype('float32') 104 | 105 | else: 106 | 107 | raise Exception('Invalid initialization!') 108 | 109 | weight_values *= gain 110 | 111 | weight = lib.param( 112 | name + '.W', 113 | weight_values 114 | ) 115 | 116 | if weightnorm==None: 117 | weightnorm = _default_weightnorm 118 | if weightnorm: 119 | norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0)) 120 | # norm_values = np.linalg.norm(weight_values, axis=0) 121 | 122 | target_norms = lib.param( 123 | name + '.g', 124 | norm_values 125 | ) 126 | 127 | with tf.name_scope('weightnorm') as scope: 128 | norms = tf.sqrt(tf.reduce_sum(tf.square(weight), reduction_indices=[0])) 129 | weight = weight * (target_norms / norms) 130 | 131 | if spectral_norm: 132 | weight = weights_spectral_norm(name, weight) 133 | 134 | 135 | # if 'Discriminator' in name: 136 | # print "WARNING weight constraint on {}".format(name) 137 | # weight = tf.nn.softsign(10.*weight)*.1 138 | 139 | if inputs.get_shape().ndims == 2: 140 | result = tf.matmul(inputs, weight) 141 | else: 142 | reshaped_inputs = tf.reshape(inputs, [-1, input_dim]) 143 | result = tf.matmul(reshaped_inputs, weight) 144 | result = tf.reshape(result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim])) 145 | 146 | if biases: 147 | result = tf.nn.bias_add( 148 | result, 149 | lib.param( 150 | name + '.b', 151 | np.zeros((output_dim,), dtype='float32') 152 | ) 153 | ) 154 | 155 | return result 156 | 157 | 158 | def CondProjection( 159 | name, 160 | input_dim, 161 | output_dim, 162 | inputs, 163 | labels, 164 | initialization=None, 165 | weightnorm=None, 166 | spectral_norm=True, 167 | ): 168 | """ 169 | initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)` 170 | """ 171 | with tf.name_scope(name) as scope: 172 | 173 | def uniform(stdev, size): 174 | if _weights_stdev is not None: 175 | stdev = _weights_stdev 176 | return np.random.uniform( 177 | low=-stdev * np.sqrt(3), 178 | high=stdev * np.sqrt(3), 179 | size=size 180 | ).astype('float32') 181 | 182 | if initialization == 'lecun': # and input_dim != output_dim): 183 | # disabling orth. init for now because it's too slow 184 | weight_values = uniform( 185 | np.sqrt(1. / input_dim), 186 | (input_dim, output_dim) 187 | ) 188 | 189 | elif initialization == 'glorot' or (initialization == None): 190 | 191 | weight_values = uniform( 192 | np.sqrt(2. / (input_dim + output_dim)), 193 | (input_dim, output_dim) 194 | ) 195 | 196 | elif initialization == 'he': 197 | 198 | weight_values = uniform( 199 | np.sqrt(2. / input_dim), 200 | (input_dim, output_dim) 201 | ) 202 | 203 | elif initialization == 'glorot_he': 204 | 205 | weight_values = uniform( 206 | np.sqrt(4. / (input_dim + output_dim)), 207 | (input_dim, output_dim) 208 | ) 209 | 210 | elif initialization == 'orthogonal' or \ 211 | (initialization == None and input_dim == output_dim): 212 | 213 | # From lasagne 214 | def sample(shape): 215 | if len(shape) < 2: 216 | raise RuntimeError("Only shapes of length 2 or more are " 217 | "supported.") 218 | flat_shape = (shape[0], np.prod(shape[1:])) 219 | # TODO: why normal and not uniform? 220 | a = np.random.normal(0.0, 1.0, flat_shape) 221 | u, _, v = np.linalg.svd(a, full_matrices=False) 222 | # pick the one with the correct shape 223 | q = u if u.shape == flat_shape else v 224 | q = q.reshape(shape) 225 | return q.astype('float32') 226 | 227 | weight_values = sample((input_dim, output_dim)) 228 | 229 | elif initialization[0] == 'uniform': 230 | 231 | weight_values = np.random.uniform( 232 | low=-initialization[1], 233 | high=initialization[1], 234 | size=(input_dim, output_dim) 235 | ).astype('float32') 236 | 237 | else: 238 | 239 | raise Exception('Invalid initialization!') 240 | 241 | weight = lib.param( 242 | name + '.W', 243 | weight_values 244 | ) 245 | 246 | if weightnorm == None: 247 | weightnorm = _default_weightnorm 248 | if weightnorm: 249 | norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0)) 250 | # norm_values = np.linalg.norm(weight_values, axis=0) 251 | 252 | target_norms = lib.param( 253 | name + '.g', 254 | norm_values 255 | ) 256 | 257 | with tf.name_scope('weightnorm') as scope: 258 | norms = tf.sqrt(tf.reduce_sum(tf.square(weight), reduction_indices=[0])) 259 | weight = weight * (target_norms / norms) 260 | 261 | if spectral_norm: 262 | weight = weights_spectral_norm(name, weight) 263 | 264 | # if 'Discriminator' in name: 265 | # print "WARNING weight constraint on {}".format(name) 266 | # weight = tf.nn.softsign(10.*weight)*.1 267 | 268 | if inputs.get_shape().ndims == 2: 269 | result = tf.matmul(inputs, tf.nn.embedding_lookup(weight, labels)) 270 | else: 271 | reshaped_inputs = tf.reshape(inputs, [-1, input_dim]) 272 | result = tf.matmul(reshaped_inputs, tf.nn.embedding_lookup(weight, labels)) 273 | result = tf.reshape(result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim])) 274 | 275 | return result -------------------------------------------------------------------------------- /gan/tflib/ops/norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @time: 1/10/19 10:29 AM 5 | @author: Chen He 6 | @site: 7 | @file: norm.py 8 | @description: 9 | """ 10 | 11 | import tensorflow as tf 12 | 13 | 14 | # spectral_norm 15 | def l2_norm(input_x, epsilon=1e-12): 16 | input_x_norm = input_x / (tf.reduce_sum(input_x**2)**0.5 + epsilon) 17 | return input_x_norm 18 | 19 | 20 | def weights_spectral_norm(name, weights, u=None, iteration=1, update_collection=None, reuse=False): 21 | 22 | with tf.variable_scope(name) as scope: 23 | if reuse: 24 | scope.reuse_variables() 25 | 26 | w_shape = weights.get_shape().as_list() 27 | w_mat = tf.reshape(weights, [-1, w_shape[-1]]) 28 | if u is None: 29 | u = tf.get_variable('u', shape=[1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), 30 | trainable=False) 31 | 32 | def power_iteration(u, ite): 33 | v_ = tf.matmul(u, tf.transpose(w_mat)) 34 | v_hat = l2_norm(v_) 35 | u_ = tf.matmul(v_hat, w_mat) 36 | u_hat = l2_norm(u_) 37 | return u_hat, v_hat, ite + 1 38 | 39 | u_hat, v_hat, _ = power_iteration(u, iteration) 40 | 41 | sigma = tf.matmul(tf.matmul(v_hat, w_mat), tf.transpose(u_hat)) 42 | 43 | w_mat = w_mat / sigma 44 | 45 | if update_collection is None: 46 | with tf.control_dependencies([u.assign(u_hat)]): 47 | w_norm = tf.reshape(w_mat, w_shape) 48 | else: 49 | if not (update_collection == 'NO_OPS'): 50 | print(update_collection) 51 | tf.add_to_collection(update_collection, u.assign(u_hat)) 52 | 53 | w_norm = tf.reshape(w_mat, w_shape) 54 | return w_norm -------------------------------------------------------------------------------- /gan/tflib/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | 7 | import collections 8 | try: 9 | import cPickle as pickle 10 | except ImportError: 11 | import pickle 12 | import os 13 | 14 | _since_beginning = collections.defaultdict(lambda: {}) 15 | _since_last_flush = collections.defaultdict(lambda: {}) 16 | 17 | _iter = [0] 18 | 19 | 20 | def reset(): 21 | _since_beginning.clear() 22 | _since_last_flush.clear() 23 | _iter[0] = 0 24 | 25 | 26 | def tick(): 27 | _iter[0] += 1 28 | 29 | 30 | def plot(name, value): 31 | _since_last_flush[name][_iter[0]] = value 32 | 33 | 34 | def flush(dir, filename='log.pkl'): 35 | prints = [] 36 | 37 | log_folder = os.path.join(dir, 'log') 38 | if not os.path.exists(log_folder): 39 | os.makedirs(log_folder) 40 | 41 | for name, vals in _since_last_flush.items(): 42 | prints.append("{}\t{}".format(name, np.mean(vals.values()))) 43 | _since_beginning[name].update(vals) 44 | 45 | x_vals = np.sort(_since_beginning[name].keys()) 46 | y_vals = [_since_beginning[name][x] for x in x_vals] 47 | 48 | plt.clf() 49 | plt.plot(x_vals, y_vals) 50 | plt.xlabel('iteration') 51 | plt.ylabel(name) 52 | 53 | plt.savefig(os.path.join(log_folder, name.replace(' ', '_') + '.jpg')) 54 | 55 | print("iter {}\t{}".format(_iter[0], "\t".join(prints))) 56 | _since_last_flush.clear() 57 | 58 | with open(os.path.join(log_folder, filename), 'wb') as f: 59 | pickle.dump(dict(_since_beginning), f, pickle.HIGHEST_PROTOCOL) 60 | -------------------------------------------------------------------------------- /gan/tflib/save_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image grid saver, based on color_grid_vis from github.com/Newmu 3 | """ 4 | 5 | import numpy as np 6 | import scipy.misc 7 | from scipy.misc import imsave 8 | 9 | 10 | def save_images(X, save_path): 11 | # [0, 1] -> [0,255] 12 | if isinstance(X.flatten()[0], np.floating): 13 | X = (255.99 * X).astype('uint8') 14 | 15 | n_samples = X.shape[0] 16 | rows = int(np.sqrt(n_samples)) 17 | while n_samples % rows != 0: 18 | rows -= 1 19 | 20 | nh, nw = rows, n_samples / rows 21 | 22 | if X.ndim == 2: 23 | X = np.reshape(X, (X.shape[0], int(np.sqrt(X.shape[1])), int(np.sqrt(X.shape[1])))) 24 | 25 | if X.ndim == 4: 26 | # BCHW -> BHWC 27 | X = X.transpose(0, 2, 3, 1) 28 | h, w = X[0].shape[:2] 29 | img = np.zeros((h * nh, w * nw, 3)) 30 | elif X.ndim == 3: 31 | h, w = X[0].shape[:2] 32 | img = np.zeros((h * nh, w * nw)) 33 | 34 | for n, x in enumerate(X): 35 | j = n / nw 36 | i = n % nw 37 | img[j * h:j * h + h, i * w:i * w + w] = x 38 | 39 | imsave(save_path, img) 40 | -------------------------------------------------------------------------------- /gan/tflib/small_imagenet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc 3 | import time 4 | 5 | def make_generator(path, n_files, batch_size): 6 | epoch_count = [1] 7 | def get_epoch(): 8 | images = np.zeros((batch_size, 3, 64, 64), dtype='int32') 9 | files = range(n_files) 10 | random_state = np.random.RandomState(epoch_count[0]) 11 | random_state.shuffle(files) 12 | epoch_count[0] += 1 13 | for n, i in enumerate(files): 14 | image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) 15 | images[n % batch_size] = image.transpose(2,0,1) 16 | if n > 0 and n % batch_size == 0: 17 | yield (images,) 18 | return get_epoch 19 | 20 | def load(batch_size, data_dir='/home/ishaan/data/imagenet64'): 21 | return ( 22 | make_generator(data_dir+'/train_64x64', 1281149, batch_size), 23 | make_generator(data_dir+'/valid_64x64', 49999, batch_size) 24 | ) 25 | 26 | if __name__ == '__main__': 27 | train_gen, valid_gen = load(64) 28 | t0 = time.time() 29 | for i, batch in enumerate(train_gen(), start=1): 30 | print "{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0]) 31 | if i == 1000: 32 | break 33 | t0 = time.time() -------------------------------------------------------------------------------- /mnist_train_dgr.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | tf.set_random_seed(1993) 6 | 7 | import numpy as np 8 | 9 | np.random.seed(1993) 10 | import os 11 | import pprint 12 | 13 | from gan.model_mnist_dgr import GAN 14 | 15 | from utils import fid 16 | from utils.visualize_result_single import vis_acc_and_fid 17 | 18 | flags = tf.app.flags 19 | 20 | flags.DEFINE_string("dataset", "fashion-mnist", "The name of dataset [mnist, fashion-mnist]") 21 | 22 | # Hyperparameters 23 | flags.DEFINE_integer("lambda_param", 10, "Gradient penalty lambda hyperparameter [10]") 24 | flags.DEFINE_integer("critic_iters", 1, 25 | "How many critic iterations per generator iteration [1 for DCGAN and 5 for WGAN-GP]") 26 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 27 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 28 | flags.DEFINE_integer("output_dim", 28 * 28, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 29 | flags.DEFINE_string("mode", 'dcgan', "Valid options are dcgan or wgan-gp") 30 | flags.DEFINE_integer("gan_save_interval", 5000, 'interval to save a checkpoint(number of iters)') 31 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 32 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.5') 33 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.999') 34 | flags.DEFINE_boolean("gan_finetune", True, 'if gan finetuned from the previous model') 35 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 36 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 37 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 38 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 39 | flags.DEFINE_float("dgr_ratio", 0.5, "") 40 | flags.DEFINE_float("solver_adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 41 | flags.DEFINE_integer("test_interval", 500, 42 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 43 | 44 | # Add how many classes every time 45 | flags.DEFINE_integer('nb_cl', 2, '') 46 | 47 | # DEBUG 48 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 49 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 50 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 51 | 52 | # Visualize 53 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 54 | 55 | FLAGS = flags.FLAGS 56 | 57 | pp = pprint.PrettyPrinter() 58 | 59 | 60 | def main(_): 61 | pp.pprint(flags.FLAGS.__flags) 62 | 63 | NUM_CLASSES = 10 64 | NUM_TRAIN_SAMPLES_PER_CLASS = 6000 # training samples per class 65 | NUM_TEST_SAMPLES_PER_CLASS = 1000 # testing samples per class 66 | 67 | if FLAGS.dataset == 'mnist': 68 | from utils import mnist 69 | raw_images_train, train_labels, train_one_hot_labels, order = mnist.load_data(kind='train', 70 | order_idx=FLAGS.order_idx) 71 | raw_images_test, test_labels, test_one_hot_labels, order = mnist.load_data(kind='t10k', 72 | order_idx=FLAGS.order_idx) 73 | elif FLAGS.dataset == 'fashion-mnist': 74 | from utils import fashion_mnist 75 | raw_images_train, train_labels, train_one_hot_labels, order = fashion_mnist.load_data(kind='train', 76 | order_idx=FLAGS.order_idx) 77 | raw_images_test, test_labels, test_one_hot_labels, order = fashion_mnist.load_data(kind='t10k', 78 | order_idx=FLAGS.order_idx) 79 | 80 | # Total training samples 81 | NUM_TRAIN_SAMPLES_TOTAL = NUM_CLASSES * NUM_TRAIN_SAMPLES_PER_CLASS 82 | NUM_TEST_SAMPLES_TOTAL = NUM_CLASSES * NUM_TEST_SAMPLES_PER_CLASS 83 | 84 | run_config = tf.ConfigProto() 85 | run_config.gpu_options.allow_growth = True 86 | 87 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 88 | result_dir = os.path.join('result', method_name) 89 | print('Result dir: %s' % result_dir) 90 | 91 | graph_fid = tf.Graph() 92 | with graph_fid.as_default(): 93 | inception_path = fid.check_or_download_inception('tmp/imagenet') 94 | fid.create_inception_graph(inception_path) 95 | sess_fid = tf.Session(config=run_config, graph=graph_fid) 96 | 97 | ''' 98 | Class Incremental Learning 99 | ''' 100 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 101 | print('Adding %d categories every time' % FLAGS.nb_cl) 102 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 103 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 104 | 105 | to_category_idx = category_idx + FLAGS.nb_cl - 1 106 | if FLAGS.nb_cl == 1: 107 | print('Adding Category ' + str(category_idx + 1)) 108 | else: 109 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 110 | 111 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 112 | category_idx <= train_labels[idx] <= to_category_idx] 113 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 114 | test_labels[idx] <= to_category_idx] 115 | 116 | train_x = raw_images_train[train_indices_reals, :] 117 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 118 | test_x = raw_images_test[test_indices, :] 119 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 120 | 121 | num_old_classes = to_category_idx + 1 - FLAGS.nb_cl 122 | if num_old_classes > 0: 123 | train_weights = np.ones(len(train_x)) * FLAGS.dgr_ratio 124 | else: 125 | train_weights = np.ones(len(train_x)) 126 | 127 | ''' 128 | Train generative model(DC-GAN) 129 | ''' 130 | # Mixed with generated samples of old classes 131 | if num_old_classes > 0: # first session or not 132 | llgan_obj.load(category_idx - 1) 133 | 134 | train_x_gens = [] 135 | train_y_gens = [] 136 | for old_class_idx in range(num_old_classes): 137 | train_x_gens_batch, _, _ = llgan_obj.test(NUM_TRAIN_SAMPLES_PER_CLASS) 138 | train_x_gens.extend(train_x_gens_batch) 139 | train_y_gens_batch = np.eye(to_category_idx + 1, dtype=float)[llgan_obj.get_label(train_x_gens_batch)] 140 | train_y_gens.extend(train_y_gens_batch) 141 | 142 | train_x = np \ 143 | .concatenate((train_x, np.uint8(train_x_gens))) 144 | train_y = np.concatenate((train_y, np.float64(train_y_gens))) 145 | train_weights = np.concatenate((train_weights, np.ones(len(train_x_gens)) * (1 - FLAGS.dgr_ratio))) 146 | 147 | sess_gan.close() 148 | del sess_gan 149 | 150 | graph_gen = tf.Graph() 151 | sess_gan = tf.Session(config=run_config, graph=graph_gen) 152 | 153 | llgan_obj = GAN(sess_gan, graph_gen, sess_fid, 154 | dataset_name=FLAGS.dataset, 155 | mode=FLAGS.mode, 156 | batch_size=FLAGS.batch_size, 157 | output_dim=FLAGS.output_dim, 158 | lambda_param=FLAGS.lambda_param, 159 | critic_iters=FLAGS.critic_iters, 160 | iters=FLAGS.iters, 161 | solver_iters=FLAGS.iters, 162 | solver_adam_lr=FLAGS.solver_adam_lr, 163 | result_dir=result_dir, 164 | checkpoint_interval=FLAGS.gan_save_interval, 165 | adam_lr=FLAGS.adam_lr, 166 | adam_beta1=FLAGS.adam_beta1, 167 | adam_beta2=FLAGS.adam_beta2, 168 | finetune=FLAGS.gan_finetune, 169 | improved_finetune=FLAGS.improved_finetune, 170 | nb_cl=FLAGS.nb_cl, 171 | nb_output=(to_category_idx + 1), 172 | dgr_ratio=FLAGS.dgr_ratio, 173 | order_idx=FLAGS.order_idx, 174 | order=order, 175 | test_interval=FLAGS.test_interval, 176 | improved_finetune_type=FLAGS.improved_finetune_type, 177 | improved_finetune_noise=FLAGS.improved_finetune_noise, 178 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level) 179 | 180 | print(llgan_obj.model_dir) 181 | 182 | ''' 183 | Train generative model(GAN) 184 | ''' 185 | if llgan_obj.check_model(to_category_idx): 186 | model_exist = True 187 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 188 | category_idx + 1, to_category_idx + 1)) 189 | else: 190 | model_exist = False 191 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 192 | category_idx + 1, to_category_idx + 1)) 193 | llgan_obj.train(train_x, train_y, train_weights, test_x, test_y, to_category_idx, model_exist=model_exist) 194 | 195 | if FLAGS.vis_result: 196 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 197 | num_iters=FLAGS.iters) 198 | 199 | sess_fid.close() 200 | 201 | 202 | if __name__ == '__main__': 203 | tf.app.run() 204 | -------------------------------------------------------------------------------- /mnist_train_introgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | On LL-GAN v1.3 4 | move the embedding one layer ahead 5 | ''' 6 | 7 | import tensorflow as tf 8 | 9 | tf.set_random_seed(1993) 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | np.random.seed(1993) 15 | import os 16 | import pprint 17 | 18 | from gan.model_mnist_introgan import IntroGAN 19 | 20 | from utils import fid 21 | from utils.visualize_result_single import vis_acc_and_fid 22 | 23 | flags = tf.app.flags 24 | 25 | flags.DEFINE_string("dataset", "fashion-mnist", "The name of dataset [mnist, fashion-mnist]") 26 | 27 | # Hyperparameters 28 | flags.DEFINE_integer("lambda_param", 10, "Gradient penalty lambda hyperparameter [10]") 29 | flags.DEFINE_integer("critic_iters", 1, 30 | "How many critic iterations per generator iteration [1 for DCGAN and 5 for WGAN-GP]") 31 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 32 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 33 | flags.DEFINE_integer("output_dim", 28 * 28, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 34 | flags.DEFINE_string("mode", 'dcgan', "Valid options are dcgan or wgan-gp") 35 | flags.DEFINE_float("protogan_scale", 1.0, "") 36 | flags.DEFINE_float("protogan_scale_g", 0.1, "") 37 | flags.DEFINE_integer("gan_save_interval", 5000, 'interval to save a checkpoint(number of iters)') 38 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 39 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.5') 40 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.999') 41 | flags.DEFINE_float("proto_weight_real", 0., "") 42 | flags.DEFINE_float("proto_weight_fake", 0., "") 43 | flags.DEFINE_integer("proto_importance", 1, "relative importance of protos versus ordinary samples") 44 | flags.DEFINE_string("dist_func", 'squared_l2', "cosine, squared_l2") 45 | flags.DEFINE_float("gamma", 1e-2, 'smoothness of the output probabilities') 46 | flags.DEFINE_string("center_type", 'rel_center', '[rel_center/fixed_center/multi_center]') 47 | flags.DEFINE_integer("fixed_center_idx", 0, '[0, proto_num - 1]') 48 | flags.DEFINE_integer("proto_num", 20, "") 49 | flags.DEFINE_float("margin", 0., "") 50 | flags.DEFINE_string("proto_select_criterion", 'random', "random or ori_kmeans or feat_kmeans") 51 | flags.DEFINE_boolean("finetune", True, 'if gan finetuned from the previous model') 52 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 53 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 54 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 55 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 56 | flags.DEFINE_boolean("classification_only", False, "True for classification only") 57 | flags.DEFINE_integer("num_samples_per_class", -1, "number of samples per class in the training set") 58 | flags.DEFINE_boolean("exemplars_dual_use", True, "if the exemplars also be added to the training set") 59 | flags.DEFINE_string("anti_imbalance", 'none', "[reweight|oversample|none]") 60 | flags.DEFINE_string("rigorous", 'min', '[min|max|random]') 61 | flags.DEFINE_boolean("train_rel_center", False, "") 62 | 63 | # Add how many classes every time 64 | flags.DEFINE_integer('nb_cl', 2, '') 65 | 66 | # DEBUG 67 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 68 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 69 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 70 | flags.DEFINE_integer("test_interval", 500, 71 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 72 | 73 | # Visualize 74 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 75 | 76 | FLAGS = flags.FLAGS 77 | 78 | pp = pprint.PrettyPrinter() 79 | 80 | 81 | def main(_): 82 | assert FLAGS.fixed_center_idx < FLAGS.proto_num 83 | 84 | if 0 < FLAGS.num_samples_per_class < FLAGS.proto_num: 85 | FLAGS.proto_num = FLAGS.num_samples_per_class 86 | 87 | pp.pprint(flags.FLAGS.__flags) 88 | 89 | NUM_CLASSES = 10 90 | NUM_TRAIN_SAMPLES_PER_CLASS = 6000 # training samples per class 91 | NUM_TEST_SAMPLES_PER_CLASS = 1000 # testing samples per class 92 | 93 | if FLAGS.dataset == 'mnist': 94 | from utils import mnist 95 | raw_images_train, train_labels, train_one_hot_labels, order = mnist.load_data(kind='train', 96 | order_idx=FLAGS.order_idx) 97 | raw_images_test, test_labels, test_one_hot_labels, order = mnist.load_data(kind='t10k', 98 | order_idx=FLAGS.order_idx) 99 | elif FLAGS.dataset == 'fashion-mnist': 100 | from utils import fashion_mnist 101 | raw_images_train, train_labels, train_one_hot_labels, order = fashion_mnist.load_data(kind='train', 102 | order_idx=FLAGS.order_idx) 103 | raw_images_test, test_labels, test_one_hot_labels, order = fashion_mnist.load_data(kind='t10k', 104 | order_idx=FLAGS.order_idx) 105 | elif FLAGS.dataset == 'svhn': 106 | from utils import svhn 107 | raw_images_train, train_labels, train_one_hot_labels, order = svhn.load_data(kind='train', 108 | order_idx=FLAGS.order_idx) 109 | raw_images_test, test_labels, test_one_hot_labels, order = svhn.load_data(kind='test', 110 | order_idx=FLAGS.order_idx) 111 | 112 | # Total training samples 113 | NUM_TRAIN_SAMPLES_TOTAL = NUM_CLASSES * NUM_TRAIN_SAMPLES_PER_CLASS 114 | NUM_TEST_SAMPLES_TOTAL = NUM_CLASSES * NUM_TEST_SAMPLES_PER_CLASS 115 | 116 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333, allow_growth=True) 117 | run_config = tf.ConfigProto(gpu_options=gpu_options) 118 | # run_config.gpu_options.allow_growth = True 119 | 120 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 121 | result_dir = os.path.join('result', method_name) 122 | FLAGS.result_dir = result_dir 123 | print('Result dir: %s' % result_dir) 124 | 125 | graph_fid = tf.Graph() 126 | with graph_fid.as_default(): 127 | inception_path = fid.check_or_download_inception('tmp/imagenet') 128 | fid.create_inception_graph(inception_path) 129 | sess_fid = tf.Session(config=run_config, graph=graph_fid) 130 | 131 | ''' 132 | Class Incremental Learning 133 | ''' 134 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 135 | print('Adding %d categories every time' % FLAGS.nb_cl) 136 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 137 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 138 | 139 | to_category_idx = category_idx + FLAGS.nb_cl - 1 140 | if FLAGS.nb_cl == 1: 141 | print('Adding Category ' + str(category_idx + 1)) 142 | else: 143 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 144 | 145 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 146 | category_idx <= train_labels[idx] <= to_category_idx] 147 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 148 | test_labels[idx] <= to_category_idx] 149 | 150 | train_x = raw_images_train[train_indices_reals, :] 151 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 152 | test_x = raw_images_test[test_indices, :] 153 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 154 | 155 | ''' 156 | Check model exist 157 | ''' 158 | if IntroGAN.check_model(FLAGS, to_category_idx): 159 | model_exist = True 160 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 161 | category_idx + 1, to_category_idx + 1)) 162 | else: 163 | model_exist = False 164 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 165 | category_idx + 1, to_category_idx + 1)) 166 | 167 | ''' 168 | Train generative model(DC-GAN) 169 | ''' 170 | # Mixed with generated samples of old classes 171 | num_old_classes = to_category_idx + 1 - FLAGS.nb_cl 172 | if num_old_classes > 0: # first session or not 173 | if not model_exist: 174 | train_y_gens = np.repeat(range(num_old_classes), [NUM_TRAIN_SAMPLES_PER_CLASS]) 175 | train_y_gens_one_hot = np.eye(num_old_classes)[train_y_gens.reshape(-1)] 176 | llgan_obj.load(category_idx - 1) 177 | 178 | # train_x_gens, _, _ = llgan_obj.test(NUM_TRAIN_SAMPLES_PER_CLASS * num_old_classes, train_y_gens_one_hot) 179 | 180 | train_x_gens = [] 181 | for start_idx in tqdm(range(0, len(train_y_gens_one_hot), FLAGS.batch_size), desc='Generative replay'): 182 | train_y_gens_one_hot_cur_batch = train_y_gens_one_hot[start_idx: start_idx + FLAGS.batch_size] 183 | train_x_gens_batch, _, _ = llgan_obj.test(len(train_y_gens_one_hot_cur_batch), 184 | train_y_gens_one_hot_cur_batch) 185 | train_x_gens.extend(train_x_gens_batch) 186 | 187 | train_x = np.concatenate((train_x, np.uint8(train_x_gens))) 188 | train_y = np.concatenate((train_y, np.eye(to_category_idx + 1)[train_y_gens.reshape(-1)])) 189 | 190 | sess_gan.close() 191 | del sess_gan 192 | 193 | graph_gen = tf.Graph() 194 | sess_gan = tf.Session(config=run_config, graph=graph_gen) 195 | 196 | llgan_obj = IntroGAN(sess_gan, graph_gen, sess_fid, 197 | dataset=FLAGS.dataset, 198 | mode=FLAGS.mode, 199 | batch_size=FLAGS.batch_size, 200 | output_dim=FLAGS.output_dim, 201 | lambda_param=FLAGS.lambda_param, 202 | critic_iters=FLAGS.critic_iters, 203 | iters=FLAGS.iters, 204 | result_dir=FLAGS.result_dir, 205 | checkpoint_interval=FLAGS.gan_save_interval, 206 | adam_lr=FLAGS.adam_lr, 207 | adam_beta1=FLAGS.adam_beta1, 208 | adam_beta2=FLAGS.adam_beta2, 209 | protogan_scale=FLAGS.protogan_scale, 210 | protogan_scale_g=FLAGS.protogan_scale_g, 211 | finetune=FLAGS.finetune, 212 | improved_finetune=FLAGS.improved_finetune, 213 | nb_cl=FLAGS.nb_cl, 214 | nb_output=(to_category_idx + 1), 215 | classification_only=FLAGS.classification_only, 216 | proto_weight_real=FLAGS.proto_weight_real, 217 | proto_weight_fake=FLAGS.proto_weight_fake, 218 | dist_func=FLAGS.dist_func, 219 | gamma=FLAGS.gamma, 220 | proto_num=FLAGS.proto_num, 221 | margin=FLAGS.margin, 222 | order_idx=FLAGS.order_idx, 223 | proto_select_criterion=FLAGS.proto_select_criterion, 224 | order=order, 225 | test_interval=FLAGS.test_interval, 226 | proto_importance=FLAGS.proto_importance, 227 | improved_finetune_type=FLAGS.improved_finetune_type, 228 | improved_finetune_noise=FLAGS.improved_finetune_noise, 229 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 230 | center_type=FLAGS.center_type, 231 | fixed_center_idx=FLAGS.fixed_center_idx, 232 | num_samples_per_class=FLAGS.num_samples_per_class, 233 | exemplars_dual_use=FLAGS.exemplars_dual_use, 234 | anti_imbalance=FLAGS.anti_imbalance, 235 | train_rel_center=FLAGS.train_rel_center, 236 | rigorous=FLAGS.rigorous) 237 | 238 | print(IntroGAN.model_dir_static(FLAGS)) 239 | 240 | ''' 241 | Train generative model(GAN) 242 | ''' 243 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 244 | 245 | if FLAGS.vis_result: 246 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 247 | num_iters=FLAGS.iters, vis_fid=(not FLAGS.classification_only)) 248 | 249 | sess_fid.close() 250 | 251 | 252 | if __name__ == '__main__': 253 | tf.app.run() 254 | -------------------------------------------------------------------------------- /mnist_train_lowerbound.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | tf.set_random_seed(1993) 6 | 7 | import numpy as np 8 | 9 | np.random.seed(1993) 10 | import os 11 | import pprint 12 | 13 | from gan.model_mnist_cnn import ClsNet 14 | 15 | from utils.visualize_result_single import vis_acc_and_fid 16 | 17 | flags = tf.app.flags 18 | 19 | flags.DEFINE_string("dataset", "fashion-mnist", "The name of dataset [mnist, fashion-mnist]") 20 | 21 | # Hyperparameters 22 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 23 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 24 | flags.DEFINE_integer("output_dim", 28 * 28, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 25 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 26 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.5') 27 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.999') 28 | flags.DEFINE_boolean("finetune", True, 'if gan finetuned from the previous model') 29 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 30 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 31 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 32 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 33 | 34 | # Add how many classes every time 35 | flags.DEFINE_integer('nb_cl', 2, '') 36 | 37 | # DEBUG 38 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 39 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 40 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 41 | flags.DEFINE_integer("test_interval", 500, 42 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 43 | 44 | # Visualize 45 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | pp = pprint.PrettyPrinter() 50 | 51 | 52 | def main(_): 53 | pp.pprint(flags.FLAGS.__flags) 54 | 55 | if FLAGS.dataset == 'mnist': 56 | from utils import mnist 57 | raw_images_train, train_labels, train_one_hot_labels, order = mnist.load_data(kind='train', 58 | order_idx=FLAGS.order_idx) 59 | raw_images_test, test_labels, test_one_hot_labels, order = mnist.load_data(kind='t10k', 60 | order_idx=FLAGS.order_idx) 61 | elif FLAGS.dataset == 'fashion-mnist': 62 | from utils import fashion_mnist 63 | raw_images_train, train_labels, train_one_hot_labels, order = fashion_mnist.load_data(kind='train', 64 | order_idx=FLAGS.order_idx) 65 | raw_images_test, test_labels, test_one_hot_labels, order = fashion_mnist.load_data(kind='t10k', 66 | order_idx=FLAGS.order_idx) 67 | 68 | # Total training samples 69 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 70 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 71 | 72 | run_config = tf.ConfigProto() 73 | run_config.gpu_options.allow_growth = True 74 | 75 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 76 | result_dir = os.path.join('result', method_name) 77 | print('Result dir: %s' % result_dir) 78 | 79 | ''' 80 | Class Incremental Learning 81 | ''' 82 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 83 | print('Adding %d categories every time' % FLAGS.nb_cl) 84 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 85 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 86 | 87 | to_category_idx = category_idx + FLAGS.nb_cl - 1 88 | if FLAGS.nb_cl == 1: 89 | print('Adding Category ' + str(category_idx + 1)) 90 | else: 91 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 92 | 93 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 94 | category_idx <= train_labels[idx] <= to_category_idx] 95 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 96 | test_labels[idx] <= to_category_idx] 97 | 98 | train_x = raw_images_train[train_indices_reals, :] 99 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 100 | test_x = raw_images_test[test_indices, :] 101 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 102 | 103 | graph = tf.Graph() 104 | sess = tf.Session(config=run_config, graph=graph) 105 | llgan_obj = ClsNet(sess, graph, 106 | dataset_name=FLAGS.dataset, 107 | batch_size=FLAGS.batch_size, 108 | output_dim=FLAGS.output_dim, 109 | iters=FLAGS.iters, 110 | result_dir=result_dir, 111 | adam_lr=FLAGS.adam_lr, 112 | adam_beta1=FLAGS.adam_beta1, 113 | adam_beta2=FLAGS.adam_beta2, 114 | nb_cl=FLAGS.nb_cl, 115 | nb_output=(to_category_idx + 1), 116 | order_idx=FLAGS.order_idx, 117 | order=order, 118 | test_interval=FLAGS.test_interval, 119 | finetune=FLAGS.finetune, 120 | improved_finetune=FLAGS.improved_finetune, 121 | improved_finetune_noise=FLAGS.improved_finetune_noise, 122 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 123 | improved_finetune_type=FLAGS.improved_finetune_type) 124 | 125 | print(llgan_obj.model_dir) 126 | 127 | ''' 128 | Train generative model(GAN) 129 | ''' 130 | if llgan_obj.check_model(to_category_idx): 131 | model_exist = True 132 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 133 | category_idx + 1, to_category_idx + 1)) 134 | else: 135 | model_exist = False 136 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 137 | category_idx + 1, to_category_idx + 1)) 138 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 139 | 140 | if FLAGS.vis_result: 141 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 142 | num_iters=FLAGS.iters, vis_fid=False) 143 | 144 | 145 | if __name__ == '__main__': 146 | tf.app.run() 147 | -------------------------------------------------------------------------------- /mnist_train_mergan.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | tf.set_random_seed(1993) 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | np.random.seed(1993) 11 | import os 12 | import pprint 13 | 14 | from gan.model_mnist_mergan import MeRGAN 15 | 16 | from utils import fid 17 | from utils.visualize_result_single import vis_acc_and_fid 18 | 19 | flags = tf.app.flags 20 | 21 | flags.DEFINE_string("dataset", "fashion-mnist", "The name of dataset [mnist, fashion-mnist]") 22 | 23 | # Hyperparameters 24 | flags.DEFINE_integer("lambda_param", 10, "Gradient penalty lambda hyperparameter [10]") 25 | flags.DEFINE_integer("critic_iters", 1, 26 | "How many critic iterations per generator iteration [1 for DCGAN and 5 for WGAN-GP]") 27 | flags.DEFINE_integer("class_iters", 1, "How many class iterations per generator iteration") 28 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 29 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 30 | flags.DEFINE_integer("output_dim", 28 * 28, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 31 | flags.DEFINE_string("mode", 'dcgan', "Valid options are dcgan or wgan-gp") 32 | flags.DEFINE_float("acgan_scale", 1.0, "") 33 | flags.DEFINE_float("acgan_scale_g", 0.1, "") 34 | flags.DEFINE_boolean("use_softmax", True, "") 35 | flags.DEFINE_integer("gan_save_interval", 5000, 'interval to save a checkpoint(number of iters)') 36 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 37 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.5') 38 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.999') 39 | flags.DEFINE_boolean("finetune", True, 'if gan finetuned from the previous model') 40 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 41 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 42 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 43 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 44 | flags.DEFINE_boolean("classification_only", False, "True for classification only") 45 | flags.DEFINE_integer("test_interval", 500, 46 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 47 | flags.DEFINE_integer("num_samples_per_class", -1, "number of samples per class in the training set") 48 | 49 | # ablation study 50 | flags.DEFINE_boolean("use_protos", False, "use prototypes of IntroGAN for ablation studies") 51 | flags.DEFINE_string("protos_path", 52 | "protogan_v1_3_8/%s_order_%d/nb_cl_%d/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_%d_weight_0.000000_0.000000_squared_l2_0.010000/finetune_improved_v2_noise_0.5_exemplars_dual_use_%d", 53 | "the path of prototypes of IntroGAN") 54 | flags.DEFINE_integer("protos_num", 20, "number of prototypes") 55 | flags.DEFINE_integer("protos_importance", 1, "relative importance of protos versus ordinary samples") 56 | 57 | # diversity promoting 58 | flags.DEFINE_float("diversity_promoting_weight", 0., "weight of the diversity promoting loss") 59 | 60 | # Add how many classes every time 61 | flags.DEFINE_integer('nb_cl', 2, '') 62 | 63 | # DEBUG 64 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 65 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 66 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 67 | 68 | # Visualize 69 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 70 | 71 | FLAGS = flags.FLAGS 72 | 73 | pp = pprint.PrettyPrinter() 74 | 75 | 76 | def main(_): 77 | FLAGS.use_diversity_promoting = (False if FLAGS.diversity_promoting_weight == 0. else True) 78 | 79 | pp.pprint(flags.FLAGS.__flags) 80 | 81 | NUM_CLASSES = 10 82 | NUM_TRAIN_SAMPLES_PER_CLASS = 6000 # training samples per class 83 | NUM_TEST_SAMPLES_PER_CLASS = 1000 # testing samples per class 84 | 85 | if FLAGS.dataset == 'mnist': 86 | from utils import mnist 87 | raw_images_train, train_labels, train_one_hot_labels, order = mnist.load_data(kind='train', 88 | order_idx=FLAGS.order_idx) 89 | raw_images_test, test_labels, test_one_hot_labels, order = mnist.load_data(kind='t10k', 90 | order_idx=FLAGS.order_idx) 91 | elif FLAGS.dataset == 'fashion-mnist': 92 | from utils import fashion_mnist 93 | raw_images_train, train_labels, train_one_hot_labels, order = fashion_mnist.load_data(kind='train', 94 | order_idx=FLAGS.order_idx) 95 | raw_images_test, test_labels, test_one_hot_labels, order = fashion_mnist.load_data(kind='t10k', 96 | order_idx=FLAGS.order_idx) 97 | 98 | # Total training samples 99 | NUM_TRAIN_SAMPLES_TOTAL = NUM_CLASSES * NUM_TRAIN_SAMPLES_PER_CLASS 100 | NUM_TEST_SAMPLES_TOTAL = NUM_CLASSES * NUM_TEST_SAMPLES_PER_CLASS 101 | 102 | run_config = tf.ConfigProto() 103 | run_config.gpu_options.allow_growth = True 104 | 105 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 106 | result_dir = os.path.join('result', method_name) 107 | FLAGS.result_dir = result_dir 108 | print('Result dir: %s' % result_dir) 109 | 110 | graph_fid = tf.Graph() 111 | with graph_fid.as_default(): 112 | inception_path = fid.check_or_download_inception('tmp/imagenet') 113 | fid.create_inception_graph(inception_path) 114 | sess_fid = tf.Session(config=run_config, graph=graph_fid) 115 | 116 | ''' 117 | Class Incremental Learning 118 | ''' 119 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 120 | print('Adding %d categories every time' % FLAGS.nb_cl) 121 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 122 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 123 | 124 | to_category_idx = category_idx + FLAGS.nb_cl - 1 125 | if FLAGS.nb_cl == 1: 126 | print('Adding Category ' + str(category_idx + 1)) 127 | else: 128 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 129 | 130 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 131 | category_idx <= train_labels[idx] <= to_category_idx] 132 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 133 | test_labels[idx] <= to_category_idx] 134 | 135 | train_x = raw_images_train[train_indices_reals, :] 136 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 137 | test_x = raw_images_test[test_indices, :] 138 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 139 | 140 | ''' 141 | Check model exist 142 | ''' 143 | if MeRGAN.check_model(FLAGS, to_category_idx): 144 | model_exist = True 145 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 146 | category_idx + 1, to_category_idx + 1)) 147 | else: 148 | model_exist = False 149 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 150 | category_idx + 1, to_category_idx + 1)) 151 | 152 | ''' 153 | Train generative model(DC-GAN) 154 | ''' 155 | # Mixed with generated samples of old classes 156 | num_old_classes = to_category_idx + 1 - FLAGS.nb_cl 157 | if num_old_classes > 0: # first session or not 158 | if not model_exist: 159 | train_y_gens = np.repeat(range(num_old_classes), [NUM_TRAIN_SAMPLES_PER_CLASS]) 160 | train_y_gens_one_hot = np.eye(num_old_classes)[train_y_gens.reshape(-1)] 161 | llgan_obj.load(category_idx - 1) 162 | 163 | # train_x_gens, _, _ = llgan_obj.test(NUM_TRAIN_SAMPLES_PER_CLASS * num_old_classes, train_y_gens_one_hot) 164 | 165 | train_x_gens = [] 166 | for start_idx in tqdm(range(0, len(train_y_gens_one_hot), FLAGS.batch_size), desc='Generative replay'): 167 | train_y_gens_one_hot_cur_batch = train_y_gens_one_hot[start_idx: start_idx + FLAGS.batch_size] 168 | train_x_gens_batch, _, _ = llgan_obj.test(len(train_y_gens_one_hot_cur_batch), 169 | train_y_gens_one_hot_cur_batch) 170 | train_x_gens.extend(train_x_gens_batch) 171 | 172 | train_x = np.concatenate((train_x, np.uint8(train_x_gens))) 173 | train_y = np.concatenate((train_y, np.eye(to_category_idx + 1)[train_y_gens.reshape(-1)])) 174 | 175 | sess_gan.close() 176 | del sess_gan 177 | 178 | graph_gen = tf.Graph() 179 | sess_gan = tf.Session(config=run_config, graph=graph_gen) 180 | 181 | llgan_obj = MeRGAN(sess_gan, graph_gen, sess_fid, 182 | dataset=FLAGS.dataset, 183 | mode=FLAGS.mode, 184 | batch_size=FLAGS.batch_size, 185 | output_dim=FLAGS.output_dim, 186 | lambda_param=FLAGS.lambda_param, 187 | critic_iters=FLAGS.critic_iters, 188 | class_iters=FLAGS.class_iters, 189 | iters=FLAGS.iters, 190 | result_dir=FLAGS.result_dir, 191 | checkpoint_interval=FLAGS.gan_save_interval, 192 | adam_lr=FLAGS.adam_lr, 193 | adam_beta1=FLAGS.adam_beta1, 194 | adam_beta2=FLAGS.adam_beta2, 195 | acgan_scale=FLAGS.acgan_scale, 196 | acgan_scale_g=FLAGS.acgan_scale_g, 197 | finetune=FLAGS.finetune, 198 | improved_finetune=FLAGS.improved_finetune, 199 | nb_cl=FLAGS.nb_cl, 200 | nb_output=(to_category_idx + 1), 201 | classification_only=FLAGS.classification_only, 202 | order_idx=FLAGS.order_idx, 203 | order=order, 204 | test_interval=FLAGS.test_interval, 205 | use_softmax=FLAGS.use_softmax, 206 | use_diversity_promoting=FLAGS.use_diversity_promoting, 207 | diversity_promoting_weight=FLAGS.diversity_promoting_weight, 208 | improved_finetune_noise=FLAGS.improved_finetune_noise, 209 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 210 | num_samples_per_class=FLAGS.num_samples_per_class, 211 | use_protos=FLAGS.use_protos, 212 | protos_path=FLAGS.protos_path, 213 | protos_num=FLAGS.protos_num, 214 | protos_importance=FLAGS.protos_importance, 215 | improved_finetune_type=FLAGS.improved_finetune_type, 216 | ) 217 | 218 | print(MeRGAN.model_dir_static(FLAGS)) 219 | 220 | ''' 221 | Train generative model(GAN) 222 | ''' 223 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 224 | 225 | if FLAGS.vis_result: 226 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 227 | num_iters=FLAGS.iters, vis_fid=(not FLAGS.classification_only)) 228 | 229 | sess_fid.close() 230 | 231 | 232 | if __name__ == '__main__': 233 | tf.app.run() 234 | -------------------------------------------------------------------------------- /mnist_train_upperbound.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | tf.set_random_seed(1993) 5 | 6 | import numpy as np 7 | np.random.seed(1993) 8 | import os 9 | import pprint 10 | 11 | from gan.model_mnist_cnn import ClsNet 12 | 13 | from utils.visualize_result_single import vis_acc_and_fid 14 | 15 | flags = tf.app.flags 16 | 17 | flags.DEFINE_string("dataset", "fashion-mnist", "The name of dataset [mnist, fashion-mnist]") 18 | 19 | # Hyperparameters 20 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 21 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 22 | flags.DEFINE_integer("output_dim", 28*28, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 23 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 24 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.5') 25 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.999') 26 | flags.DEFINE_boolean("finetune", False, '') 27 | flags.DEFINE_boolean("improved_finetune", True, '') 28 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 29 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 30 | 31 | # Add how many classes every time 32 | flags.DEFINE_integer('nb_cl', 2, '') 33 | 34 | # DEBUG 35 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 36 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 37 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 38 | flags.DEFINE_integer("test_interval", 500, "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 39 | 40 | # Visualize 41 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | pp = pprint.PrettyPrinter() 46 | 47 | 48 | def main(_): 49 | 50 | pp.pprint(flags.FLAGS.__flags) 51 | 52 | if FLAGS.dataset == 'mnist': 53 | from utils import mnist 54 | raw_images_train, train_labels, train_one_hot_labels, order = mnist.load_data(kind='train', order_idx=FLAGS.order_idx) 55 | raw_images_test, test_labels, test_one_hot_labels, order = mnist.load_data(kind='t10k', order_idx=FLAGS.order_idx) 56 | elif FLAGS.dataset == 'fashion-mnist': 57 | from utils import fashion_mnist 58 | raw_images_train, train_labels, train_one_hot_labels, order = fashion_mnist.load_data(kind='train', order_idx=FLAGS.order_idx) 59 | raw_images_test, test_labels, test_one_hot_labels, order = fashion_mnist.load_data(kind='t10k', order_idx=FLAGS.order_idx) 60 | 61 | # Total training samples 62 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 63 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 64 | 65 | run_config = tf.ConfigProto() 66 | run_config.gpu_options.allow_growth = True 67 | 68 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 69 | result_dir = os.path.join('result', method_name) 70 | print('Result dir: %s' % result_dir) 71 | 72 | ''' 73 | Class Incremental Learning 74 | ''' 75 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 76 | print('Adding %d categories every time' % FLAGS.nb_cl) 77 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 78 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 79 | 80 | to_category_idx = category_idx + FLAGS.nb_cl - 1 81 | if FLAGS.nb_cl == 1: 82 | print('Adding Category ' + str(category_idx + 1)) 83 | else: 84 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 85 | 86 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 87 | train_labels[idx] <= to_category_idx] 88 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 89 | test_labels[idx] <= to_category_idx] 90 | 91 | train_x = raw_images_train[train_indices_reals, :] 92 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx+1)] 93 | test_x = raw_images_test[test_indices, :] 94 | test_y = test_one_hot_labels[test_indices, :(to_category_idx+1)] 95 | 96 | graph = tf.Graph() 97 | sess = tf.Session(config=run_config, graph=graph) 98 | llgan_obj = ClsNet(sess, graph, 99 | dataset_name=FLAGS.dataset, 100 | batch_size=FLAGS.batch_size, 101 | output_dim=FLAGS.output_dim, 102 | iters=FLAGS.iters, 103 | result_dir=result_dir, 104 | adam_lr=FLAGS.adam_lr, 105 | adam_beta1=FLAGS.adam_beta1, 106 | adam_beta2=FLAGS.adam_beta2, 107 | nb_cl=FLAGS.nb_cl, 108 | nb_output=(to_category_idx + 1), 109 | order_idx=FLAGS.order_idx, 110 | order=order, 111 | test_interval=FLAGS.test_interval, 112 | finetune=FLAGS.finetune, 113 | improved_finetune=FLAGS.improved_finetune, 114 | improved_finetune_noise=FLAGS.improved_finetune_noise, 115 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level) 116 | 117 | print(llgan_obj.model_dir) 118 | 119 | ''' 120 | Train generative model(GAN) 121 | ''' 122 | if llgan_obj.check_model(to_category_idx): 123 | model_exist = True 124 | print(" [*] Model of Class %d-%d exists. Skip the training process" % (category_idx + 1, to_category_idx + 1)) 125 | else: 126 | model_exist = False 127 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % (category_idx + 1, to_category_idx + 1)) 128 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 129 | 130 | if FLAGS.vis_result: 131 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 132 | num_iters=FLAGS.iters, vis_fid=False) 133 | 134 | 135 | if __name__ == '__main__': 136 | tf.app.run() 137 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16. 2 | matplotlib==2.1.0 3 | protobuf==3.7.1 4 | tensorflow-gpu==1.4.0 5 | tqdm==4.46.0 6 | scipy==1.2.1 7 | scikit-learn==0.20.4 8 | pathlib==1.0.1 9 | llvmlite==0.30.0 10 | numba==0.46.0 11 | umap-learn==0.3.9 -------------------------------------------------------------------------------- /svhn_train_dgr.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | tf.set_random_seed(1993) 6 | 7 | import numpy as np 8 | 9 | np.random.seed(1993) 10 | import os 11 | import pprint 12 | 13 | from gan.model_svhn_dgr import GAN 14 | 15 | from utils import fid 16 | from utils.visualize_result_single import vis_acc_and_fid 17 | 18 | flags = tf.app.flags 19 | 20 | flags.DEFINE_string("dataset", "svhn", "The name of dataset [svhn]") 21 | 22 | # Hyperparameters 23 | flags.DEFINE_integer("lambda_param", 10, "Gradient penalty lambda hyperparameter [10]") 24 | flags.DEFINE_integer("critic_iters", 1, 25 | "How many critic iterations per generator iteration [1 for DCGAN and 5 for WGAN-GP]") 26 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 27 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 28 | flags.DEFINE_integer("output_dim", 3 * 32 * 32, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 29 | flags.DEFINE_integer("dim", 64, "GAN dim") 30 | flags.DEFINE_string("mode", 'dcgan', "Valid options are dcgan or wgan-gp") 31 | flags.DEFINE_integer("gan_save_interval", 5000, 'interval to save a checkpoint(number of iters)') 32 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 33 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.0') 34 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.9') 35 | flags.DEFINE_boolean("gan_finetune", True, 'if gan finetuned from the previous model') 36 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 37 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 38 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 39 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 40 | flags.DEFINE_float("dgr_ratio", 0.5, "") 41 | flags.DEFINE_float("solver_adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 42 | flags.DEFINE_integer("test_interval", 500, 43 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 44 | 45 | # Add how many classes every time 46 | flags.DEFINE_integer('nb_cl', 2, '') 47 | 48 | # DEBUG 49 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 50 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 51 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 52 | 53 | # Visualize 54 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 55 | 56 | FLAGS = flags.FLAGS 57 | 58 | pp = pprint.PrettyPrinter() 59 | 60 | 61 | def main(_): 62 | pp.pprint(flags.FLAGS.__flags) 63 | 64 | from utils import svhn 65 | raw_images_train, train_labels, train_one_hot_labels, order = svhn.load_data(kind='train', 66 | order_idx=FLAGS.order_idx) 67 | raw_images_test, test_labels, test_one_hot_labels, order = svhn.load_data(kind='test', order_idx=FLAGS.order_idx) 68 | 69 | # Total training samples 70 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 71 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 72 | 73 | NUM_CLASSES = 10 74 | 75 | NUM_TRAIN_SAMPLES_PER_CLASS = NUM_TRAIN_SAMPLES_TOTAL / NUM_CLASSES 76 | 77 | run_config = tf.ConfigProto() 78 | run_config.gpu_options.allow_growth = True 79 | 80 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 81 | result_dir = os.path.join('result', method_name) 82 | print('Result dir: %s' % result_dir) 83 | 84 | graph_fid = tf.Graph() 85 | with graph_fid.as_default(): 86 | inception_path = fid.check_or_download_inception('tmp/imagenet') 87 | fid.create_inception_graph(inception_path) 88 | sess_fid = tf.Session(config=run_config, graph=graph_fid) 89 | 90 | ''' 91 | Class Incremental Learning 92 | ''' 93 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 94 | print('Adding %d categories every time' % FLAGS.nb_cl) 95 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 96 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 97 | 98 | to_category_idx = category_idx + FLAGS.nb_cl - 1 99 | if FLAGS.nb_cl == 1: 100 | print('Adding Category ' + str(category_idx + 1)) 101 | else: 102 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 103 | 104 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 105 | category_idx <= train_labels[idx] <= to_category_idx] 106 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 107 | test_labels[idx] <= to_category_idx] 108 | 109 | train_x = raw_images_train[train_indices_reals, :] 110 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 111 | test_x = raw_images_test[test_indices, :] 112 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 113 | 114 | num_old_classes = to_category_idx + 1 - FLAGS.nb_cl 115 | if num_old_classes > 0: 116 | train_weights = np.ones(len(train_x)) * FLAGS.dgr_ratio 117 | else: 118 | train_weights = np.ones(len(train_x)) 119 | 120 | ''' 121 | Train generative model(DC-GAN) 122 | ''' 123 | # Mixed with generated samples of old classes 124 | if num_old_classes > 0: # first session or not 125 | llgan_obj.load(category_idx - 1) 126 | 127 | train_x_gens = [] 128 | train_y_gens = [] 129 | tmp_a, tmp_b = divmod(num_old_classes * NUM_TRAIN_SAMPLES_PER_CLASS, FLAGS.batch_size) 130 | batch_sizes = [FLAGS.batch_size] * tmp_a + [tmp_b] if tmp_b > 0 else [FLAGS.batch_size] * tmp_a 131 | for batch_size in batch_sizes: 132 | train_x_gens_batch, _, _ = llgan_obj.test(batch_size) 133 | train_y_gens_batch = np.eye(to_category_idx + 1, dtype=float)[llgan_obj.get_label(train_x_gens_batch)] 134 | train_x_gens.extend(train_x_gens_batch) 135 | train_y_gens.extend(train_y_gens_batch) 136 | 137 | train_x = np \ 138 | .concatenate((train_x, np.uint8(train_x_gens))) 139 | train_y = np.concatenate((train_y, np.float64(train_y_gens))) 140 | train_weights = np.concatenate((train_weights, np.ones(len(train_x_gens)) * (1 - FLAGS.dgr_ratio))) 141 | 142 | sess_gan.close() 143 | del sess_gan 144 | 145 | graph_gen = tf.Graph() 146 | sess_gan = tf.Session(config=run_config, graph=graph_gen) 147 | 148 | llgan_obj = GAN(sess_gan, graph_gen, sess_fid, 149 | dataset_name=FLAGS.dataset, 150 | mode=FLAGS.mode, 151 | batch_size=FLAGS.batch_size, 152 | output_dim=FLAGS.output_dim, 153 | lambda_param=FLAGS.lambda_param, 154 | critic_iters=FLAGS.critic_iters, 155 | iters=FLAGS.iters, 156 | solver_iters=FLAGS.iters, 157 | solver_adam_lr=FLAGS.solver_adam_lr, 158 | result_dir=result_dir, 159 | checkpoint_interval=FLAGS.gan_save_interval, 160 | adam_lr=FLAGS.adam_lr, 161 | adam_beta1=FLAGS.adam_beta1, 162 | adam_beta2=FLAGS.adam_beta2, 163 | finetune=FLAGS.gan_finetune, 164 | improved_finetune=FLAGS.improved_finetune, 165 | nb_cl=FLAGS.nb_cl, 166 | nb_output=(to_category_idx + 1), 167 | dgr_ratio=FLAGS.dgr_ratio, 168 | dim=FLAGS.dim, 169 | order_idx=FLAGS.order_idx, 170 | order=order, 171 | test_interval=FLAGS.test_interval, 172 | improved_finetune_type=FLAGS.improved_finetune_type, 173 | improved_finetune_noise=FLAGS.improved_finetune_noise, 174 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level) 175 | 176 | print(llgan_obj.model_dir) 177 | 178 | ''' 179 | Train generative model(GAN) 180 | ''' 181 | if llgan_obj.check_model(to_category_idx): 182 | model_exist = True 183 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 184 | category_idx + 1, to_category_idx + 1)) 185 | else: 186 | model_exist = False 187 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 188 | category_idx + 1, to_category_idx + 1)) 189 | llgan_obj.train(train_x, train_y, train_weights, test_x, test_y, to_category_idx, model_exist=model_exist) 190 | 191 | if FLAGS.vis_result: 192 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 193 | num_iters=FLAGS.iters) 194 | 195 | sess_fid.close() 196 | 197 | 198 | if __name__ == '__main__': 199 | tf.app.run() 200 | -------------------------------------------------------------------------------- /svhn_train_introgan.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ''' 3 | On LL-GAN v1.3 4 | move the embedding one layer ahead 5 | ''' 6 | import time 7 | 8 | import tensorflow as tf 9 | 10 | tf.set_random_seed(1993) 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | np.random.seed(1993) 16 | import os 17 | import pprint 18 | 19 | from gan.model_svhn_introgan import ProtoGAN 20 | 21 | from utils import fid 22 | from utils.visualize_result_single import vis_acc_and_fid 23 | 24 | flags = tf.app.flags 25 | 26 | flags.DEFINE_string("dataset", "svhn", "The name of dataset [svhn]") 27 | flags.DEFINE_boolean("rgb", True, "rgb or gray") 28 | 29 | # Hyperparameters 30 | flags.DEFINE_integer("lambda_param", 10, "Gradient penalty lambda hyperparameter [10]") 31 | flags.DEFINE_integer("critic_iters", 1, 32 | "How many critic iterations per generator iteration [1 for DCGAN and 5 for WGAN-GP]") 33 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 34 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 35 | flags.DEFINE_integer("output_dim", 3 * 32 * 32, "Number of pixels in MNIST/fashion-MNIST (3*32*32) [3072]") 36 | flags.DEFINE_integer("dim", 64, "GAN dim") 37 | flags.DEFINE_string("mode", 'dcgan', "Valid options are dcgan or wgan-gp") 38 | flags.DEFINE_float("protogan_scale", 1.0, "") 39 | flags.DEFINE_float("protogan_scale_g", 0.1, "") 40 | flags.DEFINE_integer("gan_save_interval", 5000, 'interval to save a checkpoint(number of iters)') 41 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 42 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.5') 43 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.999') 44 | flags.DEFINE_float("proto_weight_real", 0., "") 45 | flags.DEFINE_float("proto_weight_fake", 0., "") 46 | flags.DEFINE_integer("proto_importance", 1, "relative importance of protos versus ordinary samples") 47 | flags.DEFINE_string("dist_func", 'squared_l2', "cosine, squared_l2") 48 | flags.DEFINE_float("gamma", 1e-2, 'smoothness of the output probabilities') 49 | flags.DEFINE_string("center_type", 'rel_center', '[rel_center/fixed_center/multi_center]') 50 | flags.DEFINE_integer("fixed_center_idx", 0, '[0, proto_num - 1]') 51 | flags.DEFINE_integer("proto_num", 20, "") 52 | flags.DEFINE_float("margin", 0., "") 53 | flags.DEFINE_string("proto_select_criterion", 'random', "random or ori_kmeans or feat_kmeans") 54 | flags.DEFINE_boolean("finetune", True, 'if gan finetuned from the previous model') 55 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 56 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 57 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 58 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 59 | flags.DEFINE_boolean("classification_only", False, "True for classification only") 60 | flags.DEFINE_integer("num_samples_per_class", -1, "number of samples per class in the training set") 61 | flags.DEFINE_boolean("exemplars_dual_use", True, "if the exemplars also be added to the training set") 62 | flags.DEFINE_string("anti_imbalance", 'none', "[reweight|oversample|none]") 63 | flags.DEFINE_string("rigorous", 'min', '[min|max|random]') 64 | flags.DEFINE_boolean("train_rel_center", False, "") 65 | 66 | # Add how many classes every time 67 | flags.DEFINE_integer('nb_cl', 2, '') 68 | 69 | # DEBUG 70 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 71 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 72 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 73 | flags.DEFINE_integer("test_interval", 500, 74 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 75 | 76 | # Visualize 77 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 78 | 79 | FLAGS = flags.FLAGS 80 | 81 | pp = pprint.PrettyPrinter() 82 | 83 | 84 | def main(_): 85 | start_time = time.time() 86 | assert FLAGS.fixed_center_idx < FLAGS.proto_num 87 | 88 | if 0 < FLAGS.num_samples_per_class < FLAGS.proto_num: 89 | FLAGS.proto_num = FLAGS.num_samples_per_class 90 | 91 | pp.pprint(flags.FLAGS.__flags) 92 | 93 | NUM_CLASSES = 10 94 | 95 | from utils import svhn 96 | raw_images_train, train_labels, train_one_hot_labels, order = svhn.load_data(kind='train', 97 | order_idx=FLAGS.order_idx, 98 | num_samples_per_class=FLAGS.num_samples_per_class) 99 | raw_images_test, test_labels, test_one_hot_labels, order = svhn.load_data(kind='test', order_idx=FLAGS.order_idx) 100 | 101 | # Total training samples 102 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 103 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 104 | 105 | NUM_TRAIN_SAMPLES_PER_CLASS = NUM_TRAIN_SAMPLES_TOTAL / NUM_CLASSES 106 | 107 | run_config = tf.ConfigProto() 108 | run_config.gpu_options.allow_growth = True 109 | 110 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 111 | result_dir = os.path.join('result', method_name) 112 | FLAGS.result_dir = result_dir 113 | print('Result dir: %s' % result_dir) 114 | 115 | graph_fid = tf.Graph() 116 | with graph_fid.as_default(): 117 | inception_path = fid.check_or_download_inception('tmp/imagenet') 118 | fid.create_inception_graph(inception_path) 119 | sess_fid = tf.Session(config=run_config, graph=graph_fid) 120 | 121 | ''' 122 | Class Incremental Learning 123 | ''' 124 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 125 | print('Adding %d categories every time' % FLAGS.nb_cl) 126 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 127 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 128 | 129 | to_category_idx = category_idx + FLAGS.nb_cl - 1 130 | if FLAGS.nb_cl == 1: 131 | print('Adding Category ' + str(category_idx + 1)) 132 | else: 133 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 134 | 135 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 136 | category_idx <= train_labels[idx] <= to_category_idx] 137 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 138 | test_labels[idx] <= to_category_idx] 139 | 140 | train_x = raw_images_train[train_indices_reals, :] 141 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 142 | test_x = raw_images_test[test_indices, :] 143 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 144 | 145 | ''' 146 | Check model exist 147 | ''' 148 | if ProtoGAN.check_model(FLAGS, to_category_idx): 149 | model_exist = True 150 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 151 | category_idx + 1, to_category_idx + 1)) 152 | else: 153 | model_exist = False 154 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 155 | category_idx + 1, to_category_idx + 1)) 156 | 157 | ''' 158 | Train generative model(DC-GAN) 159 | ''' 160 | # Mixed with generated samples of old classes 161 | num_old_classes = to_category_idx + 1 - FLAGS.nb_cl 162 | if num_old_classes > 0: # first session or not 163 | if not model_exist: 164 | train_y_gens = np.repeat(range(num_old_classes), [NUM_TRAIN_SAMPLES_PER_CLASS]) 165 | train_y_gens_one_hot = np.eye(num_old_classes)[train_y_gens.reshape(-1)] 166 | llgan_obj.load(category_idx - 1) 167 | 168 | # train_x_gens, _, _ = llgan_obj.test(NUM_TRAIN_SAMPLES_PER_CLASS * num_old_classes, train_y_gens_one_hot) 169 | 170 | train_x_gens = [] 171 | for start_idx in tqdm(range(0, len(train_y_gens_one_hot), FLAGS.batch_size), desc='Generative replay'): 172 | train_y_gens_one_hot_cur_batch = train_y_gens_one_hot[start_idx: start_idx + FLAGS.batch_size] 173 | train_x_gens_batch, _, _ = llgan_obj.test(len(train_y_gens_one_hot_cur_batch), 174 | train_y_gens_one_hot_cur_batch) 175 | train_x_gens.extend(train_x_gens_batch) 176 | 177 | train_x = np.concatenate((train_x, np.uint8(train_x_gens))) 178 | train_y = np.concatenate((train_y, np.eye(to_category_idx + 1)[train_y_gens.reshape(-1)])) 179 | 180 | sess_gan.close() 181 | del sess_gan 182 | 183 | graph_gen = tf.Graph() 184 | sess_gan = tf.Session(config=run_config, graph=graph_gen) 185 | 186 | llgan_obj = ProtoGAN(sess_gan, graph_gen, sess_fid, 187 | dataset=FLAGS.dataset, 188 | mode=FLAGS.mode, 189 | batch_size=FLAGS.batch_size, 190 | output_dim=FLAGS.output_dim, 191 | lambda_param=FLAGS.lambda_param, 192 | critic_iters=FLAGS.critic_iters, 193 | iters=FLAGS.iters, 194 | result_dir=FLAGS.result_dir, 195 | checkpoint_interval=FLAGS.gan_save_interval, 196 | adam_lr=FLAGS.adam_lr, 197 | adam_beta1=FLAGS.adam_beta1, 198 | adam_beta2=FLAGS.adam_beta2, 199 | protogan_scale=FLAGS.protogan_scale, 200 | protogan_scale_g=FLAGS.protogan_scale_g, 201 | finetune=FLAGS.finetune, 202 | improved_finetune=FLAGS.improved_finetune, 203 | nb_cl=FLAGS.nb_cl, 204 | nb_output=(to_category_idx + 1), 205 | classification_only=FLAGS.classification_only, 206 | proto_weight_real=FLAGS.proto_weight_real, 207 | proto_weight_fake=FLAGS.proto_weight_fake, 208 | dist_func=FLAGS.dist_func, 209 | gamma=FLAGS.gamma, 210 | proto_num=FLAGS.proto_num, 211 | margin=FLAGS.margin, 212 | order_idx=FLAGS.order_idx, 213 | proto_select_criterion=FLAGS.proto_select_criterion, 214 | order=order, 215 | test_interval=FLAGS.test_interval, 216 | proto_importance=FLAGS.proto_importance, 217 | improved_finetune_type=FLAGS.improved_finetune_type, 218 | improved_finetune_noise=FLAGS.improved_finetune_noise, 219 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 220 | dim=FLAGS.dim, 221 | rgb=FLAGS.rgb, 222 | center_type=FLAGS.center_type, 223 | fixed_center_idx=FLAGS.fixed_center_idx, 224 | num_samples_per_class=FLAGS.num_samples_per_class, 225 | exemplars_dual_use=FLAGS.exemplars_dual_use, 226 | anti_imbalance=FLAGS.anti_imbalance, 227 | train_rel_center=FLAGS.train_rel_center, 228 | rigorous=FLAGS.rigorous) 229 | 230 | print(ProtoGAN.model_dir_static(FLAGS)) 231 | 232 | ''' 233 | Train generative model(GAN) 234 | ''' 235 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 236 | 237 | if FLAGS.vis_result: 238 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 239 | num_iters=FLAGS.iters, vis_fid=(not FLAGS.classification_only)) 240 | 241 | sess_fid.close() 242 | stop_time = time.time() 243 | print('TOTAL RUNNING TIME: %f' % (stop_time - start_time)) 244 | with open(os.path.join(ProtoGAN.model_dir_static(FLAGS), 'running_time.txt'), 'w') as fout: 245 | fout.write('%f' % (stop_time - start_time)) 246 | 247 | 248 | if __name__ == '__main__': 249 | tf.app.run() 250 | -------------------------------------------------------------------------------- /svhn_train_lowerbound.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | tf.set_random_seed(1993) 6 | 7 | import numpy as np 8 | 9 | np.random.seed(1993) 10 | import os 11 | import pprint 12 | 13 | from gan.model_svhn_cnn import ClsNet 14 | 15 | from utils.visualize_result_single import vis_acc_and_fid 16 | 17 | flags = tf.app.flags 18 | 19 | flags.DEFINE_string("dataset", "svhn", "The name of dataset [svhn]") 20 | 21 | # Hyperparameters 22 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 23 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 24 | flags.DEFINE_integer("output_dim", 3 * 32 * 32, "Number of pixels in MNIST/fashion-MNIST (3*32*32) [3072]") 25 | flags.DEFINE_integer("dim", 64, "GAN dim") 26 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 27 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.0') 28 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.9') 29 | flags.DEFINE_boolean("finetune", True, 'if gan finetuned from the previous model') 30 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 31 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 32 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 33 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 34 | 35 | # Add how many classes every time 36 | flags.DEFINE_integer('nb_cl', 2, '') 37 | 38 | # DEBUG 39 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 40 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 41 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 42 | flags.DEFINE_integer("test_interval", 500, 43 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 44 | 45 | # Visualize 46 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 47 | 48 | FLAGS = flags.FLAGS 49 | 50 | pp = pprint.PrettyPrinter() 51 | 52 | 53 | def main(_): 54 | pp.pprint(flags.FLAGS.__flags) 55 | 56 | from utils import svhn 57 | raw_images_train, train_labels, train_one_hot_labels, order = svhn.load_data(kind='train', 58 | order_idx=FLAGS.order_idx) 59 | raw_images_test, test_labels, test_one_hot_labels, order = svhn.load_data(kind='test', order_idx=FLAGS.order_idx) 60 | 61 | # Total training samples 62 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 63 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 64 | 65 | run_config = tf.ConfigProto() 66 | run_config.gpu_options.allow_growth = True 67 | 68 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 69 | result_dir = os.path.join('result', method_name) 70 | print('Result dir: %s' % result_dir) 71 | 72 | ''' 73 | Class Incremental Learning 74 | ''' 75 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 76 | print('Adding %d categories every time' % FLAGS.nb_cl) 77 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 78 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 79 | 80 | to_category_idx = category_idx + FLAGS.nb_cl - 1 81 | if FLAGS.nb_cl == 1: 82 | print('Adding Category ' + str(category_idx + 1)) 83 | else: 84 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 85 | 86 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 87 | category_idx <= train_labels[idx] <= to_category_idx] 88 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 89 | test_labels[idx] <= to_category_idx] 90 | 91 | train_x = raw_images_train[train_indices_reals, :] 92 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx + 1)] 93 | test_x = raw_images_test[test_indices, :] 94 | test_y = test_one_hot_labels[test_indices, :(to_category_idx + 1)] 95 | 96 | graph = tf.Graph() 97 | sess = tf.Session(config=run_config, graph=graph) 98 | llgan_obj = ClsNet(sess, graph, 99 | dataset_name=FLAGS.dataset, 100 | batch_size=FLAGS.batch_size, 101 | output_dim=FLAGS.output_dim, 102 | iters=FLAGS.iters, 103 | result_dir=result_dir, 104 | adam_lr=FLAGS.adam_lr, 105 | adam_beta1=FLAGS.adam_beta1, 106 | adam_beta2=FLAGS.adam_beta2, 107 | nb_cl=FLAGS.nb_cl, 108 | nb_output=(to_category_idx + 1), 109 | order_idx=FLAGS.order_idx, 110 | order=order, 111 | test_interval=FLAGS.test_interval, 112 | finetune=FLAGS.finetune, 113 | improved_finetune=FLAGS.improved_finetune, 114 | improved_finetune_noise=FLAGS.improved_finetune_noise, 115 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 116 | improved_finetune_type=FLAGS.improved_finetune_type, 117 | dim=FLAGS.dim) 118 | 119 | print(llgan_obj.model_dir) 120 | 121 | ''' 122 | Train generative model(GAN) 123 | ''' 124 | if llgan_obj.check_model(to_category_idx): 125 | model_exist = True 126 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 127 | category_idx + 1, to_category_idx + 1)) 128 | else: 129 | model_exist = False 130 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 131 | category_idx + 1, to_category_idx + 1)) 132 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 133 | 134 | if FLAGS.vis_result: 135 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 136 | num_iters=FLAGS.iters, vis_fid=False) 137 | 138 | 139 | if __name__ == '__main__': 140 | tf.app.run() 141 | -------------------------------------------------------------------------------- /svhn_train_mergan.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | tf.set_random_seed(1993) 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | np.random.seed(1993) 11 | import os 12 | import pprint 13 | 14 | from gan.model_svhn_mergan import MeRGAN 15 | 16 | from utils import fid 17 | from utils.visualize_result_single import vis_acc_and_fid 18 | 19 | flags = tf.app.flags 20 | 21 | flags.DEFINE_string("dataset", "svhn", "The name of dataset [svhn]") 22 | 23 | # Hyperparameters 24 | flags.DEFINE_integer("lambda_param", 10, "Gradient penalty lambda hyperparameter [10]") 25 | flags.DEFINE_integer("critic_iters", 1, 26 | "How many critic iterations per generator iteration [1 for DCGAN and 5 for WGAN-GP]") 27 | flags.DEFINE_integer("class_iters", 1, "How many class iterations per generator iteration") 28 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 29 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 30 | flags.DEFINE_integer("output_dim", 3 * 32 * 32, "Number of pixels in MNIST/fashion-MNIST (1*28*28) [768]") 31 | flags.DEFINE_integer("dim", 64, "GAN dim") 32 | flags.DEFINE_string("mode", 'dcgan', "Valid options are dcgan or wgan-gp") 33 | flags.DEFINE_float("acgan_scale", 1.0, "") 34 | flags.DEFINE_float("acgan_scale_g", 0.1, "") 35 | flags.DEFINE_integer("gan_save_interval", 5000, 'interval to save a checkpoint(number of iters)') 36 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 37 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.0') 38 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.9') 39 | flags.DEFINE_boolean("finetune", True, 'if gan finetuned from the previous model') 40 | flags.DEFINE_boolean("improved_finetune", True, 'if gan finetuned from the previous model') 41 | flags.DEFINE_string("improved_finetune_type", 'v2', '') 42 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 43 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 44 | flags.DEFINE_boolean("classification_only", False, "True for classification only") 45 | flags.DEFINE_boolean("pre_aloc_class", False, '') 46 | flags.DEFINE_boolean("num_classes", 10, 'number of classes in SVHN') 47 | flags.DEFINE_integer("num_samples_per_class", -1, "number of samples per class in the training set") 48 | 49 | # ablation study 50 | flags.DEFINE_boolean("use_protos", False, "use prototypes of IntroGAN for ablation studies") 51 | flags.DEFINE_string("protos_path", 52 | "protogan_v1_3_8/%s_order_%d/nb_cl_%d/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_%d_weight_0.000000_0.000000_squared_l2_0.010000/gan_dim_64/finetune_improved_v2_noise_0.5_exemplars_dual_use_%d", 53 | "the path of prototypes of IntroGAN") 54 | flags.DEFINE_integer("protos_num", 20, "number of prototypes") 55 | flags.DEFINE_integer("protos_importance", 1, "relative importance of protos versus ordinary samples") 56 | 57 | # Add how many classes every time 58 | flags.DEFINE_integer('nb_cl', 2, '') 59 | 60 | # DEBUG 61 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 62 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 63 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 64 | flags.DEFINE_integer("test_interval", 500, 65 | "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 66 | 67 | # Visualize 68 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 69 | 70 | FLAGS = flags.FLAGS 71 | 72 | pp = pprint.PrettyPrinter() 73 | 74 | 75 | def main(_): 76 | pp.pprint(flags.FLAGS.__flags) 77 | 78 | from utils import svhn 79 | raw_images_train, train_labels, train_one_hot_labels, order = svhn.load_data(kind='train', 80 | order_idx=FLAGS.order_idx, 81 | num_samples_per_class=FLAGS.num_samples_per_class) 82 | raw_images_test, test_labels, test_one_hot_labels, order = svhn.load_data(kind='test', order_idx=FLAGS.order_idx) 83 | 84 | # Total training samples 85 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 86 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 87 | 88 | NUM_TRAIN_SAMPLES_PER_CLASS = NUM_TRAIN_SAMPLES_TOTAL / FLAGS.num_classes 89 | 90 | run_config = tf.ConfigProto() 91 | run_config.gpu_options.allow_growth = True 92 | 93 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 94 | result_dir = os.path.join('result', method_name) 95 | FLAGS.result_dir = result_dir 96 | print('Result dir: %s' % result_dir) 97 | 98 | graph_fid = tf.Graph() 99 | with graph_fid.as_default(): 100 | inception_path = fid.check_or_download_inception('tmp/imagenet') 101 | fid.create_inception_graph(inception_path) 102 | sess_fid = tf.Session(config=run_config, graph=graph_fid) 103 | 104 | ''' 105 | Class Incremental Learning 106 | ''' 107 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 108 | print('Adding %d categories every time' % FLAGS.nb_cl) 109 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 110 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 111 | 112 | to_category_idx = category_idx + FLAGS.nb_cl - 1 113 | if FLAGS.nb_cl == 1: 114 | print('Adding Category ' + str(category_idx + 1)) 115 | else: 116 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 117 | 118 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 119 | category_idx <= train_labels[idx] <= to_category_idx] 120 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 121 | test_labels[idx] <= to_category_idx] 122 | 123 | train_x = raw_images_train[train_indices_reals, :] 124 | train_y = train_one_hot_labels[train_indices_reals, 125 | :(FLAGS.num_classes if FLAGS.pre_aloc_class else to_category_idx + 1)] 126 | test_x = raw_images_test[test_indices, :] 127 | test_y = test_one_hot_labels[test_indices, 128 | :(FLAGS.num_classes if FLAGS.pre_aloc_class else to_category_idx + 1)] 129 | 130 | ''' 131 | Check model exist 132 | ''' 133 | if MeRGAN.check_model(FLAGS, to_category_idx): 134 | model_exist = True 135 | print(" [*] Model of Class %d-%d exists. Skip the training process" % ( 136 | category_idx + 1, to_category_idx + 1)) 137 | else: 138 | model_exist = False 139 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % ( 140 | category_idx + 1, to_category_idx + 1)) 141 | 142 | ''' 143 | Train generative model(DC-GAN) 144 | ''' 145 | # Mixed with generated samples of old classes 146 | num_old_classes = to_category_idx + 1 - FLAGS.nb_cl 147 | if num_old_classes > 0: # first session or not 148 | if not model_exist: 149 | train_y_gens = np.repeat(range(num_old_classes), [NUM_TRAIN_SAMPLES_PER_CLASS]) 150 | train_y_gens_one_hot = np.eye(FLAGS.num_classes if FLAGS.pre_aloc_class else num_old_classes)[ 151 | train_y_gens.reshape(-1)] 152 | llgan_obj.load(category_idx - 1) 153 | 154 | # train_x_gens, _, _ = llgan_obj.test(NUM_TRAIN_SAMPLES_PER_CLASS * num_old_classes, train_y_gens_one_hot) 155 | 156 | train_x_gens = [] 157 | for start_idx in tqdm(range(0, len(train_y_gens_one_hot), FLAGS.batch_size), desc='Generative replay'): 158 | train_y_gens_one_hot_cur_batch = train_y_gens_one_hot[start_idx: start_idx + FLAGS.batch_size] 159 | train_x_gens_batch, _, _ = llgan_obj.test(len(train_y_gens_one_hot_cur_batch), 160 | train_y_gens_one_hot_cur_batch) 161 | train_x_gens.extend(train_x_gens_batch) 162 | 163 | train_x = np.concatenate((train_x, np.uint8(train_x_gens))) 164 | train_y = np.concatenate((train_y, 165 | np.eye(FLAGS.num_classes if FLAGS.pre_aloc_class else to_category_idx + 1)[ 166 | train_y_gens.reshape(-1)])) 167 | 168 | sess_gan.close() 169 | del sess_gan 170 | 171 | graph_gen = tf.Graph() 172 | sess_gan = tf.Session(config=run_config, graph=graph_gen) 173 | 174 | llgan_obj = MeRGAN(sess_gan, graph_gen, sess_fid, 175 | dataset=FLAGS.dataset, 176 | mode=FLAGS.mode, 177 | batch_size=FLAGS.batch_size, 178 | output_dim=FLAGS.output_dim, 179 | lambda_param=FLAGS.lambda_param, 180 | critic_iters=FLAGS.critic_iters, 181 | class_iters=FLAGS.class_iters, 182 | iters=FLAGS.iters, 183 | result_dir=FLAGS.result_dir, 184 | checkpoint_interval=FLAGS.gan_save_interval, 185 | adam_lr=FLAGS.adam_lr, 186 | adam_beta1=FLAGS.adam_beta1, 187 | adam_beta2=FLAGS.adam_beta2, 188 | acgan_scale=FLAGS.acgan_scale, 189 | acgan_scale_g=FLAGS.acgan_scale_g, 190 | finetune=FLAGS.finetune, 191 | improved_finetune=FLAGS.improved_finetune, 192 | nb_cl=FLAGS.nb_cl, 193 | num_seen_classes=(to_category_idx + 1), 194 | classification_only=FLAGS.classification_only, 195 | dim=FLAGS.dim, 196 | order_idx=FLAGS.order_idx, 197 | order=order, 198 | test_interval=FLAGS.test_interval, 199 | num_classes=FLAGS.num_classes, 200 | pre_aloc_class=FLAGS.pre_aloc_class, 201 | improved_finetune_noise=FLAGS.improved_finetune_noise, 202 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 203 | num_samples_per_class=FLAGS.num_samples_per_class, 204 | use_protos=FLAGS.use_protos, 205 | protos_path=FLAGS.protos_path, 206 | protos_num=FLAGS.protos_num, 207 | protos_importance=FLAGS.protos_importance, 208 | improved_finetune_type=FLAGS.improved_finetune_type, 209 | ) 210 | 211 | print(MeRGAN.model_dir_static(FLAGS)) 212 | 213 | ''' 214 | Train generative model(GAN) 215 | ''' 216 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 217 | 218 | if FLAGS.vis_result: 219 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 220 | num_iters=FLAGS.iters, vis_fid=(not FLAGS.classification_only)) 221 | 222 | sess_fid.close() 223 | 224 | 225 | if __name__ == '__main__': 226 | tf.app.run() 227 | -------------------------------------------------------------------------------- /svhn_train_upperbound.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import tensorflow as tf 4 | tf.set_random_seed(1993) 5 | 6 | import numpy as np 7 | np.random.seed(1993) 8 | import os 9 | import pprint 10 | 11 | from gan.model_svhn_cnn import ClsNet 12 | 13 | from utils.visualize_result_single import vis_acc_and_fid 14 | 15 | flags = tf.app.flags 16 | 17 | flags.DEFINE_string("dataset", "svhn", "The name of dataset [svhn]") 18 | 19 | # Hyperparameters 20 | flags.DEFINE_integer("batch_size", 100, "The size of batch images") 21 | flags.DEFINE_integer("iters", 10000, "How many generator iters to train") 22 | flags.DEFINE_integer("output_dim", 3*32*32, "Number of pixels in MNIST/fashion-MNIST (3*32*32) [3072]") 23 | flags.DEFINE_integer("dim", 64, "GAN dim") 24 | flags.DEFINE_float("adam_lr", 2e-4, 'default: 1e-4, 2e-4, 3e-4') 25 | flags.DEFINE_float("adam_beta1", 0.5, 'default: 0.0') 26 | flags.DEFINE_float("adam_beta2", 0.999, 'default: 0.9') 27 | flags.DEFINE_boolean("finetune", False, '') 28 | flags.DEFINE_boolean("improved_finetune", True, '') 29 | flags.DEFINE_boolean("improved_finetune_noise", True, 'use the same weight or add some variation') 30 | flags.DEFINE_float("improved_finetune_noise_level", 0.5, 'noise level') 31 | 32 | # Add how many classes every time 33 | flags.DEFINE_integer('nb_cl', 2, '') 34 | 35 | # DEBUG 36 | flags.DEFINE_integer('from_class_idx', 0, 'starting category_idx') 37 | flags.DEFINE_integer('to_class_idx', 9, 'ending category_idx') 38 | flags.DEFINE_integer('order_idx', 1, 'class orders [1~5]') 39 | flags.DEFINE_integer("test_interval", 500, "test interval (since the total training consists of 10,000 iters, we need 20 points, so 10,000 / 20 = 500)") 40 | 41 | # Visualize 42 | flags.DEFINE_boolean('vis_result', True, 'visualize accuracy and fid figure') 43 | 44 | FLAGS = flags.FLAGS 45 | 46 | pp = pprint.PrettyPrinter() 47 | 48 | 49 | def main(_): 50 | 51 | pp.pprint(flags.FLAGS.__flags) 52 | 53 | from utils import svhn 54 | raw_images_train, train_labels, train_one_hot_labels, order = svhn.load_data(kind='train', order_idx=FLAGS.order_idx) 55 | raw_images_test, test_labels, test_one_hot_labels, order = svhn.load_data(kind='test', order_idx=FLAGS.order_idx) 56 | 57 | # Total training samples 58 | NUM_TRAIN_SAMPLES_TOTAL = len(raw_images_train) 59 | NUM_TEST_SAMPLES_TOTAL = len(raw_images_test) 60 | 61 | run_config = tf.ConfigProto() 62 | run_config.gpu_options.allow_growth = True 63 | 64 | method_name = '_'.join(os.path.basename(__file__).split('.')[0].split('_')[2:]) 65 | result_dir = os.path.join('result', method_name) 66 | print('Result dir: %s' % result_dir) 67 | 68 | ''' 69 | Class Incremental Learning 70 | ''' 71 | print('Starting from category ' + str(FLAGS.from_class_idx + 1) + ' to ' + str(FLAGS.to_class_idx + 1)) 72 | print('Adding %d categories every time' % FLAGS.nb_cl) 73 | assert (FLAGS.from_class_idx % FLAGS.nb_cl == 0) 74 | for category_idx in range(FLAGS.from_class_idx, FLAGS.to_class_idx + 1, FLAGS.nb_cl): 75 | 76 | to_category_idx = category_idx + FLAGS.nb_cl - 1 77 | if FLAGS.nb_cl == 1: 78 | print('Adding Category ' + str(category_idx + 1)) 79 | else: 80 | print('Adding Category %d-%d' % (category_idx + 1, to_category_idx + 1)) 81 | 82 | train_indices_reals = [idx for idx in range(NUM_TRAIN_SAMPLES_TOTAL) if 83 | train_labels[idx] <= to_category_idx] 84 | test_indices = [idx for idx in range(NUM_TEST_SAMPLES_TOTAL) if 85 | test_labels[idx] <= to_category_idx] 86 | 87 | train_x = raw_images_train[train_indices_reals, :] 88 | train_y = train_one_hot_labels[train_indices_reals, :(to_category_idx+1)] 89 | test_x = raw_images_test[test_indices, :] 90 | test_y = test_one_hot_labels[test_indices, :(to_category_idx+1)] 91 | 92 | graph = tf.Graph() 93 | sess = tf.Session(config=run_config, graph=graph) 94 | llgan_obj = ClsNet(sess, graph, 95 | dataset_name=FLAGS.dataset, 96 | batch_size=FLAGS.batch_size, 97 | output_dim=FLAGS.output_dim, 98 | iters=FLAGS.iters, 99 | result_dir=result_dir, 100 | adam_lr=FLAGS.adam_lr, 101 | adam_beta1=FLAGS.adam_beta1, 102 | adam_beta2=FLAGS.adam_beta2, 103 | nb_cl=FLAGS.nb_cl, 104 | nb_output=(to_category_idx + 1), 105 | order_idx=FLAGS.order_idx, 106 | order=order, 107 | test_interval=FLAGS.test_interval, 108 | finetune=FLAGS.finetune, 109 | improved_finetune=FLAGS.improved_finetune, 110 | improved_finetune_noise=FLAGS.improved_finetune_noise, 111 | improved_finetune_noise_level=FLAGS.improved_finetune_noise_level, 112 | dim=FLAGS.dim) 113 | 114 | print(llgan_obj.model_dir) 115 | 116 | ''' 117 | Train generative model(GAN) 118 | ''' 119 | if llgan_obj.check_model(to_category_idx): 120 | model_exist = True 121 | print(" [*] Model of Class %d-%d exists. Skip the training process" % (category_idx + 1, to_category_idx + 1)) 122 | else: 123 | model_exist = False 124 | print(" [*] Model of Class %d-%d does not exist. Start the training process" % (category_idx + 1, to_category_idx + 1)) 125 | llgan_obj.train(train_x, train_y, test_x, test_y, to_category_idx, model_exist=model_exist) 126 | 127 | if FLAGS.vis_result: 128 | vis_acc_and_fid(llgan_obj.model_dir, FLAGS.dataset, FLAGS.nb_cl, test_interval=FLAGS.test_interval, 129 | num_iters=FLAGS.iters, vis_fid=False) 130 | 131 | 132 | if __name__ == '__main__': 133 | tf.app.run() 134 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TonyPod/IntroGAN/80b1f4fdc3f3b034c7ae2005dc328286e81cddd8/utils/__init__.py -------------------------------------------------------------------------------- /utils/fashion_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @time: 10/27/18 1:37 PM 5 | @author: Chen He 6 | @site: 7 | @file: fashion_mnist.py 8 | @description: 9 | """ 10 | 11 | import gzip 12 | import numpy as np 13 | import os 14 | import pickle 15 | 16 | import tensorflow as tf 17 | import tensorflow.contrib.slim as slim 18 | 19 | from vgg_preprocessing import preprocess_image 20 | from utils.resnet_v1_mod import resnet_arg_scope 21 | from utils.resnet_v1_mod import resnet_v1_50 22 | 23 | NUM_CLASSES = 10 24 | 25 | 26 | def load_feat(path='datasets/fashion-mnist', kind='train', order_idx=1): 27 | images, labels, one_hot_labels = load_data(path, kind, order_idx) 28 | 29 | pickle_file = os.path.join(path, '%s_feat_order_%d.pkl' % (kind, order_idx)) 30 | if os.path.exists(pickle_file): 31 | with open(pickle_file, 'rb') as fin: 32 | feats = pickle.load(fin) 33 | else: 34 | with tf.Session() as sess: 35 | real_data = tf.placeholder(tf.float32, shape=[None, 784]) 36 | inputs = tf.map_fn( 37 | lambda img: preprocess_image(tf.tile(tf.reshape(img, (28, 28, 1)), [1, 1, 3]), 224, 224, 38 | is_training=False), real_data) 39 | with slim.arg_scope(resnet_arg_scope()): 40 | _, end_points = resnet_v1_50(inputs, 1000, is_training=False) 41 | feat_tensor = slim.flatten(end_points['pool5']) 42 | 43 | sess.run(tf.global_variables_initializer()) 44 | var_list = tf.trainable_variables() 45 | saver = tf.train.Saver(var_list=var_list) 46 | saver.restore(sess, 'utils/resnet_v1_50.ckpt') 47 | 48 | feats = [] 49 | for i in range(0, len(images), 100): 50 | feat_batch = sess.run(feat_tensor, feed_dict={real_data: images[i:i+100]}) 51 | feats.extend(feat_batch) 52 | feats = np.array(feats) 53 | 54 | with open(pickle_file, 'wb') as fout: 55 | pickle.dump(feats, fout) 56 | 57 | return feats, labels, one_hot_labels 58 | 59 | 60 | def load_data(path='datasets/fashion-mnist', kind='train', order_idx=1): 61 | 62 | """Load MNIST data from `path`""" 63 | labels_path = os.path.join(path, 64 | '%s-labels-idx1-ubyte.gz' 65 | % kind) 66 | images_path = os.path.join(path, 67 | '%s-images-idx3-ubyte.gz' 68 | % kind) 69 | 70 | with gzip.open(labels_path, 'rb') as lbpath: 71 | labels = np.frombuffer(lbpath.read(), dtype=np.uint8, 72 | offset=8) 73 | 74 | with gzip.open(images_path, 'rb') as imgpath: 75 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 76 | offset=16).reshape(len(labels), 784) 77 | 78 | order = [] 79 | with open(os.path.join(path, 'order_%d.txt' % order_idx)) as file_in: 80 | for line in file_in.readlines(): 81 | order.append(int(line)) 82 | order = np.array(order) 83 | 84 | labels = change_order(labels, order=order) 85 | 86 | one_hot_labels = np.eye(NUM_CLASSES, dtype=float)[labels] 87 | 88 | return images, labels, one_hot_labels, order 89 | 90 | 91 | def change_order(cls, order): 92 | order_dict = dict() 93 | for i in range(len(order)): 94 | order_dict[order[i]] = i 95 | 96 | reordered_cls = np.array([order_dict[cls[i]] for i in range(len(cls))]) 97 | return reordered_cls -------------------------------------------------------------------------------- /utils/fid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' Calculates the Frechet Inception Distance (FID) to evaluate GANs. 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectively. 15 | 16 | See --help to see further details. 17 | ''' 18 | 19 | from __future__ import absolute_import, division, print_function 20 | import numpy as np 21 | import os 22 | import gzip, pickle 23 | import tensorflow as tf 24 | from scipy.misc import imread 25 | from scipy import linalg 26 | import pathlib 27 | import urllib 28 | import warnings 29 | 30 | 31 | class InvalidFIDException(Exception): 32 | pass 33 | 34 | 35 | def create_inception_graph(pth): 36 | """Creates a graph from saved GraphDef file.""" 37 | # Creates graph from saved graph_def.pb. 38 | with tf.gfile.FastGFile(pth, 'rb') as f: 39 | graph_def = tf.GraphDef() 40 | graph_def.ParseFromString(f.read()) 41 | _ = tf.import_graph_def(graph_def, name='FID_Inception_Net') 42 | 43 | 44 | # ------------------------------------------------------------------------------- 45 | 46 | 47 | # code for handling inception net derived from 48 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 49 | def _get_inception_layer(sess): 50 | """Prepares inception net for batched usage and returns pool_3 layer. """ 51 | layername = 'FID_Inception_Net/pool_3:0' 52 | pool3 = sess.graph.get_tensor_by_name(layername) 53 | ops = pool3.graph.get_operations() 54 | for op_idx, op in enumerate(ops): 55 | for o in op.outputs: 56 | shape = o.get_shape() 57 | if shape._dims is not None: 58 | shape = [s.value for s in shape] 59 | new_shape = [] 60 | for j, s in enumerate(shape): 61 | if s == 1 and j == 0: 62 | new_shape.append(None) 63 | else: 64 | new_shape.append(s) 65 | o._shape = tf.TensorShape(new_shape) 66 | return pool3 67 | 68 | 69 | # ------------------------------------------------------------------------------- 70 | 71 | 72 | def _get_input_layer(sess): 73 | """Prepares inception net for batched usage and returns pool_3 layer. """ 74 | layername = 'FID_Inception_Net/ExpandDims:0' 75 | input_layer = sess.graph.get_tensor_by_name(layername) 76 | return input_layer 77 | 78 | 79 | def get_activations(images, sess, batch_size=50, verbose=False): 80 | """Calculates the activations of the pool_3 layer for all images. 81 | 82 | Params: 83 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 84 | must lie between 0 and 256. 85 | -- sess : current session 86 | -- batch_size : the images numpy array is split into batches with batch size 87 | batch_size. A reasonable batch size depends on the disposable hardware. 88 | -- verbose : If set to True and parameter out_step is given, the number of calculated 89 | batches is reported. 90 | Returns: 91 | -- A numpy array of dimension (num images, 2048) that contains the 92 | activations of the given tensor when feeding inception with the query tensor. 93 | """ 94 | inception_layer = _get_inception_layer(sess) 95 | d0 = images.shape[0] 96 | if batch_size > d0: 97 | print("warning: batch size is bigger than the data size. setting batch size to data size") 98 | batch_size = d0 99 | n_batches = d0 // batch_size 100 | n_used_imgs = n_batches * batch_size 101 | pred_arr = np.empty((n_used_imgs, 2048)) 102 | for i in range(n_batches): 103 | if verbose: 104 | print("\rPropagating batch %d/%d" % (i + 1, n_batches), end="", flush=True) 105 | start = i * batch_size 106 | end = start + batch_size 107 | batch = images[start:end] 108 | pred = sess.run(inception_layer, {'FID_Inception_Net/ExpandDims:0': batch}) 109 | pred_arr[start:end] = pred.reshape(batch_size, -1) 110 | if verbose: 111 | print(" done") 112 | return pred_arr 113 | 114 | 115 | # ------------------------------------------------------------------------------- 116 | 117 | 118 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 119 | """Numpy implementation of the Frechet Distance. 120 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 121 | and X_2 ~ N(mu_2, C_2) is 122 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 123 | 124 | Stable version by Dougal J. Sutherland. 125 | 126 | Params: 127 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 128 | inception net ( like returned by the function 'get_predictions') 129 | for generated samples. 130 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 131 | on an representive data set. 132 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 133 | generated samples. 134 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 135 | precalcualted on an representive data set. 136 | 137 | Returns: 138 | -- : The Frechet Distance. 139 | """ 140 | 141 | mu1 = np.atleast_1d(mu1) 142 | mu2 = np.atleast_1d(mu2) 143 | 144 | sigma1 = np.atleast_2d(sigma1) 145 | sigma2 = np.atleast_2d(sigma2) 146 | 147 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 148 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 149 | 150 | diff = mu1 - mu2 151 | 152 | # product might be almost singular 153 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 154 | if not np.isfinite(covmean).all(): 155 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 156 | warnings.warn(msg) 157 | offset = np.eye(sigma1.shape[0]) * eps 158 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 159 | 160 | # numerical error might give slight imaginary component 161 | if np.iscomplexobj(covmean): 162 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 163 | m = np.max(np.abs(covmean.imag)) 164 | raise ValueError("Imaginary component {}".format(m)) 165 | covmean = covmean.real 166 | 167 | tr_covmean = np.trace(covmean) 168 | 169 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 170 | 171 | 172 | # ------------------------------------------------------------------------------- 173 | 174 | 175 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): 176 | """Calculation of the statistics used by the FID. 177 | Params: 178 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 179 | must lie between 0 and 255. 180 | -- sess : current session 181 | -- batch_size : the images numpy array is split into batches with batch size 182 | batch_size. A reasonable batch size depends on the available hardware. 183 | -- verbose : If set to True and parameter out_step is given, the number of calculated 184 | batches is reported. 185 | Returns: 186 | -- mu : The mean over samples of the activations of the pool_3 layer of 187 | the inception model. 188 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 189 | the inception model. 190 | """ 191 | act = get_activations(images, sess, batch_size, verbose) 192 | mu = np.mean(act, axis=0) 193 | sigma = np.cov(act, rowvar=False) 194 | return mu, sigma 195 | 196 | 197 | # ------------------------------------------------------------------------------- 198 | 199 | 200 | # ------------------------------------------------------------------------------- 201 | # The following functions aren't needed for calculating the FID 202 | # they're just here to make this module work as a stand-alone script 203 | # for calculating FID scores 204 | # ------------------------------------------------------------------------------- 205 | def check_or_download_inception(inception_path): 206 | ''' Checks if the path to the inception file is valid, or downloads 207 | the file if it is not present. ''' 208 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 209 | if inception_path is None: 210 | inception_path = './inception_model' 211 | inception_path = pathlib.Path(inception_path) 212 | model_file = inception_path / 'classify_image_graph_def.pb' 213 | if not model_file.exists(): 214 | print("Downloading Inception model") 215 | from urllib import request 216 | import tarfile 217 | fn, _ = request.urlretrieve(INCEPTION_URL) 218 | with tarfile.open(fn, mode='r') as f: 219 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 220 | return str(model_file) 221 | 222 | 223 | def _handle_path(path, sess): 224 | if path.endswith('.npz'): 225 | f = np.load(path) 226 | m, s = f['mu'][:], f['sigma'][:] 227 | f.close() 228 | else: 229 | path = pathlib.Path(path) 230 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 231 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 232 | m, s = calculate_activation_statistics(x, sess) 233 | return m, s 234 | 235 | 236 | def calculate_fid_given_paths(paths, inception_path='utils/inception_model'): 237 | ''' Calculates the FID of two paths. ''' 238 | inception_path = check_or_download_inception(inception_path) 239 | 240 | for p in paths: 241 | if not os.path.exists(p): 242 | raise RuntimeError("Invalid path: %s" % p) 243 | 244 | create_inception_graph(str(inception_path)) 245 | with tf.Session() as sess: 246 | sess.run(tf.global_variables_initializer()) 247 | m1, s1 = _handle_path(paths[0], sess) 248 | m2, s2 = _handle_path(paths[1], sess) 249 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 250 | return fid_value 251 | 252 | 253 | def calculate_fid_given_paths_with_sess(sess, paths): 254 | ''' Calculates the FID of two paths. ''' 255 | for p in paths: 256 | if not os.path.exists(p): 257 | raise RuntimeError("Invalid path: %s" % p) 258 | 259 | m1, s1 = _handle_path(paths[0], sess) 260 | m2, s2 = _handle_path(paths[1], sess) 261 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 262 | return fid_value 263 | 264 | 265 | if __name__ == "__main__": 266 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 267 | 268 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 269 | parser.add_argument("path", type=str, nargs=2, 270 | help='Path to the generated images or to .npz statistic files') 271 | parser.add_argument("-i", "--inception", type=str, default=None, 272 | help='Path to Inception model (will be downloaded if not provided)') 273 | parser.add_argument("--gpu", default="", type=str, 274 | help='GPU to use (leave blank for CPU only)') 275 | args = parser.parse_args() 276 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 277 | fid_value = calculate_fid_given_paths(args.path, args.inception) 278 | print("FID: ", fid_value) 279 | -------------------------------------------------------------------------------- /utils/mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @time: 10/27/18 1:37 PM 5 | @author: Chen He 6 | @site: 7 | @file: mnist.py 8 | @description: 9 | """ 10 | import gzip 11 | import os 12 | 13 | import numpy as np 14 | 15 | NUM_CLASSES = 10 16 | 17 | 18 | def load_data(path='datasets/mnist', kind='train', order_idx=1): 19 | """Load MNIST data from `path`""" 20 | labels_path = os.path.join(path, 21 | '%s-labels-idx1-ubyte.gz' 22 | % kind) 23 | images_path = os.path.join(path, 24 | '%s-images-idx3-ubyte.gz' 25 | % kind) 26 | 27 | with gzip.open(labels_path, 'rb') as lbpath: 28 | labels = np.frombuffer(lbpath.read(), dtype=np.uint8, 29 | offset=8) 30 | 31 | with gzip.open(images_path, 'rb') as imgpath: 32 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 33 | offset=16).reshape(len(labels), 784) 34 | 35 | order = [] 36 | with open(os.path.join(path, 'order_%d.txt' % order_idx)) as file_in: 37 | for line in file_in.readlines(): 38 | order.append(int(line)) 39 | order = np.array(order) 40 | 41 | labels = change_order(labels, order=order) 42 | 43 | one_hot_labels = np.eye(NUM_CLASSES, dtype=float)[labels] 44 | 45 | return images, labels, one_hot_labels, order 46 | 47 | 48 | def change_order(cls, order): 49 | order_dict = dict() 50 | for i in range(len(order)): 51 | order_dict[order[i]] = i 52 | 53 | reordered_cls = np.array([order_dict[cls[i]] for i in range(len(cls))]) 54 | return reordered_cls 55 | -------------------------------------------------------------------------------- /utils/resnet_v1_mod.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.python.slim.nets import 38 | resnet_v1 39 | 40 | ResNet-101 for image classification into 1000 classes: 41 | 42 | # inputs has shape [batch, 224, 224, 3] 43 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 44 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 45 | 46 | ResNet-101 for semantic segmentation into 21 classes: 47 | 48 | # inputs has shape [batch, 513, 513, 3] 49 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 50 | net, end_points = resnet_v1.resnet_v1_101(inputs, 51 | 21, 52 | is_training=False, 53 | global_pool=False, 54 | output_stride=16) 55 | """ 56 | 57 | from __future__ import absolute_import 58 | from __future__ import division 59 | from __future__ import print_function 60 | 61 | from tensorflow.contrib import layers 62 | from tensorflow.contrib.framework.python.ops import add_arg_scope 63 | from tensorflow.contrib.framework.python.ops import arg_scope 64 | from tensorflow.contrib.layers.python.layers import layers as layers_lib 65 | from tensorflow.contrib.layers.python.layers import utils 66 | from tensorflow.contrib.slim.python.slim.nets import resnet_utils 67 | from tensorflow.python.ops import math_ops 68 | from tensorflow.python.ops import nn_ops 69 | from tensorflow.python.ops import variable_scope 70 | 71 | resnet_arg_scope = resnet_utils.resnet_arg_scope 72 | 73 | 74 | @add_arg_scope 75 | def bottleneck(inputs, 76 | depth, 77 | depth_bottleneck, 78 | stride, 79 | rate=1, 80 | outputs_collections=None, 81 | scope=None): 82 | """Bottleneck residual unit variant with BN after convolutions. 83 | 84 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 85 | its definition. Note that we use here the bottleneck variant which has an 86 | extra bottleneck layer. 87 | 88 | When putting together two consecutive ResNet blocks that use this unit, one 89 | should use stride = 2 in the last unit of the first block. 90 | 91 | Args: 92 | inputs: A tensor of size [batch, height, width, channels]. 93 | depth: The depth of the ResNet unit output. 94 | depth_bottleneck: The depth of the bottleneck layers. 95 | stride: The ResNet unit's stride. Determines the amount of downsampling of 96 | the units output compared to its input. 97 | rate: An integer, rate for atrous convolution. 98 | outputs_collections: Collection to add the ResNet unit output. 99 | scope: Optional variable_scope. 100 | 101 | Returns: 102 | The ResNet unit's output. 103 | """ 104 | with variable_scope.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 105 | depth_in = utils.last_dimension(inputs.get_shape(), min_rank=4) 106 | if depth == depth_in: 107 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 108 | else: 109 | shortcut = layers.conv2d( 110 | inputs, 111 | depth, [1, 1], 112 | stride=stride, 113 | activation_fn=None, 114 | scope='shortcut') 115 | 116 | residual = layers.conv2d( 117 | inputs, depth_bottleneck, [1, 1], stride=1, scope='conv1') 118 | residual = resnet_utils.conv2d_same( 119 | residual, depth_bottleneck, 3, stride, rate=rate, scope='conv2') 120 | residual = layers.conv2d( 121 | residual, depth, [1, 1], stride=1, activation_fn=None, scope='conv3') 122 | 123 | output = nn_ops.relu(shortcut + residual) 124 | 125 | return utils.collect_named_outputs(outputs_collections, sc.name, output) 126 | 127 | 128 | def resnet_v1(inputs, 129 | blocks, 130 | num_classes=None, 131 | is_training=True, 132 | global_pool=True, 133 | output_stride=None, 134 | include_root_block=True, 135 | reuse=None, 136 | scope=None): 137 | """Generator for v1 ResNet models. 138 | 139 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 140 | methods for specific model instantiations, obtained by selecting different 141 | block instantiations that produce ResNets of various depths. 142 | 143 | Training for image classification on Imagenet is usually done with [224, 224] 144 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 145 | block for the ResNets defined in [1] that have nominal stride equal to 32. 146 | However, for dense prediction tasks we advise that one uses inputs with 147 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 148 | this case the feature maps at the ResNet output will have spatial shape 149 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 150 | and corners exactly aligned with the input image corners, which greatly 151 | facilitates alignment of the features to the image. Using as input [225, 225] 152 | images results in [8, 8] feature maps at the output of the last ResNet block. 153 | 154 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 155 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 156 | have nominal stride equal to 32 and a good choice in FCN mode is to use 157 | output_stride=16 in order to increase the density of the computed features at 158 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 159 | 160 | Args: 161 | inputs: A tensor of size [batch, height_in, width_in, channels]. 162 | blocks: A list of length equal to the number of ResNet blocks. Each element 163 | is a resnet_utils.Block object describing the units in the block. 164 | num_classes: Number of predicted classes for classification tasks. If None 165 | we return the features before the logit layer. 166 | is_training: whether batch_norm layers are in training mode. 167 | global_pool: If True, we perform global average pooling before computing the 168 | logits. Set to True for image classification, False for dense prediction. 169 | output_stride: If None, then the output will be computed at the nominal 170 | network stride. If output_stride is not None, it specifies the requested 171 | ratio of input to output spatial resolution. 172 | include_root_block: If True, include the initial convolution followed by 173 | max-pooling, if False excludes it. 174 | reuse: whether or not the network and its variables should be reused. To be 175 | able to reuse 'scope' must be given. 176 | scope: Optional variable_scope. 177 | 178 | Returns: 179 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 180 | If global_pool is False, then height_out and width_out are reduced by a 181 | factor of output_stride compared to the respective height_in and width_in, 182 | else both height_out and width_out equal one. If num_classes is None, then 183 | net is the output of the last ResNet block, potentially after global 184 | average pooling. If num_classes is not None, net contains the pre-softmax 185 | activations. 186 | end_points: A dictionary from components of the network to the corresponding 187 | activation. 188 | 189 | Raises: 190 | ValueError: If the target output_stride is not valid. 191 | """ 192 | with variable_scope.variable_scope( 193 | scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 194 | end_points_collection = sc.original_name_scope + '_end_points' 195 | with arg_scope( 196 | [layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense], 197 | outputs_collections=end_points_collection): 198 | with arg_scope([layers.batch_norm], is_training=is_training): 199 | net = inputs 200 | if include_root_block: 201 | if output_stride is not None: 202 | if output_stride % 4 != 0: 203 | raise ValueError('The output_stride needs to be a multiple of 4.') 204 | output_stride /= 4 205 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 206 | net = layers_lib.max_pool2d(net, [3, 3], stride=2, scope='pool1') 207 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 208 | if global_pool: 209 | # Global average pooling. 210 | net = math_ops.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 211 | pool5 = net 212 | if num_classes is not None: 213 | net = layers.conv2d( 214 | net, 215 | num_classes, [1, 1], 216 | activation_fn=None, 217 | normalizer_fn=None, 218 | scope='logits') 219 | # Convert end_points_collection into a dictionary of end_points. 220 | end_points = utils.convert_collection_to_dict(end_points_collection) 221 | end_points['pool5'] = pool5 222 | if num_classes is not None: 223 | end_points['predictions'] = layers_lib.softmax( 224 | net, scope='predictions') 225 | return net, end_points 226 | resnet_v1.default_image_size = 224 227 | 228 | 229 | def resnet_v1_block(scope, base_depth, num_units, stride): 230 | """Helper function for creating a resnet_v1 bottleneck block. 231 | 232 | Args: 233 | scope: The scope of the block. 234 | base_depth: The depth of the bottleneck layer for each unit. 235 | num_units: The number of units in the block. 236 | stride: The stride of the block, implemented as a stride in the last unit. 237 | All other units have stride=1. 238 | 239 | Returns: 240 | A resnet_v1 bottleneck block. 241 | """ 242 | return resnet_utils.Block(scope, bottleneck, [{ 243 | 'depth': base_depth * 4, 244 | 'depth_bottleneck': base_depth, 245 | 'stride': 1 246 | }] * (num_units - 1) + [{ 247 | 'depth': base_depth * 4, 248 | 'depth_bottleneck': base_depth, 249 | 'stride': stride 250 | }]) 251 | 252 | 253 | def resnet_v1_50(inputs, 254 | num_classes=None, 255 | is_training=True, 256 | global_pool=True, 257 | output_stride=None, 258 | reuse=None, 259 | scope='resnet_v1_50'): 260 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 261 | blocks = [ 262 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 263 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 264 | resnet_v1_block('block3', base_depth=256, num_units=6, stride=2), 265 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 266 | ] 267 | return resnet_v1( 268 | inputs, 269 | blocks, 270 | num_classes, 271 | is_training, 272 | global_pool, 273 | output_stride, 274 | include_root_block=True, 275 | reuse=reuse, 276 | scope=scope) 277 | 278 | 279 | def resnet_v1_101(inputs, 280 | num_classes=None, 281 | is_training=True, 282 | global_pool=True, 283 | output_stride=None, 284 | reuse=None, 285 | scope='resnet_v1_101'): 286 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 287 | blocks = [ 288 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 289 | resnet_v1_block('block2', base_depth=128, num_units=4, stride=2), 290 | resnet_v1_block('block3', base_depth=256, num_units=23, stride=2), 291 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 292 | ] 293 | return resnet_v1( 294 | inputs, 295 | blocks, 296 | num_classes, 297 | is_training, 298 | global_pool, 299 | output_stride, 300 | include_root_block=True, 301 | reuse=reuse, 302 | scope=scope) 303 | 304 | 305 | def resnet_v1_152(inputs, 306 | num_classes=None, 307 | is_training=True, 308 | global_pool=True, 309 | output_stride=None, 310 | reuse=None, 311 | scope='resnet_v1_152'): 312 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 313 | blocks = [ 314 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 315 | resnet_v1_block('block2', base_depth=128, num_units=8, stride=2), 316 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 317 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 318 | ] 319 | return resnet_v1( 320 | inputs, 321 | blocks, 322 | num_classes, 323 | is_training, 324 | global_pool, 325 | output_stride, 326 | include_root_block=True, 327 | reuse=reuse, 328 | scope=scope) 329 | 330 | 331 | def resnet_v1_200(inputs, 332 | num_classes=None, 333 | is_training=True, 334 | global_pool=True, 335 | output_stride=None, 336 | reuse=None, 337 | scope='resnet_v1_200'): 338 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 339 | blocks = [ 340 | resnet_v1_block('block1', base_depth=64, num_units=3, stride=2), 341 | resnet_v1_block('block2', base_depth=128, num_units=24, stride=2), 342 | resnet_v1_block('block3', base_depth=256, num_units=36, stride=2), 343 | resnet_v1_block('block4', base_depth=512, num_units=3, stride=1), 344 | ] 345 | return resnet_v1( 346 | inputs, 347 | blocks, 348 | num_classes, 349 | is_training, 350 | global_pool, 351 | output_stride, 352 | include_root_block=True, 353 | reuse=reuse, 354 | scope=scope) 355 | -------------------------------------------------------------------------------- /utils/svhn.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @time: 10/27/18 1:37 PM 5 | @author: Chen He 6 | @site: 7 | @file: fashion_mnist.py 8 | @description: 9 | """ 10 | 11 | import gzip 12 | import numpy as np 13 | import os 14 | from scipy.io import loadmat 15 | import pickle 16 | 17 | NUM_CLASSES = 10 18 | 19 | 20 | def rgb2gray(rgb): 21 | r, g, b = rgb[0, :, :], rgb[1, :, :], rgb[2, :, :] 22 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 23 | 24 | return np.repeat(np.expand_dims(np.uint8(gray), axis=0), 3, axis=0) 25 | 26 | 27 | def load_data(path='datasets/svhn', kind='train', order_idx=1, rgb=True, num_samples_per_class=-1): 28 | """Load SVHN data from `path`""" 29 | 30 | file_path = os.path.join(path, '%s_32x32.mat' % kind) 31 | file_data = loadmat(file_path) 32 | 33 | labels = np.squeeze(file_data['y']) - 1 34 | images = np.transpose(file_data['X'], (3, 2, 0, 1)) 35 | if not rgb: 36 | images = np.array(map(rgb2gray, images)) 37 | images = np.reshape(images, (-1, 32 * 32 * 3)) 38 | 39 | order = [] 40 | with open(os.path.join(path, 'order_%d.txt' % order_idx)) as file_in: 41 | for line in file_in.readlines(): 42 | order.append(int(line)) 43 | order = np.array(order) 44 | 45 | labels = change_order(labels, order=order) 46 | 47 | one_hot_labels = np.eye(NUM_CLASSES, dtype=float)[labels] 48 | 49 | if kind == 'train' and num_samples_per_class > 0: 50 | indices_chosen = [] 51 | for class_idx in order: 52 | indices_chosen_cur_cls = np.random.choice(np.where(labels == class_idx)[0], num_samples_per_class, 53 | replace=False) 54 | indices_chosen.extend(indices_chosen_cur_cls) 55 | images, labels, one_hot_labels = images[indices_chosen], labels[indices_chosen], one_hot_labels[indices_chosen] 56 | 57 | return images, labels, one_hot_labels, order 58 | 59 | 60 | def change_order(cls, order): 61 | order_dict = dict() 62 | for i in range(len(order)): 63 | order_dict[order[i]] = i 64 | 65 | reordered_cls = np.array([order_dict[cls[i]] for i in range(len(cls))]) 66 | return reordered_cls 67 | -------------------------------------------------------------------------------- /utils/vgg_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images. 16 | 17 | The preprocessing steps for VGG were introduced in the following technical 18 | report: 19 | 20 | Very Deep Convolutional Networks For Large-Scale Image Recognition 21 | Karen Simonyan and Andrew Zisserman 22 | arXiv technical report, 2015 23 | PDF: http://arxiv.org/pdf/1409.1556.pdf 24 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 25 | CC-BY-4.0 26 | 27 | More information can be obtained from the VGG website: 28 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 29 | """ 30 | 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | from __future__ import print_function 34 | 35 | import tensorflow as tf 36 | 37 | slim = tf.contrib.slim 38 | 39 | _R_MEAN = 123.68 40 | _G_MEAN = 116.78 41 | _B_MEAN = 103.94 42 | 43 | _RESIZE_SIDE_MIN = 256 44 | _RESIZE_SIDE_MAX = 512 45 | 46 | 47 | def _crop(image, offset_height, offset_width, crop_height, crop_width): 48 | """Crops the given image using the provided offsets and sizes. 49 | 50 | Note that the method doesn't assume we know the input image size but it does 51 | assume we know the input image rank. 52 | 53 | Args: 54 | image: an image of shape [height, width, channels]. 55 | offset_height: a scalar tensor indicating the height offset. 56 | offset_width: a scalar tensor indicating the width offset. 57 | crop_height: the height of the cropped image. 58 | crop_width: the width of the cropped image. 59 | 60 | Returns: 61 | the cropped (and resized) image. 62 | 63 | Raises: 64 | InvalidArgumentError: if the rank is not 3 or if the image dimensions are 65 | less than the crop size. 66 | """ 67 | original_shape = tf.shape(image) 68 | 69 | rank_assertion = tf.Assert( 70 | tf.equal(tf.rank(image), 3), 71 | ['Rank of image must be equal to 3.']) 72 | with tf.control_dependencies([rank_assertion]): 73 | cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]]) 74 | 75 | size_assertion = tf.Assert( 76 | tf.logical_and( 77 | tf.greater_equal(original_shape[0], crop_height), 78 | tf.greater_equal(original_shape[1], crop_width)), 79 | ['Crop size greater than the image size.']) 80 | 81 | offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) 82 | 83 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to 84 | # define the crop size. 85 | with tf.control_dependencies([size_assertion]): 86 | image = tf.slice(image, offsets, cropped_shape) 87 | return tf.reshape(image, cropped_shape) 88 | 89 | 90 | def _random_crop(image_list, crop_height, crop_width): 91 | """Crops the given list of images. 92 | 93 | The function applies the same crop to each image in the list. This can be 94 | effectively applied when there are multiple image inputs of the same 95 | dimension such as: 96 | 97 | image, depths, normals = _random_crop([image, depths, normals], 120, 150) 98 | 99 | Args: 100 | image_list: a list of image tensors of the same dimension but possibly 101 | varying channel. 102 | crop_height: the new height. 103 | crop_width: the new width. 104 | 105 | Returns: 106 | the image_list with cropped images. 107 | 108 | Raises: 109 | ValueError: if there are multiple image inputs provided with different size 110 | or the images are smaller than the crop dimensions. 111 | """ 112 | if not image_list: 113 | raise ValueError('Empty image_list.') 114 | 115 | # Compute the rank assertions. 116 | rank_assertions = [] 117 | for i in range(len(image_list)): 118 | image_rank = tf.rank(image_list[i]) 119 | rank_assert = tf.Assert( 120 | tf.equal(image_rank, 3), 121 | ['Wrong rank for tensor %s [expected] [actual]', 122 | image_list[i].name, 3, image_rank]) 123 | rank_assertions.append(rank_assert) 124 | 125 | with tf.control_dependencies([rank_assertions[0]]): 126 | image_shape = tf.shape(image_list[0]) 127 | image_height = image_shape[0] 128 | image_width = image_shape[1] 129 | crop_size_assert = tf.Assert( 130 | tf.logical_and( 131 | tf.greater_equal(image_height, crop_height), 132 | tf.greater_equal(image_width, crop_width)), 133 | ['Crop size greater than the image size.']) 134 | 135 | asserts = [rank_assertions[0], crop_size_assert] 136 | 137 | for i in range(1, len(image_list)): 138 | image = image_list[i] 139 | asserts.append(rank_assertions[i]) 140 | with tf.control_dependencies([rank_assertions[i]]): 141 | shape = tf.shape(image) 142 | height = shape[0] 143 | width = shape[1] 144 | 145 | height_assert = tf.Assert( 146 | tf.equal(height, image_height), 147 | ['Wrong height for tensor %s [expected][actual]', 148 | image.name, height, image_height]) 149 | width_assert = tf.Assert( 150 | tf.equal(width, image_width), 151 | ['Wrong width for tensor %s [expected][actual]', 152 | image.name, width, image_width]) 153 | asserts.extend([height_assert, width_assert]) 154 | 155 | # Create a random bounding box. 156 | # 157 | # Use tf.random_uniform and not numpy.random.rand as doing the former would 158 | # generate random numbers at graph eval time, unlike the latter which 159 | # generates random numbers at graph definition time. 160 | with tf.control_dependencies(asserts): 161 | max_offset_height = tf.reshape(image_height - crop_height + 1, []) 162 | with tf.control_dependencies(asserts): 163 | max_offset_width = tf.reshape(image_width - crop_width + 1, []) 164 | offset_height = tf.random_uniform( 165 | [], maxval=max_offset_height, dtype=tf.int32) 166 | offset_width = tf.random_uniform( 167 | [], maxval=max_offset_width, dtype=tf.int32) 168 | 169 | return [_crop(image, offset_height, offset_width, 170 | crop_height, crop_width) for image in image_list] 171 | 172 | 173 | def _central_crop(image_list, crop_height, crop_width): 174 | """Performs central crops of the given image list. 175 | 176 | Args: 177 | image_list: a list of image tensors of the same dimension but possibly 178 | varying channel. 179 | crop_height: the height of the image following the crop. 180 | crop_width: the width of the image following the crop. 181 | 182 | Returns: 183 | the list of cropped images. 184 | """ 185 | outputs = [] 186 | for image in image_list: 187 | image_height = tf.shape(image)[0] 188 | image_width = tf.shape(image)[1] 189 | 190 | offset_height = (image_height - crop_height) / 2 191 | offset_width = (image_width - crop_width) / 2 192 | 193 | outputs.append(_crop(image, offset_height, offset_width, 194 | crop_height, crop_width)) 195 | return outputs 196 | 197 | 198 | def _mean_image_subtraction(image, means): 199 | """Subtracts the given means from each image channel. 200 | 201 | For example: 202 | means = [123.68, 116.779, 103.939] 203 | image = _mean_image_subtraction(image, means) 204 | 205 | Note that the rank of `image` must be known. 206 | 207 | Args: 208 | image: a tensor of size [height, width, C]. 209 | means: a C-vector of values to subtract from each channel. 210 | 211 | Returns: 212 | the centered image. 213 | 214 | Raises: 215 | ValueError: If the rank of `image` is unknown, if `image` has a rank other 216 | than three or if the number of channels in `image` doesn't match the 217 | number of values in `means`. 218 | """ 219 | if image.get_shape().ndims != 3: 220 | raise ValueError('Input must be of size [height, width, C>0]') 221 | num_channels = image.get_shape().as_list()[-1] 222 | if len(means) != num_channels: 223 | raise ValueError('len(means) must match the number of channels') 224 | 225 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) 226 | for i in range(num_channels): 227 | channels[i] -= means[i] 228 | return tf.concat(axis=2, values=channels) 229 | 230 | 231 | def _smallest_size_at_least(height, width, smallest_side): 232 | """Computes new shape with the smallest side equal to `smallest_side`. 233 | 234 | Computes new shape with the smallest side equal to `smallest_side` while 235 | preserving the original aspect ratio. 236 | 237 | Args: 238 | height: an int32 scalar tensor indicating the current height. 239 | width: an int32 scalar tensor indicating the current width. 240 | smallest_side: A python integer or scalar `Tensor` indicating the size of 241 | the smallest side after resize. 242 | 243 | Returns: 244 | new_height: an int32 scalar tensor indicating the new height. 245 | new_width: and int32 scalar tensor indicating the new width. 246 | """ 247 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 248 | 249 | height = tf.to_float(height) 250 | width = tf.to_float(width) 251 | smallest_side = tf.to_float(smallest_side) 252 | 253 | scale = tf.cond(tf.greater(height, width), 254 | lambda: smallest_side / width, 255 | lambda: smallest_side / height) 256 | new_height = tf.to_int32(tf.rint(height * scale)) 257 | new_width = tf.to_int32(tf.rint(width * scale)) 258 | return new_height, new_width 259 | 260 | 261 | def _aspect_preserving_resize(image, smallest_side): 262 | """Resize images preserving the original aspect ratio. 263 | 264 | Args: 265 | image: A 3-D image `Tensor`. 266 | smallest_side: A python integer or scalar `Tensor` indicating the size of 267 | the smallest side after resize. 268 | 269 | Returns: 270 | resized_image: A 3-D tensor containing the resized image. 271 | """ 272 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 273 | 274 | shape = tf.shape(image) 275 | height = shape[0] 276 | width = shape[1] 277 | new_height, new_width = _smallest_size_at_least(height, width, smallest_side) 278 | image = tf.expand_dims(image, 0) 279 | resized_image = tf.image.resize_bilinear(image, [new_height, new_width], 280 | align_corners=False) 281 | resized_image = tf.squeeze(resized_image) 282 | resized_image.set_shape([None, None, 3]) 283 | return resized_image 284 | 285 | 286 | def preprocess_for_train(image, 287 | output_height, 288 | output_width, 289 | resize_side_min=_RESIZE_SIDE_MIN, 290 | resize_side_max=_RESIZE_SIDE_MAX): 291 | """Preprocesses the given image for training. 292 | 293 | Note that the actual resizing scale is sampled from 294 | [`resize_size_min`, `resize_size_max`]. 295 | 296 | Args: 297 | image: A `Tensor` representing an image of arbitrary size. 298 | output_height: The height of the image after preprocessing. 299 | output_width: The width of the image after preprocessing. 300 | resize_side_min: The lower bound for the smallest side of the image for 301 | aspect-preserving resizing. 302 | resize_side_max: The upper bound for the smallest side of the image for 303 | aspect-preserving resizing. 304 | 305 | Returns: 306 | A preprocessed image. 307 | """ 308 | resize_side = tf.random_uniform( 309 | [], minval=resize_side_min, maxval=resize_side_max+1, dtype=tf.int32) 310 | 311 | image = _aspect_preserving_resize(image, resize_side) 312 | image = _random_crop([image], output_height, output_width)[0] 313 | image.set_shape([output_height, output_width, 3]) 314 | image = tf.to_float(image) 315 | image = tf.image.random_flip_left_right(image) 316 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 317 | 318 | 319 | def preprocess_for_eval(image, output_height, output_width, resize_side): 320 | """Preprocesses the given image for evaluation. 321 | 322 | Args: 323 | image: A `Tensor` representing an image of arbitrary size. 324 | output_height: The height of the image after preprocessing. 325 | output_width: The width of the image after preprocessing. 326 | resize_side: The smallest side of the image for aspect-preserving resizing. 327 | 328 | Returns: 329 | A preprocessed image. 330 | """ 331 | image = _aspect_preserving_resize(image, resize_side) 332 | image = _central_crop([image], output_height, output_width)[0] 333 | image.set_shape([output_height, output_width, 3]) 334 | image = tf.to_float(image) 335 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 336 | 337 | 338 | def preprocess_image(image, output_height, output_width, is_training=False, 339 | resize_side_min=_RESIZE_SIDE_MIN, 340 | resize_side_max=_RESIZE_SIDE_MAX): 341 | """Preprocesses the given image. 342 | 343 | Args: 344 | image: A `Tensor` representing an image of arbitrary size. 345 | output_height: The height of the image after preprocessing. 346 | output_width: The width of the image after preprocessing. 347 | is_training: `True` if we're preprocessing the image for training and 348 | `False` otherwise. 349 | resize_side_min: The lower bound for the smallest side of the image for 350 | aspect-preserving resizing. If `is_training` is `False`, then this value 351 | is used for rescaling. 352 | resize_side_max: The upper bound for the smallest side of the image for 353 | aspect-preserving resizing. If `is_training` is `False`, this value is 354 | ignored. Otherwise, the resize side is sampled from 355 | [resize_size_min, resize_size_max]. 356 | 357 | Returns: 358 | A preprocessed image. 359 | """ 360 | if is_training: 361 | return preprocess_for_train(image, output_height, output_width, 362 | resize_side_min, resize_side_max) 363 | else: 364 | return preprocess_for_eval(image, output_height, output_width, 365 | resize_side_min) -------------------------------------------------------------------------------- /utils/visualize_embedding_protos_and_samples.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @time: 3/12/19 4:14 PM 5 | @author: Chen He 6 | @site: 7 | @file: proto_dim_reduction.py 8 | @description: 9 | """ 10 | 11 | import matplotlib as mpl 12 | 13 | mpl.use('Agg') 14 | import matplotlib.pyplot as plt 15 | 16 | from sklearn.decomposition import PCA 17 | from sklearn.manifold import TSNE 18 | 19 | import numpy as np 20 | import os 21 | import pickle 22 | 23 | num_classes_dict = { 24 | 'fashion-mnist': 10, 25 | 'svhn': 10, 26 | 'imagenet_64x64_dogs': 30, 27 | 'imagenet_64x64_birds': 30 28 | } 29 | 30 | test_intervals_dict = { 31 | 'fashion-mnist': 500, 32 | 'svhn': 500, 33 | 'imagenet_64x64_dogs': 500, 34 | 'imagenet_64x64_birds': 500 35 | } 36 | 37 | num_iters_dict = { 38 | 'fashion-mnist': 10000, 39 | 'svhn': 10000, 40 | 'imagenet_64x64_dogs': 10000, 41 | 'imagenet_64x64_birds': 10000 42 | } 43 | 44 | embedding_dim_dict = { 45 | 'fashion-mnist': 1024, 46 | 'svhn': 1024, 47 | 'imagenet_64x64_dogs': 1024, 48 | 'imagenet_64x64_birds': 1024 49 | } 50 | 51 | USE_KMEANS = False 52 | NUM_REAL = 128 53 | 54 | colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#42d4f4', '#f032e6', '#bfef45', '#fabebe', 55 | '#469990', '#e6beff', '#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#a9a9a9', 56 | '#ffffff', '#000000'] 57 | 58 | 59 | def visualize(folder_path, dataset_name, nb_cl, vis_pca=False, vis_tsne=True, embedding_dim=None, test_interval=None): 60 | num_classes = num_classes_dict[dataset_name] 61 | if test_interval is None: 62 | test_interval = test_intervals_dict[dataset_name] 63 | num_iters = num_iters_dict[dataset_name] 64 | if embedding_dim is None: 65 | embedding_dim = embedding_dim_dict[dataset_name] 66 | 67 | # session 68 | for num_seen_classes in range(nb_cl, num_classes + nb_cl, nb_cl): 69 | print('Task %d' % (num_seen_classes / nb_cl)) 70 | 71 | subfolder = os.path.join(folder_path, 'class_1-%d' % num_seen_classes) 72 | 73 | proto_folder = os.path.join(subfolder, 'protos') 74 | sample_folder = os.path.join(subfolder, 'samples') 75 | 76 | vis_folder = os.path.join(subfolder, 'vis_2d') 77 | if not os.path.exists(vis_folder): 78 | os.makedirs(vis_folder) 79 | 80 | # iteration 81 | for iter_idx in range(test_interval, num_iters + test_interval, test_interval): 82 | print('Iteration %d' % (iter_idx)) 83 | 84 | pca_protos_dict = dict() 85 | pca_samples_dict = dict() 86 | 87 | all_embedding_protos = np.zeros([0, embedding_dim], np.float) 88 | all_embedding_samples = np.zeros([0, embedding_dim], np.float) 89 | 90 | # class 91 | for category_idx in range(num_seen_classes): 92 | proto_embedding_file = os.path.join(proto_folder, 'class_%d' % (category_idx + 1), 93 | 'embedding_protos_%d.pkl' % iter_idx) 94 | with open(proto_embedding_file, 'rb') as fin: 95 | embedding_protos = pickle.load(fin) 96 | 97 | sample_embedding_file = os.path.join(sample_folder, 'class_%d' % (category_idx + 1), 98 | 'embedding_samples_%d.pkl' % iter_idx) 99 | with open(sample_embedding_file, 'rb') as fin: 100 | embedding_samples = pickle.load(fin) 101 | 102 | pca_protos_dict[category_idx] = embedding_protos 103 | pca_samples_dict[category_idx] = embedding_samples 104 | 105 | all_embedding_protos = np.concatenate((all_embedding_protos, embedding_protos)) 106 | all_embedding_samples = np.concatenate((all_embedding_samples, embedding_samples)) 107 | 108 | if vis_pca: 109 | pca = PCA(n_components=2) 110 | pca.fit(np.concatenate((all_embedding_protos, all_embedding_samples))) 111 | plt.figure(figsize=(6, 6), dpi=150) 112 | for category_idx in range(num_seen_classes): 113 | pca_protos = pca_protos_dict[category_idx] 114 | pca_samples = pca_samples_dict[category_idx] 115 | plt.scatter(pca_protos[:, 0], pca_protos[:, 1], marker='+', alpha=1., color=colors[category_idx], 116 | label='Class %d' % (category_idx + 1)) 117 | plt.scatter(pca_samples[:, 0], pca_samples[:, 1], marker='.', alpha=1., color=colors[category_idx], 118 | label='Class %d' % (category_idx + 1)) 119 | 120 | plt.legend() 121 | plt.savefig(os.path.join(vis_folder, 'pca_%d.pdf' % iter_idx)) 122 | plt.close() 123 | 124 | if vis_tsne: 125 | tsne = TSNE(n_components=2) 126 | num_proto_per_class = len(embedding_protos) 127 | num_sample_per_class = len(embedding_samples) 128 | tsne_result = tsne.fit_transform(np.concatenate((all_embedding_protos, all_embedding_samples))) 129 | plt.figure(figsize=(6, 6), dpi=150) 130 | for category_idx in range(num_seen_classes): 131 | tsne_protos = tsne_result[ 132 | category_idx * num_proto_per_class: (category_idx + 1) * num_proto_per_class] 133 | tsne_samples = tsne_result[len(all_embedding_protos) + category_idx * num_sample_per_class: len( 134 | all_embedding_protos) + (category_idx + 1) * num_sample_per_class] 135 | plt.scatter(tsne_protos[:, 0], tsne_protos[:, 1], marker='+', alpha=1., color=colors[category_idx], 136 | label='Class %d' % (category_idx + 1)) 137 | plt.scatter(tsne_samples[:, 0], tsne_samples[:, 1], marker='.', alpha=1., 138 | color=colors[category_idx], 139 | label='Class %d' % (category_idx + 1)) 140 | 141 | plt.legend() 142 | plt.savefig(os.path.join(vis_folder, 'tsne_%d.pdf' % iter_idx)) 143 | plt.close() 144 | 145 | 146 | if __name__ == '__main__': 147 | # visualize('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/from_scratch', dataset_name='fashion-mnist', nb_cl=2) 148 | # visualize('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune_improved', dataset_name='fashion-mnist', nb_cl=2) 149 | # visualize('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_non-trainable_20_weight_0.000100_squared_l2/dim_1024/finetune_improved', dataset_name='fashion-mnist', nb_cl=2) 150 | 151 | # visualize('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2_0.010000/dim_1024/from_scratch', dataset_name='fashion-mnist', nb_cl=2) 152 | # visualize('../result/protogan/svhn_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2_0.010000/embedding_dim_1024_gan_dim_64/finetune_improved', dataset_name='svhn', nb_cl=2) 153 | # visualize('../result/protogan_v2/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1_reconstr_0.000100/0.0002_0.5_0.999/10000/proto_static_20_squared_l2_0.010000_update_random_init/dim_128/from_scratch', dataset_name='fashion-mnist', nb_cl=2, embedding_dim=128) 154 | # visualize('../result/protogan_v2/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1_reconstr_0.000100/0.0002_0.5_0.999/10000/__proto_static_20_squared_l2_0.010000_update/dim_128/from_scratch', dataset_name='fashion-mnist', nb_cl=2, embedding_dim=128) 155 | # visualize( 156 | # '../result/protogan_v1_3_2/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_dup_2_20_weight_0.000000_0.000000_squared_l2_0.010000/finetune_improved_noise_0.0', 157 | # dataset_name='fashion-mnist', nb_cl=2, embedding_dim=4096) 158 | # visualize( 159 | # '../result/protogan_v1_3_2/svhn_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_dup_2_20_weight_0.000000_0.000000_squared_l2_0.010000/gan_dim_64/finetune_improved_noise_0.5', 160 | # dataset_name='svhn', nb_cl=2, embedding_dim=4096) 161 | # visualize( 162 | # '../result/protogan_v1_3_2/svhn_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_dup_2_20_weight_0.000000_0.000000_squared_l2_0.010000_multi_center/gan_dim_64/finetune_improved_noise_0.5', 163 | # dataset_name='svhn', nb_cl=2, embedding_dim=4096) 164 | 165 | # visualize( 166 | # '../result/protogan_v1_3_8/svhn_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_20_weight_0.000000_0.000000_squared_l2_0.010000/gan_dim_64/finetune_improved_v2_noise_0.5_exemplars_dual_use_1', 167 | # dataset_name='svhn', nb_cl=2, embedding_dim=4096, test_interval=2000) 168 | visualize( 169 | '../result/protogan_v1_3_8/svhn_order_1/nb_cl_2/dcgan_critic_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_random_20_weight_0.000000_0.000000_squared_l2_0.010000_train_rel_center/gan_dim_64/finetune_improved_v2_noise_0.5_exemplars_dual_use_1', 170 | dataset_name='svhn', nb_cl=2, embedding_dim=4096, test_interval=2000) 171 | -------------------------------------------------------------------------------- /utils/visualize_result_single.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | """ 4 | @time: 1/25/18 9:48 AM 5 | @author: Chen He 6 | @site: 7 | @file: visualize_result.py 8 | @description: 9 | """ 10 | 11 | import matplotlib as mpl 12 | 13 | mpl.use('Agg') 14 | from pylab import * 15 | import numpy as np 16 | import os 17 | import pickle 18 | 19 | num_classes_dict = { 20 | 'fashion-mnist': 10, 21 | 'mnist': 10, 22 | 'svhn': 10, 23 | 'cifar-10': 10, 24 | 'cifar-100': 100, 25 | 'imagenet_64x64_dogs': 30, 26 | 'imagenet_32x32_mini': 100 27 | } 28 | 29 | test_intervals_dict = { 30 | 'fashion-mnist': 500, 31 | 'mnist': 500, 32 | 'svhn': 500, 33 | 'cifar-10': 500, 34 | 'cifar-100': 500, 35 | 'imagenet_64x64_dogs': 500, 36 | 'imagenet_32x32_mini': 500 37 | } 38 | 39 | num_iters_dict = { 40 | 'fashion-mnist': 500, 41 | 'mnist': 500, 42 | 'svhn': 500, 43 | 'cifar-10': 10000, 44 | 'cifar-100': 10000, 45 | 'imagenet_64x64_dogs': 10000, 46 | 'imagenet_32x32_mini': 10000, 47 | } 48 | 49 | 50 | # Draw the accuracy curve of certain method 51 | def vis_acc_and_fid(folder_path, dataset, nb_cl, test_interval=None, num_iters=None, vis_acc=True, vis_fid=True, 52 | num_classes=None, fid_num=10000): 53 | if test_interval is None: 54 | test_interval = test_intervals_dict[dataset] 55 | 56 | if num_iters is None: 57 | num_iters = num_iters_dict[dataset] 58 | 59 | if num_classes is None: 60 | num_classes = num_classes_dict[dataset] 61 | test_fid_interval = num_iters 62 | 63 | ''' 64 | vis acc 65 | ''' 66 | if vis_acc: 67 | history_acc = dict() 68 | task_num_acc = 0 69 | 70 | for num_seen_classes in range(nb_cl, num_classes + nb_cl, nb_cl): 71 | subfolder = os.path.join(folder_path, 'class_1-%d' % num_seen_classes) 72 | 73 | # check whether the task has been trained or not 74 | if not os.path.exists(os.path.join(subfolder, 'class_1-%d_conf_mat.pkl' % nb_cl)): 75 | break 76 | 77 | task_num_acc += 1 78 | 79 | if vis_acc: 80 | history_acc['task %d' % (num_seen_classes / nb_cl)] = dict() 81 | 82 | # load acc 83 | for task_num_seen_classes in range(nb_cl, num_seen_classes + nb_cl, nb_cl): 84 | conf_mat_filename = os.path.join(subfolder, 'class_1-%d_conf_mat.pkl' % task_num_seen_classes) 85 | with open(conf_mat_filename, 'rb') as fin: 86 | conf_mat_over_time = pickle.load(fin) 87 | 88 | for iter in range(test_interval, num_iters + test_interval, test_interval): 89 | conf_mat = conf_mat_over_time[iter] 90 | accs = np.diag(conf_mat) * 100.0 / np.sum(conf_mat, axis=0) 91 | acc = np.mean(accs) 92 | 93 | history_acc['task %d' % (task_num_seen_classes / nb_cl)][ 94 | (task_num_acc - 1) * num_iters + iter] = acc 95 | 96 | if task_num_acc < 10: 97 | plt.figure(figsize=(18, 9), dpi=150) 98 | for task_idx in range(task_num_acc): 99 | ax = plt.subplot('%d%d%d' % (task_num_acc, 1, task_idx + 1)) 100 | if task_idx == 0: 101 | ax.set_title(dataset, fontdict={'size': 14, 'weight': 'bold'}) 102 | 103 | acc_over_time = history_acc['task %d' % (task_idx + 1)] 104 | x, z = zip(*sorted(acc_over_time.items(), key=lambda d: d[0])) 105 | ax.plot(x, z, marker='.') 106 | 107 | # Horizontal reference lines 108 | for i in range(num_iters, num_iters * task_num_acc, num_iters): 109 | plt.vlines(i, 0, 100, colors="lightgray", linestyles="dashed") 110 | 111 | ax.set_xlim(0, task_num_acc * num_iters) 112 | ax.set_ylim(np.max([np.min(z) - 1., 0]), np.min([np.max(z) + 1., 100])) 113 | ax.autoscale_view('tight') 114 | 115 | ax.xaxis.set_visible(False) 116 | ax.set_ylabel('Task %d' % (task_idx + 1), fontdict={'size': 12, 'weight': 'bold'}) 117 | 118 | ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) 119 | 120 | ax.spines['top'].set_visible(False) 121 | ax.spines['right'].set_visible(False) 122 | ax.spines['bottom'].set_visible(False) 123 | 124 | # set the xaxis of the last subplot ON 125 | ax.spines['bottom'].set_visible(True) 126 | ax.xaxis.set_visible(True) 127 | ax.set_xlabel('iterations', fontdict={'size': 12, 'weight': 'bold'}) 128 | 129 | plt.margins(0) 130 | 131 | output_name = os.path.join(folder_path, 'accuracy_curve.pdf') 132 | plt.savefig(output_name) 133 | plt.close() 134 | 135 | # txt versions (only record the average acc like in iCaRL) 136 | average_accs = [] 137 | for task_idx_acc in range(task_num_acc): 138 | average_acc = history_acc['task %d' % (task_idx_acc + 1)][num_iters * (task_idx_acc + 1)] 139 | average_accs.append(average_acc) 140 | with open(os.path.join(folder_path, 'average_accuracy.txt'), 'w') as fout: 141 | fout.write(os.linesep.join([str(elem) for elem in average_accs])) 142 | 143 | if num_classes / nb_cl == task_num_acc: 144 | acc_arr = [] 145 | for task_idx in range(task_num_acc): 146 | acc_arr.append(np.mean(history_acc['task %d' % (task_idx + 1)].values())) 147 | acc_auc = np.mean(acc_arr) 148 | with open(os.path.join(folder_path, 'ta_acc.txt'), 'w') as fout: 149 | fout.write(str(acc_auc)) 150 | 151 | ''' 152 | vis fid 153 | ''' 154 | if vis_fid: 155 | history_fid = dict() 156 | task_num_fid = 0 157 | 158 | sub_name = 'fid_%d' % fid_num if not fid_num == 10000 else 'fid' 159 | 160 | for num_seen_classes in range(nb_cl, num_classes + nb_cl, nb_cl): 161 | subfolder = os.path.join(folder_path, 'class_1-%d' % num_seen_classes) 162 | 163 | # check whether the task has been trained or not 164 | if not os.path.exists(os.path.join(subfolder, 'cond_%s.pkl' % sub_name)): 165 | break 166 | 167 | task_num_fid += 1 168 | 169 | # load fid 170 | if vis_fid: 171 | fid_filename = os.path.join(subfolder, 'cond_%s.pkl' % sub_name) 172 | with open(fid_filename, 'rb') as fin: 173 | fid = pickle.load(fin) 174 | 175 | history_fid['task %d' % (num_seen_classes / nb_cl)] = dict() 176 | 177 | for iter in range(test_fid_interval, num_iters + test_fid_interval, test_fid_interval): 178 | fid_vals = fid[iter] 179 | for task_num_seen_classes in range(nb_cl, num_seen_classes + nb_cl, nb_cl): 180 | fid_sum = [] 181 | for i in range(task_num_seen_classes): 182 | fid_sum.append(fid_vals[i + 1]) 183 | history_fid['task %d' % (task_num_seen_classes / nb_cl)][ 184 | (task_num_fid - 1) * num_iters + iter] = np.mean(fid_sum) 185 | 186 | if task_num_fid < 10: 187 | plt.figure(figsize=(18, 9), dpi=150) 188 | for task_idx in range(task_num_fid): 189 | ax = plt.subplot('%d%d%d' % (task_num_fid, 1, task_idx + 1)) 190 | if task_idx == 0: 191 | ax.set_title(dataset, fontdict={'size': 14, 'weight': 'bold'}) 192 | 193 | fid_over_time = history_fid['task %d' % (task_idx + 1)] 194 | x, z = zip(*sorted(fid_over_time.items(), key=lambda d: d[0])) 195 | ax.plot(x, z, marker='.') 196 | 197 | # Horizontal reference lines 198 | for i in range(num_iters, num_iters * task_num_fid, num_iters): 199 | plt.vlines(i, 0, 100, colors="lightgray", linestyles="dashed") 200 | 201 | ax.set_xlim(0, task_num_fid * num_iters) 202 | ax.set_ylim(np.min(z) - 1., np.max(z) + 1.) 203 | ax.autoscale_view('tight') 204 | 205 | ax.xaxis.set_visible(False) 206 | ax.set_ylabel('Task %d' % (task_idx + 1), fontdict={'size': 12, 'weight': 'bold'}) 207 | 208 | ax.yaxis.set_major_formatter(FormatStrFormatter('%.1f')) 209 | 210 | ax.spines['top'].set_visible(False) 211 | ax.spines['right'].set_visible(False) 212 | ax.spines['bottom'].set_visible(False) 213 | 214 | # set the xaxis of the last subplot ON 215 | ax.spines['bottom'].set_visible(True) 216 | ax.xaxis.set_visible(True) 217 | ax.set_xlabel('iterations', fontdict={'size': 12, 'weight': 'bold'}) 218 | 219 | plt.margins(0) 220 | 221 | output_name = os.path.join(folder_path, '%s_curve.pdf' % sub_name) 222 | plt.savefig(output_name) 223 | plt.close() 224 | 225 | # txt versions (only record the average fid like in iCaRL) 226 | average_fids = [] 227 | for task_idx_fid in range(task_num_fid): 228 | average_fid = history_fid['task %d' % (task_idx_fid + 1)][num_iters * (task_idx_fid + 1)] 229 | average_fids.append(average_fid) 230 | with open(os.path.join(folder_path, 'average_%s.txt' % sub_name), 'w') as fout: 231 | fout.write(os.linesep.join([str(elem) for elem in average_fids])) 232 | 233 | if num_classes / nb_cl == task_num_fid: 234 | fid_arr = [] 235 | for task_idx in range(task_num_fid): 236 | fid_arr.append(np.mean(history_fid['task %d' % (task_idx + 1)].values())) 237 | fid_auc = np.mean(fid_arr) 238 | with open(os.path.join(folder_path, 'ta_%s.txt' % sub_name), 'w') as fout: 239 | fout.write(str(fid_auc)) 240 | 241 | 242 | if __name__ == '__main__': 243 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune_improved', dataset='fashion-mnist', nb_cl=2) 244 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune', dataset='fashion-mnist', nb_cl=2) 245 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_1/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/from_scratch', dataset='fashion-mnist', nb_cl=2) 246 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_2/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune_improved', dataset='fashion-mnist', nb_cl=2) 247 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_2/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune', dataset='fashion-mnist', nb_cl=2) 248 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_2/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/from_scratch', dataset='fashion-mnist', nb_cl=2) 249 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_3/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune_improved', dataset='fashion-mnist', nb_cl=2) 250 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_3/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/finetune', dataset='fashion-mnist', nb_cl=2) 251 | # vis_acc_and_fid('../result/protogan/fashion-mnist_order_3/nb_cl_2/dcgan_critic_1_class_1_ac_1.0_0.1/0.0002_0.5_0.999/10000/proto_static_20_weight_0.000100_squared_l2/dim_1024/from_scratch', dataset='fashion-mnist', nb_cl=2) 252 | pass --------------------------------------------------------------------------------