├── .gitignore ├── Algorithm_Instructions.pdf ├── LICENSE ├── classify.py ├── lib ├── __init__.py ├── attribute_calculations.pyx ├── create_clsf_raster.pyx ├── debug_tools.py ├── rescale_intensity.pyx └── utils.py ├── ossp_process.py ├── preprocess.py ├── readme.md ├── segment.py ├── setup.py ├── training_datasets ├── icebridge_v5_training_data.h5 ├── icebridge_v7_training_data.h5 ├── pan_v2_training_data.h5 ├── wv02_ms_v2_training_data.h5 └── wv02_ms_v3.1_training_data.h5 └── training_gui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.c 9 | *.html 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # Mac stuff: 105 | .DS_Store 106 | init.sh 107 | 108 | -------------------------------------------------------------------------------- /Algorithm_Instructions.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/Algorithm_Instructions.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nicholas Wright 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | # title: Random Forest Classifier 2 | # author: Nick Wright 3 | 4 | from ctypes import * 5 | import numpy as np 6 | from sklearn.ensemble import RandomForestClassifier 7 | from sklearn import metrics 8 | import matplotlib.pyplot as plt 9 | 10 | from lib import utils 11 | from lib import attribute_calculations as attr_calc 12 | from lib import create_clsf_raster as ccr 13 | 14 | 15 | def classify_image(input_image, watershed_data, training_dataset, meta_data, wb_ref, bp_ref): 16 | ''' 17 | Run a random forest classification. 18 | Input: 19 | input_image: preprocessed image data (preprocess.py) 20 | watershed_image: Image objects created with the segmentation 21 | algorithm. (segment.py) 22 | training_dataset: Tuple of training data in the form: 23 | (label_vector, attribute_matrix) 24 | meta_data: [im_type, im_date] 25 | Returns: 26 | Raster of classified data. 27 | ''' 28 | 29 | #### Prepare Data and Variables 30 | image_type = meta_data[0] 31 | image_date = meta_data[1] 32 | 33 | ## Parse training_dataset input 34 | label_vector = training_dataset[0] 35 | training_feature_matrix = training_dataset[1] 36 | 37 | #### Construct the random forest decision tree using the training data set 38 | rfc = RandomForestClassifier(n_estimators=100) 39 | rfc.fit(training_feature_matrix, label_vector) 40 | 41 | clsf_block = classify_block(input_image, watershed_data, image_type, image_date, rfc, wb_ref, bp_ref) 42 | 43 | return clsf_block 44 | 45 | 46 | def classify_block(image_block, watershed_block, image_type, image_date, rfc, wb_ref, bp_ref): 47 | 48 | # Cast data as C int. 49 | watershed_block = watershed_block.astype(c_uint32, copy=False) 50 | 51 | ## If the block contains no data, set the classification values to 0 52 | if np.amax(image_block) < 2: 53 | clsf_block = np.zeros(np.shape(image_block)[1:3]) 54 | return clsf_block 55 | 56 | ## We need the object labels to start at 0. This shifts the entire 57 | # label image down so that the first label is 0, if it isn't already. 58 | if np.amin(watershed_block) > 0: 59 | watershed_block -= np.amin(watershed_block) 60 | ## Calculate the features of each segment within the block. This 61 | # calculation is unique for each image type. 62 | if image_type == 'wv02_ms': 63 | input_feature_matrix = attr_calc.analyze_ms_image(image_block, watershed_block, 64 | wb_ref, bp_ref) 65 | elif image_type == 'srgb': 66 | input_feature_matrix = attr_calc.analyze_srgb_image(image_block,watershed_block) 67 | elif image_type == 'pan': 68 | input_feature_matrix = attr_calc.analyze_pan_image( 69 | image_block, watershed_block, image_date) 70 | 71 | input_feature_matrix = np.array(input_feature_matrix) 72 | 73 | # Predict the classification of each segment 74 | ws_predictions = rfc.predict(input_feature_matrix) 75 | ws_predictions = np.ndarray.astype(ws_predictions, dtype=c_int, copy=False) 76 | # Create the classified image by replacing watershed id's with 77 | # classification values. 78 | # If there is more than one band, we have to select one (using 2 for 79 | # no particular reason). 80 | 81 | clsf_block = ccr.create_clsf_raster(ws_predictions, image_block, 82 | watershed_block) 83 | # clsf_block = ccr.filter_small_segments(clsf_block) 84 | return clsf_block 85 | 86 | 87 | def plot_confusion_matrix(y_pred, y): 88 | plt.imshow(metrics.confusion_matrix(y_pred, y), 89 | cmap=plt.cm.binary, interpolation='nearest') 90 | plt.colorbar() 91 | plt.xlabel("true value") 92 | plt.ylabel("predicted value") 93 | plt.show() 94 | print(metrics.confusion_matrix(y_pred, y)) 95 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/lib/__init__.py -------------------------------------------------------------------------------- /lib/attribute_calculations.pyx: -------------------------------------------------------------------------------- 1 | # cython: cdivision=True 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | cimport cython 5 | import numpy as np 6 | from scipy import stats as spstats 7 | from ctypes import * 8 | 9 | 10 | def analyze_srgb_image(input_image, watershed_image, segment_id=False): 11 | ''' 12 | Cacluate the attributes for each segment given in watershed_image 13 | using the raw pixel values in input image. Attributes calculated for 14 | srgb type images. 15 | ''' 16 | cdef int num_ws 17 | cdef int x_dim, y_dim, num_bands 18 | cdef int ws, b, i, sid 19 | cdef int ws_size 20 | 21 | # If no segment id is provided, analyze the features for every watershed 22 | # in the input image. If a segment id is provided, just analyze the features 23 | # for that one segment. 24 | # We have to add +1 to num_ws because if the maximum value in watershed_image 25 | # is 500, then there are 501 total watersheds Sum(0,1,...,499,500) = 500+1 26 | if segment_id == False: 27 | num_ws = int(np.amax(watershed_image) + 1) 28 | sid = 0 29 | else: 30 | num_ws = 1 31 | sid = segment_id 32 | 33 | num_bands, x_dim, y_dim = np.shape(input_image) 34 | 35 | feature_matrix = np.zeros((num_ws,16), dtype=c_float) 36 | cdef float [:, :] fm_view = feature_matrix 37 | 38 | internal, external, internal_ext, external_ext = pixel_sort_extended(input_image, watershed_image, 39 | sid, x_dim, y_dim, 40 | num_ws, num_bands) 41 | 42 | for ws in range(num_ws): 43 | # If there are no pixels associated with this watershed, skip this iteration 44 | if internal[ws, 0, 0] < 1: 45 | continue 46 | 47 | # Average and Variance of Pixel Intensity for each band 48 | for b in range(3): 49 | count = internal[ws, b, 0] 50 | mean = internal[ws, b, 1] 51 | M2 = internal[ws, b, 2] 52 | variance = M2 / count 53 | if mean < 1: 54 | mean = 1 55 | fm_view[ws, b] = mean 56 | fm_view[ws, b+3] = variance**(1./2) 57 | 58 | # See Miao et al for band ratios 59 | # Division by zero is not possible because fm_view[ws,0:3] have a forced min of 1 60 | # Band Ratio 1 61 | fm_view[ws, 6] = ((fm_view[ws, 2] - fm_view[ws, 0]) / 62 | (fm_view[ws, 2] + fm_view[ws, 0])) 63 | # Band Ratio 2 64 | fm_view[ws, 7] = ((fm_view[ws, 2] - fm_view[ws, 1]) / 65 | (fm_view[ws, 2] + fm_view[ws, 1])) 66 | # Band Ratio 3 67 | # Prevent division by 0 68 | if (2 * fm_view[ws, 2] - fm_view[ws, 1] - fm_view[ws, 0]) < 1: 69 | fm_view[ws, 8] = 0 70 | else: 71 | fm_view[ws, 8] = ((fm_view[ws, 1] - fm_view[ws, 0]) / 72 | (2 * fm_view[ws, 2] - fm_view[ws, 1] - fm_view[ws, 0])) 73 | 74 | # Size of Superpixel 75 | fm_view[ws, 9] = internal[ws, 0, 0] 76 | 77 | # Entropy 78 | histogram_i = internal_ext[ws]#[:last_index(internal_ext[ws])] #np.bincount(internal[1][ws]) 79 | fm_view[ws, 10] = spstats.entropy(histogram_i, base=2) 80 | 81 | # If there are no external pixels (usually when whole images is black) 82 | # skip assigning these values. 83 | if external[ws, 0, 0] < 1: 84 | continue 85 | ## Neighborhood Values 86 | # N. Average Intensity 87 | n_mean = external[ws, 1, 1] 88 | fm_view[ws, 11] = n_mean 89 | # N. Standard Deviation 90 | n_var = (external[ws, 1, 2] / external[ws, 1, 0])**(1./2) 91 | fm_view[ws, 12] = n_var 92 | # N. Maximum Single Value 93 | n_max = last_index(external_ext[ws]) 94 | fm_view[ws, 13] = n_max 95 | # N. Entropy 96 | histogram_e = external_ext[ws] 97 | # histogram_e = np.bincount(external[1][ws]) 98 | fm_view[ws, 14] = spstats.entropy(histogram_e, base=2) 99 | 100 | # Date of image acquisition (removed, but need placeholder) 101 | fm_view[ws, 15] = 0 102 | 103 | return np.copy(fm_view) 104 | 105 | 106 | def analyze_ms_image(input_image, watershed_image, wb_ref, bp_ref, segment_id=False): 107 | ''' 108 | Cacluate the attributes for each segment given in watershed_image 109 | using the raw pixel values in input image. Attributes calculated for 110 | multispectral type WorldView 2 images. 111 | ''' 112 | cdef int num_bands, x_dim, y_dim 113 | cdef int ws, b, sid 114 | cdef int num_ws 115 | cdef double mean, n_mean, M2, variance, count 116 | cdef double wb_point, wb_rel, bp_point, bp_rel 117 | 118 | # If no segment id was given, set sn to zero to signal pixel_sort that 119 | # all segments should be analyzed. Otherwise only the segment with 120 | # the number == sn will be processed 121 | # We have to add +1 to num_ws because if the maximum value in watershed_image 122 | # is 500, then there are 501 total watersheds Sum(0,1,...,499,500) = 500+1 123 | if segment_id == False: 124 | num_ws = int(np.amax(watershed_image) + 1) 125 | sid = 0 126 | else: 127 | num_ws = 1 128 | sid = segment_id 129 | 130 | num_bands, x_dim, y_dim = np.shape(input_image) 131 | 132 | internal, external, internal_ext, external_ext = pixel_sort_extended(input_image, watershed_image, 133 | sid, x_dim, y_dim, 134 | num_ws, num_bands) 135 | 136 | feature_matrix = np.zeros((num_ws, 31), dtype=c_float) 137 | cdef float[:, :] fm_view = feature_matrix 138 | cdef float[:, :, :] in_view = internal 139 | cdef float[:, :, :] ex_view = external 140 | 141 | # wb_ref = wb_ref 142 | # for b in range(8): 143 | # wb_ref[b] = wb_reference[b] 144 | 145 | for ws in range(num_ws): 146 | # If there are no pixels associated with this watershed, skip this iteration 147 | if in_view[ws, 0, 0] < 1: 148 | continue 149 | 150 | # Average Pixel Intensity of each band 151 | for b in range(8): 152 | count = in_view[ws, b, 0] 153 | mean = in_view[ws, b, 1] 154 | if mean < 1: 155 | mean = 1 156 | fm_view[ws, b] = mean 157 | 158 | # Variance of band 7 (emperically the most useful) 159 | count = in_view[ws, 6, 0] 160 | M2 = in_view[ws, 6, 2] 161 | variance = M2 / count 162 | fm_view[ws, 8] = variance**(1./2) #11 14 15 13 8 9 12 10 163 | 164 | # Important band ratios 165 | fm_view[ws, 9] = fm_view[ws, 0] / fm_view[ws, 2] 166 | fm_view[ws, 10] = fm_view[ws, 1] / fm_view[ws, 6] # 167 | #fm_view[ws, 18] = fm_view[ws, 4] / fm_view[ws, 6] # 168 | #fm_view[ws, 19] = fm_view[ws, 3] / fm_view[ws, 5] # 169 | fm_view[ws, 11] = fm_view[ws, 3] / fm_view[ws, 6] 170 | #fm_view[ws, 21] = fm_view[ws, 3] / fm_view[ws, 7] # 171 | #fm_view[ws, 22] = fm_view[ws, 4] / fm_view[ws, 6] # 172 | 173 | # If there are no external pixels (usually when whole images is black) 174 | # skip assigning this value. 175 | if ex_view[ws, 4, 0] >= 1: 176 | # N. Average Intensity 177 | n_mean = ex_view[ws, 3, 1] 178 | fm_view[ws, 12] = n_mean 179 | n_mean = ex_view[ws, 7, 1] 180 | fm_view[ws, 13] = n_mean 181 | 182 | # b1-b7 / b1+b7 183 | fm_view[ws, 14] = ((fm_view[ws, 0] - fm_view[ws, 6]) / (fm_view[ws, 0] + fm_view[ws, 6])) 184 | 185 | # b3-b5 / b3+b5 186 | fm_view[ws, 15] = ((fm_view[ws, 2] - fm_view[ws, 4]) / (fm_view[ws, 2] + fm_view[ws, 4])) 187 | 188 | # Relative to the white balance point (b8 ignored emperically) 189 | for b in range(7): 190 | wb_point = wb_ref[b] 191 | wb_rel = fm_view[ws, b] / wb_point 192 | fm_view[ws, 16+b] = wb_rel #35 193 | 194 | # Relative to the dark reference point 195 | for b in range(8): 196 | bp_point = bp_ref[b] 197 | bp_rel = fm_view[ws, b] / bp_point 198 | fm_view[ws, 23+b] = bp_rel #35 199 | 200 | return np.copy(fm_view) 201 | 202 | 203 | def analyze_pan_image(input_image, watershed_image, date, segment_id=False): 204 | ''' 205 | Cacluate the attributes for each segment given in watershed_image 206 | using the raw pixel values in input image. Attributes calculated for 207 | srgb type images. 208 | ''' 209 | feature_matrix = [] 210 | 211 | cdef int num_ws 212 | cdef int x_dim, y_dim, num_bands 213 | cdef double features[12] 214 | cdef int ws, b, sid 215 | 216 | # If no segment id was given, set sn to zero to signal pixel_sort that 217 | # all segments should be analyzed. Otherwise only the segment with 218 | # the number == sn will be processed 219 | # We have to add +1 to num_ws because if the maximum value in watershed_image 220 | # is 500, then there are 501 total watersheds Sum(0,1,...,499,500) = 500+1 221 | 222 | if segment_id == False: 223 | num_ws = int(np.amax(watershed_image) + 1) 224 | sid = 0 225 | else: 226 | num_ws = 1 227 | sid = segment_id 228 | 229 | x_dim, y_dim, num_bands = np.shape(input_image) 230 | 231 | internal, external = pixel_sort(input_image, watershed_image, 232 | sid, x_dim, y_dim, 233 | num_ws, num_bands) 234 | 235 | for ws in range(num_ws): 236 | 237 | # Check for empty watershed labels 238 | if internal[0][ws] == []: 239 | features = [0 for _ in range(12)] 240 | feature_matrix.append(features) 241 | continue 242 | 243 | # Average Pixel Intensity 244 | features[0] = np.average(internal[0][ws]) 245 | if features[0] < 1: 246 | features[0] = 1 247 | 248 | # Median Pixel Value 249 | features[1] = np.median(internal[0][ws]) 250 | # Segment Minimum 251 | features[2] = np.amin(internal[0][ws]) 252 | # Segment Maximum 253 | features[3] = np.amax(internal[0][ws]) 254 | # Standard Deviation 255 | features[4] = np.std(internal[0][ws]) 256 | # Size 257 | features[5] = len(internal[0][ws]) 258 | 259 | # Entropy 260 | histogram_i = np.bincount(internal[0][ws]) 261 | features[6] = spstats.entropy(histogram_i, base=2) 262 | 263 | ## Neighborhood Values 264 | # N. Average Intensity 265 | features[7] = np.average(external[0][ws]) 266 | # N. Standard Deviation 267 | features[8] = np.std(external[0][ws]) 268 | # N. Maximum Single Value 269 | features[9] = np.amax(external[0][ws]) 270 | # N. Entropy 271 | histogram_e = np.bincount(external[0][ws]) 272 | features[10] = spstats.entropy(histogram_e, base=2) 273 | 274 | # Date of image acquisition 275 | features[11] = int(date) 276 | 277 | feature_matrix.append(features) 278 | 279 | return feature_matrix 280 | 281 | 282 | def pixel_sort(const unsigned char[:,:,:] intensity_image_view, 283 | const unsigned int[:,:] label_image_view, 284 | unsigned int segment_id, int x_dim, int y_dim, int num_ws, int num_bands): 285 | ''' 286 | Given an intensity image and label image of the same dimension, sort 287 | pixels into a list of internal and external intensity pixels for every 288 | label in the label image. 289 | Returns: 290 | Internal: Array of length (number of labels), each element is a list 291 | of intensity values for that label number. 292 | External: Array of length (number of labels), each element is a list 293 | of intensity values that are adjacent to that label number. 294 | ''' 295 | cdef int x, y, i, w, b 296 | cdef unsigned int sn 297 | cdef unsigned char new_val 298 | cdef float count, mean, M2 299 | cdef float delta, delta2 300 | cdef char window[4] 301 | 302 | # Output variables. 303 | internal = np.zeros((num_ws, num_bands, 3), dtype=c_float) 304 | cdef float[:, :, :] in_view = internal 305 | external = np.zeros((num_ws, num_bands, 3), dtype=c_float) 306 | cdef float[:, :, :] ex_view = external 307 | 308 | # Moving window that defines the neighboring region for each pixel 309 | window = [-4, -3, 3, 4] 310 | 311 | for y in range(y_dim): 312 | for x in range(x_dim): 313 | # Ignore pixels whose value is 0 (no data) 314 | if intensity_image_view[0, x, y] == 0: 315 | continue 316 | 317 | # Set the current segment number 318 | sn = label_image_view[x, y] 319 | 320 | # If a segment_id was given 321 | # Select only the ws with the correct label. 322 | # set sn to zero to index in_view properly 323 | if segment_id != 0: 324 | if segment_id == sn: 325 | sn = 0 326 | else: 327 | continue 328 | 329 | # Assign the internal pixel 330 | for b in range(num_bands): 331 | # Find the new pixel 332 | new_val = intensity_image_view[b, x, y] 333 | # Read the previous values 334 | count = in_view[sn, b, 0] 335 | mean = in_view[sn, b, 1] 336 | M2 = in_view[sn, b, 2] 337 | 338 | # Update the stored values 339 | count += 1 340 | delta = new_val - mean 341 | mean += delta / count 342 | delta2 = new_val - mean 343 | M2 += delta * delta2 344 | 345 | # Update the internal list 346 | in_view[sn, b, 0] = count 347 | in_view[sn, b, 1] = mean 348 | in_view[sn, b, 2] = M2 349 | 350 | # Determine the external values within the window 351 | for w in range(4): 352 | i = window[w] 353 | # Determine the external values in the x-axis 354 | # Check for edge conditions 355 | if (x + i < 0) or (x + i >= x_dim): 356 | continue 357 | if label_image_view[x + i, y] != sn: 358 | for b in range(num_bands): 359 | new_val = intensity_image_view[b, x + i, y] 360 | # Read the previous values 361 | count = ex_view[sn, b, 0] 362 | mean = ex_view[sn, b, 1] 363 | M2 = ex_view[sn, b, 2] 364 | 365 | # Update the stored values 366 | count += 1 367 | delta = new_val - mean 368 | mean += delta / count 369 | delta2 = new_val - mean 370 | M2 += delta * delta2 371 | 372 | # Update the internal list 373 | ex_view[sn, b, 0] = count 374 | ex_view[sn, b, 1] = mean 375 | ex_view[sn, b, 2] = M2 376 | 377 | # Determine the external values in the y-axis 378 | # Check for edge conditions 379 | if (y + i < 0) or (y + i >= y_dim): 380 | continue 381 | if label_image_view[x, y + i] != sn: 382 | for b in range(num_bands): 383 | new_val = intensity_image_view[b, x, y + i] 384 | # Read the previous values 385 | count = ex_view[sn, b, 0] 386 | mean = ex_view[sn, b, 1] 387 | M2 = ex_view[sn, b, 2] 388 | 389 | # Update the stored values 390 | count += 1 391 | delta = new_val - mean 392 | mean += delta / count 393 | delta2 = new_val - mean 394 | M2 += delta * delta2 395 | 396 | # Update the internal list 397 | ex_view[sn, b, 0] = count 398 | ex_view[sn, b, 1] = mean 399 | ex_view[sn, b, 2] = M2 400 | 401 | return internal, external 402 | 403 | 404 | def pixel_sort_extended(const unsigned char[:,:,:] intensity_image_view, 405 | const unsigned int[:,:] label_image_view, 406 | unsigned int segment_id, int x_dim, int y_dim, int num_ws, int num_bands): 407 | ''' 408 | Given an intensity image and label image of the same dimension, sort 409 | pixels into a list of internal and external intensity pixels for every 410 | label in the label image. 411 | Returns: 412 | Internal: Array of length (number of labels), each element is a list 413 | of intensity values for that label number. 414 | External: Array of length (number of labels), each element is a list 415 | of intensity values that are adjacent to that label number. 416 | ''' 417 | cdef int x, y, i, w, b 418 | cdef int h_count 419 | cdef unsigned int sn 420 | cdef unsigned char new_val 421 | cdef float count, mean, M2 422 | cdef float delta, delta2 423 | cdef char window[4] 424 | 425 | # Output statistical variables. 426 | internal = np.zeros((num_ws, num_bands, 3), dtype=c_float) 427 | cdef float[:, :, :] in_view = internal 428 | external = np.zeros((num_ws, num_bands, 3), dtype=c_float) 429 | cdef float[:, :, :] ex_view = external 430 | 431 | # Output histogram of each segment 432 | internal_ext = np.zeros((num_ws, 256), dtype=c_int) 433 | cdef int[:, :] in_ext_view = internal_ext 434 | external_ext = np.zeros((num_ws, 256), dtype=c_int) 435 | cdef int[:, :] ex_ext_view = external_ext 436 | 437 | # Moving window that defines the neighboring region for each pixel 438 | window = [-4, -3, 3, 4] 439 | 440 | for y in range(y_dim): 441 | for x in range(x_dim): 442 | # Ignore pixels whose value is 0 (no data) 443 | if intensity_image_view[0, x, y] == 0: 444 | continue 445 | 446 | # Set the current segment number 447 | sn = label_image_view[x, y] 448 | 449 | # If a segment_id was given 450 | # Select only the ws with the correct label. 451 | # set sn to zero to index properly 452 | if segment_id != 0: 453 | if segment_id == sn: 454 | sn = 0 455 | else: 456 | continue 457 | 458 | # Assign the internal pixel 459 | for b in range(num_bands): 460 | # Find the new pixel 461 | new_val = intensity_image_view[b, x, y] 462 | # Read the previous values 463 | count = in_view[sn, b, 0] 464 | mean = in_view[sn, b, 1] 465 | M2 = in_view[sn, b, 2] 466 | 467 | # Update the stored values 468 | count += 1 469 | delta = new_val - mean 470 | mean += delta / count 471 | delta2 = new_val - mean 472 | M2 += delta * delta2 473 | 474 | # Update the internal list 475 | in_view[sn, b, 0] = count 476 | in_view[sn, b, 1] = mean 477 | in_view[sn, b, 2] = M2 478 | 479 | # Increment this pixel value in the b0 histogram 480 | if b == 1: 481 | h_count = in_ext_view[sn, new_val] 482 | h_count += 1 483 | in_ext_view[sn, new_val] = h_count 484 | 485 | # Determine the external values within the window 486 | for w in range(4): 487 | i = window[w] 488 | # Determine the external values in the x-axis 489 | # Check for edge conditions 490 | if (x + i < 0) or (x + i >= x_dim): 491 | continue 492 | if label_image_view[x + i, y] != sn: 493 | for b in range(num_bands): 494 | new_val = intensity_image_view[b, x + i, y] 495 | # Read the previous values 496 | count = ex_view[sn, b, 0] 497 | mean = ex_view[sn, b, 1] 498 | M2 = ex_view[sn, b, 2] 499 | 500 | # Update the stored values 501 | count += 1 502 | delta = new_val - mean 503 | mean += delta / count 504 | delta2 = new_val - mean 505 | M2 += delta * delta2 506 | 507 | # Update the internal list 508 | ex_view[sn, b, 0] = count 509 | ex_view[sn, b, 1] = mean 510 | ex_view[sn, b, 2] = M2 511 | 512 | # Increment this pixel value in the b0 histogram 513 | if b == 1: 514 | h_count = ex_ext_view[sn, new_val] 515 | h_count += 1 516 | ex_ext_view[sn, new_val] = h_count 517 | 518 | # Determine the external values in the y-axis 519 | # Check for edge conditions 520 | if (y + i < 0) or (y + i >= y_dim): 521 | continue 522 | if label_image_view[x, y + i] != sn: 523 | for b in range(num_bands): 524 | new_val = intensity_image_view[b, x, y + i] 525 | # Read the previous values 526 | count = ex_view[sn, b, 0] 527 | mean = ex_view[sn, b, 1] 528 | M2 = ex_view[sn, b, 2] 529 | 530 | # Update the stored values 531 | count += 1 532 | delta = new_val - mean 533 | mean += delta / count 534 | delta2 = new_val - mean 535 | M2 += delta * delta2 536 | 537 | # Update the internal list 538 | ex_view[sn, b, 0] = count 539 | ex_view[sn, b, 1] = mean 540 | ex_view[sn, b, 2] = M2 541 | 542 | # Increment this pixel value in the b0 histogram 543 | if b == 1: 544 | h_count = ex_ext_view[sn, new_val] 545 | h_count += 1 546 | ex_ext_view[sn, new_val] = h_count 547 | 548 | return internal, external, internal_ext, external_ext 549 | 550 | 551 | cdef int last_index(int[:] lst): 552 | cdef int i 553 | for i in range(255,-1,-1): 554 | if lst[i] != 0: 555 | return i 556 | 557 | return 0 558 | 559 | # From wikipedia: Welfords algorithm 560 | # for a new value newValue, compute the new count, new mean, the new M2. 561 | # mean accumulates the mean of the entire dataset 562 | # M2 aggregates the squared distance from the mean 563 | # count aggregates the number of samples seen so far 564 | # cdef float update(float count, float mean, float M2, char newValue): 565 | # cdef float delta, delta2 566 | # count += 1 567 | # delta = newValue - mean 568 | # mean += delta / count 569 | # delta2 = newValue - mean 570 | # M2 += delta * delta2 571 | # 572 | # return (count, mean, M2) 573 | 574 | # retrieve the mean, variance and sample variance from an aggregate 575 | # def finalize(float count, float M2): 576 | # cdef float variance sample_variance 577 | # if count < 2: 578 | # return 1 579 | # 580 | # variance = M2 / count 581 | # sampleVariance = M2 / (count - 1) 582 | # 583 | # return variance, sampleVariance -------------------------------------------------------------------------------- /lib/create_clsf_raster.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ctypes import * 3 | import skimage.morphology as morph 4 | 5 | 6 | def create_clsf_raster(int[:] prediction, 7 | unsigned char[:,:,:] intensity_image_view, 8 | unsigned int[:,:] label_image_view): 9 | ''' 10 | Transfer classified results from a list of segment:classification pairs 11 | to a raster where pixel values are the classification result. 12 | ''' 13 | cdef int num_ws 14 | cdef int y, x 15 | cdef int x_dim, y_dim 16 | cdef int band_list[3] 17 | 18 | # Create a blank image that we will assign values based on the prediction for each 19 | # watershed. 20 | num_bands, x_dim, y_dim = np.shape(intensity_image_view) 21 | clsf_block = np.empty((x_dim,y_dim), dtype=c_byte) 22 | cdef char [:, :] clsf_block_view = clsf_block 23 | 24 | # Watershed indexes start at 0, so we have to add 1 to get the number. 25 | num_ws = np.amax(label_image_view) + 1 26 | 27 | if num_bands == 1: 28 | band_list = [0,0,0] 29 | else: 30 | band_list = [0,1,2] 31 | # Check to see if the whole block is one segment 32 | if num_ws >= 2: 33 | # Assign all segments to their predicted classification 34 | for y in range(y_dim): 35 | for x in range(x_dim): 36 | # Setting the empty pixels (at least 3 bands have values of 0) to 0 37 | if ((intensity_image_view[band_list[0], x, y] == 0) 38 | & (intensity_image_view[band_list[1], x, y] == 0) 39 | & (intensity_image_view[band_list[2], x, y] == 0)): 40 | clsf_block_view[x,y] = 0 41 | else: 42 | clsf_block_view[x,y] = prediction[label_image_view[x,y]] 43 | else: 44 | # Assign all segments to their predicted classification 45 | for y in range(y_dim): 46 | for x in range(x_dim): 47 | # Set the empty pixels (at least 3 bands have values of 0) to 0 48 | if ((intensity_image_view[band_list[0], x, y] == 0) 49 | & (intensity_image_view[band_list[1], x, y] == 0) 50 | & (intensity_image_view[band_list[2], x, y] == 0)): 51 | clsf_block_view[x,y] = 0 52 | else: 53 | clsf_block_view[x,y] = prediction[0] 54 | 55 | clsf_block = np.copy(clsf_block_view) 56 | 57 | return clsf_block 58 | 59 | 60 | def filter_small_segments(clsf_block): 61 | ''' 62 | Remove small segments from the classified image. 63 | All regions smaller than the defined structuring element will be removed 64 | so long as the surrounding classification is a single category. 65 | ''' 66 | # Structuring element. 67 | strel = morph.disk(2) 68 | 69 | # Sequentially perform both an opening and closing operation to 70 | # remove both 'dark' and 'light' speckle. 71 | clsf_block_o = morph.opening(clsf_block,strel) 72 | clsf_block_oc = morph.closing(clsf_block_o,strel) 73 | 74 | return clsf_block_oc -------------------------------------------------------------------------------- /lib/debug_tools.py: -------------------------------------------------------------------------------- 1 | from skimage import segmentation, exposure 2 | import matplotlib.pyplot as plt 3 | import matplotlib.colors as colors 4 | import numpy as np 5 | from sklearn.ensemble import RandomForestClassifier 6 | from sklearn import metrics 7 | 8 | from lib import utils 9 | 10 | 11 | def display_image(raw,watershed,classified,type): 12 | 13 | # Save a color 14 | empty_color = [.1,.1,.1] #Almost black 15 | snow_color = [.9,.9,.9] #Almost white 16 | pond_color = [.31,.431,.647] #Blue 17 | gray_color = [.65,.65,.65] #Gray 18 | water_color = [0.,0.,0.] #Black 19 | shadow_color = [.100, .545, .0]#Orange 20 | 21 | custom_colormap = [empty_color,snow_color,gray_color,pond_color,water_color,shadow_color] 22 | custom_colormap = colors.ListedColormap(custom_colormap) 23 | 24 | #Making sure there is atleast one of every pixel so the colors map properly (only changes 25 | # display image, not saved data) 26 | classified[0][0] = 0 27 | classified[1][0] = 1 28 | classified[2][0] = 2 29 | classified[3][0] = 3 30 | classified[4][0] = 4 31 | classified[5][0] = 5 32 | 33 | # Figure that show 3 images: raw, segmented, and classified 34 | if type == 1: 35 | # Creating the watershed display image with borders highlighted 36 | ws_bound = segmentation.find_boundaries(watershed) 37 | ws_display = utils.create_composite([raw,raw,raw]) 38 | ws_display[:,:,0][ws_bound] = 255 39 | ws_display[:,:,1][ws_bound] = 255 40 | ws_display[:,:,2][ws_bound] = 22 41 | 42 | fig, axes = plt.subplots(1,3,subplot_kw={'xticks':[], 'yticks':[]}) 43 | fig.subplots_adjust(left=0.05,right=0.99,bottom=0.05,top=0.90,wspace=0.02,hspace=0.2) 44 | 45 | tnrfont = {'fontname':'Times New Roman'} 46 | 47 | axes[0].imshow(raw,cmap='gray',interpolation='None') 48 | axes[0].set_title("Raw Image", **tnrfont) 49 | axes[1].imshow(ws_display,interpolation='None') 50 | axes[1].set_title("Image Segments", **tnrfont) 51 | axes[2].imshow(classified,cmap=custom_colormap,interpolation='None') 52 | axes[2].set_title("Classification Output", **tnrfont) 53 | 54 | # Figure that shows 2 images: raw and classified. 55 | if type == 2: 56 | fig, axes = plt.subplots(1,2,subplot_kw={'xticks':[], 'yticks':[]}) 57 | fig.subplots_adjust(hspace=0.3,wspace=0.05) 58 | axes[0].imshow(raw,interpolation='None') 59 | axes[0].set_title("Raw Image") 60 | axes[1].imshow(classified,cmap=custom_colormap,interpolation='None') 61 | axes[1].set_title("Classification Output") 62 | 63 | plt.show() 64 | 65 | 66 | # Plots a watershed image on top of and beside the original image 67 | ## Used for debugging 68 | def display_watershed(original_data, watershed_data, block=5): 69 | 70 | # block = 5 71 | watershed = watershed_data[block] 72 | original_1 = original_data[6][block] 73 | original_2 = original_data[4][block] 74 | original_3 = original_data[1][block] 75 | 76 | # randcolor = colors.ListedColormap(np.random.rand(256,3)) 77 | ws_bound = segmentation.find_boundaries(watershed) 78 | ws_display = utils.create_composite([original_1,original_2,original_3]) 79 | ws_display[:,:,0][ws_bound] = 240 80 | ws_display[:,:,1][ws_bound] = 80 81 | ws_display[:,:,2][ws_bound] = 80 82 | 83 | display_im = utils.create_composite([original_1,original_2,original_3]) 84 | 85 | fig, axes = plt.subplots(1,2,subplot_kw={'xticks':[], 'yticks':[]}) 86 | fig.subplots_adjust(hspace=0.3,wspace=0.05) 87 | 88 | # axes[1].imshow(self.sobel_image,interpolation='none',cmap='gray') 89 | axes[0].imshow(display_im,interpolation='none') 90 | axes[1].imshow(ws_display,interpolation='none') 91 | plt.show() 92 | 93 | 94 | def display_histogram(image_band): 95 | ''' 96 | Displays a histogram of the given band's data. 97 | Ignores zero values. 98 | ''' 99 | hist, bin_centers = exposure.histogram(image_band[image_band>0],nbins=1000) 100 | 101 | plt.figure(1) 102 | plt.bar(bin_centers, hist) 103 | # plt.xlim((0,np.max(image_band))) 104 | # plt.ylim((0,100000)) 105 | plt.xlabel("Pixel Intensity") 106 | plt.ylabel("Frequency") 107 | plt.show() 108 | 109 | 110 | # Method to assess the training set and classification tree used for this classification 111 | def test_training(label_vector, training_feature_matrix): 112 | 113 | print("Size of training set: %i" %len(label_vector)) 114 | print(np.shape(training_feature_matrix)) 115 | 116 | # Add a random number to the training data as a reference point 117 | # Anything less important than a random number is obviously useless 118 | tfm_new = [] 119 | for i in range(len(training_feature_matrix)): 120 | tf = training_feature_matrix[i] 121 | tf.append(np.random.rand(1)[0]) 122 | tfm_new.append(tf) 123 | 124 | training_feature_matrix = tfm_new 125 | print(np.shape(training_feature_matrix)) 126 | 127 | rfc = RandomForestClassifier(n_estimators=100,oob_score=True) 128 | rfc.fit(training_feature_matrix, label_vector) 129 | print("OOB Score: %f" %rfc.oob_score_) 130 | 131 | training_feature_matrix = np.array(training_feature_matrix) 132 | importances = rfc.feature_importances_ 133 | std = np.std([tree.feature_importances_ for tree in rfc.estimators_], 134 | axis=0) 135 | 136 | feature_names = ['b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'std7', 'b1/b3', 'b2/b7', 'b4/b7', 137 | 'ex.b4', 'ex.b8', r'$\frac{b1-b7}{b1+b7}$', r'$\frac{b3-b5}{b3+b5}$', 138 | 'wb.b1', 'wb.b2', 'wb.b3', 'wb.b4', 'wb.b5', 'wb.b6', 'wb.b7', 139 | 'bp.b1', 'bp.b2', 'bp.b3', 'bp.b4', 'bp.b5', 'bp.b6', 'bp.b7', 'bp.b8', 'random'] 140 | 141 | print(len(feature_names)) 142 | # feature_names = range(len(training_feature_matrix)) 143 | 144 | indices = np.argsort(importances)[::-1] 145 | 146 | # feature_names = ['Mean Intensity','Standard Deviation','Size','Entropy','Neighbor Mean Intensity' 147 | # 'Neighbor Standard Deviation','Neighbor Maximum Intensity','Neighbor Entropy','Date'] 148 | 149 | # Print the feature ranking 150 | print("Feature ranking:") 151 | 152 | feature_names_sorted = [] 153 | for f in range(training_feature_matrix.shape[1]): 154 | print("%d. feature %s (%f)" % (f+1, feature_names[indices[f]], importances[indices[f]])) 155 | feature_names_sorted.append(feature_names[indices[f]]) 156 | 157 | # Plot the feature importances of the forest 158 | plt.figure() 159 | plt.title("Feature importances") 160 | plt.bar(range(training_feature_matrix.shape[1]), importances[indices], 161 | color=[.161,.333,.608], yerr=std[indices], align="center", 162 | error_kw=dict(ecolor=[.922,.643,.173], lw=2, capsize=3, capthick=2)) 163 | # plt.xticks(range(training_feature_matrix.shape[1]), feature_names_sorted)#, rotation='45') 164 | plt.xticks(range(training_feature_matrix.shape[1]), feature_names_sorted, rotation='45') 165 | plt.xlim([-1, training_feature_matrix.shape[1]]) 166 | plt.show() 167 | -------------------------------------------------------------------------------- /lib/rescale_intensity.pyx: -------------------------------------------------------------------------------- 1 | # cython: cdivision=True 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | cimport cython 5 | import numpy as np 6 | from ctypes import * 7 | 8 | 9 | def white_balance(src_ds, reference, double imax): 10 | ''' 11 | src_ds: input image to balance (ndim must == 3) 12 | reference: array of length equal to src_ds dim 0, image will be scaled by this reference 13 | ''' 14 | cdef int x, y, b 15 | cdef int x_dim, y_dim, num_bands 16 | cdef float val 17 | cdef unsigned short new_val_short 18 | cdef unsigned char new_val 19 | 20 | cdef unsigned char [:, :, :] src_view = src_ds 21 | dst_ds = np.empty_like(src_ds) 22 | cdef unsigned char [:, :, :] dst_view = dst_ds 23 | cdef double [:] ref_view = reference 24 | 25 | num_bands, x_dim, y_dim = np.shape(src_ds) 26 | 27 | # Check that the user provided the correct number of reference points 28 | if np.shape(reference)[0] != num_bands: 29 | return src_ds 30 | 31 | for y in range(y_dim): 32 | for x in range(x_dim): 33 | for b in range(num_bands): 34 | val = src_view[b, x, y] 35 | if val == 0: 36 | new_val = 0 37 | else: 38 | new_val_short = int((imax / ref_view[b]) * val) 39 | if new_val_short < 1: 40 | new_val = 1 41 | elif new_val_short > 255: 42 | new_val = 255 43 | else: 44 | new_val = new_val_short 45 | dst_view[b, x, y] = new_val 46 | 47 | return np.copy(dst_view) 48 | 49 | 50 | def rescale_intensity(src_ds, int imin, int imax, int omin, int omax): 51 | 52 | # Check raster structure for panchromatic images: 53 | if src_ds.ndim == 3: 54 | return _rescale_intensity_3d(src_ds, imin, imax, omin, omax) 55 | else: 56 | return _rescale_intensity_2d(src_ds, imin, imax, omin, omax) 57 | 58 | 59 | def _rescale_intensity_3d(src_ds, int imin, int imax, int omin, int omax): 60 | ''' 61 | Rescales the input image intensity values. 62 | While omin and omax are arguments, this function currently only converts 63 | to uint8 64 | ''' 65 | cdef int x, y, b 66 | cdef int x_dim, y_dim, num_bands 67 | cdef float val 68 | cdef unsigned char new_val 69 | 70 | cdef unsigned short [:, :, :] src_view = src_ds 71 | dst_ds = np.empty_like(src_ds, dtype=c_uint8) 72 | cdef unsigned char [:, :, :] dst_view = dst_ds 73 | 74 | num_bands, x_dim, y_dim = np.shape(src_ds) 75 | 76 | for y in range(y_dim): 77 | for x in range(x_dim): 78 | for b in range(num_bands): 79 | val = src_view[b, x, y] 80 | if val == 0: 81 | new_val = 0 82 | else: 83 | if val < imin: 84 | val = imin 85 | elif val > imax: 86 | val = imax 87 | new_val = int(((val - imin) / (imax - imin)) * (omax - omin) + omin) 88 | dst_view[b, x, y] = new_val 89 | 90 | return np.copy(dst_view) 91 | 92 | 93 | def _rescale_intensity_2d(src_ds, int imin, int imax, int omin, int omax): 94 | ''' 95 | Rescales the input image intensity values 96 | While omin and omax are arguments, this function currently only converts 97 | to uint8 98 | ''' 99 | cdef int x, y, b 100 | cdef int x_dim, y_dim, num_bands 101 | cdef float val 102 | cdef unsigned char new_val 103 | 104 | cdef unsigned short [:, :] src_view = src_ds 105 | dst_ds = np.empty_like(src_ds, dtype=c_uint8) 106 | cdef unsigned char [:, :] dst_view = dst_ds 107 | 108 | x_dim, y_dim = np.shape(src_ds) 109 | num_bands = 1 110 | 111 | for y in range(y_dim): 112 | for x in range(x_dim): 113 | val = src_view[x, y] 114 | if val == 0: 115 | new_val = 0 116 | else: 117 | if val < imin: 118 | val = imin 119 | elif val > imax: 120 | val = imax 121 | new_val = int(((val - imin) / (imax - imin)) * (omax - omin) + omin) 122 | dst_view[x, y] = new_val 123 | 124 | return np.copy(dst_view) 125 | 126 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | import csv 4 | import math 5 | import itertools 6 | import sqlite3 7 | import numpy as np 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | import matplotlib.colors as colors 11 | import matplotlib.image as mimg 12 | from ctypes import * 13 | 14 | valid_extensions = ['.tif','.tiff','.jpg'] 15 | 16 | class Task: 17 | 18 | def __init__(self, name, directory): 19 | self.task_id = name 20 | self.task_dir = directory 21 | self.dst_dir = None 22 | self.complete = False 23 | 24 | def get_id(self): 25 | return self.task_id 26 | 27 | def change_id(self, name): 28 | self.task_id = name 29 | 30 | def set_src_dir(self, directory): 31 | self.task_dir = directory 32 | 33 | def get_src_dir(self): 34 | return self.task_dir 35 | 36 | def set_dst_dir(self, dst_dir): 37 | self.dst_dir = dst_dir 38 | 39 | def get_dst_dir(self): 40 | return self.dst_dir 41 | 42 | def mark_complete(self): 43 | self.complete = True 44 | 45 | def is_complete(self): 46 | return self.complete 47 | 48 | 49 | def create_task_list(src_dir, dst_dir): 50 | 51 | task_list = [] 52 | # If the input is a file, return that file as the only task 53 | if os.path.isfile(src_dir): 54 | src_dir,src_file = os.path.split(src_dir) 55 | task = Task(src_file, src_dir) 56 | 57 | # Set the output directory if given, otherwise use the default 58 | if dst_dir == "default": 59 | task.set_dst_dir(os.path.join(src_dir, "classified")) 60 | else: 61 | task.set_dst_dir(dst_dir) 62 | return [task] 63 | 64 | # Loop through contents of the given directory 65 | for file in os.listdir(src_dir): 66 | 67 | # Skip hidden files 68 | if file[0] == '.': 69 | continue 70 | 71 | image_name,ext = os.path.splitext(file) 72 | # Check that the file is .tif or .jpg format 73 | ext = ext.lower() 74 | if ext not in valid_extensions: 75 | continue 76 | 77 | ## Create the task object for this image 78 | task = Task(file, src_dir) 79 | 80 | # Set the output directory if given, otherwise use the default 81 | if dst_dir == "default": 82 | task.set_dst_dir(os.path.join(src_dir, "classified")) 83 | else: 84 | task.set_dst_dir(dst_dir) 85 | 86 | ## Check the output directory for completed files 87 | # if os.path.isdir(task.get_dst_dir()): 88 | # clsf_imgs = os.listdir(task.get_dst_dir()) 89 | # # Finished images have a consistant naming structure: 90 | # target_name = image_name + '_classified.tif' 91 | # for img in clsf_imgs: 92 | # # Set this task to complete if we find the finished image 93 | # if img == target_name: 94 | # task.mark_complete() 95 | # 96 | # ## Skip to the next image if this task is complete 97 | # if task.is_complete(): 98 | # continue 99 | 100 | task_list.append(task) 101 | 102 | return task_list 103 | 104 | 105 | #### Load Training Dataset (TDS) (Label Vector and Feature Matrix) 106 | def load_tds(file_name, list_name, image_type): 107 | ''' 108 | INPUT: 109 | input_directory of .h5 training data 110 | file_name of .h5 training data 111 | list_name of label vector contained within file_name 112 | RETURNS: 113 | tds = [label_vector, training_feature_matrix] 114 | ''' 115 | if image_type == 'srgb' and list_name != 'srgb': 116 | list_prefix = list_name + "_" 117 | label_name = "{}labels".format(list_prefix) 118 | else: 119 | list_prefix = "" 120 | label_name = list_name 121 | 122 | 123 | ## Load the training data 124 | with h5py.File(file_name, 'r') as training_file: 125 | # Try loading the dataset with the provided name. If that doesnt work, 126 | # try loading with the default name 127 | if label_name in training_file.keys(): 128 | label_vector = training_file[label_name][:] 129 | training_feature_matrix = training_file['{}feature_matrix'.format(list_prefix)][:] 130 | else: 131 | label_vector = training_file[image_type][:] 132 | training_feature_matrix = training_file['feature_matrix'][:] 133 | 134 | ## Convert inputs to python lists 135 | label_vector = label_vector.tolist() 136 | training_feature_matrix = training_feature_matrix.tolist() 137 | # Remove feature lists that don't have an associated label 138 | training_feature_matrix = training_feature_matrix[:len(label_vector)] 139 | 140 | ## Remove the segments labeled "unknown" (0) 141 | while 0 in label_vector: 142 | i = label_vector.index(0) 143 | label_vector.pop(i) 144 | training_feature_matrix.pop(i) 145 | 146 | ## Remove the segments labeled "mixed" (0) 147 | while 6 in label_vector: 148 | i = label_vector.index(6) 149 | label_vector.pop(i) 150 | training_feature_matrix.pop(i) 151 | 152 | if list_name != 'spring' and image_type != 'wv02_ms': 153 | while 5 in label_vector: 154 | i = label_vector.index(5) 155 | label_vector.pop(i) 156 | training_feature_matrix.pop(i) 157 | 158 | # Combine the label vector and training feature matrix into one variable. 159 | tds = [label_vector,training_feature_matrix] 160 | 161 | return tds 162 | 163 | 164 | def find_blocksize(x_dim, y_dim, desired_size): 165 | """ 166 | Finds the appropriate block size for an input image of a given dimensions. 167 | Method returns the first factor of the input dimension that is greater than 168 | the desired size. 169 | """ 170 | if x_dim <= desired_size or y_dim <= desired_size: 171 | return x_dim, y_dim 172 | 173 | block_size_x = desired_size 174 | block_size_y = desired_size 175 | 176 | # Ensure that chosen block size divides into the image dimension with a remainder that is 177 | # at least half a standard block in width. 178 | while (x_dim % block_size_x) <= (block_size_x / 2): 179 | block_size_x += 256 180 | 181 | # Make sure the blocks don't get too big. 182 | if block_size_x >= x_dim: 183 | block_size_x = x_dim 184 | 185 | while (y_dim % block_size_y) <= (block_size_y / 2): 186 | block_size_y += 256 187 | 188 | if block_size_y >= y_dim: 189 | block_size_y = y_dim 190 | 191 | return block_size_x, block_size_y 192 | 193 | 194 | #### Save classification results 195 | def write_to_csv(csv_name, path, image_name, pixel_counts): 196 | ''' 197 | INPUT: 198 | path: location where the output will write 199 | image_name: name of the image that was classified 200 | pixel_clounts: number of pixels in each classification category 201 | 202 | Saves a csv with the input information. Appends to existing csv if one already exists 203 | 204 | NOTES: 205 | Only works with 5 classification categories: 206 | [white ice, gray ice, melt ponds, open water, shadow] 207 | ''' 208 | 209 | csv_name = os.path.splitext(csv_name)[0] + ".csv" 210 | 211 | num_pixels = 1 #Prevent division by 0 212 | for i in range(len(pixel_counts)): 213 | num_pixels += pixel_counts[i] 214 | percentages = [] 215 | for i in range(len(pixel_counts)): 216 | percentages.append(float(pixel_counts[i]/num_pixels)) 217 | 218 | try: 219 | output_csv = os.path.join(path, csv_name) 220 | if not os.path.isfile(output_csv): 221 | with open(output_csv, "wb") as csvfile: 222 | writer = csv.writer(csvfile) 223 | writer.writerow(["Source", "White Ice", "Gray Ice", "Melt Ponds", "Open Water", "Shadow", 224 | "Prcnt White Ice", "Prcnt Gray Ice", "Prcnt Melt Ponds", "Prcnt Open Water", "Prcnt Shadow"]) 225 | writer.writerow([image_name, pixel_counts[0], pixel_counts[1], pixel_counts[2], pixel_counts[3], pixel_counts[4], 226 | percentages[0], percentages[1], percentages[2], percentages[3], percentages[4]]) 227 | 228 | else: 229 | with open(output_csv, "ab+") as csvfile: 230 | writer = csv.writer(csvfile) 231 | writer.writerow([image_name, pixel_counts[0], pixel_counts[1], pixel_counts[2], pixel_counts[3], pixel_counts[4], 232 | percentages[0], percentages[1], percentages[2], percentages[3], percentages[4]]) 233 | except: 234 | print("error saving csv") 235 | print(pixel_counts) 236 | 237 | def write_to_database(db_name, path, image_id, part, pixel_counts): 238 | ''' 239 | INPUT: 240 | db_name: filename of the database 241 | path: location where the database is stored 242 | pixel_clounts: number of pixels in each classification category 243 | [snow, gray, melt, water, shadow] 244 | 245 | Writes the classification pixel counts to the database at the image_id entry 246 | NOTES: 247 | For now, this overwrites existing data. 248 | FUTURE: 249 | Develop method that checks for existing data and appends the current 250 | data to that, record which parts contributed to the total. 251 | 252 | ''' 253 | # Convert pixel_counts into percentages and total area 254 | area = 1 #Prevent division by 0 255 | for i in range(len(pixel_counts)): 256 | area += pixel_counts[i] 257 | percentages = [] 258 | for i in range(len(pixel_counts)): 259 | percentages.append(float(pixel_counts[i]/area)) 260 | # Open the database 261 | conn = sqlite3.connect(os.path.join(path,db_name)) 262 | 263 | # Update the entry at image_id with the given pixel counts 264 | conn.execute("UPDATE DigitalGlobe \ 265 | SET AREA = {0:d}, SNOW = {1:f}, GRAY = {2:f}, MP = {3:f}, \ 266 | OW = {4:f}, PART = {5:s} \ 267 | WHERE NAME = '{6:s}' \ 268 | ".format(int(area), percentages[0], percentages[1], 269 | percentages[2], percentages[3], part, image_id) 270 | ) 271 | # Commit the changes 272 | conn.commit() 273 | # Close the database 274 | conn.close() 275 | 276 | #### Recombine classified image splits 277 | def stitch(image_files, save_path=None): 278 | ''' 279 | INPUT: 280 | image_files: list of the image splits for recombination 281 | RETURN: 282 | full_classification: full image stitched back together 283 | 284 | NOTES: 285 | Currently only implemented to recombine classified images, but the 286 | method could work with any image data. 287 | There are two levels of recombination. Recompiling the subimages (see 288 | method below) and recompiling the splits (this method) 289 | ''' 290 | 291 | # Check to see if we have a square number of images 292 | # This method relies on floating point precision, but 293 | # will be accurate within the scope of this method 294 | root = math.sqrt(len(image_files)) 295 | if int(root) != root: 296 | print("Incomplete set of images!") 297 | return None 298 | 299 | classified_list = [] 300 | 301 | image_files.sort() 302 | 303 | ## Read the classified data and the original image data 304 | for image in image_files: 305 | with h5py.File(image,'r') as inputfile: 306 | classified_image = inputfile['classified'][:] 307 | classified_list.append(classified_image) 308 | 309 | 310 | # Find the right dimensions for stitching the images back together 311 | box_side = int(math.sqrt(len(classified_list))) 312 | 313 | # Stitch the classified image back together 314 | full_classification = compile_subimages(classified_list,box_side,box_side) 315 | 316 | # if os.path.isdir(save_path): 317 | # output_name = os.path.join(save_path, os.path.split(image_files[0])[1][:-18]) 318 | # fout = h5py.File(output_name + "_classified.h5",'w') 319 | # fout.create_dataset('classified',data=full_classification,compression='gzip',compression_opts=9) 320 | # fout.close() 321 | # else: 322 | # save_color(full_classification, image_files[0][:-18] + "_classified_image.png") 323 | # fout = h5py.File(image_files[0][:-18] + "_classified.h5",'w') 324 | # fout.create_dataset('classified',data=full_classification,compression='gzip',compression_opts=9) 325 | # fout.close() 326 | 327 | return full_classification 328 | 329 | 330 | def compile_subimages(subimage_list, num_x_subimages, num_y_subimages, bands=1): 331 | ''' 332 | Compiles the subimages (i.e. blocks) of a split into one raster 333 | INPUT: 334 | subimage_list: the list of subimages, in left to right top to bottom order 335 | num_x_subimages: number of subimages in the x dimension 336 | num_y_subimages: number of subimages in the y dimension 337 | bands: number of spectral bands of the input image 338 | RETURNS: 339 | compiled_image: single [x,y,b] image 340 | ''' 341 | x_size = np.shape(subimage_list[0])[1] 342 | y_size = np.shape(subimage_list[0])[0] 343 | 344 | if bands != 1: 345 | compiled_image = np.zeros([num_y_subimages*y_size, 346 | num_x_subimages*x_size, 347 | bands],dtype='uint8') 348 | 349 | counter = 0 350 | for y in range(num_y_subimages): 351 | for x in range(num_x_subimages): 352 | compiled_image[y*y_size:(y+1)*y_size, 353 | x*x_size:(x+1)*x_size, 354 | :] = subimage_list[counter] 355 | counter += 1 356 | else: 357 | compiled_image = np.zeros([num_y_subimages*y_size, 358 | num_x_subimages*x_size], 359 | dtype='uint8') 360 | counter = 0 361 | for y in range(num_y_subimages): 362 | for x in range(num_x_subimages): 363 | compiled_image[y*y_size:(y+1)*y_size, x*x_size:(x+1)*x_size] = subimage_list[counter] 364 | counter += 1 365 | 366 | return compiled_image 367 | 368 | #### Saves an image with custom colormap 369 | def save_color(image, save_name, custom_colormap=False): 370 | '''' 371 | INPUTS: 372 | image: The image you want to save 373 | save_name: full name and filepath where you want the image to go 374 | custom_colormap: matplotlib colormap if you want to use your own 375 | defaults to nwright's 5 class colormap 376 | 377 | Saves a .png of the input image with desired colormap 378 | ''' 379 | 380 | if custom_colormap is False: 381 | # Colors for the output image 382 | empty_color = [.1,.1,.1] #Almost black 383 | snow_color = [.9,.9,.9] #Almost white 384 | pond_color = [.31,.431,.647] #Blue 385 | gray_color = [.5,.5,.5] #Gray 386 | water_color = [0.,0.,0.] #Black 387 | shadow_color = [1.0, .545, .0] #Orange 388 | cloud_color = [.27, .15, .50] #Purple 389 | 390 | # custom_colormap = [empty_color,snow_color,gray_color,pond_color,water_color,shadow_color,cloud_color] 391 | custom_colormap = [empty_color, snow_color, gray_color, pond_color, pond_color, water_color, cloud_color] 392 | custom_colormap = colors.ListedColormap(custom_colormap) 393 | 394 | #Making sure there is atleast one of every pixel so the colors map properly (only changes 395 | # display image, not saved data) 396 | image[0][0] = 0 397 | image[1][0] = 1 398 | image[2][0] = 2 399 | image[3][0] = 3 400 | image[4][0] = 4 401 | image[5][0] = 5 402 | image[6][0] = 6 403 | 404 | mimg.imsave(save_name, image, format='png', cmap=custom_colormap) 405 | 406 | #### Count the number of pixels in each classification category of given image 407 | def count_features(classified_image): 408 | 409 | sum_snow = float(len(classified_image[classified_image==1.0])) 410 | sum_gray_ice = float(len(classified_image[classified_image==2.0])) 411 | sum_melt_ponds = float(len(classified_image[classified_image==3.0])) 412 | sum_open_water = float(len(classified_image[classified_image==4.0])) 413 | sum_shadow = float(len(classified_image[classified_image==5.0])) 414 | 415 | # num_pixels = sum_snow + sum_gray_ice + sum_melt_ponds + sum_open_water 416 | 417 | return sum_snow, sum_gray_ice, sum_melt_ponds, sum_open_water, sum_shadow 418 | 419 | 420 | def get_image_paths(folder,keyword='.h5',strict=True): 421 | ''' 422 | Code from http://chriskiehl.com/article/parallelism-in-one-line/ 423 | Returns a list of .h5 files in the given folder. 424 | Strict flag restricts keyword to the extension, non strict will find the 425 | keyword anywhere in the filename 426 | ''' 427 | if strict: 428 | return (os.path.join(folder, f) 429 | for f in os.listdir(folder) 430 | if (keyword in os.path.splitext(f)[1].lower() 431 | and os.path.splitext(f)[0][0] != '.')) 432 | else: 433 | return (os.path.join(folder, f) 434 | for f in os.listdir(folder) 435 | if (keyword in f.lower() 436 | and os.path.splitext(f)[0][0] != '.')) 437 | 438 | # Remove hidden folders and files from the given list of strings (mac) 439 | def remove_hidden(folder): 440 | i = 0 441 | while i < len(folder): 442 | if folder[i][0] == '.': 443 | folder.pop(i) 444 | else: 445 | i+=1 446 | return folder 447 | 448 | # Combines multiple bands (RBG) into one 3D array 449 | # Adapted from: http://gis.stackexchange.com/questions/120951/merging-multiple-16-bit-image-bands-to-create-a-true-color-tiff 450 | # Useful band combinations: http://c-agg.org/cm_vault/files/docs/WorldView_band_combs__2_.pdf 451 | def create_composite(band_list, dtype=np.uint8): 452 | img_dim = np.shape(band_list[0]) 453 | num_bands = len(band_list) 454 | img = np.zeros((img_dim[0], img_dim[1], num_bands), dtype=dtype) 455 | for i in range(num_bands): 456 | img[:,:,i] = band_list[i] 457 | 458 | return img 459 | 460 | # Plots a confusion matrix. Adapted from 461 | # http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py 462 | # 463 | def plot_confusion_matrix(cm,categories,ylabel,xlabel, 464 | normalize=False, 465 | title='', 466 | cmap=plt.cm.Blues): 467 | """ 468 | This function prints and plots the confusion matrix. 469 | Normalization can be applied by setting `normalize=True`. 470 | """ 471 | font = {'family' : 'Times New Roman', 472 | 'weight' : 'bold', 473 | 'size' : 12} 474 | 475 | matplotlib.rc('font', **font) 476 | 477 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 478 | plt.title(title) 479 | tick_marks = np.arange(len(categories)) 480 | plt.xticks(tick_marks, categories, rotation=45) 481 | plt.yticks(tick_marks, categories) 482 | 483 | if normalize: 484 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 485 | print("Normalized confusion matrix") 486 | else: 487 | print('Confusion matrix, without normalization') 488 | 489 | print(cm) 490 | 491 | thresh = cm.max() / 4. 492 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 493 | plt.text(j, i, cm[i, j], 494 | horizontalalignment="center", 495 | color="white" if cm[i, j] > thresh else "black") 496 | 497 | plt.tight_layout() 498 | plt.ylabel(ylabel) 499 | plt.xlabel(xlabel) 500 | plt.show() -------------------------------------------------------------------------------- /ossp_process.py: -------------------------------------------------------------------------------- 1 | # OSSP Process 2 | # Usage: Fully processes all images in the given directory with the given training data. 3 | # Nicholas Wright 4 | 5 | import os 6 | import time 7 | import argparse 8 | import csv 9 | import numpy as np 10 | from multiprocessing import Process, RLock, Queue 11 | import preprocess as pp 12 | from segment import segment_image 13 | from classify import classify_image 14 | from lib import utils 15 | import gdal 16 | 17 | 18 | def main(): 19 | # Set Up Arguments 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("input_dir", 22 | help='''directory path containing date directories of 23 | images to be processed''') 24 | parser.add_argument("image_type", type=str, choices=["srgb", "wv02_ms", "pan"], 25 | help="image type: 'srgb', 'wv02_ms', 'pan'") 26 | parser.add_argument("training_dataset", 27 | help="training data file") 28 | parser.add_argument("--training_label", type=str, default=None, 29 | help="name of training classification list") 30 | parser.add_argument("-o", "--output_dir", type=str, default="default", 31 | help="directory to place output results.") 32 | parser.add_argument("-v", "--verbose", action="store_true", 33 | help="display text information and progress") 34 | parser.add_argument("-c", "--stretch", 35 | type=str, 36 | choices=["hist", "pansh", "none"], 37 | default='hist', 38 | help='''Apply image correction/stretch to input: \n 39 | hist: Histogram stretch \n 40 | pansh: Orthorectify / Pansharpen for MS WV images \n 41 | none: No correction''') 42 | parser.add_argument("--pgc_script", type=str, default=None, 43 | help="Path for the pansharpening script if needed") 44 | parser.add_argument("-t", "--threads", type=int, default=1, 45 | help="Number of subprocesses to start") 46 | 47 | # Parse Arguments 48 | args = parser.parse_args() 49 | 50 | # System filepath that contains the directories or files for batch processing 51 | user_input = args.input_dir 52 | if os.path.isdir(user_input): 53 | src_dir = user_input 54 | src_file = '' 55 | elif os.path.isfile(user_input): 56 | src_dir, src_file = os.path.split(user_input) 57 | else: 58 | raise IOError('Invalid input') 59 | # Image type, choices are 'srgb', 'pan', or 'wv02_ms' 60 | image_type = args.image_type 61 | # File with the training data 62 | tds_file = args.training_dataset 63 | # Default tds label is the image type 64 | if args.training_label is None: 65 | tds_label = image_type 66 | else: 67 | tds_label = args.training_label 68 | # Default output directory 69 | # (if not provided this gets set when the tasks are created) 70 | dst_dir = args.output_dir 71 | threads = args.threads 72 | verbose = args.verbose 73 | stretch = args.stretch 74 | 75 | 76 | # Use the given pansh script path, otherwise search for the correct folder 77 | # in the same directory as this script. 78 | if args.pgc_script: 79 | pansh_script_path = args.pgc_script 80 | else: 81 | current_path = os.path.dirname(os.path.realpath(__file__)) 82 | pansh_script_path = os.path.join(os.path.split(current_path)[0], 'imagery_utils') 83 | 84 | # For Ames OIB Processing: 85 | # White balance flag (To add as user option in future, presently only used on oib imagery) 86 | if image_type == 'srgb': 87 | assess_quality = True 88 | white_balance = True 89 | else: 90 | assess_quality = False 91 | white_balance = False 92 | # Set a default quality score until this value is calculated 93 | quality_score = 1. 94 | 95 | # Prepare a list of images to be processed based on the user input 96 | # list of task objects based on the files in the input directory. 97 | # Each task is an image to process, and has a subtask for each split 98 | # of that image. 99 | task_list = utils.create_task_list(os.path.join(src_dir, src_file), dst_dir) 100 | 101 | for task in task_list: 102 | 103 | # ASP: Restrict processing to the frame range 104 | # try: 105 | # frameNum = getFrameNumberFromFilename(file) 106 | # except Exception, e: 107 | # continue 108 | # if (frameNum < args.min_frame) or (frameNum > args.max_frame): 109 | # continue 110 | 111 | # Skip this task if it is already marked as complete 112 | if task.is_complete(): 113 | continue 114 | 115 | # Make the output directory if it doesnt already exist 116 | if not os.path.isdir(task.get_dst_dir()): 117 | os.makedirs(task.get_dst_dir()) 118 | 119 | # Run Ortho/Pan scripts if necessary 120 | if stretch == 'pansh': 121 | if verbose: 122 | print("Orthorectifying and Pansharpening image...") 123 | 124 | full_image_name = os.path.join(task.get_src_dir(), task.get_id()) 125 | pansh_filepath = pp.run_pgc_pansharpen(pansh_script_path, 126 | full_image_name, 127 | task.get_dst_dir()) 128 | 129 | # Set the image name/dir to the pan output name/dir 130 | task.set_src_dir(task.get_dst_dir()) 131 | task.change_id(pansh_filepath) 132 | 133 | # Open the image dataset with gdal 134 | full_image_name = os.path.join(task.get_src_dir(), task.get_id()) 135 | if os.path.isfile(full_image_name): 136 | if verbose: 137 | print("Loading image {}...".format(task.get_id())) 138 | src_ds = gdal.Open(full_image_name, gdal.GA_ReadOnly) 139 | else: 140 | print("File not found: {}".format(full_image_name)) 141 | continue 142 | 143 | # Read metadata to get image date and keep only the metadata we need 144 | metadata = src_ds.GetMetadata() 145 | image_date = pp.parse_metadata(metadata, image_type) 146 | metadata = [image_type, image_date] 147 | 148 | # For processing icebridge imagery: 149 | if image_type == 'srgb': 150 | if image_date <= 150: 151 | tds_label = 'spring' 152 | white_balance = True 153 | else: 154 | tds_label = 'summer' 155 | 156 | # Load Training Data 157 | tds = utils.load_tds(tds_file, tds_label, image_type) 158 | # tds = utils.load_tds(tds_file, 'srgb', image_type) 159 | 160 | if verbose: 161 | print("Size of training set: {}".format(len(tds[1]))) 162 | 163 | # Set necessary parameters for reading image 1 block at a time 164 | x_dim = src_ds.RasterXSize 165 | y_dim = src_ds.RasterYSize 166 | desired_block_size = 6400 167 | 168 | src_dtype = gdal.GetDataTypeSize(src_ds.GetRasterBand(1).DataType) 169 | # Analyze input image histogram (if applying correction) 170 | if stretch == 'hist': 171 | stretch_params = pp.histogram_threshold(src_ds, src_dtype) 172 | else: # stretch == 'none': 173 | # WV Images are actually 11bit stored in 16bit files 174 | if src_dtype > 12: 175 | src_dtype = 11 176 | stretch_params = [1, 2**src_dtype - 1, 177 | [2 ** src_dtype - 1 for _ in range(src_ds.RasterCount)], 178 | [1 for _ in range(src_ds.RasterCount)]] 179 | 180 | # Create a blank output image dataset 181 | # Save the classified image output as a geotiff 182 | fileformat = "GTiff" 183 | image_name_noext = os.path.splitext(task.get_id())[0] 184 | dst_filename = os.path.join(task.get_dst_dir(), image_name_noext + '_classified.tif') 185 | driver = gdal.GetDriverByName(fileformat) 186 | dst_ds = driver.Create(dst_filename, xsize=x_dim, ysize=y_dim, 187 | bands=1, eType=gdal.GDT_Byte, options=["TILED=YES", "COMPRESS=LZW"]) 188 | 189 | # Transfer the metadata from input image 190 | # dst_ds.SetMetadata(src_ds.GetMetadata()) 191 | # Transfer the input projection and geotransform if they are different than the default 192 | if src_ds.GetGeoTransform() != (0, 1, 0, 0, 0, 1): 193 | dst_ds.SetGeoTransform(src_ds.GetGeoTransform()) # sets same geotransform as input 194 | if src_ds.GetProjection() != '': 195 | dst_ds.SetProjection(src_ds.GetProjection()) # sets same projection as input 196 | 197 | # Find the appropriate image block read size 198 | block_size_x, block_size_y = utils.find_blocksize(x_dim, y_dim, desired_block_size) 199 | if verbose: 200 | print("block size: [{},{}]".format(block_size_x, block_size_y)) 201 | 202 | # close the source dataset so that it can be loaded by each thread seperately 203 | src_ds = None 204 | lock = RLock() 205 | block_queue, qsize = construct_block_queue(block_size_x, block_size_y, x_dim, y_dim) 206 | dst_queue = Queue() 207 | 208 | # Display a progress bar 209 | if verbose: 210 | try: 211 | from tqdm import tqdm 212 | except ImportError: 213 | print("Install tqdm to display progress bar.") 214 | verbose = False 215 | else: 216 | pbar = tqdm(total=qsize, unit='block') 217 | 218 | # Set an empty value for the pixel counter 219 | pixel_counts = [0, 0, 0, 0, 0] 220 | 221 | NUMBER_OF_PROCESSES = threads 222 | block_procs = [Process(target=process_block_queue, 223 | args=(lock, block_queue, dst_queue, full_image_name, 224 | assess_quality, stretch_params, white_balance, tds, metadata)) 225 | for _ in range(NUMBER_OF_PROCESSES)] 226 | 227 | for proc in block_procs: 228 | # Add a stop command to the end of the queue for each of the 229 | # processes started. This will signal for the process to stop. 230 | block_queue.put('STOP') 231 | # Start the process 232 | proc.start() 233 | 234 | # Collect data from processes as they complete tasks 235 | finished_threads = 0 236 | while finished_threads < NUMBER_OF_PROCESSES: 237 | 238 | if not dst_queue.empty(): 239 | val = dst_queue.get() 240 | if val is None: 241 | finished_threads += 1 242 | else: 243 | # Keep only the lowest quality score found 244 | quality_score_block = val[0] 245 | if quality_score_block < quality_score: 246 | quality_score = quality_score_block 247 | # Add the pixel counts to the master list 248 | pixel_counts_block = val[1] 249 | for i in range(len(pixel_counts)): 250 | pixel_counts[i] += pixel_counts_block[i] 251 | # Write image data to output dataset 252 | x = val[2] 253 | y = val[3] 254 | classified_block = val[4] 255 | dst_ds.GetRasterBand(1).WriteArray(classified_block, xoff=x, yoff=y) 256 | dst_ds.FlushCache() 257 | # Update the progress bar 258 | if verbose: pbar.update() 259 | # Give the other threads some time to finish their tasks. 260 | else: 261 | time.sleep(10) 262 | 263 | # Update the progress bar 264 | if verbose: pbar.update() 265 | 266 | # Join all of the processes back together 267 | for proc in block_procs: 268 | proc.join() 269 | 270 | # Close dataset and write to disk 271 | dst_ds = None 272 | 273 | # Write extra data (total pixel counts and quality score to the database (or csv) 274 | output_csv = os.path.join(task.get_dst_dir(), image_name_noext + '_md.csv') 275 | with open(output_csv, "w") as csvfile: 276 | writer = csv.writer(csvfile) 277 | writer.writerow(["Quality Score", "White Ice", "Gray Ice", "Melt Ponds", "Open Water", "Shadow"]) 278 | writer.writerow([quality_score, pixel_counts[0], pixel_counts[1], pixel_counts[2], 279 | pixel_counts[3], pixel_counts[4]]) 280 | 281 | # Close the progress bar 282 | if verbose: 283 | pbar.close() 284 | print("Finished Processing.") 285 | 286 | 287 | def construct_block_queue(block_size_x, block_size_y, x_dim, y_dim): 288 | # Convert the block size into a list of the top (y) left (x) coordinate of each block 289 | # and iterate over both lists to process each block 290 | y_blocks = range(0, y_dim, block_size_y) 291 | x_blocks = range(0, x_dim, block_size_x) 292 | qsize = 0 293 | # Construct a queue of block coordinates 294 | block_queue = Queue() 295 | for y in y_blocks: 296 | for x in x_blocks: 297 | # Check that this block will lie within the image dimensions 298 | read_size_y = check_read_size(y, block_size_y, y_dim) 299 | read_size_x = check_read_size(x, block_size_x, x_dim) 300 | # Store variables needed to read each block from source dataset in queue 301 | block_queue.put((x, y, read_size_x, read_size_y)) 302 | qsize += 1 303 | 304 | return block_queue, qsize 305 | 306 | 307 | def process_block_queue(lock, block_queue, dst_queue, full_image_name, 308 | assess_quality, stretch_params, white_balance, tds, im_metadata): 309 | ''' 310 | Function run by each process. Will process blocks placed in the block_queue until the 'STOP' command is reached. 311 | ''' 312 | # Parse input arguments 313 | lower, upper, wb_reference, bp_reference = stretch_params 314 | wb_reference = np.array(wb_reference, dtype=np.float) 315 | bp_reference = np.array(bp_reference, dtype=np.float) 316 | image_type = im_metadata[0] 317 | 318 | for block_indices in iter(block_queue.get, 'STOP'): 319 | 320 | x, y, read_size_x, read_size_y = block_indices 321 | # Load block data with gdal (offset and block size) 322 | lock.acquire() 323 | src_ds = gdal.Open(full_image_name, gdal.GA_ReadOnly) 324 | image_data = src_ds.ReadAsArray(x, y, read_size_x, read_size_y) 325 | src_ds = None 326 | lock.release() 327 | 328 | # Restructure raster for panchromatic images: 329 | if image_data.ndim == 2: 330 | image_data = np.reshape(image_data, (1, read_size_y, read_size_x)) 331 | 332 | # Calculate the quality score on an arbitrary band 333 | if assess_quality: 334 | quality_score = pp.calc_q_score(image_data[0]) 335 | else: 336 | quality_score = 1. 337 | # Apply correction to block based on earlier histogram analysis (if applying correction) 338 | # Converts image to 8 bit by rescaling lower -> 1 and upper -> 255 339 | image_data = pp.rescale_band(image_data, lower, upper) 340 | if white_balance: 341 | # Applies a white balance correction 342 | image_data = pp.white_balance(image_data, wb_reference, np.amax(wb_reference)) 343 | 344 | # Segment image 345 | segmented_blocks = segment_image(image_data, image_type=image_type) 346 | 347 | # Classify image 348 | classified_block = classify_image(image_data, segmented_blocks, 349 | tds, im_metadata, wb_reference, bp_reference) 350 | 351 | # Add the pixel counts from this classified split to the 352 | # running total. 353 | pixel_counts_block = utils.count_features(classified_block) 354 | 355 | # Pass the data back to the main thread for writing 356 | dst_queue.put((quality_score, pixel_counts_block, x, y, classified_block)) 357 | 358 | dst_queue.put(None) 359 | 360 | 361 | def check_read_size(y, block_size_y, y_dim): 362 | if y + block_size_y < y_dim: 363 | return block_size_y 364 | else: 365 | return y_dim - y 366 | 367 | 368 | if __name__ == "__main__": 369 | main() 370 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # Code for preprocessing imagery for analysis by the OSSP image processing 2 | # algorithm. Includes methods to split imagery into more manageable sections, 3 | # and methods for histogram stretching to scale input images to the full 0,255 4 | # range. 5 | # Nicholas Wright 6 | # 11/30/17 7 | 8 | import os 9 | import datetime 10 | import subprocess 11 | import numpy as np 12 | import matplotlib.image as mimg 13 | from skimage.measure import block_reduce 14 | from lib import utils, rescale_intensity 15 | 16 | 17 | def rescale_band(band, bottom, top): 18 | """ 19 | Rescale and image data from range [bottom,top] to uint8 ([0,255]) 20 | """ 21 | imin, imax = (bottom, top) 22 | omin, omax = (1, 255) 23 | 24 | # Rescale intensity takes a uint16 dtype input 25 | band = band.astype(np.uint16) 26 | 27 | return rescale_intensity.rescale_intensity(band, imin, imax, omin, omax) 28 | 29 | 30 | def white_balance(band, reference, omax): 31 | 32 | return rescale_intensity.white_balance(band, reference, omax) 33 | 34 | 35 | def run_pgc_pansharpen(script_path, input_filepath, output_dir): 36 | 37 | base_cmd = os.path.join(script_path, 'pgc_pansharpen.py') 38 | 39 | cmd = 'python {} --epsg 3413 -c rf -t Byte --resample cubic {} {}'.format( 40 | base_cmd, 41 | input_filepath, 42 | output_dir) 43 | 44 | # Spawn a subprocess to execute the above command 45 | proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, 46 | stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 47 | 48 | proc.wait() 49 | 50 | # Copying PGC Naming convention, written to match above command 51 | basename = os.path.splitext(os.path.split(input_filepath)[-1])[0] 52 | pansh_filename = "{}_{}{}{}_pansh.tif".format(basename, 'u08', 'rf', '3413') 53 | 54 | return pansh_filename 55 | 56 | 57 | def find_blocksize(x_dim, y_dim, desired_size): 58 | """ 59 | Finds the appropriate block size for an input image of a given dimensions. 60 | Method returns the first factor of the input dimension that is greater than 61 | the desired size. 62 | """ 63 | block_size_x = desired_size 64 | block_size_y = desired_size 65 | 66 | # Ensure that chosen block size divides into the image dimension with a remainder that is 67 | # at least half a standard block in width. 68 | while (x_dim % block_size_x) <= (block_size_x / 2): 69 | block_size_x += 256 70 | # Make sure the blocks don't get too big. 71 | if block_size_x >= x_dim: 72 | block_size_x = x_dim 73 | break 74 | 75 | while (y_dim % block_size_y) <= (block_size_y / 2): 76 | block_size_y += 256 77 | if block_size_y >= y_dim: 78 | block_size_y = y_dim 79 | break 80 | 81 | return block_size_x, block_size_y 82 | 83 | 84 | def calc_q_score(image): 85 | """ 86 | Calculates a quality score of an input image by determining the number of 87 | high frequency peaks in the fourier transformed image relative to the 88 | image size. 89 | QA Score < 0.025 poor 90 | 0.25 < QA Score < 0.035 medium 91 | QA Score > 0.035 fine 92 | 93 | """ 94 | # Calculate the 2D fourier transform of the image 95 | im_fft = np.fft.fft2(image) 96 | # Find the maximum frequency peak in the fft image 97 | max_freq = np.amax(np.abs(im_fft)) 98 | # Set a threshold that is a fraction of the max peak 99 | # (Fraction determined empirically) 100 | thresh = max_freq / 100000 101 | # Determine the number of pixels above the threshold 102 | th = np.sum([im_fft>thresh]) 103 | # QA Score is the percent of the pixels that are greater than the threshold 104 | qa_score = float(th) / np.size(image) 105 | 106 | return qa_score 107 | 108 | 109 | def parse_metadata(metadata, image_type): 110 | """ 111 | Parse image metadata information to find date. 112 | If image date cannot be found, return mean date of melt season (June1). 113 | This is likely to have to least impact on decision tree outcomes, as there 114 | will be less bias towards no melt vs lots of melt, but this should be tested. 115 | If 0 were to be used, then the decision tree would see that as very early 116 | season, since date is a numeric feature, and not a categorical one. 117 | """ 118 | try: 119 | if image_type == 'srgb': 120 | header_date = metadata['EXIF_DateTime'] 121 | image_date = header_date[5:7] + header_date[8:10] 122 | yyyy = 2014 123 | mm = image_date[:2] 124 | dd = image_date[2:] 125 | elif image_type == 'pan' or image_type == 'wv02_ms': 126 | # image_date = metadata['NITF_STDIDC_ACQUISITION_DATE'][4:8] 127 | image_date = metadata['NITF_IDATIM'][0:8] 128 | yyyy = image_date[0:4] 129 | mm = image_date[4:6] 130 | dd = image_date[6:] 131 | except KeyError: 132 | # Use June 1 as default date 133 | yyyy = 2014 134 | mm = 6 135 | dd = 1 136 | 137 | # Convert the date to julian day format (number of days since Jan 1) 138 | d = datetime.date(int(yyyy), int(mm), int(dd)) 139 | doy = d.toordinal() - datetime.date(d.year, 1, 1).toordinal() + 1 140 | 141 | return doy 142 | 143 | 144 | def histogram_threshold(gdal_dataset, src_dtype): 145 | # Set the percentile thresholds at a temporary value until finding the 146 | # appropriate ones considering all bands. These numbers are chosen to 147 | # always get reset on first loop (for bitdepth <= uint16) 148 | lower = 2048 149 | upper = -1 150 | 151 | # Determine the number of bands in the dataset 152 | band_count = gdal_dataset.RasterCount 153 | # White balance reference points 154 | wb_reference = [0 for _ in range(band_count)] 155 | bp_reference = [0 for _ in range(band_count)] 156 | # Determine the input datatype 157 | 158 | if src_dtype > 8: 159 | max_bit = 2047 160 | upper_limit = 0.25 161 | else: 162 | max_bit = 255 163 | upper_limit = 0.8 164 | 165 | total_peaks = 0 166 | 167 | # First for loop finds the threshold based on all bands 168 | for b in range(1, band_count + 1): 169 | 170 | # Read the band information from the gdal dataset 171 | band = gdal_dataset.GetRasterBand(b) 172 | 173 | # Find the min and max image values 174 | bmin, bmax = band.ComputeRasterMinMax() 175 | 176 | # Determine the histogram using gdal 177 | nbins = int(bmax - bmin) 178 | hist = band.GetHistogram(bmin, bmax, nbins, approx_ok=0) 179 | bin_centers = range(int(bmin), int(bmax)) 180 | bin_centers = np.array(bin_centers) 181 | 182 | # Remove the image data from memory for now 183 | band = None 184 | 185 | # Find the strongest (3) peaks in the band histogram 186 | peaks = find_peaks(hist, bin_centers) 187 | # Tally the total number of peaks found across all bands 188 | total_peaks += len(peaks) 189 | # Find the high and low threshold for rescaling image intensity 190 | lower_b, upper_b, auto_wb, auto_bpr = find_threshold(hist, bin_centers, 191 | peaks, src_dtype) 192 | wb_reference[b-1] = auto_wb 193 | bp_reference[b-1] = auto_bpr 194 | # For sRGB we want to scale each band by the min and max of all 195 | # bands. Check thresholds found for this band against any that 196 | # have been previously found, and adjust if necessary. 197 | if lower_b < lower: 198 | lower = lower_b 199 | if upper_b > upper: 200 | upper = upper_b 201 | 202 | # If there is only a single peak per band, we need an upper limit. The upper 203 | # limit is to prevent open water only images from being stretched. 204 | if total_peaks <= band_count: 205 | max_range = int(max_bit * upper_limit) 206 | if upper < max_range: 207 | upper = max_range 208 | 209 | return lower, upper, wb_reference, bp_reference 210 | 211 | 212 | def find_peaks(hist, bin_centers): 213 | """ 214 | Finds the three strongest peaks in a given band. 215 | Criteria for each peak: 216 | Distance to the nearest neighboring peak is greater than one third the approx. dynamic range of the input image 217 | Has a minimum number of pixels in that peak, loosely based on image size 218 | Is greater than the directly adjacent bins, and the bins +/- 5 away 219 | """ 220 | 221 | # Roughly define the smallest acceptable size of a peak based on the number of pixels 222 | # in the largest bin. 223 | # min_count = int(max(hist)*.06) 224 | min_count = int(np.sum(hist)*.004) 225 | 226 | # First find all potential peaks in the histogram 227 | peaks = [] 228 | 229 | # Check the lowest histogram bin 230 | if hist[0] >= hist[1] and hist[0] >= hist[5]: 231 | if hist[-1] > min_count: 232 | peaks.append(bin_centers[0]) 233 | 234 | # Check the middle bins 235 | for i in range(1, len(bin_centers) - 1): 236 | # Acceptable width of peak is +/-5, except in edge cases 237 | if i < 5: 238 | w_l = i 239 | w_u = 5 240 | elif i > len(bin_centers) - 6: 241 | w_l = 5 242 | w_u = len(bin_centers) - 1 - i 243 | else: 244 | w_l = 5 245 | w_u = 5 246 | # Check neighboring peaks 247 | if (hist[i] >= hist[i + 1] and hist[i] >= hist[i - 1] 248 | and hist[i] >= hist[i - w_l] and hist[i] >= hist[i + w_u]): 249 | if hist[i] > min_count: 250 | peaks.append(bin_centers[i]) 251 | # Check the highest histogram bin 252 | if hist[-1] >= hist[-2] and hist[-1] >= hist[-6]: 253 | if hist[-1] > min_count: 254 | peaks.append(bin_centers[-1]) 255 | 256 | num_peaks = len(peaks) 257 | distance = 5 # Initial distance threshold 258 | # One third the 'dynamic range' (radius from peak) 259 | distance_threshold = int((peaks[-1] - peaks[0]) / 6) 260 | # Min threshold 261 | if distance_threshold <= 5: 262 | distance_threshold = 5 263 | # Looking for three main peaks corresponding to the main surface types: 264 | # open water, MP and snow/ice 265 | # But any peak that passes the criteria is fine. 266 | while distance <= distance_threshold: 267 | i = 0 268 | to_remove = [] 269 | # Cycle through all of the peaks 270 | while i < num_peaks - 1: 271 | # Check the current peak against the adjacent one. If they are closer 272 | # than the threshold distance, delete the lower valued peak 273 | if peaks[i + 1] - peaks[i] < distance: 274 | if (hist[np.where(bin_centers == peaks[i])[0][0]] 275 | < hist[np.where(bin_centers == peaks[i + 1])[0][0]]): 276 | to_remove.append(peaks[i]) 277 | else: 278 | to_remove.append(peaks[i + 1]) 279 | # Because we don't need to check the next peak again: 280 | i += 1 281 | i += 1 282 | 283 | # Remove all of the peaks that did not meet the criteria above 284 | for j in to_remove: 285 | peaks.remove(j) 286 | 287 | # Recalculate the number of peaks left, and increase the distance threshold 288 | num_peaks = len(peaks) 289 | distance += 5 290 | return peaks 291 | 292 | 293 | def find_threshold(hist, bin_centers, peaks, src_dtype, top=0.15, bottom=0.5): 294 | """ 295 | Finds the upper and lower threshold for histogram stretching. 296 | Using the indices of the highest and lowest peak (by intensity, not # of pixels), this searches for an upper 297 | threshold that is both greater than the highest peak and has fewer than 15% the number of pixels, and a lower 298 | threshold that is both less than the lowest peak and has fewer than 50% the number of pixels. 299 | 10% and 50% picked empirically to give good results. 300 | """ 301 | max_peak = np.where(bin_centers == peaks[-1])[0][0] # Max intensity 302 | thresh_top = max_peak 303 | while hist[thresh_top] > hist[max_peak] * top: 304 | thresh_top += 2 # Upper limit is less sensitive, so step 2 at a time 305 | # In the case that the top peak is already at/near the max bit value, limit the top 306 | # threshold to be the top bin of the histogram. 307 | if thresh_top >= len(hist)-1: 308 | thresh_top = len(hist)-1 309 | break 310 | 311 | min_peak = np.where(bin_centers == peaks[0])[0][0] # Min intensity 312 | thresh_bot = min_peak 313 | while hist[thresh_bot] > hist[min_peak] * bottom: 314 | thresh_bot -= 1 315 | # Similar to above, limit the bottom threshold to the lowest histogram bin. 316 | if thresh_bot <= 0: 317 | thresh_bot = 0 318 | break 319 | 320 | # Convert the histogram bin index to an intensity value 321 | lower = bin_centers[thresh_bot] 322 | upper = bin_centers[thresh_top] 323 | 324 | # Save the upper value for the auto white balance function 325 | auto_wb = upper 326 | # Save the lower value for the black point reference 327 | auto_bpr = lower 328 | 329 | # Determine the width of the lower peak. 330 | lower_width = min_peak - thresh_bot 331 | dynamic_range = max_peak - min_peak 332 | 333 | # Limit the amount of stretch to a percentage of the total dynamic range 334 | # in the case that all three main surface types are not represented (fewer 335 | # than 3 peaks) 336 | # 8 bit vs 11 bit (WorldView) 337 | # 256 or 2048 338 | # While WV images are 11bit, white ice tends to be ~600-800 intensity 339 | # Provide a floor to the amount of stretch allowed 340 | if src_dtype > 8: 341 | max_bit = 2047 342 | else: 343 | max_bit = 255 344 | 345 | # If the width of the lowest peak is less than 3% of the bit depth, 346 | # then the lower peak is likely open water. 3% determined visually, but 347 | # ocean has a much narrower peak than ponds or ice. 348 | if (float(lower_width)/max_bit >= 0.03) or (dynamic_range < max_bit / 3): 349 | min_range = int(max_bit * .08) 350 | if lower > min_range: 351 | lower = min_range 352 | 353 | return lower, upper, auto_wb, auto_bpr 354 | 355 | 356 | def save_color_image(image_data, output_name, image_type, block_cols, block_rows): 357 | """ 358 | Write a rgb color image (as png) of the raw image data to disk. 359 | """ 360 | holder = [] 361 | # Find the appropriate bands to use for an rgb representation 362 | if image_type == 'wv02_ms': 363 | rgb = [5, 3, 2] 364 | elif image_type == 'srgb': 365 | rgb = [1, 2, 3] 366 | else: 367 | rgb = [1, 1, 1] 368 | 369 | red_band = image_data[rgb[0]] 370 | green_band = image_data[rgb[1]] 371 | blue_band = image_data[rgb[2]] 372 | 373 | for i in range(len(red_band)): 374 | holder.append(utils.create_composite([ 375 | red_band[i], green_band[i], blue_band[i]])) 376 | 377 | colorfullimg = utils.compile_subimages(holder, block_cols, block_rows, 3) 378 | mimg.imsave(output_name, colorfullimg) 379 | colorfullimg = None 380 | 381 | 382 | def downsample(band, factor): 383 | """ 384 | 'Downsample' an image by the given factor. Every pixel in the resulting image 385 | is the result of an average of the NxN kernel centered at that pixel, 386 | where N is factor. 387 | """ 388 | 389 | band_downsample = block_reduce(band, block_size=(factor, factor, 3), func=np.mean) 390 | 391 | band_copy = np.zeros(np.shape(band)) 392 | for i in range(np.shape(band_downsample)[0]): 393 | for j in range(np.shape(band_downsample)[1]): 394 | band_copy[i * factor:(i * factor) + factor, j * factor:j * factor + factor, :] = band_downsample[i, j, :] 395 | 396 | return band_copy 397 | 398 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # OSSP 2 | ## Open Source Sea-ice Processing 3 | ### Open Source Algorithm for Detecting Sea Ice Surface Features in High Resolution Optical Imagery 4 | 5 | ### Nicholas Wright and Chris Polashenski 6 | 7 | ## Introduction 8 | 9 | Welcome to OSSP; a set of tools for detecting surface features in high resolution optical imagery of sea ice. The primary focus is on the detection of and differentiation between open water, melt ponds, and snow/ice. 10 | 11 | The Anaconda distribution of Python is recommended, but any distribution with the appropriate packages will work. You can download Anaconda, version 3.6, here: https://www.continuum.io/downloads 12 | 13 | 14 | ## Dependencies 15 | 16 | * gdal (v2.0 or above) 17 | * numpy 18 | * scipy 19 | * h5py 20 | * scikit-image 21 | * sklearn 22 | * matplotlib 23 | * tkinter 24 | 25 | #### Optional 26 | * tqdm (for progress bar) 27 | * PGC imagery_utils (for WV pansharpening) (https://github.com/PolarGeospatialCenter/imagery_utils) 28 | 29 | ## Usage 30 | 31 | For detailed usage and installation instructions, see the pdf document 'Algorithm_Instructions.pdf' 32 | 33 | ### setup.py 34 | 35 | The first step is to run the setup.py script to compile C libraries. Run __python setup.py build\_ext --build-lib .__ from the OSSP directory. Be sure to include the period after --build-lib. 36 | 37 | ### ossp_process.py 38 | 39 | This combines all steps of the image classification scheme into one script and should be the only script to call directly. If given a folder of images, this script finds all appropriately formatted files directory (.tif(f) and .jpg) and queues them for processing. If given an image file, this script processes that single image alone. This script processes images as follows: Image preprocessing (histogram stretch or pansharpening if chosen) -> segmentation (segment.py) -> classification (classify.py) -> calculate statistics. Output results are saved as a geotiff with the same georeference of the input image. 40 | 41 | #### Required Arguments 42 | * __input directory__: directory containing all of the images you wish to process. Note that all .jpg and .tif images in the input directory as well as all sub-directories of it will be processed. Can also provide the path and filename to a single image to process only that image. 43 | * __image type__: {‘srgb’, ‘wv02_ms’, ‘pan'}: the type of imagery you are processing. 44 | 1. 'srgb': RGB imagery taken by a typical camera 45 | 2. 'wv02_ms': DigitalGlobe WorldView 2 multispectral imagery 46 | 3. 'pan': High resolution panchromatic imagery 47 | * __training dataset file__: filepath of the training dataset you wish to use to analyze the input imagery 48 | 49 | #### Optional Arguments 50 | 51 | * __-o | --output_dir__: Directory to write output files. 52 | * __-v | --verbose__: Display text output as algorithm progresses. 53 | * __-c | --stretch__: {'hist', 'pansh', 'none'}: Apply an image correction prior to classification. Pansharpening / orthorectification option requires PGC scripts. *Default = hist*. 54 | * __-t | --threads__: Number of subprocesses to spawn for classification. Threads > 2 is only utilized for images larger than ~10,000x10,000 pixels. 55 | * __--pgc_script__: Path for the PGC imagery_utils folder if 'pansh' was chosen for the image correction. 56 | * __--training\_label__: The label of a custom training dataset. See advanced section for details. *Default = image\_type*. 57 | 58 | #### Notes: 59 | 60 | Example: ossp\_process.py input\_dir im\_type training\_dataset\_file -v 61 | 62 | This example will process all .tif and .jpg files in the input\_dir. 63 | 64 | 65 | ### training_gui.py 66 | 67 | Graphical user interface for creating a custom training dataset. Provide a directory of images that you wish to use as the basis of your training set. The GUI will present a random segment each time a classification is assigned. The display images can also be clicked classify a specific area. The segments themselves are automatically generated. The highlighted region corresponds to the segment that will be labeled. 68 | 69 | Output is a .h5 file that can be provided to ossp\_process.py. 70 | 71 | Note: Images are segmented prior to display on the GUI, and as such may take up to a minute to load (depending on image size and computer specs) 72 | 73 | #### Positional Arguments: 74 | * __input__: A directory containing the images you wish to use for training. 75 | * __image type__: {‘srgb’, ‘wv02_ms’, ‘pan'}: the type of imagery you are processing. 76 | 1. 'srgb': RGB imagery taken by a typical camera 77 | 2. 'wv02_ms': DigitalGlobe WorldView 2 multispectral imagery, 78 | 3. 'pan': High resolution panchromatic imagery 79 | 80 | #### Optional arguments: 81 | * __--tds_file__: Existing training dataset file. Will create a new one with this name if none exists. If a path is not provided, file is created in the image directory. *Default = \_training\_data.h5*. 82 | * __--username__: A specific label to attach to the training set. The --training\_label argument of ossp_\process references this value. *Default = * 83 | 84 | ### Contact 85 | Nicholas Wright 86 | 87 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | # title: Watershed Transform 2 | # author: Nick Wright 3 | # adapted from: Justin Chen, Arnold Song 4 | 5 | import numpy as np 6 | import gc 7 | import warnings 8 | from skimage import filters, morphology, feature, img_as_ubyte 9 | from scipy import ndimage 10 | from ctypes import * 11 | from lib import utils 12 | 13 | # For Testing: 14 | from skimage import segmentation 15 | import matplotlib.image as mimg 16 | 17 | 18 | def segment_image(input_data, image_type=False): 19 | ''' 20 | Wrapper function that handles all of the processing to create watersheds 21 | ''' 22 | 23 | #### Define segmentation parameters 24 | # High_threshold: 25 | # Low_threshold: Lower threshold for canny edge detection. Determines which "weak" edges to keep. 26 | # Values above this amount that are connected to a strong edge will be marked as an edge. 27 | # Gauss_sigma: sigma value to use in the gaussian blur applied to the image prior to segmentation. 28 | # Value chosen here should be based on the quality and resolution of the image 29 | # Feature_separation: minimum distance, in pixels, between the center point of multiple features. Use a lower value 30 | # for lower resolution (.5m) images, and higher resolution for aerial images (~.1m). 31 | # These values are dependent on the type of imagery being processed, and are 32 | # mostly empirically derived. 33 | # band_list contains the three bands to be used for segmentation 34 | if image_type == 'pan': 35 | high_threshold = 0.15 * 255 ## Needs to be checked 36 | low_threshold = 0.05 * 255 ## Needs to be checked 37 | gauss_sigma = 1 38 | feature_separation = 1 39 | band_list = [0, 0, 0] 40 | elif image_type == 'wv02_ms': 41 | high_threshold = 0.20 * 255 ## Needs to be checked 42 | low_threshold = 0.05 * 255 ## Needs to be checked 43 | gauss_sigma = 1.5 44 | feature_separation = 3 45 | band_list = [4, 2, 1] 46 | else: #image_type == 'srgb' 47 | high_threshold = 0.15 * 255 48 | low_threshold = 0.05 * 255 49 | gauss_sigma = 2 50 | feature_separation = 5 51 | band_list = [0, 1, 2] 52 | 53 | segmented_data = watershed_transformation(input_data, band_list, low_threshold, high_threshold, 54 | gauss_sigma,feature_separation) 55 | 56 | # Method that provides the user an option to view the original image 57 | # side by side with the segmented image. 58 | # print(np.amax(segmented_data)) 59 | # image_data = np.array([input_data[band_list[0]], 60 | # input_data[band_list[1]], 61 | # input_data[band_list[2]]], 62 | # dtype=np.uint8) 63 | # ws_bound = segmentation.find_boundaries(segmented_data) 64 | # ws_display = utils.create_composite(image_data) 65 | # 66 | # # save_name = '/Users/nicholas/Desktop/original_{}.png' 67 | # # mimg.imsave(save_name.format(np.random.randint(0,100)), ws_display, format='png') 68 | # 69 | # ws_display[:, :, 0][ws_bound] = 240 70 | # ws_display[:, :, 1][ws_bound] = 80 71 | # ws_display[:, :, 2][ws_bound] = 80 72 | # 73 | # save_name = '/Users/nicholas/Desktop/seg_{}.png' 74 | # mimg.imsave(save_name.format(np.random.randint(0, 100)), ws_display, format='png') 75 | 76 | return segmented_data 77 | 78 | 79 | def watershed_transformation(image_data, band_list, low_threshold, high_threshold, gauss_sigma, feature_separation): 80 | ''' 81 | Runs a watershed transform on the main dataset 82 | 1. Create a gradient image using the sobel algorithm 83 | 2. Adjust the gradient image based on given threshold and amplification. 84 | 3. Find the local minimum gradient values and place a marker 85 | 4. Construct watersheds on top of the gradient image starting at the 86 | markers. 87 | ''' 88 | # If this block has no data, return a placeholder watershed. 89 | if np.amax(image_data[0]) <= 1: 90 | # We just need the dimensions from one band 91 | return np.zeros(np.shape(image_data[0])) 92 | 93 | # Build a raster of detected edges to inform the creation of watershed seed points 94 | edge_image = edge_detect(image_data, band_list, gauss_sigma, low_threshold, high_threshold) 95 | # Build a raster of image gradient that will be the base for watershed expansion. 96 | grad_image = build_gradient(image_data, band_list, gauss_sigma) 97 | image_data = None 98 | 99 | # Find local minimum values in the edge image by inverting 100 | # edge_image and finding the local maximum values 101 | inv_edge = np.empty_like(edge_image, dtype=np.uint8) 102 | np.subtract(255, edge_image, out=inv_edge) 103 | edge_image = None 104 | 105 | # Distance to the nearest detected edge 106 | distance_image = ndimage.distance_transform_edt(inv_edge) 107 | inv_edge = None 108 | 109 | # Local maximum distance 110 | local_min = feature.peak_local_max(distance_image, min_distance=feature_separation, 111 | exclude_border=False, indices=False, num_peaks_per_label=1) 112 | distance_image = None 113 | 114 | markers = ndimage.label(local_min)[0] 115 | local_min = None 116 | 117 | # Build a watershed from the markers on top of the edge image 118 | im_watersheds = morphology.watershed(grad_image, markers) 119 | grad_image = None 120 | 121 | # Set all values outside of the image area (empty pixels, usually caused by 122 | # orthorectification) to one value, at the end of the watershed list. 123 | # im_watersheds[empty_pixels] = np.amax(im_watersheds)+1 124 | gc.collect() 125 | return im_watersheds 126 | 127 | 128 | def edge_detect(image_data, band_list, gauss_sigma, low_threshold, high_threshold): 129 | 130 | # Detect edges in the image with a canny edge detector 131 | with warnings.catch_warnings(): 132 | warnings.simplefilter("ignore") 133 | edge_image = img_as_ubyte(feature.canny(image_data[band_list[1]], sigma=gauss_sigma, 134 | low_threshold=low_threshold, high_threshold=high_threshold)) 135 | return edge_image 136 | 137 | 138 | def build_gradient(image_data, band_list, gauss_sigma): 139 | 140 | with warnings.catch_warnings(): 141 | warnings.simplefilter("ignore") 142 | smooth_im_blue = ndimage.filters.gaussian_filter(image_data[band_list[2]], sigma=gauss_sigma) 143 | grad_image = img_as_ubyte(filters.scharr(smooth_im_blue)) 144 | 145 | # Prevent the watersheds from 'leaking' along the sides of the image 146 | grad_image[:, 0] = grad_image[:, 1] 147 | grad_image[:, -1] = grad_image[:, -2] 148 | grad_image[0, :] = grad_image[1, :] 149 | grad_image[-1, :] = grad_image[-2, :] 150 | 151 | return grad_image -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | from setuptools import Extension 4 | except ImportError: 5 | from distutils.core import setup 6 | from distutils.extension import Extension 7 | 8 | USE_CYTHON = True 9 | 10 | ext = '.pyx' if USE_CYTHON else '.c' 11 | 12 | extensions = [Extension("lib.attribute_calculations", ['lib/attribute_calculations' + ext]), 13 | Extension("lib.create_clsf_raster", ["lib/create_clsf_raster" + ext]), 14 | Extension("lib.rescale_intensity", ['lib/rescale_intensity' + ext])] 15 | 16 | if USE_CYTHON: 17 | from Cython.Build import cythonize 18 | extensions = cythonize(extensions, annotate=True) 19 | 20 | setup( 21 | ext_modules = extensions, 22 | ) 23 | -------------------------------------------------------------------------------- /training_datasets/icebridge_v5_training_data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/training_datasets/icebridge_v5_training_data.h5 -------------------------------------------------------------------------------- /training_datasets/icebridge_v7_training_data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/training_datasets/icebridge_v7_training_data.h5 -------------------------------------------------------------------------------- /training_datasets/pan_v2_training_data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/training_datasets/pan_v2_training_data.h5 -------------------------------------------------------------------------------- /training_datasets/wv02_ms_v2_training_data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/training_datasets/wv02_ms_v2_training_data.h5 -------------------------------------------------------------------------------- /training_datasets/wv02_ms_v3.1_training_data.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wrightni/OSSP/1d036edaedd15b66fc190490bff00b05431c0d98/training_datasets/wv02_ms_v3.1_training_data.h5 -------------------------------------------------------------------------------- /training_gui.py: -------------------------------------------------------------------------------- 1 | #title: Training Set Creation for Random Forest Classification 2 | #author: Nick Wright 3 | #Inspired by: Justin Chen 4 | 5 | #purpose: Creates a GUI for a user to identify watershed superpixels of an image as 6 | # melt ponds, sea ice, or open water to use as a training data set for a 7 | # Random Forest Classification method. 8 | 9 | # Python 3: 10 | import tkinter as tk 11 | # Python 2: 12 | # import Tkinter as tk 13 | import numpy as np 14 | import matplotlib 15 | matplotlib.use("TkAgg") 16 | from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg 17 | import matplotlib.pyplot as plt 18 | import h5py 19 | import os 20 | import argparse 21 | from ctypes import * 22 | import gdal 23 | from sklearn.ensemble import RandomForestClassifier 24 | from select import select 25 | import sys 26 | import preprocess as pp 27 | from segment import segment_image 28 | from lib import utils 29 | from lib import attribute_calculations as attr_calc 30 | 31 | 32 | class PrintColor: 33 | PURPLE = '\033[95m' 34 | CYAN = '\033[96m' 35 | DARKCYAN = '\033[36m' 36 | BLUE = '\033[94m' 37 | GREEN = '\033[92m' 38 | YELLOW = '\033[93m' 39 | RED = '\033[91m' 40 | BOLD = '\033[1m' 41 | UNDERLINE = '\033[4m' 42 | END = '\033[0m' 43 | 44 | 45 | class Buttons(tk.Frame): 46 | # Defines the properties of all the controller buttons to be used by the GUI. 47 | def __init__(self, parent): 48 | tk.Frame.__init__(self, parent) 49 | 50 | prev_btn = tk.Button(self, text="Previous Segment", width=16, height=2, 51 | command=lambda: parent.event_manager.previous_segment()) 52 | prev_btn.grid(column=0, row=0, pady=(0,20)) 53 | 54 | water_btn = tk.Button(self, text="Open Water", width=16, height=2, highlightbackground='#000000', 55 | command=lambda: parent.event_manager.classify("water")) 56 | water_btn.grid(column=0, row=1, pady=1) 57 | 58 | melt_btn = tk.Button(self, text="Melt Pond", width=16, height=2, highlightbackground='#4C678C', 59 | command=lambda: parent.event_manager.classify("melt")) 60 | melt_btn.grid(column=0, row=2, pady=1) 61 | 62 | gray_btn = tk.Button(self, text="Dark and Thin Ice", width=16, height=2, highlightbackground='#D2D3D5', 63 | command=lambda: parent.event_manager.classify("gray")) 64 | gray_btn.grid(column=0, row=3, pady=1) 65 | 66 | snow_btn = tk.Button(self, text="Snow or Ice", width=16, height=2, 67 | command=lambda: parent.event_manager.classify("snow")) 68 | snow_btn.grid(column=0, row=4, pady=1) 69 | 70 | shadow_btn = tk.Button(self, text="Shadow", width=16, height=2, highlightbackground='#FF9200', 71 | command=lambda: parent.event_manager.classify("shadow")) 72 | shadow_btn.grid(column=0, row=5, pady=1) 73 | 74 | unknown_btn = tk.Button(self, text="Unknown / Mixed", width=16, height=2, 75 | command=lambda: parent.event_manager.classify("unknown")) 76 | unknown_btn.grid(column=0, row=6, pady=1) 77 | 78 | auto_btn = tk.Button(self, text="Autorun", width=16, height=2, 79 | command=lambda: parent.event_manager.autorun()) 80 | auto_btn.grid(column=0, row=7, pady=(20,0)) 81 | 82 | next_btn = tk.Button(self, text="Next Image", width=16, height=2, 83 | command=lambda: parent.event_manager.next_image()) 84 | next_btn.grid(column=0, row=8, pady=1) 85 | quit_btn = tk.Button(self, text="Save and Quit", width=16, height=2, 86 | command=lambda: parent.event_manager.quit_event()) 87 | quit_btn.grid(column=0, row=9, pady=1) 88 | 89 | load_first_btn = tk.Button(self, text="Initialize Image", width=16, height=2, 90 | command=lambda: parent.event_manager.initialize_image()) 91 | load_first_btn.grid(column=0, row=10, pady=(40,0)) 92 | 93 | 94 | class ProgressBar(tk.Frame): 95 | 96 | def __init__(self, parent): 97 | tk.Frame.__init__(self, parent) 98 | self.parent = parent 99 | 100 | self.total_counter = tk.StringVar() 101 | self.total_counter.set("Total Progress: {}".format(0)) 102 | self.image_tracker = tk.StringVar() 103 | self.image_tracker.set("") 104 | 105 | total_text = tk.Label(self, textvariable=self.total_counter) 106 | total_text.grid(column=0, row=0) 107 | 108 | image_text = tk.Label(self, textvariable=self.image_tracker) 109 | image_text.grid(column=0, row=1) 110 | 111 | def update_progress(self): 112 | self.total_counter.set("Total Progress: {}".format(self.parent.data.get_num_labels())) 113 | self.image_tracker.set("Image {} of {}".format(self.parent.data.im_index + 1, 114 | len(self.parent.data.available_images))) 115 | 116 | 117 | class ImageDisplay(tk.Frame): 118 | 119 | def __init__(self, parent): 120 | tk.Frame.__init__(self, parent) 121 | 122 | self.parent = parent 123 | # Initialize class variables 124 | # Populated in initialize_image method: 125 | self.display_image = None 126 | self.disp_xdim, self.disp_ydim, = 0, 0 127 | # Populated in update_images: 128 | self.zoom_win_x, self.zoom_win_y = 0, 0 129 | 130 | # Creating the canvas where the images will be 131 | self.fig = plt.figure(figsize=[10, 10]) 132 | self.fig.subplots_adjust(left=0.01, right=0.99, bottom=0.05, top=0.99, wspace=0.01, hspace=0.01) 133 | 134 | canvas = FigureCanvasTkAgg(self.fig, self) 135 | canvas.draw() 136 | # toolbar = NavigationToolbar2TkAgg(canvas, frame) 137 | canvas.get_tk_widget().grid(column=0, row=0) 138 | # toolbar.pack(in_=frame, side='top') 139 | self.cid = self.fig.canvas.mpl_connect('button_press_event', parent.event_manager.onclick) 140 | 141 | # Create a placeholder while image data is loading 142 | self.initial_display() 143 | 144 | 145 | def initialize_image(self): 146 | # Creates a local composite of the original image data for display 147 | if self.parent.data.im_type == 'wv02_ms': 148 | self.display_image = utils.create_composite([self.parent.data.original_image[4, :, :], 149 | self.parent.data.original_image[2, :, :], 150 | self.parent.data.original_image[1, :, :]], 151 | dtype=np.uint8) 152 | 153 | elif self.parent.data.im_type == 'pan': 154 | self.display_image = utils.create_composite([self.parent.data.original_image, 155 | self.parent.data.original_image, 156 | self.parent.data.original_image], 157 | dtype=np.uint8) 158 | 159 | elif self.parent.data.im_type == 'srgb': 160 | self.display_image = utils.create_composite([self.parent.data.original_image[0, :, :], 161 | self.parent.data.original_image[1, :, :], 162 | self.parent.data.original_image[2, :, :]], 163 | dtype=np.uint8) 164 | self.disp_xdim, self.disp_ydim = np.shape(self.display_image)[0:2] 165 | 166 | 167 | def loading_display(self): 168 | 169 | plt.clf() 170 | 171 | loading_text = "Images are loading, please wait... " 172 | # Creates a image placeholder while the data is being loaded. 173 | ax = self.fig.add_subplot(1, 1, 1, adjustable='datalim', frame_on=False) 174 | ax.text(0.5, 0.5, loading_text, horizontalalignment='center', verticalalignment='center') 175 | ax.axis('off') 176 | 177 | # Updating the plots 178 | self.fig.canvas.draw() 179 | 180 | def initial_display(self): 181 | 182 | plt.clf() 183 | 184 | welcome_text = "No images have been loaded. Press to begin." 185 | tds_text = "Training data file: \n {}".format(self.parent.data.tds_filename) 186 | image_text = "Images found: \n" 187 | if len(self.parent.data.available_images) == 0: 188 | image_text += 'None' 189 | else: 190 | for im in self.parent.data.available_images: 191 | image_text += im + '\n' 192 | 193 | # Creates a image placeholder while the data is being loaded. 194 | ax = self.fig.add_subplot(2, 1, 1, adjustable='datalim', frame_on=False) 195 | ax.text(0.5, 0.3, welcome_text, horizontalalignment='center', verticalalignment='bottom', weight='bold') 196 | ax.axis('off') 197 | 198 | ax2 = self.fig.add_subplot(2, 1, 2, adjustable='datalim', frame_on=False) 199 | ax2.text(0.5, 1, tds_text, horizontalalignment='center', verticalalignment='center') 200 | ax2.text(0.5, .9, image_text, horizontalalignment='center', verticalalignment='top') 201 | ax2.axis('off') 202 | 203 | # Updating the plots 204 | self.fig.canvas.draw() 205 | 206 | def update_images(self, segment_id): 207 | # Clear the existing display 208 | plt.clf() 209 | 210 | current_seg = self.parent.data.segmented_image == segment_id # array of 0 or 1 where 1 = current segment 211 | segment_pos = np.nonzero(current_seg) # returns the array position of the segment 212 | 213 | zoom_size = 100 214 | 215 | x_min = np.amin(segment_pos[0]) - zoom_size 216 | x_max = np.amax(segment_pos[0]) + zoom_size 217 | y_min = np.amin(segment_pos[1]) - zoom_size 218 | y_max = np.amax(segment_pos[1]) + zoom_size 219 | 220 | # Store the zoom window corner coordinates for reference in onclick() 221 | # xMin and yMin are defined backwards 222 | self.zoom_win_x = y_min 223 | self.zoom_win_y = x_min 224 | 225 | if x_min < 0: 226 | x_min = 0 227 | if x_max >= self.disp_xdim: 228 | x_max = self.disp_xdim - 1 229 | if y_min < 0: 230 | y_min = 0 231 | if y_max >= self.disp_ydim: 232 | y_max = self.disp_ydim - 1 233 | 234 | # Image 2 (Zoomed in image, no highlighted segment) 235 | cropped_image = self.display_image[x_min:x_max, y_min:y_max] 236 | 237 | # Image 3 (Zoomed in image, with segment highlight) 238 | color_image = np.copy(self.display_image) 239 | color_image[:, :, 0][current_seg] = 255 240 | color_image[:, :, 2][current_seg] = 0 241 | color_image = color_image[x_min:x_max, y_min:y_max] 242 | 243 | # Text instructions 244 | instructions = ''' 245 | Open Water: Surface areas that had zero ice cover 246 | as well as those covered by an unconsolidated frazil 247 | or grease ice. \n 248 | Melt Pond: Surface areas with water covering ice. 249 | Areas where meltwater is trapped in isolated patches 250 | atop ice, and the optically similar submerged ice 251 | near the edge of a floe. \n 252 | Dark Ice: 253 | Freezing season: Surfaces of thin ice that are 254 | not snow covered, including nilas and young ice. 255 | Melt season: ice covered by saturated slush, 256 | but not completely submerged in water \n 257 | Snow/Ice: Optically thick ice, and ice with a snow cover. \n 258 | Shadow: Surfaces that are covered by a dark shadow. 259 | \n 260 | ''' 261 | 262 | # Plotting onto the GUI 263 | ax = self.fig.add_subplot(2, 2, 1) 264 | ax.imshow(color_image, interpolation='None', vmin=0, vmax=255) 265 | ax.tick_params(axis='both', # changes apply to the x-axis 266 | which='both', # both major and minor ticks are affected 267 | bottom=False, # ticks along the bottom edge are off 268 | top=False, # ticks along the top edge are off 269 | left=False, 270 | right=False, 271 | labelleft=False, 272 | labelbottom=False) 273 | ax.set_label('ax1') 274 | 275 | ax = self.fig.add_subplot(2, 2, 2) 276 | ax.imshow(cropped_image, interpolation='None', vmin=0, vmax=255) 277 | ax.tick_params(axis='both', # changes apply to the x-axis 278 | which='both', # both major and minor ticks are affected 279 | bottom=False, # ticks along the bottom edge are off 280 | top=False, # ticks along the top edge are off 281 | left=False, 282 | right=False, 283 | labelleft=False, 284 | labelbottom=False) 285 | ax.set_label('ax2') 286 | 287 | ax = self.fig.add_subplot(2, 2, 3) 288 | ax.imshow(self.display_image, interpolation='None', vmin=0, vmax=255) 289 | ax.axvspan(y_min, 290 | y_max, 291 | 1. - float(x_max) / self.disp_xdim, 292 | 1. - float(x_min) / self.disp_xdim, 293 | color='red', 294 | alpha=0.3) 295 | ax.set_xlim([0, np.shape(self.display_image)[1]]) 296 | ax.tick_params(axis='both', # changes apply to the x-axis 297 | which='both', # both major and minor ticks are affected 298 | bottom=False, # ticks along the bottom edge are off 299 | top=False, # ticks along the top edge are off 300 | left=False, 301 | right=False, 302 | labelleft=False, 303 | labelbottom=False) 304 | ax.set_label('ax3') 305 | 306 | ax = self.fig.add_subplot(2, 2, 4, adjustable='datalim', frame_on=False) 307 | ax.text(0.5, 0.5, instructions, horizontalalignment='center', verticalalignment='center') 308 | ax.axis('off') 309 | 310 | # Updating the plots 311 | self.fig.canvas.draw() 312 | 313 | 314 | class DataManager: 315 | 316 | def __init__(self, available_images, tds_filename, username, im_type): 317 | 318 | # Image and segment data (populated in load_image()) 319 | self.original_image = None 320 | self.segmented_image = None 321 | 322 | # Variable Values (populated in load_training_data()) 323 | self.label_vector = [] 324 | self.segment_list = [] 325 | self.feature_matrix = [] 326 | self.tracker = 0 # Number of segment sets added from the current image 327 | self.im_index = 0 # Index for progressing through available images 328 | 329 | # Global Static Values 330 | self.tds_filename = tds_filename 331 | self.username = username 332 | self.im_type = im_type 333 | self.available_images = available_images 334 | 335 | # Image Static Value (populated in load_image()) 336 | self.wb_ref = None 337 | self.br_ref = None 338 | self.im_date = None 339 | self.im_name = None 340 | 341 | def load_next_image(self): 342 | # Increment the image index 343 | self.im_index += 1 344 | # Loop im_index based on the available number of images 345 | self.im_index = self.im_index % len(self.available_images) 346 | # Load the new data 347 | self._load_image() 348 | 349 | def load_previous_image(self): 350 | # If an image has already been loaded, and there is no previous data, 351 | # prevent the user from using this button. 352 | if self.get_num_labels() == 0 and self.im_name is not None: 353 | return 354 | 355 | # If labels exist find the correct image to load 356 | if self.get_num_labels() != 0: 357 | # If this does not find a match, im_index will default to its current value 358 | for i in range(len(self.available_images)): 359 | if self.get_current_segment()[0] in self.available_images[i]: 360 | self.im_index = i 361 | 362 | self._load_image() 363 | 364 | def _load_image(self): 365 | # Loads the optical and segmented image data from disk. Should only be called from 366 | # load_next_image method. 367 | full_image_name = self.available_images[self.im_index] 368 | 369 | self.im_name = os.path.splitext(os.path.split(full_image_name)[1])[0] 370 | 371 | src_ds = gdal.Open(full_image_name, gdal.GA_ReadOnly) 372 | 373 | # Read the image date from the metadata 374 | metadata = src_ds.GetMetadata() 375 | self.im_date = pp.parse_metadata(metadata, self.im_type) 376 | 377 | # Determine the datatype 378 | src_dtype = gdal.GetDataTypeSize(src_ds.GetRasterBand(1).DataType) 379 | 380 | # Calculate the reference points from the image histogram 381 | lower, upper, wb_ref, br_ref = pp.histogram_threshold(src_ds, src_dtype) 382 | self.wb_ref = np.array(wb_ref, dtype=c_uint8) 383 | self.br_ref = np.array(br_ref, dtype=c_uint8) 384 | 385 | # Load the image data 386 | image_data = src_ds.ReadAsArray() 387 | 388 | # Close the GDAL dataset 389 | src_ds = None 390 | 391 | # Rescale the input dataset using a histogram stretch 392 | image_data = pp.rescale_band(image_data, lower, upper) 393 | 394 | # Apply a white balance to the image 395 | image_data = pp.white_balance(image_data, self.wb_ref.astype(np.float), float(np.amax(self.wb_ref))) 396 | 397 | # Convert the input data to c_uint8 398 | self.original_image = np.ndarray.astype(image_data, c_uint8) 399 | 400 | print("Creating segments on provided image...") 401 | watershed_image = segment_image(image_data, image_type=self.im_type) 402 | # Convert the segmented image to c_int datatype. This is needed for the 403 | # Cython methods that calculate attribute of segments. 404 | self.segmented_image = np.ndarray.astype(watershed_image, c_uint32) 405 | # Clear these from memory explicitly 406 | image_data = None 407 | watershed_image = None 408 | 409 | def load_training_data(self): 410 | 411 | try: 412 | with h5py.File(self.tds_filename, 'r') as data_file: 413 | # Load the existing feature matrix and segment list if they exist, 414 | # otherwise initialize an empty array for these lists. 415 | if 'feature_matrix' in list(data_file.keys()): 416 | self.feature_matrix = data_file['feature_matrix'][:].tolist() 417 | else: 418 | self.feature_matrix = [] 419 | 420 | if 'segment_list' in list(data_file.keys()): 421 | # For loading files created in py2 422 | self.segment_list = [[name[0].decode(), name[1].decode()] for name in data_file['segment_list']] 423 | else: 424 | self.segment_list = [] 425 | 426 | # Determine if this user has data already stored in the training set. If so, 427 | # use the existing classifications. If not, start from the beginning. 428 | # must use .tolist() because datasets in h5py files are numpy arrays, and we want 429 | # these as python lists. 430 | # [y1...yn] column vector where n : number of classified segments, y = classification 431 | if self.username in list(data_file.keys()): 432 | self.label_vector = data_file[self.username][:].tolist() 433 | else: 434 | self.label_vector = [] 435 | # If the file does not exist, create empty values 436 | except OSError: 437 | self.feature_matrix = [] 438 | self.segment_list = [] 439 | self.label_vector = [] 440 | 441 | def get_num_labels(self): 442 | return len(self.label_vector) 443 | 444 | def append_label(self, label): 445 | self.tracker += 1 446 | self.label_vector.append(label) 447 | 448 | # Removes the last entry from label_vector 449 | def remove_last_label(self): 450 | self.label_vector.pop() 451 | self.tracker -= 1 452 | 453 | def get_num_segments(self): 454 | return len(self.segment_list) 455 | 456 | # The current segment is the next one that doesn't have an associated label 457 | def get_current_segment(self): 458 | return self.segment_list[len(self.label_vector)] 459 | 460 | def add_single_segment(self, new_segment): 461 | self.segment_list.append(new_segment) 462 | 463 | # Trims all unclassified segments from segment_list by trimming it to 464 | # the length of label_vector 465 | def trim_segment_list(self): 466 | self.segment_list = self.segment_list[:len(self.label_vector)] 467 | 468 | # Add 10 randomly selected segments to the list of ones to classify 469 | def add_segments(self): 470 | segments_to_add = [] 471 | 472 | a = 0 473 | # Select random x,y coordinates from the input image, and pick the segment where the random 474 | # pixel lands. This makes the selected segments representative of the average surface 475 | # distribution within the image. This still wont work if the image has a minority of any 476 | # particular surface type. 477 | while len(segments_to_add)<10: 478 | a += 1 479 | z, x, y = np.shape(self.original_image) 480 | i = np.random.randint(x) 481 | j = np.random.randint(y) 482 | # Find the segment label at the random pixel 483 | segment_id = self.segmented_image[i][j] 484 | sp_size = np.sum(self.segmented_image == segment_id) 485 | if sp_size >= 20: 486 | # Check for a duplicate segment already in the tds 487 | new_segment = [self.im_name, 488 | "{}".format(segment_id)] 489 | if new_segment not in self.segment_list and new_segment not in segments_to_add: 490 | segments_to_add.append(new_segment) 491 | 492 | print(("Attempts: {}".format(a))) 493 | self.segment_list += segments_to_add 494 | 495 | def compute_attributes(self, segment_id): 496 | # Create the a attribute list for the labeled segment 497 | feature_array = calc_attributes(self.original_image, self.segmented_image, 498 | self.wb_ref, self.br_ref, self.im_date, segment_id, self.im_type) 499 | 500 | # attribute_calculations returns a 2d array, but we only want the 1d list of features. 501 | feature_array = feature_array[0] 502 | 503 | return feature_array 504 | 505 | def append_features(self, feature_array): 506 | # If there are fewer features than labels, assume the new one should be appended 507 | # to the end 508 | if len(self.feature_matrix) == len(self.label_vector) - 1: 509 | #Adding all of the features found for this watershed to the main matrix 510 | self.feature_matrix.append(feature_array) 511 | # Otherwise replace the existing features with the newly calculated ones. 512 | # (Maybe just skip this in the future and assume they were calculated correctly before? 513 | else: 514 | # old_feature_array = self.feature_matrix[len(self.label_vector) - 1] 515 | print("Recalculated Feature.") 516 | # print(("Old: {} {}".format(old_feature_array[0], old_feature_array[1]))) 517 | # print(("New: {} {}".format(feature_array[0], feature_array[1]))) 518 | self.feature_matrix[len(self.label_vector) - 1] = feature_array 519 | 520 | 521 | class EventManager: 522 | 523 | def __init__(self, parent): 524 | self.parent = parent 525 | self.is_active = False # Prevents events from happening while images are loading 526 | 527 | def activate(self): 528 | self.is_active = True 529 | 530 | def deactivate(self): 531 | self.is_active = False 532 | 533 | def next_segment(self): 534 | if not self.is_active: 535 | return 536 | 537 | # If all of the segments in the predefined list have been classified already, 538 | # present the user with a random new segment. 539 | if self.parent.data.get_num_labels() == self.parent.data.get_num_segments(): 540 | 541 | # I think if segment_list == [] is covered by the above..? 542 | 543 | self.parent.data.add_segments() 544 | 545 | # retrain the random forest model if the live predictor is active 546 | if self.parent.live_predictor.is_active(): 547 | self.parent.live_predictor.retrain_model(self.parent.data.feature_matrix, 548 | self.parent.data.label_vector) 549 | 550 | # The current segment is the next one that doesn't have an associated label 551 | current_segment = self.parent.data.get_current_segment() 552 | segment_id = int(current_segment[1]) 553 | 554 | # Redraw the display with the new segment id 555 | self.parent.image_display.update_images(segment_id) 556 | 557 | def previous_segment(self): 558 | if not self.is_active: 559 | return 560 | # Make sure this function returns null if there is no previous sp to go back to 561 | if self.parent.data.get_num_labels() == 0: 562 | return 563 | else: 564 | # Delete the last label in the list, then get the 'new' current segment 565 | self.parent.data.remove_last_label() 566 | current_segment = self.parent.data.get_current_segment() 567 | self.parent.progress_bar.update_progress() 568 | 569 | if current_segment[0] != self.parent.data.im_name: 570 | self.previous_image() 571 | return 572 | 573 | segment_id = int(current_segment[1]) 574 | # Redraw the display with the new segment id 575 | self.parent.image_display.update_images(segment_id) 576 | 577 | def onclick(self, event): 578 | if not self.is_active: 579 | return 580 | 581 | if event.inaxes is not None: 582 | axes_properties = event.inaxes.properties() 583 | segment_id = -1 584 | x, y = 0, 0 585 | 586 | # If the mouse click was in the overview image 587 | if axes_properties['label'] == 'ax3': 588 | x = int(event.xdata) 589 | y = int(event.ydata) 590 | segment_id = self.parent.data.segmented_image[y, x] 591 | 592 | # Either of the top zoomed windows 593 | if axes_properties['label'] == 'ax1' or axes_properties['label'] == 'ax2': 594 | win_x = int(event.xdata) 595 | win_y = int(event.ydata) 596 | x = self.parent.image_display.zoom_win_x + win_x 597 | y = self.parent.image_display.zoom_win_y + win_y 598 | segment_id = self.parent.data.segmented_image[y, x] 599 | 600 | # If user clicked on a valid location, add the segment that was clicked on to segment_list, 601 | # then update the image render. 602 | if segment_id >= 0: 603 | print(("You clicked at ({}, {}) in {}".format(x, y, axes_properties['label']))) 604 | print(("Segment id: {}".format(segment_id))) 605 | new_segment = [self.parent.data.im_name, 606 | "{}".format(segment_id)] 607 | if new_segment not in self.parent.data.segment_list: 608 | # Trim all unclassified segments 609 | self.parent.data.trim_segment_list() 610 | # Add the selected one as the next segment 611 | self.parent.data.add_single_segment(new_segment) 612 | # Get the new current segment and redraw display 613 | segment_id = int(self.parent.data.get_current_segment()[1]) 614 | self.parent.image_display.update_images(segment_id) 615 | else: 616 | print("This segment has already been labeled") 617 | 618 | def classify(self, key_press): 619 | if not self.is_active: 620 | return 621 | 622 | # Assigning the highlighted segment a classification 623 | segment_id = int(self.parent.data.get_current_segment()[1]) 624 | print("Segment ID: {}".format(segment_id)) 625 | # Note that we classified one more image 626 | if key_press == "snow": 627 | self.parent.data.append_label(1) 628 | elif key_press == "gray": 629 | self.parent.data.append_label(2) 630 | elif key_press == "melt": 631 | self.parent.data.append_label(3) 632 | elif key_press == "water": 633 | self.parent.data.append_label(4) 634 | elif key_press == "shadow": 635 | self.parent.data.append_label(5) 636 | elif key_press == "unknown": 637 | self.parent.data.append_label(6) 638 | 639 | # Calculate the attributes for the current segment 640 | feature_array = self.parent.data.compute_attributes(segment_id) 641 | self.parent.data.append_features(feature_array) 642 | 643 | # Printing some useful statistics 644 | print("Assigned value: {} ({})".format(str(self.parent.data.label_vector[-1]), key_press)) 645 | 646 | if self.parent.live_predictor.is_active(): 647 | self.parent.live_predictor.print_prediction(feature_array) 648 | 649 | print(("~"*80)) 650 | 651 | self.parent.progress_bar.update_progress() 652 | 653 | self.next_segment() 654 | 655 | # if len(self.feature_matrix) == len(self.label_vector)-1: 656 | # #Adding all of the features found for this watershed to the main matrix 657 | # self.feature_matrix.append(feature_array) 658 | # else: 659 | # old_feature_array = self.feature_matrix[len(self.label_vector)-1] 660 | # print("Recalculated Feature.") 661 | # print(("Old: {} {}".format(old_feature_array[0],old_feature_array[1]))) 662 | # print(("New: {} {}".format(feature_array[0], feature_array[1]))) 663 | # self.feature_matrix[len(self.label_vector)-1] = feature_array 664 | 665 | def autorun(self): 666 | if not self.is_active: 667 | return 668 | 669 | # In the future make this function a standalone window (instead of terminal output)?? 670 | # Prevent the user from accessing this if the predictor is inactive 671 | if not self.parent.live_predictor.is_active(): 672 | print("Autorun functionality disabled") 673 | return 674 | 675 | # segment_id = int(self.segment_list[len(self.label_vector):][0][1]) 676 | segment_id = int(self.parent.data.get_current_segment()[1]) 677 | 678 | # Create the a attribute list for the labeled segment 679 | feature_array = self.parent.data.compute_attributes(segment_id) 680 | 681 | # feature_array = calc_attributes(self.original_image, self.secondary_image, 682 | # self.wb_ref, self.br_ref, self.im_date, segment_id, self.im_type) 683 | 684 | print("~" * 80) 685 | # This both prints the results of the prediction for the user to check, and also returns the 686 | # predicted values for use here. 687 | pred, proba = self.parent.live_predictor.print_prediction(feature_array) 688 | if 0.90 < proba < 0.96: 689 | timeout = 4 #6 690 | print((PrintColor.BOLD + "Label if incorrect:" + PrintColor.END)) 691 | elif proba < .9: 692 | timeout = 10 #12 693 | print((PrintColor.BOLD + PrintColor.RED + "Label if incorrect:" + PrintColor.END)) 694 | else: 695 | timeout = 0.5 696 | 697 | # Prompt the user to change the classification if they dont agree with the 698 | # predicted one. If no input is recieved, the predicted one is assumed to be correct. 699 | rlist, _, _ = select([sys.stdin], [], [], timeout) 700 | if rlist: 701 | s = sys.stdin.readline() 702 | try: 703 | s = int(s) 704 | except ValueError: 705 | print("Ending autorun.") 706 | return 707 | if 0 <= s < 6: 708 | label = s 709 | print(("Assigning label {} instead.".format(label))) 710 | else: 711 | print("Ending autorun.") 712 | return 713 | else: 714 | label = pred 715 | print(("No input. Assigning label: {}".format(label))) 716 | 717 | self.parent.data.append_label(label) 718 | self.parent.data.append_features(feature_array) 719 | 720 | self.parent.progress_bar.update_progress() 721 | 722 | self.next_segment() 723 | self.parent.after(100, self.autorun) 724 | 725 | def save(self): 726 | 727 | if self.parent.data.label_vector == []: 728 | return 729 | 730 | print("Saving...") 731 | 732 | username = self.parent.data.username 733 | 734 | prev_names = [] 735 | prev_data = [] 736 | try: 737 | with h5py.File(self.parent.data.tds_filename, 'r') as infile: 738 | # Compiles all of the user data that was in the previous training validation file so that 739 | # it can be added to the new file as well. (Because erasing and recreating a .h5 is easier 740 | # than altering an existing one) 741 | for prev_user in list(infile.keys()): 742 | if prev_user != 'feature_matrix' and prev_user != 'segment_list' and prev_user != username: 743 | prev_names.append(prev_user) 744 | prev_data.append(infile[prev_user][:]) 745 | infile.close() 746 | except OSError: 747 | pass 748 | 749 | # overwrite the h5 dataset with the updated information 750 | with h5py.File(self.parent.data.tds_filename, 'w') as outfile: 751 | outfile.create_dataset('feature_matrix', data=self.parent.data.feature_matrix) 752 | outfile.create_dataset(username, data=self.parent.data.label_vector) 753 | segment_list = np.array(self.parent.data.segment_list, dtype=np.string_) 754 | outfile.create_dataset('segment_list', data=segment_list) 755 | 756 | for i in range(len(prev_names)): 757 | outfile.create_dataset(prev_names[i], data=prev_data[i]) 758 | 759 | print("Done.") 760 | 761 | def next_image(self): 762 | if not self.is_active: 763 | return 764 | 765 | self.deactivate() 766 | # Trim the unlabeled segments from segment list 767 | self.parent.data.trim_segment_list() 768 | # Save the existing data 769 | self.save() 770 | # Set the display to the loading screen 771 | self.parent.after(10, self.parent.image_display.loading_display()) 772 | # Load the next image data 773 | self.parent.data.load_next_image() 774 | # Add the new data to the display class 775 | self.parent.image_display.initialize_image() 776 | # Update the display screen 777 | # Go to the next segment (which will add additional segments to the queue and update the display) 778 | self.parent.progress_bar.update_progress() 779 | self.activate() 780 | self.next_segment() 781 | 782 | def previous_image(self): 783 | 784 | self.deactivate() 785 | # Set the display to the loading screen 786 | self.parent.after(10, self.parent.image_display.loading_display()) 787 | # Load the previous image data 788 | self.parent.data.load_previous_image() 789 | # Add the new data to the display class 790 | self.parent.image_display.initialize_image() 791 | # Update the display screen 792 | # Go to the next segment (which will add additional segments to the queue and update the display) 793 | self.parent.progress_bar.update_progress() 794 | self.activate() 795 | self.next_segment() 796 | 797 | def initialize_image(self): 798 | if len(self.parent.data.available_images) == 0: 799 | print("No images to load!") 800 | return 801 | # Check to make sure no data has been loaded 802 | if self.parent.data.im_name is not None: 803 | return 804 | # Previous image does all the loading work we need for the first image 805 | self.previous_image() 806 | 807 | 808 | def quit_event(self): 809 | # Exits the GUI, automatically saves progress 810 | self.save() 811 | self.parent.exit_gui() 812 | 813 | 814 | class LivePredictor: 815 | 816 | def __init__(self, active_state): 817 | self.active_state = active_state 818 | self.is_trained = False 819 | self.rfc = RandomForestClassifier(n_estimators=100) 820 | 821 | # True if LivePredictor is running, false otherwise 822 | def is_active(self): 823 | return self.active_state 824 | 825 | def retrain_model(self, feature_matrix, label_vector): 826 | if len(label_vector) >= 10: 827 | self.rfc.fit(feature_matrix[:len(label_vector)], label_vector) 828 | self.is_trained = True 829 | 830 | def print_prediction(self, feature_array): 831 | if self.is_trained: 832 | pred = self.rfc.predict(feature_array.reshape(1, -1))[0] 833 | pred_prob = self.rfc.predict_proba(feature_array.reshape(1, -1))[0] 834 | pred_prob = np.amax(pred_prob) 835 | print(("Predicted value: {}{}{} ({})".format(PrintColor.PURPLE, pred, PrintColor.END, pred_prob))) 836 | return pred, pred_prob 837 | else: 838 | return 0, 0 839 | 840 | 841 | class TrainingWindow(tk.Frame): 842 | 843 | def __init__(self, parent, img_list, tds_filename, username, im_type, activate_autorun=False): 844 | 845 | tk.Frame.__init__(self, parent) 846 | self.parent = parent 847 | self.parent.title("Training GUI") 848 | 849 | # Create the controlling buttons and place them on the right side. 850 | self.buttons = Buttons(self) 851 | self.buttons.grid(column=1, row=1, sticky="N") 852 | 853 | # Manager for all the GUI events (e.g. button presses) 854 | self.event_manager = EventManager(self) 855 | 856 | # Data manager object 857 | self.data = DataManager(img_list, tds_filename, username, im_type) 858 | self.data.load_training_data() 859 | 860 | # Create the image display window 861 | self.image_display = ImageDisplay(self) 862 | self.image_display.grid(column=0, row=0, rowspan=2) 863 | 864 | self.progress_bar = ProgressBar(self) 865 | self.progress_bar.grid(column=1, row=0) 866 | 867 | self.progress_bar.update_progress() 868 | 869 | # Object for creating on the fly predictions and managing the auto_run method 870 | self.live_predictor = LivePredictor(activate_autorun) 871 | 872 | # Define keybindings 873 | self.parent.bind('1', lambda e: self.event_manager.classify("snow")) 874 | self.parent.bind('2', lambda e: self.event_manager.classify("gray")) 875 | self.parent.bind('3', lambda e: self.event_manager.classify("melt")) 876 | self.parent.bind('4', lambda e: self.event_manager.classify("water")) 877 | self.parent.bind('5', lambda e: self.event_manager.classify("shadow")) 878 | self.parent.bind('', lambda e: self.event_manager.classify("unknown")) 879 | self.parent.bind('', lambda e: self.event_manager.previous_segment()) 880 | 881 | 882 | def exit_gui(self): 883 | self.parent.quit() 884 | self.parent.destroy() 885 | 886 | 887 | def calc_attributes(original_image, secondary_image, 888 | wb_ref, br_ref, im_date, segment_id, im_type): 889 | feature_array = [] 890 | 891 | if im_type == 'pan': 892 | feature_array = attr_calc.analyze_pan_image(original_image, 893 | secondary_image, 894 | im_date, 895 | segment_id=segment_id) 896 | if im_type == 'srgb': 897 | feature_array = attr_calc.analyze_srgb_image(original_image, 898 | secondary_image, 899 | segment_id=segment_id) 900 | if im_type == 'wv02_ms': 901 | feature_array = attr_calc.analyze_ms_image(original_image, 902 | secondary_image, 903 | wb_ref, 904 | br_ref, 905 | segment_id=segment_id) 906 | return feature_array 907 | 908 | 909 | # Returns all of the unique images in segment_list 910 | def get_required_images(segment_list): 911 | image_list = [] 912 | for seg_id in segment_list: 913 | if not seg_id[0] in image_list: 914 | image_list.append(seg_id[0]) 915 | return image_list 916 | 917 | 918 | def validate_tds_file(tds_filename, input_dir, image_type): 919 | 920 | # Set the default tds filename if this was not entered 921 | if tds_filename is None: 922 | tds_filename = os.path.join(input_dir, image_type + "_training_data.h5") 923 | elif os.path.isfile(tds_filename): 924 | # If a real file was given, try opening it. 925 | try: 926 | data_file = h5py.File(tds_filename, 'r') 927 | data_file.close() 928 | except OSError: 929 | print("Invalid data file.") 930 | quit() 931 | 932 | return tds_filename 933 | 934 | 935 | # Finds all the unique images from the given directory 936 | def scrape_dir(src_dir): 937 | image_list = [] 938 | 939 | for ext in utils.valid_extensions: 940 | raw_list = utils.get_image_paths(src_dir, keyword=ext) 941 | for raw_im in raw_list: 942 | image_list.append(raw_im) 943 | 944 | # Save only the unique entries 945 | image_list = list(set(image_list)) 946 | utils.remove_hidden(image_list) 947 | 948 | return image_list 949 | 950 | 951 | if __name__ == "__main__": 952 | 953 | #### Set Up Arguments 954 | parser = argparse.ArgumentParser() 955 | parser.add_argument("input", 956 | help="folder containing training images") 957 | parser.add_argument("image_type", type=str, choices=['srgb','wv02_ms','pan'], 958 | help="image type: 'srgb', 'wv02_ms', 'pan'") 959 | parser.add_argument("--tds_file", type=str, default=None, 960 | help='''Existing training dataset file. Will create a new one with this name if none exists. 961 | default: _training_data.h5''') 962 | parser.add_argument("--username", type=str, default=None, 963 | help='''username to associate with the training set. 964 | default: image_type''') 965 | parser.add_argument("-a", "--enable_autorun", action="store_true", 966 | help='''Enables the use of the autorun function.''') 967 | 968 | # Parse Arguments 969 | args = parser.parse_args() 970 | input_dir = os.path.abspath(args.input) 971 | image_type = args.image_type 972 | autorun_flag = args.enable_autorun 973 | 974 | # Add the images in the provided folder to the image list 975 | img_list = scrape_dir(input_dir) 976 | 977 | tds_file = validate_tds_file(args.tds_file, input_dir, image_type) 978 | 979 | if args.username is None: 980 | user_name = image_type 981 | else: 982 | user_name = args.username 983 | 984 | root = tk.Tk() 985 | TrainingWindow(root, img_list, tds_file, user_name, image_type, 986 | activate_autorun=autorun_flag).pack(side='top', fill='both', expand=True) 987 | root.mainloop() 988 | --------------------------------------------------------------------------------