├── README.md ├── dataProcess.py ├── seg_metrics.py ├── seg_unet.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Unet_RSimage_Multi-band_Multi-class 2 | keras遥感图像Unet语义分割(支持多波段&多类) 3 | 文字说明详见知乎 4 | https://zhuanlan.zhihu.com/p/161925744?utm_source=zhihu&utm_medium=social&utm_oi=832719838752444416 5 | -------------------------------------------------------------------------------- /dataProcess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import gdal 5 | import cv2 6 | 7 | # 获取颜色字典 8 | # labelFolder 标签文件夹,之所以遍历文件夹是因为一张标签可能不包含所有类别颜色 9 | # classNum 类别总数(含背景) 10 | def color_dict(labelFolder, classNum): 11 | colorDict = [] 12 | # 获取文件夹内的文件名 13 | ImageNameList = os.listdir(labelFolder) 14 | for i in range(len(ImageNameList)): 15 | ImagePath = labelFolder + "/" + ImageNameList[i] 16 | img = cv2.imread(ImagePath).astype(np.uint32) 17 | # 如果是灰度,转成RGB 18 | if(len(img.shape) == 2): 19 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB).astype(np.uint32) 20 | # 为了提取唯一值,将RGB转成一个数 21 | img_new = img[:,:,0] * 1000000 + img[:,:,1] * 1000 + img[:,:,2] 22 | unique = np.unique(img_new) 23 | # 将第i个像素矩阵的唯一值添加到colorDict中 24 | for j in range(unique.shape[0]): 25 | colorDict.append(unique[j]) 26 | # 对目前i个像素矩阵里的唯一值再取唯一值 27 | colorDict = sorted(set(colorDict)) 28 | # 若唯一值数目等于总类数(包括背景)ClassNum,停止遍历剩余的图像 29 | if(len(colorDict) == classNum): 30 | break 31 | # 存储颜色的RGB字典,用于预测时的渲染结果 32 | colorDict_RGB = [] 33 | for k in range(len(colorDict)): 34 | # 对没有达到九位数字的结果进行左边补零(eg:5,201,111->005,201,111) 35 | color = str(colorDict[k]).rjust(9, '0') 36 | # 前3位R,中3位G,后3位B 37 | color_RGB = [int(color[0 : 3]), int(color[3 : 6]), int(color[6 : 9])] 38 | colorDict_RGB.append(color_RGB) 39 | # 转为numpy格式 40 | colorDict_RGB = np.array(colorDict_RGB) 41 | # 存储颜色的GRAY字典,用于预处理时的onehot编码 42 | colorDict_GRAY = colorDict_RGB.reshape((colorDict_RGB.shape[0], 1 ,colorDict_RGB.shape[1])).astype(np.uint8) 43 | colorDict_GRAY = cv2.cvtColor(colorDict_GRAY, cv2.COLOR_BGR2GRAY) 44 | return colorDict_RGB, colorDict_GRAY 45 | 46 | # 读取图像像素矩阵 47 | # fileName 图像文件名 48 | def readTif(fileName): 49 | dataset = gdal.Open(fileName) 50 | width = dataset.RasterXSize 51 | height = dataset.RasterYSize 52 | GdalImg_data = dataset.ReadAsArray(0, 0, width, height) 53 | return GdalImg_data 54 | 55 | # 数据预处理:图像归一化+标签onehot编码 56 | # img 图像数据 57 | # label 标签数据 58 | # classNum 类别总数(含背景) 59 | # colorDict_GRAY 颜色字典 60 | def dataPreprocess(img, label, classNum, colorDict_GRAY): 61 | # 归一化 62 | img = img / 255.0 63 | for i in range(colorDict_GRAY.shape[0]): 64 | label[label == colorDict_GRAY[i][0]] = i 65 | # 将数据厚度扩展到classNum层 66 | new_label = np.zeros(label.shape + (classNum,)) 67 | # 将平面的label的每类,都单独变成一层 68 | for i in range(classNum): 69 | new_label[label == i,i] = 1 70 | label = new_label 71 | return (img, label) 72 | 73 | # 训练数据生成器 74 | # batch_size 批大小 75 | # train_image_path 训练图像路径 76 | # train_label_path 训练标签路径 77 | # classNum 类别总数(含背景) 78 | # colorDict_GRAY 颜色字典 79 | # resize_shape resize大小 80 | def trainGenerator(batch_size, train_image_path, train_label_path, classNum, colorDict_GRAY, resize_shape = None): 81 | imageList = os.listdir(train_image_path) 82 | labelList = os.listdir(train_label_path) 83 | img = readTif(train_image_path + "\\" + imageList[0]) 84 | # GDAL读数据是(BandNum,Width,Height)要转换为->(Width,Height,BandNum) 85 | img = img.swapaxes(1, 0) 86 | img = img.swapaxes(1, 2) 87 | # 无限生成数据 88 | while(True): 89 | img_generator = np.zeros((batch_size, img.shape[0], img.shape[1], img.shape[2]), np.uint8) 90 | label_generator = np.zeros((batch_size, img.shape[0], img.shape[1]), np.uint8) 91 | if(resize_shape != None): 92 | img_generator = np.zeros((batch_size, resize_shape[0], resize_shape[1], resize_shape[2]), np.uint8) 93 | label_generator = np.zeros((batch_size, resize_shape[0], resize_shape[1]), np.uint8) 94 | # 随机生成一个batch的起点 95 | rand = random.randint(0, len(imageList) - batch_size) 96 | for j in range(batch_size): 97 | img = readTif(train_image_path + "\\" + imageList[rand + j]) 98 | img = img.swapaxes(1, 0) 99 | img = img.swapaxes(1, 2) 100 | # 改变图像尺寸至特定尺寸( 101 | # 因为resize用的不多,我就用了OpenCV实现的,这个不支持多波段,需要的话可以用np进行resize 102 | if(resize_shape != None): 103 | img = cv2.resize(img, (resize_shape[0], resize_shape[1])) 104 | 105 | img_generator[j] = img 106 | 107 | label = readTif(train_label_path + "\\" + labelList[rand + j]).astype(np.uint8) 108 | # 若为彩色,转为灰度 109 | if(len(label.shape) == 3): 110 | label = label.swapaxes(1, 0) 111 | label = label.swapaxes(1, 2) 112 | label = cv2.cvtColor(label, cv2.COLOR_RGB2GRAY) 113 | if(resize_shape != None): 114 | label = cv2.resize(label, (resize_shape[0], resize_shape[1])) 115 | label_generator[j] = label 116 | img_generator, label_generator = dataPreprocess(img_generator, label_generator, classNum, colorDict_GRAY) 117 | yield (img_generator,label_generator) 118 | 119 | # 测试数据生成器 120 | # test_iamge_path 测试数据路径 121 | # resize_shape resize大小 122 | def testGenerator(test_iamge_path, resize_shape = None): 123 | imageList = os.listdir(test_iamge_path) 124 | for i in range(len(imageList)): 125 | img = readTif(test_iamge_path + "\\" + imageList[i]) 126 | img = img.swapaxes(1, 0) 127 | img = img.swapaxes(1, 2) 128 | # 归一化 129 | img = img / 255.0 130 | if(resize_shape != None): 131 | # 改变图像尺寸至特定尺寸 132 | img = cv2.resize(img, (resize_shape[0], resize_shape[1])) 133 | # 将测试图片扩展一个维度,与训练时的输入[batch_size,img.shape]保持一致 134 | img = np.reshape(img, (1, ) + img.shape) 135 | yield img 136 | 137 | # 保存结果 138 | # test_iamge_path 测试数据图像路径 139 | # test_predict_path 测试数据图像预测结果路径 140 | # model_predict 模型的预测结果 141 | # color_dict 颜色词典 142 | def saveResult(test_image_path, test_predict_path, model_predict, color_dict, output_size): 143 | imageList = os.listdir(test_image_path) 144 | for i, img in enumerate(model_predict): 145 | channel_max = np.argmax(img, axis = -1) 146 | img_out = np.uint8(color_dict[channel_max.astype(np.uint8)]) 147 | # 修改差值方式为最邻近差值 148 | img_out = cv2.resize(img_out, (output_size[0], output_size[1]), interpolation = cv2.INTER_NEAREST) 149 | # 保存为无损压缩png 150 | cv2.imwrite(test_predict_path + "\\" + imageList[i][:-4] + ".png", img_out) -------------------------------------------------------------------------------- /seg_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 21 15:29:02 2020 4 | 5 | @author: 12624 6 | """ 7 | 8 | import numpy as np 9 | import cv2 10 | import os 11 | 12 | """ 13 | 混淆矩阵 14 | P\L P N 15 | P TP FP 16 | N FN TN 17 | """ 18 | # 获取颜色字典 19 | # labelFolder 标签文件夹,之所以遍历文件夹是因为一张标签可能不包含所有类别颜色 20 | # classNum 类别总数(含背景) 21 | def color_dict(labelFolder, classNum): 22 | colorDict = [] 23 | # 获取文件夹内的文件名 24 | ImageNameList = os.listdir(labelFolder) 25 | for i in range(len(ImageNameList)): 26 | ImagePath = labelFolder + "/" + ImageNameList[i] 27 | img = cv2.imread(ImagePath).astype(np.uint32) 28 | # 如果是灰度,转成RGB 29 | if(len(img.shape) == 2): 30 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB).astype(np.uint32) 31 | # 为了提取唯一值,将RGB转成一个数 32 | img_new = img[:,:,0] * 1000000 + img[:,:,1] * 1000 + img[:,:,2] 33 | unique = np.unique(img_new) 34 | # 将第i个像素矩阵的唯一值添加到colorDict中 35 | for j in range(unique.shape[0]): 36 | colorDict.append(unique[j]) 37 | # 对目前i个像素矩阵里的唯一值再取唯一值 38 | colorDict = sorted(set(colorDict)) 39 | # 若唯一值数目等于总类数(包括背景)ClassNum,停止遍历剩余的图像 40 | if(len(colorDict) == classNum): 41 | break 42 | # 存储颜色的BGR字典,用于预测时的渲染结果 43 | colorDict_BGR = [] 44 | for k in range(len(colorDict)): 45 | # 对没有达到九位数字的结果进行左边补零(eg:5,201,111->005,201,111) 46 | color = str(colorDict[k]).rjust(9, '0') 47 | # 前3位B,中3位G,后3位R 48 | color_BGR = [int(color[0 : 3]), int(color[3 : 6]), int(color[6 : 9])] 49 | colorDict_BGR.append(color_BGR) 50 | # 转为numpy格式 51 | colorDict_BGR = np.array(colorDict_BGR) 52 | # 存储颜色的GRAY字典,用于预处理时的onehot编码 53 | colorDict_GRAY = colorDict_BGR.reshape((colorDict_BGR.shape[0], 1 ,colorDict_BGR.shape[1])).astype(np.uint8) 54 | colorDict_GRAY = cv2.cvtColor(colorDict_GRAY, cv2.COLOR_BGR2GRAY) 55 | return colorDict_BGR, colorDict_GRAY 56 | 57 | def ConfusionMatrix(numClass, imgPredict, Label): 58 | # 返回混淆矩阵 59 | mask = (Label >= 0) & (Label < numClass) 60 | label = numClass * Label[mask] + imgPredict[mask] 61 | count = np.bincount(label, minlength = numClass**2) 62 | confusionMatrix = count.reshape(numClass, numClass) 63 | return confusionMatrix 64 | 65 | def OverallAccuracy(confusionMatrix): 66 | # 返回所有类的整体像素精度OA 67 | # acc = (TP + TN) / (TP + TN + FP + TN) 68 | OA = np.diag(confusionMatrix).sum() / confusionMatrix.sum() 69 | return OA 70 | 71 | def Precision(confusionMatrix): 72 | # 返回所有类别的精确率precision 73 | precision = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 1) 74 | return precision 75 | 76 | def Recall(confusionMatrix): 77 | # 返回所有类别的召回率recall 78 | recall = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 0) 79 | return recall 80 | 81 | def F1Score(confusionMatrix): 82 | precision = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 1) 83 | recall = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 0) 84 | f1score = 2 * precision * recall / (precision + recall) 85 | return f1score 86 | def IntersectionOverUnion(confusionMatrix): 87 | # 返回交并比IoU 88 | intersection = np.diag(confusionMatrix) 89 | union = np.sum(confusionMatrix, axis = 1) + np.sum(confusionMatrix, axis = 0) - np.diag(confusionMatrix) 90 | IoU = intersection / union 91 | return IoU 92 | 93 | def MeanIntersectionOverUnion(confusionMatrix): 94 | # 返回平均交并比mIoU 95 | intersection = np.diag(confusionMatrix) 96 | union = np.sum(confusionMatrix, axis = 1) + np.sum(confusionMatrix, axis = 0) - np.diag(confusionMatrix) 97 | IoU = intersection / union 98 | mIoU = np.nanmean(IoU) 99 | return mIoU 100 | 101 | def Frequency_Weighted_Intersection_over_Union(confusionMatrix): 102 | # 返回频权交并比FWIoU 103 | freq = np.sum(confusionMatrix, axis=1) / np.sum(confusionMatrix) 104 | iu = np.diag(confusionMatrix) / ( 105 | np.sum(confusionMatrix, axis = 1) + 106 | np.sum(confusionMatrix, axis = 0) - 107 | np.diag(confusionMatrix)) 108 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 109 | return FWIoU 110 | 111 | ################################################################# 112 | # 标签图像文件夹 113 | LabelPath = r"Data\test\label1" 114 | # 预测图像文件夹 115 | PredictPath = r"Data\test\predict1" 116 | # 类别数目(包括背景) 117 | classNum = 3 118 | ################################################################# 119 | 120 | # 获取类别颜色字典 121 | colorDict_BGR, colorDict_GRAY = color_dict(LabelPath, classNum) 122 | 123 | # 获取文件夹内所有图像 124 | labelList = os.listdir(LabelPath) 125 | PredictList = os.listdir(PredictPath) 126 | 127 | # 读取第一个图像,后面要用到它的shape 128 | Label0 = cv2.imread(LabelPath + "//" + labelList[0], 0) 129 | 130 | # 图像数目 131 | label_num = len(labelList) 132 | 133 | # 把所有图像放在一个数组里 134 | label_all = np.zeros((label_num, ) + Label0.shape, np.uint8) 135 | predict_all = np.zeros((label_num, ) + Label0.shape, np.uint8) 136 | for i in range(label_num): 137 | Label = cv2.imread(LabelPath + "//" + labelList[i]) 138 | Label = cv2.cvtColor(Label, cv2.COLOR_BGR2GRAY) 139 | label_all[i] = Label 140 | Predict = cv2.imread(PredictPath + "//" + PredictList[i]) 141 | Predict = cv2.cvtColor(Predict, cv2.COLOR_BGR2GRAY) 142 | predict_all[i] = Predict 143 | 144 | # 把颜色映射为0,1,2,3... 145 | for i in range(colorDict_GRAY.shape[0]): 146 | label_all[label_all == colorDict_GRAY[i][0]] = i 147 | predict_all[predict_all == colorDict_GRAY[i][0]] = i 148 | 149 | # 拉直成一维 150 | label_all = label_all.flatten() 151 | predict_all = predict_all.flatten() 152 | 153 | # 计算混淆矩阵及各精度参数 154 | confusionMatrix = ConfusionMatrix(classNum, predict_all, label_all) 155 | precision = Precision(confusionMatrix) 156 | recall = Recall(confusionMatrix) 157 | OA = OverallAccuracy(confusionMatrix) 158 | IoU = IntersectionOverUnion(confusionMatrix) 159 | FWIOU = Frequency_Weighted_Intersection_over_Union(confusionMatrix) 160 | mIOU = MeanIntersectionOverUnion(confusionMatrix) 161 | f1ccore = F1Score(confusionMatrix) 162 | 163 | for i in range(colorDict_BGR.shape[0]): 164 | # 输出类别颜色,需要安装webcolors,直接pip install webcolors 165 | try: 166 | import webcolors 167 | rgb = colorDict_BGR[i] 168 | rgb[0], rgb[2] = rgb[2], rgb[0] 169 | print(webcolors.rgb_to_name(rgb), end = " ") 170 | # 不安装的话,输出灰度值 171 | except: 172 | print(colorDict_GRAY[i][0], end = " ") 173 | print("") 174 | print("混淆矩阵:") 175 | print(confusionMatrix) 176 | print("精确度:") 177 | print(precision) 178 | print("召回率:") 179 | print(recall) 180 | print("F1-Score:") 181 | print(f1ccore) 182 | print("整体精度:") 183 | print(OA) 184 | print("IoU:") 185 | print(IoU) 186 | print("mIoU:") 187 | print(mIOU) 188 | print("FWIoU:") 189 | print(FWIOU) -------------------------------------------------------------------------------- /seg_unet.py: -------------------------------------------------------------------------------- 1 | from keras.models import Model 2 | from keras.layers import Input, BatchNormalization, Conv2D, MaxPooling2D, Dropout, concatenate, merge, UpSampling2D 3 | from keras.optimizers import Adam 4 | 5 | def unet(pretrained_weights = None, input_size = (256, 256, 4), classNum = 2, learning_rate = 1e-5): 6 | inputs = Input(input_size) 7 | # 2D卷积层 8 | conv1 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)) 9 | conv1 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)) 10 | # 对于空间数据的最大池化 11 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 12 | conv2 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)) 13 | conv2 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)) 14 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 15 | conv3 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)) 16 | conv3 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)) 17 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 18 | conv4 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)) 19 | conv4 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)) 20 | # Dropout正规化,防止过拟合 21 | drop4 = Dropout(0.5)(conv4) 22 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 23 | 24 | conv5 = BatchNormalization()(Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)) 25 | conv5 = BatchNormalization()(Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)) 26 | drop5 = Dropout(0.5)(conv5) 27 | # 上采样之后再进行卷积,相当于转置卷积操作 28 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 29 | 30 | try: 31 | merge6 = concatenate([drop4,up6],axis = 3) 32 | except: 33 | merge6 = merge([drop4,up6], mode = 'concat', concat_axis = 3) 34 | conv6 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)) 35 | conv6 = BatchNormalization()(Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)) 36 | 37 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 38 | try: 39 | merge7 = concatenate([conv3,up7],axis = 3) 40 | except: 41 | merge7 = merge([conv3,up7], mode = 'concat', concat_axis = 3) 42 | conv7 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)) 43 | conv7 = BatchNormalization()(Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)) 44 | 45 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 46 | try: 47 | merge8 = concatenate([conv2,up8],axis = 3) 48 | except: 49 | merge8 = merge([conv2,up8],mode = 'concat', concat_axis = 3) 50 | conv8 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)) 51 | conv8 = BatchNormalization()(Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)) 52 | 53 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 54 | try: 55 | merge9 = concatenate([conv1,up9],axis = 3) 56 | except: 57 | merge9 = merge([conv1,up9],mode = 'concat', concat_axis = 3) 58 | conv9 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)) 59 | conv9 = BatchNormalization()(Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)) 60 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 61 | conv10 = Conv2D(classNum, 1, activation = 'softmax')(conv9) 62 | 63 | model = Model(inputs = inputs, outputs = conv10) 64 | 65 | # 用于配置训练模型(优化器、目标函数、模型评估标准) 66 | model.compile(optimizer = Adam(lr = learning_rate), loss = 'categorical_crossentropy', metrics = ['accuracy']) 67 | 68 | # 如果有预训练的权重 69 | if(pretrained_weights): 70 | model.load_weights(pretrained_weights) 71 | 72 | return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from Model.seg_hrnet import seg_hrnet 2 | from dataProcess import testGenerator, saveResult, color_dict 3 | import os 4 | 5 | # 训练模型保存地址 6 | model_path = r"Model\hrnet_model.hdf5" 7 | # 测试数据路径 8 | test_iamge_path = r"Data\test\image" 9 | # 结果保存路径 10 | save_path = r"Data\test\predict" 11 | # 测试数据数目 12 | test_num = len(os.listdir(test_iamge_path)) 13 | # 类的数目(包括背景) 14 | classNum = 2 15 | # 模型输入图像大小 16 | input_size = (512, 512, 3) 17 | # 生成图像大小 18 | output_size = (492, 492) 19 | # 训练数据标签路径 20 | train_label_path = "Data\\train\\label" 21 | # 标签的颜色字典 22 | colorDict_RGB, colorDict_GRAY = color_dict(train_label_path, classNum) 23 | 24 | model = seg_hrnet(model_path) 25 | 26 | testGene = testGenerator(test_iamge_path, input_size) 27 | 28 | # 预测值的Numpy数组 29 | results = model.predict_generator(testGene, 30 | test_num, 31 | verbose = 1) 32 | 33 | # 保存结果 34 | saveResult(test_iamge_path, save_path, results, colorDict_GRAY, output_size) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 4 | from seg_unet import unet 5 | #from Model.seg_hrnet import seg_hrnet 6 | from dataProcess import trainGenerator, color_dict 7 | from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau 8 | import matplotlib.pyplot as plt 9 | import datetime 10 | import xlwt 11 | import os 12 | 13 | 14 | ''' 15 | 数据集相关参数 16 | ''' 17 | # 训练数据图像路径 18 | train_image_path = "Data\\train\\image" 19 | # 训练数据标签路径 20 | train_label_path = "Data\\train\\label" 21 | # 验证数据图像路径 22 | validation_image_path = "Data\\validation\\image" 23 | # 验证数据标签路径 24 | validation_label_path = "Data\\validation\\label" 25 | 26 | ''' 27 | 模型相关参数 28 | ''' 29 | # 批大小 30 | batch_size = 2 31 | # 类的数目(包括背景) 32 | classNum = 2 33 | # 模型输入图像大小 34 | input_size = (512, 512, 3) 35 | # 训练模型的迭代总轮数 36 | epochs = 50 37 | # 初始学习率 38 | learning_rate = 1e-4 39 | # 预训练模型地址 40 | premodel_path = None 41 | # 训练模型保存地址 42 | model_path = "Model\\unet_model.hdf5" 43 | 44 | # 训练数据数目 45 | train_num = len(os.listdir(train_image_path)) 46 | # 验证数据数目 47 | validation_num = len(os.listdir(validation_image_path)) 48 | # 训练集每个epoch有多少个batch_size 49 | steps_per_epoch = train_num / batch_size 50 | # 验证集每个epoch有多少个batch_size 51 | validation_steps = validation_num / batch_size 52 | # 标签的颜色字典,用于onehot编码 53 | colorDict_RGB, colorDict_GRAY = color_dict(train_label_path, classNum) 54 | 55 | 56 | # 得到一个生成器,以batch_size的速率生成训练数据 57 | train_Generator = trainGenerator(batch_size, 58 | train_image_path, 59 | train_label_path, 60 | classNum , 61 | colorDict_GRAY, 62 | input_size) 63 | 64 | # 得到一个生成器,以batch_size的速率生成验证数据 65 | validation_data = trainGenerator(batch_size, 66 | validation_image_path, 67 | validation_label_path, 68 | classNum, 69 | colorDict_GRAY, 70 | input_size) 71 | # 定义模型 72 | model = unet(pretrained_weights = premodel_path, 73 | input_size = input_size, 74 | classNum = classNum, 75 | learning_rate = learning_rate) 76 | #model = seg_hrnet(pretrained_weights = premodel_path, 77 | # input_size = input_size, 78 | # classNum = classNum, 79 | # learning_rate = learning_rate) 80 | # 打印模型结构 81 | model.summary() 82 | # 回调函数 83 | # val_loss连续10轮没有下降则停止训练 84 | early_stopping = EarlyStopping(monitor = 'val_loss', patience = 10) 85 | # 当3个epoch过去而val_loss不下降时,学习率减半 86 | reduce_lr = ReduceLROnPlateau(monitor = 'val_loss', factor = 0.5, patience = 3, verbose = 1) 87 | model_checkpoint = ModelCheckpoint(model_path, 88 | monitor = 'loss', 89 | verbose = 1,# 日志显示模式:0->安静模式,1->进度条,2->每轮一行 90 | save_best_only = True) 91 | 92 | # 获取当前时间 93 | start_time = datetime.datetime.now() 94 | 95 | # 模型训练 96 | history = model.fit_generator(train_Generator, 97 | steps_per_epoch = steps_per_epoch, 98 | epochs = epochs, 99 | callbacks = [early_stopping, model_checkpoint, model_checkpoint], 100 | validation_data = validation_data, 101 | validation_steps = validation_steps) 102 | 103 | # 训练总时间 104 | end_time = datetime.datetime.now() 105 | log_time = "训练总时间: " + str((end_time - start_time).seconds / 60) + "m" 106 | time = datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d-%H%M%S') 107 | print(log_time) 108 | with open('TrainTime_%s.txt'%time,'w') as f: 109 | f.write(log_time) 110 | 111 | # 保存并绘制loss,acc 112 | acc = history.history['acc'] 113 | val_acc = history.history['val_acc'] 114 | loss = history.history['loss'] 115 | val_loss = history.history['val_loss'] 116 | book = xlwt.Workbook(encoding='utf-8', style_compression=0) 117 | sheet = book.add_sheet('test', cell_overwrite_ok=True) 118 | for i in range(len(acc)): 119 | sheet.write(i, 0, acc[i]) 120 | sheet.write(i, 1, val_acc[i]) 121 | sheet.write(i, 2, loss[i]) 122 | sheet.write(i, 3, val_loss[i]) 123 | book.save(r'AccAndLoss_%s.xls'%time) 124 | epochs = range(1, len(acc) + 1) 125 | plt.plot(epochs, acc, 'r', label = 'Training acc') 126 | plt.plot(epochs, val_acc, 'b', label = 'Validation acc') 127 | plt.title('Training and validation accuracy') 128 | plt.legend() 129 | plt.savefig("accuracy_%s.png"%time, dpi = 300) 130 | plt.figure() 131 | plt.plot(epochs, loss, 'r', label = 'Training loss') 132 | plt.plot(epochs, val_loss, 'b', label = 'Validation loss') 133 | plt.title('Training and validation loss') 134 | plt.legend() 135 | plt.savefig("loss_%s.png"%time, dpi = 300) 136 | plt.show() 137 | --------------------------------------------------------------------------------