├── Fig ├── HY_1.png ├── HY_2.png ├── HY_GT.png ├── WH_1.png ├── WH_2.png ├── WH_GT.png ├── SiamCRNN.jpg ├── Buffalo_GT.bmp ├── Buffalo_T1.png └── Buffalo_T2.png ├── FCN_version ├── dataset │ ├── OSCD │ │ └── original_data │ │ │ ├── test.txt │ │ │ └── train.txt │ ├── make_data_loader.py │ └── imutils.py ├── deep_networks │ ├── sync_batchnorm │ │ ├── __pycache__ │ │ │ ├── comm.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── batchnorm.cpython-36.pyc │ │ │ └── replicate.cpython-36.pyc │ │ ├── __init__.py │ │ ├── unittest.py │ │ ├── replicate.py │ │ ├── comm.py │ │ └── batchnorm.py │ ├── resnet_18_34.py │ └── SiamCRNN.py ├── README.md ├── util_func │ ├── metrics.py │ └── lovasz_loss.py └── script │ └── train_siamcrnn.py ├── LICENSE ├── util ├── data_prepro.py └── net_util.py ├── infer.py ├── SiamCRNN.py ├── README.md └── train.py /Fig/HY_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/HY_1.png -------------------------------------------------------------------------------- /Fig/HY_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/HY_2.png -------------------------------------------------------------------------------- /Fig/HY_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/HY_GT.png -------------------------------------------------------------------------------- /Fig/WH_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/WH_1.png -------------------------------------------------------------------------------- /Fig/WH_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/WH_2.png -------------------------------------------------------------------------------- /Fig/WH_GT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/WH_GT.png -------------------------------------------------------------------------------- /Fig/SiamCRNN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/SiamCRNN.jpg -------------------------------------------------------------------------------- /Fig/Buffalo_GT.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/Buffalo_GT.bmp -------------------------------------------------------------------------------- /Fig/Buffalo_T1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/Buffalo_T1.png -------------------------------------------------------------------------------- /Fig/Buffalo_T2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/Buffalo_T2.png -------------------------------------------------------------------------------- /FCN_version/dataset/OSCD/original_data/test.txt: -------------------------------------------------------------------------------- 1 | brasilia 2 | montpellier 3 | norcia 4 | rio 5 | saclay_w 6 | valencia 7 | dubai 8 | lasvegas 9 | milano 10 | chongqing -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc -------------------------------------------------------------------------------- /FCN_version/dataset/OSCD/original_data/train.txt: -------------------------------------------------------------------------------- 1 | aguasclaras 2 | bercy 3 | bordeaux 4 | nantes 5 | paris 6 | rennes 7 | saclay_e 8 | abudhabi 9 | cupertino 10 | pisa 11 | beihai 12 | hongkong 13 | beirut 14 | mumbai -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 I-Hope-Peace 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util/data_prepro.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def norm_img(img, channel_first=True): 5 | if channel_first: 6 | channel, img_height, img_width = img.shape 7 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width) 8 | max_value = np.max(img, axis=1, keepdims=True) # (channel, 1) 9 | min_value = np.min(img, axis=1, keepdims=True) # (channel, 1) 10 | diff_value = max_value - min_value 11 | nm_img = (img - min_value) / diff_value 12 | nm_img = np.reshape(nm_img, (channel, img_height, img_width)) 13 | else: 14 | img_height, img_width, channel = img.shape 15 | img = np.reshape(img, (img_height * img_width, channel)) # (channel, height * width) 16 | max_value = np.max(img, axis=0, keepdims=True) # (channel, 1) 17 | min_value = np.min(img, axis=0, keepdims=True) # (channel, 1) 18 | diff_value = max_value - min_value 19 | nm_img = (img - min_value) / diff_value 20 | nm_img = np.reshape(nm_img, (img_height, img_width, channel)) 21 | return nm_img 22 | 23 | 24 | def stad_img(img, channel_first=True, get_para=False): 25 | """ 26 | normalization image 27 | :param channel_first: 28 | :param img: (C, H, W) 29 | :return: 30 | norm_img: (C, H, W) 31 | """ 32 | if channel_first: 33 | channel, img_height, img_width = img.shape 34 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width) 35 | mean = np.mean(img, axis=1, keepdims=True) # (channel, 1) 36 | center = img - mean # (channel, height * width) 37 | var = np.sum(np.power(center, 2), axis=1, keepdims=True) / (img_height * img_width) # (channel, 1) 38 | std = np.sqrt(var) # (channel, 1) 39 | std_img = center / std # (channel, height * width) 40 | std_img = np.reshape(std_img, (channel, img_height, img_width)) 41 | else: 42 | img_height, img_width, channel = img.shape 43 | img = np.reshape(img, (img_height * img_width, channel)) # (height * width, channel) 44 | mean = np.mean(img, axis=0, keepdims=True) # (1, channel) 45 | center = img - mean # (height * width, channel) 46 | var = np.sum(np.power(center, 2), axis=0, keepdims=True) / (img_height * img_width) # (1, channel) 47 | std = np.sqrt(var) # (channel, 1) 48 | std_img = center / std # (channel, height * width) 49 | std_img = np.reshape(std_img, (img_height, img_width, channel)) 50 | print('mean is ', mean) 51 | print('std is ', std) 52 | if get_para: 53 | return std_img, mean, std 54 | else: 55 | return std_img 56 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | 3 | import gdal 4 | import numpy as np 5 | import tensorflow as tf 6 | from util.data_prepro import stad_img 7 | 8 | from SiamCRNN import SiamCRNN 9 | import time 10 | 11 | 12 | 13 | def load_data(path_X, path_Y): 14 | data_set_X = gdal.Open(path_X) # data set X 15 | data_set_Y = gdal.Open(path_Y) # data set Y 16 | 17 | img_width = data_set_X.RasterXSize # image width 18 | img_height = data_set_X.RasterYSize # image height 19 | 20 | img_X = data_set_X.ReadAsArray(0, 0, img_width, img_height) 21 | img_Y = data_set_Y.ReadAsArray(0, 0, img_width, img_height) 22 | 23 | img_X = stad_img(img_X) # (C, H, W) 24 | img_Y = stad_img(img_Y) 25 | img_X = np.transpose(img_X, [1, 2, 0]) # (H, W, C) 26 | img_Y = np.transpose(img_Y, [1, 2, 0]) # (H, W, C) 27 | return img_X, img_Y 28 | 29 | 30 | def infer_result(): 31 | patch_sz = 5 32 | batch_size = 1000 33 | 34 | img_X, img_Y = load_data() 35 | img_X = np.pad(img_X, ((2, 2), (2, 2), (0, 0)), 'constant') 36 | img_Y = np.pad(img_Y, ((2, 2), (2, 2), (0, 0)), 'constant') 37 | img_height, img_width, channel = img_X.shape # image width 38 | 39 | edge = patch_sz // 2 40 | sample_X = [] 41 | sample_Y = [] 42 | for i in range(edge, img_height - edge): 43 | for j in range(edge, img_width - edge): 44 | sample_X.append(img_X[i - edge:i + edge + 1, j - edge:j + edge + 1, :]) 45 | sample_Y.append(img_Y[i - edge:i + edge + 1, j - edge:j + edge + 1, :]) 46 | sample_X = np.array(sample_X) 47 | sample_Y = np.array(sample_Y) 48 | 49 | epoch = sample_X.shape[0] // batch_size 50 | 51 | Input_X = tf.placeholder(dtype=tf.float32, shape=[None, patch_sz, patch_sz, channel], name='Input_X') 52 | Input_Y = tf.placeholder(dtype=tf.float32, shape=[None, patch_sz, patch_sz, channel], name='Input_Y') 53 | is_training = tf.placeholder(dtype=tf.bool, name='is_training') 54 | 55 | model_path = 'model_param' 56 | model = SiamCRNN() 57 | net, result = model.get_model(Input_X=Input_X, Input_Y=Input_Y, data_format='NHWC', is_training=is_training) 58 | config = tf.ConfigProto() 59 | config.gpu_options.allow_growth = True 60 | saver = tf.train.Saver() 61 | logits_result_list = [] 62 | pred_results_list = [] 63 | path = None 64 | 65 | with tf.Session() as sess: 66 | ckpt = tf.train.get_checkpoint_state(model_path) 67 | if ckpt and ckpt.model_checkpoint_path: 68 | path = ckpt.model_checkpoint_path 69 | saver.restore(sess, ckpt.model_checkpoint_path) 70 | tic = time.time() 71 | for _epoch in range(100): 72 | pred_result = sess.run([result], feed_dict={ 73 | Input_X: sample_X[batch_size * _epoch:batch_size * (_epoch + 1)], 74 | Input_Y: sample_Y[batch_size * _epoch:batch_size * (_epoch + 1)], 75 | is_training: False 76 | }) 77 | pred_results_list.append(pred_result) 78 | 79 | pred = np.reshape(pred_results_list, (img_height, img_width)) 80 | 81 | idx_1 = (pred <= 0.5) 82 | idx_2 = (pred > 0.5) 83 | pred[idx_1] = 0 84 | pred[idx_2] = 255 85 | cv.imwrite('SiamCRNN.bmp', pred) 86 | 87 | 88 | 89 | if __name__ == '__main__': 90 | infer_result() -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /FCN_version/README.md: -------------------------------------------------------------------------------- 1 |

SiamCRNN in Fully Convolutional Version

2 | 3 |

Hongruixuan Chen, Chen Wu, Bo Du, 4 | Liangpei Zhang, and Le Wang

5 | 6 | This is an implementation of fully convolutional version of **SiamCRNN** framework in our IEEE TGRS 2020 paper: [Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network](https://ieeexplore.ieee.org/document/8937755). 7 | 8 | ## Introduction 9 | We have improved the SiamCRNN from our [original paper](https://ieeexplore.ieee.org/document/8937755) so that SiamCRNN can be used for large-scale change detection tasks. The entire architecture is fully convolutional, where the encoder is an arbitrary fully convolutional deep network (we use ResNet here) and the decoder is a multilayer ConvLSTM+FPN. 10 | 11 | ## Get started 12 | ### Requirements 13 | ``` 14 | python==3.8.18 15 | pytorch==1.21.1 16 | torchvision==0.13.1 17 | imageio==2.22.4 18 | numpy==1.14.0 19 | tqdm==4.64.1 20 | ``` 21 | 22 | ### Dataset 23 | The fully convolutional version of SiamCRNN can be trained and tested on arbitrary large-scale change detection benchmark datasets, such as [SYSU](https://github.com/liumency/SYSU-CD), [LEVIR-CD](https://chenhao.in/LEVIR/), etc. We provide here the example on the [OSCD dataset](https://rcdaudt.github.io/oscd/). 24 | 25 | ### Training 26 | ``` 27 | python train_siamcrnn.py 28 | ``` 29 | 30 | ### Testing 31 | ``` 32 | python test_siamcrnn.py 33 | ``` 34 | 35 | ### Detection results 36 | | | | | | | | | 37 | |:-------------:|:---------------:|:--------------------:|:---------------:|:--------------------:|:--------------------:|:--------------------:| 38 | | Method | Rec | Pre | OA | F1 | IoU | KC | 39 | | FC-EF | 0.4612 | 0.4967 | 0.9480 | 0.4783 | 0.3143 | 0.4510 | 40 | | FC-Siam-Conc | 0.5362 | 0.4760 | 0.9455 | 0.5043 | 0.3372 | 0.4757 | 41 | | FC-Siam-Diff | 0.5385 | 0.4391 | 0.9406 | 0.4838 | 0.3191 | 0.4526 | 42 | | [DSIFN](https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images) | 0.4538 | 0.6732 | 0.9604 | 0.5421 | 0.3718 | 0.5222 | 43 | | [ChangeFormer](https://github.com/wgcban/ChangeFormer) | 0.5351 | 0.5566 | 0.9540 | 0.5457 | 0.3752 | 0.5215 | 44 | | **SiamCRNN (ResNet18)** | **0.5160** | **0.5758** | **0.9553** | **0.5442** | **0.3739** | **0.5208** | 45 | 46 | ## Citation 47 | If this code or dataset contributes to your research, please consider citing our paper. We appreciate your support!🙂 48 | ``` 49 | @article{Chen2020Change, 50 | author = {Chen, Hongruixuan and Wu, Chen and Du, Bo and Zhang, Liangpei and Wang, Le}, 51 | issn = {0196-2892}, 52 | journal = {IEEE Transactions on Geoscience and Remote Sensing}, 53 | number = {4}, 54 | pages = {2848--2864}, 55 | title = {{Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network}}, 56 | volume = {58}, 57 | year = {2020} 58 | } 59 | ``` 60 | 61 | ## Q & A 62 | **For any questions, please [contact us.](mailto:Qschrx@gmail.com)** 63 | -------------------------------------------------------------------------------- /FCN_version/util_func/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,) * 2, dtype=np.float64) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / (self.confusion_matrix.sum(axis=1) + 1e-7) 15 | mAcc = np.nanmean(Acc) 16 | return mAcc, Acc 17 | 18 | def Pixel_Precision_Rate(self): 19 | assert self.confusion_matrix.shape[0] == 2 20 | Pre = self.confusion_matrix[1, 1] / (self.confusion_matrix[0, 1] + self.confusion_matrix[1, 1]) 21 | return Pre 22 | 23 | def Pixel_Recall_Rate(self): 24 | assert self.confusion_matrix.shape[0] == 2 25 | Rec = self.confusion_matrix[1, 1] / (self.confusion_matrix[1, 0] + self.confusion_matrix[1, 1]) 26 | return Rec 27 | 28 | def Pixel_F1_score(self): 29 | assert self.confusion_matrix.shape[0] == 2 30 | Rec = self.Pixel_Recall_Rate() 31 | Pre = self.Pixel_Precision_Rate() 32 | F1 = 2 * Rec * Pre / (Rec + Pre) 33 | return F1 34 | 35 | def Mean_Intersection_over_Union(self): 36 | MIoU = np.diag(self.confusion_matrix) / ( 37 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 38 | np.diag(self.confusion_matrix) + 1e-7) 39 | MIoU = np.nanmean(MIoU) 40 | return MIoU 41 | 42 | def Intersection_over_Union(self): 43 | IoU = self.confusion_matrix[1, 1] / ( 44 | self.confusion_matrix[0, 1] + self.confusion_matrix[1, 0] + self.confusion_matrix[1, 1]) 45 | return IoU 46 | 47 | def Kappa_coefficient(self): 48 | # Number of observations (total number of classifications) 49 | # num_total = np.array(0, dtype=np.long) 50 | # row_sums = np.array([0, 0], dtype=np.long) 51 | # col_sums = np.array([0, 0], dtype=np.long) 52 | # total += np.sum(self.confusion_matrix) 53 | # # Observed agreement (i.e., sum of diagonal elements) 54 | # observed_agreement = np.sum(np.diag(self.confusion_matrix)) 55 | # # Compute expected agreement 56 | # row_sums += np.sum(self.confusion_matrix, axis=0) 57 | # col_sums += np.sum(self.confusion_matrix, axis=1) 58 | # expected_agreement = np.sum((row_sums * col_sums) / total) 59 | num_total = np.sum(self.confusion_matrix, dtype=np.float64) 60 | observed_accuracy = np.trace(self.confusion_matrix, dtype=np.float64) / num_total 61 | expected_accuracy = np.sum( 62 | np.sum(self.confusion_matrix, axis=0, dtype=np.float64) / num_total * np.sum(self.confusion_matrix, axis=1, 63 | dtype=np.float64) / num_total) 64 | 65 | # Calculate Cohen's kappa 66 | kappa = (observed_accuracy - expected_accuracy) / (1 - expected_accuracy) 67 | return kappa 68 | 69 | def Frequency_Weighted_Intersection_over_Union(self): 70 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 71 | iu = np.diag(self.confusion_matrix) / ( 72 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 73 | np.diag(self.confusion_matrix)) 74 | 75 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 76 | return FWIoU 77 | 78 | def _generate_matrix(self, gt_image, pre_image): 79 | mask = (gt_image >= 0) & (gt_image < self.num_class) 80 | label = self.num_class * gt_image[mask].astype('int64') + pre_image[mask] 81 | count = np.bincount(label, minlength=self.num_class ** 2) 82 | confusion_matrix = count.reshape(self.num_class, self.num_class) 83 | return confusion_matrix 84 | 85 | def add_batch(self, gt_image, pre_image): 86 | assert gt_image.shape == pre_image.shape 87 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 88 | 89 | def reset(self): 90 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 91 | -------------------------------------------------------------------------------- /SiamCRNN.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from net_util import conv_2d, max_pool_2d, avg_pool_2d, fully_connected 3 | 4 | 5 | class SiamCRNN(object): 6 | 7 | def get_model(self, Input_X, Input_Y, data_format='NHWC', is_training=True): 8 | net_X = self._feature_extract_layer(inputs=Input_X, name='Fea_Ext_', 9 | data_format=data_format, 10 | is_training=is_training) 11 | net_Y = self._feature_extract_layer(inputs=Input_Y, name='Fea_Ext_', 12 | data_format=data_format, 13 | is_training=is_training, is_reuse=True) 14 | 15 | fea_1 = tf.squeeze(net_X, axis=1) 16 | fea_2 = tf.squeeze(net_Y, axis=1) 17 | 18 | logits, pred = self._change_judge_layer(feature_1=fea_1, feature_2=fea_2, name='Cha_Jud_', 19 | is_training=is_training) 20 | return logits, pred 21 | 22 | def _feature_extract_layer(self, inputs, name='Fea_Ext_', data_format='NHWC', is_training=True, is_reuse=False): 23 | with tf.variable_scope(name) as scope: 24 | if is_reuse: 25 | scope.reuse_variables() 26 | # (B, H, W, C) --> (B, H, W, 16) 27 | layer_1 = conv_2d(inputs=inputs, kernel_size=[3, 3], output_channel=16, stride=[1, 1], name='layer_1_conv', 28 | padding='SAME', data_format=data_format, is_training=is_training, is_bn=False, 29 | activation=tf.nn.relu) 30 | layer_2 = conv_2d(inputs=layer_1, kernel_size=[3, 3], output_channel=16, stride=[1, 1], name='layer_2_conv', 31 | padding='SAME', data_format=data_format, is_training=is_training, is_bn=False, 32 | activation=tf.nn.relu) 33 | 34 | layer_2 = tf.contrib.layers.dropout(inputs=layer_2, is_training=is_training, keep_prob=0.8) 35 | 36 | # (B, H/2, W/2, 16) --> (B, H/2, W/2, 32) 37 | layer_3 = conv_2d(inputs=layer_2, kernel_size=[3, 3], output_channel=32, stride=[1, 1], padding='SAME', 38 | name='layer_3_conv', data_format=data_format, is_training=is_training, is_bn=False, 39 | activation=tf.nn.relu) 40 | layer_4 = conv_2d(inputs=layer_3, kernel_size=[3, 3], output_channel=32, stride=[1, 1], padding='SAME', 41 | name='layer_4_conv', data_format=data_format, is_training=is_training, is_bn=False, 42 | activation=tf.nn.relu) 43 | layer_4 = tf.contrib.layers.dropout(inputs=layer_4, is_training=is_training, keep_prob=0.7) 44 | 45 | # # (B, H/2, W/2, 32) --> (B, H/2, W/2, 64) 46 | layer_5 = conv_2d(inputs=layer_4, kernel_size=[3, 3], output_channel=64, stride=[1, 1], padding='SAME', 47 | name='layer_5_conv', data_format=data_format, is_training=is_training, is_bn=False, 48 | activation=tf.nn.relu) 49 | net = conv_2d(inputs=layer_5, kernel_size=[5, 5], output_channel=64, stride=[1, 1], padding='VALID', 50 | name='layer_6_conv', data_format=data_format, is_training=is_training, is_bn=False, 51 | activation=tf.nn.relu) 52 | net = tf.contrib.layers.dropout(inputs=net, is_training=is_training, keep_prob=0.5) 53 | return net 54 | 55 | def _change_judge_layer(self, feature_1, feature_2, name='Cha_Jud_', is_training=True, 56 | activation=tf.nn.sigmoid): 57 | with tf.variable_scope(name) as scope: 58 | seq = tf.concat([feature_1, feature_2], axis=1) # (B, 2, 128) 59 | num_units = [128, 64] 60 | cells = [tf.nn.rnn_cell.LSTMCell(num_unit, activation=tf.nn.tanh) for num_unit in num_units] 61 | mul_cells = tf.nn.rnn_cell.MultiRNNCell(cells) 62 | output, cell_state = tf.nn.dynamic_rnn(mul_cells, seq, dtype=tf.float32, time_major=False) 63 | hidden_state = tf.contrib.layers.dropout(inputs=output[:, -1], is_training=is_training, keep_prob=0.5) 64 | logits_0 = fully_connected(hidden_state, num_outputs=32, is_training=is_training, is_bn=False, 65 | activation=tf.nn.tanh) 66 | logits = fully_connected(logits_0, num_outputs=1, is_training=is_training, is_bn=False) 67 | pred = activation(logits) 68 | return logits, pred 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network

2 | 3 |

Hongruixuan Chen, Chen Wu, Bo Du, 4 | Liangpei Zhang, and Le Wang

5 | 6 | This is an official implementation of **SiamCRNN** framework in our IEEE TGRS 2020 paper: [Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network](https://ieeexplore.ieee.org/document/8937755). 7 | 8 | 9 | ## Note 10 | **2024.01.18** 11 | - Wanna go beyond the patch-wise detection pipeline and apply SiamCRNN to large-scale change detection datasets? We have updated the [fully convolutional version of SiamCRNN](https://github.com/ChenHongruixuan/SiamCRNN/tree/master/FCN_version). Please feel free to test on the benchmark dataset. 12 | 13 | **2023.04.25** 14 | - The datasets Wuhan and Hanyang used in our paper have been open-sourced! You can download them [here](http://sigma.whu.edu.cn/resource.php). 15 | 16 | 17 | 18 | ## Abstract 19 | > With the rapid development of Earth observation technology, very-high-resolution (VHR) images from various satellite sensors are more available, which greatly enrich the data source of change detection (CD). Multisource multitemporal images can provide abundant information on observed landscapes with various physical and material views, and it is exigent to develop efficient techniques to utilize these multisource data for CD. In this article, we propose a novel and general deep siamese convolutional multiple-layers recurrent neural network (RNN) (SiamCRNN) for CD in multitemporal VHR images. Superior to most VHR image CD methods, SiamCRNN can be used for both homogeneous and heterogeneous images. Integrating the merits of both convolutional neural network (CNN) and RNN, Siam-CRNN consists of three subnetworks: deep siamese convolutional neural network (DSCNN), multiple-layers RNN (MRNN), and fully connected (FC) layers. The DSCNN has a flexible structure for multisource image and is able to extract spatial–spectral features from homogeneous or heterogeneous VHR image patches. The MRNN stacked by long-short term memory (LSTM) units is responsible for mapping the spatial–spectral features extracted by DSCNN into a new latent feature space and mining the change information between them. In addition, FC, the last part of SiamCRNN, is adopted to predict change probability. The experimental results in two homogeneous data sets and one challenging heterogeneous VHR images data set demonstrate that the promising performances of the proposed network outperform several state-of-the-art approaches. 20 | 21 | ## Network architecture 22 | 23 | 24 | 25 | 26 | ## Requirements 27 | ``` 28 | tensorflow_gpu==1.9.0 29 | opencv==3.4.0 30 | numpy==1.14.0 31 | ``` 32 | 33 | ## Dataset 34 | Two homogeneous datasets, Wuhan and Hanyang, and one heterogeneous dataset, Buffalo, are used in our work. The Wuhan and Hanyang datasets can be downloaded [here](http://sigma.whu.edu.cn/resource.php). For the Buffalo dataset, please request distribution from Prof. [Chen Wu](mailto:chen.wu@whu.edu.cn). 35 |
36 | 37 | | Dataset | Pre-event image | Post-event image | Reference Image | 38 | | :----: | :----: | :----: | :----: | 39 | | Wuhan | | | | 40 | | Hanyang | | | | 41 | | Buffalo | | | | 42 | 43 |
44 | 45 | ## Citation 46 | If this code or dataset contributes to your research, please consider citing our paper. We appreciate your support!🙂 47 | ``` 48 | @article{Chen2020Change, 49 | author = {Chen, Hongruixuan and Wu, Chen and Du, Bo and Zhang, Liangpei and Wang, Le}, 50 | issn = {0196-2892}, 51 | journal = {IEEE Transactions on Geoscience and Remote Sensing}, 52 | number = {4}, 53 | pages = {2848--2864}, 54 | title = {{Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network}}, 55 | volume = {58}, 56 | year = {2020} 57 | } 58 | ``` 59 | 60 | ## Q & A 61 | **For any questions, please [contact us.](mailto:Qschrx@gmail.com)** 62 | -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /FCN_version/script/train_siamcrnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('/home/songjian/project/HSIFM') 3 | import argparse 4 | import os 5 | import time 6 | 7 | import imageio 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | from dataset.make_data_loader import OSCDDatset3Bands, make_data_loader, OSCDDatset13Bands 15 | from util_func.metrics import Evaluator 16 | from deep_networks.SiamCRNN import SiamCRNN 17 | import util_func.lovazs_loss as L 18 | 19 | 20 | class Trainer(object): 21 | def __init__(self, args): 22 | self.args = args 23 | 24 | self.train_data_loader = make_data_loader(args) 25 | print(args.model_type + ' is running') 26 | self.evaluator = Evaluator(num_class=2) 27 | 28 | self.deep_model = SiamCRNN(in_dim_1=13, in_dim_2=13) 29 | self.deep_model = self.deep_model.cuda() 30 | 31 | self.model_save_path = os.path.join(args.model_param_path, args.dataset, 32 | args.model_type + '_' + str(time.time())) 33 | self.lr = args.learning_rate 34 | self.epoch = args.max_iters // args.batch_size 35 | 36 | if not os.path.exists(self.model_save_path): 37 | os.makedirs(self.model_save_path) 38 | 39 | self.optim = optim.AdamW(self.deep_model.parameters(), 40 | lr=args.learning_rate, 41 | weight_decay=args.weight_decay) 42 | 43 | def training(self): 44 | best_kc = 0.0 45 | best_round = [] 46 | torch.cuda.empty_cache() 47 | self.deep_model.train() 48 | class_weight = torch.FloatTensor([1, 10]).cuda() 49 | elem_num = len(self.train_data_loader) 50 | train_enumerator = enumerate(self.train_data_loader) 51 | for _ in tqdm(range(elem_num)): 52 | itera, data = train_enumerator.__next__() 53 | self.optim.zero_grad() 54 | 55 | pre_img, post_img, bcd_labels, _ = data 56 | 57 | pre_img = pre_img.cuda().float() 58 | post_img = post_img.cuda().float() 59 | bcd_labels = bcd_labels.cuda().long() 60 | # input_data = torch.cat([pre_img, post_img], dim=1) 61 | 62 | # bcd_output = self.deep_model(input_data) 63 | bcd_output = self.deep_model(pre_img, post_img) 64 | 65 | bcd_loss = F.cross_entropy(bcd_output, bcd_labels, weight=class_weight, ignore_index=255) 66 | lovasz_loss = L.lovasz_softmax(F.softmax(bcd_output, dim=1), bcd_labels, ignore=255) 67 | 68 | main_loss = bcd_loss + 0.75 * lovasz_loss 69 | main_loss.backward() 70 | 71 | self.optim.step() 72 | 73 | if (itera + 1) % 10 == 0: 74 | print( 75 | f'iter is {itera + 1}, change detection loss is {bcd_loss}') 76 | if (itera + 1) % 200 == 0: 77 | self.deep_model.eval() 78 | rec, pre, oa, f1_score, iou, kc = self.validation() 79 | if kc > best_kc: 80 | torch.save(self.deep_model.state_dict(), 81 | os.path.join(self.model_save_path, f'{itera + 1}_model.pth')) 82 | 83 | best_kc = kc 84 | best_round = [rec, pre, oa, f1_score, iou, kc] 85 | self.deep_model.train() 86 | 87 | print('The accuracy of the best round is ', best_round) 88 | 89 | def validation(self): 90 | print('---------starting evaluation-----------') 91 | self.evaluator.reset() 92 | dataset_path = '/home/songjian/project/HSIFM/dataset/OSCD/original_data' 93 | with open('/home/songjian/project/HSIFM/dataset/OSCD/original_data/test.txt', "r") as f: 94 | # data_name_list = f.read() 95 | data_name_list = [data_name.strip() for data_name in f] 96 | data_name_list = data_name_list 97 | dataset = OSCDDatset13Bands(dataset_path=dataset_path, data_list=data_name_list, crop_size=512, 98 | max_iters=None, type='test') 99 | val_data_loader = DataLoader(dataset, batch_size=1, num_workers=8, drop_last=False) 100 | torch.cuda.empty_cache() 101 | 102 | for itera, data in enumerate(val_data_loader): 103 | pre_img, post_img, bcd_labels, data_name = data 104 | 105 | pre_img = pre_img.cuda().float() 106 | post_img = post_img.cuda().float() 107 | bcd_labels = bcd_labels.cuda().long() 108 | # input_data = torch.cat([pre_img, post_img], dim=1) 109 | 110 | # bcd_output = self.deep_model(input_data) 111 | bcd_output = self.deep_model(pre_img, post_img) 112 | bcd_output = bcd_output.data.cpu().numpy() 113 | bcd_output = np.argmax(bcd_output, axis=1) 114 | 115 | bcd_img = bcd_output[0].copy() 116 | bcd_img[bcd_img == 1] = 255 117 | 118 | # imageio.imwrite('./' + data_name[0] + '.png', bcd_img) 119 | 120 | bcd_labels = bcd_labels.cpu().numpy() 121 | self.evaluator.add_batch(bcd_labels, bcd_output) 122 | 123 | f1_score = self.evaluator.Pixel_F1_score() 124 | oa = self.evaluator.Pixel_Accuracy() 125 | rec = self.evaluator.Pixel_Recall_Rate() 126 | pre = self.evaluator.Pixel_Precision_Rate() 127 | iou = self.evaluator.Intersection_over_Union() 128 | kc = self.evaluator.Kappa_coefficient() 129 | print(f'Racall rate is {rec}, Precision rate is {pre}, OA is {oa}, ' 130 | f'F1 score is {f1_score}, IoU is {iou}, Kappa coefficient is {kc}') 131 | return rec, pre, oa, f1_score, iou, kc 132 | 133 | 134 | def main(): 135 | parser = argparse.ArgumentParser(description="Training on OEM_OSM dataset") 136 | parser.add_argument('--dataset', type=str, default='OSCD_13Bands') 137 | parser.add_argument('--dataset_path', type=str, 138 | default='/home/songjian/project/HSIFM/dataset/OSCD/original_data') 139 | parser.add_argument('--type', type=str, default='train') 140 | parser.add_argument('--train_data_list_path', type=str, 141 | default='/home/songjian/project/HSIFM/dataset/OSCD/original_data/train.txt') 142 | parser.add_argument('--shuffle', type=bool, default=True) 143 | parser.add_argument('--batch_size', type=int, default=16) 144 | parser.add_argument('--data_name_list', type=list) 145 | parser.add_argument('--start_iter', type=int, default=0) 146 | parser.add_argument('--cuda', type=bool, default=True) 147 | parser.add_argument('--crop_size', type=int, default=256) 148 | parser.add_argument('--max_iters', type=int, default=100000) 149 | parser.add_argument('--model_type', type=str, default='SiamCRNN') 150 | parser.add_argument('--model_param_path', type=str, default='../saved_models') 151 | 152 | parser.add_argument('--resume', type=str) 153 | parser.add_argument('--learning_rate', type=float, default=1e-4) 154 | parser.add_argument('--momentum', type=float, default=0.9) 155 | parser.add_argument('--weight_decay', type=float, default=5e-4) 156 | 157 | args = parser.parse_args() 158 | with open(args.train_data_list_path, "r") as f: 159 | # data_name_list = f.read() 160 | data_name_list = [data_name.strip() for data_name in f] 161 | args.data_name_list = data_name_list 162 | 163 | trainer = Trainer(args) 164 | trainer.training() 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /FCN_version/dataset/make_data_loader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import imageio 5 | import numpy as np 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import Dataset 9 | 10 | import dataset.imutils as imutils 11 | 12 | band_idx = ['B01.tif', 'B02.tif', 'B03.tif', 'B04.tif', 'B05.tif', 'B06.tif', 'B07.tif', 'B08.tif', 'B8A.tif', 13 | 'B09.tif', 'B10.tif', 'B11.tif', 'B12.tif'] 14 | 15 | 16 | def img_loader(path): 17 | img = np.array(imageio.imread(path), np.float32) 18 | return img 19 | 20 | 21 | def sentinel_loader(path): 22 | band_0_img = np.array(imageio.imread(os.path.join(path, 'B01.tif')), np.float32) 23 | ms_data = np.zeros((band_0_img.shape[0], band_0_img.shape[1], 13)) 24 | for i, band in enumerate(band_idx): 25 | ms_data[:, :, i] = np.array(imageio.imread(os.path.join(path, band)), np.float32) 26 | 27 | return ms_data 28 | 29 | 30 | def one_hot_encoding(image, num_classes=8): 31 | # Create a one hot encoded tensor 32 | one_hot = np.eye(num_classes)[image.astype(np.uint8)] 33 | 34 | # Move the channel axis to the front 35 | # one_hot = np.moveaxis(one_hot, -1, 0) 36 | 37 | return one_hot 38 | 39 | 40 | class OSCDDatset3Bands(Dataset): 41 | def __init__(self, dataset_path, data_list, crop_size, max_iters=None, type='train', data_loader=img_loader): 42 | self.dataset_path = dataset_path 43 | self.data_list = data_list 44 | self.loader = data_loader 45 | self.type = type 46 | self.data_pro_type = self.type 47 | 48 | if max_iters is not None: 49 | self.data_list = self.data_list * int(np.ceil(float(max_iters) / len(self.data_list))) 50 | self.data_list = self.data_list[0:max_iters] 51 | self.crop_size = crop_size 52 | 53 | def __transforms(self, aug, pre_img, post_img, label): 54 | if aug: 55 | pre_img, post_img, label = imutils.random_crop(pre_img, post_img, label, self.crop_size) 56 | pre_img, post_img, label = imutils.random_fliplr(pre_img, post_img, label) 57 | pre_img, post_img, label = imutils.random_flipud(pre_img, post_img, label) 58 | pre_img, post_img, label = imutils.random_rot(pre_img, post_img, label) 59 | 60 | pre_img = imutils.normalize_img(pre_img) # imagenet normalization 61 | pre_img = np.transpose(pre_img, (2, 0, 1)) 62 | 63 | post_img = imutils.normalize_img(post_img) # imagenet normalization 64 | post_img = np.transpose(post_img, (2, 0, 1)) 65 | 66 | return pre_img, post_img, label 67 | 68 | def __getitem__(self, index): 69 | pre_path = os.path.join(self.dataset_path, self.data_list[index], 'pair', 'img1.png') 70 | post_path = os.path.join(self.dataset_path, self.data_list[index], 'pair', 'img2.png') 71 | label_path = os.path.join(self.dataset_path, self.data_list[index], 'cm', 'cm.png') 72 | pre_img = self.loader(pre_path) 73 | post_img = self.loader(post_path) 74 | label = self.loader(label_path) 75 | 76 | if len(label.shape) > 2: 77 | label = label[:, :, 0] 78 | label = label / 255 79 | 80 | if 'train' in self.data_pro_type: 81 | pre_img, post_img, label = self.__transforms(True, pre_img, post_img, label) 82 | else: 83 | pre_img, post_img, label = self.__transforms(False, pre_img, post_img, label) 84 | label = np.asarray(label) 85 | 86 | data_idx = self.data_list[index] 87 | return pre_img, post_img, label, data_idx 88 | 89 | def __len__(self): 90 | return len(self.data_list) 91 | 92 | 93 | class OSCDDatset13Bands(Dataset): 94 | def __init__(self, dataset_path, data_list, crop_size, max_iters=None, type='train', data_loader=img_loader, 95 | sentinel_loader=sentinel_loader): 96 | self.dataset_path = dataset_path 97 | self.data_list = data_list 98 | self.data_loader = data_loader 99 | self.sentinel_loader = sentinel_loader 100 | 101 | self.type = type 102 | self.data_pro_type = self.type 103 | 104 | if max_iters is not None: 105 | self.data_list = self.data_list * int(np.ceil(float(max_iters) / len(self.data_list))) 106 | self.data_list = self.data_list[0:max_iters] 107 | self.crop_size = crop_size 108 | 109 | def __transforms(self, aug, pre_img, post_img, label): 110 | if aug: 111 | pre_img, post_img, label = imutils.random_crop(pre_img, post_img, label, self.crop_size) 112 | pre_img, post_img, label = imutils.random_fliplr(pre_img, post_img, label) 113 | pre_img, post_img, label = imutils.random_flipud(pre_img, post_img, label) 114 | pre_img, post_img, label = imutils.random_rot(pre_img, post_img, label) 115 | 116 | pre_img = imutils.normalize_img(pre_img) # imagenet normalization 117 | pre_img = np.transpose(pre_img, (2, 0, 1)) 118 | 119 | post_img = imutils.normalize_img(post_img) # imagenet normalization 120 | post_img = np.transpose(post_img, (2, 0, 1)) 121 | 122 | return pre_img, post_img, label 123 | 124 | def __getitem__(self, index): 125 | pre_path = os.path.join(self.dataset_path, self.data_list[index], 'imgs_1_rect', 'ms_data.npy') 126 | post_path = os.path.join(self.dataset_path, self.data_list[index], 'imgs_2_rect', 'ms_data.npy') 127 | label_path = os.path.join(self.dataset_path, self.data_list[index], 'cm', 'cm.png') 128 | pre_img = np.load(pre_path) 129 | post_img = np.load(post_path) 130 | label = self.data_loader(label_path) 131 | 132 | if len(label.shape) > 2: 133 | label = label[:, :, 0] 134 | label = label / 255 135 | 136 | if 'train' in self.data_pro_type: 137 | pre_img, post_img, label = self.__transforms(True, pre_img, post_img, label) 138 | else: 139 | pre_img, post_img, label = self.__transforms(False, pre_img, post_img, label) 140 | label = np.asarray(label) 141 | 142 | data_idx = self.data_list[index] 143 | return pre_img, post_img, label, data_idx 144 | 145 | def __len__(self): 146 | return len(self.data_list) 147 | 148 | 149 | def make_data_loader(args, **kwargs): # **kwargs could be omitted 150 | if 'OSCD_3Bands' in args.dataset: 151 | dataset = OSCDDatset3Bands(args.dataset_path, args.data_name_list, args.crop_size, args.max_iters, args.type) 152 | # train_sampler = DistributedSampler(dataset, shuffle=True) 153 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle, **kwargs, num_workers=16, 154 | drop_last=False) 155 | return data_loader 156 | 157 | if 'OSCD_13Bands' in args.dataset: 158 | dataset = OSCDDatset13Bands(args.dataset_path, args.data_name_list, args.crop_size, args.max_iters, args.type) 159 | # train_sampler = DistributedSampler(dataset, shuffle=True) 160 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle, **kwargs, num_workers=16, 161 | drop_last=False) 162 | return data_loader 163 | 164 | else: 165 | raise NotImplementedError 166 | 167 | 168 | if __name__ == '__main__': 169 | 170 | parser = argparse.ArgumentParser(description="SECOND DataLoader Test") 171 | parser.add_argument('--dataset', type=str, default='WHUBCD') 172 | parser.add_argument('--max_iters', type=int, default=10000) 173 | parser.add_argument('--type', type=str, default='train') 174 | parser.add_argument('--dataset_path', type=str, default='D:/Workspace/Python/STCD/data/ST-WHU-BCD') 175 | parser.add_argument('--data_list_path', type=str, default='./ST-WHU-BCD/train_list.txt') 176 | parser.add_argument('--shuffle', type=bool, default=True) 177 | parser.add_argument('--batch_size', type=int, default=8) 178 | parser.add_argument('--data_name_list', type=list) 179 | 180 | args = parser.parse_args() 181 | 182 | with open(args.data_list_path, "r") as f: 183 | # data_name_list = f.read() 184 | data_name_list = [data_name.strip() for data_name in f] 185 | args.data_name_list = data_name_list 186 | train_data_loader = make_data_loader(args) 187 | for i, data in enumerate(train_data_loader): 188 | pre_img, post_img, labels, _ = data 189 | pre_data, post_data = Variable(pre_img), Variable(post_img) 190 | labels = Variable(labels) 191 | print(i, "个inputs", pre_data.data.size(), "labels", labels.data.size()) 192 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | 6 | import gdal 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from SiamCRNN import SiamCRNN 11 | 12 | from util.data_prepro import stad_img 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--max_epoch', type=int, default=300, help='epoch to run[default: 50]') 16 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size during training[default: 512]') 17 | parser.add_argument('--learning_rate', type=float, default=2e-4, help='initial learning rate[default: 3e-4]') 18 | parser.add_argument('--save_path', default='model_param', help='model param path') 19 | parser.add_argument('--data_path', default=None, help='dataset path') 20 | parser.add_argument('--gpu_num', type=int, default=1, help='number of GPU to train') 21 | 22 | # basic params 23 | FLAGS = parser.parse_args() 24 | 25 | BATCH_SZ = FLAGS.batch_size 26 | LEARNING_RATE = FLAGS.learning_rate 27 | MAX_EPOCH = FLAGS.max_epoch 28 | SAVE_PATH = FLAGS.save_path 29 | DATA_PATH = FLAGS.data_path 30 | GPU_NUM = FLAGS.gpu_num 31 | BATCH_PER_GPU = BATCH_SZ // GPU_NUM 32 | 33 | 34 | class ChangeTrainer(object): 35 | 36 | def __init__(self): 37 | self.Input_X = None 38 | self.Input_Y = None 39 | self.label = None 40 | self.is_training = None 41 | 42 | self.net = None 43 | self.pred = None 44 | self.loss = None 45 | self.opt = None 46 | self.train_op = None 47 | self.global_step = tf.Variable(0, trainable=False) 48 | self.siamcrnn_model = SiamCRNN() 49 | 50 | 51 | def load_data(self): 52 | data_set_X = gdal.Open('data/GF_2_2/0411') # data set X 53 | data_set_Y = gdal.Open('data/GF_2_2/0901') # data set Y 54 | 55 | img_width = data_set_X.RasterXSize # image width 56 | img_height = data_set_X.RasterYSize # image height 57 | 58 | img_X = data_set_X.ReadAsArray(0, 0, img_width, img_height) 59 | img_Y = data_set_Y.ReadAsArray(0, 0, img_width, img_height) 60 | 61 | img_X = stad_img(img_X) # (C, H, W) 62 | img_Y = stad_img(img_Y) 63 | img_X = np.transpose(img_X, [1, 2, 0]) # (H, W, C) 64 | img_Y = np.transpose(img_Y, [1, 2, 0]) # (H, W, C) 65 | return img_X, img_Y 66 | 67 | def training(self): 68 | train_X, train_Y, train_label = self._load_train_data(path=DATA_PATH) 69 | self.valid_X, self.valid_Y, self.valid_label = self._load_valid_data(path=DATA_PATH) 70 | train_label = np.reshape(train_label, (-1, 1)) 71 | self.valid_label = np.reshape(self.valid_label, (-1, 1)) 72 | self.valid_sz = self.valid_X.shape[0] 73 | 74 | shape_1 = train_X.shape 75 | shape_2 = train_Y.shape 76 | train_sz = train_X.shape[0] 77 | self.Input_X = tf.placeholder(dtype=tf.float32, shape=[None, shape_1[1], shape_1[2], shape_1[3]], 78 | name='Input_X') 79 | self.Input_Y = tf.placeholder(dtype=tf.float32, shape=[None, shape_2[1], shape_2[2], shape_2[3]], 80 | name='Input_Y') 81 | self.label = tf.placeholder(dtype=tf.float32, shape=[None, 1]) 82 | self.is_training = tf.placeholder(dtype=tf.bool, name='is_training') 83 | 84 | 85 | self.net, self.pred, _, _ = self.siamcrnn_model.get_model(Input_X=self.Input_X, Input_Y=self.Input_Y, 86 | data_format='NHWC', 87 | is_training=self.is_training) # (B, 2) 88 | self.loss = self._get_loss(label=self.label, logits=self.net) 89 | self.opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) 90 | self.train_op = self.opt.minimize(loss=self.loss) 91 | best_loss = 100000 92 | iter_in_epoch = train_sz // BATCH_SZ 93 | config = tf.ConfigProto() 94 | config.gpu_options.allow_growth = True 95 | saver = tf.train.Saver(max_to_keep=0, var_list=tf.global_variables()) 96 | total_time = 0 97 | with tf.Session(config=config) as sess: 98 | sess.run(tf.global_variables_initializer()) 99 | epoch_sz = MAX_EPOCH 100 | 101 | for epoch in range(epoch_sz): 102 | tic = time.time() 103 | ave_loss = 0 104 | train_idx = np.arange(0, train_sz) 105 | np.random.shuffle(train_idx) 106 | for _iter in range(iter_in_epoch): 107 | start_idx = _iter * BATCH_SZ 108 | end_idx = (_iter + 1) * BATCH_SZ 109 | batch_train_X = train_X[train_idx[start_idx:end_idx]] 110 | batch_train_Y = train_Y[train_idx[start_idx:end_idx]] 111 | batch_label = train_label[train_idx[start_idx:end_idx]] 112 | loss, _, logits = sess.run( 113 | [self.loss, self.train_op, self.net], 114 | feed_dict={ 115 | self.Input_X: batch_train_X, 116 | self.Input_Y: batch_train_Y, 117 | self.label: batch_label, 118 | self.is_training: True 119 | }) 120 | ave_loss += loss 121 | ave_loss /= iter_in_epoch 122 | toc = time.time() 123 | total_time += (toc - tic) 124 | # print("epoch %d , loss is %f take %.3f s , min logits is %.3f, min pred is %.3f" % ( 125 | # epoch + 1, ave_loss, time.time() - tic, min_logits, min_pred)) 126 | val_loss = self.evaluate(sess) 127 | 128 | if (epoch + 1) % 5 == 0: 129 | if val_loss < best_loss: 130 | best_loss = val_loss 131 | _path = saver.save(sess, os.path.join(SAVE_PATH, "best_model.ckpt")) 132 | print("best model is saved") 133 | _path = saver.save(sess, os.path.join(SAVE_PATH, "cha_model_%d.ckpt" % (epoch + 1))) 134 | print("epoch %d, model saved in file: " % (epoch + 1), _path) 135 | # self.evaluate(sess) 136 | _path = saver.save(sess, os.path.join(SAVE_PATH, 'final_model.ckpt')) 137 | print("Model saved in file: ", _path) 138 | print(total_time) 139 | 140 | def evaluate(self, sess): 141 | iter_in_epoch = self.valid_sz // BATCH_SZ 142 | valid_idx = np.arange(0, self.valid_sz) 143 | np.random.shuffle(valid_idx) 144 | ave_loss = 0 145 | for _iter in range(iter_in_epoch): 146 | start_idx = _iter * BATCH_SZ 147 | end_idx = (_iter + 1) * BATCH_SZ 148 | batch_valid_X = self.valid_X[valid_idx[start_idx:end_idx]] 149 | batch_valid_Y = self.valid_Y[valid_idx[start_idx:end_idx]] 150 | batch_label = self.valid_label[valid_idx[start_idx:end_idx]] 151 | loss = sess.run(self.loss, feed_dict={ # (B, 2) 152 | self.Input_X: batch_valid_X, 153 | self.Input_Y: batch_valid_Y, 154 | self.label: batch_label, 155 | self.is_training: False 156 | }) 157 | ave_loss += loss 158 | ave_loss /= iter_in_epoch 159 | print("evaluate is done, validation loss is %.3f" % ave_loss) 160 | return ave_loss 161 | 162 | def _get_loss(self, label, logits): 163 | loss = tf.reduce_mean( 164 | tf.nn.weighted_cross_entropy_with_logits(targets=label, logits=logits, pos_weight=5, name='weight_loss')) 165 | return loss 166 | 167 | def _load_train_data(self, path): 168 | with open(os.path.join(path, 'train_sample_X.pickle'), 'rb') as file: 169 | train_X = pickle.load(file) 170 | with open(os.path.join(path, 'train_sample_Y.pickle'), 'rb') as file: 171 | train_Y = pickle.load(file) 172 | with open(os.path.join(path, 'train_label.pickle'), 'rb') as file: 173 | train_label = pickle.load(file) 174 | 175 | return train_X, train_Y, train_label 176 | 177 | def _load_valid_data(self, path): 178 | with open(os.path.join(path, 'valid_sample_X.pickle'), 'rb') as file: 179 | valid_X = pickle.load(file) 180 | with open(os.path.join(path, 'valid_sample_Y.pickle'), 'rb') as file: 181 | valid_Y = pickle.load(file) 182 | with open(os.path.join(path, 'valid_label.pickle'), 'rb') as file: 183 | valid_label = pickle.load(file) 184 | 185 | return valid_X, valid_Y, valid_label 186 | 187 | 188 | if __name__ == '__main__': 189 | trainer = ChangeTrainer() 190 | trainer.training() 191 | -------------------------------------------------------------------------------- /util/net_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.layers import batch_normalization 3 | 4 | _EPSILON = 1e-7 5 | 6 | 7 | def conv_2d(inputs, kernel_size, output_channel, stride, name, is_reuse=False, padding='SAME', data_format='NHWC', 8 | is_bn=False, is_training=True, activation=None): 9 | """ 10 | 2D Conv Layer 11 | :param name: scope name 12 | :param inputs: (B, H, W, C) 13 | :param kernel_size: [kernel_h, kernel_w] 14 | :param output_channel: feature num 15 | :param stride: a list of 2 ints 16 | :param padding: type of padding, str 17 | :param data_format: str, the format of input data 18 | :param is_bn: bool, is batch normalization 19 | :param is_training: bool, is training 20 | :param activation: activation function, such as tf.nn.relu 21 | :return: outputs 22 | """ 23 | with tf.variable_scope(name) as scope: 24 | if is_reuse: 25 | scope.reuse_variables() 26 | kernel_h, kernel_w = kernel_size 27 | stride_h, stride_w = stride 28 | if data_format == 'NHWC': 29 | kernel_shape = [kernel_h, kernel_w, inputs.get_shape()[-1].value, output_channel] 30 | else: 31 | kernel_shape = [kernel_h, kernel_w, inputs.get_shape()[1].value, output_channel] 32 | init = tf.keras.initializers.he_normal() 33 | kernel = tf.get_variable(name='conv_kenel', shape=kernel_shape, initializer=init, dtype=tf.float32) 34 | # kernel = tf.Variable(tf.truncated_normal(kernel_shape, dtype=tf.float32, stddev=0.1)) 35 | outputs = tf.nn.conv2d(input=inputs, 36 | filter=kernel, 37 | strides=[1, stride_h, stride_w, 1], padding=padding, 38 | data_format=data_format) 39 | biases = tf.Variable(tf.constant(0.1, shape=[output_channel], dtype=tf.float32)) 40 | outputs = outputs + biases 41 | if is_bn: 42 | outputs = batch_normalization(outputs, training=is_training) 43 | if activation is not None: 44 | outputs = activation(outputs) 45 | return outputs 46 | 47 | 48 | def conv_2d_transpose(inputs, kernel_size, output_channel, output_shape, stride, name, padding='SAME', 49 | data_format='NHWC', is_bn=False, is_training=True, activation=None): 50 | """ 51 | 2D Transpose Conv Layer 52 | :param output_shape: 53 | :param name: scope name 54 | :param inputs: (B, H, W, C) 55 | :param kernel_size: [kernel_h, kernel_w] 56 | :param output_channel: feature num 57 | :param stride: a list of 2 ints 58 | :param padding: type of padding, str 59 | :param data_format: str, the format of input data 60 | :param is_bn: bool, is batch normalization 61 | :param is_training: bool, is training 62 | :param activation: activation function, such as tf.nn.relu 63 | :return: outputs 64 | """ 65 | with tf.variable_scope(name) as scope: 66 | kernel_h, kernel_w = kernel_size 67 | stride_h, stride_w = stride 68 | if data_format == 'NHWC': 69 | kernel_shape = [kernel_h, kernel_w, output_channel, inputs.get_shape()[-1].value] 70 | else: 71 | kernel_shape = [kernel_h, kernel_w, output_channel, inputs.get_shape()[1].value] 72 | 73 | init = tf.keras.initializers.he_normal() 74 | kernel = tf.get_variable(name='conv_kenel', shape=kernel_shape, initializer=init, dtype=tf.float32) 75 | 76 | # calculate output shape 77 | # batch_size, height, width, _ = inputs.get_shape().as_list() 78 | # out_height = get_deconv_dim(height, stride_h, kernel_h, padding) 79 | # out_width = get_deconv_dim(width, stride_w, kernel_w, padding) 80 | # output_shape = [batch_size, out_height, out_width, output_channel] 81 | 82 | # kernel = tf.Variable(tf.truncated_normal(kernel_shape, dtype=tf.float32, stddev=0.1)) 83 | outputs = tf.nn.conv2d_transpose(value=inputs, 84 | filter=kernel, 85 | output_shape=output_shape, 86 | strides=[1, stride_h, stride_w, 1], padding=padding, 87 | data_format=data_format) 88 | biases = tf.Variable(tf.constant(0.1, shape=[output_channel], dtype=tf.float32)) 89 | outputs = outputs + biases 90 | if is_bn: 91 | outputs = batch_normalization(outputs, training=is_training) 92 | if activation is not None: 93 | outputs = activation(outputs) 94 | return outputs 95 | 96 | 97 | # from slim.convolution2d_transpose 98 | def get_deconv_dim(dim_size, stride_size, kernel_size, padding): 99 | dim_size *= stride_size 100 | 101 | if padding == 'VALID' and dim_size is not None: 102 | dim_size += max(kernel_size - stride_size, 0) 103 | return dim_size 104 | 105 | 106 | def max_pool_2d(inputs, kernel_size, stride, padding='SAME'): 107 | """ 108 | 2D Max Pool Layer 109 | :param inputs: (B, H, W, C) 110 | :param kernel_size: [kernel_h, kernel_w] 111 | :param padding: type of padding, str 112 | :return: outputs 113 | """ 114 | kernel_h, kernel_w = kernel_size 115 | stride_h, stride_w = stride 116 | outputs = tf.nn.max_pool(inputs, 117 | ksize=[1, kernel_h, kernel_w, 1], 118 | strides=[1, stride_h, stride_w, 1], 119 | padding=padding) 120 | return outputs 121 | 122 | 123 | def avg_pool_2d(inputs, kernel_size, stride, padding='SAME'): 124 | """ 125 | 2D Avg Pool Layer 126 | :param inputs: (B, H, W, C) 127 | :param kernel_size: [kernel_h, kernel_w] 128 | :param padding: type of padding, str 129 | :return: outputs 130 | """ 131 | kernel_h, kernel_w = kernel_size 132 | stride_h, stride_w = stride 133 | outputs = tf.nn.avg_pool(inputs, 134 | ksize=[1, kernel_h, kernel_w, 1], 135 | strides=[1, stride_h, stride_w, 1], 136 | padding=padding) 137 | return outputs 138 | 139 | 140 | def fully_connected(inputs, num_outputs, is_training=True, is_bn=False, activation=None): 141 | """ 142 | Fully connected layer with non-linear operation 143 | :param inputs: (B, N) 144 | :param num_outputs: int 145 | :param is_training: bool 146 | :param is_bn: bool 147 | :param activation: activation function, such as tf.nn.relu 148 | :return: outputs: (B, num_outputs) 149 | """ 150 | 151 | num_input_units = inputs.get_shape()[-1].value 152 | weights = tf.Variable(tf.truncated_normal([num_input_units, num_outputs], dtype=tf.float32, stddev=0.1)) 153 | outputs = tf.matmul(inputs, weights) 154 | biases = tf.Variable(tf.constant(0.1, shape=[num_outputs], dtype=tf.float32)) 155 | outputs = tf.nn.bias_add(outputs, biases) 156 | if is_bn: 157 | outputs = batch_normalization(outputs, training=is_training) 158 | if activation is not None: 159 | outputs = activation(outputs) 160 | return outputs 161 | 162 | 163 | def weight_binary_cross_entropy(target, output, weight=1.0, from_logits=False): 164 | """weight binary crossentropy between an output tensor and a target tensor. 165 | 166 | # Arguments 167 | target: A tensor with the same shape as `output`. 168 | output: A tensor. 169 | from_logits: Whether `output` is expected to be a logits tensor. 170 | By default, we consider that `output` 171 | encodes a probability distribution. 172 | 173 | # Returns 174 | A tensor. 175 | """ 176 | # Note: tf.nn.sigmoid_cross_entropy_with_logits 177 | # expects logits, Keras expects probabilities. 178 | if not from_logits: 179 | # transform back to logits 180 | _epsilon = _to_tensor(epsilon(), output.dtype.base_dtype) 181 | output = tf.clip_by_value(output, _epsilon, 1 - _epsilon) 182 | output = tf.log(output / (1 - output)) 183 | 184 | return tf.nn.weighted_cross_entropy_with_logits(targets=target, 185 | logits=output, 186 | pos_weight=weight) 187 | 188 | 189 | def _to_tensor(x, dtype): 190 | """Convert the input `x` to a tensor of type `dtype`. 191 | 192 | # Arguments 193 | x: An object to be converted (numpy array, list, tensors). 194 | dtype: The destination type. 195 | 196 | # Returns 197 | A tensor. 198 | """ 199 | return tf.convert_to_tensor(x, dtype=dtype) 200 | 201 | 202 | def epsilon(): 203 | """Returns the value of the fuzz factor used in numeric expressions. 204 | 205 | # Returns 206 | A float. 207 | 208 | # Example 209 | ```python 210 | >>> keras.backend.epsilon() 211 | 1e-07 212 | ``` 213 | """ 214 | return _EPSILON 215 | -------------------------------------------------------------------------------- /FCN_version/deep_networks/resnet_18_34.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | from deep_networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | from torchvision.models import resnet50 7 | from torchvision.models import resnet 8 | 9 | __model_file = { 10 | 18: '/home/songjian/project/HSIFM/pretrained_weight/resnet-18-pytorch.pth', 11 | 34: '/home/songjian/project/HSIFM/pretrained_weight/resnet-34-pytorch.pth', 12 | } 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 19 | super(BasicBlock, self).__init__() 20 | # if norm_layer is None: 21 | # norm_layer = nn.BatchNorm2d 22 | # if groups != 1 or base_width != 64: 23 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64') 24 | # if dilation > 1: 25 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 27 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 28 | dilation=dilation, padding=dilation, bias=False) 29 | self.bn1 = BatchNorm(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 32 | self.bn2 = BatchNorm(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | identity = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | identity = self.downsample(x) 48 | 49 | out += identity 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class ResNet(nn.Module): 56 | def __init__(self, input_dim, block, layers, output_stride, BatchNorm): 57 | self.inplanes = 64 58 | super(ResNet, self).__init__() 59 | blocks = [1, 2, 4] 60 | if output_stride == 16: 61 | strides = [1, 2, 2, 1] 62 | dilations = [1, 1, 1, 2] 63 | elif output_stride == 8: 64 | strides = [1, 2, 1, 1] 65 | dilations = [1, 1, 2, 4] 66 | else: 67 | raise NotImplementedError 68 | 69 | # Modules 70 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, 71 | bias=False) 72 | self.bn1 = BatchNorm(64) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 75 | 76 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], 77 | BatchNorm=BatchNorm) 78 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], 79 | BatchNorm=BatchNorm) 80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], 81 | BatchNorm=BatchNorm) 82 | # self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], 83 | # BatchNorm=BatchNorm) 84 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], 85 | BatchNorm=BatchNorm) 86 | 87 | self._init_weight() 88 | 89 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 90 | downsample = None 91 | if stride != 1 or self.inplanes != planes * block.expansion: 92 | downsample = nn.Sequential( 93 | nn.Conv2d(self.inplanes, planes * block.expansion, 94 | kernel_size=1, stride=stride, bias=False), 95 | BatchNorm(planes * block.expansion), 96 | ) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 100 | self.inplanes = planes * block.expansion 101 | for i in range(1, blocks): 102 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 103 | 104 | return nn.Sequential(*layers) 105 | 106 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 107 | downsample = None 108 | if stride != 1 or self.inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | nn.Conv2d(self.inplanes, planes * block.expansion, 111 | kernel_size=1, stride=stride, bias=False), 112 | BatchNorm(planes * block.expansion), 113 | ) 114 | 115 | layers = [] 116 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation, 117 | downsample=downsample, BatchNorm=BatchNorm)) 118 | self.inplanes = planes * block.expansion 119 | for i in range(1, len(blocks)): 120 | layers.append(block(self.inplanes, planes, stride=1, 121 | dilation=blocks[i] * dilation, BatchNorm=BatchNorm)) 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, input): 126 | x = self.conv1(input) 127 | x = self.bn1(x) 128 | x = self.relu(x) 129 | # low_level_feat_0 = x 130 | x = self.maxpool(x) 131 | 132 | x = self.layer1(x) 133 | low_level_feat_1 = x 134 | x = self.layer2(x) 135 | low_level_feat_2 = x 136 | x = self.layer3(x) 137 | # x = self.dropout_3(x) 138 | low_level_feat_3 = x 139 | x = self.layer4(x) 140 | # x = self.dropout_4(x) 141 | return low_level_feat_1, low_level_feat_2, low_level_feat_3, x 142 | 143 | def _init_weight(self): 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 147 | m.weight.data.normal_(0, math.sqrt(2. / n)) 148 | elif isinstance(m, SynchronizedBatchNorm2d): 149 | m.weight.data.fill_(1) 150 | m.bias.data.zero_() 151 | elif isinstance(m, nn.BatchNorm2d): 152 | m.weight.data.fill_(1) 153 | m.bias.data.zero_() 154 | 155 | def _load_pretrained_model(self): 156 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 157 | model_dict = {} 158 | state_dict = self.state_dict() 159 | for k, v in pretrain_dict.items(): 160 | if k in state_dict: 161 | model_dict[k] = v 162 | state_dict.update(model_dict) 163 | self.load_state_dict(state_dict) 164 | 165 | 166 | def ResNet18(input_dim, output_stride, BatchNorm, pretrained=False): 167 | """Constructs a ResNet-18 model. 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | """ 171 | model = ResNet(input_dim, BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm) 172 | if pretrained: 173 | pretrain_dict = torch.load(__model_file[18]) 174 | model_dict = {} 175 | state_dict = model.state_dict() 176 | for k, v in pretrain_dict.items(): 177 | if k in state_dict: 178 | model_dict[k] = v 179 | state_dict.update(model_dict) 180 | model.load_state_dict(state_dict) 181 | print('pretrainde model has been loaded') 182 | 183 | return model 184 | 185 | 186 | def ResNet34(input_dim, output_stride, BatchNorm, pretrained=False): 187 | """Constructs a ResNet-18 model. 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(input_dim, BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm) 192 | if pretrained: 193 | pretrain_dict = torch.load(__model_file[34]) 194 | model_dict = {} 195 | state_dict = model.state_dict() 196 | for k, v in pretrain_dict.items(): 197 | if k in state_dict: 198 | model_dict[k] = v 199 | state_dict.update(model_dict) 200 | model.load_state_dict(state_dict) 201 | print('pretrainde model has been loaded') 202 | 203 | return model 204 | 205 | 206 | # 207 | # if __name__ == "__main__": 208 | # import torch 209 | # 210 | # model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 211 | # input = torch.rand(1, 3, 1024, 1024) 212 | # output, low_level_feat = model(input) 213 | # print(output.size()) 214 | # print(low_level_feat.size()) 215 | -------------------------------------------------------------------------------- /FCN_version/util_func/lovasz_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | try: 14 | from itertools import ifilterfalse 15 | except ImportError: # py3k 16 | from itertools import filterfalse as ifilterfalse 17 | 18 | 19 | def lovasz_grad(gt_sorted): 20 | """ 21 | Computes gradient of the Lovasz extension w.r.t sorted errors 22 | See Alg. 1 in paper 23 | """ 24 | p = len(gt_sorted) 25 | gts = gt_sorted.sum() 26 | intersection = gts - gt_sorted.float().cumsum(0) 27 | union = gts + (1 - gt_sorted).float().cumsum(0) 28 | jaccard = 1. - intersection / union 29 | if p > 1: # cover 1-pixel case 30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 31 | return jaccard 32 | 33 | 34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 35 | """ 36 | IoU for foreground class 37 | binary: 1 foreground, 0 background 38 | """ 39 | if not per_image: 40 | preds, labels = (preds,), (labels,) 41 | ious = [] 42 | for pred, label in zip(preds, labels): 43 | intersection = ((label == 1) & (pred == 1)).sum() 44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 45 | if not union: 46 | iou = EMPTY 47 | else: 48 | iou = float(intersection) / float(union) 49 | ious.append(iou) 50 | iou = mean(ious) # mean accross images if per_image 51 | return 100 * iou 52 | 53 | 54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 55 | """ 56 | Array of IoU for each (non ignored) class 57 | """ 58 | if not per_image: 59 | preds, labels = (preds,), (labels,) 60 | ious = [] 61 | for pred, label in zip(preds, labels): 62 | iou = [] 63 | for i in range(C): 64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 65 | intersection = ((label == i) & (pred == i)).sum() 66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 67 | if not union: 68 | iou.append(EMPTY) 69 | else: 70 | iou.append(float(intersection) / float(union)) 71 | ious.append(iou) 72 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 73 | return 100 * np.array(ious) 74 | 75 | 76 | # --------------------------- BINARY LOSSES --------------------------- 77 | 78 | 79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 80 | """ 81 | Binary Lovasz hinge loss 82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 84 | per_image: compute the loss per image instead of per batch 85 | ignore: void class id 86 | """ 87 | if per_image: 88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 89 | for log, lab in zip(logits, labels)) 90 | else: 91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 92 | return loss 93 | 94 | 95 | def lovasz_hinge_flat(logits, labels): 96 | """ 97 | Binary Lovasz hinge loss 98 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 99 | labels: [P] Tensor, binary ground truth labels (0 or 1) 100 | ignore: label to ignore 101 | """ 102 | if len(labels) == 0: 103 | # only void pixels, the gradients should be 0 104 | return logits.sum() * 0. 105 | signs = 2. * labels.float() - 1. 106 | errors = (1. - logits * Variable(signs)) 107 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 108 | perm = perm.data 109 | gt_sorted = labels[perm] 110 | grad = lovasz_grad(gt_sorted) 111 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 112 | return loss 113 | 114 | 115 | def flatten_binary_scores(scores, labels, ignore=None): 116 | """ 117 | Flattens predictions in the batch (binary case) 118 | Remove labels equal to 'ignore' 119 | """ 120 | scores = scores.view(-1) 121 | labels = labels.view(-1) 122 | if ignore is None: 123 | return scores, labels 124 | valid = (labels != ignore) 125 | vscores = scores[valid] 126 | vlabels = labels[valid] 127 | return vscores, vlabels 128 | 129 | 130 | class StableBCELoss(torch.nn.modules.Module): 131 | def __init__(self): 132 | super(StableBCELoss, self).__init__() 133 | 134 | def forward(self, input, target): 135 | neg_abs = - input.abs() 136 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 137 | return loss.mean() 138 | 139 | 140 | def binary_xloss(logits, labels, ignore=None): 141 | """ 142 | Binary Cross entropy loss 143 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 144 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 145 | ignore: void class id 146 | """ 147 | logits, labels = flatten_binary_scores(logits, labels, ignore) 148 | loss = StableBCELoss()(logits, Variable(labels.float())) 149 | return loss 150 | 151 | 152 | # --------------------------- MULTICLASS LOSSES --------------------------- 153 | 154 | 155 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=255): 156 | """ 157 | Multi-class Lovasz-Softmax loss 158 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 159 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 160 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 161 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 162 | per_image: compute the loss per image instead of per batch 163 | ignore: void class labels 164 | """ 165 | if per_image: 166 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 167 | for prob, lab in zip(probas, labels)) 168 | else: 169 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 170 | return loss 171 | 172 | 173 | def lovasz_softmax_flat(probas, labels, classes='present'): 174 | """ 175 | Multi-class Lovasz-Softmax loss 176 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 177 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 178 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 179 | """ 180 | if probas.numel() == 0: 181 | # only void pixels, the gradients should be 0 182 | return 0. # current probas.shape=torch.size([0,4]) and (probas*0.).shape is also torch.size([0,4]),a tensor loss cannot be backforward 183 | if len(probas.shape)==1: 184 | probas.unsqueeze(0) # only one pixel match the interest classes, probas.shape=torch.size([4]), we need it to be torch.size([1,4]) for the subsequent calculation 185 | C = probas.size(1) 186 | losses = [] 187 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 188 | for c in class_to_sum: 189 | fg = (labels == c).float() # foreground for class c 190 | if (classes is 'present' and fg.sum() == 0): 191 | continue 192 | if C == 1: 193 | if len(classes) > 1: 194 | raise ValueError('Sigmoid output possible only with 1 class') 195 | class_pred = probas[:, 0] 196 | else: 197 | class_pred = probas[:, c] 198 | errors = (Variable(fg) - class_pred).abs() 199 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 200 | perm = perm.data 201 | fg_sorted = fg[perm] 202 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 203 | return mean(losses) 204 | 205 | 206 | def flatten_probas(probas, labels, ignore=None): 207 | """ 208 | Flattens predictions in the batch 209 | """ 210 | if probas.dim() == 3: 211 | # assumes output of a sigmoid layer 212 | B, H, W = probas.size() 213 | probas = probas.view(B, 1, H, W) 214 | B, C, H, W = probas.size() 215 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 216 | labels = labels.view(-1) 217 | if ignore is None: 218 | return probas, labels 219 | valid = (labels != ignore) 220 | vprobas = probas[valid.nonzero().squeeze()] 221 | vlabels = labels[valid] 222 | return vprobas, vlabels 223 | 224 | 225 | def xloss(logits, labels, ignore=None): 226 | """ 227 | Cross entropy loss 228 | """ 229 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 230 | 231 | 232 | # --------------------------- HELPER FUNCTIONS --------------------------- 233 | def isnan(x): 234 | return x != x 235 | 236 | 237 | def mean(l, ignore_nan=False, empty=0): 238 | """ 239 | nanmean compatible with generators. 240 | """ 241 | l = iter(l) 242 | if ignore_nan: 243 | l = ifilterfalse(isnan, l) 244 | try: 245 | n = 1 246 | acc = next(l) 247 | except StopIteration: 248 | if empty == 'raise': 249 | raise ValueError('Empty mean') 250 | return empty 251 | for n, v in enumerate(l, 2): 252 | acc += v 253 | if n == 1: 254 | return acc 255 | return acc / n -------------------------------------------------------------------------------- /FCN_version/deep_networks/SiamCRNN.py: -------------------------------------------------------------------------------- 1 | from deep_networks.resnet_18_34 import ResNet34, ResNet18 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SiamCRNN(nn.Module): 10 | def __init__(self, in_dim_1, in_dim_2, pretrained=True, output_stride=16, BatchNorm=nn.BatchNorm2d, bias=True): 11 | super(SiamCRNN, self).__init__() 12 | self.encoder_1 = ResNet34(input_dim=in_dim_1, BatchNorm=BatchNorm, pretrained=False, output_stride=output_stride) 13 | # self.encoder_1.conv1 = nn.Conv2d(in_dim_1, 64, kernel_size=7, stride=2, padding=3, bias=False) 14 | # If your dataset is heterogeneous, then please utilize pesudo-siamese architecture 15 | # self.encoder_2 = ResNet18(input_dim=in_dim_2, BatchNorm=BatchNorm, pretrained=True, output_stride=output_stride) 16 | self.convlstm_4 = ConvLSTM(input_dim=512, hidden_dim=128, kernel_size=(3, 3), num_layers=1, 17 | batch_first=True) 18 | self.convlstm_3 = ConvLSTM(input_dim=256, hidden_dim=128, kernel_size=(3, 3), num_layers=1, 19 | batch_first=True) 20 | self.convlstm_2 = ConvLSTM(input_dim=128, hidden_dim=128, kernel_size=(3, 3), num_layers=1, 21 | batch_first=True) 22 | self.convlstm_1 = ConvLSTM(input_dim=64, hidden_dim=128, kernel_size=(3, 3), num_layers=1, 23 | batch_first=True) 24 | self.smooth_layer_3 = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding=1), 25 | nn.BatchNorm2d(128), nn.ReLU()) 26 | self.smooth_layer_2 = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding=1), 27 | nn.BatchNorm2d(128), nn.ReLU()) 28 | self.smooth_layer_1 = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding=1), 29 | nn.BatchNorm2d(128), nn.ReLU()) 30 | 31 | self.main_clf_1 = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1) 32 | 33 | def _upsample_add(self, x, y): 34 | _, _, H, W = y.size() 35 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 36 | 37 | def forward(self, pre_data, post_data): 38 | pre_low_level_feat_1, pre_low_level_feat_2, pre_low_level_feat_3, pre_output = \ 39 | self.encoder_1(pre_data) 40 | post_low_level_feat_1, post_low_level_feat_2, post_low_level_feat_3, post_output = \ 41 | self.encoder_1(post_data) 42 | 43 | # Concatenate along the time dimension 44 | combined_4 = torch.stack([pre_output, post_output], dim=1) 45 | # Apply ConvLSTM 46 | _, last_state_list_4 = self.convlstm_4(combined_4) 47 | p4 = last_state_list_4[0][0] 48 | 49 | combined_3 = torch.stack([pre_low_level_feat_3, post_low_level_feat_3], dim=1) 50 | # Apply ConvLSTM 51 | _, last_state_list_3 = self.convlstm_3(combined_3) 52 | p3 = last_state_list_3[0][0] 53 | p3 = self._upsample_add(p4, p3) 54 | p3 = self.smooth_layer_3(p3) 55 | 56 | combined_2 = torch.stack([pre_low_level_feat_2, post_low_level_feat_2], dim=1) 57 | # Apply ConvLSTM 58 | _, last_state_list_2 = self.convlstm_2(combined_2) 59 | p2 = last_state_list_2[0][0] 60 | p2 = self._upsample_add(p3, p2) 61 | p2 = self.smooth_layer_2(p2) 62 | 63 | combined_1 = torch.stack([pre_low_level_feat_1, post_low_level_feat_1], dim=1) 64 | # Apply ConvLSTM 65 | _, last_state_list_1 = self.convlstm_1(combined_1) 66 | p1 = last_state_list_1[0][0] 67 | p1 = self._upsample_add(p2, p1) 68 | p1 = self.smooth_layer_1(p1) 69 | 70 | output_1 = self.main_clf_1(p1) 71 | output_1 = F.interpolate(output_1, size=pre_data.size()[-2:], mode='bilinear') 72 | return output_1 73 | 74 | 75 | class ConvLSTMCell(nn.Module): 76 | 77 | def __init__(self, input_dim, hidden_dim, kernel_size, bias): 78 | """ 79 | Initialize ConvLSTM cell. 80 | 81 | Parameters 82 | ---------- 83 | input_dim: int 84 | Number of channels of input tensor. 85 | hidden_dim: int 86 | Number of channels of hidden state. 87 | kernel_size: (int, int) 88 | Size of the convolutional kernel. 89 | bias: bool 90 | Whether or not to add the bias. 91 | """ 92 | 93 | super(ConvLSTMCell, self).__init__() 94 | 95 | self.input_dim = input_dim 96 | self.hidden_dim = hidden_dim 97 | 98 | self.kernel_size = kernel_size 99 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 100 | self.bias = bias 101 | 102 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 103 | out_channels=4 * self.hidden_dim, 104 | kernel_size=self.kernel_size, 105 | padding=self.padding, 106 | bias=self.bias) 107 | 108 | def forward(self, input_tensor, cur_state): 109 | h_cur, c_cur = cur_state 110 | 111 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 112 | 113 | combined_conv = self.conv(combined) 114 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 115 | i = torch.sigmoid(cc_i) 116 | f = torch.sigmoid(cc_f) 117 | o = torch.sigmoid(cc_o) 118 | g = torch.tanh(cc_g) 119 | 120 | c_next = f * c_cur + i * g 121 | h_next = o * torch.tanh(c_next) 122 | 123 | return h_next, c_next 124 | 125 | def init_hidden(self, batch_size, image_size): 126 | height, width = image_size 127 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 128 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) 129 | 130 | 131 | class ConvLSTM(nn.Module): 132 | """ 133 | 134 | Parameters: 135 | input_dim: Number of channels in input 136 | hidden_dim: Number of hidden channels 137 | kernel_size: Size of kernel in convolutions 138 | num_layers: Number of LSTM layers stacked on each other 139 | batch_first: Whether or not dimension 0 is the batch or not 140 | bias: Bias or no bias in Convolution 141 | return_all_layers: Return the list of computations for all layers 142 | Note: Will do same padding. 143 | 144 | Input: 145 | A tensor of size B, T, C, H, W or T, B, C, H, W 146 | Output: 147 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). 148 | 0 - layer_output_list is the list of lists of length T of each output 149 | 1 - last_state_list is the list of last states 150 | each element of the list is a tuple (h, c) for hidden state and memory 151 | Example: 152 | >> x = torch.rand((32, 10, 64, 128, 128)) 153 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) 154 | >> _, last_states = convlstm(x) 155 | >> h = last_states[0][0] # 0 for layer index, 0 for h index 156 | """ 157 | 158 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, 159 | batch_first=False, bias=True, return_all_layers=False): 160 | super(ConvLSTM, self).__init__() 161 | 162 | self._check_kernel_size_consistency(kernel_size) 163 | 164 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 165 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 166 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 167 | if not len(kernel_size) == len(hidden_dim) == num_layers: 168 | raise ValueError('Inconsistent list length.') 169 | 170 | self.input_dim = input_dim 171 | self.hidden_dim = hidden_dim 172 | self.kernel_size = kernel_size 173 | self.num_layers = num_layers 174 | self.batch_first = batch_first 175 | self.bias = bias 176 | self.return_all_layers = return_all_layers 177 | 178 | cell_list = [] 179 | for i in range(0, self.num_layers): 180 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] 181 | 182 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim, 183 | hidden_dim=self.hidden_dim[i], 184 | kernel_size=self.kernel_size[i], 185 | bias=self.bias)) 186 | 187 | self.cell_list = nn.ModuleList(cell_list) 188 | 189 | def forward(self, input_tensor, hidden_state=None): 190 | """ 191 | 192 | Parameters 193 | ---------- 194 | input_tensor: todo 195 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 196 | hidden_state: todo 197 | None. todo implement stateful 198 | 199 | Returns 200 | ------- 201 | last_state_list, layer_output 202 | """ 203 | if not self.batch_first: 204 | # (t, b, c, h, w) -> (b, t, c, h, w) 205 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 206 | 207 | b, _, _, h, w = input_tensor.size() 208 | 209 | # Implement stateful ConvLSTM 210 | if hidden_state is not None: 211 | raise NotImplementedError() 212 | else: 213 | # Since the init is done in forward. Can send image size here 214 | hidden_state = self._init_hidden(batch_size=b, 215 | image_size=(h, w)) 216 | 217 | layer_output_list = [] 218 | last_state_list = [] 219 | 220 | seq_len = input_tensor.size(1) 221 | cur_layer_input = input_tensor 222 | 223 | for layer_idx in range(self.num_layers): 224 | 225 | h, c = hidden_state[layer_idx] 226 | output_inner = [] 227 | for t in range(seq_len): 228 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 229 | cur_state=[h, c]) 230 | output_inner.append(h) 231 | 232 | layer_output = torch.stack(output_inner, dim=1) 233 | cur_layer_input = layer_output 234 | 235 | layer_output_list.append(layer_output) 236 | last_state_list.append([h, c]) 237 | 238 | if not self.return_all_layers: 239 | layer_output_list = layer_output_list[-1:] 240 | last_state_list = last_state_list[-1:] 241 | 242 | return layer_output_list, last_state_list 243 | 244 | def _init_hidden(self, batch_size, image_size): 245 | init_states = [] 246 | for i in range(self.num_layers): 247 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) 248 | return init_states 249 | 250 | @staticmethod 251 | def _check_kernel_size_consistency(kernel_size): 252 | if not (isinstance(kernel_size, tuple) or 253 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 254 | raise ValueError('`kernel_size` must be tuple or list of tuples') 255 | 256 | @staticmethod 257 | def _extend_for_multilayer(param, num_layers): 258 | if not isinstance(param, list): 259 | param = [param] * num_layers 260 | return param 261 | -------------------------------------------------------------------------------- /FCN_version/deep_networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /FCN_version/dataset/imutils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image 4 | # from scipy import misc 5 | import torch 6 | import torchvision 7 | from PIL import ImageEnhance 8 | 9 | 10 | def normalize_img(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]): 11 | imgarr = np.asarray(img) 12 | proc_img = (imgarr - mean[0]) / std[0] 13 | return proc_img 14 | 15 | 16 | def random_noise(pre_img, post_im): 17 | np.random.seed() 18 | noise_map = np.random.random(pre_img.size) 19 | return pre_img 20 | 21 | 22 | def random_scaling(pre_img, post_img, loc_label, dam_label, size_range, scale_range): 23 | h, w, = dam_label.shape 24 | 25 | min_ratio, max_ratio = scale_range 26 | assert min_ratio <= max_ratio 27 | 28 | ratio = random.uniform(min_ratio, max_ratio) 29 | 30 | new_scale = int(size_range[0] * ratio), int(size_range[1] * ratio) 31 | 32 | max_long_edge = max(new_scale) 33 | max_short_edge = min(new_scale) 34 | scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w)) 35 | 36 | return _img_rescaling(pre_img, post_img, loc_label, dam_label, scale=ratio) 37 | 38 | 39 | def _img_rescaling(pre_img, post_img, loc_label, dam_label, scale=None): 40 | # scale = random.uniform(scales) 41 | h, w, = dam_label.shape 42 | 43 | new_scale = [int(scale * w), int(scale * h)] 44 | 45 | new_pre_img = Image.fromarray(pre_img.astype(np.uint8)).resize(new_scale, resample=Image.BILINEAR) 46 | new_pre_img = np.asarray(new_pre_img).astype(np.float32) 47 | 48 | new_post_img = Image.fromarray(post_img.astype(np.uint8)).resize(new_scale, resample=Image.BILINEAR) 49 | new_post_img = np.asarray(new_post_img).astype(np.float32) 50 | 51 | if dam_label is None: 52 | return new_pre_img, new_post_img 53 | 54 | new_dam_label = Image.fromarray(dam_label).resize(new_scale, resample=Image.NEAREST) 55 | new_dam_label = np.asarray(new_dam_label) 56 | new_loc_label = Image.fromarray(loc_label).resize(new_scale, resample=Image.NEAREST) 57 | new_loc_label = np.asarray(new_loc_label) 58 | 59 | return new_pre_img, new_post_img, new_loc_label, new_dam_label 60 | 61 | 62 | def img_resize_short(image, min_size=512): 63 | h, w, _ = image.shape 64 | if min(h, w) >= min_size: 65 | return image 66 | 67 | scale = float(min_size) / min(h, w) 68 | new_scale = [int(scale * w), int(scale * h)] 69 | 70 | new_image = Image.fromarray(image.astype(np.uint8)).resize(new_scale, resample=Image.BILINEAR) 71 | new_image = np.asarray(new_image).astype(np.float32) 72 | 73 | return new_image 74 | 75 | 76 | def random_resize(image, label, size_range=None): 77 | _new_size = random.randint(size_range[0], size_range[1]) 78 | 79 | h, w, = label.shape 80 | scale = _new_size / float(max(h, w)) 81 | new_scale = [int(scale * w), int(scale * h)] 82 | 83 | new_image, new_label = _img_rescaling(image, label, scale=new_scale) 84 | 85 | return new_image, new_label 86 | 87 | 88 | def random_fliplr(pre_img, post_img, label): 89 | if random.random() > 0.5: 90 | label = np.fliplr(label) 91 | pre_img = np.fliplr(pre_img) 92 | post_img = np.fliplr(post_img) 93 | 94 | return pre_img, post_img, label 95 | 96 | 97 | def random_fliplr_multicd(pre_img, post_img, pre_lc_label, label): 98 | if random.random() > 0.5: 99 | label = np.fliplr(label) 100 | pre_img = np.fliplr(pre_img) 101 | post_img = np.fliplr(post_img) 102 | pre_lc_label = np.fliplr(pre_lc_label) 103 | 104 | return pre_img, post_img, pre_lc_label, label 105 | 106 | 107 | def random_fliplr_with_object(pre_img, post_img, object_map, label): 108 | if random.random() > 0.5: 109 | label = np.fliplr(label) 110 | pre_img = np.fliplr(pre_img) 111 | post_img = np.fliplr(post_img) 112 | object_map = np.fliplr(object_map) 113 | 114 | return pre_img, post_img, object_map, label 115 | 116 | 117 | def random_flipud(pre_img, post_img, label): 118 | if random.random() > 0.5: 119 | label = np.flipud(label) 120 | pre_img = np.flipud(pre_img) 121 | post_img = np.flipud(post_img) 122 | 123 | return pre_img, post_img, label 124 | 125 | 126 | def random_flipud_multicd(pre_img, post_img, pre_lc_label, label): 127 | if random.random() > 0.5: 128 | label = np.flipud(label) 129 | pre_img = np.flipud(pre_img) 130 | post_img = np.flipud(post_img) 131 | pre_lc_label = np.flipud(pre_lc_label) 132 | 133 | return pre_img, post_img, pre_lc_label, label 134 | 135 | 136 | def random_flipud_with_object(pre_img, post_img, object_map, label): 137 | if random.random() > 0.5: 138 | object_map = np.flipud(object_map) 139 | label = np.flipud(label) 140 | pre_img = np.flipud(pre_img) 141 | post_img = np.flipud(post_img) 142 | 143 | return pre_img, post_img, object_map, label 144 | 145 | 146 | def random_rot(pre_img, post_img, label): 147 | k = random.randrange(3) + 1 148 | 149 | pre_img = np.rot90(pre_img, k).copy() 150 | post_img = np.rot90(post_img, k).copy() 151 | label = np.rot90(label, k).copy() 152 | 153 | return pre_img, post_img, label 154 | 155 | 156 | def random_rot_multicd(pre_img, post_img, pre_lc_label, label): 157 | k = random.randrange(3) + 1 158 | 159 | pre_img = np.rot90(pre_img, k).copy() 160 | post_img = np.rot90(post_img, k).copy() 161 | label = np.rot90(label, k).copy() 162 | pre_lc_label = np.rot90(pre_lc_label, k).copy() 163 | 164 | return pre_img, post_img, pre_lc_label, label 165 | 166 | 167 | def random_rot_with_object(pre_img, post_img, object_map, label): 168 | k = random.randrange(3) + 1 169 | 170 | pre_img = np.rot90(pre_img, k).copy() 171 | post_img = np.rot90(post_img, k).copy() 172 | object_map = np.rot90(object_map, k).copy() 173 | label = np.rot90(label, k).copy() 174 | 175 | return pre_img, post_img, object_map, label 176 | 177 | 178 | def random_crop(pre_img, post_img, label, crop_size, mean_rgb=[0, 0, 0], ignore_index=255): 179 | h, w = label.shape 180 | 181 | H = max(crop_size, h) 182 | W = max(crop_size, w) 183 | 184 | # pad_pre_image = np.zeros((H, W), dtype=np.float32) 185 | pad_pre_image = np.zeros((H, W, pre_img.shape[-1]), dtype=np.float32) 186 | 187 | pad_post_image = np.zeros((H, W, pre_img.shape[-1]), dtype=np.float32) 188 | pad_label = np.ones((H, W), dtype=np.float32) * ignore_index 189 | 190 | # pad_pre_image[:, :] = mean_rgb[0] 191 | # pad_pre_image[:, :, 0] = mean_rgb[0] 192 | # pad_pre_image[:, :, 1] = mean_rgb[1] 193 | # pad_pre_image[:, :, 2] = mean_rgb[2] 194 | # 195 | # pad_post_image[:, :, 0] = mean_rgb[0] 196 | # pad_post_image[:, :, 1] = mean_rgb[1] 197 | # pad_post_image[:, :, 2] = mean_rgb[2] 198 | 199 | H_pad = int(np.random.randint(H - h + 1)) 200 | W_pad = int(np.random.randint(W - w + 1)) 201 | 202 | pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w)] = pre_img 203 | # pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = pre_img 204 | pad_post_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = post_img 205 | pad_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = label 206 | 207 | def get_random_cropbox(cat_max_ratio=0.75): 208 | 209 | for i in range(10): 210 | 211 | H_start = random.randrange(0, H - crop_size + 1, 1) 212 | H_end = H_start + crop_size 213 | W_start = random.randrange(0, W - crop_size + 1, 1) 214 | W_end = W_start + crop_size 215 | 216 | temp_label = pad_label[H_start:H_end, W_start:W_end] 217 | index, cnt = np.unique(temp_label, return_counts=True) 218 | cnt = cnt[index != ignore_index] 219 | if len(cnt > 1) and np.max(cnt) / np.sum(cnt) < cat_max_ratio: 220 | break 221 | 222 | return H_start, H_end, W_start, W_end, 223 | 224 | H_start, H_end, W_start, W_end = get_random_cropbox() 225 | # print(W_start) 226 | pre_img = pad_pre_image[H_start:H_end, W_start:W_end :] 227 | # pre_img = pad_pre_image[H_start:H_end, W_start:W_end, :] 228 | post_img = pad_post_image[H_start:H_end, W_start:W_end, :] 229 | label = pad_label[H_start:H_end, W_start:W_end] 230 | # cmap = colormap() 231 | # misc.imsave('cropimg.png',image/255) 232 | # misc.imsave('croplabel.png',encode_cmap(label)) 233 | return pre_img, post_img, label 234 | 235 | 236 | def random_crop_multicd(pre_img, post_img, pre_lc_label, label, crop_size, mean_rgb=[0, 0, 0], ignore_index=255): 237 | h, w = label.shape 238 | 239 | H = max(crop_size, h) 240 | W = max(crop_size, w) 241 | 242 | pad_pre_image = np.zeros((H, W, 3), dtype=np.float32) 243 | 244 | pad_post_image = np.zeros((H, W, 3), dtype=np.float32) 245 | 246 | pad_pre_lc_label = np.zeros((H, W), dtype=np.float32) 247 | pad_label = np.ones((H, W), dtype=np.float32) * ignore_index 248 | 249 | # pad_pre_image[:, :] = mean_rgb[0] 250 | pad_pre_image[:, :, 0] = mean_rgb[0] 251 | pad_pre_image[:, :, 1] = mean_rgb[1] 252 | pad_pre_image[:, :, 2] = mean_rgb[2] 253 | 254 | pad_post_image[:, :, 0] = mean_rgb[0] 255 | pad_post_image[:, :, 1] = mean_rgb[1] 256 | pad_post_image[:, :, 2] = mean_rgb[2] 257 | 258 | H_pad = int(np.random.randint(H - h + 1)) 259 | W_pad = int(np.random.randint(W - w + 1)) 260 | 261 | pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = pre_img 262 | pad_post_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = post_img 263 | 264 | pad_pre_lc_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = pre_lc_label 265 | pad_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = label 266 | 267 | def get_random_cropbox(cat_max_ratio=0.75): 268 | 269 | for i in range(10): 270 | 271 | H_start = random.randrange(0, H - crop_size + 1, 1) 272 | H_end = H_start + crop_size 273 | W_start = random.randrange(0, W - crop_size + 1, 1) 274 | W_end = W_start + crop_size 275 | 276 | temp_label = pad_label[H_start:H_end, W_start:W_end] 277 | index, cnt = np.unique(temp_label, return_counts=True) 278 | cnt = cnt[index != ignore_index] 279 | if len(cnt > 1) and np.max(cnt) / np.sum(cnt) < cat_max_ratio: 280 | break 281 | 282 | return H_start, H_end, W_start, W_end, 283 | 284 | H_start, H_end, W_start, W_end = get_random_cropbox() 285 | # print(W_start) 286 | pre_img = pad_pre_image[H_start:H_end, W_start:W_end, :] 287 | post_img = pad_post_image[H_start:H_end, W_start:W_end, :] 288 | pre_lc_label = pad_pre_lc_label[H_start:H_end, W_start:W_end] 289 | label = pad_label[H_start:H_end, W_start:W_end] 290 | # cmap = colormap() 291 | # misc.imsave('cropimg.png',image/255) 292 | # misc.imsave('croplabel.png',encode_cmap(label)) 293 | return pre_img, post_img, pre_lc_label, label 294 | 295 | 296 | def random_crop_with_object(pre_img, post_img, object_map, label, crop_size, mean_rgb=[0, 0, 0], ignore_index=255): 297 | h, w = label.shape 298 | 299 | H = max(crop_size, h) 300 | W = max(crop_size, w) 301 | 302 | # pad_pre_image = np.zeros((H, W), dtype=np.float32) 303 | pad_pre_image = np.zeros((H, W, 3), dtype=np.float32) 304 | pad_post_image = np.zeros((H, W, 3), dtype=np.float32) 305 | pad_object_map = np.zeros((H, W), dtype=np.long) 306 | pad_label = np.ones((H, W), dtype=np.float32) * ignore_index 307 | 308 | pad_pre_image[:, :] = mean_rgb[0] 309 | pad_pre_image[:, :, 1] = mean_rgb[1] 310 | pad_pre_image[:, :, 2] = mean_rgb[2] 311 | 312 | pad_post_image[:, :, 0] = mean_rgb[0] 313 | pad_post_image[:, :, 1] = mean_rgb[1] 314 | pad_post_image[:, :, 2] = mean_rgb[2] 315 | 316 | H_pad = int(np.random.randint(H - h + 1)) 317 | W_pad = int(np.random.randint(W - w + 1)) 318 | 319 | # pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w)] = pre_img 320 | pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = pre_img 321 | pad_post_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = post_img 322 | pad_object_map[H_pad:(H_pad + h), W_pad:(W_pad + w)] = object_map 323 | pad_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = label 324 | 325 | def get_random_cropbox(cat_max_ratio=0.75): 326 | 327 | for i in range(10): 328 | 329 | H_start = random.randrange(0, H - crop_size + 1, 1) 330 | H_end = H_start + crop_size 331 | W_start = random.randrange(0, W - crop_size + 1, 1) 332 | W_end = W_start + crop_size 333 | 334 | temp_label = pad_label[H_start:H_end, W_start:W_end] 335 | index, cnt = np.unique(temp_label, return_counts=True) 336 | cnt = cnt[index != ignore_index] 337 | if len(cnt > 1) and np.max(cnt) / np.sum(cnt) < cat_max_ratio: 338 | break 339 | 340 | return H_start, H_end, W_start, W_end, 341 | 342 | H_start, H_end, W_start, W_end = get_random_cropbox() 343 | # print(W_start) 344 | # pre_img = pad_pre_image[H_start:H_end, W_start:W_end] 345 | pre_img = pad_pre_image[H_start:H_end, W_start:W_end, :] 346 | post_img = pad_post_image[H_start:H_end, W_start:W_end, :] 347 | object_map = pad_object_map[H_start:H_end, W_start:W_end] 348 | label = pad_label[H_start:H_end, W_start:W_end] 349 | # cmap = colormap() 350 | # misc.imsave('cropimg.png',image/255) 351 | # misc.imsave('croplabel.png',encode_cmap(label)) 352 | return pre_img, post_img, object_map, label 353 | 354 | 355 | def encode_cmap(label): 356 | cmap = colormap() 357 | return cmap[label.astype(np.int16), :] 358 | 359 | 360 | def tensorboard_image(inputs=None, outputs=None, labels=None, bgr=None): 361 | ## images 362 | inputs[:, 0, :, :] = inputs[:, 0, :, :] + bgr[0] 363 | inputs[:, 1, :, :] = inputs[:, 1, :, :] + bgr[1] 364 | inputs[:, 2, :, :] = inputs[:, 2, :, :] + bgr[2] 365 | inputs = inputs[:, [2, 1, 0], :, :].type(torch.uint8) 366 | grid_inputs = torchvision.utils.make_grid(tensor=inputs, nrow=2) 367 | 368 | ## preds 369 | preds = torch.argmax(outputs, dim=1).cpu().numpy() 370 | preds_cmap = encode_cmap(preds) 371 | preds_cmap = torch.from_numpy(preds_cmap).permute([0, 3, 1, 2]) 372 | grid_outputs = torchvision.utils.make_grid(tensor=preds_cmap, nrow=2) 373 | 374 | ## labels 375 | labels_cmap = encode_cmap(labels.cpu().numpy()) 376 | labels_cmap = torch.from_numpy(labels_cmap).permute([0, 3, 1, 2]) 377 | grid_labels = torchvision.utils.make_grid(tensor=labels_cmap, nrow=2) 378 | 379 | return grid_inputs, grid_outputs, grid_labels 380 | 381 | 382 | def colormap(N=256, normalized=False): 383 | def bitget(byteval, idx): 384 | return ((byteval & (1 << idx)) != 0) 385 | 386 | dtype = 'float32' if normalized else 'uint8' 387 | cmap = np.zeros((N, 3), dtype=dtype) 388 | for i in range(N): 389 | r = g = b = 0 390 | c = i 391 | for j in range(8): 392 | r = r | (bitget(c, 0) << 7 - j) 393 | g = g | (bitget(c, 1) << 7 - j) 394 | b = b | (bitget(c, 2) << 7 - j) 395 | c = c >> 3 396 | 397 | cmap[i] = np.array([r, g, b]) 398 | 399 | cmap = cmap / 255 if normalized else cmap 400 | return cmap 401 | --------------------------------------------------------------------------------