├── cfgs ├── cvae1.yaml ├── __init__.py ├── dcgan │ └── celeba.json ├── aae │ ├── aae_mnist_ssl_acc.png │ └── mnist.json ├── wgan │ ├── res │ │ ├── cifar10exp1_generator_loss.PNG │ │ ├── cifar10exp1_inception_score.PNG │ │ └── cifar10exp1_discriminator_loss.PNG │ ├── README.md │ ├── cifar10_2.json │ ├── wgan_gp2.json │ ├── toy.json │ ├── wgan_gp.json │ └── voc.json ├── cvaegan1.json ├── mnist1.json ├── seg │ ├── voc.json │ ├── voc2.json │ └── voc3.json ├── networkconfig.py ├── cla │ ├── imagenet1.json │ ├── violence2.json │ ├── violence.json │ ├── mnist1.json │ ├── cifar10.json │ ├── mnist2.json │ ├── mnist3.json │ ├── mnist4.json │ ├── mnist_test.json │ └── mnist5.json ├── tianchi │ ├── guangdong2.json │ ├── guangdong3.json │ └── guangdong.json ├── cvae1.json ├── vae1.json ├── cifar10.json ├── cvae3.json ├── vae │ └── mnist.json ├── cvae2.json └── impgan │ └── celeba.json ├── model ├── catgan.py ├── mmd_gan.py ├── __init__.py ├── base_detection_model.py └── model.py ├── netutils ├── util.py ├── __init__.py ├── activation.py ├── learning_rate.py ├── weightsinit.py └── sample.py ├── tester └── tester.py ├── validator ├── __init__.py ├── pca_plot_validator.py ├── validator.py ├── base_validator.py ├── test_dataset_validator.py ├── gan_toy_plot.py ├── scatter_plot.py └── tensorboard_embedding.py ├── dataset ├── __init__.py ├── kitti.py ├── dataset.py ├── violence.py ├── base_mil_dataset.py ├── svhn.py ├── gan_toy.py ├── gan_toy_ssl.py ├── mnist.py └── cifar10.py ├── decoder ├── __init__.py ├── decoder.py ├── decoder_simple.py └── decoder_pixel.py ├── encoder ├── __init__.py ├── encoder.py └── encoder_simple.py ├── generator ├── __init__.py ├── generator.py ├── generator_simple.py ├── generator_cifar10_resnet.py └── generator_conv.py ├── trainer ├── __init__.py ├── trainer.py ├── supervised_mil.py ├── unsupervised.py └── supervised.py ├── classifier ├── __init__.py ├── classifier.py ├── classifier_unet.py └── classifier_simple.py ├── network ├── __init__.py └── network.py ├── discriminator ├── __init__.py ├── discriminator.py └── discriminator_simple.py ├── VAE-GAN.sublime-project ├── .gitignore ├── test ├── test_cifar10.py ├── test_inception.py ├── test_imagenet.py ├── test_unet.py ├── test_ms_coco.py ├── test_mnist.py ├── test_production.py ├── test_celeba.py ├── test_resnet.py ├── test_pascal_voc.py ├── test_vgg.py └── test_violence.py ├── LICENSE ├── README.md ├── test.py ├── train.py └── train_batch.py /cfgs/cvae1.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/catgan.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/mmd_gan.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /netutils/util.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tester/tester.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cfgs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cfgs/dcgan/celeba.json: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /netutils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /validator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /decoder/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /encoder/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /generator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /classifier/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /discriminator/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /VAE-GAN.sublime-project: -------------------------------------------------------------------------------- 1 | { 2 | "folders": 3 | [ 4 | { 5 | "path": "." 6 | } 7 | ] 8 | } 9 | -------------------------------------------------------------------------------- /cfgs/aae/aae_mnist_ssl_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanzhicong/VAE-GAN/HEAD/cfgs/aae/aae_mnist_ssl_acc.png -------------------------------------------------------------------------------- /cfgs/wgan/res/cifar10exp1_generator_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanzhicong/VAE-GAN/HEAD/cfgs/wgan/res/cifar10exp1_generator_loss.PNG -------------------------------------------------------------------------------- /cfgs/wgan/res/cifar10exp1_inception_score.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanzhicong/VAE-GAN/HEAD/cfgs/wgan/res/cifar10exp1_inception_score.PNG -------------------------------------------------------------------------------- /cfgs/wgan/res/cifar10exp1_discriminator_loss.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanzhicong/VAE-GAN/HEAD/cfgs/wgan/res/cifar10exp1_discriminator_loss.PNG -------------------------------------------------------------------------------- /network/network.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | # from classifier.classifier import get_classifier 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *pycache* 2 | assets* 3 | dataset/extra_files* 4 | *.pyc 5 | .vscode* 6 | sftp-config.json 7 | lib/* 8 | validator/inception_score/imagenet* 9 | notebook* 10 | -------------------------------------------------------------------------------- /decoder/decoder.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | sys.path.append('../') 7 | 8 | 9 | 10 | def get_decoder(name, config, is_training): 11 | 12 | if name in ['decoder', 'simple decoder', 'decoder_simple']: 13 | from .decoder_simple import DecoderSimple 14 | return DecoderSimple(config, is_training) 15 | else : 16 | raise Exception("None decoder named " + name) 17 | -------------------------------------------------------------------------------- /test/test_cifar10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('../') 6 | 7 | import tensorflow as tf 8 | import matplotlib.pyplot as plt 9 | 10 | from dataset.cifar10 import Cifar10 11 | 12 | if __name__ == '__main__': 13 | config = { 14 | 'batch_size' : 16, 15 | 'output shape' : [32, 32, 3] 16 | } 17 | 18 | dataset = Cifar10(config) 19 | 20 | for ind, x_batch, y_batch in dataset.iter_train_images(): 21 | plt.figure(0) 22 | plt.imshow(x_batch[0, :, :, :]) 23 | plt.pause(1) 24 | -------------------------------------------------------------------------------- /classifier/classifier.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('../') 6 | 7 | 8 | def get_classifier(name, config, is_training): 9 | if name in ['classifier', 'simple classifier', 'ClassifierSimple', 'vgg', 'VGG16']: 10 | from .classifier_simple import ClassifierSimple 11 | return ClassifierSimple(config, is_training) 12 | elif name == 'classifier_unet' or name == 'unet classifier': 13 | from .classifier_unet import ClassifierUNet 14 | return ClassifierUNet(config, is_training) 15 | else: 16 | raise Exception("No classifier named " + name) 17 | 18 | -------------------------------------------------------------------------------- /test/test_inception.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | 10 | from network.inception_v3 import InceptionV3 11 | 12 | if __name__ == '__main__': 13 | config = { 14 | "output dims" : 1000, 15 | 'output_activation' : 'softmax' 16 | } 17 | 18 | inception_model = InceptionV3(config, False) 19 | 20 | x = tf.placeholder(tf.float32, shape=(None, 299, 299, 3), name='input') 21 | 22 | y, end_points = inception_model(x) 23 | 24 | for name, value in end_points.items(): 25 | print(name, ' --> ', value.get_shape()) 26 | 27 | -------------------------------------------------------------------------------- /test/test_imagenet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | from dataset.imagenet import ImageNet 15 | 16 | if __name__ == '__main__': 17 | config = { 18 | 'batch_size' : 16, 19 | 'output shape' : [224, 224, 3] 20 | } 21 | dataset = ImageNet(config) 22 | 23 | for ind, x_batch, y_batch in dataset.iter_train_images(): 24 | print(ind, x_batch.shape, y_batch.shape) 25 | plt.figure(0) 26 | plt.imshow(x_batch[0, :, :, :]) 27 | plt.pause(1) 28 | 29 | 30 | -------------------------------------------------------------------------------- /test/test_unet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | from network.unet import UNet 10 | 11 | if __name__ == '__main__': 12 | config = { 13 | "output dims" : 10, 14 | 'name' : 'UNet' 15 | } 16 | 17 | is_training = tf.placeholder(tf.bool, name='is_training') 18 | 19 | model = UNet(config, is_training) 20 | x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') 21 | y, end_points = model(x) 22 | 23 | 24 | for name, value in end_points.items(): 25 | print(name, ' --> ', value.get_shape()) 26 | 27 | for var in model.vars: 28 | print(var.name, ' --> ', var.get_shape()) 29 | 30 | -------------------------------------------------------------------------------- /test/test_ms_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('../') 6 | 7 | import tensorflow as tf 8 | import matplotlib.pyplot as plt 9 | 10 | from dataset.coco import MSCOCO 11 | 12 | if __name__ == '__main__': 13 | config = { 14 | "output shape" : [224, 224, 3], 15 | "show warning" : True, 16 | } 17 | 18 | dataset = MSCOCO(config) 19 | indices = dataset.get_image_indices('trainval') 20 | 21 | for ind in indices: 22 | img, anno = dataset.read_image_by_index(ind, phase='trainval', method='supervised') 23 | 24 | if img is not None: 25 | plt.figure(0) 26 | plt.clf() 27 | dataset.show_image_and_anno(plt, img, anno) 28 | plt.pause(3) 29 | 30 | -------------------------------------------------------------------------------- /cfgs/wgan/README.md: -------------------------------------------------------------------------------- 1 | # WGAN实验 cifar10数据生成 # 2 | 3 | 4 | ## 实验1: 5 | 6 | 参数设置: 7 | 1. 生成器:1层全连接加3层反卷积,每层反卷积stride均为2,全连接输出(4,4,512) 8 | 2. 判别器:3层卷积加一层全连接, 9 | 3. 优化器:Adam优化器,learning rate=0.0001,beta1=0.5, beta2=0.9 10 | 11 | 实验结果: 12 | 13 | ![inception score](./res/cifar10exp1_inception_score.png) 14 | ![generator loss](./res/cifar10exp1_generator_loss.png) 15 | ![discriminator loss](./res/cifar10exp1_discriminator_loss.png) 16 | 17 | 问题记录: 18 | 1. 一直出现inception_score上不去的问题,论文提供代码的inception score能达到5.5左右,本实验中达到只能达到3.0到3.5之间 19 | 2. 在大约20k左右的时候,inception score开始稳定不变,貌似生成器生成图片的真实度已经无法再提升,随后生成器的损失逐渐增加,在23k的时候,生成器损失开始出现大幅震荡,判别器损失同时上升,inception score开始有下降的趋势。随后一直到100k,生成器损失一直在大幅震荡之中 20 | 21 | 22 | 问题修复: 23 | 1. -------------------------------------------------------------------------------- /discriminator/discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import os 5 | import sys 6 | sys.path.append('../') 7 | 8 | 9 | 10 | 11 | def get_discriminator(name, config, is_training): 12 | if name == 'discriminator': 13 | from .discriminator_simple import DiscriminatorSimple 14 | return DiscriminatorSimple(config, is_training) 15 | elif name == 'cifar10 discriminator' or name == 'discriminator_cifar10': 16 | from .discriminator_cifar10 import DiscriminatorCifar10 17 | return DiscriminatorCifar10(config, is_training) 18 | 19 | elif name == 'discriminator_conv': 20 | from .discriminator_conv import D_conv 21 | return D_conv(config, is_training) 22 | else : 23 | raise Exception("None discriminator named " + name) 24 | 25 | -------------------------------------------------------------------------------- /test/test_mnist.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | from dataset.mnist import MNIST 14 | from dataset.svhn import SVHN 15 | 16 | if __name__ == '__main__': 17 | 18 | config = { 19 | 'output shape' : [28, 28] 20 | } 21 | dataset = MNIST(config) 22 | 23 | # config = { 24 | # 'output shape' : [32, 32, 3] 25 | # } 26 | # dataset = SVHN(config) 27 | 28 | indices = dataset.get_image_indices('train', 'supervised') 29 | for i, ind in enumerate(indices): 30 | img, label = dataset.read_image_by_index(ind, 'train', 'supervised') 31 | 32 | print(label, np.argmax(label)) 33 | 34 | plt.figure(0) 35 | plt.imshow(img) 36 | plt.pause(1) 37 | 38 | -------------------------------------------------------------------------------- /test/test_production.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('../') 6 | 7 | import tensorflow as tf 8 | import matplotlib.pyplot as plt 9 | 10 | from dataset.production import ChipProduction 11 | 12 | 13 | 14 | 15 | 16 | 17 | if __name__ == '__main__': 18 | config = { 19 | 'batch_size' : 16, 20 | 'output shape' : [64, 64, 2] 21 | } 22 | 23 | dataset = ChipProduction(config) 24 | 25 | indices = dataset.get_image_indices('train') 26 | 27 | 28 | for ind in indices: 29 | image_list = dataset.read_image_by_index(ind, phase='train', method='supervised') 30 | # for image in image_list: 31 | plt.figure(0) 32 | 33 | for ind, image in enumerate(image_list): 34 | if ind < 16: 35 | plt.subplot(4, 4, ind+1) 36 | plt.imshow(image[:, :, 0]) 37 | plt.pause(3) 38 | 39 | -------------------------------------------------------------------------------- /test/test_celeba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('../') 6 | 7 | import tensorflow as tf 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | from dataset.celeba import CelebA 14 | 15 | if __name__ == '__main__': 16 | 17 | config = { 18 | "output shape" : [256, 256, 3], 19 | "output scalar range" : [0, 1] 20 | } 21 | 22 | 23 | dataset = CelebA(config) 24 | indices = dataset.get_image_indices(phase='train', method='supervised') 25 | 26 | print(indices.shape) 27 | 28 | for ind in indices: 29 | 30 | img, attr = dataset.read_image_by_index(ind, phase='train', method='supervised') 31 | 32 | print(img.shape) 33 | print(attr.shape) 34 | print(attr.max(), attr.min()) 35 | 36 | plt.figure(0) 37 | plt.imshow(img) 38 | plt.pause(4) 39 | 40 | 41 | -------------------------------------------------------------------------------- /classifier/classifier_unet.py: -------------------------------------------------------------------------------- 1 | # SOFTWARE. 2 | # ============================================================================== 3 | 4 | import os 5 | import sys 6 | 7 | 8 | import tensorflow as tf 9 | import tensorflow.contrib.layers as tcl 10 | 11 | 12 | sys.path.append('../') 13 | 14 | 15 | from netutils.weightsinit import get_weightsinit 16 | from netutils.activation import get_activation 17 | from netutils.normalization import get_normalization 18 | 19 | from network.base_network import BaseNetwork 20 | from network.unet import UNet 21 | 22 | 23 | class ClassifierUNet(BaseNetwork): 24 | def __init__(self, config, is_training): 25 | BaseNetwork.__init__(self, config, is_training) 26 | self.network = UNet(config, is_training) 27 | 28 | def __call__(self, i): 29 | x, end_points = self.network(i) 30 | return x 31 | 32 | def features(self, i): 33 | x, end_points = self.network(i) 34 | return x, end_points 35 | 36 | -------------------------------------------------------------------------------- /decoder/decoder_simple.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.layers as tcl 6 | 7 | sys.path.append('../') 8 | 9 | 10 | from netutils.weightsinit import get_weightsinit 11 | from netutils.activation import get_activation 12 | from netutils.normalization import get_normalization 13 | 14 | from network.devgg import DEVGG 15 | from network.base_network import BaseNetwork 16 | 17 | 18 | class DecoderSimple(BaseNetwork): 19 | 20 | def __init__(self, config, is_training): 21 | super(DecoderSimple, self).__init__(config, is_training) 22 | self.name = config.get('name', 'DecoderSimple') 23 | self.config = config 24 | network_config = config.copy() 25 | self.network = DEVGG(network_config, is_training) 26 | 27 | def __call__(self, x, condition=None): 28 | if condition is not None: 29 | x = tf.concatenate([x, condition], axis=-1) 30 | x, end_points = self.network(x) 31 | 32 | return x 33 | 34 | -------------------------------------------------------------------------------- /test/test_resnet.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | from network.resnet import Resnet 10 | 11 | if __name__ == '__main__': 12 | config = { 13 | 'output_classes' : 10, 14 | 'name' : 'Resnet50', 15 | 'normalization' : 'batch_norm', 16 | 'load pretrained weights' : 'resnet50', 17 | "output dims" : 100, 18 | 'debug' : True, 19 | } 20 | 21 | model = Resnet(config, True) 22 | x = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') 23 | 24 | y, end_points = model(x) 25 | 26 | 27 | # for name, value in end_points.items(): 28 | # print(name, ' --> ', value.get_shape()) 29 | 30 | for var in model.all_vars: 31 | print(var.name, ' --> ', var.get_shape()) 32 | 33 | tfconfig = tf.ConfigProto() 34 | tfconfig.gpu_options.allow_growth = True 35 | 36 | with tf.Session(config=tfconfig) as sess: 37 | model.load_pretrained_weights(sess) 38 | 39 | 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 ZhicongYan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cfgs/cvaegan1.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "cvaegan1", 3 | 4 | "dataset name" : "production", 5 | "dataset params" : { 6 | }, 7 | 8 | "model" : "cvaegan", 9 | "model params" : { 10 | 11 | "input shape" : [128, 128, 3], 12 | "nb_classes" : 10, 13 | "z_dim" : 256, 14 | "is_training" : true, 15 | 16 | "sample_func" : "normal", 17 | 18 | "encoder" : "EncoderPixel", 19 | "encoder params" : { 20 | "output dims" : 32 21 | }, 22 | 23 | "decoder" : "DecoderPixel", 24 | "decoder params" : { 25 | "output dims" : 3 26 | }, 27 | 28 | "classifier" : "ClassifierPixel", 29 | "classifier params" : { 30 | "output dims" : 2 31 | }, 32 | 33 | "discriminator" : "DiscriminatorVGG16", 34 | "discriminator params" : { 35 | "output dims" : 1, 36 | "including top" : true, 37 | "including_top_params" : [512, 256], 38 | "out_activation" : null, 39 | "out_activation_params" : "" 40 | } 41 | } 42 | 43 | } 44 | 45 | 46 | -------------------------------------------------------------------------------- /test/test_pascal_voc.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | import matplotlib.pyplot as plt 10 | 11 | from dataset.pascal_voc import PASCAL_VOC 12 | 13 | if __name__ == '__main__': 14 | 15 | # config = { 16 | # "output shape" : [64, 64, 3], 17 | # "scaling range" : [0.15, 0.25], 18 | # "crop range" : [0.3, 0.7], 19 | # "task" : "classification", 20 | # "random mirroring" : False 21 | # } 22 | 23 | config = { 24 | "output shape" : [256, 256, 3], 25 | "scaling range" : [0.5, 1.5], 26 | "crop range" : [0.3, 0.7], 27 | "task" : "segmentation_class_aug", 28 | # "random mirroring" : False 29 | } 30 | 31 | 32 | dataset = PASCAL_VOC(config) 33 | indices = dataset.get_image_indices(phase='train', method='supervised') 34 | 35 | print(indices.shape) 36 | 37 | for ind in indices: 38 | 39 | img, mask = dataset.read_image_by_index(ind, phase='train', method='supervised') 40 | 41 | print(img.shape) 42 | print(mask.shape) 43 | print(mask.max(), mask.min()) 44 | 45 | plt.figure(0) 46 | plt.imshow(img) 47 | plt.figure(1) 48 | plt.imshow(mask) 49 | plt.pause(4) 50 | 51 | 52 | -------------------------------------------------------------------------------- /netutils/activation.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import tensorflow as tf 4 | import tensorflow.contrib.layers as tcl 5 | 6 | 7 | 8 | def get_lrelu(params): 9 | if params == None: 10 | leak=0.1 11 | else : 12 | leak=float(params[0]) 13 | 14 | def lrelu(x, leak=leak, name="lrelu"): 15 | with tf.variable_scope(name): 16 | return tf.maximum(x, x*leak) 17 | return lrelu 18 | 19 | 20 | def get_activation(name_config): 21 | name = name_config.split()[0] 22 | 23 | if len(name_config.split()) > 1: 24 | params = name_config.split()[1:] 25 | else: 26 | params = None 27 | 28 | if name == 'relu': 29 | return tf.nn.relu 30 | elif name == 'lrelu' or name == 'leaky_relu': 31 | return get_lrelu(params) 32 | elif name == 'softmax' : 33 | return tf.nn.softmax 34 | elif name == 'sigmoid': 35 | return tf.nn.sigmoid 36 | elif name == 'tanh': 37 | return tf.tanh 38 | elif name == 'softplus': 39 | return tf.nn.softplus 40 | elif name == 'elu': 41 | return tf.nn.elu 42 | elif name == 'none' : 43 | return None 44 | else : 45 | raise Exception("None actiavtion named " + name) 46 | 47 | -------------------------------------------------------------------------------- /test/test_vgg.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | sys.path.append('.') 6 | sys.path.append('../') 7 | 8 | import tensorflow as tf 9 | from network.vgg import VGG 10 | 11 | if __name__ == '__main__': 12 | config = { 13 | 14 | "normalization" : "fused_batch_norm", 15 | 16 | "including conv" : True, 17 | "conv nb blocks" : 6, 18 | "conv nb layers" : [2, 2, 3, 3, 3, 0], 19 | "conv nb filters" : [64, 128, 256, 512, 512], 20 | "conv ksize" : [3, 3, 3, 3, 3], 21 | 22 | "including top" : True, 23 | "fc nb nodes" : [1024, 1024], 24 | 25 | "output dims" : 12, 26 | 'name' : 'VGG16', 27 | 28 | 29 | 'load pretrained weights' : 'config tianchi/guangdong3 classifier' 30 | } 31 | 32 | model = VGG(config, True) 33 | x = tf.placeholder(tf.float32, shape=(None, 256, 256, 3), name='input') 34 | y, end_points = model(x) 35 | 36 | 37 | for name, value in end_points.items(): 38 | print(name, ' --> ', value.get_shape()) 39 | 40 | for var in model.vars: 41 | print(var.name, ' --> ', var.get_shape()) 42 | 43 | tfconfig = tf.ConfigProto() 44 | tfconfig.gpu_options.allow_growth = True 45 | 46 | with tf.Session(config=tfconfig) as sess: 47 | ret = model.load_pretrained_weights(sess) 48 | 49 | print(ret) 50 | 51 | 52 | -------------------------------------------------------------------------------- /cfgs/mnist1.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "inceptionV3_1", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "semi-supervised" : true, 7 | "nb_labelled_image_per_class" : 100, 8 | "input shape" : [28, 28, 1], 9 | "batch_size" : 128 10 | }, 11 | 12 | "assets dir" : "assets/mnist_100", 13 | 14 | "ganmodel" : "classification", 15 | "ganmodel params" : { 16 | "name" : "classify", 17 | 18 | "input shape" : [28, 28, 1], 19 | "nb_classes" : 10, 20 | 21 | "optimizer" : "sgd", 22 | "lr" : 0.001, 23 | "lr_scheme" : "exponential", 24 | "lr_params" : { 25 | "decay_steps" : 10000, 26 | "decay_rate" : 0.9 27 | }, 28 | 29 | "classification loss" : "cross entropy", 30 | 31 | "summary" : true, 32 | 33 | "classifier" : "VGG", 34 | "classifier params" : { 35 | "no maxpooling" : true, 36 | "nb_conv_blocks" : 3, 37 | "nb_conv_layers" : [2, 2, 2], 38 | "nb_conv_filters" : [16, 32, 64], 39 | 40 | "including top" : true, 41 | "nb_fc_nodes" : [256, 128], 42 | 43 | "output dims" : 10 44 | } 45 | 46 | }, 47 | 48 | "trainer" : "supervised", 49 | "trainer params" : { 50 | "continue train" : false, 51 | "train steps" : 10000, 52 | "summary steps" : 1000, 53 | "log steps" : 100, 54 | "save checkpoint steps" : 1000 55 | } 56 | } 57 | 58 | -------------------------------------------------------------------------------- /dataset/kitti.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | sys.path.append('./') 29 | sys.path.append('../') 30 | sys.path.append('./lib') 31 | 32 | import numpy as np 33 | # import matplotlib.pyplot as plt 34 | 35 | from skimage import io 36 | # import pickle 37 | 38 | from .base_dataset import BaseDataset 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /cfgs/seg/voc.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "pascal_voc_segmentation", 3 | 4 | "dataset" : "pascal_voc", 5 | "dataset params" : { 6 | "output shape" : [320, 320, 3] 7 | }, 8 | 9 | "assets dir" : "assets/pascal_voc/unet", 10 | 11 | "model" : "segmentation", 12 | "model params" : { 13 | "name" : "segmentation", 14 | 15 | "input shape" : [320, 320, 3], 16 | "mask shape" : [320, 320, 21], 17 | "nb classes" : 21, 18 | 19 | "optimizer" : "adam", 20 | "optimizer params" : { 21 | "lr" : 0.0001, 22 | "lr scheme" : "exponential", 23 | "lr params" : { 24 | "decay_steps" : 10000, 25 | "decay_rate" : 0.2 26 | } 27 | }, 28 | 29 | "segmentation loss" : "cross entropy", 30 | 31 | "summary" : true, 32 | 33 | "classifier" : "classifier_unet", 34 | "classifier params" : { 35 | "weightinit" : "normal 0.00 0.1", 36 | "debug" : true 37 | } 38 | }, 39 | 40 | "trainer" : "supervised", 41 | "trainer params" : { 42 | 43 | "summary hyperparams string" : "learning_rate_0_0001", 44 | 45 | "continue train" : false, 46 | "multi thread" : true, 47 | "batch_size" : 4, 48 | "train steps" : 30000, 49 | "summary steps" : 500, 50 | "log steps" : 100, 51 | "save checkpoint steps" : 10000, 52 | 53 | "validators" : [ 54 | { 55 | "validator" : "validate_segmentation", 56 | "validate steps" : 100, 57 | "validator params" : { 58 | } 59 | } 60 | ] 61 | } 62 | } 63 | 64 | -------------------------------------------------------------------------------- /validator/pca_plot_validator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import numpy as np 27 | import matplotlib.pyplot as plt 28 | from scipy.stats import norm 29 | 30 | 31 | from .base_validator import BaseValidator 32 | 33 | class PCAPlotValidator(object): 34 | 35 | def __init__(self, config): 36 | super(PCAPlotValidator, self).__init__(config) 37 | self.assets_dir = config['assets dir'] 38 | 39 | def validate(self, model, dataset, sess, step): 40 | return NotImplementedError 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /cfgs/seg/voc2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "pascal_voc_segmentation", 3 | 4 | "dataset" : "pascal_voc", 5 | "dataset params" : { 6 | "output shape" : [160, 160, 3], 7 | "scaling range" : [0.25, 0.75] 8 | }, 9 | 10 | "assets dir" : "assets/pascal_voc/unet2", 11 | 12 | "model" : "segmentation", 13 | "model params" : { 14 | "name" : "segmentation", 15 | 16 | "input shape" : [160, 160, 3], 17 | "mask shape" : [160, 160, 21], 18 | "nb classes" : 21, 19 | 20 | "optimizer" : "adam", 21 | "optimizer params" : { 22 | "lr" : 0.0001 23 | // "lr scheme" : "exponential", 24 | // "lr params" : { 25 | // "decay_steps" : 30000, 26 | // "decay_rate" : 0.2 27 | // } 28 | }, 29 | 30 | "segmentation loss" : "cross entropy", 31 | 32 | "summary" : true, 33 | 34 | "classifier" : "classifier_unet", 35 | "classifier params" : { 36 | "normalization" : "none", 37 | "weightsinit" : "xavier", 38 | "conv nb blocks" : 3, 39 | "debug" : true 40 | } 41 | }, 42 | 43 | "trainer" : "supervised", 44 | "trainer params" : { 45 | 46 | "summary hyperparams string" : "learning_rate_0_0001", 47 | 48 | "continue train" : true, 49 | "multi thread" : true, 50 | "batch_size" : 1, 51 | "train steps" : 100000, 52 | "summary steps" : 500, 53 | "log steps" : 100, 54 | "save checkpoint steps" : 10000, 55 | 56 | "validators" : [ 57 | { 58 | "validator" : "validate_segmentation", 59 | "validate steps" : 500, 60 | "validator params" : { 61 | "log dir" : "val_seg_learning_rate_0_0001" 62 | } 63 | } 64 | ] 65 | } 66 | } 67 | 68 | -------------------------------------------------------------------------------- /encoder/encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | 29 | 30 | 31 | def get_encoder(name, config, is_training): 32 | if name == 'encoder' or name == 'EncoderSimple': 33 | from .encoder_simple import EncoderSimple 34 | return EncoderSimple(config, is_training) 35 | elif name == 'mnist encoder' or name == 'encoder_mnist': 36 | from .encoder_mnist import EncoderMnist 37 | return EncoderMnist(config, is_training) 38 | elif name == 'cifar10 encoder' or name == 'encoder_cifar10': 39 | from .encoder_cifar10 import EncoderCifar10 40 | return EncoderCifar10(config, is_training) 41 | else: 42 | raise Exception("None encoder named " + name) 43 | -------------------------------------------------------------------------------- /test/test_violence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('../') 6 | 7 | import numpy as np 8 | import cv2 9 | import tensorflow as tf 10 | import matplotlib.pyplot as plt 11 | 12 | from dataset.violence import Violence 13 | from dataset.tianchi_guangdong_defect import TianChiGuangdongDefect 14 | 15 | 16 | if __name__ == '__main__': 17 | config = { 18 | "output shape" : [224, 224, 3], 19 | "mil" : False, 20 | "use cache" : True, 21 | "one hot" : True, 22 | "show warning" : True 23 | } 24 | 25 | dataset = TianChiGuangdongDefect(config) 26 | indices = dataset.get_image_indices('trainval') 27 | 28 | print(len(indices)) 29 | 30 | img_list = [] 31 | 32 | for ind in indices: 33 | img, label = dataset.read_image_by_index(ind, phase='trainval', method='supervised') 34 | # print(label) 35 | 36 | dataset.time1 = 0.0 37 | dataset.count = 0 38 | 39 | print("") 40 | print("") 41 | print("round 2") 42 | print("") 43 | print("") 44 | for ind in indices: 45 | img, label = dataset.read_image_by_index(ind, phase='trainval', method='supervised') 46 | # print(label) 47 | # if img is not None: 48 | # plt.figure(0) 49 | # plt.clf() 50 | # plt.imshow(img) 51 | # plt.pause(1) 52 | 53 | config = { 54 | "output shape" : [224, 224, 3], 55 | } 56 | 57 | 58 | dataset = TianChiGuangdongDefect(config) 59 | indices = dataset.get_image_indices('trainval') 60 | 61 | 62 | # for ind in indices: 63 | # img_bag, label = dataset.read_image_by_index(ind, phase='trainval', method='supervised') 64 | # print(label) 65 | # if img_bag is not None: 66 | 67 | # plt.figure(0) 68 | # plt.clf() 69 | 70 | # row = 4 71 | # col = int(len(img_bag) / row) 72 | 73 | # print(len(img_bag), row, col) 74 | 75 | # for i in range(row): 76 | # for j in range(col): 77 | # plt.subplot(row, col, i * col+j+1) 78 | # plt.imshow(img_bag[i*col+j]) 79 | # plt.pause(3) 80 | -------------------------------------------------------------------------------- /generator/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | 29 | 30 | def get_generator(name, config, is_training): 31 | if name == 'generator': 32 | from .generator_simple import GeneratorSimple 33 | return GeneratorSimple(config, is_training) 34 | 35 | elif name == 'cifar10 resnet generator' or name == 'generator_cifar10_resnet': 36 | from .generator_cifar10_resnet import GeneratorCifar10ResNet 37 | return GeneratorCifar10ResNet(config, is_training) 38 | 39 | elif name == 'generator_conv': 40 | from .generator_conv import G_conv 41 | return G_conv(config, is_training) 42 | 43 | else: 44 | raise Exception("None Generator named " + name) 45 | 46 | -------------------------------------------------------------------------------- /generator/generator_simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | sys.path.append('../') 28 | 29 | from network.devgg import DEVGG 30 | from network.base_network import BaseNetwork 31 | 32 | class GeneratorSimple(BaseNetwork): 33 | 34 | def __init__(self, config, is_training): 35 | BaseNetwork.__init__(self, config, is_training) 36 | 37 | self.name = config.get('name', 'GeneratorSimple') 38 | self.config = config 39 | 40 | network_config = config.copy() 41 | network_config['name'] = self.name 42 | self.network = DEVGG(network_config, is_training) 43 | self.reuse=False 44 | 45 | def __call__(self, i): 46 | x, end_points = self.network(i) 47 | return x 48 | 49 | -------------------------------------------------------------------------------- /cfgs/networkconfig.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | import json 29 | import yaml 30 | import re 31 | 32 | def get_config(name, disp=False): 33 | with open(os.path.join('cfgs', name + '.json'), 'r') as config_file: 34 | 35 | # remove the comment in json file 36 | jsonstring = [] 37 | for line in config_file: 38 | pure_line = re.sub('([^:]//.*"?$)|(/\*(.*?)\*/)','',line) 39 | jsonstring.append(pure_line) 40 | 41 | if disp: 42 | for i, line in enumerate(jsonstring): 43 | print('%d: %s'%(i, line[:-1])) 44 | 45 | jsonstring = ''.join(jsonstring) 46 | config_json = json.loads(jsonstring) 47 | return config_json 48 | 49 | 50 | def print_config(name): 51 | pass 52 | 53 | 54 | -------------------------------------------------------------------------------- /cfgs/cla/imagenet1.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "config name" : "imagenet_resnet50", 4 | 5 | "dataset" : "imagenet", 6 | "dataset params" : { 7 | "output shape" : [224, 224, 3] 8 | }, 9 | 10 | "assets dir" : "assets/imagenet/resnet50", 11 | 12 | "model" : "classification", 13 | "model params" : { 14 | "name" : "imagenet", 15 | 16 | "input shape" : [224, 224, 3], 17 | "nb classes" : 12, 18 | 19 | "optimizer" : "adam", 20 | "optimizer params" : { 21 | "lr" : 0.0001, 22 | "lr scheme" : "exponential", 23 | "lr params" : { 24 | "decay_steps" : 70000, 25 | "decay_rate" : 0.1 26 | } 27 | }, 28 | 29 | "classification loss" : "cross entropy", 30 | 31 | "summary" : true, 32 | 33 | "classifier" : "classifier", 34 | "classifier params" : { 35 | "base network" : "resnet", 36 | 37 | "architecture" : "resnet50", 38 | 39 | "load pretrained weight" : "resnet50", 40 | "debug" : true 41 | } 42 | 43 | }, 44 | 45 | "trainer" : "supervised", 46 | "trainer params" : { 47 | "summary hyperparams string" : "lr_0_0001_adam", 48 | 49 | "multi thread" : true, 50 | "buffer depth" : 100, 51 | "nb threads" : 16, 52 | "batch_size" : 16, 53 | 54 | "train steps" : 150000, 55 | "summary steps" : 1000, 56 | "log steps" : 500, 57 | "save checkpoint steps" : 10000, 58 | 59 | "validators" : [ 60 | { 61 | "validator" : "dataset_validator", 62 | "validate steps" : 2000, 63 | "validator params" : { 64 | "metric" : "accuracy", 65 | "metric type" : "top1", 66 | "nb samples" : 200, 67 | "batch_size" : 8, 68 | "mil" : false 69 | } 70 | } 71 | ] 72 | } 73 | } 74 | 75 | 76 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | # 27 | 28 | def get_trainer(name, config, model, sess): 29 | 30 | if name == 'supervised': 31 | from .supervised import SupervisedTrainer 32 | return SupervisedTrainer(config, model, sess) 33 | 34 | elif name == 'unsupervised' : 35 | from .unsupervised import UnsupervisedTrainer 36 | return UnsupervisedTrainer(config, model, sess) 37 | 38 | elif name == 'semisupervised' or name == 'semi-supervised': 39 | from .semisupervised import SemiSupervisedTrainer 40 | return SemiSupervisedTrainer(config, model, sess) 41 | 42 | elif name == 'supervised mil' or name == 'supervised_mil': 43 | from .supervised_mil import SupervisedMILTrainer 44 | return SupervisedMILTrainer(config, model, sess) 45 | 46 | else: 47 | raise Exception('None trainer named ' + name) 48 | 49 | 50 | -------------------------------------------------------------------------------- /cfgs/cla/violence2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "violence_vgg", 3 | 4 | "dataset" : "violence", 5 | "dataset params" : { 6 | "output shape" : [224, 224, 3] 7 | }, 8 | 9 | "assets dir" : "assets/violence/resnet", 10 | 11 | "model" : "classification", 12 | "model params" : { 13 | "name" : "vgg", 14 | 15 | "input shape" : [224, 224, 3], 16 | "nb classes" : 2, 17 | 18 | "optimizer" : "adam", 19 | "optimizer params" : { 20 | "lr" : 0.001, 21 | "lr_scheme" : "exponential", 22 | "lr_params" : { 23 | "decay_steps" : 300000, 24 | "decay_rate" : 0.2 25 | } 26 | }, 27 | 28 | "classification loss" : "cross entropy", 29 | 30 | "summary" : true, 31 | 32 | "classifier" : "classifier", 33 | "classifier params" : { 34 | "base network" : "resnet", 35 | 36 | "architecture" : "resnet50", 37 | 38 | "activation" : "relu", 39 | "normalization" : "fused_batch_norm", 40 | 41 | "including top" : true, 42 | "fc nb nodes" : [2048, 2048], 43 | 44 | "output dims" : 2, 45 | "output_activation" : "none", 46 | 47 | "debug" : true 48 | } 49 | }, 50 | 51 | "trainer" : "supervised", 52 | "trainer params" : { 53 | 54 | "summary hyperparams string" : "learning_rate_0_001_adam", 55 | 56 | "continue train" : false, 57 | "multi thread" : true, 58 | "buffer depth" : 100, 59 | 60 | "batch_size" : 8, 61 | 62 | "train steps" : 1000000, 63 | "summary steps" : 30000, 64 | "log steps" : 100, 65 | "save checkpoint steps" : 30000, 66 | 67 | "validators" : [ 68 | { 69 | "validator" : "dataset_validator", 70 | "validate steps" : 2000, 71 | "has summary" : true, 72 | "validator params" : { 73 | "metric" : "accuracy", 74 | "metric type" : "top1", 75 | "batch_size" : 8 76 | } 77 | } 78 | ] 79 | } 80 | } 81 | 82 | -------------------------------------------------------------------------------- /cfgs/tianchi/guangdong2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "guangdong_defect_classification", 3 | 4 | "dataset" : "tianchi guangdong defect", 5 | "dataset params" : { 6 | "output shape" : [256, 256, 3], 7 | "one hot" : true, 8 | "use cache" : true, 9 | "mil" : false, 10 | "show warning" : false 11 | }, 12 | 13 | "assets dir" : "assets/tianchi_guangdong/result6", 14 | 15 | "model" : "classification", 16 | "model params" : { 17 | "name" : "tianchi", 18 | 19 | "input shape" : [256, 256, 3], 20 | "nb classes" : 12, 21 | 22 | "optimizer" : "adam", 23 | "optimizer params" : { 24 | "lr" : 0.0001, 25 | "lr scheme" : "exponential", 26 | "lr params" : { 27 | "decay_steps" : 30000, 28 | "decay_rate" : 0.2 29 | } 30 | }, 31 | 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | 39 | "base network" : "resnet", 40 | "normalization" : "batch_norm", 41 | 42 | "load pretrained weights" : "resnet50", 43 | "debug" : true 44 | } 45 | }, 46 | 47 | "trainer" : "supervised", 48 | "trainer params" : { 49 | "summary hyperparams string" : "lr_0_0001_adam_resnet", 50 | 51 | "continue train" : false, 52 | 53 | "multi thread" : true, 54 | "buffer depth" : 100, 55 | "nb threads" : 16, 56 | "batch_size" : 16, 57 | 58 | "train steps" : 50000, 59 | "summary steps" : 1000, 60 | "log steps" : 500, 61 | "save checkpoint steps" : 10000, 62 | 63 | "validators" : [ 64 | { 65 | "validator" : "dataset_validator", 66 | "validate steps" : 2000, 67 | "validator params" : { 68 | "metric" : "accuracy", 69 | "metric type" : "top1", 70 | "nb samples" : 200, 71 | "batch_size" : 8, 72 | "mil" : false 73 | } 74 | } 75 | ] 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE-GAN 2 | 3 | This repo implements many recently emerged generative models, such as GAN, VAE, DCGAN, WGAN, WGAN-GP, and some semi-supervised model such as AAE, SemiDGM, this code is just for learning the generative models and for fast developing algorithms. 4 | 5 | There is some problem with my code, for WGAN, I found it very likely to cause model collpase(after 30k iters on Cifar10 dataset) and the generated sample quality goes worse. For semi-supervised model AAE, I achieved 96% accuracy on MNIST dataset with 10 labels per class which is below the paper claimed accuracy 98.1%. I will keep refining this repo to support more generative and semi-supervised algorithms. 6 | 7 | Some code is outdated and may cause bug in running, please email to me: yznzhicong1069163331@outlook.com 8 | 9 | ***** 10 | 11 | this code is running with python3 and tensorflow1.9.0 on both Windows and Ubuntu 12 | 13 | ***** 14 | 15 | # Currently Implemented Models 16 | 17 | ## GAN 18 | 19 | [DCGAN](http://arxiv.org/abs/1511.06434) 20 | 21 | [Improved-GAN](http://arxiv.org/abs/1606.03498) 22 | 23 | [WGAN-GP](http://arxiv.org/abs/1704.00028) 24 | 25 | ## VAE 26 | 27 | [VAE]() 28 | 29 | [AAE](http://arxiv.org/abs/1511.05644) 30 | 31 | ***** 32 | 33 | # Training 34 | 35 | before training, you must specify the dataset location. the currently supported dataset is 36 | 1. MNIST 37 | 2. Cifar10 38 | 3. Imagenet 39 | 4. PASCAL_VOC 40 | 5. 41 | 42 | open the py file under dataset folder, for each dataset I write some if control flow to find the dataset location. please add the dataset location to the control flow. 43 | 44 | the config files is under the cfgs folder, I write config file in json format, the dataset, model and train method are all specified in this file. you can train it with train.py. If you want to run the models, just run the following command: 45 | 46 | python(3) train.py --config= --gpu= 47 | 48 | the "assets dir" in config file is the folder where the result stores. the tensorboard log file is under the 'log' folder 49 | 50 | # Result 51 | 52 | ## AAE Semi-supervised Classification with 100 labels 53 | 54 | python3 train.py --config=aae/mnist_ssl 55 | 56 | ![aae_mnist_ssl](./cfgs/aae/aae_mnist_ssl_acc.png) 57 | 58 | 59 | -------------------------------------------------------------------------------- /cfgs/cvae1.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "cvae_1", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | }, 7 | 8 | "assets dir" : "assets/cvae_1_2", 9 | 10 | "model" : "cvae", 11 | "model params" : { 12 | "name" : "cvae", 13 | 14 | "input shape" : [28, 28, 1], 15 | "flatten" : true, 16 | "nb_classes" : 10, 17 | 18 | "z_dim" : 2, 19 | "is_training" : true, 20 | "sample_func" : "normal", 21 | 22 | "optimizer" : "rmsprop", 23 | "lr" : 0.001, 24 | "lr_scheme" : "exponential", 25 | "lr_params" : { 26 | "decay_steps" : 1000, 27 | "decay_rate" : 0.9 28 | }, 29 | 30 | "kl loss" : "gaussian", 31 | "kl loss prod" : 0.00001, 32 | "reconstruction loss" : "mse", 33 | "reconstruction loss prod" : 1, 34 | 35 | "summary" : false, 36 | 37 | "x encoder" : "EncoderSimple", 38 | "x encoder params" : { 39 | "nb_conv_blocks" : 0, 40 | "batch_norm" : "none", 41 | "nb_fc_nodes" : [256], 42 | "output_distribution": "gaussian" 43 | }, 44 | 45 | "y encoder" : "EncoderSimple", 46 | "y encoder params" : { 47 | "nb_conv_blocks" : 0, 48 | "batch_norm" : "none", 49 | "nb_fc_nodes" : [], 50 | "output_distribution": "mean" 51 | }, 52 | 53 | "decoder" : "DecoderSimple", 54 | "decoder params" : { 55 | } 56 | }, 57 | 58 | "trainer" : "supervised", 59 | "trainer params" : { 60 | "continue train" : false, 61 | "train steps" : 20000, 62 | "summary steps" : 1000, 63 | "log steps" : 100, 64 | "save checkpoint steps" : 1000, 65 | 66 | "validators" : [ 67 | { 68 | "validator" : "scatter_plot_validator", 69 | "validate steps" : 1000, 70 | "validator params" : { 71 | "watch variable" : "hidden dist", 72 | "x dim" : 0, 73 | "y dim" : 1, 74 | "log dir" : "scatter" 75 | } 76 | } 77 | ] 78 | } 79 | } 80 | 81 | 82 | -------------------------------------------------------------------------------- /cfgs/seg/voc3.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "pascal_voc_segmentation", 3 | 4 | "dataset" : "pascal_voc", 5 | "dataset params" : { 6 | "output shape" : [80, 80, 3], 7 | "scaling range" : [0.125, 0.50] 8 | }, 9 | 10 | "assets dir" : "assets/pascal_voc/unet3", 11 | 12 | "model" : "segmentation", 13 | "model params" : { 14 | "name" : "segmentation", 15 | 16 | "input shape" : [80, 80, 3], 17 | "mask shape" : [80, 80, 21], 18 | "nb classes" : 21, 19 | 20 | "optimizer" : "adam", 21 | "optimizer params" : { 22 | "lr" : 0.0001 23 | // "lr scheme" : "exponential", 24 | // "lr params" : { 25 | // "decay_steps" : 30000, 26 | // "decay_rate" : 0.2 27 | // } 28 | }, 29 | 30 | "segmentation loss" : "cross entropy", 31 | 32 | "summary" : true, 33 | 34 | "classifier" : "classifier", 35 | "classifier params" : { 36 | "normalization" : "batch_norm", 37 | "weightsinit" : "xavier", 38 | "activation" : "relu", 39 | "padding" : "SAME", 40 | 41 | "including conv" : true, 42 | "conv nb blocks" : 1, 43 | "conv nb layers" : [7], 44 | "conv nb filters" : [128], 45 | "conv ksize" : [7], 46 | 47 | "including top" : false, 48 | 49 | "output dims" : 21, 50 | 51 | "debug" : true 52 | } 53 | }, 54 | 55 | "trainer" : "supervised", 56 | "trainer params" : { 57 | 58 | "summary hyperparams string" : "learning_rate_0_0001", 59 | 60 | "continue train" : false, 61 | "multi thread" : true, 62 | "batch_size" : 1, 63 | "train steps" : 30000, 64 | "summary steps" : 500, 65 | "log steps" : 100, 66 | "save checkpoint steps" : 10000, 67 | 68 | "validators" : [ 69 | { 70 | "validator" : "validate_segmentation", 71 | "validate steps" : 500, 72 | "validator params" : { 73 | "log dir" : "val_seg_learning_rate_0_0001" 74 | } 75 | } 76 | ] 77 | } 78 | } 79 | 80 | -------------------------------------------------------------------------------- /discriminator/discriminator_simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | 28 | 29 | import tensorflow as tf 30 | import tensorflow.contrib.layers as tcl 31 | 32 | 33 | sys.path.append('../') 34 | 35 | from netutils.weightsinit import get_weightsinit 36 | from netutils.activation import get_activation 37 | from netutils.normalization import get_normalization 38 | 39 | 40 | from network.vgg import VGG 41 | from network.base_network import BaseNetwork 42 | 43 | 44 | class DiscriminatorSimple(BaseNetwork): 45 | def __init__(self, config, is_training): 46 | BaseNetwork.__init__(self, config, is_training) 47 | self.name = config.get('name', 'DiscriminatorSimple') 48 | self.config = config 49 | 50 | network_config = config.copy() 51 | self.network = VGG(network_config, is_training) 52 | 53 | def __call__(self, i): 54 | x, end_points = self.network(i) 55 | return x 56 | 57 | def features(self, i, condition=None): 58 | x, end_points = self.network(i) 59 | return x, end_points 60 | 61 | -------------------------------------------------------------------------------- /cfgs/vae1.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "vae_1", 3 | 4 | 5 | "dataset" : "mnist", 6 | "dataset params" : { 7 | }, 8 | 9 | "assets dir" : "assets/vae_1", 10 | 11 | "model" : "vae", 12 | "model params" : { 13 | "name" : "vae", 14 | 15 | "input shape" : [28, 28, 1], 16 | "flatten" : true, 17 | 18 | "z_dim" : 2, 19 | "is_training" : true, 20 | "sample_func" : "normal", 21 | 22 | "optimizer" : "rmsprop", 23 | "lr" : 0.001, 24 | "lr_scheme" : "exponential", 25 | "lr_params" : { 26 | "decay_steps" : 1000, 27 | "decay_rate" : 0.9 28 | }, 29 | 30 | "kl loss" : "gaussian", 31 | "kl loss prod" : 0.01, 32 | "reconstruction loss" : "mse", 33 | "reconstruction loss prod" : 1, 34 | 35 | "summary" : true, 36 | "summary dir" : "log", 37 | 38 | "encoder" : "EncoderSimple", 39 | "encoder params" : { 40 | "nb_conv_blocks" : 0, 41 | "batch_norm" : "none", 42 | "nb_fc_nodes" : [256] 43 | }, 44 | 45 | "decoder" : "DecoderSimple", 46 | "decoder params" : { 47 | "nb_conv_blocks" : 0, 48 | "batch_norm" : "none", 49 | "nb_fc_nodes" : [256] 50 | } 51 | }, 52 | 53 | "trainer" : "unsupervised", 54 | "trainer params" : { 55 | "continue train" : false, 56 | "train steps" : 20000, 57 | "summary steps" : 1000, 58 | "log steps" : 100, 59 | "save checkpoint steps" : 1000, 60 | 61 | "validators" : [ 62 | { 63 | "validator" : "hidden_variable_validator", 64 | "validate steps" : 1000, 65 | "validator params" : { 66 | "z_dim" : 2, 67 | "num_samples" : 15 68 | } 69 | }, 70 | { 71 | "validator" : "scatter_plot_validator", 72 | "validate steps" : 1000, 73 | "validator params" : { 74 | "watch variable" : "hidden dist", 75 | "x dim" : 0, 76 | "y dim" : 1 77 | } 78 | } 79 | ] 80 | } 81 | } 82 | 83 | 84 | -------------------------------------------------------------------------------- /cfgs/cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "cifar10_classify", 3 | 4 | "dataset" : "cifar10", 5 | "dataset params" : { 6 | "input shape" : [32, 32, 3], 7 | "batch_size" : 128 8 | }, 9 | 10 | "assets dir" : "assets/cifar10/cifar10_2", 11 | 12 | "ganmodel" : "classification", 13 | "ganmodel params" : { 14 | "name" : "classification", 15 | 16 | "input shape" : [32, 32, 3], 17 | "nb_classes" : 10, 18 | 19 | "optimizer" : "sgd", 20 | "optimizer params" : { 21 | "lr" : 0.0002, 22 | "lr scheme" : "exponential", 23 | "lr params" : { 24 | "decay_steps" : 10000, 25 | "decay_rate" : 0.1 26 | } 27 | }, 28 | 29 | "classification loss" : "cross entropy", 30 | 31 | "summary" : true, 32 | 33 | "classifier" : "classifier", 34 | "classifier params" : { 35 | "normalization" : "batch_norm", 36 | 37 | "including conv" : true, 38 | "nb_conv_blocks" : 4, 39 | "nb_conv_layers" : [2, 2, 3, 3], 40 | "nb_conv_filters" : [32, 64, 128, 256], 41 | "nb_conv_ksize" : [3, 3, 3, 3], 42 | "no maxpooling" : false, 43 | 44 | "including top" : true, 45 | "nb_fc_nodes" : [1024, 1024], 46 | 47 | "output dims" : 10, 48 | "output_activation" : "none", 49 | 50 | 51 | // "conv1_0 activation" : "lrelu 0.2", 52 | 53 | "debug" : true 54 | } 55 | }, 56 | 57 | "trainer" : "supervised", 58 | "trainer params" : { 59 | 60 | "summary hyperparams string" : "learning_rate_0_0002", 61 | 62 | "continue train" : false, 63 | "multi thread" : true, 64 | "batch_size" : 8, 65 | "train steps" : 30000, 66 | "summary steps" : 1000, 67 | "log steps" : 100, 68 | "save checkpoint steps" : 10000, 69 | 70 | "validators" : [ 71 | { 72 | "validator" : "dataset_validator", 73 | "validate steps" : 500, 74 | "has summary" : true, 75 | "validator params" : { 76 | "metric" : "accuracy", 77 | "metric type" : "top1" 78 | } 79 | } 80 | ] 81 | } 82 | } 83 | 84 | -------------------------------------------------------------------------------- /generator/generator_cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | sys.path.append('../') 28 | 29 | import tensorflow as tf 30 | import tensorflow.contrib.layers as tcl 31 | 32 | from netutils.weightsinit import get_weightsinit 33 | from netutils.activation import get_activation 34 | from netutils.normalization import get_normalization 35 | 36 | 37 | # from network.vgg import VGG 38 | from network.devgg import DEVGG 39 | from network.base_network import BaseNetwork 40 | 41 | class GeneratorSimple(BaseNetwork): 42 | 43 | def __init__(self, config, is_training): 44 | BaseNetwork.__init__(self, config, is_training) 45 | 46 | self.name = config.get('name', 'GeneratorSimple') 47 | self.config = config 48 | 49 | network_config = config.copy() 50 | network_config['name'] = self.name 51 | self.network = DEVGG(network_config, is_training) 52 | 53 | self.reuse=False 54 | 55 | def __call__(self, i): 56 | x, end_points = self.network(i) 57 | return x 58 | 59 | # @property 60 | # def vars(self): 61 | # return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 62 | 63 | 64 | -------------------------------------------------------------------------------- /cfgs/cla/violence.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "violence_vgg", 3 | 4 | "dataset" : "violence", 5 | "dataset params" : { 6 | "output shape" : [224, 224, 3] 7 | }, 8 | 9 | "assets dir" : "assets/violence/vgg16", 10 | 11 | "model" : "classification", 12 | "model params" : { 13 | "name" : "vgg", 14 | 15 | "input shape" : [224, 224, 3], 16 | "nb classes" : 2, 17 | 18 | "optimizer" : "adam", 19 | "optimizer params" : { 20 | "lr" : 0.001, 21 | "lr_scheme" : "exponential", 22 | "lr_params" : { 23 | "decay_steps" : 300000, 24 | "decay_rate" : 0.2 25 | } 26 | }, 27 | 28 | "classification loss" : "cross entropy", 29 | 30 | "summary" : true, 31 | 32 | "classifier" : "classifier", 33 | "classifier params" : { 34 | "activation" : "relu", 35 | "normalization" : "fused_batch_norm", 36 | 37 | "including conv" : true, 38 | "conv nb blocks" : 5, 39 | "conv nb layers" : [2, 2, 3, 3, 3], 40 | "conv nb filters" : [64, 128, 256, 512, 512], 41 | "conv ksize" : [3, 3, 3, 3, 3], 42 | "no maxpooling" : true, 43 | 44 | "including top" : true, 45 | "fc nb nodes" : [2048, 2048], 46 | 47 | "output dims" : 2, 48 | "output_activation" : "none", 49 | 50 | "debug" : true 51 | } 52 | }, 53 | 54 | "trainer" : "supervised", 55 | "trainer params" : { 56 | 57 | "summary hyperparams string" : "learning_rate_0_001_adam2", 58 | 59 | "continue train" : false, 60 | "multi thread" : true, 61 | "buffer depth" : 100, 62 | 63 | "batch_size" : 8, 64 | 65 | "train steps" : 1000000, 66 | "summary steps" : 30000, 67 | "log steps" : 100, 68 | "save checkpoint steps" : 30000, 69 | 70 | "validators" : [ 71 | { 72 | "validator" : "dataset_validator", 73 | "validate steps" : 2000, 74 | "has summary" : true, 75 | "validator params" : { 76 | "metric" : "accuracy", 77 | "metric type" : "top1", 78 | "batch_size" : 8 79 | } 80 | } 81 | ] 82 | } 83 | } 84 | 85 | -------------------------------------------------------------------------------- /cfgs/cvae3.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "cvae_3", 3 | 4 | "dataset" : "cifar10", 5 | "dataset params" : { 6 | }, 7 | 8 | "assets dir" : "assets/cvae_3_2", 9 | 10 | "model" : "cvae", 11 | "model params" : { 12 | "name" : "cvae", 13 | 14 | "input shape" : [32, 32, 3], 15 | "flatten" : true, 16 | "nb_classes" : 10, 17 | 18 | "z_dim" : 2, 19 | "is_training" : true, 20 | "sample_func" : "normal", 21 | 22 | "optimizer" : "adam", 23 | "lr" : 0.001, 24 | "lr_scheme" : "exponential", 25 | "lr_params" : { 26 | "decay_steps" : 1000, 27 | "decay_rate" : 0.9 28 | }, 29 | 30 | "kl loss" : "gaussian", 31 | "kl loss prod" : 0.00001, 32 | "reconstruction loss" : "mse", 33 | "reconstruction loss prod" : 1, 34 | 35 | "summary" : true, 36 | "summary dir" : "log", 37 | 38 | "x encoder" : "EncoderSimple", 39 | "x encoder params" : { 40 | "nb_conv_blocks" : 0, 41 | "batch_norm" : "none", 42 | "nb_fc_nodes" : [256], 43 | "output_distribution": "gaussian", 44 | "_comment" : "no convolution layers and batch normalization, just a single hidden layer with 256 nodes" 45 | }, 46 | 47 | "y encoder" : "EncoderSimple", 48 | "y encoder params" : { 49 | "nb_conv_blocks" : 0, 50 | "batch_norm" : "none", 51 | "nb_fc_nodes" : [256], 52 | "output_distribution": "mean" 53 | }, 54 | 55 | "decoder" : "DecoderSimple", 56 | "decoder params" : { 57 | } 58 | }, 59 | 60 | "trainer" : "supervised", 61 | "trainer params" : { 62 | "continue train" : false, 63 | "train steps" : 20000, 64 | "summary steps" : 1000, 65 | "log steps" : 100, 66 | "save checkpoint steps" : 1000, 67 | 68 | "validators" : [ 69 | { 70 | "validator" : "scatter_plot_validator", 71 | "validate steps" : 1000, 72 | "validator params" : { 73 | "watch variable" : "hidden dist", 74 | "x dim" : 0, 75 | "y dim" : 1, 76 | "log dir" : "scatter1" 77 | } 78 | } 79 | ] 80 | } 81 | } 82 | 83 | 84 | -------------------------------------------------------------------------------- /cfgs/cla/mnist1.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_classification", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "semi-supervised" : true, 7 | "nb_labelled_images_per_class" : 100, 8 | "output shape" : [28, 28, 1], 9 | "batch_size" : 128 10 | }, 11 | 12 | "assets dir" : "assets/mnist/result1", 13 | 14 | "model" : "classification", 15 | "model params" : { 16 | "name" : "mnist", 17 | 18 | "input shape" : [28, 28, 1], 19 | "nb classes" : 10, 20 | 21 | "optimizer" : "adam", 22 | "optimizer params" : { 23 | "lr" : 0.001, 24 | "lr scheme" : "exponential", 25 | "lr params" : { 26 | "decay_steps" : 10000, 27 | "decay_rate" : 0.1 28 | }, 29 | "beta1" : 0.5, 30 | "beta2" : 0.9 31 | }, 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | "batch_norm" : "fused_batch_norm", 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 3, 42 | "conv nb layers" : [2, 2, 2], 43 | "conv nb filters" : [32, 64, 128], 44 | "conv ksize" : [3, 3, 3], 45 | "no maxpooling" : true, 46 | 47 | "including top" : true, 48 | "fc nb nodes" : [600, 600], 49 | 50 | "output dims" : 10, 51 | "output_activation" : "none", 52 | 53 | "debug" : true 54 | } 55 | }, 56 | 57 | "trainer" : "supervised", 58 | "trainer params" : { 59 | 60 | "summary hyperparams string" : "learning_rate_0_001_adam", 61 | 62 | "continue train" : true, 63 | "multi thread" : true, 64 | "batch_size" : 32, 65 | "train steps" : 20000, 66 | "summary steps" : 1000, 67 | "log steps" : 100, 68 | "save checkpoint steps" : 10000, 69 | 70 | "validators" : [ 71 | { 72 | "validator" : "dataset_validator", 73 | "validate steps" : 500, 74 | "has summary" : true, 75 | "validator params" : { 76 | "metric" : "accuracy", 77 | "metric type" : "top1" 78 | } 79 | } 80 | ] 81 | } 82 | } 83 | 84 | -------------------------------------------------------------------------------- /cfgs/cla/cifar10.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "cifar10_classify", 3 | 4 | "dataset" : "cifar10", 5 | "dataset params" : { 6 | 7 | "output shape" : [32, 32, 3], 8 | "batch_size" : 128 9 | }, 10 | 11 | "assets dir" : "assets/cifar10/cifar10_2", 12 | 13 | "model" : "classification", 14 | "model params" : { 15 | "name" : "classification", 16 | 17 | "input shape" : [32, 32, 3], 18 | "nb classes" : 10, 19 | 20 | "optimizer" : "sgd", 21 | "optimizer params" : { 22 | "lr" : 0.0002, 23 | "lr scheme" : "exponential", 24 | "lr params" : { 25 | "decay_steps" : 10000, 26 | "decay_rate" : 0.1 27 | } 28 | }, 29 | 30 | "classification loss" : "cross entropy", 31 | 32 | "summary" : true, 33 | 34 | "classifier" : "classifier", 35 | "classifier params" : { 36 | "activation" : "relu", 37 | "normalization" : "batch_norm", 38 | 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 4, 42 | "conv nb layers" : [2, 2, 3, 3], 43 | "conv nb filters" : [32, 64, 128, 256], 44 | "conv ksize" : [3, 3, 3, 3], 45 | "no maxpooling" : false, 46 | 47 | "including top" : true, 48 | "nb_fc_nodes" : [1024, 1024], 49 | 50 | "output dims" : 10, 51 | "output_activation" : "none", 52 | 53 | 54 | // "conv1_0 activation" : "lrelu 0.2", 55 | 56 | "debug" : true 57 | } 58 | }, 59 | 60 | "trainer" : "supervised", 61 | "trainer params" : { 62 | 63 | "summary hyperparams string" : "learning_rate_0_0002", 64 | 65 | "continue train" : false, 66 | "multi thread" : true, 67 | "batch_size" : 8, 68 | "train steps" : 30000, 69 | "summary steps" : 1000, 70 | "log steps" : 100, 71 | "save checkpoint steps" : 10000, 72 | 73 | "validators" : [ 74 | { 75 | "validator" : "dataset_validator", 76 | "validate steps" : 500, 77 | "has summary" : true, 78 | "validator params" : { 79 | "metric" : "accuracy", 80 | "metric type" : "top1" 81 | } 82 | } 83 | ] 84 | } 85 | } 86 | 87 | -------------------------------------------------------------------------------- /cfgs/vae/mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "vae", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | }, 7 | 8 | "assets dir" : "assets/vae/mnist", 9 | 10 | "model" : "vae", 11 | "model params" : { 12 | "name" : "vae", 13 | 14 | "input shape" : [28, 28, 1], 15 | 16 | "z_dim" : 2, 17 | "is_training" : true, 18 | 19 | "optimizer" : "rmsprop", 20 | "optimizer params" : { 21 | "lr" : 0.001, 22 | "lr_scheme" : "exponential", 23 | "lr_params" : { 24 | "decay_steps" : 1000, 25 | "decay_rate" : 0.9 26 | }, 27 | }, 28 | 29 | "kl loss" : "gaussian", 30 | "kl loss prod" : 0.01, 31 | "reconstruction loss" : "mse", 32 | "reconstruction loss prod" : 1, 33 | 34 | "summary" : true, 35 | "summary dir" : "log", 36 | 37 | "encoder" : "encoder", 38 | "encoder params" : { 39 | "normalization" : "none", 40 | 41 | "including conv" : false, 42 | 43 | "including top" : true, 44 | "nb_fc_nodes" : [256] 45 | }, 46 | 47 | "decoder" : "decoder", 48 | "decoder params" : { 49 | "normalization" : "none", 50 | 51 | "including_bottom" : true, 52 | "fc nb nodes" : [256], 53 | 54 | "including_deconv" : false 55 | } 56 | }, 57 | 58 | "trainer" : "unsupervised", 59 | "trainer params" : { 60 | "continue train" : false, 61 | "train steps" : 20000, 62 | "summary steps" : 1000, 63 | "log steps" : 100, 64 | "save checkpoint steps" : 1000, 65 | 66 | "validators" : [ 67 | { 68 | "validator" : "hidden_variable_validator", 69 | "validate steps" : 1000, 70 | "validator params" : { 71 | "z_dim" : 2, 72 | "num_samples" : 15 73 | } 74 | }, 75 | { 76 | "validator" : "scatter_plot_validator", 77 | "validate steps" : 1000, 78 | "validator params" : { 79 | "watch variable" : "hidden dist", 80 | "x dim" : 0, 81 | "y dim" : 1 82 | } 83 | } 84 | ] 85 | } 86 | } 87 | 88 | 89 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | def get_dataset(name, config): 27 | 28 | if name == 'imagenet': 29 | from .imagenet import ImageNet 30 | return ImageNet(config) 31 | 32 | elif name == 'celeba': 33 | from .celeba import CelebA 34 | return CelebA(config) 35 | 36 | elif name == 'mnist': 37 | from .mnist import MNIST 38 | return MNIST(config) 39 | 40 | elif name == 'cifar10': 41 | from .cifar10 import Cifar10 42 | return Cifar10(config) 43 | 44 | elif name == 'pascal_voc': 45 | from .pascal_voc import PASCAL_VOC 46 | return PASCAL_VOC(config) 47 | 48 | elif name == 'production': 49 | from .production import ChipProduction 50 | return ChipProduction(config) 51 | 52 | elif name == 'gan_toy' or name == 'gan toy' or name == 'toy gan': 53 | from .gan_toy import GanToy 54 | return GanToy(config) 55 | 56 | elif name == 'violence': 57 | from .violence import Violence 58 | return Violence(config) 59 | 60 | elif name == 'tianchi guangdong defect': 61 | from .tianchi_guangdong_defect import TianChiGuangdongDefect 62 | return TianChiGuangdongDefect(config) 63 | 64 | else: 65 | raise Exception('None dataset named ' + name) 66 | 67 | -------------------------------------------------------------------------------- /cfgs/cla/mnist2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_classification", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | // "semi-supervised" : true, 7 | // "nb_labelled_images_per_class" : 10, 8 | "output shape" : [28, 28, 1], 9 | "batch_size" : 128 10 | }, 11 | 12 | "assets dir" : "assets/mnist/result2", 13 | 14 | "model" : "classification", 15 | "model params" : { 16 | "name" : "mnist", 17 | 18 | "input shape" : [28, 28, 1], 19 | "nb classes" : 10, 20 | 21 | "optimizer" : "adam", 22 | "optimizer params" : { 23 | "lr" : 0.001, 24 | "lr scheme" : "exponential", 25 | "lr params" : { 26 | "decay_steps" : 10000, 27 | "decay_rate" : 0.1 28 | }, 29 | "beta1" : 0.5, 30 | "beta2" : 0.9 31 | }, 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | "normalization" : "fused_batch_norm", 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 3, 42 | "conv nb layers" : [2, 2, 2], 43 | "conv nb filters" : [32, 64, 128], 44 | "conv ksize" : [3, 3, 3], 45 | "no maxpooling" : true, 46 | 47 | "including top" : true, 48 | "fc nb nodes" : [600, 600], 49 | 50 | "output dims" : 10, 51 | "output_activation" : "none", 52 | 53 | "debug" : true 54 | } 55 | }, 56 | 57 | "trainer" : "supervised", 58 | "trainer params" : { 59 | 60 | "summary hyperparams string" : "learning_rate_0_001_adam", 61 | 62 | "continue train" : false, 63 | "multi thread" : true, 64 | 65 | 66 | "batch_size" : 32, 67 | "train steps" : 20000, 68 | "summary steps" : 1000, 69 | "log steps" : 100, 70 | "save checkpoint steps" : 10000, 71 | 72 | "validators" : [ 73 | { 74 | "validator" : "dataset_validator", 75 | "validate steps" : 500, 76 | "has summary" : true, 77 | "validator params" : { 78 | "metric" : "accuracy", 79 | "metric type" : "top1" 80 | } 81 | } 82 | ] 83 | } 84 | } 85 | 86 | -------------------------------------------------------------------------------- /cfgs/cla/mnist3.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_classification", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "semi-supervised" : true, 7 | "nb_labelled_images_per_class" : 50, 8 | "output shape" : [28, 28, 1], 9 | "batch_size" : 128 10 | }, 11 | 12 | "assets dir" : "assets/mnist/result3", 13 | 14 | "model" : "classification", 15 | "model params" : { 16 | "name" : "mnist", 17 | 18 | "input shape" : [28, 28, 1], 19 | "nb classes" : 10, 20 | 21 | "optimizer" : "adam", 22 | "optimizer params" : { 23 | "lr" : 0.001, 24 | "lr scheme" : "exponential", 25 | "lr params" : { 26 | "decay_steps" : 10000, 27 | "decay_rate" : 0.1 28 | }, 29 | "beta1" : 0.5, 30 | "beta2" : 0.9 31 | }, 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | "normalization" : "fused_batch_norm", 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 3, 42 | "conv nb layers" : [2, 2, 2], 43 | "conv nb filters" : [32, 64, 128], 44 | "conv ksize" : [3, 3, 3], 45 | "no maxpooling" : true, 46 | 47 | "including top" : true, 48 | "fc nb nodes" : [600, 600], 49 | 50 | "output dims" : 10, 51 | "output_activation" : "none", 52 | 53 | "debug" : true 54 | } 55 | }, 56 | 57 | "trainer" : "supervised", 58 | "trainer params" : { 59 | 60 | "summary hyperparams string" : "learning_rate_0_001_adam", 61 | 62 | "continue train" : false, 63 | "multi thread" : true, 64 | "buffer depth" : 200, 65 | "batch_size" : 32, 66 | "train steps" : 20000, 67 | "summary steps" : 1000, 68 | "log steps" : 100, 69 | "save checkpoint steps" : 10000, 70 | 71 | "validators" : [ 72 | { 73 | "validator" : "dataset_validator", 74 | "validate steps" : 500, 75 | "has summary" : true, 76 | "validator params" : { 77 | "metric" : "accuracy", 78 | "metric type" : "top1" 79 | } 80 | } 81 | ] 82 | } 83 | } 84 | 85 | -------------------------------------------------------------------------------- /cfgs/cla/mnist4.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_classification", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "semi-supervised" : true, 7 | "nb_labelled_images_per_class" : 10, 8 | "output shape" : [28, 28, 1], 9 | "batch_size" : 128 10 | }, 11 | 12 | "assets dir" : "assets/mnist/result4", 13 | 14 | "model" : "classification", 15 | "model params" : { 16 | "name" : "mnist", 17 | 18 | "input shape" : [28, 28, 1], 19 | "nb classes" : 10, 20 | 21 | "optimizer" : "adam", 22 | "optimizer params" : { 23 | "lr" : 0.001, 24 | "lr scheme" : "exponential", 25 | "lr params" : { 26 | "decay_steps" : 10000, 27 | "decay_rate" : 0.1 28 | }, 29 | "beta1" : 0.5, 30 | "beta2" : 0.9 31 | }, 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | "normalization" : "fused_batch_norm", 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 3, 42 | "conv nb layers" : [2, 2, 2], 43 | "conv nb filters" : [32, 64, 128], 44 | "conv ksize" : [3, 3, 3], 45 | "no maxpooling" : true, 46 | 47 | "including top" : true, 48 | "fc nb nodes" : [600, 600], 49 | 50 | "output dims" : 10, 51 | "output_activation" : "none", 52 | 53 | "debug" : true 54 | } 55 | }, 56 | 57 | "trainer" : "supervised", 58 | "trainer params" : { 59 | 60 | "summary hyperparams string" : "learning_rate_0_001_adam", 61 | 62 | "continue train" : false, 63 | "multi thread" : true, 64 | "buffer depth" : 200, 65 | 66 | "batch_size" : 32, 67 | "train steps" : 20000, 68 | "summary steps" : 1000, 69 | "log steps" : 100, 70 | "save checkpoint steps" : 10000, 71 | 72 | "validators" : [ 73 | { 74 | "validator" : "dataset_validator", 75 | "validate steps" : 500, 76 | "has summary" : true, 77 | "validator params" : { 78 | "metric" : "accuracy", 79 | "metric type" : "top1" 80 | } 81 | } 82 | ] 83 | } 84 | } 85 | 86 | -------------------------------------------------------------------------------- /cfgs/cla/mnist_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_classification", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "semi-supervised" : true, 7 | "labelled indices filepath" : "./mnist_temp_data_dir/method_1_9.pkl", 8 | "output shape" : [28, 28, 1], 9 | "batch_size" : 128 10 | }, 11 | 12 | "assets dir" : "assets/mnist/tests", 13 | 14 | "model" : "classification", 15 | "model params" : { 16 | "name" : "mnist", 17 | 18 | "input shape" : [28, 28, 1], 19 | "nb classes" : 10, 20 | 21 | "optimizer" : "adam", 22 | "optimizer params" : { 23 | "lr" : 0.001, 24 | "lr scheme" : "exponential", 25 | "lr params" : { 26 | "decay_steps" : 10000, 27 | "decay_rate" : 0.1 28 | }, 29 | "beta1" : 0.5, 30 | "beta2" : 0.9 31 | }, 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | "batch_norm" : "fused_batch_norm", 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 3, 42 | "conv nb layers" : [2, 2, 2], 43 | "conv nb filters" : [32, 64, 128], 44 | "conv ksize" : [3, 3, 3], 45 | "no maxpooling" : true, 46 | 47 | "including top" : true, 48 | "fc nb nodes" : [600, 600], 49 | 50 | "output dims" : 10, 51 | "output_activation" : "none", 52 | 53 | "debug" : true 54 | } 55 | }, 56 | 57 | "trainer" : "supervised", 58 | "trainer params" : { 59 | 60 | "summary hyperparams string" : "method_1_9", 61 | 62 | "continue train" : false, 63 | "multi thread" : true, 64 | "batch_size" : 32, 65 | "train steps" : 20000, 66 | "summary steps" : 1000, 67 | "log steps" : 100, 68 | "save checkpoint steps" : 10000, 69 | 70 | "debug" : true, 71 | 72 | "validators" : [ 73 | { 74 | "validator" : "dataset_validator", 75 | "validate steps" : 1000, 76 | "has summary" : true, 77 | "validator params" : { 78 | "metric" : "accuracy", 79 | "metric type" : "top1" 80 | } 81 | } 82 | ] 83 | } 84 | } 85 | 86 | -------------------------------------------------------------------------------- /netutils/learning_rate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | 27 | import tensorflow as tf 28 | 29 | def get_learning_rate(name, initial_learning_rate, global_step, config): 30 | 31 | initial_learning_rate = float(initial_learning_rate) 32 | 33 | if name == 'constant': 34 | return tf.constant(initial_learning_rate) 35 | 36 | elif name == 'exponential': 37 | return tf.train.exponential_decay(initial_learning_rate, global_step, 38 | decay_steps=config['decay_steps'], 39 | decay_rate=config['decay_rate'], 40 | staircase=config.get('staircase', True)) 41 | 42 | elif name == 'piecewise': 43 | ''' 44 | config parameters: 45 | e.g. 46 | boundaries: [10000, 30000] 47 | values : [1.0, 0.5, 0.1] 48 | ''' 49 | return tf.train.piecewise_constant(global_step, 50 | boundaries=config['boundaries'], 51 | values=[value * initial_learning_rate for value in config['values']] 52 | ) 53 | else: 54 | raise Exception('None learning rate scheme named ' + name) 55 | 56 | 57 | def get_global_step(name='global_step'): 58 | global_step = tf.Variable(0, trainable=False, name=name) 59 | global_step_update = tf.assign(global_step, global_step+1) 60 | return global_step, global_step_update 61 | 62 | -------------------------------------------------------------------------------- /model/base_detection_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | sys.path.append('.') 28 | sys.path.append("../") 29 | 30 | import tensorflow as tf 31 | import tensorflow.contrib.layers as tcl 32 | import numpy as np 33 | 34 | from netutils.learning_rate import get_global_step 35 | from netutils.loss import get_loss 36 | 37 | from .base_model import BaseModel 38 | 39 | 40 | def BaseDetectionModel(BaseModel): 41 | def __init__(self, config): 42 | BaseModel.__init__(self, config) 43 | self.config = config 44 | 45 | 46 | def build_proposal_layer(self, inputs, proposal_count, mns_threshold): 47 | """ 48 | """ 49 | post_nms_rois_training = self.config.get('nb post-nms rois in training', 2000) 50 | post_nms_rois_inference = self.config.get('nb post-nms rois in inference', 1000) 51 | 52 | rpn_nms_threshold = self.config.get('proposal nms threshold', 0.7) 53 | 54 | def train_proposal(): 55 | 56 | scores = inputs[0][:, :, 1] 57 | deltas = inputs[1] 58 | deltas = deltas * np.reshape(self.config) 59 | 60 | pass 61 | def test_proposal(): 62 | pass 63 | 64 | pass 65 | 66 | 67 | def detect(self, batch_x): 68 | pass 69 | 70 | 71 | 72 | # def get_anchors(): 73 | # pass 74 | 75 | -------------------------------------------------------------------------------- /classifier/classifier_simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | 28 | 29 | import tensorflow as tf 30 | import tensorflow.contrib.layers as tcl 31 | 32 | sys.path.append('../') 33 | 34 | from netutils.weightsinit import get_weightsinit 35 | from netutils.activation import get_activation 36 | from netutils.normalization import get_normalization 37 | 38 | from network.base_network import BaseNetwork 39 | 40 | class ClassifierSimple(BaseNetwork): 41 | def __init__(self, config, is_training): 42 | BaseNetwork.__init__(self, config, is_training) 43 | 44 | self.base_network = config.get('base network', 'vgg') 45 | if self.base_network == 'vgg': 46 | from network.vgg import VGG 47 | self.network = VGG(config, is_training) 48 | elif self.base_network == 'resnet': 49 | from network.resnet import Resnet 50 | self.network = Resnet(config, is_training) 51 | else: 52 | raise ValueError("no base network named " + self.base_network ) 53 | 54 | def __call__(self, i): 55 | x, end_points = self.network(i) 56 | return x 57 | 58 | def features(self, i): 59 | x, end_points = self.network(i) 60 | return x, end_points 61 | 62 | def load_pretrained_weights(self, sess): 63 | return self.network.load_pretrained_weights(sess) 64 | -------------------------------------------------------------------------------- /cfgs/tianchi/guangdong3.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "guangdong_defect_classification", 3 | 4 | "dataset" : "tianchi guangdong defect", 5 | "dataset params" : { 6 | "output shape" : [256, 256, 3], 7 | "one hot" : true, 8 | "use cache" : true, 9 | "mil" : false, 10 | "show warning" : false 11 | }, 12 | 13 | "assets dir" : "assets/tianchi_guangdong/result6", 14 | 15 | "model" : "classification", 16 | "model params" : { 17 | "name" : "tianchi", 18 | 19 | "input shape" : [256, 256, 3], 20 | "nb classes" : 12, 21 | 22 | "optimizer" : "adam", 23 | "optimizer params" : { 24 | "lr" : 0.001, 25 | "lr scheme" : "exponential", 26 | "lr params" : { 27 | "decay_steps" : 30000, 28 | "decay_rate" : 0.2 29 | } 30 | }, 31 | 32 | "classification loss" : "cross entropy", 33 | 34 | "summary" : true, 35 | 36 | "classifier" : "classifier", 37 | "classifier params" : { 38 | "normalization" : "fused_batch_norm", 39 | 40 | "including conv" : true, 41 | "conv nb blocks" : 6, 42 | "conv nb layers" : [2, 2, 3, 3, 3, 0], 43 | "conv nb filters" : [64, 128, 256, 512, 512], 44 | "conv ksize" : [3, 3, 3, 3, 3], 45 | 46 | "including top" : true, 47 | "fc nb nodes" : [1024, 1024], 48 | 49 | "load pretrained weights" : "vgg16", 50 | 51 | "debug" : true 52 | } 53 | }, 54 | 55 | "trainer" : "supervised", 56 | "trainer params" : { 57 | "summary hyperparams string" : "lr_0_001_adam", 58 | 59 | "continue train" : true, 60 | 61 | "multi thread" : true, 62 | "buffer depth" : 100, 63 | "nb threads" : 16, 64 | "batch_size" : 16, 65 | 66 | "train steps" : 50000, 67 | "summary steps" : 1000, 68 | "log steps" : 500, 69 | "save checkpoint steps" : 10000, 70 | 71 | "validators" : [ 72 | { 73 | "validator" : "dataset_validator", 74 | "validate steps" : 2000, 75 | "validator params" : { 76 | "metric" : "accuracy", 77 | "metric type" : "top1", 78 | "nb samples" : 200, 79 | "batch_size" : 8, 80 | "mil" : false 81 | } 82 | } 83 | ] 84 | } 85 | } 86 | 87 | -------------------------------------------------------------------------------- /cfgs/cla/mnist5.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_classification", 3 | // differ from "cla/mnist4.json" 4 | // change the activation from relu to leaky relu 5 | 6 | "dataset" : "mnist", 7 | "dataset params" : { 8 | "semi-supervised" : true, 9 | "nb_labelled_images_per_class" : 100, 10 | "output shape" : [28, 28, 1], 11 | "batch_size" : 128 12 | }, 13 | 14 | "assets dir" : "assets/mnist/result4", 15 | 16 | "model" : "classification", 17 | "model params" : { 18 | "name" : "mnist", 19 | 20 | "input shape" : [28, 28, 1], 21 | "nb classes" : 10, 22 | 23 | "optimizer" : "adam", 24 | "optimizer params" : { 25 | "lr" : 0.001, 26 | "lr scheme" : "exponential", 27 | "lr params" : { 28 | "decay_steps" : 10000, 29 | "decay_rate" : 0.1 30 | }, 31 | "beta1" : 0.5, 32 | "beta2" : 0.9 33 | }, 34 | "classification loss" : "cross entropy", 35 | 36 | "summary" : true, 37 | 38 | "classifier" : "classifier", 39 | "classifier params" : { 40 | "activation" : "lrelu 0.2", 41 | "normalization" : "fused_batch_norm", 42 | 43 | "including conv" : true, 44 | "conv nb blocks" : 3, 45 | "conv nb layers" : [2, 2, 2], 46 | "conv nb filters" : [32, 64, 128], 47 | "conv ksize" : [3, 3, 3], 48 | "no maxpooling" : true, 49 | 50 | "including top" : true, 51 | "fc nb nodes" : [600, 600], 52 | 53 | "output dims" : 10, 54 | "output_activation" : "none", 55 | 56 | "debug" : true 57 | } 58 | }, 59 | 60 | "trainer" : "supervised", 61 | "trainer params" : { 62 | 63 | "summary hyperparams string" : "learning_rate_0_001_adam", 64 | 65 | "continue train" : false, 66 | "multi thread" : true, 67 | "batch_size" : 32, 68 | "train steps" : 20000, 69 | "summary steps" : 1000, 70 | "log steps" : 100, 71 | "save checkpoint steps" : 10000, 72 | 73 | "validators" : [ 74 | { 75 | "validator" : "dataset_validator", 76 | "validate steps" : 500, 77 | "has summary" : true, 78 | "validator params" : { 79 | "metric" : "accuracy", 80 | "metric type" : "top1" 81 | } 82 | } 83 | ] 84 | } 85 | } 86 | 87 | -------------------------------------------------------------------------------- /validator/validator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | def get_validator(name, config): 26 | if name == 'hidden_variable': 27 | # view the data generated from hidden variable 28 | from .hidden_variable import HiddenVariable 29 | return HiddenVariable(config) 30 | elif name == 'scatter_plot': 31 | from .scatter_plot import ScatterPlot 32 | return ScatterPlot(config) 33 | elif name == 'dataset_validator': 34 | from .dataset_validator import DatasetValidator 35 | return DatasetValidator(config) 36 | elif name == 'random_generate': 37 | from .random_generate import RandomGenerate 38 | return RandomGenerate(config) 39 | elif name == 'embedding_visualize' or name == 'tensorboard embedding': 40 | from .tensorboard_embedding import TensorboardEmbedding 41 | return TensorboardEmbedding(config) 42 | elif name == 'validate_segmentation': 43 | from .valid_segmentation import ValidSegmentation 44 | return ValidSegmentation(config) 45 | elif name == 'gan_toy_plot': 46 | from .gan_toy_plot import GanToyPlot 47 | return GanToyPlot(config) 48 | elif name == 'inception_score': 49 | from .inception_score import InceptionScore 50 | return InceptionScore(config) 51 | else: 52 | raise Exception("None validator named " + name) 53 | 54 | 55 | -------------------------------------------------------------------------------- /dataset/violence.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | from .base_imagelist_dataset import BaseImageListDataset 28 | 29 | class Violence(BaseImageListDataset): 30 | def __init__(self, config): 31 | super(Violence, self).__init__(config) 32 | self.name = 'Violence' 33 | self.time_string = config.get('time string', '201808142212') 34 | self.nb_classes = 2 35 | 36 | self._dataset_dir = 'F:\\Documents\\new BK\\已整理好--与BK有关的复杂场景-肉眼不好区分的等等情形都作为负面样本' 37 | if not os.path.exists(self._dataset_dir): 38 | self._dataset_dir = '/mnt/data03/dataset/new BK/已整理好--与BK有关的复杂场景-肉眼不好区分的等等情形都作为负面样本' 39 | if not os.path.exists(self._dataset_dir): 40 | raise Exception("Violence : the dataset dir " + self._dataset_dir + " is not exist") 41 | 42 | self._imagelist_dir = 'F:\\Documents\\new BK\\Proj' 43 | if not os.path.exists(self._imagelist_dir): 44 | self._imagelist_dir = '/mnt/data03/dataset/new BK/Proj' 45 | if not os.path.exists(self._imagelist_dir): 46 | raise Exception("Violence : the imagelist dir " + self._imagelist_dir + " is not exist") 47 | 48 | self.train_imagelist_fp = os.path.join(self._imagelist_dir, 'train_' + self.time_string + '.txt') 49 | self.val_imagelist_fp = os.path.join(self._imagelist_dir, 'val_' + self.time_string + '.txt') 50 | 51 | assert(os.path.exists(self.train_imagelist_fp)) 52 | assert(os.path.exists(self.val_imagelist_fp)) 53 | 54 | self.build_dataset() 55 | 56 | -------------------------------------------------------------------------------- /cfgs/cvae2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "cvae_2", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | }, 7 | 8 | "assets dir" : "assets/cvae_2", 9 | 10 | "model" : "cvae", 11 | "model params" : { 12 | "name" : "cvae", 13 | 14 | "input shape" : [28, 28, 1], 15 | "flatten" : true, 16 | "nb_classes" : 10, 17 | 18 | "z_dim" : 3, 19 | "is_training" : true, 20 | "sample_func" : "normal", 21 | 22 | "optimizer" : "rmsprop", 23 | "lr" : 0.001, 24 | "lr_scheme" : "exponential", 25 | "lr_params" : { 26 | "decay_steps" : 1000, 27 | "decay_rate" : 0.9 28 | }, 29 | 30 | "kl loss" : "gaussian", 31 | "kl loss prod" : 0.01, 32 | "reconstruction loss" : "mse", 33 | "reconstruction loss prod" : 1, 34 | 35 | "summary" : true, 36 | "summary dir" : "log", 37 | 38 | "x encoder" : "EncoderSimple", 39 | "x encoder params" : { 40 | "nb_conv_blocks" : 0, 41 | "batch_norm" : "none", 42 | "nb_fc_nodes" : [256], 43 | "output_distribution": "gaussian" 44 | }, 45 | 46 | "y encoder" : "EncoderSimple", 47 | "y encoder params" : { 48 | "nb_conv_blocks" : 0, 49 | "batch_norm" : "none", 50 | "nb_fc_nodes" : [], 51 | "output_distribution": "mean" 52 | }, 53 | 54 | "decoder" : "DecoderSimple", 55 | "decoder params" : { 56 | } 57 | }, 58 | 59 | "trainer" : "supervised", 60 | "trainer params" : { 61 | "continue train" : false, 62 | "train steps" : 20000, 63 | "summary steps" : 1000, 64 | "log steps" : 100, 65 | "save checkpoint steps" : 1000, 66 | 67 | "validators" : [ 68 | { 69 | "validator" : "scatter_plot_validator", 70 | "validate steps" : 1000, 71 | "validator params" : { 72 | "watch variable" : "hidden dist", 73 | "x dim" : 0, 74 | "y dim" : 1, 75 | "log dir" : "scatter1" 76 | } 77 | }, 78 | { 79 | "validator" : "scatter_plot_validator", 80 | "validate steps" : 1000, 81 | "validator params" : { 82 | "watch variable" : "hidden dist", 83 | "x dim" : 0, 84 | "y dim" : 2, 85 | "log dir" : "scatter2" 86 | } 87 | } 88 | ] 89 | } 90 | } 91 | 92 | 93 | -------------------------------------------------------------------------------- /dataset/base_mil_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import numpy as np 27 | from skimage import io 28 | import cv2 29 | 30 | from .base_dataset import BaseDataset 31 | 32 | class BaseMILDataset(BaseDataset): 33 | """ The base dataset class for supporting multiple-instance learning. 34 | 35 | """ 36 | def __init__(self, config): 37 | 38 | super(BaseMILDataset, self).__init__(config) 39 | self.config = config 40 | 41 | def crop_image_to_bag(self, img, output_shape, *, 42 | max_nb_crops=None, nb_crops=None, nb_crop_col=None, nb_crop_row=None): 43 | 44 | output_h, output_w = output_shape[0:2] 45 | image_h, image_w = img.shape[0:2] 46 | 47 | nb_col = int(np.ceil(image_w / output_w)) 48 | nb_row = int(np.ceil(image_h / output_h)) 49 | 50 | step_h = float(image_h) / float(nb_row) 51 | step_w = float(image_w) / float(nb_col) 52 | 53 | img_bag = [] 54 | img_bbox = [] 55 | 56 | for i in range(nb_row): 57 | for j in range(nb_col): 58 | 59 | x1 = int(j * step_w) 60 | y1 = int(i * step_h) 61 | x2 = int(j * step_w) + output_w 62 | y2 = int(i * step_h) + output_h 63 | 64 | crop_image = self.crop_and_pad_image(img, [x1,y1,x2,y2]) 65 | 66 | if crop_image.shape[0] != output_h or crop_image.shape[1] != output_w or crop_image.shape[2] != output_shape[2]: 67 | print("Warning : Crop out image shape " + str(crop_image.shape) ) 68 | 69 | img_bbox.append([x1,y1,x2,y2]) 70 | img_bag.append(crop_image) 71 | 72 | return img_bag, img_bbox, nb_col, nb_row 73 | -------------------------------------------------------------------------------- /dataset/svhn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import numpy as np 28 | import pickle 29 | 30 | from scipy.io import loadmat 31 | 32 | from .base_dataset import BaseDataset 33 | from .base_simple_dataset import BaseSimpleDataset 34 | 35 | class SVHN(BaseSimpleDataset): 36 | 37 | def __init__(self, config): 38 | super(SVHN, self).__init__(config) 39 | 40 | self._dataset_dir = "C:\\Data\\SVHN" 41 | if not os.path.exists(self._dataset_dir): 42 | self._dataset_dir = self.config.get("dataset dir", "") 43 | if not os.path.exists(self._dataset_dir): 44 | raise Exception("SVHN : the dataset dir is not exist") 45 | 46 | self.name = "SVHN" 47 | self.train_mat_fp = os.path.join(self._dataset_dir, "train_32x32.mat") 48 | self.test_mat_fp = os.path.join(self._dataset_dir, "test_32x32.mat") 49 | 50 | train_data = loadmat(self.train_mat_fp) 51 | test_data = loadmat(self.test_mat_fp) 52 | 53 | self.x_train = np.array(train_data["X"]).transpose((3, 0, 1, 2)).astype(np.float32) / 255.0 54 | self.y_train = np.array(train_data["y"]).reshape([-1,]) 55 | 56 | self.x_test = np.array(test_data["X"]).transpose((3, 0, 1, 2)).astype(np.float32) / 255.0 57 | self.y_test = np.array(test_data["y"]).reshape([-1,]) 58 | 59 | indices = np.where(self.y_train == 10)[0] 60 | self.y_train[indices] = 0 61 | indices = np.where(self.y_test == 10)[0] 62 | self.y_test[indices] = 0 63 | 64 | self.output_shape = config.get('output shape', [32, 32, 3]) 65 | self.nb_classes = 10 66 | 67 | 68 | self.build_dataset() 69 | 70 | -------------------------------------------------------------------------------- /netutils/weightsinit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import tensorflow as tf 27 | import tensorflow.contrib.layers as tcl 28 | 29 | 30 | 31 | def get_weightsinit(name_config): 32 | ''' 33 | get tensorflow initializer function according to its name, 34 | for example: 35 | winit = get_weightsinit('normal 0.00 0.02') where 0.00 is mean and 0.02 is variance 36 | or 37 | winit = get_weightsinit('zeros') 38 | ''' 39 | split = name_config.split() 40 | name = split[0] 41 | if len(split) > 1: 42 | params = split[1] 43 | 44 | if name == 'normal': 45 | if len(split) == 3: 46 | init_mean = float(split[1]) 47 | init_var = float(split[2]) 48 | else: 49 | init_mean = 0.0 50 | init_var = 0.02 51 | return tf.random_normal_initializer(init_mean, init_var) 52 | elif name == 'uniform' or name == 'uni': 53 | if len(split) == 3: 54 | init_min = float(split[1]) 55 | init_max = float(split[2]) 56 | else: 57 | init_min = 0.0 58 | init_max = 1.0 59 | return tf.random_uniform_initializer(init_min, init_max) 60 | 61 | elif name == 'he_uniform': 62 | return tf.keras.initializers.he_uniform() 63 | elif name == 'he_normal': 64 | return tf.keras.initializers.he_normal() 65 | 66 | elif name == 'glorot_uniform' or name == 'glorot_uni': 67 | return tf.keras.initializers.glorot_uniform() 68 | elif name == 'glorot_normal': 69 | return tf.keras.initializers.glorot_normal() 70 | 71 | elif name == 'xavier': 72 | return tf.contrib.layers.xavier_initializer() 73 | 74 | elif name == 'zeros' or name == 'zero': 75 | return tf.zeros_initializer() 76 | elif name == 'ones' or name == 'one': 77 | return tf.ones_initializer() 78 | 79 | else : 80 | raise Exception("None weights initializer named " + name) 81 | 82 | -------------------------------------------------------------------------------- /cfgs/wgan/cifar10_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "wgan_gp", 3 | 4 | "dataset" : "cifar10", 5 | "dataset params" : { 6 | "output_shape" : [32, 32, 3], 7 | "output scalar range" : [-1.0, 1.0] 8 | }, 9 | 10 | "assets dir" : "assets/wgan_gp/cifar10_3", 11 | "model" : "wgan_gp", 12 | "model params" : { 13 | "name" : "wgan_gp", 14 | 15 | "input shape" : [32, 32, 3], 16 | "z_dim" : 100, 17 | 18 | "discriminator optimizer" : "adam", 19 | "discriminator optimizer params" : { 20 | "lr" : 0.0001, 21 | "lr scheme" : "constant", 22 | "beta1" : 0.5, 23 | "beta2" : 0.9 24 | }, 25 | 26 | "generator optimizer" : "adam", 27 | "generator optimizer params" : { 28 | "lr" : 0.0001, 29 | "lr scheme" : "constant", 30 | "beta1" : 0.5, 31 | "beta2" : 0.9 32 | }, 33 | 34 | "gradient penalty loss weight" : 10.0, 35 | "summary" : true, 36 | 37 | "generator" : "generator_cifar10", 38 | "generator params" : { 39 | 40 | }, 41 | 42 | "discriminator" : "discriminator_cifar10", 43 | "discriminator params" : { 44 | 45 | } 46 | }, 47 | 48 | "trainer" : "unsupervised", 49 | "trainer params" : { 50 | 51 | "summary dir" : "log", 52 | "summary hyperparams string" : "lr_0_0001", 53 | 54 | "multi thread" : true, 55 | "continue train" : true, 56 | "train steps" : 150000, 57 | 58 | "summary steps" : 1000, 59 | "log steps" : 1000, 60 | "save checkpoint steps" : 10000, 61 | 62 | "batch_size" : 64, 63 | 64 | "debug" : true, 65 | "validators" : [ 66 | { 67 | "validator" : "random_generate", 68 | "validate steps" : 1000, 69 | "validator params" : { 70 | "log dir" : "generated_lr_0_0001", 71 | "z shape" : [100], 72 | "x shape" : [32, 32, 3], 73 | "output scalar range" : [-1.0, 1.0], 74 | "nb row" : 8, 75 | "nb col" : 8 76 | } 77 | }, 78 | { 79 | "validator" : "embedding_visualize", 80 | "validate steps" : 5000, 81 | "validator params" : { 82 | "z shape" : [100], 83 | "x shape" : [32, 32, 3], 84 | "log dir" : "log_lr_0_0001" 85 | } 86 | }, 87 | { 88 | "validator" : "inception_score", 89 | "validate steps" : 500, 90 | "validator params" : { 91 | "z shape" : [100], 92 | "x shape" : [32, 32, 3], 93 | "output scalar range" : [-1.0, 1.0] 94 | } 95 | } 96 | ] 97 | } 98 | } 99 | 100 | 101 | -------------------------------------------------------------------------------- /validator/base_validator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | import queue 29 | import threading 30 | 31 | sys.path.append('.') 32 | sys.path.append('../') 33 | 34 | 35 | class BaseValidator(object): 36 | """ The base of validator classes. 37 | validator is another import module in this project 38 | During training process, the validator.validate is called at intervals to view the 39 | model performence and detect bugs in model. 40 | """ 41 | def __init__(self, config): 42 | self.config = config 43 | self.assets_dir = config['assets dir'] 44 | self.has_summary = False 45 | 46 | # 47 | # Please override the following functions in derived class 48 | # 49 | def build_summary(self, model): 50 | pass 51 | 52 | def validate(self, model, dataset, sess, step): 53 | return NotImplementedError 54 | 55 | # 56 | # Util functions 57 | # 58 | def parallel_data_reading(self, dataset, indices, phase, method, buffer_depth, nb_threads=4): 59 | 60 | self.t_should_stop = False 61 | 62 | data_queue = queue.Queue(maxsize=buffer_depth) 63 | 64 | def read_data_inner_loop(dataset, data_queue, indices, t_ind, nb_threads): 65 | for i, ind in enumerate(indices): 66 | if i % nb_threads == t_ind: 67 | # read img and label by its index 68 | img, label = dataset.read_image_by_index(ind, 'val', 'supervised') 69 | if isinstance(img, list) and isinstance(label, list): 70 | for _img, _label in zip(img, label): 71 | data_queue.put((img, label)) 72 | elif img is not None: 73 | data_queue.put((img, label)) 74 | 75 | 76 | def read_data_loop(indices, dataset, data_queue, nb_threads): 77 | threads = [threading.Thread(target=read_data_inner_loop, 78 | args=(dataset, data_queue, indices, t_ind, nb_threads)) for t_ind in range(nb_threads)] 79 | for t in threads: 80 | t.start() 81 | for t in threads: 82 | t.join() 83 | self.t_should_stop = True 84 | 85 | 86 | t = threading.Thread(target=read_data_loop, args=(indices, dataset, data_queue, nb_threads)) 87 | t.start() 88 | 89 | return t, data_queue 90 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | def get_model(model_name, model_config): 27 | if model_name == 'cvaegan': 28 | from .cvaegan import CVAEGAN 29 | return CVAEGAN(model_config) 30 | 31 | elif model_name == 'vae': 32 | from .vae import VAE 33 | return VAE(model_config) 34 | 35 | elif model_name == 'cvae': 36 | from .cvae import CVAE 37 | return CVAE(model_config) 38 | 39 | elif model_name == 'aae': 40 | from .aae import AAE 41 | return AAE(model_config) 42 | 43 | elif model_name == 'aae_ssl' or model_name == 'aae_semi': 44 | from .aae_ssl import AAESemiSupervised 45 | return AAESemiSupervised(model_config) 46 | 47 | elif model_name == 'classification': 48 | from .classification import Classification 49 | return Classification(model_config) 50 | 51 | elif model_name == 'segmentation': 52 | from .segmentation import Segmentation 53 | return Segmentation(model_config) 54 | 55 | elif model_name == 'stargan': 56 | from .stargan import StarGAN 57 | return StarGAN(model_config) 58 | 59 | elif model_name == 'semidgm': 60 | from .semi_dgm import SemiDeepGenerativeModel 61 | return SemiDeepGenerativeModel(model_config) 62 | 63 | elif model_name == 'semidgm2': 64 | from .semi_dgm2 import SemiDeepGenerativeModel2 65 | return SemiDeepGenerativeModel2(model_config) 66 | 67 | elif model_name == 'dcgan': 68 | from .dcgan import DCGAN 69 | return DCGAN(model_config) 70 | 71 | elif model_name == 'wgan_gp' or model_name == 'wgan': 72 | from .wgan_gp import WGAN_GP 73 | return WGAN_GP(model_config) 74 | 75 | elif model_name == 'improved_gan': 76 | from .improved_gan import ImprovedGAN 77 | return ImprovedGAN(model_config) 78 | 79 | elif model_name == 'attention_mil': 80 | from .attention_mil import AttentionMIL 81 | return AttentionMIL(model_config) 82 | 83 | else: 84 | raise Exception("None model named " + model_name) 85 | -------------------------------------------------------------------------------- /cfgs/wgan/wgan_gp2.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "wgan_gp", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "output_shape" : [28, 28, 1] 7 | }, 8 | 9 | 10 | "assets dir" : "assets/wgan_gp2/result1", 11 | "model" : "wgan_gp", 12 | "model params" : { 13 | "name" : "wgan_gp", 14 | 15 | "input shape" : [28, 28, 1], 16 | "z_dim" : 20, 17 | 18 | "discriminator optimizer" : "rmsprop", 19 | "discriminator optimizer params" : { 20 | "lr" : 0.0001, 21 | "lr scheme" : "constant" 22 | }, 23 | 24 | "generator optimizer" : "rmsprop", 25 | "generator optimizer params" : { 26 | "lr" : 0.0001, 27 | "lr scheme" : "constant" 28 | }, 29 | 30 | "gradient penalty loss weight" : 10.0, 31 | "summary" : true, 32 | 33 | "generator" : "generator", 34 | "generator params" : { 35 | "batch_norm" : "batch_norm", 36 | 37 | "including_bottom" : true, 38 | "nb_fc_nodes" : [], 39 | "fc_output_reshape" : [7, 7, 32], 40 | 41 | "including_deconv" : true, 42 | "nb_deconv_blocks" : 3, 43 | "nb_deconv_layers" : [2, 2, 1], 44 | "nb_deconv_filters" : [32, 16, 8], 45 | 46 | "output dims" : 1, 47 | "output_activation" : "sigmoid", 48 | "debug" : true 49 | }, 50 | 51 | "discriminator" : "discriminator", 52 | "discriminator params" : { 53 | 54 | "activation" : "lrelu 0.2", 55 | "batch_norm" : "none", 56 | "no maxpooling" : true, 57 | 58 | "nb_conv_blocks" : 3, 59 | "nb_conv_layers" : [2, 2, 1], 60 | "nb_conv_filters" : [8, 16, 32], 61 | "nb_fc_nodes" : [256], 62 | 63 | "output dims" : 1, 64 | "output_activation" : "none", 65 | "debug" : true 66 | } 67 | }, 68 | 69 | "trainer" : "unsupervised", 70 | "trainer params" : { 71 | 72 | "summary dir" : "log", 73 | "summary hyperparams string" : "lr_0_0001_3", 74 | 75 | "multi thread" : true, 76 | "continue train" : false, 77 | "train steps" : 20000, 78 | "summary steps" : 300, 79 | "log steps" : 100 , 80 | "debug" : true, 81 | "batch_size" : 128, 82 | // "save checkpoint steps" : 1000, 83 | "validators" : [ 84 | { 85 | "validator" : "random_generate", 86 | "validate steps" : 1000, 87 | "validator params" : { 88 | "log dir" : "generated_lr_0_0001_3", 89 | "z shape" : [20], 90 | "x shape" : [28, 28, 1], 91 | "nb row" : 8, 92 | "nb col" : 8 93 | } 94 | }, 95 | { 96 | "validator" : "embedding_visualize", 97 | "validate steps" : 500, 98 | "validator params" : { 99 | "z shape" : [20], 100 | "x shape" : [28, 28, 1], 101 | "log dir" : "log_lr_0_0001_3" 102 | } 103 | } 104 | ] 105 | } 106 | } 107 | 108 | 109 | -------------------------------------------------------------------------------- /generator/generator_conv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | sys.path.append('../') 28 | 29 | import tensorflow as tf 30 | import tensorflow.contrib.layers as tcl 31 | 32 | from netutils.weightsinit import get_weightsinit 33 | from netutils.activation import get_activation 34 | from netutils.normalization import get_normalization 35 | 36 | 37 | # from network.vgg import VGG 38 | from network.devgg import DEVGG 39 | from network.base_network import BaseNetwork 40 | 41 | class G_conv(BaseNetwork): 42 | 43 | def __init__(self, config, is_training): 44 | BaseNetwork.__init__(self, config, is_training) 45 | 46 | self.name = config.get('name', 'GeneratorSimple') 47 | self.config = config 48 | 49 | network_config = config.copy() 50 | network_config['name'] = self.name 51 | self.reuse = False 52 | self.size = 64//16 53 | self.channel = 3 54 | 55 | def __call__(self, i): 56 | with tf.variable_scope(self.name): 57 | if self.reuse: 58 | tf.get_variable_scope().reuse_variables() 59 | else: 60 | assert tf.get_variable_scope().reuse is False 61 | self.reuse = True 62 | g = tcl.fully_connected(i, self.size * self.size * 1024, activation_fn=tf.nn.relu, 63 | normalizer_fn=tcl.batch_norm) 64 | g = tf.reshape(g, (-1, self.size, self.size, 1024)) # size 65 | g = tcl.conv2d_transpose(g, 512, 3, stride=2, # size*2 66 | activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02)) 67 | g = tcl.conv2d_transpose(g, 256, 3, stride=2, # size*4 68 | activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02)) 69 | g = tcl.conv2d_transpose(g, 128, 3, stride=2, # size*8 70 | activation_fn=tf.nn.relu, normalizer_fn=tcl.batch_norm, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02)) 71 | 72 | g = tcl.conv2d_transpose(g, self.channel, 3, stride=2, # size*16 73 | activation_fn=tf.nn.sigmoid, padding='SAME', weights_initializer=tf.random_normal_initializer(0, 0.02)) 74 | return g 75 | 76 | return x 77 | 78 | 79 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | import argparse 28 | from shutil import copyfile 29 | 30 | import tensorflow as tf 31 | 32 | 33 | sys.path.append('./') 34 | sys.path.append('./lib') 35 | sys.path.append('../') 36 | 37 | from cfgs.networkconfig import get_config 38 | from dataset.dataset import get_dataset 39 | from model.model import get_model 40 | from trainer.trainer import get_trainer 41 | 42 | parser = argparse.ArgumentParser(description='') 43 | parser.add_argument('--gpu', type=str, default='0') 44 | parser.add_argument('--config_file', type=str, default='cvae1') # target config file, stored in ./cfgs 45 | parser.add_argument('--disp_config', type=bool, default=False) # if there is error in config file, set True to print the config file with line number 46 | 47 | 48 | args = parser.parse_args() 49 | 50 | 51 | if __name__ == '__main__': 52 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 53 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 54 | tf.reset_default_graph() 55 | 56 | # load config file 57 | config = get_config(args.config_file, args.disp_config) 58 | 59 | # make the assets directory and copy the config file to it 60 | # so if you want to reproduce the result in assets dir 61 | # just copy the config_file.json to ./cfgs folder and run python3 train.py --config=(config_file) 62 | if not os.path.exists(config['assets dir']): 63 | os.makedirs(config['assets dir']) 64 | copyfile(os.path.join('./cfgs', args.config_file + '.json'), 65 | os.path.join(config['assets dir'], 'config_file.json')) 66 | 67 | # prepare dataset 68 | dataset = get_dataset(config['dataset'], config['dataset params']) 69 | 70 | tfconfig = tf.ConfigProto() 71 | tfconfig.gpu_options.allow_growth = True 72 | 73 | with tf.Session(config=tfconfig) as sess: 74 | 75 | # build model 76 | config['model params']['assets dir'] = config['assets dir'] 77 | model = get_model(config['model'], config['model params']) 78 | 79 | # start testing 80 | config['tester params']['assets dir'] = config['assets dir'] 81 | trainer = get_trainer(config['tester'], config['tester params'], model) 82 | trainer.train(sess, dataset, model) 83 | -------------------------------------------------------------------------------- /encoder/encoder_simple.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | 28 | 29 | import tensorflow as tf 30 | import tensorflow.contrib.layers as tcl 31 | 32 | 33 | sys.path.append('../') 34 | 35 | 36 | from netutils.weightsinit import get_weightsinit 37 | from netutils.activation import get_activation 38 | from netutils.normalization import get_normalization 39 | 40 | 41 | from network.vgg import VGG 42 | from network.base_network import BaseNetwork 43 | 44 | class EncoderSimple(BaseNetwork): 45 | 46 | def __init__(self, config, is_training): 47 | 48 | super(EncoderSimple, self).__init__(config, is_training) 49 | 50 | self.name = config.get('name', 'EncoderSimple') 51 | self.config = config 52 | 53 | network_config = config.copy() 54 | network_config['name'] = self.name 55 | network_config["output dims"] = 0 56 | 57 | self.network = VGG(network_config, is_training) 58 | self.output_distribution = self.config.get('output_distribution', 'gaussian') 59 | self.reuse=False 60 | 61 | def __call__(self, i, condition=None): 62 | 63 | output_dims = self.config.get("output dims", 3) 64 | output_act_fn = get_activation(self.config.get('output_activation', 'none')) 65 | 66 | 67 | x, end_points = self.network(i) 68 | 69 | x = tcl.flatten(x) 70 | if condition is not None: 71 | x = tf.concatenate([x, condition], axis=-1) 72 | 73 | with tf.variable_scope(self.name): 74 | if self.reuse: 75 | tf.get_variable_scope().reuse_variables() 76 | else: 77 | assert tf.get_variable_scope().reuse is False 78 | self.reuse=True 79 | 80 | if self.output_distribution == 'gaussian': 81 | mean = self.fc('fc_out_mean', x, output_dims, **self.out_fc_args) 82 | log_var = self.fc('fc_out_log_var', x, output_dims, **self.out_fc_args) 83 | return mean, log_var 84 | 85 | elif self.output_distribution == 'mean': 86 | mean = self.fc('fc_out_mean', x, output_dims, **self.out_fc_args) 87 | return mean 88 | 89 | elif self.output_distribution == 'none': 90 | out = self.fc('fc_out_mean', x, output_dims, **self.out_fc_args) 91 | return out 92 | else: 93 | raise Exception("None output distribution named " + self.output_distribution) 94 | 95 | 96 | -------------------------------------------------------------------------------- /trainer/supervised_mil.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | import queue 28 | import threading 29 | import numpy as np 30 | 31 | sys.path.append('.') 32 | sys.path.append('../') 33 | 34 | import tensorflow as tf 35 | 36 | from validator.validator import get_validator 37 | 38 | from .base_trainer import BaseTrainer 39 | 40 | class SupervisedMILTrainer(BaseTrainer): 41 | ''' Supervised Trainer for Multiple Instance Learning, 42 | the input data is a bag of images and corresponding label, 43 | it trains the model with a bag and a label per step, 44 | 45 | The difference between SupervisedTrainer and SupervisedMILTrainer is that 46 | SupervisedMILTrainer do not need to group the input data into a batch 47 | 48 | ''' 49 | def __init__(self, config, model, sess): 50 | super(SupervisedMILTrainer, self).__init__(config, model, sess) 51 | 52 | self.config = config 53 | self.model = model 54 | 55 | self.dataset_phase = self.config.get('dataset phase', 'train') 56 | 57 | self.nb_threads = int(self.config.get('nb reading threads', 4)) 58 | self.buffer_depth = int(self.config.get('buffer depth', 50)) 59 | self.train_data_queue = queue.Queue(maxsize=self.buffer_depth) 60 | 61 | 62 | def train(self, sess, dataset, model): 63 | 64 | self.train_initialize(sess, model) 65 | 66 | # start threads for read data 67 | self.coord = tf.train.Coordinator() 68 | threads = [threading.Thread(target=self.read_data_loop, 69 | args=(self.coord, dataset, self.train_data_queue, self.dataset_phase, 'supervised', self.nb_threads))] 70 | for t in threads: 71 | t.start() 72 | 73 | while True: 74 | epoch, x_bag, y_label = self.train_data_queue.get() 75 | step = self.train_inner_step(epoch, model, dataset, x_bag, y_label) 76 | 77 | if self.train_data_queue.empty() and step % 100 == 0: 78 | print('info : train data buffer empty') 79 | if step > int(self.config['train steps']): 80 | break 81 | 82 | # join all thread when in multi thread model 83 | self.coord.request_stop() 84 | while not self.train_data_queue.empty(): 85 | epoch, x_bag, y_label = self.train_data_queue.get() 86 | self.train_data_queue.task_done() 87 | self.coord.join(threads) 88 | 89 | -------------------------------------------------------------------------------- /cfgs/wgan/toy.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "wgan_gp", 3 | 4 | "dataset" : "gan toy", 5 | "dataset params" : { 6 | "dataset" : "25gaussians" 7 | }, 8 | 9 | "assets dir" : "assets/wgan_gp/toy2", 10 | "model" : "wgan_gp", 11 | "model params" : { 12 | "name" : "wgan_gp", 13 | "input shape" : [2], 14 | "z_dim" : 2, 15 | 16 | "discriminator optimizer" : "adam", 17 | "discriminator optimizer params" : { 18 | "lr" : 0.0001, 19 | "lr scheme" : "constant", 20 | "beta1" : 0.5, // important params, please do not use the default param 0.9 21 | "beta2" : 0.9 // important params, please do not use the default param 0.99 22 | // use smaller parameters can stablize the training process. 23 | }, 24 | 25 | "generator optimizer" : "adam", 26 | "generator optimizer params" : { 27 | "lr" : 0.0001, 28 | "lr scheme" : "constant", 29 | "beta1" : 0.5, 30 | "beta2" : 0.9 31 | }, 32 | 33 | 34 | "gradient penalty loss weight" : 1.0, 35 | "summary" : true, 36 | 37 | "generator" : "generator", 38 | "generator params" : { 39 | "activation" : "relu", 40 | "normalization" : "none", // 41 | "weightsinit" : "he_normal", 42 | 43 | "including_bottom" : true, 44 | "fc nb nodes" : [512, 512, 512], 45 | 46 | "including_deconv" : false, 47 | 48 | "output dims" : 2, 49 | "output_activation" : "none", 50 | "debug" : true 51 | }, 52 | 53 | "discriminator" : "discriminator", 54 | "discriminator params" : { 55 | "activation" : "relu", 56 | "normalization" : "none", 57 | "weightsinit" : "he_normal", 58 | 59 | "including conv" : false, 60 | 61 | "including top" : true, 62 | "fc nb nodes" : [512, 512, 512], 63 | 64 | "output dims" : 1, 65 | "output_activation" : "none", 66 | "debug" : true 67 | } 68 | }, 69 | 70 | "trainer" : "unsupervised", 71 | "trainer params" : { 72 | 73 | "summary dir" : "log", 74 | "summary hyperparams string" : "gp_w_1_0", 75 | 76 | "multi thread" : true, 77 | "continue train" : false, 78 | "train steps" : 10000, 79 | "summary steps" : 100, 80 | "log steps" : 100, 81 | // "save checkpoint steps" : 10000, 82 | 83 | "batch_size" : 256, 84 | 85 | "debug" : true, 86 | "validators" : [ 87 | { 88 | "validator" : "gan_toy_plot", 89 | "validate steps" : 100, 90 | "validator params" : { 91 | "plot range" : 4, 92 | "log dir" : "gan_toy_gp_w_1_0" 93 | } 94 | } 95 | // { 96 | // "validator" : "embedding_visualize", 97 | // "validate steps" : 5000, 98 | // "validator params" : { 99 | // "z shape" : [2], 100 | // "x shape" : [2], 101 | // "log dir" : "log_lr_0_0001" 102 | // } 103 | // } 104 | 105 | 106 | ] 107 | } 108 | } 109 | 110 | 111 | -------------------------------------------------------------------------------- /dataset/gan_toy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import struct 28 | import gzip 29 | import numpy as np 30 | # import matplotlib.pyplot as plt 31 | import pickle 32 | 33 | from .base_dataset import BaseDataset 34 | 35 | 36 | 37 | class GanToy(BaseDataset): 38 | 39 | def __init__(self, config): 40 | 41 | super(GanToy, self).__init__(config) 42 | self.config = config 43 | 44 | self.name = 'gan_toy' 45 | self.batch_size = int(config.get('batch_size', 128)) 46 | self.variance = float(config.get('variance', 0.02)) 47 | self.dataset = config.get('dataset', '8gaussians') 48 | 49 | if self.dataset == '8gaussians': 50 | scale = 2.0 51 | centers = [ 52 | (1, 0), 53 | (-1, 0), 54 | (0, 1), 55 | (0, -1), 56 | ( 1.0/np.sqrt(2), 1.0/np.sqrt(2)), 57 | ( -1.0/np.sqrt(2), 1.0/np.sqrt(2)), 58 | ( 1.0/np.sqrt(2), -1.0/np.sqrt(2)), 59 | ( -1.0/np.sqrt(2), -1.0/np.sqrt(2)) 60 | ] 61 | self.centers = np.array([ (scale*x, scale*y) for x, y in centers]) 62 | self.centers = np.tile(self.centers, (1000, 1)) 63 | self.centers = self.centers + np.random.randn(*self.centers.shape) * self.variance 64 | 65 | elif self.dataset == '25gaussians': 66 | # for i in range(25): 67 | centers = [(x, y) for x in range(-2, 3) 68 | for y in range(-2, 3)] 69 | self.centers = np.tile(centers, (1000, 1)) 70 | self.centers = self.centers + np.random.randn(*self.centers.shape) * self.variance 71 | 72 | 73 | def get_image_indices(self, phase=None, method=None): 74 | indices = np.arange(self.centers.shape[0]) 75 | if self.shuffle_train: 76 | np.random.shuffle(indices) 77 | return indices 78 | 79 | def read_image_by_index(self, ind, phase=None, method='unsupervised'): 80 | assert(method in ['unsupervised']) 81 | ind = ind % self.centers.shape[0] 82 | return self.centers[ind] 83 | 84 | 85 | 86 | def iter_train_images(self, method): 87 | assert method == 'unsupervised' 88 | 89 | indices = np.array(self.get_image_indices()) 90 | centers = np.array(self.centers) 91 | for i in range(int(len(indices) // self.batch_size)): 92 | batch_ind = indices[i*self.batch_size : (i+1)*self.batch_size] 93 | batch_x = centers[batch_ind, :] 94 | yield i, batch_x 95 | -------------------------------------------------------------------------------- /dataset/gan_toy_ssl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import struct 28 | import gzip 29 | import numpy as np 30 | # import matplotlib.pyplot as plt 31 | import pickle 32 | 33 | from .base_dataset import BaseDataset 34 | 35 | 36 | 37 | class GanToy(BaseDataset): 38 | 39 | def __init__(self, config): 40 | 41 | super(GanToy, self).__init__(config) 42 | self.config = config 43 | 44 | self.name = 'gan_toy_ssl' 45 | self.batch_size = int(config.get('batch_size', 128)) 46 | self.variance = float(config.get('variance', 0.02)) 47 | self.dataset = config.get('dataset', 'circles') 48 | 49 | raise NotImplementedError 50 | 51 | # if self.dataset == 'circles': 52 | # scale = 2.0 53 | # centers = [ 54 | # (1, 0), 55 | # (-1, 0), 56 | # (0, 1), 57 | # (0, -1), 58 | # ( 1.0/np.sqrt(2), 1.0/np.sqrt(2)), 59 | # ( -1.0/np.sqrt(2), 1.0/np.sqrt(2)), 60 | # ( 1.0/np.sqrt(2), -1.0/np.sqrt(2)), 61 | # ( -1.0/np.sqrt(2), -1.0/np.sqrt(2)) 62 | # ] 63 | # self.centers = np.array([ (scale*x, scale*y) for x, y in centers]) 64 | # self.centers = np.tile(self.centeres, (1000, 1)) 65 | # self.centers = self.centers + np.random.randn(*self.centers.shape) * self.variance 66 | # elif self.dataset == '25gaussians': 67 | # # for i in range(25): 68 | # centers = [(x, y) for x in range(-2, 3) 69 | # for y in range(-2, 3)] 70 | # self.centers = np.tile(centers, (1000, 1)) 71 | # self.centers = self.centers + np.random.randn(*self.centers.shape) * self.variance 72 | 73 | 74 | def get_image_indices(self, phase=None, method=None): 75 | indices = np.arange(self.centers.shape[0]) 76 | if self.shuffle_train: 77 | np.random.shuffle(indices) 78 | return indices 79 | 80 | def read_image_by_index(self, ind, phase=None, method='unsupervised'): 81 | assert(method in ['unsupervised']) 82 | ind = ind % self.centers.shape[0] 83 | return self.centers[ind] 84 | 85 | 86 | 87 | def iter_train_images(self, method): 88 | assert method == 'unsupervised' 89 | 90 | indices = np.array(self.get_image_indices()) 91 | centers = np.array(self.centers) 92 | for i in range(int(len(indices) // self.batch_size)): 93 | batch_ind = indices[i*self.batch_size : (i+1)*self.batch_size] 94 | batch_x = centers[batch_ind, :] 95 | yield i, batch_x 96 | -------------------------------------------------------------------------------- /dataset/mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import struct 28 | import gzip 29 | import numpy as np 30 | 31 | from .base_dataset import BaseDataset 32 | from .base_simple_dataset import BaseSimpleDataset 33 | 34 | 35 | 36 | class MNIST(BaseSimpleDataset): 37 | 38 | def __init__(self, config): 39 | 40 | super(MNIST, self).__init__(config) 41 | self.config = config 42 | 43 | self._dataset_dir = 'D:/Data/MNIST' 44 | if not os.path.exists(self._dataset_dir): 45 | self._dataset_dir = 'D:/dataset/MNIST' 46 | if not os.path.exists(self._dataset_dir): 47 | self._dataset_dir = 'C:/Data/MNIST' 48 | if not os.path.exists(self._dataset_dir): 49 | self._dataset_dir = 'G:/dataset/MNIST' 50 | if not os.path.exists(self._dataset_dir): 51 | self._dataset_dir = '/mnt/data01/dataset/MNIST' 52 | if not os.path.exists(self._dataset_dir): 53 | self._dataset_dir = '/mnt/sh_flex_storage/zhicongy/tmpdataset/MNIST' 54 | if not os.path.exists(self._dataset_dir): 55 | self._dataset_dir = config.get('dataset dir', '') 56 | if not os.path.exists(self._dataset_dir): 57 | raise Exception("MNIST : the dataset dir " + self._dataset_dir + " is not exist") 58 | 59 | self.name = 'mnist' 60 | self.output_shape = config.get('output shape', [28, 28, 1]) 61 | self.nb_classes = 10 62 | 63 | self.y_train, self.x_train = self._read_data( 64 | os.path.join(self._dataset_dir, 'train-labels-idx1-ubyte.gz'), 65 | os.path.join(self._dataset_dir, 'train-images-idx3-ubyte.gz') 66 | ) 67 | 68 | self.y_test, self.x_test = self._read_data( 69 | os.path.join(self._dataset_dir, 't10k-labels-idx1-ubyte.gz'), 70 | os.path.join(self._dataset_dir, 't10k-images-idx3-ubyte.gz') 71 | ) 72 | 73 | self.x_train = self.x_train.astype(np.float32) / 255.0 74 | self.x_test = self.x_test.astype(np.float32) / 255.0 75 | 76 | self.build_dataset() 77 | 78 | def _read_data(self, label_url, image_url): 79 | with gzip.open(label_url) as flbl: 80 | magic, num = struct.unpack(">II",flbl.read(8)) 81 | label = np.fromstring(flbl.read(),dtype=np.int8) 82 | with gzip.open(image_url,'rb') as fimg: 83 | magic, num, rows, cols = struct.unpack(">IIII",fimg.read(16)) 84 | image = np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols) 85 | return (label, image) 86 | 87 | 88 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | import argparse 28 | import time 29 | from datetime import datetime 30 | from shutil import copyfile 31 | 32 | import tensorflow as tf 33 | 34 | sys.path.append('./') 35 | sys.path.append('./lib') 36 | sys.path.append('../') 37 | 38 | from cfgs.networkconfig import get_config 39 | from dataset.dataset import get_dataset 40 | from model.model import get_model 41 | from trainer.trainer import get_trainer 42 | 43 | parser = argparse.ArgumentParser(description='') 44 | parser.add_argument('--gpu', type=str, default='0') 45 | parser.add_argument('--config_file', type=str, default='cvae1') # target config file, stored in ./cfgs 46 | parser.add_argument('--disp_config', type=bool, default=False) # if there is error in config file, set True to print the config file with line number 47 | 48 | args = parser.parse_args() 49 | 50 | if __name__ == '__main__': 51 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 52 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 53 | tf.reset_default_graph() 54 | 55 | # load config file 56 | config = get_config(args.config_file, args.disp_config) 57 | 58 | # make the assets directory and copy the config file to it 59 | # so if you want to reproduce the result in assets dir 60 | # just copy the config_file.json to ./cfgs folder and run python3 train.py --config=(config_file) 61 | if not os.path.exists(config['assets dir']): 62 | os.makedirs(config['assets dir']) 63 | cfg_filename = datetime.now().strftime('config_file_%y-%m-%d_%H-%M-%S.json') 64 | copyfile(os.path.join('./cfgs', args.config_file + '.json'), 65 | os.path.join(config['assets dir'], cfg_filename)) 66 | 67 | # prepare dataset 68 | dataset = get_dataset(config['dataset'], config['dataset params']) 69 | 70 | tfconfig = tf.ConfigProto() 71 | tfconfig.gpu_options.allow_growth = True 72 | 73 | with tf.Session(config=tfconfig) as sess: 74 | 75 | # build model 76 | config['model params']['assets dir'] = config['assets dir'] 77 | model = get_model(config['model'], config['model params']) 78 | 79 | # start training 80 | config['trainer params']['assets dir'] = config['assets dir'] 81 | trainer = get_trainer(config['trainer'], config['trainer params'], model, sess) 82 | 83 | trainer.train(sess, dataset, model) 84 | 85 | -------------------------------------------------------------------------------- /train_batch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | import argparse 28 | import time 29 | from datetime import datetime 30 | from shutil import copyfile 31 | 32 | import tensorflow as tf 33 | 34 | sys.path.append('./') 35 | sys.path.append('./lib') 36 | sys.path.append('../') 37 | 38 | from cfgs.networkconfig import get_config 39 | from dataset.dataset import get_dataset 40 | from model.model import get_model 41 | from trainer.trainer import get_trainer 42 | 43 | parser = argparse.ArgumentParser(description='') 44 | parser.add_argument('--gpu', type=str, default='0') 45 | parser.add_argument('--config_file', type=str, default='cvae1') # target config file, stored in ./cfgs 46 | parser.add_argument('--disp_config', type=bool, default=False) # if there is error in config file, set True to print the config file with line number 47 | 48 | args = parser.parse_args() 49 | 50 | if __name__ == '__main__': 51 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 52 | os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu 53 | tf.reset_default_graph() 54 | 55 | # load config file 56 | config = get_config(args.config_file, args.disp_config) 57 | 58 | # make the assets directory and copy the config file to it 59 | # so if you want to reproduce the result in assets dir 60 | # just copy the config_file.json to ./cfgs folder and run python3 train.py --config=(config_file) 61 | if not os.path.exists(config['assets dir']): 62 | os.makedirs(config['assets dir']) 63 | cfg_filename = datetime.now().strftime('config_file_%y-%m-%d_%H-%M-%S.json') 64 | copyfile(os.path.join('./cfgs', args.config_file + '.json'), 65 | os.path.join(config['assets dir'], cfg_filename)) 66 | 67 | # prepare dataset 68 | dataset = get_dataset(config['dataset'], config['dataset params']) 69 | 70 | tfconfig = tf.ConfigProto() 71 | tfconfig.gpu_options.allow_growth = True 72 | 73 | with tf.Session(config=tfconfig) as sess: 74 | 75 | # build model 76 | config['model params']['assets dir'] = config['assets dir'] 77 | model = get_model(config['model'], config['model params']) 78 | 79 | # start training 80 | config['trainer params']['assets dir'] = config['assets dir'] 81 | trainer = get_trainer(config['trainer'], config['trainer params'], model, sess) 82 | 83 | trainer.train(sess, dataset, model) 84 | 85 | -------------------------------------------------------------------------------- /cfgs/tianchi/guangdong.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "guangdong_defect_classification", 3 | 4 | "dataset" : "tianchi guangdong defect", 5 | "dataset params" : { 6 | "output shape" : [256, 256, 3], 7 | "area height" : 768, 8 | "one hot" : false, 9 | "show warning" : false, 10 | "use cache" : true 11 | }, 12 | 13 | "assets dir" : "assets/tianchi_guangdong/result7", 14 | 15 | "model" : "attention_mil", 16 | "model params" : { 17 | "name" : "mnist", 18 | 19 | "input shape" : [256, 256, 3], 20 | "z dims" : 500, 21 | "nb classes" : 11, 22 | 23 | "finetune steps": 10000, 24 | 25 | "finetune optimizer" : "adam", 26 | "finetune optimizer params" : { 27 | "lr" : 0.0001, 28 | "lr scheme" : "exponential", 29 | "lr params" : { 30 | "decay_steps" : 70000, 31 | "decay_rate" : 0.1 32 | } 33 | }, 34 | 35 | 36 | "optimizer" : "adam", 37 | "optimizer params" : { 38 | "lr" : 0.00001, 39 | "lr scheme" : "exponential", 40 | "lr params" : { 41 | "decay_steps" : 70000, 42 | "decay_rate" : 0.1 43 | } 44 | }, 45 | 46 | "summary" : true, 47 | 48 | "feature_ext" : "classifier", 49 | "feature_ext params" : { 50 | "normalization" : "fused_batch_norm", 51 | 52 | "including conv" : true, 53 | "conv nb blocks" : 6, 54 | "conv nb layers" : [2, 2, 3, 3, 3, 0], 55 | "conv nb filters" : [64, 128, 256, 512, 512], 56 | "conv ksize" : [3, 3, 3, 3, 3], 57 | 58 | "including top" : true, 59 | "fc nb nodes" : [1024, 1024], 60 | 61 | "load pretrained weights" : "config tianchi/guangdong3 classifier", 62 | 63 | "debug" : true 64 | }, 65 | 66 | "attention_net" : "classifier", 67 | "attention_net params" : { 68 | "activation" : "none", 69 | "normalization" : "none", 70 | "weightsinit" : "xavier", 71 | 72 | "including conv" : false, 73 | "including top" : true, 74 | 75 | "fc nb nodes" : [600], 76 | 77 | "has bias": false, 78 | 79 | "debug" : true 80 | }, 81 | 82 | "classifier" : "classifier", 83 | "classifier params" : { 84 | "normalization" : "none", 85 | "activation" : "none", 86 | 87 | "including conv" : false, 88 | 89 | "including top" : true, 90 | "fc nb nodes" : [], 91 | 92 | "has bias": false, 93 | "debug" : true 94 | } 95 | }, 96 | 97 | "trainer" : "supervised_mil", 98 | "trainer params" : { 99 | "summary hyperparams string" : "lr_0_0001_adam", 100 | 101 | 102 | 103 | "train steps" : 150000, 104 | "summary steps" : 1000, 105 | "log steps" : 500, 106 | "save checkpoint steps" : 10000, 107 | 108 | "validators" : [ 109 | { 110 | "validator" : "dataset_validator", 111 | "validate steps" : 2000, 112 | "validator params" : { 113 | "metric" : "accuracy", 114 | "metric type" : "multi-class acc2", 115 | "nb samples" : 200, 116 | "mil" : true 117 | } 118 | } 119 | ] 120 | } 121 | } 122 | 123 | -------------------------------------------------------------------------------- /dataset/cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import numpy as np 28 | import pickle 29 | 30 | from .base_dataset import BaseDataset 31 | from .base_simple_dataset import BaseSimpleDataset 32 | 33 | class Cifar10(BaseSimpleDataset): 34 | 35 | def __init__(self, config): 36 | 37 | super(Cifar10, self).__init__(config) 38 | 39 | self._dataset_dir = 'D:\\Data\\Cifar10' 40 | if not os.path.exists(self._dataset_dir): 41 | self._dataset_dir = 'C:\\Data\\Cifar10' 42 | if not os.path.exists(self._dataset_dir): 43 | self._dataset_dir = 'G:/dataset/Cifar10' 44 | if not os.path.exists(self._dataset_dir): 45 | self._dataset_dir = 'D:\\dataset\\Cifar10' 46 | if not os.path.exists(self._dataset_dir): 47 | self._dataset_dir = '/mnt/data01/dataset/Cifar10' 48 | if not os.path.exists(self._dataset_dir): 49 | self._dataset_dir = '/mnt/sh_flex_storage/zhicongy/dataset/Cifar10' 50 | if not os.path.exists(self._dataset_dir): 51 | self._dataset_dir = self.config.get('dataset dir', self._dataset_dir) 52 | if not os.path.exists(self._dataset_dir): 53 | raise Exception("Cifar10 : the dataset dir " + self._dataset_dir + " is not exist") 54 | 55 | self.name = 'cifar10' 56 | self.output_shape = config.get('output shape', [32, 32, 3]) 57 | self.nb_classes = 10 58 | 59 | train_batch_list = [ 60 | 'data_batch_1', 61 | 'data_batch_2', 62 | 'data_batch_3', 63 | 'data_batch_4', 64 | 'data_batch_5', 65 | ] 66 | test_batch_file = 'test_batch' 67 | 68 | train_data = [] 69 | train_label = [] 70 | for train_file in train_batch_list: 71 | image_data, image_label = self.read_data(train_file, self._dataset_dir) 72 | train_data.append(image_data) 73 | train_label.append(image_label) 74 | self.x_train = np.vstack(train_data).reshape([-1, 3, 32, 32]).transpose([0, 2, 3, 1]).astype(np.float32) / 255.0 75 | self.y_train = np.hstack(train_label) 76 | 77 | test_data, test_label = self.read_data(test_batch_file, self._dataset_dir) 78 | self.x_test = np.reshape(test_data, [-1, 3, 32, 32]).transpose([0, 2, 3, 1]).astype(np.float32) / 255.0 79 | self.y_test = test_label 80 | 81 | self.build_dataset() 82 | 83 | 84 | def read_data(self, filename, data_path): 85 | with open(os.path.join(data_path, filename), 'rb') as datafile: 86 | data_dict = pickle.load(datafile, encoding='bytes') 87 | image_data = np.array(data_dict[b'data']) 88 | image_label = np.array(data_dict[b'labels']) 89 | return image_data, image_label 90 | 91 | -------------------------------------------------------------------------------- /trainer/unsupervised.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | 29 | sys.path.append('.') 30 | sys.path.append('../') 31 | 32 | import queue 33 | import threading 34 | import tensorflow as tf 35 | 36 | from validator.validator import get_validator 37 | 38 | from .base_trainer import BaseTrainer 39 | 40 | 41 | class UnsupervisedTrainer(BaseTrainer): 42 | ''' 43 | ''' 44 | def __init__(self, config, model, sess): 45 | self.config = config 46 | self.model = model 47 | 48 | super(UnsupervisedTrainer, self).__init__(config, model, sess) 49 | self.multi_thread = self.config.get('multi thread', False) 50 | 51 | if self.multi_thread: 52 | self.buffer_depth = self.config.get('buffer depth', 50) 53 | self.train_data_queue = queue.Queue(maxsize=self.buffer_depth) 54 | self.train_data_inner_queue = queue.Queue(maxsize=self.batch_size*self.buffer_depth) 55 | 56 | def train(self, sess, dataset, model): 57 | 58 | self.train_initialize(sess, model) 59 | 60 | if self.multi_thread: 61 | self.coord = tf.train.Coordinator() 62 | threads = [threading.Thread(target=self.read_data_loop, 63 | args=(self.coord, dataset, self.train_data_inner_queue, 'train', 'unsupervised')), 64 | threading.Thread(target=self.read_data_transport_loop, 65 | args=(self.coord, self.train_data_inner_queue, self.train_data_queue, 'train', 'unsupervised'))] 66 | for t in threads: 67 | t.start() 68 | 69 | if self.multi_thread : 70 | # in multi thread model, the image data were read in by dataset.get_train_indices() 71 | # and dataset.read_train_image_by_index() 72 | while True: 73 | epoch, batch_x = self.train_data_queue.get() 74 | step = self.train_inner_step(epoch, model, dataset, batch_x) 75 | if self.train_data_queue.empty() and step % 10 == 0: 76 | print('info : train data buffer empty') 77 | if step > int(self.config['train steps']): 78 | break 79 | else: 80 | epoch = 0 81 | while True: 82 | # in single thread model, the image data were read in by dataset.iter_train_images() 83 | for ind, batch_x in dataset.iter_train_images(method='unsupervised'): 84 | step = self.train_inner_step(epoch, model, dataset, batch_x) 85 | if step > int(self.config['train steps']): 86 | return 87 | epoch += 1 88 | 89 | # join all thread when in multi thread model 90 | self.coord.request_stop() 91 | while not self.train_data_queue.empty(): 92 | epoch, batch_x = self.train_data_queue.get() 93 | self.train_data_inner_queue.task_done() 94 | self.train_data_queue.task_done() 95 | self.coord.join(threads) 96 | 97 | 98 | -------------------------------------------------------------------------------- /validator/test_dataset_validator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import numpy as np 27 | import matplotlib 28 | matplotlib.use('Agg') 29 | import matplotlib.pyplot as plt 30 | from scipy.stats import norm 31 | 32 | from .base_validator import BaseValidator 33 | 34 | class ScatterPlotValidator(object): 35 | 36 | def __init__(self, config): 37 | 38 | super(ScatterPlotValidator, self).__init__(config) 39 | 40 | self.assets_dir = config['assets dir'] 41 | self.log_dir = config.get('log dir', 'test') 42 | self.log_dir = os.path.join(self.assets_dir, self.log_dir) 43 | 44 | self.x_dim = int(config.get('x dim', 0)) 45 | self.y_dim = int(config.get('y dim', 1)) 46 | 47 | if not os.path.exists(self.log_dir): 48 | os.mkdir(self.log_dir) 49 | 50 | self.watch_variable = config.get('watch_variable', 'pred') 51 | 52 | 53 | def validate(self, model, dataset, sess, step): 54 | 55 | x_pos_array = [] 56 | y_pos_array = [] 57 | label_array = [] 58 | 59 | for ind, batch_x, batch_y in dataset.iter_val_images(): 60 | 61 | if self.watch_variable == 'pred': 62 | y_pred = model.predict(sess, batch_y) 63 | 64 | x_pos_array.append(y_pred[:, self.x_dim]) 65 | y_pos_array.append(y_pred[:, self.y_dim]) 66 | label_array.append(batch_y) 67 | 68 | elif self.watch_variable == 'hidden_dist': 69 | z_mean, z_log_var = model.hidden_variable(sess, batch_x) 70 | 71 | x_pos_array.append( 72 | np.concatenate([ 73 | z_mean[:, self.x_dim:self.x_dim+1], 74 | np.exp(z_log_var[:, self.x_dim:self.x_dim+1]) 75 | ], axis=1) 76 | ) 77 | y_pos_array.append( 78 | np.concatenate([ 79 | z_mean[:, self.y_dim:self.y_dim+1], 80 | np.exp(z_log_var[:, self.y_dim:self.y_dim+1]) 81 | ], axis=1) 82 | ) 83 | label_array.append(batch_y) 84 | else: 85 | raise Exception("None watch variable named " + self.watch_variable) 86 | 87 | x_pos_array = np.concatenate(x_pos_array, axis=0) 88 | y_pos_array = np.concatenate(y_pos_array, axis=0) 89 | label_array = np.concatenate(label_array, axis=0) 90 | 91 | if len(x_pos_array.shape) == 2: 92 | for i in range(x_pos_array.shape[1]): 93 | plt.figure(figsize=(6, 6)) 94 | plt.clf() 95 | plt.scatter(x_pos_array[:, i], y_pos_array[:, i], c=label_array) 96 | plt.colorbar() 97 | plt.savefig(os.path.join(self.log_dir, '%07d_%d.png'%(step, i))) 98 | 99 | else: 100 | plt.figure(figsize=(6, 6)) 101 | plt.clf() 102 | plt.scatter(x_pos_array, y_pos_array, c=label_array) 103 | plt.colorbar() 104 | plt.savefig(os.path.join(self.log_dir, '%07d.png'%step)) 105 | pass 106 | 107 | 108 | -------------------------------------------------------------------------------- /cfgs/wgan/wgan_gp.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "wgan_gp", 3 | 4 | "dataset" : "cifar10", 5 | "dataset params" : { 6 | "output_shape" : [32, 32, 3], 7 | "output scalar range" : [-1.0, 1.0] 8 | }, 9 | 10 | "assets dir" : "assets/wgan_gp/wgan_gp5", 11 | "model" : "wgan_gp", 12 | "model params" : { 13 | "name" : "wgan_gp", 14 | 15 | "input shape" : [32, 32, 3], 16 | "z_dim" : 100, 17 | 18 | "discriminator optimizer" : "adam", 19 | "discriminator optimizer params" : { 20 | "lr" : 0.0001, 21 | "lr scheme" : "constant" 22 | // "lr params" : { 23 | // "decay_steps" : 20000, 24 | // "decay_rate" : 0.5 25 | // } 26 | }, 27 | 28 | "generator optimizer" : "adam", 29 | "generator optimizer params" : { 30 | "lr" : 0.0001, 31 | "lr scheme" : "constant" 32 | // "lr params" : { 33 | // "decay_steps" : 20000, 34 | // "decay_rate" : 0.5 35 | // } 36 | }, 37 | 38 | "gradient penalty loss weight" : 10.0, 39 | "summary" : true, 40 | 41 | "generator" : "generator", 42 | "generator params" : { 43 | "normalization" : "fused_batch_norm", // 44 | 45 | "including_bottom" : true, 46 | "fc nb nodes" : [], 47 | "fc_output_reshape" : [4, 4, 512], 48 | 49 | "including_deconv" : true, 50 | "deconv nb blocks" : 4, 51 | "deconv nb layers" : [1, 1, 1, 1], 52 | "deconv nb filters" : [256, 128, 64, 64], 53 | "deconv_ksize" : [5, 5, 5, 5], 54 | 55 | "output dims" : 3, 56 | "output_activation" : "tanh", 57 | "debug" : true 58 | }, 59 | 60 | "discriminator" : "discriminator", 61 | "discriminator params" : { 62 | "activation" : "lrelu 0.2", 63 | "normalization" : "none", 64 | 65 | "including conv" : true, 66 | "conv nb blocks" : 4, 67 | "conv nb layers" : [1, 1, 1, 0], 68 | "conv nb filters" : [128, 256, 512], 69 | "conv ksize" : [5, 5, 5, 5], 70 | "no maxpooling" : true, 71 | 72 | "including top" : true, 73 | "fc nb nodes" : [], 74 | 75 | "output dims" : 1, 76 | "output_activation" : "none", 77 | "debug" : true 78 | } 79 | }, 80 | 81 | "trainer" : "unsupervised", 82 | "trainer params" : { 83 | 84 | "summary dir" : "log", 85 | "summary hyperparams string" : "lr_0_0001", 86 | 87 | "multi thread" : true, 88 | "continue train" : true, 89 | "train steps" : 100000, 90 | 91 | "summary steps" : 1000, 92 | "log steps" : 1000, 93 | "save checkpoint steps" : 10000, 94 | 95 | "batch_size" : 64, 96 | 97 | "debug" : true, 98 | "validators" : [ 99 | { 100 | "validator" : "random_generate", 101 | "validate steps" : 1000, 102 | "validator params" : { 103 | "log dir" : "generated_lr_0_0001", 104 | "z shape" : [100], 105 | "x shape" : [32, 32, 3], 106 | "output scalar range" : [-1.0, 1.0], 107 | "nb row" : 8, 108 | "nb col" : 8 109 | } 110 | }, 111 | { 112 | "validator" : "embedding_visualize", 113 | "validate steps" : 5000, 114 | "validator params" : { 115 | "z shape" : [100], 116 | "x shape" : [32, 32, 3], 117 | "log dir" : "log_lr_0_0001" 118 | } 119 | } 120 | ] 121 | } 122 | } 123 | 124 | 125 | -------------------------------------------------------------------------------- /validator/gan_toy_plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import numpy as np 28 | import matplotlib 29 | matplotlib.use('Agg') 30 | import matplotlib.pyplot as plt 31 | from scipy.stats import norm 32 | 33 | from .base_validator import BaseValidator 34 | 35 | class GanToyPlot(BaseValidator): 36 | 37 | def __init__(self, config): 38 | super(GanToyPlot, self).__init__(config) 39 | self.config = config 40 | self.assets_dir = config['assets dir'] 41 | self.log_dir = self.config.get('log dir', 'gan_toy') 42 | self.log_dir = os.path.join(self.assets_dir, self.log_dir) 43 | 44 | self.z_shape = [int(i) for i in self.config.get('z shape', [2])] 45 | self.nb_points = int(self.config.get('nb points', 100)) 46 | self.plot_range = int(self.config.get('plot range', 3)) 47 | if not os.path.exists(self.log_dir): 48 | os.mkdir(self.log_dir) 49 | 50 | def validate(self, model, dataset, sess, step): 51 | 52 | points = np.zeros([self.nb_points, self.nb_points,] + self.z_shape, dtype=np.float32) 53 | points[:, :, 0] = np.linspace(-self.plot_range, self.plot_range, self.nb_points)[:, None] 54 | points[:, :, 1] = np.linspace(-self.plot_range, self.plot_range, self.nb_points)[None, :] 55 | points = points.reshape((-1, 2)) 56 | 57 | disc_map = model.discriminate(sess, points) 58 | disc_map = disc_map.reshape([self.nb_points, self.nb_points]).transpose() 59 | 60 | dataset_indices = dataset.get_image_indices(phase='train', method='supervised') 61 | dataset_indices = np.random.choice(dataset_indices, size=100) 62 | dataset_x = np.array([dataset.read_image_by_index(ind, phase='train', method='supervised') for ind in dataset_indices]) 63 | indices = np.where(np.logical_and( 64 | np.logical_and(dataset_x[:, 0] >= -self.plot_range, dataset_x[:, 0] <= self.plot_range), 65 | np.logical_and(dataset_x[:, 1] >= -self.plot_range, dataset_x[:, 1] <= self.plot_range) 66 | ))[0] 67 | dataset_x = dataset_x[indices, :] 68 | 69 | random_z = np.random.randn(100, *self.z_shape) 70 | generated_x = model.generate(sess, random_z) 71 | indices = np.where(np.logical_and( 72 | np.logical_and(generated_x[:, 0] >= -self.plot_range, generated_x[:, 0] <= self.plot_range), 73 | np.logical_and(generated_x[:, 1] >= -self.plot_range, generated_x[:, 1] <= self.plot_range) 74 | ))[0] 75 | generated_x = generated_x[indices, :] 76 | 77 | plt.clf() 78 | x = np.linspace(-self.plot_range, self.plot_range, self.nb_points) 79 | y = np.linspace(-self.plot_range, self.plot_range, self.nb_points) 80 | 81 | plt.contour(x, y, disc_map) 82 | plt.scatter(dataset_x[:, 0], dataset_x[:, 1], c='orange', marker='+') 83 | plt.scatter(generated_x[:, 0], generated_x[:, 1], c='green', marker='*') 84 | 85 | plt.savefig(os.path.join(self.log_dir, '%07d.jpg'%step)) 86 | 87 | -------------------------------------------------------------------------------- /cfgs/aae/mnist.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "mnist_aae", 3 | 4 | "dataset" : "mnist", 5 | "dataset params" : { 6 | "output shape" : [28, 28, 1], 7 | "output scalar range" : [-1, 1] 8 | }, 9 | 10 | "assets dir" : "assets/aae/mnist6", 11 | 12 | "model" : "aae", 13 | "model params" : { 14 | "name" : "mnist", 15 | 16 | "input shape" : [28, 28, 1], 17 | "nb classes" : 10, 18 | "z_dim" : 3, 19 | "has label" : false, 20 | "prior distribution" : "normal", 21 | 22 | "auto-encoder optimizer" : "adam", 23 | "auto-encoder optimizer params" : { 24 | "lr" : 0.0001, 25 | "lr scheme" : "constant", 26 | "beta1" : 0.5, 27 | "beta2" : 0.9 28 | }, 29 | 30 | "discriminator optimizer" : "adam", 31 | "discriminator optimizer params" : { 32 | "lr" : 0.0001, 33 | "lr scheme" : "constant", 34 | "beta1" : 0.5, 35 | "beta2" : 0.9 36 | }, 37 | 38 | "encoder optimizer" : "adam", 39 | "encoder optimizer params" : { 40 | "lr" : 0.0001, 41 | "lr scheme" : "constant", 42 | "beta1" : 0.5, 43 | "beta2" : 0.9 44 | }, 45 | 46 | "summary" : true, 47 | 48 | "encoder" : "encoder", 49 | "encoder params" : { 50 | "normalization" : "fused_batch_norm", 51 | 52 | "including conv" : true, 53 | "conv nb blocks" : 3, 54 | "conv nb layers" : [2, 2, 2], 55 | "conv nb filters" : [32, 64, 128], 56 | "conv ksize" : [3, 3, 3], 57 | "no maxpooling" : true, 58 | 59 | "including top" : true, 60 | "fc nb nodes" : [600, 600], 61 | 62 | "output_activation" : "none", 63 | "output_distribution" : "none", 64 | 65 | "debug" : true 66 | }, 67 | 68 | "decoder" : "decoder", 69 | "decoder params" : { 70 | "normalization" : "none", 71 | 72 | "including_bottom" : true, 73 | "fc nb nodes" : [600, 600], 74 | 75 | "including_deconv" : false, 76 | 77 | "output dims" : 784, 78 | "output_shape" : [28, 28, 1], 79 | "output_activation" : "tanh", 80 | 81 | "debug" : true 82 | }, 83 | 84 | "discriminator" : "discriminator", 85 | "discriminator params" : { 86 | "normalization" : "none", 87 | 88 | "including conv" : false, 89 | "including top" : true, 90 | "fc nb nodes" : [600, 600], 91 | 92 | "output_activation" : "none", 93 | 94 | "debug" : true 95 | } 96 | }, 97 | 98 | "trainer" : "unsupervised", 99 | "trainer params" : { 100 | 101 | "summary hyperparams string" : "lr0_0001_adam", 102 | 103 | "continue train" : false, 104 | "multi thread" : true, 105 | 106 | "batch_size" : 32, 107 | "train steps" : 20000, 108 | "summary steps" : 1000, 109 | "log steps" : 100, 110 | "save checkpoint steps" : 10000, 111 | 112 | "validators" : [ 113 | { 114 | "validator" : "scatter_plot", 115 | "validate steps" : 1000, 116 | "validator params" : { 117 | "watch variable" : "hidden dist", 118 | "distribution" : "none", 119 | "x dim" : 0, 120 | "y dim" : 1 121 | } 122 | }, 123 | { 124 | "validator" : "hidden_variable", 125 | "validate steps" : 1000, 126 | "validator params" : { 127 | "z_dim" : 3, 128 | "num_samples" : 15, 129 | "x shape" : [28, 28, 1] 130 | } 131 | } 132 | ] 133 | } 134 | } 135 | 136 | -------------------------------------------------------------------------------- /cfgs/wgan/voc.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "wgan_gp", 3 | 4 | "dataset" : "pascal_voc", 5 | "dataset params" : { 6 | "output shape" : [64, 64, 3], 7 | "scaling range" : [0.15, 0.25], 8 | "crop range" : [0.3, 0.7], 9 | "task" : "classification", 10 | "random mirroring" : false 11 | }, 12 | 13 | "assets dir" : "assets/wgan_gp/voc", 14 | "model" : "wgan_gp", 15 | "model params" : { 16 | "name" : "wgan_gp", 17 | 18 | "input shape" : [64, 64, 3], 19 | "z_dim" : 100, 20 | 21 | "discriminator optimizer" : "rmsprop", 22 | "discriminator optimizer params" : { 23 | "lr" : 0.0001, 24 | "lr scheme" : "exponential", 25 | "lr params" : { 26 | "decay_steps" : 20000, 27 | "decay_rate" : 0.5 28 | } 29 | }, 30 | 31 | "generator optimizer" : "rmsprop", 32 | "generator optimizer params" : { 33 | "lr" : 0.0001, 34 | "lr scheme" : "exponential", 35 | "lr params" : { 36 | "decay_steps" : 20000, 37 | "decay_rate" : 0.5 38 | } 39 | }, 40 | 41 | "gradient penalty loss weight" : 10.0, 42 | "summary" : true, 43 | 44 | "generator" : "generator_conv", 45 | "generator params" : { 46 | "normalization" : "batch_norm", // 47 | "weightsinit" : "normal 0.00 0.02", 48 | 49 | "including_bottom" : true, 50 | "fc nb nodes" : [], 51 | "fc_output_reshape" : [4, 4, 1024], 52 | 53 | "including_deconv" : true, 54 | "deconv nb blocks" : 5, 55 | "deconv nb layers" : [1, 1, 1, 1, 0], // output size = 4 * 2^(5-1) 56 | "deconv nb filters" : [512, 256, 128, 64], 57 | "deconv_ksize" : [5, 5, 5, 5], 58 | 59 | "output dims" : 1, 60 | "output_activation" : "sigmoid", 61 | "debug" : true 62 | }, 63 | 64 | "discriminator" : "discriminator_conv", 65 | "discriminator params" : { 66 | "activation" : "lrelu 0.1", 67 | "normalization" : "none", 68 | "weightsinit" : "normal 0.00 0.02", 69 | 70 | "including conv" : true, 71 | "conv nb blocks" : 5, 72 | "conv nb layers" : [1, 1, 1, 1, 0], 73 | "conv nb filters" : [64, 128, 256, 512], 74 | "conv ksize" : [5, 5, 5, 5], 75 | "no maxpooling" : true, 76 | 77 | "including top" : true, 78 | "fc nb nodes" : [], 79 | 80 | "output dims" : 1, 81 | "output_activation" : "none", 82 | "debug" : true 83 | } 84 | }, 85 | 86 | "trainer" : "unsupervised", 87 | "trainer params" : { 88 | 89 | "summary dir" : "log", 90 | "summary hyperparams string" : "lr_0_0001", 91 | 92 | "multi thread" : true, 93 | "continue train" : true, 94 | "train steps" : 100000, 95 | 96 | "summary steps" : 1000, 97 | "log steps" : 100, 98 | "save checkpoint steps" : 10000, 99 | 100 | "batch_size" : 64, 101 | 102 | "debug" : true, 103 | "validators" : [ 104 | { 105 | "validator" : "random_generate", 106 | "validate steps" : 200, 107 | "validator params" : { 108 | "log dir" : "generated_lr_0_0001", 109 | "z shape" : [100], 110 | "x shape" : [32, 32, 3], 111 | "nb row" : 8, 112 | "nb col" : 8 113 | } 114 | }, 115 | { 116 | "validator" : "embedding_visualize", 117 | "validate steps" : 5000, 118 | "validator params" : { 119 | "z shape" : [100], 120 | "x shape" : [32, 32, 3], 121 | "log dir" : "log_lr_0_0001" 122 | } 123 | } 124 | ] 125 | } 126 | } 127 | 128 | 129 | -------------------------------------------------------------------------------- /netutils/sample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | from functools import partial 27 | import numpy as np 28 | 29 | import tensorflow as tf 30 | import tensorflow.contrib.layers as tcl 31 | 32 | 33 | 34 | 35 | 36 | 37 | def sample_mix_gaussian(batch_size, nb_classes, z_dim, radis=2.0, x_var=0.5, y_var=0.1): 38 | assert z_dim == 2 39 | def sample(x, y, label, n_labels): 40 | shift = 3 41 | r = radis * np.pi / n_labels * label 42 | new_x = x * np.cos(r) - y * np.sin(r) 43 | new_y = x * np.sin(r) + y * np.cos(r) 44 | new_x += shift * np.cos(r) 45 | new_y += shift * np.sin(r) 46 | return new_x, new_y 47 | x = np.random.normal(0, x_var, [batch_size, 1]) 48 | y = np.random.normal(0, y_var, [batch_size, 1]) 49 | label = np.random.randint(0, nb_classes, size=[batch_size, 1]).astype(np.float32) 50 | label_onehot = np.zeros(shape=(batch_size, nb_classes)).astype(np.float32) 51 | for i in range(batch_size): 52 | label_onehot[i, int(label[i, 0])] = 1 53 | 54 | x, y = sample(x, y, label, nb_classes) 55 | return np.concatenate([x, y], axis=1).astype(np.float32), label_onehot 56 | 57 | 58 | def sample_normal(batch_size, z_dim, var=1.0): 59 | if isinstance(z_dim, list): 60 | shape = [batch_size] + z_dim 61 | elif isinstance(z_dim, int): 62 | shape = [batch_size, z_dim] 63 | return np.random.normal(0, var, shape).astype(np.float32) 64 | 65 | 66 | def sample_categorical(batch_size, nb_classes): 67 | def to_categorical(y, nb_classes): 68 | input_shape = y.shape 69 | y = y.ravel().astype(np.int32) 70 | n = y.shape[0] 71 | ret = np.zeros((n, nb_classes), dtype=np.float32) 72 | indices = np.where(y >= 0)[0] 73 | ret[np.arange(n)[indices], y[indices]] = 1.0 74 | ret = ret.reshape(list(input_shape) + [nb_classes, ]) 75 | return ret 76 | 77 | label = np.random.randint(0, nb_classes, size=[ 78 | batch_size]).astype(np.float32) 79 | label_onehot = to_categorical(label, nb_classes) 80 | return label_onehot 81 | 82 | 83 | sample_dict = { 84 | 'mixGaussian' : sample_mix_gaussian, 85 | 'normal' : sample_normal, 86 | 'categorical' : sample_categorical, 87 | 'mix gaussian' : sample_mix_gaussian, 88 | 'gaussian' : sample_normal 89 | } 90 | 91 | 92 | def get_sampler(name, **kwargs): 93 | """ get sample function by name 94 | """ 95 | if name in sample_dict: 96 | return partial(sample_dict[name], **kwargs) 97 | else: 98 | raise ValueError('No sampler named ' + name) 99 | 100 | 101 | def get_sample(name, args): 102 | if name == 'normal' : 103 | z_avg, z_log_var = args 104 | eps = tf.random_normal(shape=tf.shape(z_avg), mean=0.0, stddev=1.0) 105 | return z_avg + tf.exp(z_log_var / 2.0) * eps 106 | else: 107 | raise Exception("None sample function named " + name) 108 | 109 | 110 | -------------------------------------------------------------------------------- /validator/scatter_plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import numpy as np 28 | import matplotlib 29 | matplotlib.use('Agg') 30 | import matplotlib.pyplot as plt 31 | from scipy.stats import norm 32 | 33 | from .base_validator import BaseValidator 34 | 35 | class ScatterPlot(BaseValidator): 36 | 37 | def __init__(self, config): 38 | 39 | super(ScatterPlot, self).__init__(config) 40 | self.assets_dir = config['assets dir'] 41 | self.log_dir = config.get('log dir', 'scatter') 42 | self.log_dir = os.path.join(self.assets_dir, self.log_dir) 43 | 44 | self.x_dim = int(config.get('x dim', 0)) 45 | self.y_dim = int(config.get('y dim', 1)) 46 | 47 | if not os.path.exists(self.log_dir): 48 | os.mkdir(self.log_dir) 49 | 50 | self.watch_variable = config.get('watch variable', 'pred') 51 | self.distribution = config.get('distribution', 'normal') 52 | 53 | def validate(self, model, dataset, sess, step): 54 | 55 | x_pos_array = [] 56 | y_pos_array = [] 57 | label_array = [] 58 | 59 | for ind, batch_x, batch_y in dataset.iter_val_images(): 60 | if self.watch_variable == 'pred': 61 | y_pred = model.predict(sess, batch_y) 62 | x_pos_array.append(y_pred[:, self.x_dim]) 63 | y_pos_array.append(y_pred[:, self.y_dim]) 64 | label_array.append(np.argmax(batch_y, axis=1)) 65 | 66 | elif self.watch_variable == 'hidden dist': 67 | if self.distribution == 'normal': 68 | z_mean, z_log_var = model.hidden_variable(sess, batch_x) 69 | x_pos_array.append( 70 | np.concatenate([ z_mean[:, self.x_dim:self.x_dim+1], 71 | np.exp(z_log_var[:, self.x_dim:self.x_dim+1]) ], axis=1) 72 | ) 73 | y_pos_array.append( 74 | np.concatenate([ z_mean[:, self.y_dim:self.y_dim+1], 75 | np.exp(z_log_var[:, self.y_dim:self.y_dim+1]) ], axis=1) 76 | ) 77 | label_array.append(np.argmax(batch_y, axis=1)) 78 | elif self.distribution == 'none': 79 | z_sample = model.hidden_variable(sess, batch_x) 80 | x_pos_array.append(z_sample[:, self.x_dim]) 81 | y_pos_array.append(z_sample[:, self.y_dim]) 82 | label_array.append(np.argmax(batch_y, axis=1)) 83 | else: 84 | raise Exception("None watch variable named " + self.watch_variable) 85 | 86 | 87 | x_pos_array = np.concatenate(x_pos_array, axis=0) 88 | y_pos_array = np.concatenate(y_pos_array, axis=0) 89 | label_array = np.concatenate(label_array, axis=0) 90 | 91 | if len(x_pos_array.shape) == 2: 92 | for i in range(x_pos_array.shape[1]): 93 | plt.figure(figsize=(6, 6)) 94 | plt.clf() 95 | plt.scatter(x_pos_array[:, i], y_pos_array[:, i], c=label_array) 96 | plt.colorbar() 97 | plt.savefig(os.path.join(self.log_dir, '%07d_%d.png'%(step, i))) 98 | else: 99 | plt.figure(figsize=(6, 6)) 100 | plt.clf() 101 | plt.scatter(x_pos_array, y_pos_array, c=label_array) 102 | plt.colorbar() 103 | plt.savefig(os.path.join(self.log_dir, '%07d.png'%step)) 104 | return None 105 | 106 | 107 | -------------------------------------------------------------------------------- /trainer/supervised.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | import os 26 | import sys 27 | import queue 28 | import threading 29 | import numpy as np 30 | 31 | sys.path.append('.') 32 | sys.path.append('../') 33 | 34 | import tensorflow as tf 35 | 36 | from validator.validator import get_validator 37 | 38 | from .base_trainer import BaseTrainer 39 | 40 | class SupervisedTrainer(BaseTrainer): 41 | ''' 42 | 43 | ''' 44 | def __init__(self, config, model, sess): 45 | self.config = config 46 | self.model = model 47 | super(SupervisedTrainer, self).__init__(config, model, sess) 48 | self.dataset_phase = self.config.get('dataset phase', 'train') 49 | 50 | self.multi_thread = self.config.get('multi thread', False) 51 | if self.multi_thread: 52 | self.buffer_depth = self.config.get('buffer depth', 50) 53 | self.train_data_queue = queue.Queue(maxsize=self.buffer_depth) 54 | self.train_data_inner_queue = queue.Queue(maxsize=self.batch_size*self.buffer_depth) 55 | 56 | 57 | def train(self, sess, dataset, model): 58 | 59 | self.train_initialize(sess, model) 60 | 61 | # if in multi thread model, start threads for read data 62 | if self.multi_thread: 63 | self.coord = tf.train.Coordinator() 64 | threads = [ # this thread is for reading data and put data into self.train_data_inner_queue 65 | threading.Thread(target=self.read_data_loop, 66 | args=(self.coord, dataset, self.train_data_inner_queue, self.dataset_phase, 'supervised')), 67 | # this thread is for grouping data into batches and put into self.train_data_queue 68 | threading.Thread(target=self.read_data_transport_loop, 69 | args=(self.coord, self.train_data_inner_queue, self.train_data_queue, self.dataset_phase, 'supervised'))] 70 | for t in threads: 71 | t.start() 72 | 73 | if self.multi_thread : 74 | # in multi thread model, the image data were read in by dataset.get_train_indices() 75 | # and dataset.read_image_by_index() 76 | while True: 77 | epoch, batch_x, batch_y = self.train_data_queue.get() 78 | step = self.train_inner_step(epoch, model, dataset, batch_x, batch_y) 79 | if self.train_data_queue.empty() and step % 100 == 0: 80 | print('info : train data buffer empty') 81 | if step > int(self.config['train steps']): 82 | break 83 | else: 84 | epoch = 0 85 | while True: 86 | # in single thread model, the image data were read in by dataset.iter_train_images() 87 | for ind, batch_x, batch_y in dataset.iter_train_images(method='supervised'): 88 | step = self.train_inner_step(epoch, model, dataset, batch_x, batch_y) 89 | if step > int(self.config['train steps']): 90 | return 91 | epoch += 1 92 | 93 | # join all thread when in multi thread model 94 | self.coord.request_stop() 95 | while not self.train_data_queue.empty(): 96 | epoch, batch_x, batch_y = self.train_data_queue.get() 97 | self.train_data_inner_queue.task_done() 98 | self.train_data_queue.task_done() 99 | self.coord.join(threads) 100 | 101 | -------------------------------------------------------------------------------- /decoder/decoder_pixel.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | 6 | import tensorflow as tf 7 | import tensorflow.contrib.layers as tcl 8 | 9 | 10 | sys.path.append('../') 11 | 12 | 13 | 14 | from network.weightsinit import get_weightsinit 15 | from network.activation import get_activation 16 | from network.normalization import get_normalization 17 | 18 | 19 | 20 | class DecoderPixel(object): 21 | 22 | def __init__(self, config, model_config, name="DecoderPixel"): 23 | 24 | self.name = name 25 | self.training = model_config["is_training"] 26 | self.normalizer_params = { 27 | 'decay' : 0.999, 28 | 'center' : True, 29 | 'scale' : False, 30 | 'is_training' : self.training 31 | } 32 | 33 | self.config = config 34 | self.model_config = model_config 35 | 36 | 37 | def __call__(self, i, reuse=False): 38 | 39 | if 'activation' in self.config: 40 | act_fn = get_activation(self.config['activation'], self.config['activation_params']) 41 | elif 'activation' in self.model_config: 42 | act_fn = get_activation(self.model_config['activation'], self.model_config['activation_params']) 43 | else: 44 | act_fn = get_activation('lrelu', '0.2') 45 | 46 | if 'batch_norm' in self.config: 47 | norm_fn, norm_params = get_normalization(self.config['batch_norm']) 48 | elif 'batch_norm' in self.model_config: 49 | norm_fn, norm_params = get_normalization(self.model_config['batch_norm']) 50 | else: 51 | norm_fn = tcl.batch_norm 52 | 53 | if 'weightsinit' in self.config: 54 | winit_fn = get_weightsinit(self.config['weightsinit'], self.config['weightsinit_params']) 55 | elif 'weightsinit' in self.model_config: 56 | winit_fn = get_weightsinit(self.model_config['weightsinit'], self.config['weightsinit_params']) 57 | else: 58 | winit_fn = tf.random_normal_initializer(0, 0.02) 59 | 60 | 61 | if 'nb_filters' in self.config: 62 | filters = int(self.config['nb_filters']) 63 | else: 64 | filters = 64 65 | 66 | if 'out_activation' in self.config: 67 | out_act_fn = get_activation(self.config['out_activation'], self.config['out_activation_params']) 68 | else: 69 | out_act_fn = None 70 | 71 | output_classes = self.config['output_classes'] 72 | 73 | with tf.variable_scope(self.name): 74 | if reuse: 75 | tf.get_variable_scope().reuse_variables() 76 | else: 77 | assert tf.get_variable_scope().reuse is False 78 | 79 | 80 | x = tcl.conv2d(i, filters, 3, 81 | stride=1, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 82 | padding='SAME', weights_initializer=winit_fn, scope='conv1_0') 83 | x = tcl.conv2d(x, filters, 3, 84 | stride=1, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 85 | padding='SAME', weights_initializer=winit_fn, scope='conv1_1') 86 | 87 | x = tcl.conv2d_transpose(x, filters*2, 3, 88 | stride=2, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 89 | padding='SAME', weights_initializer=winit_fn, scope='conv2_0') 90 | x = tcl.conv2d(x, filters*2, 3, 91 | stride=1, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 92 | padding='SAME', weights_initializer=winit_fn, scope='conv2_1') 93 | 94 | 95 | x = tcl.conv2d_transpose(x, filters*4, 3, 96 | stride=2, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 97 | padding='SAME', weights_initializer=winit_fn, scope='conv3_0') 98 | x = tcl.conv2d(x, filters*4, 3, 99 | stride=1, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 100 | padding='SAME', weights_initializer=winit_fn, scope='conv3_1') 101 | x = tcl.conv2d(x, filters*4, 3, 102 | stride=1, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 103 | padding='SAME', weights_initializer=winit_fn, scope='conv3_2') 104 | x = tcl.conv2d(x, filters*4, 3, 105 | stride=1, activation_fn=act_fn, normalizer_fn=norm_fn, normalizer_params=norm_params, 106 | padding='SAME', weights_initializer=winit_fn, scope='conv3_3') 107 | 108 | x = tcl.conv2d(x, output_classes, 1, 109 | stride=1, activation_fn=out_act_fn, 110 | padding='SAME', weights_initializer=winit_fn, scope='conv_out') 111 | 112 | return x 113 | 114 | @property 115 | def vars(self): 116 | return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.name) 117 | 118 | 119 | -------------------------------------------------------------------------------- /cfgs/impgan/celeba.json: -------------------------------------------------------------------------------- 1 | { 2 | "config name" : "wgan_gp", 3 | 4 | "dataset" : "celeba", 5 | "dataset params" : { 6 | "output shape" : [64, 64, 3], 7 | "output scalar range" : [-1.0, 1.0], 8 | "unsupervised" : true 9 | }, 10 | 11 | "assets dir" : "assets/wgan_gp/celeba", 12 | "model" : "wgan_gp", 13 | "model params" : { 14 | "name" : "wgan_gp", 15 | 16 | "input shape" : [64, 64, 3], 17 | "z_dim" : 100, 18 | 19 | "discriminator optimizer" : "adam", 20 | "discriminator optimizer params" : { 21 | "lr" : 0.0001, 22 | "lr scheme" : "constant", 23 | "beta1" : 0.5, 24 | "beta2" : 0.9 25 | }, 26 | 27 | "generator optimizer" : "adam", 28 | "generator optimizer params" : { 29 | "lr" : 0.0001, 30 | "lr scheme" : "constant", 31 | "beta1" : 0.5, 32 | "beta2" : 0.9 33 | }, 34 | 35 | "summary" : true, 36 | 37 | "generator" : "generator", 38 | "generator params" : { 39 | "activation" : "relu", 40 | "normalization" : "fused_batch_norm", // 41 | "weightsinit" : "he_uniform", 42 | 43 | "including_bottom" : true, 44 | "fc nb nodes" : [], 45 | "fc_output_reshape" : [4, 4, 1024], 46 | 47 | "including_deconv" : true, 48 | "deconv nb blocks" : 4, 49 | "deconv nb layers" : [1, 1, 1, 0], // output size = 4 * 2^(5-1) 50 | "deconv nb filters" : [512, 256, 128, 64], 51 | "deconv_ksize" : [5, 5, 5, 5], 52 | 53 | "output dims" : 3, 54 | "output_stride" : 2, 55 | "output_ksize" : 5, 56 | "output_activation" : "tanh", 57 | "debug" : true 58 | }, 59 | 60 | "discriminator" : "discriminator", 61 | "discriminator params" : { 62 | "activation" : "lrelu 0.2", 63 | "normalization" : "none", 64 | "weightsinit" : "he_uniform", 65 | 66 | "including conv" : true, 67 | "conv nb blocks" : 5, 68 | "conv nb layers" : [1, 1, 1, 1, 0], 69 | "conv nb filters" : [64, 128, 256, 512], 70 | "conv ksize" : [5, 5, 5, 5], 71 | "no maxpooling" : true, 72 | 73 | "including top" : true, 74 | "fc nb nodes" : [], 75 | 76 | "output dims" : 1, 77 | "output_activation" : "none", 78 | "debug" : true 79 | } 80 | }, 81 | 82 | "trainer" : "unsupervised", 83 | "trainer params" : { 84 | 85 | "summary dir" : "log", 86 | "summary hyperparams string" : "bs32_adam_fm", 87 | 88 | "multi thread" : true, 89 | "continue train" : false, 90 | "train steps" : 100000, 91 | 92 | "summary steps" : 1000, 93 | "log steps" : 100, 94 | "save checkpoint steps" : 10000, 95 | 96 | "batch_size" : 32, 97 | 98 | "debug" : true, 99 | "validators" : [ 100 | { 101 | "validator" : "random_generate", 102 | "validate steps" : 1000, 103 | "validator params" : { 104 | "log dir" : "generated_lr_0_0001", 105 | "z shape" : [100], 106 | "x shape" : [64, 64, 3], 107 | "output scalar range" : [-1.0, 1.0], 108 | "nb row" : 8, 109 | "nb col" : 8 110 | } 111 | }, 112 | { 113 | "validator" : "random_generate", 114 | "validate steps" : 1000, 115 | "validator params" : { 116 | "log dir" : "generated_lr_0_0001_fixed", 117 | "z shape" : [100], 118 | "x shape" : [64, 64, 3], 119 | "output scalar range" : [-1.0, 1.0], 120 | "nb row" : 8, 121 | "nb col" : 8, 122 | "fix z" : true 123 | } 124 | }, 125 | { 126 | "validator" : "embedding_visualize", 127 | "validate steps" : 5000, 128 | "validator params" : { 129 | "z shape" : [100], 130 | "x shape" : [64, 64, 3], 131 | "log dir" : "log_lr_0_0001" 132 | } 133 | } 134 | ] 135 | } 136 | } 137 | 138 | 139 | -------------------------------------------------------------------------------- /validator/tensorboard_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MIT License 3 | # 4 | # Copyright (c) 2018 ZhicongYan 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # ============================================================================== 24 | 25 | 26 | import os 27 | import sys 28 | import numpy as np 29 | 30 | sys.path.append('.') 31 | sys.path.append('../') 32 | 33 | import tensorflow as tf 34 | import tensorflow.contrib.layers as tcl 35 | 36 | 37 | from tensorflow.contrib.tensorboard.plugins import projector 38 | 39 | 40 | from .base_validator import BaseValidator 41 | 42 | class TensorboardEmbedding(BaseValidator): 43 | """ Plot the model output or mediate features into tensorboard embedding panel. 44 | 45 | """ 46 | 47 | def __init__(self, config): 48 | super(TensorboardEmbedding, self).__init__(config) 49 | 50 | self.assets_dir = self.config['assets dir'] 51 | self.log_dir = self.config.get('log dir', 'embedding') 52 | self.log_dir = os.path.join(self.assets_dir, self.log_dir) 53 | 54 | self.z_shape = list(self.config['z shape']) 55 | self.x_shape = list(self.config['x shape']) 56 | self.nb_samples = self.config.get('nb samples', 1000) 57 | self.batch_size = self.config.get('batch_size', 100) 58 | 59 | self.nb_samples = self.nb_samples // self.batch_size * self.batch_size 60 | 61 | 62 | if not os.path.exists(self.log_dir): 63 | os.mkdir(self.log_dir) 64 | 65 | with open(os.path.join(self.log_dir, 'metadata.tsv'), 'w') as f: 66 | f.write("Index\tLabel\n") 67 | for i in range(self.nb_samples): 68 | f.write("%d\t%d\n"%(i, 0)) 69 | for i in range(self.nb_samples): 70 | f.write("%d\t%d\n"%(i+self.nb_samples, 1)) 71 | 72 | summary_writer = tf.summary.FileWriter(self.log_dir) 73 | config = projector.ProjectorConfig() 74 | embedding = config.embeddings.add() 75 | embedding.tensor_name = "test" 76 | embedding.metadata_path = "metadata.tsv" 77 | projector.visualize_embeddings(summary_writer, config) 78 | 79 | self.plot_array_var = tf.get_variable('test', shape=[self.nb_samples*2, int(np.product(self.x_shape))]) 80 | self.saver = tf.train.Saver([self.plot_array_var]) 81 | 82 | def validate(self, model, dataset, sess, step): 83 | 84 | plot_array_list = [] 85 | indices = dataset.get_image_indices(phase='train', method='unsupervised') 86 | indices = np.random.choice(indices, size=self.nb_samples) 87 | 88 | for i, ind in enumerate(indices): 89 | test_x = dataset.read_image_by_index(ind, phase='train', method='unsupervised') 90 | if isinstance(test_x, list): 91 | for x in test_x: 92 | x = x.reshape([-1,]) 93 | plot_array_list.append(x) 94 | if len(plot_array_list) >= self.nb_samples: 95 | break 96 | elif test_x is not None: 97 | test_x = test_x.reshape([-1,]) 98 | plot_array_list.append(test_x) 99 | if len(plot_array_list) >= self.nb_samples: 100 | break 101 | 102 | for i in range(self.nb_samples // self.batch_size): 103 | batch_z = np.random.randn(*([self.nb_samples,] + self.z_shape)) 104 | batch_x = model.generate(sess, batch_z) 105 | for i in range(self.batch_size): 106 | plot_array_list.append(batch_x[i].reshape([-1])) 107 | 108 | plot_array_list = np.array(plot_array_list) 109 | 110 | sess.run(self.plot_array_var.assign(plot_array_list)) 111 | 112 | self.saver.save(sess, os.path.join(self.log_dir, 'model.ckpt'), 113 | global_step=step, 114 | write_meta_graph=False, 115 | strip_default_attrs=True) 116 | 117 | --------------------------------------------------------------------------------