├── .gitignore ├── KalmanNorm ├── batch_norm_split_gpu.py ├── cifar-bn-largebatch.py ├── cifar-bn-microbatch.py ├── cifar-gn.py ├── cifar-kn-largebatch.py ├── cifar-kn-microbatch.py ├── group_norm.py └── kalman_norm.py ├── README.md ├── results ├── bkn_bn_large_batch.png └── bn_gn_bkn_micro_batch.png └── tensorpack-installed ├── .travis.yml ├── CHANGES.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── README.md ├── _static │ ├── build_toc.js │ ├── build_toc_group.js │ ├── jquery-3.2.1.min.js │ └── sanitize_desc_name.js ├── _templates │ └── layout.html ├── casestudies │ ├── colorize.md │ └── index.rst ├── conf.py ├── index.rst ├── modules │ ├── callbacks.rst │ ├── dataflow.dataset.rst │ ├── dataflow.imgaug.rst │ ├── dataflow.rst │ ├── graph_builder.rst │ ├── index.rst │ ├── input_source.rst │ ├── models.rst │ ├── predict.rst │ ├── tfutils.rst │ ├── train.rst │ └── utils.rst └── tutorial │ ├── callback.md │ ├── dataflow.md │ ├── efficient-dataflow.md │ ├── extend │ ├── augmentor.md │ ├── callback.md │ ├── dataflow.md │ ├── model.md │ └── trainer.md │ ├── faq.md │ ├── index.rst │ ├── inference.md │ ├── input-source.md │ ├── intro.rst │ ├── performance-tuning.md │ ├── save-load.md │ ├── summary.md │ ├── symbolic.md │ ├── trainer.md │ └── training-interface.md ├── readthedocs.yml ├── requirements.txt ├── scripts ├── checkpoint-manipulate.py ├── checkpoint-prof.py ├── dump-model-params.py └── ls-checkpoint.py ├── setup.cfg ├── setup.py ├── tensorpack ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── concurrency.py │ ├── graph.py │ ├── group.py │ ├── hooks.py │ ├── inference.py │ ├── inference_runner.py │ ├── misc.py │ ├── monitor.py │ ├── param.py │ ├── prof.py │ ├── saver.py │ ├── stats.py │ ├── steps.py │ ├── summary.py │ └── trigger.py ├── contrib │ ├── __init__.py │ └── keras.py ├── dataflow │ ├── __init__.py │ ├── base.py │ ├── common.py │ ├── dataset │ │ ├── __init__.py │ │ ├── bsds500.py │ │ ├── cifar.py │ │ ├── ilsvrc.py │ │ ├── mnist.py │ │ ├── smallnorb.py │ │ └── svhn.py │ ├── dftools.py │ ├── format.py │ ├── image.py │ ├── imgaug │ │ ├── __init__.py │ │ ├── _test.py │ │ ├── base.py │ │ ├── convert.py │ │ ├── crop.py │ │ ├── deform.py │ │ ├── geometry.py │ │ ├── gr_rotate.py │ │ ├── imgproc.py │ │ ├── meta.py │ │ ├── misc.py │ │ ├── noise.py │ │ ├── paste.py │ │ └── transform.py │ ├── parallel.py │ ├── parallel_map.py │ ├── raw.py │ └── remote.py ├── graph_builder │ ├── __init__.py │ ├── distributed.py │ ├── model_desc.py │ ├── predict.py │ ├── training.py │ └── utils.py ├── input_source │ ├── __init__.py │ ├── input_source.py │ └── input_source_base.py ├── libinfo.py ├── models │ ├── __init__.py │ ├── _test.py │ ├── batch_norm.py │ ├── common.py │ ├── conv2d.py │ ├── fc.py │ ├── image_sample.py │ ├── layer_norm.py │ ├── linearwrap.py │ ├── nonlin.py │ ├── pool.py │ ├── registry.py │ ├── regularize.py │ ├── shape_utils.py │ ├── shapes.py │ ├── softmax.py │ ├── tflayer.py │ └── utils.py ├── predict │ ├── __init__.py │ ├── base.py │ ├── concurrency.py │ ├── config.py │ ├── dataset.py │ └── multigpu.py ├── tfutils │ ├── __init__.py │ ├── argscope.py │ ├── collection.py │ ├── common.py │ ├── distributed.py │ ├── export.py │ ├── gradproc.py │ ├── model_utils.py │ ├── optimizer.py │ ├── scope_utils.py │ ├── sesscreate.py │ ├── sessinit.py │ ├── summary.py │ ├── symbolic_functions.py │ ├── tower.py │ ├── varmanip.py │ └── varreplace.py ├── train │ ├── __init__.py │ ├── base.py │ ├── config.py │ ├── interface.py │ ├── tower.py │ ├── trainers.py │ └── utility.py ├── trainv1 │ ├── __init__.py │ ├── base.py │ ├── config.py │ ├── distributed.py │ ├── interface.py │ ├── multigpu.py │ ├── simple.py │ └── utility.py └── utils │ ├── __init__.py │ ├── argtools.py │ ├── concurrency.py │ ├── debug.py │ ├── develop.py │ ├── fs.py │ ├── globvars.py │ ├── gpu.py │ ├── loadcaffe.py │ ├── logger.py │ ├── naming.py │ ├── nvml.py │ ├── palette.py │ ├── rect.py │ ├── serialize.py │ ├── stats.py │ ├── timer.py │ ├── utils.py │ └── viz.py ├── tests ├── case_script.py ├── dev │ └── git-hooks │ │ └── pre-commit ├── install-tensorflow.sh ├── run-tests.sh ├── test_char_rnn.py ├── test_infogan.py ├── test_mnist.py ├── test_mnist_similarity.py └── test_resnet.py └── tox.ini /.gitignore: -------------------------------------------------------------------------------- 1 | # tensorpack-specific stuff 2 | train_log 3 | tensorpack/user_ops/obj 4 | *.npy 5 | *.npz 6 | *.caffemodel 7 | *.tfmodel 8 | *.meta 9 | *.log* 10 | *.bin 11 | *.png 12 | *.jpg 13 | checkpoint 14 | *.json 15 | *.prototxt 16 | *.txt 17 | *.tgz 18 | *.gz 19 | 20 | # my personal stuff 21 | snippet 22 | examples/private 23 | examples-old 24 | TODO.md 25 | .gitignore 26 | .vimrc.local 27 | 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | env/ 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *,cover 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | *.dat 87 | 88 | .idea/ 89 | *.diff 90 | -------------------------------------------------------------------------------- /KalmanNorm/group_norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: UTF-8 -*- 3 | # File: group_norm_conv.py 4 | # Author: Guangrun Wang, Jiefeng Peng 5 | 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib.framework import add_model_variable 9 | from tensorflow.python.training import moving_averages 10 | 11 | from tensorpack.utils import logger 12 | from tensorpack.tfutils.tower import get_current_tower_context 13 | from tensorpack.tfutils.common import get_tf_version_number 14 | from tensorpack.tfutils.collection import backup_collection, restore_collection 15 | from tensorpack.models.common import layer_register, VariableHolder 16 | 17 | __all__ = ['GroupNorm'] 18 | 19 | # decay: being too close to 1 leads to slow start-up. torch use 0.9. 20 | # eps: torch: 1e-5. Lasagne: 1e-4 21 | 22 | 23 | def get_bn_variables(n_out, use_scale, use_bias, gamma_init): 24 | if use_bias: 25 | beta = tf.get_variable('beta', [1, n_out, 1, 1], initializer=tf.constant_initializer()) 26 | else: 27 | beta = tf.zeros([1, n_out, 1, 1], name='beta') 28 | if use_scale: 29 | gamma = tf.get_variable('gamma', [1, n_out, 1, 1], initializer=gamma_init) 30 | else: 31 | gamma = tf.ones([1, n_out, 1, 1], name='gamma') 32 | 33 | return beta, gamma 34 | 35 | 36 | @layer_register() 37 | def GroupNorm(inputs, training=None, momentum=0.9, epsilon=1e-5, 38 | center=True, scale=True, 39 | gamma_initializer=tf.ones_initializer(), 40 | data_format='channels_first', 41 | internal_update=False): 42 | """ 43 | """ 44 | shape = inputs.get_shape().as_list() 45 | ndims = len(shape) 46 | assert ndims in [2, 4] 47 | if ndims == 2: 48 | data_format = 'channels_first' 49 | if data_format == 'channels_first': 50 | n_out = shape[1] 51 | else: 52 | n_out = shape[-1] # channel 53 | assert n_out is not None, "Input to GroupNorm cannot have unknown channels!" 54 | beta, gamma = get_bn_variables(n_out, scale, center, gamma_initializer) 55 | 56 | if ndims == 2: 57 | inputs = tf.reshape(inputs, [-1, n_out, 1, 1]) # fused_bn only takes 4D input 58 | 59 | 60 | 61 | 62 | input_shape =inputs.get_shape().as_list() 63 | N = tf.shape(inputs)[0] 64 | C = input_shape[1] 65 | H = input_shape[2] 66 | W = input_shape[3] 67 | 68 | G = 4 69 | inputs = tf.reshape(inputs, [N, G, C // G, H, W]) 70 | batch_mean, batch_var = tf.nn.moments(inputs, [2, 3, 4], keep_dims=True) 71 | xn = (inputs - batch_mean) / tf.sqrt(batch_var + epsilon) 72 | xn = tf.reshape(xn, [N, C, H, W]) 73 | xn = xn * gamma + beta 74 | 75 | if ndims == 2: 76 | xn = tf.squeeze(xn, [1, 2]) 77 | 78 | ret = tf.identity(xn, name='output') 79 | return ret 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kalman Normalization 2 | 3 | By [Guangrun Wang](https://wanggrun.github.io/), [Jiefeng Peng](http://www.sysu-hcp.net/people/), [Ping Luo](http://personal.ie.cuhk.edu.hk/~pluo/), Xinjiang Wang and [Liang Lin](http://www.linliang.net/). 4 | 5 | Sun Yat-sen University (SYSU), the Chinese University of Hong Kong (CUHK), SenseTime Group Ltd. 6 | 7 | ### Table of Contents 8 | 0. [Results](#results) 9 | 0. [Introduction](#introduction) 10 | 0. [Citation](#citation) 11 | 0. [Dependencies](#dependencies) 12 | 0. [Usage](#usage) 13 | 14 | 15 | ### Results 16 | + Under the context of micro-batches(batch size = 2), the validataion curves on CIFAR10: 17 | 18 | top line: [Batch Normalization(BN)](https://arxiv.org/abs/1502.03167); mid line: [Group Normalization](https://arxiv.org/abs/1803.08494); bottom line: [Kalman Normalization](https://arxiv.org/abs/1802.03133) 19 | ![Training curves](https://github.com/wanggrun/Kalman-Normalization/blob/master/results/bn_gn_bkn_micro_batch.png) 20 | 21 | + Under the context of large-batches(batch size = 128), the validataion curves on CIFAR10: 22 | 23 | top line: [Batch Normalization(BN)](https://arxiv.org/abs/1502.03167); bottom line: [Kalman Normalization](https://arxiv.org/abs/1802.03133)) 24 | ![Training curves](https://github.com/wanggrun/Kalman-Normalization/blob/master/results/bkn_bn_large_batch.png) 25 | 26 | 27 | ### Introduction 28 | 29 | This repository contains the original models described in the paper "Batch Kalman Normalization: Towards Training Deep Neural Networks with Micro-Batches" (https://arxiv.org/abs/1802.03133). These models are those used in [ILSVRC](http://image-net.org/challenges/LSVRC/2015/) and [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) 30 | 31 | 32 | 33 | ### Citation 34 | 35 | If you use these models in your research, please cite: 36 | 37 | @article{wang2018batch, 38 | title={Batch Kalman Normalization: Towards Training Deep Neural Networks with Micro-Batches}, 39 | author={Wang, Guangrun and Peng, Jiefeng and Luo, Ping and Wang, Xinjiang and Lin, Liang}, 40 | journal={arXiv preprint arXiv:1802.03133}, 41 | year={2018} 42 | } 43 | 44 | 45 | ### Dependencies 46 | + Python 2.7 or 3 47 | + TensorFlow >= 1.3.0 48 | + [Tensorpack](https://github.com/ppwwyyxx/tensorpack) 49 | The code depends on Yuxin Wu's Tensorpack. For convenience, we provide a stable version 'tensorpack-installed' in this repository. 50 | ``` 51 | # install tensorpack locally: 52 | cd tensorpack-installed 53 | python setup.py install --user 54 | ``` 55 | 56 | ### Usage 57 | + To run Group Normalization, use: 58 | ``` 59 | cd KalmanNorm 60 | python cifar-gn.py --gpu 0 -n 5 --log_dir gn 61 | ``` 62 | + To run Batch Normalization under the context of micro-batches, use: 63 | ``` 64 | cd KalmanNorm 65 | python cifar-bn-microbatch.py --gpu 0 -n 5 --log_dir bn-microbatch 66 | ``` 67 | + To run Kalman Normalization under the context of micro-batches, use: 68 | ``` 69 | cd KalmanNorm 70 | python cifar-kn-microbatch.py --gpu 0 -n 5 --log_dir kn-microbatch 71 | + To run Batch Normalization under the context of large-batches, use: 72 | ``` 73 | cd KalmanNorm 74 | python cifar-bn-largebatch.py --gpu 0 -n 18 --log_dir bn-largebatch 75 | + To run Kalman Normalization under the context of large-batches, use: 76 | ``` 77 | cd KalmanNorm 78 | python cifar-kn-largebatch.py --gpu 0 -n 18 --log_dir kn-largebatch 79 | 80 | -------------------------------------------------------------------------------- /results/bkn_bn_large_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggrun/Kalman-Normalization/77fb6503beca34b25f1b4798c4a103b7fdcf9c47/results/bkn_bn_large_batch.png -------------------------------------------------------------------------------- /results/bn_gn_bkn_micro_batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wanggrun/Kalman-Normalization/77fb6503beca34b25f1b4798c4a103b7fdcf9c47/results/bn_gn_bkn_micro_batch.png -------------------------------------------------------------------------------- /tensorpack-installed/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /tensorpack-installed/README.md: -------------------------------------------------------------------------------- 1 | ![Tensorpack](.github/tensorpack.png) 2 | 3 | Tensorpack is a training interface based on TensorFlow. 4 | 5 | [![Build Status](https://travis-ci.org/ppwwyyxx/tensorpack.svg?branch=master)](https://travis-ci.org/ppwwyyxx/tensorpack) 6 | [![ReadTheDoc](https://readthedocs.org/projects/tensorpack/badge/?version=latest)](http://tensorpack.readthedocs.io/en/latest/index.html) 7 | [![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/tensorpack/users) 8 | [![model-zoo](https://img.shields.io/badge/model-zoo-brightgreen.svg)](http://models.tensorpack.com) 9 | 10 | ## Features: 11 | 12 | It's Yet Another TF high-level API, with __speed__, __readability__ and __flexibility__ built together. 13 | 14 | 1. Focus on __training speed__. 15 | + Speed comes for free with tensorpack -- it uses TensorFlow in the __efficient way__ with no extra overhead. 16 | On different CNNs, it runs training [1.2~5x faster](https://github.com/tensorpack/benchmarks/tree/master/other-wrappers) than the equivalent Keras code. 17 | 18 | + Data-parallel multi-GPU training is off-the-shelf to use. It scales as well as Google's [official benchmark](https://www.tensorflow.org/performance/benchmarks). 19 | 20 | + Distributed data-parallel training is also supported and scales well. See [tensorpack/benchmarks](https://github.com/tensorpack/benchmarks) for more benchmark scripts. 21 | 22 | 2. Focus on __large datasets__. 23 | + It's unnecessary to read/preprocess data with a new language called TF. 24 | Tensorpack helps you load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization. 25 | 26 | 3. It's not a model wrapper. 27 | + There are too many symbolic function wrappers in the world. 28 | Tensorpack includes only a few common models. 29 | But you can use any symbolic function library inside tensorpack, including tf.layers/Keras/slim/tflearn/tensorlayer/.... 30 | 31 | See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) to know more about these features. 32 | 33 | ## [Examples](examples): 34 | 35 | We refuse toy examples. 36 | Instead of showing you 10 arbitrary networks trained on toy datasets, 37 | [tensorpack examples](examples) faithfully replicate papers and care about reproducing numbers, 38 | demonstrating its flexibility for actual research. 39 | 40 | ### Vision: 41 | + [Train ResNet](examples/ResNet) and [other models](examples/ImageNetModels) on ImageNet. 42 | + [Train Faster-RCNN / Mask-RCNN on COCO object detection](examples/FasterRCNN) 43 | + [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN. 44 | + [DoReFa-Net: train binary / low-bitwidth CNN on ImageNet](examples/DoReFa-Net) 45 | + [Fully-convolutional Network for Holistically-Nested Edge Detection(HED)](examples/HED) 46 | + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) 47 | + [Visualize CNN saliency maps](examples/Saliency) 48 | + [Similarity learning on MNIST](examples/SimilarityLearning) 49 | 50 | ### Reinforcement Learning: 51 | + [Deep Q-Network(DQN) variants on Atari games](examples/DeepQNetwork), including DQN, DoubleDQN, DuelingDQN. 52 | + [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym) 53 | 54 | ### Speech / NLP: 55 | + [LSTM-CTC for speech recognition](examples/CTC-TIMIT) 56 | + [char-rnn for fun](examples/Char-RNN) 57 | + [LSTM language model on PennTreebank](examples/PennTreebank) 58 | 59 | ## Install: 60 | 61 | Dependencies: 62 | 63 | + Python 2.7 or 3 64 | + Python bindings for OpenCV (Optional, but required by a lot of features) 65 | + TensorFlow >= 1.3.0 (Optional if you only want to use `tensorpack.dataflow` alone as a data processing library) 66 | ``` 67 | # install git, then: 68 | pip install -U git+https://github.com/ppwwyyxx/tensorpack.git 69 | # or add `--user` to avoid system-wide installation. 70 | ``` 71 | 72 | ## Citing Tensorpack: 73 | 74 | If you use Tensorpack in your research or wish to refer to the examples, please cite with: 75 | ``` 76 | @misc{wu2016tensorpack, 77 | title={Tensorpack}, 78 | author={Wu, Yuxin and others}, 79 | howpublished={\url{https://github.com/tensorpack/}}, 80 | year={2016} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /tensorpack-installed/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = tensorpack 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | .PHONY: help Makefile docset clean 12 | 13 | all: html 14 | 15 | # Put it first so that "make" without argument is like "make help". 16 | help: 17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 18 | 19 | docset: html 20 | doc2dash -d ./ -n $(SPHINXPROJ) --enable-js --force $(BUILDDIR)/html/ -I tutorial/index.html 21 | tar czvf tensorpack.docset.tgz tensorpack.docset 22 | 23 | # Catch-all target: route all unknown targets to Sphinx using the new 24 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 25 | html: Makefile 26 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 27 | 28 | clean: 29 | rm -rf build 30 | -------------------------------------------------------------------------------- /tensorpack-installed/docs/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Build the docs: 3 | 4 | ### Dependencies: 5 | 1. python3 6 | 2. `pip install -r requirements.txt` 7 | 8 | ### Build HTML docs: 9 | `make html` 10 | will build the docs in `build/html`. 11 | 12 | ### Build Dash/Zeal docset 13 | 14 | `make docset` produces `tensorpack.docset`. 15 | -------------------------------------------------------------------------------- /tensorpack-installed/docs/_static/build_toc.js: -------------------------------------------------------------------------------- 1 | // modified from 2 | // https://stackoverflow.com/questions/12150491/toc-list-with-all-classes-generated-by-automodule-in-sphinx-docs 3 | 4 | $(function (){ 5 | var createList = function(selected) { 6 | var ul = $('