├── download_resnet18.sh
├── criterion.py
├── init.py
├── LICENSE
├── README.md
├── test.py
├── dataset_pcd.py
├── train.py
└── cscdnet.py
/download_resnet18.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | wget https://download.pytorch.org/models/resnet18-5c106cde.pth
3 |
4 |
--------------------------------------------------------------------------------
/criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class CrossEntropyLoss2d(nn.Module):
5 |
6 | def __init__(self, weight=None):
7 | super().__init__()
8 | self.loss = nn.NLLLoss2d(weight)
9 |
10 | def forward(self, outputs, mask):
11 | return self.loss(F.log_softmax(outputs,dim=1), mask)
12 |
13 |
--------------------------------------------------------------------------------
/init.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.init as init
3 |
4 | def xavier_uniform_relu(modules):
5 | for m in modules:
6 | if isinstance(m, nn.Conv2d):
7 | init.xavier_uniform(m.weight.data, gain=init.calculate_gain('relu'))
8 | if m.bias is not None:
9 | m.bias.data.zero_()
10 | elif isinstance(m, nn.BatchNorm2d):
11 | m.weight.data.fill_(1)
12 | m.bias.data.zero_()
13 |
14 | def xavier_uniform_sigmoid(modules):
15 | for m in modules:
16 | if isinstance(m, nn.Conv2d):
17 | init.xavier_uniform(m.weight.data, gain=init.calculate_gain('sigmoid'))
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 | elif isinstance(m, nn.BatchNorm2d):
21 | m.weight.data.fill_(1)
22 | m.bias.data.zero_()
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Ken Sakurada
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semantic Scene Change Detection Network
2 | This is an official implementation of "Correlated Siamese Change Detection Network (CSCDNet)" and "Silhouette-based Semantic Change Detection Network (SSCDNet)" in "[Weakly Supervised Silhouette-based Semantic Scene Change Detection](https://arxiv.org/abs/1811.11985)" (ICRA2020). (SSCDNet and PSCD datast are preparing...)
3 |
4 |
5 |
6 |
7 |
8 | ## Environments
9 | This code was developed and tested with Python 3.6.8 and PyTorch 1.0 and CUDA 9.2.
10 | * GCC
11 | ```
12 | # Build and install GCC (>= 7.4.0) if not installed
13 | # Set path variables
14 | export PATH=/home/$USER/local/gcc/bin:$PATH
15 | export LD_LIBRARY_PATH=/home/$USER/local/gcc/lib64:$LD_LIBRARY_PATH
16 | ```
17 |
18 | * Virtualenv for system setting
19 | ```
20 | # Set CUDA path.
21 | # In case of server, the following CUDA path setting with module load command might be necessary.
22 | module load cuda/9.2/9.2.88.1
23 |
24 | # Create a virtualenv environment
25 | virtualenv -p python /path/to/env/pytorch1.0cuda9.2
26 |
27 | #Activate the virtualenv environment
28 | source /path/to/env/pytorch1.0cuda9.2/bin/activate
29 |
30 | # Install dependencies
31 | pip install -r requirements.txt
32 | ```
33 |
34 | * Download the pretrained model of resnet18
35 | ```
36 | sh download_resnet.sh
37 | ```
38 |
39 | * Build correlation layer package from [flownet2](https://github.com/NVIDIA/flownet2-pytorch).
40 | ```
41 | sh build_correlation_package.sh
42 | ```
43 |
44 | ## Dataset
45 | Please prepare the following format dataset using change detection datasets such as [TSUNAMI](https://kensakurada.github.io/pcd_dataset.html).
46 | In the case of a large dataset, it is not necessary to split it.
47 |
48 | Training
49 | ```
50 | pcd_5cv
51 | ├── set0/
52 | │ ├── train/ # *.jpg
53 | │ ├── test/ # *.jpg
54 | │ ├── mask/ # *.png
55 | | ├── train.txt
56 | | ├── test.txt
57 | ├── set1/
58 | ...
59 | ├── set2/
60 | ...
61 | ├── set3/
62 | ...
63 | ├── set4/
64 | ├── train/ # *.jpg
65 | ├── test/ # *.jpg
66 | ├── mask/ # *.png
67 | ├── train.txt
68 | ├── test.txt
69 | ```
70 |
71 | Testing
72 | ```
73 | pcd
74 | ├── TSUNAMI/
75 | ├── t0/ # *.jpg
76 | ├── t1/ # *.jpg
77 | ├── mask/ # *.png
78 | ```
79 |
80 |
81 | ## Training
82 | Train change detection network with correlation layers (CSCDNet)
83 | ```
84 | # i-th set of five-hold cross-validation (0 <= i < 5)
85 | python train.py --cvset i --use-corr --datadir /path/to/pcd_5cv --checkpointdir /path/to/log --max-iteration 50000 --num-workers 16 --batch-size 32 --icount-plot 50 --icount-save 10000
86 | ```
87 |
88 | Train change detection network without correlation layers (CDNet)
89 | ```
90 | # i-th set of five-hold cross-validation (0 <= i < 5)
91 | python train.py --cvset i --datadir /path/to/pcd_5cv --checkpointdir /path/to/log --max-iteration 50000 --num-workers 16 --batch-size 32 --icount-plot 50 --icount-save 10000
92 | ```
93 |
94 | You can start a tensorboard session
95 | ```
96 | tensorboard --logdir=/path/to/log
97 | ```
98 |
99 |
100 | ## Testing
101 | CSCDNet
102 | ```
103 | python test.py --use-corr --dataset PCD --datadir /path/to/pcd --checkpointdir /path/to/log/cscdnet/checkpoint
104 | ```
105 | CDNet
106 | ```
107 | python test.py --dataset PCD --datadir /path/to/pcd --checkpointdir /path/to/log/cdnet/checkpoint
108 | ```
109 |
110 | ## Citation
111 | If you find this implementation useful in your work, please cite the paper. Here is a BibTeX entry:
112 | ```
113 | @article{sakurada2020weakly,
114 | title={Weakly Supervised Silhouette-based Semantic Scene Change Detection},
115 | author={Sakurada, Ken and Shibuya, Mikiya and Wang Weimin},
116 | journal={Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
117 | year={2020}
118 | }
119 | ```
120 | The preprint can be found [here](https://arxiv.org/abs/1811.11985).
121 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os.path
4 | from argparse import ArgumentParser
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import sys
9 | sys.path.append("./correlation_package/build/lib.linux-x86_64-3.6")
10 | import cscdnet
11 |
12 |
13 | class DataInfo:
14 | def __init__(self):
15 | self.width = 1024
16 | self.height = 224
17 | self.no_start = 0
18 | self.no_end = 100
19 | self.num_cv = 5
20 |
21 | class Test:
22 | def __init__(self, arguments):
23 | self.args = arguments
24 | self.di = DataInfo()
25 |
26 | def test(self):
27 |
28 | _inputs = torch.from_numpy(np.concatenate((self.t0, self.t1), axis=0)).contiguous()
29 | _inputs = Variable(_inputs).view(1, -1, self.h_resize, self.w_resize)
30 | _inputs = _inputs.cuda()
31 | _outputs = self.model(_inputs)
32 |
33 | inputs = _inputs[0].cpu().data
34 | image_t0 = inputs[0:3, :, :]
35 | image_t1 = inputs[3:6, :, :]
36 | image_t0 = (image_t0 + 1.0) * 128
37 | image_t1 = (image_t1 + 1.0) * 128
38 | mask_gt = np.where(self.mask.data.numpy().squeeze(axis=0) == True, 0, 255)
39 |
40 | outputs = _outputs[0].cpu().data
41 | mask_pred = F.softmax(outputs[0:2, :, :], dim=0)[1] * 255
42 |
43 | self.display_results(image_t0, image_t1, mask_pred, mask_gt)
44 |
45 | def display_results(self, t0, t1, mask_pred, mask_gt):
46 |
47 | w, h = self.w_orig, self.h_orig
48 | t0_disp = cv2.resize(np.transpose(t0.numpy(), (1, 2, 0)).astype(np.uint8), (w, h))
49 | t1_disp = cv2.resize(np.transpose(t1.numpy(), (1, 2, 0)).astype(np.uint8), (w, h))
50 | mask_pred_disp = cv2.resize(cv2.cvtColor(mask_pred.numpy().astype(np.uint8), cv2.COLOR_GRAY2RGB), (w, h))
51 | mask_gt_disp = cv2.resize(cv2.cvtColor(mask_gt.astype(np.uint8), cv2.COLOR_GRAY2RGB), (w, h))
52 |
53 | img_out = np.zeros((h* 2, w * 2, 3), dtype=np.uint8)
54 | img_out[0:h, 0:w, :] = t0_disp
55 | img_out[0:h, w:w * 2, :] = t1_disp
56 | img_out[h:h * 2, 0:w * 1, :] = mask_gt_disp
57 | img_out[h:h * 2, w * 1:w * 2, :] = mask_pred_disp
58 | for dn, img in zip(['mask', 'disp'], [mask_pred_disp, img_out]):
59 | dn_save = os.path.join(self.args.checkpointdir, 'result', dn)
60 | fn_save = os.path.join(dn_save, '{0:08d}.png'.format(self.index))
61 | if not os.path.exists(dn_save):
62 | os.makedirs(dn_save)
63 | print('Writing ... ' + fn_save)
64 | cv2.imwrite(fn_save, img)
65 |
66 | def run(self):
67 |
68 | for i_set in range(0,self.di.num_cv):
69 | if self.args.use_corr:
70 | print('Correlated Siamese Change Detection Network (CSCDNet)')
71 | self.model = cscdnet.Model(inc=6, outc=2, corr=True, pretrained=True)
72 | fn_model = os.path.join(os.path.join(self.args.checkpointdir, 'set{}'.format(i_set), 'cscdnet-00030000.pth'))
73 | else:
74 | print('Siamese Change Detection Network (Siamese CDResNet)')
75 | self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True)
76 | fn_model = os.path.join(os.path.join(self.args.checkpointdir, 'set{}'.format(i_set), 'cdnet-00030000.pth'))
77 |
78 | if os.path.isfile(fn_model) is False:
79 | print("Error: Cannot read file ... " + fn_model)
80 | exit(-1)
81 | else:
82 | print("Reading model ... " + fn_model)
83 | self.model.load_state_dict(torch.load(fn_model))
84 | self.model = self.model.cuda()
85 |
86 | if self.args.dataset == 'PCD':
87 | from dataset_pcd import PCD_full
88 | for dataset in ['TSUNAMI']:
89 | loader_test = PCD_full(os.path.join(self.args.datadir,dataset), self.di.no_start, self.di.no_end, self.di.width, self.di.height)
90 | for index in range(0,loader_test.__len__()):
91 | if i_set * (10 / self.di.num_cv) <= (index % 10) < (i_set + 1) * (10 / self.di.num_cv):
92 | self.index = index
93 | self.t0, self.t1, self.mask, self.w_orig, self.h_orig, self.w_resize, self.h_resize = loader_test.__getitem__(index)
94 | self.test()
95 | else:
96 | continue
97 | else:
98 | print('Error: Unexpected dataset')
99 | exit(-1)
100 |
101 |
102 | if __name__ == '__main__':
103 |
104 | parser = ArgumentParser(description='Start testing ...')
105 | parser.add_argument('--datadir', required=True)
106 | parser.add_argument('--checkpointdir', required=True)
107 | parser.add_argument('--use-corr', action='store_true', help='using correlation layer')
108 | parser.add_argument('--dataset', required=True)
109 |
110 | test = Test(parser.parse_args())
111 | test.run()
112 |
113 |
114 |
--------------------------------------------------------------------------------
/dataset_pcd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import torch
5 | from torch.utils.data import Dataset
6 |
7 | EXTENSIONS = ['jpg','.png']
8 |
9 | def check_img(filename):
10 | return any(filename.endswith(ext) for ext in EXTENSIONS)
11 |
12 | def get_img_path(root, basename, extension):
13 | return os.path.join(root, basename+extension)
14 |
15 | def get_img_basename(filename):
16 | return os.path.basename(os.path.splitext(filename)[0])
17 |
18 | class PCD(Dataset):
19 |
20 | def __init__(self, root):
21 | super(PCD, self).__init__()
22 | self.img_t0_root = os.path.join(root, 't0')
23 | self.img_t1_root = os.path.join(root, 't1')
24 | self.mask_root = os.path.join(root, 'mask')
25 |
26 | self.filenames = [get_img_basename(f) for f in os.listdir(self.mask_root) if check_img(f)]
27 | self.filenames.sort()
28 |
29 | print('{}:{}'.format(root,len(self.filenames)))
30 |
31 | def __getitem__(self, index):
32 | filename = self.filenames[index]
33 |
34 | fn_img_t0 = get_img_path(self.img_t0_root, filename, '.jpg')
35 | fn_img_t1 = get_img_path(self.img_t1_root, filename, '.jpg')
36 | fn_mask = get_img_path(self.mask_root, filename, '.png')
37 |
38 | if os.path.isfile(fn_img_t0) == False:
39 | print ('Error: File Not Found: ' + fn_img_t0)
40 | exit(-1)
41 | if os.path.isfile(fn_img_t1) == False:
42 | print ('Error: File Not Found: ' + fn_img_t1)
43 | exit(-1)
44 | if os.path.isfile(fn_mask) == False:
45 | print ('Error: File Not Found: ' + fn_mask)
46 | exit(-1)
47 |
48 | img_t0 = cv2.imread(fn_img_t0, cv2.IMREAD_COLOR)
49 | img_t1 = cv2.imread(fn_img_t1, cv2.IMREAD_COLOR)
50 | mask = cv2.imread(fn_mask, cv2.IMREAD_GRAYSCALE)
51 | w,h,c = img_t0.shape
52 | r = 286./min(w,h)
53 | # resize images so that min(w, h) == 256
54 | img_t0 = cv2.resize(img_t0, (int(r*w), int(r*h)))
55 | img_t1 = cv2.resize(img_t1, (int(r*w), int(r*h)))
56 | mask = cv2.resize(mask, (int(r*w), int(r*h)))[:,:,np.newaxis]
57 |
58 | img_t0_ = np.asarray(img_t0).astype("f").transpose(2, 0, 1) / 128.0 - 1.0
59 | img_t1_ = np.asarray(img_t1).astype("f").transpose(2, 0, 1) / 128.0 - 1.0
60 | # black/white inverting
61 | mask_ = np.asarray(mask>128).astype("int").transpose(2, 0, 1)
62 |
63 | crop_width = 256
64 | _,h,w = img_t0_.shape
65 | x_l = np.random.randint(0,w-crop_width)
66 | x_r = x_l+crop_width
67 | y_l = np.random.randint(0,h-crop_width)
68 | y_r = y_l+crop_width
69 |
70 | input_ = torch.from_numpy(np.concatenate((img_t0_[:,y_l:y_r,x_l:x_r], img_t1_[:,y_l:y_r,x_l:x_r]), axis=0))
71 | mask_ = torch.from_numpy(mask_[:, y_l:y_r, x_l:x_r]).long()
72 |
73 | return input_, mask_
74 |
75 | def __len__(self):
76 | return len(self.filenames)
77 |
78 | def get_random_index(self):
79 | index = np.random.randint(0, len(self.filenames))
80 | return index
81 |
82 |
83 |
84 | class PCD_full(Dataset):
85 |
86 | def __init__(self, root, id_s, id_e, width, height):
87 | super(PCD_full, self).__init__()
88 | self.img_t0_root = os.path.join(root, 't0')
89 | self.img_t1_root = os.path.join(root, 't1')
90 | self.mask_root = os.path.join(root, 'mask')
91 |
92 | self.filenames = [get_img_basename(f) for f in os.listdir(self.mask_root) if check_img(f)]
93 | self.filenames.sort()
94 | self.filenames = self.filenames[id_s:id_e]
95 |
96 | self.width = width
97 | self.height = height
98 |
99 | def __getitem__(self, index):
100 | filename = self.filenames[index]
101 |
102 | fn_img_t0 = get_img_path(self.img_t0_root, filename, '.jpg')
103 | fn_img_t1 = get_img_path(self.img_t1_root, filename, '.jpg')
104 | fn_mask = get_img_path(self.mask_root, filename, '.png')
105 |
106 | if os.path.isfile(fn_img_t0) == False:
107 | print ('Error: File Not Found: ' + fn_img_t0)
108 | exit(-1)
109 | if os.path.isfile(fn_img_t1) == False:
110 | print ('Error: File Not Found: ' + fn_img_t1)
111 | exit(-1)
112 |
113 | if os.path.isfile(fn_mask) == False:
114 | print ('Error: File Not Found: ' + fn_mask)
115 | exit(-1)
116 |
117 | img_t0 = cv2.imread(fn_img_t0, cv2.IMREAD_COLOR)
118 | img_t1 = cv2.imread(fn_img_t1, cv2.IMREAD_COLOR)
119 | mask = cv2.imread(fn_mask, cv2.IMREAD_GRAYSCALE)
120 | h,w,c = img_t0.shape
121 | r = min(w,h)/256
122 | w_r = int(256*max(w/256,1))
123 | h_r = int(256*max(h/256,1))
124 | # resize images so that min(w, h) == 256
125 | img_t0 = cv2.resize(img_t0, (w_r, h_r))
126 | img_t1 = cv2.resize(img_t1, (w_r, h_r))
127 | mask = cv2.resize(mask, (w_r, h_r))[:,:,np.newaxis]
128 |
129 | img_t0_ = np.asarray(img_t0).astype("f").transpose(2, 0, 1) / 128.0 - 1.0
130 | img_t1_ = np.asarray(img_t1).astype("f").transpose(2, 0, 1) / 128.0 - 1.0
131 | mask_ = np.asarray(mask>128).astype("int").transpose(2, 0, 1)
132 | mask_ = torch.from_numpy(mask_).long()
133 |
134 | return img_t0_, img_t1_, mask_, w, h, w_r, h_r
135 |
136 | def __len__(self):
137 | return len(self.filenames)
138 |
139 |
140 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import print_function
3 | from argparse import ArgumentParser
4 | import cv2
5 | import csv
6 | import os.path
7 | import numpy as np
8 | import torch
9 | from torch.optim import Adam, lr_scheduler
10 | from torch.autograd import Variable
11 | from torch.utils.data import DataLoader
12 | from tensorboardX import SummaryWriter
13 | from criterion import CrossEntropyLoss2d
14 | from dataset_pcd import PCD
15 | import sys
16 | sys.path.append("./correlation_package/build/lib.linux-x86_64-3.6")
17 | import cscdnet
18 |
19 |
20 | def colormap():
21 | cmap=np.zeros([2, 3]).astype(np.uint8)
22 |
23 | cmap[0,:] = np.array([0, 0, 0])
24 | cmap[1,:] = np.array([255, 255, 255])
25 |
26 | return cmap
27 |
28 |
29 | class Colorization:
30 |
31 | def __init__(self, n=2):
32 | self.cmap = colormap()
33 | self.cmap = torch.from_numpy(np.array(self.cmap[:n]))
34 |
35 | def __call__(self, gray_image):
36 | size = gray_image.size()
37 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
38 |
39 | for label in range(0, len(self.cmap)):
40 | mask = gray_image[0] == label
41 |
42 | color_image[0][mask] = self.cmap[label][0]
43 | color_image[1][mask] = self.cmap[label][1]
44 | color_image[2][mask] = self.cmap[label][2]
45 |
46 | return color_image
47 |
48 |
49 | class Training:
50 | def __init__(self, arguments):
51 | self.args = arguments
52 | self.icount = 0
53 | if self.args.use_corr:
54 | self.dn_save = os.path.join(self.args.checkpointdir,'cscdnet','checkpointdir','set{}'.format(self.args.cvset))
55 | else:
56 | self.dn_save = os.path.join(self.args.checkpointdir,'cdnet','checkpointdir','set{}'.format(self.args.cvset))
57 |
58 | def train(self):
59 |
60 | self.color_transform = Colorization(2)
61 |
62 | # Dataset loader for train and test
63 | dataset_train = DataLoader(
64 | PCD(os.path.join(self.args.datadir, 'set{}'.format(self.args.cvset), 'train')),
65 | num_workers=self.args.num_workers, batch_size=self.args.batch_size, shuffle=True)
66 | self.dataset_test = PCD(os.path.join(self.args.datadir, 'set{}'.format(self.args.cvset), 'test'))
67 |
68 | self.test_path = os.path.join(self.dn_save, 'test')
69 | if not os.path.exists(self.test_path):
70 | os.makedirs(self.test_path)
71 |
72 | # Set loss function, optimizer and learning rate
73 | weight = torch.ones(2)
74 | criterion = CrossEntropyLoss2d(weight.cuda())
75 | optimizer = Adam(self.model.parameters(), lr=0.0001, betas=(0.5, 0.999))
76 | lambda1 = lambda icount: (float)(self.args.max_iteration - icount) / (float)(self.args.max_iteration)
77 | model_lr_scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
78 |
79 | fn_loss = os.path.join(self.dn_save,'loss.csv')
80 | f_loss = open(fn_loss, 'w')
81 | writer = csv.writer(f_loss)
82 |
83 | self.writers= SummaryWriter(os.path.join(self.dn_save, 'log'))
84 |
85 | # Training loop
86 | icount_loss = []
87 | while self.icount < self.args.max_iteration:
88 | model_lr_scheduler.step()
89 | for step, (inputs_train, mask_train) in enumerate(dataset_train):
90 | inputs_train = inputs_train.cuda()
91 | mask_train = mask_train.cuda()
92 |
93 | inputs_train = Variable(inputs_train)
94 | mask_train = Variable(mask_train)
95 | outputs_train = self.model(inputs_train)
96 |
97 | optimizer.zero_grad()
98 | self.loss = criterion(outputs_train, mask_train[:, 0])
99 |
100 | self.loss.backward()
101 | optimizer.step()
102 |
103 | self.icount += 1
104 | icount_loss.append(self.loss.item())
105 | writer.writerow([self.icount, self.loss.item()])
106 | if self.args.icount_plot > 0 and self.icount % self.args.icount_plot == 0:
107 | self.test()
108 | average = sum(icount_loss) / len(icount_loss)
109 | print('loss: {0} (icount: {1})'.format(average, self.icount))
110 | icount_loss.clear()
111 |
112 | if self.args.icount_save > 0 and self.icount % self.args.icount_save == 0:
113 | self.checkpoint()
114 |
115 | f_loss.close()
116 |
117 | def test(self):
118 |
119 | index_test = self.dataset_test.get_random_index()
120 | inputs_test, mask_gt_test = self.dataset_test[index_test]
121 | inputs_test = inputs_test[np.newaxis, :, :]
122 | inputs_test = inputs_test.cuda()
123 | inputs_test = Variable(inputs_test)
124 | outputs_test = self.model(inputs_test)
125 |
126 | inputs = inputs_test[0].cpu().data
127 | t0_test = inputs[0:3, :, :]
128 | t1_test = inputs[3:6, :, :]
129 | t0_test = (t0_test + 1.0) * 128
130 | t1_test = (t1_test + 1.0) * 128
131 | mask_gt = mask_gt_test.numpy().astype(np.uint8) * 255
132 |
133 | outputs = outputs_test[0][np.newaxis, :, :, :]
134 | outputs = outputs[:, 0:2, :, :]
135 | mask_pred = np.transpose(self.color_transform(outputs[0].cpu().max(0)[1][np.newaxis, :, :].data).numpy(), (1, 2, 0)).astype(np.uint8)
136 |
137 | img_out = self.display_results(t0_test, t1_test, mask_pred, mask_gt)
138 | self.log_tbx(torch.from_numpy(np.transpose(np.flip(img_out,axis=2).copy(), (2, 0, 1))))
139 |
140 | def display_results(self, t0, t1, mask_pred, mask_gt):
141 |
142 | rows = cols = 256
143 | img_out = np.zeros((rows * 2, cols * 2, 3), dtype=np.uint8)
144 | img_out[0:rows, 0:cols, :] = np.transpose(t0.numpy(), (1, 2, 0)).astype(np.uint8)
145 | img_out[0:rows, cols:cols * 2, :] = np.transpose(t1.numpy(), (1, 2, 0)).astype(np.uint8)
146 | img_out[rows:rows * 2, 0:cols, :] = cv2.cvtColor(np.transpose(mask_gt, (1, 2, 0)), cv2.COLOR_GRAY2RGB)
147 | img_out[rows:rows * 2, cols:cols * 2, :] = mask_pred
148 |
149 | return img_out
150 |
151 | # Output results for tensorboard
152 | def log_tbx(self, image):
153 |
154 | writer = self.writers
155 | writer.add_scalar('data/loss', self.loss.item(), self.icount)
156 | writer.add_image('change detection', image, self.icount)
157 |
158 | def checkpoint(self):
159 | if self.args.use_corr:
160 | filename = 'cscdnet-{0:08d}.pth'.format(self.icount)
161 | else:
162 | filename = 'cdnet-{0:08d}.pth'.format(self.icount)
163 | torch.save(self.model.state_dict(), os.path.join(self.dn_save, filename))
164 | print('save: {0} (iteration: {1})'.format(filename, self.icount))
165 |
166 | def run(self):
167 |
168 | if self.args.use_corr:
169 | print('Correlated Siamese Change Detection Network (CSCDNet)')
170 | self.model = cscdnet.Model(inc=6, outc=2, corr=True, pretrained=True)
171 | else:
172 | print('Siamese Change Detection Network (Siamese CDResNet)')
173 | self.model = cscdnet.Model(inc=6, outc=2, corr=False, pretrained=True)
174 |
175 | self.model = self.model.cuda()
176 | self.train()
177 |
178 |
179 | if __name__ == '__main__':
180 |
181 | parser = ArgumentParser(description='Start training ...')
182 | parser.add_argument('--checkpointdir', required=True)
183 | parser.add_argument('--datadir', required=True)
184 | parser.add_argument('--use-corr', action='store_true', help='using correlation layer')
185 | parser.add_argument('--max-iteration', type=int, default=50000)
186 | parser.add_argument('--num-workers', type=int, default=4)
187 | parser.add_argument('--batch-size', type=int, default=32)
188 | parser.add_argument('--cvset', type=int, default=0)
189 | parser.add_argument('--icount-plot', type=int, default=0)
190 | parser.add_argument('--icount-save', type=int, default=10)
191 |
192 | training = Training(parser.parse_args())
193 | training.run()
194 |
--------------------------------------------------------------------------------
/cscdnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import init
4 | from collections import OrderedDict
5 | from correlation_package.correlation import Correlation
6 |
7 | class Model(nn.Module):
8 | def __init__(self, inc, outc, corr=True, pretrained=True):
9 | super(Model, self).__init__()
10 |
11 | self.corr = corr
12 |
13 | # encoder1
14 | self.enc1_conv1 = nn.Conv2d(int(inc/2), 64, 7, padding=3, stride=2, bias=False)
15 | self.enc1_bn1 = nn.BatchNorm2d(64)
16 | self.enc1_pool1 = nn.MaxPool2d(3, stride=2, padding=1)
17 | self.enc1_res1_1 = ResBL( 64, 64, 64, stride=1)
18 | self.enc1_res1_2 = ResBL( 64, 64, 64, stride=1)
19 | self.enc1_res2_1 = ResBL( 64, 128, 128, stride=2)
20 | self.enc1_res2_2 = ResBL(128, 128, 128, stride=1)
21 | self.enc1_res3_1 = ResBL(128, 256, 256, stride=2)
22 | self.enc1_res3_2 = ResBL(256, 256, 256, stride=1)
23 | self.enc1_res4_1 = ResBL(256, 512, 512, stride=2)
24 | self.enc1_res4_2 = ResBL(512, 512, 512, stride=1)
25 | self.enc1_conv5 = nn.Conv2d( 512, 1024, 3, padding=1, stride=2)
26 | self.enc1_bn5 = nn.BatchNorm2d(1024)
27 | self.enc1_conv6 = nn.Conv2d(1024, 1024, 3, padding=1, stride=1)
28 | self.enc1_bn6 = nn.BatchNorm2d(1024)
29 |
30 | # encoder2
31 | self.enc2_conv1 = nn.Conv2d(int(inc/2), 64, 7, padding=3, stride=2, bias=False)
32 | self.enc2_bn1 = nn.BatchNorm2d(64)
33 | self.enc2_pool1 = nn.MaxPool2d(3, stride=2, padding=1)
34 | self.enc2_res1_1 = ResBL( 64, 64, 64, stride=1)
35 | self.enc2_res1_2 = ResBL( 64, 64, 64, stride=1)
36 | self.enc2_res2_1 = ResBL( 64, 128, 128, stride=2)
37 | self.enc2_res2_2 = ResBL(128, 128, 128, stride=1)
38 | self.enc2_res3_1 = ResBL(128, 256, 256, stride=2)
39 | self.enc2_res3_2 = ResBL(256, 256, 256, stride=1)
40 | self.enc2_res4_1 = ResBL(256, 512, 512, stride=2)
41 | self.enc2_res4_2 = ResBL(512, 512, 512, stride=1)
42 | self.enc2_conv5 = nn.Conv2d( 512, 1024, 3, padding=1, stride=2)
43 | self.enc2_bn5 = nn.BatchNorm2d(1024)
44 | self.enc2_conv6 = nn.Conv2d(1024, 1024, 3, padding=1, stride=1)
45 | self.enc2_bn6 = nn.BatchNorm2d(1024)
46 |
47 | # decoder
48 | self.dec_conv6 = nn.Conv2d(2048, 1024, 3, padding=1, stride=1)
49 | self.dec_bn6 = nn.BatchNorm2d(1024)
50 | self.dec_conv5 = nn.Conv2d(1024, 512, 3, padding=1, stride=1)
51 | self.dec_bn5 = nn.BatchNorm2d(512)
52 | self.dec_res4_2 = ResBL( 512, 512, 512, upscale=1, skip2=1024)
53 | self.dec_res4_1 = ResBL( 512, 512, 256, upscale=2)
54 | self.dec_res3_2 = ResBL( 256, 256, 256, upscale=1, skip2=512)
55 | self.dec_res3_1 = ResBL( 256, 256, 128, upscale=2)
56 | if self.corr is True:
57 | self.dec_corr2 = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1)
58 | self.dec_res2_2 = ResBL( 128, 128, 128, upscale=1, skip1=256+21*21)
59 | else:
60 | self.dec_res2_2 = ResBL(128, 128, 128, upscale=1, skip1=256)
61 | self.dec_res2_1 = ResBL( 128, 128, 64, upscale=2)
62 | self.dec_res1_2 = ResBL( 64, 64, 64, upscale=1, skip2=128)
63 | self.dec_res1_1 = ResBL( 64, 64, 64, upscale=1)
64 | if self.corr is True:
65 | self.dec_corr1 = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1)
66 | self.dec_conv1 = nn.Conv2d(192+21*21, 64, 7, padding=3, stride=1, bias=False)
67 | else:
68 | self.dec_conv1 = nn.Conv2d(192, 64, 7, padding=3, stride=1, bias=False)
69 | self.dec_bn1 = nn.BatchNorm2d(64)
70 |
71 | # classifier
72 | self.classifier = nn.Conv2d(64, outc, 1, padding=0, stride=1)
73 |
74 | # util
75 | self.unpool = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
76 | self.relu = nn.ReLU(inplace=True)
77 | if self.corr is True:
78 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True)
79 |
80 | # initialization
81 | self.init_weights()
82 | if pretrained is True:
83 | self.load_net_param()
84 |
85 | def forward(self, x):
86 | x1, x2 = torch.split(x,3,1)
87 |
88 | # encoder1
89 | enc1_f1 = self.enc1_conv1(x1)
90 | enc1_f1 = self.enc1_bn1(enc1_f1)
91 | enc1_f1 = self.relu(enc1_f1)
92 | enc1_f2 = self.enc1_pool1(enc1_f1)
93 | enc1_f2 = self.enc1_res1_1(enc1_f2)
94 | enc1_f2 = self.enc1_res1_2(enc1_f2)
95 | enc1_f3 = self.enc1_res2_1(enc1_f2)
96 | enc1_f3 = self.enc1_res2_2(enc1_f3)
97 | enc1_f4 = self.enc1_res3_1(enc1_f3)
98 | enc1_f4 = self.enc1_res3_2(enc1_f4)
99 | enc1_f5 = self.enc1_res4_1(enc1_f4)
100 | enc1_f5 = self.enc1_res4_2(enc1_f5)
101 | enc1_f6 = self.enc1_conv5(enc1_f5)
102 | enc1_f6 = self.enc1_bn5(enc1_f6)
103 | enc1_f6 = self.relu(enc1_f6)
104 | enc1_f6 = self.enc1_conv6(enc1_f6)
105 | enc1_f6 = self.enc1_bn6(enc1_f6)
106 | enc1_f6 = self.relu(enc1_f6)
107 |
108 | # encoder2
109 | enc2_f1 = self.enc2_conv1(x2)
110 | enc2_f1 = self.enc2_bn1(enc2_f1)
111 | enc2_f1 = self.relu(enc2_f1)
112 | enc2_f2 = self.enc2_pool1(enc2_f1)
113 | enc2_f2 = self.enc2_res1_1(enc2_f2)
114 | enc2_f2 = self.enc2_res1_2(enc2_f2)
115 | enc2_f3 = self.enc2_res2_1(enc2_f2)
116 | enc2_f3 = self.enc2_res2_2(enc2_f3)
117 | enc2_f4 = self.enc2_res3_1(enc2_f3)
118 | enc2_f4 = self.enc2_res3_2(enc2_f4)
119 | enc2_f5 = self.enc2_res4_1(enc2_f4)
120 | enc2_f5 = self.enc2_res4_2(enc2_f5)
121 | enc2_f6 = self.enc2_conv5(enc2_f5)
122 | enc2_f6 = self.enc2_bn5(enc2_f6)
123 | enc2_f6 = self.relu(enc2_f6)
124 | enc2_f6 = self.enc2_conv6(enc2_f6)
125 | enc2_f6 = self.enc2_bn6(enc2_f6)
126 | enc2_f6 = self.relu(enc2_f6)
127 |
128 | # decoder
129 | enc_f6 = torch.cat([enc1_f6, enc2_f6], 1)
130 | dec = self.dec_conv6(enc_f6)
131 | dec = self.dec_bn6(dec)
132 | dec = self.relu(dec)
133 | dec = self.dec_conv5(dec)
134 | dec = self.unpool(dec)
135 | dec = self.dec_bn5(dec)
136 | dec = self.relu(dec)
137 | skp = torch.cat([enc1_f5, enc2_f5], 1)
138 | dec = self.dec_res4_2(dec, skip2=skp)
139 | dec = self.dec_res4_1(dec)
140 | skp = torch.cat([enc1_f4, enc2_f4], 1)
141 | dec = self.dec_res3_2(dec, skip2=skp)
142 | dec = self.dec_res3_1(dec)
143 | if self.corr is True:
144 | cor = self.dec_corr2(enc1_f3, enc2_f3)
145 | cor = self.corr_activation(cor)
146 | skp = torch.cat([enc1_f3, enc2_f3, cor], 1)
147 | else:
148 | skp = torch.cat([enc1_f3, enc2_f3], 1)
149 | dec = self.dec_res2_2(dec, skip1=skp)
150 | dec = self.dec_res2_1(dec)
151 | skp = torch.cat([enc1_f2, enc2_f2], 1)
152 | dec = self.dec_res1_2(dec, skip2=skp)
153 | dec = self.dec_res1_1(dec)
154 | dec = self.unpool(dec)
155 | if self.corr is True:
156 | cor = self.dec_corr1(enc1_f1, enc2_f1)
157 | cor = self.corr_activation(cor)
158 | dec = torch.cat([dec, enc1_f1, enc2_f1, cor], 1)
159 | else:
160 | dec = torch.cat([dec, enc1_f1, enc2_f1], 1)
161 | dec = self.dec_conv1(dec)
162 | dec = self.unpool(dec)
163 | dec = self.dec_bn1(dec)
164 | dec = self.relu(dec)
165 |
166 | out = self.classifier(dec)
167 | return out
168 |
169 | def init_weights(self):
170 | init.xavier_uniform_relu(self.modules())
171 |
172 | def load_net_param(self):
173 | from torchvision.models import resnet18
174 | resnet = resnet18(pretrained=True)
175 |
176 | self.enc1_conv1.load_state_dict(resnet.conv1.state_dict())
177 | self.enc1_bn1.load_state_dict(resnet.bn1.state_dict())
178 | self.enc1_res1_1.load_state_dict(list(resnet.layer1.children())[0].state_dict())
179 | self.enc1_res1_2.load_state_dict(list(resnet.layer1.children())[1].state_dict())
180 | self.enc1_res2_1.load_state_dict(list(resnet.layer2.children())[0].state_dict())
181 | self.enc1_res2_2.load_state_dict(list(resnet.layer2.children())[1].state_dict())
182 | self.enc1_res3_1.load_state_dict(list(resnet.layer3.children())[0].state_dict())
183 | self.enc1_res3_2.load_state_dict(list(resnet.layer3.children())[1].state_dict())
184 | self.enc1_res4_1.load_state_dict(list(resnet.layer4.children())[0].state_dict())
185 | self.enc1_res4_2.load_state_dict(list(resnet.layer4.children())[1].state_dict())
186 |
187 | self.enc2_conv1.load_state_dict(resnet.conv1.state_dict())
188 | self.enc2_bn1.load_state_dict(resnet.bn1.state_dict())
189 | self.enc2_res1_1.load_state_dict(list(resnet.layer1.children())[0].state_dict())
190 | self.enc2_res1_2.load_state_dict(list(resnet.layer1.children())[1].state_dict())
191 | self.enc2_res2_1.load_state_dict(list(resnet.layer2.children())[0].state_dict())
192 | self.enc2_res2_2.load_state_dict(list(resnet.layer2.children())[1].state_dict())
193 | self.enc2_res3_1.load_state_dict(list(resnet.layer3.children())[0].state_dict())
194 | self.enc2_res3_2.load_state_dict(list(resnet.layer3.children())[1].state_dict())
195 | self.enc2_res4_1.load_state_dict(list(resnet.layer4.children())[0].state_dict())
196 | self.enc2_res4_2.load_state_dict(list(resnet.layer4.children())[1].state_dict())
197 |
198 |
199 | class ResBL(nn.Module):
200 | def __init__(self, inc, midc, outc, stride=1, upscale=1, skip1=0, skip2=0):
201 | super(ResBL, self).__init__()
202 |
203 | self.conv1 = nn.Conv2d(inc+skip1, midc, 3, padding=1, stride=stride, bias=False)
204 | self.bn1 = nn.BatchNorm2d(midc)
205 | self.relu = nn.ReLU(inplace=True)
206 | self.conv2 = nn.Conv2d(midc+skip2, outc, 3, padding=1, bias=False)
207 | self.bn2 = nn.BatchNorm2d(outc)
208 |
209 | self.upscale = None
210 | if upscale > 1:
211 | self.upscale = nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True)
212 |
213 | self.downsample = None
214 |
215 | if inc != outc or stride > 1 or upscale > 1:
216 | if upscale > 1:
217 | self.downsample = nn.Sequential(
218 | nn.Conv2d(inc, outc, 1, padding=0, stride=stride, bias=False),
219 | nn.Upsample(scale_factor=upscale, mode='bilinear', align_corners=True),
220 | nn.BatchNorm2d(outc),
221 | )
222 | else:
223 | self.downsample = nn.Sequential(
224 | nn.Conv2d(inc, outc, 1, padding=0, stride=stride, bias=False),
225 | nn.BatchNorm2d(outc),
226 | )
227 |
228 | def forward(self, x, skip1=None, skip2=None):
229 | if skip1 is not None:
230 | res = torch.cat([x, skip1], 1)
231 | else:
232 | res = x
233 |
234 | res = self.conv1(res)
235 | res = self.bn1(res)
236 | res = self.relu(res)
237 |
238 | if skip2 is not None:
239 | res = torch.cat([res, skip2], 1)
240 |
241 | res = self.conv2(res)
242 | if self.upscale is not None:
243 | res = self.upscale(res)
244 | res = self.bn2(res)
245 |
246 | identity = x
247 | if self.downsample is not None:
248 | identity = self.downsample(x)
249 |
250 | res += identity
251 | out = self.relu(res)
252 |
253 | return out
254 |
255 |
256 |
257 |
258 |
--------------------------------------------------------------------------------