├── .gitignore ├── LICENSE.md ├── README.md ├── batch_gen.py ├── complex_layers ├── __pycache__ │ ├── bn.cpython-36.pyc │ ├── conv.cpython-36.pyc │ ├── init.cpython-36.pyc │ ├── norm.cpython-36.pyc │ └── utils.cpython-36.pyc ├── bn.py ├── conv.py ├── dense.py ├── fft.py ├── init.py ├── norm.py ├── pool.py └── utils.py ├── paper ├── DeepQuaternionNets.aux ├── DeepQuaternionNets.bbl ├── DeepQuaternionNets.blg ├── DeepQuaternionNets.log ├── DeepQuaternionNets.pdf ├── DeepQuaternionNets.tex ├── DeepQuaternionNets.tps ├── bib.bib ├── figures │ └── quatconv.png └── ieee_version │ ├── DeepQuaternionNets.aux │ ├── DeepQuaternionNets.bbl │ ├── DeepQuaternionNets.blg │ ├── DeepQuaternionNets.log │ ├── DeepQuaternionNets.pdf │ ├── DeepQuaternionNets.tex │ ├── IEEEtran.cls │ ├── bib.bib │ └── quatconv.png ├── quatconv.png ├── quaternion_layers ├── __pycache__ │ ├── bn.cpython-36.pyc │ ├── conv.cpython-36.pyc │ ├── dist.cpython-36.pyc │ ├── init.cpython-36.pyc │ ├── norm.cpython-36.pyc │ └── utils.cpython-36.pyc ├── bn.py ├── conv.py ├── dense.py ├── dist.py ├── init.py ├── norm.py └── utils.py ├── runner.py ├── scripts ├── complex_seg_train_loss.txt ├── complex_seg_val_loss.txt ├── complex_weights.hd5 ├── plot_results.py ├── quaternion_seg_train_loss.txt ├── quaternion_seg_val_loss.txt ├── quaternion_weights.hd5 ├── real_seg_train_loss.txt ├── real_seg_val_loss.txt ├── real_weights.hd5 ├── run_small_segmentation_experiments.bat ├── small_segmentation_complex.bat ├── small_segmentation_quaternion.bat └── small_segmentation_real.bat ├── training_classification.py ├── training_mnist.py └── training_segmentation.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.pyc 3 | *.cpython 4 | */_pycache_/* 5 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Chase J. Gaudet 4 | 5 | Copyright (c) 2018 Anthony S. Maida 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 8 | 9 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepQuaternionNetworks 2 | Paper and code for the paper Deep Quaternion Networks https://arxiv.org/abs/1712.04604 3 | 4 | The paper has been accepted to IJCNN 2018. 5 | 6 | Please feel free to email me with questions about using the code. I will be making documentation changes and cleaning up the code as soon as I have some more spare time. 7 | 8 | 9 | If using this code or work presented in the paper please cite 10 | 11 | @article{gaudet2017deep, 12 | 13 | title={Deep Quaternion Networks}, 14 | 15 | author={Gaudet, Chase and Maida, Anthony}, 16 | 17 | journal={arXiv preprint arXiv:1712.04604}, 18 | 19 | year={2017} 20 | 21 | } 22 | -------------------------------------------------------------------------------- /batch_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc 4 | from keras.preprocessing.image import ImageDataGenerator 5 | np.random.seed(31337) 6 | 7 | 8 | def gen_batch(shape, batch_size): 9 | image_folder = 'D:/Projects/DeepQuaternionNets/data/segmentation/imgs/' 10 | mask_folder = 'D:/Projects/DeepQuaternionNets/data/segmentation/masks/' 11 | 12 | # Get list of all images 13 | image_names = [] 14 | for image_name in os.listdir(image_folder): 15 | image_names.append(image_name) 16 | 17 | # Augment object 18 | idg = ImageDataGenerator(rotation_range=0, 19 | width_shift_range=0.2, 20 | height_shift_range=0.2, 21 | shear_range=0, 22 | zoom_range=0.2, 23 | horizontal_flip=True, 24 | vertical_flip=False) 25 | 26 | # Batch Loop 27 | image_count = len(image_names) 28 | 29 | while True: 30 | X = np.zeros((batch_size, shape[0], shape[1], shape[2])) 31 | Y = np.zeros((batch_size, 1, shape[1], shape[2])) 32 | count = 0 33 | while count < batch_size: 34 | # Random image and mask 35 | randi = np.random.randint(0, image_count) 36 | image_path = os.path.join(image_folder, image_names[randi]) 37 | mask_path = image_path.replace('imgs', 'masks') 38 | mask_path = mask_path.replace('_', '_road_') 39 | 40 | rs = (shape[1], shape[2], shape[0]) 41 | image = scipy.misc.imread(image_path) 42 | image = scipy.misc.imresize(image, rs) 43 | image = image.transpose((2,0,1)) 44 | mask = scipy.misc.imread(mask_path) 45 | mask = scipy.misc.imresize(mask, rs) 46 | mask = mask.transpose((2,0,1)) 47 | mask = np.logical_and(mask[0,:,:] == 255, mask[2,:,:] == 255).astype('float32') 48 | 49 | X[count] = (image - 127.5) / 127.5 50 | Y[count] = mask 51 | count += 1 52 | 53 | rseed = np.random.randint(0, 1000000) 54 | xc = 0 55 | for x in idg.flow(X, batch_size=1, seed=rseed): 56 | X[xc] = x 57 | xc += 1 58 | if xc >= batch_size: 59 | break 60 | 61 | yc = 0 62 | for y in idg.flow(Y, batch_size=1, seed=rseed): 63 | Y[yc] = y 64 | yc += 1 65 | if yc >= batch_size: 66 | break 67 | 68 | Y = Y > 0.5 69 | 70 | yield X.astype('float32'), Y.astype('float32') 71 | 72 | 73 | if __name__ == '__main__': 74 | bg = gen_batch((3,187,621), 1) 75 | x,y = next(bg) 76 | mc = 0 -------------------------------------------------------------------------------- /complex_layers/__pycache__/bn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/complex_layers/__pycache__/bn.cpython-36.pyc -------------------------------------------------------------------------------- /complex_layers/__pycache__/conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/complex_layers/__pycache__/conv.cpython-36.pyc -------------------------------------------------------------------------------- /complex_layers/__pycache__/init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/complex_layers/__pycache__/init.cpython-36.pyc -------------------------------------------------------------------------------- /complex_layers/__pycache__/norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/complex_layers/__pycache__/norm.cpython-36.pyc -------------------------------------------------------------------------------- /complex_layers/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/complex_layers/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /complex_layers/dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chiheb Trabelsi 6 | # 7 | 8 | from keras import backend as K 9 | import sys; sys.path.append('.') 10 | from keras import backend as K 11 | from keras import activations, initializers, regularizers, constraints 12 | from keras.layers import Layer, InputSpec 13 | import numpy as np 14 | from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams 15 | 16 | 17 | class ComplexDense(Layer): 18 | """Regular complex densely-connected NN layer. 19 | `Dense` implements the operation: 20 | `real_preact = dot(real_input, real_kernel) - dot(imag_input, imag_kernel)` 21 | `imag_preact = dot(real_input, imag_kernel) + dot(imag_input, real_kernel)` 22 | `output = activation(K.concatenate([real_preact, imag_preact]) + bias)` 23 | where `activation` is the element-wise activation function 24 | passed as the `activation` argument, `kernel` is a weights matrix 25 | created by the layer, and `bias` is a bias vector created by the layer 26 | (only applicable if `use_bias` is `True`). 27 | Note: if the input to the layer has a rank greater than 2, then 28 | AN ERROR MESSAGE IS PRINTED. 29 | # Arguments 30 | units: Positive integer, dimensionality of each of the real part 31 | and the imaginary part. It is actualy the number of complex units. 32 | activation: Activation function to use 33 | (see keras.activations). 34 | If you don't specify anything, no activation is applied 35 | (ie. "linear" activation: `a(x) = x`). 36 | use_bias: Boolean, whether the layer uses a bias vector. 37 | kernel_initializer: Initializer for the complex `kernel` weights matrix. 38 | By default it is 'complex'. 39 | and the usual initializers could also be used. 40 | (see keras.initializers and init.py). 41 | bias_initializer: Initializer for the bias vector 42 | (see keras.initializers). 43 | kernel_regularizer: Regularizer function applied to 44 | the `kernel` weights matrix 45 | (see keras.regularizers). 46 | bias_regularizer: Regularizer function applied to the bias vector 47 | (see keras.regularizers). 48 | activity_regularizer: Regularizer function applied to 49 | the output of the layer (its "activation"). 50 | (see keras.regularizers). 51 | kernel_constraint: Constraint function applied to the kernel matrix 52 | (see keras.constraints). 53 | bias_constraint: Constraint function applied to the bias vector 54 | (see keras.constraints). 55 | # Input shape 56 | a 2D input with shape `(batch_size, input_dim)`. 57 | # Output shape 58 | For a 2D input with shape `(batch_size, input_dim)`, 59 | the output would have shape `(batch_size, units)`. 60 | """ 61 | 62 | def __init__(self, units, 63 | activation=None, 64 | use_bias=True, 65 | init_criterion='he', 66 | kernel_initializer='complex', 67 | bias_initializer='zeros', 68 | kernel_regularizer=None, 69 | bias_regularizer=None, 70 | activity_regularizer=None, 71 | kernel_constraint=None, 72 | bias_constraint=None, 73 | seed=None, 74 | **kwargs): 75 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 76 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 77 | super(ComplexDense, self).__init__(**kwargs) 78 | self.units = units 79 | self.activation = activations.get(activation) 80 | self.use_bias = use_bias 81 | self.init_criterion = init_criterion 82 | if kernel_initializer in {'complex'}: 83 | self.kernel_initializer = kernel_initializer 84 | else: 85 | self.kernel_initializer = initializers.get(kernel_initializer) 86 | self.bias_initializer = initializers.get(bias_initializer) 87 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 88 | self.bias_regularizer = regularizers.get(bias_regularizer) 89 | self.activity_regularizer = regularizers.get(activity_regularizer) 90 | self.kernel_constraint = constraints.get(kernel_constraint) 91 | self.bias_constraint = constraints.get(bias_constraint) 92 | if seed is None: 93 | self.seed = np.random.randint(1, 10e6) 94 | else: 95 | self.seed = seed 96 | self.input_spec = InputSpec(ndim=2) 97 | self.supports_masking = True 98 | 99 | def build(self, input_shape): 100 | assert len(input_shape) == 2 101 | assert input_shape[-1] % 2 == 0 102 | input_dim = input_shape[-1] // 2 103 | data_format = K.image_data_format() 104 | kernel_shape = (input_dim, self.units) 105 | fan_in, fan_out = initializers._compute_fans( 106 | kernel_shape, 107 | data_format=data_format 108 | ) 109 | if self.init_criterion == 'he': 110 | s = K.sqrt(1. / fan_in) 111 | elif self.init_criterion == 'glorot': 112 | s = K.sqrt(1. / (fan_in + fan_out)) 113 | rng = RandomStreams(seed=self.seed) 114 | 115 | # Equivalent initialization using amplitude phase representation: 116 | """modulus = rng.rayleigh(scale=s, size=kernel_shape) 117 | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) 118 | def init_w_real(shape, dtype=None): 119 | return modulus * K.cos(phase) 120 | def init_w_imag(shape, dtype=None): 121 | return modulus * K.sin(phase)""" 122 | 123 | # Initialization using euclidean representation: 124 | def init_w_real(shape, dtype=None): 125 | return rng.normal( 126 | size=kernel_shape, 127 | avg=0, 128 | std=s, 129 | dtype=dtype 130 | ) 131 | def init_w_imag(shape, dtype=None): 132 | return rng.normal( 133 | size=kernel_shape, 134 | avg=0, 135 | std=s, 136 | dtype=dtype 137 | ) 138 | if self.kernel_initializer in {'complex'}: 139 | real_init = init_w_real 140 | imag_init = init_w_imag 141 | else: 142 | real_init = self.kernel_initializer 143 | imag_init = self.kernel_initializer 144 | 145 | self.real_kernel = self.add_weight( 146 | shape=kernel_shape, 147 | initializer=real_init, 148 | name='real_kernel', 149 | regularizer=self.kernel_regularizer, 150 | constraint=self.kernel_constraint 151 | ) 152 | self.imag_kernel = self.add_weight( 153 | shape=kernel_shape, 154 | initializer=imag_init, 155 | name='imag_kernel', 156 | regularizer=self.kernel_regularizer, 157 | constraint=self.kernel_constraint 158 | ) 159 | 160 | if self.use_bias: 161 | self.bias = self.add_weight( 162 | shape=(2 * self.units,), 163 | initializer=self.bias_initializer, 164 | name='bias', 165 | regularizer=self.bias_regularizer, 166 | constraint=self.bias_constraint 167 | ) 168 | else: 169 | self.bias = None 170 | 171 | self.input_spec = InputSpec(ndim=2, axes={-1: 2 * input_dim}) 172 | self.built = True 173 | 174 | def call(self, inputs): 175 | input_shape = K.shape(inputs) 176 | input_dim = input_shape[-1] // 2 177 | real_input = inputs[:, :input_dim] 178 | imag_input = inputs[:, input_dim:] 179 | 180 | cat_kernels_4_real = K.concatenate( 181 | [self.real_kernel, -self.imag_kernel], 182 | axis=-1 183 | ) 184 | cat_kernels_4_imag = K.concatenate( 185 | [self.imag_kernel, self.real_kernel], 186 | axis=-1 187 | ) 188 | cat_kernels_4_complex = K.concatenate( 189 | [cat_kernels_4_real, cat_kernels_4_imag], 190 | axis=0 191 | ) 192 | 193 | output = K.dot(inputs, cat_kernels_4_complex) 194 | 195 | if self.use_bias: 196 | output = K.bias_add(output, self.bias) 197 | if self.activation is not None: 198 | output = self.activation(output) 199 | 200 | return output 201 | 202 | def compute_output_shape(self, input_shape): 203 | assert input_shape and len(input_shape) == 2 204 | assert input_shape[-1] 205 | output_shape = list(input_shape) 206 | output_shape[-1] = 2 * self.units 207 | return tuple(output_shape) 208 | 209 | def get_config(self): 210 | if self.kernel_initializer in {'complex'}: 211 | ki = self.kernel_initializer 212 | else: 213 | ki = initializers.serialize(self.kernel_initializer) 214 | config = { 215 | 'units': self.units, 216 | 'activation': activations.serialize(self.activation), 217 | 'use_bias': self.use_bias, 218 | 'init_criterion': self.init_criterion, 219 | 'kernel_initializer': ki, 220 | 'bias_initializer': initializers.serialize(self.bias_initializer), 221 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 222 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 223 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 224 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 225 | 'bias_constraint': constraints.serialize(self.bias_constraint), 226 | 'seed': self.seed, 227 | } 228 | base_config = super(ComplexDense, self).get_config() 229 | return dict(list(base_config.items()) + list(config.items())) 230 | 231 | -------------------------------------------------------------------------------- /complex_layers/fft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Olexa Bilaniuk 6 | # 7 | 8 | import keras.backend as KB 9 | import keras.engine as KE 10 | import keras.layers as KL 11 | import keras.optimizers as KO 12 | import theano as T 13 | import theano.ifelse as TI 14 | import theano.tensor as TT 15 | import theano.tensor.fft as TTF 16 | import numpy as np 17 | 18 | 19 | # 20 | # FFT functions: 21 | # 22 | # fft(): Batched 1-D FFT (Input: (Batch, TimeSamples)) 23 | # ifft(): Batched 1-D IFFT (Input: (Batch, FreqSamples)) 24 | # fft2(): Batched 2-D FFT (Input: (Batch, TimeSamplesH, TimeSamplesW)) 25 | # ifft2(): Batched 2-D IFFT (Input: (Batch, FreqSamplesH, FreqSamplesW)) 26 | # 27 | 28 | def fft(z): 29 | B = z.shape[0]//2 30 | L = z.shape[1] 31 | C = TT.as_tensor_variable(np.asarray([[[1,-1]]], dtype=T.config.floatX)) 32 | Zr, Zi = TTF.rfft(z[:B], norm="ortho"), TTF.rfft(z[B:], norm="ortho") 33 | isOdd = TT.eq(L%2, 1) 34 | Zr = TI.ifelse(isOdd, TT.concatenate([Zr, C*Zr[:,1: ][:,::-1]], axis=1), 35 | TT.concatenate([Zr, C*Zr[:,1:-1][:,::-1]], axis=1)) 36 | Zi = TI.ifelse(isOdd, TT.concatenate([Zi, C*Zi[:,1: ][:,::-1]], axis=1), 37 | TT.concatenate([Zi, C*Zi[:,1:-1][:,::-1]], axis=1)) 38 | Zi = (C*Zi)[:,:,::-1] # Zi * i 39 | Z = Zr+Zi 40 | return TT.concatenate([Z[:,:,0], Z[:,:,1]], axis=0) 41 | def ifft(z): 42 | B = z.shape[0]//2 43 | L = z.shape[1] 44 | C = TT.as_tensor_variable(np.asarray([[[1,-1]]], dtype=T.config.floatX)) 45 | Zr, Zi = TTF.rfft(z[:B], norm="ortho"), TTF.rfft(z[B:]*-1, norm="ortho") 46 | isOdd = TT.eq(L%2, 1) 47 | Zr = TI.ifelse(isOdd, TT.concatenate([Zr, C*Zr[:,1: ][:,::-1]], axis=1), 48 | TT.concatenate([Zr, C*Zr[:,1:-1][:,::-1]], axis=1)) 49 | Zi = TI.ifelse(isOdd, TT.concatenate([Zi, C*Zi[:,1: ][:,::-1]], axis=1), 50 | TT.concatenate([Zi, C*Zi[:,1:-1][:,::-1]], axis=1)) 51 | Zi = (C*Zi)[:,:,::-1] # Zi * i 52 | Z = Zr+Zi 53 | return TT.concatenate([Z[:,:,0], Z[:,:,1]*-1], axis=0) 54 | def fft2(x): 55 | tt = x 56 | tt = KB.reshape(tt, (x.shape[0] *x.shape[1], x.shape[2])) 57 | tf = fft(tt) 58 | tf = KB.reshape(tf, (x.shape[0], x.shape[1], x.shape[2])) 59 | tf = KB.permute_dimensions(tf, (0, 2, 1)) 60 | tf = KB.reshape(tf, (x.shape[0] *x.shape[2], x.shape[1])) 61 | ff = fft(tf) 62 | ff = KB.reshape(ff, (x.shape[0], x.shape[2], x.shape[1])) 63 | ff = KB.permute_dimensions(ff, (0, 2, 1)) 64 | return ff 65 | def ifft2(x): 66 | ff = x 67 | ff = KB.permute_dimensions(ff, (0, 2, 1)) 68 | ff = KB.reshape(ff, (x.shape[0] *x.shape[2], x.shape[1])) 69 | tf = ifft(ff) 70 | tf = KB.reshape(tf, (x.shape[0], x.shape[2], x.shape[1])) 71 | tf = KB.permute_dimensions(tf, (0, 2, 1)) 72 | tf = KB.reshape(tf, (x.shape[0] *x.shape[1], x.shape[2])) 73 | tt = ifft(tf) 74 | tt = KB.reshape(tt, (x.shape[0], x.shape[1], x.shape[2])) 75 | return tt 76 | 77 | # 78 | # FFT Layers: 79 | # 80 | # FFT: Batched 1-D FFT (Input: (Batch, FeatureMaps, TimeSamples)) 81 | # IFFT: Batched 1-D IFFT (Input: (Batch, FeatureMaps, FreqSamples)) 82 | # FFT2: Batched 2-D FFT (Input: (Batch, FeatureMaps, TimeSamplesH, TimeSamplesW)) 83 | # IFFT2: Batched 2-D IFFT (Input: (Batch, FeatureMaps, FreqSamplesH, FreqSamplesW)) 84 | # 85 | 86 | class FFT(KL.Layer): 87 | def call(self, x, mask=None): 88 | a = KB.permute_dimensions(x, (1,0,2)) 89 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2])) 90 | a = fft(a) 91 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2])) 92 | return KB.permute_dimensions(a, (1,0,2)) 93 | class IFFT(KL.Layer): 94 | def call(self, x, mask=None): 95 | a = KB.permute_dimensions(x, (1,0,2)) 96 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2])) 97 | a = ifft(a) 98 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2])) 99 | return KB.permute_dimensions(a, (1,0,2)) 100 | class FFT2(KL.Layer): 101 | def call(self, x, mask=None): 102 | a = KB.permute_dimensions(x, (1,0,2,3)) 103 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2], x.shape[3])) 104 | a = fft2(a) 105 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3])) 106 | return KB.permute_dimensions(a, (1,0,2,3)) 107 | class IFFT2(KL.Layer): 108 | def call(self, x, mask=None): 109 | a = KB.permute_dimensions(x, (1,0,2,3)) 110 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2], x.shape[3])) 111 | a = ifft2(a) 112 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3])) 113 | return KB.permute_dimensions(a, (1,0,2,3)) 114 | 115 | 116 | 117 | # 118 | # Tests 119 | # 120 | # Note: The IFFT is the conjugate of the FFT of the conjugate. 121 | # 122 | # np.fft.ifft(x) == np.conj(np.fft.fft(np.conj(x))) 123 | # 124 | 125 | if __name__ == "__main__": 126 | # Numpy 127 | np.random.seed(1) 128 | L = 19 129 | r = np.random.normal(0.8, size=(L,)) 130 | i = np.random.normal(0.8, size=(L,)) 131 | x = r+i*1j 132 | R = np.fft.rfft(r, norm="ortho") 133 | I = np.fft.rfft(i, norm="ortho") 134 | X = np.fft.fft (x, norm="ortho") 135 | 136 | if L&1: 137 | R = np.concatenate([R, np.conj(R[1: ][::-1])]) 138 | I = np.concatenate([I, np.conj(I[1: ][::-1])]) 139 | else: 140 | R = np.concatenate([R, np.conj(R[1:-1][::-1])]) 141 | I = np.concatenate([I, np.conj(I[1:-1][::-1])]) 142 | Y = R+I*1j 143 | print(np.allclose(X, Y)) 144 | 145 | 146 | # Theano 147 | z = TT.dmatrix() 148 | f = T.function([z], ifft(fft(z))) 149 | v = np.concatenate([np.real(x)[np.newaxis,:], np.imag(x)[np.newaxis,:]], axis=0) 150 | print(v) 151 | print(f(v)) 152 | print(np.allclose(v, f(v))) 153 | 154 | 155 | # Keras 156 | x = i = KL.Input(shape=(128, 32,32)) 157 | x = IFFT2()(x) 158 | model = KE.Model([i],[x]) 159 | 160 | loss = "mse" 161 | opt = KO.Adam() 162 | 163 | model.compile(opt, loss) 164 | model._make_train_function() 165 | model._make_predict_function() 166 | model._make_test_function() 167 | 168 | v = np.random.normal(size=(13,128,32,32)) 169 | #print v 170 | V = model.predict(v) 171 | #print V 172 | print(V.shape) 173 | -------------------------------------------------------------------------------- /complex_layers/init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chiheb Trabelsi 6 | 7 | import numpy as np 8 | from numpy.random import RandomState 9 | import keras.backend as K 10 | from keras import initializers 11 | from keras.initializers import Initializer 12 | from keras.utils.generic_utils import (serialize_keras_object, 13 | deserialize_keras_object) 14 | 15 | 16 | class IndependentFilters(Initializer): 17 | # This initialization constructs real-valued kernels 18 | # that are independent as much as possible from each other 19 | # while respecting either the He or the Glorot criterion. 20 | def __init__(self, kernel_size, input_dim, 21 | weight_dim, nb_filters=None, 22 | criterion='glorot', seed=None): 23 | 24 | # `weight_dim` is used as a parameter for sanity check 25 | # as we should not pass an integer as kernel_size when 26 | # the weight dimension is >= 2. 27 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 28 | # then in such a case, weight_dim = 2. 29 | # (in case of 2D input): 30 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 31 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 32 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 33 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 34 | 35 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 36 | self.nb_filters = nb_filters 37 | self.kernel_size = kernel_size 38 | self.input_dim = input_dim 39 | self.weight_dim = weight_dim 40 | self.criterion = criterion 41 | self.seed = 1337 if seed is None else seed 42 | 43 | def __call__(self, shape, dtype=None): 44 | 45 | if self.nb_filters is not None: 46 | num_rows = self.nb_filters * self.input_dim 47 | num_cols = np.prod(self.kernel_size) 48 | else: 49 | # in case it is the kernel is a matrix and not a filter 50 | # which is the case of 2D input (No feature maps). 51 | num_rows = self.input_dim 52 | num_cols = self.kernel_size[-1] 53 | 54 | flat_shape = (num_rows, num_cols) 55 | rng = RandomState(self.seed) 56 | x = rng.uniform(size=flat_shape) 57 | u, _, v = np.linalg.svd(x) 58 | orthogonal_x = np.dot(u, np.dot(np.eye(num_rows, num_cols), v.T)) 59 | if self.nb_filters is not None: 60 | independent_filters = np.reshape(orthogonal_x, (num_rows,) + tuple(self.kernel_size)) 61 | fan_in, fan_out = initializers._compute_fans( 62 | tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 63 | ) 64 | else: 65 | independent_filters = orthogonal_x 66 | fan_in, fan_out = (self.input_dim, self.kernel_size[-1]) 67 | 68 | if self.criterion == 'glorot': 69 | desired_var = 2. / (fan_in + fan_out) 70 | elif self.criterion == 'he': 71 | desired_var = 2. / fan_in 72 | else: 73 | raise ValueError('Invalid criterion: ' + self.criterion) 74 | 75 | multip_constant = np.sqrt (desired_var / np.var(independent_filters)) 76 | scaled_indep = multip_constant * independent_filters 77 | 78 | if self.weight_dim == 2 and self.nb_filters is None: 79 | weight_real = scaled_real 80 | weight_imag = scaled_imag 81 | else: 82 | kernel_shape = tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 83 | if self.weight_dim == 1: 84 | transpose_shape = (1, 0) 85 | elif self.weight_dim == 2 and self.nb_filters is not None: 86 | transpose_shape = (1, 2, 0) 87 | elif self.weight_dim == 3 and self.nb_filters is not None: 88 | transpose_shape = (1, 2, 3, 0) 89 | weight = np.transpose(scaled_indep, transpose_shape) 90 | weight = np.reshape(weight, kernel_shape) 91 | 92 | return weight 93 | 94 | def get_config(self): 95 | return {'nb_filters': self.nb_filters, 96 | 'kernel_size': self.kernel_size, 97 | 'input_dim': self.input_dim, 98 | 'weight_dim': self.weight_dim, 99 | 'criterion': self.criterion, 100 | 'seed': self.seed} 101 | 102 | 103 | class ComplexIndependentFilters(Initializer): 104 | # This initialization constructs complex-valued kernels 105 | # that are independent as much as possible from each other 106 | # while respecting either the He or the Glorot criterion. 107 | def __init__(self, kernel_size, input_dim, 108 | weight_dim, nb_filters=None, 109 | criterion='glorot', seed=None): 110 | 111 | # `weight_dim` is used as a parameter for sanity check 112 | # as we should not pass an integer as kernel_size when 113 | # the weight dimension is >= 2. 114 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 115 | # then in such a case, weight_dim = 2. 116 | # (in case of 2D input): 117 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 118 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 119 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 120 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 121 | 122 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 123 | self.nb_filters = nb_filters 124 | self.kernel_size = kernel_size 125 | self.input_dim = input_dim 126 | self.weight_dim = weight_dim 127 | self.criterion = criterion 128 | self.seed = 1337 if seed is None else seed 129 | 130 | def __call__(self, shape, dtype=None): 131 | 132 | if self.nb_filters is not None: 133 | num_rows = self.nb_filters * self.input_dim 134 | num_cols = np.prod(self.kernel_size) 135 | else: 136 | # in case it is the kernel is a matrix and not a filter 137 | # which is the case of 2D input (No feature maps). 138 | num_rows = self.input_dim 139 | num_cols = self.kernel_size[-1] 140 | 141 | flat_shape = (int(num_rows), int(num_cols)) 142 | rng = RandomState(self.seed) 143 | r = rng.uniform(size=flat_shape) 144 | i = rng.uniform(size=flat_shape) 145 | z = r + 1j * i 146 | u, _, v = np.linalg.svd(z) 147 | unitary_z = np.dot(u, np.dot(np.eye(int(num_rows), int(num_cols)), np.conjugate(v).T)) 148 | real_unitary = unitary_z.real 149 | imag_unitary = unitary_z.imag 150 | if self.nb_filters is not None: 151 | indep_real = np.reshape(real_unitary, (num_rows,) + tuple(self.kernel_size)) 152 | indep_imag = np.reshape(imag_unitary, (num_rows,) + tuple(self.kernel_size)) 153 | fan_in, fan_out = initializers._compute_fans( 154 | tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 155 | ) 156 | else: 157 | indep_real = real_unitary 158 | indep_imag = imag_unitary 159 | fan_in, fan_out = (int(self.input_dim), self.kernel_size[-1]) 160 | 161 | if self.criterion == 'glorot': 162 | desired_var = 1. / (fan_in + fan_out) 163 | elif self.criterion == 'he': 164 | desired_var = 1. / (fan_in) 165 | else: 166 | raise ValueError('Invalid criterion: ' + self.criterion) 167 | 168 | multip_real = np.sqrt(desired_var / np.var(indep_real)) 169 | multip_imag = np.sqrt(desired_var / np.var(indep_imag)) 170 | scaled_real = multip_real * indep_real 171 | scaled_imag = multip_imag * indep_imag 172 | 173 | if self.weight_dim == 2 and self.nb_filters is None: 174 | weight_real = scaled_real 175 | weight_imag = scaled_imag 176 | else: 177 | kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 178 | if self.weight_dim == 1: 179 | transpose_shape = (1, 0) 180 | elif self.weight_dim == 2 and self.nb_filters is not None: 181 | transpose_shape = (1, 2, 0) 182 | elif self.weight_dim == 3 and self.nb_filters is not None: 183 | transpose_shape = (1, 2, 3, 0) 184 | 185 | weight_real = np.transpose(scaled_real, transpose_shape) 186 | weight_imag = np.transpose(scaled_imag, transpose_shape) 187 | weight_real = np.reshape(weight_real, kernel_shape) 188 | weight_imag = np.reshape(weight_imag, kernel_shape) 189 | weight = np.concatenate([weight_real, weight_imag], axis=-1) 190 | 191 | return weight 192 | 193 | def get_config(self): 194 | return {'nb_filters': self.nb_filters, 195 | 'kernel_size': self.kernel_size, 196 | 'input_dim': self.input_dim, 197 | 'weight_dim': self.weight_dim, 198 | 'criterion': self.criterion, 199 | 'seed': self.seed} 200 | 201 | 202 | class ComplexInit(Initializer): 203 | # The standard complex initialization using 204 | # either the He or the Glorot criterion. 205 | def __init__(self, kernel_size, input_dim, 206 | weight_dim, nb_filters=None, 207 | criterion='glorot', seed=None): 208 | 209 | # `weight_dim` is used as a parameter for sanity check 210 | # as we should not pass an integer as kernel_size when 211 | # the weight dimension is >= 2. 212 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 213 | # then in such a case, weight_dim = 2. 214 | # (in case of 2D input): 215 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 216 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 217 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 218 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 219 | 220 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 221 | self.nb_filters = nb_filters 222 | self.kernel_size = kernel_size 223 | self.input_dim = input_dim 224 | self.weight_dim = weight_dim 225 | self.criterion = criterion 226 | self.seed = 1337 if seed is None else seed 227 | 228 | def __call__(self, shape, dtype=None): 229 | 230 | if self.nb_filters is not None: 231 | kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 232 | else: 233 | kernel_shape = (int(self.input_dim), self.kernel_size[-1]) 234 | 235 | fan_in, fan_out = initializers._compute_fans( 236 | tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 237 | ) 238 | 239 | if self.criterion == 'glorot': 240 | s = 1. / (fan_in + fan_out) 241 | elif self.criterion == 'he': 242 | s = 1. / fan_in 243 | else: 244 | raise ValueError('Invalid criterion: ' + self.criterion) 245 | rng = RandomState(self.seed) 246 | modulus = rng.rayleigh(scale=s, size=kernel_shape) 247 | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) 248 | weight_real = modulus * np.cos(phase) 249 | weight_imag = modulus * np.sin(phase) 250 | weight = np.concatenate([weight_real, weight_imag], axis=-1) 251 | 252 | return weight 253 | 254 | 255 | class SqrtInit(Initializer): 256 | def __call__(self, shape, dtype=None): 257 | return K.constant(1 / K.sqrt(2), shape=shape, dtype=dtype) 258 | 259 | 260 | # Aliases: 261 | sqrt_init = SqrtInit 262 | independent_filters = IndependentFilters 263 | complex_init = ComplexInit -------------------------------------------------------------------------------- /complex_layers/norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chiheb Trabelsi 6 | 7 | # 8 | # Implementation of Layer Normalization and Complex Layer Normalization 9 | # 10 | 11 | import numpy as np 12 | from keras.layers import Layer, InputSpec 13 | from keras import initializers, regularizers, constraints 14 | import keras.backend as K 15 | from .bn import ComplexBN as complex_normalization 16 | from .bn import sqrt_init 17 | 18 | def layernorm(x, axis, epsilon, gamma, beta): 19 | # assert self.built, 'Layer must be built before being called' 20 | input_shape = K.shape(x) 21 | reduction_axes = list(range(K.ndim(x))) 22 | del reduction_axes[axis] 23 | del reduction_axes[0] 24 | broadcast_shape = [1] * K.ndim(x) 25 | broadcast_shape[axis] = input_shape[axis] 26 | broadcast_shape[0] = K.shape(x)[0] 27 | 28 | # Perform normalization: centering and reduction 29 | 30 | mean = K.mean(x, axis=reduction_axes) 31 | broadcast_mean = K.reshape(mean, broadcast_shape) 32 | x_centred = x - broadcast_mean 33 | variance = K.mean(x_centred ** 2, axis=reduction_axes) + epsilon 34 | broadcast_variance = K.reshape(variance, broadcast_shape) 35 | 36 | x_normed = x_centred / K.sqrt(broadcast_variance) 37 | 38 | # Perform scaling and shifting 39 | 40 | broadcast_shape_params = [1] * K.ndim(x) 41 | broadcast_shape_params[axis] = K.shape(x)[axis] 42 | broadcast_gamma = K.reshape(gamma, broadcast_shape_params) 43 | broadcast_beta = K.reshape(beta, broadcast_shape_params) 44 | 45 | x_LN = broadcast_gamma * x_normed + broadcast_beta 46 | 47 | return x_LN 48 | 49 | class LayerNormalization(Layer): 50 | 51 | def __init__(self, 52 | epsilon=1e-4, 53 | axis=-1, 54 | beta_init='zeros', 55 | gamma_init='ones', 56 | gamma_regularizer=None, 57 | beta_regularizer=None, 58 | **kwargs): 59 | 60 | self.supports_masking = True 61 | self.beta_init = initializers.get(beta_init) 62 | self.gamma_init = initializers.get(gamma_init) 63 | self.epsilon = epsilon 64 | self.axis = axis 65 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 66 | self.beta_regularizer = regularizers.get(beta_regularizer) 67 | 68 | super(LayerNormalization, self).__init__(**kwargs) 69 | 70 | def build(self, input_shape): 71 | self.input_spec = InputSpec(ndim=len(input_shape), 72 | axes={self.axis: input_shape[self.axis]}) 73 | shape = (input_shape[self.axis],) 74 | 75 | self.gamma = self.add_weight(shape, 76 | initializer=self.gamma_init, 77 | regularizer=self.gamma_regularizer, 78 | name='{}_gamma'.format(self.name)) 79 | self.beta = self.add_weight(shape, 80 | initializer=self.beta_init, 81 | regularizer=self.beta_regularizer, 82 | name='{}_beta'.format(self.name)) 83 | 84 | self.built = True 85 | 86 | def call(self, x, mask=None): 87 | assert self.built, 'Layer must be built before being called' 88 | return layernorm(x, self.axis, self.epsilon, self.gamma, self.beta) 89 | 90 | def get_config(self): 91 | config = {'epsilon': self.epsilon, 92 | 'axis': self.axis, 93 | 'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None, 94 | 'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None 95 | } 96 | base_config = super(LayerNormalization, self).get_config() 97 | return dict(list(base_config.items()) + list(config.items())) 98 | 99 | 100 | class ComplexLayerNorm(Layer): 101 | def __init__(self, 102 | epsilon=1e-4, 103 | axis=-1, 104 | center=True, 105 | scale=True, 106 | beta_initializer='zeros', 107 | gamma_diag_initializer=sqrt_init, 108 | gamma_off_initializer='zeros', 109 | beta_regularizer=None, 110 | gamma_diag_regularizer=None, 111 | gamma_off_regularizer=None, 112 | beta_constraint=None, 113 | gamma_diag_constraint=None, 114 | gamma_off_constraint=None, 115 | **kwargs): 116 | 117 | self.supports_masking = True 118 | self.epsilon = epsilon 119 | self.axis = axis 120 | self.center = center 121 | self.scale = scale 122 | self.beta_initializer = initializers.get(beta_initializer) 123 | self.gamma_diag_initializer = initializers.get(gamma_diag_initializer) 124 | self.gamma_off_initializer = initializers.get(gamma_off_initializer) 125 | self.beta_regularizer = regularizers.get(beta_regularizer) 126 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 127 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 128 | self.beta_constraint = constraints.get(beta_constraint) 129 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 130 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 131 | super(ComplexLayerNorm, self).__init__(**kwargs) 132 | 133 | def build(self, input_shape): 134 | 135 | ndim = len(input_shape) 136 | dim = input_shape[self.axis] 137 | if dim is None: 138 | raise ValueError('Axis ' + str(self.axis) + ' of ' 139 | 'input tensor should have a defined dimension ' 140 | 'but the layer received an input with shape ' + 141 | str(input_shape) + '.') 142 | self.input_spec = InputSpec(ndim=len(input_shape), 143 | axes={self.axis: dim}) 144 | 145 | gamma_shape = (input_shape[self.axis] // 2,) 146 | if self.scale: 147 | self.gamma_rr = self.add_weight( 148 | shape=gamma_shape, 149 | name='gamma_rr', 150 | initializer=self.gamma_diag_initializer, 151 | regularizer=self.gamma_diag_regularizer, 152 | constraint=self.gamma_diag_constraint 153 | ) 154 | self.gamma_ii = self.add_weight( 155 | shape=gamma_shape, 156 | name='gamma_ii', 157 | initializer=self.gamma_diag_initializer, 158 | regularizer=self.gamma_diag_regularizer, 159 | constraint=self.gamma_diag_constraint 160 | ) 161 | self.gamma_ri = self.add_weight( 162 | shape=gamma_shape, 163 | name='gamma_ri', 164 | initializer=self.gamma_off_initializer, 165 | regularizer=self.gamma_off_regularizer, 166 | constraint=self.gamma_off_constraint 167 | ) 168 | else: 169 | self.gamma_rr = None 170 | self.gamma_ii = None 171 | self.gamma_ri = None 172 | 173 | if self.center: 174 | self.beta = self.add_weight(shape=(input_shape[self.axis],), 175 | name='beta', 176 | initializer=self.beta_initializer, 177 | regularizer=self.beta_regularizer, 178 | constraint=self.beta_constraint) 179 | else: 180 | self.beta = None 181 | 182 | self.built = True 183 | 184 | def call(self, inputs): 185 | input_shape = K.shape(inputs) 186 | ndim = K.ndim(inputs) 187 | reduction_axes = list(range(ndim)) 188 | del reduction_axes[self.axis] 189 | del reduction_axes[0] 190 | input_dim = input_shape[self.axis] // 2 191 | mu = K.mean(inputs, axis=reduction_axes) 192 | broadcast_mu_shape = [1] * ndim 193 | broadcast_mu_shape[self.axis] = input_shape[self.axis] 194 | broadcast_mu_shape[0] = K.shape(inputs)[0] 195 | broadcast_mu = K.reshape(mu, broadcast_mu_shape) 196 | if self.center: 197 | input_centred = inputs - broadcast_mu 198 | else: 199 | input_centred = inputs 200 | centred_squared = input_centred ** 2 201 | if (self.axis == 1 and ndim != 3) or ndim == 2: 202 | centred_squared_real = centred_squared[:, :input_dim] 203 | centred_squared_imag = centred_squared[:, input_dim:] 204 | centred_real = input_centred[:, :input_dim] 205 | centred_imag = input_centred[:, input_dim:] 206 | elif ndim == 3: 207 | centred_squared_real = centred_squared[:, :, :input_dim] 208 | centred_squared_imag = centred_squared[:, :, input_dim:] 209 | centred_real = input_centred[:, :, :input_dim] 210 | centred_imag = input_centred[:, :, input_dim:] 211 | elif self.axis == -1 and ndim == 4: 212 | centred_squared_real = centred_squared[:, :, :, :input_dim] 213 | centred_squared_imag = centred_squared[:, :, :, input_dim:] 214 | centred_real = input_centred[:, :, :, :input_dim] 215 | centred_imag = input_centred[:, :, :, input_dim:] 216 | elif self.axis == -1 and ndim == 5: 217 | centred_squared_real = centred_squared[:, :, :, :, :input_dim] 218 | centred_squared_imag = centred_squared[:, :, :, :, input_dim:] 219 | centred_real = input_centred[:, :, :, :, :input_dim] 220 | centred_imag = input_centred[:, :, :, :, input_dim:] 221 | else: 222 | raise ValueError( 223 | 'Incorrect Layernorm combination of axis and dimensions. axis should be either 1 or -1. ' 224 | 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.' 225 | ) 226 | if self.scale: 227 | Vrr = K.mean( 228 | centred_squared_real, 229 | axis=reduction_axes 230 | ) + self.epsilon 231 | Vii = K.mean( 232 | centred_squared_imag, 233 | axis=reduction_axes 234 | ) + self.epsilon 235 | # Vri contains the real and imaginary covariance for each feature map. 236 | Vri = K.mean( 237 | centred_real * centred_imag, 238 | axis=reduction_axes, 239 | ) 240 | elif self.center: 241 | Vrr = None 242 | Vii = None 243 | Vri = None 244 | else: 245 | raise ValueError('Error. Both scale and center in batchnorm are set to False.') 246 | 247 | return complex_normalization( 248 | input_centred, Vrr, Vii, Vri, 249 | self.beta, self.gamma_rr, self.gamma_ri, 250 | self.gamma_ii, self.scale, self.center, 251 | layernorm=True, axis=self.axis 252 | ) 253 | 254 | def get_config(self): 255 | config = { 256 | 'axis': self.axis, 257 | 'epsilon': self.epsilon, 258 | 'center': self.center, 259 | 'scale': self.scale, 260 | 'beta_initializer': initializers.serialize(self.beta_initializer), 261 | 'gamma_diag_initializer': initializers.serialize(self.gamma_diag_initializer), 262 | 'gamma_off_initializer': initializers.serialize(self.gamma_off_initializer), 263 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 264 | 'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer), 265 | 'gamma_off_regularizer': regularizers.serialize(self.gamma_off_regularizer), 266 | 'beta_constraint': constraints.serialize(self.beta_constraint), 267 | 'gamma_diag_constraint': constraints.serialize(self.gamma_diag_constraint), 268 | 'gamma_off_constraint': constraints.serialize(self.gamma_off_constraint), 269 | } 270 | base_config = super(ComplexLayerNorm, self).get_config() 271 | return dict(list(base_config.items()) + list(config.items())) 272 | -------------------------------------------------------------------------------- /complex_layers/pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Olexa Bilaniuk 6 | # 7 | 8 | import keras.backend as KB 9 | import keras.engine as KE 10 | import keras.layers as KL 11 | import keras.optimizers as KO 12 | import theano as T 13 | import theano.ifelse as TI 14 | import theano.tensor as TT 15 | import theano.tensor.fft as TTF 16 | import numpy as np 17 | 18 | 19 | # 20 | # Spectral Pooling Layer 21 | # 22 | 23 | class SpectralPooling1D(KL.Layer): 24 | def __init__(self, topf=(0,)): 25 | super(SpectralPooling1D, self).__init__() 26 | if "topf" in kwargs: 27 | self.topf = (int (kwargs["topf" ][0]),) 28 | self.topf = (self.topf[0]//2,) 29 | elif "gamma" in kwargs: 30 | self.gamma = (float(kwargs["gamma"][0]),) 31 | self.gamma = (self.gamma[0]/2,) 32 | else: 33 | raise RuntimeError("Must provide either topf= or gamma= !") 34 | def call(self, x, mask=None): 35 | xshape = x._keras_shape 36 | if hasattr(self, "topf"): 37 | topf = self.topf 38 | else: 39 | if KB.image_data_format() == "channels_first": 40 | topf = (int(self.gamma[0]*xshape[2]),) 41 | else: 42 | topf = (int(self.gamma[0]*xshape[1]),) 43 | 44 | if KB.image_data_format() == "channels_first": 45 | if topf[0] > 0 and xshape[2] >= 2*topf[0]: 46 | mask = [1]*(topf[0] ) +\ 47 | [0]*(xshape[2] - 2*topf[0]) +\ 48 | [1]*(topf[0] ) 49 | mask = [[mask]] 50 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,2)) 51 | mask = KB.constant(mask) 52 | x *= mask 53 | else: 54 | if topf[0] > 0 and xshape[1] >= 2*topf[0]: 55 | mask = [1]*(topf[0] ) +\ 56 | [0]*(xshape[1] - 2*topf[0]) +\ 57 | [1]*(topf[0] ) 58 | mask = [[mask]] 59 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,2,1)) 60 | mask = KB.constant(mask) 61 | x *= mask 62 | 63 | return x 64 | class SpectralPooling2D(KL.Layer): 65 | def __init__(self, **kwargs): 66 | super(SpectralPooling2D, self).__init__() 67 | if "topf" in kwargs: 68 | self.topf = (int (kwargs["topf" ][0]), int (kwargs["topf" ][1])) 69 | self.topf = (self.topf[0]//2, self.topf[1]//2) 70 | elif "gamma" in kwargs: 71 | self.gamma = (float(kwargs["gamma"][0]), float(kwargs["gamma"][1])) 72 | self.gamma = (self.gamma[0]/2, self.gamma[1]/2) 73 | else: 74 | raise RuntimeError("Must provide either topf= or gamma= !") 75 | def call(self, x, mask=None): 76 | xshape = x._keras_shape 77 | if hasattr(self, "topf"): 78 | topf = self.topf 79 | else: 80 | if KB.image_data_format() == "channels_first": 81 | topf = (int(self.gamma[0]*xshape[2]), int(self.gamma[1]*xshape[3])) 82 | else: 83 | topf = (int(self.gamma[0]*xshape[1]), int(self.gamma[1]*xshape[2])) 84 | 85 | if KB.image_data_format() == "channels_first": 86 | if topf[0] > 0 and xshape[2] >= 2*topf[0]: 87 | mask = [1]*(topf[0] ) +\ 88 | [0]*(xshape[2] - 2*topf[0]) +\ 89 | [1]*(topf[0] ) 90 | mask = [[[mask]]] 91 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,3,2)) 92 | mask = KB.constant(mask) 93 | x *= mask 94 | if topf[1] > 0 and xshape[3] >= 2*topf[1]: 95 | mask = [1]*(topf[1] ) +\ 96 | [0]*(xshape[3] - 2*topf[1]) +\ 97 | [1]*(topf[1] ) 98 | mask = [[[mask]]] 99 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,2,3)) 100 | mask = KB.constant(mask) 101 | x *= mask 102 | else: 103 | if topf[0] > 0 and xshape[1] >= 2*topf[0]: 104 | mask = [1]*(topf[0] ) +\ 105 | [0]*(xshape[1] - 2*topf[0]) +\ 106 | [1]*(topf[0] ) 107 | mask = [[[mask]]] 108 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,3,1,2)) 109 | mask = KB.constant(mask) 110 | x *= mask 111 | if topf[1] > 0 and xshape[2] >= 2*topf[1]: 112 | mask = [1]*(topf[1] ) +\ 113 | [0]*(xshape[2] - 2*topf[1]) +\ 114 | [1]*(topf[1] ) 115 | mask = [[[mask]]] 116 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,3,2)) 117 | mask = KB.constant(mask) 118 | x *= mask 119 | 120 | return x 121 | 122 | 123 | if __name__ == "__main__": 124 | import cv2, sys 125 | import __main__ as SP 126 | import fft as CF 127 | 128 | # Build Model 129 | x = i = KL.Input(shape=(6,512,512)) 130 | f = CF.FFT2()(x) 131 | p = SP.SpectralPooling2D(gamma=[0.15,0.15])(f) 132 | o = CF.IFFT2()(p) 133 | 134 | model = KE.Model([i], [f,p,o]) 135 | model.compile("sgd", "mse") 136 | 137 | # Use it 138 | img = cv2.imread(sys.argv[1]) 139 | imgBatch = img[np.newaxis,...].transpose((0,3,1,2)) 140 | imgBatch = np.concatenate([imgBatch, np.zeros_like(imgBatch)], axis=1) 141 | f,p,o = model.predict(imgBatch) 142 | ffted = np.sqrt(np.sum(f[:,:3]**2 + f[:,3:]**2, axis=1)) 143 | ffted = ffted .transpose((1,2,0))/255 144 | pooled = np.sqrt(np.sum(p[:,:3]**2 + p[:,3:]**2, axis=1)) 145 | pooled = pooled.transpose((1,2,0))/255 146 | filtered = np.clip(o,0,255).transpose((0,2,3,1))[0,:,:,:3].astype("uint8") 147 | 148 | # Display it 149 | cv2.imshow("Original", img) 150 | cv2.imshow("FFT", ffted) 151 | cv2.imshow("Pooled", pooled) 152 | cv2.imshow("Filtered", filtered) 153 | cv2.waitKey(0) 154 | -------------------------------------------------------------------------------- /complex_layers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Dmitriy Serdyuk, Olexa Bilaniuk, Chiheb Trabelsi 6 | 7 | import keras.backend as K 8 | from keras.layers import Layer, Lambda 9 | 10 | # 11 | # GetReal/GetImag Lambda layer Implementation 12 | # 13 | 14 | 15 | def get_realpart(x): 16 | image_format = K.image_data_format() 17 | ndim = K.ndim(x) 18 | input_shape = K.shape(x) 19 | 20 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 21 | input_dim = input_shape[1] // 2 22 | return x[:, :input_dim] 23 | 24 | input_dim = input_shape[-1] // 2 25 | if ndim == 3: 26 | return x[:, :, :input_dim] 27 | elif ndim == 4: 28 | return x[:, :, :, :input_dim] 29 | elif ndim == 5: 30 | return x[:, :, :, :, :input_dim] 31 | 32 | 33 | def get_imagpart(x): 34 | image_format = K.image_data_format() 35 | ndim = K.ndim(x) 36 | input_shape = K.shape(x) 37 | 38 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 39 | input_dim = input_shape[1] // 2 40 | return x[:, input_dim:] 41 | 42 | input_dim = input_shape[-1] // 2 43 | if ndim == 3: 44 | return x[:, :, input_dim:] 45 | elif ndim == 4: 46 | return x[:, :, :, input_dim:] 47 | elif ndim == 5: 48 | return x[:, :, :, :, input_dim:] 49 | 50 | 51 | def get_abs(x): 52 | real = get_realpart(x) 53 | imag = get_imagpart(x) 54 | 55 | return K.sqrt(real * real + imag * imag) 56 | 57 | 58 | def getpart_output_shape(input_shape): 59 | returned_shape = list(input_shape[:]) 60 | image_format = K.image_data_format() 61 | ndim = len(returned_shape) 62 | 63 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 64 | axis = 1 65 | else: 66 | axis = -1 67 | 68 | returned_shape[axis] = returned_shape[axis] // 2 69 | 70 | return tuple(returned_shape) 71 | 72 | 73 | class GetReal(Layer): 74 | def call(self, inputs): 75 | return get_realpart(inputs) 76 | def compute_output_shape(self, input_shape): 77 | return getpart_output_shape(input_shape) 78 | class GetImag(Layer): 79 | def call(self, inputs): 80 | return get_imagpart(inputs) 81 | def compute_output_shape(self, input_shape): 82 | return getpart_output_shape(input_shape) 83 | class GetAbs(Layer): 84 | def call(self, inputs): 85 | return get_abs(inputs) 86 | def compute_output_shape(self, input_shape): 87 | return getpart_output_shape(input_shape) 88 | 89 | -------------------------------------------------------------------------------- /paper/DeepQuaternionNets.aux: -------------------------------------------------------------------------------- 1 | \relax 2 | \citation{ioffe2015batch} 3 | \citation{srivastava2015training} 4 | \citation{he2016deep} 5 | \citation{clevert2015fast} 6 | \citation{hochreiter1991untersuchungen} 7 | \citation{arjovsky2016unitary} 8 | \citation{danihelka2016associative} 9 | \citation{wisdom2016full} 10 | \citation{trabelsi2017deep} 11 | \citation{bulow1999hypercomplex} 12 | \citation{sangwine2000colour} 13 | \citation{bulow2001hypercomplex} 14 | \citation{rishiyur2006neural} 15 | \citation{kendall2015posenet} 16 | \citation{parcollet2016quaternion} 17 | \citation{minemoto2017feed} 18 | \@writefile{toc}{\contentsline {section}{\numberline {1}Introduction}{1}} 19 | \citation{trabelsi2017deep} 20 | \citation{kendall2015posenet} 21 | \citation{oppenheim1981importance} 22 | \citation{bulow2001hypercomplex} 23 | \citation{sangwine2000colour} 24 | \citation{shi2007quaternion} 25 | \@writefile{toc}{\contentsline {section}{\numberline {2}Motivation and Related Work}{2}} 26 | \citation{minemoto2017feed} 27 | \citation{trabelsi2017deep} 28 | \citation{bulow1999hypercomplex} 29 | \citation{sangwine2000colour} 30 | \citation{bulow2001hypercomplex} 31 | \@writefile{toc}{\contentsline {section}{\numberline {3}Quaternion Network Components}{3}} 32 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.1}Quaternion Representation}{3}} 33 | \newlabel{eq:quaternion1}{{1}{3}} 34 | \newlabel{eq:quarternion2}{{2}{3}} 35 | \newlabel{eq:quarternion3}{{3}{3}} 36 | \citation{trabelsi2017deep} 37 | \newlabel{eq:m4r}{{4}{4}} 38 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.2}Quaternion Differentiability}{4}} 39 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.3}Quaternion Convolution}{4}} 40 | \newlabel{s:qc}{{3.3}{4}} 41 | \newlabel{eq:convolve1}{{5}{4}} 42 | \citation{chollet2016xception} 43 | \citation{shi2007quaternion} 44 | \citation{ioffe2015batch} 45 | \citation{trabelsi2017deep} 46 | \newlabel{eq:}{{6}{5}} 47 | \@writefile{lof}{\contentsline {figure}{\numberline {1}{\ignorespaces An illustration of quaternion convolution.}}{5}} 48 | \newlabel{f:quatconv}{{1}{5}} 49 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.4}Quaternion Batch-Normalization}{6}} 50 | \newlabel{eq:white4d}{{7}{6}} 51 | \newlabel{eq:V4d}{{8}{6}} 52 | \newlabel{eq:gamma}{{9}{6}} 53 | \citation{glorot2010understanding} 54 | \citation{he2015delving} 55 | \citation{turner2002} 56 | \citation{glorot2010understanding} 57 | \citation{he2015delving} 58 | \citation{nair2010rectified} 59 | \newlabel{eq:qbn}{{10}{7}} 60 | \@writefile{toc}{\contentsline {subsection}{\numberline {3.5}Quaternion Weight Initialization}{7}} 61 | \newlabel{eq:quaternion_weight}{{11}{7}} 62 | \newlabel{eq:variance}{{12}{7}} 63 | \newlabel{eq:expected}{{13}{7}} 64 | \newlabel{eq:variance_sigma}{{14}{7}} 65 | \citation{trabelsi2017deep} 66 | \citation{he2016deep} 67 | \citation{trabelsi2017deep} 68 | \citation{nesterov1983method} 69 | \citation{trabelsi2017deep} 70 | \citation{he2016deep} 71 | \citation{trabelsi2017deep} 72 | \citation{he2016deep} 73 | \citation{trabelsi2017deep} 74 | \citation{he2016deep} 75 | \citation{trabelsi2017deep} 76 | \@writefile{toc}{\contentsline {section}{\numberline {4}Experimental Results}{8}} 77 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.1}Classification}{8}} 78 | \@writefile{lot}{\contentsline {table}{\numberline {1}{\ignorespaces Classification error on CIFAR-10 and CIFAR-100. Note that \cite {he2016deep} is a 110 layer residual network, \cite {trabelsi2017deep} is 118 layer complex network with the same design as the prior except additional initial layers to extract complex mappings.}}{9}} 79 | \newlabel{t:results1}{{1}{9}} 80 | \@writefile{toc}{\contentsline {subsection}{\numberline {4.2}Segmentation}{9}} 81 | \@writefile{lot}{\contentsline {table}{\numberline {2}{\ignorespaces IOU on KITTI Road Estimation benchmark.}}{9}} 82 | \newlabel{t:results2}{{2}{9}} 83 | \@writefile{toc}{\contentsline {section}{\numberline {5}Conclusions}{9}} 84 | \@writefile{toc}{\contentsline {section}{\numberline {6}Acknowledgments}{10}} 85 | \bibdata{bib} 86 | \bibcite{arjovsky2016unitary}{ASB16} 87 | \bibcite{bulow2001hypercomplex}{BS01} 88 | \bibcite{bulow1999hypercomplex}{B{\"u}l99} 89 | \bibcite{chollet2016xception}{Cho16} 90 | \bibcite{clevert2015fast}{CUH15} 91 | \bibcite{danihelka2016associative}{DWU{$^{+}$}16} 92 | \bibcite{glorot2010understanding}{GB10} 93 | \bibcite{hochreiter1991untersuchungen}{Hoc91} 94 | \bibcite{he2015delving}{HZRS15} 95 | \bibcite{he2016deep}{HZRS16} 96 | \bibcite{ioffe2015batch}{IS15} 97 | \bibcite{kendall2015posenet}{KGC15} 98 | \bibcite{kessy2017optimal}{KLS17} 99 | \bibcite{minemoto2017feed}{MINM17} 100 | \bibcite{nesterov1983method}{Nes} 101 | \bibcite{nair2010rectified}{NH10} 102 | \bibcite{oppenheim1981importance}{OL81} 103 | \bibcite{parcollet2016quaternion}{PMB{$^{+}$}16} 104 | \bibcite{turner2002}{PT02} 105 | \bibcite{rishiyur2006neural}{Ris06} 106 | \bibcite{sangwine2000colour}{SE00} 107 | \bibcite{shi2007quaternion}{SF07} 108 | \bibcite{srivastava2015training}{SGS15} 109 | \bibcite{trabelsi2017deep}{TBS{$^{+}$}17} 110 | \bibcite{wisdom2016full}{WPH{$^{+}$}16} 111 | \bibstyle{alpha} 112 | \@writefile{toc}{\contentsline {section}{\numberline {7}Appendix}{13}} 113 | \@writefile{toc}{\contentsline {subsection}{\numberline {7.1}The Generalized Quaternion Chain Rule for a Real-Valued Function}{13}} 114 | \newlabel{a:diff}{{7.1}{13}} 115 | \citation{kessy2017optimal} 116 | \newlabel{eq:diff1}{{16}{14}} 117 | \@writefile{toc}{\contentsline {subsection}{\numberline {7.2}Whitening a Matrix}{14}} 118 | \newlabel{a:whitening}{{7.2}{14}} 119 | \newlabel{eq:white1}{{17}{15}} 120 | \newlabel{eq:white2}{{18}{15}} 121 | \@writefile{toc}{\contentsline {subsection}{\numberline {7.3}Cholesky Decomposition}{15}} 122 | \newlabel{eq:cholesky1}{{19}{15}} 123 | \@writefile{toc}{\contentsline {subsection}{\numberline {7.4}4 DOF Independent Normal Distribution}{15}} 124 | \newlabel{eq:single_dists}{{20}{15}} 125 | \newlabel{eq:cumdist}{{21}{15}} 126 | \newlabel{eq:4dsphere}{{22}{16}} 127 | \newlabel{eq:polarint}{{23}{16}} 128 | \newlabel{eq:finaldist}{{24}{16}} 129 | -------------------------------------------------------------------------------- /paper/DeepQuaternionNets.bbl: -------------------------------------------------------------------------------- 1 | \newcommand{\etalchar}[1]{$^{#1}$} 2 | \begin{thebibliography}{DWU{\etalchar{+}}16} 3 | 4 | \bibitem[ASB16]{arjovsky2016unitary} 5 | Martin Arjovsky, Amar Shah, and Yoshua Bengio. 6 | \newblock Unitary evolution recurrent neural networks. 7 | \newblock In {\em International Conference on Machine Learning}, pages 8 | 1120--1128, 2016. 9 | 10 | \bibitem[BS01]{bulow2001hypercomplex} 11 | Thomas Bulow and Gerald Sommer. 12 | \newblock Hypercomplex signals-a novel extension of the analytic signal to the 13 | multidimensional case. 14 | \newblock {\em IEEE Transactions on signal processing}, 49(11):2844--2852, 15 | 2001. 16 | 17 | \bibitem[B{\"u}l99]{bulow1999hypercomplex} 18 | Thomas B{\"u}low. 19 | \newblock {\em Hypercomplex spectral signal representations for the processing 20 | and analysis of images}. 21 | \newblock Universit{\"a}t Kiel. Institut f{\"u}r Informatik und Praktische 22 | Mathematik, 1999. 23 | 24 | \bibitem[Cho16]{chollet2016xception} 25 | Fran{\c{c}}ois Chollet. 26 | \newblock Xception: Deep learning with depthwise separable convolutions. 27 | \newblock {\em arXiv preprint arXiv:1610.02357}, 2016. 28 | 29 | \bibitem[CUH15]{clevert2015fast} 30 | Djork-Arn{\'e} Clevert, Thomas Unterthiner, and Sepp Hochreiter. 31 | \newblock Fast and accurate deep network learning by exponential linear units 32 | (elus). 33 | \newblock {\em arXiv preprint arXiv:1511.07289}, 2015. 34 | 35 | \bibitem[DWU{\etalchar{+}}16]{danihelka2016associative} 36 | Ivo Danihelka, Greg Wayne, Benigno Uria, Nal Kalchbrenner, and Alex Graves. 37 | \newblock Associative long short-term memory. 38 | \newblock {\em arXiv preprint arXiv:1602.03032}, 2016. 39 | 40 | \bibitem[GB10]{glorot2010understanding} 41 | Xavier Glorot and Yoshua Bengio. 42 | \newblock Understanding the difficulty of training deep feedforward neural 43 | networks. 44 | \newblock In {\em Proceedings of the Thirteenth International Conference on 45 | Artificial Intelligence and Statistics}, pages 249--256, 2010. 46 | 47 | \bibitem[Hoc91]{hochreiter1991untersuchungen} 48 | Sepp Hochreiter. 49 | \newblock Untersuchungen zu dynamischen neuronalen netzen. 50 | \newblock {\em Diploma, Technische Universit{\"a}t M{\"u}nchen}, 91, 1991. 51 | 52 | \bibitem[HZRS15]{he2015delving} 53 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 54 | \newblock Delving deep into rectifiers: Surpassing human-level performance on 55 | imagenet classification. 56 | \newblock In {\em Proceedings of the IEEE international conference on computer 57 | vision}, pages 1026--1034, 2015. 58 | 59 | \bibitem[HZRS16]{he2016deep} 60 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 61 | \newblock Deep residual learning for image recognition. 62 | \newblock In {\em Proceedings of the IEEE conference on computer vision and 63 | pattern recognition}, pages 770--778, 2016. 64 | 65 | \bibitem[IS15]{ioffe2015batch} 66 | Sergey Ioffe and Christian Szegedy. 67 | \newblock Batch normalization: Accelerating deep network training by reducing 68 | internal covariate shift. 69 | \newblock In {\em International Conference on Machine Learning}, pages 70 | 448--456, 2015. 71 | 72 | \bibitem[KGC15]{kendall2015posenet} 73 | Alex Kendall, Matthew Grimes, and Roberto Cipolla. 74 | \newblock Posenet: A convolutional network for real-time 6-dof camera 75 | relocalization. 76 | \newblock In {\em Proceedings of the IEEE international conference on computer 77 | vision}, pages 2938--2946, 2015. 78 | 79 | \bibitem[KLS17]{kessy2017optimal} 80 | Agnan Kessy, Alex Lewin, and Korbinian Strimmer. 81 | \newblock Optimal whitening and decorrelation. 82 | \newblock {\em The American Statistician}, (just-accepted), 2017. 83 | 84 | \bibitem[MINM17]{minemoto2017feed} 85 | Toshifumi Minemoto, Teijiro Isokawa, Haruhiko Nishimura, and Nobuyuki Matsui. 86 | \newblock Feed forward neural network with random quaternionic neurons. 87 | \newblock {\em Signal Processing}, 136:59--68, 2017. 88 | 89 | \bibitem[Nes]{nesterov1983method} 90 | Yurii Nesterov. 91 | \newblock A method of solving a convex programming problem with convergence 92 | rate o (1/k2). 93 | 94 | \bibitem[NH10]{nair2010rectified} 95 | Vinod Nair and Geoffrey~E Hinton. 96 | \newblock Rectified linear units improve restricted boltzmann machines. 97 | \newblock In {\em Proceedings of the 27th international conference on machine 98 | learning (ICML-10)}, pages 807--814, 2010. 99 | 100 | \bibitem[OL81]{oppenheim1981importance} 101 | Alan~V Oppenheim and Jae~S Lim. 102 | \newblock The importance of phase in signals. 103 | \newblock {\em Proceedings of the IEEE}, 69(5):529--541, 1981. 104 | 105 | \bibitem[PMB{\etalchar{+}}16]{parcollet2016quaternion} 106 | Titouan Parcollet, Mohamed Morchid, Pierre-Michel Bousquet, Richard Dufour, 107 | Georges Linar{\`e}s, and Renato De~Mori. 108 | \newblock Quaternion neural networks for spoken language understanding. 109 | \newblock In {\em Spoken Language Technology Workshop (SLT), 2016 IEEE}, pages 110 | 362--368. IEEE, 2016. 111 | 112 | \bibitem[PT02]{turner2002} 113 | Robert Piziak and Danny Turner. 114 | \newblock The polar form of a quaternion. 115 | \newblock 2002. 116 | 117 | \bibitem[Ris06]{rishiyur2006neural} 118 | Adityan Rishiyur. 119 | \newblock Neural networks with complex and quaternion inputs. 120 | \newblock {\em arXiv preprint cs/0607090}, 2006. 121 | 122 | \bibitem[SE00]{sangwine2000colour} 123 | Stephen~J Sangwine and Todd~A Ell. 124 | \newblock Colour image filters based on hypercomplex convolution. 125 | \newblock {\em IEE Proceedings-Vision, Image and Signal Processing}, 126 | 147(2):89--93, 2000. 127 | 128 | \bibitem[SF07]{shi2007quaternion} 129 | Lilong Shi and Brian Funt. 130 | \newblock Quaternion color texture segmentation. 131 | \newblock {\em Computer Vision and image understanding}, 107(1):88--96, 2007. 132 | 133 | \bibitem[SGS15]{srivastava2015training} 134 | Rupesh~K Srivastava, Klaus Greff, and J{\"u}rgen Schmidhuber. 135 | \newblock Training very deep networks. 136 | \newblock In {\em Advances in neural information processing systems}, pages 137 | 2377--2385, 2015. 138 | 139 | \bibitem[TBS{\etalchar{+}}17]{trabelsi2017deep} 140 | Chiheb Trabelsi, Olexa Bilaniuk, Dmitriy Serdyuk, Sandeep Subramanian, 141 | Jo{\~a}o~Felipe Santos, Soroush Mehri, Negar Rostamzadeh, Yoshua Bengio, and 142 | Christopher~J Pal. 143 | \newblock Deep complex networks. 144 | \newblock {\em arXiv preprint arXiv:1705.09792}, 2017. 145 | 146 | \bibitem[WPH{\etalchar{+}}16]{wisdom2016full} 147 | Scott Wisdom, Thomas Powers, John Hershey, Jonathan Le~Roux, and Les Atlas. 148 | \newblock Full-capacity unitary recurrent neural networks. 149 | \newblock In {\em Advances in Neural Information Processing Systems}, pages 150 | 4880--4888, 2016. 151 | 152 | \end{thebibliography} 153 | -------------------------------------------------------------------------------- /paper/DeepQuaternionNets.blg: -------------------------------------------------------------------------------- 1 | This is BibTeX, Version 0.99dThe top-level auxiliary file: DeepQuaternionNets.aux 2 | The style file: alpha.bst 3 | Database file #1: bib.bib 4 | Warning--there's a number but no volume in kessy2017optimal 5 | Warning--empty booktitle in nesterov1983method 6 | Warning--empty year in nesterov1983method 7 | Warning--empty booktitle in turner2002 8 | (There were 4 warnings) 9 | -------------------------------------------------------------------------------- /paper/DeepQuaternionNets.log: -------------------------------------------------------------------------------- 1 | This is pdfTeX, Version 3.14159265-2.6-1.40.17 (MiKTeX 2.9) (preloaded format=pdflatex 2016.6.6) 26 JAN 2018 22:01 2 | entering extended mode 3 | **./DeepQuaternionNets.tex 4 | (DeepQuaternionNets.tex 5 | LaTeX2e <2016/03/31> 6 | Babel <3.9r> and hyphenation patterns for 75 language(s) loaded. 7 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\base\article.cls" 8 | Document Class: article 2014/09/29 v1.4h Standard LaTeX document class 9 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\base\size10.clo" 10 | File: size10.clo 2014/09/29 v1.4h Standard LaTeX file (size option) 11 | ) 12 | \c@part=\count79 13 | \c@section=\count80 14 | \c@subsection=\count81 15 | \c@subsubsection=\count82 16 | \c@paragraph=\count83 17 | \c@subparagraph=\count84 18 | \c@figure=\count85 19 | \c@table=\count86 20 | \abovecaptionskip=\skip41 21 | \belowcaptionskip=\skip42 22 | \bibindent=\dimen102 23 | ) 24 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\tools\array.sty" 25 | Package: array 2014/10/28 v2.4c Tabular extension package (FMi) 26 | \col@sep=\dimen103 27 | \extrarowheight=\dimen104 28 | \NC@list=\toks14 29 | \extratabsurround=\skip43 30 | \backup@length=\skip44 31 | ) 32 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsmath\amsmath.sty" 33 | Package: amsmath 2016/03/10 v2.15b AMS math features 34 | \@mathmargin=\skip45 35 | 36 | For additional information on amsmath, use the `?' option. 37 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsmath\amstext.sty" 38 | Package: amstext 2000/06/29 v2.01 AMS text 39 | 40 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsmath\amsgen.sty" 41 | File: amsgen.sty 1999/11/30 v2.0 generic functions 42 | \@emptytoks=\toks15 43 | \ex@=\dimen105 44 | )) 45 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsmath\amsbsy.sty" 46 | Package: amsbsy 1999/11/29 v1.2d Bold Symbols 47 | \pmbraise@=\dimen106 48 | ) 49 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsmath\amsopn.sty" 50 | Package: amsopn 2016/03/08 v2.02 operator names 51 | ) 52 | \inf@bad=\count87 53 | LaTeX Info: Redefining \frac on input line 199. 54 | \uproot@=\count88 55 | \leftroot@=\count89 56 | LaTeX Info: Redefining \overline on input line 297. 57 | \classnum@=\count90 58 | \DOTSCASE@=\count91 59 | LaTeX Info: Redefining \ldots on input line 394. 60 | LaTeX Info: Redefining \dots on input line 397. 61 | LaTeX Info: Redefining \cdots on input line 518. 62 | \Mathstrutbox@=\box26 63 | \strutbox@=\box27 64 | \big@size=\dimen107 65 | LaTeX Font Info: Redeclaring font encoding OML on input line 634. 66 | LaTeX Font Info: Redeclaring font encoding OMS on input line 635. 67 | \macc@depth=\count92 68 | \c@MaxMatrixCols=\count93 69 | \dotsspace@=\muskip10 70 | \c@parentequation=\count94 71 | \dspbrk@lvl=\count95 72 | \tag@help=\toks16 73 | \row@=\count96 74 | \column@=\count97 75 | \maxfields@=\count98 76 | \andhelp@=\toks17 77 | \eqnshift@=\dimen108 78 | \alignsep@=\dimen109 79 | \tagshift@=\dimen110 80 | \tagwidth@=\dimen111 81 | \totwidth@=\dimen112 82 | \lineht@=\dimen113 83 | \@envbody=\toks18 84 | \multlinegap=\skip46 85 | \multlinetaggap=\skip47 86 | \mathdisplay@stack=\toks19 87 | LaTeX Info: Redefining \[ on input line 2739. 88 | LaTeX Info: Redefining \] on input line 2740. 89 | ) 90 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsfonts\amsfonts.sty" 91 | Package: amsfonts 2013/01/14 v3.01 Basic AMSFonts support 92 | \symAMSa=\mathgroup4 93 | \symAMSb=\mathgroup5 94 | LaTeX Font Info: Overwriting math alphabet `\mathfrak' in version `bold' 95 | (Font) U/euf/m/n --> U/euf/b/n on input line 106. 96 | ) 97 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsfonts\amssymb.sty" 98 | Package: amssymb 2013/01/14 v3.01 AMS font symbols 99 | ) 100 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\jknappen\mathrsfs.sty" 101 | Package: mathrsfs 1996/01/01 Math RSFS package v1.0 (jk) 102 | \symrsfs=\mathgroup6 103 | ) 104 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\graphics\graphicx.sty" 105 | Package: graphicx 2014/10/28 v1.0g Enhanced LaTeX Graphics (DPC,SPQR) 106 | 107 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\graphics\keyval.sty" 108 | Package: keyval 2014/10/28 v1.15 key=value parser (DPC) 109 | \KV@toks@=\toks20 110 | ) 111 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\graphics\graphics.sty" 112 | Package: graphics 2016/05/09 v1.0r Standard LaTeX Graphics (DPC,SPQR) 113 | 114 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\graphics\trig.sty" 115 | Package: trig 2016/01/03 v1.10 sin cos tan (DPC) 116 | ) 117 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\00miktex\graphics.cfg" 118 | File: graphics.cfg 2016/01/02 v1.10 sample graphics configuration 119 | ) 120 | Package graphics Info: Driver file: pdftex.def on input line 96. 121 | 122 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\pdftex-def\pdftex.def" 123 | File: pdftex.def 2011/05/27 v0.06d Graphics/color for pdfTeX 124 | 125 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\infwarerr.sty" 126 | Package: infwarerr 2016/05/16 v1.4 Providing info/warning/error messages (HO) 127 | ) 128 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\ltxcmds.sty" 129 | Package: ltxcmds 2016/05/16 v1.23 LaTeX kernel commands for general use (HO) 130 | ) 131 | \Gread@gobject=\count99 132 | )) 133 | \Gin@req@height=\dimen114 134 | \Gin@req@width=\dimen115 135 | ) 136 | 137 | LaTeX Warning: Unused global option(s): 138 | [14pt]. 139 | 140 | (DeepQuaternionNets.aux) 141 | \openout1 = `DeepQuaternionNets.aux'. 142 | 143 | LaTeX Font Info: Checking defaults for OML/cmm/m/it on input line 12. 144 | LaTeX Font Info: ... okay on input line 12. 145 | LaTeX Font Info: Checking defaults for T1/cmr/m/n on input line 12. 146 | LaTeX Font Info: ... okay on input line 12. 147 | LaTeX Font Info: Checking defaults for OT1/cmr/m/n on input line 12. 148 | LaTeX Font Info: ... okay on input line 12. 149 | LaTeX Font Info: Checking defaults for OMS/cmsy/m/n on input line 12. 150 | LaTeX Font Info: ... okay on input line 12. 151 | LaTeX Font Info: Checking defaults for OMX/cmex/m/n on input line 12. 152 | LaTeX Font Info: ... okay on input line 12. 153 | LaTeX Font Info: Checking defaults for U/cmr/m/n on input line 12. 154 | LaTeX Font Info: ... okay on input line 12. 155 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\context\base\supp-pdf.mkii" 156 | [Loading MPS to PDF converter (version 2006.09.02).] 157 | \scratchcounter=\count100 158 | \scratchdimen=\dimen116 159 | \scratchbox=\box28 160 | \nofMPsegments=\count101 161 | \nofMParguments=\count102 162 | \everyMPshowfont=\toks21 163 | \MPscratchCnt=\count103 164 | \MPscratchDim=\dimen117 165 | \MPnumerator=\count104 166 | \makeMPintoPDFobject=\count105 167 | \everyMPtoPDFconversion=\toks22 168 | ) ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\pdftexcmds.sty" 169 | Package: pdftexcmds 2016/05/21 v0.22 Utility functions of pdfTeX for LuaTeX (HO) 170 | 171 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\ifluatex.sty" 172 | Package: ifluatex 2016/05/16 v1.4 Provides the ifluatex switch (HO) 173 | Package ifluatex Info: LuaTeX not detected. 174 | ) 175 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\ifpdf.sty" 176 | Package: ifpdf 2016/05/14 v3.1 Provides the ifpdf switch 177 | ) 178 | Package pdftexcmds Info: LuaTeX not detected. 179 | Package pdftexcmds Info: \pdf@primitive is available. 180 | Package pdftexcmds Info: \pdf@ifprimitive is available. 181 | Package pdftexcmds Info: \pdfdraftmode found. 182 | ) 183 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\oberdiek\epstopdf-base.sty" 184 | Package: epstopdf-base 2016/05/15 v2.6 Base part for package epstopdf 185 | 186 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\oberdiek\grfext.sty" 187 | Package: grfext 2016/05/16 v1.2 Manage graphics extensions (HO) 188 | 189 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\kvdefinekeys.sty" 190 | Package: kvdefinekeys 2016/05/16 v1.4 Define keys (HO) 191 | )) 192 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\oberdiek\kvoptions.sty" 193 | Package: kvoptions 2016/05/16 v3.12 Key value format for package options (HO) 194 | 195 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\kvsetkeys.sty" 196 | Package: kvsetkeys 2016/05/16 v1.17 Key value parser (HO) 197 | 198 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\generic\oberdiek\etexcmds.sty" 199 | Package: etexcmds 2016/05/16 v1.6 Avoid name clashes with e-TeX commands (HO) 200 | Package etexcmds Info: Could not find \expanded. 201 | (etexcmds) That can mean that you are not using pdfTeX 1.50 or 202 | (etexcmds) that some package has redefined \expanded. 203 | (etexcmds) In the latter case, load this package earlier. 204 | ))) 205 | Package grfext Info: Graphics extension search list: 206 | (grfext) [.png,.pdf,.jpg,.mps,.jpeg,.jbig2,.jb2,.PNG,.PDF,.JPG,.JPEG,.JBIG2,.JB2,.eps] 207 | (grfext) \AppendGraphicsExtensions on input line 456. 208 | ) 209 | LaTeX Font Info: Try loading font information for U+msa on input line 16. 210 | 211 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsfonts\umsa.fd" 212 | File: umsa.fd 2013/01/14 v3.01 AMS symbols A 213 | ) 214 | LaTeX Font Info: Try loading font information for U+msb on input line 16. 215 | 216 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\amsfonts\umsb.fd" 217 | File: umsb.fd 2013/01/14 v3.01 AMS symbols B 218 | ) 219 | LaTeX Font Info: Try loading font information for U+rsfs on input line 16. 220 | 221 | ("C:\Program Files (x86)\MiKTeX 2.9\tex\latex\jknappen\ursfs.fd" 222 | File: ursfs.fd 1998/03/24 rsfs font definition file (jk) 223 | ) [1 224 | 225 | {C:/ProgramData/MiKTeX/2.9/pdftex/config/pdftex.map}] [2] [3] [4] 226 | File: figures/quatconv.png Graphic file (type png) 227 | 228 | Package pdftex.def Info: figures/quatconv.png used on input line 197. 229 | (pdftex.def) Requested size: 345.0pt x 413.83757pt. 230 | [5 <./figures/quatconv.png>] 231 | 232 | LaTeX Font Warning: Command \small invalid in math mode on input line 238. 233 | 234 | 235 | LaTeX Font Warning: Command \small invalid in math mode on input line 238. 236 | 237 | 238 | Overfull \hbox (31.37411pt too wide) in paragraph at lines 238--238 239 | [] 240 | [] 241 | 242 | [6] 243 | Overfull \hbox (2.52545pt too wide) in paragraph at lines 321--324 244 | []\OT1/cmr/m/n/10 To fol-low the Glo-rot and Ben-gio [[]] ini-tial-iza-tion we have $[](\OML/cmm/m/it/10 W\OT1/cmr/m/n/1 245 | 0 ) = 2\OML/cmm/m/it/10 =\OT1/cmr/m/n/10 (\OML/cmm/m/it/10 n[] \OT1/cmr/m/n/10 + 246 | [] 247 | 248 | [7] 249 | 250 | LaTeX Warning: `h' float specifier changed to `ht'. 251 | 252 | [8] [9] [10] (DeepQuaternionNets.bbl [11 253 | 254 | ]) [12] [13 255 | 256 | ] [14] [15] [16] (DeepQuaternionNets.aux) ) 257 | Here is how much of TeX's memory you used: 258 | 2558 strings out of 493335 259 | 33958 string characters out of 3136681 260 | 108312 words of memory out of 3000000 261 | 6028 multiletter control sequences out of 15000+200000 262 | 18315 words of font info for 71 fonts, out of 3000000 for 9000 263 | 1141 hyphenation exceptions out of 8191 264 | 37i,16n,24p,454b,250s stack positions out of 5000i,500n,10000p,200000b,50000s 265 | 266 | 280 | Output written on DeepQuaternionNets.pdf (16 pages, 424893 bytes). 281 | PDF statistics: 282 | 145 PDF objects out of 1000 (max. 8388607) 283 | 0 named destinations out of 1000 (max. 500000) 284 | 6 words of extra memory for PDF output out of 10000 (max. 10000000) 285 | 286 | -------------------------------------------------------------------------------- /paper/DeepQuaternionNets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/paper/DeepQuaternionNets.pdf -------------------------------------------------------------------------------- /paper/DeepQuaternionNets.tps: -------------------------------------------------------------------------------- 1 | [FormatInfo] 2 | Type=TeXnicCenterProjectSessionInformation 3 | Version=2 4 | 5 | [SessionInfo] 6 | FrameCount=0 7 | ActiveFrame=-1 8 | 9 | -------------------------------------------------------------------------------- /paper/bib.bib: -------------------------------------------------------------------------------- 1 | @article{srivastava2014dropout, 2 | title={Dropout: a simple way to prevent neural networks from overfitting}, 3 | author={Srivastava, Nitish and Hinton, Geoffrey E and Krizhevsky, Alex and Sutskever, Ilya and Salakhutdinov, Ruslan}, 4 | journal={Journal of Machine Learning Research}, 5 | volume={15}, 6 | number={1}, 7 | pages={1929--1958}, 8 | year={2014} 9 | } 10 | 11 | 12 | @misc{chollet2015, 13 | author = {François Chollet}, 14 | title = {keras}, 15 | year = {2015}, 16 | publisher = {GitHub}, 17 | journal = {GitHub repository}, 18 | howpublished = {https://github.com/fchollet/keras} 19 | } 20 | 21 | 22 | @article{hochreiter1991untersuchungen, 23 | title={Untersuchungen zu dynamischen neuronalen Netzen}, 24 | author={Hochreiter, Sepp}, 25 | journal={Diploma, Technische Universit{\"a}t M{\"u}nchen}, 26 | volume={91}, 27 | year={1991} 28 | } 29 | 30 | 31 | @inproceedings{srivastava2015training, 32 | title={Training very deep networks}, 33 | author={Srivastava, Rupesh K and Greff, Klaus and Schmidhuber, J{\"u}rgen}, 34 | booktitle={Advances in neural information processing systems}, 35 | pages={2377--2385}, 36 | year={2015} 37 | } 38 | 39 | 40 | @article{clevert2015fast, 41 | title={Fast and accurate deep network learning by exponential linear units (elus)}, 42 | author={Clevert, Djork-Arn{\'e} and Unterthiner, Thomas and Hochreiter, Sepp}, 43 | journal={arXiv preprint arXiv:1511.07289}, 44 | year={2015} 45 | } 46 | 47 | 48 | @inproceedings{salimans2016weight, 49 | title={Weight normalization: A simple reparameterization to accelerate training of deep neural networks}, 50 | author={Salimans, Tim and Kingma, Diederik P}, 51 | booktitle={Advances in Neural Information Processing Systems}, 52 | pages={901--909}, 53 | year={2016} 54 | } 55 | 56 | 57 | @inproceedings{ioffe2015batch, 58 | title={Batch normalization: Accelerating deep network training by reducing internal covariate shift}, 59 | author={Ioffe, Sergey and Szegedy, Christian}, 60 | booktitle={International Conference on Machine Learning}, 61 | pages={448--456}, 62 | year={2015} 63 | } 64 | 65 | 66 | @inproceedings{he2016deep, 67 | title={Deep residual learning for image recognition}, 68 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 69 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 70 | pages={770--778}, 71 | year={2016} 72 | } 73 | 74 | 75 | @inproceedings{arjovsky2016unitary, 76 | title={Unitary evolution recurrent neural networks}, 77 | author={Arjovsky, Martin and Shah, Amar and Bengio, Yoshua}, 78 | booktitle={International Conference on Machine Learning}, 79 | pages={1120--1128}, 80 | year={2016} 81 | } 82 | 83 | 84 | @article{danihelka2016associative, 85 | title={Associative long short-term memory}, 86 | author={Danihelka, Ivo and Wayne, Greg and Uria, Benigno and Kalchbrenner, Nal and Graves, Alex}, 87 | journal={arXiv preprint arXiv:1602.03032}, 88 | year={2016} 89 | } 90 | 91 | 92 | @inproceedings{wisdom2016full, 93 | title={Full-capacity unitary recurrent neural networks}, 94 | author={Wisdom, Scott and Powers, Thomas and Hershey, John and Le Roux, Jonathan and Atlas, Les}, 95 | booktitle={Advances in Neural Information Processing Systems}, 96 | pages={4880--4888}, 97 | year={2016} 98 | } 99 | 100 | @book{bulow1999hypercomplex, 101 | title={Hypercomplex spectral signal representations for the processing and analysis of images}, 102 | author={B{\"u}low, Thomas}, 103 | year={1999}, 104 | publisher={Universit{\"a}t Kiel. Institut f{\"u}r Informatik und Praktische Mathematik} 105 | } 106 | 107 | @article{oppenheim1981importance, 108 | title={The importance of phase in signals}, 109 | author={Oppenheim, Alan V and Lim, Jae S}, 110 | journal={Proceedings of the IEEE}, 111 | volume={69}, 112 | number={5}, 113 | pages={529--541}, 114 | year={1981}, 115 | publisher={IEEE} 116 | } 117 | 118 | @article{sangwine2000colour, 119 | title={Colour image filters based on hypercomplex convolution}, 120 | author={Sangwine, Stephen J and Ell, Todd A}, 121 | journal={IEE Proceedings-Vision, Image and Signal Processing}, 122 | volume={147}, 123 | number={2}, 124 | pages={89--93}, 125 | year={2000}, 126 | publisher={IET} 127 | } 128 | 129 | @article{bulow2001hypercomplex, 130 | title={Hypercomplex signals-a novel extension of the analytic signal to the multidimensional case}, 131 | author={Bulow, Thomas and Sommer, Gerald}, 132 | journal={IEEE Transactions on signal processing}, 133 | volume={49}, 134 | number={11}, 135 | pages={2844--2852}, 136 | year={2001}, 137 | publisher={IEEE} 138 | } 139 | 140 | @inproceedings{parcollet2016quaternion, 141 | title={Quaternion Neural Networks for Spoken Language Understanding}, 142 | author={Parcollet, Titouan and Morchid, Mohamed and Bousquet, Pierre-Michel and Dufour, Richard and Linar{\`e}s, Georges and De Mori, Renato}, 143 | booktitle={Spoken Language Technology Workshop (SLT), 2016 IEEE}, 144 | pages={362--368}, 145 | year={2016}, 146 | organization={IEEE} 147 | } 148 | 149 | @incollection{witten2006quaternion, 150 | title={Quaternion-based signal processing}, 151 | author={Witten, Ben and Shragge, Jeff}, 152 | booktitle={SEG Technical Program Expanded Abstracts 2006}, 153 | pages={2862--2866}, 154 | year={2006}, 155 | publisher={Society of Exploration Geophysicists} 156 | } 157 | 158 | @article{minemoto2017feed, 159 | title={Feed forward neural network with random quaternionic neurons}, 160 | author={Minemoto, Toshifumi and Isokawa, Teijiro and Nishimura, Haruhiko and Matsui, Nobuyuki}, 161 | journal={Signal Processing}, 162 | volume={136}, 163 | pages={59--68}, 164 | year={2017}, 165 | publisher={Elsevier} 166 | } 167 | 168 | @article{rishiyur2006neural, 169 | title={Neural Networks with Complex and Quaternion Inputs}, 170 | author={Rishiyur, Adityan}, 171 | journal={arXiv preprint cs/0607090}, 172 | year={2006} 173 | } 174 | 175 | 176 | @inproceedings{kendall2015posenet, 177 | title={Posenet: A convolutional network for real-time 6-dof camera relocalization}, 178 | author={Kendall, Alex and Grimes, Matthew and Cipolla, Roberto}, 179 | booktitle={Proceedings of the IEEE international conference on computer vision}, 180 | pages={2938--2946}, 181 | year={2015} 182 | } 183 | 184 | 185 | @article{trabelsi2017deep, 186 | title={Deep Complex Networks}, 187 | author={Trabelsi, Chiheb and Bilaniuk, Olexa and Serdyuk, Dmitriy and Subramanian, Sandeep and Santos, Jo{\~a}o Felipe and Mehri, Soroush and Rostamzadeh, Negar and Bengio, Yoshua and Pal, Christopher J}, 188 | journal={arXiv preprint arXiv:1705.09792}, 189 | year={2017} 190 | } 191 | 192 | @inproceedings{glorot2010understanding, 193 | title={Understanding the difficulty of training deep feedforward neural networks}, 194 | author={Glorot, Xavier and Bengio, Yoshua}, 195 | booktitle={Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics}, 196 | pages={249--256}, 197 | year={2010} 198 | } 199 | 200 | @article{chollet2016xception, 201 | title={Xception: Deep Learning with Depthwise Separable Convolutions}, 202 | author={Chollet, Fran{\c{c}}ois}, 203 | journal={arXiv preprint arXiv:1610.02357}, 204 | year={2016} 205 | } 206 | 207 | @article{shi2007quaternion, 208 | title={Quaternion color texture segmentation}, 209 | author={Shi, Lilong and Funt, Brian}, 210 | journal={Computer Vision and image understanding}, 211 | volume={107}, 212 | number={1}, 213 | pages={88--96}, 214 | year={2007}, 215 | publisher={Elsevier} 216 | } 217 | 218 | @inproceedings{he2015delving, 219 | title={Delving deep into rectifiers: Surpassing human-level performance on imagenet classification}, 220 | author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, 221 | booktitle={Proceedings of the IEEE international conference on computer vision}, 222 | pages={1026--1034}, 223 | year={2015} 224 | } 225 | 226 | @inproceedings{nair2010rectified, 227 | title={Rectified linear units improve restricted boltzmann machines}, 228 | author={Nair, Vinod and Hinton, Geoffrey E}, 229 | booktitle={Proceedings of the 27th international conference on machine learning (ICML-10)}, 230 | pages={807--814}, 231 | year={2010} 232 | } 233 | 234 | @inproceedings{turner2002, 235 | title={The polar form of a quaternion}, 236 | author={Piziak, Robert and Turner, Danny}, 237 | journal={The Mathematica Journal}, 238 | year={2002} 239 | } 240 | 241 | @inproceedings{nesterov1983method, 242 | title={A method of solving a convex programming problem with convergence rate O (1/k2)}, 243 | author={Nesterov, Yurii} 244 | } 245 | 246 | @article{kessy2017optimal, 247 | title={Optimal whitening and decorrelation}, 248 | author={Kessy, Agnan and Lewin, Alex and Strimmer, Korbinian}, 249 | journal={The American Statistician}, 250 | number={just-accepted}, 251 | year={2017}, 252 | publisher={Taylor \& Francis} 253 | } 254 | -------------------------------------------------------------------------------- /paper/figures/quatconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/paper/figures/quatconv.png -------------------------------------------------------------------------------- /paper/ieee_version/DeepQuaternionNets.aux: -------------------------------------------------------------------------------- 1 | \relax 2 | \citation{ioffe2015batch} 3 | \citation{srivastava2015training} 4 | \citation{he2016deep} 5 | \citation{clevert2015fast} 6 | \citation{hochreiter1991untersuchungen} 7 | \citation{arjovsky2016unitary,danihelka2016associative,wisdom2016full} 8 | \citation{trabelsi2017deep} 9 | \citation{bulow1999hypercomplex,sangwine2000colour,bulow2001hypercomplex} 10 | \citation{rishiyur2006neural,kendall2015posenet} 11 | \citation{parcollet2016quaternion} 12 | \citation{minemoto2017feed} 13 | \citation{trabelsi2017deep} 14 | \citation{kendall2015posenet} 15 | \citation{oppenheim1981importance} 16 | \citation{bulow2001hypercomplex} 17 | \citation{sangwine2000colour} 18 | \citation{shi2007quaternion} 19 | \@writefile{toc}{\contentsline {section}{\numberline {I}Introduction}{1}} 20 | \@writefile{toc}{\contentsline {section}{\numberline {II}Motivation and Related Work}{1}} 21 | \citation{minemoto2017feed} 22 | \citation{trabelsi2017deep} 23 | \citation{bulow1999hypercomplex,sangwine2000colour,bulow2001hypercomplex} 24 | \citation{trabelsi2017deep} 25 | \@writefile{toc}{\contentsline {section}{\numberline {III}Quaternion Network Components}{2}} 26 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {III-A}}Quaternion Representation}{2}} 27 | \newlabel{eq:quaternion1}{{1}{2}} 28 | \newlabel{eq:quarternion2}{{2}{2}} 29 | \newlabel{eq:quarternion3}{{3}{2}} 30 | \newlabel{eq:m4r}{{4}{2}} 31 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {III-B}}Quaternion Differentiability}{2}} 32 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {III-C}}Quaternion Convolution}{2}} 33 | \newlabel{s:qc}{{\unhbox \voidb@x \hbox {III-C}}{2}} 34 | \citation{chollet2016xception} 35 | \citation{shi2007quaternion} 36 | \citation{ioffe2015batch} 37 | \citation{trabelsi2017deep} 38 | \citation{glorot2010understanding} 39 | \citation{he2015delving} 40 | \citation{turner2002} 41 | \newlabel{eq:qconvolve2}{{6}{3}} 42 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {III-D}}Quaternion Batch-Normalization}{3}} 43 | \newlabel{eq:white4d}{{7}{3}} 44 | \newlabel{eq:gamma}{{8}{3}} 45 | \newlabel{eq:qbn}{{9}{3}} 46 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {III-E}}Quaternion Weight Initialization}{3}} 47 | \newlabel{eq:quaternion_weight}{{10}{3}} 48 | \newlabel{eq:variance}{{11}{3}} 49 | \@writefile{lof}{\contentsline {figure}{\numberline {1}{\ignorespaces An illustration of quaternion convolution.}}{4}} 50 | \newlabel{f:quatconv}{{1}{4}} 51 | \citation{glorot2010understanding} 52 | \citation{he2015delving} 53 | \citation{nair2010rectified} 54 | \citation{krizhevsky2009learning} 55 | \citation{Fritsch2013ITSC} 56 | \citation{trabelsi2017deep} 57 | \citation{he2016deep} 58 | \citation{trabelsi2017deep} 59 | \citation{nesterov1983method} 60 | \citation{trabelsi2017deep} 61 | \newlabel{eq:expected}{{12}{5}} 62 | \newlabel{eq:variance_sigma}{{13}{5}} 63 | \@writefile{toc}{\contentsline {section}{\numberline {IV}Experimental Results}{5}} 64 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {IV-A}}Classification}{5}} 65 | \@writefile{lot}{\contentsline {table}{\numberline {I}{\ignorespaces Classification error on CIFAR-10 and CIFAR-100. Params is the total number of parameters.}}{5}} 66 | \newlabel{t:results1}{{I}{5}} 67 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {IV-B}}Segmentation}{5}} 68 | \bibdata{bib} 69 | \bibcite{ioffe2015batch}{1} 70 | \bibcite{srivastava2015training}{2} 71 | \bibcite{he2016deep}{3} 72 | \bibcite{clevert2015fast}{4} 73 | \bibcite{hochreiter1991untersuchungen}{5} 74 | \bibcite{arjovsky2016unitary}{6} 75 | \bibcite{danihelka2016associative}{7} 76 | \bibcite{wisdom2016full}{8} 77 | \bibcite{trabelsi2017deep}{9} 78 | \bibcite{bulow1999hypercomplex}{10} 79 | \bibcite{sangwine2000colour}{11} 80 | \bibcite{bulow2001hypercomplex}{12} 81 | \bibcite{rishiyur2006neural}{13} 82 | \bibcite{kendall2015posenet}{14} 83 | \bibcite{parcollet2016quaternion}{15} 84 | \bibcite{minemoto2017feed}{16} 85 | \bibcite{oppenheim1981importance}{17} 86 | \bibcite{shi2007quaternion}{18} 87 | \bibcite{chollet2016xception}{19} 88 | \bibcite{glorot2010understanding}{20} 89 | \bibcite{he2015delving}{21} 90 | \bibcite{turner2002}{22} 91 | \bibcite{nair2010rectified}{23} 92 | \bibcite{krizhevsky2009learning}{24} 93 | \bibcite{Fritsch2013ITSC}{25} 94 | \bibcite{nesterov1983method}{26} 95 | \bibcite{kessy2017optimal}{27} 96 | \bibstyle{ieeetr} 97 | \@writefile{lot}{\contentsline {table}{\numberline {II}{\ignorespaces IOU on KITTI Road Estimation benchmark.}}{6}} 98 | \newlabel{t:results2}{{II}{6}} 99 | \@writefile{toc}{\contentsline {section}{\numberline {V}Conclusions}{6}} 100 | \@writefile{toc}{\contentsline {section}{\numberline {VI}Acknowledgment}{6}} 101 | \@writefile{toc}{\contentsline {section}{References}{6}} 102 | \@writefile{toc}{\contentsline {section}{\numberline {VII}Appendix}{6}} 103 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {VII-A}}The Generalized Quaternion Chain Rule for a Real-Valued Function}{6}} 104 | \newlabel{a:diff}{{\unhbox \voidb@x \hbox {VII-A}}{6}} 105 | \citation{kessy2017optimal} 106 | \newlabel{eq:diff1}{{15}{7}} 107 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {VII-B}}Whitening a Matrix}{7}} 108 | \newlabel{a:whitening}{{\unhbox \voidb@x \hbox {VII-B}}{7}} 109 | \newlabel{eq:white1}{{16}{7}} 110 | \newlabel{eq:white2}{{17}{7}} 111 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {VII-C}}Cholesky Decomposition}{7}} 112 | \newlabel{eq:cholesky1}{{18}{7}} 113 | \@writefile{toc}{\contentsline {subsection}{\numberline {\unhbox \voidb@x \hbox {VII-D}}4 DOF Independent Normal Distribution}{7}} 114 | \newlabel{eq:single_dists}{{19}{7}} 115 | \newlabel{eq:cumdist}{{20}{7}} 116 | \newlabel{eq:4dsphere}{{21}{7}} 117 | \newlabel{eq:polarint}{{22}{7}} 118 | \newlabel{eq:finaldist}{{23}{8}} 119 | -------------------------------------------------------------------------------- /paper/ieee_version/DeepQuaternionNets.bbl: -------------------------------------------------------------------------------- 1 | \begin{thebibliography}{10} 2 | 3 | \bibitem{ioffe2015batch} 4 | S.~Ioffe and C.~Szegedy, ``Batch normalization: Accelerating deep network 5 | training by reducing internal covariate shift,'' {\em arXiv preprint 6 | arXiv:1502.03167}, 2015. 7 | 8 | \bibitem{srivastava2015training} 9 | R.~K. Srivastava, K.~Greff, and J.~Schmidhuber, ``Training very deep 10 | networks,'' in {\em Advances in neural information processing systems}, 11 | pp.~2377--2385, 2015. 12 | 13 | \bibitem{he2016deep} 14 | K.~He, X.~Zhang, S.~Ren, and J.~Sun, ``Deep residual learning for image 15 | recognition,'' in {\em Proceedings of the IEEE conference on computer vision 16 | and pattern recognition}, pp.~770--778, 2016. 17 | 18 | \bibitem{clevert2015fast} 19 | D.-A. Clevert, T.~Unterthiner, and S.~Hochreiter, ``Fast and accurate deep 20 | network learning by exponential linear units (elus),'' {\em arXiv preprint 21 | arXiv:1511.07289}, 2015. 22 | 23 | \bibitem{hochreiter1991untersuchungen} 24 | S.~Hochreiter, ``Untersuchungen zu dynamischen neuronalen netzen,'' {\em 25 | Diploma, Technische Universit{\"a}t M{\"u}nchen}, vol.~91, 1991. 26 | 27 | \bibitem{arjovsky2016unitary} 28 | M.~Arjovsky, A.~Shah, and Y.~Bengio, ``Unitary evolution recurrent neural 29 | networks,'' in {\em International Conference on Machine Learning}, 30 | pp.~1120--1128, 2016. 31 | 32 | \bibitem{danihelka2016associative} 33 | I.~Danihelka, G.~Wayne, B.~Uria, N.~Kalchbrenner, and A.~Graves, ``Associative 34 | long short-term memory,'' {\em arXiv preprint arXiv:1602.03032}, 2016. 35 | 36 | \bibitem{wisdom2016full} 37 | S.~Wisdom, T.~Powers, J.~Hershey, J.~Le~Roux, and L.~Atlas, ``Full-capacity 38 | unitary recurrent neural networks,'' in {\em Advances in Neural Information 39 | Processing Systems}, pp.~4880--4888, 2016. 40 | 41 | \bibitem{trabelsi2017deep} 42 | C.~Trabelsi, O.~Bilaniuk, D.~Serdyuk, S.~Subramanian, J.~F. Santos, S.~Mehri, 43 | N.~Rostamzadeh, Y.~Bengio, and C.~J. Pal, ``Deep complex networks,'' {\em 44 | arXiv preprint arXiv:1705.09792}, 2017. 45 | 46 | \bibitem{bulow1999hypercomplex} 47 | T.~B{\"u}low, {\em Hypercomplex spectral signal representations for the 48 | processing and analysis of images}. 49 | \newblock Universit{\"a}t Kiel. Institut f{\"u}r Informatik und Praktische 50 | Mathematik, 1999. 51 | 52 | \bibitem{sangwine2000colour} 53 | S.~J. Sangwine and T.~A. Ell, ``Colour image filters based on hypercomplex 54 | convolution,'' {\em IEE Proceedings-Vision, Image and Signal Processing}, 55 | vol.~147, no.~2, pp.~89--93, 2000. 56 | 57 | \bibitem{bulow2001hypercomplex} 58 | T.~Bulow and G.~Sommer, ``Hypercomplex signals-a novel extension of the 59 | analytic signal to the multidimensional case,'' {\em IEEE Transactions on 60 | signal processing}, vol.~49, no.~11, pp.~2844--2852, 2001. 61 | 62 | \bibitem{rishiyur2006neural} 63 | A.~Rishiyur, ``Neural networks with complex and quaternion inputs,'' {\em arXiv 64 | preprint cs/0607090}, 2006. 65 | 66 | \bibitem{kendall2015posenet} 67 | A.~Kendall, M.~Grimes, and R.~Cipolla, ``Posenet: A convolutional network for 68 | real-time 6-dof camera relocalization,'' in {\em Proceedings of the IEEE 69 | international conference on computer vision}, pp.~2938--2946, 2015. 70 | 71 | \bibitem{parcollet2016quaternion} 72 | T.~Parcollet, M.~Morchid, P.-M. Bousquet, R.~Dufour, G.~Linar{\`e}s, and 73 | R.~De~Mori, ``Quaternion neural networks for spoken language understanding,'' 74 | in {\em Spoken Language Technology Workshop (SLT), 2016 IEEE}, pp.~362--368, 75 | IEEE, 2016. 76 | 77 | \bibitem{minemoto2017feed} 78 | T.~Minemoto, T.~Isokawa, H.~Nishimura, and N.~Matsui, ``Feed forward neural 79 | network with random quaternionic neurons,'' {\em Signal Processing}, 80 | vol.~136, pp.~59--68, 2017. 81 | 82 | \bibitem{oppenheim1981importance} 83 | A.~V. Oppenheim and J.~S. Lim, ``The importance of phase in signals,'' {\em 84 | Proceedings of the IEEE}, vol.~69, no.~5, pp.~529--541, 1981. 85 | 86 | \bibitem{shi2007quaternion} 87 | L.~Shi and B.~Funt, ``Quaternion color texture segmentation,'' {\em Computer 88 | Vision and image understanding}, vol.~107, no.~1, pp.~88--96, 2007. 89 | 90 | \bibitem{chollet2016xception} 91 | F.~Chollet, ``Xception: Deep learning with depthwise separable convolutions,'' 92 | {\em arXiv preprint arXiv:1610.02357}, 2016. 93 | 94 | \bibitem{glorot2010understanding} 95 | X.~Glorot and Y.~Bengio, ``Understanding the difficulty of training deep 96 | feedforward neural networks,'' in {\em Proceedings of the Thirteenth 97 | International Conference on Artificial Intelligence and Statistics}, 98 | pp.~249--256, 2010. 99 | 100 | \bibitem{he2015delving} 101 | K.~He, X.~Zhang, S.~Ren, and J.~Sun, ``Delving deep into rectifiers: Surpassing 102 | human-level performance on imagenet classification,'' in {\em Proceedings of 103 | the IEEE International Conference on Computer Vision}, pp.~1026--1034, 2015. 104 | 105 | \bibitem{turner2002} 106 | R.~Piziak and D.~Turner, ``The polar form of a quaternion,'' 2002. 107 | 108 | \bibitem{nair2010rectified} 109 | V.~Nair and G.~E. Hinton, ``Rectified linear units improve restricted boltzmann 110 | machines,'' in {\em Proceedings of the 27th International Conference on 111 | Machine Learning (ICML-10)}, pp.~807--814, 2010. 112 | 113 | \bibitem{krizhevsky2009learning} 114 | A.~Krizhevsky and G.~Hinton, ``Learning multiple layers of features from tiny 115 | images,'' 2009. 116 | 117 | \bibitem{Fritsch2013ITSC} 118 | J.~Fritsch, T.~Kuehnl, and A.~Geiger, ``A new performance measure and 119 | evaluation benchmark for road detection algorithms,'' in {\em International 120 | Conference on Intelligent Transportation Systems (ITSC)}, 2013. 121 | 122 | \bibitem{nesterov1983method} 123 | Y.~Nesterov, ``A method of solving a convex programming problem with 124 | convergence rate o (1/k2),'' 125 | 126 | \bibitem{kessy2017optimal} 127 | A.~Kessy, A.~Lewin, and K.~Strimmer, ``Optimal whitening and decorrelation,'' 128 | {\em The American Statistician}, no.~just-accepted, 2017. 129 | 130 | \end{thebibliography} 131 | -------------------------------------------------------------------------------- /paper/ieee_version/DeepQuaternionNets.blg: -------------------------------------------------------------------------------- 1 | This is BibTeX, Version 0.99dThe top-level auxiliary file: DeepQuaternionNets.aux 2 | The style file: ieeetr.bst 3 | Database file #1: bib.bib 4 | Repeated entry---line 659 of file bib.bib 5 | : @article{clevert2015fast 6 | : , 7 | I'm skipping whatever remains of this entry 8 | Repeated entry---line 692 of file bib.bib 9 | : @inproceedings{ioffe2015batch 10 | : , 11 | I'm skipping whatever remains of this entry 12 | Repeated entry---line 853 of file bib.bib 13 | : @inproceedings{he2015delving 14 | : , 15 | I'm skipping whatever remains of this entry 16 | Repeated entry---line 861 of file bib.bib 17 | : @inproceedings{nair2010rectified 18 | : , 19 | I'm skipping whatever remains of this entry 20 | Warning--empty booktitle in turner2002 21 | Warning--empty journal in krizhevsky2009learning 22 | Warning--empty booktitle in nesterov1983method 23 | Warning--empty year in nesterov1983method 24 | (There were 4 error messages) 25 | -------------------------------------------------------------------------------- /paper/ieee_version/DeepQuaternionNets.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/paper/ieee_version/DeepQuaternionNets.pdf -------------------------------------------------------------------------------- /paper/ieee_version/quatconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/paper/ieee_version/quatconv.png -------------------------------------------------------------------------------- /quatconv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quatconv.png -------------------------------------------------------------------------------- /quaternion_layers/__pycache__/bn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quaternion_layers/__pycache__/bn.cpython-36.pyc -------------------------------------------------------------------------------- /quaternion_layers/__pycache__/conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quaternion_layers/__pycache__/conv.cpython-36.pyc -------------------------------------------------------------------------------- /quaternion_layers/__pycache__/dist.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quaternion_layers/__pycache__/dist.cpython-36.pyc -------------------------------------------------------------------------------- /quaternion_layers/__pycache__/init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quaternion_layers/__pycache__/init.cpython-36.pyc -------------------------------------------------------------------------------- /quaternion_layers/__pycache__/norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quaternion_layers/__pycache__/norm.cpython-36.pyc -------------------------------------------------------------------------------- /quaternion_layers/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/quaternion_layers/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /quaternion_layers/dist.py: -------------------------------------------------------------------------------- 1 | # Authors: Chase Gaudet 2 | 3 | 4 | import pylab 5 | import numpy 6 | 7 | class Chi4Random(object): 8 | def __init__(self, o, x = pylab.arange(0, 5, .01), Nrl = 1000): 9 | """Initialize the lookup table (with default values if necessary) 10 | Inputs: 11 | o = scale of dist > 0 12 | x = random number values 13 | Nrl = number of reverse look up values between 0 and 1""" 14 | p = (x**3.0/(2.0*o**4.0))*pylab.exp(-x**2.0/(2*o**2.0)) 15 | self.set_pdf(x, p, Nrl) 16 | 17 | def set_pdf(self, x, p, Nrl = 1000): 18 | """Generate the lookup tables. 19 | x is the value of the random variate 20 | pdf is its probability density 21 | cdf is the cumulative pdf 22 | inversecdf is the inverse look up table 23 | """ 24 | 25 | self.x = x 26 | self.pdf = p/p.sum() #normalize it 27 | self.cdf = self.pdf.cumsum() 28 | self.inversecdfbins = Nrl 29 | self.Nrl = Nrl 30 | y = pylab.arange(Nrl)/float(Nrl) 31 | delta = 1.0/Nrl 32 | self.inversecdf = pylab.zeros(Nrl) 33 | self.inversecdf[0] = self.x[0] 34 | cdf_idx = 0 35 | for n in range(1,self.inversecdfbins): 36 | while self.cdf[cdf_idx] < y[n] and cdf_idx < Nrl: 37 | cdf_idx += 1 38 | self.inversecdf[n] = self.x[cdf_idx-1] + (self.x[cdf_idx] - self.x[cdf_idx-1]) * (y[n] - self.cdf[cdf_idx-1])/(self.cdf[cdf_idx] - self.cdf[cdf_idx-1]) 39 | if cdf_idx >= Nrl: 40 | break 41 | self.delta_inversecdf = pylab.concatenate((pylab.diff(self.inversecdf), [0])) 42 | 43 | def random(self, N = 1000): 44 | """Give us N random numbers with the requested distribution""" 45 | 46 | idx_f = numpy.random.uniform(size = N, high = self.Nrl-1) 47 | idx = pylab.array([idx_f],'i') 48 | y = self.inversecdf[idx] + (idx_f - idx)*self.delta_inversecdf[idx] 49 | 50 | return y 51 | 52 | def plot_pdf(self): 53 | pylab.plot(self.x, self.pdf) 54 | pylab.show() 55 | 56 | def self_test(self, N = 1000): 57 | pylab.figure() 58 | #The cdf 59 | pylab.subplot(2,2,1) 60 | pylab.plot(self.x, self.cdf) 61 | #The inverse cdf 62 | pylab.subplot(2,2,2) 63 | y = pylab.arange(self.Nrl)/float(self.Nrl) 64 | pylab.plot(y, self.inversecdf) 65 | 66 | #The actual generated numbers 67 | pylab.subplot(2,2,3) 68 | y = self.random(N) 69 | p1, edges = pylab.histogram(y, bins = 50, 70 | range = (self.x.min(), self.x.max()), 71 | normed = True) 72 | x1 = 0.5*(edges[0:-1] + edges[1:]) 73 | pylab.plot(x1, p1/p1.max()) 74 | pylab.plot(self.x, self.pdf/self.pdf.max()) 75 | pylab.show() 76 | 77 | 78 | if __name__ == '__main__': 79 | chi4 = Chi4Random(1) 80 | chi4.plot_pdf() 81 | chi4.self_test(N=1000) -------------------------------------------------------------------------------- /quaternion_layers/init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chase Gaudet 6 | # code based on work by Chiheb Trabelsi 7 | # on Deep Complex Networks git source 8 | 9 | import numpy as np 10 | import scipy.stats as st 11 | from random import gauss 12 | from numpy.random import RandomState 13 | from .dist import Chi4Random 14 | import keras.backend as K 15 | from keras import initializers 16 | from keras.initializers import Initializer 17 | from keras.utils.generic_utils import (serialize_keras_object, 18 | deserialize_keras_object) 19 | 20 | 21 | class QuaternionInit(Initializer): 22 | # The standard quaternion initialization using 23 | # either the He or the Glorot criterion. 24 | def __init__(self, kernel_size, input_dim, 25 | weight_dim, nb_filters=None, 26 | criterion='he', seed=None): 27 | 28 | # `weight_dim` is used as a parameter for sanity check 29 | # as we should not pass an integer as kernel_size when 30 | # the weight dimension is >= 2. 31 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 32 | # then in such a case, weight_dim = 2. 33 | # (in case of 2D input): 34 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 35 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 36 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 37 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 38 | 39 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 40 | self.nb_filters = nb_filters 41 | self.kernel_size = kernel_size 42 | self.input_dim = input_dim 43 | self.weight_dim = weight_dim 44 | self.criterion = criterion 45 | self.seed = 31337 if seed is None else seed 46 | 47 | def __call__(self, shape, dtype=None): 48 | 49 | if self.nb_filters is not None: 50 | kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 51 | else: 52 | kernel_shape = (int(self.input_dim), self.kernel_size[-1]) 53 | 54 | fan_in, fan_out = initializers._compute_fans( 55 | tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 56 | ) 57 | 58 | if self.criterion == 'glorot': 59 | s = 1. / np.sqrt(2*(fan_in + fan_out)) 60 | elif self.criterion == 'he': 61 | s = 1. / np.sqrt(2*fan_in) 62 | else: 63 | raise ValueError('Invalid criterion: ' + self.criterion) 64 | rng = Chi4Random(s) 65 | flat_size = np.product(kernel_shape) 66 | modulus = rng.random(N=flat_size) 67 | modulus = modulus.reshape(kernel_shape) 68 | rng = RandomState(self.seed) 69 | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) 70 | 71 | # must make random unit vector for quaternion vector 72 | def make_rand_vector(dims): 73 | vec = [gauss(0, 1) for i in range(dims)] 74 | mag = sum(x**2 for x in vec) ** 0.5 75 | return [x/mag for x in vec] 76 | 77 | u_i = np.zeros(flat_size) 78 | u_j = np.zeros(flat_size) 79 | u_k = np.zeros(flat_size) 80 | for u in range(flat_size): 81 | unit = make_rand_vector(3) 82 | u_i[u] = unit[0] 83 | u_j[u] = unit[1] 84 | u_k[u] = unit[2] 85 | u_i = u_i.reshape(kernel_shape) 86 | u_j = u_j.reshape(kernel_shape) 87 | u_k = u_k.reshape(kernel_shape) 88 | 89 | weight_r = modulus * np.cos(phase) 90 | weight_i = modulus * u_i*np.sin(phase) 91 | weight_j = modulus * u_j*np.sin(phase) 92 | weight_k = modulus * u_k*np.sin(phase) 93 | weight = np.concatenate([weight_r, weight_i, weight_j, weight_k], axis=-1) 94 | 95 | return weight 96 | 97 | 98 | class SqrtInit(Initializer): 99 | def __call__(self, shape, dtype=None): 100 | return K.constant(1 / K.sqrt(16), shape=shape, dtype=dtype) 101 | 102 | 103 | # Aliases: 104 | sqrt_init = SqrtInit 105 | quaternion_init = QuaternionInit -------------------------------------------------------------------------------- /quaternion_layers/norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chase Gaudet 6 | # code based on work by Chiheb Trabelsi 7 | # on Deep Complex Networks git source 8 | # 9 | # Implementation of Layer Normalization and Quaternion Layer Normalization 10 | 11 | 12 | import numpy as np 13 | from keras.layers import Layer, InputSpec 14 | from keras import initializers, regularizers, constraints 15 | import keras.backend as K 16 | from .bn import QuaternionBN as quaternion_normalization 17 | from .bn import sqrt_init 18 | 19 | def layernorm(x, axis, epsilon, gamma, beta): 20 | # assert self.built, 'Layer must be built before being called' 21 | input_shape = K.shape(x) 22 | reduction_axes = list(range(K.ndim(x))) 23 | del reduction_axes[axis] 24 | del reduction_axes[0] 25 | broadcast_shape = [1] * K.ndim(x) 26 | broadcast_shape[axis] = input_shape[axis] 27 | broadcast_shape[0] = K.shape(x)[0] 28 | 29 | # Perform normalization: centering and reduction 30 | 31 | mean = K.mean(x, axis=reduction_axes) 32 | broadcast_mean = K.reshape(mean, broadcast_shape) 33 | x_centred = x - broadcast_mean 34 | variance = K.mean(x_centred ** 2, axis=reduction_axes) + epsilon 35 | broadcast_variance = K.reshape(variance, broadcast_shape) 36 | 37 | x_normed = x_centred / K.sqrt(broadcast_variance) 38 | 39 | # Perform scaling and shifting 40 | 41 | broadcast_shape_params = [1] * K.ndim(x) 42 | broadcast_shape_params[axis] = K.shape(x)[axis] 43 | broadcast_gamma = K.reshape(gamma, broadcast_shape_params) 44 | broadcast_beta = K.reshape(beta, broadcast_shape_params) 45 | 46 | x_LN = broadcast_gamma * x_normed + broadcast_beta 47 | 48 | return x_LN 49 | 50 | class LayerNormalization(Layer): 51 | 52 | def __init__(self, 53 | epsilon=1e-4, 54 | axis=-1, 55 | beta_init='zeros', 56 | gamma_init='ones', 57 | gamma_regularizer=None, 58 | beta_regularizer=None, 59 | **kwargs): 60 | 61 | self.supports_masking = True 62 | self.beta_init = initializers.get(beta_init) 63 | self.gamma_init = initializers.get(gamma_init) 64 | self.epsilon = epsilon 65 | self.axis = axis 66 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 67 | self.beta_regularizer = regularizers.get(beta_regularizer) 68 | 69 | super(LayerNormalization, self).__init__(**kwargs) 70 | 71 | def build(self, input_shape): 72 | self.input_spec = InputSpec(ndim=len(input_shape), 73 | axes={self.axis: input_shape[self.axis]}) 74 | shape = (input_shape[self.axis],) 75 | 76 | self.gamma = self.add_weight(shape, 77 | initializer=self.gamma_init, 78 | regularizer=self.gamma_regularizer, 79 | name='{}_gamma'.format(self.name)) 80 | self.beta = self.add_weight(shape, 81 | initializer=self.beta_init, 82 | regularizer=self.beta_regularizer, 83 | name='{}_beta'.format(self.name)) 84 | 85 | self.built = True 86 | 87 | def call(self, x, mask=None): 88 | assert self.built, 'Layer must be built before being called' 89 | return layernorm(x, self.axis, self.epsilon, self.gamma, self.beta) 90 | 91 | def get_config(self): 92 | config = {'epsilon': self.epsilon, 93 | 'axis': self.axis, 94 | 'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None, 95 | 'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None 96 | } 97 | base_config = super(LayerNormalization, self).get_config() 98 | return dict(list(base_config.items()) + list(config.items())) 99 | 100 | 101 | class QuaternionLayerNorm(Layer): 102 | def __init__(self, 103 | epsilon=1e-4, 104 | axis=-1, 105 | center=True, 106 | scale=True, 107 | beta_initializer='zeros', 108 | gamma_diag_initializer=sqrt_init, 109 | gamma_off_initializer='zeros', 110 | beta_regularizer=None, 111 | gamma_diag_regularizer=None, 112 | gamma_off_regularizer=None, 113 | beta_constraint=None, 114 | gamma_diag_constraint=None, 115 | gamma_off_constraint=None, 116 | **kwargs): 117 | 118 | self.supports_masking = True 119 | self.epsilon = epsilon 120 | self.axis = axis 121 | self.center = center 122 | self.scale = scale 123 | self.beta_initializer = initializers.get(beta_initializer) 124 | self.gamma_diag_initializer = initializers.get(gamma_diag_initializer) 125 | self.gamma_off_initializer = initializers.get(gamma_off_initializer) 126 | self.beta_regularizer = regularizers.get(beta_regularizer) 127 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 128 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 129 | self.beta_constraint = constraints.get(beta_constraint) 130 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 131 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 132 | super(QuaternionLayerNorm, self).__init__(**kwargs) 133 | 134 | def build(self, input_shape): 135 | 136 | ndim = len(input_shape) 137 | dim = input_shape[self.axis] 138 | if dim is None: 139 | raise ValueError('Axis ' + str(self.axis) + ' of ' 140 | 'input tensor should have a defined dimension ' 141 | 'but the layer received an input with shape ' + 142 | str(input_shape) + '.') 143 | self.input_spec = InputSpec(ndim=len(input_shape), 144 | axes={self.axis: dim}) 145 | 146 | gamma_shape = (input_shape[self.axis] // 4,) 147 | if self.scale: 148 | self.gamma_rr = self.add_weight( 149 | shape=gamma_shape, 150 | name='gamma_rr', 151 | initializer=self.gamma_diag_initializer, 152 | regularizer=self.gamma_diag_regularizer, 153 | constraint=self.gamma_diag_constraint 154 | ) 155 | self.gamma_ii = self.add_weight( 156 | shape=gamma_shape, 157 | name='gamma_ii', 158 | initializer=self.gamma_diag_initializer, 159 | regularizer=self.gamma_diag_regularizer, 160 | constraint=self.gamma_diag_constraint 161 | ) 162 | self.gamma_jj = self.add_weight( 163 | shape=gamma_shape, 164 | name='gamma_jj', 165 | initializer=self.gamma_diag_initializer, 166 | regularizer=self.gamma_diag_regularizer, 167 | constraint=self.gamma_diag_constraint 168 | ) 169 | self.gamma_kk = self.add_weight( 170 | shape=gamma_shape, 171 | name='gamma_kk', 172 | initializer=self.gamma_diag_initializer, 173 | regularizer=self.gamma_diag_regularizer, 174 | constraint=self.gamma_diag_constraint 175 | ) 176 | self.gamma_ri = self.add_weight( 177 | shape=gamma_shape, 178 | name='gamma_ri', 179 | initializer=self.gamma_off_initializer, 180 | regularizer=self.gamma_off_regularizer, 181 | constraint=self.gamma_off_constraint 182 | ) 183 | self.gamma_rj = self.add_weight( 184 | shape=gamma_shape, 185 | name='gamma_rj', 186 | initializer=self.gamma_off_initializer, 187 | regularizer=self.gamma_off_regularizer, 188 | constraint=self.gamma_off_constraint 189 | ) 190 | self.gamma_rk = self.add_weight( 191 | shape=gamma_shape, 192 | name='gamma_rk', 193 | initializer=self.gamma_off_initializer, 194 | regularizer=self.gamma_off_regularizer, 195 | constraint=self.gamma_off_constraint 196 | ) 197 | self.gamma_ij = self.add_weight( 198 | shape=gamma_shape, 199 | name='gamma_ij', 200 | initializer=self.gamma_off_initializer, 201 | regularizer=self.gamma_off_regularizer, 202 | constraint=self.gamma_off_constraint 203 | ) 204 | self.gamma_ik = self.add_weight( 205 | shape=gamma_shape, 206 | name='gamma_ik', 207 | initializer=self.gamma_off_initializer, 208 | regularizer=self.gamma_off_regularizer, 209 | constraint=self.gamma_off_constraint 210 | ) 211 | self.gamma_jk = self.add_weight( 212 | shape=gamma_shape, 213 | name='gamma_jk', 214 | initializer=self.gamma_off_initializer, 215 | regularizer=self.gamma_off_regularizer, 216 | constraint=self.gamma_off_constraint 217 | ) 218 | else: 219 | self.gamma_rr = None 220 | self.gamma_ii = None 221 | self.gamma_jj = None 222 | self.gamma_kk = None 223 | self.gamma_ri = None 224 | self.gamma_rj = None 225 | self.gamma_rk = None 226 | self.gamma_ij = None 227 | self.gamma_ik = None 228 | self.gamma_jk = None 229 | 230 | if self.center: 231 | self.beta = self.add_weight(shape=(input_shape[self.axis],), 232 | name='beta', 233 | initializer=self.beta_initializer, 234 | regularizer=self.beta_regularizer, 235 | constraint=self.beta_constraint) 236 | else: 237 | self.beta = None 238 | 239 | self.built = True 240 | 241 | def call(self, inputs): 242 | input_shape = K.shape(inputs) 243 | ndim = K.ndim(inputs) 244 | reduction_axes = list(range(ndim)) 245 | del reduction_axes[self.axis] 246 | del reduction_axes[0] 247 | input_dim = input_shape[self.axis] // 4 248 | mu = K.mean(inputs, axis=reduction_axes) 249 | broadcast_mu_shape = [1] * ndim 250 | broadcast_mu_shape[self.axis] = input_shape[self.axis] 251 | broadcast_mu_shape[0] = K.shape(inputs)[0] 252 | broadcast_mu = K.reshape(mu, broadcast_mu_shape) 253 | if self.center: 254 | input_centred = inputs - broadcast_mu 255 | else: 256 | input_centred = inputs 257 | centred_squared = input_centred ** 2 258 | if (self.axis == 1 and ndim != 3) or ndim == 2: 259 | centred_squared_r = centred_squared[:, :input_dim] 260 | centred_squared_i = centred_squared[:, input_dim:input_dim*2] 261 | centred_squared_j = centred_squared[:, input_dim*2:input_dim*3] 262 | centred_squared_k = centred_squared[:, input_dim*3:] 263 | centred_r = input_centred[:, :input_dim] 264 | centred_i = input_centred[:, input_dim:input_dim*2] 265 | centred_j = input_centred[:, input_dim*2:input_dim*3] 266 | centred_k = input_centred[:, input_dim*3:] 267 | elif ndim == 3: 268 | centred_squared_r = centred_squared[:, :, :input_dim] 269 | centred_squared_i = centred_squared[:, :, input_dim:input_dim*2] 270 | centred_squared_j = centred_squared[:, :, input_dim*2:input_dim*3] 271 | centred_squared_k = centred_squared[:, :, input_dim*3:] 272 | centred_r = input_centred[:, :, :input_dim] 273 | centred_i = input_centred[:, :, input_dim:input_dim*2] 274 | centred_j = input_centred[:, :, input_dim*2:input_dim*3] 275 | centred_k = input_centred[:, :, input_dim*3:] 276 | elif self.axis == -1 and ndim == 4: 277 | centred_squared_r = centred_squared[:, :, :, :input_dim] 278 | centred_squared_i = centred_squared[:, :, :, input_dim:input_dim*2] 279 | centred_squared_j = centred_squared[:, :, :, input_dim*2:input_dim*3] 280 | centred_squared_k = centred_squared[:, :, :, input_dim*3:] 281 | centred_r = input_centred[:, :, :, :input_dim] 282 | centred_i = input_centred[:, :, :, input_dim:input_dim*2] 283 | centred_j = input_centred[:, :, :, input_dim*2:input_dim*3] 284 | centred_k = input_centred[:, :, :, input_dim*3:] 285 | elif self.axis == -1 and ndim == 5: 286 | centred_squared_r = centred_squared[:, :, :, :, :input_dim] 287 | centred_squared_i = centred_squared[:, :, :, :, input_dim:input_dim*2] 288 | centred_squared_j = centred_squared[:, :, :, :, input_dim*2:input_dim*3] 289 | centred_squared_k = centred_squared[:, :, :, :, input_dim*3:] 290 | centred_r = input_centred[:, :, :, :, :input_dim] 291 | centred_i = input_centred[:, :, :, :, input_dim:input_dim*2] 292 | centred_j = input_centred[:, :, :, :, input_dim*2:input_dim*3] 293 | centred_k = input_centred[:, :, :, :, input_dim*3:] 294 | else: 295 | raise ValueError( 296 | 'Incorrect Layernorm combination of axis and dimensions. axis should be either 1 or -1. ' 297 | 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.' 298 | ) 299 | if self.scale: 300 | Vrr = K.mean( 301 | centred_squared_r, 302 | axis=reduction_axes 303 | ) + self.epsilon 304 | Vii = K.mean( 305 | centred_squared_i, 306 | axis=reduction_axes 307 | ) + self.epsilon 308 | Vjj = K.mean( 309 | centred_squared_j, 310 | axis=reduction_axes 311 | ) + self.epsilon 312 | Vkk = K.mean( 313 | centred_squared_k, 314 | axis=reduction_axes 315 | ) + self.epsilon 316 | Vri = K.mean( 317 | centred_r * centred_i, 318 | axis=reduction_axes, 319 | ) 320 | Vrj = K.mean( 321 | centred_r * centred_j, 322 | axis=reduction_axes, 323 | ) 324 | Vrk = K.mean( 325 | centred_r * centred_k, 326 | axis=reduction_axes, 327 | ) 328 | Vij = K.mean( 329 | centred_i * centred_j, 330 | axis=reduction_axes, 331 | ) 332 | Vik = K.mean( 333 | centred_i * centred_k, 334 | axis=reduction_axes, 335 | ) 336 | Vjk = K.mean( 337 | centred_j * centred_k, 338 | axis=reduction_axes, 339 | ) 340 | elif self.center: 341 | Vrr = None 342 | Vii = None 343 | Vjj = None 344 | Vkk = None 345 | Vri = None 346 | Vrj = None 347 | Vrk = None 348 | Vij = None 349 | Vik = None 350 | Vkk = None 351 | else: 352 | raise ValueError('Error. Both scale and center in batchnorm are set to False.') 353 | 354 | return quaternion_normalization( 355 | input_centred, 356 | Vrr, Vri, Vrj, Vrk, Vii, 357 | Vij, Vik, Vjj, Vjk, Vkk, 358 | self.beta, 359 | self.gamma_rr, self.gamma_ri, 360 | self.gamma_rj, self.gamma_rk, 361 | self.gamma_ii, self.gamma_ij, 362 | self.gamma_ik, self.gamma_jj, 363 | self.gamma_jk, self.gamma_kk, 364 | self.scale, self.center, 365 | layernorm=True, axis=self.axis 366 | ) 367 | 368 | def get_config(self): 369 | config = { 370 | 'axis': self.axis, 371 | 'epsilon': self.epsilon, 372 | 'center': self.center, 373 | 'scale': self.scale, 374 | 'beta_initializer': initializers.serialize(self.beta_initializer), 375 | 'gamma_diag_initializer': initializers.serialize(self.gamma_diag_initializer), 376 | 'gamma_off_initializer': initializers.serialize(self.gamma_off_initializer), 377 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 378 | 'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer), 379 | 'gamma_off_regularizer': regularizers.serialize(self.gamma_off_regularizer), 380 | 'beta_constraint': constraints.serialize(self.beta_constraint), 381 | 'gamma_diag_constraint': constraints.serialize(self.gamma_diag_constraint), 382 | 'gamma_off_constraint': constraints.serialize(self.gamma_off_constraint), 383 | } 384 | base_config = super(QuaternionLayerNorm, self).get_config() 385 | return dict(list(base_config.items()) + list(config.items())) 386 | -------------------------------------------------------------------------------- /quaternion_layers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import keras.backend as K 5 | from keras.layers import Lambda, Layer 6 | 7 | 8 | class Params: 9 | def __init__(self, dictionary): 10 | for k, v in dictionary.items(): 11 | setattr(self, k, v) 12 | 13 | 14 | def get_r(x): 15 | image_format = K.image_data_format() 16 | ndim = K.ndim(x) 17 | input_shape = K.shape(x) 18 | 19 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 20 | input_dim = input_shape[1] // 4 21 | return x[:, :input_dim] 22 | 23 | input_dim = input_shape[1] // 4 24 | if ndim == 3: 25 | return x[:, :, :input_dim] 26 | elif ndim == 4: 27 | return x[:, :, :, :input_dim] 28 | elif ndim == 5: 29 | return x[:, :, :, :, :input_dim] 30 | 31 | 32 | def get_i(x): 33 | image_format = K.image_data_format() 34 | ndim = K.ndim(x) 35 | input_shape = K.shape(x) 36 | 37 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 38 | input_dim = input_shape[1] // 4 39 | return x[:, input_dim:input_dim*2] 40 | 41 | input_dim = input_shape[1] // 4 42 | if ndim == 3: 43 | return x[:, :, input_dim:input_dim*2] 44 | elif ndim == 4: 45 | return x[:, :, :, input_dim:input_dim*2] 46 | elif ndim == 5: 47 | return x[:, :, :, :, input_dim:input_dim*2] 48 | 49 | 50 | def get_j(x): 51 | image_format = K.image_data_format() 52 | ndim = K.ndim(x) 53 | input_shape = K.shape(x) 54 | 55 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 56 | input_dim = input_shape[1] // 4 57 | return x[:, input_dim*2:input_dim*3] 58 | 59 | input_dim = input_shape[1] // 4 60 | if ndim == 3: 61 | return x[:, :, input_dim*2:input_dim*3] 62 | elif ndim == 4: 63 | return x[:, :, :, input_dim*2:input_dim*3] 64 | elif ndim == 5: 65 | return x[:, :, :, :, input_dim*2:input_dim*3] 66 | 67 | 68 | def get_k(x): 69 | image_format = K.image_data_format() 70 | ndim = K.ndim(x) 71 | input_shape = K.shape(x) 72 | 73 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 74 | input_dim = input_shape[1] // 4 75 | return x[:, input_dim*3:] 76 | 77 | input_dim = input_shape[1] // 4 78 | if ndim == 3: 79 | return x[:, :, input_dim*3:] 80 | elif ndim == 4: 81 | return x[:, :, :, input_dim*3:] 82 | elif ndim == 5: 83 | return x[:, :, :, :, input_dim*3:] 84 | 85 | 86 | def getpart_output_shape(input_shape): 87 | returned_shape = list(input_shape[:]) 88 | image_format = K.image_data_format() 89 | ndim = len(returned_shape) 90 | 91 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 92 | axis = 1 93 | else: 94 | axis = -1 95 | 96 | returned_shape[1] = returned_shape[1] // 4 97 | 98 | return tuple(returned_shape) 99 | 100 | class GetR(Layer): 101 | def call(self, inputs): 102 | return get_r(inputs) 103 | def compute_output_shape(self, input_shape): 104 | return getpart_output_shape(input_shape) 105 | 106 | class GetI(Layer): 107 | def call(self, inputs): 108 | return get_i(inputs) 109 | def compute_output_shape(self, input_shape): 110 | return getpart_output_shape(input_shape) 111 | 112 | class GetJ(Layer): 113 | def call(self, inputs): 114 | return get_j(inputs) 115 | def compute_output_shape(self, input_shape): 116 | return getpart_output_shape(input_shape) 117 | 118 | class GetK(Layer): 119 | def call(self, inputs): 120 | return get_k(inputs) 121 | def compute_output_shape(self, input_shape): 122 | return getpart_output_shape(input_shape) -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | # This file is used to run the CIFAR and KITTI experiments easily with different hyper paramters 2 | # Note that the MNIST experiment has its own runner since no hyper parameter exploration was used 3 | 4 | from training_classification import train as train_c 5 | from training_classification import getModel as model_c 6 | from training_segmentation import train as train_s 7 | from training_segmentation import getModel as model_s 8 | from quaternion_layers.utils import Params 9 | import click 10 | import numpy as np 11 | np.random.seed(314) 12 | 13 | @click.command() 14 | @click.argument('task') 15 | @click.option('--mode', default='quaternion', help='value type of model (real, complex, quaternion)') 16 | @click.option('--num-blocks', '-nb', default=2, help='number of residual blocks per stage') 17 | @click.option('--start-filters', '-sf', default=8, help='number of filters in first layer') 18 | @click.option('--dropout', '-d', default=0, help='dropout percent') 19 | @click.option('--batch-size', '-bs', default=8, help='batch size') 20 | @click.option('--num-epochs', '-e', default=200, help='total number of epochs') 21 | @click.option('--dataset', '-ds', default='cifar10', help='dataset to train and test on') 22 | @click.option('--activation', '-act', default='relu', help='activation function to use') 23 | @click.option('--initialization', '-init', default='quaternion', help='initialization scheme to use') 24 | @click.option('--learning-rate', '-lr', default=1e-3, help='learning rate for optimizer') 25 | @click.option('--momentum', '-mn', default=0.9, help='momentum for batch norm') 26 | @click.option('--decay', '-dc', default=0, help='decay rate of optimizer') 27 | @click.option('--clipnorm', '-cn', default=1.0, help='maximum gradient size') 28 | def runner(task, mode, num_blocks, start_filters, dropout, batch_size, num_epochs, dataset, 29 | activation, initialization, learning_rate, momentum, decay, clipnorm): 30 | 31 | param_dict = {"mode": mode, 32 | "num_blocks": num_blocks, 33 | "start_filter": start_filters, 34 | "dropout": dropout, 35 | "batch_size": batch_size, 36 | "num_epochs": num_epochs, 37 | "dataset": dataset, 38 | "act": activation, 39 | "init": initialization, 40 | "lr": learning_rate, 41 | "momentum": momentum, 42 | "decay": decay, 43 | "clipnorm": clipnorm 44 | } 45 | 46 | params = Params(param_dict) 47 | 48 | if task == 'classification': 49 | model = model_c(params) 50 | print() 51 | print(model.count_params()) 52 | train_c(params, model) 53 | elif task == 'segmentation': 54 | model = model_s(params) 55 | print() 56 | print(model.count_params()) 57 | train_s(params, model) 58 | else: 59 | print("Invalid task chosen...") 60 | 61 | 62 | if __name__ == '__main__': 63 | runner() -------------------------------------------------------------------------------- /scripts/complex_seg_train_loss.txt: -------------------------------------------------------------------------------- 1 | -1.316444612542788106e-01 2 | -2.142477285861968950e-01 3 | -3.449339691797892460e-01 4 | -4.444746887683868608e-01 5 | -4.959106159210204812e-01 6 | -5.088644226392110559e-01 7 | -5.479192670186360248e-01 8 | -5.636379094918568811e-01 9 | -5.997866884867349979e-01 10 | -5.754718605677286991e-01 11 | -5.981876031557719076e-01 12 | -5.951076332728068019e-01 13 | -5.992043916384378610e-01 14 | -5.799210000038147461e-01 15 | -6.214295339584350319e-01 16 | -5.925392444928486713e-01 17 | -6.028120040893554243e-01 18 | -6.280264441172281886e-01 19 | -5.726023368040720207e-01 20 | -6.298162476221720318e-01 21 | -5.879226827621459872e-01 22 | -5.935890889167785289e-01 23 | -6.156738313039143984e-01 24 | -6.286487054824828613e-01 25 | -6.265241479873657759e-01 26 | -6.279180065790812559e-01 27 | -6.284716248512267622e-01 28 | -6.404337938626607496e-01 29 | -6.395575674374898156e-01 30 | -6.103426845868428208e-01 31 | -6.416524116198222272e-01 32 | -6.209528549512227658e-01 33 | -6.275362888971964814e-01 34 | -6.276567641894023053e-01 35 | -6.292877666155497485e-01 36 | -6.471755838394165350e-01 37 | -6.481765627861022505e-01 38 | -6.219868238766987911e-01 39 | -6.502263482411702045e-01 40 | -6.360168560345967892e-01 41 | -6.226005180676777995e-01 42 | -6.472029495239257635e-01 43 | -6.533526627222696614e-01 44 | -6.524537968635558682e-01 45 | -6.506009149551391646e-01 46 | -6.178070664405822532e-01 47 | -6.399952268600463956e-01 48 | -6.599463534355163308e-01 49 | -6.179832649230957431e-01 50 | -6.519788773854573449e-01 51 | -6.428709220886230646e-01 52 | -6.471774419148762503e-01 53 | -6.507618196805318167e-01 54 | -6.359595227241515714e-01 55 | -6.600956058502197576e-01 56 | -6.379603370030720866e-01 57 | -6.512233066558837935e-01 58 | -6.273402023315429732e-01 59 | -6.476984699567158721e-01 60 | -6.498802542686462624e-01 61 | -6.413920728365579693e-01 62 | -6.539314174652099965e-01 63 | -6.461370873451233177e-01 64 | -6.711377739906311257e-01 65 | -6.582990169525146928e-01 66 | -6.733598534266154179e-01 67 | -6.722828308741252146e-01 68 | -6.490377147992452134e-01 69 | -6.707902844746908011e-01 70 | -6.777409458160400524e-01 71 | -6.759284218152363799e-01 72 | -6.730472318331400805e-01 73 | -6.571633736292521455e-01 74 | -6.597010922431946200e-01 75 | -6.754051907857259174e-01 76 | -6.652164816856384055e-01 77 | -6.800930253664652092e-01 78 | -6.892747068405151722e-01 79 | -6.712222973505655998e-01 80 | -6.661599238713582505e-01 81 | -6.796074914932250799e-01 82 | -6.697797266642252501e-01 83 | -6.674935309092203672e-01 84 | -6.721749194463093602e-01 85 | -6.843930490811666045e-01 86 | -6.917367124557495250e-01 87 | -6.918470517794290675e-01 88 | -6.963750116030374970e-01 89 | -6.820187648137410186e-01 90 | -6.796289610862732422e-01 91 | -6.972544407844543102e-01 92 | -7.005644663174946940e-01 93 | -7.031388719876606741e-01 94 | -7.039087136586507087e-01 95 | -6.788132858276366699e-01 96 | -6.729428275426229122e-01 97 | -6.866804369290669952e-01 98 | -6.771849012374877574e-01 99 | -7.075428668657938180e-01 100 | -6.952029649416605617e-01 101 | -6.951919301350911740e-01 102 | -7.008601498603820312e-01 103 | -7.122010302543639870e-01 104 | -7.202417182922363770e-01 105 | -7.100151403745015299e-01 106 | -7.068377391497293605e-01 107 | -7.096370172500610085e-01 108 | -7.010213057200114228e-01 109 | -6.978131516774495058e-01 110 | -6.839487298329671727e-01 111 | -6.871273406346638524e-01 112 | -6.917988101641336929e-01 113 | -6.972371395428975660e-01 114 | -6.962807043393453021e-01 115 | -7.010528731346130638e-01 116 | -7.043692763646444144e-01 117 | -7.182009434700011852e-01 118 | -7.063903379440307218e-01 119 | -6.866708366076151870e-01 120 | -7.181469639142353811e-01 121 | -6.885582502683004202e-01 122 | -7.115436013539632176e-01 123 | -7.150081539154052868e-01 124 | -7.299395593007406147e-01 125 | -7.030381313959757072e-01 126 | -7.140191984176635920e-01 127 | -7.267610319455464207e-01 128 | -7.192029571533202770e-01 129 | -7.163842121760050086e-01 130 | -7.040056546529134263e-01 131 | -7.167679635683695816e-01 132 | -6.961467663447061804e-01 133 | -6.934134014447530525e-01 134 | -7.013730518023173488e-01 135 | -7.170191081364949426e-01 136 | -7.050384720166523733e-01 137 | -7.093344561258951853e-01 138 | -6.955530722935994170e-01 139 | -7.210782909393310680e-01 140 | -7.106213871637979684e-01 141 | -7.055767432848611964e-01 142 | -7.284316809972127826e-01 143 | -7.018150393168131052e-01 144 | -7.463651363054911725e-01 145 | -7.266874790191650835e-01 146 | -7.201062758763631111e-01 147 | -7.154552173614502353e-01 148 | -7.143126360575358103e-01 149 | -6.987678408622741699e-01 150 | -7.084999656677246627e-01 151 | -7.260527698198954161e-01 152 | -7.271991340319315933e-01 153 | -7.272102435429891321e-01 154 | -7.011373766263325757e-01 155 | -7.277880930900573198e-01 156 | -7.096912733713786059e-01 157 | -7.022521026929219312e-01 158 | -7.148062737782796550e-01 159 | -7.217832938830057721e-01 160 | -7.313847200075784816e-01 161 | -7.154544965426127279e-01 162 | -7.250189288457234982e-01 163 | -7.262214040756225231e-01 164 | -7.196998874346415276e-01 165 | -7.210014645258585597e-01 166 | -7.218033075332641602e-01 167 | -7.217670146624247662e-01 168 | -7.250616224606831750e-01 169 | -7.162174852689107141e-01 170 | -7.259954516092935739e-01 171 | -7.316199453671773378e-01 172 | -7.365749700864155924e-01 173 | -7.229057852427164255e-01 174 | -7.209534017244975113e-01 175 | -7.270612827936808742e-01 176 | -7.274021228154500252e-01 177 | -7.193847060203552024e-01 178 | -7.257357112566630297e-01 179 | -7.318844167391459488e-01 180 | -7.300043447812398778e-01 181 | -7.262370729446411266e-01 182 | -7.188072816530863829e-01 183 | -7.120067159334818596e-01 184 | -7.255073658625285082e-01 185 | -7.130653079350789403e-01 186 | -7.365742023785909476e-01 187 | -7.284064110120137547e-01 188 | -7.445549209912617483e-01 189 | -7.113369703292846680e-01 190 | -7.267719197273254661e-01 191 | -7.323107059796650775e-01 192 | -7.318611176808674790e-01 193 | -7.403582096099853294e-01 194 | -7.270653502146402491e-01 195 | -7.258614834149678119e-01 196 | -7.250793759028116847e-01 197 | -7.213846906026204486e-01 198 | -7.347040510177612616e-01 199 | -7.229578574498494170e-01 200 | -7.281377665201822502e-01 201 | -------------------------------------------------------------------------------- /scripts/complex_seg_val_loss.txt: -------------------------------------------------------------------------------- 1 | -1.461381685733795233e-01 2 | -2.754126262664794988e-01 3 | -3.947364592552184970e-01 4 | -4.720196175575256392e-01 5 | -5.349624395370483620e-01 6 | -5.559992599487304510e-01 7 | -5.838403248786926403e-01 8 | -5.929917311668395641e-01 9 | -5.793518424034118652e-01 10 | -6.080713725090026722e-01 11 | -6.231574630737304332e-01 12 | -6.182611274719238548e-01 13 | -5.437477803230286177e-01 14 | -6.192388105392455655e-01 15 | -5.924358725547790305e-01 16 | -6.137154054641723588e-01 17 | -6.205016231536865101e-01 18 | -6.259360313415527344e-01 19 | -6.070324611663818759e-01 20 | -6.358899211883545233e-01 21 | -6.289268827438354581e-01 22 | -6.440449643135071067e-01 23 | -5.982728934288025213e-01 24 | -6.178746891021728027e-01 25 | -6.424169325828552601e-01 26 | -6.141177415847778320e-01 27 | -6.286828446388245117e-01 28 | -6.427838540077209784e-01 29 | -6.496171498298645153e-01 30 | -6.377745151519775835e-01 31 | -6.328751182556152655e-01 32 | -6.360844993591308061e-01 33 | -6.365111780166625710e-01 34 | -6.523978972434997248e-01 35 | -6.540223693847656561e-01 36 | -6.473586225509643244e-01 37 | -6.467662644386291682e-01 38 | -6.170506834983825462e-01 39 | -6.372461652755737616e-01 40 | -6.418756508827209162e-01 41 | -6.665806579589843572e-01 42 | -6.584801340103149103e-01 43 | -6.569112873077392889e-01 44 | -6.568776464462280362e-01 45 | -6.396222281455993253e-01 46 | -6.671596837043761719e-01 47 | -6.655937218666077015e-01 48 | -6.509510278701782227e-01 49 | -6.659518933296203480e-01 50 | -6.760296106338501421e-01 51 | -6.612827134132385209e-01 52 | -6.390809535980224165e-01 53 | -6.779685544967651856e-01 54 | -6.824875736236571955e-01 55 | -6.072277164459228826e-01 56 | -6.662547850608825373e-01 57 | -6.743702983856201261e-01 58 | -6.730909013748168634e-01 59 | -6.331337451934814231e-01 60 | -6.651067900657653409e-01 61 | -6.653561639785766646e-01 62 | -6.844003510475158425e-01 63 | -6.795467257499694380e-01 64 | -6.725332522392273304e-01 65 | -6.684042572975158469e-01 66 | -6.921418404579162686e-01 67 | -6.837698745727539551e-01 68 | -6.662281012535095082e-01 69 | -6.809846305847168102e-01 70 | -6.845344066619872825e-01 71 | -6.945492982864379661e-01 72 | -6.926962804794311035e-01 73 | -6.901262378692627486e-01 74 | -6.896827673912048207e-01 75 | -6.696479964256286666e-01 76 | -7.043341326713562500e-01 77 | -6.763520526885986595e-01 78 | -6.983466196060180708e-01 79 | -6.947831082344054776e-01 80 | -6.666097712516784179e-01 81 | -6.950817489624023793e-01 82 | -7.015632629394531472e-01 83 | -6.948651480674743475e-01 84 | -7.074588179588318093e-01 85 | -6.924183487892150435e-01 86 | -7.062314748764038086e-01 87 | -6.968796181678772461e-01 88 | -6.908558773994445401e-01 89 | -7.206291913986205833e-01 90 | -6.587990617752075284e-01 91 | -7.206112241744995206e-01 92 | -6.944504141807555930e-01 93 | -7.064751005172729137e-01 94 | -7.004324650764465199e-01 95 | -6.796376943588257058e-01 96 | -7.183857417106628285e-01 97 | -6.983022689819335938e-01 98 | -6.896026921272278010e-01 99 | -6.989464330673217374e-01 100 | -7.036971139907837181e-01 101 | -6.971960067749023438e-01 102 | -6.925298523902893022e-01 103 | -7.128389906883240146e-01 104 | -7.104511189460754883e-01 105 | -6.945311808586120961e-01 106 | -7.225813698768616122e-01 107 | -7.155652332305908470e-01 108 | -7.254408788681030451e-01 109 | -7.057176995277404385e-01 110 | -7.044340229034423695e-01 111 | -6.995967936515807839e-01 112 | -7.075223040580749245e-01 113 | -7.189946842193603027e-01 114 | -7.206727576255798073e-01 115 | -7.253643798828125488e-01 116 | -7.057449054718017978e-01 117 | -7.202923583984375266e-01 118 | -7.178186106681824219e-01 119 | -7.196004271507263628e-01 120 | -7.195854067802429643e-01 121 | -7.239996242523193404e-01 122 | -7.276273488998412642e-01 123 | -7.292190122604370606e-01 124 | -7.260328674316406383e-01 125 | -6.921277236938476740e-01 126 | -7.110942411422729315e-01 127 | -7.327606058120728072e-01 128 | -7.231938529014587669e-01 129 | -6.461170220375060946e-01 130 | -7.312947344779968661e-01 131 | -7.251522660255431685e-01 132 | -7.244056367874145419e-01 133 | -7.102389097213744895e-01 134 | -7.288829469680786488e-01 135 | -7.192525625228881614e-01 136 | -7.227883315086365279e-01 137 | -7.361877298355102850e-01 138 | -7.367080187797546698e-01 139 | -7.335247445106506392e-01 140 | -7.317865514755248491e-01 141 | -7.142522120475769398e-01 142 | -7.387028408050536621e-01 143 | -7.358820533752441495e-01 144 | -7.320012331008911577e-01 145 | -7.371281719207763983e-01 146 | -7.171958136558532759e-01 147 | -7.297581529617309659e-01 148 | -7.250421214103698331e-01 149 | -7.397181272506714311e-01 150 | -7.249681854248046786e-01 151 | -7.358610105514525879e-01 152 | -7.203029847145080433e-01 153 | -7.191020870208739701e-01 154 | -7.159489250183105336e-01 155 | -7.228564643859862748e-01 156 | -7.275077605247497470e-01 157 | -7.250951075553894176e-01 158 | -7.111243367195129172e-01 159 | -7.338865232467651767e-01 160 | -7.358608412742614346e-01 161 | -7.315016961097717818e-01 162 | -7.372112393379210982e-01 163 | -7.274785685539245650e-01 164 | -7.352986288070678667e-01 165 | -7.308194732666015714e-01 166 | -7.359233689308166682e-01 167 | -7.346414327621459961e-01 168 | -7.114872241020202548e-01 169 | -7.215720462799072310e-01 170 | -7.430076026916503373e-01 171 | -7.419008588790894088e-01 172 | -7.316324090957642134e-01 173 | -7.445248031616210627e-01 174 | -7.474663519859313654e-01 175 | -7.400720596313477007e-01 176 | -7.201358580589294345e-01 177 | -7.339479446411132368e-01 178 | -7.359149622917174893e-01 179 | -7.348461651802062899e-01 180 | -7.436160254478454412e-01 181 | -7.434402179718017090e-01 182 | -7.403450417518615545e-01 183 | -7.340902757644652832e-01 184 | -7.270309090614318626e-01 185 | -7.475925946235656427e-01 186 | -7.465321373939514116e-01 187 | -7.396259570121764826e-01 188 | -7.426653909683227361e-01 189 | -7.425438570976257147e-01 190 | -7.373210000991821111e-01 191 | -7.305530261993408159e-01 192 | -7.455754113197327060e-01 193 | -7.411987304687499556e-01 194 | -7.473372793197632369e-01 195 | -7.454401445388794123e-01 196 | -7.388975310325622825e-01 197 | -7.413851404190063832e-01 198 | -7.214165210723877397e-01 199 | -7.276318216323852672e-01 200 | -7.424168872833252220e-01 201 | -------------------------------------------------------------------------------- /scripts/complex_weights.hd5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/scripts/complex_weights.hd5 -------------------------------------------------------------------------------- /scripts/plot_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | r = np.genfromtxt('d:/Projects/DeepQuaternionNets/scripts/real_seg_train_loss.txt') 6 | c = np.genfromtxt('d:/Projects/DeepQuaternionNets/scripts/complex_seg_train_loss.txt') 7 | q = np.genfromtxt('d:/Projects/DeepQuaternionNets/scripts/quaternion_seg_train_loss.txt') 8 | 9 | plt.plot(r, c='g') 10 | plt.plot(c, c='b') 11 | plt.plot(q, c='r') 12 | 13 | r = np.genfromtxt('d:/Projects/DeepQuaternionNets/scripts/real_seg_val_loss.txt') 14 | c = np.genfromtxt('d:/Projects/DeepQuaternionNets/scripts/complex_seg_val_loss.txt') 15 | q = np.genfromtxt('d:/Projects/DeepQuaternionNets/scripts/quaternion_seg_val_loss.txt') 16 | print(min(r)) 17 | print(min(c)) 18 | print(min(q)) 19 | 20 | plt.plot(r, '--', c='g') 21 | plt.plot(c, '--', c='b') 22 | plt.plot(q, '--', c='r') 23 | 24 | plt.title("Kitti Segmentation Loss Plot") 25 | plt.xlabel("Epochs") 26 | plt.ylabel("Loss") 27 | plt.show() -------------------------------------------------------------------------------- /scripts/quaternion_seg_train_loss.txt: -------------------------------------------------------------------------------- 1 | -2.210962605476379517e-01 2 | -2.233758552869160863e-01 3 | -2.166332970062891661e-01 4 | -2.311851676305134995e-01 5 | -2.419869061311085978e-01 6 | -2.245498585700988758e-01 7 | -2.549727010726928755e-01 8 | -2.621815029780070194e-01 9 | -3.367318256696065215e-01 10 | -4.544818548361460597e-01 11 | -5.288980142275492291e-01 12 | -5.900729099909464148e-01 13 | -6.071122097969054998e-01 14 | -6.576433515548706366e-01 15 | -6.700962344805400095e-01 16 | -6.604599205652872396e-01 17 | -6.791509373982747011e-01 18 | -6.603041052818298118e-01 19 | -6.818428866068522165e-01 20 | -6.647719987233480232e-01 21 | -6.362142022450765211e-01 22 | -6.930882930755615234e-01 23 | -6.538656278451283610e-01 24 | -6.751464343070984198e-01 25 | -7.023571364084879587e-01 26 | -6.831031894683837757e-01 27 | -6.941580049196879321e-01 28 | -6.780863229433695683e-01 29 | -6.705188616116841693e-01 30 | -7.045248524347941244e-01 31 | -7.004633863766988044e-01 32 | -7.161026692390441983e-01 33 | -7.069758947690327755e-01 34 | -6.907466944058736624e-01 35 | -6.671345694859822650e-01 36 | -6.427904740969340169e-01 37 | -7.003908284505208082e-01 38 | -6.952540183067321689e-01 39 | -6.949667994181315533e-01 40 | -7.140659419695536414e-01 41 | -7.584248121579487689e-01 42 | -6.865922292073567412e-01 43 | -7.104235363006591308e-01 44 | -6.973271179199218572e-01 45 | -7.450848603248596547e-01 46 | -7.074235971768697562e-01 47 | -7.135244115193685177e-01 48 | -6.869020779927571541e-01 49 | -7.083198221524557026e-01 50 | -6.930143562952677527e-01 51 | -7.132703232765197354e-01 52 | -7.052889863650003699e-01 53 | -7.018871903419494629e-01 54 | -7.323981857299804998e-01 55 | -7.166639033953349225e-01 56 | -7.213183339436849417e-01 57 | -6.969462545712789003e-01 58 | -7.515413705507913988e-01 59 | -7.125568405787150450e-01 60 | -7.039222725232442412e-01 61 | -6.895993129412333333e-01 62 | -7.097756528854369806e-01 63 | -7.240562653541564586e-01 64 | -7.114666509628295721e-01 65 | -7.257175898551940785e-01 66 | -7.285569397608439024e-01 67 | -7.388322194417317190e-01 68 | -7.083978732426960967e-01 69 | -7.091550000508626272e-01 70 | -7.003294849395751420e-01 71 | -7.075128110249837565e-01 72 | -7.200009457270304392e-01 73 | -7.196713940302530421e-01 74 | -7.081926782925923591e-01 75 | -6.974529735247294582e-01 76 | -7.138437851270039713e-01 77 | -7.314249030749002678e-01 78 | -7.056641006469726030e-01 79 | -7.044917829831440770e-01 80 | -7.117012778917948257e-01 81 | -7.294589440027873239e-01 82 | -7.181863363583882442e-01 83 | -7.462839810053507605e-01 84 | -7.149203149477640906e-01 85 | -7.134806927045186375e-01 86 | -7.139801470438639219e-01 87 | -7.320807727177938151e-01 88 | -7.217690308888753403e-01 89 | -7.281682586669921964e-01 90 | -7.331942383448283307e-01 91 | -7.063894542058308801e-01 92 | -7.295375569661458615e-01 93 | -7.381533042589822902e-01 94 | -7.539917699495951586e-01 95 | -7.532496468226115294e-01 96 | -7.317548863093058520e-01 97 | -7.381191261609395537e-01 98 | -6.984258683522542865e-01 99 | -7.104344836870829516e-01 100 | -7.436220900217691554e-01 101 | -7.160330112775167288e-01 102 | -7.141154917081197206e-01 103 | -7.419718631108601636e-01 104 | -7.364107966423034668e-01 105 | -7.573796017964681271e-01 106 | -7.202972984313964488e-01 107 | -7.001168735822042022e-01 108 | -7.371897490819295795e-01 109 | -7.174032537142435162e-01 110 | -7.388535857200622115e-01 111 | -7.163667297363280895e-01 112 | -7.502346030871073213e-01 113 | -7.259620436032613311e-01 114 | -7.369739397366841249e-01 115 | -7.006660326321919596e-01 116 | -7.253900170326232910e-01 117 | -7.659898241360981741e-01 118 | -7.118948165575663678e-01 119 | -7.343677910168965317e-01 120 | -7.419559216499328480e-01 121 | -7.628063178062438610e-01 122 | -7.490060424804687589e-01 123 | -7.391238228480021544e-01 124 | -7.289070685704549524e-01 125 | -7.520839365323385151e-01 126 | -7.563581681251525524e-01 127 | -7.447696884473165024e-01 128 | -7.375380261739095111e-01 129 | -7.494929822285970600e-01 130 | -7.553880516688028512e-01 131 | -7.369872832298278720e-01 132 | -7.502279527982076424e-01 133 | -7.397765819231668649e-01 134 | -7.402728176116943892e-01 135 | -7.642692605654398674e-01 136 | -7.548838520050048517e-01 137 | -7.538332351048787627e-01 138 | -7.489019687970479566e-01 139 | -7.239911031723021972e-01 140 | -7.287813766797384130e-01 141 | -7.291173251469930117e-01 142 | -7.701757550239562988e-01 143 | -7.446107943852742217e-01 144 | -7.537238581975300722e-01 145 | -7.622298860549926447e-01 146 | -7.573139580090840761e-01 147 | -7.657498971621194972e-01 148 | -7.417957027753193655e-01 149 | -7.467275245984394960e-01 150 | -7.708066964149474787e-01 151 | -7.542292650540669552e-01 152 | -7.433906912803649458e-01 153 | -7.411453866958618297e-01 154 | -7.406248203913370354e-01 155 | -7.696298209826151204e-01 156 | -7.492570726076761556e-01 157 | -7.661705128351847804e-01 158 | -7.472316090265910260e-01 159 | -7.763949203491210538e-01 160 | -7.572306124369303504e-01 161 | -7.430765970547994481e-01 162 | -7.102789338429769117e-01 163 | -7.681571189562479418e-01 164 | -7.640059773127237941e-01 165 | -7.581763299306233383e-01 166 | -7.586235857009887562e-01 167 | -7.524496491750081173e-01 168 | -7.804384597142537405e-01 169 | -7.296914672851562145e-01 170 | -7.498949885368346724e-01 171 | -7.619346086184183298e-01 172 | -7.597652180989583615e-01 173 | -7.497602796554565519e-01 174 | -7.433770227432251465e-01 175 | -7.526148271560668679e-01 176 | -7.675181317329407005e-01 177 | -7.589963976542154489e-01 178 | -7.679107880592346280e-01 179 | -7.752526354789733842e-01 180 | -7.334603126843770582e-01 181 | -7.643346309661864790e-01 182 | -7.789428154627482392e-01 183 | -7.658903050422668501e-01 184 | -7.698625906308491640e-01 185 | -7.671705913543701572e-01 186 | -7.415080801645914654e-01 187 | -7.478060905138651515e-01 188 | -7.624707372983297038e-01 189 | -7.848402523994445490e-01 190 | -7.740849486986796091e-01 191 | -7.666910378138224180e-01 192 | -7.340811816851298133e-01 193 | -7.770356639226277951e-01 194 | -7.496495191256205004e-01 195 | -7.547866868972777921e-01 196 | -7.539351932207742912e-01 197 | -7.583150180180867617e-01 198 | -7.643956605593363873e-01 199 | -7.434091782569884899e-01 200 | -7.472510592142740382e-01 201 | -------------------------------------------------------------------------------- /scripts/quaternion_seg_val_loss.txt: -------------------------------------------------------------------------------- 1 | -2.240529578924179144e-01 2 | -2.262643581628799305e-01 3 | -2.288125544786453192e-01 4 | -2.323838192224502441e-01 5 | -2.379066467285156361e-01 6 | -2.449108302593231146e-01 7 | -2.624225437641143910e-01 8 | -3.047907972335815452e-01 9 | -4.518292927742004195e-01 10 | -5.646976089477538929e-01 11 | -5.892211747169494584e-01 12 | -6.871905803680420366e-01 13 | -7.094725394248961869e-01 14 | -7.141153573989867942e-01 15 | -7.414490747451781827e-01 16 | -7.468604063987731578e-01 17 | -7.565995812416076438e-01 18 | -7.360086297988891468e-01 19 | -7.545816135406494540e-01 20 | -7.397950053215026633e-01 21 | -7.647410440444946111e-01 22 | -7.567007803916930841e-01 23 | -7.500149846076965554e-01 24 | -7.499030590057372825e-01 25 | -7.563126397132873269e-01 26 | -7.629370713233947221e-01 27 | -7.629937243461608620e-01 28 | -7.758199071884155362e-01 29 | -7.511695146560668501e-01 30 | -7.618219804763793901e-01 31 | -7.670927333831787376e-01 32 | -7.875511860847472789e-01 33 | -7.721999335289001509e-01 34 | -7.853039741516113503e-01 35 | -6.166160011291503817e-01 36 | -7.423373222351073997e-01 37 | -7.832833671569824352e-01 38 | -5.746818733215331987e-01 39 | -7.707172060012816850e-01 40 | -7.559125351905823242e-01 41 | -7.731163048744201793e-01 42 | -7.777544975280761275e-01 43 | -7.805159091949462891e-01 44 | -7.840347695350646795e-01 45 | -7.617195296287536666e-01 46 | -7.943069148063659490e-01 47 | -7.056860661506653143e-01 48 | -7.210072541236877131e-01 49 | -7.790224671363830122e-01 50 | -7.696363878250122026e-01 51 | -7.972276186943054732e-01 52 | -7.556682658195496005e-01 53 | -7.719083786010741965e-01 54 | -7.163705658912659091e-01 55 | -6.644572472572326527e-01 56 | -7.611915063858032404e-01 57 | -7.750524187088012606e-01 58 | -7.861719226837158292e-01 59 | -7.712716245651245472e-01 60 | -7.834187006950378285e-01 61 | -7.845588707923889515e-01 62 | -7.502792572975158780e-01 63 | -7.788125824928283647e-01 64 | -7.913509035110473100e-01 65 | -7.913832163810730069e-01 66 | -7.125830888748169167e-01 67 | -7.530515956878661710e-01 68 | -7.473279881477356001e-01 69 | -8.024167776107787864e-01 70 | -7.869359326362609375e-01 71 | -7.882350277900695978e-01 72 | -7.561425447463989702e-01 73 | -7.901430225372314542e-01 74 | -7.941184878349304421e-01 75 | -7.618995571136474299e-01 76 | -7.990320491790771085e-01 77 | -7.596886301040649547e-01 78 | -6.167420196533203169e-01 79 | -7.917839217185974610e-01 80 | -7.024280548095702681e-01 81 | -7.772456479072570312e-01 82 | -8.024617600440978604e-01 83 | -7.488399553298950018e-01 84 | -7.953676033020019576e-01 85 | -7.630772280693054466e-01 86 | -7.810859847068786221e-01 87 | -7.639414787292480025e-01 88 | -7.402051997184753152e-01 89 | -7.867299318313598633e-01 90 | -6.096427392959594682e-01 91 | -7.936555862426757368e-01 92 | -7.541410684585571067e-01 93 | -7.777841949462890314e-01 94 | -7.926068878173828658e-01 95 | -7.792396092414856090e-01 96 | -8.029613804817199441e-01 97 | -7.658870720863342196e-01 98 | -7.512613725662231623e-01 99 | -7.574707412719726696e-01 100 | -7.750924682617187367e-01 101 | -7.570396661758422852e-01 102 | -7.928188967704773438e-01 103 | -7.940559101104736062e-01 104 | -7.862112474441528276e-01 105 | -7.959046697616577459e-01 106 | -8.087959432601928844e-01 107 | -7.672859048843383434e-01 108 | -7.914940953254699929e-01 109 | -8.122492289543151722e-01 110 | -7.484164166450500977e-01 111 | -7.972816848754882280e-01 112 | -7.829720926284789995e-01 113 | -8.080445766448974831e-01 114 | -8.075188088417053489e-01 115 | -8.109839034080504883e-01 116 | -8.097206735610962269e-01 117 | -8.127936625480651767e-01 118 | -7.740312194824219283e-01 119 | -8.126434564590454546e-01 120 | -8.001952624320983753e-01 121 | -7.903132843971252264e-01 122 | -8.106230902671813787e-01 123 | -8.096190237998962536e-01 124 | -7.989002370834350497e-01 125 | -8.010275793075561923e-01 126 | -8.053160595893860130e-01 127 | -8.096684098243713823e-01 128 | -8.096193313598633035e-01 129 | -8.117827272415161222e-01 130 | -8.109795331954956499e-01 131 | -7.052100110054015936e-01 132 | -7.397786402702331010e-01 133 | -8.000564789772033336e-01 134 | -7.935999059677123491e-01 135 | -8.136927652359008389e-01 136 | -8.075754499435424671e-01 137 | -8.032962560653686968e-01 138 | -7.970140814781189187e-01 139 | -7.641871213912964089e-01 140 | -8.120873117446899547e-01 141 | -8.043765497207641113e-01 142 | -7.840998649597168413e-01 143 | -8.126603388786315385e-01 144 | -8.128614354133606446e-01 145 | -7.725287389755248757e-01 146 | -8.111465883255004616e-01 147 | -7.936134171485901279e-01 148 | -8.133054280281066806e-01 149 | -8.013473844528198331e-01 150 | -7.765507793426513761e-01 151 | -8.136809706687927468e-01 152 | -8.167744445800780850e-01 153 | -7.857774329185486240e-01 154 | -8.078253650665283558e-01 155 | -7.986056065559387074e-01 156 | -8.049237442016601296e-01 157 | -7.901003241539001909e-01 158 | -8.093737316131591752e-01 159 | -7.999979138374329057e-01 160 | -7.885196948051452770e-01 161 | -8.014851331710814986e-01 162 | -8.178468012809753107e-01 163 | -8.111778497695922852e-01 164 | -7.827858448028564009e-01 165 | -8.155226397514343528e-01 166 | -7.868782114982605380e-01 167 | -8.124265789985656294e-01 168 | -8.177014923095703436e-01 169 | -8.152568244934081942e-01 170 | -7.993021440505981845e-01 171 | -8.123673939704895153e-01 172 | -8.014588689804077459e-01 173 | -7.859247279167175249e-01 174 | -8.090605545043945135e-01 175 | -8.276852226257324086e-01 176 | -7.707122945785522816e-01 177 | -7.795216727256775169e-01 178 | -8.189749670028686479e-01 179 | -8.109907126426696422e-01 180 | -8.062738442420959162e-01 181 | -8.139301085472107067e-01 182 | -8.255951309204101474e-01 183 | -8.050788235664367853e-01 184 | -7.813599872589110928e-01 185 | -8.092006993293762163e-01 186 | -8.234179186820983487e-01 187 | -8.202013230323791371e-01 188 | -7.969183850288391602e-01 189 | -8.207704305648804155e-01 190 | -8.152652668952942161e-01 191 | -8.006108617782592418e-01 192 | -8.266062116622925338e-01 193 | -8.186352872848510209e-01 194 | -8.250190997123718617e-01 195 | -8.007596421241760520e-01 196 | -7.997216629981994673e-01 197 | -8.274273872375488281e-01 198 | -8.168161940574646396e-01 199 | -7.978584241867064941e-01 200 | -7.988081336021423118e-01 201 | -------------------------------------------------------------------------------- /scripts/quaternion_weights.hd5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/scripts/quaternion_weights.hd5 -------------------------------------------------------------------------------- /scripts/real_seg_train_loss.txt: -------------------------------------------------------------------------------- 1 | -2.959348098436991314e-01 2 | -4.955247183640797970e-01 3 | -5.733685199419656930e-01 4 | -5.666635791460673532e-01 5 | -5.820373940467834517e-01 6 | -5.726233855883280377e-01 7 | -5.954131793975829812e-01 8 | -5.958782005310058194e-01 9 | -6.179989043871562115e-01 10 | -6.000128857294718676e-01 11 | -6.202239696184793649e-01 12 | -6.183926280339558934e-01 13 | -6.131155014038085493e-01 14 | -5.919458349545796416e-01 15 | -6.248141463597615131e-01 16 | -6.125180308024088527e-01 17 | -6.261520965894062929e-01 18 | -6.369548074404398630e-01 19 | -5.938185675938923991e-01 20 | -6.430803712209065548e-01 21 | -6.154083037376403276e-01 22 | -6.108150760332743401e-01 23 | -6.208854301770527773e-01 24 | -6.466703383127848648e-01 25 | -6.411654591560363992e-01 26 | -6.448401983578999630e-01 27 | -6.475065747896829871e-01 28 | -6.587741470336914151e-01 29 | -6.496193846066792288e-01 30 | -6.425242789586385062e-01 31 | -6.629912233352661666e-01 32 | -6.316851353645325240e-01 33 | -6.444036221504211293e-01 34 | -6.457151309649149518e-01 35 | -6.584996207555134928e-01 36 | -6.565856695175170676e-01 37 | -6.603302566210428326e-01 38 | -6.532426428794860795e-01 39 | -6.626841092109679865e-01 40 | -6.666416827837625680e-01 41 | -6.564746205012003122e-01 42 | -6.700534717241922733e-01 43 | -6.657854668299356682e-01 44 | -6.773550009727478116e-01 45 | -6.628152060508728516e-01 46 | -6.321075455347696659e-01 47 | -6.619066476821899858e-01 48 | -6.733073608080546446e-01 49 | -6.471793826421101681e-01 50 | -6.675770131746927571e-01 51 | -6.688065759340922067e-01 52 | -6.709548266728718735e-01 53 | -6.749565863609313654e-01 54 | -6.701649610201517726e-01 55 | -6.756351089477539151e-01 56 | -6.722525143623352406e-01 57 | -6.602286060651143540e-01 58 | -6.566987689336141187e-01 59 | -6.688571055730183623e-01 60 | -6.778854695955912613e-01 61 | -6.665628695487976207e-01 62 | -6.706072227160135846e-01 63 | -6.608721645673115708e-01 64 | -6.994011354446411088e-01 65 | -6.944684966405232363e-01 66 | -7.008503492673238133e-01 67 | -6.928470142682393584e-01 68 | -6.755962864557901737e-01 69 | -7.067713888486226725e-01 70 | -6.994002167383829294e-01 71 | -7.001434715588887858e-01 72 | -6.993697055180867483e-01 73 | -6.963288410504658854e-01 74 | -6.912881199518839948e-01 75 | -6.998349523544311612e-01 76 | -6.934567141532897994e-01 77 | -7.073820416132609035e-01 78 | -7.016189813613891157e-01 79 | -6.830039540926615693e-01 80 | -6.962775389353433875e-01 81 | -7.007660754521687352e-01 82 | -6.906678684552510772e-01 83 | -6.907519642512003344e-01 84 | -7.027443480491638583e-01 85 | -7.119829154014587491e-01 86 | -7.175606075922648097e-01 87 | -7.164623061815897920e-01 88 | -7.205971638361613474e-01 89 | -7.119654067357381599e-01 90 | -7.090898887316385846e-01 91 | -7.268584402402241684e-01 92 | -7.242086688677470407e-01 93 | -7.132574836413065889e-01 94 | -7.312013157208760994e-01 95 | -6.959148263931274059e-01 96 | -7.020681770642598352e-01 97 | -7.265542554855346502e-01 98 | -6.972573518753051314e-01 99 | -7.269949118296304968e-01 100 | -7.183194319407145700e-01 101 | -7.208761127789815593e-01 102 | -7.273418060938516838e-01 103 | -7.374287430445353175e-01 104 | -7.378353810310364080e-01 105 | -7.348496063550313817e-01 106 | -7.322956538200378063e-01 107 | -7.325818888346353752e-01 108 | -7.270671534538268865e-01 109 | -7.239332111676534121e-01 110 | -7.108332141240437663e-01 111 | -7.128904732068379602e-01 112 | -7.088620392481486121e-01 113 | -7.095807298024495724e-01 114 | -7.082957514127095067e-01 115 | -7.153621419270833615e-01 116 | -7.244481635093689142e-01 117 | -7.386618582407633582e-01 118 | -7.312042140960692826e-01 119 | -7.113063311576843795e-01 120 | -7.337549153963724979e-01 121 | -6.970552515983581277e-01 122 | -7.394967961311340598e-01 123 | -7.399484427769978945e-01 124 | -7.467218716939290069e-01 125 | -7.234985955556233472e-01 126 | -7.396000997225443685e-01 127 | -7.478599413235982318e-01 128 | -7.390732995669047067e-01 129 | -7.253084309895833082e-01 130 | -7.378735637664795011e-01 131 | -7.399956703186034712e-01 132 | -7.271890083948771011e-01 133 | -7.124831438064574973e-01 134 | -7.323803766568501850e-01 135 | -7.413973402976989924e-01 136 | -7.357390626271566036e-01 137 | -7.313564936319987275e-01 138 | -7.221279740333557129e-01 139 | -7.444977347056070949e-01 140 | -7.493731657663981194e-01 141 | -7.397181495030721310e-01 142 | -7.544502822558084576e-01 143 | -7.359531227747598825e-01 144 | -7.651097130775451616e-01 145 | -7.418463142712911074e-01 146 | -7.339362621307372603e-01 147 | -7.368064451217651856e-01 148 | -7.404483580589293901e-01 149 | -7.269970949490864864e-01 150 | -7.393708984057109079e-01 151 | -7.409558582305908470e-01 152 | -7.486385528246561893e-01 153 | -7.371239964167276559e-01 154 | -7.132590913772582919e-01 155 | -7.526524392763773719e-01 156 | -7.355866694450378107e-01 157 | -7.305958286921183653e-01 158 | -7.422937742869059052e-01 159 | -7.467340310414631865e-01 160 | -7.486980048815409239e-01 161 | -7.355762370427449959e-01 162 | -7.399656025568643880e-01 163 | -7.481936415036519739e-01 164 | -7.519494644800821526e-01 165 | -7.518146864573160837e-01 166 | -7.439274493853250680e-01 167 | -7.552961405118306937e-01 168 | -7.538083092371622440e-01 169 | -7.527331693967183046e-01 170 | -7.578772314389546905e-01 171 | -7.571828118960062248e-01 172 | -7.594695329666137917e-01 173 | -7.468008041381836160e-01 174 | -7.523836175600687737e-01 175 | -7.504926848411560547e-01 176 | -7.496330769856770493e-01 177 | -7.499448879559834635e-01 178 | -7.526632555325826202e-01 179 | -7.523740776379903572e-01 180 | -7.513967800140380904e-01 181 | -7.602710556983948154e-01 182 | -7.435343837738037642e-01 183 | -7.478741280237833911e-01 184 | -7.517301273345947665e-01 185 | -7.477754727999369466e-01 186 | -7.700378902753194366e-01 187 | -7.492882601420084443e-01 188 | -7.693681001663208452e-01 189 | -7.423701739311218128e-01 190 | -7.490922792752583437e-01 191 | -7.618717304865518791e-01 192 | -7.605762966473896824e-01 193 | -7.637581658363342685e-01 194 | -7.541938018798828303e-01 195 | -7.487571207682292007e-01 196 | -7.450183931986490427e-01 197 | -7.556896940867106061e-01 198 | -7.616882467269897594e-01 199 | -7.504704777399698878e-01 200 | -7.638806390762329368e-01 201 | -------------------------------------------------------------------------------- /scripts/real_seg_val_loss.txt: -------------------------------------------------------------------------------- 1 | -4.935189604759216420e-01 2 | -5.304528450965881881e-01 3 | -5.711652851104735973e-01 4 | -6.028618836402892533e-01 5 | -6.227564215660095215e-01 6 | -6.118804574012756126e-01 7 | -5.178080034255981845e-01 8 | -5.994946718215942161e-01 9 | -6.317859649658202903e-01 10 | -6.236528348922729670e-01 11 | -6.357336258888244274e-01 12 | -6.271789312362671120e-01 13 | -5.713862919807434215e-01 14 | -6.401888227462768644e-01 15 | -6.104484081268310991e-01 16 | -5.746538138389587491e-01 17 | -6.430161738395691051e-01 18 | -6.332989025115967063e-01 19 | -6.187552833557128373e-01 20 | -6.340448808670043901e-01 21 | -6.111176681518554199e-01 22 | -6.289192008972167569e-01 23 | -6.251671004295349388e-01 24 | -6.237866306304931108e-01 25 | -6.555559635162353516e-01 26 | -6.157449746131896662e-01 27 | -6.591001415252685236e-01 28 | -6.363238954544067072e-01 29 | -6.676479864120483665e-01 30 | -6.570162081718444735e-01 31 | -6.708783745765686257e-01 32 | -6.410228943824768377e-01 33 | -6.405593228340148437e-01 34 | -5.899071240425109997e-01 35 | -6.526860857009887384e-01 36 | -6.391472005844116566e-01 37 | -6.797835493087768244e-01 38 | -6.465482187271117676e-01 39 | -6.310751032829284846e-01 40 | -6.376750922203063832e-01 41 | -6.370572495460510298e-01 42 | -5.715924835205078436e-01 43 | -6.675188684463501110e-01 44 | -6.839575862884521573e-01 45 | -6.840340828895569159e-01 46 | -6.666822767257690741e-01 47 | -6.529188132286072310e-01 48 | -6.634336853027343883e-01 49 | -6.830325460433960272e-01 50 | -6.861995911598205433e-01 51 | -6.271282720565796165e-01 52 | -6.651475143432616921e-01 53 | -6.882941579818725453e-01 54 | -6.974884891510009233e-01 55 | -7.032516908645629616e-01 56 | -7.030426859855651855e-01 57 | -6.837853121757507813e-01 58 | -6.755815958976745250e-01 59 | -6.907841491699219238e-01 60 | -7.058698463439941895e-01 61 | -6.970159602165222346e-01 62 | -6.837571001052856756e-01 63 | -5.962529110908508789e-01 64 | -6.911874103546142845e-01 65 | -7.122125411033630726e-01 66 | -7.083830332756042569e-01 67 | -4.821870470046997204e-01 68 | -7.090314173698425426e-01 69 | -7.041974496841431064e-01 70 | -7.165817499160767046e-01 71 | -7.095936465263367188e-01 72 | -6.916607999801636097e-01 73 | -7.010343408584595037e-01 74 | -6.811671447753906428e-01 75 | -7.164652776718140048e-01 76 | -7.008011651039123269e-01 77 | -7.135806250572204190e-01 78 | -6.538971805572509899e-01 79 | -6.751359343528747781e-01 80 | -7.176841378211975542e-01 81 | -7.193181347846985263e-01 82 | -7.000443577766418679e-01 83 | -7.165840959548950284e-01 84 | -7.045882105827331321e-01 85 | -7.241465330123901811e-01 86 | -7.152210140228271174e-01 87 | -7.292654204368591575e-01 88 | -7.078354859352111728e-01 89 | -7.155569744110107377e-01 90 | -7.291074371337891158e-01 91 | -7.153387975692748757e-01 92 | -6.358115863800049006e-01 93 | -7.288111305236816273e-01 94 | -6.954174971580505682e-01 95 | -7.304040551185607688e-01 96 | -7.033975172042846724e-01 97 | -6.956838440895080078e-01 98 | -6.912333297729492676e-01 99 | -7.259335970878600941e-01 100 | -6.817279386520386231e-01 101 | -6.977855658531189054e-01 102 | -6.654562497138977184e-01 103 | -7.201931548118590820e-01 104 | -7.353887414932250843e-01 105 | -7.391995668411255327e-01 106 | -7.265998530387878240e-01 107 | -7.310887598991393510e-01 108 | -7.374259614944458363e-01 109 | -6.845976138114928666e-01 110 | -7.370288991928100053e-01 111 | -7.340129804611206232e-01 112 | -7.125291037559509100e-01 113 | -7.064336633682251065e-01 114 | -7.223585796356201127e-01 115 | -7.305583977699279252e-01 116 | -7.334857416152954501e-01 117 | -7.322049403190612482e-01 118 | -7.297431612014770863e-01 119 | -7.288121080398559659e-01 120 | -7.276027917861938032e-01 121 | -7.361634492874145064e-01 122 | -7.429751086235046875e-01 123 | -7.464122176170349121e-01 124 | -7.455514121055603072e-01 125 | -7.385909485816956055e-01 126 | -7.321843290328979847e-01 127 | -7.493409085273742276e-01 128 | -7.128530311584472257e-01 129 | -7.368942427635193360e-01 130 | -7.256142592430114835e-01 131 | -7.406113076210022461e-01 132 | -7.319140815734863637e-01 133 | -7.503688716888428090e-01 134 | -7.330222344398498180e-01 135 | -7.452839112281799405e-01 136 | -7.376144194602965998e-01 137 | -7.525051736831664950e-01 138 | -7.230424547195434259e-01 139 | -7.520295858383179155e-01 140 | -7.293355512619018599e-01 141 | -7.512959074974060458e-01 142 | -7.480044293403625977e-01 143 | -6.951986980438232822e-01 144 | -7.503352832794188965e-01 145 | -7.537842702865600319e-01 146 | -7.456005311012268155e-01 147 | -7.344099950790404785e-01 148 | -7.469731068611145108e-01 149 | -7.489690852165221679e-01 150 | -7.331086683273315696e-01 151 | -6.828656792640686035e-01 152 | -7.436505961418151678e-01 153 | -7.180059742927551447e-01 154 | -7.348403906822205123e-01 155 | -7.397840380668639915e-01 156 | -7.487318563461303755e-01 157 | -7.405200934410095526e-01 158 | -6.937026882171630771e-01 159 | -7.521141958236694514e-01 160 | -7.546514415740966930e-01 161 | -7.594753503799438477e-01 162 | -7.481293654441834029e-01 163 | -6.776519346237183106e-01 164 | -7.251050782203674050e-01 165 | -7.581157827377319469e-01 166 | -7.618108463287353027e-01 167 | -7.569805479049682706e-01 168 | -7.592968225479126421e-01 169 | -7.438817167282104625e-01 170 | -7.449821853637694780e-01 171 | -7.535343027114868031e-01 172 | -7.625734257698059126e-01 173 | -7.584504365921020952e-01 174 | -7.436173343658447621e-01 175 | -7.479614067077636763e-01 176 | -7.344317674636841042e-01 177 | -7.346582889556885210e-01 178 | -7.618546414375305664e-01 179 | -7.689313435554504528e-01 180 | -7.623344254493713557e-01 181 | -7.579532122611999378e-01 182 | -7.565146398544311257e-01 183 | -7.515179419517517001e-01 184 | -7.607282495498657315e-01 185 | -7.666039609909057750e-01 186 | -7.514293932914734020e-01 187 | -7.424433588981628196e-01 188 | -7.160605978965759455e-01 189 | -7.578412413597106490e-01 190 | -7.635660099983215821e-01 191 | -7.621275186538696289e-01 192 | -7.476811671257018732e-01 193 | -7.574582457542419212e-01 194 | -7.283831214904785689e-01 195 | -7.673563861846923739e-01 196 | -7.691823983192443981e-01 197 | -7.527939701080321733e-01 198 | -7.364590477943420144e-01 199 | -7.560916757583617809e-01 200 | -7.658571457862853871e-01 201 | -------------------------------------------------------------------------------- /scripts/real_weights.hd5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaudetcj/DeepQuaternionNetworks/43b321e1701287ce9cf9af1eb16457bdd2c85175/scripts/real_weights.hd5 -------------------------------------------------------------------------------- /scripts/run_small_segmentation_experiments.bat: -------------------------------------------------------------------------------- 1 | call small_segmentation_real.bat 2 | call small_segmentation_complex.bat 3 | call small_segmentation_quaternion.bat 4 | 5 | pause -------------------------------------------------------------------------------- /scripts/small_segmentation_complex.bat: -------------------------------------------------------------------------------- 1 | python ../runner.py "segmentation" --mode "complex" -sf 16 -init "complex_independent" -------------------------------------------------------------------------------- /scripts/small_segmentation_quaternion.bat: -------------------------------------------------------------------------------- 1 | python ../runner.py "segmentation" --mode "quaternion" -------------------------------------------------------------------------------- /scripts/small_segmentation_real.bat: -------------------------------------------------------------------------------- 1 | python ../runner.py "segmentation" --mode "real" -sf 32 -init "he_normal" -------------------------------------------------------------------------------- /training_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Authors: Chase Gaudet 4 | # code based on work by Chiheb Trabelsi 5 | # on Deep Complex Networks git source 6 | 7 | # Imports 8 | import sys 9 | sys.setrecursionlimit(10000) 10 | import logging as L 11 | import numpy as np 12 | from complex_layers.utils import GetReal, GetImag 13 | from complex_layers.conv import ComplexConv2D 14 | from complex_layers.bn import ComplexBatchNormalization 15 | from quaternion_layers.utils import Params, GetR, GetI, GetJ, GetK 16 | from quaternion_layers.conv import QuaternionConv2D 17 | from quaternion_layers.bn import QuaternionBatchNormalization 18 | import keras 19 | from keras.callbacks import Callback, ModelCheckpoint, LearningRateScheduler 20 | from keras.datasets import cifar10, cifar100 21 | from keras.layers import Layer, AveragePooling2D, AveragePooling3D, add, Add, concatenate, Concatenate, Input, Flatten, Dense, Convolution2D, BatchNormalization, Activation, Reshape, ConvLSTM2D, Conv2D 22 | from keras.models import Model, load_model, save_model 23 | from keras.optimizers import SGD 24 | from keras.preprocessing.image import ImageDataGenerator 25 | from keras.regularizers import l2 26 | from keras.utils.np_utils import to_categorical 27 | import keras.backend as K 28 | K.set_image_data_format('channels_first') 29 | K.set_image_dim_ordering('th') 30 | 31 | 32 | # Callbacks: 33 | # Print a newline after each epoch. 34 | class PrintNewlineAfterEpochCallback(Callback): 35 | def on_epoch_end(self, epoch, logs={}): 36 | sys.stdout.write("\n") 37 | 38 | # Also evaluate performance on test set at each epoch end. 39 | class TestErrorCallback(Callback): 40 | def __init__(self, test_data): 41 | self.test_data = test_data 42 | self.loss_history = [] 43 | self.acc_history = [] 44 | 45 | def on_epoch_end(self, epoch, logs={}): 46 | x, y = self.test_data 47 | 48 | L.getLogger("train").info("Epoch {:5d} Evaluating on test set...".format(epoch+1)) 49 | test_loss, test_acc = self.model.evaluate(x, y, verbose=0) 50 | L.getLogger("train").info(" complete.") 51 | 52 | self.loss_history.append(test_loss) 53 | self.acc_history.append(test_acc) 54 | 55 | L.getLogger("train").info("Epoch {:5d} train_loss: {}, train_acc: {}, val_loss: {}, val_acc: {}, test_loss: {}, test_acc: {}".format( 56 | epoch+1, 57 | logs["loss"], logs["acc"], 58 | logs["val_loss"], logs["val_acc"], 59 | test_loss, test_acc)) 60 | 61 | # Keep a history of the validation performance. 62 | class TrainValHistory(Callback): 63 | def __init__(self): 64 | self.train_loss = [] 65 | self.train_acc = [] 66 | self.val_loss = [] 67 | self.val_acc = [] 68 | 69 | def on_epoch_end(self, epoch, logs={}): 70 | self.train_loss.append(logs.get('loss')) 71 | self.train_acc .append(logs.get('acc')) 72 | self.val_loss .append(logs.get('val_loss')) 73 | self.val_acc .append(logs.get('val_acc')) 74 | 75 | 76 | class LrDivisor(Callback): 77 | def __init__(self, patience=float(50000), division_cst=10.0, epsilon=1e-03, verbose=1, epoch_checkpoints={41, 61}): 78 | super(Callback, self).__init__() 79 | self.patience = patience 80 | self.checkpoints = epoch_checkpoints 81 | self.wait = 0 82 | self.previous_score = 0. 83 | self.division_cst = division_cst 84 | self.epsilon = epsilon 85 | self.verbose = verbose 86 | self.iterations = 0 87 | 88 | def on_batch_begin(self, batch, logs={}): 89 | self.iterations += 1 90 | 91 | def on_epoch_end(self, epoch, logs={}): 92 | current_score = logs.get('val_acc') 93 | divide = False 94 | if (epoch + 1) in self.checkpoints: 95 | divide = True 96 | elif (current_score >= self.previous_score - self.epsilon and current_score <= self.previous_score + self.epsilon): 97 | self.wait +=1 98 | if self.wait == self.patience: 99 | divide = True 100 | else: 101 | self.wait = 0 102 | if divide == True: 103 | K.set_value(self.model.optimizer.lr, self.model.optimizer.lr.get_value() / self.division_cst) 104 | self.wait = 0 105 | if self.verbose > 0: 106 | L.getLogger("train").info("Current learning rate is divided by"+str(self.division_cst) + ' and his values is equal to: ' + str(self.model.optimizer.lr.get_value())) 107 | self.previous_score = current_score 108 | 109 | 110 | def schedule(epoch): 111 | if epoch >= 0 and epoch < 10: 112 | lrate = 0.01 113 | if epoch == 0: 114 | L.getLogger("train").info("Current learning rate value is "+str(lrate)) 115 | elif epoch >= 10 and epoch < 100: 116 | lrate = 0.01 117 | if epoch == 10: 118 | L.getLogger("train").info("Current learning rate value is "+str(lrate)) 119 | elif epoch >= 100 and epoch < 120: 120 | lrate = 0.01 121 | if epoch == 100: 122 | L.getLogger("train").info("Current learning rate value is "+str(lrate)) 123 | elif epoch >= 120 and epoch < 150: 124 | lrate = 0.001 125 | if epoch == 120: 126 | L.getLogger("train").info("Current learning rate value is "+str(lrate)) 127 | elif epoch >= 150: 128 | lrate = 0.0001 129 | if epoch == 150: 130 | L.getLogger("train").info("Current learning rate value is "+str(lrate)) 131 | return lrate 132 | 133 | 134 | def learnVectorBlock(I, featmaps, filter_size, act, bnArgs): 135 | """Learn initial vector component for input.""" 136 | 137 | O = BatchNormalization(**bnArgs)(I) 138 | O = Activation(act)(O) 139 | O = Convolution2D(featmaps, filter_size, 140 | padding='same', 141 | kernel_initializer='he_normal', 142 | use_bias=False, 143 | kernel_regularizer=l2(0.0001))(O) 144 | 145 | O = BatchNormalization(**bnArgs)(O) 146 | O = Activation(act)(O) 147 | O = Convolution2D(featmaps, filter_size, 148 | padding='same', 149 | kernel_initializer='he_normal', 150 | use_bias=False, 151 | kernel_regularizer=l2(0.0001))(O) 152 | 153 | return O 154 | 155 | 156 | def getResidualBlock(I, mode, filter_size, featmaps, activation, shortcut, convArgs, bnArgs): 157 | """Get residual block.""" 158 | 159 | if mode == "real": 160 | O = BatchNormalization(**bnArgs)(I) 161 | elif mode == "complex": 162 | O = ComplexBatchNormalization(**bnArgs)(I) 163 | elif mode == "quaternion": 164 | O = QuaternionBatchNormalization(**bnArgs)(I) 165 | O = Activation(activation)(O) 166 | 167 | if shortcut == 'regular': 168 | if mode == "real": 169 | O = Conv2D(featmaps, filter_size, **convArgs)(O) 170 | elif mode == "complex": 171 | O = ComplexConv2D(featmaps, filter_size, **convArgs)(O) 172 | elif mode == "quaternion": 173 | O = QuaternionConv2D(featmaps, filter_size, **convArgs)(O) 174 | elif shortcut == 'projection': 175 | if mode == "real": 176 | O = Conv2D(featmaps, filter_size, strides=(2, 2), **convArgs)(O) 177 | elif mode == "complex": 178 | O = ComplexConv2D(featmaps, filter_size, strides=(2, 2), **convArgs)(O) 179 | elif mode == "quaternion": 180 | O = QuaternionConv2D(featmaps, filter_size, strides=(2, 2), **convArgs)(O) 181 | 182 | if mode == "real": 183 | O = BatchNormalization(**bnArgs)(O) 184 | O = Activation(activation)(O) 185 | O = Conv2D(featmaps, filter_size, **convArgs)(O) 186 | elif mode == "complex": 187 | O = ComplexBatchNormalization(**bnArgs)(O) 188 | O = Activation(activation)(O) 189 | O = ComplexConv2D(featmaps, filter_size, **convArgs)(O) 190 | elif mode == "quaternion": 191 | O = QuaternionBatchNormalization(**bnArgs)(O) 192 | O = Activation(activation)(O) 193 | O = QuaternionConv2D(featmaps, filter_size, **convArgs)(O) 194 | 195 | if shortcut == 'regular': 196 | O = Add()([O, I]) 197 | elif shortcut == 'projection': 198 | if mode == "real": 199 | X = Conv2D(featmaps, (1, 1), strides = (2, 2), **convArgs)(I) 200 | O = Concatenate(1)([X, O]) 201 | elif mode == "complex": 202 | X = ComplexConv2D(featmaps, (1, 1), strides = (2, 2), **convArgs)(I) 203 | O_real = Concatenate(1)([GetReal()(X), GetReal()(O)]) 204 | O_imag = Concatenate(1)([GetImag()(X), GetImag()(O)]) 205 | O = Concatenate(1)([O_real, O_imag]) 206 | elif mode == "quaternion": 207 | X = QuaternionConv2D(featmaps, (1, 1), strides = (2, 2), **convArgs)(I) 208 | O_r = Concatenate(1)([GetR()(X), GetR()(O)]) 209 | O_i = Concatenate(1)([GetI()(X), GetI()(O)]) 210 | O_j = Concatenate(1)([GetJ()(X), GetJ()(O)]) 211 | O_k = Concatenate(1)([GetK()(X), GetK()(O)]) 212 | O = Concatenate(1)([O_r, O_i, O_j, O_k]) 213 | 214 | return O 215 | 216 | 217 | def getModel(params): 218 | mode = params.mode 219 | n = params.num_blocks 220 | sf = params.start_filter 221 | dataset = params.dataset 222 | activation = params.act 223 | inputShape = (3, 32, 32) 224 | channelAxis = 1 225 | filsize = (3, 3) 226 | convArgs = { 227 | "padding": "same", 228 | "use_bias": False, 229 | "kernel_regularizer": l2(0.0001), 230 | } 231 | bnArgs = { 232 | "axis": channelAxis, 233 | "momentum": 0.9, 234 | "epsilon": 1e-04 235 | } 236 | 237 | convArgs.update({"kernel_initializer": params.init}) 238 | 239 | # Create the vector channels 240 | R = Input(shape=inputShape) 241 | 242 | if mode != "quaternion": 243 | I = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 244 | O = concatenate([R, I], axis=channelAxis) 245 | else: 246 | I = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 247 | J = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 248 | K = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 249 | O = concatenate([R, I, J, K], axis=channelAxis) 250 | 251 | if mode == "real": 252 | O = Conv2D(sf, filsize, **convArgs)(O) 253 | O = BatchNormalization(**bnArgs)(O) 254 | elif mode == "complex": 255 | O = ComplexConv2D(sf, filsize, **convArgs)(O) 256 | O = ComplexBatchNormalization(**bnArgs)(O) 257 | else: 258 | O = QuaternionConv2D(sf, filsize, **convArgs)(O) 259 | O = QuaternionBatchNormalization(**bnArgs)(O) 260 | O = Activation(activation)(O) 261 | 262 | for i in range(n): 263 | O = getResidualBlock(O, mode, filsize, sf, activation, 'regular', convArgs, bnArgs) 264 | 265 | O = getResidualBlock(O, mode, filsize, sf, activation, 'projection', convArgs, bnArgs) 266 | 267 | for i in range(n-1): 268 | O = getResidualBlock(O, mode, filsize, sf*2, activation, 'regular', convArgs, bnArgs) 269 | 270 | O = getResidualBlock(O, mode, filsize, sf*2, activation, 'projection', convArgs, bnArgs) 271 | 272 | for i in range(n-1): 273 | O = getResidualBlock(O, mode, filsize, sf*4, activation, 'regular', convArgs, bnArgs) 274 | 275 | O = AveragePooling2D(pool_size=(8, 8))(O) 276 | 277 | # Flatten 278 | O = Flatten()(O) 279 | 280 | # Dense 281 | if dataset == 'cifar10': 282 | O = Dense(10, activation='softmax', kernel_regularizer=l2(0.0001))(O) 283 | elif dataset == 'cifar100': 284 | O = Dense(100, activation='softmax', kernel_regularizer=l2(0.0001))(O) 285 | 286 | model = Model(R, O) 287 | opt = SGD (lr = params.lr, 288 | momentum = params.momentum, 289 | decay = params.decay, 290 | nesterov = True, 291 | clipnorm = params.clipnorm) 292 | model.compile(opt, 'categorical_crossentropy', metrics=['accuracy']) 293 | return model 294 | 295 | 296 | def train(params, model): 297 | if params.dataset == 'cifar10': 298 | (X_train, y_train), (X_test, y_test) = cifar10.load_data() 299 | nb_classes = 10 300 | n_train = 45000 301 | elif params.dataset == 'cifar100': 302 | (X_train, y_train), (X_test, y_test) = cifar100.load_data() 303 | nb_classes = 100 304 | n_train = 45000 305 | 306 | X_train = X_train.astype('float32') / 255.0 307 | X_test = X_test.astype('float32') / 255.0 308 | 309 | shuf_inds = np.arange(len(y_train)) 310 | np.random.seed(424242) 311 | np.random.shuffle(shuf_inds) 312 | train_inds = shuf_inds[:n_train] 313 | val_inds = shuf_inds[n_train:] 314 | 315 | X_train = X_train.astype('float32') / 255.0 316 | X_test = X_test.astype('float32') / 255.0 317 | 318 | X_train_split = X_train[train_inds] 319 | X_val_split = X_train[val_inds] 320 | y_train_split = y_train[train_inds] 321 | y_val_split = y_train[val_inds] 322 | 323 | pixel_mean = np.mean(X_train_split, axis=0) 324 | 325 | X_train = X_train_split.astype(np.float32) - pixel_mean 326 | X_val = X_val_split.astype(np.float32) - pixel_mean 327 | X_test = X_test.astype(np.float32) - pixel_mean 328 | 329 | Y_train = to_categorical(y_train_split, nb_classes) 330 | Y_val = to_categorical(y_val_split, nb_classes) 331 | Y_test = to_categorical(y_test, nb_classes) 332 | 333 | datagen = ImageDataGenerator(height_shift_range=0.125, 334 | width_shift_range=0.125, 335 | horizontal_flip=True) 336 | 337 | testErrCb = TestErrorCallback((X_test, Y_test)) 338 | trainValHistCb = TrainValHistory() 339 | lrSchedCb = LearningRateScheduler(schedule) 340 | callbacks = [ModelCheckpoint('{}_weights.hd5'.format(params.mode), monitor='val_loss', verbose=0, save_best_only=True), 341 | testErrCb, 342 | lrSchedCb, 343 | trainValHistCb] 344 | 345 | model.fit_generator(generator=datagen.flow(X_train, Y_train, batch_size=params.batch_size), 346 | steps_per_epoch=(len(X_train)+params.batch_size-1) // params.batch_size, 347 | epochs=params.num_epochs, 348 | verbose=1, 349 | callbacks=callbacks, 350 | validation_data=(X_val, Y_val)) 351 | 352 | # Dump histories. 353 | np.savetxt('{}_test_loss.txt'.format(params.mode), np.asarray(testErrCb.loss_history)) 354 | np.savetxt('{}_test_acc.txt'.format(params.mode), np.asarray(testErrCb.acc_history)) 355 | np.savetxt('{}_train_loss.txt'.format(params.mode), np.asarray(trainValHistCb.train_loss)) 356 | np.savetxt('{}_train_acc.txt'.format(params.mode), np.asarray(trainValHistCb.train_acc)) 357 | np.savetxt('{}_val_loss.txt'.format(params.mode), np.asarray(trainValHistCb.val_loss)) 358 | np.savetxt('{}_val_acc.txt'.format(params.mode), np.asarray(trainValHistCb.val_acc)) 359 | 360 | 361 | if __name__ == '__main__': 362 | param_dict = {"mode": "quaternion", 363 | "num_blocks": 10, 364 | "start_filter": 24, 365 | "dropout": 0, 366 | "batch_size": 32, 367 | "num_epochs": 200, 368 | "dataset": "cifar100", 369 | "act": "relu", 370 | "init": "quaternion", 371 | "lr": 1e-3, 372 | "momentum": 0.9, 373 | "decay": 0, 374 | "clipnorm": 1.0 375 | } 376 | 377 | params = Params(param_dict) 378 | model = getModel(params) 379 | train(params, model) -------------------------------------------------------------------------------- /training_mnist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras 3 | from keras.models import Model 4 | from keras.datasets import mnist 5 | from keras.models import Sequential 6 | from keras.layers import Dense, Dropout, Flatten, Input, Activation, Convolution2D, concatenate 7 | from quaternion_layers.dense import QuaternionDense 8 | from quaternion_layers.conv import QuaternionConv2D 9 | from quaternion_layers.bn import QuaternionBatchNormalization 10 | from keras.preprocessing.image import ImageDataGenerator 11 | from keras import backend as K 12 | 13 | batch_size = 16 14 | num_classes = 10 15 | epochs = 1200 16 | 17 | def learnVectorBlock(I): 18 | """Learn initial vector component for input.""" 19 | 20 | O = Convolution2D(1, (5, 5), 21 | padding='same', activation='relu')(I) 22 | 23 | return O 24 | 25 | # input image dimensions 26 | img_rows, img_cols = 28, 28 27 | 28 | # the data, split between train and test sets 29 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 30 | 31 | if K.image_data_format() == 'channels_first': 32 | x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 33 | x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 34 | input_shape = (1, img_rows, img_cols) 35 | else: 36 | x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 37 | x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 38 | input_shape = (img_rows, img_cols, 1) 39 | 40 | x_train = x_train.astype('float32') 41 | x_test = x_test.astype('float32') 42 | print('x_train shape:', x_train.shape) 43 | print(x_train.shape[0], 'train samples') 44 | print(x_test.shape[0], 'test samples') 45 | 46 | datagen = ImageDataGenerator( 47 | featurewise_center=True, 48 | featurewise_std_normalization=True, 49 | width_shift_range=0.1, 50 | height_shift_range=0.1) 51 | 52 | datagen.fit(x_train) 53 | 54 | # convert class vectors to binary class matrices 55 | y_train = keras.utils.to_categorical(y_train, num_classes) 56 | y_test = keras.utils.to_categorical(y_test, num_classes) 57 | 58 | R = Input(shape=input_shape) 59 | I = learnVectorBlock(R) 60 | J = learnVectorBlock(R) 61 | K = learnVectorBlock(R) 62 | O = concatenate([R, I, J, K], axis=-1) 63 | O = QuaternionConv2D(64, (5, 5), activation='relu', padding="same", kernel_initializer='quaternion')(O) 64 | O = QuaternionConv2D(64, (5, 5), activation='relu', padding="same", kernel_initializer='quaternion')(O) 65 | O = QuaternionConv2D(32, (5, 5), activation='relu', padding="same", kernel_initializer='quaternion')(O) 66 | 67 | # O = Convolution2D(256, (5, 5), activation='relu', padding="same")(R) 68 | # O = Convolution2D(256, (5, 5), activation='relu', padding="same")(O) 69 | # O = Convolution2D(128, (5, 5), activation='relu', padding="same")(O) 70 | 71 | O = Flatten()(O) 72 | O = QuaternionDense(82, activation='relu', kernel_initializer='quaternion')(O) 73 | O = QuaternionDense(48, activation='relu', kernel_initializer='quaternion')(O) 74 | #O = Dropout(0.5)(O) 75 | O = Dense(num_classes, activation='softmax')(O) 76 | 77 | model = Model(R, O) 78 | model.compile(loss=keras.losses.categorical_crossentropy, 79 | optimizer=keras.optimizers.Adam(), 80 | metrics=['accuracy']) 81 | 82 | hist = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), 83 | steps_per_epoch=len(x_train) / batch_size, epochs=epochs, 84 | validation_data=(x_test, y_test)) 85 | 86 | np.save('mnist_results.npy', hist.history['val_acc']) -------------------------------------------------------------------------------- /training_segmentation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Authors: Chase Gaudet 4 | # code based on work by Chiheb Trabelsi 5 | # on Deep Complex Networks git source 6 | 7 | # Imports 8 | import sys 9 | sys.setrecursionlimit(10000) 10 | import logging as L 11 | import numpy as np 12 | from complex_layers.utils import GetReal, GetImag 13 | from complex_layers.conv import ComplexConv2D 14 | from complex_layers.bn import ComplexBatchNormalization 15 | from quaternion_layers.utils import Params, GetR, GetI, GetJ, GetK 16 | from quaternion_layers.conv import QuaternionConv2D 17 | from quaternion_layers.bn import QuaternionBatchNormalization 18 | from batch_gen import gen_batch 19 | import keras 20 | from keras.callbacks import Callback, ModelCheckpoint, LearningRateScheduler 21 | from keras.datasets import cifar10, cifar100 22 | from keras.layers import Layer, AveragePooling2D, AveragePooling3D, add, Add, concatenate, Concatenate, Input, Flatten, Dense, Convolution2D, BatchNormalization, Activation, Reshape, ConvLSTM2D, Conv2D 23 | from keras.models import Model, load_model, save_model 24 | from keras.optimizers import SGD 25 | from keras.preprocessing.image import ImageDataGenerator 26 | from keras.regularizers import l2 27 | from keras.utils.np_utils import to_categorical 28 | import keras.backend as K 29 | K.set_image_data_format('channels_first') 30 | K.set_image_dim_ordering('th') 31 | 32 | 33 | # Callbacks: 34 | # Print a newline after each epoch. 35 | class PrintNewlineAfterEpochCallback(Callback): 36 | def on_epoch_end(self, epoch, logs={}): 37 | sys.stdout.write("\n") 38 | 39 | 40 | # Keep a history of the validation performance. 41 | class TrainValHistory(Callback): 42 | def __init__(self): 43 | self.train_loss = [] 44 | self.val_loss = [] 45 | 46 | def on_epoch_end(self, epoch, logs={}): 47 | self.train_loss.append(logs.get('loss')) 48 | self.val_loss.append(logs.get('val_loss')) 49 | 50 | 51 | def schedule(epoch): 52 | if epoch >= 0 and epoch < 10: 53 | lrate = 0.01 54 | elif epoch >= 10 and epoch < 50: 55 | lrate = 0.1 56 | elif epoch >= 50 and epoch < 100: 57 | lrate = 0.01 58 | elif epoch >= 100 and epoch < 150: 59 | lrate = 0.001 60 | elif epoch >= 150: 61 | lrate = 0.0001 62 | return lrate 63 | 64 | 65 | def learnVectorBlock(I, featmaps, filter_size, act, bnArgs): 66 | """Learn initial vector component for input.""" 67 | 68 | O = BatchNormalization(**bnArgs)(I) 69 | O = Activation(act)(O) 70 | O = Convolution2D(featmaps, filter_size, 71 | padding='same', 72 | kernel_initializer='he_normal', 73 | use_bias=False, 74 | kernel_regularizer=l2(0.0001))(O) 75 | 76 | O = BatchNormalization(**bnArgs)(O) 77 | O = Activation(act)(O) 78 | O = Convolution2D(featmaps, filter_size, 79 | padding='same', 80 | kernel_initializer='he_normal', 81 | use_bias=False, 82 | kernel_regularizer=l2(0.0001))(O) 83 | 84 | return O 85 | 86 | 87 | def getResidualBlock(I, mode, filter_size, featmaps, activation, dropout, shortcut, convArgs, bnArgs): 88 | """Get residual block.""" 89 | 90 | if mode == "real": 91 | O = BatchNormalization(**bnArgs)(I) 92 | elif mode == "complex": 93 | O = ComplexBatchNormalization(**bnArgs)(I) 94 | elif mode == "quaternion": 95 | O = QuaternionBatchNormalization(**bnArgs)(I) 96 | O = Activation(activation)(O) 97 | 98 | if shortcut == 'regular': 99 | if mode == "real": 100 | O = Conv2D(featmaps, filter_size, **convArgs)(O) 101 | elif mode == "complex": 102 | O = ComplexConv2D(featmaps, filter_size, **convArgs)(O) 103 | elif mode == "quaternion": 104 | O = QuaternionConv2D(featmaps, filter_size, **convArgs)(O) 105 | elif shortcut == 'projection': 106 | if mode == "real": 107 | O = Conv2D(featmaps, filter_size, **convArgs)(O) 108 | elif mode == "complex": 109 | O = ComplexConv2D(featmaps, filter_size, **convArgs)(O) 110 | elif mode == "quaternion": 111 | O = QuaternionConv2D(featmaps, filter_size, **convArgs)(O) 112 | 113 | if mode == "real": 114 | O = BatchNormalization(**bnArgs)(O) 115 | O = Activation(activation)(O) 116 | O = Conv2D(featmaps, filter_size, **convArgs)(O) 117 | elif mode == "complex": 118 | O = ComplexBatchNormalization(**bnArgs)(O) 119 | O = Activation(activation)(O) 120 | O = ComplexConv2D(featmaps, filter_size, **convArgs)(O) 121 | elif mode == "quaternion": 122 | O = QuaternionBatchNormalization(**bnArgs)(O) 123 | O = Activation(activation)(O) 124 | O = QuaternionConv2D(featmaps, filter_size, **convArgs)(O) 125 | 126 | if shortcut == 'regular': 127 | O = Add()([O, I]) 128 | elif shortcut == 'projection': 129 | if mode == "real": 130 | X = Conv2D(featmaps, (1, 1), **convArgs)(I) 131 | O = Concatenate(1)([X, O]) 132 | elif mode == "complex": 133 | X = ComplexConv2D(featmaps, (1, 1), **convArgs)(I) 134 | O_real = Concatenate(1)([GetReal()(X), GetReal()(O)]) 135 | O_imag = Concatenate(1)([GetImag()(X), GetImag()(O)]) 136 | O = Concatenate(1)([O_real, O_imag]) 137 | elif mode == "quaternion": 138 | X = QuaternionConv2D(featmaps, (1, 1), **convArgs)(I) 139 | O_r = Concatenate(1)([GetR()(X), GetR()(O)]) 140 | O_i = Concatenate(1)([GetI()(X), GetI()(O)]) 141 | O_j = Concatenate(1)([GetJ()(X), GetJ()(O)]) 142 | O_k = Concatenate(1)([GetK()(X), GetK()(O)]) 143 | O = Concatenate(1)([O_r, O_i, O_j, O_k]) 144 | 145 | return O 146 | 147 | def dice_coef(y_true, y_pred): 148 | y_true_f = K.flatten(y_true) 149 | y_pred_f = K.flatten(y_pred) 150 | intersection = K.sum(y_true_f * y_pred_f) 151 | return (2. * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1) 152 | 153 | 154 | def dice_coef_loss(y_true, y_pred): 155 | return -dice_coef(y_true, y_pred) 156 | 157 | def getModel(params): 158 | mode = params.mode 159 | n = params.num_blocks 160 | sf = params.start_filter 161 | activation = params.act 162 | dropout = params.dropout 163 | inputShape = (3, 93, 310) 164 | channelAxis = 1 165 | filsize = (3, 3) 166 | convArgs = { 167 | "padding": "same", 168 | "use_bias": False, 169 | "kernel_regularizer": l2(0.0001), 170 | } 171 | bnArgs = { 172 | "axis": channelAxis, 173 | "momentum": 0.9, 174 | "epsilon": 1e-04, 175 | "scale": False 176 | } 177 | 178 | convArgs.update({"kernel_initializer": params.init}) 179 | 180 | # Create the vector channels 181 | R = Input(shape=inputShape) 182 | 183 | if mode != "quaternion": 184 | I = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 185 | O = concatenate([R, I], axis=channelAxis) 186 | else: 187 | I = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 188 | J = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 189 | K = learnVectorBlock(R, 3, filsize, 'relu', bnArgs) 190 | O = concatenate([R, I, J, K], axis=channelAxis) 191 | 192 | if mode == "real": 193 | O = Conv2D(sf, filsize, **convArgs)(O) 194 | O = BatchNormalization(**bnArgs)(O) 195 | elif mode == "complex": 196 | O = ComplexConv2D(sf, filsize, **convArgs)(O) 197 | O = ComplexBatchNormalization(**bnArgs)(O) 198 | else: 199 | O = QuaternionConv2D(sf, filsize, **convArgs)(O) 200 | O = QuaternionBatchNormalization(**bnArgs)(O) 201 | O = Activation(activation)(O) 202 | 203 | for i in range(n): 204 | O = getResidualBlock(O, mode, filsize, sf, activation, dropout, 'regular', convArgs, bnArgs) 205 | 206 | O = getResidualBlock(O, mode, filsize, sf, activation, dropout, 'projection', convArgs, bnArgs) 207 | 208 | for i in range(n-1): 209 | O = getResidualBlock(O, mode, filsize, sf*2, activation, dropout, 'regular', convArgs, bnArgs) 210 | 211 | O = getResidualBlock(O, mode, filsize, sf*2, activation, dropout, 'projection', convArgs, bnArgs) 212 | 213 | for i in range(n-1): 214 | O = getResidualBlock(O, mode, filsize, sf*4, activation, dropout, 'regular', convArgs, bnArgs) 215 | 216 | # heatmap output 217 | O = Convolution2D(1, 1, activation='sigmoid')(O) 218 | 219 | model = Model(R, O) 220 | opt = SGD (lr = params.lr, 221 | momentum = params.momentum, 222 | decay = params.decay, 223 | nesterov = True, 224 | clipnorm = params.clipnorm) 225 | model.compile(opt, dice_coef_loss) 226 | return model 227 | 228 | 229 | def train(params, model): 230 | image_shape = (3, 93, 310) 231 | batch_size = params.batch_size 232 | epochs = params.num_epochs 233 | 234 | 235 | lrSchedCb = LearningRateScheduler(schedule) 236 | trainValHist = TrainValHistory() 237 | callbacks = [ModelCheckpoint('{}_weights.hd5'.format(params.mode), monitor='val_loss', verbose=0, save_best_only=True), 238 | lrSchedCb, 239 | trainValHist] 240 | 241 | t_gen = gen_batch(image_shape, 150) 242 | v_gen = gen_batch(image_shape, 50) 243 | for Xvb, Yvb in v_gen: 244 | Xv = Xvb 245 | Yv = Yvb 246 | break 247 | 248 | e = 1 249 | while e <= epochs: 250 | Xt, Yt = next(t_gen) 251 | print('\nEPOCH: {}'.format(e)) 252 | model.fit(Xt, Yt, 253 | batch_size=batch_size, 254 | epochs=1, 255 | verbose=1, 256 | callbacks=callbacks, 257 | validation_data=(Xv,Yv)) 258 | e += 1 259 | 260 | np.savetxt('{}_seg_train_loss.txt'.format(params.mode), trainValHist.train_loss) 261 | np.savetxt('{}_seg_val_loss.txt'.format(params.mode), trainValHist.val_loss) 262 | 263 | 264 | if __name__ == '__main__': 265 | param_dict = {"mode": "quaternion", 266 | "num_blocks": 3, 267 | "start_filter": 8, 268 | "dropout": 0, 269 | "batch_size": 8, 270 | "num_epochs": 200, 271 | "act": "relu", 272 | "init": "quaternion", 273 | "lr": 1e-3, 274 | "momentum": 0.9, 275 | "decay": 0, 276 | "clipnorm": 1.0 277 | } 278 | 279 | params = Params(param_dict) 280 | model = getModel(params) 281 | print(model.count_params()) 282 | train(params, model) --------------------------------------------------------------------------------