├── README.md ├── complexnn ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── bn.cpython-36.pyc │ ├── bn.cpython-37.pyc │ ├── bn.cpython-38.pyc │ ├── conv.cpython-36.pyc │ ├── conv.cpython-37.pyc │ ├── dense.cpython-36.pyc │ ├── dense.cpython-37.pyc │ ├── fft.cpython-36.pyc │ ├── fft.cpython-37.pyc │ ├── init.cpython-36.pyc │ ├── init.cpython-37.pyc │ ├── norm.cpython-36.pyc │ ├── norm.cpython-37.pyc │ ├── pool.cpython-36.pyc │ ├── pool.cpython-37.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-37.pyc ├── bn.py ├── bn.pyc ├── conv.py ├── conv.pyc ├── dense.py ├── dense.pyc ├── fft.py ├── fft.pyc ├── init.py ├── init.pyc ├── norm.py ├── norm.pyc ├── pool.py ├── pool.pyc ├── utils.py └── utils.pyc ├── helper.py ├── main_cv-despecknet.ipynb ├── main_test.py ├── main_train.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Despeckling Polarimetric SAR Data Using a Multistream Complex-Valued Fully Convolutional Network 2 | 3 | A polarimetric synthetic aperture radar (PolSAR) sensor is able to collect images in different polarization states, making it a rich source of information for target characterization. PolSAR images are inherently affected by speckle. Therefore, before deriving ad hoc products from the data, the polarimetric covariance matrix needs to be estimated by reducing speckle. In recent years, deep learning-based despeckling methods have started to evolve from single-channel SAR images to PolSAR images. To this aim, deep learning-based approaches separate the real and imaginary components of the complex-valued covariance matrix and use them as independent channels in standard convolutional neural networks (CNNs). However, this approach neglects the mathematical relationship that exists between the real and imaginary components, resulting in suboptimal output. Here, we propose a multistream complex-valued fully convolutional network (FCN) (CV-deSpeckNet1) to reduce speckle and effectively estimate the PolSAR covariance matrix. To evaluate the performance of CV-deSpeckNet, we used Sentinel-1 dual polarimetric SAR images to compare against its real-valued counterpart that separates the real and imaginary parts of the complex covariance matrix. CV-deSpeckNet was also compared against the state of the art PolSAR despeckling methods. The results show that CV-deSpeckNet was able to be trained with a fewer number of samples, has a higher generalization capability, and resulted in higher accuracy than its real-valued counterpart and state-of-the-art PolSAR despeckling methods. These results showcase the potential of complex-valued deep learning for PolSAR despeckling. 4 | 5 | The pre-print for the paper describing cv-despecknet can be found here (https://arxiv.org/abs/2103.07394) 6 | 7 | In this implementation, both the input noisy images and their reference images should be in log form. Therefore, in this implementation, the noisy image reconstruction doesnot need to be converted to the linear scale by taking the matrix exponent, This is the only difference between this implementation and the one used in the paper. 8 | 9 | ![paper6_flowchart2](https://user-images.githubusercontent.com/48068921/112758977-4906ba00-8ff1-11eb-8e08-ce3cab3aaad7.png) 10 | 11 | ## Requirements 12 | To run the scripts locally, you need to install the keras-complex library (https://pypi.org/project/keras-complex/) along with keras==2.23 and tensorflow-gpu==1.13.1. There is a requirements text file included. For usage in google colab the folder complexnn (supplied in this repo) should be uploaded to your google drive. 13 | 14 | The keras-complex library used in this paper and its documentation can be found in https://github.com/JesperDramsch/keras-complex 15 | ## Usage 16 | To train the model run the script main_train.py. To test a trained model on new image run main_test.py. 17 | 18 | ## Reference 19 | If you use this implementation please cite our work as follows 20 | 21 | A. G. Mullissa, C. Persello and J. Reiche, (2021). 22 | Despeckling Polarimetric SAR Data Using a Multistream Complex-Valued Fully Convolutional Network 23 | IEEE Geoscience and Remote Sensing Letters, 1-5, doi:10.1109/LGRS.2021.3066311. 24 | 25 | -------------------------------------------------------------------------------- /complexnn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Olexa Bilaniuk 6 | # 7 | # What this module includes by default: 8 | from . import bn, conv, dense, init, norm, pool 9 | # from . import fft 10 | 11 | from .bn import ComplexBatchNormalization as ComplexBN 12 | from .conv import ( 13 | ComplexConv, 14 | ComplexConv1D, 15 | ComplexConv2D, 16 | ComplexConv3D, 17 | WeightNorm_Conv, 18 | ) 19 | from .dense import ComplexDense 20 | # from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2) 21 | from .init import ( 22 | ComplexIndependentFilters, 23 | IndependentFilters, 24 | ComplexInit, 25 | SqrtInit, 26 | ) 27 | from .norm import LayerNormalization, ComplexLayerNorm 28 | from .pool import SpectralPooling1D, SpectralPooling2D 29 | from .utils import ( 30 | get_realpart, 31 | get_imagpart, 32 | getpart_output_shape, 33 | GetImag, 34 | GetReal, 35 | GetAbs, 36 | ) 37 | -------------------------------------------------------------------------------- /complexnn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__init__.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/bn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/bn.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/bn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/bn.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/bn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/bn.cpython-38.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/conv.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/conv.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/dense.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/dense.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/dense.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/dense.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/fft.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/fft.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/fft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/fft.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/init.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/init.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/init.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/init.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/norm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/norm.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/norm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/norm.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/pool.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/pool.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/pool.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /complexnn/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /complexnn/bn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # pylint: disable=C0301,C0111 4 | # flake8: noqa 5 | 6 | # 7 | # Authors: Chiheb Trabelsi, Olexa Bilaniuk 8 | # 9 | # Note: The implementation of complex Batchnorm is based on 10 | # the Keras implementation of batch Normalization 11 | # available here: 12 | # https://github.com/fchollet/keras/blob/master/keras/layers/normalization.py 13 | 14 | import numpy as np 15 | from keras.layers import Layer, InputSpec 16 | from keras import initializers, regularizers, constraints 17 | import keras.backend as K 18 | 19 | 20 | def sqrt_init(shape, dtype=None): 21 | value = (1 / np.sqrt(2)) * K.ones(shape) 22 | return value 23 | 24 | 25 | def sanitizedInitGet(init): 26 | if init in ["sqrt_init"]: 27 | return sqrt_init 28 | else: 29 | return initializers.get(init) 30 | 31 | 32 | def sanitizedInitSer(init): 33 | if init in [sqrt_init]: 34 | return "sqrt_init" 35 | else: 36 | return initializers.serialize(init) 37 | 38 | 39 | def complex_standardization(input_centred, Vrr, Vii, Vri, 40 | layernorm=False, axis=-1): 41 | """Complex Standardization of input 42 | 43 | Arguments: 44 | input_centred -- Input Tensor 45 | Vrr -- Real component of covariance matrix V 46 | Vii -- Imaginary component of covariance matrix V 47 | Vri -- Non-diagonal component of covariance matrix V 48 | 49 | Keyword Arguments: 50 | layernorm {bool} -- Normalization (default: {False}) 51 | axis {int} -- Axis for Standardization (default: {-1}) 52 | 53 | Raises: 54 | ValueError: Mismatched dimensoins 55 | 56 | Returns: 57 | Complex standardized input 58 | """ 59 | 60 | ndim = K.ndim(input_centred) 61 | input_dim = K.shape(input_centred)[axis] // 2 62 | variances_broadcast = [1] * ndim 63 | variances_broadcast[axis] = input_dim 64 | if layernorm: 65 | variances_broadcast[0] = K.shape(input_centred)[0] 66 | 67 | # We require the covariance matrix's inverse square root. That first 68 | # requires square rooting, followed by inversion (I do this in that order 69 | # because during the computation of square root we compute the determinant 70 | # we'll need for inversion as well). 71 | 72 | # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD 73 | tau = Vrr + Vii 74 | # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because 75 | # SPD 76 | delta = (Vrr * Vii) - (Vri ** 2) 77 | 78 | s = K.sqrt(delta) # Determinant of square root matrix 79 | t = K.sqrt(tau + 2 * s) 80 | 81 | # The square root matrix could now be explicitly formed as 82 | # [ Vrr+s Vri ] 83 | # (1/t) [ Vir Vii+s ] 84 | # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 85 | # but we don't need to do this immediately since we can also simultaneously 86 | # invert. We can do this because we've already computed the determinant of 87 | # the square root matrix, and can thus invert it using the analytical 88 | # solution for 2x2 matrices 89 | # [ A B ] [ D -B ] 90 | # inv( [ C D ] ) = (1/det) [ -C A ] 91 | # http://mathworld.wolfram.com/MatrixInverse.html 92 | # Thus giving us 93 | # [ Vii+s -Vri ] 94 | # (1/s)(1/t)[ -Vir Vrr+s ] 95 | # So we proceed as follows: 96 | 97 | inverse_st = 1.0 / (s * t) 98 | Wrr = (Vii + s) * inverse_st 99 | Wii = (Vrr + s) * inverse_st 100 | Wri = -Vri * inverse_st 101 | 102 | # And we have computed the inverse square root matrix W = sqrt(V)! 103 | # Normalization. We multiply, x_normalized = W.x. 104 | 105 | # The returned result will be a complex standardized input 106 | # where the real and imaginary parts are obtained as follows: 107 | # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred 108 | # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred 109 | 110 | broadcast_Wrr = K.reshape(Wrr, variances_broadcast) 111 | broadcast_Wri = K.reshape(Wri, variances_broadcast) 112 | broadcast_Wii = K.reshape(Wii, variances_broadcast) 113 | 114 | cat_W_4_real = K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis) 115 | cat_W_4_imag = K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis) 116 | 117 | if (axis == 1 and ndim != 3) or ndim == 2: 118 | centred_real = input_centred[:, :input_dim] 119 | centred_imag = input_centred[:, input_dim:] 120 | elif ndim == 3: 121 | centred_real = input_centred[:, :, :input_dim] 122 | centred_imag = input_centred[:, :, input_dim:] 123 | elif axis == -1 and ndim == 4: 124 | centred_real = input_centred[:, :, :, :input_dim] 125 | centred_imag = input_centred[:, :, :, input_dim:] 126 | elif axis == -1 and ndim == 5: 127 | centred_real = input_centred[:, :, :, :, :input_dim] 128 | centred_imag = input_centred[:, :, :, :, input_dim:] 129 | else: 130 | raise ValueError( 131 | 'Incorrect Batchnorm combination of axis and dimensions. axis ' 132 | 'should be either 1 or -1. ' 133 | 'axis: ' + str(axis) + '; ndim: ' + str(ndim) + '.' 134 | ) 135 | rolled_input = K.concatenate([centred_imag, centred_real], axis=axis) 136 | 137 | output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input 138 | 139 | # Wrr * x_real_centered | Wii * x_imag_centered 140 | # + Wri * x_imag_centered | Wri * x_real_centered 141 | # ----------------------------------------------- 142 | # = output 143 | 144 | return output 145 | 146 | 147 | def ComplexBN(input_centred, Vrr, Vii, Vri, beta, 148 | gamma_rr, gamma_ri, gamma_ii, scale=True, 149 | center=True, layernorm=False, axis=-1): 150 | """Complex Batch Normalization 151 | 152 | Arguments: 153 | input_centred -- input data 154 | Vrr -- Real component of covariance matrix V 155 | Vii -- Imaginary component of covariance matrix V 156 | Vri -- Non-diagonal component of covariance matrix V 157 | beta -- Lernable shift parameter beta 158 | gamma_rr -- Scaling parameter gamma - rr component of 2x2 matrix 159 | gamma_ri -- Scaling parameter gamma - ri component of 2x2 matrix 160 | gamma_ii -- Scaling parameter gamma - ii component of 2x2 matrix 161 | 162 | Keyword Arguments: 163 | scale {bool} {bool} -- Standardization of input (default: {True}) 164 | center {bool} -- Mean-shift correction (default: {True}) 165 | layernorm {bool} -- Normalization (default: {False}) 166 | axis {int} -- Axis for Standardization (default: {-1}) 167 | 168 | Raises: 169 | ValueError: Dimonsional mismatch 170 | 171 | Returns: 172 | Batch-Normalized Input 173 | """ 174 | 175 | ndim = K.ndim(input_centred) 176 | input_dim = K.shape(input_centred)[axis] // 2 177 | if scale: 178 | gamma_broadcast_shape = [1] * ndim 179 | gamma_broadcast_shape[axis] = input_dim 180 | if center: 181 | broadcast_beta_shape = [1] * ndim 182 | broadcast_beta_shape[axis] = input_dim * 2 183 | 184 | if scale: 185 | standardized_output = complex_standardization( 186 | input_centred, Vrr, Vii, Vri, 187 | layernorm, 188 | axis=axis 189 | ) 190 | 191 | # Now we perform th scaling and Shifting of the normalized x using 192 | # the scaling parameter 193 | # [ gamma_rr gamma_ri ] 194 | # Gamma = [ gamma_ri gamma_ii ] 195 | # and the shifting parameter 196 | # Beta = [beta_real beta_imag].T 197 | # where: 198 | # x_real_BN = gamma_rr * x_real_normed + 199 | # gamma_ri * x_imag_normed + beta_real 200 | # x_imag_BN = gamma_ri * x_real_normed + 201 | # gamma_ii * x_imag_normed + beta_imag 202 | 203 | broadcast_gamma_rr = K.reshape(gamma_rr, gamma_broadcast_shape) 204 | broadcast_gamma_ri = K.reshape(gamma_ri, gamma_broadcast_shape) 205 | broadcast_gamma_ii = K.reshape(gamma_ii, gamma_broadcast_shape) 206 | 207 | cat_gamma_4_real = K.concatenate([broadcast_gamma_rr, broadcast_gamma_ii], axis=axis) 208 | cat_gamma_4_imag = K.concatenate([broadcast_gamma_ri, broadcast_gamma_ri], axis=axis) 209 | if (axis == 1 and ndim != 3) or ndim == 2: 210 | centred_real = standardized_output[:, :input_dim] 211 | centred_imag = standardized_output[:, input_dim:] 212 | elif ndim == 3: 213 | centred_real = standardized_output[:, :, :input_dim] 214 | centred_imag = standardized_output[:, :, input_dim:] 215 | elif axis == -1 and ndim == 4: 216 | centred_real = standardized_output[:, :, :, :input_dim] 217 | centred_imag = standardized_output[:, :, :, input_dim:] 218 | elif axis == -1 and ndim == 5: 219 | centred_real = standardized_output[:, :, :, :, :input_dim] 220 | centred_imag = standardized_output[:, :, :, :, input_dim:] 221 | else: 222 | raise ValueError( 223 | 'Incorrect Batchnorm combination of axis and dimensions. axis' 224 | ' should be either 1 or -1. ' 225 | 'axis: ' + str(axis) + '; ndim: ' + str(ndim) + '.' 226 | ) 227 | rolled_standardized_output = K.concatenate([centred_imag, centred_real], axis=axis) 228 | if center: 229 | broadcast_beta = K.reshape(beta, broadcast_beta_shape) 230 | return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta 231 | else: 232 | return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output 233 | else: 234 | if center: 235 | broadcast_beta = K.reshape(beta, broadcast_beta_shape) 236 | return input_centred + broadcast_beta 237 | else: 238 | return input_centred 239 | 240 | 241 | class ComplexBatchNormalization(Layer): 242 | """Complex version of the real domain 243 | Batch normalization layer (Ioffe and Szegedy, 2014). 244 | Normalize the activations of the previous complex layer at each batch, 245 | i.e. applies a transformation that maintains the mean of a complex unit 246 | close to the null vector, the 2 by 2 covariance matrix of a complex unit close to identity 247 | and the 2 by 2 relation matrix, also called pseudo-covariance, close to the 248 | null matrix. 249 | # Arguments 250 | axis: Integer, the axis that should be normalized 251 | (typically the features axis). 252 | For instance, after a `Conv2D` layer with 253 | `data_format="channels_first"`, 254 | set `axis=2` in `ComplexBatchNormalization`. 255 | momentum: Momentum for the moving statistics related to the real and 256 | imaginary parts. 257 | epsilon: Small float added to each of the variances related to the 258 | real and imaginary parts in order to avoid dividing by zero. 259 | center: If True, add offset of `beta` to complex normalized tensor. 260 | If False, `beta` is ignored. 261 | (beta is formed by real_beta and imag_beta) 262 | scale: If True, multiply by the `gamma` matrix. 263 | If False, `gamma` is not used. 264 | beta_initializer: Initializer for the real_beta and the imag_beta weight. 265 | gamma_diag_initializer: Initializer for the diagonal elements of the gamma matrix. 266 | which are the variances of the real part and the imaginary part. 267 | gamma_off_initializer: Initializer for the off-diagonal elements of the gamma matrix. 268 | moving_mean_initializer: Initializer for the moving means. 269 | moving_variance_initializer: Initializer for the moving variances. 270 | moving_covariance_initializer: Initializer for the moving covariance of 271 | the real and imaginary parts. 272 | beta_regularizer: Optional regularizer for the beta weights. 273 | gamma_regularizer: Optional regularizer for the gamma weights. 274 | beta_constraint: Optional constraint for the beta weights. 275 | gamma_constraint: Optional constraint for the gamma weights. 276 | # Input shape 277 | Arbitrary. Use the keyword argument `input_shape` 278 | (tuple of integers, does not include the samples axis) 279 | when using this layer as the first layer in a model. 280 | # Output shape 281 | Same shape as input. 282 | # References 283 | - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167) 284 | """ 285 | 286 | def __init__(self, 287 | axis=-1, 288 | momentum=0.9, 289 | epsilon=1e-4, 290 | center=True, 291 | scale=True, 292 | beta_initializer='zeros', 293 | gamma_diag_initializer='sqrt_init', 294 | gamma_off_initializer='zeros', 295 | moving_mean_initializer='zeros', 296 | moving_variance_initializer='sqrt_init', 297 | moving_covariance_initializer='zeros', 298 | beta_regularizer=None, 299 | gamma_diag_regularizer=None, 300 | gamma_off_regularizer=None, 301 | beta_constraint=None, 302 | gamma_diag_constraint=None, 303 | gamma_off_constraint=None, 304 | **kwargs): 305 | super(ComplexBatchNormalization, self).__init__(**kwargs) 306 | self.supports_masking = True 307 | self.axis = axis 308 | self.momentum = momentum 309 | self.epsilon = epsilon 310 | self.center = center 311 | self.scale = scale 312 | self.beta_initializer = sanitizedInitGet(beta_initializer) 313 | self.gamma_diag_initializer = sanitizedInitGet(gamma_diag_initializer) 314 | self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer) 315 | self.moving_mean_initializer = sanitizedInitGet(moving_mean_initializer) 316 | self.moving_variance_initializer = sanitizedInitGet(moving_variance_initializer) 317 | self.moving_covariance_initializer = sanitizedInitGet(moving_covariance_initializer) 318 | self.beta_regularizer = regularizers.get(beta_regularizer) 319 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 320 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 321 | self.beta_constraint = constraints .get(beta_constraint) 322 | self.gamma_diag_constraint = constraints .get(gamma_diag_constraint) 323 | self.gamma_off_constraint = constraints .get(gamma_off_constraint) 324 | 325 | def build(self, input_shape): 326 | 327 | ndim = len(input_shape) 328 | 329 | dim = input_shape[self.axis] 330 | if dim is None: 331 | raise ValueError('Axis ' + str(self.axis) + ' of ' 332 | 'input tensor should have a defined dimension ' 333 | 'but the layer received an input with shape ' + 334 | str(input_shape) + '.') 335 | self.input_spec = InputSpec(ndim=len(input_shape), 336 | axes={self.axis: dim}) 337 | 338 | param_shape = (input_shape[self.axis] // 2,) 339 | 340 | if self.scale: 341 | self.gamma_rr = self.add_weight(shape=param_shape, 342 | name='gamma_rr', 343 | initializer=self.gamma_diag_initializer, 344 | regularizer=self.gamma_diag_regularizer, 345 | constraint=self.gamma_diag_constraint) 346 | self.gamma_ii = self.add_weight(shape=param_shape, 347 | name='gamma_ii', 348 | initializer=self.gamma_diag_initializer, 349 | regularizer=self.gamma_diag_regularizer, 350 | constraint=self.gamma_diag_constraint) 351 | self.gamma_ri = self.add_weight(shape=param_shape, 352 | name='gamma_ri', 353 | initializer=self.gamma_off_initializer, 354 | regularizer=self.gamma_off_regularizer, 355 | constraint=self.gamma_off_constraint) 356 | self.moving_Vrr = self.add_weight(shape=param_shape, 357 | initializer=self.moving_variance_initializer, 358 | name='moving_Vrr', 359 | trainable=False) 360 | self.moving_Vii = self.add_weight(shape=param_shape, 361 | initializer=self.moving_variance_initializer, 362 | name='moving_Vii', 363 | trainable=False) 364 | self.moving_Vri = self.add_weight(shape=param_shape, 365 | initializer=self.moving_covariance_initializer, 366 | name='moving_Vri', 367 | trainable=False) 368 | else: 369 | self.gamma_rr = None 370 | self.gamma_ii = None 371 | self.gamma_ri = None 372 | self.moving_Vrr = None 373 | self.moving_Vii = None 374 | self.moving_Vri = None 375 | 376 | if self.center: 377 | self.beta = self.add_weight(shape=(input_shape[self.axis],), 378 | name='beta', 379 | initializer=self.beta_initializer, 380 | regularizer=self.beta_regularizer, 381 | constraint=self.beta_constraint) 382 | self.moving_mean = self.add_weight(shape=(input_shape[self.axis],), 383 | initializer=self.moving_mean_initializer, 384 | name='moving_mean', 385 | trainable=False) 386 | else: 387 | self.beta = None 388 | self.moving_mean = None 389 | 390 | self.built = True 391 | 392 | def call(self, inputs, training=None): 393 | input_shape = K.int_shape(inputs) 394 | ndim = len(input_shape) 395 | reduction_axes = list(range(ndim)) 396 | del reduction_axes[self.axis] 397 | input_dim = input_shape[self.axis] // 2 398 | mu = K.mean(inputs, axis=reduction_axes) 399 | broadcast_mu_shape = [1] * len(input_shape) 400 | broadcast_mu_shape[self.axis] = input_shape[self.axis] 401 | broadcast_mu = K.reshape(mu, broadcast_mu_shape) 402 | if self.center: 403 | input_centred = inputs - broadcast_mu 404 | else: 405 | input_centred = inputs 406 | centred_squared = input_centred ** 2 407 | if (self.axis == 1 and ndim != 3) or ndim == 2: 408 | centred_squared_real = centred_squared[:, :input_dim] 409 | centred_squared_imag = centred_squared[:, input_dim:] 410 | centred_real = input_centred[:, :input_dim] 411 | centred_imag = input_centred[:, input_dim:] 412 | elif ndim == 3: 413 | centred_squared_real = centred_squared[:, :, :input_dim] 414 | centred_squared_imag = centred_squared[:, :, input_dim:] 415 | centred_real = input_centred[:, :, :input_dim] 416 | centred_imag = input_centred[:, :, input_dim:] 417 | elif self.axis == -1 and ndim == 4: 418 | centred_squared_real = centred_squared[:, :, :, :input_dim] 419 | centred_squared_imag = centred_squared[:, :, :, input_dim:] 420 | centred_real = input_centred[:, :, :, :input_dim] 421 | centred_imag = input_centred[:, :, :, input_dim:] 422 | elif self.axis == -1 and ndim == 5: 423 | centred_squared_real = centred_squared[:, :, :, :, :input_dim] 424 | centred_squared_imag = centred_squared[:, :, :, :, input_dim:] 425 | centred_real = input_centred[:, :, :, :, :input_dim] 426 | centred_imag = input_centred[:, :, :, :, input_dim:] 427 | else: 428 | raise ValueError( 429 | 'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. ' 430 | 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.' 431 | ) 432 | if self.scale: 433 | Vrr = K.mean( 434 | centred_squared_real, 435 | axis=reduction_axes 436 | ) + self.epsilon 437 | Vii = K.mean( 438 | centred_squared_imag, 439 | axis=reduction_axes 440 | ) + self.epsilon 441 | # Vri contains the real and imaginary covariance for each feature map. 442 | Vri = K.mean( 443 | centred_real * centred_imag, 444 | axis=reduction_axes, 445 | ) 446 | elif self.center: 447 | Vrr = None 448 | Vii = None 449 | Vri = None 450 | else: 451 | raise ValueError('Error. Both scale and center in batchnorm are set to False.') 452 | 453 | input_bn = ComplexBN( 454 | input_centred, Vrr, Vii, Vri, 455 | self.beta, self.gamma_rr, self.gamma_ri, 456 | self.gamma_ii, self.scale, self.center, 457 | axis=self.axis 458 | ) 459 | if training in {0, False}: 460 | return input_bn 461 | else: 462 | update_list = [] 463 | if self.center: 464 | update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum)) 465 | if self.scale: 466 | update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum)) 467 | update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum)) 468 | update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum)) 469 | self.add_update(update_list, inputs) 470 | 471 | def normalize_inference(): 472 | if self.center: 473 | inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape) 474 | else: 475 | inference_centred = inputs 476 | return ComplexBN( 477 | inference_centred, self.moving_Vrr, self.moving_Vii, 478 | self.moving_Vri, self.beta, self.gamma_rr, self.gamma_ri, 479 | self.gamma_ii, self.scale, self.center, axis=self.axis 480 | ) 481 | 482 | # Pick the normalized form corresponding to the training phase. 483 | return K.in_train_phase(input_bn, 484 | normalize_inference, 485 | training=training) 486 | 487 | def get_config(self): 488 | config = { 489 | 'axis': self.axis, 490 | 'momentum': self.momentum, 491 | 'epsilon': self.epsilon, 492 | 'center': self.center, 493 | 'scale': self.scale, 494 | 'beta_initializer': sanitizedInitSer(self.beta_initializer), 495 | 'gamma_diag_initializer': sanitizedInitSer(self.gamma_diag_initializer), 496 | 'gamma_off_initializer': sanitizedInitSer(self.gamma_off_initializer), 497 | 'moving_mean_initializer': sanitizedInitSer(self.moving_mean_initializer), 498 | 'moving_variance_initializer': sanitizedInitSer(self.moving_variance_initializer), 499 | 'moving_covariance_initializer': sanitizedInitSer(self.moving_covariance_initializer), 500 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 501 | 'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer), 502 | 'gamma_off_regularizer': regularizers.serialize(self.gamma_off_regularizer), 503 | 'beta_constraint': constraints .serialize(self.beta_constraint), 504 | 'gamma_diag_constraint': constraints .serialize(self.gamma_diag_constraint), 505 | 'gamma_off_constraint': constraints .serialize(self.gamma_off_constraint), 506 | } 507 | base_config = super(ComplexBatchNormalization, self).get_config() 508 | return dict(list(base_config.items()) + list(config.items())) 509 | -------------------------------------------------------------------------------- /complexnn/bn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/bn.pyc -------------------------------------------------------------------------------- /complexnn/conv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """conv.py""" 4 | # pylint: disable=protected-access, consider-using-enumerate, too-many-lines 5 | 6 | # 7 | # Authors: Chiheb Trabelsi 8 | 9 | from keras import backend as K 10 | from keras import activations, initializers, regularizers, constraints 11 | from keras.layers import ( 12 | Layer, 13 | InputSpec, 14 | ) 15 | from keras.layers.convolutional import _Conv 16 | from keras.utils import conv_utils 17 | from keras.backend.common import normalize_data_format 18 | import numpy as np 19 | from .fft import fft, ifft, fft2, ifft2 20 | from .bn import ComplexBN as complex_normalization 21 | from .bn import sqrt_init 22 | from .init import ComplexInit, ComplexIndependentFilters 23 | 24 | 25 | def conv2d_transpose( 26 | inputs, 27 | filter, # pylint: disable=redefined-builtin 28 | kernel_size=None, 29 | filters=None, 30 | strides=(1, 1), 31 | padding="SAME", 32 | output_padding=None, 33 | data_format="channels_last"): 34 | """Compatibility layer for K.conv2d_transpose 35 | 36 | Take a filter defined for forward convolution and adjusts it for a 37 | transposed convolution.""" 38 | input_shape = K.shape(inputs) 39 | batch_size = input_shape[0] 40 | if data_format == 'channels_first': 41 | h_axis, w_axis = 2, 3 42 | else: 43 | h_axis, w_axis = 1, 2 44 | 45 | height, width = input_shape[h_axis], input_shape[w_axis] 46 | kernel_h, kernel_w = kernel_size 47 | stride_h, stride_w = strides 48 | 49 | # Infer the dynamic output shape: 50 | out_height = conv_utils.deconv_length(height, 51 | stride_h, kernel_h, 52 | padding, output_padding) 53 | out_width = conv_utils.deconv_length(width, 54 | stride_w, kernel_w, 55 | padding, output_padding) 56 | if data_format == 'channels_first': 57 | output_shape = (batch_size, filters, out_height, out_width) 58 | else: 59 | output_shape = (batch_size, out_height, out_width, filters) 60 | 61 | filter = K.permute_dimensions(filter, (0, 1, 3, 2)) 62 | return K.conv2d_transpose(inputs, filter, output_shape, strides, 63 | padding=padding, 64 | data_format=data_format) 65 | 66 | 67 | def ifft(f): 68 | """Stub""" 69 | raise NotImplementedError(str(f)) 70 | 71 | 72 | def ifft2(f): 73 | """Stub""" 74 | raise NotImplementedError(str(f)) 75 | 76 | 77 | def conv_transpose_output_length( 78 | input_length, filter_size, padding, stride, dilation=1, output_padding=None): 79 | """Rearrange arguments for compatibility with conv_output_length.""" 80 | if dilation != 1: 81 | msg = "Dilation must be 1 for transposed convolution. " 82 | msg += "Got dilation = {dilation}" 83 | raise ValueError(msg) 84 | return conv_utils.deconv_length( 85 | input_length, # dim_size 86 | stride, # stride_size 87 | filter_size, # kernel_size 88 | padding, # padding 89 | output_padding, # output_padding 90 | ) 91 | 92 | 93 | def sanitizedInitGet(init): 94 | """sanitizedInitGet""" 95 | if init in ["sqrt_init"]: 96 | return sqrt_init 97 | elif init in [ 98 | "complex", "complex_independent", "glorot_complex", "he_complex" 99 | ]: 100 | return init 101 | else: 102 | return initializers.get(init) 103 | 104 | 105 | def sanitizedInitSer(init): 106 | """sanitizedInitSer""" 107 | if init in [sqrt_init]: 108 | return "sqrt_init" 109 | elif init == "complex" or isinstance(init, ComplexInit): 110 | return "complex" 111 | elif init == "complex_independent" or isinstance( 112 | init, ComplexIndependentFilters 113 | ): 114 | return "complex_independent" 115 | else: 116 | return initializers.serialize(init) 117 | 118 | 119 | class ComplexConv(Layer): 120 | """Abstract nD complex convolution layer. 121 | 122 | This layer creates a complex convolution kernel that is convolved with the 123 | layer input to produce a tensor of outputs. If `use_bias` is True, a bias 124 | vector is created and added to the outputs. Finally, if `activation` is not 125 | `None`, it is applied to the outputs as well. 126 | 127 | Arguments: 128 | rank: Integer, the rank of the convolution, e.g., "2" for 2D 129 | convolution. 130 | filters: Integer, the dimensionality of the output space, i.e., the 131 | number of complex feature maps. It is also the effective number of 132 | feature maps for each of the real and imaginary parts. (I.e., the 133 | number of complex filters in the convolution) The total effective 134 | number of filters is 2 x filters. 135 | kernel_size: An integer or tuple/list of n integers, specifying the 136 | dimensions of the convolution window. 137 | strides: An integer or tuple/list of n integers, specifying the strides 138 | of the convolution. Specifying any stride value != 1 is 139 | incompatible with specifying any `dilation_rate` value != 1. 140 | padding: One of `"valid"` or `"same"` (case-insensitive). 141 | data_format: A string, one of `channels_last` (default) or 142 | `channels_first`. The ordering of the dimensions in the inputs. 143 | `channels_last` corresponds to inputs with shape 144 | `(batch, ..., channels)` while `channels_first` corresponds to 145 | inputs with shape `(batch, channels, ...)`. It defaults to the 146 | `image_data_format` value found in your Keras config file at 147 | `~/.keras/keras.json`. If you never set it, then it will be 148 | "channels_last". 149 | dilation_rate: An integer or tuple/list of n integers, specifying 150 | the dilation rate to use for dilated convolution. Currently, 151 | specifying any `dilation_rate` value != 1 is incompatible with 152 | specifying any `strides` value != 1. 153 | activation: Activation function to use (see keras.activations). If you 154 | don't specify anything, no activation is applied (i.e., "linear" 155 | activation: `a(x) = x`). 156 | use_bias: Boolean, whether the layer uses a bias vector. 157 | normalize_weight: Boolean, whether the layer normalizes its complex 158 | weights before convolving the complex input. The complex 159 | normalization performed is similar to the one for the batchnorm. 160 | Each of the complex kernels is centred and multiplied by the 161 | inverse square root of the covariance matrix. Then a complex 162 | multiplication is performed as the normalized weights are 163 | multiplied by the complex scaling factor gamma. 164 | kernel_initializer: Initializer for the complex `kernel` weights 165 | matrix. By default it is 'complex'. The 'complex_independent' 166 | and the usual initializers could also be used. (See 167 | keras.initializers and init.py). 168 | bias_initializer: Initializer for the bias vector 169 | (see keras.initializers). 170 | kernel_regularizer: Regularizer function applied to the `kernel` 171 | weights matrix (see keras.regularizers). 172 | bias_regularizer: Regularizer function applied to the bias vector 173 | (see keras.regularizers). 174 | activity_regularizer: Regularizer function applied to the output of the 175 | layer (its "activation"). (See keras.regularizers). 176 | kernel_constraint: Constraint function applied to the kernel matrix 177 | (see keras.constraints). 178 | bias_constraint: Constraint function applied to the bias vector 179 | (see keras.constraints). 180 | spectral_parametrization: Boolean, whether or not to use a spectral 181 | parametrization of the parameters. 182 | transposed: Boolean, whether or not to use transposed convolution 183 | """ 184 | 185 | def __init__( 186 | self, 187 | rank, 188 | filters, 189 | kernel_size, 190 | strides=1, 191 | padding="valid", 192 | data_format=None, 193 | dilation_rate=1, 194 | activation=None, 195 | use_bias=True, 196 | normalize_weight=False, 197 | kernel_initializer="complex", 198 | bias_initializer="zeros", 199 | gamma_diag_initializer=sqrt_init, 200 | gamma_off_initializer="zeros", 201 | kernel_regularizer=None, 202 | bias_regularizer=None, 203 | gamma_diag_regularizer=None, 204 | gamma_off_regularizer=None, 205 | activity_regularizer=None, 206 | kernel_constraint=None, 207 | bias_constraint=None, 208 | gamma_diag_constraint=None, 209 | gamma_off_constraint=None, 210 | init_criterion="he", 211 | seed=None, 212 | spectral_parametrization=False, 213 | transposed=False, 214 | epsilon=1e-7, 215 | **kwargs): 216 | super(ComplexConv, self).__init__(**kwargs) 217 | self.rank = rank 218 | self.filters = filters 219 | self.kernel_size = conv_utils.normalize_tuple( 220 | kernel_size, rank, "kernel_size" 221 | ) 222 | self.strides = conv_utils.normalize_tuple(strides, rank, "strides") 223 | self.padding = conv_utils.normalize_padding(padding) 224 | self.data_format = "channels_last" \ 225 | if rank == 1 else normalize_data_format(data_format) 226 | self.dilation_rate = conv_utils.normalize_tuple( 227 | dilation_rate, rank, "dilation_rate" 228 | ) 229 | self.activation = activations.get(activation) 230 | self.use_bias = use_bias 231 | self.normalize_weight = normalize_weight 232 | self.init_criterion = init_criterion 233 | self.spectral_parametrization = spectral_parametrization 234 | self.transposed = transposed 235 | self.epsilon = epsilon 236 | self.kernel_initializer = sanitizedInitGet(kernel_initializer) 237 | self.bias_initializer = sanitizedInitGet(bias_initializer) 238 | self.gamma_diag_initializer = sanitizedInitGet(gamma_diag_initializer) 239 | self.gamma_off_initializer = sanitizedInitGet(gamma_off_initializer) 240 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 241 | self.bias_regularizer = regularizers.get(bias_regularizer) 242 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 243 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 244 | self.activity_regularizer = regularizers.get(activity_regularizer) 245 | self.kernel_constraint = constraints.get(kernel_constraint) 246 | self.bias_constraint = constraints.get(bias_constraint) 247 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 248 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 249 | if seed is None: 250 | self.seed = np.random.randint(1, 10e6) 251 | else: 252 | self.seed = seed 253 | self.input_spec = InputSpec(ndim=self.rank + 2) 254 | 255 | # The following are initialized later 256 | self.kernel_shape = None 257 | self.kernel = None 258 | self.gamma_rr = None 259 | self.gamma_ii = None 260 | self.gamma_ri = None 261 | self.bias = None 262 | 263 | def build(self, input_shape): 264 | """build""" 265 | if self.data_format == "channels_first": 266 | channel_axis = 1 267 | else: 268 | channel_axis = -1 269 | if input_shape[channel_axis] is None: 270 | raise ValueError( 271 | "The channel dimension of the inputs " 272 | "should be defined. Found `None`." 273 | ) 274 | input_dim = input_shape[channel_axis] // 2 # Divide by 2 for real and complex input. 275 | if False and self.transposed: 276 | self.kernel_shape = self.kernel_size + (self.filters, input_dim) 277 | else: 278 | self.kernel_shape = self.kernel_size + (input_dim, self.filters) 279 | # The kernel shape here is a complex kernel shape: 280 | # nb of complex feature maps = input_dim; 281 | # nb of output complex feature maps = self.filters; 282 | # imaginary kernel size = real kernel size 283 | # = self.kernel_size 284 | # = complex kernel size 285 | if self.kernel_initializer in {"complex", "complex_independent"}: 286 | kls = { 287 | "complex": ComplexInit, 288 | "complex_independent": ComplexIndependentFilters, 289 | }[ 290 | self.kernel_initializer 291 | ] 292 | kern_init = kls( 293 | kernel_size=self.kernel_size, 294 | input_dim=input_dim, 295 | weight_dim=self.rank, 296 | nb_filters=self.filters, 297 | criterion=self.init_criterion, 298 | ) 299 | else: 300 | kern_init = self.kernel_initializer 301 | 302 | self.kernel = self.add_weight( 303 | "kernel", 304 | self.kernel_shape, 305 | initializer=kern_init, 306 | regularizer=self.kernel_regularizer, 307 | constraint=self.kernel_constraint, 308 | ) 309 | 310 | if self.normalize_weight: 311 | gamma_shape = (input_dim * self.filters,) 312 | self.gamma_rr = self.add_weight( 313 | shape=gamma_shape, 314 | name="gamma_rr", 315 | initializer=self.gamma_diag_initializer, 316 | regularizer=self.gamma_diag_regularizer, 317 | constraint=self.gamma_diag_constraint, 318 | ) 319 | self.gamma_ii = self.add_weight( 320 | shape=gamma_shape, 321 | name="gamma_ii", 322 | initializer=self.gamma_diag_initializer, 323 | regularizer=self.gamma_diag_regularizer, 324 | constraint=self.gamma_diag_constraint, 325 | ) 326 | self.gamma_ri = self.add_weight( 327 | shape=gamma_shape, 328 | name="gamma_ri", 329 | initializer=self.gamma_off_initializer, 330 | regularizer=self.gamma_off_regularizer, 331 | constraint=self.gamma_off_constraint, 332 | ) 333 | else: 334 | self.gamma_rr = None 335 | self.gamma_ii = None 336 | self.gamma_ri = None 337 | 338 | if self.use_bias: 339 | bias_shape = (2 * self.filters,) 340 | self.bias = self.add_weight( 341 | "bias", 342 | bias_shape, 343 | initializer=self.bias_initializer, 344 | regularizer=self.bias_regularizer, 345 | constraint=self.bias_constraint, 346 | ) 347 | 348 | else: 349 | self.bias = None 350 | 351 | # Set input spec. 352 | self.input_spec = InputSpec( 353 | ndim=self.rank + 2, axes={channel_axis: input_dim * 2} 354 | ) 355 | self.built = True 356 | 357 | def call(self, inputs, **kwargs): 358 | if self.data_format == "channels_first": 359 | channel_axis = 1 360 | else: 361 | channel_axis = -1 362 | input_dim = K.shape(inputs)[channel_axis] // 2 363 | if False and self.transposed: 364 | if self.rank == 1: 365 | f_real = self.kernel[:, :self.filters, :] 366 | f_imag = self.kernel[:, self.filters:, :] 367 | elif self.rank == 2: 368 | f_real = self.kernel[:, :, :self.filters, :] 369 | f_imag = self.kernel[:, :, self.filters:, :] 370 | elif self.rank == 3: 371 | f_real = self.kernel[:, :, :, :self.filters, :] 372 | f_imag = self.kernel[:, :, :, self.filters:, :] 373 | else: 374 | if self.rank == 1: 375 | f_real = self.kernel[:, :, :self.filters] 376 | f_imag = self.kernel[:, :, self.filters:] 377 | elif self.rank == 2: 378 | f_real = self.kernel[:, :, :, :self.filters] 379 | f_imag = self.kernel[:, :, :, self.filters:] 380 | elif self.rank == 3: 381 | f_real = self.kernel[:, :, :, :, :self.filters] 382 | f_imag = self.kernel[:, :, :, :, self.filters:] 383 | 384 | convArgs = { 385 | "strides": self.strides[0] if self.rank == 1 else self.strides, 386 | "padding": self.padding, 387 | "data_format": self.data_format, 388 | "dilation_rate": 389 | self.dilation_rate[0] 390 | if self.rank == 1 391 | else self.dilation_rate, 392 | } 393 | if self.transposed: 394 | convArgs.pop("dilation_rate", None) 395 | convArgs["kernel_size"] = self.kernel_size 396 | convArgs["filters"] = 2 * self.filters 397 | convFunc = { 398 | 2: conv2d_transpose}[self.rank] 399 | else: 400 | convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank] 401 | 402 | # processing if the weights are assumed to be represented in the 403 | # spectral domain 404 | 405 | if self.spectral_parametrization: 406 | if self.rank == 1: 407 | f_real = K.permute_dimensions(f_real, (2, 1, 0)) 408 | f_imag = K.permute_dimensions(f_imag, (2, 1, 0)) 409 | f = K.concatenate([f_real, f_imag], axis=0) 410 | fshape = K.shape(f) 411 | f = K.reshape(f, (fshape[0] * fshape[1], fshape[2])) 412 | f = ifft(f) 413 | f = K.reshape(f, fshape) 414 | f_real = f[:fshape[0] // 2] 415 | f_imag = f[fshape[0] // 2:] 416 | f_real = K.permute_dimensions(f_real, (2, 1, 0)) 417 | f_imag = K.permute_dimensions(f_imag, (2, 1, 0)) 418 | elif self.rank == 2: 419 | f_real = K.permute_dimensions(f_real, (3, 2, 0, 1)) 420 | f_imag = K.permute_dimensions(f_imag, (3, 2, 0, 1)) 421 | f = K.concatenate([f_real, f_imag], axis=0) 422 | fshape = K.shape(f) 423 | f = K.reshape(f, (fshape[0] * fshape[1], fshape[2], fshape[3])) 424 | f = ifft2(f) 425 | f = K.reshape(f, fshape) 426 | f_real = f[:fshape[0] // 2] 427 | f_imag = f[fshape[0] // 2:] 428 | f_real = K.permute_dimensions(f_real, (2, 3, 1, 0)) 429 | f_imag = K.permute_dimensions(f_imag, (2, 3, 1, 0)) 430 | 431 | # In case of weight normalization, real and imaginary weights are 432 | # normalized 433 | 434 | if self.normalize_weight: 435 | ker_shape = self.kernel_shape 436 | nb_kernels = ker_shape[-2] * ker_shape[-1] 437 | kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels) 438 | reshaped_f_real = K.reshape(f_real, kernel_shape_4_norm) 439 | reshaped_f_imag = K.reshape(f_imag, kernel_shape_4_norm) 440 | reduction_axes = list(range(2)) 441 | del reduction_axes[-1] 442 | mu_real = K.mean(reshaped_f_real, axis=reduction_axes) 443 | mu_imag = K.mean(reshaped_f_imag, axis=reduction_axes) 444 | 445 | broadcast_mu_shape = [1] * 2 446 | broadcast_mu_shape[-1] = nb_kernels 447 | broadcast_mu_real = K.reshape(mu_real, broadcast_mu_shape) 448 | broadcast_mu_imag = K.reshape(mu_imag, broadcast_mu_shape) 449 | reshaped_f_real_centred = reshaped_f_real - broadcast_mu_real 450 | reshaped_f_imag_centred = reshaped_f_imag - broadcast_mu_imag 451 | Vrr = K.mean( 452 | reshaped_f_real_centred ** 2, axis=reduction_axes 453 | ) + self.epsilon 454 | Vii = K.mean( 455 | reshaped_f_imag_centred ** 2, axis=reduction_axes 456 | ) + self.epsilon 457 | Vri = K.mean( 458 | reshaped_f_real_centred * reshaped_f_imag_centred, 459 | axis=reduction_axes, 460 | ) + self.epsilon 461 | 462 | normalized_weight = complex_normalization( 463 | K.concatenate([reshaped_f_real, reshaped_f_imag], axis=-1), 464 | Vrr, 465 | Vii, 466 | Vri, 467 | beta=None, 468 | gamma_rr=self.gamma_rr, 469 | gamma_ri=self.gamma_ri, 470 | gamma_ii=self.gamma_ii, 471 | scale=True, 472 | center=False, 473 | axis=-1, 474 | ) 475 | 476 | normalized_real = normalized_weight[:, :nb_kernels] 477 | normalized_imag = normalized_weight[:, nb_kernels:] 478 | f_real = K.reshape(normalized_real, self.kernel_shape) 479 | f_imag = K.reshape(normalized_imag, self.kernel_shape) 480 | 481 | # Performing complex convolution 482 | 483 | f_real._keras_shape = self.kernel_shape 484 | f_imag._keras_shape = self.kernel_shape 485 | 486 | cat_kernels_4_real = K.concatenate([f_real, -f_imag], axis=-2) 487 | cat_kernels_4_imag = K.concatenate([f_imag, f_real], axis=-2) 488 | cat_kernels_4_complex = K.concatenate( 489 | [cat_kernels_4_real, cat_kernels_4_imag], axis=-1 490 | ) 491 | if False and self.transposed: 492 | cat_kernels_4_complex._keras_shape = \ 493 | self.kernel_size + (2 * self.filters, 2 * input_dim) 494 | else: 495 | cat_kernels_4_complex._keras_shape = \ 496 | self.kernel_size + 2 * input_dim, 2 * self.filters 497 | 498 | output = convFunc(inputs, cat_kernels_4_complex, **convArgs) 499 | 500 | if self.use_bias: 501 | output = K.bias_add( 502 | output, self.bias, data_format=self.data_format) 503 | 504 | if self.activation is not None: 505 | output = self.activation(output) 506 | 507 | return output 508 | 509 | def compute_output_shape(self, input_shape): 510 | if self.transposed: 511 | outputLengthFunc = conv_transpose_output_length 512 | else: 513 | outputLengthFunc = conv_utils.conv_output_length 514 | if self.data_format == "channels_last": 515 | space = input_shape[1:-1] 516 | new_space = [] 517 | for i in range(len(space)): 518 | new_dim = outputLengthFunc( 519 | space[i], 520 | self.kernel_size[i], 521 | padding=self.padding, 522 | stride=self.strides[i], 523 | dilation=self.dilation_rate[i], 524 | ) 525 | new_space.append(new_dim) 526 | return (input_shape[0],) + tuple(new_space) + (2 * self.filters,) 527 | if self.data_format == "channels_first": 528 | space = input_shape[2:] 529 | new_space = [] 530 | for i in range(len(space)): 531 | new_dim = outputLengthFunc( 532 | space[i], 533 | self.kernel_size[i], 534 | padding=self.padding, 535 | stride=self.strides[i], 536 | dilation=self.dilation_rate[i], 537 | ) 538 | new_space.append(new_dim) 539 | return (input_shape[0],) + (2 * self.filters,) + tuple(new_space) 540 | 541 | def get_config(self): 542 | config = { 543 | "rank": self.rank, 544 | "filters": self.filters, 545 | "kernel_size": self.kernel_size, 546 | "strides": self.strides, 547 | "padding": self.padding, 548 | "data_format": self.data_format, 549 | "dilation_rate": self.dilation_rate, 550 | "activation": activations.serialize(self.activation), 551 | "use_bias": self.use_bias, 552 | "normalize_weight": self.normalize_weight, 553 | "kernel_initializer": sanitizedInitSer(self.kernel_initializer), 554 | "bias_initializer": sanitizedInitSer(self.bias_initializer), 555 | "gamma_diag_initializer": sanitizedInitSer( 556 | self.gamma_diag_initializer 557 | ), 558 | "gamma_off_initializer": sanitizedInitSer( 559 | self.gamma_off_initializer 560 | ), 561 | "kernel_regularizer": regularizers.serialize( 562 | self.kernel_regularizer 563 | ), 564 | "bias_regularizer": regularizers.serialize(self.bias_regularizer), 565 | "gamma_diag_regularizer": regularizers.serialize( 566 | self.gamma_diag_regularizer 567 | ), 568 | "gamma_off_regularizer": regularizers.serialize( 569 | self.gamma_off_regularizer 570 | ), 571 | "activity_regularizer": regularizers.serialize( 572 | self.activity_regularizer 573 | ), 574 | "kernel_constraint": constraints.serialize(self.kernel_constraint), 575 | "bias_constraint": constraints.serialize(self.bias_constraint), 576 | "gamma_diag_constraint": constraints.serialize( 577 | self.gamma_diag_constraint 578 | ), 579 | "gamma_off_constraint": constraints.serialize( 580 | self.gamma_off_constraint 581 | ), 582 | "init_criterion": self.init_criterion, 583 | "spectral_parametrization": self.spectral_parametrization, 584 | "transposed": self.transposed, 585 | } 586 | base_config = super(ComplexConv, self).get_config() 587 | return dict(list(base_config.items()) + list(config.items())) 588 | 589 | 590 | class ComplexConv1D(ComplexConv): 591 | """1D complex convolution layer. 592 | This layer creates a complex convolution kernel that is convolved 593 | with a complex input layer over a single complex spatial (or temporal) 594 | dimension 595 | to produce a complex output tensor. 596 | If `use_bias` is True, a bias vector is created and added to the complex 597 | output. 598 | Finally, if `activation` is not `None`, 599 | it is applied each of the real and imaginary parts of the output. 600 | When using this layer as the first layer in a model, 601 | provide an `input_shape` argument 602 | (tuple of integers or `None`, e.g. 603 | `(10, 128)` for sequences of 10 vectors of 128-dimensional vectors, 604 | or `(None, 128)` for variable-length sequences of 128-dimensional vectors. 605 | # Arguments 606 | filters: Integer, the dimensionality of the output space, i.e, 607 | the number of complex feature maps. It is also the effective number 608 | of feature maps for each of the real and imaginary parts. 609 | (i.e. the number of complex filters in the convolution) 610 | The total effective number of filters is 2 x filters. 611 | kernel_size: An integer or tuple/list of n integers, specifying the 612 | dimensions of the convolution window. 613 | strides: An integer or tuple/list of a single integer, 614 | specifying the stride length of the convolution. 615 | Specifying any stride value != 1 is incompatible with specifying 616 | any `dilation_rate` value != 1. 617 | padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive). 618 | `"causal"` results in causal (dilated) convolutions, e.g. output[t] 619 | does not depend on input[t+1:]. Useful when modeling temporal data 620 | where the model should not violate the temporal order. 621 | See [WaveNet: A Generative Model for Raw Audio, section 2.1] 622 | (https://arxiv.org/abs/1609.03499). 623 | dilation_rate: an integer or tuple/list of a single integer, specifying 624 | the dilation rate to use for dilated convolution. 625 | Currently, specifying any `dilation_rate` value != 1 is 626 | incompatible with specifying any `strides` value != 1. 627 | activation: Activation function to use 628 | (see keras.activations). 629 | If you don't specify anything, no activation is applied 630 | (ie. "linear" activation: `a(x) = x`). 631 | use_bias: Boolean, whether the layer uses a bias vector. 632 | normalize_weight: Boolean, whether the layer normalizes its complex 633 | weights before convolving the complex input. 634 | The complex normalization performed is similar to the one 635 | for the batchnorm. Each of the complex kernels are centred and 636 | multiplied by 637 | the inverse square root of covariance matrix. 638 | Then, a complex multiplication is perfromed as the normalized 639 | weights are 640 | multiplied by the complex scaling factor gamma. 641 | kernel_initializer: Initializer for the complex `kernel` weights 642 | matrix. 643 | By default it is 'complex'. The 'complex_independent' 644 | and the usual initializers could also be used. 645 | (see keras.initializers and init.py). 646 | bias_initializer: Initializer for the bias vector 647 | (see keras.initializers). 648 | kernel_regularizer: Regularizer function applied to 649 | the `kernel` weights matrix 650 | (see keras.regularizers). 651 | bias_regularizer: Regularizer function applied to the bias vector 652 | (see keras.regularizers). 653 | activity_regularizer: Regularizer function applied to 654 | the output of the layer (its "activation"). 655 | (see keras.regularizers). 656 | kernel_constraint: Constraint function applied to the kernel matrix 657 | (see keras.constraints). 658 | bias_constraint: Constraint function applied to the bias vector 659 | (see keras.constraints). 660 | spectral_parametrization: Whether or not to use a spectral 661 | parametrization of the parameters. 662 | transposed: Boolean, whether or not to use transposed convolution 663 | # Input shape 664 | 3D tensor with shape: `(batch_size, steps, input_dim)` 665 | # Output shape 666 | 3D tensor with shape: `(batch_size, new_steps, 2 x filters)` 667 | `steps` value might have changed due to padding or strides. 668 | """ 669 | 670 | def __init__( 671 | self, 672 | filters, 673 | kernel_size, 674 | strides=1, 675 | padding="valid", 676 | dilation_rate=1, 677 | activation=None, 678 | use_bias=True, 679 | kernel_initializer="complex", 680 | bias_initializer="zeros", 681 | kernel_regularizer=None, 682 | bias_regularizer=None, 683 | activity_regularizer=None, 684 | kernel_constraint=None, 685 | bias_constraint=None, 686 | seed=None, 687 | init_criterion="he", 688 | spectral_parametrization=False, 689 | transposed=False, 690 | **kwargs): 691 | super(ComplexConv1D, self).__init__( 692 | rank=1, 693 | filters=filters, 694 | kernel_size=kernel_size, 695 | strides=strides, 696 | padding=padding, 697 | data_format="channels_last", 698 | dilation_rate=dilation_rate, 699 | activation=activation, 700 | use_bias=use_bias, 701 | kernel_initializer=kernel_initializer, 702 | bias_initializer=bias_initializer, 703 | kernel_regularizer=kernel_regularizer, 704 | bias_regularizer=bias_regularizer, 705 | activity_regularizer=activity_regularizer, 706 | kernel_constraint=kernel_constraint, 707 | bias_constraint=bias_constraint, 708 | init_criterion=init_criterion, 709 | spectral_parametrization=spectral_parametrization, 710 | transposed=transposed, 711 | **kwargs) 712 | 713 | def get_config(self): 714 | config = super(ComplexConv1D, self).get_config() 715 | config.pop("rank") 716 | return config 717 | 718 | 719 | class ComplexConv2D(ComplexConv): 720 | """2D Complex convolution layer (e.g. spatial convolution over images). 721 | This layer creates a complex convolution kernel that is convolved 722 | with a complex input layer to produce a complex output tensor. If 723 | `use_bias` 724 | is True, a complex bias vector is created and added to the outputs. 725 | Finally, if `activation` is not `None`, it is applied to both the 726 | real and imaginary parts of the output. 727 | When using this layer as the first layer in a model, 728 | provide the keyword argument `input_shape` 729 | (tuple of integers, does not include the sample axis), 730 | e.g. `input_shape=(128, 128, 3)` for 128x128 RGB pictures 731 | in `data_format="channels_last"`. 732 | # Arguments 733 | filters: Integer, the dimensionality of the complex output space 734 | (i.e, the number complex feature maps in the convolution). The 735 | total effective number of filters or feature maps is 2 x filters. 736 | kernel_size: An integer or tuple/list of 2 integers, specifying the 737 | width and height of the 2D convolution window. 738 | Can be a single integer to specify the same value for 739 | all spatial dimensions. 740 | strides: An integer or tuple/list of 2 integers, 741 | specifying the strides of the convolution along the width and 742 | height. 743 | Can be a single integer to specify the same value for 744 | all spatial dimensions. 745 | Specifying any stride value != 1 is incompatible with specifying 746 | any `dilation_rate` value != 1. 747 | padding: one of `"valid"` or `"same"` (case-insensitive). 748 | data_format: A string, 749 | one of `channels_last` (default) or `channels_first`. 750 | The ordering of the dimensions in the inputs. 751 | `channels_last` corresponds to inputs with shape 752 | `(batch, height, width, channels)` while `channels_first` 753 | corresponds to inputs with shape 754 | `(batch, channels, height, width)`. 755 | It defaults to the `image_data_format` value found in your 756 | Keras config file at `~/.keras/keras.json`. 757 | If you never set it, then it will be "channels_last". 758 | dilation_rate: an integer or tuple/list of 2 integers, specifying 759 | the dilation rate to use for dilated convolution. 760 | Can be a single integer to specify the same value for 761 | all spatial dimensions. 762 | Currently, specifying any `dilation_rate` value != 1 is 763 | incompatible with specifying any stride value != 1. 764 | activation: Activation function to use 765 | (see keras.activations). 766 | If you don't specify anything, no activation is applied 767 | (ie. "linear" activation: `a(x) = x`). 768 | use_bias: Boolean, whether the layer uses a bias vector. 769 | normalize_weight: Boolean, whether the layer normalizes its complex 770 | weights before convolving the complex input. 771 | The complex normalization performed is similar to the one 772 | for the batchnorm. Each of the complex kernels are centred and 773 | multiplied by 774 | the inverse square root of covariance matrix. 775 | Then, a complex multiplication is perfromed as the normalized 776 | weights are 777 | multiplied by the complex scaling factor gamma. 778 | kernel_initializer: Initializer for the complex `kernel` weights 779 | matrix. 780 | By default it is 'complex'. The 'complex_independent' 781 | and the usual initializers could also be used. 782 | (see keras.initializers and init.py). 783 | bias_initializer: Initializer for the bias vector 784 | (see keras.initializers). 785 | kernel_regularizer: Regularizer function applied to 786 | the `kernel` weights matrix 787 | (see keras.regularizers). 788 | bias_regularizer: Regularizer function applied to the bias vector 789 | (see keras.regularizers). 790 | activity_regularizer: Regularizer function applied to 791 | the output of the layer (its "activation"). 792 | (see keras.regularizers). 793 | kernel_constraint: Constraint function applied to the kernel matrix 794 | (see keras.constraints). 795 | bias_constraint: Constraint function applied to the bias vector 796 | (see keras.constraints). 797 | spectral_parametrization: Whether or not to use a spectral 798 | parametrization of the parameters. 799 | transposed: Boolean, whether or not to use transposed convolution 800 | # Input shape 801 | 4D tensor with shape: 802 | `(samples, channels, rows, cols)` if data_format='channels_first' 803 | or 4D tensor with shape: 804 | `(samples, rows, cols, channels)` if data_format='channels_last'. 805 | # Output shape 806 | 4D tensor with shape: 807 | `(samples, 2 x filters, new_rows, new_cols)` if 808 | data_format='channels_first' or 4D tensor with shape: 809 | `(samples, new_rows, new_cols, 2 x filters)` if 810 | data_format='channels_last'. `rows` and `cols` values might have 811 | changed due to padding. 812 | """ 813 | 814 | def __init__( 815 | self, 816 | filters, 817 | kernel_size, 818 | strides=(1, 1), 819 | padding="valid", 820 | data_format=None, 821 | dilation_rate=(1, 1), 822 | activation=None, 823 | use_bias=True, 824 | kernel_initializer="complex", 825 | bias_initializer="zeros", 826 | kernel_regularizer=None, 827 | bias_regularizer=None, 828 | activity_regularizer=None, 829 | kernel_constraint=None, 830 | bias_constraint=None, 831 | seed=None, 832 | init_criterion="he", 833 | spectral_parametrization=False, 834 | transposed=False, 835 | **kwargs): 836 | super(ComplexConv2D, self).__init__( 837 | rank=2, 838 | filters=filters, 839 | kernel_size=kernel_size, 840 | strides=strides, 841 | padding=padding, 842 | data_format=data_format, 843 | dilation_rate=dilation_rate, 844 | activation=activation, 845 | use_bias=use_bias, 846 | kernel_initializer=kernel_initializer, 847 | bias_initializer=bias_initializer, 848 | kernel_regularizer=kernel_regularizer, 849 | bias_regularizer=bias_regularizer, 850 | activity_regularizer=activity_regularizer, 851 | kernel_constraint=kernel_constraint, 852 | bias_constraint=bias_constraint, 853 | init_criterion=init_criterion, 854 | spectral_parametrization=spectral_parametrization, 855 | transposed=transposed, 856 | **kwargs) 857 | 858 | def get_config(self): 859 | config = super(ComplexConv2D, self).get_config() 860 | config.pop("rank") 861 | return config 862 | 863 | 864 | class ComplexConv3D(ComplexConv): 865 | """3D convolution layer (e.g. spatial convolution over volumes). 866 | This layer creates a complex convolution kernel that is convolved 867 | with a complex layer input to produce a complex output tensor. 868 | If `use_bias` is True, 869 | a complex bias vector is created and added to the outputs. Finally, if 870 | `activation` is not `None`, it is applied to each of the real and imaginary 871 | parts of the output. 872 | When using this layer as the first layer in a model, 873 | provide the keyword argument `input_shape` 874 | (tuple of integers, does not include the sample axis), 875 | e.g. `input_shape=(2, 128, 128, 128, 3)` for 128x128x128 volumes 876 | with 3 channels, 877 | in `data_format="channels_last"`. 878 | # Arguments 879 | filters: Integer, the dimensionality of the complex output space 880 | (i.e, the number complex feature maps in the convolution). The 881 | total effective number of filters or feature maps is 2 x filters. 882 | kernel_size: An integer or tuple/list of 3 integers, specifying the 883 | width and height of the 3D convolution window. 884 | Can be a single integer to specify the same value for 885 | all spatial dimensions. 886 | strides: An integer or tuple/list of 3 integers, specifying 887 | the strides of the convolution along each spatial dimension. 888 | Can be a single integer to specify the same value for 889 | all spatial dimensions. 890 | Specifying any stride value != 1 is incompatible with specifying 891 | any `dilation_rate` value != 1. 892 | padding: one of `"valid"` or `"same"` (case-insensitive). 893 | data_format: A string, 894 | one of `channels_last` (default) or `channels_first`. 895 | The ordering of the dimensions in the inputs. 896 | `channels_last` corresponds to inputs with shape 897 | `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` 898 | while `channels_first` corresponds to inputs with shape 899 | `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`. 900 | It defaults to the `image_data_format` value found in your 901 | Keras config file at `~/.keras/keras.json`. 902 | If you never set it, then it will be "channels_last". 903 | dilation_rate: an integer or tuple/list of 3 integers, specifying 904 | the dilation rate to use for dilated convolution. 905 | Can be a single integer to specify the same value for 906 | all spatial dimensions. 907 | Currently, specifying any `dilation_rate` value != 1 is 908 | incompatible with specifying any stride value != 1. 909 | activation: Activation function to use 910 | (see keras.activations). 911 | If you don't specify anything, no activation is applied 912 | (ie. "linear" activation: `a(x) = x`). 913 | use_bias: Boolean, whether the layer uses a bias vector. 914 | normalize_weight: Boolean, whether the layer normalizes its complex 915 | weights before convolving the complex input. 916 | The complex normalization performed is similar to the one 917 | for the batchnorm. Each of the complex kernels are centred and 918 | multiplied by 919 | the inverse square root of covariance matrix. 920 | Then, a complex multiplication is perfromed as the normalized 921 | weights are 922 | multiplied by the complex scaling factor gamma. 923 | kernel_initializer: Initializer for the complex `kernel` weights 924 | matrix. 925 | By default it is 'complex'. The 'complex_independent' 926 | and the usual initializers could also be used. 927 | (see keras.initializers and init.py). 928 | bias_initializer: Initializer for the bias vector 929 | (see keras.initializers). 930 | kernel_regularizer: Regularizer function applied to 931 | the `kernel` weights matrix 932 | (see keras.regularizers). 933 | bias_regularizer: Regularizer function applied to the bias vector 934 | (see keras.regularizers). 935 | activity_regularizer: Regularizer function applied to 936 | the output of the layer (its "activation"). 937 | (see keras.regularizers). 938 | kernel_constraint: Constraint function applied to the kernel matrix 939 | (see keras.constraints). 940 | bias_constraint: Constraint function applied to the bias vector 941 | (see keras.constraints). 942 | spectral_parametrization: Whether or not to use a spectral 943 | parametrization of the parameters. 944 | transposed: Boolean, whether or not to use transposed convolution 945 | # Input shape 946 | 5D tensor with shape: 947 | `(samples, channels, conv_dim1, conv_dim2, conv_dim3)` if 948 | data_format='channels_first' 949 | or 5D tensor with shape: 950 | `(samples, conv_dim1, conv_dim2, conv_dim3, channels)` if 951 | data_format='channels_last'. 952 | # Output shape 953 | 5D tensor with shape: 954 | `(samples, 2 x filters, new_conv_dim1, new_conv_dim2, new_conv_dim3)` 955 | if data_format='channels_first' 956 | or 5D tensor with shape: 957 | `(samples, new_conv_dim1, new_conv_dim2, new_conv_dim3, 2 x filters)` 958 | if data_format='channels_last'. 959 | `new_conv_dim1`, `new_conv_dim2` and `new_conv_dim3` values might have 960 | changed due to padding. 961 | """ 962 | 963 | def __init__( 964 | self, 965 | filters, 966 | kernel_size, 967 | strides=(1, 1, 1), 968 | padding="valid", 969 | data_format=None, 970 | dilation_rate=(1, 1, 1), 971 | activation=None, 972 | use_bias=True, 973 | kernel_initializer="complex", 974 | bias_initializer="zeros", 975 | kernel_regularizer=None, 976 | bias_regularizer=None, 977 | activity_regularizer=None, 978 | kernel_constraint=None, 979 | bias_constraint=None, 980 | seed=None, 981 | init_criterion="he", 982 | spectral_parametrization=False, 983 | transposed=False, 984 | **kwargs): 985 | super(ComplexConv3D, self).__init__( 986 | rank=3, 987 | filters=filters, 988 | kernel_size=kernel_size, 989 | strides=strides, 990 | padding=padding, 991 | data_format=data_format, 992 | dilation_rate=dilation_rate, 993 | activation=activation, 994 | use_bias=use_bias, 995 | kernel_initializer=kernel_initializer, 996 | bias_initializer=bias_initializer, 997 | kernel_regularizer=kernel_regularizer, 998 | bias_regularizer=bias_regularizer, 999 | activity_regularizer=activity_regularizer, 1000 | kernel_constraint=kernel_constraint, 1001 | bias_constraint=bias_constraint, 1002 | init_criterion=init_criterion, 1003 | spectral_parametrization=spectral_parametrization, 1004 | transposed=transposed, 1005 | **kwargs) 1006 | 1007 | def get_config(self): 1008 | config = super(ComplexConv3D, self).get_config() 1009 | config.pop("rank") 1010 | return config 1011 | 1012 | 1013 | class WeightNorm_Conv(_Conv): 1014 | """WeightNorm_Conv""" 1015 | # Real-valued Convolutional Layer that normalizes its weights 1016 | # before convolving the input. 1017 | # The weight Normalization performed the one 1018 | # described in the following paper: 1019 | # Weight Normalization: A Simple Reparameterization to Accelerate Training 1020 | # of Deep Neural Networks 1021 | # (see https://arxiv.org/abs/1602.07868) 1022 | 1023 | def __init__( 1024 | self, 1025 | gamma_initializer="ones", 1026 | gamma_regularizer=None, 1027 | gamma_constraint=None, 1028 | epsilon=1e-07, 1029 | **kwargs): 1030 | super(WeightNorm_Conv, self).__init__(**kwargs) 1031 | if self.rank == 1: 1032 | self.data_format = "channels_last" 1033 | self.gamma_initializer = sanitizedInitGet(gamma_initializer) 1034 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 1035 | self.gamma_constraint = constraints.get(gamma_constraint) 1036 | self.epsilon = epsilon 1037 | self.gamma = None 1038 | 1039 | def build(self, input_shape): 1040 | super(WeightNorm_Conv, self).build(input_shape) 1041 | if self.data_format == "channels_first": 1042 | channel_axis = 1 1043 | else: 1044 | channel_axis = -1 1045 | if input_shape[channel_axis] is None: 1046 | raise ValueError( 1047 | "The channel dimension of the inputs " 1048 | "should be defined. Found `None`." 1049 | ) 1050 | input_dim = input_shape[channel_axis] 1051 | gamma_shape = (input_dim * self.filters,) 1052 | self.gamma = self.add_weight( 1053 | shape=gamma_shape, 1054 | name="gamma", 1055 | initializer=self.gamma_initializer, 1056 | regularizer=self.gamma_regularizer, 1057 | constraint=self.gamma_constraint, 1058 | ) 1059 | 1060 | def call(self, inputs): 1061 | input_shape = K.shape(inputs) 1062 | if self.data_format == "channels_first": 1063 | channel_axis = 1 1064 | else: 1065 | channel_axis = -1 1066 | if input_shape[channel_axis] is None: 1067 | raise ValueError( 1068 | "The channel dimension of the inputs " 1069 | "should be defined. Found `None`." 1070 | ) 1071 | input_dim = input_shape[channel_axis] 1072 | ker_shape = self.kernel_size + (input_dim, self.filters) 1073 | nb_kernels = ker_shape[-2] * ker_shape[-1] 1074 | kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels) 1075 | reshaped_kernel = K.reshape(self.kernel, kernel_shape_4_norm) 1076 | normalized_weight = K.l2_normalize( 1077 | reshaped_kernel, axis=0, epsilon=self.epsilon 1078 | ) 1079 | normalized_weight = K.reshape( 1080 | self.gamma, (1, ker_shape[-2] * ker_shape[-1]) 1081 | ) * normalized_weight 1082 | shaped_kernel = K.reshape(normalized_weight, ker_shape) 1083 | shaped_kernel._keras_shape = ker_shape 1084 | 1085 | convArgs = { 1086 | "strides": self.strides[0] if self.rank == 1 else self.strides, 1087 | "padding": self.padding, 1088 | "data_format": self.data_format, 1089 | "dilation_rate": 1090 | self.dilation_rate[0] 1091 | if self.rank == 1 1092 | else self.dilation_rate, 1093 | } 1094 | convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank] 1095 | output = convFunc(inputs, shaped_kernel, **convArgs) 1096 | 1097 | if self.use_bias: 1098 | output = K.bias_add(output, 1099 | self.bias, 1100 | data_format=self.data_format) 1101 | 1102 | if self.activation is not None: 1103 | output = self.activation(output) 1104 | 1105 | return output 1106 | 1107 | def get_config(self): 1108 | config = { 1109 | "gamma_initializer": sanitizedInitSer(self.gamma_initializer), 1110 | "gamma_regularizer": 1111 | regularizers.serialize(self.gamma_regularizer), 1112 | "gamma_constraint": constraints.serialize(self.gamma_constraint), 1113 | "epsilon": self.epsilon, 1114 | } 1115 | base_config = super(WeightNorm_Conv, self).get_config() 1116 | return dict(list(base_config.items()) + list(config.items())) 1117 | 1118 | 1119 | # Aliases 1120 | ComplexConvolution1D = ComplexConv1D 1121 | ComplexConvolution2D = ComplexConv2D 1122 | ComplexConvolution3D = ComplexConv3D 1123 | -------------------------------------------------------------------------------- /complexnn/conv.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/conv.pyc -------------------------------------------------------------------------------- /complexnn/dense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chiheb Trabelsi 6 | # 7 | 8 | from keras import backend as K 9 | import sys; sys.path.append('.') 10 | from keras import backend as K 11 | from keras import activations, initializers, regularizers, constraints 12 | from keras.layers import Layer, InputSpec 13 | import numpy as np 14 | from numpy.random import RandomState 15 | 16 | class ComplexDense(Layer): 17 | """Regular complex densely-connected NN layer. 18 | `Dense` implements the operation: 19 | `real_preact = dot(real_input, real_kernel) - dot(imag_input, imag_kernel)` 20 | `imag_preact = dot(real_input, imag_kernel) + dot(imag_input, real_kernel)` 21 | `output = activation(K.concatenate([real_preact, imag_preact]) + bias)` 22 | where `activation` is the element-wise activation function 23 | passed as the `activation` argument, `kernel` is a weights matrix 24 | created by the layer, and `bias` is a bias vector created by the layer 25 | (only applicable if `use_bias` is `True`). 26 | Note: if the input to the layer has a rank greater than 2, then 27 | AN ERROR MESSAGE IS PRINTED. 28 | # Arguments 29 | units: Positive integer, dimensionality of each of the real part 30 | and the imaginary part. It is actualy the number of complex units. 31 | activation: Activation function to use 32 | (see keras.activations). 33 | If you don't specify anything, no activation is applied 34 | (ie. "linear" activation: `a(x) = x`). 35 | use_bias: Boolean, whether the layer uses a bias vector. 36 | kernel_initializer: Initializer for the complex `kernel` weights matrix. 37 | By default it is 'complex'. 38 | and the usual initializers could also be used. 39 | (see keras.initializers and init.py). 40 | bias_initializer: Initializer for the bias vector 41 | (see keras.initializers). 42 | kernel_regularizer: Regularizer function applied to 43 | the `kernel` weights matrix 44 | (see keras.regularizers). 45 | bias_regularizer: Regularizer function applied to the bias vector 46 | (see keras.regularizers). 47 | activity_regularizer: Regularizer function applied to 48 | the output of the layer (its "activation"). 49 | (see keras.regularizers). 50 | kernel_constraint: Constraint function applied to the kernel matrix 51 | (see keras.constraints). 52 | bias_constraint: Constraint function applied to the bias vector 53 | (see keras.constraints). 54 | # Input shape 55 | a 2D input with shape `(batch_size, input_dim)`. 56 | # Output shape 57 | For a 2D input with shape `(batch_size, input_dim)`, 58 | the output would have shape `(batch_size, units)`. 59 | """ 60 | 61 | def __init__(self, units, 62 | activation=None, 63 | use_bias=True, 64 | init_criterion='he', 65 | kernel_initializer='complex', 66 | bias_initializer='zeros', 67 | kernel_regularizer=None, 68 | bias_regularizer=None, 69 | activity_regularizer=None, 70 | kernel_constraint=None, 71 | bias_constraint=None, 72 | seed=None, 73 | **kwargs): 74 | if 'input_shape' not in kwargs and 'input_dim' in kwargs: 75 | kwargs['input_shape'] = (kwargs.pop('input_dim'),) 76 | super(ComplexDense, self).__init__(**kwargs) 77 | self.units = units 78 | self.activation = activations.get(activation) 79 | self.use_bias = use_bias 80 | self.init_criterion = init_criterion 81 | if kernel_initializer in {'complex'}: 82 | self.kernel_initializer = kernel_initializer 83 | else: 84 | self.kernel_initializer = initializers.get(kernel_initializer) 85 | self.bias_initializer = initializers.get(bias_initializer) 86 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 87 | self.bias_regularizer = regularizers.get(bias_regularizer) 88 | self.activity_regularizer = regularizers.get(activity_regularizer) 89 | self.kernel_constraint = constraints.get(kernel_constraint) 90 | self.bias_constraint = constraints.get(bias_constraint) 91 | if seed is None: 92 | self.seed = np.random.randint(1, 10e6) 93 | else: 94 | self.seed = seed 95 | self.input_spec = InputSpec(ndim=2) 96 | self.supports_masking = True 97 | 98 | def build(self, input_shape): 99 | assert len(input_shape) == 2 100 | assert input_shape[-1] % 2 == 0 101 | input_dim = input_shape[-1] // 2 102 | data_format = K.image_data_format() 103 | kernel_shape = (input_dim, self.units) 104 | fan_in, fan_out = initializers._compute_fans( 105 | kernel_shape, 106 | data_format=data_format 107 | ) 108 | if self.init_criterion == 'he': 109 | s = np.sqrt(1. / fan_in) 110 | elif self.init_criterion == 'glorot': 111 | s = np.sqrt(1. / (fan_in + fan_out)) 112 | rng = RandomState(seed=self.seed) 113 | 114 | # Equivalent initialization using amplitude phase representation: 115 | """modulus = rng.rayleigh(scale=s, size=kernel_shape) 116 | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) 117 | def init_w_real(shape, dtype=None): 118 | return modulus * K.cos(phase) 119 | def init_w_imag(shape, dtype=None): 120 | return modulus * K.sin(phase)""" 121 | 122 | # Initialization using euclidean representation: 123 | def init_w_real(shape, dtype=None): 124 | return rng.normal( 125 | size=kernel_shape, 126 | loc=0, 127 | scale=s, 128 | ).astype(dtype) 129 | def init_w_imag(shape, dtype=None): 130 | return rng.normal( 131 | size=kernel_shape, 132 | loc=0, 133 | scale=s 134 | ).astype(dtype) 135 | if self.kernel_initializer in {'complex'}: 136 | real_init = init_w_real 137 | imag_init = init_w_imag 138 | else: 139 | real_init = self.kernel_initializer 140 | imag_init = self.kernel_initializer 141 | 142 | self.real_kernel = self.add_weight( 143 | shape=kernel_shape, 144 | initializer=real_init, 145 | name='real_kernel', 146 | regularizer=self.kernel_regularizer, 147 | constraint=self.kernel_constraint 148 | ) 149 | self.imag_kernel = self.add_weight( 150 | shape=kernel_shape, 151 | initializer=imag_init, 152 | name='imag_kernel', 153 | regularizer=self.kernel_regularizer, 154 | constraint=self.kernel_constraint 155 | ) 156 | 157 | if self.use_bias: 158 | self.bias = self.add_weight( 159 | shape=(2 * self.units,), 160 | initializer=self.bias_initializer, 161 | name='bias', 162 | regularizer=self.bias_regularizer, 163 | constraint=self.bias_constraint 164 | ) 165 | else: 166 | self.bias = None 167 | 168 | self.input_spec = InputSpec(ndim=2, axes={-1: 2 * input_dim}) 169 | self.built = True 170 | 171 | def call(self, inputs): 172 | input_shape = K.shape(inputs) 173 | input_dim = input_shape[-1] // 2 174 | real_input = inputs[:, :input_dim] 175 | imag_input = inputs[:, input_dim:] 176 | 177 | cat_kernels_4_real = K.concatenate( 178 | [self.real_kernel, -self.imag_kernel], 179 | axis=-1 180 | ) 181 | cat_kernels_4_imag = K.concatenate( 182 | [self.imag_kernel, self.real_kernel], 183 | axis=-1 184 | ) 185 | cat_kernels_4_complex = K.concatenate( 186 | [cat_kernels_4_real, cat_kernels_4_imag], 187 | axis=0 188 | ) 189 | 190 | output = K.dot(inputs, cat_kernels_4_complex) 191 | 192 | if self.use_bias: 193 | output = K.bias_add(output, self.bias) 194 | if self.activation is not None: 195 | output = self.activation(output) 196 | 197 | return output 198 | 199 | def compute_output_shape(self, input_shape): 200 | assert input_shape and len(input_shape) == 2 201 | assert input_shape[-1] 202 | output_shape = list(input_shape) 203 | output_shape[-1] = 2 * self.units 204 | return tuple(output_shape) 205 | 206 | def get_config(self): 207 | if self.kernel_initializer in {'complex'}: 208 | ki = self.kernel_initializer 209 | else: 210 | ki = initializers.serialize(self.kernel_initializer) 211 | config = { 212 | 'units': self.units, 213 | 'activation': activations.serialize(self.activation), 214 | 'use_bias': self.use_bias, 215 | 'init_criterion': self.init_criterion, 216 | 'kernel_initializer': ki, 217 | 'bias_initializer': initializers.serialize(self.bias_initializer), 218 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 219 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 220 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 221 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 222 | 'bias_constraint': constraints.serialize(self.bias_constraint), 223 | 'seed': self.seed, 224 | } 225 | base_config = super(ComplexDense, self).get_config() 226 | return dict(list(base_config.items()) + list(config.items())) 227 | 228 | -------------------------------------------------------------------------------- /complexnn/dense.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/dense.pyc -------------------------------------------------------------------------------- /complexnn/fft.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Olexa Bilaniuk 6 | # 7 | 8 | import keras.backend as KB 9 | import keras.engine as KE 10 | import keras.layers as KL 11 | import keras.optimizers as KO 12 | import tensorflow as tf 13 | import numpy as np 14 | 15 | 16 | # 17 | # FFT functions: 18 | # 19 | # fft(): Batched 1-D FFT (Input: (Batch, TimeSamples)) 20 | # ifft(): Batched 1-D IFFT (Input: (Batch, FreqSamples)) 21 | # fft2(): Batched 2-D FFT (Input: (Batch, TimeSamplesH, TimeSamplesW)) 22 | # ifft2(): Batched 2-D IFFT (Input: (Batch, FreqSamplesH, FreqSamplesW)) 23 | # 24 | 25 | def fft(z): 26 | B = z.shape[0]//2 27 | L = z.shape[1] 28 | C = tf.Variable(np.asarray([[[1,-1]]], dtype=tf.float32)) 29 | Zr, Zi = tf.signal.rfft(z[:B]), tf.signal.rfft(z[B:]) 30 | isOdd = tf.equal(L%2, 1) 31 | Zr = tf.cond(isOdd, tf.concat([Zr, C*Zr[:,1: ][:,::-1]], axis=1), 32 | tf.concat([Zr, C*Zr[:,1:-1][:,::-1]], axis=1)) 33 | Zi = tf.cond(isOdd, tf.concat([Zi, C*Zi[:,1: ][:,::-1]], axis=1), 34 | tf.concat([Zi, C*Zi[:,1:-1][:,::-1]], axis=1)) 35 | Zi = (C*Zi)[:,:,::-1] # Zi * i 36 | Z = Zr+Zi 37 | return tf.concat([Z[:,:,0], Z[:,:,1]], axis=0) 38 | 39 | 40 | def ifft(z): 41 | B = z.shape[0]//2 42 | L = z.shape[1] 43 | C = tf.Variable(np.asarray([[[1,-1]]], dtype=tf.float32)) 44 | Zr, Zi = tf.signal.rfft(z[:B]), tf.signal.rfft(z[B:]*-1) 45 | isOdd = tf.equal(L%2, 1) 46 | Zr = tf.cond(isOdd, tf.concat([Zr, C*Zr[:,1: ][:,::-1]], axis=1), 47 | tf.concat([Zr, C*Zr[:,1:-1][:,::-1]], axis=1)) 48 | Zi = tf.cond(isOdd, tf.concat([Zi, C*Zi[:,1: ][:,::-1]], axis=1), 49 | tf.concat([Zi, C*Zi[:,1:-1][:,::-1]], axis=1)) 50 | Zi = (C*Zi)[:,:,::-1] # Zi * i 51 | Z = Zr+Zi 52 | return tf.concat([Z[:,:,0], Z[:,:,1]*-1], axis=0) 53 | 54 | def fft2(x): 55 | tt = x 56 | tt = KB.reshape(tt, (x.shape[0] *x.shape[1], x.shape[2])) 57 | tf = fft(tt) 58 | tf = KB.reshape(tf, (x.shape[0], x.shape[1], x.shape[2])) 59 | tf = KB.permute_dimensions(tf, (0, 2, 1)) 60 | tf = KB.reshape(tf, (x.shape[0] *x.shape[2], x.shape[1])) 61 | ff = fft(tf) 62 | ff = KB.reshape(ff, (x.shape[0], x.shape[2], x.shape[1])) 63 | ff = KB.permute_dimensions(ff, (0, 2, 1)) 64 | return ff 65 | def ifft2(x): 66 | ff = x 67 | ff = KB.permute_dimensions(ff, (0, 2, 1)) 68 | ff = KB.reshape(ff, (x.shape[0] *x.shape[2], x.shape[1])) 69 | tf = ifft(ff) 70 | tf = KB.reshape(tf, (x.shape[0], x.shape[2], x.shape[1])) 71 | tf = KB.permute_dimensions(tf, (0, 2, 1)) 72 | tf = KB.reshape(tf, (x.shape[0] *x.shape[1], x.shape[2])) 73 | tt = ifft(tf) 74 | tt = KB.reshape(tt, (x.shape[0], x.shape[1], x.shape[2])) 75 | return tt 76 | 77 | # 78 | # FFT Layers: 79 | # 80 | # FFT: Batched 1-D FFT (Input: (Batch, FeatureMaps, TimeSamples)) 81 | # IFFT: Batched 1-D IFFT (Input: (Batch, FeatureMaps, FreqSamples)) 82 | # FFT2: Batched 2-D FFT (Input: (Batch, FeatureMaps, TimeSamplesH, TimeSamplesW)) 83 | # IFFT2: Batched 2-D IFFT (Input: (Batch, FeatureMaps, FreqSamplesH, FreqSamplesW)) 84 | # 85 | 86 | class FFT(KL.Layer): 87 | def call(self, x, mask=None): 88 | a = KB.permute_dimensions(x, (1,0,2)) 89 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2])) 90 | a = fft(a) 91 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2])) 92 | return KB.permute_dimensions(a, (1,0,2)) 93 | class IFFT(KL.Layer): 94 | def call(self, x, mask=None): 95 | a = KB.permute_dimensions(x, (1,0,2)) 96 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2])) 97 | a = ifft(a) 98 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2])) 99 | return KB.permute_dimensions(a, (1,0,2)) 100 | class FFT2(KL.Layer): 101 | def call(self, x, mask=None): 102 | a = KB.permute_dimensions(x, (1,0,2,3)) 103 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2], x.shape[3])) 104 | a = fft2(a) 105 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3])) 106 | return KB.permute_dimensions(a, (1,0,2,3)) 107 | class IFFT2(KL.Layer): 108 | def call(self, x, mask=None): 109 | a = KB.permute_dimensions(x, (1,0,2,3)) 110 | a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2], x.shape[3])) 111 | a = ifft2(a) 112 | a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3])) 113 | return KB.permute_dimensions(a, (1,0,2,3)) 114 | 115 | -------------------------------------------------------------------------------- /complexnn/fft.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/fft.pyc -------------------------------------------------------------------------------- /complexnn/init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chiheb Trabelsi 6 | 7 | import numpy as np 8 | from numpy.random import RandomState 9 | import keras.backend as K 10 | from keras import initializers 11 | from keras.initializers import Initializer 12 | from keras.utils.generic_utils import (serialize_keras_object, 13 | deserialize_keras_object) 14 | 15 | 16 | class IndependentFilters(Initializer): 17 | # This initialization constructs real-valued kernels 18 | # that are independent as much as possible from each other 19 | # while respecting either the He or the Glorot criterion. 20 | def __init__(self, kernel_size, input_dim, 21 | weight_dim, nb_filters=None, 22 | criterion='glorot', seed=None): 23 | 24 | # `weight_dim` is used as a parameter for sanity check 25 | # as we should not pass an integer as kernel_size when 26 | # the weight dimension is >= 2. 27 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 28 | # then in such a case, weight_dim = 2. 29 | # (in case of 2D input): 30 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 31 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 32 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 33 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 34 | 35 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 36 | self.nb_filters = nb_filters 37 | self.kernel_size = kernel_size 38 | self.input_dim = input_dim 39 | self.weight_dim = weight_dim 40 | self.criterion = criterion 41 | self.seed = 1337 if seed is None else seed 42 | 43 | def __call__(self, shape, dtype=None): 44 | 45 | if self.nb_filters is not None: 46 | num_rows = self.nb_filters * self.input_dim 47 | num_cols = np.prod(self.kernel_size) 48 | else: 49 | # in case it is the kernel is a matrix and not a filter 50 | # which is the case of 2D input (No feature maps). 51 | num_rows = self.input_dim 52 | num_cols = self.kernel_size[-1] 53 | 54 | flat_shape = (num_rows, num_cols) 55 | rng = RandomState(self.seed) 56 | x = rng.uniform(size=flat_shape) 57 | u, _, v = np.linalg.svd(x) 58 | orthogonal_x = np.dot(u, np.dot(np.eye(num_rows, num_cols), v.T)) 59 | if self.nb_filters is not None: 60 | independent_filters = np.reshape(orthogonal_x, (num_rows,) + tuple(self.kernel_size)) 61 | fan_in, fan_out = initializers._compute_fans( 62 | tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 63 | ) 64 | else: 65 | independent_filters = orthogonal_x 66 | fan_in, fan_out = (self.input_dim, self.kernel_size[-1]) 67 | 68 | if self.criterion == 'glorot': 69 | desired_var = 2. / (fan_in + fan_out) 70 | elif self.criterion == 'he': 71 | desired_var = 2. / fan_in 72 | else: 73 | raise ValueError('Invalid criterion: ' + self.criterion) 74 | 75 | multip_constant = np.sqrt (desired_var / np.var(independent_filters)) 76 | scaled_indep = multip_constant * independent_filters 77 | 78 | if self.weight_dim == 2 and self.nb_filters is None: 79 | weight_real = scaled_real 80 | weight_imag = scaled_imag 81 | else: 82 | kernel_shape = tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 83 | if self.weight_dim == 1: 84 | transpose_shape = (1, 0) 85 | elif self.weight_dim == 2 and self.nb_filters is not None: 86 | transpose_shape = (1, 2, 0) 87 | elif self.weight_dim == 3 and self.nb_filters is not None: 88 | transpose_shape = (1, 2, 3, 0) 89 | weight = np.transpose(scaled_indep, transpose_shape) 90 | weight = np.reshape(weight, kernel_shape) 91 | 92 | return weight 93 | 94 | def get_config(self): 95 | return {'nb_filters': self.nb_filters, 96 | 'kernel_size': self.kernel_size, 97 | 'input_dim': self.input_dim, 98 | 'weight_dim': self.weight_dim, 99 | 'criterion': self.criterion, 100 | 'seed': self.seed} 101 | 102 | 103 | class ComplexIndependentFilters(Initializer): 104 | # This initialization constructs complex-valued kernels 105 | # that are independent as much as possible from each other 106 | # while respecting either the He or the Glorot criterion. 107 | def __init__(self, kernel_size, input_dim, 108 | weight_dim, nb_filters=None, 109 | criterion='glorot', seed=None): 110 | 111 | # `weight_dim` is used as a parameter for sanity check 112 | # as we should not pass an integer as kernel_size when 113 | # the weight dimension is >= 2. 114 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 115 | # then in such a case, weight_dim = 2. 116 | # (in case of 2D input): 117 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 118 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 119 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 120 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 121 | 122 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 123 | self.nb_filters = nb_filters 124 | self.kernel_size = kernel_size 125 | self.input_dim = input_dim 126 | self.weight_dim = weight_dim 127 | self.criterion = criterion 128 | self.seed = 1337 if seed is None else seed 129 | 130 | def __call__(self, shape, dtype=None): 131 | 132 | if self.nb_filters is not None: 133 | num_rows = self.nb_filters * self.input_dim 134 | num_cols = np.prod(self.kernel_size) 135 | else: 136 | # in case it is the kernel is a matrix and not a filter 137 | # which is the case of 2D input (No feature maps). 138 | num_rows = self.input_dim 139 | num_cols = self.kernel_size[-1] 140 | 141 | flat_shape = (int(num_rows), int(num_cols)) 142 | rng = RandomState(self.seed) 143 | r = rng.uniform(size=flat_shape) 144 | i = rng.uniform(size=flat_shape) 145 | z = r + 1j * i 146 | u, _, v = np.linalg.svd(z) 147 | unitary_z = np.dot(u, np.dot(np.eye(int(num_rows), int(num_cols)), np.conjugate(v).T)) 148 | real_unitary = unitary_z.real 149 | imag_unitary = unitary_z.imag 150 | if self.nb_filters is not None: 151 | indep_real = np.reshape(real_unitary, (num_rows,) + tuple(self.kernel_size)) 152 | indep_imag = np.reshape(imag_unitary, (num_rows,) + tuple(self.kernel_size)) 153 | fan_in, fan_out = initializers._compute_fans( 154 | tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 155 | ) 156 | else: 157 | indep_real = real_unitary 158 | indep_imag = imag_unitary 159 | fan_in, fan_out = (int(self.input_dim), self.kernel_size[-1]) 160 | 161 | if self.criterion == 'glorot': 162 | desired_var = 1. / (fan_in + fan_out) 163 | elif self.criterion == 'he': 164 | desired_var = 1. / (fan_in) 165 | else: 166 | raise ValueError('Invalid criterion: ' + self.criterion) 167 | 168 | multip_real = np.sqrt(desired_var / np.var(indep_real)) 169 | multip_imag = np.sqrt(desired_var / np.var(indep_imag)) 170 | scaled_real = multip_real * indep_real 171 | scaled_imag = multip_imag * indep_imag 172 | 173 | if self.weight_dim == 2 and self.nb_filters is None: 174 | weight_real = scaled_real 175 | weight_imag = scaled_imag 176 | else: 177 | kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) 178 | if self.weight_dim == 1: 179 | transpose_shape = (1, 0) 180 | elif self.weight_dim == 2 and self.nb_filters is not None: 181 | transpose_shape = (1, 2, 0) 182 | elif self.weight_dim == 3 and self.nb_filters is not None: 183 | transpose_shape = (1, 2, 3, 0) 184 | 185 | weight_real = np.transpose(scaled_real, transpose_shape) 186 | weight_imag = np.transpose(scaled_imag, transpose_shape) 187 | weight_real = np.reshape(weight_real, kernel_shape) 188 | weight_imag = np.reshape(weight_imag, kernel_shape) 189 | weight = np.concatenate([weight_real, weight_imag], axis=-1) 190 | 191 | return weight 192 | 193 | def get_config(self): 194 | return {'nb_filters': self.nb_filters, 195 | 'kernel_size': self.kernel_size, 196 | 'input_dim': self.input_dim, 197 | 'weight_dim': self.weight_dim, 198 | 'criterion': self.criterion, 199 | 'seed': self.seed} 200 | 201 | 202 | class ComplexInit(Initializer): 203 | # The standard complex initialization using 204 | # either the He or the Glorot criterion. 205 | def __init__(self, kernel_size, input_dim, 206 | weight_dim, nb_filters=None, 207 | criterion='glorot', seed=None): 208 | 209 | # `weight_dim` is used as a parameter for sanity check 210 | # as we should not pass an integer as kernel_size when 211 | # the weight dimension is >= 2. 212 | # nb_filters == 0 if weights are not convolutional (matrix instead of filters) 213 | # then in such a case, weight_dim = 2. 214 | # (in case of 2D input): 215 | # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 216 | # conv1D: len(kernel_size) == 1 and weight_dim == 1 217 | # conv2D: len(kernel_size) == 2 and weight_dim == 2 218 | # conv3d: len(kernel_size) == 3 and weight_dim == 3 219 | 220 | assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} 221 | self.nb_filters = nb_filters 222 | self.kernel_size = kernel_size 223 | self.input_dim = input_dim 224 | self.weight_dim = weight_dim 225 | self.criterion = criterion 226 | self.seed = 1337 if seed is None else seed 227 | 228 | def __call__(self, shape, dtype=None): 229 | 230 | if self.nb_filters is not None: 231 | kernel_shape = shape 232 | # kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), 233 | # self.nb_filters) 234 | else: 235 | kernel_shape = (int(self.input_dim), self.kernel_size[-1]) 236 | 237 | fan_in, fan_out = initializers._compute_fans( 238 | # tuple(self.kernel_size) + (self.input_dim, self.nb_filters) 239 | kernel_shape 240 | ) 241 | 242 | if self.criterion == 'glorot': 243 | s = 1. / (fan_in + fan_out) 244 | elif self.criterion == 'he': 245 | s = 1. / fan_in 246 | else: 247 | raise ValueError('Invalid criterion: ' + self.criterion) 248 | rng = RandomState(self.seed) 249 | modulus = rng.rayleigh(scale=s, size=kernel_shape) 250 | phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) 251 | weight_real = modulus * np.cos(phase) 252 | weight_imag = modulus * np.sin(phase) 253 | weight = np.concatenate([weight_real, weight_imag], axis=-1) 254 | 255 | return weight 256 | 257 | 258 | class SqrtInit(Initializer): 259 | def __call__(self, shape, dtype=None): 260 | return K.constant(1 / K.sqrt(2), shape=shape, dtype=dtype) 261 | 262 | 263 | # Aliases: 264 | sqrt_init = SqrtInit 265 | independent_filters = IndependentFilters 266 | complex_init = ComplexInit -------------------------------------------------------------------------------- /complexnn/init.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/init.pyc -------------------------------------------------------------------------------- /complexnn/norm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Chiheb Trabelsi 6 | 7 | # 8 | # Implementation of Layer Normalization and Complex Layer Normalization 9 | # 10 | 11 | import numpy as np 12 | from keras.layers import Layer, InputSpec 13 | from keras import initializers, regularizers, constraints 14 | import keras.backend as K 15 | from .bn import ComplexBN as complex_normalization 16 | from .bn import sqrt_init 17 | 18 | def layernorm(x, axis, epsilon, gamma, beta): 19 | # assert self.built, 'Layer must be built before being called' 20 | input_shape = K.shape(x) 21 | reduction_axes = list(range(K.ndim(x))) 22 | del reduction_axes[axis] 23 | del reduction_axes[0] 24 | broadcast_shape = [1] * K.ndim(x) 25 | broadcast_shape[axis] = input_shape[axis] 26 | broadcast_shape[0] = K.shape(x)[0] 27 | 28 | # Perform normalization: centering and reduction 29 | 30 | mean = K.mean(x, axis=reduction_axes) 31 | broadcast_mean = K.reshape(mean, broadcast_shape) 32 | x_centred = x - broadcast_mean 33 | variance = K.mean(x_centred ** 2, axis=reduction_axes) + epsilon 34 | broadcast_variance = K.reshape(variance, broadcast_shape) 35 | 36 | x_normed = x_centred / K.sqrt(broadcast_variance) 37 | 38 | # Perform scaling and shifting 39 | 40 | broadcast_shape_params = [1] * K.ndim(x) 41 | broadcast_shape_params[axis] = K.shape(x)[axis] 42 | broadcast_gamma = K.reshape(gamma, broadcast_shape_params) 43 | broadcast_beta = K.reshape(beta, broadcast_shape_params) 44 | 45 | x_LN = broadcast_gamma * x_normed + broadcast_beta 46 | 47 | return x_LN 48 | 49 | class LayerNormalization(Layer): 50 | 51 | def __init__(self, 52 | epsilon=1e-4, 53 | axis=-1, 54 | beta_init='zeros', 55 | gamma_init='ones', 56 | gamma_regularizer=None, 57 | beta_regularizer=None, 58 | **kwargs): 59 | 60 | self.supports_masking = True 61 | self.beta_init = initializers.get(beta_init) 62 | self.gamma_init = initializers.get(gamma_init) 63 | self.epsilon = epsilon 64 | self.axis = axis 65 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 66 | self.beta_regularizer = regularizers.get(beta_regularizer) 67 | 68 | super(LayerNormalization, self).__init__(**kwargs) 69 | 70 | def build(self, input_shape): 71 | self.input_spec = InputSpec(ndim=len(input_shape), 72 | axes={self.axis: input_shape[self.axis]}) 73 | shape = (input_shape[self.axis],) 74 | 75 | self.gamma = self.add_weight(shape, 76 | initializer=self.gamma_init, 77 | regularizer=self.gamma_regularizer, 78 | name='{}_gamma'.format(self.name)) 79 | self.beta = self.add_weight(shape, 80 | initializer=self.beta_init, 81 | regularizer=self.beta_regularizer, 82 | name='{}_beta'.format(self.name)) 83 | 84 | self.built = True 85 | 86 | def call(self, x, mask=None): 87 | assert self.built, 'Layer must be built before being called' 88 | return layernorm(x, self.axis, self.epsilon, self.gamma, self.beta) 89 | 90 | def get_config(self): 91 | config = {'epsilon': self.epsilon, 92 | 'axis': self.axis, 93 | 'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None, 94 | 'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None 95 | } 96 | base_config = super(LayerNormalization, self).get_config() 97 | return dict(list(base_config.items()) + list(config.items())) 98 | 99 | 100 | class ComplexLayerNorm(Layer): 101 | def __init__(self, 102 | epsilon=1e-4, 103 | axis=-1, 104 | center=True, 105 | scale=True, 106 | beta_initializer='zeros', 107 | gamma_diag_initializer=sqrt_init, 108 | gamma_off_initializer='zeros', 109 | beta_regularizer=None, 110 | gamma_diag_regularizer=None, 111 | gamma_off_regularizer=None, 112 | beta_constraint=None, 113 | gamma_diag_constraint=None, 114 | gamma_off_constraint=None, 115 | **kwargs): 116 | 117 | self.supports_masking = True 118 | self.epsilon = epsilon 119 | self.axis = axis 120 | self.center = center 121 | self.scale = scale 122 | self.beta_initializer = initializers.get(beta_initializer) 123 | self.gamma_diag_initializer = initializers.get(gamma_diag_initializer) 124 | self.gamma_off_initializer = initializers.get(gamma_off_initializer) 125 | self.beta_regularizer = regularizers.get(beta_regularizer) 126 | self.gamma_diag_regularizer = regularizers.get(gamma_diag_regularizer) 127 | self.gamma_off_regularizer = regularizers.get(gamma_off_regularizer) 128 | self.beta_constraint = constraints.get(beta_constraint) 129 | self.gamma_diag_constraint = constraints.get(gamma_diag_constraint) 130 | self.gamma_off_constraint = constraints.get(gamma_off_constraint) 131 | super(ComplexLayerNorm, self).__init__(**kwargs) 132 | 133 | def build(self, input_shape): 134 | 135 | ndim = len(input_shape) 136 | dim = input_shape[self.axis] 137 | if dim is None: 138 | raise ValueError('Axis ' + str(self.axis) + ' of ' 139 | 'input tensor should have a defined dimension ' 140 | 'but the layer received an input with shape ' + 141 | str(input_shape) + '.') 142 | self.input_spec = InputSpec(ndim=len(input_shape), 143 | axes={self.axis: dim}) 144 | 145 | gamma_shape = (input_shape[self.axis] // 2,) 146 | if self.scale: 147 | self.gamma_rr = self.add_weight( 148 | shape=gamma_shape, 149 | name='gamma_rr', 150 | initializer=self.gamma_diag_initializer, 151 | regularizer=self.gamma_diag_regularizer, 152 | constraint=self.gamma_diag_constraint 153 | ) 154 | self.gamma_ii = self.add_weight( 155 | shape=gamma_shape, 156 | name='gamma_ii', 157 | initializer=self.gamma_diag_initializer, 158 | regularizer=self.gamma_diag_regularizer, 159 | constraint=self.gamma_diag_constraint 160 | ) 161 | self.gamma_ri = self.add_weight( 162 | shape=gamma_shape, 163 | name='gamma_ri', 164 | initializer=self.gamma_off_initializer, 165 | regularizer=self.gamma_off_regularizer, 166 | constraint=self.gamma_off_constraint 167 | ) 168 | else: 169 | self.gamma_rr = None 170 | self.gamma_ii = None 171 | self.gamma_ri = None 172 | 173 | if self.center: 174 | self.beta = self.add_weight(shape=(input_shape[self.axis],), 175 | name='beta', 176 | initializer=self.beta_initializer, 177 | regularizer=self.beta_regularizer, 178 | constraint=self.beta_constraint) 179 | else: 180 | self.beta = None 181 | 182 | self.built = True 183 | 184 | def call(self, inputs): 185 | input_shape = K.shape(inputs) 186 | ndim = K.ndim(inputs) 187 | reduction_axes = list(range(ndim)) 188 | del reduction_axes[self.axis] 189 | del reduction_axes[0] 190 | input_dim = input_shape[self.axis] // 2 191 | mu = K.mean(inputs, axis=reduction_axes) 192 | broadcast_mu_shape = [1] * ndim 193 | broadcast_mu_shape[self.axis] = input_shape[self.axis] 194 | broadcast_mu_shape[0] = K.shape(inputs)[0] 195 | broadcast_mu = K.reshape(mu, broadcast_mu_shape) 196 | if self.center: 197 | input_centred = inputs - broadcast_mu 198 | else: 199 | input_centred = inputs 200 | centred_squared = input_centred ** 2 201 | if (self.axis == 1 and ndim != 3) or ndim == 2: 202 | centred_squared_real = centred_squared[:, :input_dim] 203 | centred_squared_imag = centred_squared[:, input_dim:] 204 | centred_real = input_centred[:, :input_dim] 205 | centred_imag = input_centred[:, input_dim:] 206 | elif ndim == 3: 207 | centred_squared_real = centred_squared[:, :, :input_dim] 208 | centred_squared_imag = centred_squared[:, :, input_dim:] 209 | centred_real = input_centred[:, :, :input_dim] 210 | centred_imag = input_centred[:, :, input_dim:] 211 | elif self.axis == -1 and ndim == 4: 212 | centred_squared_real = centred_squared[:, :, :, :input_dim] 213 | centred_squared_imag = centred_squared[:, :, :, input_dim:] 214 | centred_real = input_centred[:, :, :, :input_dim] 215 | centred_imag = input_centred[:, :, :, input_dim:] 216 | elif self.axis == -1 and ndim == 5: 217 | centred_squared_real = centred_squared[:, :, :, :, :input_dim] 218 | centred_squared_imag = centred_squared[:, :, :, :, input_dim:] 219 | centred_real = input_centred[:, :, :, :, :input_dim] 220 | centred_imag = input_centred[:, :, :, :, input_dim:] 221 | else: 222 | raise ValueError( 223 | 'Incorrect Layernorm combination of axis and dimensions. axis should be either 1 or -1. ' 224 | 'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.' 225 | ) 226 | if self.scale: 227 | Vrr = K.mean( 228 | centred_squared_real, 229 | axis=reduction_axes 230 | ) + self.epsilon 231 | Vii = K.mean( 232 | centred_squared_imag, 233 | axis=reduction_axes 234 | ) + self.epsilon 235 | # Vri contains the real and imaginary covariance for each feature map. 236 | Vri = K.mean( 237 | centred_real * centred_imag, 238 | axis=reduction_axes, 239 | ) 240 | elif self.center: 241 | Vrr = None 242 | Vii = None 243 | Vri = None 244 | else: 245 | raise ValueError('Error. Both scale and center in batchnorm are set to False.') 246 | 247 | return complex_normalization( 248 | input_centred, Vrr, Vii, Vri, 249 | self.beta, self.gamma_rr, self.gamma_ri, 250 | self.gamma_ii, self.scale, self.center, 251 | layernorm=True, axis=self.axis 252 | ) 253 | 254 | def get_config(self): 255 | config = { 256 | 'axis': self.axis, 257 | 'epsilon': self.epsilon, 258 | 'center': self.center, 259 | 'scale': self.scale, 260 | 'beta_initializer': initializers.serialize(self.beta_initializer), 261 | 'gamma_diag_initializer': initializers.serialize(self.gamma_diag_initializer), 262 | 'gamma_off_initializer': initializers.serialize(self.gamma_off_initializer), 263 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 264 | 'gamma_diag_regularizer': regularizers.serialize(self.gamma_diag_regularizer), 265 | 'gamma_off_regularizer': regularizers.serialize(self.gamma_off_regularizer), 266 | 'beta_constraint': constraints.serialize(self.beta_constraint), 267 | 'gamma_diag_constraint': constraints.serialize(self.gamma_diag_constraint), 268 | 'gamma_off_constraint': constraints.serialize(self.gamma_off_constraint), 269 | } 270 | base_config = super(ComplexLayerNorm, self).get_config() 271 | return dict(list(base_config.items()) + list(config.items())) 272 | -------------------------------------------------------------------------------- /complexnn/norm.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/norm.pyc -------------------------------------------------------------------------------- /complexnn/pool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Olexa Bilaniuk 6 | # 7 | 8 | import keras.backend as KB 9 | import keras.engine as KE 10 | import keras.layers as KL 11 | import keras.optimizers as KO 12 | import numpy as np 13 | 14 | 15 | # 16 | # Spectral Pooling Layer 17 | # 18 | 19 | class SpectralPooling1D(KL.Layer): 20 | def __init__(self, topf=(0,)): 21 | super(SpectralPooling1D, self).__init__() 22 | if "topf" in kwargs: 23 | self.topf = (int (kwargs["topf" ][0]),) 24 | self.topf = (self.topf[0]//2,) 25 | elif "gamma" in kwargs: 26 | self.gamma = (float(kwargs["gamma"][0]),) 27 | self.gamma = (self.gamma[0]/2,) 28 | else: 29 | raise RuntimeError("Must provide either topf= or gamma= !") 30 | def call(self, x, mask=None): 31 | xshape = x._keras_shape 32 | if hasattr(self, "topf"): 33 | topf = self.topf 34 | else: 35 | if KB.image_data_format() == "channels_first": 36 | topf = (int(self.gamma[0]*xshape[2]),) 37 | else: 38 | topf = (int(self.gamma[0]*xshape[1]),) 39 | 40 | if KB.image_data_format() == "channels_first": 41 | if topf[0] > 0 and xshape[2] >= 2*topf[0]: 42 | mask = [1]*(topf[0] ) +\ 43 | [0]*(xshape[2] - 2*topf[0]) +\ 44 | [1]*(topf[0] ) 45 | mask = [[mask]] 46 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,2)) 47 | mask = KB.constant(mask) 48 | x *= mask 49 | else: 50 | if topf[0] > 0 and xshape[1] >= 2*topf[0]: 51 | mask = [1]*(topf[0] ) +\ 52 | [0]*(xshape[1] - 2*topf[0]) +\ 53 | [1]*(topf[0] ) 54 | mask = [[mask]] 55 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,2,1)) 56 | mask = KB.constant(mask) 57 | x *= mask 58 | 59 | return x 60 | class SpectralPooling2D(KL.Layer): 61 | def __init__(self, **kwargs): 62 | super(SpectralPooling2D, self).__init__() 63 | if "topf" in kwargs: 64 | self.topf = (int (kwargs["topf" ][0]), int (kwargs["topf" ][1])) 65 | self.topf = (self.topf[0]//2, self.topf[1]//2) 66 | elif "gamma" in kwargs: 67 | self.gamma = (float(kwargs["gamma"][0]), float(kwargs["gamma"][1])) 68 | self.gamma = (self.gamma[0]/2, self.gamma[1]/2) 69 | else: 70 | raise RuntimeError("Must provide either topf= or gamma= !") 71 | def call(self, x, mask=None): 72 | xshape = x._keras_shape 73 | if hasattr(self, "topf"): 74 | topf = self.topf 75 | else: 76 | if KB.image_data_format() == "channels_first": 77 | topf = (int(self.gamma[0]*xshape[2]), int(self.gamma[1]*xshape[3])) 78 | else: 79 | topf = (int(self.gamma[0]*xshape[1]), int(self.gamma[1]*xshape[2])) 80 | 81 | if KB.image_data_format() == "channels_first": 82 | if topf[0] > 0 and xshape[2] >= 2*topf[0]: 83 | mask = [1]*(topf[0] ) +\ 84 | [0]*(xshape[2] - 2*topf[0]) +\ 85 | [1]*(topf[0] ) 86 | mask = [[[mask]]] 87 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,3,2)) 88 | mask = KB.constant(mask) 89 | x *= mask 90 | if topf[1] > 0 and xshape[3] >= 2*topf[1]: 91 | mask = [1]*(topf[1] ) +\ 92 | [0]*(xshape[3] - 2*topf[1]) +\ 93 | [1]*(topf[1] ) 94 | mask = [[[mask]]] 95 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,2,3)) 96 | mask = KB.constant(mask) 97 | x *= mask 98 | else: 99 | if topf[0] > 0 and xshape[1] >= 2*topf[0]: 100 | mask = [1]*(topf[0] ) +\ 101 | [0]*(xshape[1] - 2*topf[0]) +\ 102 | [1]*(topf[0] ) 103 | mask = [[[mask]]] 104 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,3,1,2)) 105 | mask = KB.constant(mask) 106 | x *= mask 107 | if topf[1] > 0 and xshape[2] >= 2*topf[1]: 108 | mask = [1]*(topf[1] ) +\ 109 | [0]*(xshape[2] - 2*topf[1]) +\ 110 | [1]*(topf[1] ) 111 | mask = [[[mask]]] 112 | mask = np.asarray(mask, dtype=KB.floatx()).transpose((0,1,3,2)) 113 | mask = KB.constant(mask) 114 | x *= mask 115 | 116 | return x 117 | 118 | 119 | if __name__ == "__main__": 120 | import cv2, sys 121 | import __main__ as SP 122 | import fft as CF 123 | 124 | # Build Model 125 | x = i = KL.Input(shape=(6,512,512)) 126 | f = CF.FFT2()(x) 127 | p = SP.SpectralPooling2D(gamma=[0.15,0.15])(f) 128 | o = CF.IFFT2()(p) 129 | 130 | model = KE.Model([i], [f,p,o]) 131 | model.compile("sgd", "mse") 132 | 133 | # Use it 134 | img = cv2.imread(sys.argv[1]) 135 | imgBatch = img[np.newaxis,...].transpose((0,3,1,2)) 136 | imgBatch = np.concatenate([imgBatch, np.zeros_like(imgBatch)], axis=1) 137 | f,p,o = model.predict(imgBatch) 138 | ffted = np.sqrt(np.sum(f[:,:3]**2 + f[:,3:]**2, axis=1)) 139 | ffted = ffted .transpose((1,2,0))/255 140 | pooled = np.sqrt(np.sum(p[:,:3]**2 + p[:,3:]**2, axis=1)) 141 | pooled = pooled.transpose((1,2,0))/255 142 | filtered = np.clip(o,0,255).transpose((0,2,3,1))[0,:,:,:3].astype("uint8") 143 | 144 | # Display it 145 | cv2.imshow("Original", img) 146 | cv2.imshow("FFT", ffted) 147 | cv2.imshow("Pooled", pooled) 148 | cv2.imshow("Filtered", filtered) 149 | cv2.waitKey(0) 150 | -------------------------------------------------------------------------------- /complexnn/pool.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/pool.pyc -------------------------------------------------------------------------------- /complexnn/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # 5 | # Authors: Dmitriy Serdyuk, Olexa Bilaniuk, Chiheb Trabelsi 6 | 7 | import keras.backend as K 8 | from keras.layers import Layer, Lambda 9 | 10 | # 11 | # GetReal/GetImag Lambda layer Implementation 12 | # 13 | 14 | 15 | def get_realpart(x): 16 | image_format = K.image_data_format() 17 | ndim = K.ndim(x) 18 | input_shape = K.shape(x) 19 | 20 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 21 | input_dim = input_shape[1] // 2 22 | return x[:, :input_dim] 23 | 24 | input_dim = input_shape[-1] // 2 25 | if ndim == 3: 26 | return x[:, :, :input_dim] 27 | elif ndim == 4: 28 | return x[:, :, :, :input_dim] 29 | elif ndim == 5: 30 | return x[:, :, :, :, :input_dim] 31 | 32 | 33 | def get_imagpart(x): 34 | image_format = K.image_data_format() 35 | ndim = K.ndim(x) 36 | input_shape = K.shape(x) 37 | 38 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 39 | input_dim = input_shape[1] // 2 40 | return x[:, input_dim:] 41 | 42 | input_dim = input_shape[-1] // 2 43 | if ndim == 3: 44 | return x[:, :, input_dim:] 45 | elif ndim == 4: 46 | return x[:, :, :, input_dim:] 47 | elif ndim == 5: 48 | return x[:, :, :, :, input_dim:] 49 | 50 | 51 | def get_abs(x): 52 | real = get_realpart(x) 53 | imag = get_imagpart(x) 54 | 55 | return K.sqrt(real * real + imag * imag) 56 | 57 | 58 | def getpart_output_shape(input_shape): 59 | returned_shape = list(input_shape[:]) 60 | image_format = K.image_data_format() 61 | ndim = len(returned_shape) 62 | 63 | if (image_format == 'channels_first' and ndim != 3) or ndim == 2: 64 | axis = 1 65 | else: 66 | axis = -1 67 | 68 | returned_shape[axis] = returned_shape[axis] // 2 69 | 70 | return tuple(returned_shape) 71 | 72 | 73 | class GetReal(Layer): 74 | def call(self, inputs): 75 | return get_realpart(inputs) 76 | def compute_output_shape(self, input_shape): 77 | return getpart_output_shape(input_shape) 78 | class GetImag(Layer): 79 | def call(self, inputs): 80 | return get_imagpart(inputs) 81 | def compute_output_shape(self, input_shape): 82 | return getpart_output_shape(input_shape) 83 | class GetAbs(Layer): 84 | def call(self, inputs): 85 | return get_abs(inputs) 86 | def compute_output_shape(self, input_shape): 87 | return getpart_output_shape(input_shape) 88 | 89 | -------------------------------------------------------------------------------- /complexnn/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adugnag/CV-deSpeckNet/1f20307a7853ecbafd94bc7a3bd53afe7edd4679/complexnn/utils.pyc -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Version: v1.2 3 | Date: 2021-01-12 4 | Author: Mullissa A.G. 5 | Description: This script contains helper functions that are used in synthesizing data tensors. 6 | """ 7 | 8 | 9 | import glob 10 | import os 11 | import cv2 12 | import numpy as np 13 | import tifffile 14 | 15 | 16 | patch_size, stride = 40, 9 17 | aug_times = 1 18 | scales = [1, 0.9, 0.8, 0.7] 19 | batch_size = 128 20 | 21 | 22 | def gen_patches(file_name): 23 | """ This function prepares patches using the parameters patch_size and stride""" 24 | 25 | # read image 26 | img = tifffile.imread(file_name) 27 | img = np.array(img) 28 | h, w, d = img.shape 29 | patches = [] 30 | for s in scales: 31 | h_scaled, w_scaled = int(h*s),int(w*s) 32 | img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC) 33 | # extract patches 34 | for i in range(0, h_scaled-patch_size+1, stride): 35 | for j in range(0, w_scaled-patch_size+1, stride): 36 | x = img_scaled[i:i+patch_size, j:j+patch_size,:] 37 | patches.append(x) 38 | 39 | return patches 40 | 41 | def make_dataTensor(data_dir,verbose=False): 42 | """ This function prepares the data tensor from a given directory by using gen_patches function""" 43 | 44 | file_list = glob.glob(data_dir+'/*.tif') # get name list of all .tif files 45 | # initrialize 46 | data = [] 47 | # generate patches 48 | for i in range(len(file_list)): 49 | patch = gen_patches(file_list[i]) 50 | data.append(patch) 51 | if verbose: 52 | print(str(i+1)+'/'+ str(len(file_list)) + ' is done ^_^') 53 | data = np.array(data) 54 | data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],6)) 55 | discard_n = len(data)-len(data)//batch_size*batch_size; 56 | data = np.delete(data,range(discard_n),axis = 0) 57 | print('^_^-training data finished-^_^') 58 | return data 59 | 60 | def get_steps(data_dir, batch_size=128): 61 | """ This function estimates the step per epoch in training the model from the given data tensor length""" 62 | if os.path.isfile(data_dir): 63 | noisy_files = [data_dir] 64 | else: 65 | noisy_files = glob.glob(data_dir + '/*.tif') 66 | num = 0 67 | #get number of steps per epoch to use in training 68 | for data_file in noisy_files: 69 | xs = make_dataTensor(data_dir) 70 | if xs is not None: 71 | num += len(xs) 72 | print("total number of patches: {}".format(num)) 73 | print("steps per epoch: {}".format(num//batch_size)) 74 | print("") 75 | return num // batch_size 76 | 77 | if __name__ == '__main__': 78 | 79 | data = make_dataTensor(data_dir='data/Train2') 80 | label = make_dataTensor(data_dir='data/Label2') 81 | -------------------------------------------------------------------------------- /main_cv-despecknet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "anaconda-cloud": {}, 7 | "colab": { 8 | "name": "Copy of main_train_corr_despecknet_cplx.ipynb", 9 | "provenance": [], 10 | "collapsed_sections": [] 11 | }, 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.6.1" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "metadata": { 34 | "id": "nPEY4wbfN6YH" 35 | }, 36 | "source": [ 37 | "**CV-DESPECKNET**\n", 38 | "\n", 39 | "**Version**: v1.2 \\\\\n", 40 | "**Date**: 2021-01-12 \\\\\n", 41 | "**Author**: Mullissa A.G. \\\\\n", 42 | "**Description**: This script trains a complex-valued multistream fully convolutional network for despeckling a polarimetric SAR covariance matrix as discussed in our paper A. G. Mullissa, C. Persello and J. Reiche, \"Despeckling Polarimetric SAR Data Using a Multistream Complex-Valued Fully Convolutional Network,\" in IEEE Geoscience and Remote Sensing Letters, doi: 10.1109/LGRS.2021.3066311. Some utility functions are adopted from https://github.com/cszn/DnCNN\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": { 48 | "id": "DRVu12-4aFMw" 49 | }, 50 | "source": [ 51 | "**SETTING UP THE ENVIORNMENT**" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "metadata": { 57 | "colab": { 58 | "base_uri": "https://localhost:8080/" 59 | }, 60 | "id": "SNab3sAtMNgw", 61 | "outputId": "96ad2226-3f13-41ee-c1df-70c0b48fcc5f" 62 | }, 63 | "source": [ 64 | "#mount drive\n", 65 | "from google.colab import drive\n", 66 | "drive.mount('/content/drive')" 67 | ], 68 | "execution_count": 1, 69 | "outputs": [ 70 | { 71 | "output_type": "stream", 72 | "text": [ 73 | "Mounted at /content/drive\n" 74 | ], 75 | "name": "stdout" 76 | } 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "colab": { 83 | "base_uri": "https://localhost:8080/" 84 | }, 85 | "id": "-SdJIBqCEkrS", 86 | "outputId": "f75c294c-d281-4bf6-82fa-6995a1939e32" 87 | }, 88 | "source": [ 89 | "#check GPU\n", 90 | "gpu_info = !nvidia-smi\n", 91 | "gpu_info = '\\n'.join(gpu_info)\n", 92 | "if gpu_info.find('failed') >= 0:\n", 93 | " print('Select the Runtime > \"Change runtime type\" menu to enable a GPU accelerator, ')\n", 94 | " print('and then re-execute this cell.')\n", 95 | "else:\n", 96 | " print(gpu_info)" 97 | ], 98 | "execution_count": 2, 99 | "outputs": [ 100 | { 101 | "output_type": "stream", 102 | "text": [ 103 | "Sun Mar 28 15:44:02 2021 \n", 104 | "+-----------------------------------------------------------------------------+\n", 105 | "| NVIDIA-SMI 460.56 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", 106 | "|-------------------------------+----------------------+----------------------+\n", 107 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 108 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 109 | "| | | MIG M. |\n", 110 | "|===============================+======================+======================|\n", 111 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", 112 | "| N/A 39C P8 9W / 70W | 0MiB / 15109MiB | 0% Default |\n", 113 | "| | | N/A |\n", 114 | "+-------------------------------+----------------------+----------------------+\n", 115 | " \n", 116 | "+-----------------------------------------------------------------------------+\n", 117 | "| Processes: |\n", 118 | "| GPU GI CI PID Type Process name GPU Memory |\n", 119 | "| ID ID Usage |\n", 120 | "|=============================================================================|\n", 121 | "| No running processes found |\n", 122 | "+-----------------------------------------------------------------------------+\n" 123 | ], 124 | "name": "stdout" 125 | } 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "metadata": { 131 | "id": "g5APznoIvIqp" 132 | }, 133 | "source": [ 134 | "#add path\n", 135 | "import sys\n", 136 | "sys.path.append('/content/drive/My Drive')" 137 | ], 138 | "execution_count": 3, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "metadata": { 144 | "colab": { 145 | "base_uri": "https://localhost:8080/" 146 | }, 147 | "id": "UFe2OSJBvMHA", 148 | "outputId": "eeeaa8fc-0f7e-4969-b10b-e991f15b34bf" 149 | }, 150 | "source": [ 151 | "#Install the right tensorflow and keras versions\n", 152 | "%tensorflow_version 1.x\n", 153 | "!pip uninstall keras\n", 154 | "!pip install keras==2.2.3" 155 | ], 156 | "execution_count": 4, 157 | "outputs": [ 158 | { 159 | "output_type": "stream", 160 | "text": [ 161 | "TensorFlow 1.x selected.\n", 162 | "Uninstalling Keras-2.3.1:\n", 163 | " Would remove:\n", 164 | " /tensorflow-1.15.2/python3.7/Keras-2.3.1.dist-info/*\n", 165 | " /tensorflow-1.15.2/python3.7/docs/*\n", 166 | " /tensorflow-1.15.2/python3.7/keras/*\n", 167 | "Proceed (y/n)? y\n", 168 | " Successfully uninstalled Keras-2.3.1\n", 169 | "Collecting keras==2.2.3\n", 170 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/06/ea/ad52366ce566f7b54d36834f98868f743ea81a416b3665459a9728287728/Keras-2.2.3-py2.py3-none-any.whl (312kB)\n", 171 | "\u001b[K |████████████████████████████████| 317kB 18.5MB/s \n", 172 | "\u001b[?25hRequirement already satisfied: numpy>=1.9.1 in /usr/local/lib/python3.7/dist-packages (from keras==2.2.3) (1.19.5)\n", 173 | "Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from keras==2.2.3) (3.13)\n", 174 | "Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (from keras==2.2.3) (2.10.0)\n", 175 | "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.7/dist-packages (from keras==2.2.3) (1.1.2)\n", 176 | "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.7/dist-packages (from keras==2.2.3) (1.4.1)\n", 177 | "Requirement already satisfied: keras-applications>=1.0.6 in /tensorflow-1.15.2/python3.7 (from keras==2.2.3) (1.0.8)\n", 178 | "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from keras==2.2.3) (1.15.0)\n", 179 | "Installing collected packages: keras\n", 180 | " Found existing installation: Keras 2.4.3\n", 181 | " Uninstalling Keras-2.4.3:\n", 182 | " Successfully uninstalled Keras-2.4.3\n", 183 | "Successfully installed keras-2.2.3\n" 184 | ], 185 | "name": "stdout" 186 | } 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "metadata": { 192 | "id": "9DSpC6MOZ38P" 193 | }, 194 | "source": [ 195 | "**HELPERS**" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "metadata": { 201 | "id": "9cKzIgBdMRJ3" 202 | }, 203 | "source": [ 204 | "import glob\n", 205 | "import os\n", 206 | "import cv2\n", 207 | "import numpy as np\n", 208 | "import tifffile\n", 209 | "\n", 210 | "patch_size, stride = 40, 9\n", 211 | "aug_times = 1\n", 212 | "scales = [1, 0.9, 0.8, 0.7]\n", 213 | "batch_size = 128\n", 214 | "\n", 215 | "def gen_patches(file_name):\n", 216 | "\n", 217 | " # read image\n", 218 | " img = tifffile.imread(file_name) \n", 219 | " img = np.array(img)\n", 220 | " h, w, d = img.shape\n", 221 | " patches = []\n", 222 | " for s in scales:\n", 223 | " h_scaled, w_scaled = int(h*s),int(w*s)\n", 224 | " img_scaled = cv2.resize(img, (h_scaled,w_scaled), interpolation=cv2.INTER_CUBIC)\n", 225 | " # extract patches\n", 226 | " for i in range(0, h_scaled-patch_size+1, stride):\n", 227 | " for j in range(0, w_scaled-patch_size+1, stride):\n", 228 | " x = img_scaled[i:i+patch_size, j:j+patch_size,:]\n", 229 | " patches.append(x) \n", 230 | "\n", 231 | " \n", 232 | " return patches\n", 233 | "\n", 234 | "def make_dataTensor(data_dir,verbose=False):\n", 235 | " \n", 236 | " file_list = glob.glob(data_dir+'/*.tif') # get name list of all .tif files\n", 237 | " # initrialize\n", 238 | " data = []\n", 239 | " # generate patches\n", 240 | " for i in range(len(file_list)):\n", 241 | " patch = gen_patches(file_list[i])\n", 242 | " data.append(patch)\n", 243 | " if verbose:\n", 244 | " print(str(i+1)+'/'+ str(len(file_list)) + ' is done ^_^')\n", 245 | " data = np.array(data)\n", 246 | " data = data.reshape((data.shape[0]*data.shape[1],data.shape[2],data.shape[3],6))\n", 247 | " discard_n = len(data)-len(data)//batch_size*batch_size;\n", 248 | " data = np.delete(data,range(discard_n),axis = 0)\n", 249 | " print(\"Finished generating data from {}\".format(data_dir))\n", 250 | " return data\n", 251 | "\n", 252 | "def get_steps(data_dir, batch_size=128):\n", 253 | " if os.path.isfile(data_dir):\n", 254 | " noisy_files = [data_dir]\n", 255 | " else:\n", 256 | " noisy_files = glob.glob(data_dir + '/*.tif')\n", 257 | " num = 0\n", 258 | " #get number of steps per epoch to use in training\n", 259 | " for data_file in noisy_files:\n", 260 | " xs = make_dataTensor(data_dir)\n", 261 | " if xs is not None: \n", 262 | " num += len(xs)\n", 263 | " print(\"total number of patches: {}\".format(num))\n", 264 | " print(\"steps per epoch: {}\".format(num//batch_size))\n", 265 | " print(\"\")\n", 266 | " return num // batch_size" 267 | ], 268 | "execution_count": 6, 269 | "outputs": [] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": { 274 | "id": "tp-cgcPBabRJ" 275 | }, 276 | "source": [ 277 | "**DO THE JOB**" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "metadata": { 283 | "colab": { 284 | "base_uri": "https://localhost:8080/", 285 | "height": 1000 286 | }, 287 | "id": "6URbymYZ9lcj", 288 | "outputId": "8b06ad59-9876-440f-ba83-e5097af2a272" 289 | }, 290 | "source": [ 291 | "\n", 292 | "import complexnn\n", 293 | "import argparse\n", 294 | "import re\n", 295 | "import os, glob, datetime\n", 296 | "from keras.layers import Input,Conv2D,BatchNormalization,Activation,Multiply, Add\n", 297 | "from keras.models import Model, load_model\n", 298 | "from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler\n", 299 | "from keras.optimizers import Adam\n", 300 | "import keras.backend as K\n", 301 | "\n", 302 | "\n", 303 | "save_dir = os.path.join('models','/content/drive/My Drive/modelCPLX_despecknet_SSE') \n", 304 | "\n", 305 | "if not os.path.exists(save_dir):\n", 306 | " os.mkdir(save_dir)\n", 307 | "\n", 308 | "def cv_deSpeckNet(depth,filters=48,image_channels=6, use_bnorm=True):\n", 309 | " layer_count = 0\n", 310 | " inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count))\n", 311 | " # 1st layer, CV-Conv+Crelu\n", 312 | " layer_count += 1\n", 313 | " x0 = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1), activation='relu', padding='same',name = 'conv'+str(layer_count))(inpt)\n", 314 | " # depth-2 layers, CV-Conv+CV-BN+Crelu\n", 315 | " for i in range(depth-2):\n", 316 | " layer_count += 1\n", 317 | " x0 = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1),activation='relu', padding='same',name = 'conv'+str(layer_count))(x0)\n", 318 | " if use_bnorm:\n", 319 | " layer_count += 1\n", 320 | " x0 = complexnn.bn.ComplexBatchNormalization(name = 'bn'+str(layer_count))(x0)\n", 321 | " # last layer, CV-Conv+Crelu\n", 322 | " layer_count += 1\n", 323 | " x0 = complexnn.conv.ComplexConv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding='same',name = 'speckle'+str(1))(x0)\n", 324 | " layer_count += 1\n", 325 | " \n", 326 | " # 1st layer, CV-Conv+Crelu\n", 327 | " x = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1), activation='relu', padding='same',name = 'conv'+str(layer_count))(inpt)\n", 328 | " # depth-2 layers, CV-Conv+CV-BN+Crelu\n", 329 | " for i in range(depth-2):\n", 330 | " layer_count += 1\n", 331 | " x = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1),activation='relu', padding='same',name = 'conv'+str(layer_count))(x)\n", 332 | " if use_bnorm:\n", 333 | " layer_count += 1\n", 334 | " x = complexnn.bn.ComplexBatchNormalization(name = 'bn'+str(layer_count))(x)\n", 335 | " # last layer, CV-Conv\n", 336 | " layer_count += 1\n", 337 | " x = complexnn.conv.ComplexConv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding='same',name = 'clean'+str(1))(x)\n", 338 | " layer_count += 1\n", 339 | " x_orig = Add(name = 'noisy' + str(1))([x0,x])\n", 340 | " \n", 341 | " model = Model(inputs=inpt, outputs=[x,x_orig])\n", 342 | " \n", 343 | " return model\n", 344 | "\n", 345 | "\n", 346 | "def findLastCheckpoint(save_dir):\n", 347 | " file_list = glob.glob(os.path.join(save_dir,'model_*.hdf5')) # get name list of all .hdf5 files\n", 348 | " #file_list = os.listdir(save_dir)\n", 349 | " if file_list:\n", 350 | " epochs_exist = []\n", 351 | " for file_ in file_list:\n", 352 | " result = re.findall(\".*model_(.*).hdf5.*\",file_)\n", 353 | " #print(result[0])\n", 354 | " epochs_exist.append(int(result[0]))\n", 355 | " initial_epoch=max(epochs_exist) \n", 356 | " else:\n", 357 | " initial_epoch = 0\n", 358 | " return initial_epoch\n", 359 | "\n", 360 | "def log(args,kwargs):\n", 361 | " print(datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S:\"),args,kwargs)\n", 362 | "\n", 363 | "def lr_schedule(epoch):\n", 364 | " initial_lr = 1e-3\n", 365 | " if epoch<=30:\n", 366 | " lr = initial_lr\n", 367 | " elif epoch<=60:\n", 368 | " lr = initial_lr/10\n", 369 | " elif epoch<=80:\n", 370 | " lr = initial_lr/20 \n", 371 | " else:\n", 372 | " lr = initial_lr/20 \n", 373 | " #log('current learning rate is %2.8f' %lr)\n", 374 | " return lr\n", 375 | "\n", 376 | "def train_datagen(epoch_iter=2000,epoch_num=5,batch_size=64,data_dir='/content/drive/My Drive/data_cplx/New_train',label_dir='/content/drive/My Drive/data_cplx/New_label_final'):\n", 377 | " while(True):\n", 378 | " n_count = 0\n", 379 | " if n_count == 0:\n", 380 | " #print(n_count)\n", 381 | " xs = make_dataTensor(data_dir)\n", 382 | " xy = make_dataTensor(label_dir)\n", 383 | " assert len(xs)%batch_size ==0, \\\n", 384 | " log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!')\n", 385 | " xs = xs.astype('float32')\n", 386 | " xy = xy.astype('float32')\n", 387 | " indices = list(range(xs.shape[0]))\n", 388 | " n_count = 1\n", 389 | " for _ in range(epoch_num):\n", 390 | " np.random.shuffle(indices) # shuffle\n", 391 | " for i in range(0, len(indices), batch_size):\n", 392 | " batch_x = xs[indices[i:i+batch_size]]\n", 393 | " batch_y = xy[indices[i:i+batch_size]]\n", 394 | " yield batch_x, [batch_y, batch_x]\n", 395 | " \n", 396 | "# sum square error loss function\n", 397 | "def sum_squared_error(y_true, y_pred):\n", 398 | " return K.sum(K.square(y_pred - y_true))/2\n", 399 | " \n", 400 | "if __name__ == '__main__':\n", 401 | " # model selection\n", 402 | " model = cv_deSpeckNet(depth=17,filters=48,image_channels=6,use_bnorm=True)\n", 403 | " model.summary()\n", 404 | " \n", 405 | " # load the last model in matconvnet style\n", 406 | " initial_epoch = findLastCheckpoint(save_dir=save_dir)\n", 407 | " if initial_epoch > 0: \n", 408 | " print('resuming by loading epoch %03d'%initial_epoch)\n", 409 | " model = load_model(os.path.join(save_dir,'model_%03d.hdf5'%initial_epoch), custom_objects={'ComplexConv2D': complexnn.conv.ComplexConv2D, 'ComplexBatchNormalization': complexnn.bn.ComplexBatchNormalization, 'sum_squared_error': sum_squared_error})\n", 410 | "\n", 411 | " loss_funcs = {\n", 412 | " 'clean1': sum_squared_error,\n", 413 | " 'noisy1' : sum_squared_error}\n", 414 | " \n", 415 | " loss_weights = {'clean1': 100.0, 'noisy1': 1.0}\n", 416 | " \n", 417 | " # compile the model\n", 418 | " model.compile(optimizer=Adam(0.001), loss=loss_funcs, loss_weights=loss_weights)\n", 419 | " \n", 420 | " # use call back functions\n", 421 | " checkpointer = ModelCheckpoint(os.path.join(save_dir,'model_{epoch:03d}.hdf5'), \n", 422 | " verbose=1, save_weights_only=False, period=1)\n", 423 | " csv_logger = CSVLogger(os.path.join(save_dir,'log.csv'), append=True, separator=',')\n", 424 | " lr_scheduler = LearningRateScheduler(lr_schedule)\n", 425 | "\n", 426 | " nsteps = get_steps(data_dir='/content/drive/My Drive/data_cplx/New_train', batch_size=128)\n", 427 | " \n", 428 | " history = model.fit_generator(train_datagen(batch_size=64),\n", 429 | " steps_per_epoch=nsteps, epochs=52, verbose=1, initial_epoch=initial_epoch,\n", 430 | " callbacks=[checkpointer,csv_logger,lr_scheduler])\n", 431 | "\n" 432 | ], 433 | "execution_count": 7, 434 | "outputs": [ 435 | { 436 | "output_type": "stream", 437 | "text": [ 438 | "Using TensorFlow backend.\n" 439 | ], 440 | "name": "stderr" 441 | }, 442 | { 443 | "output_type": "stream", 444 | "text": [ 445 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.\n", 446 | "\n", 447 | "WARNING:tensorflow:From /tensorflow-1.15.2/python3.7/tensorflow_core/python/ops/variables.py:2825: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", 448 | "Instructions for updating:\n", 449 | "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n", 450 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:131: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.\n", 451 | "\n", 452 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:133: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.\n", 453 | "\n", 454 | "__________________________________________________________________________________________________\n", 455 | "Layer (type) Output Shape Param # Connected to \n", 456 | "==================================================================================================\n", 457 | "input0 (InputLayer) (None, None, None, 6 0 \n", 458 | "__________________________________________________________________________________________________\n", 459 | "conv33 (ComplexConv2D) (None, None, None, 9 2688 input0[0][0] \n", 460 | "__________________________________________________________________________________________________\n", 461 | "conv1 (ComplexConv2D) (None, None, None, 9 2688 input0[0][0] \n", 462 | "__________________________________________________________________________________________________\n", 463 | "conv34 (ComplexConv2D) (None, None, None, 9 41568 conv33[0][0] \n", 464 | "__________________________________________________________________________________________________\n", 465 | "conv2 (ComplexConv2D) (None, None, None, 9 41568 conv1[0][0] \n", 466 | "__________________________________________________________________________________________________\n", 467 | "bn35 (ComplexBatchNormalization (None, None, None, 9 480 conv34[0][0] \n", 468 | "__________________________________________________________________________________________________\n", 469 | "bn3 (ComplexBatchNormalization) (None, None, None, 9 480 conv2[0][0] \n", 470 | "__________________________________________________________________________________________________\n", 471 | "conv36 (ComplexConv2D) (None, None, None, 9 41568 bn35[0][0] \n", 472 | "__________________________________________________________________________________________________\n", 473 | "conv4 (ComplexConv2D) (None, None, None, 9 41568 bn3[0][0] \n", 474 | "__________________________________________________________________________________________________\n", 475 | "bn37 (ComplexBatchNormalization (None, None, None, 9 480 conv36[0][0] \n", 476 | "__________________________________________________________________________________________________\n", 477 | "bn5 (ComplexBatchNormalization) (None, None, None, 9 480 conv4[0][0] \n", 478 | "__________________________________________________________________________________________________\n", 479 | "conv38 (ComplexConv2D) (None, None, None, 9 41568 bn37[0][0] \n", 480 | "__________________________________________________________________________________________________\n", 481 | "conv6 (ComplexConv2D) (None, None, None, 9 41568 bn5[0][0] \n", 482 | "__________________________________________________________________________________________________\n", 483 | "bn39 (ComplexBatchNormalization (None, None, None, 9 480 conv38[0][0] \n", 484 | "__________________________________________________________________________________________________\n", 485 | "bn7 (ComplexBatchNormalization) (None, None, None, 9 480 conv6[0][0] \n", 486 | "__________________________________________________________________________________________________\n", 487 | "conv40 (ComplexConv2D) (None, None, None, 9 41568 bn39[0][0] \n", 488 | "__________________________________________________________________________________________________\n", 489 | "conv8 (ComplexConv2D) (None, None, None, 9 41568 bn7[0][0] \n", 490 | "__________________________________________________________________________________________________\n", 491 | "bn41 (ComplexBatchNormalization (None, None, None, 9 480 conv40[0][0] \n", 492 | "__________________________________________________________________________________________________\n", 493 | "bn9 (ComplexBatchNormalization) (None, None, None, 9 480 conv8[0][0] \n", 494 | "__________________________________________________________________________________________________\n", 495 | "conv42 (ComplexConv2D) (None, None, None, 9 41568 bn41[0][0] \n", 496 | "__________________________________________________________________________________________________\n", 497 | "conv10 (ComplexConv2D) (None, None, None, 9 41568 bn9[0][0] \n", 498 | "__________________________________________________________________________________________________\n", 499 | "bn43 (ComplexBatchNormalization (None, None, None, 9 480 conv42[0][0] \n", 500 | "__________________________________________________________________________________________________\n", 501 | "bn11 (ComplexBatchNormalization (None, None, None, 9 480 conv10[0][0] \n", 502 | "__________________________________________________________________________________________________\n", 503 | "conv44 (ComplexConv2D) (None, None, None, 9 41568 bn43[0][0] \n", 504 | "__________________________________________________________________________________________________\n", 505 | "conv12 (ComplexConv2D) (None, None, None, 9 41568 bn11[0][0] \n", 506 | "__________________________________________________________________________________________________\n", 507 | "bn45 (ComplexBatchNormalization (None, None, None, 9 480 conv44[0][0] \n", 508 | "__________________________________________________________________________________________________\n", 509 | "bn13 (ComplexBatchNormalization (None, None, None, 9 480 conv12[0][0] \n", 510 | "__________________________________________________________________________________________________\n", 511 | "conv46 (ComplexConv2D) (None, None, None, 9 41568 bn45[0][0] \n", 512 | "__________________________________________________________________________________________________\n", 513 | "conv14 (ComplexConv2D) (None, None, None, 9 41568 bn13[0][0] \n", 514 | "__________________________________________________________________________________________________\n", 515 | "bn47 (ComplexBatchNormalization (None, None, None, 9 480 conv46[0][0] \n", 516 | "__________________________________________________________________________________________________\n", 517 | "bn15 (ComplexBatchNormalization (None, None, None, 9 480 conv14[0][0] \n", 518 | "__________________________________________________________________________________________________\n", 519 | "conv48 (ComplexConv2D) (None, None, None, 9 41568 bn47[0][0] \n", 520 | "__________________________________________________________________________________________________\n", 521 | "conv16 (ComplexConv2D) (None, None, None, 9 41568 bn15[0][0] \n", 522 | "__________________________________________________________________________________________________\n", 523 | "bn49 (ComplexBatchNormalization (None, None, None, 9 480 conv48[0][0] \n", 524 | "__________________________________________________________________________________________________\n", 525 | "bn17 (ComplexBatchNormalization (None, None, None, 9 480 conv16[0][0] \n", 526 | "__________________________________________________________________________________________________\n", 527 | "conv50 (ComplexConv2D) (None, None, None, 9 41568 bn49[0][0] \n", 528 | "__________________________________________________________________________________________________\n", 529 | "conv18 (ComplexConv2D) (None, None, None, 9 41568 bn17[0][0] \n", 530 | "__________________________________________________________________________________________________\n", 531 | "bn51 (ComplexBatchNormalization (None, None, None, 9 480 conv50[0][0] \n", 532 | "__________________________________________________________________________________________________\n", 533 | "bn19 (ComplexBatchNormalization (None, None, None, 9 480 conv18[0][0] \n", 534 | "__________________________________________________________________________________________________\n", 535 | "conv52 (ComplexConv2D) (None, None, None, 9 41568 bn51[0][0] \n", 536 | "__________________________________________________________________________________________________\n", 537 | "conv20 (ComplexConv2D) (None, None, None, 9 41568 bn19[0][0] \n", 538 | "__________________________________________________________________________________________________\n", 539 | "bn53 (ComplexBatchNormalization (None, None, None, 9 480 conv52[0][0] \n", 540 | "__________________________________________________________________________________________________\n", 541 | "bn21 (ComplexBatchNormalization (None, None, None, 9 480 conv20[0][0] \n", 542 | "__________________________________________________________________________________________________\n", 543 | "conv54 (ComplexConv2D) (None, None, None, 9 41568 bn53[0][0] \n", 544 | "__________________________________________________________________________________________________\n", 545 | "conv22 (ComplexConv2D) (None, None, None, 9 41568 bn21[0][0] \n", 546 | "__________________________________________________________________________________________________\n", 547 | "bn55 (ComplexBatchNormalization (None, None, None, 9 480 conv54[0][0] \n", 548 | "__________________________________________________________________________________________________\n", 549 | "bn23 (ComplexBatchNormalization (None, None, None, 9 480 conv22[0][0] \n", 550 | "__________________________________________________________________________________________________\n", 551 | "conv56 (ComplexConv2D) (None, None, None, 9 41568 bn55[0][0] \n", 552 | "__________________________________________________________________________________________________\n", 553 | "conv24 (ComplexConv2D) (None, None, None, 9 41568 bn23[0][0] \n", 554 | "__________________________________________________________________________________________________\n", 555 | "bn57 (ComplexBatchNormalization (None, None, None, 9 480 conv56[0][0] \n", 556 | "__________________________________________________________________________________________________\n", 557 | "bn25 (ComplexBatchNormalization (None, None, None, 9 480 conv24[0][0] \n", 558 | "__________________________________________________________________________________________________\n", 559 | "conv58 (ComplexConv2D) (None, None, None, 9 41568 bn57[0][0] \n", 560 | "__________________________________________________________________________________________________\n", 561 | "conv26 (ComplexConv2D) (None, None, None, 9 41568 bn25[0][0] \n", 562 | "__________________________________________________________________________________________________\n", 563 | "bn59 (ComplexBatchNormalization (None, None, None, 9 480 conv58[0][0] \n", 564 | "__________________________________________________________________________________________________\n", 565 | "bn27 (ComplexBatchNormalization (None, None, None, 9 480 conv26[0][0] \n", 566 | "__________________________________________________________________________________________________\n", 567 | "conv60 (ComplexConv2D) (None, None, None, 9 41568 bn59[0][0] \n", 568 | "__________________________________________________________________________________________________\n", 569 | "conv28 (ComplexConv2D) (None, None, None, 9 41568 bn27[0][0] \n", 570 | "__________________________________________________________________________________________________\n", 571 | "bn61 (ComplexBatchNormalization (None, None, None, 9 480 conv60[0][0] \n", 572 | "__________________________________________________________________________________________________\n", 573 | "bn29 (ComplexBatchNormalization (None, None, None, 9 480 conv28[0][0] \n", 574 | "__________________________________________________________________________________________________\n", 575 | "conv62 (ComplexConv2D) (None, None, None, 9 41568 bn61[0][0] \n", 576 | "__________________________________________________________________________________________________\n", 577 | "conv30 (ComplexConv2D) (None, None, None, 9 41568 bn29[0][0] \n", 578 | "__________________________________________________________________________________________________\n", 579 | "bn63 (ComplexBatchNormalization (None, None, None, 9 480 conv62[0][0] \n", 580 | "__________________________________________________________________________________________________\n", 581 | "bn31 (ComplexBatchNormalization (None, None, None, 9 480 conv30[0][0] \n", 582 | "__________________________________________________________________________________________________\n", 583 | "clean1 (ComplexConv2D) (None, None, None, 6 2598 bn63[0][0] \n", 584 | "__________________________________________________________________________________________________\n", 585 | "speckle1 (ComplexConv2D) (None, None, None, 6 2598 bn31[0][0] \n", 586 | "__________________________________________________________________________________________________\n", 587 | "noisy1 (Add) (None, None, None, 6 0 speckle1[0][0] \n", 588 | " clean1[0][0] \n", 589 | "==================================================================================================\n", 590 | "Total params: 1,272,012\n", 591 | "Trainable params: 1,264,812\n", 592 | "Non-trainable params: 7,200\n", 593 | "__________________________________________________________________________________________________\n", 594 | "resuming by loading epoch 051\n" 595 | ], 596 | "name": "stdout" 597 | }, 598 | { 599 | "output_type": "stream", 600 | "text": [ 601 | "/usr/local/lib/python3.7/dist-packages/keras/utils/io_utils.py:186: H5pyDeprecationWarning: The default file mode will change to 'r' (read-only) in h5py 3.0. To suppress this warning, pass the mode you need to h5py.File(), or set the global default h5.get_config().default_file_mode, or set the environment variable H5PY_DEFAULT_READONLY=1. Available modes are: 'r', 'r+', 'w', 'w-'/'x', 'a'. See the docs for details.\n", 602 | " self.data = h5py.File(path,)\n" 603 | ], 604 | "name": "stderr" 605 | }, 606 | { 607 | "output_type": "stream", 608 | "text": [ 609 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:174: The name tf.get_default_session is deprecated. Please use tf.compat.v1.get_default_session instead.\n", 610 | "\n", 611 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:181: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.\n", 612 | "\n", 613 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:186: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\n", 614 | "\n", 615 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:190: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.\n", 616 | "\n", 617 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:199: The name tf.is_variable_initialized is deprecated. Please use tf.compat.v1.is_variable_initialized instead.\n", 618 | "\n", 619 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:206: The name tf.variables_initializer is deprecated. Please use tf.compat.v1.variables_initializer instead.\n", 620 | "\n", 621 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n", 622 | "\n", 623 | "WARNING:tensorflow:From /tensorflow-1.15.2/python3.7/tensorflow_core/python/ops/math_grad.py:1424: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", 624 | "Instructions for updating:\n", 625 | "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", 626 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:986: The name tf.assign_add is deprecated. Please use tf.compat.v1.assign_add instead.\n", 627 | "\n", 628 | "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py:973: The name tf.assign is deprecated. Please use tf.compat.v1.assign instead.\n", 629 | "\n", 630 | "Finished generating data from /content/drive/My Drive/data_cplx/New_train\n", 631 | "Finished generating data from /content/drive/My Drive/data_cplx/New_train\n", 632 | "total number of patches: 133376\n", 633 | "steps per epoch: 1042\n", 634 | "\n", 635 | "Epoch 52/52\n", 636 | "Finished generating data from /content/drive/My Drive/data_cplx/New_train\n", 637 | "Finished generating data from /content/drive/My Drive/data_cplx/New_label_final\n", 638 | " 226/1042 [=====>........................] - ETA: 18:29 - loss: 435660.3356 - clean1_loss: 4335.5028 - noisy1_loss: 2110.0556" 639 | ], 640 | "name": "stdout" 641 | }, 642 | { 643 | "output_type": "error", 644 | "ename": "KeyboardInterrupt", 645 | "evalue": "ignored", 646 | "traceback": [ 647 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 648 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 649 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 137\u001b[0m history = model.fit_generator(train_datagen(batch_size=64),\n\u001b[1;32m 138\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnsteps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m52\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m callbacks=[checkpointer,csv_logger,lr_scheduler])\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 650 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/legacy/interfaces.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 89\u001b[0m warnings.warn('Update your `' + object_name + '` call to the ' +\n\u001b[1;32m 90\u001b[0m 'Keras 2 API: ' + signature, stacklevel=2)\n\u001b[0;32m---> 91\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 92\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_original_function\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 651 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 1416\u001b[0m \u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0muse_multiprocessing\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1417\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1418\u001b[0;31m initial_epoch=initial_epoch)\n\u001b[0m\u001b[1;32m 1419\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1420\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0minterfaces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_generator_methods_support\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 652 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/training_generator.py\u001b[0m in \u001b[0;36mfit_generator\u001b[0;34m(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m 214\u001b[0m outs = model.train_on_batch(x, y,\n\u001b[1;32m 215\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 216\u001b[0;31m class_weight=class_weight)\n\u001b[0m\u001b[1;32m 217\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 653 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mtrain_on_batch\u001b[0;34m(self, x, y, sample_weight, class_weight)\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0mins\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0msample_weights\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1216\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_train_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1217\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1218\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0munpack_singleton\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1219\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 654 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2713\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2717\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 655 | "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2673\u001b[0m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2676\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 656 | "\u001b[0;32m/tensorflow-1.15.2/python3.7/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1470\u001b[0m ret = tf_session.TF_SessionRunCallable(self._session._session,\n\u001b[1;32m 1471\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1472\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1473\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1474\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 657 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 658 | ] 659 | } 660 | ] 661 | } 662 | ] 663 | } 664 | -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Version: v1.2 3 | Date: 2021-01-12 4 | Author: Mullissa A.G. 5 | Description: This script tests a pre-trained cv-despecknet model on a properly formatted polarimetric SAR covariance matrix 6 | """ 7 | 8 | # run this to test the model 9 | import complexnn 10 | import argparse 11 | import os, time 12 | import numpy as np 13 | from keras.models import load_model 14 | import keras.backend as K 15 | from tifffile import imread, imwrite 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--set_dir', default='Train', type=str, help='directory of test dataset') 20 | parser.add_argument('--set_names', default=[''], type=list, help='name of test dataset') 21 | parser.add_argument('--model_dir', default=os.path.join('models','cv_despecknet'), type=str, help='directory of the model') 22 | parser.add_argument('--model_name', default='model_050.hdf5', type=str, help='the model name') 23 | parser.add_argument('--result_dir', default='results', type=str, help='directory of results') 24 | parser.add_argument('--save_result', default=1, type=int, help='save the denoised image, 1 or 0') 25 | return parser.parse_args() 26 | 27 | def to_tensor(img): 28 | if img.ndim == 2: 29 | return img[np.newaxis,...,np.newaxis] 30 | elif img.ndim == 3: 31 | return img[np.newaxis,...,] 32 | 33 | def from_tensor(img): 34 | return np.squeeze(img[0,...]) 35 | 36 | 37 | def sum_squared_error(y_true, y_pred): 38 | return K.sum(K.square(y_pred - y_true))/2 39 | 40 | 41 | if __name__ == '__main__': 42 | args = parse_args() 43 | 44 | model = load_model(os.path.join(args.model_dir, args.model_name), custom_objects={'ComplexConv2D': complexnn.conv.ComplexConv2D, 'ComplexBatchNormalization': complexnn.bn.ComplexBatchNormalization, 'sum_squared_error': sum_squared_error}) 45 | 46 | if not os.path.exists(args.result_dir): 47 | os.mkdir(args.result_dir) 48 | 49 | for set_cur in args.set_names: 50 | 51 | if not os.path.exists(os.path.join(args.result_dir,set_cur)): 52 | os.mkdir(os.path.join(args.result_dir,set_cur)) 53 | 54 | for im in os.listdir(os.path.join(args.set_dir,set_cur)): 55 | if im.endswith(".tif") : 56 | y = np.array(imread(os.path.join(args.set_dir,set_cur,im)), dtype=np.float32) 57 | np.random.seed(seed=0) # for reproducibility 58 | y = y.astype(np.float32) 59 | y_ = to_tensor(y) 60 | start_time = time.time() 61 | x_ = model.predict(y_) # filter 62 | elapsed_time = time.time() - start_time 63 | print('%10s : %10s : %2.4f second'%(set_cur,im,elapsed_time)) 64 | x_=from_tensor(x_[0]) 65 | if args.save_result: 66 | name, ext = os.path.splitext(im) 67 | imwrite(os.path.join(args.result_dir,set_cur,name+'_cv-despecknet'+ext) , x_, planarconfig='CONTIG') 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Version: v1.2 3 | Date: 2021-01-12 4 | Author: Mullissa A.G. 5 | Description: This script trains a complex-valued multistream fully convolutional network for despeckling a 6 | polarimetric SAR covariance matrix as discussed in 7 | our paper A. G. Mullissa, C. Persello and J. Reiche, 8 | "Despeckling Polarimetric SAR Data Using a Multistream Complex-Valued Fully Convolutional Network," 9 | in IEEE Geoscience and Remote Sensing Letters, doi: 10.1109/LGRS.2021.3066311. 10 | Some utility functions are adopted from https://github.com/cszn/DnCNN 11 | """ 12 | # ============================================================================= 13 | import complexnn 14 | import helper 15 | import argparse 16 | import re 17 | import os, glob, datetime 18 | import numpy as np 19 | from keras.layers import Input, Add 20 | from keras.models import Model, load_model 21 | from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler 22 | from keras.optimizers import Adam 23 | import keras.backend as K 24 | 25 | 26 | ## Params 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--model', default='cv-despecknet', type=str, help='choose a type of model') 29 | parser.add_argument('--batch_size', default=128, type=int, help='batch size') 30 | parser.add_argument('--train_image', default='data/Train2', type=str, help='path of train data real') 31 | parser.add_argument('--train_label', default='data/Label2', type=str, help='path of label data') 32 | parser.add_argument('--epoch', default=50, type=int, help='number of train epoches') 33 | parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam') 34 | parser.add_argument('--save_every', default=1, type=int, help='save model at every x epoches') 35 | args = parser.parse_args() 36 | 37 | 38 | save_dir = os.path.join('models',args.model) 39 | 40 | if not os.path.exists(save_dir): 41 | os.mkdir(save_dir) 42 | 43 | def cv_deSpeckNet(depth,filters=48,image_channels=6, use_bnorm=True): 44 | #FCN noise 45 | layer_count = 0 46 | inpt = Input(shape=(None,None,image_channels),name = 'input'+str(layer_count)) 47 | # 1st layer, CV-Conv+Crelu 48 | layer_count += 1 49 | x0 = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1), activation='relu', padding='same',name = 'conv'+str(layer_count))(inpt) 50 | # depth-2 layers, CV-Conv+CV-BN+Crelu 51 | for i in range(depth-2): 52 | layer_count += 1 53 | x0 = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1),activation='relu', padding='same',name = 'conv'+str(layer_count))(x0) 54 | if use_bnorm: 55 | layer_count += 1 56 | x0 = complexnn.bn.ComplexBatchNormalization(name = 'bn'+str(layer_count))(x0) 57 | # last layer, CV-Conv 58 | layer_count += 1 59 | x0 = complexnn.conv.ComplexConv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding='same',name = 'speckle'+str(1))(x0) 60 | layer_count += 1 61 | 62 | #FCN clean 63 | # 1st layer, CV-Conv+Crelu 64 | x = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1), activation='relu', padding='same',name = 'conv'+str(layer_count))(inpt) 65 | # depth-2 layers, CV-Conv+CV-BN+Crelu 66 | for i in range(depth-2): 67 | layer_count += 1 68 | x = complexnn.conv.ComplexConv2D(filters=filters, kernel_size=(3,3), strides=(1,1),activation='relu', padding='same',name = 'conv'+str(layer_count))(x) 69 | if use_bnorm: 70 | layer_count += 1 71 | x = complexnn.bn.ComplexBatchNormalization(name = 'bn'+str(layer_count))(x) 72 | # last layer, CV-Conv 73 | layer_count += 1 74 | x = complexnn.conv.ComplexConv2D(filters=3, kernel_size=(3,3), strides=(1,1),padding='same',name = 'clean'+str(1))(x) 75 | layer_count += 1 76 | 77 | x_orig = Add(name = 'noisy' + str(1))([x0,x]) 78 | model = Model(inputs=inpt, outputs=[x,x_orig]) 79 | 80 | return model 81 | 82 | 83 | def findLastCheckpoint(save_dir): 84 | file_list = glob.glob(os.path.join(save_dir,'model_*.hdf5')) # get name list of all .hdf5 files 85 | #file_list = os.listdir(save_dir) 86 | if file_list: 87 | epochs_exist = [] 88 | for file_ in file_list: 89 | result = re.findall(".*model_(.*).hdf5.*",file_) 90 | #print(result[0]) 91 | epochs_exist.append(int(result[0])) 92 | initial_epoch=max(epochs_exist) 93 | else: 94 | initial_epoch = 0 95 | return initial_epoch 96 | 97 | def log(args,kwargs): 98 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"),args,kwargs) 99 | 100 | def lr_schedule(epoch): 101 | initial_lr = args.lr 102 | if epoch<=30: 103 | lr = initial_lr 104 | elif epoch<=60: 105 | lr = initial_lr/10 106 | elif epoch<=80: 107 | lr = initial_lr/20 108 | else: 109 | lr = initial_lr/20 110 | #log('current learning rate is %2.8f' %lr) 111 | return lr 112 | 113 | def train_datagen(epoch_iter=2000,epoch_num=5,batch_size=64,data_dir=args.train_image,label_dir=args.train_label): 114 | while(True): 115 | n_count = 0 116 | if n_count == 0: 117 | #print(n_count) 118 | xs = helper.make_dataTensor(data_dir) 119 | xy = helper.make_dataTensor(label_dir) 120 | assert len(xs)%batch_size ==0, \ 121 | log('make sure the last iteration has a full batchsize, this is important if you use batch normalization!') 122 | xs = xs.astype('float32') 123 | xy = xy.astype('float32') 124 | indices = list(range(xs.shape[0])) 125 | n_count = 1 126 | for _ in range(epoch_num): 127 | np.random.shuffle(indices) # shuffle 128 | for i in range(0, len(indices), batch_size): 129 | batch_x = xs[indices[i:i+batch_size]] 130 | batch_y = xy[indices[i:i+batch_size]] 131 | yield batch_x, [batch_y, batch_x] 132 | 133 | # sum square error loss function 134 | def sum_squared_error(y_true, y_pred): 135 | return K.sum(K.square(y_pred - y_true))/2 136 | 137 | if __name__ == '__main__': 138 | # model selection 139 | 140 | model = cv_deSpeckNet(depth=17,filters=48,image_channels=6,use_bnorm=True) 141 | model.summary() 142 | 143 | # load the last model in matconvnet style 144 | initial_epoch = findLastCheckpoint(save_dir=save_dir) 145 | if initial_epoch > 0: 146 | print('resuming by loading epoch %03d'%initial_epoch) 147 | model = load_model(os.path.join(save_dir,'model_%03d.hdf5'%initial_epoch), custom_objects={'ComplexConv2D': complexnn.conv.ComplexConv2D, 'ComplexBatchNormalization': complexnn.bn.ComplexBatchNormalization, 'sum_squared_error': sum_squared_error}) 148 | 149 | loss_funcs = { 150 | 'clean1': sum_squared_error, 151 | 'noisy1' : sum_squared_error} 152 | 153 | loss_weights = {'clean1': 100.0, 'noisy1': 1.0} 154 | 155 | # compile the model 156 | model.compile(optimizer=Adam(0.001), loss=loss_funcs, loss_weights=loss_weights) 157 | 158 | # use call back functions 159 | checkpointer = ModelCheckpoint(os.path.join(save_dir,'model_{epoch:03d}.hdf5'), 160 | verbose=1, save_weights_only=False, period=1) 161 | csv_logger = CSVLogger(os.path.join(save_dir,'log.csv'), append=True, separator=',') 162 | lr_scheduler = LearningRateScheduler(lr_schedule) 163 | 164 | # numer of steps per epoch 165 | nsteps = helper.get_steps(args.train_image, batch_size=64) 166 | 167 | history = model.fit_generator(train_datagen(batch_size=64), 168 | steps_per_epoch=nsteps, epochs=51, verbose=1, initial_epoch=initial_epoch, 169 | callbacks=[checkpointer,csv_logger,lr_scheduler]) 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.1 2 | tensorflow-gpu==1.13.1 3 | keras==2.2.4 4 | keras-complex==0.1.3 5 | --------------------------------------------------------------------------------