├── download_resnet18.sh ├── criterion.py ├── init.py ├── LICENSE ├── README.md ├── test.py ├── dataset_pcd.py ├── train.py └── cscdnet.py /download_resnet18.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | wget https://download.pytorch.org/models/resnet18-5c106cde.pth 3 | 4 | -------------------------------------------------------------------------------- /criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class CrossEntropyLoss2d(nn.Module): 5 | 6 | def __init__(self, weight=None): 7 | super().__init__() 8 | self.loss = nn.NLLLoss2d(weight) 9 | 10 | def forward(self, outputs, mask): 11 | return self.loss(F.log_softmax(outputs,dim=1), mask) 12 | 13 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | 4 | def xavier_uniform_relu(modules): 5 | for m in modules: 6 | if isinstance(m, nn.Conv2d): 7 | init.xavier_uniform(m.weight.data, gain=init.calculate_gain('relu')) 8 | if m.bias is not None: 9 | m.bias.data.zero_() 10 | elif isinstance(m, nn.BatchNorm2d): 11 | m.weight.data.fill_(1) 12 | m.bias.data.zero_() 13 | 14 | def xavier_uniform_sigmoid(modules): 15 | for m in modules: 16 | if isinstance(m, nn.Conv2d): 17 | init.xavier_uniform(m.weight.data, gain=init.calculate_gain('sigmoid')) 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | elif isinstance(m, nn.BatchNorm2d): 21 | m.weight.data.fill_(1) 22 | m.bias.data.zero_() 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ken Sakurada 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Scene Change Detection Network 2 | This is an official implementation of "Correlated Siamese Change Detection Network (CSCDNet)" and "Silhouette-based Semantic Change Detection Network (SSCDNet)" in "[Weakly Supervised Silhouette-based Semantic Scene Change Detection](https://arxiv.org/abs/1811.11985)" (ICRA2020). (SSCDNet and PSCD datast are preparing...) 3 | 4 |

5 | 6 |

7 | 8 | ## Environments 9 | This code was developed and tested with Python 3.6.8 and PyTorch 1.0 and CUDA 9.2. 10 | * GCC 11 | ``` 12 | # Build and install GCC (>= 7.4.0) if not installed 13 | # Set path variables 14 | export PATH=/home/$USER/local/gcc/bin:$PATH 15 | export LD_LIBRARY_PATH=/home/$USER/local/gcc/lib64:$LD_LIBRARY_PATH 16 | ``` 17 | 18 | * Virtualenv for system setting 19 | ``` 20 | # Set CUDA path. 21 | # In case of server, the following CUDA path setting with module load command might be necessary. 22 | module load cuda/9.2/9.2.88.1 23 | 24 | # Create a virtualenv environment 25 | virtualenv -p python /path/to/env/pytorch1.0cuda9.2 26 | 27 | #Activate the virtualenv environment 28 | source /path/to/env/pytorch1.0cuda9.2/bin/activate 29 | 30 | # Install dependencies 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | * Download the pretrained model of resnet18 35 | ``` 36 | sh download_resnet.sh 37 | ``` 38 | 39 | * Build correlation layer package from [flownet2](https://github.com/NVIDIA/flownet2-pytorch). 40 | ``` 41 | sh build_correlation_package.sh 42 | ``` 43 | 44 | ## Dataset 45 | Please prepare the following format dataset using change detection datasets such as [TSUNAMI](https://kensakurada.github.io/pcd_dataset.html). 46 | In the case of a large dataset, it is not necessary to split it. 47 | 48 | Training 49 | ``` 50 | pcd_5cv 51 | ├── set0/ 52 | │ ├── train/ # *.jpg 53 | │ ├── test/ # *.jpg 54 | │ ├── mask/ # *.png 55 | | ├── train.txt 56 | | ├── test.txt 57 | ├── set1/ 58 | ... 59 | ├── set2/ 60 | ... 61 | ├── set3/ 62 | ... 63 | ├── set4/ 64 | ├── train/ # *.jpg 65 | ├── test/ # *.jpg 66 | ├── mask/ # *.png 67 | ├── train.txt 68 | ├── test.txt 69 | ``` 70 | 71 | Testing 72 | ``` 73 | pcd 74 | ├── TSUNAMI/ 75 | ├── t0/ # *.jpg 76 | ├── t1/ # *.jpg 77 | ├── mask/ # *.png 78 | ``` 79 | 80 | 81 | ## Training 82 | Train change detection network with correlation layers (CSCDNet) 83 | ``` 84 | # i-th set of five-hold cross-validation (0 <= i < 5) 85 | python train.py --cvset i --use-corr --datadir /path/to/pcd_5cv --checkpointdir /path/to/log --max-iteration 50000 --num-workers 16 --batch-size 32 --icount-plot 50 --icount-save 10000 86 | ``` 87 | 88 | Train change detection network without correlation layers (CDNet) 89 | ``` 90 | # i-th set of five-hold cross-validation (0 <= i < 5) 91 | python train.py --cvset i --datadir /path/to/pcd_5cv --checkpointdir /path/to/log --max-iteration 50000 --num-workers 16 --batch-size 32 --icount-plot 50 --icount-save 10000 92 | ``` 93 | 94 | You can start a tensorboard session 95 | ``` 96 | tensorboard --logdir=/path/to/log 97 | ``` 98 | 99 | 100 | ## Testing 101 | CSCDNet 102 | ``` 103 | python test.py --use-corr --dataset PCD --datadir /path/to/pcd --checkpointdir /path/to/log/cscdnet/checkpoint 104 | ``` 105 | CDNet 106 | ``` 107 | python test.py --dataset PCD --datadir /path/to/pcd --checkpointdir /path/to/log/cdnet/checkpoint 108 | ``` 109 | 110 | ## Citation 111 | If you find this implementation useful in your work, please cite the paper. Here is a BibTeX entry: 112 | ``` 113 | @article{sakurada2020weakly, 114 | title={Weakly Supervised Silhouette-based Semantic Scene Change Detection}, 115 | author={Sakurada, Ken and Shibuya, Mikiya and Wang Weimin}, 116 | journal={Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)}, 117 | year={2020} 118 | } 119 | ``` 120 | The preprint can be found [here](https://arxiv.org/abs/1811.11985). 121 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os.path 4 | from argparse import ArgumentParser 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import sys 9 | sys.path.append("./correlation_package/build/lib.linux-x86_64-3.6") 10 | import cscdnet 11 | 12 | 13 | class DataInfo: 14 | def __init__(self): 15 | self.width = 1024 16 | self.height = 224 17 | self.no_start = 0 18 | self.no_end = 100 19 | self.num_cv = 5 20 | 21 | class Test: 22 | def __init__(self, arguments): 23 | self.args = arguments 24 | self.di = DataInfo() 25 | 26 | def test(self): 27 | 28 | _inputs = torch.from_numpy(np.concatenate((self.t0, self.t1), axis=0)).contiguous() 29 | _inputs = Variable(_inputs).view(1, -1, self.h_resize, self.w_resize) 30 | _inputs = _inputs.cuda() 31 | _outputs = self.model(_inputs) 32 | 33 | inputs = _inputs[0].cpu().data 34 | image_t0 = inputs[0:3, :, :] 35 | image_t1 = inputs[3:6, :, :] 36 | image_t0 = (image_t0 + 1.0) * 128 37 | image_t1 = (image_t1 + 1.0) * 128 38 | mask_gt = np.where(self.mask.data.numpy().squeeze(axis=0) == True, 0, 255) 39 | 40 | outputs = _outputs[0].cpu().data 41 | mask_pred = F.softmax(outputs[0:2, :, :], dim=0)[1] * 255 42 | 43 | self.display_results(image_t0, image_t1, mask_pred, mask_gt) 44 | 45 | def display_results(self, t0, t1, mask_pred, mask_gt): 46 | 47 | w, h = self.w_orig, self.h_orig 48 | t0_disp = cv2.resize(np.transpose(t0.numpy(), (1, 2, 0)).astype(np.uint8), (w, h)) 49 | t1_disp = cv2.resize(np.transpose(t1.numpy(), (1, 2, 0)).astype(np.uint8), (w, h)) 50 | mask_pred_disp = cv2.resize(cv2.cvtColor(mask_pred.numpy().astype(np.uint8), cv2.COLOR_GRAY2RGB), (w, h)) 51 | mask_gt_disp = cv2.resize(cv2.cvtColor(mask_gt.astype(np.uint8), cv2.COLOR_GRAY2RGB), (w, h)) 52 | 53 | img_out = np.zeros((h* 2, w * 2, 3), dtype=np.uint8) 54 | img_out[0:h, 0:w, :] = t0_disp 55 | img_out[0:h, w:w * 2, :] = t1_disp 56 | img_out[h:h * 2, 0:w * 1, :] = mask_gt_disp 57 | img_out[h:h * 2, w * 1:w * 2, :] = mask_pred_disp 58 | for dn, img in zip(['mask', 'disp'], [mask_pred_disp, img_out]): 59 | dn_save = os.path.join(self.args.checkpointdir, 'result', dn) 60 | fn_save = os.path.join(dn_save, '{0:08d}.png'.format(self.index)) 61 | if not os.path.exists(dn_save): 62 | os.makedirs(dn_save) 63 | print('Writing ... ' + fn_save) 64 | cv2.imwrite(fn_save, img) 65 | 66 | def run(self): 67 | 68 | for i_set in range(0,self.di.num_cv): 69 | if self.args.use_corr: 70 | print('Correlated Siamese Change Detection Network (CSCDNet)') 71 | self.model = cscdnet.Model(inc=6, outc=2, corr=True, pretrained=True) 72 | fn_model = os.path.join(os.path.join(self.args.checkpointdir, 'set{}'.format(i_set), 'cscdnet-00030000.pth')) 73 | else: 74 | print('Siamese Change Detection Network (Siamese CDResNet)') 75 | self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True) 76 | fn_model = os.path.join(os.path.join(self.args.checkpointdir, 'set{}'.format(i_set), 'cdnet-00030000.pth')) 77 | 78 | if os.path.isfile(fn_model) is False: 79 | print("Error: Cannot read file ... " + fn_model) 80 | exit(-1) 81 | else: 82 | print("Reading model ... " + fn_model) 83 | self.model.load_state_dict(torch.load(fn_model)) 84 | self.model = self.model.cuda() 85 | 86 | if self.args.dataset == 'PCD': 87 | from dataset_pcd import PCD_full 88 | for dataset in ['TSUNAMI']: 89 | loader_test = PCD_full(os.path.join(self.args.datadir,dataset), self.di.no_start, self.di.no_end, self.di.width, self.di.height) 90 | for index in range(0,loader_test.__len__()): 91 | if i_set * (10 / self.di.num_cv) <= (index % 10) < (i_set + 1) * (10 / self.di.num_cv): 92 | self.index = index 93 | self.t0, self.t1, self.mask, self.w_orig, self.h_orig, self.w_resize, self.h_resize = loader_test.__getitem__(index) 94 | self.test() 95 | else: 96 | continue 97 | else: 98 | print('Error: Unexpected dataset') 99 | exit(-1) 100 | 101 | 102 | if __name__ == '__main__': 103 | 104 | parser = ArgumentParser(description='Start testing ...') 105 | parser.add_argument('--datadir', required=True) 106 | parser.add_argument('--checkpointdir', required=True) 107 | parser.add_argument('--use-corr', action='store_true', help='using correlation layer') 108 | parser.add_argument('--dataset', required=True) 109 | 110 | test = Test(parser.parse_args()) 111 | test.run() 112 | 113 | 114 | -------------------------------------------------------------------------------- /dataset_pcd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | EXTENSIONS = ['jpg','.png'] 8 | 9 | def check_img(filename): 10 | return any(filename.endswith(ext) for ext in EXTENSIONS) 11 | 12 | def get_img_path(root, basename, extension): 13 | return os.path.join(root, basename+extension) 14 | 15 | def get_img_basename(filename): 16 | return os.path.basename(os.path.splitext(filename)[0]) 17 | 18 | class PCD(Dataset): 19 | 20 | def __init__(self, root): 21 | super(PCD, self).__init__() 22 | self.img_t0_root = os.path.join(root, 't0') 23 | self.img_t1_root = os.path.join(root, 't1') 24 | self.mask_root = os.path.join(root, 'mask') 25 | 26 | self.filenames = [get_img_basename(f) for f in os.listdir(self.mask_root) if check_img(f)] 27 | self.filenames.sort() 28 | 29 | print('{}:{}'.format(root,len(self.filenames))) 30 | 31 | def __getitem__(self, index): 32 | filename = self.filenames[index] 33 | 34 | fn_img_t0 = get_img_path(self.img_t0_root, filename, '.jpg') 35 | fn_img_t1 = get_img_path(self.img_t1_root, filename, '.jpg') 36 | fn_mask = get_img_path(self.mask_root, filename, '.png') 37 | 38 | if os.path.isfile(fn_img_t0) == False: 39 | print ('Error: File Not Found: ' + fn_img_t0) 40 | exit(-1) 41 | if os.path.isfile(fn_img_t1) == False: 42 | print ('Error: File Not Found: ' + fn_img_t1) 43 | exit(-1) 44 | if os.path.isfile(fn_mask) == False: 45 | print ('Error: File Not Found: ' + fn_mask) 46 | exit(-1) 47 | 48 | img_t0 = cv2.imread(fn_img_t0, cv2.IMREAD_COLOR) 49 | img_t1 = cv2.imread(fn_img_t1, cv2.IMREAD_COLOR) 50 | mask = cv2.imread(fn_mask, cv2.IMREAD_GRAYSCALE) 51 | w,h,c = img_t0.shape 52 | r = 286./min(w,h) 53 | # resize images so that min(w, h) == 256 54 | img_t0 = cv2.resize(img_t0, (int(r*w), int(r*h))) 55 | img_t1 = cv2.resize(img_t1, (int(r*w), int(r*h))) 56 | mask = cv2.resize(mask, (int(r*w), int(r*h)))[:,:,np.newaxis] 57 | 58 | img_t0_ = np.asarray(img_t0).astype("f").transpose(2, 0, 1) / 128.0 - 1.0 59 | img_t1_ = np.asarray(img_t1).astype("f").transpose(2, 0, 1) / 128.0 - 1.0 60 | # black/white inverting 61 | mask_ = np.asarray(mask>128).astype("int").transpose(2, 0, 1) 62 | 63 | crop_width = 256 64 | _,h,w = img_t0_.shape 65 | x_l = np.random.randint(0,w-crop_width) 66 | x_r = x_l+crop_width 67 | y_l = np.random.randint(0,h-crop_width) 68 | y_r = y_l+crop_width 69 | 70 | input_ = torch.from_numpy(np.concatenate((img_t0_[:,y_l:y_r,x_l:x_r], img_t1_[:,y_l:y_r,x_l:x_r]), axis=0)) 71 | mask_ = torch.from_numpy(mask_[:, y_l:y_r, x_l:x_r]).long() 72 | 73 | return input_, mask_ 74 | 75 | def __len__(self): 76 | return len(self.filenames) 77 | 78 | def get_random_index(self): 79 | index = np.random.randint(0, len(self.filenames)) 80 | return index 81 | 82 | 83 | 84 | class PCD_full(Dataset): 85 | 86 | def __init__(self, root, id_s, id_e, width, height): 87 | super(PCD_full, self).__init__() 88 | self.img_t0_root = os.path.join(root, 't0') 89 | self.img_t1_root = os.path.join(root, 't1') 90 | self.mask_root = os.path.join(root, 'mask') 91 | 92 | self.filenames = [get_img_basename(f) for f in os.listdir(self.mask_root) if check_img(f)] 93 | self.filenames.sort() 94 | self.filenames = self.filenames[id_s:id_e] 95 | 96 | self.width = width 97 | self.height = height 98 | 99 | def __getitem__(self, index): 100 | filename = self.filenames[index] 101 | 102 | fn_img_t0 = get_img_path(self.img_t0_root, filename, '.jpg') 103 | fn_img_t1 = get_img_path(self.img_t1_root, filename, '.jpg') 104 | fn_mask = get_img_path(self.mask_root, filename, '.png') 105 | 106 | if os.path.isfile(fn_img_t0) == False: 107 | print ('Error: File Not Found: ' + fn_img_t0) 108 | exit(-1) 109 | if os.path.isfile(fn_img_t1) == False: 110 | print ('Error: File Not Found: ' + fn_img_t1) 111 | exit(-1) 112 | 113 | if os.path.isfile(fn_mask) == False: 114 | print ('Error: File Not Found: ' + fn_mask) 115 | exit(-1) 116 | 117 | img_t0 = cv2.imread(fn_img_t0, cv2.IMREAD_COLOR) 118 | img_t1 = cv2.imread(fn_img_t1, cv2.IMREAD_COLOR) 119 | mask = cv2.imread(fn_mask, cv2.IMREAD_GRAYSCALE) 120 | h,w,c = img_t0.shape 121 | r = min(w,h)/256 122 | w_r = int(256*max(w/256,1)) 123 | h_r = int(256*max(h/256,1)) 124 | # resize images so that min(w, h) == 256 125 | img_t0 = cv2.resize(img_t0, (w_r, h_r)) 126 | img_t1 = cv2.resize(img_t1, (w_r, h_r)) 127 | mask = cv2.resize(mask, (w_r, h_r))[:,:,np.newaxis] 128 | 129 | img_t0_ = np.asarray(img_t0).astype("f").transpose(2, 0, 1) / 128.0 - 1.0 130 | img_t1_ = np.asarray(img_t1).astype("f").transpose(2, 0, 1) / 128.0 - 1.0 131 | mask_ = np.asarray(mask>128).astype("int").transpose(2, 0, 1) 132 | mask_ = torch.from_numpy(mask_).long() 133 | 134 | return img_t0_, img_t1_, mask_, w, h, w_r, h_r 135 | 136 | def __len__(self): 137 | return len(self.filenames) 138 | 139 | 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from argparse import ArgumentParser 4 | import cv2 5 | import csv 6 | import os.path 7 | import numpy as np 8 | import torch 9 | from torch.optim import Adam, lr_scheduler 10 | from torch.autograd import Variable 11 | from torch.utils.data import DataLoader 12 | from tensorboardX import SummaryWriter 13 | from criterion import CrossEntropyLoss2d 14 | from dataset_pcd import PCD 15 | import sys 16 | sys.path.append("./correlation_package/build/lib.linux-x86_64-3.6") 17 | import cscdnet 18 | 19 | 20 | def colormap(): 21 | cmap=np.zeros([2, 3]).astype(np.uint8) 22 | 23 | cmap[0,:] = np.array([0, 0, 0]) 24 | cmap[1,:] = np.array([255, 255, 255]) 25 | 26 | return cmap 27 | 28 | 29 | class Colorization: 30 | 31 | def __init__(self, n=2): 32 | self.cmap = colormap() 33 | self.cmap = torch.from_numpy(np.array(self.cmap[:n])) 34 | 35 | def __call__(self, gray_image): 36 | size = gray_image.size() 37 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 38 | 39 | for label in range(0, len(self.cmap)): 40 | mask = gray_image[0] == label 41 | 42 | color_image[0][mask] = self.cmap[label][0] 43 | color_image[1][mask] = self.cmap[label][1] 44 | color_image[2][mask] = self.cmap[label][2] 45 | 46 | return color_image 47 | 48 | 49 | class Training: 50 | def __init__(self, arguments): 51 | self.args = arguments 52 | self.icount = 0 53 | if self.args.use_corr: 54 | self.dn_save = os.path.join(self.args.checkpointdir,'cscdnet','checkpointdir','set{}'.format(self.args.cvset)) 55 | else: 56 | self.dn_save = os.path.join(self.args.checkpointdir,'cdnet','checkpointdir','set{}'.format(self.args.cvset)) 57 | 58 | def train(self): 59 | 60 | self.color_transform = Colorization(2) 61 | 62 | # Dataset loader for train and test 63 | dataset_train = DataLoader( 64 | PCD(os.path.join(self.args.datadir, 'set{}'.format(self.args.cvset), 'train')), 65 | num_workers=self.args.num_workers, batch_size=self.args.batch_size, shuffle=True) 66 | self.dataset_test = PCD(os.path.join(self.args.datadir, 'set{}'.format(self.args.cvset), 'test')) 67 | 68 | self.test_path = os.path.join(self.dn_save, 'test') 69 | if not os.path.exists(self.test_path): 70 | os.makedirs(self.test_path) 71 | 72 | # Set loss function, optimizer and learning rate 73 | weight = torch.ones(2) 74 | criterion = CrossEntropyLoss2d(weight.cuda()) 75 | optimizer = Adam(self.model.parameters(), lr=0.0001, betas=(0.5, 0.999)) 76 | lambda1 = lambda icount: (float)(self.args.max_iteration - icount) / (float)(self.args.max_iteration) 77 | model_lr_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 78 | 79 | fn_loss = os.path.join(self.dn_save,'loss.csv') 80 | f_loss = open(fn_loss, 'w') 81 | writer = csv.writer(f_loss) 82 | 83 | self.writers= SummaryWriter(os.path.join(self.dn_save, 'log')) 84 | 85 | # Training loop 86 | icount_loss = [] 87 | while self.icount < self.args.max_iteration: 88 | model_lr_scheduler.step() 89 | for step, (inputs_train, mask_train) in enumerate(dataset_train): 90 | inputs_train = inputs_train.cuda() 91 | mask_train = mask_train.cuda() 92 | 93 | inputs_train = Variable(inputs_train) 94 | mask_train = Variable(mask_train) 95 | outputs_train = self.model(inputs_train) 96 | 97 | optimizer.zero_grad() 98 | self.loss = criterion(outputs_train, mask_train[:, 0]) 99 | 100 | self.loss.backward() 101 | optimizer.step() 102 | 103 | self.icount += 1 104 | icount_loss.append(self.loss.item()) 105 | writer.writerow([self.icount, self.loss.item()]) 106 | if self.args.icount_plot > 0 and self.icount % self.args.icount_plot == 0: 107 | self.test() 108 | average = sum(icount_loss) / len(icount_loss) 109 | print('loss: {0} (icount: {1})'.format(average, self.icount)) 110 | icount_loss.clear() 111 | 112 | if self.args.icount_save > 0 and self.icount % self.args.icount_save == 0: 113 | self.checkpoint() 114 | 115 | f_loss.close() 116 | 117 | def test(self): 118 | 119 | index_test = self.dataset_test.get_random_index() 120 | inputs_test, mask_gt_test = self.dataset_test[index_test] 121 | inputs_test = inputs_test[np.newaxis, :, :] 122 | inputs_test = inputs_test.cuda() 123 | inputs_test = Variable(inputs_test) 124 | outputs_test = self.model(inputs_test) 125 | 126 | inputs = inputs_test[0].cpu().data 127 | t0_test = inputs[0:3, :, :] 128 | t1_test = inputs[3:6, :, :] 129 | t0_test = (t0_test + 1.0) * 128 130 | t1_test = (t1_test + 1.0) * 128 131 | mask_gt = mask_gt_test.numpy().astype(np.uint8) * 255 132 | 133 | outputs = outputs_test[0][np.newaxis, :, :, :] 134 | outputs = outputs[:, 0:2, :, :] 135 | mask_pred = np.transpose(self.color_transform(outputs[0].cpu().max(0)[1][np.newaxis, :, :].data).numpy(), (1, 2, 0)).astype(np.uint8) 136 | 137 | img_out = self.display_results(t0_test, t1_test, mask_pred, mask_gt) 138 | self.log_tbx(torch.from_numpy(np.transpose(np.flip(img_out,axis=2).copy(), (2, 0, 1)))) 139 | 140 | def display_results(self, t0, t1, mask_pred, mask_gt): 141 | 142 | rows = cols = 256 143 | img_out = np.zeros((rows * 2, cols * 2, 3), dtype=np.uint8) 144 | img_out[0:rows, 0:cols, :] = np.transpose(t0.numpy(), (1, 2, 0)).astype(np.uint8) 145 | img_out[0:rows, cols:cols * 2, :] = np.transpose(t1.numpy(), (1, 2, 0)).astype(np.uint8) 146 | img_out[rows:rows * 2, 0:cols, :] = cv2.cvtColor(np.transpose(mask_gt, (1, 2, 0)), cv2.COLOR_GRAY2RGB) 147 | img_out[rows:rows * 2, cols:cols * 2, :] = mask_pred 148 | 149 | return img_out 150 | 151 | # Output results for tensorboard 152 | def log_tbx(self, image): 153 | 154 | writer = self.writers 155 | writer.add_scalar('data/loss', self.loss.item(), self.icount) 156 | writer.add_image('change detection', image, self.icount) 157 | 158 | def checkpoint(self): 159 | if self.args.use_corr: 160 | filename = 'cscdnet-{0:08d}.pth'.format(self.icount) 161 | else: 162 | filename = 'cdnet-{0:08d}.pth'.format(self.icount) 163 | torch.save(self.model.state_dict(), os.path.join(self.dn_save, filename)) 164 | print('save: {0} (iteration: {1})'.format(filename, self.icount)) 165 | 166 | def run(self): 167 | 168 | if self.args.use_corr: 169 | print('Correlated Siamese Change Detection Network (CSCDNet)') 170 | self.model = cscdnet.Model(inc=6, outc=2, corr=True, pretrained=True) 171 | else: 172 | print('Siamese Change Detection Network (Siamese CDResNet)') 173 | self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True) 174 | 175 | self.model = self.model.cuda() 176 | self.train() 177 | 178 | 179 | if __name__ == '__main__': 180 | 181 | parser = ArgumentParser(description='Start training ...') 182 | parser.add_argument('--checkpointdir', required=True) 183 | parser.add_argument('--datadir', required=True) 184 | parser.add_argument('--use-corr', action='store_true', help='using correlation layer') 185 | parser.add_argument('--max-iteration', type=int, default=50000) 186 | parser.add_argument('--num-workers', type=int, default=4) 187 | parser.add_argument('--batch-size', type=int, default=32) 188 | parser.add_argument('--cvset', type=int, default=0) 189 | parser.add_argument('--icount-plot', type=int, default=0) 190 | parser.add_argument('--icount-save', type=int, default=10) 191 | 192 | training = Training(parser.parse_args()) 193 | training.run() 194 | -------------------------------------------------------------------------------- /cscdnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import init 4 | from collections import OrderedDict 5 | from correlation_package.correlation import Correlation 6 | 7 | class Model(nn.Module): 8 | def __init__(self, inc, outc, corr=True, pretrained=True): 9 | super(Model, self).__init__() 10 | 11 | self.corr = corr 12 | 13 | # encoder1 14 | self.enc1_conv1 = nn.Conv2d(int(inc/2), 64, 7, padding=3, stride=2, bias=False) 15 | self.enc1_bn1 = nn.BatchNorm2d(64) 16 | self.enc1_pool1 = nn.MaxPool2d(3, stride=2, padding=1) 17 | self.enc1_res1_1 = ResBL( 64, 64, 64, stride=1) 18 | self.enc1_res1_2 = ResBL( 64, 64, 64, stride=1) 19 | self.enc1_res2_1 = ResBL( 64, 128, 128, stride=2) 20 | self.enc1_res2_2 = ResBL(128, 128, 128, stride=1) 21 | self.enc1_res3_1 = ResBL(128, 256, 256, stride=2) 22 | self.enc1_res3_2 = ResBL(256, 256, 256, stride=1) 23 | self.enc1_res4_1 = ResBL(256, 512, 512, stride=2) 24 | self.enc1_res4_2 = ResBL(512, 512, 512, stride=1) 25 | self.enc1_conv5 = nn.Conv2d( 512, 1024, 3, padding=1, stride=2) 26 | self.enc1_bn5 = nn.BatchNorm2d(1024) 27 | self.enc1_conv6 = nn.Conv2d(1024, 1024, 3, padding=1, stride=1) 28 | self.enc1_bn6 = nn.BatchNorm2d(1024) 29 | 30 | # encoder2 31 | self.enc2_conv1 = nn.Conv2d(int(inc/2), 64, 7, padding=3, stride=2, bias=False) 32 | self.enc2_bn1 = nn.BatchNorm2d(64) 33 | self.enc2_pool1 = nn.MaxPool2d(3, stride=2, padding=1) 34 | self.enc2_res1_1 = ResBL( 64, 64, 64, stride=1) 35 | self.enc2_res1_2 = ResBL( 64, 64, 64, stride=1) 36 | self.enc2_res2_1 = ResBL( 64, 128, 128, stride=2) 37 | self.enc2_res2_2 = ResBL(128, 128, 128, stride=1) 38 | self.enc2_res3_1 = ResBL(128, 256, 256, stride=2) 39 | self.enc2_res3_2 = ResBL(256, 256, 256, stride=1) 40 | self.enc2_res4_1 = ResBL(256, 512, 512, stride=2) 41 | self.enc2_res4_2 = ResBL(512, 512, 512, stride=1) 42 | self.enc2_conv5 = nn.Conv2d( 512, 1024, 3, padding=1, stride=2) 43 | self.enc2_bn5 = nn.BatchNorm2d(1024) 44 | self.enc2_conv6 = nn.Conv2d(1024, 1024, 3, padding=1, stride=1) 45 | self.enc2_bn6 = nn.BatchNorm2d(1024) 46 | 47 | # decoder 48 | self.dec_conv6 = nn.Conv2d(2048, 1024, 3, padding=1, stride=1) 49 | self.dec_bn6 = nn.BatchNorm2d(1024) 50 | self.dec_conv5 = nn.Conv2d(1024, 512, 3, padding=1, stride=1) 51 | self.dec_bn5 = nn.BatchNorm2d(512) 52 | self.dec_res4_2 = ResBL( 512, 512, 512, upscale=1, skip2=1024) 53 | self.dec_res4_1 = ResBL( 512, 512, 256, upscale=2) 54 | self.dec_res3_2 = ResBL( 256, 256, 256, upscale=1, skip2=512) 55 | self.dec_res3_1 = ResBL( 256, 256, 128, upscale=2) 56 | if self.corr is True: 57 | self.dec_corr2 = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 58 | self.dec_res2_2 = ResBL( 128, 128, 128, upscale=1, skip1=256+21*21) 59 | else: 60 | self.dec_res2_2 = ResBL(128, 128, 128, upscale=1, skip1=256) 61 | self.dec_res2_1 = ResBL( 128, 128, 64, upscale=2) 62 | self.dec_res1_2 = ResBL( 64, 64, 64, upscale=1, skip2=128) 63 | self.dec_res1_1 = ResBL( 64, 64, 64, upscale=1) 64 | if self.corr is True: 65 | self.dec_corr1 = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 66 | self.dec_conv1 = nn.Conv2d(192+21*21, 64, 7, padding=3, stride=1, bias=False) 67 | else: 68 | self.dec_conv1 = nn.Conv2d(192, 64, 7, padding=3, stride=1, bias=False) 69 | self.dec_bn1 = nn.BatchNorm2d(64) 70 | 71 | # classifier 72 | self.classifier = nn.Conv2d(64, outc, 1, padding=0, stride=1) 73 | 74 | # util 75 | self.unpool = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 76 | self.relu = nn.ReLU(inplace=True) 77 | if self.corr is True: 78 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True) 79 | 80 | # initialization 81 | self.init_weights() 82 | if pretrained is True: 83 | self.load_net_param() 84 | 85 | def forward(self, x): 86 | x1, x2 = torch.split(x,3,1) 87 | 88 | # encoder1 89 | enc1_f1 = self.enc1_conv1(x1) 90 | enc1_f1 = self.enc1_bn1(enc1_f1) 91 | enc1_f1 = self.relu(enc1_f1) 92 | enc1_f2 = self.enc1_pool1(enc1_f1) 93 | enc1_f2 = self.enc1_res1_1(enc1_f2) 94 | enc1_f2 = self.enc1_res1_2(enc1_f2) 95 | enc1_f3 = self.enc1_res2_1(enc1_f2) 96 | enc1_f3 = self.enc1_res2_2(enc1_f3) 97 | enc1_f4 = self.enc1_res3_1(enc1_f3) 98 | enc1_f4 = self.enc1_res3_2(enc1_f4) 99 | enc1_f5 = self.enc1_res4_1(enc1_f4) 100 | enc1_f5 = self.enc1_res4_2(enc1_f5) 101 | enc1_f6 = self.enc1_conv5(enc1_f5) 102 | enc1_f6 = self.enc1_bn5(enc1_f6) 103 | enc1_f6 = self.relu(enc1_f6) 104 | enc1_f6 = self.enc1_conv6(enc1_f6) 105 | enc1_f6 = self.enc1_bn6(enc1_f6) 106 | enc1_f6 = self.relu(enc1_f6) 107 | 108 | # encoder2 109 | enc2_f1 = self.enc2_conv1(x2) 110 | enc2_f1 = self.enc2_bn1(enc2_f1) 111 | enc2_f1 = self.relu(enc2_f1) 112 | enc2_f2 = self.enc2_pool1(enc2_f1) 113 | enc2_f2 = self.enc2_res1_1(enc2_f2) 114 | enc2_f2 = self.enc2_res1_2(enc2_f2) 115 | enc2_f3 = self.enc2_res2_1(enc2_f2) 116 | enc2_f3 = self.enc2_res2_2(enc2_f3) 117 | enc2_f4 = self.enc2_res3_1(enc2_f3) 118 | enc2_f4 = self.enc2_res3_2(enc2_f4) 119 | enc2_f5 = self.enc2_res4_1(enc2_f4) 120 | enc2_f5 = self.enc2_res4_2(enc2_f5) 121 | enc2_f6 = self.enc2_conv5(enc2_f5) 122 | enc2_f6 = self.enc2_bn5(enc2_f6) 123 | enc2_f6 = self.relu(enc2_f6) 124 | enc2_f6 = self.enc2_conv6(enc2_f6) 125 | enc2_f6 = self.enc2_bn6(enc2_f6) 126 | enc2_f6 = self.relu(enc2_f6) 127 | 128 | # decoder 129 | enc_f6 = torch.cat([enc1_f6, enc2_f6], 1) 130 | dec = self.dec_conv6(enc_f6) 131 | dec = self.dec_bn6(dec) 132 | dec = self.relu(dec) 133 | dec = self.dec_conv5(dec) 134 | dec = self.unpool(dec) 135 | dec = self.dec_bn5(dec) 136 | dec = self.relu(dec) 137 | skp = torch.cat([enc1_f5, enc2_f5], 1) 138 | dec = self.dec_res4_2(dec, skip2=skp) 139 | dec = self.dec_res4_1(dec) 140 | skp = torch.cat([enc1_f4, enc2_f4], 1) 141 | dec = self.dec_res3_2(dec, skip2=skp) 142 | dec = self.dec_res3_1(dec) 143 | if self.corr is True: 144 | cor = self.dec_corr2(enc1_f3, enc2_f3) 145 | cor = self.corr_activation(cor) 146 | skp = torch.cat([enc1_f3, enc2_f3, cor], 1) 147 | else: 148 | skp = torch.cat([enc1_f3, enc2_f3], 1) 149 | dec = self.dec_res2_2(dec, skip1=skp) 150 | dec = self.dec_res2_1(dec) 151 | skp = torch.cat([enc1_f2, enc2_f2], 1) 152 | dec = self.dec_res1_2(dec, skip2=skp) 153 | dec = self.dec_res1_1(dec) 154 | dec = self.unpool(dec) 155 | if self.corr is True: 156 | cor = self.dec_corr1(enc1_f1, enc2_f1) 157 | cor = self.corr_activation(cor) 158 | dec = torch.cat([dec, enc1_f1, enc2_f1, cor], 1) 159 | else: 160 | dec = torch.cat([dec, enc1_f1, enc2_f1], 1) 161 | dec = self.dec_conv1(dec) 162 | dec = self.unpool(dec) 163 | dec = self.dec_bn1(dec) 164 | dec = self.relu(dec) 165 | 166 | out = self.classifier(dec) 167 | return out 168 | 169 | def init_weights(self): 170 | init.xavier_uniform_relu(self.modules()) 171 | 172 | def load_net_param(self): 173 | from torchvision.models import resnet18 174 | resnet = resnet18(pretrained=True) 175 | 176 | self.enc1_conv1.load_state_dict(resnet.conv1.state_dict()) 177 | self.enc1_bn1.load_state_dict(resnet.bn1.state_dict()) 178 | self.enc1_res1_1.load_state_dict(list(resnet.layer1.children())[0].state_dict()) 179 | self.enc1_res1_2.load_state_dict(list(resnet.layer1.children())[1].state_dict()) 180 | self.enc1_res2_1.load_state_dict(list(resnet.layer2.children())[0].state_dict()) 181 | self.enc1_res2_2.load_state_dict(list(resnet.layer2.children())[1].state_dict()) 182 | self.enc1_res3_1.load_state_dict(list(resnet.layer3.children())[0].state_dict()) 183 | self.enc1_res3_2.load_state_dict(list(resnet.layer3.children())[1].state_dict()) 184 | self.enc1_res4_1.load_state_dict(list(resnet.layer4.children())[0].state_dict()) 185 | self.enc1_res4_2.load_state_dict(list(resnet.layer4.children())[1].state_dict()) 186 | 187 | self.enc2_conv1.load_state_dict(resnet.conv1.state_dict()) 188 | self.enc2_bn1.load_state_dict(resnet.bn1.state_dict()) 189 | self.enc2_res1_1.load_state_dict(list(resnet.layer1.children())[0].state_dict()) 190 | self.enc2_res1_2.load_state_dict(list(resnet.layer1.children())[1].state_dict()) 191 | self.enc2_res2_1.load_state_dict(list(resnet.layer2.children())[0].state_dict()) 192 | self.enc2_res2_2.load_state_dict(list(resnet.layer2.children())[1].state_dict()) 193 | self.enc2_res3_1.load_state_dict(list(resnet.layer3.children())[0].state_dict()) 194 | self.enc2_res3_2.load_state_dict(list(resnet.layer3.children())[1].state_dict()) 195 | self.enc2_res4_1.load_state_dict(list(resnet.layer4.children())[0].state_dict()) 196 | self.enc2_res4_2.load_state_dict(list(resnet.layer4.children())[1].state_dict()) 197 | 198 | 199 | class ResBL(nn.Module): 200 | def __init__(self, inc, midc, outc, stride=1, upscale=1, skip1=0, skip2=0): 201 | super(ResBL, self).__init__() 202 | 203 | self.conv1 = nn.Conv2d(inc+skip1, midc, 3, padding=1, stride=stride, bias=False) 204 | self.bn1 = nn.BatchNorm2d(midc) 205 | self.relu = nn.ReLU(inplace=True) 206 | self.conv2 = nn.Conv2d(midc+skip2, outc, 3, padding=1, bias=False) 207 | self.bn2 = nn.BatchNorm2d(outc) 208 | 209 | self.upscale = None 210 | if upscale > 1: 211 | self.upscale = nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True) 212 | 213 | self.downsample = None 214 | 215 | if inc != outc or stride > 1 or upscale > 1: 216 | if upscale > 1: 217 | self.downsample = nn.Sequential( 218 | nn.Conv2d(inc, outc, 1, padding=0, stride=stride, bias=False), 219 | nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True), 220 | nn.BatchNorm2d(outc), 221 | ) 222 | else: 223 | self.downsample = nn.Sequential( 224 | nn.Conv2d(inc, outc, 1, padding=0, stride=stride, bias=False), 225 | nn.BatchNorm2d(outc), 226 | ) 227 | 228 | def forward(self, x, skip1=None, skip2=None): 229 | if skip1 is not None: 230 | res = torch.cat([x, skip1], 1) 231 | else: 232 | res = x 233 | 234 | res = self.conv1(res) 235 | res = self.bn1(res) 236 | res = self.relu(res) 237 | 238 | if skip2 is not None: 239 | res = torch.cat([res, skip2], 1) 240 | 241 | res = self.conv2(res) 242 | if self.upscale is not None: 243 | res = self.upscale(res) 244 | res = self.bn2(res) 245 | 246 | identity = x 247 | if self.downsample is not None: 248 | identity = self.downsample(x) 249 | 250 | res += identity 251 | out = self.relu(res) 252 | 253 | return out 254 | 255 | 256 | 257 | 258 | --------------------------------------------------------------------------------