├── README.md ├── data.py ├── data_t.py ├── dice_loss.py ├── hrnet_keras.py ├── main_train.py └── test_labl.py /README.md: -------------------------------------------------------------------------------- 1 | # segmentation_hrnet_keras 2 | 将pytorch代码翻译成keras 3 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import scipy.misc as im 2 | import numpy as np 3 | import os 4 | import random 5 | 6 | 7 | def load_img_pairs(imgFile, imgLabelFile): 8 | 9 | img0 = im.imread(imgFile) 10 | img1 = im.imresize(img0, (256, 256)) 11 | img2 = np.asarray(img1, dtype="float32") 12 | img2 = img2/255 13 | 14 | labl0 = im.imread(imgLabelFile) 15 | labl1 = im.imresize(labl0, (256, 256)) 16 | labl2 = np.asarray(labl1, dtype="float32") 17 | 18 | labl3, labl4 = np.unique(labl2, return_inverse=True) 19 | labl4 = np.reshape(labl4, (256, 256)) 20 | labl5 = np.zeros((256, 256, 3), dtype="float32") 21 | for i in range(256): 22 | for j in range(256): 23 | if(labl4[i, j] < 40): 24 | labl5[i, j] = [0, 0, 1] 25 | elif(labl4[i, j] < 130): 26 | labl5[i, j] = [0, 1, 0] 27 | else: 28 | labl5[i, j] = [1, 0, 0] 29 | return img2, labl5 30 | 31 | 32 | def load_data_gen(batch_size, begin, end, dataDir, lablDir): 33 | 34 | imgs = os.listdir(dataDir) 35 | labls = os.listdir(lablDir) 36 | i = 0 37 | lis = [] 38 | while i < end-begin: 39 | lis.append((imgs[i+begin], labls[i+begin])) 40 | i += 1 41 | 42 | random.shuffle(lis) 43 | 44 | i = 0 45 | img_gen = np.zeros((batch_size, 256, 256, 3), dtype='float32') 46 | labl_gen = np.zeros((batch_size, 256, 256, 3), dtype='float32') 47 | while True: 48 | 49 | img_gen[i % batch_size], labl_gen[i % batch_size] = load_img_pairs( 50 | dataDir+'/'+lis[i][0], lablDir+'/'+lis[i][1]) 51 | 52 | if (i % batch_size) == (batch_size-1): 53 | 54 | yield (img_gen, labl_gen) 55 | i += 1 56 | if i == end - begin: 57 | i = 0 58 | random.shuffle(lis) 59 | -------------------------------------------------------------------------------- /data_t.py: -------------------------------------------------------------------------------- 1 | import scipy.misc as im 2 | import numpy as np 3 | import os 4 | 5 | 6 | def load_test_img(imgFile): 7 | 8 | img0 = im.imread(imgFile) 9 | img1 = im.imresize(img0, (256, 256)) 10 | img2 = np.asarray(img1, dtype="float32") 11 | 12 | return img2/255 13 | 14 | 15 | def load_test_data(dataDir): 16 | imgs = os.listdir(dataDir) 17 | num = len(imgs) 18 | 19 | for i in range(num): 20 | imgs[i] = load_test_img(dataDir+'/'+imgs[i]) 21 | 22 | data_test = np.asarray(imgs) 23 | 24 | return data_test 25 | 26 | 27 | ''' 28 | dataDir="d:/my code/deep_learning/dis_cup/REFUGE-Validation400" 29 | 30 | saveDir="d:/my code/deep_learning/dis_cup/" 31 | data_test=load_test_data(dataDir) 32 | np.save(saveDir+"data_t.npy",data_test) 33 | ''' 34 | -------------------------------------------------------------------------------- /dice_loss.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | def dice_coef(y_true, y_pred, smooth=1): 6 | """ 7 | Dice = (2*|X & Y|)/ (|X|+ |Y|) 8 | = 2*sum(|A*B|)/(sum(A^2)+sum(B^2)) 9 | ref: https://arxiv.org/pdf/1606.04797v1.pdf 10 | """ 11 | dice_weight=np.array([1.0/3,1.0/3,1.0/3],dtype="float32") 12 | dice_weight=tf.convert_to_tensor(dice_weight,dtype="float32") 13 | intersection = K.sum(K.abs(y_true * y_pred), axis=(0,1,2)) 14 | result= (2. * intersection + smooth) / (K.sum(K.square(y_true),(0,1,2)) + K.sum(K.square(y_pred),(0,1,2)) + smooth) 15 | return K.sum(dice_weight*result) 16 | 17 | 18 | def dice_loss(y_true, y_pred): 19 | return 1-dice_coef(y_true, y_pred) -------------------------------------------------------------------------------- /hrnet_keras.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.layers import Input 3 | from keras.layers import Activation 4 | from keras.layers import BatchNormalization 5 | from keras.layers import Conv2D 6 | from keras.layers import UpSampling2D 7 | from keras.layers import Add 8 | from keras.initializers import RandomNormal 9 | from keras.optimizers import Adam 10 | from dice_loss import dice_coef, dice_loss 11 | 12 | 13 | def conv(x, outsize, kernel_size, strides_=1, padding_='same', activation=None): 14 | return Conv2D(outsize, kernel_size, strides=strides_, padding=padding_, kernel_initializer=RandomNormal( 15 | stddev=0.001), use_bias=False, activation=activation)(x) 16 | 17 | 18 | def Bottleneck(x, size, downsampe=False): 19 | residual = x 20 | 21 | out = conv(x, size, 1, padding_='valid') 22 | out = BatchNormalization(epsilon=1e-5, momentum=0.1)(out) 23 | out = Activation('relu')(out) 24 | 25 | out = conv(out, size, 3) 26 | out = BatchNormalization(epsilon=1e-5, momentum=0.1)(out) 27 | out = Activation('relu')(out) 28 | 29 | out = conv(out, size * 4, 1, padding_='valid') 30 | out = BatchNormalization(epsilon=1e-5, momentum=0.1)(out) 31 | 32 | if downsampe: 33 | residual = conv(x, size * 4, 1, padding_='valid') 34 | residual = BatchNormalization(epsilon=1e-5, momentum=0.1)(residual) 35 | 36 | out = Add()([out, residual]) 37 | out = Activation('relu')(out) 38 | 39 | return out 40 | 41 | 42 | def BasicBlock(x, size, downsampe=False): 43 | residual = x 44 | 45 | out = conv(x, size, 3) 46 | out = BatchNormalization(epsilon=1e-5, momentum=0.1)(out) 47 | out = Activation('relu')(out) 48 | 49 | out = conv(out, size, 3) 50 | out = BatchNormalization(epsilon=1e-5, momentum=0.1)(out) 51 | 52 | if downsampe: 53 | residual = conv(x, size, 1, padding_='valid') 54 | residual = BatchNormalization(epsilon=1e-5, momentum=0.1)(residual) 55 | 56 | out = Add()([out, residual]) 57 | out = Activation('relu')(out) 58 | 59 | return out 60 | 61 | 62 | def layer1(x): 63 | x = Bottleneck(x, 64, downsampe=True) 64 | x = Bottleneck(x, 64) 65 | x = Bottleneck(x, 64) 66 | x = Bottleneck(x, 64) 67 | 68 | return x 69 | 70 | 71 | def transition_layer(x, in_channels, out_channels): 72 | num_in = len(in_channels) 73 | num_out = len(out_channels) 74 | out = [] 75 | 76 | for i in range(num_out): 77 | if i < num_in: 78 | if in_channels[i] != out_channels[i]: 79 | residual = conv(x[i], out_channels[i], 3) 80 | residual = BatchNormalization( 81 | epsilon=1e-5, momentum=0.1)(residual) 82 | residual = Activation('relu')(residual) 83 | out.append(residual) 84 | else: 85 | out.append(x[i]) 86 | else: 87 | residual = conv(x[-1], out_channels[i], 3, strides_=2) 88 | residual = BatchNormalization(epsilon=1e-5, momentum=0.1)(residual) 89 | residual = Activation('relu')(residual) 90 | out.append(residual) 91 | 92 | return out 93 | 94 | 95 | def branches(x, block_num, channels): 96 | out = [] 97 | for i in range(len(channels)): 98 | residual = x[i] 99 | for j in range(block_num): 100 | residual = BasicBlock(residual, channels[i]) 101 | out.append(residual) 102 | return out 103 | 104 | 105 | def fuse_layers(x, channels, multi_scale_output=True): 106 | out = [] 107 | 108 | for i in range(len(channels) if multi_scale_output else 1): 109 | residual = x[i] 110 | for j in range(len(channels)): 111 | if j > i: 112 | y = conv(x[j], channels[i], 1, padding_='valid') 113 | y = BatchNormalization(epsilon=1e-5, momentum=0.1)(y) 114 | y = UpSampling2D(size=2 ** (j - i))(y) 115 | residual = Add()([residual, y]) 116 | elif j < i: 117 | y = x[j] 118 | for k in range(i - j): 119 | if k == i - j - 1: 120 | y = conv(y, channels[i], 3, strides_=2) 121 | y = BatchNormalization(epsilon=1e-5, momentum=0.1)(y) 122 | else: 123 | y = conv(y, channels[j], 3, strides_=2) 124 | y = BatchNormalization(epsilon=1e-5, momentum=0.1)(y) 125 | y = Activation('relu')(y) 126 | residual = Add()([residual, y]) 127 | 128 | residual = Activation('relu')(residual) 129 | out.append(residual) 130 | 131 | return out 132 | 133 | 134 | def HighResolutionModule(x, channels, multi_scale_output=True): 135 | residual = branches(x, 4, channels) 136 | out = fuse_layers(residual, channels, 137 | multi_scale_output=multi_scale_output) 138 | return out 139 | 140 | 141 | def stage(x, num_modules, channels, multi_scale_output=True): 142 | out = x 143 | for i in range(num_modules): 144 | if i == num_modules - 1 and multi_scale_output == False: 145 | out = HighResolutionModule(out, channels, multi_scale_output=False) 146 | else: 147 | out = HighResolutionModule(out, channels) 148 | 149 | return out 150 | 151 | 152 | def hrnet_keras(input_size=(256, 256, 3)): 153 | channels_2 = [32, 64] 154 | channels_3 = [32, 64, 128] 155 | channels_4 = [32, 64, 128, 256] 156 | num_modules_2 = 1 157 | num_modules_3 = 4 158 | num_modules_4 = 3 159 | 160 | inputs = Input(input_size) 161 | x = conv(inputs, 64, 3, strides_=2) 162 | x = BatchNormalization(epsilon=1e-5, momentum=0.1)(x) 163 | x = conv(x, 64, 3, strides_=2) 164 | x = BatchNormalization(epsilon=1e-5, momentum=0.1)(x) 165 | x = Activation('relu')(x) 166 | 167 | la1 = layer1(x) 168 | tr1 = transition_layer([la1], [256], channels_2) 169 | st2 = stage(tr1, num_modules_2, channels_2) 170 | tr2 = transition_layer(st2, channels_2, channels_3) 171 | st3 = stage(tr2, num_modules_3, channels_3) 172 | tr3 = transition_layer(st3, channels_3, channels_4) 173 | st4 = stage(tr3, num_modules_4, channels_4, multi_scale_output=False) 174 | up1 = UpSampling2D()(st4[0]) 175 | up1 = conv(up1, 32, 3) 176 | up1 = BatchNormalization(epsilon=1e-5, momentum=0.1)(up1) 177 | up1 = Activation('relu')(up1) 178 | up2 = UpSampling2D()(up1) 179 | up2 = conv(up2, 32, 3) 180 | up2 = BatchNormalization(epsilon=1e-5, momentum=0.1)(up2) 181 | up2 = Activation('relu')(up2) 182 | final = conv(up2, 3, 1, padding_='valid', activation='softmax') 183 | 184 | model = Model(inputs=inputs, outputs=final) 185 | 186 | model.compile(optimizer=Adam(lr=1e-4), 187 | loss=dice_loss, metrics=[dice_coef]) 188 | 189 | return model 190 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from data import load_data_gen 4 | from dice_loss import dice_coef, dice_loss 5 | from keras.callbacks import TensorBoard 6 | from keras.models import load_model 7 | from hrnet_keras import hrnet_keras 8 | 9 | 10 | modelFile = "../model.hdf5" 11 | model_weight_path = "../model_weight.pkl" 12 | lablDir = "../Annotation-Training400/Disc_Cup_Masks/label" 13 | dataDir = "../Training400/imgs" 14 | batch_size = 1 15 | train_begin = 40 16 | train_end = 120 17 | val_begin = 200 18 | val_end = 230 19 | LoadWeight = False 20 | train_gen = load_data_gen(batch_size, train_begin, train_end, dataDir, lablDir) 21 | val_gen = load_data_gen(batch_size, val_begin, val_end, dataDir, lablDir) 22 | train_steps = (train_end-train_begin)/batch_size 23 | val_steps = (val_end-val_begin)/batch_size 24 | callback = [TensorBoard(log_dir='../train_logs')] 25 | 26 | 27 | model = hrnet_keras() 28 | 29 | if LoadWeight: 30 | with open(model_weight_path, 'rb') as fpkl: 31 | weight = pickle.load(fpkl) 32 | model.set_weights(weight) 33 | 34 | model.fit_generator(train_gen, steps_per_epoch=train_steps, callbacks=callback, epochs=100, 35 | validation_data=val_gen, validation_steps=val_steps) 36 | 37 | 38 | with open(model_weight_path, 'wb') as fpkl: 39 | weight = model.get_weights() 40 | pickle.dump(weight, fpkl, protocol=pickle.HIGHEST_PROTOCOL) 41 | -------------------------------------------------------------------------------- /test_labl.py: -------------------------------------------------------------------------------- 1 | import scipy.misc as im 2 | import numpy as np 3 | import pickle 4 | from data_t import load_test_img 5 | from hrnet_keras import hrnet_keras 6 | 7 | 8 | model_weight = "../model_weight.pkl" 9 | imgDir = "../Training400/imgs/n0094.jpg" 10 | save_path = "../test_labl/" 11 | 12 | 13 | img_test = load_test_img(imgDir) 14 | img_test = np.reshape(img_test, (1, 256, 256, 3)) 15 | 16 | model = hrnet_keras() 17 | with open(model_weight, 'rb') as fpkl: 18 | weight = pickle.load(fpkl) 19 | model.set_weights(weight) 20 | 21 | labl_test = np.zeros((1, 256, 256), dtype='float32') 22 | 23 | result_test = model.predict(img_test, batch_size=1) 24 | 25 | for i in range(len(result_test)): 26 | for j in range(len(result_test[i])): 27 | for k in range(len(result_test[i][j])): 28 | if result_test[i][j][k][0] >= result_test[i][j][k][1] and result_test[i][j][k][0] >= result_test[i][j][k][2]: 29 | labl_test[i][j][k] = 255 30 | 31 | elif result_test[i][j][k][1] >= result_test[i][j][k][0] and result_test[i][j][k][1] >= result_test[i][j][k][2]: 32 | labl_test[i][j][k] = 127 33 | 34 | else: 35 | labl_test[i][j][k] = 0 36 | 37 | im.imsave(save_path+"n0094.jpg", labl_test[0]) 38 | --------------------------------------------------------------------------------