├── src
├── deeplab
│ ├── __init__.py
│ └── deeplab_xception.py
├── LICENSE_deeplab
├── LICENSE_adamw
├── utils.py
├── visualizer.py
├── model.py
├── evaluator.py
├── adamw.py
├── cosine_scheduler.py
├── test.py
├── test_merge.py
├── train.py
├── finetune_real.py
├── finetune.py
└── dataset.py
├── imgs
└── teaser.png
├── bandsel
├── readme.txt
├── bands
│ ├── rs.txt
│ ├── nncv.txt
│ └── mvpca.txt
└── bandsel_nncv.py
├── prepare
├── create_real_hdf5.sh
├── create_hdf5_bg.py
├── create_hdf5_testext.py
├── create_hdf5_val.py
├── create_hdf5_test.py
├── create_hdf5_trainext.py
├── create_hdf5_bgext.py
└── calibrate_kappa.py
├── LICENSE
├── README.md
└── recog
├── recognition.py
├── recognition_testext.py
└── recognition_trainext.py
/src/deeplab/__init__.py:
--------------------------------------------------------------------------------
1 | from .deeplab_xception import DeepLabv3_plus
2 |
--------------------------------------------------------------------------------
/imgs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tiancheng-zhi/ms-powder/HEAD/imgs/teaser.png
--------------------------------------------------------------------------------
/bandsel/readme.txt:
--------------------------------------------------------------------------------
1 | Band Selection
2 | run:
3 | python bandsel_nncv.py
4 |
5 | Selected SWIR bands are in folder "bands". Band index ranges from 0 to 960.
6 |
--------------------------------------------------------------------------------
/bandsel/bands/rs.txt:
--------------------------------------------------------------------------------
1 | 901,418,0,584,91,660,534,148,438,472,168,298,3,255,461,73,646,515,26,751,117,163,320,82,596,70,656,143,57,356,113,367,557,542,110,52,61,346,471,464,216,930,303,65,413,36,430,342,229
2 |
--------------------------------------------------------------------------------
/bandsel/bands/nncv.txt:
--------------------------------------------------------------------------------
1 | 746,397,875,73,679,430,395,672,47,562,365,676,45,709,394,506,235,363,620,406,819,41,72,103,275,75,219,344,71,787,699,341,329,705,372,559,442,431,656,536,163,204,18,744,48,428,238,105,498
2 |
--------------------------------------------------------------------------------
/bandsel/bands/mvpca.txt:
--------------------------------------------------------------------------------
1 | 833,682,148,123,483,633,338,551,1,391,166,226,504,684,143,342,583,556,907,594,182,154,876,599,85,45,490,285,535,531,249,615,174,812,560,115,2,40,444,487,254,539,626,111,133,262,510,899,253,463,332,546,586
2 |
--------------------------------------------------------------------------------
/prepare/create_real_hdf5.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ../real/bg
2 | mkdir -p ../real/bgext
3 | python create_hdf5_bg.py
4 | python create_hdf5_bgext.py
5 | python create_hdf5_trainext.py
6 | python create_hdf5_test.py
7 | python create_hdf5_testext.py
8 | python create_hdf5_val.py
9 |
--------------------------------------------------------------------------------
/prepare/create_hdf5_bg.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import cv2
4 |
5 | from pathlib import Path
6 |
7 | if __name__ == '__main__':
8 | data_path = Path('../data/train/')
9 | real_path = Path('../real/bg/')
10 | n_scenes = 16
11 | height = 160
12 | width = 280
13 | n_channels = 965
14 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
15 |
16 | for light in lights:
17 | h5f = h5py.File(str(Path(real_path / (light + '.hdf5'))), 'w')
18 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32')
19 | for i in range(n_scenes):
20 | idx = str(i).zfill(2)
21 | im_npz = np.load(data_path / light / 'bgscene' / (idx + '_bgscene.npz'))
22 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
23 | dset_im[i, :, :, :] = im
24 | h5f.close()
25 |
--------------------------------------------------------------------------------
/prepare/create_hdf5_testext.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import cv2
4 |
5 | from pathlib import Path
6 |
7 | if __name__ == '__main__':
8 | data_path = Path('../data/testext/')
9 | real_path = Path('../real/')
10 | n_scenes = 64
11 | height = 160
12 | width = 280
13 | n_channels = 38
14 | h5f = h5py.File(str(Path(real_path / ('testext.hdf5'))), 'w')
15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32')
16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8')
17 | for i in range(n_scenes):
18 | idx = str(i).zfill(2)
19 | im_npz = np.load(data_path / 'scene' / (idx + '_scene.npz'))
20 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
21 | label = cv2.imread(str(data_path / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
22 | dset_im[i, :, :, :] = im
23 | dset_label[i, :, :] = label
24 | h5f.close()
25 |
--------------------------------------------------------------------------------
/prepare/create_hdf5_val.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import cv2
4 |
5 | from pathlib import Path
6 |
7 | if __name__ == '__main__':
8 | data_path = Path('../data/val/')
9 | real_path = Path('../real/')
10 | n_scenes = 32
11 | height = 160
12 | width = 280
13 | n_channels = 965
14 | h5f = h5py.File(str(Path(real_path / ('val.hdf5'))), 'w')
15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32')
16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8')
17 | for i in range(n_scenes):
18 | idx = str(i).zfill(2)
19 | im_npz = np.load(data_path / 'scene' / (idx + '_scene.npz'))
20 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
21 | label = cv2.imread(str(data_path / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
22 | dset_im[i, :, :, :] = im
23 | dset_label[i, :, :] = label
24 | h5f.close()
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/prepare/create_hdf5_test.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import cv2
4 |
5 | from pathlib import Path
6 |
7 | if __name__ == '__main__':
8 | data_path = Path('../data/test/')
9 | real_path = Path('../real/')
10 | n_scenes = 32
11 | height = 160
12 | width = 280
13 | n_channels = 965
14 | h5f = h5py.File(str(Path(real_path / ('test.hdf5'))), 'w')
15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32')
16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8')
17 | for i in range(n_scenes):
18 | idx = str(i).zfill(2)
19 | im_npz = np.load(data_path / 'scene' / (idx + '_scene.npz'))
20 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
21 | label = cv2.imread(str(data_path / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
22 | dset_im[i, :, :, :] = im
23 | dset_label[i, :, :] = label
24 | h5f.close()
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Tiancheng Zhi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/LICENSE_deeplab:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Pyjcsx
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/LICENSE_adamw:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Maksym Pyrozhok
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/prepare/create_hdf5_trainext.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import cv2
4 |
5 | from pathlib import Path
6 |
7 | if __name__ == '__main__':
8 | data_path = Path('../data/trainext/')
9 | real_path = Path('../real/')
10 | n_scenes = 64
11 | height = 160
12 | width = 280
13 | n_channels = 38
14 | h5f = h5py.File(str(Path(real_path / ('trainext.hdf5'))), 'w')
15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32')
16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8')
17 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
18 |
19 | for i in range(n_scenes):
20 | idx = str(i % (n_scenes // len(lights))).zfill(2)
21 | light = lights[i // (n_scenes // len(lights))]
22 | im_npz = np.load(data_path / light / 'scene' / (idx + '_scene.npz'))
23 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
24 | label = cv2.imread(str(data_path / light / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
25 | dset_im[i, :, :, :] = im
26 | dset_label[i, :, :] = label
27 | h5f.close()
28 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def cpu_np(tensor):
6 | return tensor.cpu().numpy()
7 |
8 |
9 | def to_image(matrix):
10 | image = cpu_np(torch.clamp(matrix, 0, 1) * 255).astype(np.uint8)
11 | if matrix.size()[0] == 1:
12 | image = np.concatenate((image, image, image), 0)
13 | return image
14 | return image
15 |
16 |
17 | def errormap(green, yellow, red, blue):
18 | err = np.zeros((3, green.shape[0], green.shape[1]), dtype=np.uint8)
19 | err[2, :, :][blue] = 255
20 | err[0, :, :][red] = 255
21 | err[0, :, :][yellow] = 255
22 | err[1, :, :][yellow] = 255
23 | err[2, :, :][yellow] = 0
24 | err[0, :, :][green] = 0
25 | err[1, :, :][green] = 255
26 | return err
27 |
28 |
29 | def colormap(label):
30 | cm = []
31 | for r in [35, 90, 145, 200, 255]:
32 | for g in [35, 90, 145, 200, 255]:
33 | for b in [60, 125, 190, 255]:
34 | cm.append((r, g, b))
35 | cm.append((0, 0, 0))
36 | label_cm = np.stack((label, label, label), 0).astype(np.uint8)
37 | for c, color in enumerate(cm):
38 | mask = (label == c)
39 | label_cm[0, :, :][mask] = color[0]
40 | label_cm[1, :, :][mask] = color[1]
41 | label_cm[2, :, :][mask] = color[2]
42 | return label_cm
43 |
--------------------------------------------------------------------------------
/src/visualizer.py:
--------------------------------------------------------------------------------
1 | import visdom
2 | import numpy as np
3 |
4 | class Visualizer():
5 |
6 | def __init__(self, server='http://localhost', port=8097, env='main'):
7 | self.vis = visdom.Visdom(server=server, port=port, env=env, use_incoming_socket=False)
8 | self.iteration = []
9 | self.nlogloss = []
10 | self.epoch = []
11 | self.acc = []
12 |
13 | def state_dict(self):
14 | return {'iteration': self.iteration, 'nlogloss': self.nlogloss, 'epoch': self.epoch, 'acc': self.acc}
15 |
16 |
17 | def load_state_dict(self, state_dict):
18 | self.iteration = state_dict['iteration']
19 | self.nlogloss = state_dict['nlogloss']
20 | self.epoch = state_dict['epoch']
21 | self.acc = state_dict['acc']
22 |
23 | def plot_loss(self):
24 | self.vis.line(
25 | X=np.array(self.iteration),
26 | Y=np.array(self.nlogloss),
27 | opts={
28 | 'title': '-LogLoss',
29 | 'legend': ['-LogLoss'],
30 | 'xlabel': 'epoch',
31 | 'ylabel': '-logloss'},
32 | win=0)
33 |
34 | def plot_acc(self):
35 | self.vis.line(
36 | X=np.array(self.epoch),
37 | Y=np.array(self.acc),
38 | opts={
39 | 'title': 'Performance',
40 | 'legend': ['mIoUval', 'mIoUtest'],
41 | 'xlabel': 'epoch',
42 | 'ylabel': 'performance'},
43 | win=1)
44 |
45 | def plot_image(self, im, idx):
46 | self.vis.image(im, win=idx + 2)
47 |
--------------------------------------------------------------------------------
/prepare/create_hdf5_bgext.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import numpy as np
3 | import cv2
4 |
5 | from pathlib import Path
6 |
7 | if __name__ == '__main__':
8 | data_path_bg = Path('../data/train/')
9 | data_path_ext = Path('../data/trainext/')
10 | real_path = Path('../real/bgext/')
11 | real_path.mkdir(exist_ok=True, parents=True)
12 | n_scenes = 16
13 | height = 160
14 | width = 280
15 | n_channels = 38
16 | sel = [0, 1, 2, 3, 4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964]
17 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
18 | for light in lights:
19 | h5f = h5py.File(str(Path(real_path / (light + '.hdf5'))), 'w')
20 | dset_im = h5f.create_dataset('im', (n_scenes * 2, height, width, n_channels), dtype='float32')
21 | for i in range(n_scenes):
22 | idx = str(i).zfill(2)
23 | im_npz = np.load(data_path_bg / light / 'bgscene' / (idx + '_bgscene.npz'))
24 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
25 | dset_im[i, :, :, :] = im[:, :, sel]
26 | for i in range(n_scenes):
27 | idx = str(i).zfill(2)
28 | im_npz = np.load(data_path_ext / light / 'bgscene' / (idx + '_bgscene.npz'))
29 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2)
30 | dset_im[n_scenes + i, :, :, :] = im
31 | h5f.close()
32 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 | from torch.nn import Parameter
7 | from deeplab import DeepLabv3_plus
8 |
9 |
10 | class PowderNet(nn.Module):
11 |
12 | def __init__(self, arch, n_channels, n_classes):
13 | super(PowderNet, self).__init__()
14 | if arch == 'deeplab':
15 | self.body = DeepLabv3_plus(nInputChannels=n_channels, n_classes=n_classes, pretrained=False, _print=False)
16 | else:
17 | assert(False)
18 |
19 | def forward(self, x):
20 | out = self.body(x)
21 | return out
22 |
23 |
24 | def get_1x_lr_params(model):
25 | """
26 | This generator returns all the parameters of the net except for
27 | the last classification layer. Note that for each batchnorm layer,
28 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
29 | any batchnorm parameter
30 | """
31 | b = [model.body.xception_features]
32 | for i in range(len(b)):
33 | for k in b[i].parameters():
34 | if k.requires_grad:
35 | yield k
36 |
37 |
38 | def get_10x_lr_params(model):
39 | """
40 | This generator returns all the parameters for the last layer of the net,
41 | which does the classification of pixel into classes
42 | """
43 | b = [model.body.aspp1, model.body.aspp2, model.body.aspp3, model.body.aspp4, model.body.conv1, model.body.conv2, model.body.last_conv]
44 | for j in range(len(b)):
45 | for k in b[j].parameters():
46 | if k.requires_grad:
47 | yield k
48 |
--------------------------------------------------------------------------------
/prepare/calibrate_kappa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | from pathlib import Path
4 |
5 | def parse_args():
6 | parser = argparse.ArgumentParser(description='Calibration')
7 | parser.add_argument('--data-path', type=str, default='../data/train')
8 | parser.add_argument('--out-path', type=str, default='../params')
9 | opt = parser.parse_args()
10 | return opt
11 |
12 |
13 | if __name__ == '__main__':
14 | opt = parse_args()
15 | print(opt)
16 |
17 | Path(opt.out_path).mkdir(parents=True, exist_ok=True)
18 |
19 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
20 | n = 100
21 | h = 14
22 | w = 14
23 | c = 965
24 | valid_threshold = 20
25 |
26 | kappa_params = []
27 | for i in range(n):
28 | print(i)
29 | key = str(i).zfill(2)
30 | kappa_lights = []
31 | for lid, light in enumerate(lights):
32 | thick_path = Path(opt.data_path) / light / 'thick'
33 | thin_path = Path(opt.data_path) / light / 'thin'
34 | bg_path = Path(opt.data_path) / light / 'bg'
35 | thick = np.load(thick_path / (key + '_thick.npz'))
36 | thin = np.load(thin_path / (key + '_thin.npz'))
37 | bg = np.load(bg_path / (key + '_bg.npz'))
38 | thick = np.concatenate((thick['rgbn'], thick['swir']), 2)
39 | thin = np.concatenate((thin['rgbn'], thin['swir']), 2)
40 | bg = np.concatenate((bg['rgbn'], bg['swir']), 2)
41 |
42 |
43 | thick = np.mean(thick, (0, 1), keepdims=True)
44 | bg = np.mean(bg, (0, 1), keepdims=True)
45 | alpha = (thin - thick) / (bg - thick)
46 |
47 | # valid alpha selection
48 | alpha = alpha.reshape([h * w, c])
49 | alpha = np.clip(alpha, 0.01, 0.99)
50 | kt = -np.log(alpha)
51 | kappa = np.median(kt, axis=0)
52 | ratio = (kappa[:4].mean() + kappa[4:].mean()) / 2
53 | kappa = kappa / ratio
54 | print(kappa[:4], kappa[[5, -1]], kappa.max(), kappa.min())
55 | kappa_lights.append(kappa)
56 | kappa_params.append(kappa_lights)
57 | np.savez(Path(opt.out_path) / 'kappa_params.npz', params=kappa_params)
58 |
--------------------------------------------------------------------------------
/src/evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import cv2
5 |
6 |
7 | class Evaluator:
8 |
9 | def __init__(self, n_classes, bg_err):
10 | self.n_classes = n_classes
11 | self.bg_err = bg_err
12 |
13 | self.bg_confs = []
14 | self.pd_preds = []
15 | self.gt_nums = [[] for c in range(n_classes)]
16 | self.cls_bg_confs = [[] for c in range(n_classes)]
17 |
18 | self.tp = np.zeros(n_classes)
19 | self.fp = np.zeros(n_classes)
20 | self.num = np.zeros(n_classes)
21 | self.itp = np.zeros(n_classes)
22 | self.inum = np.zeros(n_classes)
23 |
24 | self.preds = []
25 |
26 | def register(self, label, prob):
27 | """
28 | label: H x W, n_classes-1 is bg
29 | prob: C x H x W
30 | """
31 |
32 | bg_conf = prob[-1,:,:]
33 | pd_pred = np.argmax(prob[:-1,:,:], axis=0)
34 | self.bg_confs.append(bg_conf)
35 | self.pd_preds.append(pd_pred)
36 | pred = np.argmax(prob, axis=0)
37 | self.preds.append(pred)
38 | for c in range(self.n_classes):
39 | gt_mask = (label == c)
40 | # for Powder Accuracy
41 | if gt_mask.any():
42 | if c < self.n_classes - 1:
43 | pred_mask = (pd_pred == c)
44 | self.gt_nums[c].append(gt_mask.sum())
45 | self.cls_bg_confs[c].append(bg_conf[gt_mask * pred_mask])
46 | else:
47 | self.cls_bg_confs[c] += list(bg_conf[gt_mask])
48 | # for IoU
49 | self.num[c] += gt_mask.sum()
50 | self.tp[c] += ((pred == c) * gt_mask).sum()
51 | self.inum[c] += 1
52 | self.itp[c] += ((pred == c) * gt_mask).sum() / gt_mask.sum()
53 | self.fp[c] += ((pred == c) * (1 - gt_mask)).sum()
54 |
55 | def evaluate(self):
56 | self.bg_conf_threshold = np.percentile(self.cls_bg_confs[-1], self.bg_err)
57 | accs = []
58 | for c in range(self.n_classes - 1):
59 | for i, cls_bg_conf in enumerate(self.cls_bg_confs[c]):
60 | acc = (cls_bg_conf < self.bg_conf_threshold).sum() / self.gt_nums[c][i]
61 | accs.append(acc)
62 | msa = np.mean(np.array(accs))
63 |
64 | self.bg_confs = np.array(self.bg_confs)
65 | self.pd_preds = np.array(self.pd_preds)
66 | predictions = self.pd_preds.copy()
67 | predictions[self.bg_confs >= self.bg_conf_threshold] = self.n_classes - 1
68 |
69 | iou = self.tp / (self.num + self.fp)
70 | miou = iou.mean()
71 | iiou = (self.itp * self.num / self.inum) / (self.num + self.fp)
72 | miiou = iiou.mean()
73 | return msa, predictions, miou, miiou, self.preds
74 |
75 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multispectral Imaging for Fine-Grained Recognition of Powders on Complex Backgrounds
2 |
3 |
4 |
5 | [Tiancheng Zhi](http://cs.cmu.edu/~tzhi), [Bernardo R. Pires](http://www.andrew.cmu.edu/user/bpires/), [Martial Hebert](http://www.cs.cmu.edu/~hebert/), [Srinivasa G. Narasimhan](http://www.cs.cmu.edu/~srinivas/)
6 |
7 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019.
8 |
9 | [[Project](http://www.cs.cmu.edu/~ILIM/projects/IM/MSPowder/)] [[Paper](http://www.cs.cmu.edu/~ILIM/projects/IM/MSPowder/files/ZPHN-CVPR19.pdf)] [[Supp](http://www.cs.cmu.edu/~ILIM/projects/IM/MSPowder/files/ZPHN-CVPR19-supp.pdf)]
10 |
11 |
12 |
13 |
14 |
15 | ## Requirements
16 | - NVIDIA TITAN Xp
17 | - Ubuntu 16.04
18 | - Python 3.6
19 | - OpenCV 4.0
20 | - PyTorch 1.0
21 | - Visdom
22 |
23 | ## Download "SWIRPowder" Dataset
24 | Download the ["data" folder](http://platformpgh.cs.cmu.edu/tzhi/SWIRPowderRelease/data/), and put it in the repo root directory.
25 | See "data/readme.txt" for description.
26 |
27 | ## Calibarate Attenuation Parameter
28 | In "prepare" directory, run:
29 | ```
30 | python calibrate_kappa.py
31 | ```
32 |
33 | ## Band Selection
34 | See "readme.txt" in "bandsel" directory
35 |
36 | ## Recognition with Known Powder Location/Mask
37 | In "recog" directory, run:
38 | ```
39 | python recognition.py
40 | ```
41 |
42 | ## Recognition without Known Powder Location/Mask
43 | ### Prepare real data
44 | In "prepare" directory, run:
45 | ```
46 | sh create_real_hdf5.sh
47 | ```
48 |
49 | ### Prepare synthetic data
50 | Download the ["synthetic" folder](http://platformpgh.cs.cmu.edu/tzhi/SWIRPowderRelease/synthetic/) and put it in the repo root directory.
51 |
52 | ### Train on synthetic powder on synthetic background
53 | In "src" directory, run:
54 | ```
55 | python train.py --out-path ckpts/ckpt_default --bands 0,1,2,3,77,401,750,879
56 | ```
57 |
58 | Note that the hdf5 file merges RGBN and SWIR channels, so channel ID 0\~3 are RGBN channels, channel ID 4\~964 are SWIR channels.
59 |
60 | To use NNCV selection, use `--bands 0,1,2,3,77,401,750,879`.
61 |
62 | To use Grid selection, use `--bands 0,1,2,3,4,34,934,964`.
63 |
64 | To use MVPCA selection, use `--bands 0,1,2,3,127,152,686,837`.
65 |
66 | To use RS selection, use `--bands 0,1,2,3,4,422,588,905`.
67 |
68 | See "bandsel/bands/" for more selected bands. Remember to "add 4" to convert 0\~960 range to 4\~964 range.
69 |
70 |
71 | ### Train on synthetic powder on real background
72 | In "src" directory, run:
73 | ```
74 | python finetune.py --out-path ckpts/ckpt_default_extft --bands 0,1,2,3,77,401,750,879 --pretrain ckpts/ckpt_default/247.pth --split bgext
75 | ```
76 |
77 | Note: use `--split bg` for experiments on unextended dataset.
78 |
79 | ### Train on real powder on real background
80 | In "src" directory, run:
81 | ```
82 | python finetune_real.py --out-path ckpts/ckpt_default_extft_real --bands 0,1,2,3,77,401,750,879 --pretrain ckpts/ckpt_default_extft/55.pth
83 | ```
84 |
85 | ### Test with CRF post-processing
86 | In "src" directory, run:
87 | ```
88 | python test.py --ckpt model.pth # Test on Scene-test
89 | python test_merge.py --ckpt model.pth # Test on dataset merging Scene-test and Scene-sl-test
90 | ```
91 |
92 | ### Pretrained model
93 | Download [pretrained.pth](http://platformpgh.cs.cmu.edu/tzhi/SWIRPowderRelease/pretrained.pth), put it in "src" directory, and test it with:
94 | ```
95 | python test_merge.py --ckpt pretrained.pth
96 | ```
97 |
--------------------------------------------------------------------------------
/src/adamw.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class AdamW(Optimizer):
7 | """Implements Adam algorithm.
8 |
9 | Arguments:
10 | params (iterable): iterable of parameters to optimize or dicts defining
11 | parameter groups
12 | lr (float, optional): learning rate (default: 1e-3)
13 | betas (Tuple[float, float], optional): coefficients used for computing
14 | running averages of gradient and its square (default: (0.9, 0.999))
15 | eps (float, optional): term added to the denominator to improve
16 | numerical stability (default: 1e-8)
17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
18 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this
19 | algorithm from the paper `On the Convergence of Adam and Beyond`_
20 |
21 | """
22 |
23 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
24 | weight_decay=0, amsgrad=False):
25 | if not 0.0 <= betas[0] < 1.0:
26 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
27 | if not 0.0 <= betas[1] < 1.0:
28 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
29 | defaults = dict(lr=lr, betas=betas, eps=eps,
30 | weight_decay=weight_decay, amsgrad=amsgrad)
31 | #super(AdamW, self).__init__(params, defaults)
32 | super().__init__(params, defaults)
33 |
34 | def step(self, closure=None):
35 | """Performs a single optimization step.
36 |
37 | Arguments:
38 | closure (callable, optional): A closure that reevaluates the model
39 | and returns the loss.
40 | """
41 | loss = None
42 | if closure is not None:
43 | loss = closure()
44 |
45 | for group in self.param_groups:
46 | for p in group['params']:
47 | if p.grad is None:
48 | continue
49 | grad = p.grad.data
50 | if grad.is_sparse:
51 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
52 | amsgrad = group['amsgrad']
53 |
54 | state = self.state[p]
55 |
56 | # State initialization
57 | if len(state) == 0:
58 | state['step'] = 0
59 | # Exponential moving average of gradient values
60 | state['exp_avg'] = torch.zeros_like(p.data)
61 | # Exponential moving average of squared gradient values
62 | state['exp_avg_sq'] = torch.zeros_like(p.data)
63 | if amsgrad:
64 | # Maintains max of all exp. moving avg. of sq. grad. values
65 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
66 |
67 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
68 | if amsgrad:
69 | max_exp_avg_sq = state['max_exp_avg_sq']
70 | beta1, beta2 = group['betas']
71 |
72 | state['step'] += 1
73 |
74 | # Decay the first and second moment running average coefficient
75 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
76 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
77 | if amsgrad:
78 | # Maintains the maximum of all 2nd moment running avg. till now
79 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
80 | # Use the max. for normalizing running avg. of gradient
81 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
82 | else:
83 | denom = exp_avg_sq.sqrt().add_(group['eps'])
84 |
85 | bias_correction1 = 1 - beta1 ** state['step']
86 | bias_correction2 = 1 - beta2 ** state['step']
87 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
88 |
89 | if group['weight_decay'] != 0:
90 | decayed_weights = torch.mul(p.data, group['weight_decay'])
91 | p.data.addcdiv_(-step_size, exp_avg, denom)
92 | p.data.sub_(decayed_weights)
93 | else:
94 | p.data.addcdiv_(-step_size, exp_avg, denom)
95 |
96 | return loss
97 |
--------------------------------------------------------------------------------
/bandsel/bandsel_nncv.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import numba
6 | import cv2
7 | import time
8 |
9 | from pathlib import Path
10 |
11 |
12 | def cosine(a, b):
13 | y = b.unsqueeze(0)
14 | n_pixels = a.size()[0]
15 | batch_size = 1024
16 | if n_pixels % batch_size == 0:
17 | n_batches = n_pixels // batch_size
18 | else:
19 | n_batches = n_pixels // batch_size + 1
20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda()
21 | for batch_idx in range(n_batches):
22 | bs = batch_idx * batch_size
23 | be = min(n_pixels, (batch_idx + 1) * batch_size)
24 | x = a[bs:be, :].unsqueeze(1)
25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2)
26 | return sim
27 |
28 |
29 | def sim_func(dist, query, database):
30 | if (dist == 'full') or (query.size()[1] <= 5):
31 | return cosine(query, database)
32 | elif dist == 'split':
33 | return cosine(query[:,:4], database[:,:4]) + cosine(query[:,4:], database[:,4:])
34 | else:
35 | assert(False)
36 |
37 |
38 | def feat_eng(dist, raw):
39 | if dist == 'full' or dist == 'split':
40 | return raw
41 | elif dist == 'decouple':
42 | swir = raw[:, 4:]
43 | mean_swir = swir.mean(dim=1, keepdim=True)
44 | feat = torch.cat((mean_swir, raw), dim=1)
45 | return feat
46 | else:
47 | assert(False)
48 |
49 |
50 | def parse_args():
51 | parser = argparse.ArgumentParser(description='NNCV Band Selection')
52 | parser.add_argument('--data-path', type=str, default='../data')
53 | parser.add_argument('--log-path', type=str, default='./bands')
54 | parser.add_argument('--n-sels', type=int, default=49)
55 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split'])
56 | opt = parser.parse_args()
57 | return opt
58 |
59 |
60 | if __name__ == '__main__':
61 | opt = parse_args()
62 | Path(opt.log_path).mkdir(parents=True, exist_ok=True)
63 | if opt.dist == 'split':
64 | log = open(Path(opt.log_path) / ('nncv.txt'), 'w')
65 | else:
66 | log = open(Path(opt.log_path) / ('nncv_{}.txt'.format(opt.dist)), 'w')
67 | print(opt)
68 | opt.lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
69 | opt.n_lights = len(opt.lights)
70 | opt.n_powders = 100
71 | opt.n_bgmats = 100
72 | opt.n_channels = 965
73 | opt.n_full_swir_channels = 961
74 |
75 | train_path = Path(opt.data_path) / 'train'
76 |
77 | y = []
78 |
79 | thick_list = np.zeros((opt.n_powders, opt.n_lights, opt.n_channels))
80 | for i in range(opt.n_powders):
81 | idx = str(i).zfill(2)
82 | for lid, light in enumerate(opt.lights):
83 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz'))
84 | thick = np.concatenate((thick['rgbn'], thick['swir']), axis=2)
85 | thick = thick.mean((0, 1))
86 | thick_list[i, lid] = thick
87 | y.append(i)
88 | thick_list = thick_list.reshape((opt.n_powders * opt.n_lights, opt.n_channels))
89 |
90 | bgmat_list = np.zeros((opt.n_bgmats, opt.n_lights, opt.n_channels))
91 | for i in range(opt.n_bgmats):
92 | idx = str(i).zfill(2)
93 | for lid, light in enumerate(opt.lights):
94 | bgmat = np.load(train_path / light / 'bgmat' / (idx + '_bgmat.npz'))
95 | bgmat = np.concatenate((bgmat['rgbn'], bgmat['swir']), axis=2)
96 | bgmat = bgmat.mean((0, 1))
97 | bgmat_list[i, lid] = bgmat
98 | y.append(opt.n_powders)
99 | bgmat_list = bgmat_list.reshape((opt.n_bgmats * opt.n_lights, opt.n_channels))
100 |
101 | raw = np.concatenate((thick_list, bgmat_list), axis=0)
102 | raw = torch.from_numpy(raw).cuda()
103 | y = np.array(y)
104 | y = torch.from_numpy(y).cuda()
105 |
106 | selection = np.zeros(opt.n_channels, dtype=np.bool_)
107 | selection[0] = True
108 | selection[1] = True
109 | selection[2] = True
110 | selection[3] = True
111 |
112 | start_time = time.time()
113 | bands = []
114 | for i in range(opt.n_sels):
115 | best_acc = 0
116 | for j in range(opt.n_channels):
117 | if selection[j]:
118 | continue
119 | selection[j] = True
120 | selection_th = torch.from_numpy(selection.astype(np.uint8)).unsqueeze(0).cuda()
121 | x = torch.masked_select(raw, selection_th).view(raw.size()[0], -1)
122 | if i > 0:
123 | x = feat_eng(opt.dist, x)
124 |
125 | sims = sim_func(opt.dist, x, x)
126 | _, indices = torch.sort(sims, dim=1)
127 | acc = (y == y[indices[:,-2]]).cpu().numpy().astype(np.float32)
128 |
129 | acc = acc.reshape((opt.n_powders + opt.n_bgmats, opt.n_lights)).mean(axis=1)
130 |
131 | acc = (acc[:opt.n_powders].sum() + acc[opt.n_powders:].mean()) / (opt.n_powders + 1)
132 | if acc > best_acc:
133 | best_acc = acc
134 | best_sel = j
135 | selection[j] = False
136 | selection[best_sel] = True
137 | print(i, best_acc, best_sel - 4, round(time.time()-start_time))
138 | bands.append(best_sel - 4)
139 |
140 | print(bands)
141 | st = ''
142 | for i in bands:
143 | st = st + ',' + str(i)
144 | st = st[1:]
145 | print(st, file=log)
146 | log.close()
147 |
148 |
--------------------------------------------------------------------------------
/src/cosine_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim import Optimizer
2 | import math
3 | import torch
4 |
5 |
6 | class CosineLRWithRestarts():
7 | """Decays learning rate with cosine annealing, normalizes weight decay
8 | hyperparameter value, implements restarts.
9 | https://arxiv.org/abs/1711.05101
10 |
11 | Args:
12 | optimizer (Optimizer): Wrapped optimizer.
13 | batch_size: minibatch size
14 | epoch_size: training samples per epoch
15 | restart_period: epoch count in the first restart period
16 | t_mult: multiplication factor by which the next restart period will extend/shrink
17 |
18 |
19 | Example:
20 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2)
21 | >>> for epoch in range(100):
22 | >>> scheduler.step()
23 | >>> train(...)
24 | >>> ...
25 | >>> optimizer.zero_grad()
26 | >>> loss.backward()
27 | >>> optimizer.step()
28 | >>> scheduler.batch_step()
29 | >>> validate(...)
30 | """
31 |
32 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100,
33 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False):
34 |
35 | if not isinstance(optimizer, Optimizer):
36 | raise TypeError('{} is not an Optimizer'.format(
37 | type(optimizer).__name__))
38 | self.optimizer = optimizer
39 | if last_epoch == -1:
40 | for group in optimizer.param_groups:
41 | group.setdefault('initial_lr', group['lr'])
42 | else:
43 | for i, group in enumerate(optimizer.param_groups):
44 | if 'initial_lr' not in group:
45 | raise KeyError("param 'initial_lr' is not specified "
46 | "in param_groups[{}] when resuming an"
47 | " optimizer".format(i))
48 | self.base_lrs = list(map(lambda group: group['initial_lr'],
49 | optimizer.param_groups))
50 |
51 | self.last_epoch = last_epoch
52 | self.batch_size = batch_size
53 | self.epoch_size = epoch_size
54 | self.eta_threshold = eta_threshold
55 | self.t_mult = t_mult
56 | self.verbose = verbose
57 | self.base_weight_decays = list(map(lambda group: group['weight_decay'],
58 | optimizer.param_groups))
59 | self.restart_period = restart_period
60 | self.restarts = 0
61 | self.t_epoch = -1
62 |
63 | def state_dict(self):
64 | """Returns the state of the scheduler as a :class:`dict`.
65 |
66 | It contains an entry for every variable in self.__dict__ which
67 | is not the optimizer.
68 | """
69 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'batch_increment'}
70 |
71 | def load_state_dict(self, state_dict):
72 | """Loads the schedulers state.
73 |
74 | Arguments:
75 | state_dict (dict): scheduler state. Should be an object returned
76 | from a call to :meth:`state_dict`.
77 | """
78 | self.__dict__.update(state_dict)
79 |
80 | def _schedule_eta(self):
81 | """
82 | Threshold value could be adjusted to shrink eta_min and eta_max values.
83 | """
84 | eta_min = 0
85 | eta_max = 1
86 | if self.restarts <= self.eta_threshold:
87 | return eta_min, eta_max
88 | else:
89 | d = self.restarts - self.eta_threshold
90 | k = d * 0.09
91 | return (eta_min + k, eta_max - k)
92 |
93 | def get_lr(self, t_cur):
94 | eta_min, eta_max = self._schedule_eta()
95 |
96 | eta_t = (eta_min + 0.5 * (eta_max - eta_min)
97 | * (1. + math.cos(math.pi *
98 | (t_cur / self.restart_period))))
99 |
100 | weight_decay_norm_multi = math.sqrt(self.batch_size /
101 | (self.epoch_size *
102 | self.restart_period))
103 | lrs = [base_lr * eta_t for base_lr in self.base_lrs]
104 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi
105 | for base_weight_decay in self.base_weight_decays]
106 |
107 | if self.t_epoch % self.restart_period < self.t_epoch:
108 | if self.verbose:
109 | print("Restart at epoch {}".format(self.last_epoch))
110 | self.restart_period *= self.t_mult
111 | self.restarts += 1
112 | self.t_epoch = 0
113 |
114 | return zip(lrs, weight_decays)
115 |
116 | def _set_batch_size(self):
117 | d, r = divmod(self.epoch_size, self.batch_size)
118 | batches_in_epoch = d + 2 if r > 0 else d + 1
119 | self.batch_increment = (i for i in torch.linspace(0, 1,
120 | batches_in_epoch))
121 |
122 | def step(self):
123 | self.last_epoch += 1
124 | self.t_epoch += 1
125 | self._set_batch_size()
126 | self.batch_step()
127 |
128 | def batch_step(self):
129 | t_cur = self.t_epoch + next(self.batch_increment)
130 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,
131 | self.get_lr(t_cur)):
132 | param_group['lr'] = lr
133 | param_group['weight_decay'] = weight_decay
134 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import pydensecrf.densecrf as dcrf
8 | from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
9 | from pathlib import Path
10 | from torch.utils.data import DataLoader
11 | from dataset import RealDataset
12 | from model import PowderNet
13 | from utils import to_image, colormap, errormap
14 | from evaluator import Evaluator
15 | import skimage.io as io
16 | import cv2
17 | import collections
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Testing')
22 | parser.add_argument('--ckpt', type=str)
23 | parser.add_argument('--real-path', type=str, default='../real')
24 | parser.add_argument('--out-path', type=str, default='./result')
25 | parser.add_argument('--bg-err', type=float, default=1.0)
26 | parser.add_argument('--sdims', type=int, default=3)
27 | parser.add_argument('--schan', type=int, default=3)
28 | parser.add_argument('--compat', type=int, default=3)
29 | parser.add_argument('--iters', type=int, default=10)
30 | parser.add_argument('--threads', type=int, default=1)
31 | parser.add_argument('--batch-size', type=int, default=1)
32 | opt = parser.parse_args()
33 | return opt
34 |
35 |
36 | def crf(prob, im, sdims, schan, compat, iters):
37 | if opt.channels == 965:
38 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964]
39 | elif opt.channels == 4:
40 | bilateral_ch = [0,1,2,3]
41 | elif opt.channels == -961:
42 | bilateral_ch = [0,155,320,465,640,775,960]
43 | else:
44 | bilateral_ch = range(opt.n_channels)
45 | C, H, W = prob.shape
46 | U = unary_from_softmax(prob)
47 | d = dcrf.DenseCRF2D(H, W, C)
48 | d.setUnaryEnergy(U)
49 | pairwise_energy = create_pairwise_bilateral(sdims=(sdims, sdims), schan=(schan,), img=im[bilateral_ch, :, :], chdim=0)
50 | d.addPairwiseEnergy(pairwise_energy, compat=compat)
51 | Q_unary = d.inference(iters)
52 | Q_unary = np.array(Q_unary).reshape(-1, H, W)
53 | return Q_unary
54 |
55 |
56 | def test(opt, test_loader, net, split):
57 | start_time = time.time()
58 | eva = Evaluator(opt.n_classes, opt.bg_err)
59 | eva_crf = Evaluator(opt.n_classes, opt.bg_err)
60 | ims = []
61 | labels = []
62 |
63 | net = net.eval()
64 |
65 | for iteration, batch in enumerate(test_loader):
66 | im, label = batch
67 | im = im.cuda()
68 | label = label.cuda()
69 | out = net(im)
70 | prob = F.softmax(out, dim=1)
71 | for i in range(opt.batch_size):
72 | prob_np = prob[i].detach().cpu().numpy()
73 | label_np = label[i].cpu().numpy()
74 | im_np = im[i].cpu().numpy()
75 | ims.append(to_image(im[i,:3,:,:]))
76 | labels.append(label_np)
77 | eva.register(label_np, prob_np)
78 | prob_crf = crf(prob_np, im_np, opt.sdims, opt.schan, opt.compat, opt.iters)
79 | eva_crf.register(label_np, prob_crf)
80 | print(str(iteration * opt.batch_size + i).zfill(2), time.time() - start_time, 'seconds')
81 |
82 | msa, preds_msa, miou, miiou, preds_miou = eva.evaluate()
83 | msa_crf, preds_msa_crf, miou_crf, miiou_crf, preds_miou_crf = eva_crf.evaluate()
84 | print('Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)))
85 | print('Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)))
86 | for i, label in enumerate(labels):
87 | pred_msa = preds_msa[i]
88 | pred_msa_crf = preds_msa_crf[i]
89 | pred_miou = preds_miou[i]
90 | pred_miou_crf = preds_miou_crf[i]
91 | vis_im = ims[i]
92 | vis_label = colormap(label)
93 | vis_pred_msa = colormap(pred_msa)
94 | vis_pred_msa_crf = colormap(pred_msa_crf)
95 | vis_pred_miou = colormap(pred_miou)
96 | vis_pred_miou_crf = colormap(pred_miou_crf)
97 | vis_all = np.concatenate((
98 | np.concatenate((vis_im, vis_label), axis=2),
99 | np.concatenate((vis_pred_miou, vis_pred_miou_crf), axis=2)), axis=1)
100 | vis_all = vis_all.transpose((1, 2, 0))
101 | io.imsave(Path(opt.out_path) / split / (str(i).zfill(2) + '.png'), vis_all)
102 | return msa, miou, miiou, msa_crf, miou_crf, miiou_crf
103 |
104 | if __name__ == '__main__':
105 | cv2.setNumThreads(0)
106 |
107 | opt = parse_args()
108 | print(opt)
109 |
110 | (Path(opt.out_path) / 'test').mkdir(parents=True, exist_ok=True)
111 | (Path(opt.out_path) / 'val').mkdir(parents=True, exist_ok=True)
112 |
113 | checkpoint = torch.load(opt.ckpt)
114 |
115 | opt.channels = checkpoint['opt'].channels if 'channels' in checkpoint['opt'].__dict__ else 965
116 | opt.n_channels = checkpoint['opt'].n_channels if 'n_channels' in checkpoint['opt'].__dict__ else abs(opt.channels)
117 | opt.n_classes = checkpoint['opt'].n_classes
118 | opt.arch = checkpoint['opt'].arch
119 |
120 | test_set = RealDataset(opt.real_path, opt.channels, split='test')
121 | test_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False)
122 |
123 | val_set = RealDataset(opt.real_path, opt.channels, split='val')
124 | val_loader = DataLoader(dataset=val_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False)
125 |
126 | net = PowderNet(opt.arch, opt.n_channels, opt.n_classes)
127 | net = net.cuda()
128 | net.load_state_dict(checkpoint['state_dict'])
129 |
130 | log_file = open(Path(opt.out_path) / 'performance.txt', 'w')
131 | print(opt, file=log_file)
132 | msa, miou, miiou, msa_crf, miou_crf, miiou_crf = test(opt, test_loader, net, 'test')
133 | print('Test Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)), file=log_file)
134 | print('Test Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)), file=log_file)
135 | msa, miou, miiou, msa_crf, miou_crf, miiou_crf = test(opt, val_loader, net, 'val')
136 | print('Val Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)), file=log_file)
137 | print('Val Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)), file=log_file)
138 | print('Complete', file=log_file)
139 | log_file.close()
140 |
--------------------------------------------------------------------------------
/src/test_merge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import pydensecrf.densecrf as dcrf
8 | from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral
9 | from pathlib import Path
10 | from torch.utils.data import DataLoader
11 | from dataset import RealDataset
12 | from model import PowderNet
13 | from utils import to_image, colormap, errormap
14 | from evaluator import Evaluator
15 | import skimage.io as io
16 | import cv2
17 | import collections
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Testing')
22 | parser.add_argument('--ckpt', type=str)
23 | parser.add_argument('--real-path', type=str, default='../real')
24 | parser.add_argument('--out-path', type=str, default='./result')
25 | parser.add_argument('--bg-err', type=float, default=1.0)
26 | parser.add_argument('--sdims', type=int, default=3)
27 | parser.add_argument('--schan', type=int, default=3)
28 | parser.add_argument('--compat', type=int, default=3)
29 | parser.add_argument('--iters', type=int, default=10)
30 | parser.add_argument('--threads', type=int, default=1)
31 | parser.add_argument('--batch-size', type=int, default=1)
32 | opt = parser.parse_args()
33 | return opt
34 |
35 |
36 | def crf(prob, im, sdims, schan, compat, iters):
37 | if opt.channels == 965:
38 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964]
39 | elif opt.channels == 4:
40 | bilateral_ch = [0,1,2,3]
41 | elif opt.channels == -961:
42 | bilateral_ch = [0,155,320,465,640,775,960]
43 | else:
44 | bilateral_ch = range(opt.n_channels)
45 | C, H, W = prob.shape
46 | U = unary_from_softmax(prob)
47 | d = dcrf.DenseCRF2D(H, W, C)
48 | d.setUnaryEnergy(U)
49 | pairwise_energy = create_pairwise_bilateral(sdims=(sdims, sdims), schan=(schan,), img=im[bilateral_ch, :, :], chdim=0)
50 | d.addPairwiseEnergy(pairwise_energy, compat=compat)
51 | Q_unary = d.inference(iters)
52 | Q_unary = np.array(Q_unary).reshape(-1, H, W)
53 | return Q_unary
54 |
55 |
56 | def test(opt, test_loader, testext_loader, net, split):
57 | start_time = time.time()
58 | eva = Evaluator(opt.n_classes, opt.bg_err)
59 | eva_crf = Evaluator(opt.n_classes, opt.bg_err)
60 | ims = []
61 | labels = []
62 |
63 | net = net.eval()
64 |
65 | for iteration, batch in enumerate(test_loader):
66 | im, label = batch
67 | im = im.cuda()
68 | label = label.cuda()
69 | out = net(im)
70 | prob = F.softmax(out, dim=1)
71 | for i in range(opt.batch_size):
72 | prob_np = prob[i].detach().cpu().numpy()
73 | label_np = label[i].cpu().numpy()
74 | im_np = im[i].cpu().numpy()
75 | ims.append(to_image(im[i,:3,:,:]))
76 | labels.append(label_np)
77 | eva.register(label_np, prob_np)
78 | prob_crf = crf(prob_np, im_np, opt.sdims, opt.schan, opt.compat, opt.iters)
79 | eva_crf.register(label_np, prob_crf)
80 | print('test', str(iteration * opt.batch_size + i).zfill(2), time.time() - start_time, 'seconds')
81 |
82 | for iteration, batch in enumerate(testext_loader):
83 | im, label = batch
84 | im = im.cuda()
85 | label = label.cuda()
86 | out = net(im)
87 | prob = F.softmax(out, dim=1)
88 | for i in range(opt.batch_size):
89 | prob_np = prob[i].detach().cpu().numpy()
90 | label_np = label[i].cpu().numpy()
91 | im_np = im[i].cpu().numpy()
92 | ims.append(to_image(im[i,:3,:,:]))
93 | labels.append(label_np)
94 | eva.register(label_np, prob_np)
95 | prob_crf = crf(prob_np, im_np, opt.sdims, opt.schan, opt.compat, opt.iters)
96 | eva_crf.register(label_np, prob_crf)
97 | print('testext', str(iteration * opt.batch_size + i).zfill(2), time.time() - start_time, 'seconds')
98 |
99 | msa, preds_msa, miou, miiou, preds_miou = eva.evaluate()
100 | msa_crf, preds_msa_crf, miou_crf, miiou_crf, preds_miou_crf = eva_crf.evaluate()
101 | print('Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)))
102 | print('Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)))
103 | for i, label in enumerate(labels):
104 | pred_msa = preds_msa[i]
105 | pred_msa_crf = preds_msa_crf[i]
106 | pred_miou = preds_miou[i]
107 | pred_miou_crf = preds_miou_crf[i]
108 | vis_im = ims[i]
109 | vis_label = colormap(label)
110 | vis_pred_msa = colormap(pred_msa)
111 | vis_pred_msa_crf = colormap(pred_msa_crf)
112 | vis_pred_miou = colormap(pred_miou)
113 | vis_pred_miou_crf = colormap(pred_miou_crf)
114 | vis_all = np.concatenate((
115 | np.concatenate((vis_im, vis_label), axis=2),
116 | np.concatenate((vis_pred_miou, vis_pred_miou_crf), axis=2)), axis=1)
117 | vis_all = vis_all.transpose((1, 2, 0))
118 | io.imsave(Path(opt.out_path) / split / (str(i).zfill(2) + '.png'), vis_all)
119 | return msa, miou, miiou, msa_crf, miou_crf, miiou_crf
120 |
121 | if __name__ == '__main__':
122 | cv2.setNumThreads(0)
123 |
124 | opt = parse_args()
125 | print(opt)
126 |
127 | (Path(opt.out_path) / 'merge').mkdir(parents=True, exist_ok=True)
128 |
129 | checkpoint = torch.load(opt.ckpt)
130 |
131 | opt.channels = checkpoint['opt'].channels if 'channels' in checkpoint['opt'].__dict__ else 965
132 | opt.n_channels = checkpoint['opt'].n_channels if 'n_channels' in checkpoint['opt'].__dict__ else abs(opt.channels)
133 | opt.n_classes = checkpoint['opt'].n_classes
134 | opt.arch = checkpoint['opt'].arch
135 |
136 | test_set = RealDataset(opt.real_path, opt.channels, split='test')
137 | test_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False)
138 |
139 | testext_set = RealDataset(opt.real_path, opt.channels, split='testext')
140 | testext_loader = DataLoader(dataset=testext_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False)
141 |
142 | net = PowderNet(opt.arch, opt.n_channels, opt.n_classes)
143 | net = net.cuda()
144 | net.load_state_dict(checkpoint['state_dict'])
145 |
146 | log_file = open(Path(opt.out_path) / 'performance_merge.txt', 'w')
147 | print(opt, file=log_file)
148 | msa, miou, miiou, msa_crf, miou_crf, miiou_crf = test(opt, test_loader, testext_loader, net, 'merge')
149 | print('Merge Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)), file=log_file)
150 | print('Merge Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)), file=log_file)
151 | print('Complete', file=log_file)
152 | log_file.close()
153 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | from cosine_scheduler import CosineLRWithRestarts
10 | from pathlib import Path
11 | from torch.utils.data import DataLoader
12 | from dataset import SyntheticDataset, RealDataset
13 | from visualizer import Visualizer
14 | from model import PowderNet
15 | from utils import to_image, colormap, errormap
16 | from adamw import AdamW
17 | import cv2
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description='Training')
22 | parser.add_argument('--syn-path', type=str, default='../synthetic')
23 | parser.add_argument('--real-path', type=str, default='../real')
24 | parser.add_argument('--params-path', type=str, default='../params')
25 | parser.add_argument('--out-path', type=str, default='./checkpoint')
26 | parser.add_argument('--channels', type=int, choices=[965, 4, -961, 0], default=0, help='x>0 select [:x]; x<0 select [x:]; x=0 see --bands')
27 | parser.add_argument('--bands', type=str, default=None)
28 | parser.add_argument('--resume', type=str, default=None)
29 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa')
30 | parser.add_argument('--arch', type=str, choices=['deeplab'], default='deeplab')
31 | parser.add_argument('--threads', type=int, default=6)
32 | parser.add_argument('--batch-size', type=int, default=8)
33 | parser.add_argument('--n-epochs', type=int, default=248)
34 | parser.add_argument('--lr', type=float, default=1e-3)
35 | parser.add_argument('--decay', type=float, default=1e-4)
36 | parser.add_argument('--period', type=int, default=8)
37 | parser.add_argument('--t-mult', type=float, default=2)
38 | parser.add_argument('--vis-iter', type=int, default=0)
39 | parser.add_argument('--server', type=str, default='http://localhost')
40 | parser.add_argument('--env', type=str, default='main')
41 | opt = parser.parse_args()
42 | if opt.bands is not None:
43 | assert(opt.channels == 0)
44 | opt.channels = [int(i) for i in opt.bands.split(',')]
45 | opt.n_channels = len(opt.channels)
46 | else:
47 | opt.n_channels = abs(opt.channels)
48 | return opt
49 |
50 |
51 | def train(opt, vis, epoch, train_loader, net, optimizer, scheduler):
52 | net = net.train()
53 | train_len = len(train_loader)
54 | start_time = time.time()
55 | scheduler.step()
56 | for iteration, batch in enumerate(train_loader):
57 | # Load Data
58 | im, label = batch
59 | im = im.cuda(non_blocking=True)
60 | label = label.cuda(non_blocking=True)
61 |
62 | # Forward Pass
63 | out = net(im)
64 | loss = F.cross_entropy(out, label)
65 |
66 | # Backward Pass
67 | optimizer.zero_grad()
68 | loss.backward()
69 | optimizer.step()
70 | scheduler.batch_step()
71 |
72 | # Logging
73 | cur_time = time.time()
74 | loss_scalar = float(loss.cpu().detach().numpy())
75 | if iteration < opt.threads:
76 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \
77 | round((cur_time - start_time) / (iteration + 1), 2), \
78 | round(loss_scalar, 4)))
79 | if iteration == opt.threads - 1:
80 | start_time = cur_time
81 | else:
82 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \
83 | round((cur_time - start_time) / (iteration + 1 - opt.threads), 2), \
84 | round(loss_scalar, 4)))
85 |
86 | # Visualization
87 | vis.iteration.append(epoch + iteration / train_len)
88 | vis.nlogloss.append(-np.log(np.maximum(1e-6, loss_scalar)))
89 | vis.plot_loss()
90 | if opt.channels != 965 or opt.vis_iter <= 0 or iteration % opt.vis_iter > 0:
91 | continue
92 | prob, pred = torch.max(out, dim=1)
93 | vis_rgb = to_image(im[0, 0:3, :, :] * 0.5)
94 | vis_nir = to_image(im[0, 3:4, :, :] * 0.5)
95 | vis_swir1 = to_image(im[0, 4:5, :, :] * 0.5)
96 | vis_swir2 = to_image(im[0, 964:965, :, :] * 0.5)
97 | vis_label = colormap(label[0].cpu().numpy())
98 | vis_pred = colormap(pred[0].cpu().numpy())
99 | vis_im = np.concatenate((np.concatenate((vis_label, vis_pred), axis=1), \
100 | np.concatenate((vis_rgb, vis_nir), axis=1), \
101 | np.concatenate((vis_swir1, vis_swir2), axis=1)), axis=2)
102 | vis.plot_image(vis_im, 0)
103 |
104 |
105 | def test(opt, epoch, test_loader, net):
106 | if opt.channels == 965:
107 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964]
108 | elif opt.channels == 4:
109 | bilateral_ch = [0,1,2,3]
110 | elif opt.channels == -961:
111 | bilateral_ch = [0,155,320,465,640,775,960]
112 | else:
113 | bilateral_ch = range(opt.n_channels)
114 | net = net.eval()
115 | test_len = len(test_loader)
116 | tp = np.zeros(opt.n_classes)
117 | fp = np.zeros(opt.n_classes)
118 | num = np.zeros(opt.n_classes)
119 | start_time = time.time()
120 | for iteration, batch in enumerate(test_loader):
121 | # Load Data
122 | im, label = batch
123 | im = im.cuda()
124 | label = label.cuda()
125 |
126 | # Forward Pass
127 | out = net(im)
128 |
129 | # Evaluation
130 | prob = F.softmax(out, dim=1)
131 | _, pred = torch.max(prob, dim=1)
132 |
133 | bsize = pred.size()[0]
134 |
135 | for i in range(bsize):
136 | label_np = label[i].cpu().numpy()
137 | pred_np = pred[i].cpu().numpy()
138 | for c in range(opt.n_classes):
139 | mask = (label_np == c)
140 | tp[c] += ((pred_np == c) * mask).sum()
141 | fp[c] += ((pred_np == c) * (1 - mask)).sum()
142 | num[c] += mask.sum()
143 |
144 | iou = tp / (num + fp)
145 | miou = iou.mean()
146 |
147 | return miou
148 |
149 |
150 | if __name__ == '__main__':
151 | cv2.setNumThreads(0)
152 |
153 | opt = parse_args()
154 | print(opt)
155 |
156 | Path(opt.out_path).mkdir(parents=True, exist_ok=True)
157 |
158 | train_set = SyntheticDataset(opt.syn_path, opt.params_path, opt.blend, opt.channels)
159 | train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
160 |
161 | val_set = RealDataset(opt.real_path, opt.channels, split='val')
162 | val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)
163 |
164 | test_set = RealDataset(opt.real_path, opt.channels, split='test')
165 | test_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False)
166 |
167 | opt.n_classes = train_set.n_classes
168 | net = PowderNet(opt.arch, opt.n_channels, train_set.n_classes)
169 | net = net.cuda()
170 | optimizer = AdamW(net.parameters(), lr=opt.lr, weight_decay=opt.decay)
171 | scheduler = CosineLRWithRestarts(optimizer, opt.batch_size, len(train_set), opt.period, opt.t_mult)
172 | vis = Visualizer(server=opt.server, env=opt.env)
173 | start_epoch = 0
174 | if opt.resume is not None:
175 | checkpoint = torch.load(opt.resume)
176 | old_opt = checkpoint['opt']
177 | assert(old_opt.channels == opt.channels)
178 | assert(old_opt.bands == opt.bands)
179 | assert(old_opt.arch == opt.arch)
180 | assert(old_opt.blend == opt.blend)
181 | assert(old_opt.lr == opt.lr)
182 | assert(old_opt.decay == opt.decay)
183 | assert(old_opt.period == opt.period)
184 | assert(old_opt.t_mult == opt.t_mult)
185 |
186 | net.load_state_dict(checkpoint['state_dict'])
187 | optimizer.load_state_dict(checkpoint['optimizer'])
188 | scheduler.load_state_dict(checkpoint['scheduler'])
189 | vis.load_state_dict(checkpoint['vis'])
190 | start_epoch = checkpoint['epoch'] + 1
191 |
192 | for epoch in range(start_epoch, opt.n_epochs):
193 | train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
194 | miou_val = test(opt, epoch, val_loader, net)
195 | miou_test = test(opt, epoch, test_loader, net)
196 | vis.epoch.append(epoch)
197 | vis.acc.append([miou_val, miou_test])
198 | vis.plot_acc()
199 | if (epoch + 1) % opt.period == 0:
200 | torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth'))
201 | print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)
202 |
--------------------------------------------------------------------------------
/src/finetune_real.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | from cosine_scheduler import CosineLRWithRestarts
10 | from pathlib import Path
11 | from torch.utils.data import DataLoader
12 | from dataset import RealDataset
13 | from visualizer import Visualizer
14 | from model import PowderNet
15 | from utils import to_image, colormap, errormap
16 | from adamw import AdamW
17 | from model import get_1x_lr_params, get_10x_lr_params
18 | import cv2
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='PowderDetector')
23 | parser.add_argument('--real-path', type=str, default='../real')
24 | parser.add_argument('--params-path', type=str, default='../params')
25 | parser.add_argument('--out-path', type=str, default='./checkpoint')
26 | parser.add_argument('--channels', type=int, choices=[965, 4, -961, 0], default=0, help='x>0 select [:x]; x<0 select [x:]; x=0 see --bands')
27 | parser.add_argument('--bands', type=str, default=None)
28 | parser.add_argument('--pretrain', type=str, default=None)
29 | parser.add_argument('--resume', type=str, default=None)
30 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa')
31 | parser.add_argument('--arch', type=str, choices=['deeplab'], default='deeplab')
32 | parser.add_argument('--threads', type=int, default=6)
33 | parser.add_argument('--batch-size', type=int, default=8)
34 | parser.add_argument('--n-epochs', type=int, default=24)
35 | parser.add_argument('--lr', type=float, default=5e-5)
36 | parser.add_argument('--decay', type=float, default=1e-4)
37 | parser.add_argument('--period', type=int, default=8)
38 | parser.add_argument('--t-mult', type=float, default=2)
39 | parser.add_argument('--vis-iter', type=int, default=0)
40 | parser.add_argument('--server', type=str, default='http://localhost')
41 | parser.add_argument('--env', type=str, default='main')
42 | opt = parser.parse_args()
43 | if opt.bands is not None:
44 | assert(opt.channels == 0)
45 | opt.channels = [int(i) for i in opt.bands.split(',')]
46 | opt.n_channels = len(opt.channels)
47 | else:
48 | opt.n_channels = abs(opt.channels)
49 | return opt
50 |
51 |
52 | def train(opt, vis, epoch, train_loader, net, optimizer, scheduler):
53 | net = net.train()
54 | train_len = len(train_loader)
55 | start_time = time.time()
56 | scheduler.step()
57 | for iteration, batch in enumerate(train_loader):
58 | # Load Data
59 | im, label = batch
60 | im = im.cuda(non_blocking=True)
61 | label = label.cuda(non_blocking=True)
62 |
63 | # Forward Pass
64 | out = net(im)
65 | loss = F.cross_entropy(out, label)
66 |
67 | # Backward Pass
68 | optimizer.zero_grad()
69 | loss.backward()
70 | optimizer.step()
71 | scheduler.batch_step()
72 |
73 | # Logging
74 | cur_time = time.time()
75 | loss_scalar = float(loss.cpu().detach().numpy())
76 | if iteration < opt.threads:
77 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \
78 | round((cur_time - start_time) / (iteration + 1), 2), \
79 | round(loss_scalar, 4)))
80 | if iteration == opt.threads - 1:
81 | start_time = cur_time
82 | else:
83 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \
84 | round((cur_time - start_time) / (iteration + 1 - opt.threads), 2), \
85 | round(loss_scalar, 4)))
86 |
87 | # Visualization
88 | vis.iteration.append(epoch + iteration / train_len)
89 | vis.nlogloss.append(-np.log(np.maximum(1e-6, loss_scalar)))
90 | vis.plot_loss()
91 | if opt.vis_iter <= 0 or iteration % opt.vis_iter > 0:
92 | continue
93 | prob, pred = torch.max(out, dim=1)
94 | vis_rgb = to_image(im[0, 0:3, :, :] * 0.5)
95 | vis_nir = to_image(im[0, 3:4, :, :] * 0.5)
96 | vis_swir1 = to_image(im[0, 4:5, :, :] * 0.5)
97 | vis_swir2 = to_image(im[0, -2:-1, :, :] * 0.5)
98 | vis_label = colormap(label[0].cpu().numpy())
99 | vis_pred = colormap(pred[0].cpu().numpy())
100 | vis_im = np.concatenate((np.concatenate((vis_label, vis_pred), axis=1), \
101 | np.concatenate((vis_rgb, vis_nir), axis=1), \
102 | np.concatenate((vis_swir1, vis_swir2), axis=1)), axis=2)
103 | vis.plot_image(vis_im, 0)
104 |
105 |
106 | def test(opt, epoch, test_loader, net):
107 | if opt.channels == 965:
108 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964]
109 | elif opt.channels == 4:
110 | bilateral_ch = [0,1,2,3]
111 | elif opt.channels == -961:
112 | bilateral_ch = [0,155,320,465,640,775,960]
113 | else:
114 | bilateral_ch = range(opt.n_channels)
115 | net = net.eval()
116 | test_len = len(test_loader)
117 | tp = np.zeros(opt.n_classes)
118 | fp = np.zeros(opt.n_classes)
119 | tp_crf = np.zeros(opt.n_classes)
120 | fp_crf = np.zeros(opt.n_classes)
121 | num = np.zeros(opt.n_classes)
122 | start_time = time.time()
123 | for iteration, batch in enumerate(test_loader):
124 | # Load Data
125 | im, label = batch
126 | im = im.cuda()
127 | label = label.cuda()
128 |
129 | # Forward Pass
130 | out = net(im)
131 |
132 | # Visualization
133 | prob = F.softmax(out, dim=1)
134 | _, pred = torch.max(prob, dim=1)
135 |
136 | bsize = pred.size()[0]
137 |
138 | for i in range(bsize):
139 | label_np = label[i].cpu().numpy()
140 | pred_np = pred[i].cpu().numpy()
141 | for c in range(opt.n_classes):
142 | mask = (label_np == c)
143 | tp[c] += ((pred_np == c) * mask).sum()
144 | fp[c] += ((pred_np == c) * (1 - mask)).sum()
145 | num[c] += mask.sum()
146 |
147 | iou = tp / (num + fp)
148 | miou = iou.mean()
149 | return miou
150 |
151 |
152 | if __name__ == '__main__':
153 | cv2.setNumThreads(0)
154 |
155 | opt = parse_args()
156 | print(opt)
157 |
158 | Path(opt.out_path).mkdir(parents=True, exist_ok=True)
159 |
160 | train_set = RealDataset(opt.real_path, opt.channels, split='trainext', flip=True)
161 | train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
162 |
163 | val_set = RealDataset(opt.real_path, opt.channels, split='val')
164 | val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)
165 |
166 | test_set = RealDataset(opt.real_path, opt.channels, split='test')
167 | test_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False)
168 |
169 | opt.n_classes = train_set.n_classes
170 | net = PowderNet(opt.arch, opt.n_channels, train_set.n_classes)
171 | net = net.cuda()
172 | optimizer = AdamW([{'params': get_1x_lr_params(net)}, {'params': get_10x_lr_params(net), 'lr': opt.lr * 10}], lr=opt.lr, weight_decay=opt.decay)
173 | scheduler = CosineLRWithRestarts(optimizer, opt.batch_size, len(train_set), opt.period, opt.t_mult)
174 | vis = Visualizer(server=opt.server, env=opt.env)
175 | start_epoch = 0
176 | if opt.resume is not None:
177 | checkpoint = torch.load(opt.resume)
178 | old_opt = checkpoint['opt']
179 | assert(old_opt.channels == opt.channels)
180 | assert(old_opt.bands == opt.bands)
181 | assert(old_opt.arch == opt.arch)
182 | assert(old_opt.blend == opt.blend)
183 | assert(old_opt.lr == opt.lr)
184 | assert(old_opt.decay == opt.decay)
185 | assert(old_opt.period == opt.period)
186 | assert(old_opt.t_mult == opt.t_mult)
187 | net.load_state_dict(checkpoint['state_dict'])
188 | optimizer.load_state_dict(checkpoint['optimizer'])
189 | scheduler.load_state_dict(checkpoint['scheduler'])
190 | vis.load_state_dict(checkpoint['vis'])
191 | start_epoch = checkpoint['epoch'] + 1
192 | elif opt.pretrain is not None:
193 | checkpoint = torch.load(opt.pretrain)
194 | old_opt = checkpoint['opt']
195 | assert(old_opt.channels == opt.channels)
196 | assert(old_opt.bands == opt.bands)
197 | assert(old_opt.arch == opt.arch)
198 | assert(old_opt.blend == opt.blend)
199 | net.load_state_dict(checkpoint['state_dict'])
200 | else:
201 | assert(False)
202 |
203 | for epoch in range(start_epoch, opt.n_epochs):
204 | train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
205 | miou_val = test(opt, epoch, val_loader, net)
206 | miou_test = test(opt, epoch, test_loader, net)
207 | vis.epoch.append(epoch)
208 | vis.acc.append([miou_val, miou_test])
209 | vis.plot_acc()
210 | if (epoch + 1) % opt.period == 0:
211 | torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth'))
212 | print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)
213 |
--------------------------------------------------------------------------------
/src/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import torch.optim.lr_scheduler as lr_scheduler
9 | from cosine_scheduler import CosineLRWithRestarts
10 | from pathlib import Path
11 | from torch.utils.data import DataLoader
12 | from dataset import RealDataset, HalfHalfDataset
13 | from visualizer import Visualizer
14 | from model import PowderNet
15 | from utils import to_image, colormap, errormap
16 | from adamw import AdamW
17 | from model import get_1x_lr_params, get_10x_lr_params
18 | import cv2
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(description='PowderDetector')
23 | parser.add_argument('--syn-path', type=str, default='../synthetic')
24 | parser.add_argument('--real-path', type=str, default='../real')
25 | parser.add_argument('--params-path', type=str, default='../params')
26 | parser.add_argument('--out-path', type=str, default='./checkpoint')
27 | parser.add_argument('--split', type=str, default='bgext', choices=['bg', 'bgext'])
28 | parser.add_argument('--channels', type=int, choices=[965, 4, -961, 0], default=0, help='x>0 select [:x]; x<0 select [x:]; x=0 see --bands')
29 | parser.add_argument('--bands', type=str, default=None)
30 | parser.add_argument('--pretrain', type=str, default=None)
31 | parser.add_argument('--resume', type=str, default=None)
32 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa')
33 | parser.add_argument('--arch', type=str, choices=['deeplab'], default='deeplab')
34 | parser.add_argument('--threads', type=int, default=6)
35 | parser.add_argument('--batch-size', type=int, default=8)
36 | parser.add_argument('--n-epochs', type=int, default=56)
37 | parser.add_argument('--lr', type=float, default=1e-4)
38 | parser.add_argument('--decay', type=float, default=1e-4)
39 | parser.add_argument('--period', type=int, default=8)
40 | parser.add_argument('--t-mult', type=float, default=2)
41 | parser.add_argument('--vis-iter', type=int, default=0)
42 | parser.add_argument('--server', type=str, default='http://localhost')
43 | parser.add_argument('--env', type=str, default='main')
44 | opt = parser.parse_args()
45 | if opt.bands is not None:
46 | assert(opt.channels == 0)
47 | opt.channels = [int(i) for i in opt.bands.split(',')]
48 | opt.n_channels = len(opt.channels)
49 | else:
50 | opt.n_channels = abs(opt.channels)
51 | return opt
52 |
53 |
54 | def train(opt, vis, epoch, train_loader, net, optimizer, scheduler):
55 | net = net.train()
56 | train_len = len(train_loader)
57 | start_time = time.time()
58 | scheduler.step()
59 | for iteration, batch in enumerate(train_loader):
60 | # Load Data
61 | im, label = batch
62 | im = im.cuda(non_blocking=True)
63 | label = label.cuda(non_blocking=True)
64 |
65 | # Forward Pass
66 | out = net(im)
67 | loss = F.cross_entropy(out, label)
68 |
69 | # Backward Pass
70 | optimizer.zero_grad()
71 | loss.backward()
72 | optimizer.step()
73 | scheduler.batch_step()
74 |
75 | # Logging
76 | cur_time = time.time()
77 | loss_scalar = float(loss.cpu().detach().numpy())
78 | if iteration < opt.threads:
79 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \
80 | round((cur_time - start_time) / (iteration + 1), 2), \
81 | round(loss_scalar, 4)))
82 | if iteration == opt.threads - 1:
83 | start_time = cur_time
84 | else:
85 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \
86 | round((cur_time - start_time) / (iteration + 1 - opt.threads), 2), \
87 | round(loss_scalar, 4)))
88 |
89 | # Visualization
90 | vis.iteration.append(epoch + iteration / train_len)
91 | vis.nlogloss.append(-np.log(np.maximum(1e-6, loss_scalar)))
92 | vis.plot_loss()
93 | if opt.vis_iter <= 0 or iteration % opt.vis_iter > 0:
94 | continue
95 | prob, pred = torch.max(out, dim=1)
96 | vis_rgb = to_image(im[0, 0:3, :, :] * 0.5)
97 | vis_nir = to_image(im[0, 3:4, :, :] * 0.5)
98 | vis_swir1 = to_image(im[0, 4:5, :, :] * 0.5)
99 | vis_swir2 = to_image(im[0, -2:-1, :, :] * 0.5)
100 | vis_label = colormap(label[0].cpu().numpy())
101 | vis_pred = colormap(pred[0].cpu().numpy())
102 | vis_im = np.concatenate((np.concatenate((vis_label, vis_pred), axis=1), \
103 | np.concatenate((vis_rgb, vis_nir), axis=1), \
104 | np.concatenate((vis_swir1, vis_swir2), axis=1)), axis=2)
105 | vis.plot_image(vis_im, 0)
106 |
107 |
108 | def test(opt, epoch, test_loader, net):
109 | if opt.channels == 965:
110 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964]
111 | elif opt.channels == 4:
112 | bilateral_ch = [0,1,2,3]
113 | elif opt.channels == -961:
114 | bilateral_ch = [0,155,320,465,640,775,960]
115 | else:
116 | bilateral_ch = range(opt.n_channels)
117 | net = net.eval()
118 | test_len = len(test_loader)
119 | tp = np.zeros(opt.n_classes)
120 | fp = np.zeros(opt.n_classes)
121 | tp_crf = np.zeros(opt.n_classes)
122 | fp_crf = np.zeros(opt.n_classes)
123 | num = np.zeros(opt.n_classes)
124 | start_time = time.time()
125 | for iteration, batch in enumerate(test_loader):
126 | # Load Data
127 | im, label = batch
128 | im = im.cuda()
129 | label = label.cuda()
130 |
131 | # Forward Pass
132 | out = net(im)
133 |
134 | # Visualization
135 | prob = F.softmax(out, dim=1)
136 | _, pred = torch.max(prob, dim=1)
137 |
138 | bsize = pred.size()[0]
139 |
140 | for i in range(bsize):
141 | label_np = label[i].cpu().numpy()
142 | pred_np = pred[i].cpu().numpy()
143 | for c in range(opt.n_classes):
144 | mask = (label_np == c)
145 | tp[c] += ((pred_np == c) * mask).sum()
146 | fp[c] += ((pred_np == c) * (1 - mask)).sum()
147 | num[c] += mask.sum()
148 |
149 | iou = tp / (num + fp)
150 | miou = iou.mean()
151 | return miou
152 |
153 |
154 | if __name__ == '__main__':
155 | cv2.setNumThreads(0)
156 |
157 | opt = parse_args()
158 | print(opt)
159 |
160 | Path(opt.out_path).mkdir(parents=True, exist_ok=True)
161 |
162 | train_set = HalfHalfDataset(opt.real_path, opt.syn_path, opt.params_path, opt.blend, opt.channels, opt.split)
163 | train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
164 |
165 | val_set = RealDataset(opt.real_path, opt.channels, split='val')
166 | val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)
167 |
168 | test_set = RealDataset(opt.real_path, opt.channels, split='test')
169 | test_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False)
170 |
171 | opt.n_classes = train_set.n_classes
172 | net = PowderNet(opt.arch, opt.n_channels, train_set.n_classes)
173 | net = net.cuda()
174 | optimizer = AdamW([{'params': get_1x_lr_params(net)}, {'params': get_10x_lr_params(net), 'lr': opt.lr * 10}], lr=opt.lr, weight_decay=opt.decay)
175 | scheduler = CosineLRWithRestarts(optimizer, opt.batch_size, len(train_set), opt.period, opt.t_mult)
176 | vis = Visualizer(server=opt.server, env=opt.env)
177 | start_epoch = 0
178 | if opt.resume is not None:
179 | checkpoint = torch.load(opt.resume)
180 | old_opt = checkpoint['opt']
181 | assert(old_opt.channels == opt.channels)
182 | assert(old_opt.bands == opt.bands)
183 | assert(old_opt.arch == opt.arch)
184 | assert(old_opt.blend == opt.blend)
185 | assert(old_opt.lr == opt.lr)
186 | assert(old_opt.decay == opt.decay)
187 | assert(old_opt.period == opt.period)
188 | assert(old_opt.t_mult == opt.t_mult)
189 | net.load_state_dict(checkpoint['state_dict'])
190 | optimizer.load_state_dict(checkpoint['optimizer'])
191 | scheduler.load_state_dict(checkpoint['scheduler'])
192 | vis.load_state_dict(checkpoint['vis'])
193 | start_epoch = checkpoint['epoch'] + 1
194 | elif opt.pretrain is not None:
195 | checkpoint = torch.load(opt.pretrain)
196 | old_opt = checkpoint['opt']
197 | #assert(old_opt.channels == opt.channels)
198 | #assert(old_opt.bands == opt.bands)
199 | assert(old_opt.arch == opt.arch)
200 | assert(old_opt.blend == opt.blend)
201 | net.load_state_dict(checkpoint['state_dict'])
202 | else:
203 | assert(False)
204 |
205 | for epoch in range(start_epoch, opt.n_epochs):
206 | train(opt, vis, epoch, train_loader, net, optimizer, scheduler)
207 | miou_val = test(opt, epoch, val_loader, net)
208 | miou_test = test(opt, epoch, test_loader, net)
209 | vis.epoch.append(epoch)
210 | vis.acc.append([miou_val, miou_test])
211 | vis.plot_acc()
212 | if (epoch + 1) % opt.period == 0:
213 | torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth'))
214 | print('Val mIoU:', miou_val, ' Test mIoU:', miou_test)
215 |
--------------------------------------------------------------------------------
/recog/recognition.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import cv2
7 | import time
8 |
9 | from pathlib import Path
10 |
11 |
12 | def cosine(a, b):
13 | y = b.unsqueeze(0)
14 | n_pixels = a.size()[0]
15 | batch_size = 1000
16 | if n_pixels % batch_size == 0:
17 | n_batches = n_pixels // batch_size
18 | else:
19 | n_batches = n_pixels // batch_size + 1
20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda()
21 | for batch_idx in range(n_batches):
22 | bs = batch_idx * batch_size
23 | be = min(n_pixels, (batch_idx + 1) * batch_size)
24 | x = a[bs:be, :].unsqueeze(1)
25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2)
26 | return sim
27 |
28 |
29 | def sim_func(opt, query, database):
30 | if opt.dist == 'full':
31 | return cosine(query, database)
32 | elif opt.dist == 'split':
33 | return cosine(query[:, :opt.n_rgbns], database[:, :opt.n_rgbns]) + cosine(query[:, opt.n_rgbns:], database[:, opt.n_rgbns:])
34 | else:
35 | assert(False)
36 |
37 |
38 | def sims2pred(n_lights, sims):
39 | votes = sims.argmax(dim=1) // n_lights
40 | votes = votes.cpu().numpy()
41 | counts = collections.Counter(votes)
42 | pred = [i[0] for i in counts.most_common()]
43 | return pred
44 |
45 |
46 | def match_powder_none(opt, database, scene):
47 | n_pixels = scene.size()[0]
48 | query = scene.view((n_pixels, -1))
49 | sims = sim_func(opt, query, database)
50 | return sims
51 |
52 |
53 | def match_powder_kappa(opt, database, scene, bg, kappa):
54 | n_database = database.size()[0]
55 | n_pixels = scene.size()[0]
56 | n_channels = scene.size()[1]
57 |
58 | eta = torch.linspace(0, opt.eta_max, opt.n_etas, dtype=torch.double).cuda()
59 | alpha = eta.unsqueeze(0).unsqueeze(2) ** kappa.unsqueeze(1)
60 |
61 | # n_pixels * n_database * n_etas * n_channels
62 | bg = bg.unsqueeze(1).unsqueeze(2)
63 |
64 | # n_database * n_etas * n_channels
65 | database = database.unsqueeze(1) * (1 - alpha)
66 |
67 | batch_size = 64000 // n_channels
68 | sims = torch.zeros((n_pixels, n_database), dtype=torch.double).cuda()
69 | if n_database % batch_size == 0:
70 | n_batches = n_database // batch_size
71 | else:
72 | n_batches = n_database // batch_size + 1
73 |
74 | for p in range(n_pixels):
75 | query = scene[p:p+1, :]
76 | db = database + bg[p, :, :, :] * alpha
77 | for batch_idx in range(n_batches):
78 | bs = batch_idx * batch_size
79 | be = min(n_database, (batch_idx + 1) * batch_size)
80 | cur_database = db[bs:be, :, :].reshape(((be - bs) * opt.n_etas, -1))
81 | cur_sims = sim_func(opt, query, cur_database)
82 | cur_sims, _ = cur_sims.view((be - bs, opt.n_etas)).max(1)
83 | sims[p, bs:be] = cur_sims
84 |
85 | return sims
86 |
87 |
88 | def parse_args():
89 | parser = argparse.ArgumentParser(description='Recognition with known mask')
90 | parser.add_argument('--data-path', type=str, default='../data')
91 | parser.add_argument('--log-path', type=str, default='./log')
92 | parser.add_argument('--sel-path', type=str, default='../bandsel/bands')
93 | parser.add_argument('--bg', type=str, choices=['gt', 'inpaint'], default='inpaint')
94 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa')
95 | parser.add_argument('--eta-max', type=float, default=0.9)
96 | parser.add_argument('--n-etas', type=int, default=10)
97 | parser.add_argument('--kappa-params', type=str, default='../params/kappa_params.npz')
98 | parser.add_argument('--n-swirs', type=int, default=4)
99 | parser.add_argument('--n-rgbns', type=int, default=4, choices=[0, 1, 3, 4])
100 | parser.add_argument('--sel', type=str, default='nncv', choices=['nncv', 'grid', 'mvpca', 'rs'])
101 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split'])
102 | parser.add_argument('--set', type=str, default='test', choices=['test', 'val'])
103 | opt = parser.parse_args()
104 |
105 | assert(opt.n_rgbns + opt.n_swirs > 1)
106 | if opt.n_rgbns <= 1 or opt.n_swirs <= 1:
107 | assert(opt.dist == 'full')
108 |
109 | return opt
110 |
111 |
112 | if __name__ == '__main__':
113 | opt = parse_args()
114 |
115 | Path(opt.log_path).mkdir(parents=True, exist_ok=True)
116 | log_fname = Path(opt.log_path) / ('{}_{}_{}_{}_{}.txt'.format(opt.set, opt.bg, opt.blend, opt.n_swirs, opt.n_rgbns))
117 | assert(not log_fname.is_file())
118 |
119 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
120 | n_lights = len(lights)
121 | n_powders = 100
122 | n_scenes = 32
123 |
124 | if opt.n_rgbns == 4:
125 | rgbn_channels = [0, 1, 2, 3]
126 | elif opt.n_rgbns == 3:
127 | rgbn_channels = [0, 1, 2]
128 | elif opt.n_rgbns == 1:
129 | rgbn_channels = [3]
130 | else:
131 | rgbn_channels = []
132 |
133 | all_channels = rgbn_channels.copy()
134 | swir_channels = []
135 | if opt.n_swirs > 0:
136 | if opt.sel == 'grid':
137 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs)
138 | if opt.n_swirs == 1:
139 | swir_channels.append(480)
140 | else:
141 | decimation = int(30 // (np.sqrt(opt.n_swirs) - 1))
142 | for i in range(0, 31, decimation):
143 | for j in range(0, 31, decimation):
144 | swir_channels.append(i * 31 + j)
145 | else:
146 | sel_file = open(Path(opt.sel_path) / (opt.sel + '.txt'), 'r')
147 | splited = sel_file.readlines()[-1].strip().split(',')
148 | sel_file.close()
149 | for i in splited[:opt.n_swirs]:
150 | swir_channels.append(int(i))
151 | assert(len(swir_channels) == opt.n_swirs)
152 | for i in swir_channels:
153 | all_channels.append(i + 4)
154 |
155 | n_channels = opt.n_rgbns + opt.n_swirs
156 |
157 | log_file = open(log_fname, 'w')
158 | print(opt)
159 | print(opt, file=log_file)
160 | print(swir_channels)
161 | print(swir_channels, file=log_file)
162 |
163 | train_path = Path(opt.data_path) / 'train'
164 | test_path = Path(opt.data_path) / opt.set
165 |
166 | scene_path = test_path / 'scene'
167 | bgscene_path = test_path / 'bgscene'
168 | label_path = test_path / 'label'
169 |
170 | thick_list = np.zeros((n_powders, n_lights, n_channels))
171 | for i in range(n_powders):
172 | idx = str(i).zfill(2)
173 | for lid, light in enumerate(lights):
174 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz'))
175 | thick = np.concatenate((thick['rgbn'][:, :, rgbn_channels], thick['swir'][:, :, swir_channels]), axis=2)
176 | thick = thick.mean((0, 1))
177 | thick_list[i, lid] = thick
178 | thick_list = thick_list.reshape((n_powders * n_lights, n_channels))
179 | thick_list = torch.from_numpy(thick_list).cuda()
180 |
181 | if opt.blend == 'alpha':
182 | kappa = torch.ones((n_powders * n_lights, n_channels)).double().cuda()
183 | elif opt.blend == 'kappa':
184 | kappa_params = np.load(opt.kappa_params)
185 | kappa = kappa_params['params'][:, :, all_channels].reshape((n_powders * n_lights, n_channels))
186 | kappa = torch.from_numpy(kappa).cuda()
187 |
188 | acc_top1 = []
189 | acc_top3 = []
190 | start_time = time.time()
191 | for i in range(n_scenes):
192 | idx = str(i).zfill(2)
193 | print()
194 | print('scene', idx)
195 |
196 | print(file=log_file)
197 | print('scene', idx, file=log_file)
198 |
199 | scene = np.load(scene_path / (idx + '_scene.npz'))
200 | scene = np.concatenate((scene['rgbn'][:, :, rgbn_channels], scene['swir'][:, :, swir_channels]), axis=2)
201 | label = cv2.imread(str(label_path / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
202 | if opt.bg == 'inpaint':
203 | mask = (label < 255).astype(np.uint8) * 255
204 | bgscene = scene.copy()
205 | for c in range(n_channels):
206 | scene_max = scene[mask == 255, c].max()
207 | bgscene[:, :, c] = (cv2.inpaint((scene[:, :, c] / scene_max * 65535).astype(np.uint16), mask, 3, cv2.INPAINT_TELEA)).astype(scene.dtype) * scene_max / 65535
208 | elif opt.bg == 'gt':
209 | bgscene = np.load(bgscene_path / (idx + '_bgscene.npz'))
210 | bgscene = np.concatenate((bgscene['rgbn'][:, :, rgbn_channels], bgscene['swir'][:, :, swir_channels]), axis=2)
211 | else:
212 | assert(False)
213 | for powder in range(n_powders):
214 | mask = (label == powder)
215 | if mask.any():
216 | print('powder', powder)
217 | print('powder', powder, file=log_file)
218 | scene_list = scene[mask, :]
219 | bgscene_list = bgscene[mask, :]
220 | scene_list = torch.from_numpy(scene_list).cuda()
221 | bgscene_list = torch.from_numpy(bgscene_list).cuda()
222 | if opt.blend == 'none':
223 | sims = match_powder_none(opt, thick_list, scene_list)
224 | else:
225 | sims = match_powder_kappa(opt, thick_list, scene_list, bgscene_list, kappa)
226 | pred = sims2pred(n_lights, sims)
227 | top1 = (powder in pred[:1])
228 | top3 = (powder in pred[:3])
229 | acc_top1.append(top1)
230 | acc_top3.append(top3)
231 | print(pred)
232 | print(top1, top3)
233 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3))
234 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's')
235 | print(pred, file=log_file)
236 | print(top1, top3, file=log_file)
237 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3), file=log_file)
238 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's', file=log_file)
239 | print(np.mean(acc_top1), np.mean(acc_top3))
240 | print(np.mean(acc_top1), np.mean(acc_top3), file=log_file)
241 | print('Complete', file=log_file)
242 | log_file.close()
243 |
--------------------------------------------------------------------------------
/recog/recognition_testext.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import cv2
7 | import time
8 |
9 | from pathlib import Path
10 |
11 |
12 | def cosine(a, b):
13 | y = b.unsqueeze(0)
14 | n_pixels = a.size()[0]
15 | batch_size = 1000
16 | if n_pixels % batch_size == 0:
17 | n_batches = n_pixels // batch_size
18 | else:
19 | n_batches = n_pixels // batch_size + 1
20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda()
21 | for batch_idx in range(n_batches):
22 | bs = batch_idx * batch_size
23 | be = min(n_pixels, (batch_idx + 1) * batch_size)
24 | x = a[bs:be, :].unsqueeze(1)
25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2)
26 | return sim
27 |
28 |
29 | def sim_func(opt, query, database):
30 | if opt.dist == 'full':
31 | return cosine(query, database)
32 | elif opt.dist == 'split':
33 | return cosine(query[:, :opt.n_rgbns], database[:, :opt.n_rgbns]) + cosine(query[:, opt.n_rgbns:], database[:, opt.n_rgbns:])
34 | else:
35 | assert(False)
36 |
37 |
38 | def sims2pred(n_lights, sims):
39 | votes = sims.argmax(dim=1) // n_lights
40 | votes = votes.cpu().numpy()
41 | counts = collections.Counter(votes)
42 | pred = [i[0] for i in counts.most_common()]
43 | return pred
44 |
45 |
46 | def match_powder_none(opt, database, scene):
47 | n_pixels = scene.size()[0]
48 | query = scene.view((n_pixels, -1))
49 | sims = sim_func(opt, query, database)
50 | return sims
51 |
52 |
53 | def match_powder_kappa(opt, database, scene, bg, kappa):
54 | n_database = database.size()[0]
55 | n_pixels = scene.size()[0]
56 | n_channels = scene.size()[1]
57 |
58 | eta = torch.linspace(0, opt.eta_max, opt.n_etas, dtype=torch.double).cuda()
59 | alpha = eta.unsqueeze(0).unsqueeze(2) ** kappa.unsqueeze(1)
60 |
61 | # n_pixels * n_database * n_etas * n_channels
62 | bg = bg.unsqueeze(1).unsqueeze(2)
63 |
64 | # n_database * n_etas * n_channels
65 | database = database.unsqueeze(1) * (1 - alpha)
66 |
67 | batch_size = 64000 // n_channels
68 | sims = torch.zeros((n_pixels, n_database), dtype=torch.double).cuda()
69 | if n_database % batch_size == 0:
70 | n_batches = n_database // batch_size
71 | else:
72 | n_batches = n_database // batch_size + 1
73 |
74 | for p in range(n_pixels):
75 | query = scene[p:p+1, :]
76 | db = database + bg[p, :, :, :] * alpha
77 | for batch_idx in range(n_batches):
78 | bs = batch_idx * batch_size
79 | be = min(n_database, (batch_idx + 1) * batch_size)
80 | cur_database = db[bs:be, :, :].reshape(((be - bs) * opt.n_etas, -1))
81 | cur_sims = sim_func(opt, query, cur_database)
82 | cur_sims, _ = cur_sims.view((be - bs, opt.n_etas)).max(1)
83 | sims[p, bs:be] = cur_sims
84 |
85 | return sims
86 |
87 | def chmap(channels):
88 | ch_list = [4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964]
89 | mapped = []
90 | for c in channels:
91 | for i, ch in enumerate(ch_list):
92 | if ch - 4 == c:
93 | mapped.append(i)
94 | return mapped
95 |
96 | def parse_args():
97 | parser = argparse.ArgumentParser(description='Recognition with known mask')
98 | parser.add_argument('--data-path', type=str, default='../data')
99 | parser.add_argument('--log-path', type=str, default='./log')
100 | parser.add_argument('--sel-path', type=str, default='../bandsel/bands')
101 | parser.add_argument('--bg', type=str, choices=['gt', 'inpaint'], default='inpaint')
102 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa')
103 | parser.add_argument('--eta-max', type=float, default=0.9)
104 | parser.add_argument('--n-etas', type=int, default=10)
105 | parser.add_argument('--kappa-params', type=str, default='../params/kappa_params.npz')
106 | parser.add_argument('--n-swirs', type=int, default=4)
107 | parser.add_argument('--n-rgbns', type=int, default=4, choices=[0, 1, 3, 4])
108 | parser.add_argument('--sel', type=str, default='nncv', choices=['nncv', 'grid', 'mvpca', 'rs'])
109 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split'])
110 | parser.add_argument('--set', type=str, default='testext', choices=['testext'])
111 | opt = parser.parse_args()
112 |
113 | assert(opt.n_rgbns + opt.n_swirs > 1)
114 | if opt.n_rgbns <= 1 or opt.n_swirs <= 1:
115 | assert(opt.dist == 'full')
116 |
117 | return opt
118 |
119 |
120 | if __name__ == '__main__':
121 | opt = parse_args()
122 |
123 | Path(opt.log_path).mkdir(parents=True, exist_ok=True)
124 | log_fname = Path(opt.log_path) / ('{}_{}_{}_{}_{}.txt'.format(opt.set, opt.bg, opt.blend, opt.n_swirs, opt.n_rgbns))
125 | assert(not log_fname.is_file())
126 |
127 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
128 | n_lights = len(lights)
129 | n_powders = 100
130 | n_scenes = 64
131 |
132 | if opt.n_rgbns == 4:
133 | rgbn_channels = [0, 1, 2, 3]
134 | elif opt.n_rgbns == 3:
135 | rgbn_channels = [0, 1, 2]
136 | elif opt.n_rgbns == 1:
137 | rgbn_channels = [3]
138 | else:
139 | rgbn_channels = []
140 |
141 | all_channels = rgbn_channels.copy()
142 | swir_channels = []
143 | if opt.n_swirs > 0:
144 | if opt.sel == 'grid':
145 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs)
146 | if opt.n_swirs == 1:
147 | swir_channels.append(480)
148 | else:
149 | decimation = int(30 // (np.sqrt(opt.n_swirs) - 1))
150 | for i in range(0, 31, decimation):
151 | for j in range(0, 31, decimation):
152 | swir_channels.append(i * 31 + j)
153 | else:
154 | sel_file = open(Path(opt.sel_path) / (opt.sel + '.txt'), 'r')
155 | splited = sel_file.readlines()[-1].strip().split(',')
156 | sel_file.close()
157 | for i in splited[:opt.n_swirs]:
158 | swir_channels.append(int(i))
159 | assert(len(swir_channels) == opt.n_swirs)
160 | for i in swir_channels:
161 | all_channels.append(i + 4)
162 |
163 | n_channels = opt.n_rgbns + opt.n_swirs
164 |
165 | log_file = open(log_fname, 'w')
166 | print(opt)
167 | print(opt, file=log_file)
168 | print(swir_channels)
169 | print(swir_channels, file=log_file)
170 |
171 | train_path = Path(opt.data_path) / 'train'
172 | test_path = Path(opt.data_path) / opt.set
173 |
174 |
175 | thick_list = np.zeros((n_powders, n_lights, n_channels))
176 | for i in range(n_powders):
177 | idx = str(i).zfill(2)
178 | for lid, light in enumerate(lights):
179 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz'))
180 | thick = np.concatenate((thick['rgbn'][:, :, rgbn_channels], thick['swir'][:, :, swir_channels]), axis=2)
181 | thick = thick.mean((0, 1))
182 | thick_list[i, lid] = thick
183 | thick_list = thick_list.reshape((n_powders * n_lights, n_channels))
184 | thick_list = torch.from_numpy(thick_list).cuda()
185 |
186 | if opt.blend == 'alpha':
187 | kappa = torch.ones((n_powders * n_lights, n_channels)).double().cuda()
188 | elif opt.blend == 'kappa':
189 | kappa_params = np.load(opt.kappa_params)
190 | kappa = kappa_params['params'][:, :, all_channels].reshape((n_powders * n_lights, n_channels))
191 | kappa = torch.from_numpy(kappa).cuda()
192 |
193 | acc_top1 = []
194 | acc_top3 = []
195 | start_time = time.time()
196 | for i in range(n_scenes):
197 | idx = str(i).zfill(2)
198 | print()
199 | print('scene', idx)
200 | print(file=log_file)
201 | print('scene', idx, file=log_file)
202 |
203 | scene_path = test_path / 'scene'
204 | bgscene_path = test_path / 'bgscene'
205 | label_path = test_path / 'label'
206 |
207 | scene = np.load(scene_path / (idx + '_scene.npz'))
208 | scene = np.concatenate((scene['rgbn'][:, :, rgbn_channels], scene['swir'][:, :, chmap(swir_channels)]), axis=2)
209 | label = cv2.imread(str(label_path / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
210 | if opt.bg == 'inpaint':
211 | mask = (label < 255).astype(np.uint8) * 255
212 | bgscene = scene.copy()
213 | for c in range(n_channels):
214 | scene_max = scene[mask == 255, c].max()
215 | bgscene[:, :, c] = (cv2.inpaint((scene[:, :, c] / scene_max * 65535).astype(np.uint16), mask, 3, cv2.INPAINT_TELEA)).astype(scene.dtype) * scene_max / 65535
216 | elif opt.bg == 'gt':
217 | bgscene = np.load(bgscene_path / (idx + '_bgscene.npz'))
218 | bgscene = np.concatenate((bgscene['rgbn'][:, :, rgbn_channels], bgscene['swir'][:, :, chmap(swir_channels)]), axis=2)
219 | else:
220 | assert(False)
221 | for powder in range(n_powders):
222 | mask = (label == powder)
223 | if mask.any():
224 | print('powder', powder)
225 | print('powder', powder, file=log_file)
226 | scene_list = scene[mask, :]
227 | bgscene_list = bgscene[mask, :]
228 | scene_list = torch.from_numpy(scene_list).cuda()
229 | bgscene_list = torch.from_numpy(bgscene_list).cuda()
230 | if opt.blend == 'none':
231 | sims = match_powder_none(opt, thick_list, scene_list)
232 | else:
233 | sims = match_powder_kappa(opt, thick_list, scene_list, bgscene_list, kappa)
234 | pred = sims2pred(n_lights, sims)
235 | top1 = (powder in pred[:1])
236 | top3 = (powder in pred[:3])
237 | acc_top1.append(top1)
238 | acc_top3.append(top3)
239 | print(pred)
240 | print(top1, top3)
241 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3))
242 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's')
243 | print(pred, file=log_file)
244 | print(top1, top3, file=log_file)
245 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3), file=log_file)
246 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's', file=log_file)
247 | print(np.mean(acc_top1), np.mean(acc_top3))
248 | print(np.mean(acc_top1), np.mean(acc_top3), file=log_file)
249 | print('Complete', file=log_file)
250 | log_file.close()
251 |
--------------------------------------------------------------------------------
/recog/recognition_trainext.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import cv2
7 | import time
8 |
9 | from pathlib import Path
10 |
11 |
12 | def cosine(a, b):
13 | y = b.unsqueeze(0)
14 | n_pixels = a.size()[0]
15 | batch_size = 1000
16 | if n_pixels % batch_size == 0:
17 | n_batches = n_pixels // batch_size
18 | else:
19 | n_batches = n_pixels // batch_size + 1
20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda()
21 | for batch_idx in range(n_batches):
22 | bs = batch_idx * batch_size
23 | be = min(n_pixels, (batch_idx + 1) * batch_size)
24 | x = a[bs:be, :].unsqueeze(1)
25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2)
26 | return sim
27 |
28 |
29 | def sim_func(opt, query, database):
30 | if opt.dist == 'full':
31 | return cosine(query, database)
32 | elif opt.dist == 'split':
33 | return cosine(query[:, :opt.n_rgbns], database[:, :opt.n_rgbns]) + cosine(query[:, opt.n_rgbns:], database[:, opt.n_rgbns:])
34 | else:
35 | assert(False)
36 |
37 |
38 | def sims2pred(n_lights, sims):
39 | votes = sims.argmax(dim=1) // n_lights
40 | votes = votes.cpu().numpy()
41 | counts = collections.Counter(votes)
42 | pred = [i[0] for i in counts.most_common()]
43 | return pred
44 |
45 |
46 | def match_powder_none(opt, database, scene):
47 | n_pixels = scene.size()[0]
48 | query = scene.view((n_pixels, -1))
49 | sims = sim_func(opt, query, database)
50 | return sims
51 |
52 |
53 | def match_powder_kappa(opt, database, scene, bg, kappa):
54 | n_database = database.size()[0]
55 | n_pixels = scene.size()[0]
56 | n_channels = scene.size()[1]
57 |
58 | eta = torch.linspace(0, opt.eta_max, opt.n_etas, dtype=torch.double).cuda()
59 | alpha = eta.unsqueeze(0).unsqueeze(2) ** kappa.unsqueeze(1)
60 |
61 | # n_pixels * n_database * n_etas * n_channels
62 | bg = bg.unsqueeze(1).unsqueeze(2)
63 |
64 | # n_database * n_etas * n_channels
65 | database = database.unsqueeze(1) * (1 - alpha)
66 |
67 | batch_size = 64000 // n_channels
68 | sims = torch.zeros((n_pixels, n_database), dtype=torch.double).cuda()
69 | if n_database % batch_size == 0:
70 | n_batches = n_database // batch_size
71 | else:
72 | n_batches = n_database // batch_size + 1
73 |
74 | for p in range(n_pixels):
75 | query = scene[p:p+1, :]
76 | db = database + bg[p, :, :, :] * alpha
77 | for batch_idx in range(n_batches):
78 | bs = batch_idx * batch_size
79 | be = min(n_database, (batch_idx + 1) * batch_size)
80 | cur_database = db[bs:be, :, :].reshape(((be - bs) * opt.n_etas, -1))
81 | cur_sims = sim_func(opt, query, cur_database)
82 | cur_sims, _ = cur_sims.view((be - bs, opt.n_etas)).max(1)
83 | sims[p, bs:be] = cur_sims
84 |
85 | return sims
86 |
87 | def chmap(channels):
88 | ch_list = [4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964]
89 | mapped = []
90 | for c in channels:
91 | for i, ch in enumerate(ch_list):
92 | if ch - 4 == c:
93 | mapped.append(i)
94 | return mapped
95 |
96 | def parse_args():
97 | parser = argparse.ArgumentParser(description='Recognition with known mask')
98 | parser.add_argument('--data-path', type=str, default='../data')
99 | parser.add_argument('--log-path', type=str, default='./log')
100 | parser.add_argument('--sel-path', type=str, default='../bandsel/bands')
101 | parser.add_argument('--bg', type=str, choices=['gt', 'inpaint'], default='inpaint')
102 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa')
103 | parser.add_argument('--eta-max', type=float, default=0.9)
104 | parser.add_argument('--n-etas', type=int, default=10)
105 | parser.add_argument('--kappa-params', type=str, default='../params/kappa_params.npz')
106 | parser.add_argument('--n-swirs', type=int, default=4)
107 | parser.add_argument('--n-rgbns', type=int, default=4, choices=[0, 1, 3, 4])
108 | parser.add_argument('--sel', type=str, default='nncv', choices=['nncv', 'grid', 'mvpca', 'rs'])
109 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split'])
110 | parser.add_argument('--set', type=str, default='trainext', choices=['trainext'])
111 | opt = parser.parse_args()
112 |
113 | assert(opt.n_rgbns + opt.n_swirs > 1)
114 | if opt.n_rgbns <= 1 or opt.n_swirs <= 1:
115 | assert(opt.dist == 'full')
116 |
117 | return opt
118 |
119 |
120 | if __name__ == '__main__':
121 | opt = parse_args()
122 |
123 | Path(opt.log_path).mkdir(parents=True, exist_ok=True)
124 | log_fname = Path(opt.log_path) / ('{}_{}_{}_{}_{}.txt'.format(opt.set, opt.bg, opt.blend, opt.n_swirs, opt.n_rgbns))
125 | assert(not log_fname.is_file())
126 |
127 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
128 | n_lights = len(lights)
129 | n_powders = 100
130 | n_scenes_per_light = 16
131 |
132 | if opt.n_rgbns == 4:
133 | rgbn_channels = [0, 1, 2, 3]
134 | elif opt.n_rgbns == 3:
135 | rgbn_channels = [0, 1, 2]
136 | elif opt.n_rgbns == 1:
137 | rgbn_channels = [3]
138 | else:
139 | rgbn_channels = []
140 |
141 | all_channels = rgbn_channels.copy()
142 | swir_channels = []
143 | if opt.n_swirs > 0:
144 | if opt.sel == 'grid':
145 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs)
146 | if opt.n_swirs == 1:
147 | swir_channels.append(480)
148 | else:
149 | decimation = int(30 // (np.sqrt(opt.n_swirs) - 1))
150 | for i in range(0, 31, decimation):
151 | for j in range(0, 31, decimation):
152 | swir_channels.append(i * 31 + j)
153 | elif opt.sel == 'uniform':
154 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs)
155 | if opt.n_swirs == 1:
156 | swir_channels.append(480)
157 | else:
158 | a = int(np.sqrt(opt.n_swirs))
159 | for i in range(0, a):
160 | for j in range(0, a):
161 | x = int(np.floor(31*(2*i+1)/2/a))
162 | y = int(np.floor(31*(2*j+1)/2/a))
163 | swir_channels.append(x * 31 + y)
164 | else:
165 | sel_file = open(Path(opt.sel_path) / (opt.sel + '.txt'), 'r')
166 | splited = sel_file.readlines()[-1].strip().split(',')
167 | sel_file.close()
168 | for i in splited[:opt.n_swirs]:
169 | swir_channels.append(int(i))
170 | assert(len(swir_channels) == opt.n_swirs)
171 | for i in swir_channels:
172 | all_channels.append(i + 4)
173 |
174 | n_channels = opt.n_rgbns + opt.n_swirs
175 |
176 | log_file = open(log_fname, 'w')
177 | print(opt)
178 | print(opt, file=log_file)
179 | print(swir_channels)
180 | print(swir_channels, file=log_file)
181 |
182 | train_path = Path(opt.data_path) / 'train'
183 | test_path = Path(opt.data_path) / opt.set
184 |
185 |
186 | thick_list = np.zeros((n_powders, n_lights, n_channels))
187 | for i in range(n_powders):
188 | idx = str(i).zfill(2)
189 | for lid, light in enumerate(lights):
190 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz'))
191 | thick = np.concatenate((thick['rgbn'][:, :, rgbn_channels], thick['swir'][:, :, swir_channels]), axis=2)
192 | thick = thick.mean((0, 1))
193 | thick_list[i, lid] = thick
194 | thick_list = thick_list.reshape((n_powders * n_lights, n_channels))
195 | thick_list = torch.from_numpy(thick_list).cuda()
196 |
197 | if opt.blend == 'alpha':
198 | kappa = torch.ones((n_powders * n_lights, n_channels)).double().cuda()
199 | elif opt.blend == 'kappa':
200 | kappa_params = np.load(opt.kappa_params)
201 | kappa = kappa_params['params'][:, :, all_channels].reshape((n_powders * n_lights, n_channels))
202 | kappa = torch.from_numpy(kappa).cuda()
203 |
204 | acc_top1 = []
205 | acc_top3 = []
206 | start_time = time.time()
207 | for lid, light in enumerate(lights):
208 | for i in range(n_scenes_per_light):
209 | idx = str(i).zfill(2)
210 | print()
211 | print('scene', light, idx)
212 |
213 | print(file=log_file)
214 | print('scene', light, idx, file=log_file)
215 |
216 | scene_path = test_path / light / 'scene'
217 | bgscene_path = test_path / light / 'bgscene'
218 | label_path = test_path / light / 'label'
219 |
220 | scene = np.load(scene_path / (idx + '_scene.npz'))
221 | scene = np.concatenate((scene['rgbn'][:, :, rgbn_channels], scene['swir'][:, :, chmap(swir_channels)]), axis=2)
222 | label = cv2.imread(str(label_path / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE)
223 | if opt.bg == 'inpaint':
224 | mask = (label < 255).astype(np.uint8) * 255
225 | bgscene = scene.copy()
226 | for c in range(n_channels):
227 | scene_max = scene[mask == 255, c].max()
228 | bgscene[:, :, c] = (cv2.inpaint((scene[:, :, c] / scene_max * 65535).astype(np.uint16), mask, 3, cv2.INPAINT_TELEA)).astype(scene.dtype) * scene_max / 65535
229 | elif opt.bg == 'gt':
230 | bgscene = np.load(bgscene_path / (idx + '_bgscene.npz'))
231 | bgscene = np.concatenate((bgscene['rgbn'][:, :, rgbn_channels], bgscene['swir'][:, :, chmap(swir_channels)]), axis=2)
232 | else:
233 | assert(False)
234 | for powder in range(n_powders):
235 | mask = (label == powder)
236 | if mask.any():
237 | print('powder', powder)
238 | print('powder', powder, file=log_file)
239 | scene_list = scene[mask, :]
240 | bgscene_list = bgscene[mask, :]
241 | scene_list = torch.from_numpy(scene_list).cuda()
242 | bgscene_list = torch.from_numpy(bgscene_list).cuda()
243 | if opt.blend == 'none':
244 | sims = match_powder_none(opt, thick_list, scene_list)
245 | else:
246 | sims = match_powder_kappa(opt, thick_list, scene_list, bgscene_list, kappa)
247 | pred = sims2pred(n_lights, sims)
248 | top1 = (powder in pred[:1])
249 | top3 = (powder in pred[:3])
250 | acc_top1.append(top1)
251 | acc_top3.append(top3)
252 | print(pred)
253 | print(top1, top3)
254 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3))
255 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's')
256 | print(pred, file=log_file)
257 | print(top1, top3, file=log_file)
258 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3), file=log_file)
259 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's', file=log_file)
260 | print(np.mean(acc_top1), np.mean(acc_top3))
261 | print(np.mean(acc_top1), np.mean(acc_top3), file=log_file)
262 | print('Complete', file=log_file)
263 | log_file.close()
264 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import torch.utils.data as data
4 | import torch
5 | import cv2
6 | import random
7 | import scipy.special
8 | from pathlib import Path
9 |
10 |
11 | class SyntheticDataset(data.Dataset):
12 |
13 | def __init__(self, data_path, params_path, blend, channels):
14 | super(SyntheticDataset, self).__init__()
15 | assert(blend in ['none', 'alpha', 'kappa'])
16 | self.data_path = Path(data_path)
17 | self.blend = blend
18 | self.lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
19 | self.n_lights = len(self.lights)
20 | self.n_powders = 100
21 | self.height = 160
22 | self.width = 280
23 | self.channels = channels
24 |
25 | if type(channels) is int:
26 | self.channel = abs(channels)
27 | if channels > 0:
28 | self.ch_begin = 0
29 | self.ch_end = channels
30 | else:
31 | self.ch_begin = 965 + channels
32 | self.ch_end = 965
33 | else:
34 | self.channel = len(channels)
35 | self.ch_begin = None
36 | self.ch_end = None
37 |
38 | self.thickness_threshold = 0.1
39 | self.n_classes = 100 + 1
40 | self.n_per_light = 1000
41 | self.thick_sigma = 0.1
42 | self.shad_sigma = 0.1
43 | self.brdf_sigma = 0.1
44 | if blend == 'kappa':
45 | if self.ch_begin is None:
46 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.channels]
47 | else:
48 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.ch_begin:self.ch_end]
49 | else:
50 | self.kappa = None
51 |
52 | def __getitem__(self, index):
53 | lid = index // self.n_per_light
54 | light = self.lights[lid]
55 | powder_idx = index % self.n_per_light
56 | bg_idx = random.randint(0, self.n_per_light - 1)
57 | h5file = h5py.File(self.data_path / (light + '.hdf5'), 'r')
58 | if self.ch_begin is None:
59 | bg = h5file['bg'][bg_idx, :, :, self.channels].astype(np.float32)
60 | powder = h5file['powder'][powder_idx, :, :, self.channels].astype(np.float32)
61 | else:
62 | bg = h5file['bg'][bg_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32)
63 | powder = h5file['powder'][powder_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32)
64 | shading = h5file['shading'][bg_idx].astype(np.float32)
65 | label = h5file['label'][powder_idx]
66 | thickness = h5file['thickness'][powder_idx].astype(np.float32)
67 | h5file.close()
68 |
69 | if random.randint(0, 1) == 1:
70 | bg = np.fliplr(bg)
71 | shading = np.fliplr(shading)
72 | if random.randint(0, 1) == 1:
73 | bg = np.flipud(bg)
74 | shading = np.flipud(shading)
75 | if random.randint(0, 1) == 1:
76 | powder = np.fliplr(powder)
77 | label = np.fliplr(label)
78 | thickness = np.fliplr(thickness)
79 | if random.randint(0, 1) == 1:
80 | powder = np.flipud(powder)
81 | label = np.flipud(label)
82 | thickness = np.flipud(thickness)
83 |
84 | for i in range(self.n_powders):
85 | mask = (label == i)
86 | thickness[mask] = thickness[mask] * self.exp_gauss(self.thick_sigma)
87 | powder[mask] = powder[mask] * self.exp_gauss(self.brdf_sigma)
88 | label[thickness < self.thickness_threshold] = 255
89 |
90 | if self.blend == 'none':
91 | thickness = (thickness >= self.thickness_threshold).astype(np.float32)
92 | thickness = thickness[:, :, np.newaxis]
93 | alpha = 1 - thickness
94 | elif self.blend == 'alpha':
95 | thickness[thickness > 1] = 1
96 | thickness = thickness[:, :, np.newaxis]
97 | alpha = 1 - thickness
98 | elif self.blend == 'kappa':
99 | thickness[thickness > 1] = 1
100 | thickness = thickness[:, :, np.newaxis]
101 | alpha = np.ones((self.height, self.width, self.channel), dtype=np.float32)
102 | for i in range(self.n_powders):
103 | mask = (label == i)
104 | alpha[mask, :] = (1 - thickness[mask, :]) ** self.kappa[i, lid, :][np.newaxis, :]
105 | im = alpha * bg + (1 - alpha) * powder
106 | med_shad = np.median(shading)
107 | im = im * shading[:, :, np.newaxis] / med_shad
108 | im = im * self.exp_gauss(self.shad_sigma)
109 |
110 | im = im.transpose([2, 0, 1])
111 | label[label == 255] = self.n_classes - 1
112 | label = label.astype(np.int64)
113 | return im, label
114 |
115 | def exp_gauss(self, sigma):
116 | return np.exp(random.gauss(0, sigma))
117 |
118 | def __len__(self):
119 | return self.n_lights * self.n_per_light
120 |
121 |
122 | class RealDataset(data.Dataset):
123 |
124 | def __init__(self, data_path, channels, split, flip=False):
125 | super(RealDataset, self).__init__()
126 | self.data_path = Path(data_path)
127 | self.n_classes = 100 + 1
128 | self.split = split
129 | self.flip = flip
130 |
131 | if split == 'trainext' or split == 'testext':
132 | assert(type(channels) is not int)
133 | self.n_images = 64
134 | self.channels = self.chmap(channels)
135 | self.channel = len(self.channels)
136 | self.ch_begin = None
137 | self.ch_end = None
138 | else:
139 | self.n_images = 32
140 | self.channels = channels
141 | if type(channels) is int:
142 | self.channel = abs(channels)
143 | if channels > 0:
144 | self.ch_begin = 0
145 | self.ch_end = channels
146 | else:
147 | self.ch_begin = 965 + channels
148 | self.ch_end = 965
149 | else:
150 | self.channel = len(channels)
151 | self.ch_begin = None
152 | self.ch_end = None
153 |
154 | def __getitem__(self, index):
155 | h5file = h5py.File(self.data_path / (self.split + '.hdf5'), 'r')
156 | if self.ch_begin is None:
157 | im = h5file['im'][index, :, :, self.channels].astype(np.float32)
158 | else:
159 | im = h5file['im'][index, :, :, self.ch_begin:self.ch_end].astype(np.float32)
160 | label = h5file['label'][index]
161 | h5file.close()
162 |
163 | if self.flip:
164 | if random.randint(0, 1) == 1:
165 | im = np.fliplr(im)
166 | label = np.fliplr(label)
167 | if random.randint(0, 1) == 1:
168 | im = np.flipud(im)
169 | label = np.flipud(label)
170 |
171 | im = im.transpose([2, 0, 1]).copy()
172 | label[label == 255] = self.n_classes - 1
173 | label = label.astype(np.int64).copy()
174 | return im, label
175 |
176 | def __len__(self):
177 | return self.n_images
178 |
179 | def chmap(self, channels):
180 | ch_list = [0, 1, 2, 3, 4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964]
181 | mapped = []
182 | for i, ch in enumerate(ch_list):
183 | if ch in list(channels):
184 | mapped.append(i)
185 | print(channels, mapped)
186 | assert(len(channels) == len(mapped))
187 | return mapped
188 |
189 |
190 | class HalfHalfDataset(data.Dataset):
191 |
192 | def __init__(self, real_path, syn_path, params_path, blend, channels, split):
193 | super(HalfHalfDataset, self).__init__()
194 | assert(blend in ['none', 'alpha', 'kappa'])
195 | self.real_path = Path(real_path) / split
196 | self.syn_path = Path(syn_path)
197 | self.blend = blend
198 | self.lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W']
199 | self.n_lights = len(self.lights)
200 | self.n_powders = 100
201 | self.height = 160
202 | self.width = 280
203 | self.channels = channels
204 | if split == 'bgext':
205 | assert(type(channels) is not int)
206 | self.n_bg_per_light = 32
207 | self.bg_channels = self.chmap(channels)
208 | self.channel = len(self.channels)
209 | self.ch_begin = None
210 | self.ch_end = None
211 | else:
212 | self.bg_channels = self.channels
213 | self.n_bg_per_light = 16
214 | if type(channels) is int:
215 | self.channel = abs(channels)
216 | if channels > 0:
217 | self.ch_begin = 0
218 | self.ch_end = channels
219 | else:
220 | self.ch_begin = 965 + channels
221 | self.ch_end = 965
222 | else:
223 | self.channel = len(channels)
224 | self.ch_begin = None
225 | self.ch_end = None
226 | self.thickness_threshold = 0.1
227 | self.n_classes = 100 + 1
228 | self.n_powder_per_light = 1000
229 | self.thick_sigma = 0.1
230 | self.brdf_sigma = 0.15
231 | if blend == 'kappa':
232 | if self.ch_begin is None:
233 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.channels]
234 | else:
235 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.ch_begin:self.ch_end]
236 | else:
237 | self.kappa = None
238 |
239 | def __getitem__(self, index):
240 | lid = index // self.n_powder_per_light
241 | light = self.lights[lid]
242 | powder_idx = index % self.n_powder_per_light
243 | bg_idx = random.randint(0, self.n_bg_per_light - 1)
244 | h5file = h5py.File(self.real_path / (light + '.hdf5'), 'r')
245 | if self.ch_begin is None:
246 | bg = h5file['im'][bg_idx, :, :, self.bg_channels].astype(np.float32)
247 | else:
248 | bg = h5file['im'][bg_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32)
249 | h5file.close()
250 | h5file = h5py.File(self.syn_path / (light + '.hdf5'), 'r')
251 | if self.ch_begin is None:
252 | powder = h5file['powder'][powder_idx, :, :, self.channels].astype(np.float32)
253 | else:
254 | powder = h5file['powder'][powder_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32)
255 | label = h5file['label'][powder_idx]
256 | thickness = h5file['thickness'][powder_idx].astype(np.float32)
257 | h5file.close()
258 |
259 | if random.randint(0, 1) == 1:
260 | bg = np.fliplr(bg)
261 | if random.randint(0, 1) == 1:
262 | bg = np.flipud(bg)
263 | if random.randint(0, 1) == 1:
264 | powder = np.fliplr(powder)
265 | label = np.fliplr(label)
266 | thickness = np.fliplr(thickness)
267 | if random.randint(0, 1) == 1:
268 | powder = np.flipud(powder)
269 | label = np.flipud(label)
270 | thickness = np.flipud(thickness)
271 |
272 | for i in range(self.n_powders):
273 | mask = (label == i)
274 | thickness[mask] = thickness[mask] * self.exp_gauss(self.thick_sigma)
275 | powder[mask] = powder[mask] * self.exp_gauss(self.brdf_sigma)
276 | label[thickness < self.thickness_threshold] = 255
277 |
278 | if self.blend == 'none':
279 | thickness = (thickness >= self.thickness_threshold).astype(np.float32)
280 | thickness = thickness[:, :, np.newaxis]
281 | alpha = 1 - thickness
282 | elif self.blend == 'alpha':
283 | thickness[thickness > 1] = 1
284 | thickness = thickness[:, :, np.newaxis]
285 | alpha = 1 - thickness
286 | elif self.blend == 'kappa':
287 | thickness[thickness > 1] = 1
288 | thickness = thickness[:, :, np.newaxis]
289 | alpha = np.ones((self.height, self.width, self.channel), dtype=np.float32)
290 | for i in range(self.n_powders):
291 | mask = (label == i)
292 | alpha[mask, :] = (1 - thickness[mask, :]) ** self.kappa[i, lid, :][np.newaxis, :]
293 | im = alpha * bg + (1 - alpha) * powder
294 |
295 | im = im.transpose([2, 0, 1])
296 | label[label == 255] = self.n_classes - 1
297 | label = label.astype(np.int64)
298 | return im, label
299 |
300 | def exp_gauss(self, sigma):
301 | return np.exp(random.gauss(0, sigma))
302 |
303 | def __len__(self):
304 | return self.n_lights * self.n_powder_per_light
305 |
306 | def chmap(self, channels):
307 | ch_list = [0, 1, 2, 3, 4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964]
308 | mapped = []
309 | for i, ch in enumerate(ch_list):
310 | if ch in list(channels):
311 | mapped.append(i)
312 | print(channels, mapped)
313 | assert(len(channels) == len(mapped))
314 | return mapped
315 |
--------------------------------------------------------------------------------
/src/deeplab/deeplab_xception.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 |
7 |
8 | class SeparableConv2d(nn.Module):
9 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False):
10 | super(SeparableConv2d, self).__init__()
11 |
12 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation,
13 | groups=inplanes, bias=bias)
14 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
15 |
16 | def forward(self, x):
17 | x = self.conv1(x)
18 | x = self.pointwise(x)
19 | return x
20 |
21 |
22 | def fixed_padding(inputs, kernel_size, rate):
23 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
24 | pad_total = kernel_size_effective - 1
25 | pad_beg = pad_total // 2
26 | pad_end = pad_total - pad_beg
27 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
28 | return padded_inputs
29 |
30 |
31 | class SeparableConv2d_same(nn.Module):
32 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False):
33 | super(SeparableConv2d_same, self).__init__()
34 |
35 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
36 | groups=inplanes, bias=bias)
37 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
38 |
39 | def forward(self, x):
40 | x = fixed_padding(x, self.conv1.kernel_size[0], rate=self.conv1.dilation[0])
41 | x = self.conv1(x)
42 | x = self.pointwise(x)
43 | return x
44 |
45 |
46 | class Block(nn.Module):
47 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True):
48 | super(Block, self).__init__()
49 |
50 | if planes != inplanes or stride != 1:
51 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
52 | #self.skipbn = nn.BatchNorm2d(planes)
53 | self.skipbn = nn.GroupNorm(8, planes)
54 | else:
55 | self.skip = None
56 |
57 | self.relu = nn.ReLU(inplace=True)
58 | rep = []
59 |
60 | filters = inplanes
61 | if grow_first:
62 | rep.append(self.relu)
63 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
64 | #rep.append(nn.BatchNorm2d(planes))
65 | rep.append(nn.GroupNorm(8, planes))
66 | filters = planes
67 |
68 | for i in range(reps - 1):
69 | rep.append(self.relu)
70 | rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation))
71 | #rep.append(nn.BatchNorm2d(filters))
72 | rep.append(nn.GroupNorm(8, filters))
73 |
74 | if not grow_first:
75 | rep.append(self.relu)
76 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation))
77 | #rep.append(nn.BatchNorm2d(planes))
78 | rep.append(nn.GroupNorm(8, planes))
79 |
80 | if not start_with_relu:
81 | rep = rep[1:]
82 |
83 | if stride != 1:
84 | rep.append(SeparableConv2d_same(planes, planes, 3, stride=stride))
85 |
86 | self.rep = nn.Sequential(*rep)
87 |
88 | def forward(self, inp):
89 | x = self.rep(inp)
90 |
91 | if self.skip is not None:
92 | skip = self.skip(inp)
93 | skip = self.skipbn(skip)
94 | else:
95 | skip = inp
96 |
97 | x += skip
98 |
99 | return x
100 |
101 |
102 | class Xception(nn.Module):
103 | """
104 | Modified Alighed Xception
105 | """
106 | def __init__(self, inplanes=3, pretrained=False):
107 | super(Xception, self).__init__()
108 |
109 | # Entry flow
110 | self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False)
111 | #self.bn1 = nn.BatchNorm2d(32)
112 | self.bn1 = nn.GroupNorm(16, 32)
113 | self.relu = nn.ReLU(inplace=True)
114 |
115 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
116 | #self.bn2 = nn.BatchNorm2d(64)
117 | self.bn2 = nn.GroupNorm(32, 64)
118 |
119 | self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False)
120 | self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True)
121 | self.block3 = Block(256, 728, reps=2, stride=2, start_with_relu=True, grow_first=True)
122 |
123 | # Middle flow
124 | self.block4 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
125 | self.block5 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
126 | self.block6 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
127 | self.block7 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
128 | self.block8 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
129 | self.block9 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
130 | self.block10 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
131 | self.block11 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
132 | self.block12 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
133 | self.block13 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
134 | self.block14 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
135 | self.block15 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
136 | self.block16 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
137 | self.block17 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
138 | self.block18 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
139 | self.block19 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True)
140 |
141 | # Exit flow
142 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=2, start_with_relu=True, grow_first=False)
143 |
144 | self.conv3 = SeparableConv2d_same(1024, 1536, 3, stride=1, dilation=2)
145 | #self.bn3 = nn.BatchNorm2d(1536)
146 | self.bn3 = nn.GroupNorm(32, 1536)
147 |
148 | self.conv4 = SeparableConv2d_same(1536, 1536, 3, stride=1, dilation=2)
149 | #self.bn4 = nn.BatchNorm2d(1536)
150 | self.bn4 = nn.GroupNorm(32, 1536)
151 |
152 | self.conv5 = SeparableConv2d_same(1536, 2048, 3, stride=1, dilation=2)
153 | #self.bn5 = nn.BatchNorm2d(2048)
154 | self.bn5 = nn.GroupNorm(32, 2048)
155 |
156 | # Init weights
157 | self.__init_weight()
158 |
159 | # Load pretrained model
160 | if pretrained:
161 | self.__load_xception_pretrained()
162 |
163 | def forward(self, x):
164 | # Entry flow
165 | x = self.conv1(x)
166 | x = self.bn1(x)
167 | x = self.relu(x)
168 |
169 | x = self.conv2(x)
170 | x = self.bn2(x)
171 | x = self.relu(x)
172 |
173 | x = self.block1(x)
174 | low_level_feat = x
175 | x = self.block2(x)
176 | x = self.block3(x)
177 |
178 | # Middle flow
179 | x = self.block4(x)
180 | x = self.block5(x)
181 | x = self.block6(x)
182 | x = self.block7(x)
183 | x = self.block8(x)
184 | x = self.block9(x)
185 | x = self.block10(x)
186 | x = self.block11(x)
187 | x = self.block12(x)
188 | x = self.block13(x)
189 | x = self.block14(x)
190 | x = self.block15(x)
191 | x = self.block16(x)
192 | x = self.block17(x)
193 | x = self.block18(x)
194 | x = self.block19(x)
195 |
196 | # Exit flow
197 | x = self.block20(x)
198 | x = self.conv3(x)
199 | x = self.bn3(x)
200 | x = self.relu(x)
201 |
202 | x = self.conv4(x)
203 | x = self.bn4(x)
204 | x = self.relu(x)
205 |
206 | x = self.conv5(x)
207 | x = self.bn5(x)
208 | x = self.relu(x)
209 |
210 | return x, low_level_feat
211 |
212 | def __init_weight(self):
213 | for m in self.modules():
214 | if isinstance(m, nn.Conv2d):
215 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
216 | # m.weight.data.normal_(0, math.sqrt(2. / n))
217 | torch.nn.init.kaiming_normal_(m.weight)
218 | elif isinstance(m, nn.BatchNorm2d):
219 | m.weight.data.fill_(1)
220 | m.bias.data.zero_()
221 |
222 | def __load_xception_pretrained(self):
223 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
224 | model_dict = {}
225 | state_dict = self.state_dict()
226 |
227 | for k, v in pretrain_dict.items():
228 | if k in state_dict:
229 | if 'pointwise' in k:
230 | v = v.unsqueeze(-1).unsqueeze(-1)
231 | if k.startswith('block12'):
232 | model_dict[k.replace('block12', 'block20')] = v
233 | elif k.startswith('block11'):
234 | model_dict[k.replace('block11', 'block12')] = v
235 | model_dict[k.replace('block11', 'block13')] = v
236 | model_dict[k.replace('block11', 'block14')] = v
237 | model_dict[k.replace('block11', 'block15')] = v
238 | model_dict[k.replace('block11', 'block16')] = v
239 | model_dict[k.replace('block11', 'block17')] = v
240 | model_dict[k.replace('block11', 'block18')] = v
241 | model_dict[k.replace('block11', 'block19')] = v
242 | elif k.startswith('conv3'):
243 | model_dict[k] = v
244 | elif k.startswith('bn3'):
245 | model_dict[k] = v
246 | model_dict[k.replace('bn3', 'bn4')] = v
247 | elif k.startswith('conv4'):
248 | model_dict[k.replace('conv4', 'conv5')] = v
249 | elif k.startswith('bn4'):
250 | model_dict[k.replace('bn4', 'bn5')] = v
251 | else:
252 | model_dict[k] = v
253 | state_dict.update(model_dict)
254 | self.load_state_dict(state_dict)
255 |
256 | class ASPP_module(nn.Module):
257 | def __init__(self, inplanes, planes, rate):
258 | super(ASPP_module, self).__init__()
259 | if rate == 1:
260 | kernel_size = 1
261 | padding = 0
262 | else:
263 | kernel_size = 3
264 | padding = rate
265 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
266 | stride=1, padding=padding, dilation=rate, bias=False)
267 | self.bn = nn.BatchNorm2d(planes)
268 | self.relu = nn.ReLU()
269 |
270 | self.__init_weight()
271 |
272 | def forward(self, x):
273 | x = self.atrous_convolution(x)
274 | x = self.bn(x)
275 |
276 | return self.relu(x)
277 |
278 | def __init_weight(self):
279 | for m in self.modules():
280 | if isinstance(m, nn.Conv2d):
281 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
282 | # m.weight.data.normal_(0, math.sqrt(2. / n))
283 | torch.nn.init.kaiming_normal_(m.weight)
284 | elif isinstance(m, nn.BatchNorm2d):
285 | m.weight.data.fill_(1)
286 | m.bias.data.zero_()
287 |
288 |
289 | class DeepLabv3_plus(nn.Module):
290 | def __init__(self, nInputChannels=3, n_classes=21, pretrained=False, _print=True):
291 | if _print:
292 | print("Constructing DeepLabv3+ model...")
293 | print("Number of classes: {}".format(n_classes))
294 | print("Number of Input Channels: {}".format(nInputChannels))
295 | super(DeepLabv3_plus, self).__init__()
296 |
297 | # Atrous Conv
298 | self.xception_features = Xception(nInputChannels, pretrained=pretrained)
299 |
300 | # ASPP
301 | rates = [1, 6, 12, 18]
302 | self.aspp1 = ASPP_module(2048, 256, rate=rates[0])
303 | self.aspp2 = ASPP_module(2048, 256, rate=rates[1])
304 | self.aspp3 = ASPP_module(2048, 256, rate=rates[2])
305 | self.aspp4 = ASPP_module(2048, 256, rate=rates[3])
306 |
307 | self.relu = nn.ReLU()
308 |
309 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
310 | nn.Conv2d(2048, 256, 1, stride=1, bias=False),
311 | nn.GroupNorm(32, 256),#nn.BatchNorm2d(256),
312 | nn.ReLU())
313 |
314 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
315 | #self.bn1 = nn.BatchNorm2d(256)
316 | self.bn1 = nn.GroupNorm(32, 256)
317 |
318 | # adopt [1x1, 48] for channel reduction.
319 | self.conv2 = nn.Conv2d(128, 48, 1, bias=False)
320 | #self.bn2 = nn.BatchNorm2d(48)
321 | self.bn2 = nn.GroupNorm(16, 48)
322 |
323 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
324 | #nn.BatchNorm2d(256),
325 | nn.GroupNorm(32, 256),
326 | nn.ReLU(),
327 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
328 | #nn.BatchNorm2d(256),
329 | nn.GroupNorm(32, 256),
330 | nn.ReLU(),
331 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1))
332 |
333 | def forward(self, input):
334 | x, low_level_features = self.xception_features(input)
335 |
336 | low_level_features = self.conv2(low_level_features)
337 | low_level_features = self.bn2(low_level_features)
338 | low_level_features = self.relu(low_level_features)
339 |
340 | x1 = self.aspp1(x)
341 | x2 = self.aspp2(x)
342 | x3 = self.aspp3(x)
343 | x4 = self.aspp4(x)
344 | x5 = self.global_avg_pool(x)
345 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
346 |
347 | y = torch.cat((x1, x2, x3, x4, x5), dim=1)
348 |
349 | y = self.conv1(y)
350 | y = self.bn1(y)
351 | y = self.relu(y)
352 | y = F.interpolate(y, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
353 |
354 | z = torch.cat((y, low_level_features), dim=1)
355 | z = self.last_conv(z)
356 | z = F.interpolate(z, size=input.size()[2:], mode='bilinear', align_corners=True)
357 |
358 | return z
359 |
360 | def freeze_bn(self):
361 | for m in self.modules():
362 | if isinstance(m, nn.BatchNorm2d):
363 | m.eval()
364 |
365 | def __init_weight(self):
366 | for m in self.modules():
367 | if isinstance(m, nn.Conv2d):
368 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
369 | # m.weight.data.normal_(0, math.sqrt(2. / n))
370 | torch.nn.init.kaiming_normal_(m.weight)
371 | elif isinstance(m, nn.BatchNorm2d):
372 | m.weight.data.fill_(1)
373 | m.bias.data.zero_()
374 |
375 | def get_1x_lr_params(model):
376 | """
377 | This generator returns all the parameters of the net except for
378 | the last classification layer. Note that for each batchnorm layer,
379 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
380 | any batchnorm parameter
381 | """
382 | b = [model.xception_features]
383 | for i in range(len(b)):
384 | for k in b[i].parameters():
385 | if k.requires_grad:
386 | yield k
387 |
388 |
389 | def get_10x_lr_params(model):
390 | """
391 | This generator returns all the parameters for the last layer of the net,
392 | which does the classification of pixel into classes
393 | """
394 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv]
395 | for j in range(len(b)):
396 | for k in b[j].parameters():
397 | if k.requires_grad:
398 | yield k
399 |
400 |
401 | if __name__ == "__main__":
402 | model = DeepLabv3_plus(nInputChannels=3, n_classes=21, pretrained=True, _print=True)
403 | image = torch.randn(1, 3, 512, 512)
404 | with torch.no_grad():
405 | output = model.forward(image)
406 | print(output.size())
407 |
408 |
409 |
410 |
411 |
412 |
413 |
--------------------------------------------------------------------------------