├── .gitignore ├── LICENSE ├── README.md ├── assets ├── adaptaion.png ├── cls.png ├── cls2.png ├── distribution.png ├── transfer.png └── transfer_image.png └── batch_instance_norm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Batch_Instance_Normalization-Tensorflow 2 | Simple Tensorflow implementation of [Batch-Instance Normalization for Adaptively Style-Invariant Neural Networks (NIPS 2018)](https://arxiv.org/abs/1805.07925) 3 | 4 | ## Code 5 | ```python 6 | 7 | import tensorflow as tf 8 | 9 | def batch_instance_norm(x, scope='batch_instance_norm'): 10 | with tf.variable_scope(scope): 11 | ch = x.shape[-1] 12 | eps = 1e-5 13 | 14 | batch_mean, batch_sigma = tf.nn.moments(x, axes=[0, 1, 2], keep_dims=True) 15 | x_batch = (x - batch_mean) / (tf.sqrt(batch_sigma + eps)) 16 | 17 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 18 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 19 | 20 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 21 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 22 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 23 | 24 | x_hat = rho * x_batch + (1 - rho) * x_ins 25 | x_hat = x_hat * gamma + beta 26 | 27 | return x_hat 28 | 29 | ``` 30 | 31 | ## Usage 32 | 33 | ```python 34 | with tf.variable_scope('network') : 35 | x = conv(x, scope='conv_0') 36 | x = batch_instance_norm(x, scope='bin_norm_0') 37 | x = relu(x) 38 | ``` 39 | 40 | ## Distribution of ρ 41 | ![distribution](./assets/distribution.png) 42 | 43 | ## Results 44 | ### Classification 45 | ![cls](./assets/cls.png) 46 | ![cls2](./assets/cls2.png) 47 | 48 | ## Domain Adaptation 49 | ![adaptation](./assets/adaptaion.png) 50 | 51 | ## Style Transfer 52 | ![transfer](./assets/transfer.png) 53 | ![transfer_image](./assets/transfer_image.png) 54 | 55 | ## Related works 56 | * [Switchable_Normalization](https://github.com/taki0112/Switchable_Normalization-Tensorflow) 57 | 58 | ## Author 59 | Junho Kim 60 | -------------------------------------------------------------------------------- /assets/adaptaion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Batch_Instance_Normalization-Tensorflow/c26afbdd23e5ac2917770c6a0657037aa3d629f5/assets/adaptaion.png -------------------------------------------------------------------------------- /assets/cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Batch_Instance_Normalization-Tensorflow/c26afbdd23e5ac2917770c6a0657037aa3d629f5/assets/cls.png -------------------------------------------------------------------------------- /assets/cls2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Batch_Instance_Normalization-Tensorflow/c26afbdd23e5ac2917770c6a0657037aa3d629f5/assets/cls2.png -------------------------------------------------------------------------------- /assets/distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Batch_Instance_Normalization-Tensorflow/c26afbdd23e5ac2917770c6a0657037aa3d629f5/assets/distribution.png -------------------------------------------------------------------------------- /assets/transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Batch_Instance_Normalization-Tensorflow/c26afbdd23e5ac2917770c6a0657037aa3d629f5/assets/transfer.png -------------------------------------------------------------------------------- /assets/transfer_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/Batch_Instance_Normalization-Tensorflow/c26afbdd23e5ac2917770c6a0657037aa3d629f5/assets/transfer_image.png -------------------------------------------------------------------------------- /batch_instance_norm.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def batch_instance_norm(x, scope='batch_instance_norm'): 4 | with tf.variable_scope(scope): 5 | ch = x.shape[-1] 6 | eps = 1e-5 7 | 8 | batch_mean, batch_sigma = tf.nn.moments(x, axes=[0, 1, 2], keep_dims=True) 9 | x_batch = (x - batch_mean) / (tf.sqrt(batch_sigma + eps)) 10 | 11 | ins_mean, ins_sigma = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 12 | x_ins = (x - ins_mean) / (tf.sqrt(ins_sigma + eps)) 13 | 14 | rho = tf.get_variable("rho", [ch], initializer=tf.constant_initializer(1.0), constraint=lambda x: tf.clip_by_value(x, clip_value_min=0.0, clip_value_max=1.0)) 15 | gamma = tf.get_variable("gamma", [ch], initializer=tf.constant_initializer(1.0)) 16 | beta = tf.get_variable("beta", [ch], initializer=tf.constant_initializer(0.0)) 17 | 18 | x_hat = rho * x_batch + (1 - rho) * x_ins 19 | x_hat = x_hat * gamma + beta 20 | 21 | return x_hat 22 | --------------------------------------------------------------------------------