├── BiasNet.py ├── README.md ├── architecture_nn.png ├── config.json ├── figure1_architecture.png ├── kernal.py ├── main.py ├── non_local.py └── unet_vanilla.py /BiasNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from keras.layers import Layer 6 | 7 | 8 | class BiasNet(Layer): 9 | def __init__(self, **kwargs): 10 | super(BiasNet, self).__init__(**kwargs) 11 | self.kernel = None 12 | 13 | def build(self, input_shape): 14 | self.kernel = self.add_weight(name='kernel', 15 | shape=(1, 1, input_shape[3]), 16 | initializer='he_normal', 17 | trainable=True) 18 | super(BiasNet, self).build(input_shape) 19 | 20 | def call(self, x, **kwargs): 21 | return x + self.kernel 22 | 23 | def compute_output_shape(self, input_shape): 24 | return input_shape 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hyper-Convolution Networks for Biomedical Image Segmentation 2 | Code for our WACV 2022 paper: 3 | 4 | *Hyper-Convolution Networks for Biomedical Image Segmentation* (https://arxiv.org/abs/2105.10559) 5 | 6 | and our journal extension published at Medical Image Analysis 7 | 8 | *Hyper-convolutions via implicit kernels for medical image analysis* 9 | (https://www.sciencedirect.com/science/article/pii/S1361841523000579) 10 | 11 | Convolutional Kernels are generated by a hyper-network instead of independtly learned 12 | 13 | 14 | 15 | The input to the hyper-network are the spatial coordinates of the kernels 16 | 17 | ## requirements: 18 | 19 | `tensorflow-gpu 1.15.0` 20 | 21 | `python 3.6.13` 22 | 23 | ## Code: 24 | 25 | To initiate training or testing, run: 26 | `python main.py --mode train --config_path config.json` 27 | 28 | `--mode train` for training, `--mode test` for testing 29 | 30 | `--config_path` is the path to config json file that contains all model related config 31 | 32 | `kernal.py` contains the input to the hyper-network, which is a two-channels coordinates grid (x and y) 33 | 34 | `unet_vanilla.py` contains all the networks including the baseline UNet, non-local UNet and our method 35 | 36 | ## Citation: 37 | 38 | If you find our code useful, please cite our work, thank you! 39 | ``` 40 | @inproceedings{ma2022hyper, 41 | title={Hyper-convolution networks for biomedical image segmentation}, 42 | author={Ma, Tianyu and Dalca, Adrian V and Sabuncu, Mert R}, 43 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 44 | pages={1933--1942}, 45 | year={2022} 46 | } 47 | ``` 48 | ``` 49 | @article{ma2023hyper, 50 | title={Hyper-convolutions via implicit kernels for medical image analysis}, 51 | author={Ma, Tianyu and Wang, Alan Q and Dalca, Adrian V and Sabuncu, Mert R}, 52 | journal={Medical Image Analysis}, 53 | pages={102796}, 54 | year={2023}, 55 | publisher={Elsevier} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /architecture_nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tym002/Hyper-Convolution/fe5757239817508b6383476654ab1fe41d9977c8/architecture_nn.png -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_shape_x" : 128, 3 | "img_shape_y" : 128, 4 | "img_shape_z" : 4, 5 | "depth": 4, 6 | "dropout": 0.5, 7 | "activation": "relu", 8 | "start_ch": 32, 9 | "residual": false, 10 | "batchnorm": true, 11 | "attention": false, 12 | "non_local": false, 13 | "pos": false, 14 | "hyper": false, 15 | "x_train_path": "/data/x_train.npy", 16 | "y_train_path": "/data/y_train.npy", 17 | "x_val_path": "/data/x_val.npy", 18 | "y_val_path": "/data/y_val.npy", 19 | "x_test_path": "/data/x_test.npy", 20 | "y_test_path": "/data/y_test.npy" 21 | } -------------------------------------------------------------------------------- /figure1_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tym002/Hyper-Convolution/fe5757239817508b6383476654ab1fe41d9977c8/figure1_architecture.png -------------------------------------------------------------------------------- /kernal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def hyperNet(x_dim=3, y_dim=3, ch_in=64, ch_out=64): 8 | xx_range = tf.range(-(x_dim - 1) / 2, (x_dim + 1) / 2, dtype='float32') 9 | yy_range = tf.range(-(y_dim - 1) / 2, (y_dim + 1) / 2, dtype='float32') 10 | 11 | xx_range = tf.tile(tf.expand_dims(xx_range, -1), [1, y_dim]) 12 | yy_range = tf.tile(tf.expand_dims(yy_range, 0), [x_dim, 1]) 13 | 14 | xx_range = tf.expand_dims(xx_range, -1) 15 | yy_range = tf.expand_dims(yy_range, -1) 16 | 17 | pos = tf.concat([xx_range, yy_range], -1) 18 | 19 | pos = tf.expand_dims(pos, 0) 20 | 21 | return pos 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from unet_vanilla import * 6 | import os 7 | import numpy as np 8 | from keras.callbacks import ModelCheckpoint 9 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 10 | import pandas as pd 11 | import argparse 12 | import json 13 | 14 | 15 | def load_train_data(x_train_path, y_train_path, x_val_path, y_val_path): 16 | """ 17 | load the training data and ground truth 18 | """ 19 | x_train = np.load(x_train_path) 20 | y_train = np.load(y_train_path) 21 | x_val = np.load(x_val_path) 22 | y_val = np.load(y_val_path) 23 | 24 | return x_train, y_train, x_val, y_val 25 | 26 | 27 | def load_test_data(x_test_path, y_test_path): 28 | """ 29 | load the test data and ground truth 30 | """ 31 | x_test = np.load(x_test_path) 32 | y_test = np.load(y_test_path) 33 | 34 | return x_test, y_test 35 | 36 | 37 | def train_validate(gpu, b_size, weight_path, save_path, history_path, config_arg): 38 | print('----- Loading and preprocessing train data... -----') 39 | 40 | imgs_test, mask_test = load_test_data(config_arg["x_test_path"], config_arg["y_test_path"]) 41 | x_train, y_train, x_test, y_test = load_train_data(config_arg["x_train_path"], config_arg["y_train_path"], 42 | config_arg["x_val_path"], config_arg["y_val_path"]) 43 | 44 | print('Number of train:', x_train.shape[0]) 45 | print('Number of val:', x_test.shape[0]) 46 | print('Number of test:', imgs_test.shape[0]) 47 | 48 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 49 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 50 | 51 | print('----- Creating and compiling model... -----') 52 | 53 | data_gen_args = dict(featurewise_center=False, 54 | featurewise_std_normalization=False, 55 | horizontal_flip=False, 56 | vertical_flip=False, 57 | width_shift_range=0, 58 | height_shift_range=0, 59 | zoom_range=0, 60 | rotation_range=0 61 | ) 62 | image_datagen = ImageDataGenerator(**data_gen_args) 63 | mask_datagen = ImageDataGenerator(**data_gen_args) 64 | 65 | seed = 42 66 | 67 | image_generator = image_datagen.flow(x_train, batch_size=b_size, seed=seed) 68 | mask_generator = mask_datagen.flow(y_train, batch_size=b_size, seed=seed) 69 | 70 | train_generator = (pair for pair in zip(image_generator, mask_generator)) 71 | 72 | model = unet(img_shape=(config_arg["img_shape_x"], config_arg["img_shape_y"], config_arg["img_shape_z"]), 73 | # size of the input image 74 | depth=config_arg["depth"], # nums of max-pooling layers 75 | dropout=config_arg["dropout"], # dropout rate 76 | activation=config_arg["activation"], # non-linear activation type 77 | start_ch=config_arg["start_ch"], # initial channels 78 | residual=config_arg["residual"], # Residual connection 79 | batchnorm=config_arg["batchnorm"], # Batch Normalization 80 | att=config_arg["attention"], # Attention module 81 | nl=config_arg["non_local"], # Non-local block at the bottom 82 | pos=config_arg["pos"], # Positional encoding 83 | hyper=config_arg["hyper"]) # Hyper-conv kernel size. Use False for regular network 84 | 85 | model_checkpoint = ModelCheckpoint(weight_path, 86 | monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True) 87 | 88 | print('----- Fitting model... -----') 89 | 90 | mtrain = model.fit_generator(train_generator, steps_per_epoch=len(x_train) // b_size, 91 | epochs=1000, verbose=1, shuffle=True, callbacks=[model_checkpoint], 92 | validation_data=(x_test, [y_test])) 93 | 94 | model_predict = model.predict([imgs_test], verbose=1, batch_size=16) 95 | np.save(save_path, model_predict) 96 | pd.DataFrame.from_dict(mtrain.history).to_csv(history_path, index=False) 97 | 98 | 99 | def prediction(gpu, weight_path, save_path, config_arg): 100 | print('----- Loading and preprocessing test data... -----') 101 | 102 | imgs_test, mask_test = load_test_data(config_arg["x_test_path"], config_arg["y_test_path"]) 103 | 104 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 105 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 106 | 107 | print('----- Creating and compiling model... -----') 108 | 109 | ############################## 110 | # this is the prediction!! # 111 | ############################## 112 | model = unet(img_shape=(config_arg["img_shape_x"], config_arg["img_shape_y"], config_arg["img_shape_z"]), 113 | # size of the input image 114 | depth=config_arg["depth"], # nums of max-pooling layers 115 | dropout=config_arg["dropout"], # dropout rate 116 | activation=config_arg["activation"], # non-linear activation type 117 | start_ch=config_arg["start_ch"], # initial channels 118 | residual=config_arg["residual"], # Residual connection 119 | batchnorm=config_arg["batchnorm"], # Batch Normalization 120 | att=config_arg["attention"], # Attention module 121 | nl=config_arg["non_local"], # Non-local block at the bottom 122 | pos=config_arg["pos"], # Positional encoding 123 | hyper=config_arg["hyper"]) # Hyper-conv kernel size. Use False for regular network 124 | 125 | model.load_weights(weight_path) 126 | print(model.summary()) 127 | 128 | print('----- Fitting model... -----') 129 | 130 | model_predict = model.predict([imgs_test], verbose=1, batch_size=1) 131 | np.save(save_path, model_predict) 132 | 133 | 134 | def main(arg, config_arg): 135 | mode = arg.mode 136 | batch_size = arg.b_size 137 | gpu = arg.gpu 138 | file_name = arg.file_name 139 | save_folder = '/result/' + arg.folder_name + '/' 140 | weight_path = save_folder + file_name + '.hdf5' 141 | save_path = save_folder + 'Prediction_' + file_name + '.npy' 142 | history_path = save_folder + 'history_' + file_name + '.csv' 143 | 144 | if not os.path.exists(save_folder): 145 | print(f'making save folder {save_folder}') 146 | os.makedirs(save_folder) 147 | if mode == "train": 148 | train_validate(gpu, batch_size, weight_path, save_path, history_path, config_arg) 149 | elif mode == "test": 150 | prediction(gpu, weight_path, save_path, config_arg) 151 | else: 152 | print("mode should be either train or test") 153 | 154 | 155 | if __name__ == '__main__': 156 | parser = argparse.ArgumentParser(description="training and testing script") 157 | parser.add_argument("--mode", default="train", help="train or test") 158 | parser.add_argument("--config_path", default="config.json", help="path to config file") 159 | parser.add_argument("--folder_name", default="training", help="name of the folder to save results") 160 | parser.add_argument("--file_name", default="training", help="name of the trained model file") 161 | parser.add_argument("--gpu", default=0, help="which gpu to use") 162 | parser.add_argument("--b_size", default=8, help="batch size") 163 | 164 | args = parser.parse_args() 165 | config_args = json.load(open(args.config_path)) 166 | main(args, config_args) 167 | -------------------------------------------------------------------------------- /non_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | from keras.layers import Activation, Reshape, Lambda, dot, add 6 | from keras.layers import Conv1D, Conv2D, Conv3D 7 | from keras.layers import MaxPool1D 8 | from keras import backend as K 9 | 10 | 11 | def non_local_block(ip, intermediate_dim=None, compression=2, 12 | mode='embedded', add_residual=True): 13 | """ 14 | Adds a Non-Local block for self attention to the input tensor. 15 | Input tensor can be or rank 3 (temporal), 4 (spatial) or 5 (spatio-temporal). 16 | Arguments: 17 | ip: input tensor 18 | intermediate_dim: The dimension of the intermediate representation. Can be 19 | `None` or a positive integer greater than 0. If `None`, computes the 20 | intermediate dimension as half of the input channel dimension. 21 | compression: None or positive integer. Compresses the intermediate 22 | representation during the dot products to reduce memory consumption. 23 | Default is set to 2, which states halve the time/space/spatio-time 24 | dimension for the intermediate step. Set to 1 to prevent computation 25 | compression. None or 1 causes no reduction. 26 | mode: Mode of operation. Can be one of `embedded`, `gaussian`, `dot` or 27 | `concatenate`. 28 | add_residual: Boolean value to decide if the residual connection should be 29 | added or not. Default is True for ResNets, and False for Self Attention. 30 | Returns: 31 | a tensor of same shape as input 32 | """ 33 | channel_dim = 1 if K.image_data_format() == 'channels_first' else -1 34 | ip_shape = K.int_shape(ip) 35 | 36 | if mode not in ['gaussian', 'embedded', 'dot', 'concatenate']: 37 | raise ValueError('`mode` must be one of `gaussian`, `embedded`, `dot` or `concatenate`') 38 | 39 | if compression is None: 40 | compression = 1 41 | 42 | dim1, dim2, dim3 = None, None, None 43 | 44 | # check rank and calculate the input shape 45 | if len(ip_shape) == 3: # temporal / time series data 46 | rank = 3 47 | batchsize, dim1, channels = ip_shape 48 | 49 | elif len(ip_shape) == 4: # spatial / image data 50 | rank = 4 51 | 52 | if channel_dim == 1: 53 | batchsize, channels, dim1, dim2 = ip_shape 54 | else: 55 | batchsize, dim1, dim2, channels = ip_shape 56 | 57 | elif len(ip_shape) == 5: # spatio-temporal / Video or Voxel data 58 | rank = 5 59 | 60 | if channel_dim == 1: 61 | batchsize, channels, dim1, dim2, dim3 = ip_shape 62 | else: 63 | batchsize, dim1, dim2, dim3, channels = ip_shape 64 | 65 | else: 66 | raise ValueError('Input dimension has to be either 3 (temporal), 4 (spatial) or 5 (spatio-temporal)') 67 | 68 | # verify correct intermediate dimension specified 69 | if intermediate_dim is None: 70 | intermediate_dim = channels // 2 71 | 72 | if intermediate_dim < 1: 73 | intermediate_dim = 1 74 | 75 | else: 76 | intermediate_dim = int(intermediate_dim) 77 | 78 | if intermediate_dim < 1: 79 | raise ValueError('`intermediate_dim` must be either `None` or positive integer greater than 1.') 80 | 81 | if mode == 'gaussian': # Gaussian instantiation 82 | x1 = Reshape((-1, channels))(ip) # xi 83 | x2 = Reshape((-1, channels))(ip) # xj 84 | f = dot([x1, x2], axes=2) 85 | f = Activation('softmax')(f) 86 | 87 | elif mode == 'dot': # Dot instantiation 88 | # theta path 89 | theta = _convND(ip, rank, intermediate_dim) 90 | theta = Reshape((-1, intermediate_dim))(theta) 91 | 92 | # phi path 93 | phi = _convND(ip, rank, intermediate_dim) 94 | phi = Reshape((-1, intermediate_dim))(phi) 95 | 96 | f = dot([theta, phi], axes=2) 97 | 98 | size = K.int_shape(f) 99 | 100 | # scale the values to make it size invariant 101 | f = Lambda(lambda z: (1. / float(size[-1])) * z)(f) 102 | 103 | elif mode == 'concatenate': # Concatenation instantiation 104 | raise NotImplementedError('Concatenate model has not been implemented yet') 105 | 106 | else: # Embedded Gaussian instantiation 107 | # theta path 108 | theta = _convND(ip, rank, intermediate_dim) 109 | theta = Reshape((-1, intermediate_dim))(theta) 110 | 111 | # phi path 112 | phi = _convND(ip, rank, intermediate_dim) 113 | phi = Reshape((-1, intermediate_dim))(phi) 114 | 115 | if compression > 1: 116 | # shielded computation 117 | phi = MaxPool1D(compression)(phi) 118 | 119 | f = dot([theta, phi], axes=2) 120 | f = Activation('softmax')(f) 121 | 122 | # g path 123 | g = _convND(ip, rank, intermediate_dim) 124 | g = Reshape((-1, intermediate_dim))(g) 125 | 126 | if compression > 1 and mode == 'embedded': 127 | # shielded computation 128 | g = MaxPool1D(compression)(g) 129 | 130 | # compute output path 131 | y = dot([f, g], axes=[2, 1]) 132 | 133 | # reshape to input tensor format 134 | if rank == 3: 135 | y = Reshape((dim1, intermediate_dim))(y) 136 | elif rank == 4: 137 | if channel_dim == -1: 138 | y = Reshape((dim1, dim2, intermediate_dim))(y) 139 | else: 140 | y = Reshape((intermediate_dim, dim1, dim2))(y) 141 | else: 142 | if channel_dim == -1: 143 | y = Reshape((dim1, dim2, dim3, intermediate_dim))(y) 144 | else: 145 | y = Reshape((intermediate_dim, dim1, dim2, dim3))(y) 146 | 147 | # project filters 148 | y = _convND(y, rank, channels) 149 | 150 | # residual connection 151 | if add_residual: 152 | y = add([ip, y]) 153 | 154 | return y 155 | 156 | 157 | def _convND(ip, rank, channels): 158 | assert rank in [3, 4, 5], "Rank of input must be 3, 4 or 5" 159 | 160 | if rank == 3: 161 | x = Conv1D(channels, 1, padding='same', use_bias=False, kernel_initializer='he_normal')(ip) 162 | elif rank == 4: 163 | x = Conv2D(channels, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal')(ip) 164 | else: 165 | x = Conv3D(channels, (1, 1, 1), padding='same', use_bias=False, kernel_initializer='he_normal')(ip) 166 | return x 167 | -------------------------------------------------------------------------------- /unet_vanilla.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | from keras.models import Model 5 | from keras.layers import Softmax, Lambda, Input, Conv2D, UpSampling2D, Dropout, MaxPooling2D, Concatenate, \ 6 | BatchNormalization, Activation, Conv2DTranspose, LeakyReLU, Reshape 7 | from keras.optimizers import Adam 8 | import tensorflow as tf 9 | from keras import backend as K 10 | from keras.losses import binary_crossentropy 11 | from non_local import non_local_block 12 | from kernal import hyperNet 13 | from BiasNet import BiasNet 14 | 15 | img_rows = 512 16 | img_cols = 512 17 | in_c = 1 18 | 19 | 20 | def dice_coef(y_true, y_pred): 21 | y_pred = tf.cast((y_pred > 0.5), tf.float32) 22 | intersection = K.sum(y_true * y_pred) 23 | union = K.sum(y_true) + K.sum(y_pred) 24 | return (2. * intersection + 0.01) / (union + 0.01) 25 | 26 | 27 | def soft_dice_loss(y_true, y_pred): 28 | numerator = 2. * K.sum(y_pred * y_true) + 1.0 29 | denominator = K.sum(K.square(y_pred)) + K.sum(K.square(y_true)) + 1.0 30 | loss = 1 - (numerator / denominator) 31 | return loss 32 | 33 | 34 | def combine_loss(y_true, y_pred): 35 | crossentropy = binary_crossentropy(y_true, y_pred) 36 | return soft_dice_loss(y_true, y_pred) + crossentropy 37 | 38 | 39 | def Tversky_loss(b): 40 | def loss(y_true, y_pred): 41 | beta = b 42 | TP = K.sum(y_pred * y_true) 43 | FN = beta * K.sum((1 - y_pred) * y_true) 44 | FP = (1 - beta) * K.sum(y_pred * (1 - y_true)) 45 | return 1 - (TP + 1) / (TP + FN + FP + 1) 46 | 47 | return loss 48 | 49 | 50 | def coverage(y_true, y_pred): 51 | y_pred = tf.cast((y_pred > 0.5), tf.float32) 52 | return tf.reduce_sum(y_true * y_pred) / (tf.reduce_sum(y_true) + K.epsilon()) 53 | 54 | 55 | def conv_block(m, dim, acti='relu', bn=False, res=False, do=0): 56 | n = Conv2D(dim, (3, 3), padding='same', dilation_rate=(1, 1))(m) 57 | n = BatchNormalization()(n) if bn else n 58 | n = Activation(acti)(n) 59 | 60 | n = Dropout(do)(n) if do else n 61 | 62 | n = Conv2D(dim, (3, 3), padding='same', dilation_rate=(1, 1))(n) 63 | n = BatchNormalization()(n) if bn else n 64 | n = Activation(acti)(n) 65 | return Concatenate()([m, n]) if res else n 66 | 67 | 68 | def p_conv(ip, kernal): 69 | pos = tf.squeeze(kernal, axis=0) 70 | out = tf.nn.convolution(ip, pos, padding='SAME') 71 | # out = Activation('relu')(out) 72 | return out 73 | 74 | 75 | def p_kernal(ip, x_dim, y_dim, ch_in, ch_out): 76 | num_c = int(ch_in * ch_out) 77 | pos = Conv2D(16, (1, 1), padding='same', activation=None, use_bias=True, kernel_initializer='he_normal')(ip) 78 | pos = LeakyReLU(alpha=0.1)(pos) 79 | 80 | pos = Conv2D(16, (1, 1), padding='same', activation=None, use_bias=True, kernel_initializer='he_normal')(pos) 81 | pos = LeakyReLU(alpha=0.1)(pos) 82 | 83 | pos = Conv2D(4, (1, 1), padding='same', activation=None, use_bias=True, kernel_initializer='he_normal')(pos) 84 | pos = LeakyReLU(alpha=0.1)(pos) 85 | 86 | pos = Conv2D(num_c, (1, 1), padding='same', activation=None, use_bias=True, kernel_initializer='he_normal')(pos) 87 | pos = Reshape((x_dim, y_dim, ch_in, ch_out))(pos) 88 | return pos 89 | 90 | 91 | def hyper_block(ip, x_dim, y_dim, ch_in, ch_out, acti='relu', bn=False, do=0, mode='xy', multi=True, res=False): 92 | input_channel = ch_in 93 | kernal1 = Lambda(lambda x: hyperNet(x_dim, y_dim, input_channel, ch_out))(ip) 94 | kernal1 = p_kernal(kernal1, x_dim, y_dim, input_channel, ch_out) 95 | n = Lambda(lambda x: p_conv(x[0], x[1]))([ip, kernal1]) 96 | n = BiasNet()(n) 97 | n = BatchNormalization()(n) if bn else n 98 | if acti: 99 | n = Activation(acti)(n) 100 | n = Dropout(do)(n) if do else n 101 | if multi: 102 | kernal2 = Lambda(lambda x: hyperNet(x_dim, y_dim, ch_out, ch_out))(ip) 103 | kernal2 = p_kernal(kernal2, x_dim, y_dim, ch_out, ch_out) 104 | n = Lambda(lambda x: p_conv(x[0], x[1]))([n, kernal2]) 105 | n = BiasNet()(n) 106 | n = BatchNormalization()(n) if bn else n 107 | if acti: 108 | n = Activation(acti)(n) 109 | 110 | return Concatenate()([ip, n]) if res else n 111 | 112 | 113 | def combine_block(ip, x_dim, y_dim, ch_in, ch_out, acti='relu', bn=False, do=0, mode='xy', multi=True, res=False): 114 | n = Conv2D(ch_out, (3, 3), padding='same')(ip) 115 | n = BatchNormalization()(n) if bn else n 116 | n = Activation(acti)(n) 117 | 118 | n = Dropout(do)(n) if do else n 119 | 120 | kernal1 = Lambda(lambda x: hyperNet(x_dim, y_dim, ch_out, ch_out))(n) 121 | kernal1 = p_kernal(kernal1, x_dim, y_dim, ch_out, ch_out) 122 | n = Lambda(lambda x: p_conv(x[0], x[1]))([n, kernal1]) 123 | n = BiasNet()(n) 124 | n = BatchNormalization()(n) if bn else n 125 | if acti: 126 | n = Activation(acti)(n) 127 | return Concatenate()([ip, n]) if res else n 128 | 129 | 130 | def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res, att, nl, pos, hyper): 131 | if depth > 0: 132 | if hyper: 133 | ft = hyper 134 | in_c = 1 if res else 1 / 2 135 | n = hyper_block(m, ft, ft, int(in_c * dim), dim, acti=acti, bn=bn, do=do, 136 | mode='xy', 137 | multi=True, 138 | res=res) 139 | 140 | else: 141 | n = conv_block(m, dim, acti, bn, res) 142 | if att: 143 | n1 = Conv2D(1, 1, padding='same', activation='linear')(n) 144 | n1 = Softmax(axis=(1, 2))(n1) 145 | n = Lambda(lambda x: tf.math.multiply(x[0], x[1]))([n, n1]) 146 | m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n) 147 | m = level_block(m, int(inc * dim), depth - 1, inc, acti, do, bn, mp, up, res, att, nl, pos, hyper) 148 | if up: 149 | m = UpSampling2D()(m) 150 | if hyper: 151 | in_c = 4 if res else 2 152 | m = hyper_block(m, hyper, hyper, int(in_c * dim), dim, acti=acti, bn=bn, do=do, 153 | mode='xy', 154 | multi=False, 155 | res=False) 156 | else: 157 | m = Conv2D(dim, 3, activation=acti, padding='same')(m) 158 | else: 159 | m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m) 160 | n = Concatenate()([n, m]) 161 | if hyper: 162 | in_c = 3 if res else 2 163 | m = hyper_block(n, hyper, hyper, int(in_c * dim), dim, acti=acti, bn=bn, do=do, mode='xy', multi=True, 164 | res=res) 165 | else: 166 | m = conv_block(n, dim, acti, bn, res) 167 | else: 168 | if hyper: 169 | in_c = 1 if res else 1 / 2 170 | m = hyper_block(m, hyper, hyper, int(in_c * dim), dim, acti=acti, bn=bn, do=do, mode='xy', multi=True, 171 | res=res) 172 | else: 173 | m = conv_block(m, int(dim), acti, bn, res, do) 174 | if nl or pos: 175 | m = non_local_block(m, compression=1, mode='dot') 176 | return m 177 | 178 | 179 | def unet(img_shape=(img_rows, img_cols, in_c), out_ch=1, start_ch=16, depth=4, inc_rate=2., activation='relu', 180 | dropout=0, batchnorm=False, maxpool=True, upconv=True, residual=False, att=False, nl=False, pos=False, 181 | hyper=False): 182 | i = Input(shape=img_shape) 183 | if hyper: 184 | i1 = hyper_block(i, hyper, hyper, 1, int(1 / 2 * start_ch), acti=activation, bn=batchnorm, do=dropout, 185 | mode='xy', multi=False) 186 | else: 187 | i1 = i 188 | o1 = level_block(i1, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual, att, nl, 189 | pos, hyper) 190 | o1 = Conv2D(out_ch, 1, activation='sigmoid')(o1) 191 | model = Model(inputs=i, outputs=[o1]) 192 | model.compile(optimizer=Adam(lr=1e-4), loss=soft_dice_loss, metrics=[dice_coef, coverage]) 193 | return model 194 | --------------------------------------------------------------------------------