├── .idea
├── dictionaries
│ ├── omnisky.xml
│ └── root.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── semantic_segment_RSImage.iml
└── vcs.xml
├── README.md
├── batch_predict
├── __init__.py
├── batch_base_predict_functions.py
├── batch_predict_binary_jaccard.py
├── batch_predict_binary_notonehot.py
├── batch_predict_binary_onehot.py
├── batch_predict_binary_onlyjaccard.py
├── batch_predict_multiclass.py
└── batch_smooth_tiled_predictions.py
├── model_build
├── config.json
├── config.py
├── config_WHU_buildings.json
├── config_multiclass_global.json
├── config_multiclass_manmade.json
├── config_scrs_buildings.json
├── config_tuitiantu.json
├── mlp_model.json
├── model.json
├── test_1.py
├── train_main.py
└── utils.py
├── model_predict
├── config_pred.json
├── config_pred.py
├── config_pred_WHU.json
├── config_pred_bieshu.json
├── config_pred_multiclass.json
├── config_pred_multiclass_manmade.json
├── config_pred_scrsbuilings.json
├── config_pred_tuitiantu.json
├── predict_backbone.py
├── predict_main.py
├── predict_main_blocks.py
└── test_predict_1.py
├── postprocess
├── __init__.py
├── acc_evaluate.py
├── combine_diffclass_for_singlemodel_result.py
├── mismatch_analyze.py
├── raster_to_vector.py
└── vote.py
├── predict
├── __init__.py
├── base_predict_functions.py
├── predict_binary_jaccard.py
├── predict_binary_notonehot.py
├── predict_binary_onehot.py
├── predict_binary_onlyjaccard.py
├── predict_multiclass.py
└── smooth_tiled_predictions.py
├── samples_produce
├── __init__.py
├── check_original_labels_froNodata.py
├── label_visulise.py
├── sample_produce_for_singleimage.py
├── traindata_generate_byCV.py
├── traindata_generate_bygdal.py
└── traindata_generate_common.py
├── segmentation_models
├── __init__.py
├── __version__.py
├── backbones
│ ├── __init__.py
│ ├── inception_resnet_v2.py
│ ├── inception_v3.py
│ ├── mobilenet.py
│ └── mobilenetv2.py
├── common
│ ├── __init__.py
│ ├── blocks.py
│ ├── functions.py
│ └── layers.py
├── fpn
│ ├── __init__.py
│ ├── blocks.py
│ ├── builder.py
│ └── model.py
├── linknet
│ ├── __init__.py
│ ├── blocks.py
│ ├── builder.py
│ └── model.py
├── losses.py
├── metrics.py
├── pspnet
│ ├── __init__.py
│ ├── blocks.py
│ ├── builder.py
│ └── model.py
├── unet
│ ├── __init__.py
│ ├── blocks.py
│ ├── builder.py
│ └── model.py
└── utils.py
├── temp
├── __init__.py
├── all_predict.py
├── band4_image.py
├── change_geotransform.py
├── change_label_zym.py
├── compose_labels.py
├── fcn8_train_binary.py
├── main.py
├── modify_segnet_train_labels.py
├── multibans_saveas_RGB.py
├── predict_from_xuhuimin.py
├── segnet_predict.py
├── segnet_train_binary.py
├── segnet_train_multiclass.py
├── test_cv2read.py
├── test_for_jaccrad_predict.py
├── test_unet_multiclass_predict.py
├── train_binary_4orMorebands.py
├── train_binary_jaccard_4orMorebands.py
├── train_multiclass_jaccard.py
├── unet_predict.py
├── unet_train_binary.py
├── unet_train_multiclass.py
└── unet_train_qiwenchao.py
├── train
├── Unet_resnet.py
├── Unet_resnet_test.py
├── __init__.py
├── semantic_segmentation_networks.py
├── train_binary_jaccard.py
├── train_binary_jaccard_2.py
├── train_binary_notOneHot.py
├── train_binary_onehot.py
├── train_binary_onlyjaccard.py
├── train_binary_onlyjaccard_2.py
└── train_multiclass.py
├── ui
├── MainWin.py
├── MainWin.ui
├── about.py
├── about.ui
├── classifyUi
│ ├── PredictBackend.py
│ ├── PredictBinaryBatch.py
│ ├── PredictBinaryBatch.ui
│ ├── PredictBinaryForSingleimage.py
│ ├── PredictBinaryForSingleimage.ui
│ ├── PredictMulticlassBatch.py
│ ├── PredictMulticlassBatch.ui
│ ├── PredictMulticlassForSingleimage.py
│ ├── PredictMulticlassForSingleimage.ui
│ └── predict_implements.py
├── else
│ ├── manual.docx
│ ├── scrslogo.png
│ └── 中文标签.docx
├── main_gui.py
├── mysrc.qrc
├── mysrc_rc.py
├── postProcess
│ ├── AccuracyEvaluate.py
│ ├── AccuracyEvaluate.ui
│ ├── Binarization.py
│ ├── Binarization.ui
│ ├── CombineMulticlassFromSingleModelResults.py
│ ├── CombineMulticlassFromSingleModelResults.ui
│ ├── PostPrecessBackend.py
│ ├── VoteMultimodleResults.py
│ ├── VoteMultimodleResults.ui
│ └── postProcess_implements.py
├── preProcess
│ ├── ImageClip.py
│ ├── ImageClip.ui
│ ├── ImageStretch.py
│ ├── ImageStretch.ui
│ ├── hist_process.py
│ ├── label_check.py
│ ├── label_check.ui
│ ├── preprocess_backend.py
│ └── preprocess_implements.py
├── sampleProduce
│ ├── SampleGenCommon.py
│ ├── SampleGenCommon.ui
│ ├── SampleGenSelfAdapt.py
│ ├── SampleGenSelfAdapt.ui
│ ├── sampleProcess_backend.py
│ └── sampleProcess_implements.py
├── tmp
│ ├── new_train_backend.py
│ └── new_train_implements.py
└── trainUi
│ ├── TrainBinaryCommon.py
│ ├── TrainBinaryCommon.ui
│ ├── TrainBinaryCrossentropy.py
│ ├── TrainBinaryCrossentropy.ui
│ ├── TrainBinaryJaccard.py
│ ├── TrainBinaryJaccard.ui
│ ├── TrainBinaryJaccardCrossentropy.py
│ ├── TrainBinaryJaccardCrossentropy.ui
│ ├── TrainBinaryOnehot.py
│ ├── TrainBinaryOnehot.ui
│ ├── TrainMulticlass.py
│ ├── TrainMulticlass.ui
│ ├── modelTrainBackend.py
│ ├── trainModels_implements.py
│ └── untitled.ui
├── ulitities
├── __init__.py
├── band_compose.py
├── base_functions.py
├── base_predict_functions.py
├── ecogToPredict.py
├── image_clip.py
├── image_stretch.py
├── resample_image.py
├── smooth_tiled_predictions.py
├── test_augument.py
└── xml_prec.py
└── venv
├── bin
├── activate
├── activate.csh
├── activate.fish
├── python
└── python3
├── lib64
└── pyvenv.cfg
/.idea/dictionaries/omnisky.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | labelencoder
5 | multiclass
6 | unet
7 |
8 |
9 |
--------------------------------------------------------------------------------
/.idea/dictionaries/root.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | dataset
5 | imagenet
6 | jaccard
7 | notonehot
8 | originaldata
9 | segnet
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/semantic_segment_RSImage.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | We use several semantic segmentation models to deal remote sensing images classifcation.
2 |
3 | Semantic segmentation networks:U-net,PSPNet, FPN, LinkNet, DeepLab V3+
4 |
5 | Backones:
6 | VGG:'vgg16' 'vgg19'
7 |
8 | ResNet:'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'
9 |
10 | SE-ResNet:'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'
11 |
12 | ResNeXt:'resnext50' 'resnext101'
13 |
14 | SE-ResNeXt:'seresnext50' 'seresnext101'
15 |
16 | SENet154:'senet154'
17 |
18 | DenseNet:'densenet121' 'densenet169' 'densenet201'
19 |
20 | Inception:'inceptionv3' 'inceptionresnetv2'
21 |
22 | MobileNet:'mobilenet' 'mobilenetv2'
23 |
24 | EfficientNet:'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3'
25 |
26 | Requirements:
27 |
28 | tensorflow-gpu==1.9.0(only test on V1.9.0)
29 | keras>=2.2.4
30 | keras_applications==1.0.7
31 | image-classifiers==0.2.0
32 | efficientnet>=0.0.3
33 | cuda==9.0
34 | qt==5.6
35 | numpy==1.12.0
36 | scipy==0.19.1
37 | tqdm==4.11.2
38 |
39 |
40 |
--------------------------------------------------------------------------------
/batch_predict/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/batch_predict/__init__.py
--------------------------------------------------------------------------------
/batch_predict/batch_predict_binary_jaccard.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 | from ulitities.base_functions import get_file
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256 # 224, 256, 288, 320
36 | # step = 128
37 |
38 | im_bands =3
39 | im_type = UINT8 # UINT10,UINT8,UINT16
40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
41 | dict_target={0: 'roads', 1: 'buildings'}
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings
45 |
46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
47 |
48 | input_path = '/home/omnisky/PycharmProjects/data/test/APtest/images/'
49 | output_path = ''.join(['/home/omnisky/PycharmProjects/data/test/APtest/pred_', str(window_size)])
50 |
51 |
52 | # model_file = ''.join(['../../data/models/sat_urban_rgb/',dict_network[FLAG_USING_NETWORK], '_',
53 | # dict_target[FLAG_TARGET_CLASS],'_binary_jaccard_', str(window_size), '_final.h5'])
54 | model_file ='/home/omnisky/PycharmProjects/data/models/APsamples/unet_buildings_binary_jaccard_256_final.h5'
55 |
56 |
57 | print("model: {}".format(model_file))
58 |
59 | def predict_binary_jaccard(img_file, output_file):
60 |
61 | print("[INFO] opening image...")
62 | input_img = load_img_by_gdal(img_file)
63 | if im_type == UINT8:
64 | input_img = input_img / 255.0
65 | elif im_type == UINT10:
66 | input_img = input_img / 1024.0
67 | elif im_type == UINT16:
68 | input_img = input_img / 65535.0
69 |
70 | input_img = np.clip(input_img, 0.0, 1.0)
71 | input_img = input_img.astype(np.float16)
72 |
73 | model = load_model(model_file)
74 |
75 | if FLAG_APPROACH_PREDICT==0:
76 | print("[INFO] predict image by orignal approach\n")
77 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
78 | abs_filename = os.path.split(img_file)[1]
79 | abs_filename = abs_filename.split(".")[0]
80 | # output_file = ''.join([output_path, '/original_pred_',
81 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png'])
82 | output_file = ''.join([output_path, '/mask_binary_',
83 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_original.png'])
84 | print("result save as to: {}".format(output_file))
85 | cv2.imwrite(output_file, result*128)
86 |
87 | elif FLAG_APPROACH_PREDICT==1:
88 | print("[INFO] predict image by smooth approach\n")
89 | result = predict_img_with_smooth_windowing_multiclassbands(
90 | input_img,
91 | model,
92 | window_size=window_size,
93 | subdivisions=2,
94 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
95 | pred_func=smooth_predict_for_binary_notonehot,
96 | PLOT_PROGRESS=False
97 | )
98 |
99 | cv2.imwrite(output_file, result)
100 | print("Saved to: {}".format(output_file))
101 |
102 | gc.collect()
103 |
104 |
105 | if __name__ == '__main__':
106 |
107 | all_files, num = get_file(input_path)
108 | if num == 0:
109 | print("There is no file in path:{}".format(input_path))
110 | sys.exit(-1)
111 |
112 | """checke model file"""
113 | print("model file: {}".format(model_file))
114 | if not os.path.isfile(model_file):
115 | print("model does not exist:{}".format(model_file))
116 | sys.exit(-2)
117 |
118 | for in_file in all_files:
119 | abs_filename = os.path.split(in_file)[1]
120 | abs_filename = abs_filename.split(".")[0]
121 | print(abs_filename)
122 | out_file = ''.join([output_path, '/mask_binary_',
123 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard.png'])
124 | predict_binary_jaccard(in_file, out_file)
--------------------------------------------------------------------------------
/batch_predict/batch_predict_binary_notonehot.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 | from ulitities.base_functions import get_file
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256
36 | # step = 128
37 |
38 | im_bands =4
39 | im_type = UINT10 # UINT10,UINT8,UINT16
40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
41 | dict_target={0: 'roads', 1: 'buildings'}
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings
45 |
46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
47 |
48 | input_path = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/Tianfuxinqu/images//'
49 | output_path = ''.join(['/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/Tianfuxinqu/pred/pred_', str(window_size)])
50 |
51 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_',
52 | dict_target[FLAG_TARGET_CLASS],'_binary_notonehot_final.h5'])
53 |
54 | print("model: {}".format(model_file))
55 |
56 | def predict_binary_notonehot(img_file, output_file):
57 |
58 | print("[INFO] opening image...")
59 |
60 | input_img = load_img_by_gdal(img_file)
61 | if im_type == UINT8:
62 | input_img = input_img / 255.0
63 | elif im_type == UINT10:
64 | input_img = input_img / 1024.0
65 | elif im_type == UINT16:
66 | input_img = input_img / 65535.0
67 |
68 | input_img = np.clip(input_img, 0.0, 1.0)
69 | input_img = input_img.astype(np.float16)
70 |
71 | """checke model file"""
72 | print("model file: {}".format(model_file))
73 | if not os.path.isfile(model_file):
74 | print("model does not exist:{}".format(model_file))
75 | sys.exit(-2)
76 |
77 | model = load_model(model_file)
78 |
79 | if FLAG_APPROACH_PREDICT==0:
80 | print("[INFO] predict image by orignal approach\n")
81 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
82 | abs_filename = os.path.split(img_file)[1]
83 | abs_filename = abs_filename.split(".")[0]
84 | output_file = ''.join([output_path, '/original_pred_',
85 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard.png'])
86 | print("result save as to: {}".format(output_file))
87 | cv2.imwrite(output_file, result*128)
88 |
89 | elif FLAG_APPROACH_PREDICT==1:
90 | print("[INFO] predict image by smooth approach\n")
91 | result = predict_img_with_smooth_windowing_multiclassbands(
92 | input_img,
93 | model,
94 | window_size=window_size,
95 | subdivisions=2,
96 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
97 | pred_func=smooth_predict_for_binary_notonehot,
98 | PLOT_PROGRESS=False
99 | )
100 |
101 | cv2.imwrite(output_file, result)
102 | print("Saved to: {}".format(output_file))
103 |
104 | gc.collect()
105 |
106 |
107 | if __name__ == '__main__':
108 |
109 | all_files, num = get_file(input_path)
110 | if num == 0:
111 | print("There is no file in path:{}".format(input_path))
112 | sys.exit(-1)
113 |
114 | """checke model file"""
115 | print("model file: {}".format(model_file))
116 | if not os.path.isfile(model_file):
117 | print("model does not exist:{}".format(model_file))
118 | sys.exit(-2)
119 |
120 | for in_file in all_files:
121 | abs_filename = os.path.split(in_file)[1]
122 | abs_filename = abs_filename.split(".")[0]
123 | print(abs_filename)
124 | out_file = ''.join([output_path, '/mask_binary_',
125 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_notonehot.png'])
126 | predict_binary_notonehot(in_file, out_file)
127 |
128 |
129 |
--------------------------------------------------------------------------------
/batch_predict/batch_predict_binary_onehot.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_binary_onehot
23 | from ulitities.base_functions import load_img_normalization, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 | from ulitities.base_functions import get_file
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256
36 | # step = 128
37 | im_bands = 3
38 | im_type = UINT8
39 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
40 | dict_target={0: 'roads', 1: 'buildings'}
41 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
42 |
43 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings
44 |
45 | FLAG_APPROACH_PREDICT = 0 # 0: original predict, 1: smooth predict
46 |
47 | input_path = '../../data/test/paper/images/'
48 | output_path = ''.join(['../../data/test/paper/pred_', str(window_size)])
49 |
50 | # model_file = ''.join(['../../data/models/sat_urban_nrg/',dict_network[FLAG_USING_NETWORK], '_', dict_target[FLAG_TARGET_CLASS],'_binary.h5'])
51 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_buildings_binary2_onehot.h5'
52 | model_file='/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onehot.h5'
53 | print("model: {}".format(model_file))
54 |
55 | def predict_binary_onehot(img_file, output_file):
56 |
57 | print("[INFO] opening image...")
58 | input_img = load_img_by_gdal(img_file)
59 | if im_type == UINT8:
60 | input_img = input_img / 255.0
61 | elif im_type == UINT10:
62 | input_img = input_img / 1024.0
63 | elif im_type == UINT16:
64 | input_img = input_img / 65535.0
65 | input_img = np.clip(input_img, 0.0, 1.0)
66 |
67 | """checke model file"""
68 | print("model file: {}".format(model_file))
69 | if not os.path.isfile(model_file):
70 | print("model does not exist:{}".format(model_file))
71 | sys.exit(-2)
72 |
73 | model = load_model(model_file)
74 |
75 | if FLAG_APPROACH_PREDICT==0:
76 | print("[INFO] predict image by orignal approach\n")
77 | result = orignal_predict_onehot(input_img, im_bands, model, window_size)
78 | abs_filename = os.path.split(img_file)[1]
79 | abs_filename = abs_filename.split(".")[0]
80 | print(abs_filename)
81 | output_file = ''.join(['../../data/predict/original_predict_',abs_filename, '_onehot.png'])
82 | print("result save as to: {}".format(output_file))
83 | cv2.imwrite(output_file, result*100)
84 |
85 | elif FLAG_APPROACH_PREDICT==1:
86 | print("[INFO] predict image by smooth approach\n")
87 | result = predict_img_with_smooth_windowing_multiclassbands(
88 | input_img,
89 | model,
90 | window_size=window_size,
91 | subdivisions=2,
92 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
93 | pred_func=smooth_predict_for_binary_onehot,
94 | PLOT_PROGRESS=False
95 | )
96 |
97 | cv2.imwrite(output_file, result)
98 | print("Saved to: {}".format(output_file))
99 |
100 | gc.collect()
101 |
102 |
103 | if __name__ == '__main__':
104 |
105 | all_files, num = get_file(input_path)
106 | if num == 0:
107 | print("There is no file in path:{}".format(input_path))
108 | sys.exit(-1)
109 |
110 | """checke model file"""
111 | print("model file: {}".format(model_file))
112 | if not os.path.isfile(model_file):
113 | print("model does not exist:{}".format(model_file))
114 | sys.exit(-2)
115 |
116 | for in_file in all_files:
117 | abs_filename = os.path.split(in_file)[1]
118 | abs_filename = abs_filename.split(".")[0]
119 | print(abs_filename)
120 | out_file = ''.join([output_path, '/mask_binary_',
121 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_onehot.png'])
122 | predict_binary_onehot(in_file, out_file)
--------------------------------------------------------------------------------
/batch_predict/batch_predict_binary_onlyjaccard.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 | from ulitities.base_functions import get_file
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256 # 224, 256, 288
36 | # step = 128
37 |
38 | im_bands =4
39 | im_type = UINT10 # UINT10,UINT8,UINT16
40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
41 | dict_target={0: 'roads', 1: 'buildings'}
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings
45 |
46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
47 |
48 |
49 | input_path = '../../data/test/paper/images/'
50 | output_path = ''.join(['../../data/test/paper/pred_', str(window_size)])
51 |
52 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_',
53 | dict_target[FLAG_TARGET_CLASS],'_binary_onlyjaccard_final.h5'])
54 |
55 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onlyjaccard_2018-09-29_18-55-11.h5'
56 | print("model: {}".format(model_file))
57 |
58 | def predict_binary_only_jaccard(img_file, output_file):
59 |
60 | print("[INFO] opening image...")
61 |
62 | input_img = load_img_by_gdal(img_file)
63 | if im_type == UINT8:
64 | input_img = input_img / 255.0
65 | elif im_type == UINT10:
66 | input_img = input_img / 1024.0
67 | elif im_type == UINT16:
68 | input_img = input_img / 65535.0
69 |
70 | input_img = np.clip(input_img, 0.0, 1.0)
71 |
72 | model = load_model(model_file)
73 |
74 | if FLAG_APPROACH_PREDICT==0:
75 | print("[INFO] predict image by orignal approach\n")
76 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
77 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_',
78 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png'])
79 | print("result save as to: {}".format(output_file))
80 | cv2.imwrite(output_file, result*128)
81 |
82 | elif FLAG_APPROACH_PREDICT==1:
83 | print("[INFO] predict image by smooth approach\n")
84 | result = predict_img_with_smooth_windowing_multiclassbands(
85 | input_img,
86 | model,
87 | window_size=window_size,
88 | subdivisions=2,
89 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
90 | pred_func=smooth_predict_for_binary_notonehot,
91 | PLOT_PROGRESS=False
92 | )
93 |
94 | cv2.imwrite(output_file, result)
95 | print("Saved to: {}".format(output_file))
96 |
97 | gc.collect()
98 |
99 |
100 |
101 | if __name__=='__main__':
102 |
103 | all_files, num= get_file(input_path)
104 | if num==0:
105 | print("There is no file in path:{}".format(input_path))
106 | sys.exit(-1)
107 |
108 |
109 | """checke model file"""
110 | print("model file: {}".format(model_file))
111 | if not os.path.isfile(model_file):
112 | print("model does not exist:{}".format(model_file))
113 | sys.exit(-2)
114 |
115 | for in_file in all_files:
116 | abs_filename = os.path.split(in_file)[1]
117 | abs_filename = abs_filename.split(".")[0]
118 | print(abs_filename)
119 | out_file = ''.join([output_path, '/mask_binary_',
120 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_onlyjaccard.png'])
121 | predict_binary_only_jaccard(in_file, out_file)
122 |
123 |
124 |
--------------------------------------------------------------------------------
/batch_predict/batch_predict_multiclass.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_multiclass
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | from ulitities.base_functions import get_file
26 | """
27 | The following global variables should be put into meta data file
28 | """
29 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
30 |
31 |
32 | window_size = 256
33 | step = 128
34 |
35 | im_bands = 4
36 | im_type = UINT10 #UINT8, UINT10, UINT16
37 |
38 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
39 | dict_target={0: 'roads', 1: 'buildings'}
40 | target_class=len(dict_target)
41 |
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_APPROACH_PREDICT=1 # 0: original predict, 1: smooth predict
45 |
46 | input_path = '../../data/test/paper/images/'
47 | output_path = ''.join(['../../data/test/paper/pred_', str(window_size)])
48 |
49 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_multiclass_final.h5'])
50 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_multiclass_2018-09-11_14-05-31.h5'
51 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_multiclass.h5'
52 |
53 | def predict_multiclass(img_file, out_path):
54 |
55 | print("[INFO] opening image...")
56 | input_img = load_img_by_gdal(img_file)
57 | if im_type == UINT8:
58 | input_img = input_img / 255.0
59 | elif im_type == UINT10:
60 | input_img = input_img / 1024.0
61 | elif im_type == UINT16:
62 | input_img = input_img / 65535.0
63 | input_img = np.clip(input_img, 0.0, 1.0)
64 |
65 | """checke model file"""
66 | print("model file: {}".format(model_file))
67 | if not os.path.isfile(model_file):
68 | print("model does not exist:{}".format(model_file))
69 | sys.exit(-2)
70 |
71 | model= load_model(model_file)
72 | abs_filename = os.path.split(img_file)[1]
73 | abs_filename = abs_filename.split(".")[0]
74 | print(abs_filename)
75 |
76 | if FLAG_APPROACH_PREDICT==0:
77 | print("[INFO] predict image by orignal approach\n")
78 | result = orignal_predict_onehot(input_img, im_bands, model, window_size)
79 | output_file = ''.join([out_path, '/original_predict_',abs_filename, '_multiclass.png'])
80 | print("result save as to: {}".format(output_file))
81 | cv2.imwrite(output_file, result*128)
82 |
83 | elif FLAG_APPROACH_PREDICT==1:
84 | print("[INFO] predict image by smooth approach\n")
85 | result = predict_img_with_smooth_windowing_multiclassbands(
86 | input_img,
87 | model,
88 | window_size=window_size,
89 | subdivisions=2,
90 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
91 | pred_func=smooth_predict_for_multiclass,
92 | PLOT_PROGRESS=False
93 | )
94 |
95 | for b in range(target_class):
96 | output_file = ''.join([out_path, '/mask_multiclass_',
97 | abs_filename, '_', dict_target[b], '.png'])
98 | cv2.imwrite(output_file, result[:,:,b])
99 | print("Saved to: {}".format(output_file))
100 | gc.collect()
101 |
102 | if __name__ == '__main__':
103 |
104 | all_files, num = get_file(input_path)
105 | if num == 0:
106 | print("There is no file in path:{}".format(input_path))
107 | sys.exit(-1)
108 |
109 | """checke model file"""
110 | print("model file: {}".format(model_file))
111 | if not os.path.isfile(model_file):
112 | print("model does not exist:{}".format(model_file))
113 | sys.exit(-2)
114 |
115 | for in_file in all_files:
116 | abs_filename = os.path.split(in_file)[1]
117 | abs_filename = abs_filename.split(".")[0]
118 | print(abs_filename)
119 | # out_file = ''.join([output_path, '/mask_binary_',
120 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_notonehot.png'])
121 | predict_multiclass(in_file, output_path)
122 |
123 |
124 |
--------------------------------------------------------------------------------
/model_build/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_path":"/home/omnisky/PycharmProjects/data/traindata/rice/",
3 | "img_w":256,
4 | "img_h":256,
5 | "im_bands":3,
6 | "im_type":"UINT8",
7 | "target_name":"rice",
8 | "val_rate":0.25,
9 |
10 | "network":"unet",
11 | "BACKBONE":"resnet34",
12 | "activation":"sigmoid",
13 | "encoder_weights":null,
14 | "nb_classes":1,
15 | "batch_size":32,
16 | "epochs": 50,
17 |
18 | "loss":"binary_crossentropy",
19 | "metrics":"accuracy",
20 | "optimizer": "sgd",
21 | "lr": 0.0001,
22 | "lr_steps": [540, 560],
23 | "lr_gamma": 0.1,
24 | "lr_scheduler":"CosineAnnealingLR",
25 | "nb_epoch": 48,
26 | "old_epoch":0,
27 | "test_pad": 64,
28 |
29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/rice/",
30 | "base_model":"",
31 | "monitor":"val_loss",
32 | "save_best_only":true,
33 | "mode":"max",
34 | "factor":0.1,
35 | "patience":5,
36 | "epsilon":0.0001,
37 | "cooldown":0,
38 | "min_lr":0,
39 |
40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/",
41 | "iter_size": 1,
42 | "folder": "0312_rice_size",
43 | "predict_batch_size": 16,
44 | "results_dir": "/media/scrs/Data1/chenjun/data/rice/results/",
45 | "loss_x": {"dice": 0.9, "bce": 0.1},
46 | "loss_ema":0.3,
47 | "loss_kl":25,
48 | "loss_student":1,
49 | "ignore_target_size": false,
50 | "warmup": 0,
51 | "lovasz":false,
52 | "unlabel_size":1,
53 | "negative_rate":1,
54 | "external_weights":"",
55 | "class_weights":[],
56 | "train_fold":-1,
57 | "use_lb":true,
58 | "train_images":"images_clip",
59 | "train_images_unlabel":"",
60 | "train_masks":"label_half_nodata",
61 | "train_image_suffix":"*.png",
62 | "csv_file":"folds_rice.csv",
63 | "num_workers":4,
64 | "use_border":false,
65 | "class_picture":false,
66 | "use_good_image":true,
67 | "hard_negative_miner": false,
68 | "mixup": false
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/model_build/config.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 |
4 | Config = namedtuple("Config", [
5 | "train_data_path",
6 | "img_w",
7 | "img_h",
8 | "im_bands",
9 | "im_type",
10 | "target_name",
11 | "val_rate",
12 | "network",
13 | "BACKBONE",
14 | "activation",
15 | "encoder_weights",
16 | "nb_classes",
17 | "batch_size",
18 | "epochs",
19 | "optimizer",
20 | "loss",
21 | "metrics",
22 | "lr",
23 | "lr_steps",
24 | "lr_gamma",
25 | "lr_scheduler",
26 | "nb_epoch",
27 | "old_epoch",
28 | "test_pad",
29 | "model_dir",
30 | "base_model",
31 | "monitor",
32 | "save_best_only",
33 | "mode",
34 | "factor",
35 | "patience",
36 | "epsilon",
37 | "cooldown",
38 | "min_lr",
39 | "log_dir",
40 | "iter_size",
41 | "folder",
42 | "predict_batch_size",
43 | "results_dir",
44 | "loss_x",
45 | "loss_ema",
46 | "loss_kl",
47 | "loss_student",
48 | "ignore_target_size",
49 | "warmup",
50 | "lovasz",
51 | "unlabel_size",
52 | "negative_rate",
53 | "external_weights",
54 | "class_weights",
55 | "train_fold",
56 | "use_lb",
57 | "train_images",
58 | "train_images_unlabel",
59 | "train_masks",
60 | "train_image_suffix",
61 | "csv_file",
62 | "num_workers",
63 | "use_border",
64 | "class_picture",
65 | "use_good_image",
66 | "hard_negative_miner",
67 | "mixup"
68 | ])
69 |
70 |
71 |
--------------------------------------------------------------------------------
/model_build/config_WHU_buildings.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_path":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/traindata/train_my/",
3 | "img_w":512,
4 | "img_h":512,
5 | "im_bands":3,
6 | "im_type":"UINT8",
7 | "target_name":"WHU_building",
8 | "val_rate":0.25,
9 |
10 | "network":"fpn",
11 | "BACKBONE":"seresnet18",
12 | "activation":"sigmoid",
13 | "encoder_weights":"imagenet",
14 | "nb_classes":1,
15 | "batch_size":4,
16 | "epochs": 80,
17 |
18 | "loss":"bce_jaccard_loss",
19 | "metrics":"iou_score",
20 | "optimizer": "sgd",
21 | "lr": 0.0001,
22 | "lr_steps": [540, 560],
23 | "lr_gamma": 0.1,
24 | "lr_scheduler":"CosineAnnealingLR",
25 | "nb_epoch": 48,
26 | "old_epoch":0,
27 | "test_pad": 64,
28 |
29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/WHU/pre-trained/",
30 | "base_model":"",
31 | "monitor":"val_loss",
32 | "save_best_only":true,
33 | "mode":"min",
34 | "factor":0.1,
35 | "patience":10,
36 | "epsilon":0.0001,
37 | "cooldown":0,
38 | "min_lr":0,
39 |
40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/",
41 | "iter_size": 1,
42 | "folder": "0312_rice_size",
43 | "predict_batch_size": 16,
44 | "results_dir": "/media/scrs/Data1/chenjun/data/rice/results/",
45 | "loss_x": {"dice": 0.9, "bce": 0.1},
46 | "loss_ema":0.3,
47 | "loss_kl":25,
48 | "loss_student":1,
49 | "ignore_target_size": false,
50 | "warmup": 0,
51 | "lovasz":false,
52 | "unlabel_size":1,
53 | "negative_rate":1,
54 | "external_weights":"",
55 | "class_weights":[],
56 | "train_fold":-1,
57 | "use_lb":true,
58 | "train_images":"images_clip",
59 | "train_images_unlabel":"",
60 | "train_masks":"label_half_nodata",
61 | "train_image_suffix":"*.png",
62 | "csv_file":"folds_rice.csv",
63 | "num_workers":4,
64 | "use_border":false,
65 | "class_picture":false,
66 | "use_good_image":true,
67 | "hard_negative_miner": false,
68 | "mixup": false
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/model_build/config_multiclass_global.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_path":"/home/omnisky/PycharmProjects/data/traindata/global/multiclass_miandiantaiguo_histM/",
3 | "img_w":288,
4 | "img_h":288,
5 | "im_bands":4,
6 | "im_type":"uint10",
7 | "target_name":"global",
8 | "val_rate":0.25,
9 |
10 | "network":"unet",
11 | "BACKBONE":"resnet34",
12 | "activation":"softmax",
13 | "encoder_weights":null,
14 | "nb_classes":6,
15 | "batch_size":4,
16 | "epochs": 100,
17 |
18 | "loss":"categorical_crossentropy",
19 | "metrics":"accuracy",
20 | "optimizer": "adam",
21 | "lr": 0.0001,
22 | "lr_steps": [540, 560],
23 | "lr_gamma": 0.1,
24 | "lr_scheduler":"CosineAnnealingLR",
25 | "nb_epoch": 48,
26 | "old_epoch":0,
27 | "test_pad": 64,
28 |
29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/global/",
30 | "base_model":"",
31 | "monitor":"val_loss",
32 | "save_best_only":true,
33 | "mode":"min",
34 | "factor":0.1,
35 | "patience":10,
36 | "epsilon":0.0001,
37 | "cooldown":0,
38 | "min_lr":0,
39 |
40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/",
41 | "iter_size": 1,
42 | "folder": "0312_rice_size",
43 | "predict_batch_size": 16,
44 | "results_dir": "",
45 | "loss_x": {"dice": 0.9, "bce": 0.1},
46 | "loss_ema":0.3,
47 | "loss_kl":25,
48 | "loss_student":1,
49 | "ignore_target_size": false,
50 | "warmup": 0,
51 | "lovasz":false,
52 | "unlabel_size":1,
53 | "negative_rate":1,
54 | "external_weights":"",
55 | "class_weights":[],
56 | "train_fold":-1,
57 | "use_lb":true,
58 | "train_images":"images_clip",
59 | "train_images_unlabel":"",
60 | "train_masks":"label_half_nodata",
61 | "train_image_suffix":"*.png",
62 | "csv_file":"folds_rice.csv",
63 | "num_workers":4,
64 | "use_border":false,
65 | "class_picture":false,
66 | "use_good_image":true,
67 | "hard_negative_miner": false,
68 | "mixup": false
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/model_build/config_multiclass_manmade.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_path":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/traindata/sat_4bands_288/multiclass/",
3 | "img_w":256,
4 | "img_h":256,
5 | "im_bands":4,
6 | "im_type":"uint10",
7 | "target_name":"manmade",
8 | "val_rate":0.25,
9 |
10 | "network":"fpn",
11 | "BACKBONE":"resnet34",
12 | "activation":"softmax",
13 | "encoder_weights":null,
14 | "nb_classes":3,
15 | "batch_size":32,
16 | "epochs": 100,
17 |
18 | "loss":"categorical_crossentropy",
19 | "metrics":"accuracy",
20 | "optimizer": "adam",
21 | "lr": 0.0001,
22 | "lr_steps": [540, 560],
23 | "lr_gamma": 0.1,
24 | "lr_scheduler":"CosineAnnealingLR",
25 | "nb_epoch": 48,
26 | "old_epoch":0,
27 | "test_pad": 64,
28 |
29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/",
30 | "base_model":"",
31 | "monitor":"val_loss",
32 | "save_best_only":true,
33 | "mode":"min",
34 | "factor":0.1,
35 | "patience":5,
36 | "epsilon":0.0001,
37 | "cooldown":0,
38 | "min_lr":0,
39 |
40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/",
41 | "iter_size": 1,
42 | "folder": "0312_rice_size",
43 | "predict_batch_size": 16,
44 | "results_dir": "",
45 | "loss_x": {"dice": 0.9, "bce": 0.1},
46 | "loss_ema":0.3,
47 | "loss_kl":25,
48 | "loss_student":1,
49 | "ignore_target_size": false,
50 | "warmup": 0,
51 | "lovasz":false,
52 | "unlabel_size":1,
53 | "negative_rate":1,
54 | "external_weights":"",
55 | "class_weights":[],
56 | "train_fold":-1,
57 | "use_lb":true,
58 | "train_images":"images_clip",
59 | "train_images_unlabel":"",
60 | "train_masks":"label_half_nodata",
61 | "train_image_suffix":"*.png",
62 | "csv_file":"folds_rice.csv",
63 | "num_workers":4,
64 | "use_border":false,
65 | "class_picture":false,
66 | "use_good_image":true,
67 | "hard_negative_miner": false,
68 | "mixup": false
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/model_build/config_scrs_buildings.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_path":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/traindata/buildings_576_normal/",
3 | "img_w":576,
4 | "img_h":576,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"scrs_building",
8 | "val_rate":0.25,
9 |
10 | "network":"unet",
11 | "BACKBONE":"inceptionresnetv2",
12 | "activation":"sigmoid",
13 | "encoder_weights":null,
14 | "nb_classes":1,
15 | "batch_size":4,
16 | "epochs": 80,
17 |
18 | "loss":"bce_jaccard_loss",
19 | "metrics":"iou_score",
20 | "optimizer": "sgd",
21 | "lr": 0.0001,
22 | "lr_steps": [540, 560],
23 | "lr_gamma": 0.1,
24 | "lr_scheduler":"CosineAnnealingLR",
25 | "nb_epoch": 48,
26 | "old_epoch":0,
27 | "test_pad": 64,
28 |
29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/scrs_buildings/",
30 | "base_model":"",
31 | "monitor":"val_loss",
32 | "save_best_only":true,
33 | "mode":"min",
34 | "factor":0.1,
35 | "patience":10,
36 | "epsilon":0.0001,
37 | "cooldown":0,
38 | "min_lr":0,
39 |
40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/",
41 | "iter_size": 1,
42 | "folder": "0312_rice_size",
43 | "predict_batch_size": 16,
44 | "results_dir": "/media/scrs/Data1/chenjun/data/rice/results/",
45 | "loss_x": {"dice": 0.9, "bce": 0.1},
46 | "loss_ema":0.3,
47 | "loss_kl":25,
48 | "loss_student":1,
49 | "ignore_target_size": false,
50 | "warmup": 0,
51 | "lovasz":false,
52 | "unlabel_size":1,
53 | "negative_rate":1,
54 | "external_weights":"",
55 | "class_weights":[],
56 | "train_fold":-1,
57 | "use_lb":true,
58 | "train_images":"images_clip",
59 | "train_images_unlabel":"",
60 | "train_masks":"label_half_nodata",
61 | "train_image_suffix":"*.png",
62 | "csv_file":"folds_rice.csv",
63 | "num_workers":4,
64 | "use_border":false,
65 | "class_picture":false,
66 | "use_good_image":true,
67 | "hard_negative_miner": false,
68 | "mixup": false
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/model_build/config_tuitiantu.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_data_path":"/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/traindata/tuitiantu_480/",
3 | "img_w":480,
4 | "img_h":480,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"tuitiantu",
8 | "val_rate":0.25,
9 |
10 | "network":"unet",
11 | "BACKBONE":"resnet101",
12 | "activation":"sigmoid",
13 | "encoder_weights":null,
14 | "nb_classes":1,
15 | "batch_size":8,
16 | "epochs": 100,
17 |
18 | "loss":"bce_jaccard_loss",
19 | "metrics":"iou_score",
20 | "optimizer": "sgd",
21 | "lr": 0.0001,
22 | "lr_steps": [540, 560],
23 | "lr_gamma": 0.1,
24 | "lr_scheduler":"CosineAnnealingLR",
25 | "nb_epoch": 48,
26 | "old_epoch":0,
27 | "test_pad": 64,
28 |
29 | "model_dir": "/home/omnisky/PycharmProjects/data/models/ducha/tuitiantu/",
30 | "base_model":"",
31 | "monitor":"val_loss",
32 | "save_best_only":true,
33 | "mode":"min",
34 | "factor":0.1,
35 | "patience":10,
36 | "epsilon":0.0001,
37 | "cooldown":0,
38 | "min_lr":0,
39 |
40 | "log_dir": "/home/omnisky/PycharmProjects/data/tmp/",
41 | "iter_size": 1,
42 | "folder": "",
43 | "predict_batch_size": 16,
44 | "results_dir": "",
45 | "loss_x": {"dice": 0.9, "bce": 0.1},
46 | "loss_ema":0.3,
47 | "loss_kl":25,
48 | "loss_student":1,
49 | "ignore_target_size": false,
50 | "warmup": 0,
51 | "lovasz":false,
52 | "unlabel_size":1,
53 | "negative_rate":1,
54 | "external_weights":"",
55 | "class_weights":[],
56 | "train_fold":-1,
57 | "use_lb":true,
58 | "train_images":"images_clip",
59 | "train_images_unlabel":"",
60 | "train_masks":"label_half_nata",
61 | "train_image_suffix":"*.png",
62 | "csv_file":"fe.csv",
63 | "num_workers":4,
64 | "use_border":false,
65 | "class_picture":false,
66 | "use_good_image":true,
67 | "hard_negative_miner": false,
68 | "mixup": false
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/model_build/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import shutil
4 | import time
5 | import os
6 | from config import Config
7 | def get_csv_folds(path, d):
8 | df = pd.read_csv(path, index_col=0)
9 | m = df.max()[0] + 1
10 | train = [[] for i in range(m)]
11 | test = [[] for i in range(m)]
12 |
13 | folds = {}
14 | for i in range(m):
15 | fold_ids = list(df[df['fold'].isin([i])].index)
16 | folds.update({i: [n for n, l in enumerate(d) if l in fold_ids]})
17 |
18 | for k, v in folds.items():
19 | for i in range(m):
20 | if i != k:
21 | train[i].extend(v)
22 | test[k] = v
23 |
24 | return list(zip(np.array(train), np.array(test)))
25 | def update_config(config, **kwargs):
26 | d = config._asdict()
27 | d.update(**kwargs)
28 | print(d)
29 | return Config(**d)
30 | def save(path,network,jsonPath=None):
31 |
32 | folder=time.strftime("%Y%m%d%H%M", time.localtime())
33 | new_path=os.path.join(path,"history",folder+"_"+network)
34 | try:
35 | os.makedirs(new_path)
36 | except:
37 | new_path=new_path+"_"+str(np.random.randint(0,100))
38 | os.makedirs ( new_path )
39 | shutil.copy(os.path.join(path,"train.py"),new_path)
40 | if jsonPath is None:
41 | shutil.copy ( os.path.join ( path , "config.json" ) , new_path )
42 | else:
43 | shutil.copy ( jsonPath , new_path )
--------------------------------------------------------------------------------
/model_predict/config_pred.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/home/omnisky/PycharmProjects/data/test/rice/normal/",
3 | "img_w":256,
4 | "img_h":256,
5 | "im_bands":3,
6 | "im_type":"UINT8",
7 | "target_name":"rice",
8 | "activation":"sigmoid",
9 | "mask_classes":1,
10 | "strategy":"smooth",
11 | "window_size":256,
12 | "subdivisions":2,
13 | "slices":4,
14 | "block_size":100000000,
15 | "nodata":65535,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/rice/unet_test_Unet_resnet2019-03-25_11-57-12.h5",
17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/rice/pred/",
18 | "suffix":".png"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/config_pred.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 |
4 | Config_Pred = namedtuple("Config_pred", [
5 | "img_input",
6 | "img_w",
7 | "img_h",
8 | "im_bands",
9 | "im_type",
10 | "target_name",
11 | "model_path",
12 | "activation",
13 | "mask_classes",
14 | "strategy",
15 | "window_size",
16 | "subdivisions",
17 | "slices",
18 | "block_size",
19 | "nodata",
20 | "mask_dir",
21 | "suffix"
22 | ])
23 |
24 |
25 |
--------------------------------------------------------------------------------
/model_predict/config_pred_WHU.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/images/",
3 | "img_w":512,
4 | "img_h":512,
5 | "im_bands":3,
6 | "im_type":"UINT8",
7 | "target_name":"buildings",
8 | "activation":"sigmoid",
9 | "mask_classes":1,
10 | "strategy":"smooth",
11 | "window_size":512,
12 | "subdivisions":2,
13 | "slices":1,
14 | "block_size":200000000,
15 | "nodata":256,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/WHU/pre-trained/WHU_building_unet_resnet101_bce_jaccard_loss_512_2019-05-09_14-01-47.h5",
17 | "mask_dir": "/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/pred/",
18 | "suffix":".png"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/config_pred_bieshu.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/media/omnisky/579215ea-c392-4d2f-b136-6b62d9fb31ac/test/bieshu/norm_images/",
3 | "img_w":480,
4 | "img_h":480,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"bieshu",
8 | "activation":"sigmoid",
9 | "mask_classes":1,
10 | "strategy":"smooth",
11 | "window_size":480,
12 | "subdivisions":2,
13 | "slices":8,
14 | "block_size":5000000000,
15 | "nodata":65535,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/bieshu/bieshu_unet_vgg16_binary_crossentropy_480_2019-06-01_22-49-00.h5",
17 | "mask_dir": "/media/omnisky/579215ea-c392-4d2f-b136-6b62d9fb31ac/test/bieshu/pred/",
18 | "suffix":".tif"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/config_pred_multiclass.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/images_forClass/test_normal/",
3 | "img_w":288,
4 | "img_h":288,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"global",
8 | "activation":"softmax",
9 | "mask_classes":6,
10 | "strategy":"smooth",
11 | "window_size":288,
12 | "subdivisions":2,
13 | "slices":8,
14 | "block_size":500000000,
15 | "nodata":65535,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/global/global16bits_miandiantaiguosamples_fpn_resnet34_categorical_crossentropy_288_2019-04-06_17-25-00.h5",
17 | "mask_dir": "/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/pred/",
18 | "suffix":".tif"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/config_pred_multiclass_manmade.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/home/omnisky/PycharmProjects/data/test/paper/images/",
3 | "img_w":256,
4 | "img_h":256,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"manmade",
8 | "activation":"softmax",
9 | "mask_classes":3,
10 | "strategy":"smooth",
11 | "window_size":256,
12 | "subdivisions":2,
13 | "slices":8,
14 | "block_size":100000000,
15 | "nodata":65535,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/manmade_fpn_resnet34_categorical_crossentropy_256_2019-04-19_14-09-20.h5",
17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/paper/",
18 | "suffix":".png"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/config_pred_scrsbuilings.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/home/omnisky/PycharmProjects/data/test/tianfuxinqu/stretched/",
3 | "img_w":480,
4 | "img_h":480,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"tuitiantu",
8 | "activation":"sigmoid",
9 | "mask_classes":1,
10 | "strategy":"smooth",
11 | "window_size":480,
12 | "subdivisions":2,
13 | "slices":8,
14 | "block_size":500000000,
15 | "nodata":65535,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/ducha/tuitiantu/tuitiantu_unet_vgg16_bce_jaccard_loss_480_2019-05-09_08-40-41.h5",
17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/tianfuxinqu/tuitiantu_pred_2019/",
18 | "suffix":".png"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/config_pred_tuitiantu.json:
--------------------------------------------------------------------------------
1 | {
2 | "img_input":"/home/omnisky/PycharmProjects/data/test/tianfuxinqu/stretched/",
3 | "img_w":576,
4 | "img_h":576,
5 | "im_bands":4,
6 | "im_type":"UINT10",
7 | "target_name":"tuitiantu",
8 | "activation":"sigmoid",
9 | "mask_classes":1,
10 | "strategy":"smooth",
11 | "window_size":576,
12 | "subdivisions":2,
13 | "slices":8,
14 | "block_size":500000000,
15 | "nodata":65535,
16 | "model_path":"/home/omnisky/PycharmProjects/data/models/scrs_buildings/scrs_building_unet_inceptionv3_bce_jaccard_loss_576_2019-05-21_16-02-24.h5",
17 | "mask_dir": "/home/omnisky/PycharmProjects/data/test/tianfuxinqu/scrs_buildings/",
18 | "suffix":".png"
19 | }
20 |
--------------------------------------------------------------------------------
/model_predict/test_predict_1.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | from keras.models import load_model
13 | from sklearn.preprocessing import LabelEncoder
14 |
15 |
16 | from keras import backend as K
17 | K.set_image_dim_ordering('tf')
18 | K.clear_session()
19 | from segmentation_models.losses import bce_jaccard_loss
20 | from segmentation_models.metrics import iou_score
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from predict_backbone import predict_img_with_smooth_windowing_multiclassbands
25 |
26 | """
27 | The following global variables should be put into meta data file
28 | """
29 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
30 |
31 | target_class =1
32 |
33 | window_size = 256
34 | # step = 128
35 |
36 | im_bands =3
37 | im_type = UINT8 # UINT10,UINT8,UINT16
38 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
39 | dict_target={0: 'roads', 1: 'buildings'}
40 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
41 |
42 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings
43 |
44 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
45 |
46 | img_file = '/home/omnisky/PycharmProjects/data/test/rice/normal/testsrc_1.png'
47 |
48 | model_file = '/home/omnisky/PycharmProjects/data/models/rice/rice_unet_resnet34_bce_jaccard_loss_256_2019-03-28_20-42-58.h5'
49 | print("model: {}".format(model_file))
50 |
51 | if __name__ == '__main__':
52 |
53 | print("[INFO] opening image...")
54 |
55 | input_img = load_img_by_gdal(img_file)
56 | if im_type == UINT8:
57 | input_img = input_img / 255.0
58 | elif im_type == UINT10:
59 | input_img = input_img / 1024.0
60 | elif im_type == UINT16:
61 | input_img = input_img / 65535.0
62 |
63 | input_img = np.clip(input_img, 0.0, 1.0)
64 | input_img = input_img.astype(np.float32)
65 |
66 | abs_filename = os.path.split(img_file)[1]
67 | abs_filename = abs_filename.split(".")[0]
68 | print (abs_filename)
69 |
70 | """checke model file"""
71 | print("model file: {}".format(model_file))
72 | if not os.path.isfile(model_file):
73 | print("model does not exist:{}".format(model_file))
74 | sys.exit(-2)
75 |
76 | model = load_model(model_file)
77 |
78 | if FLAG_APPROACH_PREDICT==0:
79 | print("[INFO] predict image by orignal approach\n")
80 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
81 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_',
82 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_notonehot.png'])
83 | print("result save as to: {}".format(output_file))
84 | cv2.imwrite(output_file, result*128)
85 |
86 | elif FLAG_APPROACH_PREDICT==1:
87 | print("[INFO] predict image by smooth approach\n")
88 | result = predict_img_with_smooth_windowing_multiclassbands(
89 | input_img,
90 | model,
91 | window_size=window_size,
92 | subdivisions=2,
93 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
94 | pred_func=smooth_predict_for_binary_notonehot
95 | )
96 | """for single class test"""
97 | result[result<128]=0
98 | result[result>=128]=1
99 |
100 | output_file = '/home/omnisky/PycharmProjects/data/test/rice/newpred/test_1_unet_resnet34_jaccard.png'
101 |
102 | print("result save as to: {}".format(output_file))
103 |
104 |
105 |
106 | cv2.imwrite(output_file, result)
107 |
108 | gc.collect()
109 |
110 |
111 |
--------------------------------------------------------------------------------
/postprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/postprocess/__init__.py
--------------------------------------------------------------------------------
/postprocess/combine_diffclass_for_singlemodel_result.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | import os
4 | import cv2
5 | import numpy as np
6 |
7 | from tqdm import tqdm
8 | import matplotlib.pyplot as plt
9 |
10 | from ulitities.base_functions import load_img_by_cv2
11 |
12 | FOREGROUND = 127# for segnet:40; for unet=127; define the foreground value
13 |
14 | ROAD_VALUE=127
15 | BUILDING_VALUE=255
16 |
17 | """for unet"""
18 | # input_path = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/Tianfuxinqu/pred/pred_256/'
19 | input_path = '../../data/test/tianfuxinqu/pred/pred_256/'
20 | mask_pool = ['mask_binary_fw01_buildings_jaccard.png', 'mask_binary_fw01_roads_jaccard.png']
21 | output_file = input_path+'/combined/mask_2018_256_fw0_smooth_combined.png'
22 | print(output_file)
23 |
24 | # mask_pool = ['mask_multiclass_3_buildings.png','mask_multiclass_3_roads.png']
25 | # output_file = '../../data/predict/unet/unet_multiclass_combined_3.png'
26 |
27 | """for segnet"""
28 | # input_path = '../../data/predict/segnet/'
29 | # mask_pool = ['mask_binary_3_buildings.png','mask_binary_3_roads.png']
30 | # output_file = '../../data/predict/segnet/segnet_binary_combined_3.png'
31 |
32 | # mask_pool = ['mask_multiclass_3_buildings.png','mask_multiclass_3_roads.png']
33 | # output_file = '../../data/predict/segnet/segnet_multiclass_combined_3.png'
34 |
35 |
36 | def check_input_file(path,masks):
37 | ret, img_1 = load_img_by_cv2(path+masks[0], grayscale=True)
38 | assert (ret == 0)
39 |
40 | height, width = img_1.shape
41 | num_img = len(masks)
42 |
43 | for next_index in range(1,num_img):
44 | next_ret, next_img=load_img_by_cv2(path+masks[next_index],grayscale=True)
45 | assert (next_ret ==0 )
46 | next_height, next_width = next_img.shape
47 | assert(height==next_height and width==next_width)
48 | return height, width
49 |
50 |
51 |
52 | def combine_all_mask(height, width,input_path,mask_pool):
53 | """
54 |
55 | :param height:
56 | :param width:
57 | :param input_path:
58 | :param mask_pool:
59 | :return: final mask from roads_mask and buildings_mask
60 |
61 | prior: road(1)>bulidings(2)
62 | """
63 | final_mask=np.zeros((height,width),np.uint8)
64 | for idx,file in enumerate(mask_pool):
65 | ret,img = load_img_by_cv2(input_path+file,grayscale=True)
66 | assert (ret == 0)
67 | label_value=0
68 | if 'road' in file:
69 | label_value =ROAD_VALUE
70 | elif 'building' in file:
71 | label_value=BUILDING_VALUE
72 | # label_value = idx+1
73 | # print("buildings prior")
74 | print("Roads prior")
75 | for i in tqdm(range(height)):
76 | for j in range(width):
77 | if img[i,j]>=FOREGROUND:
78 | # print ("img[{},{}]:{}".format(i,j,img[i,j]))
79 |
80 | if label_value == ROAD_VALUE:
81 | final_mask[i, j] = label_value
82 | elif label_value == BUILDING_VALUE and final_mask[i, j] != ROAD_VALUE:
83 | final_mask[i, j] = label_value
84 |
85 | # if label_value == BUILDING_VALUE:
86 | # final_mask[i, j] = label_value
87 | # elif label_value == ROAD_VALUE and final_mask[i, j] != BUILDING_VALUE:
88 | # final_mask[i, j] = label_value
89 |
90 | final_mask[final_mask == ROAD_VALUE] = 1
91 | final_mask[final_mask == BUILDING_VALUE] = 2
92 | return final_mask
93 |
94 |
95 |
96 | if __name__=='__main__':
97 |
98 | x,y=check_input_file(input_path,mask_pool)
99 |
100 | result_mask=combine_all_mask(x,y,input_path,mask_pool)
101 |
102 | plt.imshow(result_mask, cmap='gray')
103 | plt.title("combined mask")
104 | plt.show()
105 |
106 | cv2.imwrite(output_file,result_mask)
107 | print("Saved to : {}".format(output_file))
--------------------------------------------------------------------------------
/postprocess/mismatch_analyze.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | import cv2
4 | import os
5 | import sys
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from tqdm import tqdm
9 | from ulitities.base_functions import load_img_by_cv2, compare_two_image_size
10 |
11 | ref_file ='../../data/test/paper/label/cuiping_label.png'
12 | # 1) jian11_test_label, 2) jiangyou_label, 3) yujiang_test_label,
13 | # 4) cuiping_label, 5) shuangliu_1test_label, 6) tongchuan_test_label
14 |
15 | pred_file ='../../data/test/paper/voted/' \
16 | 'unet_cuiping_4bands_voted0.png'
17 |
18 | output_file = '../../data/test/paper/voted/' \
19 | 'unet_cuiping_error_map_7c.png'
20 |
21 |
22 | if __name__=='__main__':
23 | print("[INFO]Reading images")
24 | ret, ref_img = load_img_by_cv2(ref_file, grayscale=True)
25 | if ret !=0:
26 | print("Open file failed:{}".format(ref_file))
27 | sys.exit(-1)
28 |
29 | ret, pred_img = load_img_by_cv2(pred_file, grayscale=True)
30 | if ret !=0:
31 | print("Open file failed:{}".format(pred_file))
32 | sys.exit(-2)
33 |
34 | compare_two_image_size(ref_img, pred_img, grayscale=True)
35 |
36 | height, width = ref_img.shape
37 | print("height,width: {},{}".format(height, width))
38 |
39 | match_img = np.zeros((height, width), np.uint8)
40 |
41 | for j in tqdm(range(height)):
42 | for i in range(width):
43 | # if ref_img[j,i]!=0:
44 | # if pred_img[j,i]==0:
45 | # match_img[j,i]=3 # 漏检的目标
46 | # if ref_img[j,i]==pred_img[j,i]:
47 | # match_img[j,i]=2 # true positive 检测正确的目标
48 | # elif ref_img[j,i]==1 and pred_img[j,i]==2:
49 | # match_img[j,i]=4 # 道路被错分为房屋建筑
50 | # elif ref_img[j,i]==2 and pred_img[j,i]==1:
51 | # match_img[j,i]=5 # 房屋建筑被错分为道路
52 | # else:
53 | # if pred_img[j,i]!=0:
54 | # match_img[j,i]=1 # false negative 背景被错分为目标
55 | # #ref_img[j,i]=pred_img[j,i]=0 # true negative
56 |
57 | if ref_img[j,i]!=0:
58 | if pred_img[j,i]==0:
59 | match_img[j,i]=4 # 漏检的目标
60 | if ref_img[j,i]==pred_img[j,i]:
61 | match_img[j,i]=3 # true positive 检测正确的目标
62 | elif ref_img[j,i]==1 and pred_img[j,i]==2:
63 | match_img[j,i]=5 # 道路被错分为房屋建筑
64 | elif ref_img[j,i]==2 and pred_img[j,i]==1:
65 | match_img[j,i]=6 # 房屋建筑被错分为道路
66 | else:
67 | if pred_img[j, i] == 1:
68 | match_img[j, i] = 1 # false negative 背景被错分为road
69 | if pred_img[j, i] == 2:
70 | match_img[j, i] = 2
71 |
72 |
73 |
74 | plt.imshow(match_img)
75 | plt.show()
76 |
77 | cv2.imwrite(output_file, match_img)
78 | print("Saving into: {}".format(output_file))
79 |
80 |
81 |
82 |
83 |
84 |
--------------------------------------------------------------------------------
/postprocess/raster_to_vector.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys
3 |
4 |
5 | input_raster_file = '/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/pred/2019-04-11_19-31-32--miandiantaiguo-histmatched/c-ZY302918320140104s.tif'
6 | output_shp_dir = '/media/omnisky/b1aca4b8-81b8-4751-8dee-24f70574dae9/test_global/pred/2019-04-11_19-31-32--miandiantaiguo-histmatched/'
7 |
8 | import gdal,osr,ogr
9 | def polygonize(rasterTemp, outShp, sieveSize=1):
10 | sourceRaster = gdal.Open(rasterTemp)
11 | band = sourceRaster.GetRasterBand(1)
12 | driver = ogr.GetDriverByName("ESRI Shapefile")
13 | # If shapefile already exist, delete it
14 | if os.path.exists(outShp):
15 | driver.DeleteDataSource(outShp)
16 |
17 | outDatasource = driver.CreateDataSource(outShp)
18 | # get proj from raster
19 | srs = osr.SpatialReference()
20 | srs.ImportFromWkt(sourceRaster.GetProjectionRef())
21 | # create layer with proj
22 | outLayer = outDatasource.CreateLayer(outShp, srs)
23 | # Add class column (1,2...) to shapefile
24 |
25 | newField = ogr.FieldDefn('grid_code', ogr.OFTInteger)
26 | outLayer.CreateField(newField)
27 |
28 | gdal.Polygonize(band, None, outLayer, 0, [], callback=None)
29 |
30 | outDatasource.Destroy()
31 | sourceRaster = None
32 | band = None
33 |
34 | try:
35 | # Add area for each feature
36 | ioShpFile = ogr.Open(outShp, update=1)
37 |
38 | lyr = ioShpFile.GetLayerByIndex(0)
39 | lyr.ResetReading()
40 |
41 | field_defn = ogr.FieldDefn("Area", ogr.OFTReal)
42 | lyr.CreateField(field_defn)
43 | except:
44 | print("Can not add filed of Area!")
45 |
46 | for i in lyr:
47 | # feat = lyr.GetFeature(i)
48 | geom = i.GetGeometryRef()
49 | area = round(geom.GetArea())
50 |
51 | lyr.SetFeature(i)
52 | i.SetField("Area", area)
53 | lyr.SetFeature(i)
54 | # if area is less than inMinSize or if it isn't forest, remove polygon
55 | if area < sieveSize:
56 | lyr.DeleteFeature(i.GetFID())
57 | ioShpFile.Destroy()
58 |
59 | return outShp
60 |
61 |
62 | if __name__=='__main__':
63 |
64 | if not os.path.isfile(input_raster_file):
65 | print("Error: Please input a raster file!")
66 | sys.exit(-1)
67 |
68 | if not os.path.isdir(output_shp_dir):
69 | print("Warning: output directory do not exist!")
70 | os.mkdir(output_shp_dir)
71 |
72 | absname = os.path.split(input_raster_file)[1]
73 | print('\n\t[Info] images:{}'.format(absname))
74 | absname = absname.split('.')[0]
75 | absname = ''.join([absname, '4.shp'])
76 | shp_file = os.path.join(output_shp_dir,absname)
77 |
78 | polygonize(input_raster_file, shp_file)
79 |
80 |
--------------------------------------------------------------------------------
/postprocess/vote.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from tqdm import tqdm
4 | import matplotlib.pyplot as plt
5 |
6 | from ulitities.base_functions import load_img_by_cv2
7 |
8 | target_values=[0, 1, 2] #
9 |
10 | input_path = '../../data/test/paper/'
11 | input_masks=['pred_224/combined/unet_jaccard_yujiang_test_4bands_combined.png',
12 | 'pred_288/combined/unet_jaccard_yujiang_test_4bands_combined.png',
13 | 'pred_256/combined/unet_jaccard_yujiang_4bands_combined.png',
14 | 'pred_256/combined/unet_multiclass_yujiang_4bands_combined.png',
15 | 'pred_256/combined/unet_notonehot_yujiang_test_4bands_combined.png',
16 | 'pred_256/combined/unet_onlyjaccard_yujiang_4bands_combined.png']
17 |
18 | output_file = '../../data/test/paper/voted/unet_yujiang_4bands_voted0.png'
19 |
20 |
21 | def check_input_file(path, masks):
22 | ret, img_1 = load_img_by_cv2(path+masks[0],grayscale=True)
23 | assert (ret == 0)
24 |
25 | height, width = img_1.shape
26 | num_img = len(masks)
27 |
28 | for next_index in range(1,num_img):
29 | next_ret, next_img=load_img_by_cv2(path+masks[next_index],grayscale=True)
30 | assert (next_ret ==0 )
31 | next_height, next_width = next_img.shape
32 | assert(height==next_height and width==next_width)
33 | return height, width
34 |
35 |
36 |
37 | def vote_per_image(height, width, path, masks):
38 | num_target = len(target_values)
39 |
40 | mask_list = []
41 | for tt in range(len(masks)):
42 | ret, img = load_img_by_cv2(path+masks[tt],grayscale=True)
43 | assert(ret ==0)
44 | mask_list.append(img)
45 |
46 | vote_mask=np.zeros((height,width), np.uint8)
47 |
48 | for i in tqdm(range(height)):
49 | for j in range(width):
50 | # record=np.zeros(256,np.uint8)
51 | record = np.zeros(num_target, np.uint8)
52 | for n in range(len(mask_list)):
53 | mask=mask_list[n]
54 | pixel=mask[i,j]
55 | record[pixel] +=1
56 |
57 | # """Alarming"""
58 | if record.argmax()==0: # if argmax of 0 = 125 or 255, not prior considering background(0)
59 | record[0] -=1
60 | # print("record:{}".format(record))
61 | # a = record[1:]
62 | # print("else:{}".format(a))
63 | # if a.any()>1:
64 | # record[0]=0
65 |
66 | label=record.argmax()
67 | # print ("{},{} label={}".format(i,j,label))
68 | vote_mask[i,j]=label
69 | # vote_mask[vote_mask==125]=1
70 | # vote_mask[vote_mask == 255] = 2
71 | print(np.unique(vote_mask))
72 |
73 | return vote_mask
74 |
75 |
76 |
77 | if __name__=='__main__':
78 | x,y = check_input_file(input_path, input_masks)
79 |
80 | final_mask = vote_per_image(x,y,input_path, input_masks)
81 | plt.imshow(final_mask)
82 | plt.show()
83 |
84 | cv2.imwrite(output_file, final_mask)
--------------------------------------------------------------------------------
/predict/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/predict/__init__.py
--------------------------------------------------------------------------------
/predict/predict_binary_jaccard.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 |
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256 #224, 256, 288. 320
36 | # step = 128
37 |
38 | im_bands =4
39 | im_type = UINT10 # UINT10,UINT8,UINT16
40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
41 | dict_target={0: 'roads', 1: 'buildings'}
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings
45 |
46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
47 |
48 | # position = 'shuangliu_1test' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test,
49 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test
50 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test
51 | # img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024.
52 | img_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_src.png'
53 | # img_file='/home/omnisky/PycharmProjects/data/test/sample1_12.png'
54 |
55 |
56 | # model_file = ''.join(['../../data/models/sat_urban_rgb/',dict_network[FLAG_USING_NETWORK], '_',
57 | # dict_target[FLAG_TARGET_CLASS],'_binary_jaccard_', str(window_size), '_final.h5'])
58 | model_file ='/home/omnisky/PycharmProjects/data/models/ducha/tuitiantu_jaccardandCross_2018-12-29_09-14-05.h5'
59 |
60 | print("model: {}".format(model_file))
61 |
62 | if __name__ == '__main__':
63 |
64 | print("[INFO] opening image...")
65 | # ret, input_img = load_img_normalization_by_cv2(img_file)
66 | # if ret !=0:
67 | # print("Open input file failed: {}".format(img_file))
68 | # sys.exit(-1)
69 |
70 | input_img = load_img_by_gdal(img_file)
71 | if im_type == UINT8:
72 | input_img = input_img / 255.0
73 | elif im_type == UINT10:
74 | input_img = input_img / 1024.0
75 | elif im_type == UINT16:
76 | input_img = input_img / 65535.0
77 |
78 | input_img = np.clip(input_img, 0.0, 1.0)
79 | input_img = input_img.astype(np.float16) # test accuracy
80 |
81 |
82 | abs_filename = os.path.split(img_file)[1]
83 | abs_filename = abs_filename.split(".")[0]
84 | print (abs_filename)
85 |
86 | """checke model file"""
87 | print("model file: {}".format(model_file))
88 | if not os.path.isfile(model_file):
89 | print("model does not exist:{}".format(model_file))
90 | sys.exit(-2)
91 |
92 | model = load_model(model_file)
93 |
94 | if FLAG_APPROACH_PREDICT==0:
95 | print("[INFO] predict image by orignal approach\n")
96 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
97 | # output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_',
98 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png'])
99 | output_file = ''.join(['../../data/test/tianfuxinqu/pred/pred_', str(window_size), '/mask_binary_',
100 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_original.png'])
101 | output_file = '/home/omnisky/PycharmProjects/data/originaldata/zs/pred/b_pred_original.png'
102 | print("result save as to: {}".format(output_file))
103 | cv2.imwrite(output_file, result*128)
104 |
105 | elif FLAG_APPROACH_PREDICT==1:
106 | print("[INFO] predict image by smooth approach\n")
107 | result = predict_img_with_smooth_windowing_multiclassbands(
108 | input_img,
109 | model,
110 | window_size=window_size,
111 | subdivisions=2,
112 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
113 | pred_func=smooth_predict_for_binary_notonehot
114 | )
115 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_rgb/mask_binary_',str(window_size),
116 | # '_', abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_jaccard.png'])
117 | # output_file = ''.join(['../../data/test/tianfuxinqu/pred/pred_', str(window_size), '/mask_binary_',
118 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_jaccard_smooth.png'])
119 | output_file = '/home/omnisky/PycharmProjects/data/originaldata/zs/pred/b_pred_512.png'
120 | print("result save as to: {}".format(output_file))
121 |
122 | cv2.imwrite(output_file, result)
123 | print("Saved to {}".format(output_file))
124 |
125 | gc.collect()
126 |
127 |
128 |
--------------------------------------------------------------------------------
/predict/predict_binary_notonehot.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 |
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256
36 | # step = 128
37 |
38 | im_bands =4
39 | im_type = UINT10 # UINT10,UINT8,UINT16
40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
41 | dict_target={0: 'roads', 1: 'buildings'}
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_TARGET_CLASS = 0 # 0:roads; 1:buildings
45 |
46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
47 |
48 | position = 'tongchuan_test' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test,
49 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test
50 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test
51 | # img_file = '../../data/test/sat_test/'+position+'_4bands1024.png' # _rgb, _nrg, __4bands1024.
52 | # img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024.
53 | img_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_src.png'
54 |
55 | # model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_',
56 | # dict_target[FLAG_TARGET_CLASS],'_binary_notonehot_final.h5'])
57 |
58 | model_file = '/home/omnisky/PycharmProjects/data/models/ducha/tuitiantuunet_Crossentropy_256_2018-12-29_14-59-59.h5'
59 | # print("model: {}".format(model_file))
60 |
61 | if __name__ == '__main__':
62 |
63 | print("[INFO] opening image...")
64 |
65 | input_img = load_img_by_gdal(img_file)
66 | if im_type == UINT8:
67 | input_img = input_img / 255.0
68 | elif im_type == UINT10:
69 | input_img = input_img / 1024.0
70 | elif im_type == UINT16:
71 | input_img = input_img / 65535.0
72 |
73 | input_img = np.clip(input_img, 0.0, 1.0)
74 | input_img = input_img.astype(np.float32)
75 |
76 |
77 | abs_filename = os.path.split(img_file)[1]
78 | abs_filename = abs_filename.split(".")[0]
79 | print (abs_filename)
80 |
81 | """checke model file"""
82 | print("model file: {}".format(model_file))
83 | if not os.path.isfile(model_file):
84 | print("model does not exist:{}".format(model_file))
85 | sys.exit(-2)
86 |
87 | model = load_model(model_file)
88 |
89 | if FLAG_APPROACH_PREDICT==0:
90 | print("[INFO] predict image by orignal approach\n")
91 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
92 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_',
93 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_notonehot.png'])
94 | print("result save as to: {}".format(output_file))
95 | cv2.imwrite(output_file, result*128)
96 |
97 | elif FLAG_APPROACH_PREDICT==1:
98 | print("[INFO] predict image by smooth approach\n")
99 | result = predict_img_with_smooth_windowing_multiclassbands(
100 | input_img,
101 | model,
102 | window_size=window_size,
103 | subdivisions=2,
104 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
105 | pred_func=smooth_predict_for_binary_notonehot
106 | )
107 | """for single class test"""
108 | result[result<128]=0
109 | result[result>=128]=1
110 | # output_file = '//home/omnisky/PycharmProjects/data/test/shuidao/GF2shuitian22_test_pred.png'
111 |
112 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_4bands/mask_binary_',
113 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_notonehot.png'])
114 |
115 | # output_file = ''.join(['../../data/test/paper/pred/mask_binary_',
116 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_notonehot.png'])
117 | output_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_src_pred.png'
118 |
119 | print("result save as to: {}".format(output_file))
120 |
121 |
122 |
123 | cv2.imwrite(output_file, result)
124 |
125 | gc.collect()
126 |
127 |
128 |
--------------------------------------------------------------------------------
/predict/predict_binary_onehot.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_binary_onehot
23 | from ulitities.base_functions import load_img_normalization, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 |
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256
36 | # step = 128
37 | im_bands = 3
38 | im_type = UINT8
39 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
40 | dict_target={0: 'roads', 1: 'buildings'}
41 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
42 |
43 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings
44 |
45 | FLAG_APPROACH_PREDICT = 0 # 0: original predict, 1: smooth predict
46 |
47 | # img_file = '../../data/test/GF2_yilong11.png'
48 | # img_file = '../../data/test/sample1.png'
49 | img_file = '../../data/test/lizhou_test_4bands255.png' # sample1_nrg, lizhou_test_4bands255
50 |
51 | # model_file = ''.join(['../../data/models/sat_urban_nrg/',dict_network[FLAG_USING_NETWORK], '_', dict_target[FLAG_TARGET_CLASS],'_binary.h5'])
52 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_buildings_binary2_onehot.h5'
53 | model_file='/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onehot.h5'
54 | print("model: {}".format(model_file))
55 |
56 | if __name__ == '__main__':
57 |
58 | print("[INFO] opening image...")
59 | if not os.path.isfile(img_file):
60 | print("Please check the input: {}".format(img_file))
61 | sys.exit(-1)
62 | # ret, input_img = load_img_normalization(img_file)
63 |
64 | input_img = load_img_by_gdal(img_file)
65 | if im_type == UINT8:
66 | input_img = input_img / 255.0
67 | elif im_type == UINT10:
68 | input_img = input_img / 1024.0
69 | elif im_type == UINT16:
70 | input_img = input_img / 65535.0
71 | input_img = np.clip(input_img, 0.0, 1.0)
72 |
73 | abs_filename = os.path.split(img_file)[1]
74 | abs_filename = abs_filename.split(".")[0]
75 | print (abs_filename)
76 |
77 | """checke model file"""
78 | print("model file: {}".format(model_file))
79 | if not os.path.isfile(model_file):
80 | print("model does not exist:{}".format(model_file))
81 | sys.exit(-2)
82 |
83 | model = load_model(model_file)
84 |
85 | if FLAG_APPROACH_PREDICT==0:
86 | print("[INFO] predict image by orignal approach\n")
87 | result = orignal_predict_onehot(input_img, im_bands, model, window_size)
88 | output_file = ''.join(['../../data/predict/original_predict_',abs_filename, '.png'])
89 | print("result save as to: {}".format(output_file))
90 | cv2.imwrite(output_file, result*100)
91 |
92 | elif FLAG_APPROACH_PREDICT==1:
93 | print("[INFO] predict image by smooth approach\n")
94 | result = predict_img_with_smooth_windowing_multiclassbands(
95 | input_img,
96 | model,
97 | window_size=window_size,
98 | subdivisions=2,
99 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
100 | pred_func=smooth_predict_for_binary_onehot
101 | )
102 | output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/mask_binary_',
103 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'.png'])
104 | print("result save as to: {}".format(output_file))
105 |
106 | cv2.imwrite(output_file, result)
107 |
108 | gc.collect()
109 |
110 |
111 |
--------------------------------------------------------------------------------
/predict/predict_binary_onlyjaccard.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import gc
11 | import argparse
12 | # from keras.preprocessing.image import img_to_array
13 | from keras.models import load_model
14 | from sklearn.preprocessing import LabelEncoder
15 | from PIL import Image
16 | from keras.preprocessing.image import img_to_array
17 |
18 | from keras import backend as K
19 | K.set_image_dim_ordering('tf')
20 | K.clear_session()
21 |
22 | from base_predict_functions import orignal_predict_notonehot, smooth_predict_for_binary_notonehot
23 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
24 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
25 | # from semantic_segmentation_networks import jaccard_coef,jaccard_coef_int
26 |
27 | """
28 | The following global variables should be put into meta data file
29 | """
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
31 |
32 |
33 | target_class =1
34 |
35 | window_size = 256
36 | # step = 128
37 |
38 | im_bands =4
39 | im_type = UINT10 # UINT10,UINT8,UINT16
40 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
41 | dict_target={0: 'roads', 1: 'buildings'}
42 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
43 |
44 | FLAG_TARGET_CLASS = 1 # 0:roads; 1:buildings
45 |
46 | FLAG_APPROACH_PREDICT = 1 # 0: original predict, 1: smooth predict
47 |
48 | position = 'jiangyou' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test,
49 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test
50 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test
51 | img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024.
52 | # img_file = '../../data/test/shuidao.png'
53 |
54 |
55 | model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_',
56 | dict_target[FLAG_TARGET_CLASS],'_binary_onlyjaccard_final.h5'])
57 |
58 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_buildings_binary_onlyjaccard_2018-09-29_18-55-11.h5'
59 | print("model: {}".format(model_file))
60 |
61 | if __name__ == '__main__':
62 |
63 | print("[INFO] opening image...")
64 | # ret, input_img = load_img_normalization_by_cv2(img_file)
65 | # if ret !=0:
66 | # print("Open input file failed: {}".format(img_file))
67 | # sys.exit(-1)
68 |
69 | input_img = load_img_by_gdal(img_file)
70 | if im_type == UINT8:
71 | input_img = input_img / 255.0
72 | elif im_type == UINT10:
73 | input_img = input_img / 1024.0
74 | elif im_type == UINT16:
75 | input_img = input_img / 65535.0
76 |
77 | input_img = np.clip(input_img, 0.0, 1.0)
78 |
79 |
80 | abs_filename = os.path.split(img_file)[1]
81 | abs_filename = abs_filename.split(".")[0]
82 | print (abs_filename)
83 |
84 | """checke model file"""
85 | print("model file: {}".format(model_file))
86 | if not os.path.isfile(model_file):
87 | print("model does not exist:{}".format(model_file))
88 | sys.exit(-2)
89 |
90 | model = load_model(model_file)
91 |
92 | if FLAG_APPROACH_PREDICT==0:
93 | print("[INFO] predict image by orignal approach\n")
94 | result = orignal_predict_notonehot(input_img,im_bands, model, window_size)
95 | output_file = ''.join(['../../data/predict/',dict_network[FLAG_USING_NETWORK],'/sat_4bands/original_pred_',
96 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png'])
97 | print("result save as to: {}".format(output_file))
98 | cv2.imwrite(output_file, result*128)
99 |
100 | elif FLAG_APPROACH_PREDICT==1:
101 | print("[INFO] predict image by smooth approach\n")
102 | result = predict_img_with_smooth_windowing_multiclassbands(
103 | input_img,
104 | model,
105 | window_size=window_size,
106 | subdivisions=2,
107 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
108 | pred_func=smooth_predict_for_binary_notonehot
109 | )
110 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK],'/sat_4bands/mask_binary_',
111 | # abs_filename, '_', dict_target[FLAG_TARGET_CLASS],'_onlyjaccard.png'])
112 |
113 | output_file = ''.join(['../../data/test/paper/pred/mask_binary_',
114 | abs_filename, '_', dict_target[FLAG_TARGET_CLASS], '_onlyjaccard.png'])
115 | print("result save as to: {}".format(output_file))
116 |
117 | cv2.imwrite(output_file, result)
118 |
119 | gc.collect()
120 |
121 |
122 |
--------------------------------------------------------------------------------
/predict/predict_multiclass.py:
--------------------------------------------------------------------------------
1 | #coding:utf8
2 | """"
3 | This is main procedure for remote sensing image semantic segmentation
4 |
5 | """
6 | import cv2
7 | import numpy as np
8 | import os
9 | import sys
10 | import argparse
11 | # from keras.preprocessing.image import img_to_array
12 | from keras.models import load_model
13 | from sklearn.preprocessing import LabelEncoder
14 | from PIL import Image
15 | from keras.preprocessing.image import img_to_array
16 |
17 | from keras import backend as K
18 | K.set_image_dim_ordering('tf')
19 | K.clear_session()
20 |
21 | from base_predict_functions import orignal_predict_onehot, smooth_predict_for_multiclass
22 | from ulitities.base_functions import load_img_normalization_by_cv2, load_img_by_gdal, UINT10,UINT8,UINT16
23 | from smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
24 |
25 | """
26 | The following global variables should be put into meta data file
27 | """
28 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
29 |
30 |
31 | window_size = 256
32 | step = 128
33 |
34 | im_bands = 4
35 | im_type = UINT10 #UINT8, UINT10, UINT16
36 |
37 | dict_network={0: 'unet', 1: 'fcnnet', 2: 'segnet'}
38 | dict_target={0: 'roads', 1: 'buildings'}
39 | target_class=len(dict_target)
40 |
41 | FLAG_USING_NETWORK = 0 # 0:unet; 1:fcn; 2:segnet;
42 |
43 | FLAG_APPROACH_PREDICT=1 # 0: original predict, 1: smooth predict
44 |
45 | position = 'tongchuan_test' # 1)jian11_test, , 2)jiangyou, 3)yujiang_test,
46 | # 4)cuiping, 5)shuangliu_1test, 6) tongchuan_test
47 | # 7) lizhou_test, 8) jianyang, 9)yushui22_test, 10) sample1, 11)ruoergai_52test
48 | # img_file = '../../data/test/sat_test/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024.
49 | img_file = '../../data/test/paper/images/'+position+'_4bands1024.png' # _rgb, _nrg, _4bands1024.
50 | # img_file = '../../data/test/shuidao.png'
51 |
52 | # img_file = '../../data/test/sat_test/cuiping_4bands1024.png' # jian11_test_nrg, sample1_nrg
53 |
54 | # model_file = ''.join(['../../data/models/sat_urban_4bands/',dict_network[FLAG_USING_NETWORK], '_multiclass_final.h5'])
55 | model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_4bands/unet_multiclass_final.h5'
56 | # model_file = '/home/omnisky/PycharmProjects/data/models/sat_urban_nrg/unet_multiclass.h5'
57 |
58 | if __name__ == '__main__':
59 |
60 | print("[INFO] opening image...")
61 |
62 | # ret, input_img = load_img_normalization_by_cv2(img_file)
63 |
64 | input_img = load_img_by_gdal(img_file)
65 | if im_type == UINT8:
66 | input_img = input_img / 255.0
67 | elif im_type == UINT10:
68 | input_img = input_img / 1024.0
69 | elif im_type == UINT16:
70 | input_img = input_img / 65535.0
71 | input_img = np.clip(input_img, 0.0, 1.0)
72 |
73 |
74 | abs_filename = os.path.split(img_file)[1]
75 | abs_filename = abs_filename.split(".")[0]
76 | print (abs_filename)
77 |
78 | """checke model file"""
79 | print("model file: {}".format(model_file))
80 | if not os.path.isfile(model_file):
81 | print("model does not exist:{}".format(model_file))
82 | sys.exit(-2)
83 |
84 | model= load_model(model_file)
85 |
86 | if FLAG_APPROACH_PREDICT==0:
87 | print("[INFO] predict image by orignal approach\n")
88 | result = orignal_predict_onehot(input_img, im_bands, model, window_size)
89 | output_file = ''.join(['../../data/predict/original_predict_',abs_filename, '.png'])
90 | print("result save as to: {}".format(output_file))
91 | cv2.imwrite(output_file, result*128)
92 |
93 | elif FLAG_APPROACH_PREDICT==1:
94 | print("[INFO] predict image by smooth approach\n")
95 | result = predict_img_with_smooth_windowing_multiclassbands(
96 | input_img,
97 | model,
98 | window_size=window_size,
99 | subdivisions=2,
100 | real_classes=target_class, # output channels = 是真的类别,总类别-背景
101 | pred_func=smooth_predict_for_multiclass
102 | )
103 |
104 | for b in range(target_class):
105 | # output_file = ''.join(['../../data/predict/', dict_network[FLAG_USING_NETWORK], '/sat_4bands/mask_multiclass_',
106 | # abs_filename, '_', dict_target[b], '.png'])
107 |
108 | output_file = ''.join(['../../data/test/paper/pred/mask_multiclass_',
109 | abs_filename, '_', dict_target[b], '.png'])
110 | print("result save as to: {}".format(output_file))
111 | cv2.imwrite(output_file, result[:,:,b])
112 |
113 |
114 |
--------------------------------------------------------------------------------
/samples_produce/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/samples_produce/__init__.py
--------------------------------------------------------------------------------
/samples_produce/check_original_labels_froNodata.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | import numpy as np
4 | import sys
5 | import os
6 | import cv2
7 | from tqdm import tqdm
8 | from ulitities.base_functions import load_img_by_cv2, get_file
9 |
10 | import matplotlib.pyplot as plt
11 |
12 |
13 | # segnet_labels = [0, 1, 2] # have not test for segnetlabels
14 | # unet_labels = [0, 1]
15 |
16 |
17 | # """for unet"""
18 | # input_src_path = '../../data/originaldata/unet/roads/src/'
19 | # input_label_path = '../../data/originaldata/unet/roads/label/'
20 |
21 | # """for segnet"""
22 | # input_src_path = '../../data/originaldata/segnet/src/'
23 | # input_label_path = '../../data/originaldata/segnet/label/'
24 |
25 |
26 | input_label_path = '../../data/originaldata/SatRGB/label/'
27 | valid_labels = [0, 1, 2]
28 |
29 | HAS_INVALID_VALUE = False
30 | # FLAG_BINARY_LABELS = True
31 |
32 |
33 | def make_label_valid(img, false_values):
34 | height, width = img.shape
35 |
36 | tp_img = img.reshape((height*width))
37 | for inv_lab in false_values:
38 | index = np.where(tp_img==inv_lab)
39 | tp_img[index]=0
40 | tp_img = tp_img.reshape((height, width))
41 |
42 | return tp_img
43 |
44 |
45 |
46 | #
47 | # for i in range(height):
48 | # for j in range(width):
49 | # tmp = img[i,j]
50 | # if not tmp in true_values:
51 | # print("img[{},{}]: {}".format(i,j,tmp))
52 | # img[i,j]=0
53 | # return img
54 |
55 |
56 | if __name__ == '__main__':
57 | files,num = get_file(input_label_path)
58 | assert (num!=0)
59 |
60 | # valid_labels = []
61 | # if FLAG_USING_UNET:
62 | # valid_labels = unet_labels
63 | # else:
64 | # valid_labels = segnet_labels
65 |
66 | for label_file in tqdm(files):
67 | # label_file = input_label_path + os.path.split(src_file)[1]
68 | #
69 | # ret,src_img = load_img(src_file)
70 | # assert(ret==0)
71 |
72 | ret,label_img = load_img_by_cv2(label_file, grayscale=True)
73 | assert (ret == 0)
74 |
75 | local_labels = np.unique(label_img)
76 | invalid_labels=[]
77 |
78 | for tmp in local_labels:
79 | if tmp not in valid_labels:
80 | invalid_labels.append(tmp)
81 | print ("\nWarning: some label is not valid value")
82 | print ("\nFile: {}".format(label_file))
83 | HAS_INVALID_VALUE = True
84 |
85 |
86 | if HAS_INVALID_VALUE == True:
87 | new_label_img = make_label_valid(label_img, invalid_labels)
88 | new_label_file = os.path.split(label_file)[0]+'/new_'+os.path.split(label_file)[1]
89 | cv2.imwrite(new_label_file, new_label_img)
90 | HAS_INVALID_VALUE = False
91 | label_img = new_label_img
92 |
93 | plt.imshow(label_img, cmap='gray')
94 | plt.show()
95 |
96 | print("Check completely!\n")
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/samples_produce/label_visulise.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | import cv2
4 | import sys, os
5 | from ulitities.base_functions import load_img,get_file
6 |
7 | input_path = '../../data/traindata/unet/buildings/label/'
8 | output_path = '../../data/traindata/unet/buildings/visulize/'
9 |
10 |
11 | if __name__ == '__main__':
12 |
13 | if not os.path.isdir(input_path):
14 | print("No input directory:{}".format(input_path))
15 | sys.exit(-1)
16 | if not os.path.isdir(output_path):
17 | print("No output directory:{}".format(output_path))
18 | os.mkdir(output_path)
19 |
20 | srcfiles, tt= get_file(input_path)
21 | assert(tt!=0)
22 |
23 | for index, file in enumerate(srcfiles):
24 | ret,img = load_img(file,grayscale=True)
25 | assert(ret==0)
26 |
27 | img = img*100
28 | filename = os.path.split(file)[1]
29 | outfile = os.path.join(output_path,filename)
30 | print(outfile)
31 |
32 | cv2.imwrite(outfile, img)
33 |
34 |
--------------------------------------------------------------------------------
/segmentation_models/__init__.py:
--------------------------------------------------------------------------------
1 | name = "segmentation_models"
2 |
3 | from .__version__ import __version__
4 |
5 | from .unet import Unet
6 | from .fpn import FPN
7 | from .linknet import Linknet
8 | from .pspnet import PSPNet
9 |
10 | from . import metrics
11 | from . import losses
--------------------------------------------------------------------------------
/segmentation_models/__version__.py:
--------------------------------------------------------------------------------
1 | VERSION = (0, 2, 0)
2 |
3 | __version__ = '.'.join(map(str, VERSION))
--------------------------------------------------------------------------------
/segmentation_models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from classification_models import Classifiers
2 | from classification_models import resnext
3 |
4 | from . import inception_resnet_v2 as irv2
5 | from . import inception_v3 as iv3
6 | from . import mobilenet as mbn
7 | from . import mobilenetv2 as mbn2
8 |
9 | # replace backbones with others, which have corrected padding mode in first pooling
10 | Classifiers._models.update({
11 | 'inceptionresnetv2': [irv2.InceptionResNetV2, irv2.preprocess_input],
12 | 'inceptionv3': [iv3.InceptionV3, iv3.preprocess_input],
13 | 'resnext50': [resnext.ResNeXt50, resnext.models.preprocess_input],
14 | 'resnext101': [resnext.ResNeXt101, resnext.models.preprocess_input],
15 | 'mobilenet': [mbn.MobileNet, mbn.preprocess_input],
16 | 'mobilenetv2': [mbn2.MobileNetV2, mbn2.preprocess_input],
17 | })
18 |
19 | DEFAULT_FEATURE_LAYERS = {
20 |
21 | # List of layers to take features from backbone in the following order:
22 | # (x16, x8, x4, x2, x1) - `x4` mean that features has 4 times less spatial
23 | # resolution (Height x Width) than input image.
24 |
25 | # VGG
26 | 'vgg16': ('block5_conv3', 'block4_conv3', 'block3_conv3', 'block2_conv2', 'block1_conv2'),
27 | 'vgg19': ('block5_conv4', 'block4_conv4', 'block3_conv4', 'block2_conv2', 'block1_conv2'),
28 |
29 | # ResNets
30 | 'resnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
31 | 'resnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
32 | 'resnet50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
33 | 'resnet101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
34 | 'resnet152': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
35 |
36 | # ResNeXt
37 | 'resnext50': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
38 | 'resnext101': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
39 |
40 | # Inception
41 | 'inceptionv3': (228, 86, 16, 9),
42 | 'inceptionresnetv2': (594, 260, 16, 9),
43 |
44 | # DenseNet
45 | 'densenet121': (311, 139, 51, 4),
46 | 'densenet169': (367, 139, 51, 4),
47 | 'densenet201': (479, 139, 51, 4),
48 |
49 | # SE models
50 | 'seresnet18': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
51 | 'seresnet34': ('stage4_unit1_relu1', 'stage3_unit1_relu1', 'stage2_unit1_relu1', 'relu0'),
52 | 'seresnet50': (233, 129, 59, 4),
53 | 'seresnet101': (522, 129, 59, 4),
54 | 'seresnet152': (811, 197, 59, 4),
55 | 'seresnext50': (1065, 577, 251, 4),
56 | 'seresnext101': (2442, 577, 251, 4),
57 | 'senet154': (6837, 1614, 451, 12),
58 |
59 | # Mobile Nets
60 | 'mobilenet': ('conv_pw_11_relu', 'conv_pw_5_relu', 'conv_pw_3_relu', 'conv_pw_1_relu'),
61 | 'mobilenetv2': ('block_13_expand_relu', 'block_6_expand_relu', 'block_3_expand_relu', 'block_1_expand_relu'),
62 |
63 | }
64 |
65 |
66 | def get_names():
67 | return list(DEFAULT_FEATURE_LAYERS.keys())
68 |
69 |
70 | def get_feature_layers(name, n=5):
71 | return DEFAULT_FEATURE_LAYERS[name][:n]
72 |
73 |
74 | def get_backbone(name, *args, **kwargs):
75 | return Classifiers.get_classifier(name)(*args, **kwargs)
76 |
77 |
78 | def get_preprocessing(name):
79 | return Classifiers.get_preprocessing(name)
80 |
--------------------------------------------------------------------------------
/segmentation_models/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .blocks import Conv2DBlock
2 | from .layers import ResizeImage
--------------------------------------------------------------------------------
/segmentation_models/common/blocks.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D
2 | from keras.layers import Activation
3 | from keras.layers import BatchNormalization
4 |
5 |
6 | def Conv2DBlock(n_filters, kernel_size,
7 | activation='relu',
8 | use_batchnorm=True,
9 | name='conv_block',
10 | **kwargs):
11 | """Extension of Conv2D layer with batchnorm"""
12 | def layer(input_tensor):
13 |
14 | x = Conv2D(n_filters, kernel_size, use_bias=not(use_batchnorm),
15 | name=name+'_conv', **kwargs)(input_tensor)
16 | if use_batchnorm:
17 | x = BatchNormalization(name=name+'_bn',)(x)
18 | x = Activation(activation, name=name+'_'+activation)(x)
19 |
20 | return x
21 | return layer
--------------------------------------------------------------------------------
/segmentation_models/common/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 |
5 | def transpose_shape(shape, target_format, spatial_axes):
6 | """Converts a tuple or a list to the correct `data_format`.
7 | It does so by switching the positions of its elements.
8 | # Arguments
9 | shape: Tuple or list, often representing shape,
10 | corresponding to `'channels_last'`.
11 | target_format: A string, either `'channels_first'` or `'channels_last'`.
12 | spatial_axes: A tuple of integers.
13 | Correspond to the indexes of the spatial axes.
14 | For example, if you pass a shape
15 | representing (batch_size, timesteps, rows, cols, channels),
16 | then `spatial_axes=(2, 3)`.
17 | # Returns
18 | A tuple or list, with the elements permuted according
19 | to `target_format`.
20 | # Example
21 | # Raises
22 | ValueError: if `value` or the global `data_format` invalid.
23 | """
24 | if target_format == 'channels_first':
25 | new_values = shape[:spatial_axes[0]]
26 | new_values += (shape[-1],)
27 | new_values += tuple(shape[x] for x in spatial_axes)
28 |
29 | if isinstance(shape, list):
30 | return list(new_values)
31 | return new_values
32 | elif target_format == 'channels_last':
33 | return shape
34 | else:
35 | raise ValueError('The `data_format` argument must be one of '
36 | '"channels_first", "channels_last". Received: ' +
37 | str(target_format))
38 |
39 |
40 | def permute_dimensions(x, pattern):
41 | """Permutes axes in a tensor.
42 | # Arguments
43 | x: Tensor or variable.
44 | pattern: A tuple of
45 | dimension indices, e.g. `(0, 2, 1)`.
46 | # Returns
47 | A tensor.
48 | """
49 | return tf.transpose(x, perm=pattern)
50 |
51 |
52 | def int_shape(x):
53 | """Returns the shape of tensor or variable as a tuple of int or None entries.
54 | # Arguments
55 | x: Tensor or variable.
56 | # Returns
57 | A tuple of integers (or None entries).
58 | """
59 | if hasattr(x, '_keras_shape'):
60 | return x._keras_shape
61 | try:
62 | return tuple(x.get_shape().as_list())
63 | except ValueError:
64 | return None
65 |
66 |
67 | def resize_images(x,
68 | height_factor,
69 | width_factor,
70 | data_format,
71 | interpolation='nearest'):
72 | """Resizes the images contained in a 4D tensor.
73 | # Arguments
74 | x: Tensor or variable to resize.
75 | height_factor: Positive integer.
76 | width_factor: Positive integer.
77 | data_format: string, `"channels_last"` or `"channels_first"`.
78 | interpolation: A string, one of `nearest` or `bilinear`.
79 | # Returns
80 | A tensor.
81 | # Raises
82 | ValueError: if `data_format` is neither `"channels_last"` or `"channels_first"`.
83 | """
84 | if data_format == 'channels_first':
85 | rows, cols = 2, 3
86 | else:
87 | rows, cols = 1, 2
88 |
89 | original_shape = int_shape(x)
90 | new_shape = tf.shape(x)[rows:cols + 1]
91 | new_shape *= tf.constant(np.array([height_factor, width_factor], dtype='int32'))
92 |
93 | if data_format == 'channels_first':
94 | x = permute_dimensions(x, [0, 2, 3, 1])
95 | if interpolation == 'nearest':
96 | x = tf.image.resize_nearest_neighbor(x, new_shape)
97 | elif interpolation == 'bilinear':
98 | x = tf.image.resize_bilinear(x, new_shape, align_corners=True)
99 | else:
100 | raise ValueError('interpolation should be one '
101 | 'of "nearest" or "bilinear".')
102 | if data_format == 'channels_first':
103 | x = permute_dimensions(x, [0, 3, 1, 2])
104 |
105 | if original_shape[rows] is None:
106 | new_height = None
107 | else:
108 | new_height = original_shape[rows] * height_factor
109 |
110 | if original_shape[cols] is None:
111 | new_width = None
112 | else:
113 | new_width = original_shape[cols] * width_factor
114 |
115 | output_shape = (None, new_height, new_width, None)
116 | x.set_shape(transpose_shape(output_shape, data_format, spatial_axes=(1, 2)))
117 | return x
--------------------------------------------------------------------------------
/segmentation_models/common/layers.py:
--------------------------------------------------------------------------------
1 | from keras.engine import Layer
2 | from keras.engine import InputSpec
3 | from keras.utils import conv_utils
4 | from keras.legacy import interfaces
5 | from keras.utils.generic_utils import get_custom_objects
6 |
7 | from .functions import resize_images
8 |
9 |
10 | class ResizeImage(Layer):
11 | """ResizeImage layer for 2D inputs.
12 | Repeats the rows and columns of the data
13 | by factor[0] and factor[1] respectively.
14 | # Arguments
15 | factor: int, or tuple of 2 integers.
16 | The upsampling factors for rows and columns.
17 | data_format: A string,
18 | one of `"channels_last"` or `"channels_first"`.
19 | The ordering of the dimensions in the inputs.
20 | `"channels_last"` corresponds to inputs with shape
21 | `(batch, height, width, channels)` while `"channels_first"`
22 | corresponds to inputs with shape
23 | `(batch, channels, height, width)`.
24 | It defaults to the `image_data_format` value found in your
25 | Keras config file at `~/.keras/keras.json`.
26 | If you never set it, then it will be "channels_last".
27 | interpolation: A string, one of `nearest` or `bilinear`.
28 | Note that CNTK does not support yet the `bilinear` upscaling
29 | and that with Theano, only `factor=(2, 2)` is possible.
30 | # Input shape
31 | 4D tensor with shape:
32 | - If `data_format` is `"channels_last"`:
33 | `(batch, rows, cols, channels)`
34 | - If `data_format` is `"channels_first"`:
35 | `(batch, channels, rows, cols)`
36 | # Output shape
37 | 4D tensor with shape:
38 | - If `data_format` is `"channels_last"`:
39 | `(batch, upsampled_rows, upsampled_cols, channels)`
40 | - If `data_format` is `"channels_first"`:
41 | `(batch, channels, upsampled_rows, upsampled_cols)`
42 | """
43 |
44 | @interfaces.legacy_upsampling2d_support
45 | def __init__(self, factor=(2, 2), data_format='channels_last', interpolation='nearest', **kwargs):
46 | super(ResizeImage, self).__init__(**kwargs)
47 | self.data_format = data_format
48 | self.factor = conv_utils.normalize_tuple(factor, 2, 'factor')
49 | self.input_spec = InputSpec(ndim=4)
50 | if interpolation not in ['nearest', 'bilinear']:
51 | raise ValueError('interpolation should be one '
52 | 'of "nearest" or "bilinear".')
53 | self.interpolation = interpolation
54 |
55 | def compute_output_shape(self, input_shape):
56 | if self.data_format == 'channels_first':
57 | height = self.factor[0] * input_shape[2] if input_shape[2] is not None else None
58 | width = self.factor[1] * input_shape[3] if input_shape[3] is not None else None
59 | return (input_shape[0],
60 | input_shape[1],
61 | height,
62 | width)
63 | elif self.data_format == 'channels_last':
64 | height = self.factor[0] * input_shape[1] if input_shape[1] is not None else None
65 | width = self.factor[1] * input_shape[2] if input_shape[2] is not None else None
66 | return (input_shape[0],
67 | height,
68 | width,
69 | input_shape[3])
70 |
71 | def call(self, inputs):
72 | return resize_images(inputs, self.factor[0], self.factor[1],
73 | self.data_format, self.interpolation)
74 |
75 | def get_config(self):
76 | config = {'factor': self.factor,
77 | 'data_format': self.data_format}
78 | base_config = super(ResizeImage, self).get_config()
79 | return dict(list(base_config.items()) + list(config.items()))
80 |
81 |
82 | get_custom_objects().update({'ResizeImage': ResizeImage})
83 |
--------------------------------------------------------------------------------
/segmentation_models/fpn/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import FPN
2 |
3 |
--------------------------------------------------------------------------------
/segmentation_models/fpn/blocks.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Add
2 |
3 | from ..common import Conv2DBlock
4 | from ..common import ResizeImage
5 | from ..utils import to_tuple
6 |
7 |
8 | def pyramid_block(pyramid_filters=256, segmentation_filters=128, upsample_rate=2,
9 | use_batchnorm=False, stage=0):
10 | """
11 | Pyramid block according to:
12 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
13 |
14 | This block generate `M` and `P` blocks.
15 |
16 | Args:
17 | pyramid_filters: integer, filters in `M` block of top-down FPN branch
18 | segmentation_filters: integer, number of filters in segmentation head,
19 | basically filters in convolution layers between `M` and `P` blocks
20 | upsample_rate: integer, uspsample rate for `M` block of top-down FPN branch
21 | use_batchnorm: bool, include batchnorm in convolution blocks
22 |
23 | Returns:
24 | Pyramid block function (as Keras layers functional API)
25 | """
26 | def layer(c, m=None):
27 |
28 | x = Conv2DBlock(pyramid_filters, (1, 1),
29 | padding='same',
30 | use_batchnorm=use_batchnorm,
31 | name='pyramid_stage_{}'.format(stage))(c)
32 |
33 | if m is not None:
34 | up = ResizeImage(to_tuple(upsample_rate))(m)
35 | x = Add()([x, up])
36 |
37 | # segmentation head
38 | p = Conv2DBlock(segmentation_filters, (3, 3),
39 | padding='same',
40 | use_batchnorm=use_batchnorm,
41 | name='segm1_stage_{}'.format(stage))(x)
42 |
43 | p = Conv2DBlock(segmentation_filters, (3, 3),
44 | padding='same',
45 | use_batchnorm=use_batchnorm,
46 | name='segm2_stage_{}'.format(stage))(p)
47 | m = x
48 |
49 | return m, p
50 | return layer
51 |
--------------------------------------------------------------------------------
/segmentation_models/fpn/builder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.layers import Conv2D
3 | from keras.layers import Concatenate
4 | from keras.layers import Activation
5 | from keras.layers import SpatialDropout2D
6 | from keras.models import Model
7 |
8 | from .blocks import pyramid_block
9 | from ..common import ResizeImage
10 | from ..common import Conv2DBlock
11 | from ..utils import extract_outputs, to_tuple
12 |
13 |
14 | def build_fpn(backbone,
15 | fpn_layers,
16 | classes=21,
17 | activation='softmax',
18 | upsample_rates=(2,2,2),
19 | last_upsample=4,
20 | pyramid_filters=256,
21 | segmentation_filters=128,
22 | use_batchnorm=False,
23 | dropout=None,
24 | interpolation='bilinear'):
25 | """
26 | Implementation of FPN head for segmentation models according to:
27 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
28 |
29 | Args:
30 | backbone: Keras `Model`, some classification model without top
31 | layers: list of layer names or indexes, used for pyramid building
32 | classes: int, number of output feature maps
33 | activation: activation in last layer, e.g. 'sigmoid' or 'softmax'
34 | upsample_rates: tuple of integers, scaling rates between pyramid blocks
35 | pyramid_filters: int, number of filters in `M` blocks of top-down FPN branch
36 | segmentation_filters: int, number of filters in `P` blocks of FPN
37 | last_upsample: rate for upsumpling concatenated pyramid predictions to
38 | match spatial resolution of input data
39 | last_upsampling_type: 'nn' or 'bilinear'
40 | dropout: float [0, 1), dropout rate
41 | use_batchnorm: bool, include batch normalization to FPN between `conv`
42 | and `relu` layers
43 |
44 | Returns:
45 | model: Keras `Model`
46 | """
47 |
48 | if len(upsample_rates) != len(fpn_layers):
49 | raise ValueError('Number of intermediate feature maps and upsample steps should match')
50 |
51 | # extract model layer outputs
52 | outputs = extract_outputs(backbone, fpn_layers, include_top=True)
53 |
54 | # add upsample rate `1` for first block
55 | upsample_rates = [1] + list(upsample_rates)
56 |
57 | # top - down path, build pyramid
58 | m = None
59 | pyramid = []
60 | for i, c in enumerate(outputs):
61 | m, p = pyramid_block(pyramid_filters=pyramid_filters,
62 | segmentation_filters=segmentation_filters,
63 | upsample_rate=upsample_rates[i],
64 | use_batchnorm=use_batchnorm,
65 | stage=i)(c, m)
66 | pyramid.append(p)
67 |
68 |
69 | # upsample and concatenate all pyramid layer
70 | upsampled_pyramid = []
71 |
72 | for i, p in enumerate(pyramid[::-1]):
73 | if upsample_rates[i] > 1:
74 | upsample_rate = to_tuple(np.prod(upsample_rates[:i+1]))
75 | p = ResizeImage(upsample_rate, interpolation=interpolation)(p)
76 | upsampled_pyramid.append(p)
77 |
78 | x = Concatenate()(upsampled_pyramid)
79 |
80 | # final convolution
81 | n_filters = segmentation_filters * len(pyramid)
82 | x = Conv2DBlock(n_filters, (3, 3), use_batchnorm=use_batchnorm, padding='same')(x)
83 | if dropout is not None:
84 | x = SpatialDropout2D(dropout)(x)
85 |
86 | x = Conv2D(classes, (3, 3), padding='same')(x)
87 |
88 | # upsampling to original spatial resolution
89 | x = ResizeImage(to_tuple(last_upsample), interpolation=interpolation)(x)
90 |
91 | # activation
92 | x = Activation(activation)(x)
93 |
94 | model = Model(backbone.input, x)
95 | return model
96 |
--------------------------------------------------------------------------------
/segmentation_models/fpn/model.py:
--------------------------------------------------------------------------------
1 | from .builder import build_fpn
2 | from ..backbones import get_backbone, get_feature_layers
3 | from ..utils import freeze_model
4 | from ..utils import legacy_support
5 |
6 | old_args_map = {
7 | 'freeze_encoder': 'encoder_freeze',
8 | 'fpn_layers': 'encoder_features',
9 | 'use_batchnorm': 'pyramid_use_batchnorm',
10 | 'dropout': 'pyramid_dropout',
11 | 'interpolation': 'final_interpolation',
12 | 'upsample_rates': None, # removed
13 | 'last_upsample': None, # removed
14 | }
15 |
16 |
17 | @legacy_support(old_args_map)
18 | def FPN(backbone_name='vgg16',
19 | input_shape=(None, None, 3),
20 | input_tensor=None,
21 | classes=21,
22 | activation='softmax',
23 | encoder_weights='imagenet',
24 | encoder_freeze=False,
25 | encoder_features='default',
26 | pyramid_block_filters=256,
27 | pyramid_use_batchnorm=True,
28 | pyramid_dropout=None,
29 | final_interpolation='bilinear',
30 | **kwargs):
31 | """FPN_ is a fully convolution neural network for image semantic segmentation
32 |
33 | Args:
34 | backbone_name: name of classification model (without last dense layers) used as feature
35 | extractor to build segmentation model.
36 | input_shape: shape of input data/image ``(H, W, C)``, in general
37 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
38 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
39 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model
40 | (works only if ``encoder_weights`` is ``None``).
41 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
42 | activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``).
43 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
44 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
45 | encoder_features: a list of layer numbers or names starting from top of the model.
46 | Each of these layers will be used to build features pyramid. If ``default`` is used
47 | layer names are taken from ``DEFAULT_FEATURE_PYRAMID_LAYERS``.
48 | pyramid_block_filters: a number of filters in Feature Pyramid Block of FPN_.
49 | pyramid_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
50 | is used.
51 | pyramid_dropout: spatial dropout rate for feature pyramid in range (0, 1).
52 | final_interpolation: interpolation type for upsampling layers, on of ``nearest``, ``bilinear``.
53 |
54 | Returns:
55 | ``keras.models.Model``: **FPN**
56 |
57 | .. _FPN:
58 | http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
59 |
60 | """
61 |
62 | backbone = get_backbone(backbone_name,
63 | input_shape=input_shape,
64 | input_tensor=input_tensor,
65 | weights=encoder_weights,
66 | include_top=False)
67 |
68 | if encoder_features == 'default':
69 | encoder_features = get_feature_layers(backbone_name, n=3)
70 |
71 | upsample_rates = [2] * len(encoder_features)
72 | last_upsample = 2 ** (5 - len(encoder_features))
73 |
74 | model = build_fpn(backbone, encoder_features,
75 | classes=classes,
76 | pyramid_filters=pyramid_block_filters,
77 | segmentation_filters=pyramid_block_filters // 2,
78 | upsample_rates=upsample_rates,
79 | use_batchnorm=pyramid_use_batchnorm,
80 | dropout=pyramid_dropout,
81 | last_upsample=last_upsample,
82 | interpolation=final_interpolation,
83 | activation=activation)
84 |
85 | if encoder_freeze:
86 | freeze_model(backbone)
87 |
88 | model.name = 'fpn-{}'.format(backbone.name)
89 |
90 | return model
91 |
--------------------------------------------------------------------------------
/segmentation_models/linknet/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Linknet
2 |
--------------------------------------------------------------------------------
/segmentation_models/linknet/blocks.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | from keras.layers import Conv2DTranspose as Transpose
3 | from keras.layers import UpSampling2D
4 | from keras.layers import Conv2D
5 | from keras.layers import BatchNormalization
6 | from keras.layers import Activation
7 | from keras.layers import Add
8 |
9 |
10 | def handle_block_names(stage):
11 | conv_name = 'decoder_stage{}_conv'.format(stage)
12 | bn_name = 'decoder_stage{}_bn'.format(stage)
13 | relu_name = 'decoder_stage{}_relu'.format(stage)
14 | up_name = 'decoder_stage{}_upsample'.format(stage)
15 | return conv_name, bn_name, relu_name, up_name
16 |
17 |
18 | def ConvRelu(filters,
19 | kernel_size,
20 | use_batchnorm=False,
21 | conv_name='conv',
22 | bn_name='bn',
23 | relu_name='relu'):
24 |
25 | def layer(x):
26 |
27 | x = Conv2D(filters,
28 | kernel_size,
29 | padding="same",
30 | name=conv_name,
31 | use_bias=not(use_batchnorm))(x)
32 |
33 | if use_batchnorm:
34 | x = BatchNormalization(name=bn_name)(x)
35 |
36 | x = Activation('relu', name=relu_name)(x)
37 |
38 | return x
39 | return layer
40 |
41 |
42 | def Conv2DUpsample(filters,
43 | upsample_rate,
44 | kernel_size=(3,3),
45 | up_name='up',
46 | conv_name='conv',
47 | **kwargs):
48 |
49 | def layer(input_tensor):
50 | x = UpSampling2D(upsample_rate, name=up_name)(input_tensor)
51 | x = Conv2D(filters,
52 | kernel_size,
53 | padding='same',
54 | name=conv_name,
55 | **kwargs)(x)
56 | return x
57 | return layer
58 |
59 |
60 | def Conv2DTranspose(filters,
61 | upsample_rate,
62 | kernel_size=(4,4),
63 | up_name='up',
64 | **kwargs):
65 |
66 | if not tuple(upsample_rate) == (2,2):
67 | raise NotImplementedError(
68 | 'Conv2DTranspose support only upsample_rate=(2, 2), got {}'.format(upsample_rate))
69 |
70 | def layer(input_tensor):
71 | x = Transpose(filters,
72 | kernel_size=kernel_size,
73 | strides=upsample_rate,
74 | padding='same',
75 | name=up_name)(input_tensor)
76 | return x
77 | return layer
78 |
79 |
80 | def UpsampleBlock(filters,
81 | upsample_rate,
82 | kernel_size,
83 | use_batchnorm=False,
84 | upsample_layer='upsampling',
85 | conv_name='conv',
86 | bn_name='bn',
87 | relu_name='relu',
88 | up_name='up',
89 | **kwargs):
90 |
91 | if upsample_layer == 'upsampling':
92 | UpBlock = Conv2DUpsample
93 |
94 | elif upsample_layer == 'transpose':
95 | UpBlock = Conv2DTranspose
96 |
97 | else:
98 | raise ValueError('Not supported up layer type {}'.format(upsample_layer))
99 |
100 | def layer(input_tensor):
101 |
102 | x = UpBlock(filters,
103 | upsample_rate=upsample_rate,
104 | kernel_size=kernel_size,
105 | use_bias=not(use_batchnorm),
106 | conv_name=conv_name,
107 | up_name=up_name,
108 | **kwargs)(input_tensor)
109 |
110 | if use_batchnorm:
111 | x = BatchNormalization(name=bn_name)(x)
112 |
113 | x = Activation('relu', name=relu_name)(x)
114 |
115 | return x
116 | return layer
117 |
118 |
119 | def DecoderBlock(stage,
120 | filters=None,
121 | kernel_size=(3,3),
122 | upsample_rate=(2,2),
123 | use_batchnorm=False,
124 | skip=None,
125 | upsample_layer='upsampling'):
126 |
127 | def layer(input_tensor):
128 |
129 | conv_name, bn_name, relu_name, up_name = handle_block_names(stage)
130 | input_filters = K.int_shape(input_tensor)[-1]
131 |
132 | if skip is not None:
133 | output_filters = K.int_shape(skip)[-1]
134 | else:
135 | output_filters = filters
136 |
137 | x = ConvRelu(input_filters // 4,
138 | kernel_size=(1, 1),
139 | use_batchnorm=use_batchnorm,
140 | conv_name=conv_name + '1',
141 | bn_name=bn_name + '1',
142 | relu_name=relu_name + '1')(input_tensor)
143 |
144 | x = UpsampleBlock(filters=input_filters // 4,
145 | kernel_size=kernel_size,
146 | upsample_layer=upsample_layer,
147 | upsample_rate=upsample_rate,
148 | use_batchnorm=use_batchnorm,
149 | conv_name=conv_name + '2',
150 | bn_name=bn_name + '2',
151 | up_name=up_name + '2',
152 | relu_name=relu_name + '2')(x)
153 |
154 | x = ConvRelu(output_filters,
155 | kernel_size=(1, 1),
156 | use_batchnorm=use_batchnorm,
157 | conv_name=conv_name + '3',
158 | bn_name=bn_name + '3',
159 | relu_name=relu_name + '3')(x)
160 |
161 | if skip is not None:
162 | x = Add()([x, skip])
163 |
164 | return x
165 | return layer
166 |
--------------------------------------------------------------------------------
/segmentation_models/linknet/builder.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D
2 | from keras.layers import Activation
3 | from keras.models import Model
4 |
5 | from .blocks import DecoderBlock
6 | from ..utils import get_layer_number, to_tuple
7 |
8 |
9 | def build_linknet(backbone,
10 | classes,
11 | skip_connection_layers,
12 | decoder_filters=(None, None, None, None, 16),
13 | upsample_rates=(2, 2, 2, 2, 2),
14 | n_upsample_blocks=5,
15 | upsample_kernel_size=(3, 3),
16 | upsample_layer='upsampling',
17 | activation='sigmoid',
18 | use_batchnorm=True):
19 |
20 | input = backbone.input
21 | x = backbone.output
22 |
23 | # convert layer names to indices
24 | skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l
25 | for l in skip_connection_layers])
26 |
27 | for i in range(n_upsample_blocks):
28 |
29 | # check if there is a skip connection
30 | skip_connection = None
31 | if i < len(skip_connection_idx):
32 | skip_connection = backbone.layers[skip_connection_idx[i]].output
33 |
34 | upsample_rate = to_tuple(upsample_rates[i])
35 |
36 | x = DecoderBlock(stage=i,
37 | filters=decoder_filters[i],
38 | kernel_size=upsample_kernel_size,
39 | upsample_rate=upsample_rate,
40 | use_batchnorm=use_batchnorm,
41 | upsample_layer=upsample_layer,
42 | skip=skip_connection)(x)
43 |
44 | x = Conv2D(classes, (3, 3), padding='same', name='final_conv')(x)
45 | x = Activation(activation, name=activation)(x)
46 |
47 | model = Model(input, x)
48 |
49 | return model
50 |
--------------------------------------------------------------------------------
/segmentation_models/linknet/model.py:
--------------------------------------------------------------------------------
1 | from .builder import build_linknet
2 | from ..utils import freeze_model
3 | from ..utils import legacy_support
4 | from ..backbones import get_backbone, get_feature_layers
5 |
6 | old_args_map = {
7 | 'freeze_encoder': 'encoder_freeze',
8 | 'skip_connections': 'encoder_features',
9 | 'upsample_layer': 'decoder_block_type',
10 | 'n_upsample_blocks': None, # removed
11 | 'input_tensor': None, # removed
12 | 'upsample_kernel_size': None, # removed
13 | }
14 |
15 |
16 | @legacy_support(old_args_map)
17 | def Linknet(backbone_name='vgg16',
18 | input_shape=(None, None, 3),
19 | classes=1,
20 | activation='sigmoid',
21 | encoder_weights='imagenet',
22 | encoder_freeze=False,
23 | encoder_features='default',
24 | decoder_filters=(None, None, None, None, 16),
25 | decoder_use_batchnorm=True,
26 | decoder_block_type='upsampling',
27 | **kwargs):
28 | """Linknet_ is a fully convolution neural network for fast image semantic segmentation
29 |
30 | Note:
31 | This implementation by default has 4 skip connections (original - 3).
32 |
33 | Args:
34 | backbone_name: name of classification model (without last dense layers) used as feature
35 | extractor to build segmentation model.
36 | input_shape: shape of input data/image ``(H, W, C)``, in general
37 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
38 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
39 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
40 | activation: name of one of ``keras.activations`` for last model layer
41 | (e.g. ``sigmoid``, ``softmax``, ``linear``).
42 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
43 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
44 | encoder_features: a list of layer numbers or names starting from top of the model.
45 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
46 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
47 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks,
48 | for block with skip connection a number of filters is equal to number of filters in
49 | corresponding encoder block (estimates automatically and can be passed as ``None`` value).
50 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
51 | is used.
52 | decoder_block_type: one of
53 | - `upsampling`: use ``Upsampling2D`` keras layer
54 | - `transpose`: use ``Transpose2D`` keras layer
55 |
56 | Returns:
57 | ``keras.models.Model``: **Linknet**
58 |
59 | .. _Linknet:
60 | https://arxiv.org/pdf/1707.03718.pdf
61 | """
62 |
63 | backbone = get_backbone(backbone_name,
64 | input_shape=input_shape,
65 | input_tensor=None,
66 | weights=encoder_weights,
67 | include_top=False)
68 |
69 | if encoder_features == 'default':
70 | encoder_features = get_feature_layers(backbone_name, n=4)
71 |
72 | model = build_linknet(backbone,
73 | classes,
74 | encoder_features,
75 | decoder_filters=decoder_filters,
76 | upsample_layer=decoder_block_type,
77 | activation=activation,
78 | n_upsample_blocks=len(decoder_filters),
79 | upsample_rates=(2, 2, 2, 2, 2),
80 | upsample_kernel_size=(3, 3),
81 | use_batchnorm=decoder_use_batchnorm)
82 |
83 | # lock encoder weights for fine-tuning
84 | if encoder_freeze:
85 | freeze_model(backbone)
86 |
87 | model.name = 'link-{}'.format(backbone_name)
88 |
89 | return model
90 |
--------------------------------------------------------------------------------
/segmentation_models/losses.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | from keras.losses import binary_crossentropy
3 | from keras.losses import categorical_crossentropy
4 | from keras.utils.generic_utils import get_custom_objects
5 |
6 | from .metrics import jaccard_score, f_score
7 |
8 | SMOOTH = 1e-12
9 |
10 | __all__ = [
11 | 'jaccard_loss', 'bce_jaccard_loss', 'cce_jaccard_loss',
12 | 'dice_loss', 'bce_dice_loss', 'cce_dice_loss',
13 | ]
14 |
15 |
16 | # ============================== Jaccard Losses ==============================
17 |
18 | def jaccard_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
19 | r"""Jaccard loss function for imbalanced datasets:
20 |
21 | .. math:: L(A, B) = 1 - \frac{A \cap B}{A \cup B}
22 |
23 | Args:
24 | gt: ground truth 4D keras tensor (B, H, W, C)
25 | pr: prediction 4D keras tensor (B, H, W, C)
26 | class_weights: 1. or list of class weights, len(weights) = C
27 | smooth: value to avoid division by zero
28 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
29 | else over whole batch
30 |
31 | Returns:
32 | Jaccard loss in range [0, 1]
33 |
34 | """
35 | return 1 - jaccard_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image)
36 |
37 |
38 | def bce_jaccard_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True):
39 | bce = K.mean(binary_crossentropy(gt, pr))
40 | loss = bce_weight * bce + jaccard_loss(gt, pr, smooth=smooth, per_image=per_image)
41 | return loss
42 |
43 |
44 | def cce_jaccard_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True):
45 | cce = categorical_crossentropy(gt, pr) * class_weights
46 | cce = K.mean(cce)
47 | return cce_weight * cce + jaccard_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image)
48 |
49 |
50 | # Update custom objects
51 | get_custom_objects().update({
52 | 'jaccard_loss': jaccard_loss,
53 | 'bce_jaccard_loss': bce_jaccard_loss,
54 | 'cce_jaccard_loss': cce_jaccard_loss,
55 | })
56 |
57 |
58 | # ============================== Dice Losses ================================
59 |
60 | def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
61 | r"""Dice loss function for imbalanced datasets:
62 |
63 | .. math:: L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall}
64 | {\beta^2 \cdot precision + recall}
65 |
66 | Args:
67 | gt: ground truth 4D keras tensor (B, H, W, C)
68 | pr: prediction 4D keras tensor (B, H, W, C)
69 | class_weights: 1. or list of class weights, len(weights) = C
70 | smooth: value to avoid division by zero
71 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
72 | else over whole batch
73 |
74 | Returns:
75 | Dice loss in range [0, 1]
76 |
77 | """
78 | return 1 - f_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, beta=1.)
79 |
80 |
81 | def bce_dice_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True):
82 | bce = K.mean(binary_crossentropy(gt, pr))
83 | loss = bce_weight * bce + dice_loss(gt, pr, smooth=smooth, per_image=per_image)
84 | return loss
85 |
86 |
87 | def cce_dice_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True):
88 | cce = categorical_crossentropy(gt, pr) * class_weights
89 | cce = K.mean(cce)
90 | return cce_weight * cce + dice_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image)
91 |
92 |
93 | # Update custom objects
94 | get_custom_objects().update({
95 | 'dice_loss': dice_loss,
96 | 'bce_dice_loss': bce_dice_loss,
97 | 'cce_dice_loss': cce_dice_loss,
98 | })
99 |
--------------------------------------------------------------------------------
/segmentation_models/metrics.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | from keras.utils.generic_utils import get_custom_objects
3 |
4 | __all__ = [
5 | 'iou_score', 'jaccard_score', 'f1_score', 'f2_score', 'dice_score',
6 | 'get_f_score', 'get_iou_score', 'get_jaccard_score',
7 | ]
8 |
9 | SMOOTH = 1e-12
10 |
11 |
12 | # ============================ Jaccard/IoU score ============================
13 |
14 |
15 | def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
16 | r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient
17 | (originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the
18 | similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets,
19 | and is defined as the size of the intersection divided by the size of the union of the sample sets:
20 |
21 | .. math:: J(A, B) = \frac{A \cap B}{A \cup B}
22 |
23 | Args:
24 | gt: ground truth 4D keras tensor (B, H, W, C)
25 | pr: prediction 4D keras tensor (B, H, W, C)
26 | class_weights: 1. or list of class weights, len(weights) = C
27 | smooth: value to avoid division by zero
28 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
29 | else over whole batch
30 |
31 | Returns:
32 | IoU/Jaccard score in range [0, 1]
33 |
34 | .. _`Jaccard index`: https://en.wikipedia.org/wiki/Jaccard_index
35 |
36 | """
37 | if per_image:
38 | axes = [1, 2]
39 | else:
40 | axes = [0, 1, 2]
41 |
42 | intersection = K.sum(gt * pr, axis=axes)
43 | union = K.sum(gt + pr, axis=axes) - intersection
44 | iou = (intersection + smooth) / (union + smooth)
45 |
46 | # mean per image
47 | if per_image:
48 | iou = K.mean(iou, axis=0)
49 |
50 | # weighted mean per class
51 | iou = K.mean(iou * class_weights)
52 |
53 | return iou
54 |
55 |
56 | def get_iou_score(class_weights=1., smooth=SMOOTH, per_image=True):
57 | """Change default parameters of IoU/Jaccard score
58 |
59 | Args:
60 | class_weights: 1. or list of class weights, len(weights) = C
61 | smooth: value to avoid division by zero
62 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
63 | else over whole batch
64 |
65 | Returns:
66 | ``callable``: IoU/Jaccard score
67 | """
68 | def score(gt, pr):
69 | return iou_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image)
70 |
71 | return score
72 |
73 |
74 | jaccard_score = iou_score
75 | get_jaccard_score = get_iou_score
76 |
77 | # Update custom objects
78 | get_custom_objects().update({
79 | 'iou_score': iou_score,
80 | 'jaccard_score': jaccard_score,
81 | })
82 |
83 |
84 | # ============================== F/Dice - score ==============================
85 |
86 | def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
87 | r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall,
88 | where an F-score reaches its best value at 1 and worst score at 0.
89 | The relative contribution of ``precision`` and ``recall`` to the F1-score are equal.
90 | The formula for the F score is:
91 |
92 | .. math:: F_\beta(precision, recall) = (1 + \beta^2) \frac{precision \cdot recall}
93 | {\beta^2 \cdot precision + recall}
94 |
95 | The formula in terms of *Type I* and *Type II* errors:
96 |
97 | .. math:: F_\beta(A, B) = \frac{(1 + \beta^2) TP} {(1 + \beta^2) TP + \beta^2 FN + FP}
98 |
99 |
100 | where:
101 | TP - true positive;
102 | FP - false positive;
103 | FN - false negative;
104 |
105 | Args:
106 | gt: ground truth 4D keras tensor (B, H, W, C)
107 | pr: prediction 4D keras tensor (B, H, W, C)
108 | class_weights: 1. or list of class weights, len(weights) = C
109 | beta: f-score coefficient
110 | smooth: value to avoid division by zero
111 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
112 | else over whole batch
113 |
114 | Returns:
115 | F-score in range [0, 1]
116 |
117 | """
118 | if per_image:
119 | axes = [1, 2]
120 | else:
121 | axes = [0, 1, 2]
122 |
123 | tp = K.sum(gt * pr, axis=axes)
124 | fp = K.sum(pr, axis=axes) - tp
125 | fn = K.sum(gt, axis=axes) - tp
126 |
127 | score = ((1 + beta ** 2) * tp + smooth) \
128 | / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
129 |
130 | # mean per image
131 | if per_image:
132 | score = K.mean(score, axis=0)
133 |
134 | # weighted mean per class
135 | score = K.mean(score * class_weights)
136 |
137 | return score
138 |
139 |
140 | def get_f_score(class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
141 | """Change default parameters of F-score score
142 |
143 | Args:
144 | class_weights: 1. or list of class weights, len(weights) = C
145 | smooth: value to avoid division by zero
146 | beta: f-score coefficient
147 | per_image: if ``True``, metric is calculated as mean over images in batch (B),
148 | else over whole batch
149 |
150 | Returns:
151 | ``callable``: F-score
152 | """
153 | def score(gt, pr):
154 | return f_score(gt, pr, class_weights=class_weights, beta=beta, smooth=smooth, per_image=per_image)
155 |
156 | return score
157 |
158 |
159 | f1_score = get_f_score(beta=1)
160 | f2_score = get_f_score(beta=2)
161 | dice_score = f1_score
162 |
163 | # Update custom objects
164 | get_custom_objects().update({
165 | 'f1_score': f1_score,
166 | 'f2_score': f2_score,
167 | 'dice_score': dice_score,
168 | })
169 |
--------------------------------------------------------------------------------
/segmentation_models/pspnet/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import PSPNet
--------------------------------------------------------------------------------
/segmentation_models/pspnet/blocks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from keras.layers import MaxPool2D
3 | from keras.layers import AveragePooling2D
4 | from keras.layers import Concatenate
5 | from keras.layers import Permute
6 | from keras.layers import Reshape
7 | from keras.backend import int_shape
8 |
9 | from ..common import Conv2DBlock
10 | from ..common import ResizeImage
11 |
12 |
13 | def InterpBlock(level, feature_map_shape,
14 | conv_filters=512,
15 | conv_kernel_size=(1,1),
16 | conv_padding='same',
17 | pooling_type='avg',
18 | pool_padding='same',
19 | use_batchnorm=True,
20 | activation='relu',
21 | interpolation='bilinear'):
22 |
23 | if pooling_type == 'max':
24 | Pool2D = MaxPool2D
25 | elif pooling_type == 'avg':
26 | Pool2D = AveragePooling2D
27 | else:
28 | raise ValueError('Unsupported pooling type - `{}`.'.format(pooling_type) +
29 | 'Use `avg` or `max`.')
30 |
31 | def layer(input_tensor):
32 | # Compute the kernel and stride sizes according to how large the final feature map will be
33 | # When the kernel factor and strides are equal, then we can compute the final feature map factor
34 | # by simply dividing the current factor by the kernel or stride factor
35 | # The final feature map sizes are 1x1, 2x2, 3x3, and 6x6. We round to the closest integer
36 | pool_size = [int(np.round(feature_map_shape[0] / level)),
37 | int(np.round(feature_map_shape[1] / level))]
38 | strides = pool_size
39 |
40 | x = Pool2D(pool_size, strides=strides, padding=pool_padding)(input_tensor)
41 | x = Conv2DBlock(conv_filters,
42 | kernel_size=conv_kernel_size,
43 | padding=conv_padding,
44 | use_batchnorm=use_batchnorm,
45 | activation=activation,
46 | name='level{}'.format(level))(x)
47 | x = ResizeImage(strides, interpolation=interpolation)(x)
48 | return x
49 | return layer
50 |
51 |
52 | def DUC(factor=(8, 8)):
53 |
54 | if factor[0] != factor[1]:
55 | raise ValueError('DUC upconvolution support only equal factors, '
56 | 'got {}'.format(factor))
57 | factor = factor[0]
58 |
59 | def layer(input_tensor):
60 |
61 | h, w, c = int_shape(input_tensor)[1:]
62 | H = h * factor
63 | W = w * factor
64 |
65 | x = Conv2DBlock(c*factor**2, (1,1),
66 | padding='same',
67 | name='duc_{}'.format(factor))(input_tensor)
68 | x = Permute((3, 1, 2))(x)
69 | x = Reshape((c, factor, factor, h, w))(x)
70 | x = Permute((1, 4, 2, 5, 3))(x)
71 | x = Reshape((c, H, W))(x)
72 | x = Permute((2, 3, 1))(x)
73 | return x
74 | return layer
75 |
76 |
77 | def PyramidPoolingModule(**params):
78 | """
79 | Build the Pyramid Pooling Module.
80 | """
81 |
82 | _params = {
83 | 'conv_filters': 512,
84 | 'conv_kernel_size': (1, 1),
85 | 'conv_padding': 'same',
86 | 'pooling_type': 'avg',
87 | 'pool_padding': 'same',
88 | 'use_batchnorm': True,
89 | 'activation': 'relu',
90 | 'interpolation': 'bilinear',
91 | }
92 |
93 | _params.update(params)
94 |
95 | def module(input_tensor):
96 |
97 | feature_map_shape = int_shape(input_tensor)[1:3]
98 |
99 | x1 = InterpBlock(1, feature_map_shape, **_params)(input_tensor)
100 | x2 = InterpBlock(2, feature_map_shape, **_params)(input_tensor)
101 | x3 = InterpBlock(3, feature_map_shape, **_params)(input_tensor)
102 | x6 = InterpBlock(6, feature_map_shape, **_params)(input_tensor)
103 |
104 | x = Concatenate()([input_tensor, x1, x2, x3, x6])
105 | return x
106 | return module
--------------------------------------------------------------------------------
/segmentation_models/pspnet/builder.py:
--------------------------------------------------------------------------------
1 | """
2 | Code is constructed based on following repositories:
3 | https://github.com/ykamikawa/PSPNet/
4 | https://github.com/hujh14/PSPNet-Keras/
5 | https://github.com/Vladkryvoruchko/PSPNet-Keras-tensorflow/
6 |
7 | And original paper of PSPNet:
8 | https://arxiv.org/pdf/1612.01105.pdf
9 | """
10 |
11 | from keras.layers import Conv2D
12 | from keras.layers import Activation
13 | from keras.layers import SpatialDropout2D
14 | from keras.models import Model
15 |
16 | from .blocks import PyramidPoolingModule, DUC
17 | from ..common import Conv2DBlock
18 | from ..common import ResizeImage
19 | from ..utils import extract_outputs
20 | from ..utils import to_tuple
21 |
22 |
23 | def build_psp(backbone,
24 | psp_layer,
25 | last_upsampling_factor,
26 | classes=21,
27 | activation='softmax',
28 | conv_filters=512,
29 | pooling_type='avg',
30 | dropout=None,
31 | final_interpolation='bilinear',
32 | use_batchnorm=True):
33 |
34 | input = backbone.input
35 |
36 | x = extract_outputs(backbone, [psp_layer])[0]
37 |
38 | x = PyramidPoolingModule(
39 | conv_filters=conv_filters,
40 | pooling_type=pooling_type,
41 | use_batchnorm=use_batchnorm)(x)
42 |
43 | x = Conv2DBlock(512, (1, 1), activation='relu', padding='same',
44 | use_batchnorm=use_batchnorm)(x)
45 |
46 | if dropout is not None:
47 | x = SpatialDropout2D(dropout)(x)
48 |
49 | x = Conv2D(classes, (3,3), padding='same', name='final_conv')(x)
50 |
51 | if final_interpolation == 'bilinear':
52 | x = ResizeImage(to_tuple(last_upsampling_factor))(x)
53 | elif final_interpolation == 'duc':
54 | x = DUC(to_tuple(last_upsampling_factor))(x)
55 | else:
56 | raise ValueError('Unsupported interpolation type {}. '.format(final_interpolation) +
57 | 'Use `duc` or `bilinear`.')
58 |
59 | x = Activation(activation, name=activation)(x)
60 |
61 | model = Model(input, x)
62 |
63 | return model
64 |
--------------------------------------------------------------------------------
/segmentation_models/pspnet/model.py:
--------------------------------------------------------------------------------
1 | from .builder import build_psp
2 | from ..utils import freeze_model
3 | from ..utils import legacy_support
4 | from ..backbones import get_backbone, get_feature_layers
5 |
6 |
7 | def _get_layer_by_factor(backbone_name, factor):
8 | feature_layers = get_feature_layers(backbone_name, n=3)
9 | if factor == 4:
10 | return feature_layers[-1]
11 | elif factor == 8:
12 | return feature_layers[-2]
13 | elif factor == 16:
14 | return feature_layers[-3]
15 | else:
16 | raise ValueError('Unsupported factor - `{}`, Use 4, 8 or 16.'.format(factor))
17 |
18 |
19 | def _shape_guard(factor, shape):
20 | h, w = shape[:2]
21 | min_size = factor * 6
22 |
23 | res = (h % min_size != 0 or w % min_size != 0 or
24 | h < min_size or w < min_size)
25 | if res:
26 | raise ValueError('Wrong shape {}, input H and W should '.format(shape) +
27 | 'be divisible by `{}`'.format(min_size))
28 |
29 |
30 | old_args_map = {
31 | 'freeze_encoder': 'encoder_freeze',
32 | 'use_batchnorm': 'psp_use_batchnorm',
33 | 'dropout': 'psp_dropout',
34 | 'input_tensor': None, # removed
35 | }
36 |
37 |
38 | @legacy_support(old_args_map)
39 | def PSPNet(backbone_name='vgg16',
40 | input_shape=(384, 384, 3),
41 | classes=21,
42 | activation='softmax',
43 | encoder_weights='imagenet',
44 | encoder_freeze=False,
45 | downsample_factor=8,
46 | psp_conv_filters=512,
47 | psp_pooling_type='avg',
48 | psp_use_batchnorm=True,
49 | psp_dropout=None,
50 | final_interpolation='bilinear',
51 | **kwargs):
52 | """PSPNet_ is a fully convolution neural network for image semantic segmentation
53 |
54 | Args:
55 | backbone_name: name of classification model used as feature
56 | extractor to build segmentation model.
57 | input_shape: shape of input data/image ``(H, W, C)``.
58 | ``H`` and ``W`` should be divisible by ``6 * downsample_factor`` and **NOT** ``None``!
59 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
60 | activation: name of one of ``keras.activations`` for last model layer
61 | (e.g. ``sigmoid``, ``softmax``, ``linear``).
62 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
63 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
64 | downsample_factor: one of 4, 8 and 16. Downsampling rate or in other words backbone depth
65 | to construct PSP module on it.
66 | psp_conv_filters: number of filters in ``Conv2D`` layer in each PSP block.
67 | psp_pooling_type: one of 'avg', 'max'. PSP block pooling type (maximum or average).
68 | psp_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
69 | is used.
70 | psp_dropout: dropout rate between 0 and 1.
71 | final_interpolation: ``duc`` or ``bilinear`` - interpolation type for final
72 | upsampling layer.
73 |
74 | Returns:
75 | ``keras.models.Model``: **PSPNet**
76 |
77 | .. _PSPNet:
78 | https://arxiv.org/pdf/1612.01105.pdf
79 |
80 | """
81 |
82 | # control image input shape
83 | _shape_guard(downsample_factor, input_shape)
84 |
85 | backbone = get_backbone(backbone_name,
86 | input_shape=input_shape,
87 | input_tensor=None,
88 | weights=encoder_weights,
89 | include_top=False)
90 |
91 | psp_layer = _get_layer_by_factor(backbone_name, downsample_factor)
92 |
93 | model = build_psp(backbone,
94 | psp_layer,
95 | last_upsampling_factor=downsample_factor,
96 | classes=classes,
97 | conv_filters=psp_conv_filters,
98 | pooling_type=psp_pooling_type,
99 | activation=activation,
100 | use_batchnorm=psp_use_batchnorm,
101 | dropout=psp_dropout,
102 | final_interpolation=final_interpolation)
103 |
104 | # lock encoder weights for fine-tuning
105 | if encoder_freeze:
106 | freeze_model(backbone)
107 |
108 | model.name = 'psp-{}'.format(backbone_name)
109 |
110 | return model
111 |
--------------------------------------------------------------------------------
/segmentation_models/unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Unet
2 |
--------------------------------------------------------------------------------
/segmentation_models/unet/blocks.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2DTranspose
2 | from keras.layers import UpSampling2D
3 | from keras.layers import Conv2D
4 | from keras.layers import BatchNormalization
5 | from keras.layers import Activation
6 | from keras.layers import Concatenate
7 |
8 |
9 | def handle_block_names(stage):
10 | conv_name = 'decoder_stage{}_conv'.format(stage)
11 | bn_name = 'decoder_stage{}_bn'.format(stage)
12 | relu_name = 'decoder_stage{}_relu'.format(stage)
13 | up_name = 'decoder_stage{}_upsample'.format(stage)
14 | return conv_name, bn_name, relu_name, up_name
15 |
16 |
17 | def ConvRelu(filters, kernel_size, use_batchnorm=False, conv_name='conv', bn_name='bn', relu_name='relu'):
18 | def layer(x):
19 | x = Conv2D(filters, kernel_size, padding="same", name=conv_name, use_bias=not(use_batchnorm))(x)
20 | if use_batchnorm:
21 | x = BatchNormalization(name=bn_name)(x)
22 | x = Activation('relu', name=relu_name)(x)
23 | return x
24 | return layer
25 |
26 |
27 | def Upsample2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2),
28 | use_batchnorm=False, skip=None):
29 |
30 | def layer(input_tensor):
31 |
32 | conv_name, bn_name, relu_name, up_name = handle_block_names(stage)
33 |
34 | x = UpSampling2D(size=upsample_rate, name=up_name)(input_tensor)
35 |
36 | if skip is not None:
37 | x = Concatenate()([x, skip])
38 |
39 | x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm,
40 | conv_name=conv_name + '1', bn_name=bn_name + '1', relu_name=relu_name + '1')(x)
41 |
42 | x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm,
43 | conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x)
44 |
45 | return x
46 | return layer
47 |
48 |
49 | def Transpose2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2),
50 | transpose_kernel_size=(4,4), use_batchnorm=False, skip=None):
51 |
52 | def layer(input_tensor):
53 |
54 | conv_name, bn_name, relu_name, up_name = handle_block_names(stage)
55 |
56 | x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate,
57 | padding='same', name=up_name, use_bias=not(use_batchnorm))(input_tensor)
58 | if use_batchnorm:
59 | x = BatchNormalization(name=bn_name+'1')(x)
60 | x = Activation('relu', name=relu_name+'1')(x)
61 |
62 | if skip is not None:
63 | x = Concatenate()([x, skip])
64 |
65 | x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm,
66 | conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x)
67 |
68 | return x
69 | return layer
--------------------------------------------------------------------------------
/segmentation_models/unet/builder.py:
--------------------------------------------------------------------------------
1 | from keras.layers import Conv2D
2 | from keras.layers import Activation
3 | from keras.models import Model
4 |
5 | from .blocks import Transpose2D_block
6 | from .blocks import Upsample2D_block
7 | from ..utils import get_layer_number, to_tuple
8 |
9 |
10 | def build_unet(backbone, classes, skip_connection_layers,
11 | decoder_filters=(256,128,64,32,16),
12 | upsample_rates=(2,2,2,2,2),
13 | n_upsample_blocks=5,
14 | block_type='upsampling',
15 | activation='sigmoid',
16 | use_batchnorm=True):
17 |
18 | input = backbone.input
19 | x = backbone.output
20 |
21 | if block_type == 'transpose':
22 | up_block = Transpose2D_block
23 | else:
24 | up_block = Upsample2D_block
25 |
26 | # convert layer names to indices
27 | skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l
28 | for l in skip_connection_layers])
29 |
30 | for i in range(n_upsample_blocks):
31 |
32 | # check if there is a skip connection
33 | skip_connection = None
34 | if i < len(skip_connection_idx):
35 | skip_connection = backbone.layers[skip_connection_idx[i]].output
36 |
37 | upsample_rate = to_tuple(upsample_rates[i])
38 |
39 | x = up_block(decoder_filters[i], i, upsample_rate=upsample_rate,
40 | skip=skip_connection, use_batchnorm=use_batchnorm)(x)
41 |
42 | x = Conv2D(classes, (3,3), padding='same', name='final_conv')(x)
43 | x = Activation(activation, name=activation)(x)
44 |
45 | model = Model(input, x)
46 |
47 | return model
48 |
--------------------------------------------------------------------------------
/segmentation_models/unet/model.py:
--------------------------------------------------------------------------------
1 | from .builder import build_unet
2 | from ..utils import freeze_model
3 | from ..utils import legacy_support
4 | from ..backbones import get_backbone, get_feature_layers
5 |
6 | old_args_map = {
7 | 'freeze_encoder': 'encoder_freeze',
8 | 'skip_connections': 'encoder_features',
9 | 'upsample_rates': None, # removed
10 | 'input_tensor': None, # removed
11 | }
12 |
13 |
14 | @legacy_support(old_args_map)
15 | def Unet(backbone_name='vgg16',
16 | input_shape=(None, None, 3),
17 | classes=1,
18 | activation='sigmoid',
19 | encoder_weights='imagenet',
20 | encoder_freeze=False,
21 | encoder_features='default',
22 | decoder_block_type='upsampling',
23 | decoder_filters=(256, 128, 64, 32, 16),
24 | decoder_use_batchnorm=True,
25 | **kwargs):
26 | """ Unet_ is a fully convolution neural network for image semantic segmentation
27 |
28 | Args:
29 | backbone_name: name of classification model (without last dense layers) used as feature
30 | extractor to build segmentation model.
31 | input_shape: shape of input data/image ``(H, W, C)``, in general
32 | case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
33 | able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
34 | classes: a number of classes for output (output shape - ``(h, w, classes)``).
35 | activation: name of one of ``keras.activations`` for last model layer
36 | (e.g. ``sigmoid``, ``softmax``, ``linear``).
37 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
38 | encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
39 | encoder_features: a list of layer numbers or names starting from top of the model.
40 | Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
41 | layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
42 | decoder_block_type: one of blocks with following layers structure:
43 |
44 | - `upsampling`: ``Upsampling2D`` -> ``Conv2D`` -> ``Conv2D``
45 | - `transpose`: ``Transpose2D`` -> ``Conv2D``
46 |
47 | decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks
48 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
49 | is used.
50 |
51 | Returns:
52 | ``keras.models.Model``: **Unet**
53 |
54 | .. _Unet:
55 | https://arxiv.org/pdf/1505.04597
56 |
57 | """
58 |
59 | backbone = get_backbone(backbone_name,
60 | input_shape=input_shape,
61 | input_tensor=None,
62 | weights=encoder_weights,
63 | include_top=False)
64 |
65 | if encoder_features == 'default':
66 | encoder_features = get_feature_layers(backbone_name, n=4)
67 |
68 | model = build_unet(backbone,
69 | classes,
70 | encoder_features,
71 | decoder_filters=decoder_filters,
72 | block_type=decoder_block_type,
73 | activation=activation,
74 | n_upsample_blocks=len(decoder_filters),
75 | upsample_rates=(2, 2, 2, 2, 2),
76 | use_batchnorm=decoder_use_batchnorm)
77 |
78 | # lock encoder weights for fine-tuning
79 | if encoder_freeze:
80 | freeze_model(backbone)
81 |
82 | model.name = 'u-{}'.format(backbone_name)
83 |
84 | return model
85 |
--------------------------------------------------------------------------------
/temp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/temp/__init__.py
--------------------------------------------------------------------------------
/temp/band4_image.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import gdal
4 | import numpy as np
5 | import keras.backend as K
6 | K.set_image_dim_ordering('tf')
7 | from keras.preprocessing.image import img_to_array
8 | import matplotlib.pyplot as plt
9 |
10 |
11 | img_path = '../../data/originaldata/NatureProtected/src/sample1.png'
12 |
13 | out_dir='../../data/'
14 |
15 | window_size=2048
16 |
17 | if __name__=='__main__':
18 | img = cv2.imread(img_path)
19 |
20 | print(img.shape)
21 |
22 | dataset = gdal.Open(img_path)
23 | if dataset==None:
24 | print("open failed!\n")
25 |
26 | height = dataset.RasterYSize
27 | width=dataset.RasterXSize
28 | bands = dataset.RasterCount
29 | all_data=dataset.ReadAsArray(0,0,width,height)
30 | # all_data = np.array(all_data)
31 | # new_data = img_to_array(all_data)
32 |
33 | # print("shape:{}".format(new_data.shape))
34 |
35 | x = np.random.randint(0, height - window_size - 1)
36 | y = np.random.randint(0, width - window_size - 1)
37 |
38 | output_img = all_data[:, x:x + window_size, y:y + window_size]
39 | print(output_img.shape)
40 | result = output_img[1:4,:,:]
41 | result = np.transpose(result,(1,2,0))
42 | plt.imshow(result)
43 | plt.show()
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/temp/change_geotransform.py:
--------------------------------------------------------------------------------
1 | import gdal
2 | import os ,sys
3 | from ulitities.base_functions import load_img_by_gdal_geo
4 |
5 |
6 | input_file = '/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/images/test_2w.png'
7 | output_file = '/media/omnisky/6b62a451-463c-41e2-b06c-57f95571fdec/Backups/data/test/WHU/images/test_2w-C.png'
8 |
9 |
10 | if __name__=="__main__":
11 | try:
12 | data, geotransform = load_img_by_gdal_geo(input_file)
13 | print("Geotransform:{}".format(geotransform))
14 | except:
15 | print("Error: Failde load image..")
16 | sys.exit(-1)
17 | tmp = list(geotransform)
18 | tmp[-1]=-1*tmp[-1]
19 |
20 | new_geo=tuple(tmp)
21 | print("new Geotransform:{}".format(new_geo))
22 |
23 | w,h,c = data.shape
24 |
25 | driver = gdal.GetDriverByName("GTiff")
26 | # driver = gdal.GetDriverByName("PNG")
27 | # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16)
28 | outdataset = driver.Create(output_file, w, h, c, gdal.GDT_Byte)
29 | outdataset.SetGeoTransform(new_geo)
30 | if outdataset == None:
31 | print("create dataset failed!\n")
32 | sys.exit(-2)
33 | if c == 1:
34 | outdataset.GetRasterBand(1).WriteArray(data)
35 | else:
36 | for i in range(c):
37 | outdataset.GetRasterBand(i + 1).WriteArray(data[:,:,i])
38 | del outdataset
39 |
--------------------------------------------------------------------------------
/temp/change_label_zym.py:
--------------------------------------------------------------------------------
1 | import os, sys, gdal
2 | import numpy as np
3 | from tqdm import tqdm
4 | import matplotlib.pyplot as plt
5 | import cv2
6 |
7 | from ulitities.base_functions import get_file, load_img_by_gdal
8 |
9 |
10 | def write_img_by_gdal(path, data, bands, dtype):
11 | data = np.array(data)
12 | if bands >1:
13 | a,b,c = data.shape
14 | if c= 4)
22 |
23 | img = dataset.ReadAsArray(0, 0)
24 | rgb_img = img[:3]
25 | rgb_img = np.transpose(rgb_img, (1, 2, 0))
26 |
27 | p_band=dataset.GetRasterBand(1)
28 | data_type = p_band.DataType
29 | # data_type = gdal.GetDataTypeByName(p_band.DataType)
30 | if data_type==1:
31 |
32 | elif data_type==2:
33 | # img = dataset.ReadAsArray(0, 0)
34 | # rgb_img = img[:3]
35 | # rgb_img = np.transpose(rgb_img, (1, 2, 0))
36 | rgb_img = rgb_img*255/6
37 | rgb_img = rgb_img.astype(np.uint8)
38 | plt.imshow(rgb_img, cmap='jet')
39 | plt.show()
40 | rgb_filename = ''.join([outdir_rgb, os.path.split(img_path)[1]])
41 |
42 | cv2.imwrite(rgb_filename, rgb_img)
43 |
44 | nrg_img = img[1:4]
45 | nrg_img = np.transpose(nrg_img, (1, 2, 0))
46 | nrg_img = nrg_img.astype(np.uint8)
47 | plt.imshow(nrg_img)
48 | plt.show()
49 | rgb_filename = ''.join([outdir_nrg, os.path.split(img_path)[1]])
50 |
51 | cv2.imwrite(rgb_filename, nrg_img)
52 |
53 |
54 |
55 |
56 |
57 | print("complete!\n")
58 |
59 |
--------------------------------------------------------------------------------
/temp/predict_from_xuhuimin.py:
--------------------------------------------------------------------------------
1 | import mxnet as mx
2 | import numpy as np
3 | import cv2
4 |
5 | from test_utils.predict import predict
6 |
7 |
8 | def pred(image, net, step, ctx):#step为样本选取间隔
9 | h, w, channel = image.shape
10 | image = image.astype('float32')
11 | size = int(step *0.75) #取样本中size尺寸为最终预测尺寸
12 | margin = int((step - size) / 2)
13 | inhang = int(np.ceil(h/size))
14 | inlie = int(np.ceil(w / size))
15 |
16 | # newimage0=np.zeros((inhang*size, inlie*size,channel))
17 | # borderType = cv2.BORDER_REFLECT
18 | # newimage = cv2.copyMakeBorder(newimage0, margin, margin, margin, margin, borderType)
19 |
20 | newimage = np.zeros((inhang*size + margin*2 , inlie*size +2*margin,channel))
21 | newimage[margin : h + margin,margin : w + margin ,:] = image
22 | newimage /= 255
23 | predictions = np.zeros((inhang*size , inlie*size), dtype=np.int64)
24 | for i in range(inhang):
25 | for j in range(inlie):
26 | patch = newimage[ i*size: i*size+step ,j*size: j*size+step ,:]
27 | patch = np.transpose(patch, axes=(2, 0, 1)).astype(np.float32)
28 | patch = mx.nd.array(np.expand_dims(patch, 0), ctx=ctx)
29 | pred = predict(patch, net)#预测
30 | predictions[ i*size: (i+1)*size ,j*size: (j+1)*size] = pred[margin:size+margin,margin:size+margin]
31 | result = predictions[:h,:w]
32 | return result
33 |
34 |
--------------------------------------------------------------------------------
/temp/test_cv2read.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import cv2
4 |
5 | if __name__=='__main__':
6 | img = cv2.imread('/home/omnisky/PycharmProjects/data/originaldata/SatRGB/label/ruoergai_8.png')
7 |
8 | print("ok")
9 |
--------------------------------------------------------------------------------
/temp/test_for_jaccrad_predict.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import matplotlib
3 |
4 | matplotlib.use("Agg")
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from keras.models import Sequential,load_model
8 | from keras.layers import Conv2D, MaxPooling2D, UpSampling2D, BatchNormalization, Reshape, Permute, Activation, Input
9 | from keras.utils.np_utils import to_categorical
10 | from keras.preprocessing.image import img_to_array
11 | from keras.callbacks import ModelCheckpoint, EarlyStopping, History,ReduceLROnPlateau
12 | from keras.models import Model
13 | from keras.layers.merge import concatenate
14 | from PIL import Image
15 | import matplotlib.pyplot as plt
16 | import cv2
17 | import random
18 | import sys
19 | import os
20 | from tqdm import tqdm
21 | from keras.models import *
22 | from keras.layers import *
23 | from keras.optimizers import *
24 |
25 | from keras import backend as K
26 | K.set_image_dim_ordering('tf')
27 |
28 |
29 | from semantic_segmentation_networks import binary_unet_jaccard, binary_fcnnet_jaccard, binary_segnet_jaccard
30 | from ulitities.base_functions import load_img_normalization
31 |
32 | os.environ["CUDA_VISIBLE_DEVICES"] = "5"
33 | seed = 7
34 | np.random.seed(seed)
35 |
36 | img_w = 256
37 | img_h = 256
38 |
39 | n_label = 1
40 |
41 | model_save_path='/home/omnisky/PycharmProjects/data/models/sat_urban_rgb/unet_buildings_binary_jaccard.h5'
42 |
43 |
44 | window_size=256
45 |
46 | def test_predict(image,model):
47 | stride = window_size
48 |
49 | h, w, _ = image.shape
50 | print('h,w:', h, w)
51 | padding_h = (h // stride + 1) * stride
52 | padding_w = (w // stride + 1) * stride
53 | padding_img = np.zeros((padding_h, padding_w, 3))
54 | padding_img[0:h, 0:w, :] = image[:, :, :]
55 |
56 | padding_img = img_to_array(padding_img)
57 |
58 | mask_whole = np.zeros((padding_h, padding_w), dtype=np.float32)
59 | for i in list(range(padding_h // stride)):
60 | for j in list(range(padding_w // stride)):
61 | crop = padding_img[i * stride:i * stride + window_size, j * stride:j * stride + window_size, :3]
62 |
63 | crop = np.expand_dims(crop, axis=0)
64 | print('crop:{}'.format(crop.shape))
65 |
66 | # pred = model.predict(crop, verbose=2)
67 | pred = model.predict(crop, verbose=2)
68 | # pred = np.argmax(pred, axis=2) #for one hot encoding
69 |
70 | pred = pred.reshape(256, 256)
71 | # pred = pred[0]
72 | # pred = pred[:,:,0]
73 | print(np.unique(pred))
74 |
75 |
76 | mask_whole[i * stride:i * stride + window_size, j * stride:j * stride + window_size] = pred[:, :]
77 |
78 | outputresult =mask_whole[0:h,0:w]
79 | # outputresult = outputresult.astype(np.uint8)
80 |
81 | plt.imshow(outputresult, cmap='gray')
82 | plt.title("Original predicted result")
83 | plt.show()
84 | cv2.imwrite('../../data/predict/test_model.png',outputresult*255)
85 | return outputresult
86 |
87 |
88 |
89 |
90 | if __name__ == '__main__':
91 |
92 | print("test ....................predict by trained model .....\n")
93 | test_img_path = '../../data/test/sample1.png'
94 | import sys
95 |
96 | if not os.path.isfile(test_img_path):
97 | print("no file: {}".format(test_img_path))
98 | sys.exit(-1)
99 |
100 | ret, input_img = load_img_normalization(test_img_path)
101 | # model_save_path ='../../data/models/unet_buildings_onehot.h5'
102 |
103 | new_model = load_model(model_save_path)
104 |
105 | test_predict(input_img, new_model)
106 |
--------------------------------------------------------------------------------
/temp/test_unet_multiclass_predict.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | import cv2
4 | import numpy as np
5 | import os
6 | import sys
7 | import argparse
8 | # from keras.preprocessing.image import img_to_array
9 | from keras.models import load_model
10 | from sklearn.preprocessing import LabelEncoder
11 | from keras.preprocessing.image import img_to_array
12 |
13 | import matplotlib.pyplot as plt
14 |
15 | from predict.smooth_tiled_predictions import predict_img_with_smooth_windowing_multiclassbands
16 |
17 | from keras import backend as K
18 | K.set_image_dim_ordering('th')
19 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
20 |
21 | # segnet_classes = [0., 1., 2., 3., 4.]
22 | unet_classes = [0., 1., 2.]
23 |
24 | labelencoder = LabelEncoder()
25 | labelencoder.fit(unet_classes)
26 |
27 | input_image = '../../data/test/1.png'
28 |
29 |
30 | """(1.1) for unet road predict"""
31 | unet_model_path = '../../data/models/unet_channel_first_multiclass.h5'
32 | unet_output_mask = '../../data/predict/unet/mask_unet_roads_'+os.path.split(input_image)[1]
33 |
34 |
35 | window_size = 256
36 |
37 |
38 |
39 |
40 |
41 |
42 | def cheap_predict(input_img, model,label_encoder):
43 | stride = window_size
44 |
45 | h, w, _ = input_img.shape
46 | print 'h,w:', h, w
47 | padding_h = (h // stride + 1) * stride
48 | padding_w = (w // stride + 1) * stride
49 | padding_img = np.zeros((padding_h, padding_w, 3))
50 | padding_img[0:h, 0:w, :] = input_img[:, :, :]
51 |
52 | # Using "img_to_array" to convert the dimensions ordering, to adapt "K.set_image_dim_ordering('**') "
53 | padding_img = img_to_array(padding_img)
54 | print 'src:', padding_img.shape
55 |
56 | mask_whole = np.zeros((padding_h, padding_w), dtype=np.float32)
57 | for i in range(padding_h // stride):
58 | for j in range(padding_w // stride):
59 | crop = padding_img[:3, i * stride:i * stride + window_size, j * stride:j * stride + window_size]
60 | # crop = padding_img[i * stride:i * stride + window_size, j * stride:j * stride + window_size, :3]
61 | cb, ch, cw = crop.shape # for channel_first
62 |
63 | # print ('crop:{}'.format(crop.shape))
64 |
65 | crop = np.expand_dims(crop, axis=0)
66 | # print ('crop:{}'.format(crop.shape))
67 | pred = model.predict(crop, verbose=2)
68 | # print (np.unique(pred))
69 | # pred = label_encoder.inverse_transform(pred[0])
70 | # print (np.unique(pred))
71 | pred = np.argmax(pred,axis=2)
72 | pred = pred.reshape(window_size, window_size)
73 |
74 |
75 | mask_whole[i * stride:i * stride + window_size, j * stride:j * stride + window_size] = pred[:, :]
76 |
77 | outputresult = mask_whole[0:h, 0:w] * 255
78 | # print (np.unique(outputresult))
79 | # print (np.unique(outputresult[:,:,0]))
80 | # print (np.unique(outputresult[:,:,1]))
81 |
82 | # plt.imshow(outputresult[:,:,2])
83 | plt.imshow(outputresult)
84 | plt.title("Original predicted result")
85 | plt.show()
86 | cv2.imwrite('../../data/predict/unet/mask_multiclass_test.png', outputresult)
87 |
88 |
89 | def new_predict_for_unet_multiclass(small_img_patches, model, real_classes,labelencoder):
90 | """
91 |
92 | :param small_img_patches: input image 4D array (patches, row,column, channels)
93 | :param model: pretrained model
94 | :param real_classes: the number of classes and the channels of output mask
95 | :param labelencoder:
96 | :return: predict mask 4D array (patches, row,column, real_classes)
97 | """
98 |
99 | # assert(real_classes ==1 ) # only usefully for one class
100 |
101 | small_img_patches = np.array(small_img_patches)
102 | print (small_img_patches.shape)
103 | assert (len(small_img_patches.shape) == 4)
104 |
105 | patches,row,column,input_channels = small_img_patches.shape
106 |
107 | mask_output = []
108 | for p in range(patches):
109 | # crop = np.zeros((row, column, input_channels), np.uint8)
110 | crop = small_img_patches[p,:,:,:]
111 | crop = img_to_array(crop)
112 | crop = np.expand_dims(crop, axis=0)
113 | # print ('crop:{}'.format(crop.shape))
114 | pred = model.predict(crop, verbose=2)
115 | pred = pred[0].reshape((row,column,real_classes))
116 |
117 | # 将预测结果2D expand to 3D
118 | # res_pred = np.expand_dims(pred, axis=-1)
119 |
120 | mask_output.append(pred)
121 |
122 | mask_output = np.array(mask_output)
123 | print ("Shape of mask_output:{}".format(mask_output.shape))
124 |
125 | return mask_output
126 |
127 |
128 | if __name__=='__main__':
129 | input_img = cv2.imread(input_image)
130 | input_img = np.array(input_img, dtype="float") / 255.0 # must do it
131 | model = load_model(unet_model_path)
132 |
133 | cheap_predict(input_img,model,labelencoder)
134 |
135 | # predictions_smooth = predict_img_with_smooth_windowing_multiclassbands(
136 | # input_img,
137 | # model,
138 | # window_size=window_size,
139 | # subdivisions=2,
140 | # real_classes=3, # output channels = 是真的类别,
141 | # pred_func=new_predict_for_unet_multiclass,
142 | # labelencoder=labelencoder
143 | # )
144 | # plt.imshow(predictions_smooth)
145 | # plt.title("Original predicted result")
146 | # plt.show()
147 |
148 |
149 |
150 |
151 |
152 |
153 |
--------------------------------------------------------------------------------
/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/train/__init__.py
--------------------------------------------------------------------------------
/ui/about.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'about.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.5.1
6 | #
7 | # WARNING! All changes made in this file will be lost!
8 |
9 | from PyQt5 import QtCore, QtGui, QtWidgets
10 |
11 | class Ui_Dialog_about(object):
12 | def setupUi(self, Dialog_about):
13 | Dialog_about.setObjectName("Dialog_about")
14 | Dialog_about.resize(353, 143)
15 | self.textEdit = QtWidgets.QTextEdit(Dialog_about)
16 | self.textEdit.setGeometry(QtCore.QRect(-10, 0, 371, 151))
17 | self.textEdit.setObjectName("textEdit")
18 |
19 | self.retranslateUi(Dialog_about)
20 | QtCore.QMetaObject.connectSlotsByName(Dialog_about)
21 |
22 | def retranslateUi(self, Dialog_about):
23 | _translate = QtCore.QCoreApplication.translate
24 | Dialog_about.setWindowTitle(_translate("Dialog_about", "About"))
25 | self.textEdit.setHtml(_translate("Dialog_about", "\n"
26 | "\n"
29 | "
Copyright SCRS
"))
30 |
31 | import mysrc_rc
32 |
--------------------------------------------------------------------------------
/ui/about.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | Dialog_about
4 |
5 |
6 |
7 | 0
8 | 0
9 | 353
10 | 143
11 |
12 |
13 |
14 | About
15 |
16 |
17 |
18 |
19 | -10
20 | 0
21 | 371
22 | 151
23 |
24 |
25 |
26 | <!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.0//EN" "http://www.w3.org/TR/REC-html40/strict.dtd">
27 | <html><head><meta name="qrichtext" content="1" /><style type="text/css">
28 | p, li { white-space: pre-wrap; }
29 | </style></head><body style=" font-family:'Sans Serif'; font-size:9pt; font-weight:400; font-style:normal;">
30 | <p style=" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;"> <img src=":/log/scrslogo.png" /> <span style=" font-size:16pt;">Copyright SCRS</span><span style=" font-size:20pt;"> </span></p></body></html>
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/ui/else/manual.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ui/else/manual.docx
--------------------------------------------------------------------------------
/ui/else/scrslogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ui/else/scrslogo.png
--------------------------------------------------------------------------------
/ui/else/中文标签.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ui/else/中文标签.docx
--------------------------------------------------------------------------------
/ui/mysrc.qrc:
--------------------------------------------------------------------------------
1 |
2 |
3 | scrslogo.png
4 |
5 |
6 |
--------------------------------------------------------------------------------
/ui/postProcess/Binarization.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'Binarization.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.5.1
6 | #
7 | # WARNING! All changes made in this file will be lost!
8 |
9 | from PyQt5 import QtCore, QtGui, QtWidgets
10 |
11 | class Ui_Dialog_binarization(object):
12 | def setupUi(self, Dialog_binarization):
13 | Dialog_binarization.setObjectName("Dialog_binarization")
14 | Dialog_binarization.resize(408, 214)
15 | self.buttonBox = QtWidgets.QDialogButtonBox(Dialog_binarization)
16 | self.buttonBox.setGeometry(QtCore.QRect(180, 180, 221, 32))
17 | self.buttonBox.setOrientation(QtCore.Qt.Horizontal)
18 | self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel|QtWidgets.QDialogButtonBox.Ok)
19 | self.buttonBox.setObjectName("buttonBox")
20 | self.layoutWidget = QtWidgets.QWidget(Dialog_binarization)
21 | self.layoutWidget.setGeometry(QtCore.QRect(0, 10, 401, 161))
22 | self.layoutWidget.setObjectName("layoutWidget")
23 | self.verticalLayout = QtWidgets.QVBoxLayout(self.layoutWidget)
24 | self.verticalLayout.setObjectName("verticalLayout")
25 | self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
26 | self.horizontalLayout_8.setObjectName("horizontalLayout_8")
27 | self.label_6 = QtWidgets.QLabel(self.layoutWidget)
28 | self.label_6.setMinimumSize(QtCore.QSize(55, 23))
29 | self.label_6.setObjectName("label_6")
30 | self.horizontalLayout_8.addWidget(self.label_6)
31 | self.lineEdit_grayscale_mask = QtWidgets.QLineEdit(self.layoutWidget)
32 | self.lineEdit_grayscale_mask.setMinimumSize(QtCore.QSize(201, 23))
33 | self.lineEdit_grayscale_mask.setObjectName("lineEdit_grayscale_mask")
34 | self.horizontalLayout_8.addWidget(self.lineEdit_grayscale_mask)
35 | self.pushButton_grayscale_mask = QtWidgets.QPushButton(self.layoutWidget)
36 | self.pushButton_grayscale_mask.setMinimumSize(QtCore.QSize(0, 23))
37 | self.pushButton_grayscale_mask.setObjectName("pushButton_grayscale_mask")
38 | self.horizontalLayout_8.addWidget(self.pushButton_grayscale_mask)
39 | self.verticalLayout.addLayout(self.horizontalLayout_8)
40 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
41 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
42 | self.label = QtWidgets.QLabel(self.layoutWidget)
43 | self.label.setObjectName("label")
44 | self.horizontalLayout_2.addWidget(self.label)
45 | self.spinBox_forground = QtWidgets.QSpinBox(self.layoutWidget)
46 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
47 | sizePolicy.setHorizontalStretch(0)
48 | sizePolicy.setVerticalStretch(23)
49 | sizePolicy.setHeightForWidth(self.spinBox_forground.sizePolicy().hasHeightForWidth())
50 | self.spinBox_forground.setSizePolicy(sizePolicy)
51 | self.spinBox_forground.setMinimum(1)
52 | self.spinBox_forground.setMaximum(100000)
53 | self.spinBox_forground.setSingleStep(1)
54 | self.spinBox_forground.setProperty("value", 127)
55 | self.spinBox_forground.setObjectName("spinBox_forground")
56 | self.horizontalLayout_2.addWidget(self.spinBox_forground)
57 | spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
58 | self.horizontalLayout_2.addItem(spacerItem)
59 | self.verticalLayout.addLayout(self.horizontalLayout_2)
60 | self.horizontalLayout = QtWidgets.QHBoxLayout()
61 | self.horizontalLayout.setObjectName("horizontalLayout")
62 | self.label_7 = QtWidgets.QLabel(self.layoutWidget)
63 | self.label_7.setMinimumSize(QtCore.QSize(55, 23))
64 | self.label_7.setObjectName("label_7")
65 | self.horizontalLayout.addWidget(self.label_7)
66 | self.lineEdit_binary_mask = QtWidgets.QLineEdit(self.layoutWidget)
67 | self.lineEdit_binary_mask.setMinimumSize(QtCore.QSize(201, 23))
68 | self.lineEdit_binary_mask.setObjectName("lineEdit_binary_mask")
69 | self.horizontalLayout.addWidget(self.lineEdit_binary_mask)
70 | self.pushButton_binary_mask = QtWidgets.QPushButton(self.layoutWidget)
71 | self.pushButton_binary_mask.setMinimumSize(QtCore.QSize(0, 23))
72 | self.pushButton_binary_mask.setObjectName("pushButton_binary_mask")
73 | self.horizontalLayout.addWidget(self.pushButton_binary_mask)
74 | self.verticalLayout.addLayout(self.horizontalLayout)
75 |
76 | self.retranslateUi(Dialog_binarization)
77 | self.pushButton_grayscale_mask.clicked.connect(Dialog_binarization.slot_get_grayscale_mask)
78 | self.pushButton_binary_mask.clicked.connect(Dialog_binarization.slot_get_saving_binary_mask_path)
79 | self.buttonBox.accepted.connect(Dialog_binarization.slot_ok)
80 | self.buttonBox.rejected.connect(Dialog_binarization.reject)
81 | QtCore.QMetaObject.connectSlotsByName(Dialog_binarization)
82 |
83 | def retranslateUi(self, Dialog_binarization):
84 | _translate = QtCore.QCoreApplication.translate
85 | Dialog_binarization.setWindowTitle(_translate("Dialog_binarization", "Dialog"))
86 | self.label_6.setText(_translate("Dialog_binarization", "Grayscale mask:"))
87 | self.pushButton_grayscale_mask.setText(_translate("Dialog_binarization", "Open"))
88 | self.label.setText(_translate("Dialog_binarization", "Threshold value:"))
89 | self.label_7.setText(_translate("Dialog_binarization", "Binary Mask:"))
90 | self.pushButton_binary_mask.setText(_translate("Dialog_binarization", "Open"))
91 |
92 |
--------------------------------------------------------------------------------
/ui/postProcess/VoteMultimodleResults.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'VoteMultimodleResults.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.5.1
6 | #
7 | # WARNING! All changes made in this file will be lost!
8 |
9 | from PyQt5 import QtCore, QtGui, QtWidgets
10 |
11 | class Ui_Dialog_vote_multimodels(object):
12 | def setupUi(self, Dialog_vote_multimodels):
13 | Dialog_vote_multimodels.setObjectName("Dialog_vote_multimodels")
14 | Dialog_vote_multimodels.resize(419, 232)
15 | self.buttonBox = QtWidgets.QDialogButtonBox(Dialog_vote_multimodels)
16 | self.buttonBox.setGeometry(QtCore.QRect(150, 190, 201, 32))
17 | self.buttonBox.setOrientation(QtCore.Qt.Horizontal)
18 | self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel|QtWidgets.QDialogButtonBox.Ok)
19 | self.buttonBox.setObjectName("buttonBox")
20 | self.layoutWidget = QtWidgets.QWidget(Dialog_vote_multimodels)
21 | self.layoutWidget.setGeometry(QtCore.QRect(10, 10, 401, 171))
22 | self.layoutWidget.setObjectName("layoutWidget")
23 | self.verticalLayout = QtWidgets.QVBoxLayout(self.layoutWidget)
24 | self.verticalLayout.setObjectName("verticalLayout")
25 | self.horizontalLayout_8 = QtWidgets.QHBoxLayout()
26 | self.horizontalLayout_8.setObjectName("horizontalLayout_8")
27 | self.label_6 = QtWidgets.QLabel(self.layoutWidget)
28 | self.label_6.setMinimumSize(QtCore.QSize(55, 23))
29 | self.label_6.setObjectName("label_6")
30 | self.horizontalLayout_8.addWidget(self.label_6)
31 | self.lineEdit_inputs = QtWidgets.QLineEdit(self.layoutWidget)
32 | self.lineEdit_inputs.setMinimumSize(QtCore.QSize(201, 23))
33 | self.lineEdit_inputs.setObjectName("lineEdit_inputs")
34 | self.horizontalLayout_8.addWidget(self.lineEdit_inputs)
35 | self.pushButton_inputs = QtWidgets.QPushButton(self.layoutWidget)
36 | self.pushButton_inputs.setMinimumSize(QtCore.QSize(0, 23))
37 | self.pushButton_inputs.setObjectName("pushButton_inputs")
38 | self.horizontalLayout_8.addWidget(self.pushButton_inputs)
39 | self.verticalLayout.addLayout(self.horizontalLayout_8)
40 | self.horizontalLayout = QtWidgets.QHBoxLayout()
41 | self.horizontalLayout.setObjectName("horizontalLayout")
42 | self.label_7 = QtWidgets.QLabel(self.layoutWidget)
43 | self.label_7.setMinimumSize(QtCore.QSize(55, 23))
44 | self.label_7.setObjectName("label_7")
45 | self.horizontalLayout.addWidget(self.label_7)
46 | self.lineEdit_mask = QtWidgets.QLineEdit(self.layoutWidget)
47 | self.lineEdit_mask.setMinimumSize(QtCore.QSize(201, 23))
48 | self.lineEdit_mask.setObjectName("lineEdit_mask")
49 | self.horizontalLayout.addWidget(self.lineEdit_mask)
50 | self.pushButton_mask = QtWidgets.QPushButton(self.layoutWidget)
51 | self.pushButton_mask.setMinimumSize(QtCore.QSize(0, 23))
52 | self.pushButton_mask.setObjectName("pushButton_mask")
53 | self.horizontalLayout.addWidget(self.pushButton_mask)
54 | self.verticalLayout.addLayout(self.horizontalLayout)
55 | self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
56 | self.horizontalLayout_3.setObjectName("horizontalLayout_3")
57 | self.groupBox = QtWidgets.QGroupBox(self.layoutWidget)
58 | self.groupBox.setObjectName("groupBox")
59 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox)
60 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
61 | self.label = QtWidgets.QLabel(self.groupBox)
62 | self.label.setObjectName("label")
63 | self.horizontalLayout_2.addWidget(self.label)
64 | self.spinBox_min = QtWidgets.QSpinBox(self.groupBox)
65 | self.spinBox_min.setObjectName("spinBox_min")
66 | self.horizontalLayout_2.addWidget(self.spinBox_min)
67 | self.label_2 = QtWidgets.QLabel(self.groupBox)
68 | self.label_2.setObjectName("label_2")
69 | self.horizontalLayout_2.addWidget(self.label_2)
70 | self.spinBox_max = QtWidgets.QSpinBox(self.groupBox)
71 | self.spinBox_max.setMinimum(1)
72 | self.spinBox_max.setProperty("value", 2)
73 | self.spinBox_max.setObjectName("spinBox_max")
74 | self.horizontalLayout_2.addWidget(self.spinBox_max)
75 | self.horizontalLayout_3.addWidget(self.groupBox)
76 | spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
77 | self.horizontalLayout_3.addItem(spacerItem)
78 | self.verticalLayout.addLayout(self.horizontalLayout_3)
79 |
80 | self.retranslateUi(Dialog_vote_multimodels)
81 | self.buttonBox.accepted.connect(Dialog_vote_multimodels.accept)
82 | self.pushButton_inputs.clicked.connect(Dialog_vote_multimodels.slot_select_input_files)
83 | self.pushButton_mask.clicked.connect(Dialog_vote_multimodels.slot_get_save_mask)
84 | self.buttonBox.accepted.connect(Dialog_vote_multimodels.slot_ok)
85 | QtCore.QMetaObject.connectSlotsByName(Dialog_vote_multimodels)
86 |
87 | def retranslateUi(self, Dialog_vote_multimodels):
88 | _translate = QtCore.QCoreApplication.translate
89 | Dialog_vote_multimodels.setWindowTitle(_translate("Dialog_vote_multimodels", "Vote result"))
90 | self.label_6.setText(_translate("Dialog_vote_multimodels", "Input masks:"))
91 | self.pushButton_inputs.setText(_translate("Dialog_vote_multimodels", "Open"))
92 | self.label_7.setText(_translate("Dialog_vote_multimodels", "Output Mask:"))
93 | self.pushButton_mask.setText(_translate("Dialog_vote_multimodels", "Open"))
94 | self.groupBox.setTitle(_translate("Dialog_vote_multimodels", "Values range"))
95 | self.label.setText(_translate("Dialog_vote_multimodels", "min:"))
96 | self.label_2.setText(_translate("Dialog_vote_multimodels", "max:"))
97 |
98 |
--------------------------------------------------------------------------------
/ui/tmp/new_train_implements.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | from keras.layers import *
5 |
6 | from keras import backend as K
7 | K.set_image_dim_ordering('tf')
8 |
9 | from PyQt5.QtCore import Qt
10 | from PyQt5.QtWidgets import QDialog, QFileDialog, QMessageBox
11 |
12 | from TrainBinaryCommon import Ui_Dialog_train_binary_common
13 | from tmp.new_train_backend import train_binary_for_ui
14 |
15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16 | seed = 7
17 | np.random.seed(seed)
18 |
19 | trainBinary_dict={'trainData_path':'', 'saveModel_path':'', 'baseModel':'', 'im_bands':3, 'dtype':0,
20 | 'windsize':256, 'network':'unet', 'target_function':'crossentropy',
21 | 'class_name':'default', 'label_value': 1, 'BS':16, 'EPOCHS':100, 'GPUID':"0"}
22 |
23 |
24 |
25 |
26 | class child_trainBinaryCommon(QDialog, Ui_Dialog_train_binary_common):
27 | def __init__(self):
28 | super(child_trainBinaryCommon, self).__init__()
29 | self.setupUi(self)
30 |
31 | def slot_traindatapath(self):
32 | str = QFileDialog.getExistingDirectory(self, "Train data path", '../../data/')
33 | self.lineEdit_traindata_path.setText(str)
34 | # QDir.setCurrent(str)
35 |
36 | def slot_savemodelpath(self):
37 | str = QFileDialog.getExistingDirectory(self, "Save model", '../../data/')
38 | self.lineEdit_savemodel.setText(str)
39 |
40 | def slot_basemodel(self):
41 | str, _= QFileDialog.getOpenFileName(self, "Select base model", '../../data/', self.tr("Models(*.h5)"))
42 | if not str=='':
43 | self.lineEdit_basemodel.setText(str)
44 |
45 | def slot_ok(self):
46 | self.setWindowModality(Qt.ApplicationModal)
47 | input_dict = trainBinary_dict
48 | if os.path.isdir(self.lineEdit_traindata_path.text()):
49 | input_dict['trainData_path'] = self.lineEdit_traindata_path.text()
50 | if os.path.isdir(self.lineEdit_savemodel.text()):
51 | input_dict['saveModel_path'] = self.lineEdit_savemodel.text()
52 | if os.path.isfile(self.lineEdit_basemodel.text()):
53 | input_dict['baseModel'] = self.lineEdit_basemodel.text()
54 |
55 | input_dict['im_bands'] = int(self.spinBox_bands.value())
56 | input_dict['dtype'] = self.comboBox_dtype.currentIndex()
57 | input_dict['windsize'] = int(self.spinBox_windsize.value())
58 | if self.radioButton_unet.isChecked():
59 | input_dict['network'] = 'unet'
60 | elif self.radioButton_fcnnet.isChecked():
61 | input_dict['network'] = 'fcnnet'
62 | elif self.radioButton_segnet.isChecked():
63 | input_dict['network']='segnet'
64 | else:
65 | print("other network")
66 | sys.exit(-1)
67 |
68 | if self.radioButton_cross_entropy.isChecked():
69 | input_dict['target_function'] = 'crossentropy'
70 | elif self.radioButton_jaccard.isChecked():
71 | input_dict['target_function'] = 'jaccard'
72 | elif self.radioButton_jaccard_crossentropy.isChecked():
73 | input_dict['target_function']='jacc_and_cross'
74 | else:
75 | print("other function")
76 | sys.exit(-1)
77 |
78 | input_dict['class_name'] = self.lineEdit_class_name.text()
79 | input_dict['label_value'] = self.spinBox_label_value.value()
80 | input_dict['BS'] = self.spinBox_BS.value()
81 | input_dict['EPOCHS'] = self.spinBox_epoch.value()
82 | input_dict['GPUID'] = self.comboBox_gupid.currentText()
83 |
84 | ret =-1
85 | ret = train_binary_for_ui(input_dict)
86 | if ret ==0:
87 | QMessageBox.information(self, "Prompt", self.tr("Model Traind successfully!"))
88 |
89 |
90 | self.setWindowModality(Qt.NonModal)
--------------------------------------------------------------------------------
/ulitities/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/ulitities/__init__.py
--------------------------------------------------------------------------------
/ulitities/band_compose.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | import os
4 | import sys
5 | import gdal
6 | import numpy as np
7 |
8 |
9 |
10 | root = '../../data/originaldata/zs/test/stretched/'
11 | files = ['extract_bottom_16bit.png','dem_r05_b.png', 'Int_slope100_Resample05_b.png']
12 |
13 |
14 | output_file =root+'composed.png'
15 |
16 | if __name__=='__main__':
17 | all_data=[]
18 | file_path = root + files[0]
19 | if not os.path.isfile(file_path):
20 | print("File dose not exist:{}".format(file_path))
21 | sys.exit(-1)
22 | # print("deal file: {}".format(file_path))
23 | dataset = gdal.Open(file_path)
24 | if dataset == None:
25 | print("Open falied:{}".format(files[0]))
26 | sys.exit(-2)
27 |
28 | x = dataset.RasterXSize
29 | y = dataset.RasterYSize
30 | # im_band = dataset.RasterCount
31 | d_type = dataset.GetRasterBand(1).DataType
32 | del dataset
33 | all_bands=0
34 |
35 | for file in files:
36 | file_path = root+file
37 | if not os.path.isfile(file_path):
38 | print("File dose not exist:{}".format(file_path))
39 | sys.exit(-3)
40 | print("deal file: {}".format(file_path))
41 | dataset = gdal.Open(file_path)
42 | if dataset==None:
43 | print("Open falied:{}".format(file))
44 | continue
45 | width = dataset.RasterXSize
46 | height = dataset.RasterYSize
47 | if x!=width or y!=height:
48 | print("Error: input files have different width and height\n")
49 | sys.exit(-4)
50 | im_band = dataset.RasterCount
51 | all_bands +=im_band
52 | im_type = dataset.GetRasterBand(1).DataType
53 | if d_type !=im_type:
54 | print("Error: data types are not the same!\n")
55 | sys.exit(-5)
56 | im_data = dataset.ReadAsArray(0,0,width,height)
57 | im_data = np.array(im_data)
58 | all_data.append(im_data)
59 | del dataset
60 |
61 | # all_data = np.array(all_data)
62 | # a,b,c =all_data.shape
63 | # print("allbands : {}".format(c))
64 |
65 | my_driver = gdal.GetDriverByName("GTiff")
66 | out_dataset = my_driver.Create(output_file, x, y, all_bands, d_type)
67 | which_band =1
68 | for i in range(len(all_data)):
69 | dims= len(all_data[i].shape)
70 | print("dimension :{}".format(dims))
71 | if dims <3:
72 | out_dataset.GetRasterBand(which_band).WriteArray(all_data[i])
73 | which_band +=1
74 | else:
75 | im_bands = all_data[i].shape[0]
76 | for j in range(im_bands):
77 | out_dataset.GetRasterBand(which_band).WriteArray(all_data[i][j])
78 | which_band +=1
79 |
80 | print("Saved to: {}".format(output_file))
81 |
82 | del out_dataset
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/ulitities/ecogToPredict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import gdal
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import cv2
8 |
9 |
10 | input_file = '/home/omnisky/PycharmProjects/data/test/tianfuxinqu/ecognition/original/c2017_GJ_Clip_SVM.v1.tif'
11 | output_file = '/home/omnisky/PycharmProjects/data/test/tianfuxinqu/ecognition/pred/2017_GJ_Clip_SVM.png'
12 | if __name__=='__main__':
13 | if not os.path.isfile(input_file):
14 | print("Not valid file path or name:{}".format(input_file))
15 | sys.exit(-1)
16 |
17 | dataset_in = gdal.Open(input_file)
18 | if None==dataset_in:
19 | print("Open file failed!")
20 | sys.exit(-2)
21 | width = dataset_in.RasterXSize
22 | height = dataset_in.RasterYSize
23 | im_bands = dataset_in.RasterCount
24 | # im_type = dataset_in.
25 |
26 | img= dataset_in.ReadAsArray(0,0,width,height)
27 | mask = np.zeros((height,width,), np.uint8)
28 |
29 | # for b in im_bands:
30 | # tmp = img[b,:,:]
31 | # indx = np.where(tmp)
32 | bk_data = img[2, :, :]
33 | # plt.imshow(bk_data, cmap='gray')
34 | # plt.show()
35 |
36 |
37 | roads_data = img[0,:,:]
38 | indx = np.where(roads_data == 255)
39 | mask[indx] =1
40 | # plt.imshow(roads_data)
41 | # plt.show()
42 |
43 | buildings_data = img[1, :, :]
44 | indx = np.where(buildings_data == 255)
45 | mask[indx] = 2
46 | # plt.imshow(buildings_data)
47 | # plt.show()
48 | plt.imshow(mask)
49 | plt.show()
50 |
51 | cv2.imwrite(output_file, mask)
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/ulitities/image_clip.py:
--------------------------------------------------------------------------------
1 | #coding:utf-8
2 |
3 | from PIL import Image
4 | import cv2
5 | import numpy as np
6 | from base_functions import load_img_by_gdal
7 | import matplotlib.pyplot as plt
8 | import gdal
9 | import sys
10 |
11 |
12 | # input_src_file = '/home/omnisky/PycharmProjects/data/test/paper/label/yujiang_test_label.png'
13 | input_src_file ='/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/ducha/DCtuitiantu/label/cd13.png'
14 | # clip_src_file = '/home/omnisky/PycharmProjects/data/test/paper/new/yujiang_test_label.png'
15 | clip_src_file = '/home/omnisky/PycharmProjects/data/test/ducha/cd13_test_label.png'
16 |
17 | window_size = 8000
18 | # h_clip = 5000
19 |
20 | if __name__=='__main__':
21 | # img = load_img_by_gdal(input_src_file)
22 | dataset = gdal.Open(input_src_file)
23 | if dataset==None:
24 | print("Open file failed:{}".format(input_src_file))
25 | sys.exit(-1)
26 | # assert (ret==0)
27 | height = dataset.RasterYSize
28 | width = dataset.RasterXSize
29 | im_bands = dataset.RasterCount
30 | d_type = dataset.GetRasterBand(1).DataType
31 | img = dataset.ReadAsArray(0,0,width,height)
32 | del dataset
33 |
34 | # x = np.random.randint(0, height-window_size-1)
35 | # y = np.random.randint(0, width - window_size - 1)
36 | x =15000
37 | y=3000
38 | # h_clip = int(0.5*width+0.5)
39 | # print("cliped pixels:{}".format(h_clip))
40 |
41 | if im_bands ==1:
42 | output_img = img[y:y + window_size, x:x + window_size]
43 | # output_img = img[100:5000+100, 100:5500+100]
44 | output_img = np.array(output_img, np.uint16)
45 | # output_img[output_img > 2] = 127
46 | tp = output_img
47 | tp[tp>2]=0
48 | print(np.unique(tp))
49 | output_img = np.array(output_img, np.uint8)
50 | plt.imshow(output_img)
51 | plt.show()
52 | cv2.imwrite(clip_src_file, output_img) # for label clip
53 | else:
54 | output_img = img[:,y:y + window_size, x:x + window_size]
55 | # output_img = img[:, :, :h_clip]
56 | # output_img = img[:, :, h_clip:]
57 | plt.imshow(output_img[0])
58 | plt.show()
59 | driver = gdal.GetDriverByName("GTiff")
60 | outdataset = driver.Create(clip_src_file, window_size, window_size, im_bands, d_type)
61 | # outdataset = driver.Create(clip_src_file, h_clip, height, im_bands, d_type)
62 | if outdataset == None:
63 | print("create dataset failed!\n")
64 | sys.exit(-2)
65 | if im_bands == 1:
66 | outdataset.GetRasterBand(1).WriteArray(output_img)
67 | else:
68 | for i in range(im_bands):
69 | outdataset.GetRasterBand(i + 1).WriteArray(output_img[i])
70 | del outdataset
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/ulitities/resample_image.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import cv2, os, sys
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | input_file = '/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/originalLabelandImages/rice/test/testlabel_2.png'
9 | output_file ='/media/omnisky/e0331d4a-a3ea-4c31-90ab-41f5b0ee2663/originalLabelandImages/rice/test/testlabel_2_resampled.png'
10 |
11 | s = 2
12 |
13 | if __name__=="__main__":
14 | img = cv2.imread(input_file, cv2.IMREAD_GRAYSCALE)
15 | a,b = img.shape[:2]
16 | m = int(a/2)
17 | n=int(b/2)
18 |
19 | #
20 | # result = []
21 | # for x in range(m):
22 | # for y in range(n):
23 | # t = img[s*x:s*(x+1), s*y:s*(y+1)].mean()
24 | # result.append(t)
25 | # down_img = np.array((img), np.uint8).reshape(m,n)
26 |
27 | result = cv2.resize(img, (n,m))
28 |
29 | plt.imshow(result)
30 | plt.show()
31 |
32 | cv2.imwrite(output_file,result)
33 |
--------------------------------------------------------------------------------
/ulitities/test_augument.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import os
4 | import numpy as np
5 | from tqdm import tqdm
6 | import sys
7 | import gdal
8 | from keras.preprocessing.image import img_to_array
9 | from scipy.signal import medfilt, medfilt2d
10 | from skimage import exposure
11 |
12 |
13 | filename = '/home/omnisky/PycharmProjects/data/traindata/APsamples/binary/buildings/3.png'
14 |
15 | outputfile = '/home/omnisky/PycharmProjects/data/traindata/APsamples/binary/buildings/3_medfit5.png'
16 |
17 | def rotate(xb, angle):
18 | xb = np.rot90(np.array(xb), k=angle)
19 | # yb = np.rot90(np.array(yb), k=angle)
20 | return xb
21 |
22 | def add_noise(xb):
23 | for i in range(1000):
24 | temp_x = np.random.randint(0, xb.shape[1])
25 | temp_y = np.random.randint(0, xb.shape[2])
26 | xb[:, temp_x, temp_y] =255
27 | return xb
28 |
29 |
30 |
31 | def data_augment(xb):
32 | # xb = exposure.adjust_gamma(xb, 1.0)
33 |
34 | # xb = np.transpose(xb,(1,2,0))
35 | # xb = rotate(xb, 1)
36 |
37 | # xb = rotate(xb, 2)
38 | #
39 | # xb = rotate(xb, 3)
40 | # xb = np.transpose(xb, (2, 0, 1))
41 | #
42 | # xb = np.transpose(xb, (1, 2, 0))
43 | # xb = np.fliplr(xb) # flip an array horizontally
44 | # xb = np.transpose(xb, (2, 0, 1))
45 | #
46 | # xb = np.transpose(xb, (1, 2, 0))
47 | # xb = np.flipud(xb) # flip an array vertically (up down directory)
48 | # xb = np.transpose(xb, (2, 0, 1))
49 | #
50 | # xb = exposure.adjust_gamma(xb, 2.0)
51 | #
52 |
53 | xb = np.transpose(xb, (1, 2, 0))
54 | for i in range(3):
55 | xb[:,:,i] = medfilt(xb[:,:,i], (5, 5))
56 | xb = np.transpose(xb, (2, 0, 1))
57 | #
58 | # xb = add_noise(xb)
59 |
60 |
61 | return xb
62 |
63 |
64 |
65 | if __name__=='__main__':
66 | print("[INFO] open file")
67 |
68 | dataset = gdal.Open(filename)
69 | if dataset == None:
70 | print("open failed!\n")
71 |
72 | Y_height = dataset.RasterYSize
73 | X_width = dataset.RasterXSize
74 | im_bands = dataset.RasterCount
75 | data_type = dataset.GetRasterBand(1).DataType
76 |
77 | src_img = dataset.ReadAsArray(0, 0, X_width, Y_height)
78 | src_img = np.array(src_img)
79 | del dataset
80 | # src_img = np.transpose(src_img, (2, 1, 0))
81 |
82 | print("[INFO] augmentation ")
83 |
84 | src_img = data_augment(src_img)
85 |
86 | # src_img = np.transpose(src_img, (1, 2, 0))
87 |
88 | driver = gdal.GetDriverByName("GTiff")
89 | # outdataset = driver.Create(src_sample_file, img_w, img_h, im_bands, gdal.GDT_UInt16)
90 | outdataset = driver.Create(outputfile, X_width,Y_height, im_bands, data_type)
91 | if im_bands == 1:
92 | outdataset.GetRasterBand(1).WriteArray(src_img)
93 | else:
94 | for i in range(im_bands):
95 | outdataset.GetRasterBand(i + 1).WriteArray(src_img[i])
96 | del outdataset
97 |
98 |
--------------------------------------------------------------------------------
/ulitities/xml_prec.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 |
3 | import os
4 | import sys
5 | import xml.etree.ElementTree as ET
6 | import xml.dom.minidom as Document
7 | import xml.dom
8 |
9 |
10 | # def dict_to_xml(input_dict,root_tag,node_tag):
11 | # """ 定义根节点root_tag,定义第二层节点node_tag
12 | # 第三层中将字典中键值对对应参数名和值
13 | # return: xml的tree结构 """
14 | # root_name = ET.Element(root_tag)
15 | # for (k, v) in input_dict.items():
16 | # node_name = ET.SubElement(root_name, node_tag)
17 | # for (key, val) in sorted(v.items(), key=lambda e:e[0], reverse=True):
18 | # key = ET.SubElement(node_name, key)
19 | # key.text = val
20 | # return root_name
21 |
22 |
23 | # doc = Document()
24 | # root = doc.createElement('root')
25 |
26 |
27 |
28 |
29 | def generate_xml_from_dict(input_dict, xml_file):
30 |
31 |
32 | impl = xml.dom.getDOMImplementation()
33 | dom = impl.createDocument(None, 'root', None)
34 | root = dom.documentElement
35 | instance = dom.createElement('instance')
36 | root.appendChild(instance)
37 |
38 | for (key, value) in sorted(input_dict.items()):
39 | nameE = dom.createElement(key)
40 | nameT = dom.createTextNode(str(value))
41 | nameE.appendChild(nameT)
42 | instance.appendChild(nameE)
43 |
44 | with open(xml_file, 'w', encoding='utf-8') as tp:
45 | dom.writexml(tp, addindent=' ', newl='\n', encoding='utf-8')
46 |
47 |
48 | def parse_xml_to_dict(xml_file):
49 |
50 | if not os.path.isfile(xml_file):
51 | print("input file is not existed!:{}".format(xml_file))
52 | sys.exit(-1)
53 |
54 | tree = ET.parse(xml_file)
55 | root = tree.getroot()
56 | dict_new = {}
57 | for key, value in enumerate(root):
58 | dict_init = {}
59 | list_init = []
60 | for item in value:
61 | list_init.append([item.tag, item.text])
62 | for lists in list_init:
63 | dict_init[lists[0]] = lists[1]
64 | dict_new[key] = dict_init
65 | return dict_new
66 |
67 |
68 |
69 |
70 |
71 | if __name__=='__main__':
72 | # imgStretch_dict = {'input_dir': '', 'output_dir': '', 'NoData': 65535, 'OutBits': '16bits',
73 | # 'StretchRange': '1024',
74 | # 'CutValue': '100'}
75 | save_file = 'imgstretch.xml'
76 | # generate_xml_from_dict(imgStretch_dict, save_file)
77 | dict_one = parse_xml_to_dict(save_file)
78 | print(dict_one[0])
--------------------------------------------------------------------------------
/venv/bin/activate:
--------------------------------------------------------------------------------
1 | # This file must be used with "source bin/activate" *from bash*
2 | # you cannot run it directly
3 |
4 | deactivate () {
5 | # reset old environment variables
6 | if [ -n "$_OLD_VIRTUAL_PATH" ] ; then
7 | PATH="$_OLD_VIRTUAL_PATH"
8 | export PATH
9 | unset _OLD_VIRTUAL_PATH
10 | fi
11 | if [ -n "$_OLD_VIRTUAL_PYTHONHOME" ] ; then
12 | PYTHONHOME="$_OLD_VIRTUAL_PYTHONHOME"
13 | export PYTHONHOME
14 | unset _OLD_VIRTUAL_PYTHONHOME
15 | fi
16 |
17 | # This should detect bash and zsh, which have a hash command that must
18 | # be called to get it to forget past commands. Without forgetting
19 | # past commands the $PATH changes we made may not be respected
20 | if [ -n "$BASH" -o -n "$ZSH_VERSION" ] ; then
21 | hash -r
22 | fi
23 |
24 | if [ -n "$_OLD_VIRTUAL_PS1" ] ; then
25 | PS1="$_OLD_VIRTUAL_PS1"
26 | export PS1
27 | unset _OLD_VIRTUAL_PS1
28 | fi
29 |
30 | unset VIRTUAL_ENV
31 | if [ ! "$1" = "nondestructive" ] ; then
32 | # Self destruct!
33 | unset -f deactivate
34 | fi
35 | }
36 |
37 | # unset irrelavent variables
38 | deactivate nondestructive
39 |
40 | VIRTUAL_ENV="/home/omnisky/PycharmProjects/semantic_segment_RSImage/venv"
41 | export VIRTUAL_ENV
42 |
43 | _OLD_VIRTUAL_PATH="$PATH"
44 | PATH="$VIRTUAL_ENV/bin:$PATH"
45 | export PATH
46 |
47 | # unset PYTHONHOME if set
48 | # this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
49 | # could use `if (set -u; : $PYTHONHOME) ;` in bash
50 | if [ -n "$PYTHONHOME" ] ; then
51 | _OLD_VIRTUAL_PYTHONHOME="$PYTHONHOME"
52 | unset PYTHONHOME
53 | fi
54 |
55 | if [ -z "$VIRTUAL_ENV_DISABLE_PROMPT" ] ; then
56 | _OLD_VIRTUAL_PS1="$PS1"
57 | if [ "x(venv) " != x ] ; then
58 | PS1="(venv) $PS1"
59 | else
60 | if [ "`basename \"$VIRTUAL_ENV\"`" = "__" ] ; then
61 | # special case for Aspen magic directories
62 | # see http://www.zetadev.com/software/aspen/
63 | PS1="[`basename \`dirname \"$VIRTUAL_ENV\"\``] $PS1"
64 | else
65 | PS1="(`basename \"$VIRTUAL_ENV\"`)$PS1"
66 | fi
67 | fi
68 | export PS1
69 | fi
70 |
71 | # This should detect bash and zsh, which have a hash command that must
72 | # be called to get it to forget past commands. Without forgetting
73 | # past commands the $PATH changes we made may not be respected
74 | if [ -n "$BASH" -o -n "$ZSH_VERSION" ] ; then
75 | hash -r
76 | fi
77 |
--------------------------------------------------------------------------------
/venv/bin/activate.csh:
--------------------------------------------------------------------------------
1 | # This file must be used with "source bin/activate.csh" *from csh*.
2 | # You cannot run it directly.
3 | # Created by Davide Di Blasi .
4 | # Ported to Python 3.3 venv by Andrew Svetlov
5 |
6 | alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; test "\!:*" != "nondestructive" && unalias deactivate'
7 |
8 | # Unset irrelavent variables.
9 | deactivate nondestructive
10 |
11 | setenv VIRTUAL_ENV "/home/omnisky/PycharmProjects/semantic_segment_RSImage/venv"
12 |
13 | set _OLD_VIRTUAL_PATH="$PATH"
14 | setenv PATH "$VIRTUAL_ENV/bin:$PATH"
15 |
16 |
17 | set _OLD_VIRTUAL_PROMPT="$prompt"
18 |
19 | if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
20 | if ("venv" != "") then
21 | set env_name = "venv"
22 | else
23 | if (`basename "VIRTUAL_ENV"` == "__") then
24 | # special case for Aspen magic directories
25 | # see http://www.zetadev.com/software/aspen/
26 | set env_name = `basename \`dirname "$VIRTUAL_ENV"\``
27 | else
28 | set env_name = `basename "$VIRTUAL_ENV"`
29 | endif
30 | endif
31 | set prompt = "[$env_name] $prompt"
32 | unset env_name
33 | endif
34 |
35 | alias pydoc python -m pydoc
36 |
37 | rehash
38 |
--------------------------------------------------------------------------------
/venv/bin/activate.fish:
--------------------------------------------------------------------------------
1 | # This file must be used with ". bin/activate.fish" *from fish* (http://fishshell.org)
2 | # you cannot run it directly
3 |
4 | function deactivate -d "Exit virtualenv and return to normal shell environment"
5 | # reset old environment variables
6 | if test -n "$_OLD_VIRTUAL_PATH"
7 | set -gx PATH $_OLD_VIRTUAL_PATH
8 | set -e _OLD_VIRTUAL_PATH
9 | end
10 | if test -n "$_OLD_VIRTUAL_PYTHONHOME"
11 | set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
12 | set -e _OLD_VIRTUAL_PYTHONHOME
13 | end
14 |
15 | if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
16 | functions -e fish_prompt
17 | set -e _OLD_FISH_PROMPT_OVERRIDE
18 | . ( begin
19 | printf "function fish_prompt\n\t#"
20 | functions _old_fish_prompt
21 | end | psub )
22 | functions -e _old_fish_prompt
23 | end
24 |
25 | set -e VIRTUAL_ENV
26 | if test "$argv[1]" != "nondestructive"
27 | # Self destruct!
28 | functions -e deactivate
29 | end
30 | end
31 |
32 | # unset irrelavent variables
33 | deactivate nondestructive
34 |
35 | set -gx VIRTUAL_ENV "/home/omnisky/PycharmProjects/semantic_segment_RSImage/venv"
36 |
37 | set -gx _OLD_VIRTUAL_PATH $PATH
38 | set -gx PATH "$VIRTUAL_ENV/bin" $PATH
39 |
40 | # unset PYTHONHOME if set
41 | if set -q PYTHONHOME
42 | set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
43 | set -e PYTHONHOME
44 | end
45 |
46 | if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
47 | # fish uses a function instead of an env var to generate the prompt.
48 |
49 | # save the current fish_prompt function as the function _old_fish_prompt
50 | . ( begin
51 | printf "function _old_fish_prompt\n\t#"
52 | functions fish_prompt
53 | end | psub )
54 |
55 | # with the original prompt function renamed, we can override with our own.
56 | function fish_prompt
57 | # Prompt override?
58 | if test -n "$(venv) "
59 | printf "%s%s%s" "$(venv) " (set_color normal) (_old_fish_prompt)
60 | return
61 | end
62 | # ...Otherwise, prepend env
63 | set -l _checkbase (basename "$VIRTUAL_ENV")
64 | if test $_checkbase = "__"
65 | # special case for Aspen magic directories
66 | # see http://www.zetadev.com/software/aspen/
67 | printf "%s[%s]%s %s" (set_color -b blue white) (basename (dirname "$VIRTUAL_ENV")) (set_color normal) (_old_fish_prompt)
68 | else
69 | printf "%s(%s)%s%s" (set_color -b blue white) (basename "$VIRTUAL_ENV") (set_color normal) (_old_fish_prompt)
70 | end
71 | end
72 |
73 | set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
74 | end
75 |
--------------------------------------------------------------------------------
/venv/bin/python:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/venv/bin/python
--------------------------------------------------------------------------------
/venv/bin/python3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scrssys/semantic_segment_RSImage/c4148e63eb7b4bcace1ea49209b388699536936a/venv/bin/python3
--------------------------------------------------------------------------------
/venv/lib64:
--------------------------------------------------------------------------------
1 | lib
--------------------------------------------------------------------------------
/venv/pyvenv.cfg:
--------------------------------------------------------------------------------
1 | home = /usr/bin
2 | include-system-site-packages = true
3 | version = 3.5.2
4 |
--------------------------------------------------------------------------------