├── Deformable_Conv ├── __init__.py ├── dataset.py ├── layers.py ├── network.py ├── callbacks.py └── deform_conv.py ├── BlurPooling ├── __init__.py ├── BlurPooling.py ├── MaxBlurPooling.py └── AverageBlurPooling.py ├── BlurPooling_test.py ├── SE_module.py ├── deform_conv_test.py ├── .gitignore ├── CBAM_module.py ├── sacled_mnist.py ├── README.md ├── non_local.py ├── stn_module.py ├── non_local_test.py └── STN_test.ipynb /Deformable_Conv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /BlurPooling/__init__.py: -------------------------------------------------------------------------------- 1 | from .AverageBlurPooling import * 2 | from .BlurPooling import * 3 | from .MaxBlurPooling import * 4 | -------------------------------------------------------------------------------- /Deformable_Conv/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import, division 3 | import tensorflow as tf 4 | from tensorflow.keras.datasets import mnist 5 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 6 | 7 | 8 | def get_mnist_dataset(): 9 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 10 | X_train = X_train.astype('float32') / 255 11 | X_test = X_test.astype('float32') / 255 12 | X_train = X_train[..., None] 13 | X_test = X_test[..., None] 14 | Y_train = tf.keras.utils.to_categorical(y_train, 10) 15 | Y_test = tf.keras.utils.to_categorical(y_test, 10) 16 | 17 | return (X_train, Y_train), (X_test, Y_test) 18 | 19 | 20 | def get_gen(set_name, batch_size, translate, scale, 21 | shuffle=True): 22 | if set_name == 'train': 23 | (X, Y), _ = get_mnist_dataset() 24 | elif set_name == 'test': 25 | _, (X, Y) = get_mnist_dataset() 26 | 27 | image_gen = ImageDataGenerator( 28 | zoom_range=scale, 29 | width_shift_range=translate, 30 | height_shift_range=translate 31 | ) 32 | gen = image_gen.flow(X, Y, batch_size=batch_size, shuffle=shuffle) 33 | return gen -------------------------------------------------------------------------------- /BlurPooling_test.py: -------------------------------------------------------------------------------- 1 | import BlurPooling as pooling 2 | import tensorflow as tf 3 | from tensorflow.keras.layers import * 4 | from tensorflow.keras.models import Model 5 | import numpy as np 6 | 7 | def test_pooling_model(input_shape, pooling_type): 8 | input = Input(input_shape) 9 | layer_pool = eval('pooling.'+pooling_type)()(input) 10 | layer_flattern = Flatten()(layer_pool) 11 | output = Dense(1)(layer_flattern) 12 | model = Model(input, output) 13 | model.summary() 14 | return model 15 | 16 | if __name__=='__main__': 17 | poolingtype = [ 18 | 'MaxBlurPooling1D', 19 | 'MaxBlurPooling2D', 20 | 'AverageBlurPooling1D', 21 | 'AverageBlurPooling2D', 22 | 'BlurPool2D', 23 | 'BlurPool1D' 24 | ] 25 | for item in poolingtype: 26 | if '2D' in item: 27 | input_shape = (224, 224, 3) 28 | model = test_pooling_model(input_shape, item) 29 | model.predict([np.random.random((1, 224, 224, 3))]) 30 | else: 31 | input_shape = (224, 3) 32 | model = test_pooling_model(input_shape, item) 33 | model.predict([np.random.random((1, 224, 3))]) 34 | 35 | -------------------------------------------------------------------------------- /SE_module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Activation, Reshape, Conv2D, BatchNormalization 3 | import numpy as np 4 | 5 | class Squeeze_excitation_layer(tf.keras.Model): 6 | def __init__(self, filter_sq): 7 | super().__init__() 8 | self.filter_sq = filter_sq 9 | 10 | def call(self, inputs): 11 | channel = inputs.shape[-1] 12 | squeeze = GlobalAveragePooling2D()(inputs) 13 | excitation = Dense(channel//self.filter_sq)(squeeze) 14 | excitation = Activation('relu')(excitation) 15 | excitation = Dense(channel)(excitation) 16 | excitation = Activation('sigmoid')(excitation) 17 | # reshape excitation: 1*1*input.shape[-1] 18 | excitation = Reshape((1, 1, channel))(excitation) 19 | # 获得通道权重 20 | outputs = inputs*excitation 21 | return outputs 22 | 23 | def SEBottleneck(input, filter_sq=16, stride=1): 24 | residual = inputs 25 | se_module = Squeeze_excitation_layer(16) 26 | 27 | x = Conv2D(16, kernel_size=1)(input) 28 | x = BatchNormalization()(x) 29 | x = Activation('relu')(x) 30 | 31 | x = Conv2D(16, kernel_size=3, strides=stride, padding='same')(input) 32 | x = BatchNormalization()(x) 33 | x = Activation('relu')(x) 34 | 35 | x = Conv2D(32, kernel_size=1)(input) 36 | x = BatchNormalization()(x) 37 | x = se_module(x) 38 | 39 | output = x+residual 40 | output = Activation('relu')(output) 41 | 42 | return output 43 | 44 | 45 | 46 | 47 | SE_module = Squeeze_excitation_layer(16) 48 | inputs = np.zeros((1, 32, 32, 32), dtype=np.float32) 49 | out_shape = SE_module(inputs).shape 50 | print(out_shape) 51 | 52 | 53 | inputs = np.zeros((1, 32, 32, 32), dtype=np.float32) 54 | SEB = SEBottleneck(inputs) 55 | print(SEB.shape) 56 | 57 | -------------------------------------------------------------------------------- /Deformable_Conv/layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | import tensorflow as tf 3 | from tensorflow.keras.layers import Conv2D 4 | from .deform_conv import tf_batch_map_offsets 5 | 6 | 7 | class ConvOffset2D(Conv2D): 8 | """ConvOffset2D""" 9 | 10 | def __init__(self, filters, init_normal_stddev=0.01, **kwargs): 11 | """Init""" 12 | 13 | self.filters = filters 14 | super(ConvOffset2D, self).__init__( 15 | self.filters * 2, (3, 3), padding='same', use_bias=False, 16 | # TODO gradients are near zero if init is zeros 17 | kernel_initializer='zeros', 18 | # kernel_initializer=RandomNormal(0, init_normal_stddev), 19 | **kwargs 20 | ) 21 | 22 | def call(self, x): 23 | # TODO offsets probably have no nonlinearity? 24 | x_shape = x.get_shape() 25 | offsets = super(ConvOffset2D, self).call(x) 26 | 27 | offsets = self._to_bc_h_w_2(offsets, x_shape) 28 | x = self._to_bc_h_w(x, x_shape) 29 | x_offset = tf_batch_map_offsets(x, offsets) 30 | x_offset = self._to_b_h_w_c(x_offset, x_shape) 31 | return x_offset 32 | 33 | def compute_output_shape(self, input_shape): 34 | return input_shape 35 | 36 | @staticmethod 37 | def _to_bc_h_w_2(x, x_shape): 38 | """(b, h, w, 2c) -> (b*c, h, w, 2)""" 39 | x = tf.transpose(x, [0, 3, 1, 2]) 40 | x = tf.reshape(x, (-1, int(x_shape[1]), int(x_shape[2]), 2)) 41 | return x 42 | 43 | @staticmethod 44 | def _to_bc_h_w(x, x_shape): 45 | """(b, h, w, c) -> (b*c, h, w)""" 46 | x = tf.transpose(x, [0, 3, 1, 2]) 47 | x = tf.reshape(x, (-1, int(x_shape[1]), int(x_shape[2]))) 48 | return x 49 | 50 | @staticmethod 51 | def _to_b_h_w_c(x, x_shape): 52 | """(b*c, h, w) -> (b, h, w, c)""" 53 | x = tf.reshape( 54 | x, (-1, int(x_shape[3]), int(x_shape[1]), int(x_shape[2])) 55 | ) 56 | x = tf.transpose(x, [0, 2, 3, 1]) 57 | return x -------------------------------------------------------------------------------- /deform_conv_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.keras.backend as K 3 | from scipy.ndimage.interpolation import map_coordinates 4 | from Deformable_Conv.deform_conv import ( 5 | tf_map_coordinates, 6 | sp_batch_map_coordinates, tf_batch_map_coordinates, 7 | sp_batch_map_offsets, tf_batch_map_offsets 8 | ) 9 | import tensorflow as tf 10 | tf.compat.v1.disable_eager_execution() 11 | 12 | def test_tf_map_coordinates(): 13 | np.random.seed(42) 14 | input = np.random.random((100, 100)) 15 | coords = np.random.random((200, 2)) * 99 16 | 17 | sp_mapped_vals = map_coordinates(input, coords.T, order=1) 18 | tf_mapped_vals = tf_map_coordinates( 19 | K.variable(input), K.variable(coords) 20 | ) 21 | assert np.allclose(sp_mapped_vals, K.eval(tf_mapped_vals), atol=1e-5) 22 | 23 | 24 | def test_tf_batch_map_coordinates(): 25 | np.random.seed(42) 26 | input = np.random.random((4, 100, 100)) 27 | coords = np.random.random((4, 200, 2)) * 99 28 | 29 | sp_mapped_vals = sp_batch_map_coordinates(input, coords) 30 | tf_mapped_vals = tf_batch_map_coordinates( 31 | K.variable(input), K.variable(coords) 32 | ) 33 | assert np.allclose(sp_mapped_vals, K.eval(tf_mapped_vals), atol=1e-5) 34 | 35 | 36 | def test_tf_batch_map_offsets(): 37 | np.random.seed(42) 38 | input = np.random.random((4, 100, 100)) 39 | offsets = np.random.random((4, 100, 100, 2)) * 2 40 | 41 | sp_mapped_vals = sp_batch_map_offsets(input, offsets) 42 | tf_mapped_vals = tf_batch_map_offsets( 43 | K.variable(input), K.variable(offsets) 44 | ) 45 | assert np.allclose(sp_mapped_vals, K.eval(tf_mapped_vals), atol=1e-5) 46 | 47 | 48 | def test_tf_batch_map_offsets_grad(): 49 | np.random.seed(42) 50 | input = np.random.random((4, 100, 100)) 51 | offsets = np.random.random((4, 100, 100, 2)) * 2 52 | 53 | input = K.variable(input) 54 | offsets = K.variable(offsets) 55 | 56 | tf_mapped_vals = tf_batch_map_offsets(input, offsets) 57 | grad = K.gradients(tf_mapped_vals, input)[0] 58 | grad = K.eval(grad) 59 | assert not np.allclose(grad, 0) 60 | 61 | if __name__ == '__main__': 62 | test_tf_map_coordinates() 63 | test_tf_batch_map_coordinates() 64 | test_tf_batch_map_offsets() 65 | test_tf_batch_map_offsets_grad() 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .history -------------------------------------------------------------------------------- /CBAM_module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend as K 3 | import tensorflow.keras.layers as layer 4 | import numpy as np 5 | 6 | class CBAM_module(tf.keras.Model): 7 | def __init__(self, ratio=16, name=''): 8 | super().__init__() 9 | self._ratio = ratio 10 | self._name = name 11 | 12 | def channel_attention(self, input): 13 | channel = input.shape[-1] 14 | # 同时进行avg pooling和 max pooling 15 | avg_pool = layer.GlobalAveragePooling2D()(input) 16 | max_pool = layer.GlobalAveragePooling2D()(input) 17 | avg_pool = layer.Reshape((1, 1, channel))(avg_pool) 18 | max_pool = layer.Reshape((1, 1, channel))(max_pool) 19 | 20 | # 对pooling结果经过两层全连接层,第一层核数量为input的通道数//ratio,第二层则恢复到原通道数 21 | avg_pool = layer.Dense(channel//self._ratio, activation='relu', kernel_initializer='he_normal', name=self._name)(avg_pool) 22 | max_pool = layer.Dense(channel//self._ratio, activation='relu',kernel_initializer='he_normal', name=self._name)(max_pool) 23 | 24 | avg_pool = layer.Dense(channel, activation='relu', kernel_initializer='he_normal', name=self._name)(avg_pool) 25 | max_pool = layer.Dense(channel, activation='relu', kernel_initializer='he_normal', name=self._name)(max_pool) 26 | 27 | # 对avg_pool与max_pool相加做激活,得到(batchsize, 1,1 channel)的tensor,作为权重与input相乘 28 | output = layer.Add()([avg_pool, max_pool]) 29 | output = layer.Activation('sigmoid')(output) 30 | output = layer.multiply([input, output]) 31 | return output 32 | 33 | def spatial_attention(self, input, kernel_size=7): 34 | avg_pool = layer.Lambda(lambda x:K.mean(x,axis=3, keepdims=True))(input) 35 | max_pool = layer.Lambda(lambda x:K.max(x,axis=3, keepdims=True))(input) 36 | 37 | concat_feature = layer.Concatenate(axis=3)([avg_pool, max_pool]) 38 | output =layer.Conv2D(filters = 1, 39 | kernel_size=kernel_size, 40 | strides=1, 41 | padding='same', 42 | kernel_initializer='he_normal')(concat_feature) 43 | output = layer.Activation('sigmoid')(output) 44 | output = layer.multiply([input, output]) 45 | return output 46 | 47 | def call(self, input): 48 | cbam_feature = self.channel_attention(input) 49 | cbam_feature = self.spatial_attention(cbam_feature) 50 | return cbam_feature 51 | 52 | CBAM_module = CBAM_module() 53 | inputs = np.zeros((1, 32, 32, 32), dtype=np.float32) 54 | out_shape = CBAM_module(inputs).shape 55 | print(out_shape) -------------------------------------------------------------------------------- /Deformable_Conv/network.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | from tensorflow.keras.layers import Input, Conv2D, Activation, GlobalAvgPool2D, Dense, BatchNormalization 3 | from .layers import ConvOffset2D 4 | 5 | 6 | def build_cnn(): 7 | inputs = l = Input((28, 28, 1), name='input') 8 | 9 | # conv11 10 | l = Conv2D(32, (3, 3), padding='same', name='conv11')(l) 11 | l = Activation('relu', name='conv11_relu')(l) 12 | l = BatchNormalization(name='conv11_bn')(l) 13 | 14 | # conv12 15 | l = Conv2D(64, (3, 3), padding='same', strides=(2, 2), name='conv12')(l) 16 | l = Activation('relu', name='conv12_relu')(l) 17 | l = BatchNormalization(name='conv12_bn')(l) 18 | 19 | # conv21 20 | l = Conv2D(128, (3, 3), padding='same', name='conv21')(l) 21 | l = Activation('relu', name='conv21_relu')(l) 22 | l = BatchNormalization(name='conv21_bn')(l) 23 | 24 | # conv22 25 | l = Conv2D(128, (3, 3), padding='same', strides=(2, 2), name='conv22')(l) 26 | l = Activation('relu', name='conv22_relu')(l) 27 | l = BatchNormalization(name='conv22_bn')(l) 28 | 29 | # out 30 | l = GlobalAvgPool2D(name='avg_pool')(l) 31 | l = Dense(10, name='fc1')(l) 32 | outputs = l = Activation('softmax', name='out')(l) 33 | 34 | return inputs, outputs 35 | 36 | 37 | def build_deform_cnn(trainable): 38 | inputs = l = Input((28, 28, 1), name='input') 39 | 40 | # conv11 41 | l = Conv2D(32, (3, 3), padding='same', name='conv11', trainable=trainable)(l) 42 | l = Activation('relu', name='conv11_relu')(l) 43 | l = BatchNormalization(name='conv11_bn')(l) 44 | 45 | # conv12 46 | l_offset = ConvOffset2D(32, name='conv12_offset')(l) 47 | l = Conv2D(64, (3, 3), padding='same', strides=(2, 2), name='conv12', trainable=trainable)(l_offset) 48 | l = Activation('relu', name='conv12_relu')(l) 49 | l = BatchNormalization(name='conv12_bn')(l) 50 | 51 | # conv21 52 | l_offset = ConvOffset2D(64, name='conv21_offset')(l) 53 | l = Conv2D(128, (3, 3), padding='same', name='conv21', trainable=trainable)(l_offset) 54 | l = Activation('relu', name='conv21_relu')(l) 55 | l = BatchNormalization(name='conv21_bn')(l) 56 | 57 | # conv22 58 | l_offset = ConvOffset2D(128, name='conv22_offset')(l) 59 | l = Conv2D(128, (3, 3), padding='same', strides=(2, 2), name='conv22', trainable=trainable)(l_offset) 60 | l = Activation('relu', name='conv22_relu')(l) 61 | l = BatchNormalization(name='conv22_bn')(l) 62 | 63 | # out 64 | l = GlobalAvgPool2D(name='avg_pool')(l) 65 | l = Dense(10, name='fc1', trainable=trainable)(l) 66 | outputs = l = Activation('softmax', name='out')(l) 67 | 68 | return inputs, outputs -------------------------------------------------------------------------------- /BlurPooling/BlurPooling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Layer 3 | from tensorflow.keras import backend as K 4 | import numpy as np 5 | 6 | class BlurPool2D(Layer): 7 | def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs): 8 | self.pool_size = pool_size 9 | self.blur_kernel = None 10 | self.kernel_size = kernel_size 11 | 12 | super(BlurPool2D, self).__init__(**kwargs) 13 | 14 | def build(self, input_shape): 15 | 16 | if self.kernel_size == 3: 17 | blur_kernel = np.array([[1, 2, 1], 18 | [2, 4, 2], 19 | [1, 2, 1]]) 20 | blur_kernel = blur_kernel / np.sum(blur_kernel) 21 | elif self.kernel_size == 5: 22 | blur_kernel = np.array([[1, 4, 6, 4, 1], 23 | [4, 16, 24, 16, 4], 24 | [6, 24, 36, 24, 6], 25 | [4, 16, 24, 16, 4], 26 | [1, 4, 6, 4, 1]]) 27 | blur_kernel = blur_kernel / np.sum(blur_kernel) 28 | else: 29 | raise ValueError 30 | 31 | blur_kernel = np.repeat(blur_kernel, input_shape[3]) 32 | 33 | blur_kernel = np.reshape(blur_kernel, (self.kernel_size, self.kernel_size, input_shape[3], 1)) 34 | blur_init = tf.keras.initializers.constant(blur_kernel) 35 | 36 | self.blur_kernel = self.add_weight(name='blur_kernel', 37 | shape=(self.kernel_size, self.kernel_size, input_shape[3], 1), 38 | initializer=blur_init, 39 | trainable=False) 40 | 41 | super(BlurPool2D, self).build(input_shape) # Be sure to call this at the end 42 | 43 | def call(self, x): 44 | x = K.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size)) 45 | 46 | return x 47 | 48 | def compute_output_shape(self, input_shape): 49 | return input_shape[0], int(np.ceil(input_shape[1] / 2)), int(np.ceil(input_shape[2] / 2)), input_shape[3] 50 | 51 | 52 | class BlurPool1D(Layer): 53 | 54 | def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs): 55 | self.pool_size = pool_size 56 | self.blur_kernel = None 57 | self.kernel_size = kernel_size 58 | 59 | super(BlurPool1D, self).__init__(**kwargs) 60 | 61 | def build(self, input_shape): 62 | 63 | if self.kernel_size == 3: 64 | blur_kernel = np.array([2, 4, 2]) 65 | elif self.kernel_size == 5: 66 | blur_kernel = np.array([6, 24, 36, 24, 6]) 67 | else: 68 | raise ValueError 69 | 70 | blur_kernel = blur_kernel / np.sum(blur_kernel) 71 | blur_kernel = np.repeat(blur_kernel, input_shape[2]) 72 | blur_kernel = np.reshape(blur_kernel, (self.kernel_size, 1, input_shape[2], 1)) 73 | blur_init = tf.keras.initializers.constant(blur_kernel) 74 | 75 | self.blur_kernel = self.add_weight(name='blur_kernel', 76 | shape=(self.kernel_size, 1, input_shape[2], 1), 77 | initializer=blur_init, 78 | trainable=False) 79 | 80 | super(BlurPool1D, self).build(input_shape) # Be sure to call this at the end 81 | 82 | def call(self, x): 83 | 84 | x = K.expand_dims(x, axis=-2) 85 | x = K.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size)) 86 | x = K.squeeze(x, axis=-2) 87 | 88 | return x 89 | 90 | def compute_output_shape(self, input_shape): 91 | return input_shape[0], int(np.ceil(input_shape[1] / 2)), input_shape[2] -------------------------------------------------------------------------------- /BlurPooling/MaxBlurPooling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Layer 3 | from tensorflow.keras import backend as K 4 | import numpy as np 5 | 6 | class MaxBlurPooling1D(Layer): 7 | def __init__(self, pool_size=2, kernel_size=3, **kwargs): 8 | super().__init__(**kwargs) 9 | self._pool_size = pool_size 10 | self._kernel_size = kernel_size 11 | self._avg_kernel = None 12 | self._blur_kernel = None 13 | 14 | def build(self, input_shape): 15 | if self._kernel_size == 3: 16 | blur_kernel = np.array([2, 4, 2]) 17 | elif self._kernel_size == 5: 18 | blur_kernel = np.array([6, 24, 36, 24, 6]) 19 | else: 20 | raise ValueError 21 | blur_kernel = blur_kernel/np.sum(blur_kernel) 22 | blur_kernel = np.repeat(blur_kernel, input_shape[2]) 23 | blur_kernel = np.reshape(blur_kernel, (self._kernel_size, 1, input_shape[2], 1)) 24 | blur_init = tf.keras.initializers.constant(blur_kernel) 25 | 26 | self._blur_kernel = self.add_weight(name='blur_kernel', shape=(self._kernel_size, 1, input_shape[2], 1), initializer=blur_init, trainable=False) 27 | super(MaxBlurPooling1D,self).build(input_shape) 28 | 29 | def call(self, x): 30 | x = tf.nn.pool(x, (self._pool_size, ), strides=(1, ), padding='SAME', pooling_type='MAX',data_format='NWC') 31 | x = K.expand_dims(x, axis=-2) 32 | x = K.depthwise_conv2d(x, self._blur_kernel, padding='same', strides=(self._pool_size, self._pool_size)) 33 | x = K.squeeze(x, axis=-2) 34 | return x 35 | def compute_output_shape(self, input_shape): 36 | return input_shape[0], int(np.ceil(input_shape[1]/2)), input_shape[2] 37 | 38 | class MaxBlurPooling2D(Layer): 39 | 40 | def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs): 41 | self.pool_size = pool_size 42 | self.blur_kernel = None 43 | self.kernel_size = kernel_size 44 | 45 | super(MaxBlurPooling2D, self).__init__(**kwargs) 46 | 47 | def build(self, input_shape): 48 | 49 | if self.kernel_size == 3: 50 | blur_kernel = np.array([[1, 2, 1], 51 | [2, 4, 2], 52 | [1, 2, 1]]) 53 | blur_kernel = blur_kernel / np.sum(blur_kernel) 54 | elif self.kernel_size == 5: 55 | blur_kernel = np.array([[1, 4, 6, 4, 1], 56 | [4, 16, 24, 16, 4], 57 | [6, 24, 36, 24, 6], 58 | [4, 16, 24, 16, 4], 59 | [1, 4, 6, 4, 1]]) 60 | blur_kernel = blur_kernel / np.sum(blur_kernel) 61 | else: 62 | raise ValueError 63 | 64 | blur_kernel = np.repeat(blur_kernel, input_shape[3]) 65 | 66 | blur_kernel = np.reshape(blur_kernel, (self.kernel_size, self.kernel_size, input_shape[3], 1)) 67 | blur_init =tf.keras.initializers.constant(blur_kernel) 68 | 69 | self.blur_kernel = self.add_weight(name='blur_kernel', 70 | shape=(self.kernel_size, self.kernel_size, input_shape[3], 1), 71 | initializer=blur_init, 72 | trainable=False) 73 | 74 | super(MaxBlurPooling2D, self).build(input_shape) # Be sure to call this at the end 75 | 76 | def call(self, x): 77 | 78 | x = tf.nn.pool(x, (self.pool_size, self.pool_size), 79 | strides=(1, 1), padding='SAME', pooling_type='MAX', data_format='NHWC') 80 | x = K.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size)) 81 | 82 | return x 83 | 84 | def compute_output_shape(self, input_shape): 85 | return input_shape[0], int(np.ceil(input_shape[1] / 2)), int(np.ceil(input_shape[2] / 2)), input_shape[3] -------------------------------------------------------------------------------- /BlurPooling/AverageBlurPooling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Layer 3 | from tensorflow.keras import backend as K 4 | import numpy as np 5 | 6 | class AverageBlurPooling1D(Layer): 7 | 8 | def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs): 9 | self.pool_size = pool_size 10 | self.blur_kernel = None 11 | self.kernel_size = kernel_size 12 | 13 | super(AverageBlurPooling1D, self).__init__(**kwargs) 14 | 15 | def build(self, input_shape): 16 | 17 | if self.kernel_size == 3: 18 | blur_kernel = np.array([2, 4, 2]) 19 | elif self.kernel_size == 5: 20 | blur_kernel = np.array([6, 24, 36, 24, 6]) 21 | else: 22 | raise ValueError 23 | 24 | blur_kernel = blur_kernel / np.sum(blur_kernel) 25 | blur_kernel = np.repeat(blur_kernel, input_shape[2]) 26 | blur_kernel = np.reshape(blur_kernel, (self.kernel_size, 1, input_shape[2], 1)) 27 | blur_init =tf.keras.initializers.constant(blur_kernel) 28 | 29 | self.blur_kernel = self.add_weight(name='blur_kernel', 30 | shape=(self.kernel_size, 1, input_shape[2], 1), 31 | initializer=blur_init, 32 | trainable=False) 33 | 34 | super(AverageBlurPooling1D, self).build(input_shape) # Be sure to call this at the end 35 | 36 | def call(self, x): 37 | 38 | x = tf.nn.pool(x, (self.pool_size,), strides=(1,), padding='SAME', pooling_type='AVG', 39 | data_format='NWC') 40 | x = K.expand_dims(x, axis=-2) 41 | x = K.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size)) 42 | x = K.squeeze(x, axis=-2) 43 | 44 | return x 45 | 46 | def compute_output_shape(self, input_shape): 47 | return input_shape[0], int(np.ceil(input_shape[1] / 2)), input_shape[2] 48 | 49 | class AverageBlurPooling2D(Layer): 50 | 51 | def __init__(self, pool_size: int = 2, kernel_size: int = 3, **kwargs): 52 | self.pool_size = pool_size 53 | self.blur_kernel = None 54 | self.kernel_size = kernel_size 55 | 56 | super(AverageBlurPooling2D, self).__init__(**kwargs) 57 | 58 | def build(self, input_shape): 59 | 60 | if self.kernel_size == 3: 61 | blur_kernel = np.array([[1, 2, 1], 62 | [2, 4, 2], 63 | [1, 2, 1]]) 64 | blur_kernel = blur_kernel / np.sum(blur_kernel) 65 | elif self.kernel_size == 5: 66 | blur_kernel = np.array([[1, 4, 6, 4, 1], 67 | [4, 16, 24, 16, 4], 68 | [6, 24, 36, 24, 6], 69 | [4, 16, 24, 16, 4], 70 | [1, 4, 6, 4, 1]]) 71 | blur_kernel = blur_kernel / np.sum(blur_kernel) 72 | else: 73 | raise ValueError 74 | 75 | blur_kernel = np.repeat(blur_kernel, input_shape[3]) 76 | 77 | blur_kernel = np.reshape(blur_kernel, (self.kernel_size, self.kernel_size, input_shape[3], 1)) 78 | blur_init = tf.keras.initializers.constant(blur_kernel) 79 | 80 | self.blur_kernel = self.add_weight(name='blur_kernel', 81 | shape=(self.kernel_size, self.kernel_size, input_shape[3], 1), 82 | initializer=blur_init, 83 | trainable=False) 84 | 85 | super(AverageBlurPooling2D, self).build(input_shape) # Be sure to call this at the end 86 | 87 | def call(self, x): 88 | 89 | x = tf.nn.pool(x, (self.pool_size, self.pool_size), strides=(1, 1), padding='SAME', pooling_type='AVG', 90 | data_format='NHWC') 91 | x = K.depthwise_conv2d(x, self.blur_kernel, padding='same', strides=(self.pool_size, self.pool_size)) 92 | 93 | return x 94 | 95 | def compute_output_shape(self, input_shape): 96 | return input_shape[0], int(np.ceil(input_shape[1] / 2)), int(np.ceil(input_shape[2] / 2)), input_shape[3] -------------------------------------------------------------------------------- /sacled_mnist.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow.keras.backend as K 5 | from tensorflow.keras.models import Model 6 | from tensorflow.keras.losses import categorical_crossentropy 7 | from tensorflow.keras.optimizers import Adam, SGD 8 | from Deformable_Conv.layers import ConvOffset2D 9 | from Deformable_Conv.callbacks import TensorBoard 10 | from Deformable_Conv.network import build_cnn, build_deform_cnn 11 | from Deformable_Conv.dataset import get_gen 12 | config = tf.ConfigProto() 13 | config.gpu_options.allow_growth = True 14 | sess = tf.Session(config=config) 15 | K.set_session(sess) 16 | 17 | # --- 18 | # Config 19 | 20 | batch_size = 32 21 | n_train = 60000 22 | n_test = 10000 23 | steps_per_epoch = int(np.ceil(n_train / batch_size)) 24 | validation_steps = int(np.ceil(n_test / batch_size)) 25 | 26 | train_gen = get_gen( 27 | 'train', batch_size=batch_size, 28 | scale=(1.0, 1.0), translate=0.0, 29 | shuffle=True 30 | ) 31 | test_gen = get_gen( 32 | 'test', batch_size=batch_size, 33 | scale=(1.0, 1.0), translate=0.0, 34 | shuffle=False 35 | ) 36 | train_scaled_gen = get_gen( 37 | 'train', batch_size=batch_size, 38 | scale=(1.0, 2.5), translate=0.2, 39 | shuffle=True 40 | ) 41 | test_scaled_gen = get_gen( 42 | 'test', batch_size=batch_size, 43 | scale=(1.0, 2.5), translate=0.2, 44 | shuffle=False 45 | ) 46 | 47 | 48 | # --- 49 | # Normal CNN 50 | 51 | inputs, outputs = build_cnn() 52 | model = Model(inputs=inputs, outputs=outputs) 53 | model.summary() 54 | optim = Adam(1e-3) 55 | # optim = SGD(1e-3, momentum=0.99, nesterov=True) 56 | loss = categorical_crossentropy 57 | model.compile(optim, loss, metrics=['accuracy']) 58 | 59 | model.fit_generator( 60 | train_gen, steps_per_epoch=steps_per_epoch, 61 | epochs=10, verbose=1, 62 | validation_data=test_gen, validation_steps=validation_steps 63 | ) 64 | # model.save_weights('models/cnn.h5') 65 | # # 1875/1875 [==============================] - 24s - loss: 0.0090 - acc: 0.9969 - val_loss: 0.0528 - val_acc: 0.9858 66 | 67 | # # --- 68 | # # Evaluate normal CNN 69 | 70 | # model.load_weights('models/cnn.h5', by_name=True) 71 | 72 | val_loss, val_acc = model.evaluate_generator( 73 | test_gen, steps=validation_steps 74 | ) 75 | print('Test accuracy', val_acc) 76 | # 0.9874 77 | 78 | val_loss, val_acc = model.evaluate_generator( 79 | test_scaled_gen, steps=validation_steps 80 | ) 81 | print('Test accuracy with scaled images', val_acc) 82 | # 0.5701 83 | 84 | # --- 85 | # Deformable CNN 86 | 87 | inputs, outputs = build_deform_cnn(trainable=False) 88 | model = Model(inputs=inputs, outputs=outputs) 89 | # model.load_weights('models/cnn.h5', by_name=True) 90 | model.summary() 91 | optim = Adam(5e-4) 92 | # optim = SGD(1e-4, momentum=0.99, nesterov=True) 93 | loss = categorical_crossentropy 94 | model.compile(optim, loss, metrics=['accuracy']) 95 | 96 | model.fit_generator( 97 | train_scaled_gen, steps_per_epoch=steps_per_epoch, 98 | epochs=20, verbose=1, 99 | validation_data=test_scaled_gen, validation_steps=validation_steps 100 | ) 101 | # Epoch 20/20 102 | # 1875/1875 [==============================] - 504s - loss: 0.2838 - acc: 0.9122 - val_loss: 0.2359 - val_acc: 0.9231 103 | # model.save_weights('models/deform_cnn.h5') 104 | 105 | # # -- 106 | # # Evaluate deformable CNN 107 | 108 | # model.load_weights('models/deform_cnn.h5') 109 | 110 | val_loss, val_acc = model.evaluate_generator( 111 | test_scaled_gen, steps=validation_steps 112 | ) 113 | print('Test accuracy of deformable convolution with scaled images', val_acc) 114 | # 0.9255 115 | 116 | val_loss, val_acc = model.evaluate_generator( 117 | test_gen, steps=validation_steps 118 | ) 119 | print('Test accuracy of deformable convolution with regular images', val_acc) 120 | # 0.9727 121 | 122 | deform_conv_layers = [l for l in model.layers if isinstance(l, ConvOffset2D)] 123 | 124 | Xb, Yb = next(test_gen) 125 | for l in deform_conv_layers: 126 | print(l) 127 | _model = Model(inputs=inputs, outputs=l.output) 128 | offsets = _model.predict(Xb) 129 | offsets = offsets.reshape(offsets.shape[0], offsets.shape[1], offsets.shape[2], -1, 2) 130 | print(offsets.min()) 131 | print(offsets.mean()) 132 | print(offsets.max()) -------------------------------------------------------------------------------- /Deformable_Conv/callbacks.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.keras.callbacks import Callback 5 | import tensorflow.keras.backend as K 6 | 7 | 8 | class TensorBoard(Callback): 9 | """Tensorboard basic visualizations""" 10 | 11 | def __init__(self, log_dir='./logs', 12 | histogram_freq=0, 13 | write_graph=True, 14 | write_images=False): 15 | super(TensorBoard, self).__init__() 16 | if K.backend() != 'tensorflow': 17 | raise RuntimeError('TensorBoard callback only works ' 18 | 'with the TensorFlow backend.') 19 | self.log_dir = log_dir 20 | self.histogram_freq = histogram_freq 21 | self.merged = None 22 | self.write_graph = write_graph 23 | self.write_images = write_images 24 | 25 | def set_model(self, model): 26 | self.model = model 27 | self.sess = K.get_session() 28 | total_loss = self.model.total_loss 29 | if self.histogram_freq and self.merged is None: 30 | for layer in self.model.layers: 31 | for weight in layer.weights: 32 | # dense_1/bias:0 > dense_1/bias_0 33 | name = weight.name.replace(':', '_') 34 | tf.summary.histogram(name, weight) 35 | tf.summary.histogram( 36 | '{}_gradients'.format(name), 37 | K.gradients(total_loss, [weight])[0] 38 | ) 39 | if self.write_images: 40 | w_img = tf.squeeze(weight) 41 | shape = w_img.get_shape() 42 | if len(shape) > 1 and shape[0] > shape[1]: 43 | w_img = tf.transpose(w_img) 44 | if len(shape) == 1: 45 | w_img = tf.expand_dims(w_img, 0) 46 | w_img = tf.expand_dims(tf.expand_dims(w_img, 0), -1) 47 | tf.summary.image(name, w_img) 48 | 49 | if hasattr(layer, 'output'): 50 | tf.summary.histogram('{}_out'.format(layer.name), 51 | layer.output) 52 | self.merged = tf.summary.merge_all() 53 | 54 | if self.write_graph: 55 | self.writer = tf.summary.FileWriter(self.log_dir, 56 | self.sess.graph) 57 | else: 58 | self.writer = tf.summary.FileWriter(self.log_dir) 59 | 60 | def on_epoch_end(self, epoch, logs=None): 61 | logs = logs or {} 62 | 63 | if self.validation_data and self.histogram_freq: 64 | if epoch % self.histogram_freq == 0: 65 | # TODO: implement batched calls to sess.run 66 | # (current call will likely go OOM on GPU) 67 | if self.model.uses_learning_phase: 68 | cut_v_data = len(self.model.inputs) 69 | val_data = self.validation_data[:cut_v_data][:32] + [0] 70 | tensors = self.model.inputs + self.model.targets + [K.learning_phase()] 71 | else: 72 | val_data = self.validation_data 73 | tensors = self.model.inputs + self.model.targets 74 | 75 | feed_dict = dict(zip(tensors, val_data)) 76 | sample_weights = self.model.sample_weights 77 | for w in sample_weights: 78 | w_val = np.ones(len(val_data[0]), dtype=np.float32) 79 | feed_dict.update({w.name: w_val}) 80 | result = self.sess.run([self.merged], feed_dict=feed_dict) 81 | summary_str = result[0] 82 | self.writer.add_summary(summary_str, epoch) 83 | 84 | for name, value in logs.items(): 85 | if name in ['batch', 'size']: 86 | continue 87 | 88 | if name[:3] != 'val': 89 | name = 'train_' + name 90 | 91 | summary = tf.Summary() 92 | summary_value = summary.value.add() 93 | summary_value.simple_value = value.item() 94 | summary_value.tag = name 95 | self.writer.add_summary(summary, epoch) 96 | self.writer.flush() 97 | 98 | def on_train_end(self, _): 99 | self.writer.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CNNComponent 2 | 3 | 基于`Tensorflow2.X`实现卷积神经网络即插即用模块。 4 | 5 | ## 已实现 6 | - [x] STN 7 | - [x] SE 8 | - [x] CBAM 9 | - [x] non_local 10 | - [x] blur pooling 11 | - [x] Deformable Conv 12 | 13 | ## STN 14 | 主要参考:[Hands-on: implement a spatial transformer network by yourself](https://xeonqq.github.io/machine%20learning/spatial-transformer-networks/)。简单的理论部分可以参考我的博客:[深度学习 卷积神经网络即插即用的小插件](https://blog.csdn.net/u012655441/article/details/121919291)。STN结构如下图所示: 15 | ![image](https://user-images.githubusercontent.com/27406337/145952361-5d738cbc-ca73-40ce-bd89-4244b81358d6.png) 16 | 里面包括三个组件: 17 | - **Localization net**:该网络可以是卷积神经网络或者是全连接神经网络,它们有个特点是最后一层是一个回归层,主要生成6个值表示仿射变换的参数θ。 18 | - **Grid Generator**:它首先在目标图像V上生成一个网格,网格的每个点刚好对应目标图像中每个像素的像素坐标。然后它使用Localization net生成的θ来变换网格。 19 | - **Sampler**:变换后的网格就像源图像U上的遮罩,它检索遮罩下的像素。然而,变换的网格不再包含整数值,因此对源图像U执行双线性插值,以获得变换网格下的估计像素值。 20 | 21 | ### Localization Net 22 | 23 | Localization Net输入为\[批量大小、高度、宽度、通道]的输入图像,并为每个维度的输入图像生成转换参数。转换的维度为\[batch_size,6]。 24 | ```python 25 | def create_localization_head(inputs): 26 | x = Conv2D(14, (5,5),padding='valid',activation="relu")(inputs) 27 | x = MaxPooling2D((2, 2), strides=2)(x) 28 | x = Conv2D(32, (5,5), padding='valid',activation="relu")(x) 29 | x = MaxPooling2D((2, 2),strides=2)(x) 30 | x = Flatten()(x) 31 | x = Dense(120, activation='relu')(x) 32 | x = Dropout(0.2)(x) 33 | x = Dense(84, activation='relu')(x) 34 | x = Dense(6, activation="linear", kernel_initializer="zeros", 35 | bias_initializer=lambda shape, dtype: tf.constant([1,0,0,0,1,0], dtype=dtype))(x) # 6 elements to describe the transformation 36 | return tf.keras.Model(inputs, x) 37 | ``` 38 | 39 | ### Grid Generator 40 | 41 | 在网格生成器中,必须注意,变换θ应用于从目标图像V而不是源图像U生成的网格,在图像处理领域称为逆映射。另一方面,如果我们将源图像U转换为目标图像V,这个过程称为前向映射。 42 | 43 | **正向映射**迭代输入图像的每个像素,为其计算新坐标,并将其值复制到新位置。但新坐标可能不在输出图像的边界内,也可能不是整数。通过在复制像素值之前检查计算的坐标,前一个问题很容易解决。第二个问题通过将最近的整数指定给x′和y′并将其用作变换像素的输出坐标来解决。问题在于,每个输出像素可能会被寻址多次或根本不寻址(后一种情况会导致“孔”,其中输出图像中的像素没有赋值)。**逆映射**迭代输出图像的每个像素,并使用逆变换确定输入图像中必须从中采样值的位置。在这种情况下,确定的位置也可能不在输入图像的边界内,也可能不是整数。但是输出图像没有孔。 44 | 45 | ```python 46 | def generate_normalized_homo_meshgrids(inputs): 47 | # for x, y in grid, -1 <=x,y<=1 48 | batch_size = tf.shape(inputs)[0] 49 | _, H, W,_ = inputs.shape 50 | x_range = tf.range(W) 51 | y_range = tf.range(H) 52 | x_mesh, y_mesh = tf.meshgrid(x_range, y_range) 53 | x_mesh = (x_mesh/W-0.5)*2 54 | y_mesh = (y_mesh/H-0.5)*2 55 | y_mesh = tf.reshape(y_mesh, (*y_mesh.shape,1)) 56 | x_mesh = tf.reshape(x_mesh, (*x_mesh.shape,1)) 57 | ones_mesh = tf.ones_like(x_mesh) 58 | homogeneous_grid = tf.concat([x_mesh, y_mesh, ones_mesh],-1) 59 | homogeneous_grid = tf.reshape(homogeneous_grid, (-1, 3,1)) 60 | homogeneous_grid = tf.dtypes.cast(homogeneous_grid, tf.float32) 61 | homogeneous_grid = tf.expand_dims(homogeneous_grid, 0) 62 | return tf.tile(homogeneous_grid, [batch_size, 1,1,1]) 63 | ``` 64 | 65 | 在```generate_normalized_homo_meshgrid```s函数中,给定输入维数,我们可以生成一个```meshgrid```。然后在[-1,1]之间对网格网格进行规格化,以便相对于图像中心执行旋转或平移。每个网格还扩展了第三维,并填充了第三维,因此被称为均质网格,在以下变换网格中更方便地执行变换。 66 | 67 | 在变换网格中,我们将从本地化网络生成的变换应用到从generate_normalized_homo_meshgrids生成的网格上,以获得重新```reprojected_grids```。变换后,```reprojected_grids```将重新缩放回输入图像的宽度和高度范围内。 68 | 69 | ### Sampler 70 | ```python 71 | def generate_four_neighbors_from_reprojection(inputs, reprojected_grids): 72 | _, H, W, _ = inputs.shape 73 | x, y = tf.split(reprojected_grids, 2, axis=-1) 74 | x1 = tf.floor(x) 75 | x1 = tf.dtypes.cast(x1, tf.int32) 76 | x2 = x1 + tf.constant(1) 77 | y1 = tf.floor(y) 78 | y1 = tf.dtypes.cast(y1, tf.int32) 79 | y2 = y1 + tf.constant(1) 80 | y_max = tf.constant(H - 1, dtype=tf.int32) 81 | x_max = tf.constant(W - 1, dtype=tf.int32) 82 | zero = tf.zeros([1], dtype=tf.int32) 83 | x1_safe = tf.clip_by_value(x1, zero, x_max) 84 | y1_safe = tf.clip_by_value(y1, zero, y_max) 85 | x2_safe = tf.clip_by_value(x2, zero, x_max) 86 | y2_safe = tf.clip_by_value(y2, zero, y_max) 87 | return x1_safe, y1_safe, x2_safe, y2_safe 88 | 89 | def bilinear_sample(inputs, reprojected_grids): 90 | x1, y1, x2, y2 = generate_four_neighbors_from_reprojection(inputs, reprojected_grids) 91 | x1y1 = tf.concat([y1,x1],-1) 92 | x1y2 = tf.concat([y2,x1],-1) 93 | x2y1 = tf.concat([y1,x2],-1) 94 | x2y2 = tf.concat([y2,x2],-1) 95 | pixel_x1y1 = tf.gather_nd(inputs, x1y1, batch_dims=1) 96 | pixel_x1y2 = tf.gather_nd(inputs, x1y2, batch_dims=1) 97 | pixel_x2y1 = tf.gather_nd(inputs, x2y1, batch_dims=1) 98 | pixel_x2y2 = tf.gather_nd(inputs, x2y2, batch_dims=1) 99 | x, y = tf.split(reprojected_grids, 2, axis=-1) 100 | wx = tf.concat([tf.dtypes.cast(x2, tf.float32) - x, x -tf.dtypes.cast(x1, tf.float32)],-1) 101 | wx = tf.expand_dims(wx, -2) 102 | wy = tf.concat([tf.dtypes.cast(y2, tf.float32) - y, y - tf.dtypes.cast(y1, tf.float32)],-1) 103 | wy = tf.expand_dims(wy, -1) 104 | Q = tf.concat([pixel_x1y1, pixel_x1y2, pixel_x2y1, pixel_x2y2], -1) 105 | Q_shape = tf.shape(Q) 106 | Q = tf.reshape(Q, (Q_shape[0], Q_shape[1],2,2)) 107 | Q = tf.cast(Q, tf.float32) 108 | 109 | r = wx@Q@wy 110 | _, H, W, channels = inputs.shape 111 | r = tf.reshape(r, (-1,H,W,1)) 112 | return r 113 | ``` 114 | 115 | ## Non-Local 116 | ![image](https://user-images.githubusercontent.com/27406337/146329854-5e1f5d7c-b69d-493e-8f88-60019b0eaae8.png) 117 | 118 | 119 | ## Deformable Convolution 120 | 121 | 参考:https://github.com/kastnerkyle/deform-conv 122 | -------------------------------------------------------------------------------- /non_local.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.layers import Activation,Reshape, Lambda 3 | from tensorflow.keras.layers import Conv1D, Conv2D, Conv3D 4 | from tensorflow.keras.layers import MaxPool1D 5 | from tensorflow.keras import backend as K 6 | from tensorflow.python.keras.layers.merge import add, dot 7 | 8 | 9 | def _convND(input, rank, channels): 10 | assert rank in [3, 4, 5], "Rank of input must be 3, 4 or 5" 11 | if rank == 3: 12 | x = Conv1D(channels, 1, padding='same', kernel_initializer = 'he_normal')(input) 13 | elif rank == 4: 14 | x = Conv2D(channels, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal')(input) 15 | else: 16 | x = Conv3D(channels, (1, 1, 1), padding='same', use_bias=False, kernel_initializer='he_normal')(input) 17 | return x 18 | 19 | 20 | def non_local_block(input, intermediate_dim=None, compression=2, mode='embedded', add_residual=True): 21 | ''' 22 | Adds a Non-Local block for self attention to the input tensor. 23 | Input tensor can be or rank 3(temporal), 4(spatial) or 5(spatio-temporal) 24 | 25 | Arguments: 26 | input:input tensor 27 | intermediate_dim: The dimension of the intermediate representation 28 | compression: None or positive integer. 29 | mode: Mode of operation 30 | add_residual: Boolean value to decide if the residual connection should be added or not. 31 | 32 | Returns: 33 | a tensor of same shape of input 34 | ''' 35 | # 获取通道数所在的维度 36 | channel_dim =1 if K.image_data_format() == 'channel_first' else -1 37 | input_shape = K.int_shape(input) 38 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 39 | raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`') 40 | 41 | if compression is None: 42 | compression = 1 43 | 44 | # check rank and calculate the input shape 45 | if len(input_shape) == 3: # temporal / time series data 46 | rank = 3 47 | batchsize, dim1, channels = input_shape 48 | 49 | elif len(input_shape) == 4: # spatial / image data 50 | rank = 4 51 | 52 | if channel_dim == 1: 53 | batchsize, channels, dim1, dim2 = input_shape 54 | else: 55 | batchsize, dim1, dim2, channels = input_shape 56 | 57 | elif len(input_shape) == 5: # spatio-temporal / Video or Voxel data 58 | rank = 5 59 | 60 | if channel_dim == 1: 61 | batchsize, channels, dim1, dim2, dim3 = input_shape 62 | else: 63 | batchsize, dim1, dim2, dim3, channels = input_shape 64 | 65 | else: 66 | raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial) or 5 (spatio-temporal)') 67 | 68 | if intermediate_dim is None: 69 | intermediate_dim=channels//2 70 | 71 | if intermediate_dim<1: 72 | intermediate_dim=1 73 | else: 74 | intermediate_dim = int(intermediate_dim) 75 | if intermediate_dim<1: 76 | raise ValueError('`intermediate_dim` must be either `None` or positive integer greater than 1.') 77 | 78 | # instantiation 79 | if mode == 'gaussian': 80 | x1 = Reshape((-1, channels))(input) 81 | x2 = Reshape((-1, channels))(input) 82 | f = dot([x1, x2], axes=2) 83 | f = Activation('softmax')(f) 84 | elif mode == 'dot': 85 | # theta path 86 | theta = _convND(input, rank, intermediate_dim) 87 | theta = Reshape((-1, intermediate_dim))(theta) 88 | 89 | # phi path 90 | phi = _convND(input, rank, intermediate_dim) 91 | phi = Reshape((-1, intermediate_dim))(phi) 92 | 93 | f = dot([theta, phi], axes=2) 94 | 95 | size = K.int_shape(f) 96 | 97 | # scale the values to make it size invariant 98 | f = Lambda(lambda z: (1. / float(size[-1])) * z)(f) 99 | else: 100 | # theta path 101 | theta = _convND(input, rank, intermediate_dim) 102 | theta = Reshape((-1, intermediate_dim))(theta) 103 | 104 | # phi path 105 | phi = _convND(input, rank, intermediate_dim) 106 | phi = Reshape((-1, intermediate_dim))(phi) 107 | 108 | if compression > 1: 109 | # shielded computation 110 | phi = MaxPool1D(compression)(phi) 111 | 112 | f = dot([theta, phi], axes=2) 113 | f = Activation('softmax')(f) 114 | 115 | # g path 116 | g = _convND(input, rank, intermediate_dim) 117 | g = Reshape((-1, intermediate_dim))(g) 118 | 119 | if compression > 1 and mode == 'embedded': 120 | # shielded computation 121 | g = MaxPool1D(compression)(g) 122 | 123 | # compute output path 124 | y = dot([f, g], axes=[2, 1]) 125 | 126 | # reshape to input tensor format 127 | if rank == 3: 128 | y = Reshape((dim1, intermediate_dim))(y) 129 | elif rank == 4: 130 | if channel_dim == -1: 131 | y = Reshape((dim1, dim2, intermediate_dim))(y) 132 | else: 133 | y = Reshape((intermediate_dim, dim1, dim2))(y) 134 | else: 135 | if channel_dim == -1: 136 | y = Reshape((dim1, dim2, dim3, intermediate_dim))(y) 137 | else: 138 | y = Reshape((intermediate_dim, dim1, dim2, dim3))(y) 139 | 140 | # project filters 141 | y = _convND(y, rank, channels) 142 | 143 | # residual connection 144 | if add_residual: 145 | y = add([input, y]) 146 | 147 | return y 148 | -------------------------------------------------------------------------------- /stn_module.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.keras.layers import Conv2D, MaxPooling2D,Flatten,Dense,Input,Dropout 4 | ''' 5 | 以DenseNet为例,添加stn模块 6 | https://github.com/xeonqq/spatial_transformer_network 7 | ''' 8 | 9 | 10 | def create_localizaton_head(inputs): 11 | x = Conv2D(14, (5, 5), padding='valid', activation='relu')(inputs) 12 | x = MaxPooling2D((2, 2), strides=2)(x) 13 | x = Conv2D(32, (5, 5), padding='valid', activation='relu')(x) 14 | x = MaxPooling2D((2,2), strides=2)(x) 15 | x = Flatten()(x) 16 | x = Dense(120,activation='relu')(x) 17 | x = Dropout(0.2)(x) 18 | x = Dense(84, activation='relu')(x) 19 | # 6 elements to describe the transformation 20 | x = Dense(6, activation='linear', kernel_initializer='zeros', 21 | bias_initializer=lambda shape, dtype:tf.constant([1,0,0,0,1,0], dtype=dtype))(x) 22 | return tf.keras.Model(inputs, x) 23 | 24 | def generate_normalized_homo_meshgrids(inputs): 25 | # for x, y in grid, -1 <=x,y<=1 26 | batch_size = tf.shape(inputs)[0] 27 | _, H, W,_ = inputs.shape 28 | x_range = tf.range(W) 29 | y_range = tf.range(H) 30 | x_mesh, y_mesh = tf.meshgrid(x_range, y_range) 31 | x_mesh = (x_mesh/W-0.5)*2 32 | y_mesh = (y_mesh/H-0.5)*2 33 | y_mesh = tf.reshape(y_mesh, (*y_mesh.shape,1)) 34 | x_mesh = tf.reshape(x_mesh, (*x_mesh.shape,1)) 35 | ones_mesh = tf.ones_like(x_mesh) 36 | homogeneous_grid = tf.concat([x_mesh, y_mesh, ones_mesh],-1) 37 | homogeneous_grid = tf.reshape(homogeneous_grid, (-1, 3,1)) 38 | homogeneous_grid = tf.dtypes.cast(homogeneous_grid, tf.float32) 39 | homogeneous_grid = tf.expand_dims(homogeneous_grid, 0) 40 | return tf.tile(homogeneous_grid, [batch_size, 1,1,1]) 41 | 42 | def transform_grids(transformations, grids, inputs): 43 | with tf.name_scope("transform_grids"): 44 | trans_matrices=tf.reshape(transformations, (-1, 2,3)) 45 | batch_size = tf.shape(trans_matrices)[0] 46 | gs = tf.squeeze(grids, -1) 47 | 48 | reprojected_grids = tf.matmul(trans_matrices, gs, transpose_b=True) 49 | # transform grid range from [-1,1) to the range of [0,1) 50 | reprojected_grids = (tf.linalg.matrix_transpose(reprojected_grids) + 1)*0.5 51 | _, H, W, _ = inputs.shape 52 | reprojected_grids = tf.math.multiply(reprojected_grids, [W, H]) 53 | 54 | return reprojected_grids 55 | 56 | def generate_four_neighbors_from_reprojection(inputs, reprojected_grids): 57 | _, H, W, _ = inputs.shape 58 | 59 | x, y = tf.split(reprojected_grids, 2, axis=-1) 60 | 61 | x1 = tf.floor(x) 62 | x1 = tf.dtypes.cast(x1, tf.int32) 63 | 64 | x2 = x1 + tf.constant(1) 65 | 66 | y1 = tf.floor(y) 67 | y1 = tf.dtypes.cast(y1, tf.int32) 68 | y2 = y1 + tf.constant(1) 69 | 70 | y_max = tf.constant(H - 1, dtype=tf.int32) 71 | x_max = tf.constant(W - 1, dtype=tf.int32) 72 | zero = tf.zeros([1], dtype=tf.int32) 73 | 74 | x1_safe = tf.clip_by_value(x1, zero, x_max) 75 | y1_safe = tf.clip_by_value(y1, zero, y_max) 76 | x2_safe = tf.clip_by_value(x2, zero, x_max) 77 | y2_safe = tf.clip_by_value(y2, zero, y_max) 78 | return x1_safe, y1_safe, x2_safe, y2_safe 79 | 80 | def bilinear_sample(inputs, reprojected_grids): 81 | x1, y1, x2, y2 = generate_four_neighbors_from_reprojection(inputs, reprojected_grids) 82 | x1y1 = tf.concat([y1,x1],-1) 83 | x1y2 = tf.concat([y2,x1],-1) 84 | x2y1 = tf.concat([y1,x2],-1) 85 | x2y2 = tf.concat([y2,x2],-1) 86 | 87 | pixel_x1y1 = tf.gather_nd(inputs, x1y1, batch_dims=1) 88 | pixel_x1y2 = tf.gather_nd(inputs, x1y2, batch_dims=1) 89 | pixel_x2y1 = tf.gather_nd(inputs, x2y1, batch_dims=1) 90 | pixel_x2y2 = tf.gather_nd(inputs, x2y2, batch_dims=1) 91 | x, y = tf.split(reprojected_grids, 2, axis=-1) 92 | wx = tf.concat([tf.dtypes.cast(x2, tf.float32) - x, x -tf.dtypes.cast(x1, tf.float32)],-1) 93 | wx = tf.expand_dims(wx, -2) 94 | wy = tf.concat([tf.dtypes.cast(y2, tf.float32) - y, y - tf.dtypes.cast(y1, tf.float32)],-1) 95 | wy = tf.expand_dims(wy, -1) 96 | Q = tf.concat([pixel_x1y1, pixel_x1y2, pixel_x2y1, pixel_x2y2], -1) 97 | Q_shape = tf.shape(Q) 98 | Q = tf.reshape(Q, (Q_shape[0], Q_shape[1],2,2)) 99 | Q = tf.cast(Q, tf.float32) 100 | 101 | r = wx@Q@wy 102 | _, H, W, channels = inputs.shape 103 | 104 | r = tf.reshape(r, (-1,H,W,1)) 105 | return r 106 | 107 | def spatial_transform_input(inputs, transormations): 108 | grids = generate_normalized_homo_meshgrids(inputs) 109 | reprojected_grids = transform_grids(transormations, grids,inputs) 110 | result = bilinear_sample(inputs, reprojected_grids) 111 | return result 112 | 113 | def stn_module(inputs): 114 | localication_head = create_localizaton_head(inputs) 115 | x = spatial_transform_input(inputs, localication_head.output) 116 | return x 117 | 118 | def model(input_shape): 119 | inputs = Input(input_shape) 120 | inputs_stn = stn_module(inputs) 121 | x = Conv2D(6, (3,3),padding='valid',activation="relu")(inputs_stn) 122 | x = MaxPooling2D((2, 2))(x) 123 | x = Conv2D(16, (3,3),padding='valid',activation="relu")(x) 124 | x = MaxPooling2D((2, 2))(x) 125 | x = Flatten()(x) 126 | x = Dense(120, activation='relu')(x) 127 | x = Dense(84, activation='relu')(x) 128 | x = Dense(10)(x) 129 | return tf.keras.Model(inputs, x) 130 | 131 | input_shape = (28, 28, 1) 132 | st_model = model(input_shape) 133 | st_model.summary() 134 | 135 | 136 | -------------------------------------------------------------------------------- /Deformable_Conv/deform_conv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division 2 | import numpy as np 3 | from scipy.ndimage.interpolation import map_coordinates as sp_map_coordinates 4 | import tensorflow as tf 5 | 6 | 7 | def tf_flatten(a): 8 | """Flatten tensor""" 9 | return tf.reshape(a, [-1]) 10 | 11 | 12 | def tf_repeat(a, repeats, axis=0): 13 | """TensorFlow version of np.repeat for 1D""" 14 | # https://github.com/tensorflow/tensorflow/issues/8521 15 | assert len(a.get_shape()) == 1 16 | 17 | a = tf.expand_dims(a, -1) 18 | a = tf.tile(a, [1, repeats]) 19 | a = tf_flatten(a) 20 | return a 21 | 22 | 23 | def tf_repeat_2d(a, repeats): 24 | """Tensorflow version of np.repeat for 2D""" 25 | 26 | assert len(a.get_shape()) == 2 27 | a = tf.expand_dims(a, 0) 28 | a = tf.tile(a, [repeats, 1, 1]) 29 | return a 30 | 31 | 32 | def tf_map_coordinates(input, coords, order=1): 33 | """Tensorflow verion of scipy.ndimage.map_coordinates 34 | Note that coords is transposed and only 2D is supported 35 | Parameters 36 | ---------- 37 | input : tf.Tensor. shape = (s, s) 38 | coords : tf.Tensor. shape = (n_points, 2) 39 | """ 40 | 41 | assert order == 1 42 | 43 | coords_lt = tf.cast(tf.floor(coords), 'int32') 44 | coords_rb = tf.cast(tf.math.ceil(coords), 'int32') 45 | coords_lb = tf.stack([coords_lt[:, 0], coords_rb[:, 1]], axis=1) 46 | coords_rt = tf.stack([coords_rb[:, 0], coords_lt[:, 1]], axis=1) 47 | 48 | vals_lt = tf.gather_nd(input, coords_lt) 49 | vals_rb = tf.gather_nd(input, coords_rb) 50 | vals_lb = tf.gather_nd(input, coords_lb) 51 | vals_rt = tf.gather_nd(input, coords_rt) 52 | 53 | coords_offset_lt = coords - tf.cast(coords_lt, 'float32') 54 | vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[:, 0] 55 | vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[:, 0] 56 | mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[:, 1] 57 | 58 | return mapped_vals 59 | 60 | 61 | def sp_batch_map_coordinates(inputs, coords): 62 | """Reference implementation for batch_map_coordinates""" 63 | coords = coords.clip(0, inputs.shape[1] - 1) 64 | mapped_vals = np.array([ 65 | sp_map_coordinates(input, coord.T, mode='nearest', order=1) 66 | for input, coord in zip(inputs, coords) 67 | ]) 68 | return mapped_vals 69 | 70 | 71 | def tf_batch_map_coordinates(input, coords, order=1): 72 | """Batch version of tf_map_coordinates 73 | Only supports 2D feature maps 74 | Parameters 75 | ---------- 76 | input : tf.Tensor. shape = (b, s, s) 77 | coords : tf.Tensor. shape = (b, n_points, 2) 78 | """ 79 | 80 | input_shape = tf.shape(input) 81 | batch_size = input_shape[0] 82 | input_size = input_shape[1] 83 | n_coords = tf.shape(coords)[1] 84 | 85 | coords = tf.clip_by_value(coords, 0, tf.cast(input_size, 'float32') - 1) 86 | coords_lt = tf.cast(tf.floor(coords), 'int32') 87 | coords_rb = tf.cast(tf.math.ceil(coords), 'int32') 88 | coords_lb = tf.stack([coords_lt[..., 0], coords_rb[..., 1]], axis=-1) 89 | coords_rt = tf.stack([coords_rb[..., 0], coords_lt[..., 1]], axis=-1) 90 | 91 | idx = tf_repeat(tf.range(batch_size), n_coords) 92 | 93 | def _get_vals_by_coords(input, coords): 94 | indices = tf.stack([ 95 | idx, tf_flatten(coords[..., 0]), tf_flatten(coords[..., 1]) 96 | ], axis=-1) 97 | vals = tf.gather_nd(input, indices) 98 | vals = tf.reshape(vals, (batch_size, n_coords)) 99 | return vals 100 | 101 | vals_lt = _get_vals_by_coords(input, coords_lt) 102 | vals_rb = _get_vals_by_coords(input, coords_rb) 103 | vals_lb = _get_vals_by_coords(input, coords_lb) 104 | vals_rt = _get_vals_by_coords(input, coords_rt) 105 | 106 | coords_offset_lt = coords - tf.cast(coords_lt, 'float32') 107 | vals_t = vals_lt + (vals_rt - vals_lt) * coords_offset_lt[..., 0] 108 | vals_b = vals_lb + (vals_rb - vals_lb) * coords_offset_lt[..., 0] 109 | mapped_vals = vals_t + (vals_b - vals_t) * coords_offset_lt[..., 1] 110 | 111 | return mapped_vals 112 | 113 | 114 | def sp_batch_map_offsets(input, offsets): 115 | """Reference implementation for tf_batch_map_offsets""" 116 | 117 | batch_size = input.shape[0] 118 | input_size = input.shape[1] 119 | 120 | offsets = offsets.reshape(batch_size, -1, 2) 121 | grid = np.stack(np.mgrid[:input_size, :input_size], -1).reshape(-1, 2) 122 | grid = np.repeat([grid], batch_size, axis=0) 123 | coords = offsets + grid 124 | coords = coords.clip(0, input_size - 1) 125 | 126 | mapped_vals = sp_batch_map_coordinates(input, coords) 127 | return mapped_vals 128 | 129 | 130 | def tf_batch_map_offsets(input, offsets, order=1): 131 | """Batch map offsets into input 132 | Parameters 133 | --------- 134 | input : tf.Tensor. shape = (b, s, s) 135 | offsets: tf.Tensor. shape = (b, s, s, 2) 136 | """ 137 | 138 | input_shape = tf.shape(input) 139 | batch_size = input_shape[0] 140 | input_size = input_shape[1] 141 | 142 | offsets = tf.reshape(offsets, (batch_size, -1, 2)) 143 | grid = tf.meshgrid( 144 | tf.range(input_size), tf.range(input_size), indexing='ij' 145 | ) 146 | grid = tf.stack(grid, axis=-1) 147 | grid = tf.cast(grid, 'float32') 148 | grid = tf.reshape(grid, (-1, 2)) 149 | grid = tf_repeat_2d(grid, batch_size) 150 | coords = offsets + grid 151 | 152 | mapped_vals = tf_batch_map_coordinates(input, coords) 153 | return mapped_vals -------------------------------------------------------------------------------- /non_local_test.py: -------------------------------------------------------------------------------- 1 | import six 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.layers import Input, Activation, Reshape, Dense, Conv2D, MaxPooling2D, GlobalMaxPooling2D, GlobalAveragePooling2D, Dropout, BatchNormalization 4 | from tensorflow.python.keras.layers.merge import add 5 | from tensorflow.keras.regularizers import l2 6 | from tensorflow.keras import backend as K 7 | from keras_applications.imagenet_utils import _obtain_input_shape 8 | 9 | from non_local import non_local_block 10 | 11 | 12 | def _bn_relu(x, bn_name=None, relu_name=None): 13 | """Helper to build a BN -> relu block 14 | """ 15 | norm = BatchNormalization(axis=CHANNEL_AXIS, name=bn_name)(x) 16 | return Activation("relu", name=relu_name)(norm) 17 | 18 | 19 | def _conv_bn_relu(**conv_params): 20 | """Helper to build a conv -> BN -> relu residual unit activation function. 21 | This is the original ResNet v1 scheme in https://arxiv.org/abs/1512.03385 22 | """ 23 | filters = conv_params["filters"] 24 | kernel_size = conv_params["kernel_size"] 25 | strides = conv_params.setdefault("strides", (1, 1)) 26 | dilation_rate = conv_params.setdefault("dilation_rate", (1, 1)) 27 | conv_name = conv_params.setdefault("conv_name", None) 28 | bn_name = conv_params.setdefault("bn_name", None) 29 | relu_name = conv_params.setdefault("relu_name", None) 30 | kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal") 31 | padding = conv_params.setdefault("padding", "same") 32 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4)) 33 | 34 | def f(x): 35 | x = Conv2D(filters=filters, kernel_size=kernel_size, 36 | strides=strides, padding=padding, 37 | dilation_rate=dilation_rate, 38 | kernel_initializer=kernel_initializer, 39 | kernel_regularizer=kernel_regularizer, 40 | name=conv_name)(x) 41 | return _bn_relu(x, bn_name=bn_name, relu_name=relu_name) 42 | 43 | return f 44 | 45 | 46 | def _bn_relu_conv(**conv_params): 47 | """Helper to build a BN -> relu -> conv residual unit with full pre-activation function. 48 | This is the ResNet v2 scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf 49 | """ 50 | filters = conv_params["filters"] 51 | kernel_size = conv_params["kernel_size"] 52 | strides = conv_params.setdefault("strides", (1, 1)) 53 | dilation_rate = conv_params.setdefault("dilation_rate", (1, 1)) 54 | conv_name = conv_params.setdefault("conv_name", None) 55 | bn_name = conv_params.setdefault("bn_name", None) 56 | relu_name = conv_params.setdefault("relu_name", None) 57 | kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal") 58 | padding = conv_params.setdefault("padding", "same") 59 | kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4)) 60 | 61 | def f(x): 62 | activation = _bn_relu(x, bn_name=bn_name, relu_name=relu_name) 63 | return Conv2D(filters=filters, kernel_size=kernel_size, 64 | strides=strides, padding=padding, 65 | dilation_rate=dilation_rate, 66 | kernel_initializer=kernel_initializer, 67 | kernel_regularizer=kernel_regularizer, 68 | name=conv_name)(activation) 69 | 70 | return f 71 | 72 | 73 | def _shortcut(input_feature, residual, conv_name_base=None, bn_name_base=None): 74 | """Adds a shortcut between input and residual block and merges them with "sum" 75 | """ 76 | # Expand channels of shortcut to match residual. 77 | # Stride appropriately to match residual (width, height) 78 | # Should be int if network architecture is correctly configured. 79 | input_shape = K.int_shape(input_feature) 80 | residual_shape = K.int_shape(residual) 81 | stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS])) 82 | stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS])) 83 | equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS] 84 | 85 | shortcut = input_feature 86 | # 1 X 1 conv if shape is different. Else identity. 87 | if stride_width > 1 or stride_height > 1 or not equal_channels: 88 | print('reshaping via a convolution...') 89 | if conv_name_base is not None: 90 | conv_name_base = conv_name_base + '1' 91 | shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS], 92 | kernel_size=(1, 1), 93 | strides=(stride_width, stride_height), 94 | padding="valid", 95 | kernel_initializer="he_normal", 96 | kernel_regularizer=l2(0.0001), 97 | name=conv_name_base)(input_feature) 98 | if bn_name_base is not None: 99 | bn_name_base = bn_name_base + '1' 100 | shortcut = BatchNormalization(axis=CHANNEL_AXIS, name=bn_name_base)(shortcut) 101 | 102 | return add([shortcut, residual]) 103 | 104 | 105 | def _residual_block(block_function, filters, blocks, stage, 106 | transition_strides=None, transition_dilation_rates=None, 107 | dilation_rates=(1, 1), is_first_layer=False, dropout=None, 108 | residual_unit=_bn_relu_conv): 109 | """Builds a residual block with repeating bottleneck blocks. 110 | stage: integer, current stage label, used for generating layer names 111 | blocks: number of blocks 'a','b'..., current block label, used for generating layer names 112 | transition_strides: a list of tuples for the strides of each transition 113 | transition_dilation_rates: a list of tuples for the dilation rate of each transition 114 | """ 115 | if transition_dilation_rates is None: 116 | transition_dilation_rates = [(1, 1)] * blocks 117 | if transition_strides is None: 118 | transition_strides = [(1, 1)] * blocks 119 | 120 | def f(x): 121 | for i in range(blocks): 122 | x = block_function(filters=filters, stage=stage, block=i, 123 | transition_strides=transition_strides[i], 124 | dilation_rate=transition_dilation_rates[i], 125 | is_first_block_of_first_layer=(is_first_layer and i == 0), 126 | dropout=dropout, 127 | residual_unit=residual_unit)(x) 128 | 129 | # Non Local Blook 130 | if filters >= 256: 131 | print("Filters : ", filters, "Adding Non Local Blocks") 132 | x = non_local_block(x, mode='embedded', compression=2) 133 | 134 | return x 135 | 136 | return f 137 | 138 | 139 | def _block_name_base(stage, block): 140 | """Get the convolution name base and batch normalization name base defined by stage and block. 141 | If there are less than 26 blocks they will be labeled 'a', 'b', 'c' to match the paper and keras 142 | and beyond 26 blocks they will simply be numbered. 143 | """ 144 | if block < 27: 145 | block = '%c' % (block + 97) # 97 is the ascii number for lowercase 'a' 146 | conv_name_base = 'res' + str(stage) + block + '_branch' 147 | bn_name_base = 'bn' + str(stage) + block + '_branch' 148 | return conv_name_base, bn_name_base 149 | 150 | 151 | def basic_block(filters, stage, block, transition_strides=(1, 1), 152 | dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None, 153 | residual_unit=_bn_relu_conv): 154 | """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34. 155 | Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf 156 | """ 157 | def f(input_features): 158 | conv_name_base, bn_name_base = _block_name_base(stage, block) 159 | if is_first_block_of_first_layer: 160 | # don't repeat bn->relu since we just did bn->relu->maxpool 161 | x = Conv2D(filters=filters, kernel_size=(3, 3), 162 | strides=transition_strides, 163 | dilation_rate=dilation_rate, 164 | padding="same", 165 | kernel_initializer="he_normal", 166 | kernel_regularizer=l2(1e-4), 167 | name=conv_name_base + '2a')(input_features) 168 | else: 169 | x = residual_unit(filters=filters, kernel_size=(3, 3), 170 | strides=transition_strides, 171 | dilation_rate=dilation_rate, 172 | conv_name_base=conv_name_base + '2a', 173 | bn_name_base=bn_name_base + '2a')(input_features) 174 | 175 | if dropout is not None: 176 | x = Dropout(dropout)(x) 177 | 178 | x = residual_unit(filters=filters, kernel_size=(3, 3), 179 | conv_name_base=conv_name_base + '2b', 180 | bn_name_base=bn_name_base + '2b')(x) 181 | 182 | return _shortcut(input_features, x) 183 | 184 | return f 185 | 186 | 187 | def bottleneck(filters, stage, block, transition_strides=(1, 1), 188 | dilation_rate=(1, 1), is_first_block_of_first_layer=False, dropout=None, 189 | residual_unit=_bn_relu_conv): 190 | """Bottleneck architecture for > 34 layer resnet. 191 | Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf 192 | Returns: 193 | A final conv layer of filters * 4 194 | """ 195 | def f(input_feature): 196 | conv_name_base, bn_name_base = _block_name_base(stage, block) 197 | if is_first_block_of_first_layer: 198 | # don't repeat bn->relu since we just did bn->relu->maxpool 199 | x = Conv2D(filters=filters, kernel_size=(1, 1), 200 | strides=transition_strides, 201 | dilation_rate=dilation_rate, 202 | padding="same", 203 | kernel_initializer="he_normal", 204 | kernel_regularizer=l2(1e-4), 205 | name=conv_name_base + '2a')(input_feature) 206 | else: 207 | x = residual_unit(filters=filters, kernel_size=(1, 1), 208 | strides=transition_strides, 209 | dilation_rate=dilation_rate, 210 | conv_name_base=conv_name_base + '2a', 211 | bn_name_base=bn_name_base + '2a')(input_feature) 212 | 213 | if dropout is not None: 214 | x = Dropout(dropout)(x) 215 | 216 | x = residual_unit(filters=filters, kernel_size=(3, 3), 217 | conv_name_base=conv_name_base + '2b', 218 | bn_name_base=bn_name_base + '2b')(x) 219 | 220 | if dropout is not None: 221 | x = Dropout(dropout)(x) 222 | 223 | x = residual_unit(filters=filters * 4, kernel_size=(1, 1), 224 | conv_name_base=conv_name_base + '2c', 225 | bn_name_base=bn_name_base + '2c')(x) 226 | 227 | return _shortcut(input_feature, x) 228 | 229 | return f 230 | 231 | 232 | def _handle_dim_ordering(): 233 | global ROW_AXIS 234 | global COL_AXIS 235 | global CHANNEL_AXIS 236 | if K.image_data_format() == 'channels_last': 237 | ROW_AXIS = 1 238 | COL_AXIS = 2 239 | CHANNEL_AXIS = 3 240 | else: 241 | CHANNEL_AXIS = 1 242 | ROW_AXIS = 2 243 | COL_AXIS = 3 244 | 245 | 246 | def _string_to_function(identifier): 247 | if isinstance(identifier, six.string_types): 248 | res = globals().get(identifier) 249 | if not res: 250 | raise ValueError('Invalid {}'.format(identifier)) 251 | return res 252 | return identifier 253 | 254 | 255 | def NonLocalResNet(input_shape=None, classes=10, block='bottleneck', residual_unit='v2', repetitions=None, 256 | initial_filters=64, activation='softmax', include_top=True, input_tensor=None, dropout=None, 257 | transition_dilation_rate=(1, 1), initial_strides=(2, 2), initial_kernel_size=(7, 7), 258 | initial_pooling='max', final_pooling=None, top='classification'): 259 | """Builds a custom ResNet like architecture. Defaults to ResNet50 v2. 260 | Args: 261 | input_shape: optional shape tuple, only to be specified 262 | if `include_top` is False (otherwise the input shape 263 | has to be `(224, 224, 3)` (with `channels_last` dim ordering) 264 | or `(3, 224, 224)` (with `channels_first` dim ordering). 265 | It should have exactly 3 inputs channels, 266 | and width and height should be no smaller than 8. 267 | E.g. `(224, 224, 3)` would be one valid value. 268 | classes: The number of outputs at final softmax layer 269 | block: The block function to use. This is either `'basic'` or `'bottleneck'`. 270 | The original paper used `basic` for layers < 50. 271 | repetitions: Number of repetitions of various block units. 272 | At each block unit, the number of filters are doubled and the input size is halved. 273 | Default of None implies the ResNet50v2 values of [3, 4, 6, 3]. 274 | transition_dilation_rate: Used for pixel-wise prediction tasks such as image segmentation. 275 | residual_unit: the basic residual unit, 'v1' for conv bn relu, 'v2' for bn relu conv. 276 | See [Identity Mappings in Deep Residual Networks](https://arxiv.org/abs/1603.05027) 277 | for details. 278 | dropout: None for no dropout, otherwise rate of dropout from 0 to 1. 279 | Based on [Wide Residual Networks.(https://arxiv.org/pdf/1605.07146) paper. 280 | transition_dilation_rate: Dilation rate for transition layers. For semantic 281 | segmentation of images use a dilation rate of (2, 2). 282 | initial_strides: Stride of the very first residual unit and MaxPooling2D call, 283 | with default (2, 2), set to (1, 1) for small images like cifar. 284 | initial_kernel_size: kernel size of the very first convolution, (7, 7) for imagenet 285 | and (3, 3) for small image datasets like tiny imagenet and cifar. 286 | See [ResNeXt](https://arxiv.org/abs/1611.05431) paper for details. 287 | initial_pooling: Determine if there will be an initial pooling layer, 288 | 'max' for imagenet and None for small image datasets. 289 | See [ResNeXt](https://arxiv.org/abs/1611.05431) paper for details. 290 | final_pooling: Optional pooling mode for feature extraction at the final model layer 291 | when `include_top` is `False`. 292 | - `None` means that the output of the model 293 | will be the 4D tensor output of the 294 | last convolutional layer. 295 | - `avg` means that global average pooling 296 | will be applied to the output of the 297 | last convolutional layer, and thus 298 | the output of the model will be a 299 | 2D tensor. 300 | - `max` means that global max pooling will 301 | be applied. 302 | top: Defines final layers to evaluate based on a specific problem type. Options are 303 | 'classification' for ImageNet style problems, 'segmentation' for problems like 304 | the Pascal VOC dataset, and None to exclude these layers entirely. 305 | Returns: 306 | The keras `Model`. 307 | """ 308 | if activation not in ['softmax', 'sigmoid', None]: 309 | raise ValueError('activation must be one of "softmax", "sigmoid", or None') 310 | if activation == 'sigmoid' and classes != 1: 311 | raise ValueError('sigmoid activation can only be used when classes = 1') 312 | if repetitions is None: 313 | repetitions = [3, 4, 6, 3] 314 | # Determine proper input shape 315 | input_shape = _obtain_input_shape (input_shape, 316 | default_size=32, 317 | min_size=8, 318 | data_format=K.image_data_format(), 319 | require_flatten=include_top) 320 | _handle_dim_ordering() 321 | if len(input_shape) != 3: 322 | raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)") 323 | 324 | if block == 'basic': 325 | block_fn = basic_block 326 | elif block == 'bottleneck': 327 | block_fn = bottleneck 328 | elif isinstance(block, six.string_types): 329 | block_fn = _string_to_function(block) 330 | else: 331 | block_fn = block 332 | 333 | if residual_unit == 'v2': 334 | residual_unit = _bn_relu_conv 335 | elif residual_unit == 'v1': 336 | residual_unit = _conv_bn_relu 337 | elif isinstance(residual_unit, six.string_types): 338 | residual_unit = _string_to_function(residual_unit) 339 | else: 340 | residual_unit = residual_unit 341 | 342 | # Permute dimension order if necessary 343 | if K.image_data_format() == 'channels_first': 344 | input_shape = (input_shape[1], input_shape[2], input_shape[0]) 345 | # Determine proper input shape 346 | input_shape = _obtain_input_shape(input_shape, 347 | default_size=32, 348 | min_size=8, 349 | data_format=K.image_data_format(), 350 | require_flatten=include_top) 351 | 352 | img_input = Input(shape=input_shape, tensor=input_tensor) 353 | x = _conv_bn_relu(filters=initial_filters, kernel_size=initial_kernel_size, strides=initial_strides)(img_input) 354 | if initial_pooling == 'max': 355 | x = MaxPooling2D(pool_size=(3, 3), strides=initial_strides, padding="same")(x) 356 | 357 | 358 | block = x 359 | filters = initial_filters 360 | for i, r in enumerate(repetitions): 361 | transition_dilation_rates = [transition_dilation_rate] * r 362 | transition_strides = [(1, 1)] * r 363 | if transition_dilation_rate == (1, 1): 364 | transition_strides[0] = (2, 2) 365 | block = _residual_block(block_fn, filters=filters, 366 | stage=i, blocks=r, 367 | is_first_layer=(i == 0), 368 | dropout=dropout, 369 | transition_dilation_rates=transition_dilation_rates, 370 | transition_strides=transition_strides, 371 | residual_unit=residual_unit)(block) 372 | filters *= 2 373 | 374 | # Last activation 375 | x = _bn_relu(block) 376 | 377 | # Classifier block 378 | if include_top and top is 'classification': 379 | x = GlobalAveragePooling2D()(x) 380 | x = Dense(units=classes, activation=activation, kernel_initializer="he_normal")(x) 381 | elif include_top and top is 'segmentation': 382 | x = Conv2D(classes, (1, 1), activation='linear', padding='same')(x) 383 | 384 | if K.image_data_format() == 'channels_first': 385 | channel, row, col = input_shape 386 | else: 387 | row, col, channel = input_shape 388 | 389 | x = Reshape((row * col, classes))(x) 390 | x = Activation(activation)(x) 391 | x = Reshape((row, col, classes))(x) 392 | elif final_pooling == 'avg': 393 | x = GlobalAveragePooling2D()(x) 394 | elif final_pooling == 'max': 395 | x = GlobalMaxPooling2D()(x) 396 | 397 | model = Model(inputs=img_input, outputs=x) 398 | return model 399 | 400 | 401 | def NonLocalResNet18(input_shape, classes): 402 | """ResNet with 18 layers and v2 residual units 403 | """ 404 | return NonLocalResNet(input_shape, classes, basic_block, repetitions=[2, 2, 2, 2]) 405 | 406 | 407 | def NonLocalResNet34(input_shape, classes): 408 | """ResNet with 34 layers and v2 residual units 409 | """ 410 | return NonLocalResNet(input_shape, classes, basic_block, repetitions=[3, 4, 6, 3]) 411 | 412 | 413 | def NonLocalResNet50(input_shape, classes): 414 | """ResNet with 50 layers and v2 residual units 415 | """ 416 | return NonLocalResNet(input_shape, classes, bottleneck, repetitions=[3, 4, 6, 3]) 417 | 418 | 419 | def NonLocalResNet101(input_shape, classes): 420 | """ResNet with 101 layers and v2 residual units 421 | """ 422 | return NonLocalResNet(input_shape, classes, bottleneck, repetitions=[3, 4, 23, 3]) 423 | 424 | 425 | def NonLocalResNet152(input_shape, classes): 426 | """ResNet with 152 layers and v2 residual units 427 | """ 428 | return NonLocalResNet(input_shape, classes, bottleneck, repetitions=[3, 8, 36, 3]) 429 | 430 | 431 | if __name__ == '__main__': 432 | model = NonLocalResNet18((128, 160, 3), classes=10) 433 | model.summary() -------------------------------------------------------------------------------- /STN_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## STN测试报告\n", 8 | "以公开mnist数据集为例做训练,比较有stn模块和没有stn模块的网络两者性能的差异" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [], 16 | "source": [ 17 | "import tensorflow as tf\n", 18 | "import matplotlib\n", 19 | "import numpy as np\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "import imgaug.augmenters as iaa\n", 22 | "import imgaug as ia\n", 23 | "from datetime import datetime\n", 24 | "from math import cos, sin, pi" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# 加载数据集\n", 34 | "mnist = tf.keras.datasets.mnist\n", 35 | "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 36 | "# 归一化\n", 37 | "x_train, x_test = x_train / 255.0, x_test / 255.0\n", 38 | "# reshape\n", 39 | "x_train = x_train.reshape(-1, 28, 28, 1)\n", 40 | "x_test = x_test.reshape(-1, 28, 28, 1)\n", 41 | "# 获取图像的长宽\n", 42 | "H, W,_ = x_train[0].shape" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### 数据增强\n", 50 | "生成一些数据增强的数据,数据增强的方法是仿射变换与原数据构成新的训练集" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 3, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "# Apply affine transformations to some of the images\n", 60 | "# - scale to 80-120% of image height/width (each axis independently)\n", 61 | "# - translate by -20 to +20 relative to height/width (per axis)\n", 62 | "# - rotate by -45 to +45 degrees\n", 63 | "# - shear by -16 to +16 degrees\n", 64 | "# - order: use nearest neighbour or bilinear interpolation (fast)\n", 65 | "# - mode: use any available mode to fill newly created pixels\n", 66 | "# see API or scikit-image for which modes are available\n", 67 | "# - cval: if the mode is constant, then use a random brightness\n", 68 | "# for the newly created pixels (e.g. sometimes black,\n", 69 | "# sometimes white)\n", 70 | "seq = iaa.Sequential([\n", 71 | " iaa.OneOf([\n", 72 | " iaa.Affine(\n", 73 | " scale={\"x\": (0.6, 1.1), \"y\": (0.5, 1.1)},\n", 74 | " translate_percent={\"x\": (-0.2, 0.2), \"y\": (-0.2, 0.2)},\n", 75 | " rotate=(-30, 30),\n", 76 | " shear=(-15, 15),\n", 77 | " order=[0, 1],\n", 78 | " cval=(0),\n", 79 | " ),\n", 80 | " iaa.Affine(\n", 81 | " scale={\"x\": (0.6, 1.1), \"y\": (0.6, 1.1)},\n", 82 | " order=[0, 1],\n", 83 | " cval=(0),\n", 84 | " ),\n", 85 | " iaa.Affine(\n", 86 | " scale={\"x\": (0.6, 0.8), \"y\": (0.6, 0.8)},\n", 87 | " translate_percent={\"x\": (-0.2, 0.2), \"y\": (-0.2, 0.2)},\n", 88 | " order=[0, 1],\n", 89 | " cval=(0),\n", 90 | " ),\n", 91 | " iaa.Affine(\n", 92 | " rotate=(-60, 60),\n", 93 | " #shear=(-30, 30),\n", 94 | " order=[0, 1],\n", 95 | " cval=(0),\n", 96 | " ),\n", 97 | " iaa.Affine(\n", 98 | " shear=(-40, 40),\n", 99 | " order=[0, 1],\n", 100 | " cval=(0),\n", 101 | " ),\n", 102 | " ]\n", 103 | " )\n", 104 | "])\n", 105 | "x_train_aug = seq(images=x_train) \n", 106 | "x_test_aug = seq(images=x_test) " 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "metadata": {}, 112 | "source": [ 113 | "### 可视化\n", 114 | "可视化训练集的函数" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 4, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "from math import ceil\n", 124 | "def draw_samples(images, images_per_row=5):\n", 125 | " num = len(images)\n", 126 | " per_row = min(images_per_row, num)\n", 127 | " rows = ceil(num /per_row)\n", 128 | " fig, axs = plt.subplots(rows, per_row)\n", 129 | " count = 0 \n", 130 | " for i in range(rows):\n", 131 | " \n", 132 | " for j in range(images_per_row):\n", 133 | " count+=1\n", 134 | " if (count > num):\n", 135 | " break\n", 136 | " if rows == 1:\n", 137 | " axs[j+i*per_row].imshow(images[j+i*per_row], cmap='gray')\n", 138 | " else:\n", 139 | " axs[i,j].imshow(images[j+i*per_row], cmap='gray')\n", 140 | " \n", 141 | " plt.show()" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 5, 147 | "metadata": {}, 148 | "outputs": [ 149 | { 150 | "data": { 151 | "image/png": "", 152 | "text/plain": [ 153 | "
" 154 | ] 155 | }, 156 | "metadata": { 157 | "needs_background": "light" 158 | }, 159 | "output_type": "display_data" 160 | } 161 | ], 162 | "source": [ 163 | "# 原数据集数据\n", 164 | "draw_samples(x_train[:30])" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "metadata": {}, 171 | "outputs": [ 172 | { 173 | "data": { 174 | "image/png": "", 175 | "text/plain": [ 176 | "
" 177 | ] 178 | }, 179 | "metadata": { 180 | "needs_background": "light" 181 | }, 182 | "output_type": "display_data" 183 | } 184 | ], 185 | "source": [ 186 | "# 仿射变换的数据\n", 187 | "draw_samples(x_train_aug[:30])" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "### 神经网络\n", 195 | "定义简单的密集神经网络进行训练" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 7, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "Model: \"sequential\"\n", 208 | "_________________________________________________________________\n", 209 | "Layer (type) Output Shape Param # \n", 210 | "=================================================================\n", 211 | "flatten (Flatten) (None, 784) 0 \n", 212 | "_________________________________________________________________\n", 213 | "dense (Dense) (None, 128) 100480 \n", 214 | "_________________________________________________________________\n", 215 | "dropout (Dropout) (None, 128) 0 \n", 216 | "_________________________________________________________________\n", 217 | "dense_1 (Dense) (None, 10) 1290 \n", 218 | "=================================================================\n", 219 | "Total params: 101,770\n", 220 | "Trainable params: 101,770\n", 221 | "Non-trainable params: 0\n", 222 | "_________________________________________________________________\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "# baseline\n", 228 | "model = tf.keras.models.Sequential([\n", 229 | " tf.keras.layers.Flatten(input_shape=(H, W)),\n", 230 | " tf.keras.layers.Dense(128, activation='relu'),\n", 231 | " tf.keras.layers.Dropout(0.2),\n", 232 | " tf.keras.layers.Dense(10)\n", 233 | "])\n", 234 | "model.summary()" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 8, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "name": "stdout", 244 | "output_type": "stream", 245 | "text": [ 246 | "Epoch 1/5\n", 247 | "3750/3750 [==============================] - 5s 1ms/step - loss: 0.9359 - accuracy: 0.7041\n", 248 | "Epoch 2/5\n", 249 | "3750/3750 [==============================] - 4s 987us/step - loss: 0.4982 - accuracy: 0.8430\n", 250 | "Epoch 3/5\n", 251 | "3750/3750 [==============================] - 4s 1ms/step - loss: 0.4187 - accuracy: 0.8673\n", 252 | "Epoch 4/5\n", 253 | "3750/3750 [==============================] - 4s 1ms/step - loss: 0.3788 - accuracy: 0.8810\n", 254 | "Epoch 5/5\n", 255 | "3750/3750 [==============================] - 4s 1ms/step - loss: 0.3533 - accuracy: 0.8883\n" 256 | ] 257 | }, 258 | { 259 | "data": { 260 | "text/plain": [ 261 | "" 262 | ] 263 | }, 264 | "execution_count": 8, 265 | "metadata": {}, 266 | "output_type": "execute_result" 267 | } 268 | ], 269 | "source": [ 270 | "# 定义损失函数\n", 271 | "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", 272 | "model.compile(optimizer='adam',\n", 273 | " loss=loss_fn,\n", 274 | " metrics=['accuracy'])\n", 275 | "# completely retrain the original model with mnist dataset + distorted mnist dataset\n", 276 | "model.fit(tf.concat([x_train_aug, x_train],0), tf.concat([y_train, y_train],0), epochs=5)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 11, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "name": "stdout", 286 | "output_type": "stream", 287 | "text": [ 288 | "313/313 [==============================] - 0s 929us/step - loss: 0.5379 - accuracy: 0.8329\n", 289 | "distorted data: [0.5378811955451965, 0.8328999876976013]\n", 290 | "313/313 [==============================] - 0s 663us/step - loss: 0.0755 - accuracy: 0.9776\n", 291 | "undistorted original data: [0.07546955347061157, 0.9775999784469604]\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "# not good for distorted data\n", 297 | "print('distorted data:',model.evaluate(x_test_aug, y_test))\n", 298 | "# still good for undistorted original data\n", 299 | "print('undistorted original data:',model.evaluate(x_test, y_test))" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "定义STN模块的网络进行训练" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 13, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "Model: \"model_3\"\n", 319 | "__________________________________________________________________________________________________\n", 320 | "Layer (type) Output Shape Param # Connected to \n", 321 | "==================================================================================================\n", 322 | "input_3 (InputLayer) [(None, 28, 28, 1)] 0 \n", 323 | "__________________________________________________________________________________________________\n", 324 | "conv2d_5 (Conv2D) (None, 24, 24, 14) 364 input_3[0][0] \n", 325 | "__________________________________________________________________________________________________\n", 326 | "max_pooling2d_4 (MaxPooling2D) (None, 12, 12, 14) 0 conv2d_5[0][0] \n", 327 | "__________________________________________________________________________________________________\n", 328 | "conv2d_6 (Conv2D) (None, 8, 8, 32) 11232 max_pooling2d_4[0][0] \n", 329 | "__________________________________________________________________________________________________\n", 330 | "max_pooling2d_5 (MaxPooling2D) (None, 4, 4, 32) 0 conv2d_6[0][0] \n", 331 | "__________________________________________________________________________________________________\n", 332 | "flatten_3 (Flatten) (None, 512) 0 max_pooling2d_5[0][0] \n", 333 | "__________________________________________________________________________________________________\n", 334 | "dense_8 (Dense) (None, 120) 61560 flatten_3[0][0] \n", 335 | "__________________________________________________________________________________________________\n", 336 | "dropout_2 (Dropout) (None, 120) 0 dense_8[0][0] \n", 337 | "__________________________________________________________________________________________________\n", 338 | "tf.compat.v1.shape_3 (TFOpLambd (4,) 0 input_3[0][0] \n", 339 | "__________________________________________________________________________________________________\n", 340 | "dense_9 (Dense) (None, 84) 10164 dropout_2[0][0] \n", 341 | "__________________________________________________________________________________________________\n", 342 | "tf.__operators__.getitem_4 (Sli () 0 tf.compat.v1.shape_3[0][0] \n", 343 | "__________________________________________________________________________________________________\n", 344 | "dense_10 (Dense) (None, 6) 510 dense_9[0][0] \n", 345 | "__________________________________________________________________________________________________\n", 346 | "tf.tile_1 (TFOpLambda) (None, 784, 3, 1) 0 tf.__operators__.getitem_4[0][0] \n", 347 | "__________________________________________________________________________________________________\n", 348 | "tf.reshape_3 (TFOpLambda) (None, 2, 3) 0 dense_10[0][0] \n", 349 | "__________________________________________________________________________________________________\n", 350 | "tf.compat.v1.squeeze_1 (TFOpLam (None, 784, 3) 0 tf.tile_1[0][0] \n", 351 | "__________________________________________________________________________________________________\n", 352 | "tf.linalg.matmul_3 (TFOpLambda) (None, 2, 784) 0 tf.reshape_3[0][0] \n", 353 | " tf.compat.v1.squeeze_1[0][0] \n", 354 | "__________________________________________________________________________________________________\n", 355 | "tf.linalg.matrix_transpose_1 (T (None, 784, 2) 0 tf.linalg.matmul_3[0][0] \n", 356 | "__________________________________________________________________________________________________\n", 357 | "tf.__operators__.add_3 (TFOpLam (None, 784, 2) 0 tf.linalg.matrix_transpose_1[0][0\n", 358 | "__________________________________________________________________________________________________\n", 359 | "tf.math.multiply_2 (TFOpLambda) (None, 784, 2) 0 tf.__operators__.add_3[0][0] \n", 360 | "__________________________________________________________________________________________________\n", 361 | "tf.math.multiply_3 (TFOpLambda) (None, 784, 2) 0 tf.math.multiply_2[0][0] \n", 362 | "__________________________________________________________________________________________________\n", 363 | "tf.split_2 (TFOpLambda) [(None, 784, 1), (No 0 tf.math.multiply_3[0][0] \n", 364 | "__________________________________________________________________________________________________\n", 365 | "tf.math.floor_2 (TFOpLambda) (None, 784, 1) 0 tf.split_2[0][0] \n", 366 | "__________________________________________________________________________________________________\n", 367 | "tf.math.floor_3 (TFOpLambda) (None, 784, 1) 0 tf.split_2[0][1] \n", 368 | "__________________________________________________________________________________________________\n", 369 | "tf.cast_7 (TFOpLambda) (None, 784, 1) 0 tf.math.floor_2[0][0] \n", 370 | "__________________________________________________________________________________________________\n", 371 | "tf.cast_8 (TFOpLambda) (None, 784, 1) 0 tf.math.floor_3[0][0] \n", 372 | "__________________________________________________________________________________________________\n", 373 | "tf.__operators__.add_4 (TFOpLam (None, 784, 1) 0 tf.cast_7[0][0] \n", 374 | "__________________________________________________________________________________________________\n", 375 | "tf.__operators__.add_5 (TFOpLam (None, 784, 1) 0 tf.cast_8[0][0] \n", 376 | "__________________________________________________________________________________________________\n", 377 | "tf.clip_by_value_6 (TFOpLambda) (None, 784, 1) 0 tf.__operators__.add_4[0][0] \n", 378 | "__________________________________________________________________________________________________\n", 379 | "tf.clip_by_value_4 (TFOpLambda) (None, 784, 1) 0 tf.cast_7[0][0] \n", 380 | "__________________________________________________________________________________________________\n", 381 | "tf.clip_by_value_5 (TFOpLambda) (None, 784, 1) 0 tf.cast_8[0][0] \n", 382 | "__________________________________________________________________________________________________\n", 383 | "tf.clip_by_value_7 (TFOpLambda) (None, 784, 1) 0 tf.__operators__.add_5[0][0] \n", 384 | "__________________________________________________________________________________________________\n", 385 | "tf.concat_7 (TFOpLambda) (None, 784, 2) 0 tf.clip_by_value_5[0][0] \n", 386 | " tf.clip_by_value_4[0][0] \n", 387 | "__________________________________________________________________________________________________\n", 388 | "tf.concat_8 (TFOpLambda) (None, 784, 2) 0 tf.clip_by_value_7[0][0] \n", 389 | " tf.clip_by_value_4[0][0] \n", 390 | "__________________________________________________________________________________________________\n", 391 | "tf.concat_9 (TFOpLambda) (None, 784, 2) 0 tf.clip_by_value_5[0][0] \n", 392 | " tf.clip_by_value_6[0][0] \n", 393 | "__________________________________________________________________________________________________\n", 394 | "tf.concat_10 (TFOpLambda) (None, 784, 2) 0 tf.clip_by_value_7[0][0] \n", 395 | " tf.clip_by_value_6[0][0] \n", 396 | "__________________________________________________________________________________________________\n", 397 | "tf.compat.v1.gather_nd_4 (TFOpL (None, 784, 1) 0 input_3[0][0] \n", 398 | " tf.concat_7[0][0] \n", 399 | "__________________________________________________________________________________________________\n", 400 | "tf.compat.v1.gather_nd_5 (TFOpL (None, 784, 1) 0 input_3[0][0] \n", 401 | " tf.concat_8[0][0] \n", 402 | "__________________________________________________________________________________________________\n", 403 | "tf.compat.v1.gather_nd_6 (TFOpL (None, 784, 1) 0 input_3[0][0] \n", 404 | " tf.concat_9[0][0] \n", 405 | "__________________________________________________________________________________________________\n", 406 | "tf.compat.v1.gather_nd_7 (TFOpL (None, 784, 1) 0 input_3[0][0] \n", 407 | " tf.concat_10[0][0] \n", 408 | "__________________________________________________________________________________________________\n", 409 | "tf.concat_13 (TFOpLambda) (None, 784, 4) 0 tf.compat.v1.gather_nd_4[0][0] \n", 410 | " tf.compat.v1.gather_nd_5[0][0] \n", 411 | " tf.compat.v1.gather_nd_6[0][0] \n", 412 | " tf.compat.v1.gather_nd_7[0][0] \n", 413 | "__________________________________________________________________________________________________\n", 414 | "tf.cast_9 (TFOpLambda) (None, 784, 1) 0 tf.clip_by_value_6[0][0] \n", 415 | "__________________________________________________________________________________________________\n", 416 | "tf.split_3 (TFOpLambda) [(None, 784, 1), (No 0 tf.math.multiply_3[0][0] \n", 417 | "__________________________________________________________________________________________________\n", 418 | "tf.cast_10 (TFOpLambda) (None, 784, 1) 0 tf.clip_by_value_4[0][0] \n", 419 | "__________________________________________________________________________________________________\n", 420 | "tf.compat.v1.shape_5 (TFOpLambd (3,) 0 tf.concat_13[0][0] \n", 421 | "__________________________________________________________________________________________________\n", 422 | "tf.math.subtract_4 (TFOpLambda) (None, 784, 1) 0 tf.cast_9[0][0] \n", 423 | " tf.split_3[0][0] \n", 424 | "__________________________________________________________________________________________________\n", 425 | "tf.math.subtract_5 (TFOpLambda) (None, 784, 1) 0 tf.split_3[0][0] \n", 426 | " tf.cast_10[0][0] \n", 427 | "__________________________________________________________________________________________________\n", 428 | "tf.__operators__.getitem_6 (Sli () 0 tf.compat.v1.shape_5[0][0] \n", 429 | "__________________________________________________________________________________________________\n", 430 | "tf.__operators__.getitem_7 (Sli () 0 tf.compat.v1.shape_5[0][0] \n", 431 | "__________________________________________________________________________________________________\n", 432 | "tf.cast_11 (TFOpLambda) (None, 784, 1) 0 tf.clip_by_value_7[0][0] \n", 433 | "__________________________________________________________________________________________________\n", 434 | "tf.cast_12 (TFOpLambda) (None, 784, 1) 0 tf.clip_by_value_5[0][0] \n", 435 | "__________________________________________________________________________________________________\n", 436 | "tf.concat_11 (TFOpLambda) (None, 784, 2) 0 tf.math.subtract_4[0][0] \n", 437 | " tf.math.subtract_5[0][0] \n", 438 | "__________________________________________________________________________________________________\n", 439 | "tf.reshape_4 (TFOpLambda) (None, None, 2, 2) 0 tf.concat_13[0][0] \n", 440 | " tf.__operators__.getitem_6[0][0] \n", 441 | " tf.__operators__.getitem_7[0][0] \n", 442 | "__________________________________________________________________________________________________\n", 443 | "tf.math.subtract_6 (TFOpLambda) (None, 784, 1) 0 tf.cast_11[0][0] \n", 444 | " tf.split_3[0][1] \n", 445 | "__________________________________________________________________________________________________\n", 446 | "tf.math.subtract_7 (TFOpLambda) (None, 784, 1) 0 tf.split_3[0][1] \n", 447 | " tf.cast_12[0][0] \n", 448 | "__________________________________________________________________________________________________\n", 449 | "tf.expand_dims_2 (TFOpLambda) (None, 784, 1, 2) 0 tf.concat_11[0][0] \n", 450 | "__________________________________________________________________________________________________\n", 451 | "tf.cast_13 (TFOpLambda) (None, None, 2, 2) 0 tf.reshape_4[0][0] \n", 452 | "__________________________________________________________________________________________________\n", 453 | "tf.concat_12 (TFOpLambda) (None, 784, 2) 0 tf.math.subtract_6[0][0] \n", 454 | " tf.math.subtract_7[0][0] \n", 455 | "__________________________________________________________________________________________________\n", 456 | "tf.linalg.matmul_4 (TFOpLambda) (None, 784, 1, 2) 0 tf.expand_dims_2[0][0] \n", 457 | " tf.cast_13[0][0] \n", 458 | "__________________________________________________________________________________________________\n", 459 | "tf.expand_dims_3 (TFOpLambda) (None, 784, 2, 1) 0 tf.concat_12[0][0] \n", 460 | "__________________________________________________________________________________________________\n", 461 | "tf.linalg.matmul_5 (TFOpLambda) (None, 784, 1, 1) 0 tf.linalg.matmul_4[0][0] \n", 462 | " tf.expand_dims_3[0][0] \n", 463 | "__________________________________________________________________________________________________\n", 464 | "tf.reshape_5 (TFOpLambda) (None, 28, 28, 1) 0 tf.linalg.matmul_5[0][0] \n", 465 | "__________________________________________________________________________________________________\n", 466 | "conv2d_7 (Conv2D) (None, 26, 26, 6) 60 tf.reshape_5[0][0] \n", 467 | "__________________________________________________________________________________________________\n", 468 | "max_pooling2d_6 (MaxPooling2D) (None, 13, 13, 6) 0 conv2d_7[0][0] \n", 469 | "__________________________________________________________________________________________________\n", 470 | "conv2d_8 (Conv2D) (None, 11, 11, 16) 880 max_pooling2d_6[0][0] \n", 471 | "__________________________________________________________________________________________________\n", 472 | "max_pooling2d_7 (MaxPooling2D) (None, 5, 5, 16) 0 conv2d_8[0][0] \n", 473 | "__________________________________________________________________________________________________\n", 474 | "flatten_4 (Flatten) (None, 400) 0 max_pooling2d_7[0][0] \n", 475 | "__________________________________________________________________________________________________\n", 476 | "dense_11 (Dense) (None, 120) 48120 flatten_4[0][0] \n", 477 | "__________________________________________________________________________________________________\n", 478 | "dense_12 (Dense) (None, 84) 10164 dense_11[0][0] \n", 479 | "__________________________________________________________________________________________________\n", 480 | "dense_13 (Dense) (None, 10) 850 dense_12[0][0] \n", 481 | "==================================================================================================\n", 482 | "Total params: 143,904\n", 483 | "Trainable params: 143,904\n", 484 | "Non-trainable params: 0\n", 485 | "__________________________________________________________________________________________________\n" 486 | ] 487 | } 488 | ], 489 | "source": [ 490 | "from stn_module import stn_module\n", 491 | "from tensorflow.keras.layers import Conv2D, MaxPooling2D,Flatten,Dense,Input,Dropout\n", 492 | "def model(input_shape):\n", 493 | " inputs = Input(input_shape)\n", 494 | " inputs_stn = stn_module(inputs)\n", 495 | " x = Conv2D(6, (3,3),padding='valid',activation=\"relu\")(inputs_stn)\n", 496 | " x = MaxPooling2D((2, 2))(x)\n", 497 | " x = Conv2D(16, (3,3),padding='valid',activation=\"relu\")(x)\n", 498 | " x = MaxPooling2D((2, 2))(x)\n", 499 | " x = Flatten()(x)\n", 500 | " x = Dense(120, activation='relu')(x)\n", 501 | " x = Dense(84, activation='relu')(x)\n", 502 | " x = Dense(10)(x)\n", 503 | " return tf.keras.Model(inputs, x)\n", 504 | "st= model(input_shape=(H, W, 1))\n", 505 | "st.summary()" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 15, 511 | "metadata": {}, 512 | "outputs": [ 513 | { 514 | "name": "stdout", 515 | "output_type": "stream", 516 | "text": [ 517 | "Epoch 1/5\n", 518 | "3750/3750 [==============================] - 76s 20ms/step - loss: 0.4957 - accuracy: 0.8421\n", 519 | "Epoch 2/5\n", 520 | "3750/3750 [==============================] - 76s 20ms/step - loss: 0.1169 - accuracy: 0.9625\n", 521 | "Epoch 3/5\n", 522 | "3750/3750 [==============================] - 74s 20ms/step - loss: 0.0902 - accuracy: 0.9719\n", 523 | "Epoch 4/5\n", 524 | "3750/3750 [==============================] - 72s 19ms/step - loss: 0.0817 - accuracy: 0.9746\n", 525 | "Epoch 5/5\n", 526 | "3750/3750 [==============================] - 73s 19ms/step - loss: 0.0713 - accuracy: 0.9778\n" 527 | ] 528 | }, 529 | { 530 | "data": { 531 | "text/plain": [ 532 | "" 533 | ] 534 | }, 535 | "execution_count": 15, 536 | "metadata": {}, 537 | "output_type": "execute_result" 538 | } 539 | ], 540 | "source": [ 541 | "# 训练\n", 542 | "loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", 543 | "st.compile(optimizer=\"adam\",\n", 544 | " loss=loss_fn,\n", 545 | " metrics=['accuracy'])\n", 546 | "st.fit(tf.concat([x_train_aug, x_train],0), tf.concat([y_train, y_train],0), epochs=5)" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 16, 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "name": "stdout", 556 | "output_type": "stream", 557 | "text": [ 558 | "313/313 [==============================] - 3s 9ms/step - loss: 0.0346 - accuracy: 0.9900\n", 559 | "undistorted origin data: [0.034552980214357376, 0.9900000095367432]\n", 560 | "313/313 [==============================] - 3s 9ms/step - loss: 0.1000 - accuracy: 0.9684\n", 561 | "distorted data: [0.1000135987997055, 0.9684000015258789]\n" 562 | ] 563 | } 564 | ], 565 | "source": [ 566 | "# 评估模型\n", 567 | "print('undistorted origin data:', st.evaluate(x_test, y_test))\n", 568 | "print('distorted data:', st.evaluate(x_test_aug, y_test))" 569 | ] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": null, 574 | "metadata": {}, 575 | "outputs": [], 576 | "source": [] 577 | } 578 | ], 579 | "metadata": { 580 | "kernelspec": { 581 | "display_name": "Python 3.7.6 ('tf2')", 582 | "language": "python", 583 | "name": "python3" 584 | }, 585 | "language_info": { 586 | "codemirror_mode": { 587 | "name": "ipython", 588 | "version": 3 589 | }, 590 | "file_extension": ".py", 591 | "mimetype": "text/x-python", 592 | "name": "python", 593 | "nbconvert_exporter": "python", 594 | "pygments_lexer": "ipython3", 595 | "version": "3.7.6" 596 | }, 597 | "orig_nbformat": 4, 598 | "vscode": { 599 | "interpreter": { 600 | "hash": "c6c6e9ad919e43ea991096268ac22857d89ff5f05140928bb8d03f6bb8d6e7c0" 601 | } 602 | } 603 | }, 604 | "nbformat": 4, 605 | "nbformat_minor": 2 606 | } 607 | --------------------------------------------------------------------------------