├── mxgan ├── __init__.py ├── viz.py ├── encoder.py ├── ops.py ├── generator.py ├── custom_ops.py └── module.py ├── .gitignore ├── example ├── gan_mnist.py ├── gan_mnist_semisupervised.py ├── gan_cifar10.py ├── gan_mnist_minibatch_discrimination.py └── gan_cifar10_semisupervised.py ├── README.md └── LICENSE /mxgan/__init__.py: -------------------------------------------------------------------------------- 1 | """T-test: random experiment code by Tianqi """ 2 | -------------------------------------------------------------------------------- /mxgan/viz.py: -------------------------------------------------------------------------------- 1 | """Visualization modules""" 2 | 3 | import cv2 4 | 5 | import numpy as np 6 | 7 | def _fill_buf(buf, i, img, shape): 8 | n = buf.shape[0]/shape[1] 9 | m = buf.shape[1]/shape[0] 10 | 11 | sx = (i%m)*shape[0] 12 | sy = (i//m)*shape[1] 13 | buf[sy:sy+shape[1], sx:sx+shape[0], :] = img 14 | 15 | 16 | def layout(X, flip=False): 17 | assert len(X.shape) == 4 18 | X = X.transpose((0, 2, 3, 1)) 19 | X = np.clip(X * 255.0, 0, 255).astype(np.uint8) 20 | n = int(np.ceil(np.sqrt(X.shape[0]))) 21 | buff = np.zeros((n*X.shape[1], n*X.shape[2], X.shape[3]), dtype=np.uint8) 22 | for i, img in enumerate(X): 23 | img = np.flipud(img) if flip else img 24 | _fill_buf(buff, i, img, X.shape[1:3]) 25 | if buff.shape[-1] == 1: 26 | return buff.reshape(buff.shape[0], buff.shape[1]) 27 | if X.shape[-1] != 1: 28 | buff = cv2.cvtColor(buff, cv2.COLOR_BGR2RGB) 29 | return buff 30 | 31 | 32 | def imshow(title, X, waitsec=1, flip=False): 33 | """Show images in X and wait for wait sec. 34 | """ 35 | buff = layout(X, flip=flip) 36 | cv2.imshow(title, buff) 37 | cv2.waitKey(waitsec) 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | *~ 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | -------------------------------------------------------------------------------- /mxgan/encoder.py: -------------------------------------------------------------------------------- 1 | """Collection of encoder symbols 2 | 3 | An encoder encode a input to a vector space. 4 | """ 5 | import mxnet as mx 6 | from .ops import conv2d_bn_leaky 7 | 8 | def lenet(data=None): 9 | """Lenet before classification layer.""" 10 | data = mx.sym.Variable("data") if data is None else data 11 | # 28x28 12 | conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20, name="conv1") 13 | tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh") 14 | pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max", 15 | kernel=(2,2), stride=(2,2)) 16 | # second conv 17 | conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50, name="conv2") 18 | tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh") 19 | pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max", 20 | kernel=(2,2), stride=(2,2)) 21 | d5 = mx.sym.Flatten(pool2) 22 | d5 = mx.sym.FullyConnected(d5, num_hidden=500, name="fc1") 23 | d5 = mx.sym.Activation(d5, act_type="tanh") 24 | return d5 25 | 26 | 27 | def dcgan(data=None, ngf=128): 28 | """Conv net used in original DGCAN""" 29 | data = mx.sym.Variable("data") if data is None else data 30 | # 128, 16, 16 31 | net = mx.sym.Convolution(data, kernel=(4, 4), stride=(2, 2), pad=(1, 1), num_filter=ngf, name="e1_conv") 32 | net = mx.sym.LeakyReLU(net, slope=0.2, act_type="leaky", name="e1_act") 33 | # 256, 8, 8 34 | net = conv2d_bn_leaky(net, kernel=(4, 4), stride=(2, 2), pad=(1, 1), num_filter=ngf*2, prefix="e2") 35 | # 512, 4, 4 36 | net = conv2d_bn_leaky(net, kernel=(4, 4), stride=(2, 2), pad=(1, 1), num_filter=ngf*4, prefix="e3") 37 | net = mx.sym.Flatten(net) 38 | return net 39 | -------------------------------------------------------------------------------- /example/gan_mnist.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import mxnet as mx 4 | import sys 5 | 6 | sys.path.append("..") 7 | 8 | from mxgan import module, generator, encoder, viz 9 | 10 | def ferr(label, pred): 11 | pred = pred.ravel() 12 | label = label.ravel() 13 | return np.abs(label - (pred > 0.5)).sum() / label.shape[0] 14 | 15 | lr = 0.0005 16 | beta1 = 0.5 17 | batch_size = 100 18 | rand_shape = (batch_size, 100) 19 | num_epoch = 100 20 | data_shape = (batch_size, 1, 28, 28) 21 | context = mx.gpu() 22 | 23 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') 24 | sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid") 25 | 26 | gmod = module.GANModule( 27 | sym_gen, 28 | symbol_encoder=encoder.lenet(), 29 | context=context, 30 | data_shape=data_shape, 31 | code_shape=rand_shape) 32 | 33 | gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) 34 | 35 | gmod.init_optimizer( 36 | optimizer="adam", 37 | optimizer_params={ 38 | "learning_rate": lr, 39 | "wd": 0., 40 | "beta1": beta1, 41 | }) 42 | 43 | data_dir = './../../mxnet/example/image-classification/mnist/' 44 | train = mx.io.MNISTIter( 45 | image = data_dir + "train-images-idx3-ubyte", 46 | label = data_dir + "train-labels-idx1-ubyte", 47 | input_shape = data_shape[1:], 48 | batch_size = batch_size, 49 | shuffle = True) 50 | 51 | metric_acc = mx.metric.CustomMetric(ferr) 52 | 53 | for epoch in range(num_epoch): 54 | train.reset() 55 | metric_acc.reset() 56 | for t, batch in enumerate(train): 57 | gmod.update(batch) 58 | gmod.temp_label[:] = 0.0 59 | metric_acc.update([gmod.temp_label], gmod.outputs_fake) 60 | gmod.temp_label[:] = 1.0 61 | metric_acc.update([gmod.temp_label], gmod.outputs_real) 62 | 63 | if t % 100 == 0: 64 | logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) 65 | 66 | viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2) 67 | diff = gmod.temp_diffD[0].asnumpy() 68 | diff = (diff - diff.mean()) / diff.std() + 0.5 69 | viz.imshow("diff", diff) 70 | viz.imshow("data", batch.data[0].asnumpy(), 2) 71 | -------------------------------------------------------------------------------- /example/gan_mnist_semisupervised.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import mxnet as mx 4 | import sys 5 | 6 | sys.path.append("..") 7 | 8 | from mxgan import module, generator, encoder, viz 9 | 10 | def ferr(label, pred): 11 | pred = pred.ravel() 12 | label = label.ravel() 13 | return np.abs(label - (pred > 0.5)).sum() / label.shape[0] 14 | 15 | lr = 0.0005 16 | beta1 = 0.5 17 | batch_size = 100 18 | rand_shape = (batch_size, 100) 19 | num_epoch = 100 20 | data_shape = (batch_size, 1, 28, 28) 21 | context = mx.gpu() 22 | 23 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') 24 | sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid") 25 | 26 | gmod = module.SemiGANModule( 27 | sym_gen, 28 | symbol_encoder=encoder.lenet(), 29 | context=context, 30 | data_shape=data_shape, 31 | num_class=10, 32 | code_shape=rand_shape) 33 | 34 | gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) 35 | 36 | gmod.init_optimizer( 37 | optimizer="adam", 38 | optimizer_params={ 39 | "learning_rate": lr, 40 | "wd": 0., 41 | "beta1": beta1, 42 | }) 43 | 44 | data_dir = './../../mxnet/example/image-classification/mnist/' 45 | train = mx.io.MNISTIter( 46 | image = data_dir + "train-images-idx3-ubyte", 47 | label = data_dir + "train-labels-idx1-ubyte", 48 | input_shape = data_shape[1:], 49 | batch_size = batch_size, 50 | shuffle = True) 51 | 52 | metric_acc = mx.metric.CustomMetric(ferr) 53 | 54 | for epoch in range(num_epoch): 55 | train.reset() 56 | metric_acc.reset() 57 | for t, batch in enumerate(train): 58 | # can switch between labeled and unlabeled. 59 | gmod.update(batch, is_labeled=True) 60 | gmod.temp_label[:] = 0.0 61 | metric_acc.update([gmod.temp_label], gmod.outputs_fake) 62 | gmod.temp_label[:] = 1.0 63 | metric_acc.update([gmod.temp_label], gmod.outputs_real) 64 | 65 | if t % 100 == 0: 66 | logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) 67 | viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2) 68 | diff = gmod.temp_diffD[0].asnumpy() 69 | diff = (diff - diff.mean()) / diff.std() + 0.5 70 | viz.imshow("diff", diff) 71 | viz.imshow("data", batch.data[0].asnumpy(), 2) 72 | -------------------------------------------------------------------------------- /example/gan_cifar10.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import mxnet as mx 4 | import sys 5 | 6 | sys.path.append("..") 7 | 8 | from mxgan import module, generator, encoder, viz 9 | 10 | def ferr(label, pred): 11 | pred = pred.ravel() 12 | label = label.ravel() 13 | return np.abs(label - (pred > 0.5)).sum() / label.shape[0] 14 | 15 | ngf= 64 16 | lr = 0.0003 17 | beta1 = 0.5 18 | batch_size = 100 19 | rand_shape = (batch_size, 100) 20 | num_epoch = 100 21 | data_shape = (batch_size, 3, 32, 32) 22 | context = mx.gpu() 23 | 24 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') 25 | sym_gen = generator.dcgan32x32(oshape=data_shape, ngf=ngf, final_act="tanh") 26 | sym_dec = encoder.dcgan(ngf=ngf / 2) 27 | gmod = module.GANModule( 28 | sym_gen, 29 | sym_dec, 30 | context=context, 31 | data_shape=data_shape, 32 | code_shape=rand_shape) 33 | 34 | gmod.modG.init_params(mx.init.Normal(0.05)) 35 | gmod.modD.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) 36 | 37 | gmod.init_optimizer( 38 | optimizer="adam", 39 | optimizer_params={ 40 | "learning_rate": lr, 41 | "wd": 0., 42 | "beta1": beta1, 43 | }) 44 | 45 | data_dir = './../../mxnet/example/image-classification/cifar10/' 46 | train = mx.io.ImageRecordIter( 47 | path_imgrec = data_dir + "train.rec", 48 | data_shape = data_shape[1:], 49 | batch_size = batch_size, 50 | shuffle=True) 51 | 52 | metric_acc = mx.metric.CustomMetric(ferr) 53 | 54 | for epoch in range(num_epoch): 55 | train.reset() 56 | metric_acc.reset() 57 | for t, batch in enumerate(train): 58 | batch.data[0] = batch.data[0] * (1.0 / 255.0) - 0.5 59 | gmod.update(batch) 60 | gmod.temp_label[:] = 0.0 61 | metric_acc.update([gmod.temp_label], gmod.outputs_fake) 62 | gmod.temp_label[:] = 1.0 63 | metric_acc.update([gmod.temp_label], gmod.outputs_real) 64 | 65 | if t % 50 == 0: 66 | logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) 67 | viz.imshow("gout", gmod.temp_outG[0].asnumpy() + 0.5 , 2, flip=True) 68 | diff = gmod.temp_diffD[0].asnumpy() 69 | diff = (diff - diff.mean()) / diff.std() + 0.5 70 | viz.imshow("diff", diff, flip=True) 71 | viz.imshow("data", batch.data[0].asnumpy() + 0.5, 2, flip=True) 72 | -------------------------------------------------------------------------------- /example/gan_mnist_minibatch_discrimination.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import mxnet as mx 4 | import sys 5 | 6 | sys.path.append("..") 7 | 8 | from mxgan import module, generator, encoder, viz, ops 9 | 10 | def ferr(label, pred): 11 | pred = pred.ravel() 12 | label = label.ravel() 13 | return np.abs(label - (pred > 0.5)).sum() / label.shape[0] 14 | 15 | lr = 0.0005 16 | beta1 = 0.5 17 | batch_size = 100 18 | rand_shape = (batch_size, 100) 19 | num_epoch = 100 20 | data_shape = (batch_size, 1, 28, 28) 21 | context = mx.gpu() 22 | 23 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') 24 | sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid") 25 | encoder = encoder.lenet() 26 | encoder = ops.minibatch_layer(encoder, batch_size, num_kernels=100) 27 | 28 | gmod = module.GANModule( 29 | sym_gen, 30 | symbol_encoder=encoder, 31 | context=context, 32 | data_shape=data_shape, 33 | code_shape=rand_shape) 34 | 35 | gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) 36 | 37 | gmod.init_optimizer( 38 | optimizer="adam", 39 | optimizer_params={ 40 | "learning_rate": lr, 41 | "wd": 0., 42 | "beta1": beta1, 43 | }) 44 | 45 | data_dir = './../../mxnet/example/image-classification/mnist/' 46 | train = mx.io.MNISTIter( 47 | image = data_dir + "train-images-idx3-ubyte", 48 | label = data_dir + "train-labels-idx1-ubyte", 49 | input_shape = data_shape[1:], 50 | batch_size = batch_size, 51 | shuffle = True) 52 | 53 | metric_acc = mx.metric.CustomMetric(ferr) 54 | 55 | for epoch in range(num_epoch): 56 | train.reset() 57 | metric_acc.reset() 58 | for t, batch in enumerate(train): 59 | gmod.update(batch) 60 | gmod.temp_label[:] = 0.0 61 | metric_acc.update([gmod.temp_label], gmod.outputs_fake) 62 | gmod.temp_label[:] = 1.0 63 | metric_acc.update([gmod.temp_label], gmod.outputs_real) 64 | 65 | if t % 100 == 0: 66 | logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) 67 | continue 68 | viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2) 69 | diff = gmod.temp_diffD[0].asnumpy() 70 | diff = (diff - diff.mean()) / diff.std() + 0.5 71 | viz.imshow("diff", diff) 72 | viz.imshow("data", batch.data[0].asnumpy(), 2) 73 | -------------------------------------------------------------------------------- /example/gan_cifar10_semisupervised.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import mxnet as mx 4 | import sys 5 | 6 | sys.path.append("..") 7 | 8 | from mxgan import module, generator, encoder, viz 9 | 10 | def ferr(label, pred): 11 | pred = pred.ravel() 12 | label = label.ravel() 13 | return np.abs(label - (pred > 0.5)).sum() / label.shape[0] 14 | 15 | ngf= 64 16 | lr = 0.0003 17 | beta1 = 0.5 18 | batch_size = 100 19 | rand_shape = (batch_size, 100) 20 | num_epoch = 100 21 | data_shape = (batch_size, 3, 32, 32) 22 | context = mx.gpu() 23 | 24 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') 25 | sym_gen = generator.dcgan32x32(oshape=data_shape, ngf=ngf, final_act="tanh") 26 | sym_dec = encoder.dcgan(ngf=ngf / 2) 27 | gmod = module.SemiGANModule( 28 | sym_gen, 29 | sym_dec, 30 | num_class=10, 31 | context=context, 32 | data_shape=data_shape, 33 | code_shape=rand_shape) 34 | 35 | gmod.modG.init_params(mx.init.Normal(0.05)) 36 | gmod.modD.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) 37 | 38 | gmod.init_optimizer( 39 | optimizer="adam", 40 | optimizer_params={ 41 | "learning_rate": lr, 42 | "wd": 0., 43 | "beta1": beta1, 44 | }) 45 | 46 | data_dir = './../../mxnet/example/image-classification/cifar10/' 47 | train = mx.io.ImageRecordIter( 48 | path_imgrec = data_dir + "train.rec", 49 | data_shape = data_shape[1:], 50 | batch_size = batch_size, 51 | shuffle=True) 52 | 53 | metric_acc = mx.metric.CustomMetric(ferr) 54 | 55 | for epoch in range(num_epoch): 56 | train.reset() 57 | metric_acc.reset() 58 | for t, batch in enumerate(train): 59 | batch.data[0] = batch.data[0] * (1.0 / 255.0) - 0.5 60 | gmod.update(batch, is_labeled=True) 61 | gmod.temp_label[:] = 0.0 62 | metric_acc.update([gmod.temp_label], gmod.outputs_fake) 63 | gmod.temp_label[:] = 1.0 64 | metric_acc.update([gmod.temp_label], gmod.outputs_real) 65 | 66 | if t % 50 == 0: 67 | logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) 68 | viz.imshow("gout", gmod.temp_outG[0].asnumpy() + 0.5 , 2, flip=True) 69 | diff = gmod.temp_diffD[0].asnumpy() 70 | diff = (diff - diff.mean()) / diff.std() + 0.5 71 | viz.imshow("diff", diff, flip=True) 72 | viz.imshow("data", batch.data[0].asnumpy() + 0.5, 2, flip=True) 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MXNet GAN 2 | 3 | [MXNet](https://github.com/dmlc/mxnet) module implementation of multi GPU compatible generative models. 4 | 5 | ## List of Methods 6 | - Unsupervised Training 7 | - Semisupervised Training 8 | - Minibatch discrimation 9 | 10 | ## Usage 11 | 12 | ```python 13 | import logging 14 | import numpy as np 15 | import mxnet as mx 16 | 17 | from mxgan import module, generator, encoder, viz 18 | 19 | def ferr(label, pred): 20 | pred = pred.ravel() 21 | label = label.ravel() 22 | return np.abs(label - (pred > 0.5)).sum() / label.shape[0] 23 | 24 | lr = 0.0005 25 | beta1 = 0.5 26 | batch_size = 100 27 | rand_shape = (batch_size, 100) 28 | num_epoch = 100 29 | data_shape = (batch_size, 1, 28, 28) 30 | context = mx.gpu() 31 | 32 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)-15s %(message)s') 33 | sym_gen = generator.dcgan28x28(oshape=data_shape, ngf=32, final_act="sigmoid") 34 | 35 | gmod = module.GANModule( 36 | sym_gen, 37 | symbol_encoder=encoder.lenet(), 38 | context=context, 39 | data_shape=data_shape, 40 | code_shape=rand_shape) 41 | 42 | gmod.init_params(mx.init.Xavier(factor_type="in", magnitude=2.34)) 43 | 44 | gmod.init_optimizer( 45 | optimizer="adam", 46 | optimizer_params={ 47 | "learning_rate": lr, 48 | "wd": 0., 49 | "beta1": beta1, 50 | }) 51 | 52 | data_dir = './../../mxnet/example/image-classification/mnist/' 53 | train = mx.io.MNISTIter( 54 | image = data_dir + "train-images-idx3-ubyte", 55 | label = data_dir + "train-labels-idx1-ubyte", 56 | input_shape = data_shape[1:], 57 | batch_size = batch_size, 58 | shuffle = True) 59 | 60 | metric_acc = mx.metric.CustomMetric(ferr) 61 | 62 | for epoch in range(num_epoch): 63 | train.reset() 64 | metric_acc.reset() 65 | for t, batch in enumerate(train): 66 | gmod.update(batch) 67 | gmod.temp_label[:] = 0.0 68 | metric_acc.update([gmod.temp_label], gmod.outputs_fake) 69 | gmod.temp_label[:] = 1.0 70 | metric_acc.update([gmod.temp_label], gmod.outputs_real) 71 | 72 | if t % 100 == 0: 73 | logging.info("epoch: %d, iter %d, metric=%s", epoch, t, metric_acc.get()) 74 | viz.imshow("gout", gmod.temp_outG[0].asnumpy(), 2) 75 | diff = gmod.temp_diffD[0].asnumpy() 76 | diff = (diff - diff.mean()) / diff.std() + 0.5 77 | viz.imshow("diff", diff) 78 | viz.imshow("data", batch.data[0].asnumpy(), 2) 79 | ``` -------------------------------------------------------------------------------- /mxgan/ops.py: -------------------------------------------------------------------------------- 1 | """Collection of operatpors.""" 2 | import numpy as np 3 | import mxnet as mx 4 | from .custom_ops import log_sum_exp 5 | from .custom_ops import constant 6 | 7 | BatchNorm = mx.sym.BatchNorm 8 | eps = 1e-5 + 1e-12 9 | 10 | def deconv2d(data, ishape, oshape, kshape, name, stride=(2, 2)): 11 | """a deconv layer that enlarges the feature map""" 12 | target_shape = (oshape[-2], oshape[-1]) 13 | net = mx.sym.Deconvolution(data, 14 | kernel=kshape, 15 | stride=stride, 16 | target_shape=target_shape, 17 | num_filter=oshape[0], 18 | no_bias=True, 19 | name=name) 20 | return net 21 | 22 | 23 | def deconv2d_bn_relu(data, prefix , **kwargs): 24 | net = deconv2d(data, name="%s_deconv" % prefix, **kwargs) 25 | net = BatchNorm(net, fix_gamma=True, eps=eps, name="%s_bn" % prefix) 26 | net = mx.sym.Activation(net, name="%s_act" % prefix, act_type='relu') 27 | return net 28 | 29 | 30 | def deconv2d_act(data, prefix, act_type="relu", **kwargs): 31 | net = deconv2d(data, name="%s_deconv" % prefix, **kwargs) 32 | net = mx.sym.Activation(net, name="%s_act" % prefix, act_type=act_type) 33 | return net 34 | 35 | 36 | def conv2d_bn_leaky(data, prefix, use_global_stats=False, **kwargs): 37 | net = mx.sym.Convolution(data, name="%s_conv" % prefix, **kwargs) 38 | net = BatchNorm(net, fix_gamma=True, eps=eps, 39 | use_global_stats=use_global_stats, 40 | name="%s_bn" % prefix) 41 | net = mx.sym.LeakyReLU(net, act_type="leaky", name="%s_leaky" % prefix) 42 | return net 43 | 44 | 45 | def minibatch_layer(data, batch_size, num_kernels, num_dim=5): 46 | net = mx.sym.FullyConnected(data, 47 | num_hidden=num_kernels*num_dim, 48 | no_bias=True) 49 | net = mx.sym.Reshape(net, shape=(-1, num_kernels, num_dim)) 50 | a = mx.sym.expand_dims(net, axis=3) 51 | b = mx.sym.expand_dims( 52 | mx.sym.transpose(net, axes=(1, 2, 0)), axis=0) 53 | abs_dif = mx.sym.abs(mx.sym.broadcast_minus(a, b)) 54 | # batch, num_kernels, batch 55 | abs_dif = mx.sym.sum(abs_dif, axis=2) 56 | mask = np.eye(batch_size) 57 | mask = np.expand_dims(mask, 1) 58 | mask = 1.0 - mask 59 | rscale = 1.0 / np.sum(mask) 60 | # multiply by mask and rescale 61 | out = mx.sym.sum(mx.sym.broadcast_mul(abs_dif, constant(mask)), axis=2) * rscale 62 | 63 | return mx.sym.Concat(data, out) 64 | -------------------------------------------------------------------------------- /mxgan/generator.py: -------------------------------------------------------------------------------- 1 | """Collection of generators symbols. 2 | 3 | A generator takes random inputs and generate 4 | """ 5 | import numpy as np 6 | import mxnet as mx 7 | 8 | from .ops import deconv2d_bn_relu, deconv2d_act 9 | 10 | 11 | def dcgan32x32(oshape, final_act, ngf=128, code=None): 12 | """DCGAN that generates 32x32 images.""" 13 | assert oshape[-1] == 32 14 | assert oshape[-2] == 32 15 | code = mx.sym.Variable("code") if code is None else code 16 | net = mx.sym.FullyConnected(code, name="g1", num_hidden=4*4*ngf*4, no_bias=True) 17 | net = mx.sym.Activation(net, name="gact1", act_type="relu") 18 | # 4 x 4 19 | net = mx.sym.Reshape(net, shape=(-1, ngf * 4, 4, 4)) 20 | # 8 x 8 21 | net = deconv2d_bn_relu( 22 | net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(4, 4), prefix="g2") 23 | # 16x16 24 | net = deconv2d_bn_relu( 25 | net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 16, 16), kshape=(4, 4), prefix="g3") 26 | # 32x32 27 | net = deconv2d_act( 28 | net, ishape=(ngf, 16, 16), oshape=oshape[-3:], kshape=(4, 4), prefix="g4", act_type=final_act) 29 | return net 30 | 31 | 32 | def dcgan28x28(oshape, final_act, ngf=128, code=None): 33 | """DCGAN that generates 28x28 images.""" 34 | assert oshape[-1] == 28 35 | assert oshape[-2] == 28 36 | code = mx.sym.Variable("code") if code is None else code 37 | net = mx.sym.FullyConnected(code, name="g1", num_hidden=4*4*ngf*4, no_bias=True) 38 | net = mx.sym.Activation(net, name="gact1", act_type="relu") 39 | # 4 x 4 40 | net = mx.sym.Reshape(net, shape=(-1, ngf*4, 4, 4)) 41 | # 8 x 8 42 | net = deconv2d_bn_relu( 43 | net, ishape=(ngf * 4, 4, 4), oshape=(ngf * 2, 8, 8), kshape=(3, 3), prefix="g2") 44 | # 14x14 45 | net = deconv2d_bn_relu( 46 | net, ishape=(ngf * 2, 8, 8), oshape=(ngf, 14, 14), kshape=(4, 4), prefix="g3") 47 | # 28x28 48 | net = deconv2d_act( 49 | net, ishape=(ngf, 14, 14), oshape=oshape[-3:], kshape=(4, 4), prefix="g4", act_type=final_act) 50 | return net 51 | 52 | 53 | def fcgan(oshape, final_act, code=None): 54 | """DCGAN that generates 28x28 images using fully connected nets""" 55 | # Q: whether add BN 56 | code = mx.sym.Variable("code") if code is None else code 57 | net = mx.sym.FullyConnected(code, name="g1", num_hidden=500, no_bias=True) 58 | net = mx.sym.Activation(net, name="a1") 59 | net = mx.sym.FullyConnected(net2, name="g2", num_hidden=500, no_bias=True) 60 | net = mx.sym.Activation(net, name="a2") 61 | s = oshape[-3:] 62 | net = mx.sym.FullyConnected(net2, name="g3", 63 | num_hidden=(s[-3] * s[-2] * s[-1]), 64 | no_bias=True) 65 | net = mx.sym.Activation(net, name="gout", act_type=final_act) 66 | return net 67 | -------------------------------------------------------------------------------- /mxgan/custom_ops.py: -------------------------------------------------------------------------------- 1 | """Customized operators using NDArray GPU API""" 2 | 3 | import numpy as np 4 | import mxnet as mx 5 | import pickle as pkl 6 | 7 | class LogSumExpOp(mx.operator.CustomOp): 8 | """Implementation of log sum exp for numerical stability 9 | """ 10 | def __init__(self, axis): 11 | self.axis = axis 12 | 13 | def forward(self, is_train, req, in_data, out_data, aux): 14 | x = in_data[0] 15 | max_x = mx.nd.max_axis(x, axis=self.axis, keepdims=True) 16 | sum_x = mx.nd.sum(mx.nd.exp(x - max_x), axis=self.axis, keepdims=True) 17 | y = mx.nd.log(sum_x) + max_x 18 | y = y.reshape(out_data[0].shape) 19 | self.assign(out_data[0], req[0], y) 20 | 21 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 22 | y = out_grad[0] 23 | x = in_data[0] 24 | max_x = mx.nd.max_axis(x, axis=self.axis, keepdims=True) 25 | y = y.reshape(max_x.shape) 26 | x = mx.nd.exp(x - max_x) 27 | prob = x / mx.nd.sum(x, axis=self.axis, keepdims=True) 28 | self.assign(in_grad[0], req[0], prob * y) 29 | 30 | 31 | @mx.operator.register("log_sum_exp") 32 | class LogSumExpProp(mx.operator.CustomOpProp): 33 | def __init__(self, axis, keepdims=False): 34 | super(LogSumExpProp, self).__init__(need_top_grad=True) 35 | self.axis = int(axis) 36 | self.keepdims = keepdims in ('True',) 37 | 38 | def list_arguments(self): 39 | return ['data'] 40 | 41 | def list_outputs(self): 42 | return ['output'] 43 | 44 | def infer_shape(self, in_shape): 45 | data_shape = in_shape[0] 46 | oshape = [] 47 | for i, x in enumerate(data_shape): 48 | if i == self.axis: 49 | if self.keepdims: 50 | oshape.append(1) 51 | else: 52 | oshape.append(x) 53 | return [data_shape], [tuple(oshape)], [] 54 | 55 | def create_operator(self, ctx, shapes, dtypes): 56 | return LogSumExpOp(self.axis) 57 | 58 | 59 | def log_sum_exp(in_sym, axis, keepdims=False, name="log_sum_exp"): 60 | return mx.symbol.Custom(in_sym, name=name, 61 | op_type="log_sum_exp", 62 | axis=axis, keepdims=keepdims) 63 | 64 | 65 | class ConstantOp(mx.operator.CustomOp): 66 | """Implementation of mask on minibatch layer. 67 | """ 68 | def __init__(self, data): 69 | self.data = data 70 | 71 | def forward(self, is_train, req, in_data, out_data, aux): 72 | if self.data.context != out_data[0].context: 73 | self.data = self.data.copyto(out_data[0].context) 74 | self.assign(out_data[0], req[0], self.data) 75 | 76 | def backward(self, req, out_grad, in_data, out_data, in_grad, aux): 77 | raise RuntimeError("cannot bp to constant") 78 | 79 | 80 | @mx.operator.register("constant") 81 | class ConstantOpProp(mx.operator.CustomOpProp): 82 | def __init__(self, pkl_data): 83 | super(ConstantOpProp, self).__init__(need_top_grad=False) 84 | self.data = pkl.loads(pkl_data) 85 | 86 | def list_arguments(self): 87 | return [] 88 | 89 | def list_outputs(self): 90 | return ['output'] 91 | 92 | def infer_shape(self, in_shape): 93 | return in_shape, [self.data.shape], [] 94 | 95 | def create_operator(self, ctx, shapes, dtypes): 96 | return ConstantOp(mx.nd.array(self.data)) 97 | 98 | def constant(data, name="constant"): 99 | if isinstance(data, mx.nd.NDArray): 100 | data = data.asnumpy() 101 | pkl_data = pkl.dumps(data) 102 | return mx.symbol.Custom(name=name, 103 | op_type="constant", 104 | pkl_data=pkl_data) 105 | 106 | 107 | # test case latter 108 | def np_softmax(x, axis): 109 | max_x = np.max(x, axis=axis, keepdims=True) 110 | x = np.exp(x - max_x) 111 | x = x / np.sum(x, axis=axis, keepdims=True) 112 | return x 113 | 114 | 115 | def np_log_sum_exp(x, axis, keepdims=False): 116 | max_x = np.max(x, axis=axis, keepdims=True) 117 | x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) 118 | x = x + max_x 119 | if not keepdims: 120 | x = np.squeeze(x, axis=axis) 121 | return x 122 | 123 | 124 | def test_log_sum_exp(): 125 | xpu = mx.gpu() 126 | shape = (2, 2, 100) 127 | axis = 2 128 | keepdims = True 129 | X = mx.sym.Variable('X') 130 | Y = log_sum_exp(X, axis=axis, keepdims=keepdims) 131 | x = mx.nd.array(np.random.normal(size=shape)) 132 | x[:] = 1 133 | xgrad = mx.nd.empty(x.shape) 134 | exec1 = Y.bind(xpu, args = [x], args_grad = {'X': xgrad}) 135 | exec1.forward() 136 | y = exec1.outputs[0] 137 | np.testing.assert_allclose( 138 | y.asnumpy(), 139 | np_log_sum_exp(x.asnumpy(), axis=axis, keepdims=keepdims)) 140 | y[:] = 1 141 | exec1.backward([y]) 142 | np.testing.assert_allclose( 143 | xgrad.asnumpy(), 144 | np_softmax(x.asnumpy(), axis=axis) * y.asnumpy()) 145 | 146 | 147 | def test_constant(): 148 | xpu = mx.gpu() 149 | shape = (2, 2, 100) 150 | x = mx.nd.ones(shape, ctx=xpu) 151 | y = mx.nd.ones(shape, ctx=xpu) 152 | gy = mx.nd.zeros(shape, ctx=xpu) 153 | X = constant(x) + mx.sym.Variable('Y') 154 | xexec = X.bind(xpu, 155 | {'Y': y}, 156 | {'Y': gy}) 157 | xexec.forward() 158 | np.testing.assert_allclose( 159 | xexec.outputs[0].asnumpy(), (x + y).asnumpy()) 160 | xexec.backward([y]) 161 | np.testing.assert_allclose( 162 | gy.asnumpy(), y.asnumpy()) 163 | 164 | 165 | if __name__ == "__main__": 166 | test_constant() 167 | test_log_sum_exp() 168 | -------------------------------------------------------------------------------- /mxgan/module.py: -------------------------------------------------------------------------------- 1 | """Modules for training GAN, work with multiple GPU.""" 2 | import mxnet as mx 3 | from . import ops 4 | import numpy as np 5 | 6 | class GANBaseModule(object): 7 | """Base class to hold gan data 8 | """ 9 | def __init__(self, 10 | symbol_generator, 11 | context, 12 | code_shape): 13 | # generator 14 | self.modG = mx.mod.Module(symbol=symbol_generator, 15 | data_names=("code",), 16 | label_names=None, 17 | context=context) 18 | self.modG.bind(data_shapes=[("code", code_shape)]) 19 | # leave the discriminator 20 | self.temp_outG = None 21 | self.temp_diffD = None 22 | self.temp_gradD = None 23 | self.context = context if isinstance(context, list) else [context] 24 | self.outputs_fake = None 25 | self.outputs_real = None 26 | self.temp_rbatch = mx.io.DataBatch( 27 | [mx.nd.zeros(code_shape, ctx=self.context[-1])], None) 28 | 29 | def _save_temp_gradD(self): 30 | if self.temp_gradD is None: 31 | self.temp_gradD = [ 32 | [grad.copyto(grad.context) for grad in grads] 33 | for grads in self.modD._exec_group.grad_arrays] 34 | else: 35 | for gradsr, gradsf in zip(self.modD._exec_group.grad_arrays, self.temp_gradD): 36 | for gradr, gradf in zip(gradsr, gradsf): 37 | gradr.copyto(gradf) 38 | 39 | def _add_temp_gradD(self): 40 | # add back saved gradient 41 | for gradsr, gradsf in zip(self.modD._exec_group.grad_arrays, self.temp_gradD): 42 | for gradr, gradf in zip(gradsr, gradsf): 43 | gradr += gradf 44 | 45 | def init_params(self, *args, **kwargs): 46 | self.modG.init_params(*args, **kwargs) 47 | self.modD.init_params(*args, **kwargs) 48 | 49 | def init_optimizer(self, *args, **kwargs): 50 | self.modG.init_optimizer(*args, **kwargs) 51 | self.modD.init_optimizer(*args, **kwargs) 52 | 53 | 54 | class GANModule(GANBaseModule): 55 | """A thin wrapper of module to group generator and discriminator together in GAN. 56 | 57 | Example 58 | ------- 59 | lr = 0.0005 60 | mod = GANModule(generator, encoder, context=mx.gpu()), 61 | mod.bind(data_shape=(3, 32, 32)) 62 | mod.init_params(mx.init.Xavier()) 63 | mod.init_optimizer("adam", optimizer_params={ 64 | "learning_rate": lr, 65 | }) 66 | 67 | for t, batch in enumerate(train_data): 68 | mod.update(batch) 69 | # update metrics 70 | mod.temp_label[:] = 0.0 71 | metricG.update_metric(mod.outputs_fake, [mod.temp_label]) 72 | mod.temp_label[:] = 1.0 73 | metricD.update_metric(mod.outputs_real, [mod.temp_label]) 74 | # visualize 75 | if t % 100 == 0: 76 | gen_image = mod.temp_outG[0].asnumpy() 77 | gen_diff = mod.temp_diffD[0].asnumpy() 78 | viz.imshow("gen_image", gen_image) 79 | viz.imshow("gen_diff", gen_diff) 80 | """ 81 | def __init__(self, 82 | symbol_generator, 83 | symbol_encoder, 84 | context, 85 | data_shape, 86 | code_shape, 87 | pos_label=0.9): 88 | super(GANModule, self).__init__( 89 | symbol_generator, context, code_shape) 90 | context = context if isinstance(context, list) else [context] 91 | self.batch_size = data_shape[0] 92 | label_shape = (self.batch_size, ) 93 | encoder = symbol_encoder 94 | encoder = mx.sym.FullyConnected(encoder, num_hidden=1, name="fc_dloss") 95 | encoder = mx.sym.LogisticRegressionOutput(encoder, name='dloss') 96 | self.modD = mx.mod.Module(symbol=encoder, 97 | data_names=("data",), 98 | label_names=("dloss_label",), 99 | context=context) 100 | self.modD.bind(data_shapes=[("data", data_shape)], 101 | label_shapes=[("dloss_label", label_shape)], 102 | inputs_need_grad=True) 103 | self.pos_label = pos_label 104 | self.temp_label = mx.nd.zeros( 105 | label_shape, ctx=context[-1]) 106 | 107 | def update(self, dbatch): 108 | """Update the model for a single batch.""" 109 | # generate fake image 110 | mx.random.normal(0, 1.0, out=self.temp_rbatch.data[0]) 111 | self.modG.forward(self.temp_rbatch) 112 | outG = self.modG.get_outputs() 113 | self.temp_label[:] = 0 114 | self.modD.forward(mx.io.DataBatch(outG, [self.temp_label]), is_train=True) 115 | self.modD.backward() 116 | self._save_temp_gradD() 117 | # update generator 118 | self.temp_label[:] = 1 119 | self.modD.forward(mx.io.DataBatch(outG, [self.temp_label]), is_train=True) 120 | self.modD.backward() 121 | diffD = self.modD.get_input_grads() 122 | self.modG.backward(diffD) 123 | self.modG.update() 124 | self.outputs_fake = [x.copyto(x.context) for x in self.modD.get_outputs()] 125 | # update discriminator 126 | self.temp_label[:] = self.pos_label 127 | dbatch.label = [self.temp_label] 128 | self.modD.forward(dbatch, is_train=True) 129 | self.modD.backward() 130 | self._add_temp_gradD() 131 | self.modD.update() 132 | self.outputs_real = self.modD.get_outputs() 133 | self.temp_outG = outG 134 | self.temp_diffD = diffD 135 | 136 | 137 | class SemiGANModule(GANBaseModule): 138 | """A semisupervised gan that can take both labeled and unlabeled data. 139 | """ 140 | def __init__(self, 141 | symbol_generator, 142 | symbol_encoder, 143 | context, 144 | data_shape, 145 | code_shape, 146 | num_class, 147 | pos_label=0.9): 148 | super(SemiGANModule, self).__init__( 149 | symbol_generator, context, code_shape) 150 | # the discriminator encoder 151 | context = context if isinstance(context, list) else [context] 152 | batch_size = data_shape[0] 153 | self.num_class = num_class 154 | encoder = symbol_encoder 155 | encoder = mx.sym.FullyConnected( 156 | encoder, num_hidden=num_class + 1, name="energy") 157 | self.modD = mx.mod.Module(symbol=encoder, 158 | data_names=("data",), 159 | label_names=None, 160 | context=context) 161 | self.modD.bind(data_shapes=[("data", data_shape)], 162 | inputs_need_grad=True) 163 | self.pos_label = pos_label 164 | # discriminator loss 165 | energy = mx.sym.Variable("energy") 166 | label_out = mx.sym.SoftmaxOutput(energy, name="softmax") 167 | ul_pos_energy = mx.sym.slice_axis( 168 | energy, axis=1, begin=0, end=num_class) 169 | ul_pos_energy = ops.log_sum_exp( 170 | ul_pos_energy, axis=1, keepdims=True, name="ul_pos") 171 | ul_neg_energy = mx.sym.slice_axis( 172 | energy, axis=1, begin=num_class, end=num_class + 1) 173 | ul_pos_prob = mx.sym.LogisticRegressionOutput( 174 | ul_pos_energy - ul_neg_energy, name="dloss") 175 | # use module to bind the 176 | self.mod_label_out = mx.mod.Module( 177 | symbol=label_out, 178 | data_names=("energy",), 179 | label_names=("softmax_label",), 180 | context=context) 181 | self.mod_label_out.bind( 182 | data_shapes=[("energy", (batch_size, num_class + 1))], 183 | label_shapes=[("softmax_label", (batch_size,))], 184 | inputs_need_grad=True) 185 | self.mod_ul_out = mx.mod.Module( 186 | symbol=ul_pos_prob, 187 | data_names=("energy",), 188 | label_names=("dloss_label",), 189 | context=context) 190 | self.mod_ul_out.bind( 191 | data_shapes=[("energy", (batch_size, num_class + 1))], 192 | label_shapes=[("dloss_label", (batch_size,))], 193 | inputs_need_grad=True) 194 | self.mod_ul_out.init_params() 195 | self.mod_label_out.init_params() 196 | self.temp_label = mx.nd.zeros( 197 | (batch_size,), ctx=context[0]) 198 | 199 | def update(self, dbatch, is_labeled): 200 | """Update the model for a single batch.""" 201 | # generate fake image 202 | mx.random.normal(0, 1.0, out=self.temp_rbatch.data[0]) 203 | self.modG.forward(self.temp_rbatch) 204 | outG = self.modG.get_outputs() 205 | self.temp_label[:] = self.num_class 206 | self.modD.forward(mx.io.DataBatch(outG, []), is_train=True) 207 | self.mod_label_out.forward( 208 | mx.io.DataBatch(self.modD.get_outputs(), [self.temp_label]), is_train=True) 209 | self.mod_label_out.backward() 210 | self.modD.backward(self.mod_label_out.get_input_grads()) 211 | self._save_temp_gradD() 212 | # update generator 213 | self.temp_label[:] = 1 214 | self.modD.forward(mx.io.DataBatch(outG, []), is_train=True) 215 | self.mod_ul_out.forward( 216 | mx.io.DataBatch(self.modD.get_outputs(), [self.temp_label]), is_train=True) 217 | self.mod_ul_out.backward() 218 | self.modD.backward(self.mod_ul_out.get_input_grads()) 219 | diffD = self.modD.get_input_grads() 220 | self.modG.backward(diffD) 221 | self.modG.update() 222 | self.outputs_fake = [x.copyto(x.context) for x in self.mod_ul_out.get_outputs()] 223 | # update discriminator 224 | self.modD.forward(mx.io.DataBatch(dbatch.data, []), is_train=True) 225 | outD = self.modD.get_outputs() 226 | self.temp_label[:] = self.pos_label 227 | self.mod_ul_out.forward( 228 | mx.io.DataBatch(outD, [self.temp_label]), is_train=True) 229 | self.outputs_real = [x.copyto(x.context) for x in self.mod_ul_out.get_outputs()] 230 | if is_labeled: 231 | self.mod_label_out.forward( 232 | mx.io.DataBatch(outD, dbatch.label), is_train=True) 233 | self.mod_label_out.backward() 234 | egrad = self.mod_label_out.get_input_grads() 235 | else: 236 | self.mod_ul_out.backward() 237 | egrad = self.mod_ul_out.get_input_grads() 238 | self.modD.backward(egrad) 239 | self._add_temp_gradD() 240 | self.modD.update() 241 | self.temp_outG = outG 242 | self.temp_diffD = diffD 243 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------