├── .gitignore ├── instance_normalization ├── __init__.py ├── link.py └── function.py ├── comparison.png ├── README.md ├── sample.py ├── not_layer_instance_norm_sample.py └── v3 └── on_BN.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | /__pycache__/ 4 | -------------------------------------------------------------------------------- /instance_normalization/__init__.py: -------------------------------------------------------------------------------- 1 | from .link import InstanceNormalization # NOQA 2 | -------------------------------------------------------------------------------- /comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crcrpar/instance_normalization_chainer/HEAD/comparison.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instance Normalization 2 | 2017/12/18: Added new version which is Instance Normalization on top of Batch Normalization. Might be faster if CUDA and cuDNN are available. 3 | 4 | This is a [Chainer v2](https://chainer.org) implementation of Instance Normalization. 5 | Note that this implementation will not work if your Chainer version is under 2.0. 6 | Instance normalization is regarded as more suitable for `style transfer` task than batch normalization. 7 | In Instance normalization, you normalize each mini batch using mean and variance of each tensor in one mini batch. 8 | So the shapes of mean and variance should be `(batch_size, n_channel)`. 9 | 10 | The original paper is found [here](http://arxiv.org/abs/1607.08022). 11 | 12 | I'm looking forward to your review and/or correction. 13 | 14 | # Comparison 15 | This is comparison of InstanceNormalization layer and combination of `chainer.functions` and `chainer.variable.Parameter`. 16 | ![Comparison](https://raw.githubusercontent.com/crcrpar/instance_normalization_chainer/master/comparison.png) 17 | 18 | As in `not_layer_instance_norm_sample.py`, the latter implementation might be more naive than the other. 19 | ```python 20 | def prepare_beta(size, init=0, dtype=np.float32): 21 | initial_beta = chainer.initializers._get_initializer(init) 22 | initial_beta.dtype = dtype 23 | beta = chainer.variable.Parameter(init, size) 24 | return beta 25 | 26 | 27 | def prepare_gamma(size, init=1, dtype=np.float32): 28 | initial_gamma = chainer.initializers._get_initializer(init) 29 | initial_gamma.dtype = dtype 30 | gamma = chainer.variable.Parameter(init, size) 31 | return gamma 32 | 33 | 34 | def instance_norm(self, x, gamma=None, beta=None): 35 | mean = F.mean(x, axis=-1) 36 | mean = F.mean(mean, axis=-1) 37 | mean = F.broadcast_to(mean[Ellipsis, None, None], x.shape) 38 | var = F.squared_difference(x, mean) 39 | std = F.sqrt(var + 1e-5) 40 | x_hat = (x - mean) / std 41 | if gamma is not None: 42 | gamma = F.broadcast_to(gamma[None, Ellipsis, None, None], x.shape) 43 | beta = F.broadcast_to(beta[None, Ellipsis, None, None], x.shape) 44 | return gamma * x_hat + beta 45 | else: 46 | return x_hat 47 | ``` 48 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import chainer 5 | import chainer.functions as F 6 | import chainer.links as L 7 | from chainer import training 8 | from chainer.training import extensions 9 | 10 | from instance_normalization import InstanceNormalization 11 | 12 | 13 | class ShallowConv(chainer.Chain): 14 | 15 | """Shallow Conv 16 | 17 | This is a shallow convolutional network to check whether 18 | InstanceNormalization work or not. 19 | """ 20 | 21 | def __init__(self): 22 | super(ShallowConv, self).__init__() 23 | with self.init_scope(): 24 | self.c_1 = L.Convolution2D(1, 3, 7, 2, 3) 25 | self.i_1 = InstanceNormalization(3) 26 | self.c_2 = L.Convolution2D(3, 6, 7, 4, 4) 27 | self.i_2 = InstanceNormalization(6) 28 | self.l_1 = L.Linear(None, 10) 29 | 30 | def __call__(self, x): 31 | h = F.relu(self.i_1(self.c_1(x))) 32 | h = F.relu(self.i_2(self.c_2(h))) 33 | bs = len(h) 34 | h = F.reshape(h, (bs, -1)) 35 | return self.l_1(h) 36 | 37 | 38 | def main(gpu_id=-1, bs=32, epoch=20, out='./result', resume=''): 39 | net = ShallowConv() 40 | model = L.Classifier(net) 41 | if gpu_id >= 0: 42 | chainer.cuda.get_device_from_id(gpu_id) 43 | model.to_gpu() 44 | optimizer = chainer.optimizers.Adam() 45 | optimizer.setup(model) 46 | 47 | train, test = chainer.datasets.get_mnist(ndim=3) 48 | train_iter = chainer.iterators.SerialIterator(train, bs) 49 | test_iter = chainer.iterators.SerialIterator( 50 | test, bs, repeat=False, shuffle=False) 51 | 52 | updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id) 53 | trainer = training.Trainer(updater, (epoch, 'epoch'), out=out) 54 | trainer.extend(extensions.ParameterStatistics(model.predictor)) 55 | trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) 56 | trainer.extend(extensions.LogReport(log_name='parameter_statistics')) 57 | trainer.extend(extensions.PrintReport( 58 | ['epoch', 'main/loss', 'validation/main/loss', 59 | 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) 60 | trainer.extend(extensions.ProgressBar()) 61 | 62 | if resume: 63 | chainer.serializers.load_npz(resume, trainer) 64 | 65 | trainer.run() 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /instance_normalization/link.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from chainer import configuration 4 | from chainer import cuda 5 | from chainer import initializers 6 | from chainer import link 7 | from chainer import variable 8 | 9 | from .function import fixed_instance_normalization 10 | from .function import InstanceNormalizationFunction 11 | 12 | 13 | class InstanceNormalization(link.Link): 14 | 15 | """Instance normalization layer on outputs of convolution functions. 16 | It is recommended to use this normalization instead of batch normalization 17 | in generative models of what we call Style Transfer. 18 | """ 19 | 20 | def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32, 21 | valid_test=False, use_gamma=True, use_beta=True, 22 | initial_gamma=None, initial_beta=None): 23 | super(InstanceNormalization, self).__init__() 24 | self.valid_test = valid_test 25 | self.avg_mean = None 26 | self.avg_var = None 27 | self.N = 0 28 | if valid_test: 29 | self.register_persistent('avg_mean') 30 | self.register_persistent('avg_var') 31 | self.register_persistent('N') 32 | self.decay = decay 33 | self.eps = eps 34 | 35 | with self.init_scope(): 36 | if use_gamma: 37 | if initial_gamma is None: 38 | initial_gamma = 1 39 | initial_gamma = initializers._get_initializer(initial_gamma) 40 | initial_gamma.dtype = dtype 41 | self.gamma = variable.Parameter(initial_gamma, size) 42 | if use_beta: 43 | if initial_beta is None: 44 | initial_beta = 0 45 | initial_beta = initializers._get_initializer(initial_beta) 46 | initial_beta.dtype = dtype 47 | self.beta = variable.Parameter(initial_beta, size) 48 | 49 | def __call__(self, x, gamma_=None, beta_=None): 50 | if hasattr(self, 'gamma'): 51 | gamma = self.gamma 52 | elif gamma_ is not None: 53 | gamma = gamma_ 54 | else: 55 | with cuda.get_device_from_id(self._device_id): 56 | gamma = variable.Variable(self.xp.ones( 57 | self.avg_mean.shape, dtype=x.dtype)) 58 | if hasattr(self, 'beta'): 59 | beta = self.beta 60 | elif beta_ is not None: 61 | beta = beta_ 62 | else: 63 | with cuda.get_device_from_id(self._device_id): 64 | beta = variable.Variable(self.xp.zeros( 65 | self.avg_mean.shape, dtype=x.dtype)) 66 | 67 | decay = self.decay 68 | if (not configuration.config.train) and self.valid_test: 69 | mean = variable.Variable(self.avg_mean) 70 | var = variable.Variable(self.avg_var) 71 | ret = fixed_instance_normalization( 72 | x, gamma, beta, mean, var, self.eps) 73 | else: 74 | func = InstanceNormalizationFunction( 75 | self.eps, self.avg_mean, self.avg_var, decay) 76 | ret = func(x, gamma, beta) 77 | self.avg_mean = func.running_mean 78 | self.avg_var = func.running_var 79 | 80 | return ret 81 | -------------------------------------------------------------------------------- /not_layer_instance_norm_sample.py: -------------------------------------------------------------------------------- 1 | """Note that this program is possible only v3""" 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | 5 | import numpy as np 6 | 7 | import chainer 8 | import chainer.functions as F 9 | import chainer.links as L 10 | from chainer import training 11 | from chainer.training import extensions 12 | 13 | 14 | def prepare_beta(size, init=0, dtype=np.float32): 15 | initial_beta = chainer.initializers._get_initializer(init) 16 | initial_beta.dtype = dtype 17 | beta = chainer.variable.Parameter(init, size) 18 | return beta 19 | 20 | 21 | def prepare_gamma(size, init=1, dtype=np.float32): 22 | initial_gamma = chainer.initializers._get_initializer(init) 23 | initial_gamma.dtype = dtype 24 | gamma = chainer.variable.Parameter(init, size) 25 | return gamma 26 | 27 | 28 | class ShallowConv(chainer.Chain): 29 | 30 | def __init__(self): 31 | super(ShallowConv, self).__init__() 32 | with self.init_scope(): 33 | self.c_1 = L.Convolution2D(1, 3, 7, 2, 3, nobias=False) 34 | self.c_2 = L.Convolution2D(3, 6, 7, 4, 3, nobias=False) 35 | self.prob = L.Linear(None, 10) 36 | self.gamma_1 = prepare_gamma(3) 37 | self.beta_1 = prepare_beta(3) 38 | self.gamma_2 = prepare_gamma(6) 39 | self.beta_2 = prepare_beta(6) 40 | 41 | def __call__(self, x): 42 | h = F.relu(self.instance_norm(self.c_1(x), self.gamma_1, self.beta_1)) 43 | h = F.relu(self.instance_norm(self.c_2(h), self.gamma_2, self.beta_2)) 44 | bs = len(x) 45 | h = F.reshape(h, (bs, -1)) 46 | return self.prob(h) 47 | 48 | def instance_norm(self, x, gamma=None, beta=None): 49 | mean = F.mean(x, axis=-1) 50 | mean = F.mean(mean, axis=-1) 51 | mean = F.broadcast_to(mean[Ellipsis, None, None], x.shape) 52 | var = F.squared_difference(x, mean) 53 | std = F.sqrt(var + 1e-5) 54 | x_hat = (x - mean) / std 55 | if gamma is not None: 56 | gamma = F.broadcast_to(gamma[None, Ellipsis, None, None], x.shape) 57 | beta = F.broadcast_to(beta[None, Ellipsis, None, None], x.shape) 58 | return gamma * x_hat + beta 59 | else: 60 | return x_hat 61 | 62 | 63 | def main(gpu_id=-1, bs=32, epoch=20, out='./not_layer_result', resume=''): 64 | net = ShallowConv() 65 | model = L.Classifier(net) 66 | if gpu_id >= 0: 67 | chainer.cuda.get_device_from_id(gpu_id) 68 | model.to_gpu() 69 | optimizer = chainer.optimizers.Adam() 70 | optimizer.setup(model) 71 | 72 | train, test = chainer.datasets.get_mnist(ndim=3) 73 | train_iter = chainer.iterators.SerialIterator(train, bs) 74 | test_iter = chainer.iterators.SerialIterator(test, bs, repeat=False, 75 | shuffle=False) 76 | 77 | updater = training.StandardUpdater(train_iter, optimizer, device=gpu_id) 78 | trainer = training.Trainer(updater, (epoch, 'epoch'), out=out) 79 | trainer.extend(extensions.ParameterStatistics(model.predictor)) 80 | trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id)) 81 | trainer.extend(extensions.LogReport()) 82 | trainer.extend(extensions.PrintReport( 83 | ['epoch', 'main/loss', 'validation/main/loss', 84 | 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) 85 | trainer.extend(extensions.ProgressBar()) 86 | 87 | if resume: 88 | chainer.serializers.load_npz(resume, trainer) 89 | 90 | trainer.run() 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /instance_normalization/function.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import chainer 4 | from chainer import configuration 5 | from chainer import cuda 6 | from chainer import function 7 | from chainer.utils import type_check 8 | 9 | 10 | def _xhat(x, mean, std, expander): 11 | x_mu = x - mean[expander] 12 | x_mu /= std[expander] 13 | return x_mu 14 | 15 | 16 | class InstanceNormalizationFunction(function.Function): 17 | 18 | """Instance Normalization function. 19 | 20 | This is similar to Batch Normalization, however, different 21 | in that this function does not require running_mean nor running_var 22 | and mean and variance are calculated for each tensor in mini batch. 23 | """ 24 | 25 | def __init__(self, eps=2e-5, mean=None, var=None, decay=0.9, valid_test=False): 26 | self.running_mean = mean 27 | self.running_var = var 28 | self.eps = eps 29 | self.decay = decay 30 | self.valid_test = valid_test 31 | 32 | self.mean_cache = None 33 | 34 | def check_type_forward(self, in_types): 35 | n_in = type_check.eval(in_types.size()) 36 | if n_in != 3: 37 | raise type_check.InvalidType( 38 | '%s == %s' % (in_types.size(), n_in)) 39 | x_type, gamma_type, beta_type = in_types[:3] 40 | M = type_check.eval(gamma_type.ndim) 41 | type_check.expect( 42 | x_type.dtype.kind == 'f', 43 | x_type.ndim >= gamma_type.ndim + 1, 44 | x_type.shape[1:1 + M] == gamma_type.shape, 45 | gamma_type.dtype == x_type.dtype, 46 | beta_type.dtype == x_type.dtype, 47 | gamma_type.shape == beta_type.shape, 48 | ) 49 | 50 | def forward(self, inputs): 51 | xp = cuda.get_array_module(*inputs) 52 | x, gamma, beta = inputs[:3] 53 | if configuration.config.train: 54 | if self.running_mean is None: 55 | self.running_mean = xp.zeros(x.shape[:2]) 56 | self.running_var = xp.zeros_like(self.running_mean) 57 | else: 58 | self.running_mean = xp.array(self.running_mean) 59 | self.running_var = xp.array(self.running_var) 60 | head_ndim = gamma.ndim + 1 61 | expander = (None, Ellipsis,) + (None,) * (x.ndim - head_ndim) 62 | gamma = gamma[expander] 63 | beta = beta[expander] 64 | mean_var_expander = (Ellipsis, None, None) 65 | 66 | axis = (2, 3) 67 | mean = x.mean(axis=axis) 68 | var = x.var(axis=axis) 69 | var += self.eps 70 | 71 | if (not configuration.config.train) and self.valid_test: 72 | mean = self.fixed_mean 73 | var = self.fixed_var + self.eps 74 | 75 | self.std = xp.sqrt(var, dtype=var.dtype) 76 | 77 | if xp is numpy: 78 | self.x_hat = _xhat(x, mean, self.std, mean_var_expander) 79 | y = gamma * self.x_hat + beta 80 | else: 81 | self.x_hat, y = cuda.elementwise( 82 | 'T x, T mean, T std, T gamma, T beta', 'T x_hat, T y', 83 | ''' 84 | x_hat = (x - mean) / std; 85 | y = gamma * x_hat + beta; 86 | ''', 87 | 'in_fwd')(x, mean[mean_var_expander], 88 | self.std[mean_var_expander], gamma, beta) 89 | 90 | if configuration.config.train: 91 | m = x.size // gamma.size 92 | adjust = m / max(m - 1., 1.) 93 | self.running_mean *= self.decay 94 | tmp_ary = (1 - self.decay) * xp.array(mean) 95 | self.running_mean += tmp_ary 96 | del tmp_ary 97 | self.running_var *= self.decay 98 | tmp_ary = (1 - self.decay) * adjust * xp.array(var) 99 | self.running_var += tmp_ary 100 | del tmp_ary 101 | return y, 102 | 103 | def backward(self, inputs, grad_outputs): 104 | x, gamma, beta = inputs[:3] 105 | gy = grad_outputs[0] 106 | head_ndim = gamma.ndim + 1 107 | expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim) 108 | m = gamma.dtype.type(x.size // gamma.size) 109 | axis = (2, 3) 110 | gamma_beta_axis = (0, 2, 3) 111 | mean_var_expander = (Ellipsis, None, None) 112 | xp = cuda.get_array_module(x) 113 | 114 | gbeta = gy.sum(axis=gamma_beta_axis) 115 | ggamma = (gy * self.x_hat).sum(axis=gamma_beta_axis) 116 | if xp is numpy: 117 | gx = (gamma / self.std)[mean_var_expander] * ( 118 | gy - (self.x_hat * ggamma[mean_var_expander] + gbeta[mean_var_expander]) / m) 119 | else: 120 | inv_m = numpy.float32(1) / m 121 | gx = cuda.elementwise( 122 | 'T gy, T x_hat, T gamma, T std, T ggamma, T gbeta, \ 123 | T inv_m', 124 | 'T gx', 125 | 'gx = (gamma / std) * (gy - (x_hat * ggamma + gbeta) * \ 126 | inv_m)', 127 | 'bn_bwd')(gy, self.x_hat, gamma[expander], 128 | self.std[mean_var_expander], ggamma[mean_var_expander], 129 | gbeta[mean_var_expander], inv_m) 130 | return gx, ggamma, gbeta 131 | 132 | 133 | def fixed_instance_normalization(x, gamma, beta, mean, var, eps=2e-5): 134 | with configuration.using_config('train', False): 135 | return InstanceNormalizationFunction(eps, None, None)(x, gamma, beta) 136 | -------------------------------------------------------------------------------- /v3/on_BN.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import chainer 4 | from chainer import configuration 5 | from chainer import cuda 6 | from chainer import functions 7 | from chainer import initializers 8 | from chainer import link 9 | from chainer.utils import argument 10 | from chainer import variable 11 | 12 | 13 | class InstanceNormalization(link.Link): 14 | 15 | def __init__(self, size, decay=0.9, eps=2e-5, dtype=numpy.float32, 16 | valid_test=False, use_gamma=True, use_beta=True, 17 | initial_gamma=None, initial_beta=None): 18 | super(InstanceNormalization, self).__init__() 19 | self.valid_test = valid_test 20 | self.avg_mean = numpy.zeros(size, dtype=dtype) 21 | self.avg_var = numpy.zeros(size, dtype=dtype) 22 | self.N = 0 23 | self.register_persistent('avg_mean') 24 | self.register_persistent('avg_var') 25 | self.register_persistent('N') 26 | self.decay = decay 27 | self.eps = eps 28 | 29 | with self.init_scope(): 30 | if use_gamma: 31 | if initial_gamma is None: 32 | initial_gamma = 1 33 | initial_gamma = initializers._get_initializer(initial_gamma) 34 | initial_gamma.dtype = dtype 35 | self.gamma = variable.Parameter(initial_gamma, size) 36 | if use_beta: 37 | if initial_beta is None: 38 | initial_beta = 0 39 | initial_beta = initializers._get_initializer(initial_beta) 40 | initial_beta.dtype = dtype 41 | self.beta = variable.Parameter(initial_beta, size) 42 | 43 | def __call__(self, x, **kwargs): 44 | """__call__(self, x, finetune=False) 45 | Invokes the forward propagation of BatchNormalization. 46 | In training mode, the BatchNormalization computes moving averages of 47 | mean and variance for evaluation during training, and normalizes the 48 | input using batch statistics. 49 | .. warning:: 50 | ``test`` argument is not supported anymore since v2. 51 | Instead, use ``chainer.using_config('train', False)``. 52 | See :func:`chainer.using_config`. 53 | Args: 54 | x (Variable): Input variable. 55 | finetune (bool): If it is in the training mode and ``finetune`` is 56 | ``True``, BatchNormalization runs in fine-tuning mode; it 57 | accumulates the input array to compute population statistics 58 | for normalization, and normalizes the input using batch 59 | statistics. 60 | """ 61 | # check argument 62 | argument.check_unexpected_kwargs( 63 | kwargs, test='test argument is not supported anymore. ' 64 | 'Use chainer.using_config') 65 | finetune, = argument.parse_kwargs(kwargs, ('finetune', False)) 66 | 67 | # reshape input x 68 | original_shape = x.shape 69 | batch_size, n_ch = original_shape[:2] 70 | new_shape = (1, batch_size * n_ch) + original_shape[2:] 71 | reshaped_x = functions.reshape(x, new_shape) 72 | 73 | if hasattr(self, 'gamma'): 74 | gamma = self.gamma 75 | else: 76 | with cuda.get_device_from_id(self._device_id): 77 | gamma = variable.Variable(self.xp.ones( 78 | self.avg_mean.shape, dtype=x.dtype)) 79 | if hasattr(self, 'beta'): 80 | beta = self.beta 81 | else: 82 | with cuda.get_device_from_id(self._device_id): 83 | beta = variable.Variable(self.xp.zeros( 84 | self.avg_mean.shape, dtype=x.dtype)) 85 | 86 | mean = chainer.as_variable(self.xp.hstack([self.avg_mean] * batch_size)) 87 | var = chainer.as_variable(self.xp.hstack([self.avg_var] * batch_size)) 88 | gamma = chainer.as_variable(self.xp.hstack([gamma.array] * batch_size)) 89 | beta = chainer.as_variable(self.xp.hstack([beta.array] * batch_size)) 90 | if configuration.config.train: 91 | if finetune: 92 | self.N += 1 93 | decay = 1. - 1. / self.N 94 | else: 95 | decay = self.decay 96 | 97 | ret = functions.batch_normalization( 98 | reshaped_x, gamma, beta, eps=self.eps, running_mean=mean, 99 | running_var=var, decay=decay) 100 | else: 101 | # Use running average statistics or fine-tuned statistics. 102 | ret = functions.fixed_batch_normalization( 103 | reshaped_x, gamma, beta, mean, var, self.eps) 104 | 105 | # ret is normalized input x 106 | return functions.reshape(ret, original_shape) 107 | 108 | 109 | if __name__ == '__main__': 110 | import numpy as np 111 | base_shape = [10, 3] 112 | with chainer.using_config('debug', True): 113 | for i, n_element in enumerate([32, 32, 32]): 114 | base_shape.append(n_element) 115 | print('# {} th: input shape: {}'.format(i, base_shape)) 116 | x_array = np.random.normal(size=base_shape).astype(np.float32) 117 | x = chainer.as_variable(x_array) 118 | layer = InstanceNormalization(base_shape[1]) 119 | y = layer(x) 120 | # calculate y_hat manually 121 | axes = tuple(range(2, len(base_shape))) 122 | x_mean = np.mean(x_array, axis=axes, keepdims=True) 123 | x_var = np.var(x_array, axis=axes, keepdims=True) + 1e-5 124 | x_std = np.sqrt(x_var) 125 | y_hat = (x_array - x_mean) / x_std 126 | diff = y.array - y_hat 127 | print('*** diff ***') 128 | print('\tmean: {:03f},\n\tstd: {:.03f}'.format( 129 | np.mean(diff), np.std(diff))) 130 | 131 | base_shape = [10, 3] 132 | with chainer.using_config('train', False): 133 | print('\n# test mode\n') 134 | for i, n_element in enumerate([32, 32, 32]): 135 | base_shape.append(n_element) 136 | print('# {} th: input shape: {}'.format(i, base_shape)) 137 | x_array = np.random.normal(size=base_shape).astype(np.float32) 138 | x = chainer.as_variable(x_array) 139 | layer = InstanceNormalization(base_shape[1]) 140 | y = layer(x) 141 | axes = tuple(range(2, len(base_shape))) 142 | x_mean = np.mean(x_array, axis=axes, keepdims=True) 143 | x_var = np.var(x_array, axis=axes, keepdims=True) + 1e-5 144 | x_std = np.sqrt(x_var) 145 | y_hat = (x_array - x_mean) / x_std 146 | diff = y.array - y_hat 147 | print('*** diff ***') 148 | print('\tmean: {:03f},\n\tstd: {:.03f}'.format( 149 | np.mean(diff), np.std(diff))) 150 | 151 | 152 | """ 153 | ○ → python instance_norm.py 154 | # 0 th: input shape: [10, 3, 32] 155 | *** diff *** 156 | mean: -0.000000, 157 | std: 0.000 158 | # 1 th: input shape: [10, 3, 32, 32] 159 | *** diff *** 160 | mean: -0.000000, 161 | std: 0.000 162 | # 2 th: input shape: [10, 3, 32, 32, 32] 163 | *** diff *** 164 | mean: -0.000000, 165 | std: 0.000 166 | 167 | # test mode 168 | # 0 th: input shape: [10, 3, 32] 169 | *** diff *** 170 | mean: 14.126040, 171 | std: 227.823 172 | # 1 th: input shape: [10, 3, 32, 32] 173 | *** diff *** 174 | mean: -0.286635, 175 | std: 221.926 176 | # 2 th: input shape: [10, 3, 32, 32, 32] 177 | *** diff *** 178 | mean: -0.064297, 179 | std: 222.492 180 | """ 181 | --------------------------------------------------------------------------------