├── .gitignore ├── README.md ├── checkpoints └── readme.txt ├── data_tools ├── create_tfrecords.py └── download_dataset.py ├── deepwatermap.py ├── inference.py ├── metrics.py ├── sample_data ├── readme.txt └── sample_input_output.png └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepWaterMap 2 | DeepWaterMap is a deep convolutional neural network trained to segment surface water on multispectral imagery. This repository contains the resources for the most recent version of DeepWaterMap (v2.0). 3 | 4 | ![Sample input and output](sample_data/sample_input_output.png) 5 | 6 | ## Papers 7 | If you find our code, models, or dataset useful in your research, please cite our papers: 8 | 9 | * L. F. Isikdogan, A.C. Bovik, and P. Passalacqua, "Seeing Through the Clouds with DeepWaterMap," *IEEE Geoscience and Remote Sensing Letters*, 2019. [[**Read at IEEExplore**]](https://ieeexplore.ieee.org/document/8913594), [[**PDF**]](http://www.isikdogan.com/files/isikdogan2019_deepwatermap_v2.pdf) 10 | * F. Isikdogan, A.C. Bovik, and P. Passalacqua, "Surface Water Mapping by Deep Learning," *IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing*, 2017. [[**Read at IEEExplore**]](https://ieeexplore.ieee.org/document/8013683/), [[**PDF**]](http://www.isikdogan.com/files/isikdogan2017_deepwatermap.pdf) 11 | 12 | ## Dependencies 13 | * TensorFlow (tested on TensorFlow 1.12) 14 | * Numpy 15 | * Tifffile (for reading geotiff files) 16 | * OpenCV (for reading and writing images) 17 | 18 | Dependencies can be installed using the Python Package Installer (pip): 19 | 20 | ``` 21 | pip install tensorflow==1.12.0 tifffile opencv-python 22 | ``` 23 | 24 | ## Running inference 25 | 26 | You can use our inference script to generate a surface water map given a multi-spectral image as: 27 | ``` 28 | $ python inference.py --checkpoint_path checkpoints/cp.135.ckpt \ 29 | --image_path sample_data/sentinel2_example.tif --save_path water_map.png 30 | ``` 31 | 32 | You can download the checkpoint that contains our trained model parameters at: 33 | [checkpoints.zip](https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/file/565662752887). 34 | 35 | You can find sample input images at: [sample_data.zip](https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/file/565677626152). 36 | 37 | If you receive an out of memory error, try running it on CPU or on a smaller input image. Running inference on a full-sized Landsat image on CPU may take a few minutes. 38 | 39 | DeepWaterMap does not require the input images to be Landsat-7, Landsat-8, or Sentinel-2 images. However, it expects the input bands to approximately match the Landsat bands listed below. 40 | 41 | ``` 42 | B2: Blue 43 | B3: Green 44 | B4: Red 45 | B5: Near Infrared (NIR) 46 | B6: Shortwave Infrared 1 (SWIR1) 47 | B7: Shortwave Infrared 2 (SWIR2) 48 | ``` 49 | 50 | See [https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites](https://www.usgs.gov/faqs/what-are-band-designations-landsat-satellites) for more information about those bands. 51 | 52 | If you are using images acquired by a sensor other than Landsat, try to match the bands above as closely as possible and make sure the input bands are in the correct order. The model is robust against the shifts in the spectral responses of sensors. Therefore, the bands do not need to match perfectly. 53 | 54 | The inference script we provide gets its input from a 6-band TIFF file. You can modify the script to feed the model a matrix of MxNx6 dimensions in any form. For example, you can read the input bands from separate files, concatenate them in the channel axis, and feed it to the model. 55 | 56 | 57 | ## Dataset 58 | 59 | > You do not need to download the dataset to use DeepWaterMap to run inference on images. You can use the [checkpoint](https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/file/565662752887) we provide instead. We made this dataset available for researchers who wish to experiment with it. 60 | 61 | We provide the original GeoTIFF files and the compressed TFRecords dataset that we used to train DeepWaterMap. 62 | 63 | If you wish to re-train our model or train another model using our dataset you can download our TFRecords dataset and run ```trainer.py```. 64 | 65 | [Download TFRecords (~205GB)](https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/folder/94459511962) 66 | 67 | You can find the original GeoTIFF images in the following box folder. 68 | 69 | [Original GeoTIFF files (~1TB)](https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/folder/94459536870) 70 | 71 | The GeoTIFF dataset is very big. If you need to find and download a particular tile you can find an index of the tiles in ShapeFile format in [metadata.zip](https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/file/564393935179). 72 | 73 | 74 | ## Differences from the paper 75 | 76 | The dataset we provide here has more samples than mentioned in the paper. The additional samples in the dataset have heavier cloud coverage. The checkpoint we provide is a result of training DeepWaterMap on the entire dataset for 135 epochs, including the additional, more challenging samples. 77 | 78 | -------------------------------------------------------------------------------- /checkpoints/readme.txt: -------------------------------------------------------------------------------- 1 | You can download our most recent checkpoint from: 2 | https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/file/565662752887 -------------------------------------------------------------------------------- /data_tools/create_tfrecords.py: -------------------------------------------------------------------------------- 1 | ''' Creates tfrecords given GeoTIFF files. 2 | We provide a copy of the dataset in tfrecords format. 3 | You should not need this script unless you modify the dataset. 4 | ''' 5 | 6 | import os, glob 7 | import argparse 8 | import random 9 | import math 10 | import tifffile as tiff 11 | import tensorflow as tf 12 | 13 | def _bytes_feature(value): 14 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 15 | 16 | def _create_tfexample(B2, B3, B4, B5, B6, B7, label): 17 | example = tf.train.Example(features=tf.train.Features(feature={ 18 | 'B2': _bytes_feature(B2), 19 | 'B3': _bytes_feature(B3), 20 | 'B4': _bytes_feature(B4), 21 | 'B5': _bytes_feature(B5), 22 | 'B6': _bytes_feature(B6), 23 | 'B7': _bytes_feature(B7), 24 | 'L': _bytes_feature(label) 25 | })) 26 | return example 27 | 28 | def preprocess_and_encode_sample(data_tensor): 29 | image = data_tensor[..., :-1] 30 | label = data_tensor[..., -1] 31 | 32 | image = tf.cast(image, tf.float32) 33 | image = image - tf.reduce_min(image) 34 | image = image / tf.maximum(tf.reduce_max(image), 1) 35 | image = image * 255 36 | 37 | image = tf.cast(image, tf.uint8) 38 | label = tf.cast(label, tf.uint8) 39 | 40 | B2 = tf.image.encode_png(image[..., 0, None]) 41 | B3 = tf.image.encode_png(image[..., 1, None]) 42 | B4 = tf.image.encode_png(image[..., 2, None]) 43 | B5 = tf.image.encode_png(image[..., 3, None]) 44 | B6 = tf.image.encode_png(image[..., 4, None]) 45 | B7 = tf.image.encode_png(image[..., 5, None]) 46 | L = tf.image.encode_png(label[..., None]) 47 | return [B2, B3, B4, B5, B6, B7, L] 48 | 49 | def create_tfrecords(save_dir, dataset_name, filenames, images_per_shard): 50 | data_placeholder = tf.placeholder(tf.uint16) 51 | processed_bands = preprocess_and_encode_sample(data_placeholder) 52 | 53 | with tf.Session() as sess: 54 | num_shards = math.ceil(len(filenames) / images_per_shard) 55 | for shard in range(num_shards): 56 | output_filename = os.path.join(save_dir, '{}_{:03d}-of-{:03d}.tfrecord' 57 | .format(dataset_name, shard, num_shards)) 58 | print('Writing into {}'.format(output_filename)) 59 | filenames_shard = filenames[shard*images_per_shard:(shard+1)*images_per_shard] 60 | 61 | with tf.io.TFRecordWriter(output_filename) as tfrecord_writer: 62 | for filename in filenames_shard: 63 | data = tiff.imread(filename) 64 | B2, B3, B4, B5, B6, B7, L = sess.run(processed_bands, feed_dict={data_placeholder: data}) 65 | example = _create_tfexample(B2, B3, B4, B5, B6, B7, L) 66 | tfrecord_writer.write(example.SerializeToString()) 67 | 68 | print('Finished writing {} images into TFRecords'.format(len(filenames))) 69 | 70 | def main(args): 71 | path = os.path.join(args.input_dir, '**/*.tif') 72 | filenames = glob.glob(path) 73 | 74 | random.seed(args.seed) 75 | random.shuffle(filenames) 76 | 77 | num_test = args.num_test_images 78 | 79 | # create TFRecords for the training and test sets 80 | create_tfrecords(save_dir=args.output_dir, 81 | dataset_name='train', 82 | filenames=filenames[num_test:], 83 | images_per_shard=args.images_per_shard) 84 | create_tfrecords(save_dir=args.output_dir, 85 | dataset_name='test', 86 | filenames=filenames[:num_test], 87 | images_per_shard=args.images_per_shard) 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--input_dir', type=str, default='E:/global_water_dataset', 92 | help='path to the directory where the images will be read from') 93 | parser.add_argument('--output_dir', type=str, default='E:/tfrecords', 94 | help='path to the directory where the TFRecords will be saved to') 95 | parser.add_argument('--images_per_shard', type=int, default=5000, 96 | help='number of images per shard') 97 | parser.add_argument('--num_test_images', type=float, default=5000, 98 | help='number of images in the test set') 99 | parser.add_argument('--seed', type=int, default=42, 100 | help='random seed for repeatable train/test splits') 101 | args = parser.parse_args() 102 | main(args) -------------------------------------------------------------------------------- /data_tools/download_dataset.py: -------------------------------------------------------------------------------- 1 | ''' Script to generate and download the dataset using the Google Earth Engine. 2 | You should never need to use this script since we provide a copy of the dataset. 3 | It takes over a month to finish processing the entire dataset using this script. 4 | The script is inclueded in the repository for archival purposes. 5 | ''' 6 | 7 | import ee 8 | import time 9 | 10 | ee.Initialize() 11 | 12 | # Select tiles 13 | valid_tiles = ee.FeatureCollection("users/isikdogan/valid_tiles_filtered") 14 | valid_tiles = tiles.filter(ee.Filter.gt('occurrence', 1.0)) 15 | tile_list = valid_tiles.toList(valid_tiles.size()) 16 | 17 | # Create the dataset by matching inputs and outputs 18 | date_start = '2015-01-01' 19 | date_end = '2015-12-31' 20 | input_bands = ee.ImageCollection('LANDSAT/LC08/C01/T1') \ 21 | .filterDate(date_start, date_end).median() \ 22 | .select(['B2', 'B3', 'B4', 'B5', 'B6', 'B7']) \ 23 | .uint16() 24 | labels = ee.ImageCollection('JRC/GSW1_0/YearlyHistory') \ 25 | .filter(ee.Filter.date(date_start, date_end)) \ 26 | .select('waterClass').first().uint16() 27 | dataset = input_bands.addBands(labels) 28 | 29 | def download_tile(i, tile_list, save_folder): 30 | current_tile = tile_list.get(i) 31 | tile_geometry = ee.Feature(current_tile).geometry().getInfo()["coordinates"] 32 | task = ee.batch.Export.image.toDrive( 33 | image=dataset, 34 | description=savepath, 35 | folder=save_folder, 36 | fileNamePrefix=savepath, 37 | region=tile_geometry, 38 | scale=30) 39 | task.start() 40 | 41 | # Iterate and download 42 | num_tiles = valid_tiles.size().getInfo() 43 | subsample_ratio = 1 44 | for i in range(0, num_tiles, subsample_ratio): 45 | savepath = "tile_{}".format(i) 46 | save_folder = 'tiles_data_{}'.format((i//10000) * 10000) 47 | try: 48 | download_tile(i, tile_list, save_folder) 49 | except Exception, e: 50 | print(e) 51 | print("Capacity reached, waiting...") 52 | time.sleep(1200) 53 | download_tile(i, tile_list, save_folder) 54 | print("Exporting {} ({} / {})".format(savepath, i, num_tiles)) 55 | time.sleep(10) -------------------------------------------------------------------------------- /deepwatermap.py: -------------------------------------------------------------------------------- 1 | ''' Implementation of DeepWaterMapV2. 2 | 3 | The model architecture is explained in: 4 | L.F. Isikdogan, A.C. Bovik, and P. Passalacqua, 5 | "Seeing Through the Clouds with DeepWaterMap," IEEE GRSL, 2019. 6 | ''' 7 | 8 | import tensorflow as tf 9 | 10 | def model(min_width=4): 11 | inputs = tf.keras.layers.Input(shape=[None, None, 6]) 12 | 13 | def conv_block(x, num_filters, kernel_size, stride=1, use_relu=True): 14 | x = tf.keras.layers.Conv2D( 15 | filters=num_filters, 16 | kernel_size=kernel_size, 17 | kernel_initializer='he_uniform', 18 | strides=stride, 19 | padding='same', 20 | use_bias=False)(x) 21 | x = tf.keras.layers.BatchNormalization()(x) 22 | if use_relu: 23 | x = tf.keras.layers.Activation('relu')(x) 24 | return x 25 | 26 | def downscaling_unit(x): 27 | num_filters = int(x.get_shape()[-1]) * 4 28 | x_1 = conv_block(x, num_filters, kernel_size=5, stride=2) 29 | x_2 = conv_block(x_1, num_filters, kernel_size=3, stride=1) 30 | x = tf.keras.layers.Add()([x_1, x_2]) 31 | return x 32 | 33 | def upscaling_unit(x): 34 | num_filters = int(x.get_shape()[-1]) // 4 35 | x = tf.keras.layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(x) 36 | x_1 = conv_block(x, num_filters, kernel_size=3) 37 | x_2 = conv_block(x_1, num_filters, kernel_size=3) 38 | x = tf.keras.layers.Add()([x_1, x_2]) 39 | return x 40 | 41 | def bottleneck_unit(x): 42 | num_filters = int(x.get_shape()[-1]) 43 | x_1 = conv_block(x, num_filters, kernel_size=3) 44 | x_2 = conv_block(x_1, num_filters, kernel_size=3) 45 | x = tf.keras.layers.Add()([x_1, x_2]) 46 | return x 47 | 48 | # model flow 49 | skip_connections = [] 50 | num_filters = min_width 51 | 52 | # first layer 53 | x = conv_block(inputs, num_filters, kernel_size=1, use_relu=False) 54 | skip_connections.append(x) 55 | 56 | # encoder 57 | for i in range(4): 58 | x = downscaling_unit(x) 59 | skip_connections.append(x) 60 | 61 | # bottleneck 62 | x = bottleneck_unit(x) 63 | 64 | # decoder 65 | for i in range(4): 66 | x = tf.keras.layers.Add()([x, skip_connections.pop()]) 67 | x = upscaling_unit(x) 68 | 69 | # last layer 70 | x = tf.keras.layers.Add()([x, skip_connections.pop()]) 71 | x = conv_block(x, 1, kernel_size=1, use_relu=False) 72 | x = tf.keras.layers.Activation('sigmoid')(x) 73 | 74 | model = tf.keras.Model(inputs=inputs, outputs=x) 75 | return model -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ''' Runs inference on a given GeoTIFF image. 2 | 3 | example: 4 | $ python inference.py --checkpoint_path checkpoints/cp.135.ckpt \ 5 | --image_path sample_data/sentinel2_example.tif --save_path water_map.png 6 | ''' 7 | 8 | # Uncomment this to run inference on CPU if your GPU runs out of memory 9 | # import os 10 | # os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 11 | 12 | import argparse 13 | import deepwatermap 14 | import tifffile as tiff 15 | import numpy as np 16 | import cv2 17 | 18 | def find_padding(v, divisor=32): 19 | v_divisible = max(divisor, int(divisor * np.ceil( v / divisor ))) 20 | total_pad = v_divisible - v 21 | pad_1 = total_pad // 2 22 | pad_2 = total_pad - pad_1 23 | return pad_1, pad_2 24 | 25 | def main(checkpoint_path, image_path, save_path): 26 | # load the model 27 | model = deepwatermap.model() 28 | model.load_weights(checkpoint_path) 29 | 30 | # load and preprocess the input image 31 | image = tiff.imread(image_path) 32 | pad_r = find_padding(image.shape[0]) 33 | pad_c = find_padding(image.shape[1]) 34 | image = np.pad(image, ((pad_r[0], pad_r[1]), (pad_c[0], pad_c[1]), (0, 0)), 'reflect') 35 | 36 | # solve no-pad index issue after inference 37 | if pad_r[1] == 0: 38 | pad_r = (pad_r[0], 1) 39 | if pad_c[1] == 0: 40 | pad_c = (pad_c[0], 1) 41 | 42 | image = image.astype(np.float32) 43 | 44 | # remove nans (and infinity) - replace with 0s 45 | image = np.nan_to_num(image, copy=False, nan=0.0, posinf=0.0, neginf=0.0) 46 | 47 | image = image - np.min(image) 48 | image = image / np.maximum(np.max(image), 1) 49 | 50 | # run inference 51 | image = np.expand_dims(image, axis=0) 52 | dwm = model.predict(image) 53 | dwm = np.squeeze(dwm) 54 | dwm = dwm[pad_r[0]:-pad_r[1], pad_c[0]:-pad_c[1]] 55 | 56 | # soft threshold 57 | dwm = 1./(1+np.exp(-(16*(dwm-0.5)))) 58 | dwm = np.clip(dwm, 0, 1) 59 | 60 | # save the output water map 61 | cv2.imwrite(save_path, dwm * 255) 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--checkpoint_path', type=str, 66 | help="Path to the dir where the checkpoints are stored") 67 | parser.add_argument('--image_path', type=str, help="Path to the input GeoTIFF image") 68 | parser.add_argument('--save_path', type=str, help="Path where the output map will be saved") 69 | args = parser.parse_args() 70 | main(args.checkpoint_path, args.image_path, args.save_path) 71 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | ''' This script defines custom metrics and loss functions. 2 | 3 | The Adaptive Max-Pool Loss acts as a weighting function that multiplies a 4 | loss value with the maximum loss values within an NxN neighborhood. 5 | An earlier version of this loss function described in: 6 | 7 | F. Isikdogan, A.C. Bovik, and P. Passalacqua, 8 | "Learning a River Network Extractor using an Adaptive Loss Function," 9 | IEEE Geoscience and Remote Sensing Letters, 2018. 10 | ''' 11 | 12 | import tensorflow as tf 13 | from tensorflow.keras import backend as K 14 | import numpy as np 15 | 16 | def running_recall(y_true, y_pred): 17 | TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 18 | TP_FN = K.sum(K.round(K.clip(y_true, 0, 1))) 19 | recall = TP / (TP_FN + K.epsilon()) 20 | return recall 21 | 22 | def running_precision(y_true, y_pred): 23 | TP = K.sum(K.round(K.clip(y_true * y_pred, 0, 1))) 24 | TP_FP = K.sum(K.round(K.clip(y_pred, 0, 1))) 25 | precision = TP / (TP_FP + K.epsilon()) 26 | return precision 27 | 28 | def running_f1(y_true, y_pred): 29 | precision = running_precision(y_true, y_pred) 30 | recall = running_recall(y_true, y_pred) 31 | return 2 * ((precision * recall) / (precision + recall + K.epsilon())) 32 | 33 | def adaptive_maxpool_loss(y_true, y_pred, alpha=0.25): 34 | y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon()) 35 | positive = -y_true * K.log(y_pred) * alpha 36 | negative = -(1. - y_true) * K.log(1. - y_pred) * (1-alpha) 37 | pointwise_loss = positive + negative 38 | max_loss = tf.keras.layers.MaxPool2D(pool_size=8, strides=1, padding='same')(pointwise_loss) 39 | x = pointwise_loss * max_loss 40 | x = K.mean(x, axis=-1) 41 | return x -------------------------------------------------------------------------------- /sample_data/readme.txt: -------------------------------------------------------------------------------- 1 | You can find some sample input images at: 2 | https://utexas.app.box.com/s/j9ymvdkaq36tk04be680mbmlaju08zkq/file/565677626152 -------------------------------------------------------------------------------- /sample_data/sample_input_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isikdogan/deepwatermap/f02bd76e1bf3f502746a13add185a66d097730b2/sample_data/sample_input_output.png -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | ''' Trains a DeepWaterMap model. We provide a copy of the trained checkpoints. 2 | You should not need this script unless you want to re-train the model. 3 | ''' 4 | 5 | import os, glob 6 | import argparse 7 | import tensorflow as tf 8 | import deepwatermap 9 | from metrics import running_precision, running_recall, running_f1 10 | from metrics import adaptive_maxpool_loss 11 | 12 | class TFModelTrainer: 13 | def __init__(self, checkpoint_dir, data_path): 14 | self.checkpoint_dir = checkpoint_dir 15 | 16 | # set training parameters 17 | self.image_size = (512, 512) 18 | self.learning_rate = 0.1 19 | self.num_epoch = 150 20 | self.batch_size = 24 21 | 22 | # create the data generators 23 | train_filenames = glob.glob(os.path.join(data_path, 'train_*.tfrecord')) 24 | val_filenames = glob.glob(os.path.join(data_path, 'test_*.tfrecord')) 25 | 26 | self.dataset_train = self._data_layer(train_filenames) 27 | self.dataset_val = self._data_layer(val_filenames) 28 | 29 | self.dataset_train_size = 137682 30 | self.dataset_val_size = 5000 31 | self.steps_per_epoch = self.dataset_train_size // self.batch_size 32 | self.validation_steps = self.dataset_val_size // self.batch_size 33 | 34 | def _data_layer(self, filenames, num_threads=24): 35 | dataset = tf.data.TFRecordDataset(filenames) 36 | dataset = dataset.map(self._parse_tfrecord, num_parallel_calls=num_threads) 37 | dataset = dataset.repeat() 38 | dataset = dataset.batch(self.batch_size, drop_remainder=True) 39 | dataset = dataset.prefetch(buffer_size=4) 40 | return dataset 41 | 42 | def _parse_tfrecord(self, example_proto): 43 | keys_to_features = {'B2': tf.io.FixedLenFeature([], tf.string), 44 | 'B3': tf.io.FixedLenFeature([], tf.string), 45 | 'B4': tf.io.FixedLenFeature([], tf.string), 46 | 'B5': tf.io.FixedLenFeature([], tf.string), 47 | 'B6': tf.io.FixedLenFeature([], tf.string), 48 | 'B7': tf.io.FixedLenFeature([], tf.string), 49 | 'L': tf.io.FixedLenFeature([], tf.string)} 50 | F = tf.io.parse_single_example(example_proto, keys_to_features) 51 | data = F['B2'], F['B3'], F['B4'], F['B5'], F['B6'], F['B7'], F['L'] 52 | image, label = self._decode_images(data) 53 | return image, label 54 | 55 | def _decode_images(self, data_strings): 56 | bands = [[]] * len(data_strings) 57 | for i in range(len(data_strings)): 58 | bands[i] = tf.image.decode_png(data_strings[i]) 59 | data = tf.concat(bands, -1) 60 | data = tf.image.random_crop(data, size=[self.image_size[0], self.image_size[1], len(data_strings)]) 61 | data = tf.cast(data, tf.float32) 62 | image = data[..., :-1] / 255 63 | label = data[..., -1, None] / 3 64 | self._preprocess_images(image) 65 | return image, label 66 | 67 | def _preprocess_images(self, image): 68 | image = self._random_channel_mixing(image) 69 | image = self._gaussian_noise(image) 70 | image = self._normalize_image(image) 71 | return image 72 | 73 | def _random_channel_mixing(self, image): 74 | ccm = tf.eye(6)[None, :, :, None] 75 | r = tf.random.uniform([3], maxval=0.25) + [0, 1, 0] 76 | filter = r[None, :, None, None] 77 | ccm = tf.nn.depthwise_conv2d(ccm, filter, strides=[1,1,1,1], padding='SAME', data_format='NHWC') 78 | ccm = tf.squeeze(ccm) 79 | image = tf.tensordot(image, ccm, (-1, 0)) 80 | return image 81 | 82 | def _gaussian_noise(self, image): 83 | r = tf.random.uniform((), maxval=0.04) 84 | image = image + tf.random.normal([self.image_size[0], self.image_size[1], 6], stddev=r) 85 | return image 86 | 87 | def _normalize_image(self, image): 88 | image = tf.cast(image, tf.float32) 89 | image = image - tf.reduce_min(image) 90 | image = image / tf.maximum(tf.reduce_max(image), 1) 91 | return image 92 | 93 | def _optimizer(self): 94 | optimizer = tf.keras.optimizers.SGD(lr=self.learning_rate, momentum=0.9) 95 | return optimizer 96 | 97 | def train(self): 98 | # Callbacks 99 | cp_callback = tf.keras.callbacks.ModelCheckpoint(os.path.join(self.checkpoint_dir, 'cp.{epoch:03d}.ckpt'), 100 | save_weights_only=True) 101 | tb_callback = tf.keras.callbacks.TensorBoard(log_dir=self.checkpoint_dir) 102 | lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=1) 103 | 104 | # Model 105 | model = deepwatermap.model() 106 | 107 | initial_epoch = 0 108 | ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) 109 | if ckpt and ckpt.model_checkpoint_path: 110 | model.load_weights(ckpt.model_checkpoint_path) 111 | print("Loaded weights from", ckpt.model_checkpoint_path) 112 | initial_epoch = int(ckpt.model_checkpoint_path.split('.')[-2]) 113 | 114 | model.compile(optimizer=self._optimizer(), 115 | loss=adaptive_maxpool_loss, 116 | metrics=[tf.keras.metrics.binary_accuracy, 117 | running_precision, running_recall, running_f1]) 118 | model.fit(self.dataset_train, 119 | validation_data=self.dataset_val, 120 | epochs=self.num_epoch, 121 | initial_epoch=initial_epoch, 122 | steps_per_epoch=self.steps_per_epoch, 123 | validation_steps=self.validation_steps, 124 | callbacks=[cp_callback, tb_callback, lr_callback]) 125 | 126 | def main(): 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--checkpoint_path', type=str, default='./checkpoints/', 129 | help="Path to the dir where the checkpoints are saved") 130 | parser.add_argument('--data_path', type=str, 131 | help="Path to the tfrecord files") 132 | args = parser.parse_args() 133 | trainer = TFModelTrainer(args.checkpoint_path, args.data_path) 134 | trainer.train() 135 | 136 | if __name__ == '__main__': 137 | main() 138 | --------------------------------------------------------------------------------