├── .idea ├── dictionaries │ ├── omnisky.xml │ └── root.xml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── semantic_segment_RSImage.iml └── vcs.xml ├── README.md ├── batch_predict ├── __init__.py ├── batch_base_predict_functions.py ├── batch_predict_binary_jaccard.py ├── batch_predict_binary_notonehot.py ├── batch_predict_binary_onehot.py ├── batch_predict_binary_onlyjaccard.py ├── batch_predict_multiclass.py └── batch_smooth_tiled_predictions.py ├── model_build ├── config.json ├── config.py ├── config_WHU_buildings.json ├── config_multiclass_global.json ├── config_multiclass_manmade.json ├── config_scrs_buildings.json ├── config_tuitiantu.json ├── mlp_model.json ├── model.json ├── test_1.py ├── train_main.py └── utils.py ├── model_predict ├── config_pred.json ├── config_pred.py ├── config_pred_WHU.json ├── config_pred_bieshu.json ├── config_pred_multiclass.json ├── config_pred_multiclass_manmade.json ├── config_pred_scrsbuilings.json ├── config_pred_tuitiantu.json ├── predict_backbone.py ├── predict_main.py ├── predict_main_blocks.py └── test_predict_1.py ├── postprocess ├── __init__.py ├── acc_evaluate.py ├── combine_diffclass_for_singlemodel_result.py ├── mismatch_analyze.py ├── raster_to_vector.py └── vote.py ├── predict ├── __init__.py ├── base_predict_functions.py ├── predict_binary_jaccard.py ├── predict_binary_notonehot.py ├── predict_binary_onehot.py ├── predict_binary_onlyjaccard.py ├── predict_multiclass.py └── smooth_tiled_predictions.py ├── samples_produce ├── __init__.py ├── check_original_labels_froNodata.py ├── label_visulise.py ├── sample_produce_for_singleimage.py ├── traindata_generate_byCV.py ├── traindata_generate_bygdal.py └── traindata_generate_common.py ├── segmentation_models ├── __init__.py ├── __version__.py ├── backbones │ ├── __init__.py │ ├── inception_resnet_v2.py │ ├── inception_v3.py │ ├── mobilenet.py │ └── mobilenetv2.py ├── common │ ├── __init__.py │ ├── blocks.py │ ├── functions.py │ └── layers.py ├── fpn │ ├── __init__.py │ ├── blocks.py │ ├── builder.py │ └── model.py ├── linknet │ ├── __init__.py │ ├── blocks.py │ ├── builder.py │ └── model.py ├── losses.py ├── metrics.py ├── pspnet │ ├── __init__.py │ ├── blocks.py │ ├── builder.py │ └── model.py ├── unet │ ├── __init__.py │ ├── blocks.py │ ├── builder.py │ └── model.py └── utils.py ├── temp ├── __init__.py ├── all_predict.py ├── band4_image.py ├── change_geotransform.py ├── change_label_zym.py ├── compose_labels.py ├── fcn8_train_binary.py ├── main.py ├── modify_segnet_train_labels.py ├── multibans_saveas_RGB.py ├── predict_from_xuhuimin.py ├── segnet_predict.py ├── segnet_train_binary.py ├── segnet_train_multiclass.py ├── test_cv2read.py ├── test_for_jaccrad_predict.py ├── test_unet_multiclass_predict.py ├── train_binary_4orMorebands.py ├── train_binary_jaccard_4orMorebands.py ├── train_multiclass_jaccard.py ├── unet_predict.py ├── unet_train_binary.py ├── unet_train_multiclass.py └── unet_train_qiwenchao.py ├── train ├── Unet_resnet.py ├── Unet_resnet_test.py ├── __init__.py ├── semantic_segmentation_networks.py ├── train_binary_jaccard.py ├── train_binary_jaccard_2.py ├── train_binary_notOneHot.py ├── train_binary_onehot.py ├── train_binary_onlyjaccard.py ├── train_binary_onlyjaccard_2.py └── train_multiclass.py ├── ui ├── MainWin.py ├── MainWin.ui ├── about.py ├── about.ui ├── classifyUi │ ├── PredictBackend.py │ ├── PredictBinaryBatch.py │ ├── PredictBinaryBatch.ui │ ├── PredictBinaryForSingleimage.py │ ├── PredictBinaryForSingleimage.ui │ ├── PredictMulticlassBatch.py │ ├── PredictMulticlassBatch.ui │ ├── PredictMulticlassForSingleimage.py │ ├── PredictMulticlassForSingleimage.ui │ └── predict_implements.py ├── else │ ├── manual.docx │ ├── scrslogo.png │ └── 中文标签.docx ├── main_gui.py ├── mysrc.qrc ├── mysrc_rc.py ├── postProcess │ ├── AccuracyEvaluate.py │ ├── AccuracyEvaluate.ui │ ├── Binarization.py │ ├── Binarization.ui │ ├── CombineMulticlassFromSingleModelResults.py │ ├── CombineMulticlassFromSingleModelResults.ui │ ├── PostPrecessBackend.py │ ├── VoteMultimodleResults.py │ ├── VoteMultimodleResults.ui │ └── postProcess_implements.py ├── preProcess │ ├── ImageClip.py │ ├── ImageClip.ui │ ├── ImageStretch.py │ ├── ImageStretch.ui │ ├── hist_process.py │ ├── label_check.py │ ├── label_check.ui │ ├── preprocess_backend.py │ └── preprocess_implements.py ├── sampleProduce │ ├── SampleGenCommon.py │ ├── SampleGenCommon.ui │ ├── SampleGenSelfAdapt.py │ ├── SampleGenSelfAdapt.ui │ ├── sampleProcess_backend.py │ └── sampleProcess_implements.py ├── tmp │ ├── new_train_backend.py │ └── new_train_implements.py └── trainUi │ ├── TrainBinaryCommon.py │ ├── TrainBinaryCommon.ui │ ├── TrainBinaryCrossentropy.py │ ├── TrainBinaryCrossentropy.ui │ ├── TrainBinaryJaccard.py │ ├── TrainBinaryJaccard.ui │ ├── TrainBinaryJaccardCrossentropy.py │ ├── TrainBinaryJaccardCrossentropy.ui │ ├── TrainBinaryOnehot.py │ ├── TrainBinaryOnehot.ui │ ├── TrainMulticlass.py │ ├── TrainMulticlass.ui │ ├── modelTrainBackend.py │ ├── trainModels_implements.py │ └── untitled.ui ├── ulitities ├── __init__.py ├── band_compose.py ├── base_functions.py ├── base_predict_functions.py ├── ecogToPredict.py ├── image_clip.py ├── image_stretch.py ├── resample_image.py ├── smooth_tiled_predictions.py ├── test_augument.py └── xml_prec.py └── venv ├── bin ├── activate ├── activate.csh ├── activate.fish ├── python └── python3 ├── lib64 └── pyvenv.cfg /.idea/dictionaries/omnisky.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | labelencoder 5 | multiclass 6 | unet 7 | 8 | 9 | -------------------------------------------------------------------------------- /.idea/dictionaries/root.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | dataset 5 | imagenet 6 | jaccard 7 | notonehot 8 | originaldata 9 | segnet 10 | 11 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 16 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /.idea/semantic_segment_RSImage.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 34 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | We use several semantic segmentation models to deal remote sensing images classifcation. 2 | 3 | Semantic segmentation networks:U-net,PSPNet, FPN, LinkNet, DeepLab V3+ 4 | 5 | Backones: 6 | VGG:'vgg16' 'vgg19' 7 | 8 | ResNet:'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' 9 | 10 | SE-ResNet:'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152' 11 | 12 | ResNeXt:'resnext50' 'resnext101' 13 | 14 | SE-ResNeXt:'seresnext50' 'seresnext101' 15 | 16 | SENet154:'senet154' 17 | 18 | DenseNet:'densenet121' 'densenet169' 'densenet201' 19 | 20 | Inception:'inceptionv3' 'inceptionresnetv2' 21 | 22 | MobileNet:'mobilenet' 'mobilenetv2' 23 | 24 | EfficientNet:'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 25 | 26 | Requirements: 27 | 28 | tensorflow-gpu==1.9.0(only test on V1.9.0) 29 | keras>=2.2.4 30 | keras_applications==1.0.7 31 | image-classifiers==0.2.0 32 | efficientnet>=0.0.3 33 | cuda==9.0 34 | qt==5.6 35 | numpy==1.12.0 36 | scipy==0.19.1 37 | tqdm==4.11.2 38 | 39 | 40 | -------------------------------------------------------------------------------- /batch_predict/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/batch_predict/__init__.py -------------------------------------------------------------------------------- /batch_predict/batch_predict_binary_jaccard.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | from ulitities.base_functions import get_file 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 # 224, 256, 288, 320 36 | # step = 128 37 | 38 | im_bands =3 39 | im_type = UINT8 # UINT10,UINT8,UINT16 40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 41 | dict_target={0: 'roads', 1: 'buildings'} 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings 45 | 46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 47 | 48 | input_path = '/home/omnisky/PycharmProjects/data/test/APtest/images/' 49 | output_path = ''.join(['/home/omnisky/PycharmProjects/data/test/APtest/pred_', str(window_size)]) 50 | 51 | 52 | # model_file = ''.join(['../../data/models/sat_urban_rgb/',dict_network[FLAG_USING_NETWORK], '_', 53 | # dict_target[FLAG_TARGET_CLASS],'_binary_jaccard_', str(window_size), '_final.h5']) 54 | model_file ='/home/omnisky/PycharmProjects/data/models/APsamples/unet_buildings_binary_jaccard_256_final.h5' 55 | 56 | 57 | print("model: {}".format(model_file)) 58 | 59 | def predict_binary_jaccard(img_file, output_file): 60 | 61 | print("[INFO] opening image...") 62 | input_img = load_img_by_gdal(img_file) 63 | if im_type == UINT8: 64 | input_img = input_img / 255.0 65 | elif im_type == UINT10: 66 | input_img = input_img / 1024.0 67 | elif im_type == UINT16: 68 | input_img = input_img / 65535.0 69 | 70 | input_img = np.clip(input_img, 0.0, 1.0) 71 | input_img = input_img.astype(np.float16) 72 | 73 | model = load_model(model_file) 74 | 75 | if FLAG_APPROACH_PREDICT==0: 76 | print("[INFO] predict image by orignal approach\n") 77 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 78 | abs_filename = os.path.split(img_file)[1] 79 | abs_filename = abs_filename.split(".")[0] 80 | # output_file = ''.join([output_path, '/original_pred_', 81 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png']) 82 | output_file = ''.join([output_path, '/mask_binary_', 83 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_original.png']) 84 | print("result save as to: {}".format(output_file)) 85 | cv2.imwrite(output_file, result*128) 86 | 87 | elif FLAG_APPROACH_PREDICT==1: 88 | print("[INFO] predict image by smooth approach\n") 89 | result = predict_img_with_smooth_windowing_multiclassbands( 90 | input_img, 91 | model, 92 | window_size=window_size, 93 | subdivisions=2, 94 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 95 | pred_func=smooth_predict_for_binary_notonehot, 96 | PLOT_PROGRESS=False 97 | ) 98 | 99 | cv2.imwrite(output_file, result) 100 | print("Saved to: {}".format(output_file)) 101 | 102 | gc.collect() 103 | 104 | 105 | if __name__ == '__main__': 106 | 107 | all_files, num = get_file(input_path) 108 | if num == 0: 109 | print("There is no file in path:{}".format(input_path)) 110 | sys.exit(-1) 111 | 112 | """checke model file""" 113 | print("model file: {}".format(model_file)) 114 | if not os.path.isfile(model_file): 115 | print("model does not exist:{}".format(model_file)) 116 | sys.exit(-2) 117 | 118 | for in_file in all_files: 119 | abs_filename = os.path.split(in_file)[1] 120 | abs_filename = abs_filename.split(".")[0] 121 | print(abs_filename) 122 | out_file = ''.join([output_path, '/mask_binary_', 123 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard.png']) 124 | predict_binary_jaccard(in_file, out_file) -------------------------------------------------------------------------------- /batch_predict/batch_predict_binary_notonehot.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | from ulitities.base_functions import get_file 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 36 | # step = 128 37 | 38 | im_bands =4 39 | im_type = UINT10 # UINT10,UINT8,UINT16 40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 41 | dict_target={0: 'roads', 1: 'buildings'} 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings 45 | 46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 47 | 48 | input_path = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/Tianfuxinqu/images//' 49 | output_path = ''.join(['/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/Tianfuxinqu/pred/pred_', str(window_size)]) 50 | 51 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_', 52 | dict_target[FLAG_TARGET_CLASS],'_binary_notonehot_final.h5']) 53 | 54 | print("model: {}".format(model_file)) 55 | 56 | def predict_binary_notonehot(img_file, output_file): 57 | 58 | print("[INFO] opening image...") 59 | 60 | input_img = load_img_by_gdal(img_file) 61 | if im_type == UINT8: 62 | input_img = input_img / 255.0 63 | elif im_type == UINT10: 64 | input_img = input_img / 1024.0 65 | elif im_type == UINT16: 66 | input_img = input_img / 65535.0 67 | 68 | input_img = np.clip(input_img, 0.0, 1.0) 69 | input_img = input_img.astype(np.float16) 70 | 71 | """checke model file""" 72 | print("model file: {}".format(model_file)) 73 | if not os.path.isfile(model_file): 74 | print("model does not exist:{}".format(model_file)) 75 | sys.exit(-2) 76 | 77 | model = load_model(model_file) 78 | 79 | if FLAG_APPROACH_PREDICT==0: 80 | print("[INFO] predict image by orignal approach\n") 81 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 82 | abs_filename = os.path.split(img_file)[1] 83 | abs_filename = abs_filename.split(".")[0] 84 | output_file = ''.join([output_path, '/original_pred_', 85 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard.png']) 86 | print("result save as to: {}".format(output_file)) 87 | cv2.imwrite(output_file, result*128) 88 | 89 | elif FLAG_APPROACH_PREDICT==1: 90 | print("[INFO] predict image by smooth approach\n") 91 | result = predict_img_with_smooth_windowing_multiclassbands( 92 | input_img, 93 | model, 94 | window_size=window_size, 95 | subdivisions=2, 96 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 97 | pred_func=smooth_predict_for_binary_notonehot, 98 | PLOT_PROGRESS=False 99 | ) 100 | 101 | cv2.imwrite(output_file, result) 102 | print("Saved to: {}".format(output_file)) 103 | 104 | gc.collect() 105 | 106 | 107 | if __name__ == '__main__': 108 | 109 | all_files, num = get_file(input_path) 110 | if num == 0: 111 | print("There is no file in path:{}".format(input_path)) 112 | sys.exit(-1) 113 | 114 | """checke model file""" 115 | print("model file: {}".format(model_file)) 116 | if not os.path.isfile(model_file): 117 | print("model does not exist:{}".format(model_file)) 118 | sys.exit(-2) 119 | 120 | for in_file in all_files: 121 | abs_filename = os.path.split(in_file)[1] 122 | abs_filename = abs_filename.split(".")[0] 123 | print(abs_filename) 124 | out_file = ''.join([output_path, '/mask_binary_', 125 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_notonehot.png']) 126 | predict_binary_notonehot(in_file, out_file) 127 | 128 | 129 | -------------------------------------------------------------------------------- /batch_predict/batch_predict_binary_onehot.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_binary_onehot 23 | from ulitities.base_functions import load_img_normalization, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | from ulitities.base_functions import get_file 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 36 | # step = 128 37 | im_bands = 3 38 | im_type = UINT8 39 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 40 | dict_target={0: 'roads', 1: 'buildings'} 41 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 42 | 43 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings 44 | 45 | FLAG_APPROACH_PREDICT = 0 # 0: original predict, 1: smooth predict 46 | 47 | input_path = '../../data/test/paper/images/' 48 | output_path = ''.join(['../../data/test/paper/pred_', str(window_size)]) 49 | 50 | # model_file = ''.join(['../../data/models/sat_urban_nrg/',dict_network[FLAG_USING_NETWORK], '_', dict_target[FLAG_TARGET_CLASS],'_binary.h5']) 51 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_buildings_binary2_onehot.h5' 52 | model_file='/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onehot.h5' 53 | print("model: {}".format(model_file)) 54 | 55 | def predict_binary_onehot(img_file, output_file): 56 | 57 | print("[INFO] opening image...") 58 | input_img = load_img_by_gdal(img_file) 59 | if im_type == UINT8: 60 | input_img = input_img / 255.0 61 | elif im_type == UINT10: 62 | input_img = input_img / 1024.0 63 | elif im_type == UINT16: 64 | input_img = input_img / 65535.0 65 | input_img = np.clip(input_img, 0.0, 1.0) 66 | 67 | """checke model file""" 68 | print("model file: {}".format(model_file)) 69 | if not os.path.isfile(model_file): 70 | print("model does not exist:{}".format(model_file)) 71 | sys.exit(-2) 72 | 73 | model = load_model(model_file) 74 | 75 | if FLAG_APPROACH_PREDICT==0: 76 | print("[INFO] predict image by orignal approach\n") 77 | result = orignal_predict_onehot(input_img, im_bands, model, window_size) 78 | abs_filename = os.path.split(img_file)[1] 79 | abs_filename = abs_filename.split(".")[0] 80 | print(abs_filename) 81 | output_file = ''.join(['../../data/predict/original_predict_',abs_filename, '_onehot.png']) 82 | print("result save as to: {}".format(output_file)) 83 | cv2.imwrite(output_file, result*100) 84 | 85 | elif FLAG_APPROACH_PREDICT==1: 86 | print("[INFO] predict image by smooth approach\n") 87 | result = predict_img_with_smooth_windowing_multiclassbands( 88 | input_img, 89 | model, 90 | window_size=window_size, 91 | subdivisions=2, 92 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 93 | pred_func=smooth_predict_for_binary_onehot, 94 | PLOT_PROGRESS=False 95 | ) 96 | 97 | cv2.imwrite(output_file, result) 98 | print("Saved to: {}".format(output_file)) 99 | 100 | gc.collect() 101 | 102 | 103 | if __name__ == '__main__': 104 | 105 | all_files, num = get_file(input_path) 106 | if num == 0: 107 | print("There is no file in path:{}".format(input_path)) 108 | sys.exit(-1) 109 | 110 | """checke model file""" 111 | print("model file: {}".format(model_file)) 112 | if not os.path.isfile(model_file): 113 | print("model does not exist:{}".format(model_file)) 114 | sys.exit(-2) 115 | 116 | for in_file in all_files: 117 | abs_filename = os.path.split(in_file)[1] 118 | abs_filename = abs_filename.split(".")[0] 119 | print(abs_filename) 120 | out_file = ''.join([output_path, '/mask_binary_', 121 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_onehot.png']) 122 | predict_binary_onehot(in_file, out_file) -------------------------------------------------------------------------------- /batch_predict/batch_predict_binary_onlyjaccard.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | from ulitities.base_functions import get_file 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 # 224, 256, 288 36 | # step = 128 37 | 38 | im_bands =4 39 | im_type = UINT10 # UINT10,UINT8,UINT16 40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 41 | dict_target={0: 'roads', 1: 'buildings'} 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings 45 | 46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 47 | 48 | 49 | input_path = '../../data/test/paper/images/' 50 | output_path = ''.join(['../../data/test/paper/pred_', str(window_size)]) 51 | 52 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_', 53 | dict_target[FLAG_TARGET_CLASS],'_binary_onlyjaccard_final.h5']) 54 | 55 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onlyjaccard_2018-09-29_18-55-11.h5' 56 | print("model: {}".format(model_file)) 57 | 58 | def predict_binary_only_jaccard(img_file, output_file): 59 | 60 | print("[INFO] opening image...") 61 | 62 | input_img = load_img_by_gdal(img_file) 63 | if im_type == UINT8: 64 | input_img = input_img / 255.0 65 | elif im_type == UINT10: 66 | input_img = input_img / 1024.0 67 | elif im_type == UINT16: 68 | input_img = input_img / 65535.0 69 | 70 | input_img = np.clip(input_img, 0.0, 1.0) 71 | 72 | model = load_model(model_file) 73 | 74 | if FLAG_APPROACH_PREDICT==0: 75 | print("[INFO] predict image by orignal approach\n") 76 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 77 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_', 78 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png']) 79 | print("result save as to: {}".format(output_file)) 80 | cv2.imwrite(output_file, result*128) 81 | 82 | elif FLAG_APPROACH_PREDICT==1: 83 | print("[INFO] predict image by smooth approach\n") 84 | result = predict_img_with_smooth_windowing_multiclassbands( 85 | input_img, 86 | model, 87 | window_size=window_size, 88 | subdivisions=2, 89 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 90 | pred_func=smooth_predict_for_binary_notonehot, 91 | PLOT_PROGRESS=False 92 | ) 93 | 94 | cv2.imwrite(output_file, result) 95 | print("Saved to: {}".format(output_file)) 96 | 97 | gc.collect() 98 | 99 | 100 | 101 | if __name__=='__main__': 102 | 103 | all_files, num= get_file(input_path) 104 | if num==0: 105 | print("There is no file in path:{}".format(input_path)) 106 | sys.exit(-1) 107 | 108 | 109 | """checke model file""" 110 | print("model file: {}".format(model_file)) 111 | if not os.path.isfile(model_file): 112 | print("model does not exist:{}".format(model_file)) 113 | sys.exit(-2) 114 | 115 | for in_file in all_files: 116 | abs_filename = os.path.split(in_file)[1] 117 | abs_filename = abs_filename.split(".")[0] 118 | print(abs_filename) 119 | out_file = ''.join([output_path, '/mask_binary_', 120 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_onlyjaccard.png']) 121 | predict_binary_only_jaccard(in_file, out_file) 122 | 123 | 124 | -------------------------------------------------------------------------------- /batch_predict/batch_predict_multiclass.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_multiclass 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | from ulitities.base_functions import get_file 26 | """ 27 | The following global variables should be put into meta data file 28 | """ 29 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 30 | 31 | 32 | window_size = 256 33 | step = 128 34 | 35 | im_bands = 4 36 | im_type = UINT10 #UINT8, UINT10, UINT16 37 | 38 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 39 | dict_target={0: 'roads', 1: 'buildings'} 40 | target_class=len(dict_target) 41 | 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_APPROACH_PREDICT=1 # 0: original predict, 1: smooth predict 45 | 46 | input_path = '../../data/test/paper/images/' 47 | output_path = ''.join(['../../data/test/paper/pred_', str(window_size)]) 48 | 49 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_multiclass_final.h5']) 50 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_multiclass_2018-09-11_14-05-31.h5' 51 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_multiclass.h5' 52 | 53 | def predict_multiclass(img_file, out_path): 54 | 55 | print("[INFO] opening image...") 56 | input_img = load_img_by_gdal(img_file) 57 | if im_type == UINT8: 58 | input_img = input_img / 255.0 59 | elif im_type == UINT10: 60 | input_img = input_img / 1024.0 61 | elif im_type == UINT16: 62 | input_img = input_img / 65535.0 63 | input_img = np.clip(input_img, 0.0, 1.0) 64 | 65 | """checke model file""" 66 | print("model file: {}".format(model_file)) 67 | if not os.path.isfile(model_file): 68 | print("model does not exist:{}".format(model_file)) 69 | sys.exit(-2) 70 | 71 | model= load_model(model_file) 72 | abs_filename = os.path.split(img_file)[1] 73 | abs_filename = abs_filename.split(".")[0] 74 | print(abs_filename) 75 | 76 | if FLAG_APPROACH_PREDICT==0: 77 | print("[INFO] predict image by orignal approach\n") 78 | result = orignal_predict_onehot(input_img, im_bands, model, window_size) 79 | output_file = ''.join([out_path, '/original_predict_',abs_filename, '_multiclass.png']) 80 | print("result save as to: {}".format(output_file)) 81 | cv2.imwrite(output_file, result*128) 82 | 83 | elif FLAG_APPROACH_PREDICT==1: 84 | print("[INFO] predict image by smooth approach\n") 85 | result = predict_img_with_smooth_windowing_multiclassbands( 86 | input_img, 87 | model, 88 | window_size=window_size, 89 | subdivisions=2, 90 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 91 | pred_func=smooth_predict_for_multiclass, 92 | PLOT_PROGRESS=False 93 | ) 94 | 95 | for b in range(target_class): 96 | output_file = ''.join([out_path, '/mask_multiclass_', 97 | abs_filename, '_', dict_target[b], '.png']) 98 | cv2.imwrite(output_file, result[:,:,b]) 99 | print("Saved to: {}".format(output_file)) 100 | gc.collect() 101 | 102 | if __name__ == '__main__': 103 | 104 | all_files, num = get_file(input_path) 105 | if num == 0: 106 | print("There is no file in path:{}".format(input_path)) 107 | sys.exit(-1) 108 | 109 | """checke model file""" 110 | print("model file: {}".format(model_file)) 111 | if not os.path.isfile(model_file): 112 | print("model does not exist:{}".format(model_file)) 113 | sys.exit(-2) 114 | 115 | for in_file in all_files: 116 | abs_filename = os.path.split(in_file)[1] 117 | abs_filename = abs_filename.split(".")[0] 118 | print(abs_filename) 119 | # out_file = ''.join([output_path, '/mask_binary_', 120 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_notonehot.png']) 121 | predict_multiclass(in_file, output_path) 122 | 123 | 124 | -------------------------------------------------------------------------------- /model_build/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_path":"/home/omnisky/PycharmProjects/data/traindata/rice/", 3 | "img_w":256, 4 | "img_h":256, 5 | "im_bands":3, 6 | "im_type":"UINT8", 7 | "target_name":"rice", 8 | "val_rate":0.25, 9 | 10 | "network":"unet", 11 | "BACKBONE":"resnet34", 12 | "activation":"sigmoid", 13 | "encoder_weights":null, 14 | "nb_classes":1, 15 | "batch_size":32, 16 | "epochs": 50, 17 | 18 | "loss":"binary_crossentropy", 19 | "metrics":"accuracy", 20 | "optimizer": "sgd", 21 | "lr": 0.0001, 22 | "lr_steps": [540, 560], 23 | "lr_gamma": 0.1, 24 | "lr_scheduler":"CosineAnnealingLR", 25 | "nb_epoch": 48, 26 | "old_epoch":0, 27 | "test_pad": 64, 28 | 29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/rice/", 30 | "base_model":"", 31 | "monitor":"val_loss", 32 | "save_best_only":true, 33 | "mode":"max", 34 | "factor":0.1, 35 | "patience":5, 36 | "epsilon":0.0001, 37 | "cooldown":0, 38 | "min_lr":0, 39 | 40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/", 41 | "iter_size": 1, 42 | "folder": "0312_rice_size", 43 | "predict_batch_size": 16, 44 | "results_dir": "/media/scrs/Data1/chenjun/data/rice/results/", 45 | "loss_x": {"dice": 0.9, "bce": 0.1}, 46 | "loss_ema":0.3, 47 | "loss_kl":25, 48 | "loss_student":1, 49 | "ignore_target_size": false, 50 | "warmup": 0, 51 | "lovasz":false, 52 | "unlabel_size":1, 53 | "negative_rate":1, 54 | "external_weights":"", 55 | "class_weights":[], 56 | "train_fold":-1, 57 | "use_lb":true, 58 | "train_images":"images_clip", 59 | "train_images_unlabel":"", 60 | "train_masks":"label_half_nodata", 61 | "train_image_suffix":"*.png", 62 | "csv_file":"folds_rice.csv", 63 | "num_workers":4, 64 | "use_border":false, 65 | "class_picture":false, 66 | "use_good_image":true, 67 | "hard_negative_miner": false, 68 | "mixup": false 69 | 70 | } 71 | -------------------------------------------------------------------------------- /model_build/config.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | Config = namedtuple("Config", [ 5 | "train_data_path", 6 | "img_w", 7 | "img_h", 8 | "im_bands", 9 | "im_type", 10 | "target_name", 11 | "val_rate", 12 | "network", 13 | "BACKBONE", 14 | "activation", 15 | "encoder_weights", 16 | "nb_classes", 17 | "batch_size", 18 | "epochs", 19 | "optimizer", 20 | "loss", 21 | "metrics", 22 | "lr", 23 | "lr_steps", 24 | "lr_gamma", 25 | "lr_scheduler", 26 | "nb_epoch", 27 | "old_epoch", 28 | "test_pad", 29 | "model_dir", 30 | "base_model", 31 | "monitor", 32 | "save_best_only", 33 | "mode", 34 | "factor", 35 | "patience", 36 | "epsilon", 37 | "cooldown", 38 | "min_lr", 39 | "log_dir", 40 | "iter_size", 41 | "folder", 42 | "predict_batch_size", 43 | "results_dir", 44 | "loss_x", 45 | "loss_ema", 46 | "loss_kl", 47 | "loss_student", 48 | "ignore_target_size", 49 | "warmup", 50 | "lovasz", 51 | "unlabel_size", 52 | "negative_rate", 53 | "external_weights", 54 | "class_weights", 55 | "train_fold", 56 | "use_lb", 57 | "train_images", 58 | "train_images_unlabel", 59 | "train_masks", 60 | "train_image_suffix", 61 | "csv_file", 62 | "num_workers", 63 | "use_border", 64 | "class_picture", 65 | "use_good_image", 66 | "hard_negative_miner", 67 | "mixup" 68 | ]) 69 | 70 | 71 | -------------------------------------------------------------------------------- /model_build/config_WHU_buildings.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_path":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/traindata/train_my/", 3 | "img_w":512, 4 | "img_h":512, 5 | "im_bands":3, 6 | "im_type":"UINT8", 7 | "target_name":"WHU_building", 8 | "val_rate":0.25, 9 | 10 | "network":"fpn", 11 | "BACKBONE":"seresnet18", 12 | "activation":"sigmoid", 13 | "encoder_weights":"imagenet", 14 | "nb_classes":1, 15 | "batch_size":4, 16 | "epochs": 80, 17 | 18 | "loss":"bce_jaccard_loss", 19 | "metrics":"iou_score", 20 | "optimizer": "sgd", 21 | "lr": 0.0001, 22 | "lr_steps": [540, 560], 23 | "lr_gamma": 0.1, 24 | "lr_scheduler":"CosineAnnealingLR", 25 | "nb_epoch": 48, 26 | "old_epoch":0, 27 | "test_pad": 64, 28 | 29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/WHU/pre-trained/", 30 | "base_model":"", 31 | "monitor":"val_loss", 32 | "save_best_only":true, 33 | "mode":"min", 34 | "factor":0.1, 35 | "patience":10, 36 | "epsilon":0.0001, 37 | "cooldown":0, 38 | "min_lr":0, 39 | 40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/", 41 | "iter_size": 1, 42 | "folder": "0312_rice_size", 43 | "predict_batch_size": 16, 44 | "results_dir": "/media/scrs/Data1/chenjun/data/rice/results/", 45 | "loss_x": {"dice": 0.9, "bce": 0.1}, 46 | "loss_ema":0.3, 47 | "loss_kl":25, 48 | "loss_student":1, 49 | "ignore_target_size": false, 50 | "warmup": 0, 51 | "lovasz":false, 52 | "unlabel_size":1, 53 | "negative_rate":1, 54 | "external_weights":"", 55 | "class_weights":[], 56 | "train_fold":-1, 57 | "use_lb":true, 58 | "train_images":"images_clip", 59 | "train_images_unlabel":"", 60 | "train_masks":"label_half_nodata", 61 | "train_image_suffix":"*.png", 62 | "csv_file":"folds_rice.csv", 63 | "num_workers":4, 64 | "use_border":false, 65 | "class_picture":false, 66 | "use_good_image":true, 67 | "hard_negative_miner": false, 68 | "mixup": false 69 | 70 | } 71 | -------------------------------------------------------------------------------- /model_build/config_multiclass_global.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_path":"/home/omnisky/PycharmProjects/data/traindata/global/multiclass_miandiantaiguo_histM/", 3 | "img_w":288, 4 | "img_h":288, 5 | "im_bands":4, 6 | "im_type":"uint10", 7 | "target_name":"global", 8 | "val_rate":0.25, 9 | 10 | "network":"unet", 11 | "BACKBONE":"resnet34", 12 | "activation":"softmax", 13 | "encoder_weights":null, 14 | "nb_classes":6, 15 | "batch_size":4, 16 | "epochs": 100, 17 | 18 | "loss":"categorical_crossentropy", 19 | "metrics":"accuracy", 20 | "optimizer": "adam", 21 | "lr": 0.0001, 22 | "lr_steps": [540, 560], 23 | "lr_gamma": 0.1, 24 | "lr_scheduler":"CosineAnnealingLR", 25 | "nb_epoch": 48, 26 | "old_epoch":0, 27 | "test_pad": 64, 28 | 29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/global/", 30 | "base_model":"", 31 | "monitor":"val_loss", 32 | "save_best_only":true, 33 | "mode":"min", 34 | "factor":0.1, 35 | "patience":10, 36 | "epsilon":0.0001, 37 | "cooldown":0, 38 | "min_lr":0, 39 | 40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/", 41 | "iter_size": 1, 42 | "folder": "0312_rice_size", 43 | "predict_batch_size": 16, 44 | "results_dir": "", 45 | "loss_x": {"dice": 0.9, "bce": 0.1}, 46 | "loss_ema":0.3, 47 | "loss_kl":25, 48 | "loss_student":1, 49 | "ignore_target_size": false, 50 | "warmup": 0, 51 | "lovasz":false, 52 | "unlabel_size":1, 53 | "negative_rate":1, 54 | "external_weights":"", 55 | "class_weights":[], 56 | "train_fold":-1, 57 | "use_lb":true, 58 | "train_images":"images_clip", 59 | "train_images_unlabel":"", 60 | "train_masks":"label_half_nodata", 61 | "train_image_suffix":"*.png", 62 | "csv_file":"folds_rice.csv", 63 | "num_workers":4, 64 | "use_border":false, 65 | "class_picture":false, 66 | "use_good_image":true, 67 | "hard_negative_miner": false, 68 | "mixup": false 69 | 70 | } 71 | -------------------------------------------------------------------------------- /model_build/config_multiclass_manmade.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_path":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/traindata/sat_4bands_288/multiclass/", 3 | "img_w":256, 4 | "img_h":256, 5 | "im_bands":4, 6 | "im_type":"uint10", 7 | "target_name":"manmade", 8 | "val_rate":0.25, 9 | 10 | "network":"fpn", 11 | "BACKBONE":"resnet34", 12 | "activation":"softmax", 13 | "encoder_weights":null, 14 | "nb_classes":3, 15 | "batch_size":32, 16 | "epochs": 100, 17 | 18 | "loss":"categorical_crossentropy", 19 | "metrics":"accuracy", 20 | "optimizer": "adam", 21 | "lr": 0.0001, 22 | "lr_steps": [540, 560], 23 | "lr_gamma": 0.1, 24 | "lr_scheduler":"CosineAnnealingLR", 25 | "nb_epoch": 48, 26 | "old_epoch":0, 27 | "test_pad": 64, 28 | 29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/", 30 | "base_model":"", 31 | "monitor":"val_loss", 32 | "save_best_only":true, 33 | "mode":"min", 34 | "factor":0.1, 35 | "patience":5, 36 | "epsilon":0.0001, 37 | "cooldown":0, 38 | "min_lr":0, 39 | 40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/", 41 | "iter_size": 1, 42 | "folder": "0312_rice_size", 43 | "predict_batch_size": 16, 44 | "results_dir": "", 45 | "loss_x": {"dice": 0.9, "bce": 0.1}, 46 | "loss_ema":0.3, 47 | "loss_kl":25, 48 | "loss_student":1, 49 | "ignore_target_size": false, 50 | "warmup": 0, 51 | "lovasz":false, 52 | "unlabel_size":1, 53 | "negative_rate":1, 54 | "external_weights":"", 55 | "class_weights":[], 56 | "train_fold":-1, 57 | "use_lb":true, 58 | "train_images":"images_clip", 59 | "train_images_unlabel":"", 60 | "train_masks":"label_half_nodata", 61 | "train_image_suffix":"*.png", 62 | "csv_file":"folds_rice.csv", 63 | "num_workers":4, 64 | "use_border":false, 65 | "class_picture":false, 66 | "use_good_image":true, 67 | "hard_negative_miner": false, 68 | "mixup": false 69 | 70 | } 71 | -------------------------------------------------------------------------------- /model_build/config_scrs_buildings.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_path":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/traindata/buildings_576_normal/", 3 | "img_w":576, 4 | "img_h":576, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"scrs_building", 8 | "val_rate":0.25, 9 | 10 | "network":"unet", 11 | "BACKBONE":"inceptionresnetv2", 12 | "activation":"sigmoid", 13 | "encoder_weights":null, 14 | "nb_classes":1, 15 | "batch_size":4, 16 | "epochs": 80, 17 | 18 | "loss":"bce_jaccard_loss", 19 | "metrics":"iou_score", 20 | "optimizer": "sgd", 21 | "lr": 0.0001, 22 | "lr_steps": [540, 560], 23 | "lr_gamma": 0.1, 24 | "lr_scheduler":"CosineAnnealingLR", 25 | "nb_epoch": 48, 26 | "old_epoch":0, 27 | "test_pad": 64, 28 | 29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/scrs_buildings/", 30 | "base_model":"", 31 | "monitor":"val_loss", 32 | "save_best_only":true, 33 | "mode":"min", 34 | "factor":0.1, 35 | "patience":10, 36 | "epsilon":0.0001, 37 | "cooldown":0, 38 | "min_lr":0, 39 | 40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/", 41 | "iter_size": 1, 42 | "folder": "0312_rice_size", 43 | "predict_batch_size": 16, 44 | "results_dir": "/media/scrs/Data1/chenjun/data/rice/results/", 45 | "loss_x": {"dice": 0.9, "bce": 0.1}, 46 | "loss_ema":0.3, 47 | "loss_kl":25, 48 | "loss_student":1, 49 | "ignore_target_size": false, 50 | "warmup": 0, 51 | "lovasz":false, 52 | "unlabel_size":1, 53 | "negative_rate":1, 54 | "external_weights":"", 55 | "class_weights":[], 56 | "train_fold":-1, 57 | "use_lb":true, 58 | "train_images":"images_clip", 59 | "train_images_unlabel":"", 60 | "train_masks":"label_half_nodata", 61 | "train_image_suffix":"*.png", 62 | "csv_file":"folds_rice.csv", 63 | "num_workers":4, 64 | "use_border":false, 65 | "class_picture":false, 66 | "use_good_image":true, 67 | "hard_negative_miner": false, 68 | "mixup": false 69 | 70 | } 71 | -------------------------------------------------------------------------------- /model_build/config_tuitiantu.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_data_path":"/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/traindata/tuitiantu_480/", 3 | "img_w":480, 4 | "img_h":480, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"tuitiantu", 8 | "val_rate":0.25, 9 | 10 | "network":"unet", 11 | "BACKBONE":"resnet101", 12 | "activation":"sigmoid", 13 | "encoder_weights":null, 14 | "nb_classes":1, 15 | "batch_size":8, 16 | "epochs": 100, 17 | 18 | "loss":"bce_jaccard_loss", 19 | "metrics":"iou_score", 20 | "optimizer": "sgd", 21 | "lr": 0.0001, 22 | "lr_steps": [540, 560], 23 | "lr_gamma": 0.1, 24 | "lr_scheduler":"CosineAnnealingLR", 25 | "nb_epoch": 48, 26 | "old_epoch":0, 27 | "test_pad": 64, 28 | 29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/ducha/tuitiantu/", 30 | "base_model":"", 31 | "monitor":"val_loss", 32 | "save_best_only":true, 33 | "mode":"min", 34 | "factor":0.1, 35 | "patience":10, 36 | "epsilon":0.0001, 37 | "cooldown":0, 38 | "min_lr":0, 39 | 40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/", 41 | "iter_size": 1, 42 | "folder": "", 43 | "predict_batch_size": 16, 44 | "results_dir": "", 45 | "loss_x": {"dice": 0.9, "bce": 0.1}, 46 | "loss_ema":0.3, 47 | "loss_kl":25, 48 | "loss_student":1, 49 | "ignore_target_size": false, 50 | "warmup": 0, 51 | "lovasz":false, 52 | "unlabel_size":1, 53 | "negative_rate":1, 54 | "external_weights":"", 55 | "class_weights":[], 56 | "train_fold":-1, 57 | "use_lb":true, 58 | "train_images":"images_clip", 59 | "train_images_unlabel":"", 60 | "train_masks":"label_half_nata", 61 | "train_image_suffix":"*.png", 62 | "csv_file":"fe.csv", 63 | "num_workers":4, 64 | "use_border":false, 65 | "class_picture":false, 66 | "use_good_image":true, 67 | "hard_negative_miner": false, 68 | "mixup": false 69 | 70 | } 71 | -------------------------------------------------------------------------------- /model_build/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import shutil 4 | import time 5 | import os 6 | from config import Config 7 | def get_csv_folds(path, d): 8 | df = pd.read_csv(path, index_col=0) 9 | m = df.max()[0] + 1 10 | train = [[] for i in range(m)] 11 | test = [[] for i in range(m)] 12 | 13 | folds = {} 14 | for i in range(m): 15 | fold_ids = list(df[df['fold'].isin([i])].index) 16 | folds.update({i: [n for n, l in enumerate(d) if l in fold_ids]}) 17 | 18 | for k, v in folds.items(): 19 | for i in range(m): 20 | if i != k: 21 | train[i].extend(v) 22 | test[k] = v 23 | 24 | return list(zip(np.array(train), np.array(test))) 25 | def update_config(config, **kwargs): 26 | d = config._asdict() 27 | d.update(**kwargs) 28 | print(d) 29 | return Config(**d) 30 | def save(path,network,jsonPath=None): 31 | 32 | folder=time.strftime("%Y%m%d%H%M", time.localtime()) 33 | new_path=os.path.join(path,"history",folder+"_"+network) 34 | try: 35 | os.makedirs(new_path) 36 | except: 37 | new_path=new_path+"_"+str(np.random.randint(0,100)) 38 | os.makedirs ( new_path ) 39 | shutil.copy(os.path.join(path,"train.py"),new_path) 40 | if jsonPath is None: 41 | shutil.copy ( os.path.join ( path , "config.json" ) , new_path ) 42 | else: 43 | shutil.copy ( jsonPath , new_path ) -------------------------------------------------------------------------------- /model_predict/config_pred.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/home/omnisky/PycharmProjects/data/test/rice/normal/", 3 | "img_w":256, 4 | "img_h":256, 5 | "im_bands":3, 6 | "im_type":"UINT8", 7 | "target_name":"rice", 8 | "activation":"sigmoid", 9 | "mask_classes":1, 10 | "strategy":"smooth", 11 | "window_size":256, 12 | "subdivisions":2, 13 | "slices":4, 14 | "block_size":100000000, 15 | "nodata":65535, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/rice/unet_test_Unet_resnet2019-03-25_11-57-12.h5", 17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/rice/pred/", 18 | "suffix":".png" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/config_pred.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | Config_Pred = namedtuple("Config_pred", [ 5 | "img_input", 6 | "img_w", 7 | "img_h", 8 | "im_bands", 9 | "im_type", 10 | "target_name", 11 | "model_path", 12 | "activation", 13 | "mask_classes", 14 | "strategy", 15 | "window_size", 16 | "subdivisions", 17 | "slices", 18 | "block_size", 19 | "nodata", 20 | "mask_dir", 21 | "suffix" 22 | ]) 23 | 24 | 25 | -------------------------------------------------------------------------------- /model_predict/config_pred_WHU.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/images/", 3 | "img_w":512, 4 | "img_h":512, 5 | "im_bands":3, 6 | "im_type":"UINT8", 7 | "target_name":"buildings", 8 | "activation":"sigmoid", 9 | "mask_classes":1, 10 | "strategy":"smooth", 11 | "window_size":512, 12 | "subdivisions":2, 13 | "slices":1, 14 | "block_size":200000000, 15 | "nodata":256, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/WHU/pre-trained/WHU_building_unet_resnet101_bce_jaccard_loss_512_2019-05-09_14-01-47.h5", 17 | "mask_dir": "/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/pred/", 18 | "suffix":".png" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/config_pred_bieshu.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/media/omnisky/579215ea-c392-4d2f-b136-6b62d9fb31ac/test/bieshu/norm_images/", 3 | "img_w":480, 4 | "img_h":480, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"bieshu", 8 | "activation":"sigmoid", 9 | "mask_classes":1, 10 | "strategy":"smooth", 11 | "window_size":480, 12 | "subdivisions":2, 13 | "slices":8, 14 | "block_size":5000000000, 15 | "nodata":65535, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/bieshu/bieshu_unet_vgg16_binary_crossentropy_480_2019-06-01_22-49-00.h5", 17 | "mask_dir": "/media/omnisky/579215ea-c392-4d2f-b136-6b62d9fb31ac/test/bieshu/pred/", 18 | "suffix":".tif" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/config_pred_multiclass.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/images_forClass/test_normal/", 3 | "img_w":288, 4 | "img_h":288, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"global", 8 | "activation":"softmax", 9 | "mask_classes":6, 10 | "strategy":"smooth", 11 | "window_size":288, 12 | "subdivisions":2, 13 | "slices":8, 14 | "block_size":500000000, 15 | "nodata":65535, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/global/global16bits_miandiantaiguosamples_fpn_resnet34_categorical_crossentropy_288_2019-04-06_17-25-00.h5", 17 | "mask_dir": "/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/pred/", 18 | "suffix":".tif" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/config_pred_multiclass_manmade.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/home/omnisky/PycharmProjects/data/test/paper/images/", 3 | "img_w":256, 4 | "img_h":256, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"manmade", 8 | "activation":"softmax", 9 | "mask_classes":3, 10 | "strategy":"smooth", 11 | "window_size":256, 12 | "subdivisions":2, 13 | "slices":8, 14 | "block_size":100000000, 15 | "nodata":65535, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/manmade_fpn_resnet34_categorical_crossentropy_256_2019-04-19_14-09-20.h5", 17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/paper/", 18 | "suffix":".png" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/config_pred_scrsbuilings.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/home/omnisky/PycharmProjects/data/test/tianfuxinqu/stretched/", 3 | "img_w":480, 4 | "img_h":480, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"tuitiantu", 8 | "activation":"sigmoid", 9 | "mask_classes":1, 10 | "strategy":"smooth", 11 | "window_size":480, 12 | "subdivisions":2, 13 | "slices":8, 14 | "block_size":500000000, 15 | "nodata":65535, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/ducha/tuitiantu/tuitiantu_unet_vgg16_bce_jaccard_loss_480_2019-05-09_08-40-41.h5", 17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/tianfuxinqu/tuitiantu_pred_2019/", 18 | "suffix":".png" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/config_pred_tuitiantu.json: -------------------------------------------------------------------------------- 1 | { 2 | "img_input":"/home/omnisky/PycharmProjects/data/test/tianfuxinqu/stretched/", 3 | "img_w":576, 4 | "img_h":576, 5 | "im_bands":4, 6 | "im_type":"UINT10", 7 | "target_name":"tuitiantu", 8 | "activation":"sigmoid", 9 | "mask_classes":1, 10 | "strategy":"smooth", 11 | "window_size":576, 12 | "subdivisions":2, 13 | "slices":8, 14 | "block_size":500000000, 15 | "nodata":65535, 16 | "model_path":"/home/omnisky/PycharmProjects/data/models/scrs_buildings/scrs_building_unet_inceptionv3_bce_jaccard_loss_576_2019-05-21_16-02-24.h5", 17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/tianfuxinqu/scrs_buildings/", 18 | "suffix":".png" 19 | } 20 | -------------------------------------------------------------------------------- /model_predict/test_predict_1.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | from keras.models import load_model 13 | from sklearn.preprocessing import LabelEncoder 14 | 15 | 16 | from keras import backend as K 17 | K.set_image_dim_ordering('tf') 18 | K.clear_session() 19 | from segmentation_models.losses import bce_jaccard_loss 20 | from segmentation_models.metrics import iou_score 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from predict_backbone import predict_img_with_smooth_windowing_multiclassbands 25 | 26 | """ 27 | The following global variables should be put into meta data file 28 | """ 29 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 30 | 31 | target_class =1 32 | 33 | window_size = 256 34 | # step = 128 35 | 36 | im_bands =3 37 | im_type = UINT8 # UINT10,UINT8,UINT16 38 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 39 | dict_target={0: 'roads', 1: 'buildings'} 40 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 41 | 42 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings 43 | 44 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 45 | 46 | img_file = '/home/omnisky/PycharmProjects/data/test/rice/normal/testsrc_1.png' 47 | 48 | model_file = '/home/omnisky/PycharmProjects/data/models/rice/rice_unet_resnet34_bce_jaccard_loss_256_2019-03-28_20-42-58.h5' 49 | print("model: {}".format(model_file)) 50 | 51 | if __name__ == '__main__': 52 | 53 | print("[INFO] opening image...") 54 | 55 | input_img = load_img_by_gdal(img_file) 56 | if im_type == UINT8: 57 | input_img = input_img / 255.0 58 | elif im_type == UINT10: 59 | input_img = input_img / 1024.0 60 | elif im_type == UINT16: 61 | input_img = input_img / 65535.0 62 | 63 | input_img = np.clip(input_img, 0.0, 1.0) 64 | input_img = input_img.astype(np.float32) 65 | 66 | abs_filename = os.path.split(img_file)[1] 67 | abs_filename = abs_filename.split(".")[0] 68 | print (abs_filename) 69 | 70 | """checke model file""" 71 | print("model file: {}".format(model_file)) 72 | if not os.path.isfile(model_file): 73 | print("model does not exist:{}".format(model_file)) 74 | sys.exit(-2) 75 | 76 | model = load_model(model_file) 77 | 78 | if FLAG_APPROACH_PREDICT==0: 79 | print("[INFO] predict image by orignal approach\n") 80 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 81 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_', 82 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_notonehot.png']) 83 | print("result save as to: {}".format(output_file)) 84 | cv2.imwrite(output_file, result*128) 85 | 86 | elif FLAG_APPROACH_PREDICT==1: 87 | print("[INFO] predict image by smooth approach\n") 88 | result = predict_img_with_smooth_windowing_multiclassbands( 89 | input_img, 90 | model, 91 | window_size=window_size, 92 | subdivisions=2, 93 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 94 | pred_func=smooth_predict_for_binary_notonehot 95 | ) 96 | """for single class test""" 97 | result[result<128]=0 98 | result[result>=128]=1 99 | 100 | output_file = '/home/omnisky/PycharmProjects/data/test/rice/newpred/test_1_unet_resnet34_jaccard.png' 101 | 102 | print("result save as to: {}".format(output_file)) 103 | 104 | 105 | 106 | cv2.imwrite(output_file, result) 107 | 108 | gc.collect() 109 | 110 | 111 | -------------------------------------------------------------------------------- /postprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/postprocess/__init__.py -------------------------------------------------------------------------------- /postprocess/combine_diffclass_for_singlemodel_result.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | 10 | from ulitities.base_functions import load_img_by_cv2 11 | 12 | FOREGROUND = 127# for segnet:40; for unet=127; define the foreground value 13 | 14 | ROAD_VALUE=127 15 | BUILDING_VALUE=255 16 | 17 | """for unet""" 18 | # input_path = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/Tianfuxinqu/pred/pred_256/' 19 | input_path = '../../data/test/tianfuxinqu/pred/pred_256/' 20 | mask_pool = ['mask_binary_fw01_buildings_jaccard.png', 'mask_binary_fw01_roads_jaccard.png'] 21 | output_file = input_path+'/combined/mask_2018_256_fw0_smooth_combined.png' 22 | print(output_file) 23 | 24 | # mask_pool = ['mask_multiclass_3_buildings.png','mask_multiclass_3_roads.png'] 25 | # output_file = '../../data/predict/unet/unet_multiclass_combined_3.png' 26 | 27 | """for segnet""" 28 | # input_path = '../../data/predict/segnet/' 29 | # mask_pool = ['mask_binary_3_buildings.png','mask_binary_3_roads.png'] 30 | # output_file = '../../data/predict/segnet/segnet_binary_combined_3.png' 31 | 32 | # mask_pool = ['mask_multiclass_3_buildings.png','mask_multiclass_3_roads.png'] 33 | # output_file = '../../data/predict/segnet/segnet_multiclass_combined_3.png' 34 | 35 | 36 | def check_input_file(path,masks): 37 | ret, img_1 = load_img_by_cv2(path+masks[0], grayscale=True) 38 | assert (ret == 0) 39 | 40 | height, width = img_1.shape 41 | num_img = len(masks) 42 | 43 | for next_index in range(1,num_img): 44 | next_ret, next_img=load_img_by_cv2(path+masks[next_index],grayscale=True) 45 | assert (next_ret ==0 ) 46 | next_height, next_width = next_img.shape 47 | assert(height==next_height and width==next_width) 48 | return height, width 49 | 50 | 51 | 52 | def combine_all_mask(height, width,input_path,mask_pool): 53 | """ 54 | 55 | :param height: 56 | :param width: 57 | :param input_path: 58 | :param mask_pool: 59 | :return: final mask from roads_mask and buildings_mask 60 | 61 | prior: road(1)>bulidings(2) 62 | """ 63 | final_mask=np.zeros((height,width),np.uint8) 64 | for idx,file in enumerate(mask_pool): 65 | ret,img = load_img_by_cv2(input_path+file,grayscale=True) 66 | assert (ret == 0) 67 | label_value=0 68 | if 'road' in file: 69 | label_value =ROAD_VALUE 70 | elif 'building' in file: 71 | label_value=BUILDING_VALUE 72 | # label_value = idx+1 73 | # print("buildings prior") 74 | print("Roads prior") 75 | for i in tqdm(range(height)): 76 | for j in range(width): 77 | if img[i,j]>=FOREGROUND: 78 | # print ("img[{},{}]:{}".format(i,j,img[i,j])) 79 | 80 | if label_value == ROAD_VALUE: 81 | final_mask[i, j] = label_value 82 | elif label_value == BUILDING_VALUE and final_mask[i, j] != ROAD_VALUE: 83 | final_mask[i, j] = label_value 84 | 85 | # if label_value == BUILDING_VALUE: 86 | # final_mask[i, j] = label_value 87 | # elif label_value == ROAD_VALUE and final_mask[i, j] != BUILDING_VALUE: 88 | # final_mask[i, j] = label_value 89 | 90 | final_mask[final_mask == ROAD_VALUE] = 1 91 | final_mask[final_mask == BUILDING_VALUE] = 2 92 | return final_mask 93 | 94 | 95 | 96 | if __name__=='__main__': 97 | 98 | x,y=check_input_file(input_path,mask_pool) 99 | 100 | result_mask=combine_all_mask(x,y,input_path,mask_pool) 101 | 102 | plt.imshow(result_mask, cmap='gray') 103 | plt.title("combined mask") 104 | plt.show() 105 | 106 | cv2.imwrite(output_file,result_mask) 107 | print("Saved to : {}".format(output_file)) -------------------------------------------------------------------------------- /postprocess/mismatch_analyze.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import cv2 4 | import os 5 | import sys 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from tqdm import tqdm 9 | from ulitities.base_functions import load_img_by_cv2, compare_two_image_size 10 | 11 | ref_file ='../../data/test/paper/label/cuiping_label.png' 12 | # 1) jian11_test_label, 2) jiangyou_label, 3) yujiang_test_label, 13 | # 4) cuiping_label, 5) shuangliu_1test_label, 6) tongchuan_test_label 14 | 15 | pred_file ='../../data/test/paper/voted/' \ 16 | 'unet_cuiping_4bands_voted0.png' 17 | 18 | output_file = '../../data/test/paper/voted/' \ 19 | 'unet_cuiping_error_map_7c.png' 20 | 21 | 22 | if __name__=='__main__': 23 | print("[INFO]Reading images") 24 | ret, ref_img = load_img_by_cv2(ref_file, grayscale=True) 25 | if ret !=0: 26 | print("Open file failed:{}".format(ref_file)) 27 | sys.exit(-1) 28 | 29 | ret, pred_img = load_img_by_cv2(pred_file, grayscale=True) 30 | if ret !=0: 31 | print("Open file failed:{}".format(pred_file)) 32 | sys.exit(-2) 33 | 34 | compare_two_image_size(ref_img, pred_img, grayscale=True) 35 | 36 | height, width = ref_img.shape 37 | print("height,width: {},{}".format(height, width)) 38 | 39 | match_img = np.zeros((height, width), np.uint8) 40 | 41 | for j in tqdm(range(height)): 42 | for i in range(width): 43 | # if ref_img[j,i]!=0: 44 | # if pred_img[j,i]==0: 45 | # match_img[j,i]=3 # 漏检的目标 46 | # if ref_img[j,i]==pred_img[j,i]: 47 | # match_img[j,i]=2 # true positive 检测正确的目标 48 | # elif ref_img[j,i]==1 and pred_img[j,i]==2: 49 | # match_img[j,i]=4 # 道路被错分为房屋建筑 50 | # elif ref_img[j,i]==2 and pred_img[j,i]==1: 51 | # match_img[j,i]=5 # 房屋建筑被错分为道路 52 | # else: 53 | # if pred_img[j,i]!=0: 54 | # match_img[j,i]=1 # false negative 背景被错分为目标 55 | # #ref_img[j,i]=pred_img[j,i]=0 # true negative 56 | 57 | if ref_img[j,i]!=0: 58 | if pred_img[j,i]==0: 59 | match_img[j,i]=4 # 漏检的目标 60 | if ref_img[j,i]==pred_img[j,i]: 61 | match_img[j,i]=3 # true positive 检测正确的目标 62 | elif ref_img[j,i]==1 and pred_img[j,i]==2: 63 | match_img[j,i]=5 # 道路被错分为房屋建筑 64 | elif ref_img[j,i]==2 and pred_img[j,i]==1: 65 | match_img[j,i]=6 # 房屋建筑被错分为道路 66 | else: 67 | if pred_img[j, i] == 1: 68 | match_img[j, i] = 1 # false negative 背景被错分为road 69 | if pred_img[j, i] == 2: 70 | match_img[j, i] = 2 71 | 72 | 73 | 74 | plt.imshow(match_img) 75 | plt.show() 76 | 77 | cv2.imwrite(output_file, match_img) 78 | print("Saving into: {}".format(output_file)) 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /postprocess/raster_to_vector.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | 4 | 5 | input_raster_file = '/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/pred/2019-04-11_19-31-32--miandiantaiguo-histmatched/c-ZY302918320140104s.tif' 6 | output_shp_dir = '/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/pred/2019-04-11_19-31-32--miandiantaiguo-histmatched/' 7 | 8 | import gdal,osr,ogr 9 | def polygonize(rasterTemp, outShp, sieveSize=1): 10 | sourceRaster = gdal.Open(rasterTemp) 11 | band = sourceRaster.GetRasterBand(1) 12 | driver = ogr.GetDriverByName("ESRI Shapefile") 13 | # If shapefile already exist, delete it 14 | if os.path.exists(outShp): 15 | driver.DeleteDataSource(outShp) 16 | 17 | outDatasource = driver.CreateDataSource(outShp) 18 | # get proj from raster 19 | srs = osr.SpatialReference() 20 | srs.ImportFromWkt(sourceRaster.GetProjectionRef()) 21 | # create layer with proj 22 | outLayer = outDatasource.CreateLayer(outShp, srs) 23 | # Add class column (1,2...) to shapefile 24 | 25 | newField = ogr.FieldDefn('grid_code', ogr.OFTInteger) 26 | outLayer.CreateField(newField) 27 | 28 | gdal.Polygonize(band, None, outLayer, 0, [], callback=None) 29 | 30 | outDatasource.Destroy() 31 | sourceRaster = None 32 | band = None 33 | 34 | try: 35 | # Add area for each feature 36 | ioShpFile = ogr.Open(outShp, update=1) 37 | 38 | lyr = ioShpFile.GetLayerByIndex(0) 39 | lyr.ResetReading() 40 | 41 | field_defn = ogr.FieldDefn("Area", ogr.OFTReal) 42 | lyr.CreateField(field_defn) 43 | except: 44 | print("Can not add filed of Area!") 45 | 46 | for i in lyr: 47 | # feat = lyr.GetFeature(i) 48 | geom = i.GetGeometryRef() 49 | area = round(geom.GetArea()) 50 | 51 | lyr.SetFeature(i) 52 | i.SetField("Area", area) 53 | lyr.SetFeature(i) 54 | # if area is less than inMinSize or if it isn't forest, remove polygon 55 | if area < sieveSize: 56 | lyr.DeleteFeature(i.GetFID()) 57 | ioShpFile.Destroy() 58 | 59 | return outShp 60 | 61 | 62 | if __name__=='__main__': 63 | 64 | if not os.path.isfile(input_raster_file): 65 | print("Error: Please input a raster file!") 66 | sys.exit(-1) 67 | 68 | if not os.path.isdir(output_shp_dir): 69 | print("Warning: output directory do not exist!") 70 | os.mkdir(output_shp_dir) 71 | 72 | absname = os.path.split(input_raster_file)[1] 73 | print('\n\t[Info] images:{}'.format(absname)) 74 | absname = absname.split('.')[0] 75 | absname = ''.join([absname, '4.shp']) 76 | shp_file = os.path.join(output_shp_dir,absname) 77 | 78 | polygonize(input_raster_file, shp_file) 79 | 80 | -------------------------------------------------------------------------------- /postprocess/vote.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from tqdm import tqdm 4 | import matplotlib.pyplot as plt 5 | 6 | from ulitities.base_functions import load_img_by_cv2 7 | 8 | target_values=[0, 1, 2] # 9 | 10 | input_path = '../../data/test/paper/' 11 | input_masks=['pred_224/combined/unet_jaccard_yujiang_test_4bands_combined.png', 12 | 'pred_288/combined/unet_jaccard_yujiang_test_4bands_combined.png', 13 | 'pred_256/combined/unet_jaccard_yujiang_4bands_combined.png', 14 | 'pred_256/combined/unet_multiclass_yujiang_4bands_combined.png', 15 | 'pred_256/combined/unet_notonehot_yujiang_test_4bands_combined.png', 16 | 'pred_256/combined/unet_onlyjaccard_yujiang_4bands_combined.png'] 17 | 18 | output_file = '../../data/test/paper/voted/unet_yujiang_4bands_voted0.png' 19 | 20 | 21 | def check_input_file(path, masks): 22 | ret, img_1 = load_img_by_cv2(path+masks[0],grayscale=True) 23 | assert (ret == 0) 24 | 25 | height, width = img_1.shape 26 | num_img = len(masks) 27 | 28 | for next_index in range(1,num_img): 29 | next_ret, next_img=load_img_by_cv2(path+masks[next_index],grayscale=True) 30 | assert (next_ret ==0 ) 31 | next_height, next_width = next_img.shape 32 | assert(height==next_height and width==next_width) 33 | return height, width 34 | 35 | 36 | 37 | def vote_per_image(height, width, path, masks): 38 | num_target = len(target_values) 39 | 40 | mask_list = [] 41 | for tt in range(len(masks)): 42 | ret, img = load_img_by_cv2(path+masks[tt],grayscale=True) 43 | assert(ret ==0) 44 | mask_list.append(img) 45 | 46 | vote_mask=np.zeros((height,width), np.uint8) 47 | 48 | for i in tqdm(range(height)): 49 | for j in range(width): 50 | # record=np.zeros(256,np.uint8) 51 | record = np.zeros(num_target, np.uint8) 52 | for n in range(len(mask_list)): 53 | mask=mask_list[n] 54 | pixel=mask[i,j] 55 | record[pixel] +=1 56 | 57 | # """Alarming""" 58 | if record.argmax()==0: # if argmax of 0 = 125 or 255, not prior considering background(0) 59 | record[0] -=1 60 | # print("record:{}".format(record)) 61 | # a = record[1:] 62 | # print("else:{}".format(a)) 63 | # if a.any()>1: 64 | # record[0]=0 65 | 66 | label=record.argmax() 67 | # print ("{},{} label={}".format(i,j,label)) 68 | vote_mask[i,j]=label 69 | # vote_mask[vote_mask==125]=1 70 | # vote_mask[vote_mask == 255] = 2 71 | print(np.unique(vote_mask)) 72 | 73 | return vote_mask 74 | 75 | 76 | 77 | if __name__=='__main__': 78 | x,y = check_input_file(input_path, input_masks) 79 | 80 | final_mask = vote_per_image(x,y,input_path, input_masks) 81 | plt.imshow(final_mask) 82 | plt.show() 83 | 84 | cv2.imwrite(output_file, final_mask) -------------------------------------------------------------------------------- /predict/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/predict/__init__.py -------------------------------------------------------------------------------- /predict/predict_binary_jaccard.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 #224, 256, 288. 320 36 | # step = 128 37 | 38 | im_bands =4 39 | im_type = UINT10 # UINT10,UINT8,UINT16 40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 41 | dict_target={0: 'roads', 1: 'buildings'} 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings 45 | 46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 47 | 48 | # position = 'shuangliu_1test' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test, 49 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test 50 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test 51 | # img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024. 52 | img_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_src.png' 53 | # img_file='/home/omnisky/PycharmProjects/data/test/sample1_12.png' 54 | 55 | 56 | # model_file = ''.join(['../../data/models/sat_urban_rgb/',dict_network[FLAG_USING_NETWORK], '_', 57 | # dict_target[FLAG_TARGET_CLASS],'_binary_jaccard_', str(window_size), '_final.h5']) 58 | model_file ='/home/omnisky/PycharmProjects/data/models/ducha/tuitiantu_jaccardandCross_2018-12-29_09-14-05.h5' 59 | 60 | print("model: {}".format(model_file)) 61 | 62 | if __name__ == '__main__': 63 | 64 | print("[INFO] opening image...") 65 | # ret, input_img = load_img_normalization_by_cv2(img_file) 66 | # if ret !=0: 67 | # print("Open input file failed: {}".format(img_file)) 68 | # sys.exit(-1) 69 | 70 | input_img = load_img_by_gdal(img_file) 71 | if im_type == UINT8: 72 | input_img = input_img / 255.0 73 | elif im_type == UINT10: 74 | input_img = input_img / 1024.0 75 | elif im_type == UINT16: 76 | input_img = input_img / 65535.0 77 | 78 | input_img = np.clip(input_img, 0.0, 1.0) 79 | input_img = input_img.astype(np.float16) # test accuracy 80 | 81 | 82 | abs_filename = os.path.split(img_file)[1] 83 | abs_filename = abs_filename.split(".")[0] 84 | print (abs_filename) 85 | 86 | """checke model file""" 87 | print("model file: {}".format(model_file)) 88 | if not os.path.isfile(model_file): 89 | print("model does not exist:{}".format(model_file)) 90 | sys.exit(-2) 91 | 92 | model = load_model(model_file) 93 | 94 | if FLAG_APPROACH_PREDICT==0: 95 | print("[INFO] predict image by orignal approach\n") 96 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 97 | # output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_', 98 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png']) 99 | output_file = ''.join(['../../data/test/tianfuxinqu/pred/pred_', str(window_size), '/mask_binary_', 100 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_original.png']) 101 | output_file = '/home/omnisky/PycharmProjects/data/originaldata/zs/pred/b_pred_original.png' 102 | print("result save as to: {}".format(output_file)) 103 | cv2.imwrite(output_file, result*128) 104 | 105 | elif FLAG_APPROACH_PREDICT==1: 106 | print("[INFO] predict image by smooth approach\n") 107 | result = predict_img_with_smooth_windowing_multiclassbands( 108 | input_img, 109 | model, 110 | window_size=window_size, 111 | subdivisions=2, 112 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 113 | pred_func=smooth_predict_for_binary_notonehot 114 | ) 115 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_rgb/mask_binary_',str(window_size), 116 | # '_', abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png']) 117 | # output_file = ''.join(['../../data/test/tianfuxinqu/pred/pred_', str(window_size), '/mask_binary_', 118 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_smooth.png']) 119 | output_file = '/home/omnisky/PycharmProjects/data/originaldata/zs/pred/b_pred_512.png' 120 | print("result save as to: {}".format(output_file)) 121 | 122 | cv2.imwrite(output_file, result) 123 | print("Saved to {}".format(output_file)) 124 | 125 | gc.collect() 126 | 127 | 128 | -------------------------------------------------------------------------------- /predict/predict_binary_notonehot.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 36 | # step = 128 37 | 38 | im_bands =4 39 | im_type = UINT10 # UINT10,UINT8,UINT16 40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 41 | dict_target={0: 'roads', 1: 'buildings'} 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings 45 | 46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 47 | 48 | position = 'tongchuan_test' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test, 49 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test 50 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test 51 | # img_file = '../../data/test/sat_test/'+position+'_4bands1024.png' # _rgb, _nrg, __4bands1024. 52 | # img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024. 53 | img_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_src.png' 54 | 55 | # model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_', 56 | # dict_target[FLAG_TARGET_CLASS],'_binary_notonehot_final.h5']) 57 | 58 | model_file = '/home/omnisky/PycharmProjects/data/models/ducha/tuitiantuunet_Crossentropy_256_2018-12-29_14-59-59.h5' 59 | # print("model: {}".format(model_file)) 60 | 61 | if __name__ == '__main__': 62 | 63 | print("[INFO] opening image...") 64 | 65 | input_img = load_img_by_gdal(img_file) 66 | if im_type == UINT8: 67 | input_img = input_img / 255.0 68 | elif im_type == UINT10: 69 | input_img = input_img / 1024.0 70 | elif im_type == UINT16: 71 | input_img = input_img / 65535.0 72 | 73 | input_img = np.clip(input_img, 0.0, 1.0) 74 | input_img = input_img.astype(np.float32) 75 | 76 | 77 | abs_filename = os.path.split(img_file)[1] 78 | abs_filename = abs_filename.split(".")[0] 79 | print (abs_filename) 80 | 81 | """checke model file""" 82 | print("model file: {}".format(model_file)) 83 | if not os.path.isfile(model_file): 84 | print("model does not exist:{}".format(model_file)) 85 | sys.exit(-2) 86 | 87 | model = load_model(model_file) 88 | 89 | if FLAG_APPROACH_PREDICT==0: 90 | print("[INFO] predict image by orignal approach\n") 91 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 92 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_', 93 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_notonehot.png']) 94 | print("result save as to: {}".format(output_file)) 95 | cv2.imwrite(output_file, result*128) 96 | 97 | elif FLAG_APPROACH_PREDICT==1: 98 | print("[INFO] predict image by smooth approach\n") 99 | result = predict_img_with_smooth_windowing_multiclassbands( 100 | input_img, 101 | model, 102 | window_size=window_size, 103 | subdivisions=2, 104 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 105 | pred_func=smooth_predict_for_binary_notonehot 106 | ) 107 | """for single class test""" 108 | result[result<128]=0 109 | result[result>=128]=1 110 | # output_file = '//home/omnisky/PycharmProjects/data/test/shuidao/GF2shuitian22_test_pred.png' 111 | 112 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_4bands/mask_binary_', 113 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_notonehot.png']) 114 | 115 | # output_file = ''.join(['../../data/test/paper/pred/mask_binary_', 116 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_notonehot.png']) 117 | output_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_src_pred.png' 118 | 119 | print("result save as to: {}".format(output_file)) 120 | 121 | 122 | 123 | cv2.imwrite(output_file, result) 124 | 125 | gc.collect() 126 | 127 | 128 | -------------------------------------------------------------------------------- /predict/predict_binary_onehot.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_binary_onehot 23 | from ulitities.base_functions import load_img_normalization, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 36 | # step = 128 37 | im_bands = 3 38 | im_type = UINT8 39 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 40 | dict_target={0: 'roads', 1: 'buildings'} 41 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 42 | 43 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings 44 | 45 | FLAG_APPROACH_PREDICT = 0 # 0: original predict, 1: smooth predict 46 | 47 | # img_file = '../../data/test/GF2_yilong11.png' 48 | # img_file = '../../data/test/sample1.png' 49 | img_file = '../../data/test/lizhou_test_4bands255.png' # sample1_nrg, lizhou_test_4bands255 50 | 51 | # model_file = ''.join(['../../data/models/sat_urban_nrg/',dict_network[FLAG_USING_NETWORK], '_', dict_target[FLAG_TARGET_CLASS],'_binary.h5']) 52 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_buildings_binary2_onehot.h5' 53 | model_file='/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onehot.h5' 54 | print("model: {}".format(model_file)) 55 | 56 | if __name__ == '__main__': 57 | 58 | print("[INFO] opening image...") 59 | if not os.path.isfile(img_file): 60 | print("Please check the input: {}".format(img_file)) 61 | sys.exit(-1) 62 | # ret, input_img = load_img_normalization(img_file) 63 | 64 | input_img = load_img_by_gdal(img_file) 65 | if im_type == UINT8: 66 | input_img = input_img / 255.0 67 | elif im_type == UINT10: 68 | input_img = input_img / 1024.0 69 | elif im_type == UINT16: 70 | input_img = input_img / 65535.0 71 | input_img = np.clip(input_img, 0.0, 1.0) 72 | 73 | abs_filename = os.path.split(img_file)[1] 74 | abs_filename = abs_filename.split(".")[0] 75 | print (abs_filename) 76 | 77 | """checke model file""" 78 | print("model file: {}".format(model_file)) 79 | if not os.path.isfile(model_file): 80 | print("model does not exist:{}".format(model_file)) 81 | sys.exit(-2) 82 | 83 | model = load_model(model_file) 84 | 85 | if FLAG_APPROACH_PREDICT==0: 86 | print("[INFO] predict image by orignal approach\n") 87 | result = orignal_predict_onehot(input_img, im_bands, model, window_size) 88 | output_file = ''.join(['../../data/predict/original_predict_',abs_filename, '.png']) 89 | print("result save as to: {}".format(output_file)) 90 | cv2.imwrite(output_file, result*100) 91 | 92 | elif FLAG_APPROACH_PREDICT==1: 93 | print("[INFO] predict image by smooth approach\n") 94 | result = predict_img_with_smooth_windowing_multiclassbands( 95 | input_img, 96 | model, 97 | window_size=window_size, 98 | subdivisions=2, 99 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 100 | pred_func=smooth_predict_for_binary_onehot 101 | ) 102 | output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/mask_binary_', 103 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'.png']) 104 | print("result save as to: {}".format(output_file)) 105 | 106 | cv2.imwrite(output_file, result) 107 | 108 | gc.collect() 109 | 110 | 111 | -------------------------------------------------------------------------------- /predict/predict_binary_onlyjaccard.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import gc 11 | import argparse 12 | # from keras.preprocessing.image import img_to_array 13 | from keras.models import load_model 14 | from sklearn.preprocessing import LabelEncoder 15 | from PIL import Image 16 | from keras.preprocessing.image import img_to_array 17 | 18 | from keras import backend as K 19 | K.set_image_dim_ordering('tf') 20 | K.clear_session() 21 | 22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot 23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int 26 | 27 | """ 28 | The following global variables should be put into meta data file 29 | """ 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 31 | 32 | 33 | target_class =1 34 | 35 | window_size = 256 36 | # step = 128 37 | 38 | im_bands =4 39 | im_type = UINT10 # UINT10,UINT8,UINT16 40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 41 | dict_target={0: 'roads', 1: 'buildings'} 42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 43 | 44 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings 45 | 46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict 47 | 48 | position = 'jiangyou' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test, 49 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test 50 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test 51 | img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024. 52 | # img_file = '../../data/test/shuidao.png' 53 | 54 | 55 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_', 56 | dict_target[FLAG_TARGET_CLASS],'_binary_onlyjaccard_final.h5']) 57 | 58 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onlyjaccard_2018-09-29_18-55-11.h5' 59 | print("model: {}".format(model_file)) 60 | 61 | if __name__ == '__main__': 62 | 63 | print("[INFO] opening image...") 64 | # ret, input_img = load_img_normalization_by_cv2(img_file) 65 | # if ret !=0: 66 | # print("Open input file failed: {}".format(img_file)) 67 | # sys.exit(-1) 68 | 69 | input_img = load_img_by_gdal(img_file) 70 | if im_type == UINT8: 71 | input_img = input_img / 255.0 72 | elif im_type == UINT10: 73 | input_img = input_img / 1024.0 74 | elif im_type == UINT16: 75 | input_img = input_img / 65535.0 76 | 77 | input_img = np.clip(input_img, 0.0, 1.0) 78 | 79 | 80 | abs_filename = os.path.split(img_file)[1] 81 | abs_filename = abs_filename.split(".")[0] 82 | print (abs_filename) 83 | 84 | """checke model file""" 85 | print("model file: {}".format(model_file)) 86 | if not os.path.isfile(model_file): 87 | print("model does not exist:{}".format(model_file)) 88 | sys.exit(-2) 89 | 90 | model = load_model(model_file) 91 | 92 | if FLAG_APPROACH_PREDICT==0: 93 | print("[INFO] predict image by orignal approach\n") 94 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size) 95 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_', 96 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png']) 97 | print("result save as to: {}".format(output_file)) 98 | cv2.imwrite(output_file, result*128) 99 | 100 | elif FLAG_APPROACH_PREDICT==1: 101 | print("[INFO] predict image by smooth approach\n") 102 | result = predict_img_with_smooth_windowing_multiclassbands( 103 | input_img, 104 | model, 105 | window_size=window_size, 106 | subdivisions=2, 107 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 108 | pred_func=smooth_predict_for_binary_notonehot 109 | ) 110 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_4bands/mask_binary_', 111 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png']) 112 | 113 | output_file = ''.join(['../../data/test/paper/pred/mask_binary_', 114 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_onlyjaccard.png']) 115 | print("result save as to: {}".format(output_file)) 116 | 117 | cv2.imwrite(output_file, result) 118 | 119 | gc.collect() 120 | 121 | 122 | -------------------------------------------------------------------------------- /predict/predict_multiclass.py: -------------------------------------------------------------------------------- 1 | #coding:utf8 2 | """" 3 | This is main procedure for remote sensing image semantic segmentation 4 | 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | import sys 10 | import argparse 11 | # from keras.preprocessing.image import img_to_array 12 | from keras.models import load_model 13 | from sklearn.preprocessing import LabelEncoder 14 | from PIL import Image 15 | from keras.preprocessing.image import img_to_array 16 | 17 | from keras import backend as K 18 | K.set_image_dim_ordering('tf') 19 | K.clear_session() 20 | 21 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_multiclass 22 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16 23 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 24 | 25 | """ 26 | The following global variables should be put into meta data file 27 | """ 28 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 29 | 30 | 31 | window_size = 256 32 | step = 128 33 | 34 | im_bands = 4 35 | im_type = UINT10 #UINT8, UINT10, UINT16 36 | 37 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'} 38 | dict_target={0: 'roads', 1: 'buildings'} 39 | target_class=len(dict_target) 40 | 41 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet; 42 | 43 | FLAG_APPROACH_PREDICT=1 # 0: original predict, 1: smooth predict 44 | 45 | position = 'tongchuan_test' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test, 46 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test 47 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test 48 | # img_file = '../../data/test/sat_test/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024. 49 | img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024. 50 | # img_file = '../../data/test/shuidao.png' 51 | 52 | # img_file = '../../data/test/sat_test/cuiping_4bands1024.png' # jian11_test_nrg, sample1_nrg 53 | 54 | # model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_multiclass_final.h5']) 55 | model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_multiclass_final.h5' 56 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_multiclass.h5' 57 | 58 | if __name__ == '__main__': 59 | 60 | print("[INFO] opening image...") 61 | 62 | # ret, input_img = load_img_normalization_by_cv2(img_file) 63 | 64 | input_img = load_img_by_gdal(img_file) 65 | if im_type == UINT8: 66 | input_img = input_img / 255.0 67 | elif im_type == UINT10: 68 | input_img = input_img / 1024.0 69 | elif im_type == UINT16: 70 | input_img = input_img / 65535.0 71 | input_img = np.clip(input_img, 0.0, 1.0) 72 | 73 | 74 | abs_filename = os.path.split(img_file)[1] 75 | abs_filename = abs_filename.split(".")[0] 76 | print (abs_filename) 77 | 78 | """checke model file""" 79 | print("model file: {}".format(model_file)) 80 | if not os.path.isfile(model_file): 81 | print("model does not exist:{}".format(model_file)) 82 | sys.exit(-2) 83 | 84 | model= load_model(model_file) 85 | 86 | if FLAG_APPROACH_PREDICT==0: 87 | print("[INFO] predict image by orignal approach\n") 88 | result = orignal_predict_onehot(input_img, im_bands, model, window_size) 89 | output_file = ''.join(['../../data/predict/original_predict_',abs_filename, '.png']) 90 | print("result save as to: {}".format(output_file)) 91 | cv2.imwrite(output_file, result*128) 92 | 93 | elif FLAG_APPROACH_PREDICT==1: 94 | print("[INFO] predict image by smooth approach\n") 95 | result = predict_img_with_smooth_windowing_multiclassbands( 96 | input_img, 97 | model, 98 | window_size=window_size, 99 | subdivisions=2, 100 | real_classes=target_class, # output channels = 是真的类别,总类别-背景 101 | pred_func=smooth_predict_for_multiclass 102 | ) 103 | 104 | for b in range(target_class): 105 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK], '/sat_4bands/mask_multiclass_', 106 | # abs_filename, '_', dict_target[b], '.png']) 107 | 108 | output_file = ''.join(['../../data/test/paper/pred/mask_multiclass_', 109 | abs_filename, '_', dict_target[b], '.png']) 110 | print("result save as to: {}".format(output_file)) 111 | cv2.imwrite(output_file, result[:,:,b]) 112 | 113 | 114 | -------------------------------------------------------------------------------- /samples_produce/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/samples_produce/__init__.py -------------------------------------------------------------------------------- /samples_produce/check_original_labels_froNodata.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import numpy as np 4 | import sys 5 | import os 6 | import cv2 7 | from tqdm import tqdm 8 | from ulitities.base_functions import load_img_by_cv2, get_file 9 | 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | # segnet_labels = [0, 1, 2] # have not test for segnetlabels 14 | # unet_labels = [0, 1] 15 | 16 | 17 | # """for unet""" 18 | # input_src_path = '../../data/originaldata/unet/roads/src/' 19 | # input_label_path = '../../data/originaldata/unet/roads/label/' 20 | 21 | # """for segnet""" 22 | # input_src_path = '../../data/originaldata/segnet/src/' 23 | # input_label_path = '../../data/originaldata/segnet/label/' 24 | 25 | 26 | input_label_path = '../../data/originaldata/SatRGB/label/' 27 | valid_labels = [0, 1, 2] 28 | 29 | HAS_INVALID_VALUE = False 30 | # FLAG_BINARY_LABELS = True 31 | 32 | 33 | def make_label_valid(img, false_values): 34 | height, width = img.shape 35 | 36 | tp_img = img.reshape((height*width)) 37 | for inv_lab in false_values: 38 | index = np.where(tp_img==inv_lab) 39 | tp_img[index]=0 40 | tp_img = tp_img.reshape((height, width)) 41 | 42 | return tp_img 43 | 44 | 45 | 46 | # 47 | # for i in range(height): 48 | # for j in range(width): 49 | # tmp = img[i,j] 50 | # if not tmp in true_values: 51 | # print("img[{},{}]: {}".format(i,j,tmp)) 52 | # img[i,j]=0 53 | # return img 54 | 55 | 56 | if __name__ == '__main__': 57 | files,num = get_file(input_label_path) 58 | assert (num!=0) 59 | 60 | # valid_labels = [] 61 | # if FLAG_USING_UNET: 62 | # valid_labels = unet_labels 63 | # else: 64 | # valid_labels = segnet_labels 65 | 66 | for label_file in tqdm(files): 67 | # label_file = input_label_path + os.path.split(src_file)[1] 68 | # 69 | # ret,src_img = load_img(src_file) 70 | # assert(ret==0) 71 | 72 | ret,label_img = load_img_by_cv2(label_file, grayscale=True) 73 | assert (ret == 0) 74 | 75 | local_labels = np.unique(label_img) 76 | invalid_labels=[] 77 | 78 | for tmp in local_labels: 79 | if tmp not in valid_labels: 80 | invalid_labels.append(tmp) 81 | print ("\nWarning: some label is not valid value") 82 | print ("\nFile: {}".format(label_file)) 83 | HAS_INVALID_VALUE = True 84 | 85 | 86 | if HAS_INVALID_VALUE == True: 87 | new_label_img = make_label_valid(label_img, invalid_labels) 88 | new_label_file = os.path.split(label_file)[0]+'/new_'+os.path.split(label_file)[1] 89 | cv2.imwrite(new_label_file, new_label_img) 90 | HAS_INVALID_VALUE = False 91 | label_img = new_label_img 92 | 93 | plt.imshow(label_img, cmap='gray') 94 | plt.show() 95 | 96 | print("Check completely!\n") 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /samples_produce/label_visulise.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import cv2 4 | import sys, os 5 | from ulitities.base_functions import load_img,get_file 6 | 7 | input_path = '../../data/traindata/unet/buildings/label/' 8 | output_path = '../../data/traindata/unet/buildings/visulize/' 9 | 10 | 11 | if __name__ == '__main__': 12 | 13 | if not os.path.isdir(input_path): 14 | print("No input directory:{}".format(input_path)) 15 | sys.exit(-1) 16 | if not os.path.isdir(output_path): 17 | print("No output directory:{}".format(output_path)) 18 | os.mkdir(output_path) 19 | 20 | srcfiles, tt= get_file(input_path) 21 | assert(tt!=0) 22 | 23 | for index, file in enumerate(srcfiles): 24 | ret,img = load_img(file,grayscale=True) 25 | assert(ret==0) 26 | 27 | img = img*100 28 | filename = os.path.split(file)[1] 29 | outfile = os.path.join(output_path,filename) 30 | print(outfile) 31 | 32 | cv2.imwrite(outfile, img) 33 | 34 | -------------------------------------------------------------------------------- /segmentation_models/__init__.py: -------------------------------------------------------------------------------- 1 | name = "segmentation_models" 2 | 3 | from .__version__ import __version__ 4 | 5 | from .unet import Unet 6 | from .fpn import FPN 7 | from .linknet import Linknet 8 | from .pspnet import PSPNet 9 | 10 | from . import metrics 11 | from . import losses -------------------------------------------------------------------------------- /segmentation_models/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 2, 0) 2 | 3 | __version__ = '.'.join(map(str, VERSION)) -------------------------------------------------------------------------------- /segmentation_models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from classification_models import Classifiers 2 | from classification_models import resnext 3 | 4 | from . import inception_resnet_v2 as irv2 5 | from . import inception_v3 as iv3 6 | from . import mobilenet as mbn 7 | from . import mobilenetv2 as mbn2 8 | 9 | # replace backbones with others, which have corrected padding mode in first pooling 10 | Classifiers._models.update({ 11 | 'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input], 12 | 'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input], 13 | 'resnext50': [resnext.ResNeXt50, resnext.models.preprocess_input], 14 | 'resnext101': [resnext.ResNeXt101, resnext.models.preprocess_input], 15 | 'mobilenet': [mbn.MobileNet, mbn.preprocess_input], 16 | 'mobilenetv2': [mbn2.MobileNetV2, mbn2.preprocess_input], 17 | }) 18 | 19 | DEFAULT_FEATURE_LAYERS = { 20 | 21 | # List of layers to take features from backbone in the following order: 22 | # (x16, x8, x4, x2, x1) - `x4` mean that features has 4 times less spatial 23 | # resolution (Height x Width) than input image. 24 | 25 | # VGG 26 | 'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2', 'block1_conv2'), 27 | 'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2', 'block1_conv2'), 28 | 29 | # ResNets 30 | 'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 31 | 'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 32 | 'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 33 | 'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 34 | 'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 35 | 36 | # ResNeXt 37 | 'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 38 | 'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 39 | 40 | # Inception 41 | 'inceptionv3': (228, 86, 16, 9), 42 | 'inceptionresnetv2': (594, 260, 16, 9), 43 | 44 | # DenseNet 45 | 'densenet121': (311, 139, 51, 4), 46 | 'densenet169': (367, 139, 51, 4), 47 | 'densenet201': (479, 139, 51, 4), 48 | 49 | # SE models 50 | 'seresnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 51 | 'seresnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'), 52 | 'seresnet50': (233, 129, 59, 4), 53 | 'seresnet101': (522, 129, 59, 4), 54 | 'seresnet152': (811, 197, 59, 4), 55 | 'seresnext50': (1065, 577, 251, 4), 56 | 'seresnext101': (2442, 577, 251, 4), 57 | 'senet154': (6837, 1614, 451, 12), 58 | 59 | # Mobile Nets 60 | 'mobilenet': ('conv_pw_11_relu', 'conv_pw_5_relu', 'conv_pw_3_relu', 'conv_pw_1_relu'), 61 | 'mobilenetv2': ('block_13_expand_relu', 'block_6_expand_relu', 'block_3_expand_relu', 'block_1_expand_relu'), 62 | 63 | } 64 | 65 | 66 | def get_names(): 67 | return list(DEFAULT_FEATURE_LAYERS.keys()) 68 | 69 | 70 | def get_feature_layers(name, n=5): 71 | return DEFAULT_FEATURE_LAYERS[name][:n] 72 | 73 | 74 | def get_backbone(name, *args, **kwargs): 75 | return Classifiers.get_classifier(name)(*args, **kwargs) 76 | 77 | 78 | def get_preprocessing(name): 79 | return Classifiers.get_preprocessing(name) 80 | -------------------------------------------------------------------------------- /segmentation_models/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import Conv2DBlock 2 | from .layers import ResizeImage -------------------------------------------------------------------------------- /segmentation_models/common/blocks.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D 2 | from keras.layers import Activation 3 | from keras.layers import BatchNormalization 4 | 5 | 6 | def Conv2DBlock(n_filters, kernel_size, 7 | activation='relu', 8 | use_batchnorm=True, 9 | name='conv_block', 10 | **kwargs): 11 | """Extension of Conv2D layer with batchnorm""" 12 | def layer(input_tensor): 13 | 14 | x = Conv2D(n_filters, kernel_size, use_bias=not(use_batchnorm), 15 | name=name+'_conv', **kwargs)(input_tensor) 16 | if use_batchnorm: 17 | x = BatchNormalization(name=name+'_bn',)(x) 18 | x = Activation(activation, name=name+'_'+activation)(x) 19 | 20 | return x 21 | return layer -------------------------------------------------------------------------------- /segmentation_models/common/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def transpose_shape(shape, target_format, spatial_axes): 6 | """Converts a tuple or a list to the correct `data_format`. 7 | It does so by switching the positions of its elements. 8 | # Arguments 9 | shape: Tuple or list, often representing shape, 10 | corresponding to `'channels_last'`. 11 | target_format: A string, either `'channels_first'` or `'channels_last'`. 12 | spatial_axes: A tuple of integers. 13 | Correspond to the indexes of the spatial axes. 14 | For example, if you pass a shape 15 | representing (batch_size, timesteps, rows, cols, channels), 16 | then `spatial_axes=(2, 3)`. 17 | # Returns 18 | A tuple or list, with the elements permuted according 19 | to `target_format`. 20 | # Example 21 | # Raises 22 | ValueError: if `value` or the global `data_format` invalid. 23 | """ 24 | if target_format == 'channels_first': 25 | new_values = shape[:spatial_axes[0]] 26 | new_values += (shape[-1],) 27 | new_values += tuple(shape[x] for x in spatial_axes) 28 | 29 | if isinstance(shape, list): 30 | return list(new_values) 31 | return new_values 32 | elif target_format == 'channels_last': 33 | return shape 34 | else: 35 | raise ValueError('The `data_format` argument must be one of ' 36 | '"channels_first", "channels_last". Received: ' + 37 | str(target_format)) 38 | 39 | 40 | def permute_dimensions(x, pattern): 41 | """Permutes axes in a tensor. 42 | # Arguments 43 | x: Tensor or variable. 44 | pattern: A tuple of 45 | dimension indices, e.g. `(0, 2, 1)`. 46 | # Returns 47 | A tensor. 48 | """ 49 | return tf.transpose(x, perm=pattern) 50 | 51 | 52 | def int_shape(x): 53 | """Returns the shape of tensor or variable as a tuple of int or None entries. 54 | # Arguments 55 | x: Tensor or variable. 56 | # Returns 57 | A tuple of integers (or None entries). 58 | """ 59 | if hasattr(x, '_keras_shape'): 60 | return x._keras_shape 61 | try: 62 | return tuple(x.get_shape().as_list()) 63 | except ValueError: 64 | return None 65 | 66 | 67 | def resize_images(x, 68 | height_factor, 69 | width_factor, 70 | data_format, 71 | interpolation='nearest'): 72 | """Resizes the images contained in a 4D tensor. 73 | # Arguments 74 | x: Tensor or variable to resize. 75 | height_factor: Positive integer. 76 | width_factor: Positive integer. 77 | data_format: string, `"channels_last"` or `"channels_first"`. 78 | interpolation: A string, one of `nearest` or `bilinear`. 79 | # Returns 80 | A tensor. 81 | # Raises 82 | ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`. 83 | """ 84 | if data_format == 'channels_first': 85 | rows, cols = 2, 3 86 | else: 87 | rows, cols = 1, 2 88 | 89 | original_shape = int_shape(x) 90 | new_shape = tf.shape(x)[rows:cols + 1] 91 | new_shape *= tf.constant(np.array([height_factor, width_factor], dtype='int32')) 92 | 93 | if data_format == 'channels_first': 94 | x = permute_dimensions(x, [0, 2, 3, 1]) 95 | if interpolation == 'nearest': 96 | x = tf.image.resize_nearest_neighbor(x, new_shape) 97 | elif interpolation == 'bilinear': 98 | x = tf.image.resize_bilinear(x, new_shape, align_corners=True) 99 | else: 100 | raise ValueError('interpolation should be one ' 101 | 'of "nearest" or "bilinear".') 102 | if data_format == 'channels_first': 103 | x = permute_dimensions(x, [0, 3, 1, 2]) 104 | 105 | if original_shape[rows] is None: 106 | new_height = None 107 | else: 108 | new_height = original_shape[rows] * height_factor 109 | 110 | if original_shape[cols] is None: 111 | new_width = None 112 | else: 113 | new_width = original_shape[cols] * width_factor 114 | 115 | output_shape = (None, new_height, new_width, None) 116 | x.set_shape(transpose_shape(output_shape, data_format, spatial_axes=(1, 2))) 117 | return x -------------------------------------------------------------------------------- /segmentation_models/common/layers.py: -------------------------------------------------------------------------------- 1 | from keras.engine import Layer 2 | from keras.engine import InputSpec 3 | from keras.utils import conv_utils 4 | from keras.legacy import interfaces 5 | from keras.utils.generic_utils import get_custom_objects 6 | 7 | from .functions import resize_images 8 | 9 | 10 | class ResizeImage(Layer): 11 | """ResizeImage layer for 2D inputs. 12 | Repeats the rows and columns of the data 13 | by factor[0] and factor[1] respectively. 14 | # Arguments 15 | factor: int, or tuple of 2 integers. 16 | The upsampling factors for rows and columns. 17 | data_format: A string, 18 | one of `"channels_last"` or `"channels_first"`. 19 | The ordering of the dimensions in the inputs. 20 | `"channels_last"` corresponds to inputs with shape 21 | `(batch, height, width, channels)` while `"channels_first"` 22 | corresponds to inputs with shape 23 | `(batch, channels, height, width)`. 24 | It defaults to the `image_data_format` value found in your 25 | Keras config file at `~/.keras/keras.json`. 26 | If you never set it, then it will be "channels_last". 27 | interpolation: A string, one of `nearest` or `bilinear`. 28 | Note that CNTK does not support yet the `bilinear` upscaling 29 | and that with Theano, only `factor=(2, 2)` is possible. 30 | # Input shape 31 | 4D tensor with shape: 32 | - If `data_format` is `"channels_last"`: 33 | `(batch, rows, cols, channels)` 34 | - If `data_format` is `"channels_first"`: 35 | `(batch, channels, rows, cols)` 36 | # Output shape 37 | 4D tensor with shape: 38 | - If `data_format` is `"channels_last"`: 39 | `(batch, upsampled_rows, upsampled_cols, channels)` 40 | - If `data_format` is `"channels_first"`: 41 | `(batch, channels, upsampled_rows, upsampled_cols)` 42 | """ 43 | 44 | @interfaces.legacy_upsampling2d_support 45 | def __init__(self, factor=(2, 2), data_format='channels_last', interpolation='nearest', **kwargs): 46 | super(ResizeImage, self).__init__(**kwargs) 47 | self.data_format = data_format 48 | self.factor = conv_utils.normalize_tuple(factor, 2, 'factor') 49 | self.input_spec = InputSpec(ndim=4) 50 | if interpolation not in ['nearest', 'bilinear']: 51 | raise ValueError('interpolation should be one ' 52 | 'of "nearest" or "bilinear".') 53 | self.interpolation = interpolation 54 | 55 | def compute_output_shape(self, input_shape): 56 | if self.data_format == 'channels_first': 57 | height = self.factor[0] * input_shape[2] if input_shape[2] is not None else None 58 | width = self.factor[1] * input_shape[3] if input_shape[3] is not None else None 59 | return (input_shape[0], 60 | input_shape[1], 61 | height, 62 | width) 63 | elif self.data_format == 'channels_last': 64 | height = self.factor[0] * input_shape[1] if input_shape[1] is not None else None 65 | width = self.factor[1] * input_shape[2] if input_shape[2] is not None else None 66 | return (input_shape[0], 67 | height, 68 | width, 69 | input_shape[3]) 70 | 71 | def call(self, inputs): 72 | return resize_images(inputs, self.factor[0], self.factor[1], 73 | self.data_format, self.interpolation) 74 | 75 | def get_config(self): 76 | config = {'factor': self.factor, 77 | 'data_format': self.data_format} 78 | base_config = super(ResizeImage, self).get_config() 79 | return dict(list(base_config.items()) + list(config.items())) 80 | 81 | 82 | get_custom_objects().update({'ResizeImage': ResizeImage}) 83 | -------------------------------------------------------------------------------- /segmentation_models/fpn/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FPN 2 | 3 | -------------------------------------------------------------------------------- /segmentation_models/fpn/blocks.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Add 2 | 3 | from ..common import Conv2DBlock 4 | from ..common import ResizeImage 5 | from ..utils import to_tuple 6 | 7 | 8 | def pyramid_block(pyramid_filters=256, segmentation_filters=128, upsample_rate=2, 9 | use_batchnorm=False, stage=0): 10 | """ 11 | Pyramid block according to: 12 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 13 | 14 | This block generate `M` and `P` blocks. 15 | 16 | Args: 17 | pyramid_filters: integer, filters in `M` block of top-down FPN branch 18 | segmentation_filters: integer, number of filters in segmentation head, 19 | basically filters in convolution layers between `M` and `P` blocks 20 | upsample_rate: integer, uspsample rate for `M` block of top-down FPN branch 21 | use_batchnorm: bool, include batchnorm in convolution blocks 22 | 23 | Returns: 24 | Pyramid block function (as Keras layers functional API) 25 | """ 26 | def layer(c, m=None): 27 | 28 | x = Conv2DBlock(pyramid_filters, (1, 1), 29 | padding='same', 30 | use_batchnorm=use_batchnorm, 31 | name='pyramid_stage_{}'.format(stage))(c) 32 | 33 | if m is not None: 34 | up = ResizeImage(to_tuple(upsample_rate))(m) 35 | x = Add()([x, up]) 36 | 37 | # segmentation head 38 | p = Conv2DBlock(segmentation_filters, (3, 3), 39 | padding='same', 40 | use_batchnorm=use_batchnorm, 41 | name='segm1_stage_{}'.format(stage))(x) 42 | 43 | p = Conv2DBlock(segmentation_filters, (3, 3), 44 | padding='same', 45 | use_batchnorm=use_batchnorm, 46 | name='segm2_stage_{}'.format(stage))(p) 47 | m = x 48 | 49 | return m, p 50 | return layer 51 | -------------------------------------------------------------------------------- /segmentation_models/fpn/builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.layers import Conv2D 3 | from keras.layers import Concatenate 4 | from keras.layers import Activation 5 | from keras.layers import SpatialDropout2D 6 | from keras.models import Model 7 | 8 | from .blocks import pyramid_block 9 | from ..common import ResizeImage 10 | from ..common import Conv2DBlock 11 | from ..utils import extract_outputs, to_tuple 12 | 13 | 14 | def build_fpn(backbone, 15 | fpn_layers, 16 | classes=21, 17 | activation='softmax', 18 | upsample_rates=(2,2,2), 19 | last_upsample=4, 20 | pyramid_filters=256, 21 | segmentation_filters=128, 22 | use_batchnorm=False, 23 | dropout=None, 24 | interpolation='bilinear'): 25 | """ 26 | Implementation of FPN head for segmentation models according to: 27 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 28 | 29 | Args: 30 | backbone: Keras `Model`, some classification model without top 31 | layers: list of layer names or indexes, used for pyramid building 32 | classes: int, number of output feature maps 33 | activation: activation in last layer, e.g. 'sigmoid' or 'softmax' 34 | upsample_rates: tuple of integers, scaling rates between pyramid blocks 35 | pyramid_filters: int, number of filters in `M` blocks of top-down FPN branch 36 | segmentation_filters: int, number of filters in `P` blocks of FPN 37 | last_upsample: rate for upsumpling concatenated pyramid predictions to 38 | match spatial resolution of input data 39 | last_upsampling_type: 'nn' or 'bilinear' 40 | dropout: float [0, 1), dropout rate 41 | use_batchnorm: bool, include batch normalization to FPN between `conv` 42 | and `relu` layers 43 | 44 | Returns: 45 | model: Keras `Model` 46 | """ 47 | 48 | if len(upsample_rates) != len(fpn_layers): 49 | raise ValueError('Number of intermediate feature maps and upsample steps should match') 50 | 51 | # extract model layer outputs 52 | outputs = extract_outputs(backbone, fpn_layers, include_top=True) 53 | 54 | # add upsample rate `1` for first block 55 | upsample_rates = [1] + list(upsample_rates) 56 | 57 | # top - down path, build pyramid 58 | m = None 59 | pyramid = [] 60 | for i, c in enumerate(outputs): 61 | m, p = pyramid_block(pyramid_filters=pyramid_filters, 62 | segmentation_filters=segmentation_filters, 63 | upsample_rate=upsample_rates[i], 64 | use_batchnorm=use_batchnorm, 65 | stage=i)(c, m) 66 | pyramid.append(p) 67 | 68 | 69 | # upsample and concatenate all pyramid layer 70 | upsampled_pyramid = [] 71 | 72 | for i, p in enumerate(pyramid[::-1]): 73 | if upsample_rates[i] > 1: 74 | upsample_rate = to_tuple(np.prod(upsample_rates[:i+1])) 75 | p = ResizeImage(upsample_rate, interpolation=interpolation)(p) 76 | upsampled_pyramid.append(p) 77 | 78 | x = Concatenate()(upsampled_pyramid) 79 | 80 | # final convolution 81 | n_filters = segmentation_filters * len(pyramid) 82 | x = Conv2DBlock(n_filters, (3, 3), use_batchnorm=use_batchnorm, padding='same')(x) 83 | if dropout is not None: 84 | x = SpatialDropout2D(dropout)(x) 85 | 86 | x = Conv2D(classes, (3, 3), padding='same')(x) 87 | 88 | # upsampling to original spatial resolution 89 | x = ResizeImage(to_tuple(last_upsample), interpolation=interpolation)(x) 90 | 91 | # activation 92 | x = Activation(activation)(x) 93 | 94 | model = Model(backbone.input, x) 95 | return model 96 | -------------------------------------------------------------------------------- /segmentation_models/fpn/model.py: -------------------------------------------------------------------------------- 1 | from .builder import build_fpn 2 | from ..backbones import get_backbone, get_feature_layers 3 | from ..utils import freeze_model 4 | from ..utils import legacy_support 5 | 6 | old_args_map = { 7 | 'freeze_encoder': 'encoder_freeze', 8 | 'fpn_layers': 'encoder_features', 9 | 'use_batchnorm': 'pyramid_use_batchnorm', 10 | 'dropout': 'pyramid_dropout', 11 | 'interpolation': 'final_interpolation', 12 | 'upsample_rates': None, # removed 13 | 'last_upsample': None, # removed 14 | } 15 | 16 | 17 | @legacy_support(old_args_map) 18 | def FPN(backbone_name='vgg16', 19 | input_shape=(None, None, 3), 20 | input_tensor=None, 21 | classes=21, 22 | activation='softmax', 23 | encoder_weights='imagenet', 24 | encoder_freeze=False, 25 | encoder_features='default', 26 | pyramid_block_filters=256, 27 | pyramid_use_batchnorm=True, 28 | pyramid_dropout=None, 29 | final_interpolation='bilinear', 30 | **kwargs): 31 | """FPN_ is a fully convolution neural network for image semantic segmentation 32 | 33 | Args: 34 | backbone_name: name of classification model (without last dense layers) used as feature 35 | extractor to build segmentation model. 36 | input_shape: shape of input data/image ``(H, W, C)``, in general 37 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be 38 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. 39 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model 40 | (works only if ``encoder_weights`` is ``None``). 41 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 42 | activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). 43 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 44 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 45 | encoder_features: a list of layer numbers or names starting from top of the model. 46 | Each of these layers will be used to build features pyramid. If ``default`` is used 47 | layer names are taken from ``DEFAULT_FEATURE_PYRAMID_LAYERS``. 48 | pyramid_block_filters: a number of filters in Feature Pyramid Block of FPN_. 49 | pyramid_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 50 | is used. 51 | pyramid_dropout: spatial dropout rate for feature pyramid in range (0, 1). 52 | final_interpolation: interpolation type for upsampling layers, on of ``nearest``, ``bilinear``. 53 | 54 | Returns: 55 | ``keras.models.Model``: **FPN** 56 | 57 | .. _FPN: 58 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf 59 | 60 | """ 61 | 62 | backbone = get_backbone(backbone_name, 63 | input_shape=input_shape, 64 | input_tensor=input_tensor, 65 | weights=encoder_weights, 66 | include_top=False) 67 | 68 | if encoder_features == 'default': 69 | encoder_features = get_feature_layers(backbone_name, n=3) 70 | 71 | upsample_rates = [2] * len(encoder_features) 72 | last_upsample = 2 ** (5 - len(encoder_features)) 73 | 74 | model = build_fpn(backbone, encoder_features, 75 | classes=classes, 76 | pyramid_filters=pyramid_block_filters, 77 | segmentation_filters=pyramid_block_filters // 2, 78 | upsample_rates=upsample_rates, 79 | use_batchnorm=pyramid_use_batchnorm, 80 | dropout=pyramid_dropout, 81 | last_upsample=last_upsample, 82 | interpolation=final_interpolation, 83 | activation=activation) 84 | 85 | if encoder_freeze: 86 | freeze_model(backbone) 87 | 88 | model.name = 'fpn-{}'.format(backbone.name) 89 | 90 | return model 91 | -------------------------------------------------------------------------------- /segmentation_models/linknet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Linknet 2 | -------------------------------------------------------------------------------- /segmentation_models/linknet/blocks.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.layers import Conv2DTranspose as Transpose 3 | from keras.layers import UpSampling2D 4 | from keras.layers import Conv2D 5 | from keras.layers import BatchNormalization 6 | from keras.layers import Activation 7 | from keras.layers import Add 8 | 9 | 10 | def handle_block_names(stage): 11 | conv_name = 'decoder_stage{}_conv'.format(stage) 12 | bn_name = 'decoder_stage{}_bn'.format(stage) 13 | relu_name = 'decoder_stage{}_relu'.format(stage) 14 | up_name = 'decoder_stage{}_upsample'.format(stage) 15 | return conv_name, bn_name, relu_name, up_name 16 | 17 | 18 | def ConvRelu(filters, 19 | kernel_size, 20 | use_batchnorm=False, 21 | conv_name='conv', 22 | bn_name='bn', 23 | relu_name='relu'): 24 | 25 | def layer(x): 26 | 27 | x = Conv2D(filters, 28 | kernel_size, 29 | padding="same", 30 | name=conv_name, 31 | use_bias=not(use_batchnorm))(x) 32 | 33 | if use_batchnorm: 34 | x = BatchNormalization(name=bn_name)(x) 35 | 36 | x = Activation('relu', name=relu_name)(x) 37 | 38 | return x 39 | return layer 40 | 41 | 42 | def Conv2DUpsample(filters, 43 | upsample_rate, 44 | kernel_size=(3,3), 45 | up_name='up', 46 | conv_name='conv', 47 | **kwargs): 48 | 49 | def layer(input_tensor): 50 | x = UpSampling2D(upsample_rate, name=up_name)(input_tensor) 51 | x = Conv2D(filters, 52 | kernel_size, 53 | padding='same', 54 | name=conv_name, 55 | **kwargs)(x) 56 | return x 57 | return layer 58 | 59 | 60 | def Conv2DTranspose(filters, 61 | upsample_rate, 62 | kernel_size=(4,4), 63 | up_name='up', 64 | **kwargs): 65 | 66 | if not tuple(upsample_rate) == (2,2): 67 | raise NotImplementedError( 68 | 'Conv2DTranspose support only upsample_rate=(2, 2), got {}'.format(upsample_rate)) 69 | 70 | def layer(input_tensor): 71 | x = Transpose(filters, 72 | kernel_size=kernel_size, 73 | strides=upsample_rate, 74 | padding='same', 75 | name=up_name)(input_tensor) 76 | return x 77 | return layer 78 | 79 | 80 | def UpsampleBlock(filters, 81 | upsample_rate, 82 | kernel_size, 83 | use_batchnorm=False, 84 | upsample_layer='upsampling', 85 | conv_name='conv', 86 | bn_name='bn', 87 | relu_name='relu', 88 | up_name='up', 89 | **kwargs): 90 | 91 | if upsample_layer == 'upsampling': 92 | UpBlock = Conv2DUpsample 93 | 94 | elif upsample_layer == 'transpose': 95 | UpBlock = Conv2DTranspose 96 | 97 | else: 98 | raise ValueError('Not supported up layer type {}'.format(upsample_layer)) 99 | 100 | def layer(input_tensor): 101 | 102 | x = UpBlock(filters, 103 | upsample_rate=upsample_rate, 104 | kernel_size=kernel_size, 105 | use_bias=not(use_batchnorm), 106 | conv_name=conv_name, 107 | up_name=up_name, 108 | **kwargs)(input_tensor) 109 | 110 | if use_batchnorm: 111 | x = BatchNormalization(name=bn_name)(x) 112 | 113 | x = Activation('relu', name=relu_name)(x) 114 | 115 | return x 116 | return layer 117 | 118 | 119 | def DecoderBlock(stage, 120 | filters=None, 121 | kernel_size=(3,3), 122 | upsample_rate=(2,2), 123 | use_batchnorm=False, 124 | skip=None, 125 | upsample_layer='upsampling'): 126 | 127 | def layer(input_tensor): 128 | 129 | conv_name, bn_name, relu_name, up_name = handle_block_names(stage) 130 | input_filters = K.int_shape(input_tensor)[-1] 131 | 132 | if skip is not None: 133 | output_filters = K.int_shape(skip)[-1] 134 | else: 135 | output_filters = filters 136 | 137 | x = ConvRelu(input_filters // 4, 138 | kernel_size=(1, 1), 139 | use_batchnorm=use_batchnorm, 140 | conv_name=conv_name + '1', 141 | bn_name=bn_name + '1', 142 | relu_name=relu_name + '1')(input_tensor) 143 | 144 | x = UpsampleBlock(filters=input_filters // 4, 145 | kernel_size=kernel_size, 146 | upsample_layer=upsample_layer, 147 | upsample_rate=upsample_rate, 148 | use_batchnorm=use_batchnorm, 149 | conv_name=conv_name + '2', 150 | bn_name=bn_name + '2', 151 | up_name=up_name + '2', 152 | relu_name=relu_name + '2')(x) 153 | 154 | x = ConvRelu(output_filters, 155 | kernel_size=(1, 1), 156 | use_batchnorm=use_batchnorm, 157 | conv_name=conv_name + '3', 158 | bn_name=bn_name + '3', 159 | relu_name=relu_name + '3')(x) 160 | 161 | if skip is not None: 162 | x = Add()([x, skip]) 163 | 164 | return x 165 | return layer 166 | -------------------------------------------------------------------------------- /segmentation_models/linknet/builder.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D 2 | from keras.layers import Activation 3 | from keras.models import Model 4 | 5 | from .blocks import DecoderBlock 6 | from ..utils import get_layer_number, to_tuple 7 | 8 | 9 | def build_linknet(backbone, 10 | classes, 11 | skip_connection_layers, 12 | decoder_filters=(None, None, None, None, 16), 13 | upsample_rates=(2, 2, 2, 2, 2), 14 | n_upsample_blocks=5, 15 | upsample_kernel_size=(3, 3), 16 | upsample_layer='upsampling', 17 | activation='sigmoid', 18 | use_batchnorm=True): 19 | 20 | input = backbone.input 21 | x = backbone.output 22 | 23 | # convert layer names to indices 24 | skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l 25 | for l in skip_connection_layers]) 26 | 27 | for i in range(n_upsample_blocks): 28 | 29 | # check if there is a skip connection 30 | skip_connection = None 31 | if i < len(skip_connection_idx): 32 | skip_connection = backbone.layers[skip_connection_idx[i]].output 33 | 34 | upsample_rate = to_tuple(upsample_rates[i]) 35 | 36 | x = DecoderBlock(stage=i, 37 | filters=decoder_filters[i], 38 | kernel_size=upsample_kernel_size, 39 | upsample_rate=upsample_rate, 40 | use_batchnorm=use_batchnorm, 41 | upsample_layer=upsample_layer, 42 | skip=skip_connection)(x) 43 | 44 | x = Conv2D(classes, (3, 3), padding='same', name='final_conv')(x) 45 | x = Activation(activation, name=activation)(x) 46 | 47 | model = Model(input, x) 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /segmentation_models/linknet/model.py: -------------------------------------------------------------------------------- 1 | from .builder import build_linknet 2 | from ..utils import freeze_model 3 | from ..utils import legacy_support 4 | from ..backbones import get_backbone, get_feature_layers 5 | 6 | old_args_map = { 7 | 'freeze_encoder': 'encoder_freeze', 8 | 'skip_connections': 'encoder_features', 9 | 'upsample_layer': 'decoder_block_type', 10 | 'n_upsample_blocks': None, # removed 11 | 'input_tensor': None, # removed 12 | 'upsample_kernel_size': None, # removed 13 | } 14 | 15 | 16 | @legacy_support(old_args_map) 17 | def Linknet(backbone_name='vgg16', 18 | input_shape=(None, None, 3), 19 | classes=1, 20 | activation='sigmoid', 21 | encoder_weights='imagenet', 22 | encoder_freeze=False, 23 | encoder_features='default', 24 | decoder_filters=(None, None, None, None, 16), 25 | decoder_use_batchnorm=True, 26 | decoder_block_type='upsampling', 27 | **kwargs): 28 | """Linknet_ is a fully convolution neural network for fast image semantic segmentation 29 | 30 | Note: 31 | This implementation by default has 4 skip connections (original - 3). 32 | 33 | Args: 34 | backbone_name: name of classification model (without last dense layers) used as feature 35 | extractor to build segmentation model. 36 | input_shape: shape of input data/image ``(H, W, C)``, in general 37 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be 38 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. 39 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 40 | activation: name of one of ``keras.activations`` for last model layer 41 | (e.g. ``sigmoid``, ``softmax``, ``linear``). 42 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 43 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 44 | encoder_features: a list of layer numbers or names starting from top of the model. 45 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used 46 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. 47 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks, 48 | for block with skip connection a number of filters is equal to number of filters in 49 | corresponding encoder block (estimates automatically and can be passed as ``None`` value). 50 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 51 | is used. 52 | decoder_block_type: one of 53 | - `upsampling`: use ``Upsampling2D`` keras layer 54 | - `transpose`: use ``Transpose2D`` keras layer 55 | 56 | Returns: 57 | ``keras.models.Model``: **Linknet** 58 | 59 | .. _Linknet: 60 | https://arxiv.org/pdf/1707.03718.pdf 61 | """ 62 | 63 | backbone = get_backbone(backbone_name, 64 | input_shape=input_shape, 65 | input_tensor=None, 66 | weights=encoder_weights, 67 | include_top=False) 68 | 69 | if encoder_features == 'default': 70 | encoder_features = get_feature_layers(backbone_name, n=4) 71 | 72 | model = build_linknet(backbone, 73 | classes, 74 | encoder_features, 75 | decoder_filters=decoder_filters, 76 | upsample_layer=decoder_block_type, 77 | activation=activation, 78 | n_upsample_blocks=len(decoder_filters), 79 | upsample_rates=(2, 2, 2, 2, 2), 80 | upsample_kernel_size=(3, 3), 81 | use_batchnorm=decoder_use_batchnorm) 82 | 83 | # lock encoder weights for fine-tuning 84 | if encoder_freeze: 85 | freeze_model(backbone) 86 | 87 | model.name = 'link-{}'.format(backbone_name) 88 | 89 | return model 90 | -------------------------------------------------------------------------------- /segmentation_models/losses.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.losses import binary_crossentropy 3 | from keras.losses import categorical_crossentropy 4 | from keras.utils.generic_utils import get_custom_objects 5 | 6 | from .metrics import jaccard_score, f_score 7 | 8 | SMOOTH = 1e-12 9 | 10 | __all__ = [ 11 | 'jaccard_loss', 'bce_jaccard_loss', 'cce_jaccard_loss', 12 | 'dice_loss', 'bce_dice_loss', 'cce_dice_loss', 13 | ] 14 | 15 | 16 | # ============================== Jaccard Losses ============================== 17 | 18 | def jaccard_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True): 19 | r"""Jaccard loss function for imbalanced datasets: 20 | 21 | .. math:: L(A, B) = 1 - \frac{A \cap B}{A \cup B} 22 | 23 | Args: 24 | gt: ground truth 4D keras tensor (B, H, W, C) 25 | pr: prediction 4D keras tensor (B, H, W, C) 26 | class_weights: 1. or list of class weights, len(weights) = C 27 | smooth: value to avoid division by zero 28 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 29 | else over whole batch 30 | 31 | Returns: 32 | Jaccard loss in range [0, 1] 33 | 34 | """ 35 | return 1 - jaccard_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image) 36 | 37 | 38 | def bce_jaccard_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True): 39 | bce = K.mean(binary_crossentropy(gt, pr)) 40 | loss = bce_weight * bce + jaccard_loss(gt, pr, smooth=smooth, per_image=per_image) 41 | return loss 42 | 43 | 44 | def cce_jaccard_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True): 45 | cce = categorical_crossentropy(gt, pr) * class_weights 46 | cce = K.mean(cce) 47 | return cce_weight * cce + jaccard_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image) 48 | 49 | 50 | # Update custom objects 51 | get_custom_objects().update({ 52 | 'jaccard_loss': jaccard_loss, 53 | 'bce_jaccard_loss': bce_jaccard_loss, 54 | 'cce_jaccard_loss': cce_jaccard_loss, 55 | }) 56 | 57 | 58 | # ============================== Dice Losses ================================ 59 | 60 | def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True): 61 | r"""Dice loss function for imbalanced datasets: 62 | 63 | .. math:: L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall} 64 | {\beta^2 \cdot precision + recall} 65 | 66 | Args: 67 | gt: ground truth 4D keras tensor (B, H, W, C) 68 | pr: prediction 4D keras tensor (B, H, W, C) 69 | class_weights: 1. or list of class weights, len(weights) = C 70 | smooth: value to avoid division by zero 71 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 72 | else over whole batch 73 | 74 | Returns: 75 | Dice loss in range [0, 1] 76 | 77 | """ 78 | return 1 - f_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, beta=1.) 79 | 80 | 81 | def bce_dice_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True): 82 | bce = K.mean(binary_crossentropy(gt, pr)) 83 | loss = bce_weight * bce + dice_loss(gt, pr, smooth=smooth, per_image=per_image) 84 | return loss 85 | 86 | 87 | def cce_dice_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True): 88 | cce = categorical_crossentropy(gt, pr) * class_weights 89 | cce = K.mean(cce) 90 | return cce_weight * cce + dice_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image) 91 | 92 | 93 | # Update custom objects 94 | get_custom_objects().update({ 95 | 'dice_loss': dice_loss, 96 | 'bce_dice_loss': bce_dice_loss, 97 | 'cce_dice_loss': cce_dice_loss, 98 | }) 99 | -------------------------------------------------------------------------------- /segmentation_models/metrics.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.utils.generic_utils import get_custom_objects 3 | 4 | __all__ = [ 5 | 'iou_score', 'jaccard_score', 'f1_score', 'f2_score', 'dice_score', 6 | 'get_f_score', 'get_iou_score', 'get_jaccard_score', 7 | ] 8 | 9 | SMOOTH = 1e-12 10 | 11 | 12 | # ============================ Jaccard/IoU score ============================ 13 | 14 | 15 | def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True): 16 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient 17 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the 18 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets, 19 | and is defined as the size of the intersection divided by the size of the union of the sample sets: 20 | 21 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B} 22 | 23 | Args: 24 | gt: ground truth 4D keras tensor (B, H, W, C) 25 | pr: prediction 4D keras tensor (B, H, W, C) 26 | class_weights: 1. or list of class weights, len(weights) = C 27 | smooth: value to avoid division by zero 28 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 29 | else over whole batch 30 | 31 | Returns: 32 | IoU/Jaccard score in range [0, 1] 33 | 34 | .. _`Jaccard index`: https://en.wikipedia.org/wiki/Jaccard_index 35 | 36 | """ 37 | if per_image: 38 | axes = [1, 2] 39 | else: 40 | axes = [0, 1, 2] 41 | 42 | intersection = K.sum(gt * pr, axis=axes) 43 | union = K.sum(gt + pr, axis=axes) - intersection 44 | iou = (intersection + smooth) / (union + smooth) 45 | 46 | # mean per image 47 | if per_image: 48 | iou = K.mean(iou, axis=0) 49 | 50 | # weighted mean per class 51 | iou = K.mean(iou * class_weights) 52 | 53 | return iou 54 | 55 | 56 | def get_iou_score(class_weights=1., smooth=SMOOTH, per_image=True): 57 | """Change default parameters of IoU/Jaccard score 58 | 59 | Args: 60 | class_weights: 1. or list of class weights, len(weights) = C 61 | smooth: value to avoid division by zero 62 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 63 | else over whole batch 64 | 65 | Returns: 66 | ``callable``: IoU/Jaccard score 67 | """ 68 | def score(gt, pr): 69 | return iou_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image) 70 | 71 | return score 72 | 73 | 74 | jaccard_score = iou_score 75 | get_jaccard_score = get_iou_score 76 | 77 | # Update custom objects 78 | get_custom_objects().update({ 79 | 'iou_score': iou_score, 80 | 'jaccard_score': jaccard_score, 81 | }) 82 | 83 | 84 | # ============================== F/Dice - score ============================== 85 | 86 | def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True): 87 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall, 88 | where an F-score reaches its best value at 1 and worst score at 0. 89 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal. 90 | The formula for the F score is: 91 | 92 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall} 93 | {\beta^2 \cdot precision + recall} 94 | 95 | The formula in terms of *Type I* and *Type II* errors: 96 | 97 | .. math:: F_\beta(A, B) = \frac{(1 + \beta^2) TP} {(1 + \beta^2) TP + \beta^2 FN + FP} 98 | 99 | 100 | where: 101 | TP - true positive; 102 | FP - false positive; 103 | FN - false negative; 104 | 105 | Args: 106 | gt: ground truth 4D keras tensor (B, H, W, C) 107 | pr: prediction 4D keras tensor (B, H, W, C) 108 | class_weights: 1. or list of class weights, len(weights) = C 109 | beta: f-score coefficient 110 | smooth: value to avoid division by zero 111 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 112 | else over whole batch 113 | 114 | Returns: 115 | F-score in range [0, 1] 116 | 117 | """ 118 | if per_image: 119 | axes = [1, 2] 120 | else: 121 | axes = [0, 1, 2] 122 | 123 | tp = K.sum(gt * pr, axis=axes) 124 | fp = K.sum(pr, axis=axes) - tp 125 | fn = K.sum(gt, axis=axes) - tp 126 | 127 | score = ((1 + beta ** 2) * tp + smooth) \ 128 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth) 129 | 130 | # mean per image 131 | if per_image: 132 | score = K.mean(score, axis=0) 133 | 134 | # weighted mean per class 135 | score = K.mean(score * class_weights) 136 | 137 | return score 138 | 139 | 140 | def get_f_score(class_weights=1, beta=1, smooth=SMOOTH, per_image=True): 141 | """Change default parameters of F-score score 142 | 143 | Args: 144 | class_weights: 1. or list of class weights, len(weights) = C 145 | smooth: value to avoid division by zero 146 | beta: f-score coefficient 147 | per_image: if ``True``, metric is calculated as mean over images in batch (B), 148 | else over whole batch 149 | 150 | Returns: 151 | ``callable``: F-score 152 | """ 153 | def score(gt, pr): 154 | return f_score(gt, pr, class_weights=class_weights, beta=beta, smooth=smooth, per_image=per_image) 155 | 156 | return score 157 | 158 | 159 | f1_score = get_f_score(beta=1) 160 | f2_score = get_f_score(beta=2) 161 | dice_score = f1_score 162 | 163 | # Update custom objects 164 | get_custom_objects().update({ 165 | 'f1_score': f1_score, 166 | 'f2_score': f2_score, 167 | 'dice_score': dice_score, 168 | }) 169 | -------------------------------------------------------------------------------- /segmentation_models/pspnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PSPNet -------------------------------------------------------------------------------- /segmentation_models/pspnet/blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras.layers import MaxPool2D 3 | from keras.layers import AveragePooling2D 4 | from keras.layers import Concatenate 5 | from keras.layers import Permute 6 | from keras.layers import Reshape 7 | from keras.backend import int_shape 8 | 9 | from ..common import Conv2DBlock 10 | from ..common import ResizeImage 11 | 12 | 13 | def InterpBlock(level, feature_map_shape, 14 | conv_filters=512, 15 | conv_kernel_size=(1,1), 16 | conv_padding='same', 17 | pooling_type='avg', 18 | pool_padding='same', 19 | use_batchnorm=True, 20 | activation='relu', 21 | interpolation='bilinear'): 22 | 23 | if pooling_type == 'max': 24 | Pool2D = MaxPool2D 25 | elif pooling_type == 'avg': 26 | Pool2D = AveragePooling2D 27 | else: 28 | raise ValueError('Unsupported pooling type - `{}`.'.format(pooling_type) + 29 | 'Use `avg` or `max`.') 30 | 31 | def layer(input_tensor): 32 | # Compute the kernel and stride sizes according to how large the final feature map will be 33 | # When the kernel factor and strides are equal, then we can compute the final feature map factor 34 | # by simply dividing the current factor by the kernel or stride factor 35 | # The final feature map sizes are 1x1, 2x2, 3x3, and 6x6. We round to the closest integer 36 | pool_size = [int(np.round(feature_map_shape[0] / level)), 37 | int(np.round(feature_map_shape[1] / level))] 38 | strides = pool_size 39 | 40 | x = Pool2D(pool_size, strides=strides, padding=pool_padding)(input_tensor) 41 | x = Conv2DBlock(conv_filters, 42 | kernel_size=conv_kernel_size, 43 | padding=conv_padding, 44 | use_batchnorm=use_batchnorm, 45 | activation=activation, 46 | name='level{}'.format(level))(x) 47 | x = ResizeImage(strides, interpolation=interpolation)(x) 48 | return x 49 | return layer 50 | 51 | 52 | def DUC(factor=(8, 8)): 53 | 54 | if factor[0] != factor[1]: 55 | raise ValueError('DUC upconvolution support only equal factors, ' 56 | 'got {}'.format(factor)) 57 | factor = factor[0] 58 | 59 | def layer(input_tensor): 60 | 61 | h, w, c = int_shape(input_tensor)[1:] 62 | H = h * factor 63 | W = w * factor 64 | 65 | x = Conv2DBlock(c*factor**2, (1,1), 66 | padding='same', 67 | name='duc_{}'.format(factor))(input_tensor) 68 | x = Permute((3, 1, 2))(x) 69 | x = Reshape((c, factor, factor, h, w))(x) 70 | x = Permute((1, 4, 2, 5, 3))(x) 71 | x = Reshape((c, H, W))(x) 72 | x = Permute((2, 3, 1))(x) 73 | return x 74 | return layer 75 | 76 | 77 | def PyramidPoolingModule(**params): 78 | """ 79 | Build the Pyramid Pooling Module. 80 | """ 81 | 82 | _params = { 83 | 'conv_filters': 512, 84 | 'conv_kernel_size': (1, 1), 85 | 'conv_padding': 'same', 86 | 'pooling_type': 'avg', 87 | 'pool_padding': 'same', 88 | 'use_batchnorm': True, 89 | 'activation': 'relu', 90 | 'interpolation': 'bilinear', 91 | } 92 | 93 | _params.update(params) 94 | 95 | def module(input_tensor): 96 | 97 | feature_map_shape = int_shape(input_tensor)[1:3] 98 | 99 | x1 = InterpBlock(1, feature_map_shape, **_params)(input_tensor) 100 | x2 = InterpBlock(2, feature_map_shape, **_params)(input_tensor) 101 | x3 = InterpBlock(3, feature_map_shape, **_params)(input_tensor) 102 | x6 = InterpBlock(6, feature_map_shape, **_params)(input_tensor) 103 | 104 | x = Concatenate()([input_tensor, x1, x2, x3, x6]) 105 | return x 106 | return module -------------------------------------------------------------------------------- /segmentation_models/pspnet/builder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code is constructed based on following repositories: 3 | https://github.com/ykamikawa/PSPNet/ 4 | https://github.com/hujh14/PSPNet-Keras/ 5 | https://github.com/Vladkryvoruchko/PSPNet-Keras-tensorflow/ 6 | 7 | And original paper of PSPNet: 8 | https://arxiv.org/pdf/1612.01105.pdf 9 | """ 10 | 11 | from keras.layers import Conv2D 12 | from keras.layers import Activation 13 | from keras.layers import SpatialDropout2D 14 | from keras.models import Model 15 | 16 | from .blocks import PyramidPoolingModule, DUC 17 | from ..common import Conv2DBlock 18 | from ..common import ResizeImage 19 | from ..utils import extract_outputs 20 | from ..utils import to_tuple 21 | 22 | 23 | def build_psp(backbone, 24 | psp_layer, 25 | last_upsampling_factor, 26 | classes=21, 27 | activation='softmax', 28 | conv_filters=512, 29 | pooling_type='avg', 30 | dropout=None, 31 | final_interpolation='bilinear', 32 | use_batchnorm=True): 33 | 34 | input = backbone.input 35 | 36 | x = extract_outputs(backbone, [psp_layer])[0] 37 | 38 | x = PyramidPoolingModule( 39 | conv_filters=conv_filters, 40 | pooling_type=pooling_type, 41 | use_batchnorm=use_batchnorm)(x) 42 | 43 | x = Conv2DBlock(512, (1, 1), activation='relu', padding='same', 44 | use_batchnorm=use_batchnorm)(x) 45 | 46 | if dropout is not None: 47 | x = SpatialDropout2D(dropout)(x) 48 | 49 | x = Conv2D(classes, (3,3), padding='same', name='final_conv')(x) 50 | 51 | if final_interpolation == 'bilinear': 52 | x = ResizeImage(to_tuple(last_upsampling_factor))(x) 53 | elif final_interpolation == 'duc': 54 | x = DUC(to_tuple(last_upsampling_factor))(x) 55 | else: 56 | raise ValueError('Unsupported interpolation type {}. '.format(final_interpolation) + 57 | 'Use `duc` or `bilinear`.') 58 | 59 | x = Activation(activation, name=activation)(x) 60 | 61 | model = Model(input, x) 62 | 63 | return model 64 | -------------------------------------------------------------------------------- /segmentation_models/pspnet/model.py: -------------------------------------------------------------------------------- 1 | from .builder import build_psp 2 | from ..utils import freeze_model 3 | from ..utils import legacy_support 4 | from ..backbones import get_backbone, get_feature_layers 5 | 6 | 7 | def _get_layer_by_factor(backbone_name, factor): 8 | feature_layers = get_feature_layers(backbone_name, n=3) 9 | if factor == 4: 10 | return feature_layers[-1] 11 | elif factor == 8: 12 | return feature_layers[-2] 13 | elif factor == 16: 14 | return feature_layers[-3] 15 | else: 16 | raise ValueError('Unsupported factor - `{}`, Use 4, 8 or 16.'.format(factor)) 17 | 18 | 19 | def _shape_guard(factor, shape): 20 | h, w = shape[:2] 21 | min_size = factor * 6 22 | 23 | res = (h % min_size != 0 or w % min_size != 0 or 24 | h < min_size or w < min_size) 25 | if res: 26 | raise ValueError('Wrong shape {}, input H and W should '.format(shape) + 27 | 'be divisible by `{}`'.format(min_size)) 28 | 29 | 30 | old_args_map = { 31 | 'freeze_encoder': 'encoder_freeze', 32 | 'use_batchnorm': 'psp_use_batchnorm', 33 | 'dropout': 'psp_dropout', 34 | 'input_tensor': None, # removed 35 | } 36 | 37 | 38 | @legacy_support(old_args_map) 39 | def PSPNet(backbone_name='vgg16', 40 | input_shape=(384, 384, 3), 41 | classes=21, 42 | activation='softmax', 43 | encoder_weights='imagenet', 44 | encoder_freeze=False, 45 | downsample_factor=8, 46 | psp_conv_filters=512, 47 | psp_pooling_type='avg', 48 | psp_use_batchnorm=True, 49 | psp_dropout=None, 50 | final_interpolation='bilinear', 51 | **kwargs): 52 | """PSPNet_ is a fully convolution neural network for image semantic segmentation 53 | 54 | Args: 55 | backbone_name: name of classification model used as feature 56 | extractor to build segmentation model. 57 | input_shape: shape of input data/image ``(H, W, C)``. 58 | ``H`` and ``W`` should be divisible by ``6 * downsample_factor`` and **NOT** ``None``! 59 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 60 | activation: name of one of ``keras.activations`` for last model layer 61 | (e.g. ``sigmoid``, ``softmax``, ``linear``). 62 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 63 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 64 | downsample_factor: one of 4, 8 and 16. Downsampling rate or in other words backbone depth 65 | to construct PSP module on it. 66 | psp_conv_filters: number of filters in ``Conv2D`` layer in each PSP block. 67 | psp_pooling_type: one of 'avg', 'max'. PSP block pooling type (maximum or average). 68 | psp_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 69 | is used. 70 | psp_dropout: dropout rate between 0 and 1. 71 | final_interpolation: ``duc`` or ``bilinear`` - interpolation type for final 72 | upsampling layer. 73 | 74 | Returns: 75 | ``keras.models.Model``: **PSPNet** 76 | 77 | .. _PSPNet: 78 | https://arxiv.org/pdf/1612.01105.pdf 79 | 80 | """ 81 | 82 | # control image input shape 83 | _shape_guard(downsample_factor, input_shape) 84 | 85 | backbone = get_backbone(backbone_name, 86 | input_shape=input_shape, 87 | input_tensor=None, 88 | weights=encoder_weights, 89 | include_top=False) 90 | 91 | psp_layer = _get_layer_by_factor(backbone_name, downsample_factor) 92 | 93 | model = build_psp(backbone, 94 | psp_layer, 95 | last_upsampling_factor=downsample_factor, 96 | classes=classes, 97 | conv_filters=psp_conv_filters, 98 | pooling_type=psp_pooling_type, 99 | activation=activation, 100 | use_batchnorm=psp_use_batchnorm, 101 | dropout=psp_dropout, 102 | final_interpolation=final_interpolation) 103 | 104 | # lock encoder weights for fine-tuning 105 | if encoder_freeze: 106 | freeze_model(backbone) 107 | 108 | model.name = 'psp-{}'.format(backbone_name) 109 | 110 | return model 111 | -------------------------------------------------------------------------------- /segmentation_models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Unet 2 | -------------------------------------------------------------------------------- /segmentation_models/unet/blocks.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2DTranspose 2 | from keras.layers import UpSampling2D 3 | from keras.layers import Conv2D 4 | from keras.layers import BatchNormalization 5 | from keras.layers import Activation 6 | from keras.layers import Concatenate 7 | 8 | 9 | def handle_block_names(stage): 10 | conv_name = 'decoder_stage{}_conv'.format(stage) 11 | bn_name = 'decoder_stage{}_bn'.format(stage) 12 | relu_name = 'decoder_stage{}_relu'.format(stage) 13 | up_name = 'decoder_stage{}_upsample'.format(stage) 14 | return conv_name, bn_name, relu_name, up_name 15 | 16 | 17 | def ConvRelu(filters, kernel_size, use_batchnorm=False, conv_name='conv', bn_name='bn', relu_name='relu'): 18 | def layer(x): 19 | x = Conv2D(filters, kernel_size, padding="same", name=conv_name, use_bias=not(use_batchnorm))(x) 20 | if use_batchnorm: 21 | x = BatchNormalization(name=bn_name)(x) 22 | x = Activation('relu', name=relu_name)(x) 23 | return x 24 | return layer 25 | 26 | 27 | def Upsample2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2), 28 | use_batchnorm=False, skip=None): 29 | 30 | def layer(input_tensor): 31 | 32 | conv_name, bn_name, relu_name, up_name = handle_block_names(stage) 33 | 34 | x = UpSampling2D(size=upsample_rate, name=up_name)(input_tensor) 35 | 36 | if skip is not None: 37 | x = Concatenate()([x, skip]) 38 | 39 | x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, 40 | conv_name=conv_name + '1', bn_name=bn_name + '1', relu_name=relu_name + '1')(x) 41 | 42 | x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, 43 | conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x) 44 | 45 | return x 46 | return layer 47 | 48 | 49 | def Transpose2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2), 50 | transpose_kernel_size=(4,4), use_batchnorm=False, skip=None): 51 | 52 | def layer(input_tensor): 53 | 54 | conv_name, bn_name, relu_name, up_name = handle_block_names(stage) 55 | 56 | x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate, 57 | padding='same', name=up_name, use_bias=not(use_batchnorm))(input_tensor) 58 | if use_batchnorm: 59 | x = BatchNormalization(name=bn_name+'1')(x) 60 | x = Activation('relu', name=relu_name+'1')(x) 61 | 62 | if skip is not None: 63 | x = Concatenate()([x, skip]) 64 | 65 | x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, 66 | conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x) 67 | 68 | return x 69 | return layer -------------------------------------------------------------------------------- /segmentation_models/unet/builder.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D 2 | from keras.layers import Activation 3 | from keras.models import Model 4 | 5 | from .blocks import Transpose2D_block 6 | from .blocks import Upsample2D_block 7 | from ..utils import get_layer_number, to_tuple 8 | 9 | 10 | def build_unet(backbone, classes, skip_connection_layers, 11 | decoder_filters=(256,128,64,32,16), 12 | upsample_rates=(2,2,2,2,2), 13 | n_upsample_blocks=5, 14 | block_type='upsampling', 15 | activation='sigmoid', 16 | use_batchnorm=True): 17 | 18 | input = backbone.input 19 | x = backbone.output 20 | 21 | if block_type == 'transpose': 22 | up_block = Transpose2D_block 23 | else: 24 | up_block = Upsample2D_block 25 | 26 | # convert layer names to indices 27 | skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l 28 | for l in skip_connection_layers]) 29 | 30 | for i in range(n_upsample_blocks): 31 | 32 | # check if there is a skip connection 33 | skip_connection = None 34 | if i < len(skip_connection_idx): 35 | skip_connection = backbone.layers[skip_connection_idx[i]].output 36 | 37 | upsample_rate = to_tuple(upsample_rates[i]) 38 | 39 | x = up_block(decoder_filters[i], i, upsample_rate=upsample_rate, 40 | skip=skip_connection, use_batchnorm=use_batchnorm)(x) 41 | 42 | x = Conv2D(classes, (3,3), padding='same', name='final_conv')(x) 43 | x = Activation(activation, name=activation)(x) 44 | 45 | model = Model(input, x) 46 | 47 | return model 48 | -------------------------------------------------------------------------------- /segmentation_models/unet/model.py: -------------------------------------------------------------------------------- 1 | from .builder import build_unet 2 | from ..utils import freeze_model 3 | from ..utils import legacy_support 4 | from ..backbones import get_backbone, get_feature_layers 5 | 6 | old_args_map = { 7 | 'freeze_encoder': 'encoder_freeze', 8 | 'skip_connections': 'encoder_features', 9 | 'upsample_rates': None, # removed 10 | 'input_tensor': None, # removed 11 | } 12 | 13 | 14 | @legacy_support(old_args_map) 15 | def Unet(backbone_name='vgg16', 16 | input_shape=(None, None, 3), 17 | classes=1, 18 | activation='sigmoid', 19 | encoder_weights='imagenet', 20 | encoder_freeze=False, 21 | encoder_features='default', 22 | decoder_block_type='upsampling', 23 | decoder_filters=(256, 128, 64, 32, 16), 24 | decoder_use_batchnorm=True, 25 | **kwargs): 26 | """ Unet_ is a fully convolution neural network for image semantic segmentation 27 | 28 | Args: 29 | backbone_name: name of classification model (without last dense layers) used as feature 30 | extractor to build segmentation model. 31 | input_shape: shape of input data/image ``(H, W, C)``, in general 32 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be 33 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. 34 | classes: a number of classes for output (output shape - ``(h, w, classes)``). 35 | activation: name of one of ``keras.activations`` for last model layer 36 | (e.g. ``sigmoid``, ``softmax``, ``linear``). 37 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 38 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. 39 | encoder_features: a list of layer numbers or names starting from top of the model. 40 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used 41 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. 42 | decoder_block_type: one of blocks with following layers structure: 43 | 44 | - `upsampling`: ``Upsampling2D`` -> ``Conv2D`` -> ``Conv2D`` 45 | - `transpose`: ``Transpose2D`` -> ``Conv2D`` 46 | 47 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks 48 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 49 | is used. 50 | 51 | Returns: 52 | ``keras.models.Model``: **Unet** 53 | 54 | .. _Unet: 55 | https://arxiv.org/pdf/1505.04597 56 | 57 | """ 58 | 59 | backbone = get_backbone(backbone_name, 60 | input_shape=input_shape, 61 | input_tensor=None, 62 | weights=encoder_weights, 63 | include_top=False) 64 | 65 | if encoder_features == 'default': 66 | encoder_features = get_feature_layers(backbone_name, n=4) 67 | 68 | model = build_unet(backbone, 69 | classes, 70 | encoder_features, 71 | decoder_filters=decoder_filters, 72 | block_type=decoder_block_type, 73 | activation=activation, 74 | n_upsample_blocks=len(decoder_filters), 75 | upsample_rates=(2, 2, 2, 2, 2), 76 | use_batchnorm=decoder_use_batchnorm) 77 | 78 | # lock encoder weights for fine-tuning 79 | if encoder_freeze: 80 | freeze_model(backbone) 81 | 82 | model.name = 'u-{}'.format(backbone_name) 83 | 84 | return model 85 | -------------------------------------------------------------------------------- /temp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/temp/__init__.py -------------------------------------------------------------------------------- /temp/band4_image.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import gdal 4 | import numpy as np 5 | import keras.backend as K 6 | K.set_image_dim_ordering('tf') 7 | from keras.preprocessing.image import img_to_array 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | img_path = '../../data/originaldata/NatureProtected/src/sample1.png' 12 | 13 | out_dir='../../data/' 14 | 15 | window_size=2048 16 | 17 | if __name__=='__main__': 18 | img = cv2.imread(img_path) 19 | 20 | print(img.shape) 21 | 22 | dataset = gdal.Open(img_path) 23 | if dataset==None: 24 | print("open failed!\n") 25 | 26 | height = dataset.RasterYSize 27 | width=dataset.RasterXSize 28 | bands = dataset.RasterCount 29 | all_data=dataset.ReadAsArray(0,0,width,height) 30 | # all_data = np.array(all_data) 31 | # new_data = img_to_array(all_data) 32 | 33 | # print("shape:{}".format(new_data.shape)) 34 | 35 | x = np.random.randint(0, height - window_size - 1) 36 | y = np.random.randint(0, width - window_size - 1) 37 | 38 | output_img = all_data[:, x:x + window_size, y:y + window_size] 39 | print(output_img.shape) 40 | result = output_img[1:4,:,:] 41 | result = np.transpose(result,(1,2,0)) 42 | plt.imshow(result) 43 | plt.show() 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /temp/change_geotransform.py: -------------------------------------------------------------------------------- 1 | import gdal 2 | import os ,sys 3 | from ulitities.base_functions import load_img_by_gdal_geo 4 | 5 | 6 | input_file = '/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/images/test_2w.png' 7 | output_file = '/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/images/test_2w-C.png' 8 | 9 | 10 | if __name__=="__main__": 11 | try: 12 | data, geotransform = load_img_by_gdal_geo(input_file) 13 | print("Geotransform:{}".format(geotransform)) 14 | except: 15 | print("Error: Failde load image..") 16 | sys.exit(-1) 17 | tmp = list(geotransform) 18 | tmp[-1]=-1*tmp[-1] 19 | 20 | new_geo=tuple(tmp) 21 | print("new Geotransform:{}".format(new_geo)) 22 | 23 | w,h,c = data.shape 24 | 25 | driver = gdal.GetDriverByName("GTiff") 26 | # driver = gdal.GetDriverByName("PNG") 27 | # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16) 28 | outdataset = driver.Create(output_file, w, h, c, gdal.GDT_Byte) 29 | outdataset.SetGeoTransform(new_geo) 30 | if outdataset == None: 31 | print("create dataset failed!\n") 32 | sys.exit(-2) 33 | if c == 1: 34 | outdataset.GetRasterBand(1).WriteArray(data) 35 | else: 36 | for i in range(c): 37 | outdataset.GetRasterBand(i + 1).WriteArray(data[:,:,i]) 38 | del outdataset 39 | -------------------------------------------------------------------------------- /temp/change_label_zym.py: -------------------------------------------------------------------------------- 1 | import os, sys, gdal 2 | import numpy as np 3 | from tqdm import tqdm 4 | import matplotlib.pyplot as plt 5 | import cv2 6 | 7 | from ulitities.base_functions import get_file, load_img_by_gdal 8 | 9 | 10 | def write_img_by_gdal(path, data, bands, dtype): 11 | data = np.array(data) 12 | if bands >1: 13 | a,b,c = data.shape 14 | if c= 4) 22 | 23 | img = dataset.ReadAsArray(0, 0) 24 | rgb_img = img[:3] 25 | rgb_img = np.transpose(rgb_img, (1, 2, 0)) 26 | 27 | p_band=dataset.GetRasterBand(1) 28 | data_type = p_band.DataType 29 | # data_type = gdal.GetDataTypeByName(p_band.DataType) 30 | if data_type==1: 31 | 32 | elif data_type==2: 33 | # img = dataset.ReadAsArray(0, 0) 34 | # rgb_img = img[:3] 35 | # rgb_img = np.transpose(rgb_img, (1, 2, 0)) 36 | rgb_img = rgb_img*255/6 37 | rgb_img = rgb_img.astype(np.uint8) 38 | plt.imshow(rgb_img, cmap='jet') 39 | plt.show() 40 | rgb_filename = ''.join([outdir_rgb, os.path.split(img_path)[1]]) 41 | 42 | cv2.imwrite(rgb_filename, rgb_img) 43 | 44 | nrg_img = img[1:4] 45 | nrg_img = np.transpose(nrg_img, (1, 2, 0)) 46 | nrg_img = nrg_img.astype(np.uint8) 47 | plt.imshow(nrg_img) 48 | plt.show() 49 | rgb_filename = ''.join([outdir_nrg, os.path.split(img_path)[1]]) 50 | 51 | cv2.imwrite(rgb_filename, nrg_img) 52 | 53 | 54 | 55 | 56 | 57 | print("complete!\n") 58 | 59 | -------------------------------------------------------------------------------- /temp/predict_from_xuhuimin.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import cv2 4 | 5 | from test_utils.predict import predict 6 | 7 | 8 | def pred(image, net, step, ctx):#step为样本选取间隔 9 | h, w, channel = image.shape 10 | image = image.astype('float32') 11 | size = int(step *0.75) #取样本中size尺寸为最终预测尺寸 12 | margin = int((step - size) / 2) 13 | inhang = int(np.ceil(h/size)) 14 | inlie = int(np.ceil(w / size)) 15 | 16 | # newimage0=np.zeros((inhang*size, inlie*size,channel)) 17 | # borderType = cv2.BORDER_REFLECT 18 | # newimage = cv2.copyMakeBorder(newimage0, margin, margin, margin, margin, borderType) 19 | 20 | newimage = np.zeros((inhang*size + margin*2 , inlie*size +2*margin,channel)) 21 | newimage[margin : h + margin,margin : w + margin ,:] = image 22 | newimage /= 255 23 | predictions = np.zeros((inhang*size , inlie*size), dtype=np.int64) 24 | for i in range(inhang): 25 | for j in range(inlie): 26 | patch = newimage[ i*size: i*size+step ,j*size: j*size+step ,:] 27 | patch = np.transpose(patch, axes=(2, 0, 1)).astype(np.float32) 28 | patch = mx.nd.array(np.expand_dims(patch, 0), ctx=ctx) 29 | pred = predict(patch, net)#预测 30 | predictions[ i*size: (i+1)*size ,j*size: (j+1)*size] = pred[margin:size+margin,margin:size+margin] 31 | result = predictions[:h,:w] 32 | return result 33 | 34 | -------------------------------------------------------------------------------- /temp/test_cv2read.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import cv2 4 | 5 | if __name__=='__main__': 6 | img = cv2.imread('/home/omnisky/PycharmProjects/data/originaldata/SatRGB/label/ruoergai_8.png') 7 | 8 | print("ok") 9 | -------------------------------------------------------------------------------- /temp/test_for_jaccrad_predict.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import matplotlib 3 | 4 | matplotlib.use("Agg") 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from keras.models import Sequential,load_model 8 | from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Reshape, Permute, Activation, Input 9 | from keras.utils.np_utils import to_categorical 10 | from keras.preprocessing.image import img_to_array 11 | from keras.callbacks import ModelCheckpoint, EarlyStopping, History,ReduceLROnPlateau 12 | from keras.models import Model 13 | from keras.layers.merge import concatenate 14 | from PIL import Image 15 | import matplotlib.pyplot as plt 16 | import cv2 17 | import random 18 | import sys 19 | import os 20 | from tqdm import tqdm 21 | from keras.models import * 22 | from keras.layers import * 23 | from keras.optimizers import * 24 | 25 | from keras import backend as K 26 | K.set_image_dim_ordering('tf') 27 | 28 | 29 | from semantic_segmentation_networks import binary_unet_jaccard, binary_fcnnet_jaccard, binary_segnet_jaccard 30 | from ulitities.base_functions import load_img_normalization 31 | 32 | os.environ["CUDA_VISIBLE_DEVICES"] = "5" 33 | seed = 7 34 | np.random.seed(seed) 35 | 36 | img_w = 256 37 | img_h = 256 38 | 39 | n_label = 1 40 | 41 | model_save_path='/home/omnisky/PycharmProjects/data/models/sat_urban_rgb/unet_buildings_binary_jaccard.h5' 42 | 43 | 44 | window_size=256 45 | 46 | def test_predict(image,model): 47 | stride = window_size 48 | 49 | h, w, _ = image.shape 50 | print('h,w:', h, w) 51 | padding_h = (h // stride + 1) * stride 52 | padding_w = (w // stride + 1) * stride 53 | padding_img = np.zeros((padding_h, padding_w, 3)) 54 | padding_img[0:h, 0:w, :] = image[:, :, :] 55 | 56 | padding_img = img_to_array(padding_img) 57 | 58 | mask_whole = np.zeros((padding_h, padding_w), dtype=np.float32) 59 | for i in list(range(padding_h // stride)): 60 | for j in list(range(padding_w // stride)): 61 | crop = padding_img[i * stride:i * stride + window_size, j * stride:j * stride + window_size, :3] 62 | 63 | crop = np.expand_dims(crop, axis=0) 64 | print('crop:{}'.format(crop.shape)) 65 | 66 | # pred = model.predict(crop, verbose=2) 67 | pred = model.predict(crop, verbose=2) 68 | # pred = np.argmax(pred, axis=2) #for one hot encoding 69 | 70 | pred = pred.reshape(256, 256) 71 | # pred = pred[0] 72 | # pred = pred[:,:,0] 73 | print(np.unique(pred)) 74 | 75 | 76 | mask_whole[i * stride:i * stride + window_size, j * stride:j * stride + window_size] = pred[:, :] 77 | 78 | outputresult =mask_whole[0:h,0:w] 79 | # outputresult = outputresult.astype(np.uint8) 80 | 81 | plt.imshow(outputresult, cmap='gray') 82 | plt.title("Original predicted result") 83 | plt.show() 84 | cv2.imwrite('../../data/predict/test_model.png',outputresult*255) 85 | return outputresult 86 | 87 | 88 | 89 | 90 | if __name__ == '__main__': 91 | 92 | print("test ....................predict by trained model .....\n") 93 | test_img_path = '../../data/test/sample1.png' 94 | import sys 95 | 96 | if not os.path.isfile(test_img_path): 97 | print("no file: {}".format(test_img_path)) 98 | sys.exit(-1) 99 | 100 | ret, input_img = load_img_normalization(test_img_path) 101 | # model_save_path ='../../data/models/unet_buildings_onehot.h5' 102 | 103 | new_model = load_model(model_save_path) 104 | 105 | test_predict(input_img, new_model) 106 | -------------------------------------------------------------------------------- /temp/test_unet_multiclass_predict.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import cv2 4 | import numpy as np 5 | import os 6 | import sys 7 | import argparse 8 | # from keras.preprocessing.image import img_to_array 9 | from keras.models import load_model 10 | from sklearn.preprocessing import LabelEncoder 11 | from keras.preprocessing.image import img_to_array 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | from predict.smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands 16 | 17 | from keras import backend as K 18 | K.set_image_dim_ordering('th') 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 20 | 21 | # segnet_classes = [0., 1., 2., 3., 4.] 22 | unet_classes = [0., 1., 2.] 23 | 24 | labelencoder = LabelEncoder() 25 | labelencoder.fit(unet_classes) 26 | 27 | input_image = '../../data/test/1.png' 28 | 29 | 30 | """(1.1) for unet road predict""" 31 | unet_model_path = '../../data/models/unet_channel_first_multiclass.h5' 32 | unet_output_mask = '../../data/predict/unet/mask_unet_roads_'+os.path.split(input_image)[1] 33 | 34 | 35 | window_size = 256 36 | 37 | 38 | 39 | 40 | 41 | 42 | def cheap_predict(input_img, model,label_encoder): 43 | stride = window_size 44 | 45 | h, w, _ = input_img.shape 46 | print 'h,w:', h, w 47 | padding_h = (h // stride + 1) * stride 48 | padding_w = (w // stride + 1) * stride 49 | padding_img = np.zeros((padding_h, padding_w, 3)) 50 | padding_img[0:h, 0:w, :] = input_img[:, :, :] 51 | 52 | # Using "img_to_array" to convert the dimensions ordering, to adapt "K.set_image_dim_ordering('**') " 53 | padding_img = img_to_array(padding_img) 54 | print 'src:', padding_img.shape 55 | 56 | mask_whole = np.zeros((padding_h, padding_w), dtype=np.float32) 57 | for i in range(padding_h // stride): 58 | for j in range(padding_w // stride): 59 | crop = padding_img[:3, i * stride:i * stride + window_size, j * stride:j * stride + window_size] 60 | # crop = padding_img[i * stride:i * stride + window_size, j * stride:j * stride + window_size, :3] 61 | cb, ch, cw = crop.shape # for channel_first 62 | 63 | # print ('crop:{}'.format(crop.shape)) 64 | 65 | crop = np.expand_dims(crop, axis=0) 66 | # print ('crop:{}'.format(crop.shape)) 67 | pred = model.predict(crop, verbose=2) 68 | # print (np.unique(pred)) 69 | # pred = label_encoder.inverse_transform(pred[0]) 70 | # print (np.unique(pred)) 71 | pred = np.argmax(pred,axis=2) 72 | pred = pred.reshape(window_size, window_size) 73 | 74 | 75 | mask_whole[i * stride:i * stride + window_size, j * stride:j * stride + window_size] = pred[:, :] 76 | 77 | outputresult = mask_whole[0:h, 0:w] * 255 78 | # print (np.unique(outputresult)) 79 | # print (np.unique(outputresult[:,:,0])) 80 | # print (np.unique(outputresult[:,:,1])) 81 | 82 | # plt.imshow(outputresult[:,:,2]) 83 | plt.imshow(outputresult) 84 | plt.title("Original predicted result") 85 | plt.show() 86 | cv2.imwrite('../../data/predict/unet/mask_multiclass_test.png', outputresult) 87 | 88 | 89 | def new_predict_for_unet_multiclass(small_img_patches, model, real_classes,labelencoder): 90 | """ 91 | 92 | :param small_img_patches: input image 4D array (patches, row,column, channels) 93 | :param model: pretrained model 94 | :param real_classes: the number of classes and the channels of output mask 95 | :param labelencoder: 96 | :return: predict mask 4D array (patches, row,column, real_classes) 97 | """ 98 | 99 | # assert(real_classes ==1 ) # only usefully for one class 100 | 101 | small_img_patches = np.array(small_img_patches) 102 | print (small_img_patches.shape) 103 | assert (len(small_img_patches.shape) == 4) 104 | 105 | patches,row,column,input_channels = small_img_patches.shape 106 | 107 | mask_output = [] 108 | for p in range(patches): 109 | # crop = np.zeros((row, column, input_channels), np.uint8) 110 | crop = small_img_patches[p,:,:,:] 111 | crop = img_to_array(crop) 112 | crop = np.expand_dims(crop, axis=0) 113 | # print ('crop:{}'.format(crop.shape)) 114 | pred = model.predict(crop, verbose=2) 115 | pred = pred[0].reshape((row,column,real_classes)) 116 | 117 | # 将预测结果2D expand to 3D 118 | # res_pred = np.expand_dims(pred, axis=-1) 119 | 120 | mask_output.append(pred) 121 | 122 | mask_output = np.array(mask_output) 123 | print ("Shape of mask_output:{}".format(mask_output.shape)) 124 | 125 | return mask_output 126 | 127 | 128 | if __name__=='__main__': 129 | input_img = cv2.imread(input_image) 130 | input_img = np.array(input_img, dtype="float") / 255.0 # must do it 131 | model = load_model(unet_model_path) 132 | 133 | cheap_predict(input_img,model,labelencoder) 134 | 135 | # predictions_smooth = predict_img_with_smooth_windowing_multiclassbands( 136 | # input_img, 137 | # model, 138 | # window_size=window_size, 139 | # subdivisions=2, 140 | # real_classes=3, # output channels = 是真的类别, 141 | # pred_func=new_predict_for_unet_multiclass, 142 | # labelencoder=labelencoder 143 | # ) 144 | # plt.imshow(predictions_smooth) 145 | # plt.title("Original predicted result") 146 | # plt.show() 147 | 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/train/__init__.py -------------------------------------------------------------------------------- /ui/about.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'about.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.5.1 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | 11 | class Ui_Dialog_about(object): 12 | def setupUi(self, Dialog_about): 13 | Dialog_about.setObjectName("Dialog_about") 14 | Dialog_about.resize(353, 143) 15 | self.textEdit = QtWidgets.QTextEdit(Dialog_about) 16 | self.textEdit.setGeometry(QtCore.QRect(-10, 0, 371, 151)) 17 | self.textEdit.setObjectName("textEdit") 18 | 19 | self.retranslateUi(Dialog_about) 20 | QtCore.QMetaObject.connectSlotsByName(Dialog_about) 21 | 22 | def retranslateUi(self, Dialog_about): 23 | _translate = QtCore.QCoreApplication.translate 24 | Dialog_about.setWindowTitle(_translate("Dialog_about", "About")) 25 | self.textEdit.setHtml(_translate("Dialog_about", "\n" 26 | "\n" 29 | "

Copyright SCRS

")) 30 | 31 | import mysrc_rc 32 | -------------------------------------------------------------------------------- /ui/about.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Dialog_about 4 | 5 | 6 | 7 | 0 8 | 0 9 | 353 10 | 143 11 | 12 | 13 | 14 | About 15 | 16 | 17 | 18 | 19 | -10 20 | 0 21 | 371 22 | 151 23 | 24 | 25 | 26 | <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd"> 27 | <html><head><meta name="qrichtext" content="1" /><style type="text/css"> 28 | p, li { white-space: pre-wrap; } 29 | </style></head><body style=" font-family:'Sans Serif'; font-size:9pt; font-weight:400; font-style:normal;"> 30 | <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"> <img src=":/log/scrslogo.png" /> <span style=" font-size:16pt;">Copyright SCRS</span><span style=" font-size:20pt;"> </span></p></body></html> 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /ui/else/manual.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ui/else/manual.docx -------------------------------------------------------------------------------- /ui/else/scrslogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ui/else/scrslogo.png -------------------------------------------------------------------------------- /ui/else/中文标签.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ui/else/中文标签.docx -------------------------------------------------------------------------------- /ui/mysrc.qrc: -------------------------------------------------------------------------------- 1 | 2 | 3 | scrslogo.png 4 | 5 | 6 | -------------------------------------------------------------------------------- /ui/postProcess/Binarization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'Binarization.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.5.1 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | 11 | class Ui_Dialog_binarization(object): 12 | def setupUi(self, Dialog_binarization): 13 | Dialog_binarization.setObjectName("Dialog_binarization") 14 | Dialog_binarization.resize(408, 214) 15 | self.buttonBox = QtWidgets.QDialogButtonBox(Dialog_binarization) 16 | self.buttonBox.setGeometry(QtCore.QRect(180, 180, 221, 32)) 17 | self.buttonBox.setOrientation(QtCore.Qt.Horizontal) 18 | self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel|QtWidgets.QDialogButtonBox.Ok) 19 | self.buttonBox.setObjectName("buttonBox") 20 | self.layoutWidget = QtWidgets.QWidget(Dialog_binarization) 21 | self.layoutWidget.setGeometry(QtCore.QRect(0, 10, 401, 161)) 22 | self.layoutWidget.setObjectName("layoutWidget") 23 | self.verticalLayout = QtWidgets.QVBoxLayout(self.layoutWidget) 24 | self.verticalLayout.setObjectName("verticalLayout") 25 | self.horizontalLayout_8 = QtWidgets.QHBoxLayout() 26 | self.horizontalLayout_8.setObjectName("horizontalLayout_8") 27 | self.label_6 = QtWidgets.QLabel(self.layoutWidget) 28 | self.label_6.setMinimumSize(QtCore.QSize(55, 23)) 29 | self.label_6.setObjectName("label_6") 30 | self.horizontalLayout_8.addWidget(self.label_6) 31 | self.lineEdit_grayscale_mask = QtWidgets.QLineEdit(self.layoutWidget) 32 | self.lineEdit_grayscale_mask.setMinimumSize(QtCore.QSize(201, 23)) 33 | self.lineEdit_grayscale_mask.setObjectName("lineEdit_grayscale_mask") 34 | self.horizontalLayout_8.addWidget(self.lineEdit_grayscale_mask) 35 | self.pushButton_grayscale_mask = QtWidgets.QPushButton(self.layoutWidget) 36 | self.pushButton_grayscale_mask.setMinimumSize(QtCore.QSize(0, 23)) 37 | self.pushButton_grayscale_mask.setObjectName("pushButton_grayscale_mask") 38 | self.horizontalLayout_8.addWidget(self.pushButton_grayscale_mask) 39 | self.verticalLayout.addLayout(self.horizontalLayout_8) 40 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout() 41 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 42 | self.label = QtWidgets.QLabel(self.layoutWidget) 43 | self.label.setObjectName("label") 44 | self.horizontalLayout_2.addWidget(self.label) 45 | self.spinBox_forground = QtWidgets.QSpinBox(self.layoutWidget) 46 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) 47 | sizePolicy.setHorizontalStretch(0) 48 | sizePolicy.setVerticalStretch(23) 49 | sizePolicy.setHeightForWidth(self.spinBox_forground.sizePolicy().hasHeightForWidth()) 50 | self.spinBox_forground.setSizePolicy(sizePolicy) 51 | self.spinBox_forground.setMinimum(1) 52 | self.spinBox_forground.setMaximum(100000) 53 | self.spinBox_forground.setSingleStep(1) 54 | self.spinBox_forground.setProperty("value", 127) 55 | self.spinBox_forground.setObjectName("spinBox_forground") 56 | self.horizontalLayout_2.addWidget(self.spinBox_forground) 57 | spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 58 | self.horizontalLayout_2.addItem(spacerItem) 59 | self.verticalLayout.addLayout(self.horizontalLayout_2) 60 | self.horizontalLayout = QtWidgets.QHBoxLayout() 61 | self.horizontalLayout.setObjectName("horizontalLayout") 62 | self.label_7 = QtWidgets.QLabel(self.layoutWidget) 63 | self.label_7.setMinimumSize(QtCore.QSize(55, 23)) 64 | self.label_7.setObjectName("label_7") 65 | self.horizontalLayout.addWidget(self.label_7) 66 | self.lineEdit_binary_mask = QtWidgets.QLineEdit(self.layoutWidget) 67 | self.lineEdit_binary_mask.setMinimumSize(QtCore.QSize(201, 23)) 68 | self.lineEdit_binary_mask.setObjectName("lineEdit_binary_mask") 69 | self.horizontalLayout.addWidget(self.lineEdit_binary_mask) 70 | self.pushButton_binary_mask = QtWidgets.QPushButton(self.layoutWidget) 71 | self.pushButton_binary_mask.setMinimumSize(QtCore.QSize(0, 23)) 72 | self.pushButton_binary_mask.setObjectName("pushButton_binary_mask") 73 | self.horizontalLayout.addWidget(self.pushButton_binary_mask) 74 | self.verticalLayout.addLayout(self.horizontalLayout) 75 | 76 | self.retranslateUi(Dialog_binarization) 77 | self.pushButton_grayscale_mask.clicked.connect(Dialog_binarization.slot_get_grayscale_mask) 78 | self.pushButton_binary_mask.clicked.connect(Dialog_binarization.slot_get_saving_binary_mask_path) 79 | self.buttonBox.accepted.connect(Dialog_binarization.slot_ok) 80 | self.buttonBox.rejected.connect(Dialog_binarization.reject) 81 | QtCore.QMetaObject.connectSlotsByName(Dialog_binarization) 82 | 83 | def retranslateUi(self, Dialog_binarization): 84 | _translate = QtCore.QCoreApplication.translate 85 | Dialog_binarization.setWindowTitle(_translate("Dialog_binarization", "Dialog")) 86 | self.label_6.setText(_translate("Dialog_binarization", "Grayscale mask:")) 87 | self.pushButton_grayscale_mask.setText(_translate("Dialog_binarization", "Open")) 88 | self.label.setText(_translate("Dialog_binarization", "Threshold value:")) 89 | self.label_7.setText(_translate("Dialog_binarization", "Binary Mask:")) 90 | self.pushButton_binary_mask.setText(_translate("Dialog_binarization", "Open")) 91 | 92 | -------------------------------------------------------------------------------- /ui/postProcess/VoteMultimodleResults.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'VoteMultimodleResults.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.5.1 6 | # 7 | # WARNING! All changes made in this file will be lost! 8 | 9 | from PyQt5 import QtCore, QtGui, QtWidgets 10 | 11 | class Ui_Dialog_vote_multimodels(object): 12 | def setupUi(self, Dialog_vote_multimodels): 13 | Dialog_vote_multimodels.setObjectName("Dialog_vote_multimodels") 14 | Dialog_vote_multimodels.resize(419, 232) 15 | self.buttonBox = QtWidgets.QDialogButtonBox(Dialog_vote_multimodels) 16 | self.buttonBox.setGeometry(QtCore.QRect(150, 190, 201, 32)) 17 | self.buttonBox.setOrientation(QtCore.Qt.Horizontal) 18 | self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel|QtWidgets.QDialogButtonBox.Ok) 19 | self.buttonBox.setObjectName("buttonBox") 20 | self.layoutWidget = QtWidgets.QWidget(Dialog_vote_multimodels) 21 | self.layoutWidget.setGeometry(QtCore.QRect(10, 10, 401, 171)) 22 | self.layoutWidget.setObjectName("layoutWidget") 23 | self.verticalLayout = QtWidgets.QVBoxLayout(self.layoutWidget) 24 | self.verticalLayout.setObjectName("verticalLayout") 25 | self.horizontalLayout_8 = QtWidgets.QHBoxLayout() 26 | self.horizontalLayout_8.setObjectName("horizontalLayout_8") 27 | self.label_6 = QtWidgets.QLabel(self.layoutWidget) 28 | self.label_6.setMinimumSize(QtCore.QSize(55, 23)) 29 | self.label_6.setObjectName("label_6") 30 | self.horizontalLayout_8.addWidget(self.label_6) 31 | self.lineEdit_inputs = QtWidgets.QLineEdit(self.layoutWidget) 32 | self.lineEdit_inputs.setMinimumSize(QtCore.QSize(201, 23)) 33 | self.lineEdit_inputs.setObjectName("lineEdit_inputs") 34 | self.horizontalLayout_8.addWidget(self.lineEdit_inputs) 35 | self.pushButton_inputs = QtWidgets.QPushButton(self.layoutWidget) 36 | self.pushButton_inputs.setMinimumSize(QtCore.QSize(0, 23)) 37 | self.pushButton_inputs.setObjectName("pushButton_inputs") 38 | self.horizontalLayout_8.addWidget(self.pushButton_inputs) 39 | self.verticalLayout.addLayout(self.horizontalLayout_8) 40 | self.horizontalLayout = QtWidgets.QHBoxLayout() 41 | self.horizontalLayout.setObjectName("horizontalLayout") 42 | self.label_7 = QtWidgets.QLabel(self.layoutWidget) 43 | self.label_7.setMinimumSize(QtCore.QSize(55, 23)) 44 | self.label_7.setObjectName("label_7") 45 | self.horizontalLayout.addWidget(self.label_7) 46 | self.lineEdit_mask = QtWidgets.QLineEdit(self.layoutWidget) 47 | self.lineEdit_mask.setMinimumSize(QtCore.QSize(201, 23)) 48 | self.lineEdit_mask.setObjectName("lineEdit_mask") 49 | self.horizontalLayout.addWidget(self.lineEdit_mask) 50 | self.pushButton_mask = QtWidgets.QPushButton(self.layoutWidget) 51 | self.pushButton_mask.setMinimumSize(QtCore.QSize(0, 23)) 52 | self.pushButton_mask.setObjectName("pushButton_mask") 53 | self.horizontalLayout.addWidget(self.pushButton_mask) 54 | self.verticalLayout.addLayout(self.horizontalLayout) 55 | self.horizontalLayout_3 = QtWidgets.QHBoxLayout() 56 | self.horizontalLayout_3.setObjectName("horizontalLayout_3") 57 | self.groupBox = QtWidgets.QGroupBox(self.layoutWidget) 58 | self.groupBox.setObjectName("groupBox") 59 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox) 60 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 61 | self.label = QtWidgets.QLabel(self.groupBox) 62 | self.label.setObjectName("label") 63 | self.horizontalLayout_2.addWidget(self.label) 64 | self.spinBox_min = QtWidgets.QSpinBox(self.groupBox) 65 | self.spinBox_min.setObjectName("spinBox_min") 66 | self.horizontalLayout_2.addWidget(self.spinBox_min) 67 | self.label_2 = QtWidgets.QLabel(self.groupBox) 68 | self.label_2.setObjectName("label_2") 69 | self.horizontalLayout_2.addWidget(self.label_2) 70 | self.spinBox_max = QtWidgets.QSpinBox(self.groupBox) 71 | self.spinBox_max.setMinimum(1) 72 | self.spinBox_max.setProperty("value", 2) 73 | self.spinBox_max.setObjectName("spinBox_max") 74 | self.horizontalLayout_2.addWidget(self.spinBox_max) 75 | self.horizontalLayout_3.addWidget(self.groupBox) 76 | spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 77 | self.horizontalLayout_3.addItem(spacerItem) 78 | self.verticalLayout.addLayout(self.horizontalLayout_3) 79 | 80 | self.retranslateUi(Dialog_vote_multimodels) 81 | self.buttonBox.accepted.connect(Dialog_vote_multimodels.accept) 82 | self.pushButton_inputs.clicked.connect(Dialog_vote_multimodels.slot_select_input_files) 83 | self.pushButton_mask.clicked.connect(Dialog_vote_multimodels.slot_get_save_mask) 84 | self.buttonBox.accepted.connect(Dialog_vote_multimodels.slot_ok) 85 | QtCore.QMetaObject.connectSlotsByName(Dialog_vote_multimodels) 86 | 87 | def retranslateUi(self, Dialog_vote_multimodels): 88 | _translate = QtCore.QCoreApplication.translate 89 | Dialog_vote_multimodels.setWindowTitle(_translate("Dialog_vote_multimodels", "Vote result")) 90 | self.label_6.setText(_translate("Dialog_vote_multimodels", "Input masks:")) 91 | self.pushButton_inputs.setText(_translate("Dialog_vote_multimodels", "Open")) 92 | self.label_7.setText(_translate("Dialog_vote_multimodels", "Output Mask:")) 93 | self.pushButton_mask.setText(_translate("Dialog_vote_multimodels", "Open")) 94 | self.groupBox.setTitle(_translate("Dialog_vote_multimodels", "Values range")) 95 | self.label.setText(_translate("Dialog_vote_multimodels", "min:")) 96 | self.label_2.setText(_translate("Dialog_vote_multimodels", "max:")) 97 | 98 | -------------------------------------------------------------------------------- /ui/tmp/new_train_implements.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | from keras.layers import * 5 | 6 | from keras import backend as K 7 | K.set_image_dim_ordering('tf') 8 | 9 | from PyQt5.QtCore import Qt 10 | from PyQt5.QtWidgets import QDialog, QFileDialog, QMessageBox 11 | 12 | from TrainBinaryCommon import Ui_Dialog_train_binary_common 13 | from tmp.new_train_backend import train_binary_for_ui 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | seed = 7 17 | np.random.seed(seed) 18 | 19 | trainBinary_dict={'trainData_path':'', 'saveModel_path':'', 'baseModel':'', 'im_bands':3, 'dtype':0, 20 | 'windsize':256, 'network':'unet', 'target_function':'crossentropy', 21 | 'class_name':'default', 'label_value': 1, 'BS':16, 'EPOCHS':100, 'GPUID':"0"} 22 | 23 | 24 | 25 | 26 | class child_trainBinaryCommon(QDialog, Ui_Dialog_train_binary_common): 27 | def __init__(self): 28 | super(child_trainBinaryCommon, self).__init__() 29 | self.setupUi(self) 30 | 31 | def slot_traindatapath(self): 32 | str = QFileDialog.getExistingDirectory(self, "Train data path", '../../data/') 33 | self.lineEdit_traindata_path.setText(str) 34 | # QDir.setCurrent(str) 35 | 36 | def slot_savemodelpath(self): 37 | str = QFileDialog.getExistingDirectory(self, "Save model", '../../data/') 38 | self.lineEdit_savemodel.setText(str) 39 | 40 | def slot_basemodel(self): 41 | str, _= QFileDialog.getOpenFileName(self, "Select base model", '../../data/', self.tr("Models(*.h5)")) 42 | if not str=='': 43 | self.lineEdit_basemodel.setText(str) 44 | 45 | def slot_ok(self): 46 | self.setWindowModality(Qt.ApplicationModal) 47 | input_dict = trainBinary_dict 48 | if os.path.isdir(self.lineEdit_traindata_path.text()): 49 | input_dict['trainData_path'] = self.lineEdit_traindata_path.text() 50 | if os.path.isdir(self.lineEdit_savemodel.text()): 51 | input_dict['saveModel_path'] = self.lineEdit_savemodel.text() 52 | if os.path.isfile(self.lineEdit_basemodel.text()): 53 | input_dict['baseModel'] = self.lineEdit_basemodel.text() 54 | 55 | input_dict['im_bands'] = int(self.spinBox_bands.value()) 56 | input_dict['dtype'] = self.comboBox_dtype.currentIndex() 57 | input_dict['windsize'] = int(self.spinBox_windsize.value()) 58 | if self.radioButton_unet.isChecked(): 59 | input_dict['network'] = 'unet' 60 | elif self.radioButton_fcnnet.isChecked(): 61 | input_dict['network'] = 'fcnnet' 62 | elif self.radioButton_segnet.isChecked(): 63 | input_dict['network']='segnet' 64 | else: 65 | print("other network") 66 | sys.exit(-1) 67 | 68 | if self.radioButton_cross_entropy.isChecked(): 69 | input_dict['target_function'] = 'crossentropy' 70 | elif self.radioButton_jaccard.isChecked(): 71 | input_dict['target_function'] = 'jaccard' 72 | elif self.radioButton_jaccard_crossentropy.isChecked(): 73 | input_dict['target_function']='jacc_and_cross' 74 | else: 75 | print("other function") 76 | sys.exit(-1) 77 | 78 | input_dict['class_name'] = self.lineEdit_class_name.text() 79 | input_dict['label_value'] = self.spinBox_label_value.value() 80 | input_dict['BS'] = self.spinBox_BS.value() 81 | input_dict['EPOCHS'] = self.spinBox_epoch.value() 82 | input_dict['GPUID'] = self.comboBox_gupid.currentText() 83 | 84 | ret =-1 85 | ret = train_binary_for_ui(input_dict) 86 | if ret ==0: 87 | QMessageBox.information(self, "Prompt", self.tr("Model Traind successfully!")) 88 | 89 | 90 | self.setWindowModality(Qt.NonModal) -------------------------------------------------------------------------------- /ulitities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ulitities/__init__.py -------------------------------------------------------------------------------- /ulitities/band_compose.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | import os 4 | import sys 5 | import gdal 6 | import numpy as np 7 | 8 | 9 | 10 | root = '../../data/originaldata/zs/test/stretched/' 11 | files = ['extract_bottom_16bit.png','dem_r05_b.png', 'Int_slope100_Resample05_b.png'] 12 | 13 | 14 | output_file =root+'composed.png' 15 | 16 | if __name__=='__main__': 17 | all_data=[] 18 | file_path = root + files[0] 19 | if not os.path.isfile(file_path): 20 | print("File dose not exist:{}".format(file_path)) 21 | sys.exit(-1) 22 | # print("deal file: {}".format(file_path)) 23 | dataset = gdal.Open(file_path) 24 | if dataset == None: 25 | print("Open falied:{}".format(files[0])) 26 | sys.exit(-2) 27 | 28 | x = dataset.RasterXSize 29 | y = dataset.RasterYSize 30 | # im_band = dataset.RasterCount 31 | d_type = dataset.GetRasterBand(1).DataType 32 | del dataset 33 | all_bands=0 34 | 35 | for file in files: 36 | file_path = root+file 37 | if not os.path.isfile(file_path): 38 | print("File dose not exist:{}".format(file_path)) 39 | sys.exit(-3) 40 | print("deal file: {}".format(file_path)) 41 | dataset = gdal.Open(file_path) 42 | if dataset==None: 43 | print("Open falied:{}".format(file)) 44 | continue 45 | width = dataset.RasterXSize 46 | height = dataset.RasterYSize 47 | if x!=width or y!=height: 48 | print("Error: input files have different width and height\n") 49 | sys.exit(-4) 50 | im_band = dataset.RasterCount 51 | all_bands +=im_band 52 | im_type = dataset.GetRasterBand(1).DataType 53 | if d_type !=im_type: 54 | print("Error: data types are not the same!\n") 55 | sys.exit(-5) 56 | im_data = dataset.ReadAsArray(0,0,width,height) 57 | im_data = np.array(im_data) 58 | all_data.append(im_data) 59 | del dataset 60 | 61 | # all_data = np.array(all_data) 62 | # a,b,c =all_data.shape 63 | # print("allbands : {}".format(c)) 64 | 65 | my_driver = gdal.GetDriverByName("GTiff") 66 | out_dataset = my_driver.Create(output_file, x, y, all_bands, d_type) 67 | which_band =1 68 | for i in range(len(all_data)): 69 | dims= len(all_data[i].shape) 70 | print("dimension :{}".format(dims)) 71 | if dims <3: 72 | out_dataset.GetRasterBand(which_band).WriteArray(all_data[i]) 73 | which_band +=1 74 | else: 75 | im_bands = all_data[i].shape[0] 76 | for j in range(im_bands): 77 | out_dataset.GetRasterBand(which_band).WriteArray(all_data[i][j]) 78 | which_band +=1 79 | 80 | print("Saved to: {}".format(output_file)) 81 | 82 | del out_dataset 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /ulitities/ecogToPredict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import gdal 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import cv2 8 | 9 | 10 | input_file = '/home/omnisky/PycharmProjects/data/test/tianfuxinqu/ecognition/original/c2017_GJ_Clip_SVM.v1.tif' 11 | output_file = '/home/omnisky/PycharmProjects/data/test/tianfuxinqu/ecognition/pred/2017_GJ_Clip_SVM.png' 12 | if __name__=='__main__': 13 | if not os.path.isfile(input_file): 14 | print("Not valid file path or name:{}".format(input_file)) 15 | sys.exit(-1) 16 | 17 | dataset_in = gdal.Open(input_file) 18 | if None==dataset_in: 19 | print("Open file failed!") 20 | sys.exit(-2) 21 | width = dataset_in.RasterXSize 22 | height = dataset_in.RasterYSize 23 | im_bands = dataset_in.RasterCount 24 | # im_type = dataset_in. 25 | 26 | img= dataset_in.ReadAsArray(0,0,width,height) 27 | mask = np.zeros((height,width,), np.uint8) 28 | 29 | # for b in im_bands: 30 | # tmp = img[b,:,:] 31 | # indx = np.where(tmp) 32 | bk_data = img[2, :, :] 33 | # plt.imshow(bk_data, cmap='gray') 34 | # plt.show() 35 | 36 | 37 | roads_data = img[0,:,:] 38 | indx = np.where(roads_data == 255) 39 | mask[indx] =1 40 | # plt.imshow(roads_data) 41 | # plt.show() 42 | 43 | buildings_data = img[1, :, :] 44 | indx = np.where(buildings_data == 255) 45 | mask[indx] = 2 46 | # plt.imshow(buildings_data) 47 | # plt.show() 48 | plt.imshow(mask) 49 | plt.show() 50 | 51 | cv2.imwrite(output_file, mask) 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /ulitities/image_clip.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | 3 | from PIL import Image 4 | import cv2 5 | import numpy as np 6 | from base_functions import load_img_by_gdal 7 | import matplotlib.pyplot as plt 8 | import gdal 9 | import sys 10 | 11 | 12 | # input_src_file = '/home/omnisky/PycharmProjects/data/test/paper/label/yujiang_test_label.png' 13 | input_src_file ='/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/ducha/DCtuitiantu/label/cd13.png' 14 | # clip_src_file = '/home/omnisky/PycharmProjects/data/test/paper/new/yujiang_test_label.png' 15 | clip_src_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_label.png' 16 | 17 | window_size = 8000 18 | # h_clip = 5000 19 | 20 | if __name__=='__main__': 21 | # img = load_img_by_gdal(input_src_file) 22 | dataset = gdal.Open(input_src_file) 23 | if dataset==None: 24 | print("Open file failed:{}".format(input_src_file)) 25 | sys.exit(-1) 26 | # assert (ret==0) 27 | height = dataset.RasterYSize 28 | width = dataset.RasterXSize 29 | im_bands = dataset.RasterCount 30 | d_type = dataset.GetRasterBand(1).DataType 31 | img = dataset.ReadAsArray(0,0,width,height) 32 | del dataset 33 | 34 | # x = np.random.randint(0, height-window_size-1) 35 | # y = np.random.randint(0, width - window_size - 1) 36 | x =15000 37 | y=3000 38 | # h_clip = int(0.5*width+0.5) 39 | # print("cliped pixels:{}".format(h_clip)) 40 | 41 | if im_bands ==1: 42 | output_img = img[y:y + window_size, x:x + window_size] 43 | # output_img = img[100:5000+100, 100:5500+100] 44 | output_img = np.array(output_img, np.uint16) 45 | # output_img[output_img > 2] = 127 46 | tp = output_img 47 | tp[tp>2]=0 48 | print(np.unique(tp)) 49 | output_img = np.array(output_img, np.uint8) 50 | plt.imshow(output_img) 51 | plt.show() 52 | cv2.imwrite(clip_src_file, output_img) # for label clip 53 | else: 54 | output_img = img[:,y:y + window_size, x:x + window_size] 55 | # output_img = img[:, :, :h_clip] 56 | # output_img = img[:, :, h_clip:] 57 | plt.imshow(output_img[0]) 58 | plt.show() 59 | driver = gdal.GetDriverByName("GTiff") 60 | outdataset = driver.Create(clip_src_file, window_size, window_size, im_bands, d_type) 61 | # outdataset = driver.Create(clip_src_file, h_clip, height, im_bands, d_type) 62 | if outdataset == None: 63 | print("create dataset failed!\n") 64 | sys.exit(-2) 65 | if im_bands == 1: 66 | outdataset.GetRasterBand(1).WriteArray(output_img) 67 | else: 68 | for i in range(im_bands): 69 | outdataset.GetRasterBand(i + 1).WriteArray(output_img[i]) 70 | del outdataset 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /ulitities/resample_image.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import cv2, os, sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | input_file = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/originalLabelandImages/rice/test/testlabel_2.png' 9 | output_file ='/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/originalLabelandImages/rice/test/testlabel_2_resampled.png' 10 | 11 | s = 2 12 | 13 | if __name__=="__main__": 14 | img = cv2.imread(input_file, cv2.IMREAD_GRAYSCALE) 15 | a,b = img.shape[:2] 16 | m = int(a/2) 17 | n=int(b/2) 18 | 19 | # 20 | # result = [] 21 | # for x in range(m): 22 | # for y in range(n): 23 | # t = img[s*x:s*(x+1), s*y:s*(y+1)].mean() 24 | # result.append(t) 25 | # down_img = np.array((img), np.uint8).reshape(m,n) 26 | 27 | result = cv2.resize(img, (n,m)) 28 | 29 | plt.imshow(result) 30 | plt.show() 31 | 32 | cv2.imwrite(output_file,result) 33 | -------------------------------------------------------------------------------- /ulitities/test_augument.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | import sys 7 | import gdal 8 | from keras.preprocessing.image import img_to_array 9 | from scipy.signal import medfilt, medfilt2d 10 | from skimage import exposure 11 | 12 | 13 | filename = '/home/omnisky/PycharmProjects/data/traindata/APsamples/binary/buildings/3.png' 14 | 15 | outputfile = '/home/omnisky/PycharmProjects/data/traindata/APsamples/binary/buildings/3_medfit5.png' 16 | 17 | def rotate(xb, angle): 18 | xb = np.rot90(np.array(xb), k=angle) 19 | # yb = np.rot90(np.array(yb), k=angle) 20 | return xb 21 | 22 | def add_noise(xb): 23 | for i in range(1000): 24 | temp_x = np.random.randint(0, xb.shape[1]) 25 | temp_y = np.random.randint(0, xb.shape[2]) 26 | xb[:, temp_x, temp_y] =255 27 | return xb 28 | 29 | 30 | 31 | def data_augment(xb): 32 | # xb = exposure.adjust_gamma(xb, 1.0) 33 | 34 | # xb = np.transpose(xb,(1,2,0)) 35 | # xb = rotate(xb, 1) 36 | 37 | # xb = rotate(xb, 2) 38 | # 39 | # xb = rotate(xb, 3) 40 | # xb = np.transpose(xb, (2, 0, 1)) 41 | # 42 | # xb = np.transpose(xb, (1, 2, 0)) 43 | # xb = np.fliplr(xb) # flip an array horizontally 44 | # xb = np.transpose(xb, (2, 0, 1)) 45 | # 46 | # xb = np.transpose(xb, (1, 2, 0)) 47 | # xb = np.flipud(xb) # flip an array vertically (up down directory) 48 | # xb = np.transpose(xb, (2, 0, 1)) 49 | # 50 | # xb = exposure.adjust_gamma(xb, 2.0) 51 | # 52 | 53 | xb = np.transpose(xb, (1, 2, 0)) 54 | for i in range(3): 55 | xb[:,:,i] = medfilt(xb[:,:,i], (5, 5)) 56 | xb = np.transpose(xb, (2, 0, 1)) 57 | # 58 | # xb = add_noise(xb) 59 | 60 | 61 | return xb 62 | 63 | 64 | 65 | if __name__=='__main__': 66 | print("[INFO] open file") 67 | 68 | dataset = gdal.Open(filename) 69 | if dataset == None: 70 | print("open failed!\n") 71 | 72 | Y_height = dataset.RasterYSize 73 | X_width = dataset.RasterXSize 74 | im_bands = dataset.RasterCount 75 | data_type = dataset.GetRasterBand(1).DataType 76 | 77 | src_img = dataset.ReadAsArray(0, 0, X_width, Y_height) 78 | src_img = np.array(src_img) 79 | del dataset 80 | # src_img = np.transpose(src_img, (2, 1, 0)) 81 | 82 | print("[INFO] augmentation ") 83 | 84 | src_img = data_augment(src_img) 85 | 86 | # src_img = np.transpose(src_img, (1, 2, 0)) 87 | 88 | driver = gdal.GetDriverByName("GTiff") 89 | # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16) 90 | outdataset = driver.Create(outputfile, X_width,Y_height, im_bands, data_type) 91 | if im_bands == 1: 92 | outdataset.GetRasterBand(1).WriteArray(src_img) 93 | else: 94 | for i in range(im_bands): 95 | outdataset.GetRasterBand(i + 1).WriteArray(src_img[i]) 96 | del outdataset 97 | 98 | -------------------------------------------------------------------------------- /ulitities/xml_prec.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | import os 4 | import sys 5 | import xml.etree.ElementTree as ET 6 | import xml.dom.minidom as Document 7 | import xml.dom 8 | 9 | 10 | # def dict_to_xml(input_dict,root_tag,node_tag): 11 | # """ 定义根节点root_tag,定义第二层节点node_tag 12 | # 第三层中将字典中键值对对应参数名和值 13 | # return: xml的tree结构 """ 14 | # root_name = ET.Element(root_tag) 15 | # for (k, v) in input_dict.items(): 16 | # node_name = ET.SubElement(root_name, node_tag) 17 | # for (key, val) in sorted(v.items(), key=lambda e:e[0], reverse=True): 18 | # key = ET.SubElement(node_name, key) 19 | # key.text = val 20 | # return root_name 21 | 22 | 23 | # doc = Document() 24 | # root = doc.createElement('root') 25 | 26 | 27 | 28 | 29 | def generate_xml_from_dict(input_dict, xml_file): 30 | 31 | 32 | impl = xml.dom.getDOMImplementation() 33 | dom = impl.createDocument(None, 'root', None) 34 | root = dom.documentElement 35 | instance = dom.createElement('instance') 36 | root.appendChild(instance) 37 | 38 | for (key, value) in sorted(input_dict.items()): 39 | nameE = dom.createElement(key) 40 | nameT = dom.createTextNode(str(value)) 41 | nameE.appendChild(nameT) 42 | instance.appendChild(nameE) 43 | 44 | with open(xml_file, 'w', encoding='utf-8') as tp: 45 | dom.writexml(tp, addindent=' ', newl='\n', encoding='utf-8') 46 | 47 | 48 | def parse_xml_to_dict(xml_file): 49 | 50 | if not os.path.isfile(xml_file): 51 | print("input file is not existed!:{}".format(xml_file)) 52 | sys.exit(-1) 53 | 54 | tree = ET.parse(xml_file) 55 | root = tree.getroot() 56 | dict_new = {} 57 | for key, value in enumerate(root): 58 | dict_init = {} 59 | list_init = [] 60 | for item in value: 61 | list_init.append([item.tag, item.text]) 62 | for lists in list_init: 63 | dict_init[lists[0]] = lists[1] 64 | dict_new[key] = dict_init 65 | return dict_new 66 | 67 | 68 | 69 | 70 | 71 | if __name__=='__main__': 72 | # imgStretch_dict = {'input_dir': '', 'output_dir': '', 'NoData': 65535, 'OutBits': '16bits', 73 | # 'StretchRange': '1024', 74 | # 'CutValue': '100'} 75 | save_file = 'imgstretch.xml' 76 | # generate_xml_from_dict(imgStretch_dict, save_file) 77 | dict_one = parse_xml_to_dict(save_file) 78 | print(dict_one[0]) -------------------------------------------------------------------------------- /venv/bin/activate: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate" *from bash* 2 | # you cannot run it directly 3 | 4 | deactivate () { 5 | # reset old environment variables 6 | if [ -n "$_OLD_VIRTUAL_PATH" ] ; then 7 | PATH="$_OLD_VIRTUAL_PATH" 8 | export PATH 9 | unset _OLD_VIRTUAL_PATH 10 | fi 11 | if [ -n "$_OLD_VIRTUAL_PYTHONHOME" ] ; then 12 | PYTHONHOME="$_OLD_VIRTUAL_PYTHONHOME" 13 | export PYTHONHOME 14 | unset _OLD_VIRTUAL_PYTHONHOME 15 | fi 16 | 17 | # This should detect bash and zsh, which have a hash command that must 18 | # be called to get it to forget past commands. Without forgetting 19 | # past commands the $PATH changes we made may not be respected 20 | if [ -n "$BASH" -o -n "$ZSH_VERSION" ] ; then 21 | hash -r 22 | fi 23 | 24 | if [ -n "$_OLD_VIRTUAL_PS1" ] ; then 25 | PS1="$_OLD_VIRTUAL_PS1" 26 | export PS1 27 | unset _OLD_VIRTUAL_PS1 28 | fi 29 | 30 | unset VIRTUAL_ENV 31 | if [ ! "$1" = "nondestructive" ] ; then 32 | # Self destruct! 33 | unset -f deactivate 34 | fi 35 | } 36 | 37 | # unset irrelavent variables 38 | deactivate nondestructive 39 | 40 | VIRTUAL_ENV="/home/omnisky/PycharmProjects/semantic_segment_RSImage/venv" 41 | export VIRTUAL_ENV 42 | 43 | _OLD_VIRTUAL_PATH="$PATH" 44 | PATH="$VIRTUAL_ENV/bin:$PATH" 45 | export PATH 46 | 47 | # unset PYTHONHOME if set 48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway) 49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash 50 | if [ -n "$PYTHONHOME" ] ; then 51 | _OLD_VIRTUAL_PYTHONHOME="$PYTHONHOME" 52 | unset PYTHONHOME 53 | fi 54 | 55 | if [ -z "$VIRTUAL_ENV_DISABLE_PROMPT" ] ; then 56 | _OLD_VIRTUAL_PS1="$PS1" 57 | if [ "x(venv) " != x ] ; then 58 | PS1="(venv) $PS1" 59 | else 60 | if [ "`basename \"$VIRTUAL_ENV\"`" = "__" ] ; then 61 | # special case for Aspen magic directories 62 | # see http://www.zetadev.com/software/aspen/ 63 | PS1="[`basename \`dirname \"$VIRTUAL_ENV\"\``] $PS1" 64 | else 65 | PS1="(`basename \"$VIRTUAL_ENV\"`)$PS1" 66 | fi 67 | fi 68 | export PS1 69 | fi 70 | 71 | # This should detect bash and zsh, which have a hash command that must 72 | # be called to get it to forget past commands. Without forgetting 73 | # past commands the $PATH changes we made may not be respected 74 | if [ -n "$BASH" -o -n "$ZSH_VERSION" ] ; then 75 | hash -r 76 | fi 77 | -------------------------------------------------------------------------------- /venv/bin/activate.csh: -------------------------------------------------------------------------------- 1 | # This file must be used with "source bin/activate.csh" *from csh*. 2 | # You cannot run it directly. 3 | # Created by Davide Di Blasi . 4 | # Ported to Python 3.3 venv by Andrew Svetlov 5 | 6 | alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate' 7 | 8 | # Unset irrelavent variables. 9 | deactivate nondestructive 10 | 11 | setenv VIRTUAL_ENV "/home/omnisky/PycharmProjects/semantic_segment_RSImage/venv" 12 | 13 | set _OLD_VIRTUAL_PATH="$PATH" 14 | setenv PATH "$VIRTUAL_ENV/bin:$PATH" 15 | 16 | 17 | set _OLD_VIRTUAL_PROMPT="$prompt" 18 | 19 | if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then 20 | if ("venv" != "") then 21 | set env_name = "venv" 22 | else 23 | if (`basename "VIRTUAL_ENV"` == "__") then 24 | # special case for Aspen magic directories 25 | # see http://www.zetadev.com/software/aspen/ 26 | set env_name = `basename \`dirname "$VIRTUAL_ENV"\`` 27 | else 28 | set env_name = `basename "$VIRTUAL_ENV"` 29 | endif 30 | endif 31 | set prompt = "[$env_name] $prompt" 32 | unset env_name 33 | endif 34 | 35 | alias pydoc python -m pydoc 36 | 37 | rehash 38 | -------------------------------------------------------------------------------- /venv/bin/activate.fish: -------------------------------------------------------------------------------- 1 | # This file must be used with ". bin/activate.fish" *from fish* (http://fishshell.org) 2 | # you cannot run it directly 3 | 4 | function deactivate -d "Exit virtualenv and return to normal shell environment" 5 | # reset old environment variables 6 | if test -n "$_OLD_VIRTUAL_PATH" 7 | set -gx PATH $_OLD_VIRTUAL_PATH 8 | set -e _OLD_VIRTUAL_PATH 9 | end 10 | if test -n "$_OLD_VIRTUAL_PYTHONHOME" 11 | set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME 12 | set -e _OLD_VIRTUAL_PYTHONHOME 13 | end 14 | 15 | if test -n "$_OLD_FISH_PROMPT_OVERRIDE" 16 | functions -e fish_prompt 17 | set -e _OLD_FISH_PROMPT_OVERRIDE 18 | . ( begin 19 | printf "function fish_prompt\n\t#" 20 | functions _old_fish_prompt 21 | end | psub ) 22 | functions -e _old_fish_prompt 23 | end 24 | 25 | set -e VIRTUAL_ENV 26 | if test "$argv[1]" != "nondestructive" 27 | # Self destruct! 28 | functions -e deactivate 29 | end 30 | end 31 | 32 | # unset irrelavent variables 33 | deactivate nondestructive 34 | 35 | set -gx VIRTUAL_ENV "/home/omnisky/PycharmProjects/semantic_segment_RSImage/venv" 36 | 37 | set -gx _OLD_VIRTUAL_PATH $PATH 38 | set -gx PATH "$VIRTUAL_ENV/bin" $PATH 39 | 40 | # unset PYTHONHOME if set 41 | if set -q PYTHONHOME 42 | set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME 43 | set -e PYTHONHOME 44 | end 45 | 46 | if test -z "$VIRTUAL_ENV_DISABLE_PROMPT" 47 | # fish uses a function instead of an env var to generate the prompt. 48 | 49 | # save the current fish_prompt function as the function _old_fish_prompt 50 | . ( begin 51 | printf "function _old_fish_prompt\n\t#" 52 | functions fish_prompt 53 | end | psub ) 54 | 55 | # with the original prompt function renamed, we can override with our own. 56 | function fish_prompt 57 | # Prompt override? 58 | if test -n "$(venv) " 59 | printf "%s%s%s" "$(venv) " (set_color normal) (_old_fish_prompt) 60 | return 61 | end 62 | # ...Otherwise, prepend env 63 | set -l _checkbase (basename "$VIRTUAL_ENV") 64 | if test $_checkbase = "__" 65 | # special case for Aspen magic directories 66 | # see http://www.zetadev.com/software/aspen/ 67 | printf "%s[%s]%s %s" (set_color -b blue white) (basename (dirname "$VIRTUAL_ENV")) (set_color normal) (_old_fish_prompt) 68 | else 69 | printf "%s(%s)%s%s" (set_color -b blue white) (basename "$VIRTUAL_ENV") (set_color normal) (_old_fish_prompt) 70 | end 71 | end 72 | 73 | set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV" 74 | end 75 | -------------------------------------------------------------------------------- /venv/bin/python: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/venv/bin/python -------------------------------------------------------------------------------- /venv/bin/python3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/venv/bin/python3 -------------------------------------------------------------------------------- /venv/lib64: -------------------------------------------------------------------------------- 1 | lib -------------------------------------------------------------------------------- /venv/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | home = /usr/bin 2 | include-system-site-packages = true 3 | version = 3.5.2 4 | --------------------------------------------------------------------------------