├── Fig
├── HY_1.png
├── HY_2.png
├── HY_GT.png
├── WH_1.png
├── WH_2.png
├── WH_GT.png
├── SiamCRNN.jpg
├── Buffalo_GT.bmp
├── Buffalo_T1.png
└── Buffalo_T2.png
├── FCN_version
├── dataset
│ ├── OSCD
│ │ └── original_data
│ │ │ ├── test.txt
│ │ │ └── train.txt
│ ├── make_data_loader.py
│ └── imutils.py
├── deep_networks
│ ├── sync_batchnorm
│ │ ├── __pycache__
│ │ │ ├── comm.cpython-36.pyc
│ │ │ ├── __init__.cpython-36.pyc
│ │ │ ├── batchnorm.cpython-36.pyc
│ │ │ └── replicate.cpython-36.pyc
│ │ ├── __init__.py
│ │ ├── unittest.py
│ │ ├── replicate.py
│ │ ├── comm.py
│ │ └── batchnorm.py
│ ├── resnet_18_34.py
│ └── SiamCRNN.py
├── README.md
├── util_func
│ ├── metrics.py
│ └── lovasz_loss.py
└── script
│ └── train_siamcrnn.py
├── LICENSE
├── util
├── data_prepro.py
└── net_util.py
├── infer.py
├── SiamCRNN.py
├── README.md
└── train.py
/Fig/HY_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/HY_1.png
--------------------------------------------------------------------------------
/Fig/HY_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/HY_2.png
--------------------------------------------------------------------------------
/Fig/HY_GT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/HY_GT.png
--------------------------------------------------------------------------------
/Fig/WH_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/WH_1.png
--------------------------------------------------------------------------------
/Fig/WH_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/WH_2.png
--------------------------------------------------------------------------------
/Fig/WH_GT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/WH_GT.png
--------------------------------------------------------------------------------
/Fig/SiamCRNN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/SiamCRNN.jpg
--------------------------------------------------------------------------------
/Fig/Buffalo_GT.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/Buffalo_GT.bmp
--------------------------------------------------------------------------------
/Fig/Buffalo_T1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/Buffalo_T1.png
--------------------------------------------------------------------------------
/Fig/Buffalo_T2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/Fig/Buffalo_T2.png
--------------------------------------------------------------------------------
/FCN_version/dataset/OSCD/original_data/test.txt:
--------------------------------------------------------------------------------
1 | brasilia
2 | montpellier
3 | norcia
4 | rio
5 | saclay_w
6 | valencia
7 | dubai
8 | lasvegas
9 | milano
10 | chongqing
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/comm.cpython-36.pyc
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenHongruixuan/SiamCRNN/HEAD/FCN_version/deep_networks/sync_batchnorm/__pycache__/replicate.cpython-36.pyc
--------------------------------------------------------------------------------
/FCN_version/dataset/OSCD/original_data/train.txt:
--------------------------------------------------------------------------------
1 | aguasclaras
2 | bercy
3 | bordeaux
4 | nantes
5 | paris
6 | rennes
7 | saclay_e
8 | abudhabi
9 | cupertino
10 | pisa
11 | beihai
12 | hongkong
13 | beirut
14 | mumbai
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 I-Hope-Peace
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/util/data_prepro.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def norm_img(img, channel_first=True):
5 | if channel_first:
6 | channel, img_height, img_width = img.shape
7 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width)
8 | max_value = np.max(img, axis=1, keepdims=True) # (channel, 1)
9 | min_value = np.min(img, axis=1, keepdims=True) # (channel, 1)
10 | diff_value = max_value - min_value
11 | nm_img = (img - min_value) / diff_value
12 | nm_img = np.reshape(nm_img, (channel, img_height, img_width))
13 | else:
14 | img_height, img_width, channel = img.shape
15 | img = np.reshape(img, (img_height * img_width, channel)) # (channel, height * width)
16 | max_value = np.max(img, axis=0, keepdims=True) # (channel, 1)
17 | min_value = np.min(img, axis=0, keepdims=True) # (channel, 1)
18 | diff_value = max_value - min_value
19 | nm_img = (img - min_value) / diff_value
20 | nm_img = np.reshape(nm_img, (img_height, img_width, channel))
21 | return nm_img
22 |
23 |
24 | def stad_img(img, channel_first=True, get_para=False):
25 | """
26 | normalization image
27 | :param channel_first:
28 | :param img: (C, H, W)
29 | :return:
30 | norm_img: (C, H, W)
31 | """
32 | if channel_first:
33 | channel, img_height, img_width = img.shape
34 | img = np.reshape(img, (channel, img_height * img_width)) # (channel, height * width)
35 | mean = np.mean(img, axis=1, keepdims=True) # (channel, 1)
36 | center = img - mean # (channel, height * width)
37 | var = np.sum(np.power(center, 2), axis=1, keepdims=True) / (img_height * img_width) # (channel, 1)
38 | std = np.sqrt(var) # (channel, 1)
39 | std_img = center / std # (channel, height * width)
40 | std_img = np.reshape(std_img, (channel, img_height, img_width))
41 | else:
42 | img_height, img_width, channel = img.shape
43 | img = np.reshape(img, (img_height * img_width, channel)) # (height * width, channel)
44 | mean = np.mean(img, axis=0, keepdims=True) # (1, channel)
45 | center = img - mean # (height * width, channel)
46 | var = np.sum(np.power(center, 2), axis=0, keepdims=True) / (img_height * img_width) # (1, channel)
47 | std = np.sqrt(var) # (channel, 1)
48 | std_img = center / std # (channel, height * width)
49 | std_img = np.reshape(std_img, (img_height, img_width, channel))
50 | print('mean is ', mean)
51 | print('std is ', std)
52 | if get_para:
53 | return std_img, mean, std
54 | else:
55 | return std_img
56 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import cv2 as cv
2 |
3 | import gdal
4 | import numpy as np
5 | import tensorflow as tf
6 | from util.data_prepro import stad_img
7 |
8 | from SiamCRNN import SiamCRNN
9 | import time
10 |
11 |
12 |
13 | def load_data(path_X, path_Y):
14 | data_set_X = gdal.Open(path_X) # data set X
15 | data_set_Y = gdal.Open(path_Y) # data set Y
16 |
17 | img_width = data_set_X.RasterXSize # image width
18 | img_height = data_set_X.RasterYSize # image height
19 |
20 | img_X = data_set_X.ReadAsArray(0, 0, img_width, img_height)
21 | img_Y = data_set_Y.ReadAsArray(0, 0, img_width, img_height)
22 |
23 | img_X = stad_img(img_X) # (C, H, W)
24 | img_Y = stad_img(img_Y)
25 | img_X = np.transpose(img_X, [1, 2, 0]) # (H, W, C)
26 | img_Y = np.transpose(img_Y, [1, 2, 0]) # (H, W, C)
27 | return img_X, img_Y
28 |
29 |
30 | def infer_result():
31 | patch_sz = 5
32 | batch_size = 1000
33 |
34 | img_X, img_Y = load_data()
35 | img_X = np.pad(img_X, ((2, 2), (2, 2), (0, 0)), 'constant')
36 | img_Y = np.pad(img_Y, ((2, 2), (2, 2), (0, 0)), 'constant')
37 | img_height, img_width, channel = img_X.shape # image width
38 |
39 | edge = patch_sz // 2
40 | sample_X = []
41 | sample_Y = []
42 | for i in range(edge, img_height - edge):
43 | for j in range(edge, img_width - edge):
44 | sample_X.append(img_X[i - edge:i + edge + 1, j - edge:j + edge + 1, :])
45 | sample_Y.append(img_Y[i - edge:i + edge + 1, j - edge:j + edge + 1, :])
46 | sample_X = np.array(sample_X)
47 | sample_Y = np.array(sample_Y)
48 |
49 | epoch = sample_X.shape[0] // batch_size
50 |
51 | Input_X = tf.placeholder(dtype=tf.float32, shape=[None, patch_sz, patch_sz, channel], name='Input_X')
52 | Input_Y = tf.placeholder(dtype=tf.float32, shape=[None, patch_sz, patch_sz, channel], name='Input_Y')
53 | is_training = tf.placeholder(dtype=tf.bool, name='is_training')
54 |
55 | model_path = 'model_param'
56 | model = SiamCRNN()
57 | net, result = model.get_model(Input_X=Input_X, Input_Y=Input_Y, data_format='NHWC', is_training=is_training)
58 | config = tf.ConfigProto()
59 | config.gpu_options.allow_growth = True
60 | saver = tf.train.Saver()
61 | logits_result_list = []
62 | pred_results_list = []
63 | path = None
64 |
65 | with tf.Session() as sess:
66 | ckpt = tf.train.get_checkpoint_state(model_path)
67 | if ckpt and ckpt.model_checkpoint_path:
68 | path = ckpt.model_checkpoint_path
69 | saver.restore(sess, ckpt.model_checkpoint_path)
70 | tic = time.time()
71 | for _epoch in range(100):
72 | pred_result = sess.run([result], feed_dict={
73 | Input_X: sample_X[batch_size * _epoch:batch_size * (_epoch + 1)],
74 | Input_Y: sample_Y[batch_size * _epoch:batch_size * (_epoch + 1)],
75 | is_training: False
76 | })
77 | pred_results_list.append(pred_result)
78 |
79 | pred = np.reshape(pred_results_list, (img_height, img_width))
80 |
81 | idx_1 = (pred <= 0.5)
82 | idx_2 = (pred > 0.5)
83 | pred[idx_1] = 0
84 | pred[idx_2] = 255
85 | cv.imwrite('SiamCRNN.bmp', pred)
86 |
87 |
88 |
89 | if __name__ == '__main__':
90 | infer_result()
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31 | Note that, as all modules are isomorphism, we assign each sub-module with a context
32 | (shared among multiple copies of this module on different devices).
33 | Through this context, different copies can share some information.
34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35 | of any slave copies.
36 | """
37 | master_copy = modules[0]
38 | nr_modules = len(list(master_copy.modules()))
39 | ctxs = [CallbackContext() for _ in range(nr_modules)]
40 |
41 | for i, module in enumerate(modules):
42 | for j, m in enumerate(module.modules()):
43 | if hasattr(m, '__data_parallel_replicate__'):
44 | m.__data_parallel_replicate__(ctxs[j], i)
45 |
46 |
47 | class DataParallelWithCallback(DataParallel):
48 | """
49 | Data Parallel with a replication callback.
50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51 | original `replicate` function.
52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53 | Examples:
54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56 | # sync_bn.__data_parallel_replicate__ will be invoked.
57 | """
58 |
59 | def replicate(self, module, device_ids):
60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61 | execute_replication_callbacks(modules)
62 | return modules
63 |
64 |
65 | def patch_replication_callback(data_parallel):
66 | """
67 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
68 | Useful when you have customized `DataParallel` implementation.
69 | Examples:
70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72 | > patch_replication_callback(sync_bn)
73 | # this is equivalent to
74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76 | """
77 |
78 | assert isinstance(data_parallel, DataParallel)
79 |
80 | old_replicate = data_parallel.replicate
81 |
82 | @functools.wraps(old_replicate)
83 | def new_replicate(module, device_ids):
84 | modules = old_replicate(module, device_ids)
85 | execute_replication_callbacks(modules)
86 | return modules
87 |
88 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/FCN_version/README.md:
--------------------------------------------------------------------------------
1 |
SiamCRNN in Fully Convolutional Version
2 |
3 |
5 |
6 | This is an implementation of fully convolutional version of **SiamCRNN** framework in our IEEE TGRS 2020 paper: [Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network](https://ieeexplore.ieee.org/document/8937755).
7 |
8 | ## Introduction
9 | We have improved the SiamCRNN from our [original paper](https://ieeexplore.ieee.org/document/8937755) so that SiamCRNN can be used for large-scale change detection tasks. The entire architecture is fully convolutional, where the encoder is an arbitrary fully convolutional deep network (we use ResNet here) and the decoder is a multilayer ConvLSTM+FPN.
10 |
11 | ## Get started
12 | ### Requirements
13 | ```
14 | python==3.8.18
15 | pytorch==1.21.1
16 | torchvision==0.13.1
17 | imageio==2.22.4
18 | numpy==1.14.0
19 | tqdm==4.64.1
20 | ```
21 |
22 | ### Dataset
23 | The fully convolutional version of SiamCRNN can be trained and tested on arbitrary large-scale change detection benchmark datasets, such as [SYSU](https://github.com/liumency/SYSU-CD), [LEVIR-CD](https://chenhao.in/LEVIR/), etc. We provide here the example on the [OSCD dataset](https://rcdaudt.github.io/oscd/).
24 |
25 | ### Training
26 | ```
27 | python train_siamcrnn.py
28 | ```
29 |
30 | ### Testing
31 | ```
32 | python test_siamcrnn.py
33 | ```
34 |
35 | ### Detection results
36 | | | | | | | | |
37 | |:-------------:|:---------------:|:--------------------:|:---------------:|:--------------------:|:--------------------:|:--------------------:|
38 | | Method | Rec | Pre | OA | F1 | IoU | KC |
39 | | FC-EF | 0.4612 | 0.4967 | 0.9480 | 0.4783 | 0.3143 | 0.4510 |
40 | | FC-Siam-Conc | 0.5362 | 0.4760 | 0.9455 | 0.5043 | 0.3372 | 0.4757 |
41 | | FC-Siam-Diff | 0.5385 | 0.4391 | 0.9406 | 0.4838 | 0.3191 | 0.4526 |
42 | | [DSIFN](https://github.com/GeoZcx/A-deeply-supervised-image-fusion-network-for-change-detection-in-remote-sensing-images) | 0.4538 | 0.6732 | 0.9604 | 0.5421 | 0.3718 | 0.5222 |
43 | | [ChangeFormer](https://github.com/wgcban/ChangeFormer) | 0.5351 | 0.5566 | 0.9540 | 0.5457 | 0.3752 | 0.5215 |
44 | | **SiamCRNN (ResNet18)** | **0.5160** | **0.5758** | **0.9553** | **0.5442** | **0.3739** | **0.5208** |
45 |
46 | ## Citation
47 | If this code or dataset contributes to your research, please consider citing our paper. We appreciate your support!🙂
48 | ```
49 | @article{Chen2020Change,
50 | author = {Chen, Hongruixuan and Wu, Chen and Du, Bo and Zhang, Liangpei and Wang, Le},
51 | issn = {0196-2892},
52 | journal = {IEEE Transactions on Geoscience and Remote Sensing},
53 | number = {4},
54 | pages = {2848--2864},
55 | title = {{Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network}},
56 | volume = {58},
57 | year = {2020}
58 | }
59 | ```
60 |
61 | ## Q & A
62 | **For any questions, please [contact us.](mailto:Qschrx@gmail.com)**
63 |
--------------------------------------------------------------------------------
/FCN_version/util_func/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Evaluator(object):
5 | def __init__(self, num_class):
6 | self.num_class = num_class
7 | self.confusion_matrix = np.zeros((self.num_class,) * 2, dtype=np.float64)
8 |
9 | def Pixel_Accuracy(self):
10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
11 | return Acc
12 |
13 | def Pixel_Accuracy_Class(self):
14 | Acc = np.diag(self.confusion_matrix) / (self.confusion_matrix.sum(axis=1) + 1e-7)
15 | mAcc = np.nanmean(Acc)
16 | return mAcc, Acc
17 |
18 | def Pixel_Precision_Rate(self):
19 | assert self.confusion_matrix.shape[0] == 2
20 | Pre = self.confusion_matrix[1, 1] / (self.confusion_matrix[0, 1] + self.confusion_matrix[1, 1])
21 | return Pre
22 |
23 | def Pixel_Recall_Rate(self):
24 | assert self.confusion_matrix.shape[0] == 2
25 | Rec = self.confusion_matrix[1, 1] / (self.confusion_matrix[1, 0] + self.confusion_matrix[1, 1])
26 | return Rec
27 |
28 | def Pixel_F1_score(self):
29 | assert self.confusion_matrix.shape[0] == 2
30 | Rec = self.Pixel_Recall_Rate()
31 | Pre = self.Pixel_Precision_Rate()
32 | F1 = 2 * Rec * Pre / (Rec + Pre)
33 | return F1
34 |
35 | def Mean_Intersection_over_Union(self):
36 | MIoU = np.diag(self.confusion_matrix) / (
37 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
38 | np.diag(self.confusion_matrix) + 1e-7)
39 | MIoU = np.nanmean(MIoU)
40 | return MIoU
41 |
42 | def Intersection_over_Union(self):
43 | IoU = self.confusion_matrix[1, 1] / (
44 | self.confusion_matrix[0, 1] + self.confusion_matrix[1, 0] + self.confusion_matrix[1, 1])
45 | return IoU
46 |
47 | def Kappa_coefficient(self):
48 | # Number of observations (total number of classifications)
49 | # num_total = np.array(0, dtype=np.long)
50 | # row_sums = np.array([0, 0], dtype=np.long)
51 | # col_sums = np.array([0, 0], dtype=np.long)
52 | # total += np.sum(self.confusion_matrix)
53 | # # Observed agreement (i.e., sum of diagonal elements)
54 | # observed_agreement = np.sum(np.diag(self.confusion_matrix))
55 | # # Compute expected agreement
56 | # row_sums += np.sum(self.confusion_matrix, axis=0)
57 | # col_sums += np.sum(self.confusion_matrix, axis=1)
58 | # expected_agreement = np.sum((row_sums * col_sums) / total)
59 | num_total = np.sum(self.confusion_matrix, dtype=np.float64)
60 | observed_accuracy = np.trace(self.confusion_matrix, dtype=np.float64) / num_total
61 | expected_accuracy = np.sum(
62 | np.sum(self.confusion_matrix, axis=0, dtype=np.float64) / num_total * np.sum(self.confusion_matrix, axis=1,
63 | dtype=np.float64) / num_total)
64 |
65 | # Calculate Cohen's kappa
66 | kappa = (observed_accuracy - expected_accuracy) / (1 - expected_accuracy)
67 | return kappa
68 |
69 | def Frequency_Weighted_Intersection_over_Union(self):
70 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
71 | iu = np.diag(self.confusion_matrix) / (
72 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
73 | np.diag(self.confusion_matrix))
74 |
75 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
76 | return FWIoU
77 |
78 | def _generate_matrix(self, gt_image, pre_image):
79 | mask = (gt_image >= 0) & (gt_image < self.num_class)
80 | label = self.num_class * gt_image[mask].astype('int64') + pre_image[mask]
81 | count = np.bincount(label, minlength=self.num_class ** 2)
82 | confusion_matrix = count.reshape(self.num_class, self.num_class)
83 | return confusion_matrix
84 |
85 | def add_batch(self, gt_image, pre_image):
86 | assert gt_image.shape == pre_image.shape
87 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
88 |
89 | def reset(self):
90 | self.confusion_matrix = np.zeros((self.num_class,) * 2)
91 |
--------------------------------------------------------------------------------
/SiamCRNN.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from net_util import conv_2d, max_pool_2d, avg_pool_2d, fully_connected
3 |
4 |
5 | class SiamCRNN(object):
6 |
7 | def get_model(self, Input_X, Input_Y, data_format='NHWC', is_training=True):
8 | net_X = self._feature_extract_layer(inputs=Input_X, name='Fea_Ext_',
9 | data_format=data_format,
10 | is_training=is_training)
11 | net_Y = self._feature_extract_layer(inputs=Input_Y, name='Fea_Ext_',
12 | data_format=data_format,
13 | is_training=is_training, is_reuse=True)
14 |
15 | fea_1 = tf.squeeze(net_X, axis=1)
16 | fea_2 = tf.squeeze(net_Y, axis=1)
17 |
18 | logits, pred = self._change_judge_layer(feature_1=fea_1, feature_2=fea_2, name='Cha_Jud_',
19 | is_training=is_training)
20 | return logits, pred
21 |
22 | def _feature_extract_layer(self, inputs, name='Fea_Ext_', data_format='NHWC', is_training=True, is_reuse=False):
23 | with tf.variable_scope(name) as scope:
24 | if is_reuse:
25 | scope.reuse_variables()
26 | # (B, H, W, C) --> (B, H, W, 16)
27 | layer_1 = conv_2d(inputs=inputs, kernel_size=[3, 3], output_channel=16, stride=[1, 1], name='layer_1_conv',
28 | padding='SAME', data_format=data_format, is_training=is_training, is_bn=False,
29 | activation=tf.nn.relu)
30 | layer_2 = conv_2d(inputs=layer_1, kernel_size=[3, 3], output_channel=16, stride=[1, 1], name='layer_2_conv',
31 | padding='SAME', data_format=data_format, is_training=is_training, is_bn=False,
32 | activation=tf.nn.relu)
33 |
34 | layer_2 = tf.contrib.layers.dropout(inputs=layer_2, is_training=is_training, keep_prob=0.8)
35 |
36 | # (B, H/2, W/2, 16) --> (B, H/2, W/2, 32)
37 | layer_3 = conv_2d(inputs=layer_2, kernel_size=[3, 3], output_channel=32, stride=[1, 1], padding='SAME',
38 | name='layer_3_conv', data_format=data_format, is_training=is_training, is_bn=False,
39 | activation=tf.nn.relu)
40 | layer_4 = conv_2d(inputs=layer_3, kernel_size=[3, 3], output_channel=32, stride=[1, 1], padding='SAME',
41 | name='layer_4_conv', data_format=data_format, is_training=is_training, is_bn=False,
42 | activation=tf.nn.relu)
43 | layer_4 = tf.contrib.layers.dropout(inputs=layer_4, is_training=is_training, keep_prob=0.7)
44 |
45 | # # (B, H/2, W/2, 32) --> (B, H/2, W/2, 64)
46 | layer_5 = conv_2d(inputs=layer_4, kernel_size=[3, 3], output_channel=64, stride=[1, 1], padding='SAME',
47 | name='layer_5_conv', data_format=data_format, is_training=is_training, is_bn=False,
48 | activation=tf.nn.relu)
49 | net = conv_2d(inputs=layer_5, kernel_size=[5, 5], output_channel=64, stride=[1, 1], padding='VALID',
50 | name='layer_6_conv', data_format=data_format, is_training=is_training, is_bn=False,
51 | activation=tf.nn.relu)
52 | net = tf.contrib.layers.dropout(inputs=net, is_training=is_training, keep_prob=0.5)
53 | return net
54 |
55 | def _change_judge_layer(self, feature_1, feature_2, name='Cha_Jud_', is_training=True,
56 | activation=tf.nn.sigmoid):
57 | with tf.variable_scope(name) as scope:
58 | seq = tf.concat([feature_1, feature_2], axis=1) # (B, 2, 128)
59 | num_units = [128, 64]
60 | cells = [tf.nn.rnn_cell.LSTMCell(num_unit, activation=tf.nn.tanh) for num_unit in num_units]
61 | mul_cells = tf.nn.rnn_cell.MultiRNNCell(cells)
62 | output, cell_state = tf.nn.dynamic_rnn(mul_cells, seq, dtype=tf.float32, time_major=False)
63 | hidden_state = tf.contrib.layers.dropout(inputs=output[:, -1], is_training=is_training, keep_prob=0.5)
64 | logits_0 = fully_connected(hidden_state, num_outputs=32, is_training=is_training, is_bn=False,
65 | activation=tf.nn.tanh)
66 | logits = fully_connected(logits_0, num_outputs=1, is_training=is_training, is_bn=False)
67 | pred = activation(logits)
68 | return logits, pred
69 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network
2 |
3 |
5 |
6 | This is an official implementation of **SiamCRNN** framework in our IEEE TGRS 2020 paper: [Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network](https://ieeexplore.ieee.org/document/8937755).
7 |
8 |
9 | ## Note
10 | **2024.01.18**
11 | - Wanna go beyond the patch-wise detection pipeline and apply SiamCRNN to large-scale change detection datasets? We have updated the [fully convolutional version of SiamCRNN](https://github.com/ChenHongruixuan/SiamCRNN/tree/master/FCN_version). Please feel free to test on the benchmark dataset.
12 |
13 | **2023.04.25**
14 | - The datasets Wuhan and Hanyang used in our paper have been open-sourced! You can download them [here](http://sigma.whu.edu.cn/resource.php).
15 |
16 |
17 |
18 | ## Abstract
19 | > With the rapid development of Earth observation technology, very-high-resolution (VHR) images from various satellite sensors are more available, which greatly enrich the data source of change detection (CD). Multisource multitemporal images can provide abundant information on observed landscapes with various physical and material views, and it is exigent to develop efficient techniques to utilize these multisource data for CD. In this article, we propose a novel and general deep siamese convolutional multiple-layers recurrent neural network (RNN) (SiamCRNN) for CD in multitemporal VHR images. Superior to most VHR image CD methods, SiamCRNN can be used for both homogeneous and heterogeneous images. Integrating the merits of both convolutional neural network (CNN) and RNN, Siam-CRNN consists of three subnetworks: deep siamese convolutional neural network (DSCNN), multiple-layers RNN (MRNN), and fully connected (FC) layers. The DSCNN has a flexible structure for multisource image and is able to extract spatial–spectral features from homogeneous or heterogeneous VHR image patches. The MRNN stacked by long-short term memory (LSTM) units is responsible for mapping the spatial–spectral features extracted by DSCNN into a new latent feature space and mining the change information between them. In addition, FC, the last part of SiamCRNN, is adopted to predict change probability. The experimental results in two homogeneous data sets and one challenging heterogeneous VHR images data set demonstrate that the promising performances of the proposed network outperform several state-of-the-art approaches.
20 |
21 | ## Network architecture
22 |
23 |
24 |
25 |
26 | ## Requirements
27 | ```
28 | tensorflow_gpu==1.9.0
29 | opencv==3.4.0
30 | numpy==1.14.0
31 | ```
32 |
33 | ## Dataset
34 | Two homogeneous datasets, Wuhan and Hanyang, and one heterogeneous dataset, Buffalo, are used in our work. The Wuhan and Hanyang datasets can be downloaded [here](http://sigma.whu.edu.cn/resource.php). For the Buffalo dataset, please request distribution from Prof. [Chen Wu](mailto:chen.wu@whu.edu.cn).
35 |
36 |
37 | | Dataset | Pre-event image | Post-event image | Reference Image |
38 | | :----: | :----: | :----: | :----: |
39 | | Wuhan |
|
|
|
40 | | Hanyang |
|
|
|
41 | | Buffalo |
|
|
|
42 |
43 |
44 |
45 | ## Citation
46 | If this code or dataset contributes to your research, please consider citing our paper. We appreciate your support!🙂
47 | ```
48 | @article{Chen2020Change,
49 | author = {Chen, Hongruixuan and Wu, Chen and Du, Bo and Zhang, Liangpei and Wang, Le},
50 | issn = {0196-2892},
51 | journal = {IEEE Transactions on Geoscience and Remote Sensing},
52 | number = {4},
53 | pages = {2848--2864},
54 | title = {{Change Detection in Multisource VHR Images via Deep Siamese Convolutional Multiple-Layers Recurrent Neural Network}},
55 | volume = {58},
56 | year = {2020}
57 | }
58 | ```
59 |
60 | ## Q & A
61 | **For any questions, please [contact us.](mailto:Qschrx@gmail.com)**
62 |
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {'master_callback': self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state['master_callback'])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, 'The first result should belongs to the master.'
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/FCN_version/script/train_siamcrnn.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/songjian/project/HSIFM')
3 | import argparse
4 | import os
5 | import time
6 |
7 | import imageio
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from torch.utils.data import DataLoader
13 | from tqdm import tqdm
14 | from dataset.make_data_loader import OSCDDatset3Bands, make_data_loader, OSCDDatset13Bands
15 | from util_func.metrics import Evaluator
16 | from deep_networks.SiamCRNN import SiamCRNN
17 | import util_func.lovazs_loss as L
18 |
19 |
20 | class Trainer(object):
21 | def __init__(self, args):
22 | self.args = args
23 |
24 | self.train_data_loader = make_data_loader(args)
25 | print(args.model_type + ' is running')
26 | self.evaluator = Evaluator(num_class=2)
27 |
28 | self.deep_model = SiamCRNN(in_dim_1=13, in_dim_2=13)
29 | self.deep_model = self.deep_model.cuda()
30 |
31 | self.model_save_path = os.path.join(args.model_param_path, args.dataset,
32 | args.model_type + '_' + str(time.time()))
33 | self.lr = args.learning_rate
34 | self.epoch = args.max_iters // args.batch_size
35 |
36 | if not os.path.exists(self.model_save_path):
37 | os.makedirs(self.model_save_path)
38 |
39 | self.optim = optim.AdamW(self.deep_model.parameters(),
40 | lr=args.learning_rate,
41 | weight_decay=args.weight_decay)
42 |
43 | def training(self):
44 | best_kc = 0.0
45 | best_round = []
46 | torch.cuda.empty_cache()
47 | self.deep_model.train()
48 | class_weight = torch.FloatTensor([1, 10]).cuda()
49 | elem_num = len(self.train_data_loader)
50 | train_enumerator = enumerate(self.train_data_loader)
51 | for _ in tqdm(range(elem_num)):
52 | itera, data = train_enumerator.__next__()
53 | self.optim.zero_grad()
54 |
55 | pre_img, post_img, bcd_labels, _ = data
56 |
57 | pre_img = pre_img.cuda().float()
58 | post_img = post_img.cuda().float()
59 | bcd_labels = bcd_labels.cuda().long()
60 | # input_data = torch.cat([pre_img, post_img], dim=1)
61 |
62 | # bcd_output = self.deep_model(input_data)
63 | bcd_output = self.deep_model(pre_img, post_img)
64 |
65 | bcd_loss = F.cross_entropy(bcd_output, bcd_labels, weight=class_weight, ignore_index=255)
66 | lovasz_loss = L.lovasz_softmax(F.softmax(bcd_output, dim=1), bcd_labels, ignore=255)
67 |
68 | main_loss = bcd_loss + 0.75 * lovasz_loss
69 | main_loss.backward()
70 |
71 | self.optim.step()
72 |
73 | if (itera + 1) % 10 == 0:
74 | print(
75 | f'iter is {itera + 1}, change detection loss is {bcd_loss}')
76 | if (itera + 1) % 200 == 0:
77 | self.deep_model.eval()
78 | rec, pre, oa, f1_score, iou, kc = self.validation()
79 | if kc > best_kc:
80 | torch.save(self.deep_model.state_dict(),
81 | os.path.join(self.model_save_path, f'{itera + 1}_model.pth'))
82 |
83 | best_kc = kc
84 | best_round = [rec, pre, oa, f1_score, iou, kc]
85 | self.deep_model.train()
86 |
87 | print('The accuracy of the best round is ', best_round)
88 |
89 | def validation(self):
90 | print('---------starting evaluation-----------')
91 | self.evaluator.reset()
92 | dataset_path = '/home/songjian/project/HSIFM/dataset/OSCD/original_data'
93 | with open('/home/songjian/project/HSIFM/dataset/OSCD/original_data/test.txt', "r") as f:
94 | # data_name_list = f.read()
95 | data_name_list = [data_name.strip() for data_name in f]
96 | data_name_list = data_name_list
97 | dataset = OSCDDatset13Bands(dataset_path=dataset_path, data_list=data_name_list, crop_size=512,
98 | max_iters=None, type='test')
99 | val_data_loader = DataLoader(dataset, batch_size=1, num_workers=8, drop_last=False)
100 | torch.cuda.empty_cache()
101 |
102 | for itera, data in enumerate(val_data_loader):
103 | pre_img, post_img, bcd_labels, data_name = data
104 |
105 | pre_img = pre_img.cuda().float()
106 | post_img = post_img.cuda().float()
107 | bcd_labels = bcd_labels.cuda().long()
108 | # input_data = torch.cat([pre_img, post_img], dim=1)
109 |
110 | # bcd_output = self.deep_model(input_data)
111 | bcd_output = self.deep_model(pre_img, post_img)
112 | bcd_output = bcd_output.data.cpu().numpy()
113 | bcd_output = np.argmax(bcd_output, axis=1)
114 |
115 | bcd_img = bcd_output[0].copy()
116 | bcd_img[bcd_img == 1] = 255
117 |
118 | # imageio.imwrite('./' + data_name[0] + '.png', bcd_img)
119 |
120 | bcd_labels = bcd_labels.cpu().numpy()
121 | self.evaluator.add_batch(bcd_labels, bcd_output)
122 |
123 | f1_score = self.evaluator.Pixel_F1_score()
124 | oa = self.evaluator.Pixel_Accuracy()
125 | rec = self.evaluator.Pixel_Recall_Rate()
126 | pre = self.evaluator.Pixel_Precision_Rate()
127 | iou = self.evaluator.Intersection_over_Union()
128 | kc = self.evaluator.Kappa_coefficient()
129 | print(f'Racall rate is {rec}, Precision rate is {pre}, OA is {oa}, '
130 | f'F1 score is {f1_score}, IoU is {iou}, Kappa coefficient is {kc}')
131 | return rec, pre, oa, f1_score, iou, kc
132 |
133 |
134 | def main():
135 | parser = argparse.ArgumentParser(description="Training on OEM_OSM dataset")
136 | parser.add_argument('--dataset', type=str, default='OSCD_13Bands')
137 | parser.add_argument('--dataset_path', type=str,
138 | default='/home/songjian/project/HSIFM/dataset/OSCD/original_data')
139 | parser.add_argument('--type', type=str, default='train')
140 | parser.add_argument('--train_data_list_path', type=str,
141 | default='/home/songjian/project/HSIFM/dataset/OSCD/original_data/train.txt')
142 | parser.add_argument('--shuffle', type=bool, default=True)
143 | parser.add_argument('--batch_size', type=int, default=16)
144 | parser.add_argument('--data_name_list', type=list)
145 | parser.add_argument('--start_iter', type=int, default=0)
146 | parser.add_argument('--cuda', type=bool, default=True)
147 | parser.add_argument('--crop_size', type=int, default=256)
148 | parser.add_argument('--max_iters', type=int, default=100000)
149 | parser.add_argument('--model_type', type=str, default='SiamCRNN')
150 | parser.add_argument('--model_param_path', type=str, default='../saved_models')
151 |
152 | parser.add_argument('--resume', type=str)
153 | parser.add_argument('--learning_rate', type=float, default=1e-4)
154 | parser.add_argument('--momentum', type=float, default=0.9)
155 | parser.add_argument('--weight_decay', type=float, default=5e-4)
156 |
157 | args = parser.parse_args()
158 | with open(args.train_data_list_path, "r") as f:
159 | # data_name_list = f.read()
160 | data_name_list = [data_name.strip() for data_name in f]
161 | args.data_name_list = data_name_list
162 |
163 | trainer = Trainer(args)
164 | trainer.training()
165 |
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
--------------------------------------------------------------------------------
/FCN_version/dataset/make_data_loader.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import imageio
5 | import numpy as np
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data import Dataset
9 |
10 | import dataset.imutils as imutils
11 |
12 | band_idx = ['B01.tif', 'B02.tif', 'B03.tif', 'B04.tif', 'B05.tif', 'B06.tif', 'B07.tif', 'B08.tif', 'B8A.tif',
13 | 'B09.tif', 'B10.tif', 'B11.tif', 'B12.tif']
14 |
15 |
16 | def img_loader(path):
17 | img = np.array(imageio.imread(path), np.float32)
18 | return img
19 |
20 |
21 | def sentinel_loader(path):
22 | band_0_img = np.array(imageio.imread(os.path.join(path, 'B01.tif')), np.float32)
23 | ms_data = np.zeros((band_0_img.shape[0], band_0_img.shape[1], 13))
24 | for i, band in enumerate(band_idx):
25 | ms_data[:, :, i] = np.array(imageio.imread(os.path.join(path, band)), np.float32)
26 |
27 | return ms_data
28 |
29 |
30 | def one_hot_encoding(image, num_classes=8):
31 | # Create a one hot encoded tensor
32 | one_hot = np.eye(num_classes)[image.astype(np.uint8)]
33 |
34 | # Move the channel axis to the front
35 | # one_hot = np.moveaxis(one_hot, -1, 0)
36 |
37 | return one_hot
38 |
39 |
40 | class OSCDDatset3Bands(Dataset):
41 | def __init__(self, dataset_path, data_list, crop_size, max_iters=None, type='train', data_loader=img_loader):
42 | self.dataset_path = dataset_path
43 | self.data_list = data_list
44 | self.loader = data_loader
45 | self.type = type
46 | self.data_pro_type = self.type
47 |
48 | if max_iters is not None:
49 | self.data_list = self.data_list * int(np.ceil(float(max_iters) / len(self.data_list)))
50 | self.data_list = self.data_list[0:max_iters]
51 | self.crop_size = crop_size
52 |
53 | def __transforms(self, aug, pre_img, post_img, label):
54 | if aug:
55 | pre_img, post_img, label = imutils.random_crop(pre_img, post_img, label, self.crop_size)
56 | pre_img, post_img, label = imutils.random_fliplr(pre_img, post_img, label)
57 | pre_img, post_img, label = imutils.random_flipud(pre_img, post_img, label)
58 | pre_img, post_img, label = imutils.random_rot(pre_img, post_img, label)
59 |
60 | pre_img = imutils.normalize_img(pre_img) # imagenet normalization
61 | pre_img = np.transpose(pre_img, (2, 0, 1))
62 |
63 | post_img = imutils.normalize_img(post_img) # imagenet normalization
64 | post_img = np.transpose(post_img, (2, 0, 1))
65 |
66 | return pre_img, post_img, label
67 |
68 | def __getitem__(self, index):
69 | pre_path = os.path.join(self.dataset_path, self.data_list[index], 'pair', 'img1.png')
70 | post_path = os.path.join(self.dataset_path, self.data_list[index], 'pair', 'img2.png')
71 | label_path = os.path.join(self.dataset_path, self.data_list[index], 'cm', 'cm.png')
72 | pre_img = self.loader(pre_path)
73 | post_img = self.loader(post_path)
74 | label = self.loader(label_path)
75 |
76 | if len(label.shape) > 2:
77 | label = label[:, :, 0]
78 | label = label / 255
79 |
80 | if 'train' in self.data_pro_type:
81 | pre_img, post_img, label = self.__transforms(True, pre_img, post_img, label)
82 | else:
83 | pre_img, post_img, label = self.__transforms(False, pre_img, post_img, label)
84 | label = np.asarray(label)
85 |
86 | data_idx = self.data_list[index]
87 | return pre_img, post_img, label, data_idx
88 |
89 | def __len__(self):
90 | return len(self.data_list)
91 |
92 |
93 | class OSCDDatset13Bands(Dataset):
94 | def __init__(self, dataset_path, data_list, crop_size, max_iters=None, type='train', data_loader=img_loader,
95 | sentinel_loader=sentinel_loader):
96 | self.dataset_path = dataset_path
97 | self.data_list = data_list
98 | self.data_loader = data_loader
99 | self.sentinel_loader = sentinel_loader
100 |
101 | self.type = type
102 | self.data_pro_type = self.type
103 |
104 | if max_iters is not None:
105 | self.data_list = self.data_list * int(np.ceil(float(max_iters) / len(self.data_list)))
106 | self.data_list = self.data_list[0:max_iters]
107 | self.crop_size = crop_size
108 |
109 | def __transforms(self, aug, pre_img, post_img, label):
110 | if aug:
111 | pre_img, post_img, label = imutils.random_crop(pre_img, post_img, label, self.crop_size)
112 | pre_img, post_img, label = imutils.random_fliplr(pre_img, post_img, label)
113 | pre_img, post_img, label = imutils.random_flipud(pre_img, post_img, label)
114 | pre_img, post_img, label = imutils.random_rot(pre_img, post_img, label)
115 |
116 | pre_img = imutils.normalize_img(pre_img) # imagenet normalization
117 | pre_img = np.transpose(pre_img, (2, 0, 1))
118 |
119 | post_img = imutils.normalize_img(post_img) # imagenet normalization
120 | post_img = np.transpose(post_img, (2, 0, 1))
121 |
122 | return pre_img, post_img, label
123 |
124 | def __getitem__(self, index):
125 | pre_path = os.path.join(self.dataset_path, self.data_list[index], 'imgs_1_rect', 'ms_data.npy')
126 | post_path = os.path.join(self.dataset_path, self.data_list[index], 'imgs_2_rect', 'ms_data.npy')
127 | label_path = os.path.join(self.dataset_path, self.data_list[index], 'cm', 'cm.png')
128 | pre_img = np.load(pre_path)
129 | post_img = np.load(post_path)
130 | label = self.data_loader(label_path)
131 |
132 | if len(label.shape) > 2:
133 | label = label[:, :, 0]
134 | label = label / 255
135 |
136 | if 'train' in self.data_pro_type:
137 | pre_img, post_img, label = self.__transforms(True, pre_img, post_img, label)
138 | else:
139 | pre_img, post_img, label = self.__transforms(False, pre_img, post_img, label)
140 | label = np.asarray(label)
141 |
142 | data_idx = self.data_list[index]
143 | return pre_img, post_img, label, data_idx
144 |
145 | def __len__(self):
146 | return len(self.data_list)
147 |
148 |
149 | def make_data_loader(args, **kwargs): # **kwargs could be omitted
150 | if 'OSCD_3Bands' in args.dataset:
151 | dataset = OSCDDatset3Bands(args.dataset_path, args.data_name_list, args.crop_size, args.max_iters, args.type)
152 | # train_sampler = DistributedSampler(dataset, shuffle=True)
153 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle, **kwargs, num_workers=16,
154 | drop_last=False)
155 | return data_loader
156 |
157 | if 'OSCD_13Bands' in args.dataset:
158 | dataset = OSCDDatset13Bands(args.dataset_path, args.data_name_list, args.crop_size, args.max_iters, args.type)
159 | # train_sampler = DistributedSampler(dataset, shuffle=True)
160 | data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=args.shuffle, **kwargs, num_workers=16,
161 | drop_last=False)
162 | return data_loader
163 |
164 | else:
165 | raise NotImplementedError
166 |
167 |
168 | if __name__ == '__main__':
169 |
170 | parser = argparse.ArgumentParser(description="SECOND DataLoader Test")
171 | parser.add_argument('--dataset', type=str, default='WHUBCD')
172 | parser.add_argument('--max_iters', type=int, default=10000)
173 | parser.add_argument('--type', type=str, default='train')
174 | parser.add_argument('--dataset_path', type=str, default='D:/Workspace/Python/STCD/data/ST-WHU-BCD')
175 | parser.add_argument('--data_list_path', type=str, default='./ST-WHU-BCD/train_list.txt')
176 | parser.add_argument('--shuffle', type=bool, default=True)
177 | parser.add_argument('--batch_size', type=int, default=8)
178 | parser.add_argument('--data_name_list', type=list)
179 |
180 | args = parser.parse_args()
181 |
182 | with open(args.data_list_path, "r") as f:
183 | # data_name_list = f.read()
184 | data_name_list = [data_name.strip() for data_name in f]
185 | args.data_name_list = data_name_list
186 | train_data_loader = make_data_loader(args)
187 | for i, data in enumerate(train_data_loader):
188 | pre_img, post_img, labels, _ = data
189 | pre_data, post_data = Variable(pre_img), Variable(post_img)
190 | labels = Variable(labels)
191 | print(i, "个inputs", pre_data.data.size(), "labels", labels.data.size())
192 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import pickle
4 | import time
5 |
6 | import gdal
7 | import numpy as np
8 | import tensorflow as tf
9 |
10 | from SiamCRNN import SiamCRNN
11 |
12 | from util.data_prepro import stad_img
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--max_epoch', type=int, default=300, help='epoch to run[default: 50]')
16 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size during training[default: 512]')
17 | parser.add_argument('--learning_rate', type=float, default=2e-4, help='initial learning rate[default: 3e-4]')
18 | parser.add_argument('--save_path', default='model_param', help='model param path')
19 | parser.add_argument('--data_path', default=None, help='dataset path')
20 | parser.add_argument('--gpu_num', type=int, default=1, help='number of GPU to train')
21 |
22 | # basic params
23 | FLAGS = parser.parse_args()
24 |
25 | BATCH_SZ = FLAGS.batch_size
26 | LEARNING_RATE = FLAGS.learning_rate
27 | MAX_EPOCH = FLAGS.max_epoch
28 | SAVE_PATH = FLAGS.save_path
29 | DATA_PATH = FLAGS.data_path
30 | GPU_NUM = FLAGS.gpu_num
31 | BATCH_PER_GPU = BATCH_SZ // GPU_NUM
32 |
33 |
34 | class ChangeTrainer(object):
35 |
36 | def __init__(self):
37 | self.Input_X = None
38 | self.Input_Y = None
39 | self.label = None
40 | self.is_training = None
41 |
42 | self.net = None
43 | self.pred = None
44 | self.loss = None
45 | self.opt = None
46 | self.train_op = None
47 | self.global_step = tf.Variable(0, trainable=False)
48 | self.siamcrnn_model = SiamCRNN()
49 |
50 |
51 | def load_data(self):
52 | data_set_X = gdal.Open('data/GF_2_2/0411') # data set X
53 | data_set_Y = gdal.Open('data/GF_2_2/0901') # data set Y
54 |
55 | img_width = data_set_X.RasterXSize # image width
56 | img_height = data_set_X.RasterYSize # image height
57 |
58 | img_X = data_set_X.ReadAsArray(0, 0, img_width, img_height)
59 | img_Y = data_set_Y.ReadAsArray(0, 0, img_width, img_height)
60 |
61 | img_X = stad_img(img_X) # (C, H, W)
62 | img_Y = stad_img(img_Y)
63 | img_X = np.transpose(img_X, [1, 2, 0]) # (H, W, C)
64 | img_Y = np.transpose(img_Y, [1, 2, 0]) # (H, W, C)
65 | return img_X, img_Y
66 |
67 | def training(self):
68 | train_X, train_Y, train_label = self._load_train_data(path=DATA_PATH)
69 | self.valid_X, self.valid_Y, self.valid_label = self._load_valid_data(path=DATA_PATH)
70 | train_label = np.reshape(train_label, (-1, 1))
71 | self.valid_label = np.reshape(self.valid_label, (-1, 1))
72 | self.valid_sz = self.valid_X.shape[0]
73 |
74 | shape_1 = train_X.shape
75 | shape_2 = train_Y.shape
76 | train_sz = train_X.shape[0]
77 | self.Input_X = tf.placeholder(dtype=tf.float32, shape=[None, shape_1[1], shape_1[2], shape_1[3]],
78 | name='Input_X')
79 | self.Input_Y = tf.placeholder(dtype=tf.float32, shape=[None, shape_2[1], shape_2[2], shape_2[3]],
80 | name='Input_Y')
81 | self.label = tf.placeholder(dtype=tf.float32, shape=[None, 1])
82 | self.is_training = tf.placeholder(dtype=tf.bool, name='is_training')
83 |
84 |
85 | self.net, self.pred, _, _ = self.siamcrnn_model.get_model(Input_X=self.Input_X, Input_Y=self.Input_Y,
86 | data_format='NHWC',
87 | is_training=self.is_training) # (B, 2)
88 | self.loss = self._get_loss(label=self.label, logits=self.net)
89 | self.opt = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
90 | self.train_op = self.opt.minimize(loss=self.loss)
91 | best_loss = 100000
92 | iter_in_epoch = train_sz // BATCH_SZ
93 | config = tf.ConfigProto()
94 | config.gpu_options.allow_growth = True
95 | saver = tf.train.Saver(max_to_keep=0, var_list=tf.global_variables())
96 | total_time = 0
97 | with tf.Session(config=config) as sess:
98 | sess.run(tf.global_variables_initializer())
99 | epoch_sz = MAX_EPOCH
100 |
101 | for epoch in range(epoch_sz):
102 | tic = time.time()
103 | ave_loss = 0
104 | train_idx = np.arange(0, train_sz)
105 | np.random.shuffle(train_idx)
106 | for _iter in range(iter_in_epoch):
107 | start_idx = _iter * BATCH_SZ
108 | end_idx = (_iter + 1) * BATCH_SZ
109 | batch_train_X = train_X[train_idx[start_idx:end_idx]]
110 | batch_train_Y = train_Y[train_idx[start_idx:end_idx]]
111 | batch_label = train_label[train_idx[start_idx:end_idx]]
112 | loss, _, logits = sess.run(
113 | [self.loss, self.train_op, self.net],
114 | feed_dict={
115 | self.Input_X: batch_train_X,
116 | self.Input_Y: batch_train_Y,
117 | self.label: batch_label,
118 | self.is_training: True
119 | })
120 | ave_loss += loss
121 | ave_loss /= iter_in_epoch
122 | toc = time.time()
123 | total_time += (toc - tic)
124 | # print("epoch %d , loss is %f take %.3f s , min logits is %.3f, min pred is %.3f" % (
125 | # epoch + 1, ave_loss, time.time() - tic, min_logits, min_pred))
126 | val_loss = self.evaluate(sess)
127 |
128 | if (epoch + 1) % 5 == 0:
129 | if val_loss < best_loss:
130 | best_loss = val_loss
131 | _path = saver.save(sess, os.path.join(SAVE_PATH, "best_model.ckpt"))
132 | print("best model is saved")
133 | _path = saver.save(sess, os.path.join(SAVE_PATH, "cha_model_%d.ckpt" % (epoch + 1)))
134 | print("epoch %d, model saved in file: " % (epoch + 1), _path)
135 | # self.evaluate(sess)
136 | _path = saver.save(sess, os.path.join(SAVE_PATH, 'final_model.ckpt'))
137 | print("Model saved in file: ", _path)
138 | print(total_time)
139 |
140 | def evaluate(self, sess):
141 | iter_in_epoch = self.valid_sz // BATCH_SZ
142 | valid_idx = np.arange(0, self.valid_sz)
143 | np.random.shuffle(valid_idx)
144 | ave_loss = 0
145 | for _iter in range(iter_in_epoch):
146 | start_idx = _iter * BATCH_SZ
147 | end_idx = (_iter + 1) * BATCH_SZ
148 | batch_valid_X = self.valid_X[valid_idx[start_idx:end_idx]]
149 | batch_valid_Y = self.valid_Y[valid_idx[start_idx:end_idx]]
150 | batch_label = self.valid_label[valid_idx[start_idx:end_idx]]
151 | loss = sess.run(self.loss, feed_dict={ # (B, 2)
152 | self.Input_X: batch_valid_X,
153 | self.Input_Y: batch_valid_Y,
154 | self.label: batch_label,
155 | self.is_training: False
156 | })
157 | ave_loss += loss
158 | ave_loss /= iter_in_epoch
159 | print("evaluate is done, validation loss is %.3f" % ave_loss)
160 | return ave_loss
161 |
162 | def _get_loss(self, label, logits):
163 | loss = tf.reduce_mean(
164 | tf.nn.weighted_cross_entropy_with_logits(targets=label, logits=logits, pos_weight=5, name='weight_loss'))
165 | return loss
166 |
167 | def _load_train_data(self, path):
168 | with open(os.path.join(path, 'train_sample_X.pickle'), 'rb') as file:
169 | train_X = pickle.load(file)
170 | with open(os.path.join(path, 'train_sample_Y.pickle'), 'rb') as file:
171 | train_Y = pickle.load(file)
172 | with open(os.path.join(path, 'train_label.pickle'), 'rb') as file:
173 | train_label = pickle.load(file)
174 |
175 | return train_X, train_Y, train_label
176 |
177 | def _load_valid_data(self, path):
178 | with open(os.path.join(path, 'valid_sample_X.pickle'), 'rb') as file:
179 | valid_X = pickle.load(file)
180 | with open(os.path.join(path, 'valid_sample_Y.pickle'), 'rb') as file:
181 | valid_Y = pickle.load(file)
182 | with open(os.path.join(path, 'valid_label.pickle'), 'rb') as file:
183 | valid_label = pickle.load(file)
184 |
185 | return valid_X, valid_Y, valid_label
186 |
187 |
188 | if __name__ == '__main__':
189 | trainer = ChangeTrainer()
190 | trainer.training()
191 |
--------------------------------------------------------------------------------
/util/net_util.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.layers import batch_normalization
3 |
4 | _EPSILON = 1e-7
5 |
6 |
7 | def conv_2d(inputs, kernel_size, output_channel, stride, name, is_reuse=False, padding='SAME', data_format='NHWC',
8 | is_bn=False, is_training=True, activation=None):
9 | """
10 | 2D Conv Layer
11 | :param name: scope name
12 | :param inputs: (B, H, W, C)
13 | :param kernel_size: [kernel_h, kernel_w]
14 | :param output_channel: feature num
15 | :param stride: a list of 2 ints
16 | :param padding: type of padding, str
17 | :param data_format: str, the format of input data
18 | :param is_bn: bool, is batch normalization
19 | :param is_training: bool, is training
20 | :param activation: activation function, such as tf.nn.relu
21 | :return: outputs
22 | """
23 | with tf.variable_scope(name) as scope:
24 | if is_reuse:
25 | scope.reuse_variables()
26 | kernel_h, kernel_w = kernel_size
27 | stride_h, stride_w = stride
28 | if data_format == 'NHWC':
29 | kernel_shape = [kernel_h, kernel_w, inputs.get_shape()[-1].value, output_channel]
30 | else:
31 | kernel_shape = [kernel_h, kernel_w, inputs.get_shape()[1].value, output_channel]
32 | init = tf.keras.initializers.he_normal()
33 | kernel = tf.get_variable(name='conv_kenel', shape=kernel_shape, initializer=init, dtype=tf.float32)
34 | # kernel = tf.Variable(tf.truncated_normal(kernel_shape, dtype=tf.float32, stddev=0.1))
35 | outputs = tf.nn.conv2d(input=inputs,
36 | filter=kernel,
37 | strides=[1, stride_h, stride_w, 1], padding=padding,
38 | data_format=data_format)
39 | biases = tf.Variable(tf.constant(0.1, shape=[output_channel], dtype=tf.float32))
40 | outputs = outputs + biases
41 | if is_bn:
42 | outputs = batch_normalization(outputs, training=is_training)
43 | if activation is not None:
44 | outputs = activation(outputs)
45 | return outputs
46 |
47 |
48 | def conv_2d_transpose(inputs, kernel_size, output_channel, output_shape, stride, name, padding='SAME',
49 | data_format='NHWC', is_bn=False, is_training=True, activation=None):
50 | """
51 | 2D Transpose Conv Layer
52 | :param output_shape:
53 | :param name: scope name
54 | :param inputs: (B, H, W, C)
55 | :param kernel_size: [kernel_h, kernel_w]
56 | :param output_channel: feature num
57 | :param stride: a list of 2 ints
58 | :param padding: type of padding, str
59 | :param data_format: str, the format of input data
60 | :param is_bn: bool, is batch normalization
61 | :param is_training: bool, is training
62 | :param activation: activation function, such as tf.nn.relu
63 | :return: outputs
64 | """
65 | with tf.variable_scope(name) as scope:
66 | kernel_h, kernel_w = kernel_size
67 | stride_h, stride_w = stride
68 | if data_format == 'NHWC':
69 | kernel_shape = [kernel_h, kernel_w, output_channel, inputs.get_shape()[-1].value]
70 | else:
71 | kernel_shape = [kernel_h, kernel_w, output_channel, inputs.get_shape()[1].value]
72 |
73 | init = tf.keras.initializers.he_normal()
74 | kernel = tf.get_variable(name='conv_kenel', shape=kernel_shape, initializer=init, dtype=tf.float32)
75 |
76 | # calculate output shape
77 | # batch_size, height, width, _ = inputs.get_shape().as_list()
78 | # out_height = get_deconv_dim(height, stride_h, kernel_h, padding)
79 | # out_width = get_deconv_dim(width, stride_w, kernel_w, padding)
80 | # output_shape = [batch_size, out_height, out_width, output_channel]
81 |
82 | # kernel = tf.Variable(tf.truncated_normal(kernel_shape, dtype=tf.float32, stddev=0.1))
83 | outputs = tf.nn.conv2d_transpose(value=inputs,
84 | filter=kernel,
85 | output_shape=output_shape,
86 | strides=[1, stride_h, stride_w, 1], padding=padding,
87 | data_format=data_format)
88 | biases = tf.Variable(tf.constant(0.1, shape=[output_channel], dtype=tf.float32))
89 | outputs = outputs + biases
90 | if is_bn:
91 | outputs = batch_normalization(outputs, training=is_training)
92 | if activation is not None:
93 | outputs = activation(outputs)
94 | return outputs
95 |
96 |
97 | # from slim.convolution2d_transpose
98 | def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
99 | dim_size *= stride_size
100 |
101 | if padding == 'VALID' and dim_size is not None:
102 | dim_size += max(kernel_size - stride_size, 0)
103 | return dim_size
104 |
105 |
106 | def max_pool_2d(inputs, kernel_size, stride, padding='SAME'):
107 | """
108 | 2D Max Pool Layer
109 | :param inputs: (B, H, W, C)
110 | :param kernel_size: [kernel_h, kernel_w]
111 | :param padding: type of padding, str
112 | :return: outputs
113 | """
114 | kernel_h, kernel_w = kernel_size
115 | stride_h, stride_w = stride
116 | outputs = tf.nn.max_pool(inputs,
117 | ksize=[1, kernel_h, kernel_w, 1],
118 | strides=[1, stride_h, stride_w, 1],
119 | padding=padding)
120 | return outputs
121 |
122 |
123 | def avg_pool_2d(inputs, kernel_size, stride, padding='SAME'):
124 | """
125 | 2D Avg Pool Layer
126 | :param inputs: (B, H, W, C)
127 | :param kernel_size: [kernel_h, kernel_w]
128 | :param padding: type of padding, str
129 | :return: outputs
130 | """
131 | kernel_h, kernel_w = kernel_size
132 | stride_h, stride_w = stride
133 | outputs = tf.nn.avg_pool(inputs,
134 | ksize=[1, kernel_h, kernel_w, 1],
135 | strides=[1, stride_h, stride_w, 1],
136 | padding=padding)
137 | return outputs
138 |
139 |
140 | def fully_connected(inputs, num_outputs, is_training=True, is_bn=False, activation=None):
141 | """
142 | Fully connected layer with non-linear operation
143 | :param inputs: (B, N)
144 | :param num_outputs: int
145 | :param is_training: bool
146 | :param is_bn: bool
147 | :param activation: activation function, such as tf.nn.relu
148 | :return: outputs: (B, num_outputs)
149 | """
150 |
151 | num_input_units = inputs.get_shape()[-1].value
152 | weights = tf.Variable(tf.truncated_normal([num_input_units, num_outputs], dtype=tf.float32, stddev=0.1))
153 | outputs = tf.matmul(inputs, weights)
154 | biases = tf.Variable(tf.constant(0.1, shape=[num_outputs], dtype=tf.float32))
155 | outputs = tf.nn.bias_add(outputs, biases)
156 | if is_bn:
157 | outputs = batch_normalization(outputs, training=is_training)
158 | if activation is not None:
159 | outputs = activation(outputs)
160 | return outputs
161 |
162 |
163 | def weight_binary_cross_entropy(target, output, weight=1.0, from_logits=False):
164 | """weight binary crossentropy between an output tensor and a target tensor.
165 |
166 | # Arguments
167 | target: A tensor with the same shape as `output`.
168 | output: A tensor.
169 | from_logits: Whether `output` is expected to be a logits tensor.
170 | By default, we consider that `output`
171 | encodes a probability distribution.
172 |
173 | # Returns
174 | A tensor.
175 | """
176 | # Note: tf.nn.sigmoid_cross_entropy_with_logits
177 | # expects logits, Keras expects probabilities.
178 | if not from_logits:
179 | # transform back to logits
180 | _epsilon = _to_tensor(epsilon(), output.dtype.base_dtype)
181 | output = tf.clip_by_value(output, _epsilon, 1 - _epsilon)
182 | output = tf.log(output / (1 - output))
183 |
184 | return tf.nn.weighted_cross_entropy_with_logits(targets=target,
185 | logits=output,
186 | pos_weight=weight)
187 |
188 |
189 | def _to_tensor(x, dtype):
190 | """Convert the input `x` to a tensor of type `dtype`.
191 |
192 | # Arguments
193 | x: An object to be converted (numpy array, list, tensors).
194 | dtype: The destination type.
195 |
196 | # Returns
197 | A tensor.
198 | """
199 | return tf.convert_to_tensor(x, dtype=dtype)
200 |
201 |
202 | def epsilon():
203 | """Returns the value of the fuzz factor used in numeric expressions.
204 |
205 | # Returns
206 | A float.
207 |
208 | # Example
209 | ```python
210 | >>> keras.backend.epsilon()
211 | 1e-07
212 | ```
213 | """
214 | return _EPSILON
215 |
--------------------------------------------------------------------------------
/FCN_version/deep_networks/resnet_18_34.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 | from deep_networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 | from torchvision.models import resnet50
7 | from torchvision.models import resnet
8 |
9 | __model_file = {
10 | 18: '/home/songjian/project/HSIFM/pretrained_weight/resnet-18-pytorch.pth',
11 | 34: '/home/songjian/project/HSIFM/pretrained_weight/resnet-34-pytorch.pth',
12 | }
13 |
14 |
15 | class BasicBlock(nn.Module):
16 | expansion = 1
17 |
18 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
19 | super(BasicBlock, self).__init__()
20 | # if norm_layer is None:
21 | # norm_layer = nn.BatchNorm2d
22 | # if groups != 1 or base_width != 64:
23 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64')
24 | # if dilation > 1:
25 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
26 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
27 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
28 | dilation=dilation, padding=dilation, bias=False)
29 | self.bn1 = BatchNorm(planes)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
32 | self.bn2 = BatchNorm(planes)
33 | self.downsample = downsample
34 | self.stride = stride
35 |
36 | def forward(self, x):
37 | identity = x
38 |
39 | out = self.conv1(x)
40 | out = self.bn1(out)
41 | out = self.relu(out)
42 |
43 | out = self.conv2(out)
44 | out = self.bn2(out)
45 |
46 | if self.downsample is not None:
47 | identity = self.downsample(x)
48 |
49 | out += identity
50 | out = self.relu(out)
51 |
52 | return out
53 |
54 |
55 | class ResNet(nn.Module):
56 | def __init__(self, input_dim, block, layers, output_stride, BatchNorm):
57 | self.inplanes = 64
58 | super(ResNet, self).__init__()
59 | blocks = [1, 2, 4]
60 | if output_stride == 16:
61 | strides = [1, 2, 2, 1]
62 | dilations = [1, 1, 1, 2]
63 | elif output_stride == 8:
64 | strides = [1, 2, 1, 1]
65 | dilations = [1, 1, 2, 4]
66 | else:
67 | raise NotImplementedError
68 |
69 | # Modules
70 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3,
71 | bias=False)
72 | self.bn1 = BatchNorm(64)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
75 |
76 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0],
77 | BatchNorm=BatchNorm)
78 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1],
79 | BatchNorm=BatchNorm)
80 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2],
81 | BatchNorm=BatchNorm)
82 | # self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3],
83 | # BatchNorm=BatchNorm)
84 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3],
85 | BatchNorm=BatchNorm)
86 |
87 | self._init_weight()
88 |
89 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
90 | downsample = None
91 | if stride != 1 or self.inplanes != planes * block.expansion:
92 | downsample = nn.Sequential(
93 | nn.Conv2d(self.inplanes, planes * block.expansion,
94 | kernel_size=1, stride=stride, bias=False),
95 | BatchNorm(planes * block.expansion),
96 | )
97 |
98 | layers = []
99 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
100 | self.inplanes = planes * block.expansion
101 | for i in range(1, blocks):
102 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
103 |
104 | return nn.Sequential(*layers)
105 |
106 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
107 | downsample = None
108 | if stride != 1 or self.inplanes != planes * block.expansion:
109 | downsample = nn.Sequential(
110 | nn.Conv2d(self.inplanes, planes * block.expansion,
111 | kernel_size=1, stride=stride, bias=False),
112 | BatchNorm(planes * block.expansion),
113 | )
114 |
115 | layers = []
116 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation,
117 | downsample=downsample, BatchNorm=BatchNorm))
118 | self.inplanes = planes * block.expansion
119 | for i in range(1, len(blocks)):
120 | layers.append(block(self.inplanes, planes, stride=1,
121 | dilation=blocks[i] * dilation, BatchNorm=BatchNorm))
122 |
123 | return nn.Sequential(*layers)
124 |
125 | def forward(self, input):
126 | x = self.conv1(input)
127 | x = self.bn1(x)
128 | x = self.relu(x)
129 | # low_level_feat_0 = x
130 | x = self.maxpool(x)
131 |
132 | x = self.layer1(x)
133 | low_level_feat_1 = x
134 | x = self.layer2(x)
135 | low_level_feat_2 = x
136 | x = self.layer3(x)
137 | # x = self.dropout_3(x)
138 | low_level_feat_3 = x
139 | x = self.layer4(x)
140 | # x = self.dropout_4(x)
141 | return low_level_feat_1, low_level_feat_2, low_level_feat_3, x
142 |
143 | def _init_weight(self):
144 | for m in self.modules():
145 | if isinstance(m, nn.Conv2d):
146 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
147 | m.weight.data.normal_(0, math.sqrt(2. / n))
148 | elif isinstance(m, SynchronizedBatchNorm2d):
149 | m.weight.data.fill_(1)
150 | m.bias.data.zero_()
151 | elif isinstance(m, nn.BatchNorm2d):
152 | m.weight.data.fill_(1)
153 | m.bias.data.zero_()
154 |
155 | def _load_pretrained_model(self):
156 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
157 | model_dict = {}
158 | state_dict = self.state_dict()
159 | for k, v in pretrain_dict.items():
160 | if k in state_dict:
161 | model_dict[k] = v
162 | state_dict.update(model_dict)
163 | self.load_state_dict(state_dict)
164 |
165 |
166 | def ResNet18(input_dim, output_stride, BatchNorm, pretrained=False):
167 | """Constructs a ResNet-18 model.
168 | Args:
169 | pretrained (bool): If True, returns a model pre-trained on ImageNet
170 | """
171 | model = ResNet(input_dim, BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm)
172 | if pretrained:
173 | pretrain_dict = torch.load(__model_file[18])
174 | model_dict = {}
175 | state_dict = model.state_dict()
176 | for k, v in pretrain_dict.items():
177 | if k in state_dict:
178 | model_dict[k] = v
179 | state_dict.update(model_dict)
180 | model.load_state_dict(state_dict)
181 | print('pretrainde model has been loaded')
182 |
183 | return model
184 |
185 |
186 | def ResNet34(input_dim, output_stride, BatchNorm, pretrained=False):
187 | """Constructs a ResNet-18 model.
188 | Args:
189 | pretrained (bool): If True, returns a model pre-trained on ImageNet
190 | """
191 | model = ResNet(input_dim, BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm)
192 | if pretrained:
193 | pretrain_dict = torch.load(__model_file[34])
194 | model_dict = {}
195 | state_dict = model.state_dict()
196 | for k, v in pretrain_dict.items():
197 | if k in state_dict:
198 | model_dict[k] = v
199 | state_dict.update(model_dict)
200 | model.load_state_dict(state_dict)
201 | print('pretrainde model has been loaded')
202 |
203 | return model
204 |
205 |
206 | #
207 | # if __name__ == "__main__":
208 | # import torch
209 | #
210 | # model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
211 | # input = torch.rand(1, 3, 1024, 1024)
212 | # output, low_level_feat = model(input)
213 | # print(output.size())
214 | # print(low_level_feat.size())
215 |
--------------------------------------------------------------------------------
/FCN_version/util_func/lovasz_loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch
3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
4 | """
5 |
6 | from __future__ import print_function, division
7 |
8 | import torch
9 | from torch.autograd import Variable
10 | import torch.nn.functional as F
11 | import numpy as np
12 |
13 | try:
14 | from itertools import ifilterfalse
15 | except ImportError: # py3k
16 | from itertools import filterfalse as ifilterfalse
17 |
18 |
19 | def lovasz_grad(gt_sorted):
20 | """
21 | Computes gradient of the Lovasz extension w.r.t sorted errors
22 | See Alg. 1 in paper
23 | """
24 | p = len(gt_sorted)
25 | gts = gt_sorted.sum()
26 | intersection = gts - gt_sorted.float().cumsum(0)
27 | union = gts + (1 - gt_sorted).float().cumsum(0)
28 | jaccard = 1. - intersection / union
29 | if p > 1: # cover 1-pixel case
30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
31 | return jaccard
32 |
33 |
34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
35 | """
36 | IoU for foreground class
37 | binary: 1 foreground, 0 background
38 | """
39 | if not per_image:
40 | preds, labels = (preds,), (labels,)
41 | ious = []
42 | for pred, label in zip(preds, labels):
43 | intersection = ((label == 1) & (pred == 1)).sum()
44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
45 | if not union:
46 | iou = EMPTY
47 | else:
48 | iou = float(intersection) / float(union)
49 | ious.append(iou)
50 | iou = mean(ious) # mean accross images if per_image
51 | return 100 * iou
52 |
53 |
54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
55 | """
56 | Array of IoU for each (non ignored) class
57 | """
58 | if not per_image:
59 | preds, labels = (preds,), (labels,)
60 | ious = []
61 | for pred, label in zip(preds, labels):
62 | iou = []
63 | for i in range(C):
64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
65 | intersection = ((label == i) & (pred == i)).sum()
66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum()
67 | if not union:
68 | iou.append(EMPTY)
69 | else:
70 | iou.append(float(intersection) / float(union))
71 | ious.append(iou)
72 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
73 | return 100 * np.array(ious)
74 |
75 |
76 | # --------------------------- BINARY LOSSES ---------------------------
77 |
78 |
79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
80 | """
81 | Binary Lovasz hinge loss
82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
84 | per_image: compute the loss per image instead of per batch
85 | ignore: void class id
86 | """
87 | if per_image:
88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
89 | for log, lab in zip(logits, labels))
90 | else:
91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
92 | return loss
93 |
94 |
95 | def lovasz_hinge_flat(logits, labels):
96 | """
97 | Binary Lovasz hinge loss
98 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
99 | labels: [P] Tensor, binary ground truth labels (0 or 1)
100 | ignore: label to ignore
101 | """
102 | if len(labels) == 0:
103 | # only void pixels, the gradients should be 0
104 | return logits.sum() * 0.
105 | signs = 2. * labels.float() - 1.
106 | errors = (1. - logits * Variable(signs))
107 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
108 | perm = perm.data
109 | gt_sorted = labels[perm]
110 | grad = lovasz_grad(gt_sorted)
111 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
112 | return loss
113 |
114 |
115 | def flatten_binary_scores(scores, labels, ignore=None):
116 | """
117 | Flattens predictions in the batch (binary case)
118 | Remove labels equal to 'ignore'
119 | """
120 | scores = scores.view(-1)
121 | labels = labels.view(-1)
122 | if ignore is None:
123 | return scores, labels
124 | valid = (labels != ignore)
125 | vscores = scores[valid]
126 | vlabels = labels[valid]
127 | return vscores, vlabels
128 |
129 |
130 | class StableBCELoss(torch.nn.modules.Module):
131 | def __init__(self):
132 | super(StableBCELoss, self).__init__()
133 |
134 | def forward(self, input, target):
135 | neg_abs = - input.abs()
136 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
137 | return loss.mean()
138 |
139 |
140 | def binary_xloss(logits, labels, ignore=None):
141 | """
142 | Binary Cross entropy loss
143 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
144 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
145 | ignore: void class id
146 | """
147 | logits, labels = flatten_binary_scores(logits, labels, ignore)
148 | loss = StableBCELoss()(logits, Variable(labels.float()))
149 | return loss
150 |
151 |
152 | # --------------------------- MULTICLASS LOSSES ---------------------------
153 |
154 |
155 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=255):
156 | """
157 | Multi-class Lovasz-Softmax loss
158 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
159 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
160 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
161 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
162 | per_image: compute the loss per image instead of per batch
163 | ignore: void class labels
164 | """
165 | if per_image:
166 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
167 | for prob, lab in zip(probas, labels))
168 | else:
169 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
170 | return loss
171 |
172 |
173 | def lovasz_softmax_flat(probas, labels, classes='present'):
174 | """
175 | Multi-class Lovasz-Softmax loss
176 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
177 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
178 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
179 | """
180 | if probas.numel() == 0:
181 | # only void pixels, the gradients should be 0
182 | return 0. # current probas.shape=torch.size([0,4]) and (probas*0.).shape is also torch.size([0,4]),a tensor loss cannot be backforward
183 | if len(probas.shape)==1:
184 | probas.unsqueeze(0) # only one pixel match the interest classes, probas.shape=torch.size([4]), we need it to be torch.size([1,4]) for the subsequent calculation
185 | C = probas.size(1)
186 | losses = []
187 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
188 | for c in class_to_sum:
189 | fg = (labels == c).float() # foreground for class c
190 | if (classes is 'present' and fg.sum() == 0):
191 | continue
192 | if C == 1:
193 | if len(classes) > 1:
194 | raise ValueError('Sigmoid output possible only with 1 class')
195 | class_pred = probas[:, 0]
196 | else:
197 | class_pred = probas[:, c]
198 | errors = (Variable(fg) - class_pred).abs()
199 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
200 | perm = perm.data
201 | fg_sorted = fg[perm]
202 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
203 | return mean(losses)
204 |
205 |
206 | def flatten_probas(probas, labels, ignore=None):
207 | """
208 | Flattens predictions in the batch
209 | """
210 | if probas.dim() == 3:
211 | # assumes output of a sigmoid layer
212 | B, H, W = probas.size()
213 | probas = probas.view(B, 1, H, W)
214 | B, C, H, W = probas.size()
215 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
216 | labels = labels.view(-1)
217 | if ignore is None:
218 | return probas, labels
219 | valid = (labels != ignore)
220 | vprobas = probas[valid.nonzero().squeeze()]
221 | vlabels = labels[valid]
222 | return vprobas, vlabels
223 |
224 |
225 | def xloss(logits, labels, ignore=None):
226 | """
227 | Cross entropy loss
228 | """
229 | return F.cross_entropy(logits, Variable(labels), ignore_index=255)
230 |
231 |
232 | # --------------------------- HELPER FUNCTIONS ---------------------------
233 | def isnan(x):
234 | return x != x
235 |
236 |
237 | def mean(l, ignore_nan=False, empty=0):
238 | """
239 | nanmean compatible with generators.
240 | """
241 | l = iter(l)
242 | if ignore_nan:
243 | l = ifilterfalse(isnan, l)
244 | try:
245 | n = 1
246 | acc = next(l)
247 | except StopIteration:
248 | if empty == 'raise':
249 | raise ValueError('Empty mean')
250 | return empty
251 | for n, v in enumerate(l, 2):
252 | acc += v
253 | if n == 1:
254 | return acc
255 | return acc / n
--------------------------------------------------------------------------------
/FCN_version/deep_networks/SiamCRNN.py:
--------------------------------------------------------------------------------
1 | from deep_networks.resnet_18_34 import ResNet34, ResNet18
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class SiamCRNN(nn.Module):
10 | def __init__(self, in_dim_1, in_dim_2, pretrained=True, output_stride=16, BatchNorm=nn.BatchNorm2d, bias=True):
11 | super(SiamCRNN, self).__init__()
12 | self.encoder_1 = ResNet34(input_dim=in_dim_1, BatchNorm=BatchNorm, pretrained=False, output_stride=output_stride)
13 | # self.encoder_1.conv1 = nn.Conv2d(in_dim_1, 64, kernel_size=7, stride=2, padding=3, bias=False)
14 | # If your dataset is heterogeneous, then please utilize pesudo-siamese architecture
15 | # self.encoder_2 = ResNet18(input_dim=in_dim_2, BatchNorm=BatchNorm, pretrained=True, output_stride=output_stride)
16 | self.convlstm_4 = ConvLSTM(input_dim=512, hidden_dim=128, kernel_size=(3, 3), num_layers=1,
17 | batch_first=True)
18 | self.convlstm_3 = ConvLSTM(input_dim=256, hidden_dim=128, kernel_size=(3, 3), num_layers=1,
19 | batch_first=True)
20 | self.convlstm_2 = ConvLSTM(input_dim=128, hidden_dim=128, kernel_size=(3, 3), num_layers=1,
21 | batch_first=True)
22 | self.convlstm_1 = ConvLSTM(input_dim=64, hidden_dim=128, kernel_size=(3, 3), num_layers=1,
23 | batch_first=True)
24 | self.smooth_layer_3 = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding=1),
25 | nn.BatchNorm2d(128), nn.ReLU())
26 | self.smooth_layer_2 = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding=1),
27 | nn.BatchNorm2d(128), nn.ReLU())
28 | self.smooth_layer_1 = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=128, out_channels=128, padding=1),
29 | nn.BatchNorm2d(128), nn.ReLU())
30 |
31 | self.main_clf_1 = nn.Conv2d(in_channels=128, out_channels=2, kernel_size=1)
32 |
33 | def _upsample_add(self, x, y):
34 | _, _, H, W = y.size()
35 | return F.interpolate(x, size=(H, W), mode='bilinear') + y
36 |
37 | def forward(self, pre_data, post_data):
38 | pre_low_level_feat_1, pre_low_level_feat_2, pre_low_level_feat_3, pre_output = \
39 | self.encoder_1(pre_data)
40 | post_low_level_feat_1, post_low_level_feat_2, post_low_level_feat_3, post_output = \
41 | self.encoder_1(post_data)
42 |
43 | # Concatenate along the time dimension
44 | combined_4 = torch.stack([pre_output, post_output], dim=1)
45 | # Apply ConvLSTM
46 | _, last_state_list_4 = self.convlstm_4(combined_4)
47 | p4 = last_state_list_4[0][0]
48 |
49 | combined_3 = torch.stack([pre_low_level_feat_3, post_low_level_feat_3], dim=1)
50 | # Apply ConvLSTM
51 | _, last_state_list_3 = self.convlstm_3(combined_3)
52 | p3 = last_state_list_3[0][0]
53 | p3 = self._upsample_add(p4, p3)
54 | p3 = self.smooth_layer_3(p3)
55 |
56 | combined_2 = torch.stack([pre_low_level_feat_2, post_low_level_feat_2], dim=1)
57 | # Apply ConvLSTM
58 | _, last_state_list_2 = self.convlstm_2(combined_2)
59 | p2 = last_state_list_2[0][0]
60 | p2 = self._upsample_add(p3, p2)
61 | p2 = self.smooth_layer_2(p2)
62 |
63 | combined_1 = torch.stack([pre_low_level_feat_1, post_low_level_feat_1], dim=1)
64 | # Apply ConvLSTM
65 | _, last_state_list_1 = self.convlstm_1(combined_1)
66 | p1 = last_state_list_1[0][0]
67 | p1 = self._upsample_add(p2, p1)
68 | p1 = self.smooth_layer_1(p1)
69 |
70 | output_1 = self.main_clf_1(p1)
71 | output_1 = F.interpolate(output_1, size=pre_data.size()[-2:], mode='bilinear')
72 | return output_1
73 |
74 |
75 | class ConvLSTMCell(nn.Module):
76 |
77 | def __init__(self, input_dim, hidden_dim, kernel_size, bias):
78 | """
79 | Initialize ConvLSTM cell.
80 |
81 | Parameters
82 | ----------
83 | input_dim: int
84 | Number of channels of input tensor.
85 | hidden_dim: int
86 | Number of channels of hidden state.
87 | kernel_size: (int, int)
88 | Size of the convolutional kernel.
89 | bias: bool
90 | Whether or not to add the bias.
91 | """
92 |
93 | super(ConvLSTMCell, self).__init__()
94 |
95 | self.input_dim = input_dim
96 | self.hidden_dim = hidden_dim
97 |
98 | self.kernel_size = kernel_size
99 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
100 | self.bias = bias
101 |
102 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
103 | out_channels=4 * self.hidden_dim,
104 | kernel_size=self.kernel_size,
105 | padding=self.padding,
106 | bias=self.bias)
107 |
108 | def forward(self, input_tensor, cur_state):
109 | h_cur, c_cur = cur_state
110 |
111 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
112 |
113 | combined_conv = self.conv(combined)
114 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
115 | i = torch.sigmoid(cc_i)
116 | f = torch.sigmoid(cc_f)
117 | o = torch.sigmoid(cc_o)
118 | g = torch.tanh(cc_g)
119 |
120 | c_next = f * c_cur + i * g
121 | h_next = o * torch.tanh(c_next)
122 |
123 | return h_next, c_next
124 |
125 | def init_hidden(self, batch_size, image_size):
126 | height, width = image_size
127 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
128 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
129 |
130 |
131 | class ConvLSTM(nn.Module):
132 | """
133 |
134 | Parameters:
135 | input_dim: Number of channels in input
136 | hidden_dim: Number of hidden channels
137 | kernel_size: Size of kernel in convolutions
138 | num_layers: Number of LSTM layers stacked on each other
139 | batch_first: Whether or not dimension 0 is the batch or not
140 | bias: Bias or no bias in Convolution
141 | return_all_layers: Return the list of computations for all layers
142 | Note: Will do same padding.
143 |
144 | Input:
145 | A tensor of size B, T, C, H, W or T, B, C, H, W
146 | Output:
147 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
148 | 0 - layer_output_list is the list of lists of length T of each output
149 | 1 - last_state_list is the list of last states
150 | each element of the list is a tuple (h, c) for hidden state and memory
151 | Example:
152 | >> x = torch.rand((32, 10, 64, 128, 128))
153 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
154 | >> _, last_states = convlstm(x)
155 | >> h = last_states[0][0] # 0 for layer index, 0 for h index
156 | """
157 |
158 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
159 | batch_first=False, bias=True, return_all_layers=False):
160 | super(ConvLSTM, self).__init__()
161 |
162 | self._check_kernel_size_consistency(kernel_size)
163 |
164 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
165 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
166 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
167 | if not len(kernel_size) == len(hidden_dim) == num_layers:
168 | raise ValueError('Inconsistent list length.')
169 |
170 | self.input_dim = input_dim
171 | self.hidden_dim = hidden_dim
172 | self.kernel_size = kernel_size
173 | self.num_layers = num_layers
174 | self.batch_first = batch_first
175 | self.bias = bias
176 | self.return_all_layers = return_all_layers
177 |
178 | cell_list = []
179 | for i in range(0, self.num_layers):
180 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
181 |
182 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
183 | hidden_dim=self.hidden_dim[i],
184 | kernel_size=self.kernel_size[i],
185 | bias=self.bias))
186 |
187 | self.cell_list = nn.ModuleList(cell_list)
188 |
189 | def forward(self, input_tensor, hidden_state=None):
190 | """
191 |
192 | Parameters
193 | ----------
194 | input_tensor: todo
195 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
196 | hidden_state: todo
197 | None. todo implement stateful
198 |
199 | Returns
200 | -------
201 | last_state_list, layer_output
202 | """
203 | if not self.batch_first:
204 | # (t, b, c, h, w) -> (b, t, c, h, w)
205 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
206 |
207 | b, _, _, h, w = input_tensor.size()
208 |
209 | # Implement stateful ConvLSTM
210 | if hidden_state is not None:
211 | raise NotImplementedError()
212 | else:
213 | # Since the init is done in forward. Can send image size here
214 | hidden_state = self._init_hidden(batch_size=b,
215 | image_size=(h, w))
216 |
217 | layer_output_list = []
218 | last_state_list = []
219 |
220 | seq_len = input_tensor.size(1)
221 | cur_layer_input = input_tensor
222 |
223 | for layer_idx in range(self.num_layers):
224 |
225 | h, c = hidden_state[layer_idx]
226 | output_inner = []
227 | for t in range(seq_len):
228 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
229 | cur_state=[h, c])
230 | output_inner.append(h)
231 |
232 | layer_output = torch.stack(output_inner, dim=1)
233 | cur_layer_input = layer_output
234 |
235 | layer_output_list.append(layer_output)
236 | last_state_list.append([h, c])
237 |
238 | if not self.return_all_layers:
239 | layer_output_list = layer_output_list[-1:]
240 | last_state_list = last_state_list[-1:]
241 |
242 | return layer_output_list, last_state_list
243 |
244 | def _init_hidden(self, batch_size, image_size):
245 | init_states = []
246 | for i in range(self.num_layers):
247 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
248 | return init_states
249 |
250 | @staticmethod
251 | def _check_kernel_size_consistency(kernel_size):
252 | if not (isinstance(kernel_size, tuple) or
253 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
254 | raise ValueError('`kernel_size` must be tuple or list of tuples')
255 |
256 | @staticmethod
257 | def _extend_for_multilayer(param, num_layers):
258 | if not isinstance(param, list):
259 | param = [param] * num_layers
260 | return param
261 |
--------------------------------------------------------------------------------
/FCN_version/deep_networks/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | return F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 |
55 | # Resize the input to (B, C, -1).
56 | input_shape = input.size()
57 | input = input.view(input.size(0), self.num_features, -1)
58 |
59 | # Compute the sum and square-sum.
60 | sum_size = input.size(0) * input.size(2)
61 | input_sum = _sum_ft(input)
62 | input_ssum = _sum_ft(input ** 2)
63 |
64 | # Reduce-and-broadcast the statistics.
65 | if self._parallel_id == 0:
66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67 | else:
68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69 |
70 | # Compute the output.
71 | if self.affine:
72 | # MJY:: Fuse the multiplication for speed.
73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74 | else:
75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76 |
77 | # Reshape it.
78 | return output.view(input_shape)
79 |
80 | def __data_parallel_replicate__(self, ctx, copy_id):
81 | self._is_parallel = True
82 | self._parallel_id = copy_id
83 |
84 | # parallel_id == 0 means master device.
85 | if self._parallel_id == 0:
86 | ctx.sync_master = self._sync_master
87 | else:
88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89 |
90 | def _data_parallel_master(self, intermediates):
91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92 |
93 | # Always using same "device order" makes the ReduceAdd operation faster.
94 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96 |
97 | to_reduce = [i[1][:2] for i in intermediates]
98 | to_reduce = [j for i in to_reduce for j in i] # flatten
99 | target_gpus = [i[1].sum.get_device() for i in intermediates]
100 |
101 | sum_size = sum([i[1].sum_size for i in intermediates])
102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104 |
105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106 |
107 | outputs = []
108 | for i, rec in enumerate(intermediates):
109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110 |
111 | return outputs
112 |
113 | def _compute_mean_std(self, sum_, ssum, size):
114 | """Compute the mean and standard-deviation with sum and square-sum. This method
115 | also maintains the moving average on the master device."""
116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117 | mean = sum_ / size
118 | sumvar = ssum - sum_ * mean
119 | unbias_var = sumvar / (size - 1)
120 | bias_var = sumvar / size
121 |
122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124 |
125 | return mean, bias_var.clamp(self.eps) ** -0.5
126 |
127 |
128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130 | mini-batch.
131 | .. math::
132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
134 | standard-deviation are reduced across all devices during training.
135 | For example, when one uses `nn.DataParallel` to wrap the network during
136 | training, PyTorch's implementation normalize the tensor on each device using
137 | the statistics only on that device, which accelerated the computation and
138 | is also easy to implement, but the statistics might be inaccurate.
139 | Instead, in this synchronized version, the statistics will be computed
140 | over all training samples distributed on multiple devices.
141 |
142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
143 | as the built-in PyTorch implementation.
144 | The mean and standard-deviation are calculated per-dimension over
145 | the mini-batches and gamma and beta are learnable parameter vectors
146 | of size C (where C is the input size).
147 | During training, this layer keeps a running estimate of its computed mean
148 | and variance. The running sum is kept with a default momentum of 0.1.
149 | During evaluation, this running mean/variance is used for normalization.
150 | Because the BatchNorm is done over the `C` dimension, computing statistics
151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
152 | Args:
153 | num_features: num_features from an expected input of size
154 | `batch_size x num_features [x width]`
155 | eps: a value added to the denominator for numerical stability.
156 | Default: 1e-5
157 | momentum: the value used for the running_mean and running_var
158 | computation. Default: 0.1
159 | affine: a boolean value that when set to ``True``, gives the layer learnable
160 | affine parameters. Default: ``True``
161 | Shape:
162 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
164 | Examples:
165 | >>> # With Learnable Parameters
166 | >>> m = SynchronizedBatchNorm1d(100)
167 | >>> # Without Learnable Parameters
168 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
169 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
170 | >>> output = m(input)
171 | """
172 |
173 | def _check_input_dim(self, input):
174 | if input.dim() != 2 and input.dim() != 3:
175 | raise ValueError('expected 2D or 3D input (got {}D input)'
176 | .format(input.dim()))
177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
178 |
179 |
180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
182 | of 3d inputs
183 | .. math::
184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
186 | standard-deviation are reduced across all devices during training.
187 | For example, when one uses `nn.DataParallel` to wrap the network during
188 | training, PyTorch's implementation normalize the tensor on each device using
189 | the statistics only on that device, which accelerated the computation and
190 | is also easy to implement, but the statistics might be inaccurate.
191 | Instead, in this synchronized version, the statistics will be computed
192 | over all training samples distributed on multiple devices.
193 |
194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
195 | as the built-in PyTorch implementation.
196 | The mean and standard-deviation are calculated per-dimension over
197 | the mini-batches and gamma and beta are learnable parameter vectors
198 | of size C (where C is the input size).
199 | During training, this layer keeps a running estimate of its computed mean
200 | and variance. The running sum is kept with a default momentum of 0.1.
201 | During evaluation, this running mean/variance is used for normalization.
202 | Because the BatchNorm is done over the `C` dimension, computing statistics
203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
204 | Args:
205 | num_features: num_features from an expected input of
206 | size batch_size x num_features x height x width
207 | eps: a value added to the denominator for numerical stability.
208 | Default: 1e-5
209 | momentum: the value used for the running_mean and running_var
210 | computation. Default: 0.1
211 | affine: a boolean value that when set to ``True``, gives the layer learnable
212 | affine parameters. Default: ``True``
213 | Shape:
214 | - Input: :math:`(N, C, H, W)`
215 | - Output: :math:`(N, C, H, W)` (same shape as input)
216 | Examples:
217 | >>> # With Learnable Parameters
218 | >>> m = SynchronizedBatchNorm2d(100)
219 | >>> # Without Learnable Parameters
220 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
222 | >>> output = m(input)
223 | """
224 |
225 | def _check_input_dim(self, input):
226 | if input.dim() != 4:
227 | raise ValueError('expected 4D input (got {}D input)'
228 | .format(input.dim()))
229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
230 |
231 |
232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
234 | of 4d inputs
235 | .. math::
236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
238 | standard-deviation are reduced across all devices during training.
239 | For example, when one uses `nn.DataParallel` to wrap the network during
240 | training, PyTorch's implementation normalize the tensor on each device using
241 | the statistics only on that device, which accelerated the computation and
242 | is also easy to implement, but the statistics might be inaccurate.
243 | Instead, in this synchronized version, the statistics will be computed
244 | over all training samples distributed on multiple devices.
245 |
246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
247 | as the built-in PyTorch implementation.
248 | The mean and standard-deviation are calculated per-dimension over
249 | the mini-batches and gamma and beta are learnable parameter vectors
250 | of size C (where C is the input size).
251 | During training, this layer keeps a running estimate of its computed mean
252 | and variance. The running sum is kept with a default momentum of 0.1.
253 | During evaluation, this running mean/variance is used for normalization.
254 | Because the BatchNorm is done over the `C` dimension, computing statistics
255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
256 | or Spatio-temporal BatchNorm
257 | Args:
258 | num_features: num_features from an expected input of
259 | size batch_size x num_features x depth x height x width
260 | eps: a value added to the denominator for numerical stability.
261 | Default: 1e-5
262 | momentum: the value used for the running_mean and running_var
263 | computation. Default: 0.1
264 | affine: a boolean value that when set to ``True``, gives the layer learnable
265 | affine parameters. Default: ``True``
266 | Shape:
267 | - Input: :math:`(N, C, D, H, W)`
268 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
269 | Examples:
270 | >>> # With Learnable Parameters
271 | >>> m = SynchronizedBatchNorm3d(100)
272 | >>> # Without Learnable Parameters
273 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
275 | >>> output = m(input)
276 | """
277 |
278 | def _check_input_dim(self, input):
279 | if input.dim() != 5:
280 | raise ValueError('expected 5D input (got {}D input)'
281 | .format(input.dim()))
282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/FCN_version/dataset/imutils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from PIL import Image
4 | # from scipy import misc
5 | import torch
6 | import torchvision
7 | from PIL import ImageEnhance
8 |
9 |
10 | def normalize_img(img, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]):
11 | imgarr = np.asarray(img)
12 | proc_img = (imgarr - mean[0]) / std[0]
13 | return proc_img
14 |
15 |
16 | def random_noise(pre_img, post_im):
17 | np.random.seed()
18 | noise_map = np.random.random(pre_img.size)
19 | return pre_img
20 |
21 |
22 | def random_scaling(pre_img, post_img, loc_label, dam_label, size_range, scale_range):
23 | h, w, = dam_label.shape
24 |
25 | min_ratio, max_ratio = scale_range
26 | assert min_ratio <= max_ratio
27 |
28 | ratio = random.uniform(min_ratio, max_ratio)
29 |
30 | new_scale = int(size_range[0] * ratio), int(size_range[1] * ratio)
31 |
32 | max_long_edge = max(new_scale)
33 | max_short_edge = min(new_scale)
34 | scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
35 |
36 | return _img_rescaling(pre_img, post_img, loc_label, dam_label, scale=ratio)
37 |
38 |
39 | def _img_rescaling(pre_img, post_img, loc_label, dam_label, scale=None):
40 | # scale = random.uniform(scales)
41 | h, w, = dam_label.shape
42 |
43 | new_scale = [int(scale * w), int(scale * h)]
44 |
45 | new_pre_img = Image.fromarray(pre_img.astype(np.uint8)).resize(new_scale, resample=Image.BILINEAR)
46 | new_pre_img = np.asarray(new_pre_img).astype(np.float32)
47 |
48 | new_post_img = Image.fromarray(post_img.astype(np.uint8)).resize(new_scale, resample=Image.BILINEAR)
49 | new_post_img = np.asarray(new_post_img).astype(np.float32)
50 |
51 | if dam_label is None:
52 | return new_pre_img, new_post_img
53 |
54 | new_dam_label = Image.fromarray(dam_label).resize(new_scale, resample=Image.NEAREST)
55 | new_dam_label = np.asarray(new_dam_label)
56 | new_loc_label = Image.fromarray(loc_label).resize(new_scale, resample=Image.NEAREST)
57 | new_loc_label = np.asarray(new_loc_label)
58 |
59 | return new_pre_img, new_post_img, new_loc_label, new_dam_label
60 |
61 |
62 | def img_resize_short(image, min_size=512):
63 | h, w, _ = image.shape
64 | if min(h, w) >= min_size:
65 | return image
66 |
67 | scale = float(min_size) / min(h, w)
68 | new_scale = [int(scale * w), int(scale * h)]
69 |
70 | new_image = Image.fromarray(image.astype(np.uint8)).resize(new_scale, resample=Image.BILINEAR)
71 | new_image = np.asarray(new_image).astype(np.float32)
72 |
73 | return new_image
74 |
75 |
76 | def random_resize(image, label, size_range=None):
77 | _new_size = random.randint(size_range[0], size_range[1])
78 |
79 | h, w, = label.shape
80 | scale = _new_size / float(max(h, w))
81 | new_scale = [int(scale * w), int(scale * h)]
82 |
83 | new_image, new_label = _img_rescaling(image, label, scale=new_scale)
84 |
85 | return new_image, new_label
86 |
87 |
88 | def random_fliplr(pre_img, post_img, label):
89 | if random.random() > 0.5:
90 | label = np.fliplr(label)
91 | pre_img = np.fliplr(pre_img)
92 | post_img = np.fliplr(post_img)
93 |
94 | return pre_img, post_img, label
95 |
96 |
97 | def random_fliplr_multicd(pre_img, post_img, pre_lc_label, label):
98 | if random.random() > 0.5:
99 | label = np.fliplr(label)
100 | pre_img = np.fliplr(pre_img)
101 | post_img = np.fliplr(post_img)
102 | pre_lc_label = np.fliplr(pre_lc_label)
103 |
104 | return pre_img, post_img, pre_lc_label, label
105 |
106 |
107 | def random_fliplr_with_object(pre_img, post_img, object_map, label):
108 | if random.random() > 0.5:
109 | label = np.fliplr(label)
110 | pre_img = np.fliplr(pre_img)
111 | post_img = np.fliplr(post_img)
112 | object_map = np.fliplr(object_map)
113 |
114 | return pre_img, post_img, object_map, label
115 |
116 |
117 | def random_flipud(pre_img, post_img, label):
118 | if random.random() > 0.5:
119 | label = np.flipud(label)
120 | pre_img = np.flipud(pre_img)
121 | post_img = np.flipud(post_img)
122 |
123 | return pre_img, post_img, label
124 |
125 |
126 | def random_flipud_multicd(pre_img, post_img, pre_lc_label, label):
127 | if random.random() > 0.5:
128 | label = np.flipud(label)
129 | pre_img = np.flipud(pre_img)
130 | post_img = np.flipud(post_img)
131 | pre_lc_label = np.flipud(pre_lc_label)
132 |
133 | return pre_img, post_img, pre_lc_label, label
134 |
135 |
136 | def random_flipud_with_object(pre_img, post_img, object_map, label):
137 | if random.random() > 0.5:
138 | object_map = np.flipud(object_map)
139 | label = np.flipud(label)
140 | pre_img = np.flipud(pre_img)
141 | post_img = np.flipud(post_img)
142 |
143 | return pre_img, post_img, object_map, label
144 |
145 |
146 | def random_rot(pre_img, post_img, label):
147 | k = random.randrange(3) + 1
148 |
149 | pre_img = np.rot90(pre_img, k).copy()
150 | post_img = np.rot90(post_img, k).copy()
151 | label = np.rot90(label, k).copy()
152 |
153 | return pre_img, post_img, label
154 |
155 |
156 | def random_rot_multicd(pre_img, post_img, pre_lc_label, label):
157 | k = random.randrange(3) + 1
158 |
159 | pre_img = np.rot90(pre_img, k).copy()
160 | post_img = np.rot90(post_img, k).copy()
161 | label = np.rot90(label, k).copy()
162 | pre_lc_label = np.rot90(pre_lc_label, k).copy()
163 |
164 | return pre_img, post_img, pre_lc_label, label
165 |
166 |
167 | def random_rot_with_object(pre_img, post_img, object_map, label):
168 | k = random.randrange(3) + 1
169 |
170 | pre_img = np.rot90(pre_img, k).copy()
171 | post_img = np.rot90(post_img, k).copy()
172 | object_map = np.rot90(object_map, k).copy()
173 | label = np.rot90(label, k).copy()
174 |
175 | return pre_img, post_img, object_map, label
176 |
177 |
178 | def random_crop(pre_img, post_img, label, crop_size, mean_rgb=[0, 0, 0], ignore_index=255):
179 | h, w = label.shape
180 |
181 | H = max(crop_size, h)
182 | W = max(crop_size, w)
183 |
184 | # pad_pre_image = np.zeros((H, W), dtype=np.float32)
185 | pad_pre_image = np.zeros((H, W, pre_img.shape[-1]), dtype=np.float32)
186 |
187 | pad_post_image = np.zeros((H, W, pre_img.shape[-1]), dtype=np.float32)
188 | pad_label = np.ones((H, W), dtype=np.float32) * ignore_index
189 |
190 | # pad_pre_image[:, :] = mean_rgb[0]
191 | # pad_pre_image[:, :, 0] = mean_rgb[0]
192 | # pad_pre_image[:, :, 1] = mean_rgb[1]
193 | # pad_pre_image[:, :, 2] = mean_rgb[2]
194 | #
195 | # pad_post_image[:, :, 0] = mean_rgb[0]
196 | # pad_post_image[:, :, 1] = mean_rgb[1]
197 | # pad_post_image[:, :, 2] = mean_rgb[2]
198 |
199 | H_pad = int(np.random.randint(H - h + 1))
200 | W_pad = int(np.random.randint(W - w + 1))
201 |
202 | pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w)] = pre_img
203 | # pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = pre_img
204 | pad_post_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = post_img
205 | pad_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = label
206 |
207 | def get_random_cropbox(cat_max_ratio=0.75):
208 |
209 | for i in range(10):
210 |
211 | H_start = random.randrange(0, H - crop_size + 1, 1)
212 | H_end = H_start + crop_size
213 | W_start = random.randrange(0, W - crop_size + 1, 1)
214 | W_end = W_start + crop_size
215 |
216 | temp_label = pad_label[H_start:H_end, W_start:W_end]
217 | index, cnt = np.unique(temp_label, return_counts=True)
218 | cnt = cnt[index != ignore_index]
219 | if len(cnt > 1) and np.max(cnt) / np.sum(cnt) < cat_max_ratio:
220 | break
221 |
222 | return H_start, H_end, W_start, W_end,
223 |
224 | H_start, H_end, W_start, W_end = get_random_cropbox()
225 | # print(W_start)
226 | pre_img = pad_pre_image[H_start:H_end, W_start:W_end :]
227 | # pre_img = pad_pre_image[H_start:H_end, W_start:W_end, :]
228 | post_img = pad_post_image[H_start:H_end, W_start:W_end, :]
229 | label = pad_label[H_start:H_end, W_start:W_end]
230 | # cmap = colormap()
231 | # misc.imsave('cropimg.png',image/255)
232 | # misc.imsave('croplabel.png',encode_cmap(label))
233 | return pre_img, post_img, label
234 |
235 |
236 | def random_crop_multicd(pre_img, post_img, pre_lc_label, label, crop_size, mean_rgb=[0, 0, 0], ignore_index=255):
237 | h, w = label.shape
238 |
239 | H = max(crop_size, h)
240 | W = max(crop_size, w)
241 |
242 | pad_pre_image = np.zeros((H, W, 3), dtype=np.float32)
243 |
244 | pad_post_image = np.zeros((H, W, 3), dtype=np.float32)
245 |
246 | pad_pre_lc_label = np.zeros((H, W), dtype=np.float32)
247 | pad_label = np.ones((H, W), dtype=np.float32) * ignore_index
248 |
249 | # pad_pre_image[:, :] = mean_rgb[0]
250 | pad_pre_image[:, :, 0] = mean_rgb[0]
251 | pad_pre_image[:, :, 1] = mean_rgb[1]
252 | pad_pre_image[:, :, 2] = mean_rgb[2]
253 |
254 | pad_post_image[:, :, 0] = mean_rgb[0]
255 | pad_post_image[:, :, 1] = mean_rgb[1]
256 | pad_post_image[:, :, 2] = mean_rgb[2]
257 |
258 | H_pad = int(np.random.randint(H - h + 1))
259 | W_pad = int(np.random.randint(W - w + 1))
260 |
261 | pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = pre_img
262 | pad_post_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = post_img
263 |
264 | pad_pre_lc_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = pre_lc_label
265 | pad_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = label
266 |
267 | def get_random_cropbox(cat_max_ratio=0.75):
268 |
269 | for i in range(10):
270 |
271 | H_start = random.randrange(0, H - crop_size + 1, 1)
272 | H_end = H_start + crop_size
273 | W_start = random.randrange(0, W - crop_size + 1, 1)
274 | W_end = W_start + crop_size
275 |
276 | temp_label = pad_label[H_start:H_end, W_start:W_end]
277 | index, cnt = np.unique(temp_label, return_counts=True)
278 | cnt = cnt[index != ignore_index]
279 | if len(cnt > 1) and np.max(cnt) / np.sum(cnt) < cat_max_ratio:
280 | break
281 |
282 | return H_start, H_end, W_start, W_end,
283 |
284 | H_start, H_end, W_start, W_end = get_random_cropbox()
285 | # print(W_start)
286 | pre_img = pad_pre_image[H_start:H_end, W_start:W_end, :]
287 | post_img = pad_post_image[H_start:H_end, W_start:W_end, :]
288 | pre_lc_label = pad_pre_lc_label[H_start:H_end, W_start:W_end]
289 | label = pad_label[H_start:H_end, W_start:W_end]
290 | # cmap = colormap()
291 | # misc.imsave('cropimg.png',image/255)
292 | # misc.imsave('croplabel.png',encode_cmap(label))
293 | return pre_img, post_img, pre_lc_label, label
294 |
295 |
296 | def random_crop_with_object(pre_img, post_img, object_map, label, crop_size, mean_rgb=[0, 0, 0], ignore_index=255):
297 | h, w = label.shape
298 |
299 | H = max(crop_size, h)
300 | W = max(crop_size, w)
301 |
302 | # pad_pre_image = np.zeros((H, W), dtype=np.float32)
303 | pad_pre_image = np.zeros((H, W, 3), dtype=np.float32)
304 | pad_post_image = np.zeros((H, W, 3), dtype=np.float32)
305 | pad_object_map = np.zeros((H, W), dtype=np.long)
306 | pad_label = np.ones((H, W), dtype=np.float32) * ignore_index
307 |
308 | pad_pre_image[:, :] = mean_rgb[0]
309 | pad_pre_image[:, :, 1] = mean_rgb[1]
310 | pad_pre_image[:, :, 2] = mean_rgb[2]
311 |
312 | pad_post_image[:, :, 0] = mean_rgb[0]
313 | pad_post_image[:, :, 1] = mean_rgb[1]
314 | pad_post_image[:, :, 2] = mean_rgb[2]
315 |
316 | H_pad = int(np.random.randint(H - h + 1))
317 | W_pad = int(np.random.randint(W - w + 1))
318 |
319 | # pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w)] = pre_img
320 | pad_pre_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = pre_img
321 | pad_post_image[H_pad:(H_pad + h), W_pad:(W_pad + w), :] = post_img
322 | pad_object_map[H_pad:(H_pad + h), W_pad:(W_pad + w)] = object_map
323 | pad_label[H_pad:(H_pad + h), W_pad:(W_pad + w)] = label
324 |
325 | def get_random_cropbox(cat_max_ratio=0.75):
326 |
327 | for i in range(10):
328 |
329 | H_start = random.randrange(0, H - crop_size + 1, 1)
330 | H_end = H_start + crop_size
331 | W_start = random.randrange(0, W - crop_size + 1, 1)
332 | W_end = W_start + crop_size
333 |
334 | temp_label = pad_label[H_start:H_end, W_start:W_end]
335 | index, cnt = np.unique(temp_label, return_counts=True)
336 | cnt = cnt[index != ignore_index]
337 | if len(cnt > 1) and np.max(cnt) / np.sum(cnt) < cat_max_ratio:
338 | break
339 |
340 | return H_start, H_end, W_start, W_end,
341 |
342 | H_start, H_end, W_start, W_end = get_random_cropbox()
343 | # print(W_start)
344 | # pre_img = pad_pre_image[H_start:H_end, W_start:W_end]
345 | pre_img = pad_pre_image[H_start:H_end, W_start:W_end, :]
346 | post_img = pad_post_image[H_start:H_end, W_start:W_end, :]
347 | object_map = pad_object_map[H_start:H_end, W_start:W_end]
348 | label = pad_label[H_start:H_end, W_start:W_end]
349 | # cmap = colormap()
350 | # misc.imsave('cropimg.png',image/255)
351 | # misc.imsave('croplabel.png',encode_cmap(label))
352 | return pre_img, post_img, object_map, label
353 |
354 |
355 | def encode_cmap(label):
356 | cmap = colormap()
357 | return cmap[label.astype(np.int16), :]
358 |
359 |
360 | def tensorboard_image(inputs=None, outputs=None, labels=None, bgr=None):
361 | ## images
362 | inputs[:, 0, :, :] = inputs[:, 0, :, :] + bgr[0]
363 | inputs[:, 1, :, :] = inputs[:, 1, :, :] + bgr[1]
364 | inputs[:, 2, :, :] = inputs[:, 2, :, :] + bgr[2]
365 | inputs = inputs[:, [2, 1, 0], :, :].type(torch.uint8)
366 | grid_inputs = torchvision.utils.make_grid(tensor=inputs, nrow=2)
367 |
368 | ## preds
369 | preds = torch.argmax(outputs, dim=1).cpu().numpy()
370 | preds_cmap = encode_cmap(preds)
371 | preds_cmap = torch.from_numpy(preds_cmap).permute([0, 3, 1, 2])
372 | grid_outputs = torchvision.utils.make_grid(tensor=preds_cmap, nrow=2)
373 |
374 | ## labels
375 | labels_cmap = encode_cmap(labels.cpu().numpy())
376 | labels_cmap = torch.from_numpy(labels_cmap).permute([0, 3, 1, 2])
377 | grid_labels = torchvision.utils.make_grid(tensor=labels_cmap, nrow=2)
378 |
379 | return grid_inputs, grid_outputs, grid_labels
380 |
381 |
382 | def colormap(N=256, normalized=False):
383 | def bitget(byteval, idx):
384 | return ((byteval & (1 << idx)) != 0)
385 |
386 | dtype = 'float32' if normalized else 'uint8'
387 | cmap = np.zeros((N, 3), dtype=dtype)
388 | for i in range(N):
389 | r = g = b = 0
390 | c = i
391 | for j in range(8):
392 | r = r | (bitget(c, 0) << 7 - j)
393 | g = g | (bitget(c, 1) << 7 - j)
394 | b = b | (bitget(c, 2) << 7 - j)
395 | c = c >> 3
396 |
397 | cmap[i] = np.array([r, g, b])
398 |
399 | cmap = cmap / 255 if normalized else cmap
400 | return cmap
401 |
--------------------------------------------------------------------------------