├── src ├── deeplab │ ├── __init__.py │ └── deeplab_xception.py ├── LICENSE_deeplab ├── LICENSE_adamw ├── utils.py ├── visualizer.py ├── model.py ├── evaluator.py ├── adamw.py ├── cosine_scheduler.py ├── test.py ├── test_merge.py ├── train.py ├── finetune_real.py ├── finetune.py └── dataset.py ├── imgs └── teaser.png ├── bandsel ├── readme.txt ├── bands │ ├── rs.txt │ ├── nncv.txt │ └── mvpca.txt └── bandsel_nncv.py ├── prepare ├── create_real_hdf5.sh ├── create_hdf5_bg.py ├── create_hdf5_testext.py ├── create_hdf5_val.py ├── create_hdf5_test.py ├── create_hdf5_trainext.py ├── create_hdf5_bgext.py └── calibrate_kappa.py ├── LICENSE ├── README.md └── recog ├── recognition.py ├── recognition_testext.py └── recognition_trainext.py /src/deeplab/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab_xception import DeepLabv3_plus 2 | -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tiancheng-zhi/ms-powder/HEAD/imgs/teaser.png -------------------------------------------------------------------------------- /bandsel/readme.txt: -------------------------------------------------------------------------------- 1 | Band Selection 2 | run: 3 | python bandsel_nncv.py 4 | 5 | Selected SWIR bands are in folder "bands". Band index ranges from 0 to 960. 6 | -------------------------------------------------------------------------------- /bandsel/bands/rs.txt: -------------------------------------------------------------------------------- 1 | 901,418,0,584,91,660,534,148,438,472,168,298,3,255,461,73,646,515,26,751,117,163,320,82,596,70,656,143,57,356,113,367,557,542,110,52,61,346,471,464,216,930,303,65,413,36,430,342,229 2 | -------------------------------------------------------------------------------- /bandsel/bands/nncv.txt: -------------------------------------------------------------------------------- 1 | 746,397,875,73,679,430,395,672,47,562,365,676,45,709,394,506,235,363,620,406,819,41,72,103,275,75,219,344,71,787,699,341,329,705,372,559,442,431,656,536,163,204,18,744,48,428,238,105,498 2 | -------------------------------------------------------------------------------- /bandsel/bands/mvpca.txt: -------------------------------------------------------------------------------- 1 | 833,682,148,123,483,633,338,551,1,391,166,226,504,684,143,342,583,556,907,594,182,154,876,599,85,45,490,285,535,531,249,615,174,812,560,115,2,40,444,487,254,539,626,111,133,262,510,899,253,463,332,546,586 2 | -------------------------------------------------------------------------------- /prepare/create_real_hdf5.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ../real/bg 2 | mkdir -p ../real/bgext 3 | python create_hdf5_bg.py 4 | python create_hdf5_bgext.py 5 | python create_hdf5_trainext.py 6 | python create_hdf5_test.py 7 | python create_hdf5_testext.py 8 | python create_hdf5_val.py 9 | -------------------------------------------------------------------------------- /prepare/create_hdf5_bg.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import cv2 4 | 5 | from pathlib import Path 6 | 7 | if __name__ == '__main__': 8 | data_path = Path('../data/train/') 9 | real_path = Path('../real/bg/') 10 | n_scenes = 16 11 | height = 160 12 | width = 280 13 | n_channels = 965 14 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 15 | 16 | for light in lights: 17 | h5f = h5py.File(str(Path(real_path / (light + '.hdf5'))), 'w') 18 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32') 19 | for i in range(n_scenes): 20 | idx = str(i).zfill(2) 21 | im_npz = np.load(data_path / light / 'bgscene' / (idx + '_bgscene.npz')) 22 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 23 | dset_im[i, :, :, :] = im 24 | h5f.close() 25 | -------------------------------------------------------------------------------- /prepare/create_hdf5_testext.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import cv2 4 | 5 | from pathlib import Path 6 | 7 | if __name__ == '__main__': 8 | data_path = Path('../data/testext/') 9 | real_path = Path('../real/') 10 | n_scenes = 64 11 | height = 160 12 | width = 280 13 | n_channels = 38 14 | h5f = h5py.File(str(Path(real_path / ('testext.hdf5'))), 'w') 15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32') 16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8') 17 | for i in range(n_scenes): 18 | idx = str(i).zfill(2) 19 | im_npz = np.load(data_path / 'scene' / (idx + '_scene.npz')) 20 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 21 | label = cv2.imread(str(data_path / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 22 | dset_im[i, :, :, :] = im 23 | dset_label[i, :, :] = label 24 | h5f.close() 25 | -------------------------------------------------------------------------------- /prepare/create_hdf5_val.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import cv2 4 | 5 | from pathlib import Path 6 | 7 | if __name__ == '__main__': 8 | data_path = Path('../data/val/') 9 | real_path = Path('../real/') 10 | n_scenes = 32 11 | height = 160 12 | width = 280 13 | n_channels = 965 14 | h5f = h5py.File(str(Path(real_path / ('val.hdf5'))), 'w') 15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32') 16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8') 17 | for i in range(n_scenes): 18 | idx = str(i).zfill(2) 19 | im_npz = np.load(data_path / 'scene' / (idx + '_scene.npz')) 20 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 21 | label = cv2.imread(str(data_path / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 22 | dset_im[i, :, :, :] = im 23 | dset_label[i, :, :] = label 24 | h5f.close() 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /prepare/create_hdf5_test.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import cv2 4 | 5 | from pathlib import Path 6 | 7 | if __name__ == '__main__': 8 | data_path = Path('../data/test/') 9 | real_path = Path('../real/') 10 | n_scenes = 32 11 | height = 160 12 | width = 280 13 | n_channels = 965 14 | h5f = h5py.File(str(Path(real_path / ('test.hdf5'))), 'w') 15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32') 16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8') 17 | for i in range(n_scenes): 18 | idx = str(i).zfill(2) 19 | im_npz = np.load(data_path / 'scene' / (idx + '_scene.npz')) 20 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 21 | label = cv2.imread(str(data_path / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 22 | dset_im[i, :, :, :] = im 23 | dset_label[i, :, :] = label 24 | h5f.close() 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tiancheng Zhi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/LICENSE_deeplab: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Pyjcsx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/LICENSE_adamw: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Maksym Pyrozhok 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /prepare/create_hdf5_trainext.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import cv2 4 | 5 | from pathlib import Path 6 | 7 | if __name__ == '__main__': 8 | data_path = Path('../data/trainext/') 9 | real_path = Path('../real/') 10 | n_scenes = 64 11 | height = 160 12 | width = 280 13 | n_channels = 38 14 | h5f = h5py.File(str(Path(real_path / ('trainext.hdf5'))), 'w') 15 | dset_im = h5f.create_dataset('im', (n_scenes, height, width, n_channels), dtype='float32') 16 | dset_label = h5f.create_dataset('label', (n_scenes, height, width), dtype='uint8') 17 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 18 | 19 | for i in range(n_scenes): 20 | idx = str(i % (n_scenes // len(lights))).zfill(2) 21 | light = lights[i // (n_scenes // len(lights))] 22 | im_npz = np.load(data_path / light / 'scene' / (idx + '_scene.npz')) 23 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 24 | label = cv2.imread(str(data_path / light / 'label' / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 25 | dset_im[i, :, :, :] = im 26 | dset_label[i, :, :] = label 27 | h5f.close() 28 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def cpu_np(tensor): 6 | return tensor.cpu().numpy() 7 | 8 | 9 | def to_image(matrix): 10 | image = cpu_np(torch.clamp(matrix, 0, 1) * 255).astype(np.uint8) 11 | if matrix.size()[0] == 1: 12 | image = np.concatenate((image, image, image), 0) 13 | return image 14 | return image 15 | 16 | 17 | def errormap(green, yellow, red, blue): 18 | err = np.zeros((3, green.shape[0], green.shape[1]), dtype=np.uint8) 19 | err[2, :, :][blue] = 255 20 | err[0, :, :][red] = 255 21 | err[0, :, :][yellow] = 255 22 | err[1, :, :][yellow] = 255 23 | err[2, :, :][yellow] = 0 24 | err[0, :, :][green] = 0 25 | err[1, :, :][green] = 255 26 | return err 27 | 28 | 29 | def colormap(label): 30 | cm = [] 31 | for r in [35, 90, 145, 200, 255]: 32 | for g in [35, 90, 145, 200, 255]: 33 | for b in [60, 125, 190, 255]: 34 | cm.append((r, g, b)) 35 | cm.append((0, 0, 0)) 36 | label_cm = np.stack((label, label, label), 0).astype(np.uint8) 37 | for c, color in enumerate(cm): 38 | mask = (label == c) 39 | label_cm[0, :, :][mask] = color[0] 40 | label_cm[1, :, :][mask] = color[1] 41 | label_cm[2, :, :][mask] = color[2] 42 | return label_cm 43 | -------------------------------------------------------------------------------- /src/visualizer.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | import numpy as np 3 | 4 | class Visualizer(): 5 | 6 | def __init__(self, server='http://localhost', port=8097, env='main'): 7 | self.vis = visdom.Visdom(server=server, port=port, env=env, use_incoming_socket=False) 8 | self.iteration = [] 9 | self.nlogloss = [] 10 | self.epoch = [] 11 | self.acc = [] 12 | 13 | def state_dict(self): 14 | return {'iteration': self.iteration, 'nlogloss': self.nlogloss, 'epoch': self.epoch, 'acc': self.acc} 15 | 16 | 17 | def load_state_dict(self, state_dict): 18 | self.iteration = state_dict['iteration'] 19 | self.nlogloss = state_dict['nlogloss'] 20 | self.epoch = state_dict['epoch'] 21 | self.acc = state_dict['acc'] 22 | 23 | def plot_loss(self): 24 | self.vis.line( 25 | X=np.array(self.iteration), 26 | Y=np.array(self.nlogloss), 27 | opts={ 28 | 'title': '-LogLoss', 29 | 'legend': ['-LogLoss'], 30 | 'xlabel': 'epoch', 31 | 'ylabel': '-logloss'}, 32 | win=0) 33 | 34 | def plot_acc(self): 35 | self.vis.line( 36 | X=np.array(self.epoch), 37 | Y=np.array(self.acc), 38 | opts={ 39 | 'title': 'Performance', 40 | 'legend': ['mIoUval', 'mIoUtest'], 41 | 'xlabel': 'epoch', 42 | 'ylabel': 'performance'}, 43 | win=1) 44 | 45 | def plot_image(self, im, idx): 46 | self.vis.image(im, win=idx + 2) 47 | -------------------------------------------------------------------------------- /prepare/create_hdf5_bgext.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import cv2 4 | 5 | from pathlib import Path 6 | 7 | if __name__ == '__main__': 8 | data_path_bg = Path('../data/train/') 9 | data_path_ext = Path('../data/trainext/') 10 | real_path = Path('../real/bgext/') 11 | real_path.mkdir(exist_ok=True, parents=True) 12 | n_scenes = 16 13 | height = 160 14 | width = 280 15 | n_channels = 38 16 | sel = [0, 1, 2, 3, 4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964] 17 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 18 | for light in lights: 19 | h5f = h5py.File(str(Path(real_path / (light + '.hdf5'))), 'w') 20 | dset_im = h5f.create_dataset('im', (n_scenes * 2, height, width, n_channels), dtype='float32') 21 | for i in range(n_scenes): 22 | idx = str(i).zfill(2) 23 | im_npz = np.load(data_path_bg / light / 'bgscene' / (idx + '_bgscene.npz')) 24 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 25 | dset_im[i, :, :, :] = im[:, :, sel] 26 | for i in range(n_scenes): 27 | idx = str(i).zfill(2) 28 | im_npz = np.load(data_path_ext / light / 'bgscene' / (idx + '_bgscene.npz')) 29 | im = np.concatenate((im_npz['rgbn'].astype(np.float32), im_npz['swir'].astype(np.float32)), axis=2) 30 | dset_im[n_scenes + i, :, :, :] = im 31 | h5f.close() 32 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | from torch.nn import Parameter 7 | from deeplab import DeepLabv3_plus 8 | 9 | 10 | class PowderNet(nn.Module): 11 | 12 | def __init__(self, arch, n_channels, n_classes): 13 | super(PowderNet, self).__init__() 14 | if arch == 'deeplab': 15 | self.body = DeepLabv3_plus(nInputChannels=n_channels, n_classes=n_classes, pretrained=False, _print=False) 16 | else: 17 | assert(False) 18 | 19 | def forward(self, x): 20 | out = self.body(x) 21 | return out 22 | 23 | 24 | def get_1x_lr_params(model): 25 | """ 26 | This generator returns all the parameters of the net except for 27 | the last classification layer. Note that for each batchnorm layer, 28 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 29 | any batchnorm parameter 30 | """ 31 | b = [model.body.xception_features] 32 | for i in range(len(b)): 33 | for k in b[i].parameters(): 34 | if k.requires_grad: 35 | yield k 36 | 37 | 38 | def get_10x_lr_params(model): 39 | """ 40 | This generator returns all the parameters for the last layer of the net, 41 | which does the classification of pixel into classes 42 | """ 43 | b = [model.body.aspp1, model.body.aspp2, model.body.aspp3, model.body.aspp4, model.body.conv1, model.body.conv2, model.body.last_conv] 44 | for j in range(len(b)): 45 | for k in b[j].parameters(): 46 | if k.requires_grad: 47 | yield k 48 | -------------------------------------------------------------------------------- /prepare/calibrate_kappa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser(description='Calibration') 7 | parser.add_argument('--data-path', type=str, default='../data/train') 8 | parser.add_argument('--out-path', type=str, default='../params') 9 | opt = parser.parse_args() 10 | return opt 11 | 12 | 13 | if __name__ == '__main__': 14 | opt = parse_args() 15 | print(opt) 16 | 17 | Path(opt.out_path).mkdir(parents=True, exist_ok=True) 18 | 19 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 20 | n = 100 21 | h = 14 22 | w = 14 23 | c = 965 24 | valid_threshold = 20 25 | 26 | kappa_params = [] 27 | for i in range(n): 28 | print(i) 29 | key = str(i).zfill(2) 30 | kappa_lights = [] 31 | for lid, light in enumerate(lights): 32 | thick_path = Path(opt.data_path) / light / 'thick' 33 | thin_path = Path(opt.data_path) / light / 'thin' 34 | bg_path = Path(opt.data_path) / light / 'bg' 35 | thick = np.load(thick_path / (key + '_thick.npz')) 36 | thin = np.load(thin_path / (key + '_thin.npz')) 37 | bg = np.load(bg_path / (key + '_bg.npz')) 38 | thick = np.concatenate((thick['rgbn'], thick['swir']), 2) 39 | thin = np.concatenate((thin['rgbn'], thin['swir']), 2) 40 | bg = np.concatenate((bg['rgbn'], bg['swir']), 2) 41 | 42 | 43 | thick = np.mean(thick, (0, 1), keepdims=True) 44 | bg = np.mean(bg, (0, 1), keepdims=True) 45 | alpha = (thin - thick) / (bg - thick) 46 | 47 | # valid alpha selection 48 | alpha = alpha.reshape([h * w, c]) 49 | alpha = np.clip(alpha, 0.01, 0.99) 50 | kt = -np.log(alpha) 51 | kappa = np.median(kt, axis=0) 52 | ratio = (kappa[:4].mean() + kappa[4:].mean()) / 2 53 | kappa = kappa / ratio 54 | print(kappa[:4], kappa[[5, -1]], kappa.max(), kappa.min()) 55 | kappa_lights.append(kappa) 56 | kappa_params.append(kappa_lights) 57 | np.savez(Path(opt.out_path) / 'kappa_params.npz', params=kappa_params) 58 | -------------------------------------------------------------------------------- /src/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | 6 | 7 | class Evaluator: 8 | 9 | def __init__(self, n_classes, bg_err): 10 | self.n_classes = n_classes 11 | self.bg_err = bg_err 12 | 13 | self.bg_confs = [] 14 | self.pd_preds = [] 15 | self.gt_nums = [[] for c in range(n_classes)] 16 | self.cls_bg_confs = [[] for c in range(n_classes)] 17 | 18 | self.tp = np.zeros(n_classes) 19 | self.fp = np.zeros(n_classes) 20 | self.num = np.zeros(n_classes) 21 | self.itp = np.zeros(n_classes) 22 | self.inum = np.zeros(n_classes) 23 | 24 | self.preds = [] 25 | 26 | def register(self, label, prob): 27 | """ 28 | label: H x W, n_classes-1 is bg 29 | prob: C x H x W 30 | """ 31 | 32 | bg_conf = prob[-1,:,:] 33 | pd_pred = np.argmax(prob[:-1,:,:], axis=0) 34 | self.bg_confs.append(bg_conf) 35 | self.pd_preds.append(pd_pred) 36 | pred = np.argmax(prob, axis=0) 37 | self.preds.append(pred) 38 | for c in range(self.n_classes): 39 | gt_mask = (label == c) 40 | # for Powder Accuracy 41 | if gt_mask.any(): 42 | if c < self.n_classes - 1: 43 | pred_mask = (pd_pred == c) 44 | self.gt_nums[c].append(gt_mask.sum()) 45 | self.cls_bg_confs[c].append(bg_conf[gt_mask * pred_mask]) 46 | else: 47 | self.cls_bg_confs[c] += list(bg_conf[gt_mask]) 48 | # for IoU 49 | self.num[c] += gt_mask.sum() 50 | self.tp[c] += ((pred == c) * gt_mask).sum() 51 | self.inum[c] += 1 52 | self.itp[c] += ((pred == c) * gt_mask).sum() / gt_mask.sum() 53 | self.fp[c] += ((pred == c) * (1 - gt_mask)).sum() 54 | 55 | def evaluate(self): 56 | self.bg_conf_threshold = np.percentile(self.cls_bg_confs[-1], self.bg_err) 57 | accs = [] 58 | for c in range(self.n_classes - 1): 59 | for i, cls_bg_conf in enumerate(self.cls_bg_confs[c]): 60 | acc = (cls_bg_conf < self.bg_conf_threshold).sum() / self.gt_nums[c][i] 61 | accs.append(acc) 62 | msa = np.mean(np.array(accs)) 63 | 64 | self.bg_confs = np.array(self.bg_confs) 65 | self.pd_preds = np.array(self.pd_preds) 66 | predictions = self.pd_preds.copy() 67 | predictions[self.bg_confs >= self.bg_conf_threshold] = self.n_classes - 1 68 | 69 | iou = self.tp / (self.num + self.fp) 70 | miou = iou.mean() 71 | iiou = (self.itp * self.num / self.inum) / (self.num + self.fp) 72 | miiou = iiou.mean() 73 | return msa, predictions, miou, miiou, self.preds 74 | 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multispectral Imaging for Fine-Grained Recognition of Powders on Complex Backgrounds 2 | 3 | 4 | 5 | [Tiancheng Zhi](http://cs.cmu.edu/~tzhi), [Bernardo R. Pires](http://www.andrew.cmu.edu/user/bpires/), [Martial Hebert](http://www.cs.cmu.edu/~hebert/), [Srinivasa G. Narasimhan](http://www.cs.cmu.edu/~srinivas/) 6 | 7 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019. 8 | 9 | [[Project](http://www.cs.cmu.edu/~ILIM/projects/IM/MSPowder/)] [[Paper](http://www.cs.cmu.edu/~ILIM/projects/IM/MSPowder/files/ZPHN-CVPR19.pdf)] [[Supp](http://www.cs.cmu.edu/~ILIM/projects/IM/MSPowder/files/ZPHN-CVPR19-supp.pdf)] 10 | 11 |

12 | 13 |

14 | 15 | ## Requirements 16 | - NVIDIA TITAN Xp 17 | - Ubuntu 16.04 18 | - Python 3.6 19 | - OpenCV 4.0 20 | - PyTorch 1.0 21 | - Visdom 22 | 23 | ## Download "SWIRPowder" Dataset 24 | Download the ["data" folder](http://platformpgh.cs.cmu.edu/tzhi/SWIRPowderRelease/data/), and put it in the repo root directory. 25 | See "data/readme.txt" for description. 26 | 27 | ## Calibarate Attenuation Parameter 28 | In "prepare" directory, run: 29 | ``` 30 | python calibrate_kappa.py 31 | ``` 32 | 33 | ## Band Selection 34 | See "readme.txt" in "bandsel" directory 35 | 36 | ## Recognition with Known Powder Location/Mask 37 | In "recog" directory, run: 38 | ``` 39 | python recognition.py 40 | ``` 41 | 42 | ## Recognition without Known Powder Location/Mask 43 | ### Prepare real data 44 | In "prepare" directory, run: 45 | ``` 46 | sh create_real_hdf5.sh 47 | ``` 48 | 49 | ### Prepare synthetic data 50 | Download the ["synthetic" folder](http://platformpgh.cs.cmu.edu/tzhi/SWIRPowderRelease/synthetic/) and put it in the repo root directory. 51 | 52 | ### Train on synthetic powder on synthetic background 53 | In "src" directory, run: 54 | ``` 55 | python train.py --out-path ckpts/ckpt_default --bands 0,1,2,3,77,401,750,879 56 | ``` 57 | 58 | Note that the hdf5 file merges RGBN and SWIR channels, so channel ID 0\~3 are RGBN channels, channel ID 4\~964 are SWIR channels. 59 | 60 | To use NNCV selection, use `--bands 0,1,2,3,77,401,750,879`. 61 | 62 | To use Grid selection, use `--bands 0,1,2,3,4,34,934,964`. 63 | 64 | To use MVPCA selection, use `--bands 0,1,2,3,127,152,686,837`. 65 | 66 | To use RS selection, use `--bands 0,1,2,3,4,422,588,905`. 67 | 68 | See "bandsel/bands/" for more selected bands. Remember to "add 4" to convert 0\~960 range to 4\~964 range. 69 | 70 | 71 | ### Train on synthetic powder on real background 72 | In "src" directory, run: 73 | ``` 74 | python finetune.py --out-path ckpts/ckpt_default_extft --bands 0,1,2,3,77,401,750,879 --pretrain ckpts/ckpt_default/247.pth --split bgext 75 | ``` 76 | 77 | Note: use `--split bg` for experiments on unextended dataset. 78 | 79 | ### Train on real powder on real background 80 | In "src" directory, run: 81 | ``` 82 | python finetune_real.py --out-path ckpts/ckpt_default_extft_real --bands 0,1,2,3,77,401,750,879 --pretrain ckpts/ckpt_default_extft/55.pth 83 | ``` 84 | 85 | ### Test with CRF post-processing 86 | In "src" directory, run: 87 | ``` 88 | python test.py --ckpt model.pth # Test on Scene-test 89 | python test_merge.py --ckpt model.pth # Test on dataset merging Scene-test and Scene-sl-test 90 | ``` 91 | 92 | ### Pretrained model 93 | Download [pretrained.pth](http://platformpgh.cs.cmu.edu/tzhi/SWIRPowderRelease/pretrained.pth), put it in "src" directory, and test it with: 94 | ``` 95 | python test_merge.py --ckpt pretrained.pth 96 | ``` 97 | -------------------------------------------------------------------------------- /src/adamw.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class AdamW(Optimizer): 7 | """Implements Adam algorithm. 8 | 9 | Arguments: 10 | params (iterable): iterable of parameters to optimize or dicts defining 11 | parameter groups 12 | lr (float, optional): learning rate (default: 1e-3) 13 | betas (Tuple[float, float], optional): coefficients used for computing 14 | running averages of gradient and its square (default: (0.9, 0.999)) 15 | eps (float, optional): term added to the denominator to improve 16 | numerical stability (default: 1e-8) 17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 18 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 19 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 20 | 21 | """ 22 | 23 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 24 | weight_decay=0, amsgrad=False): 25 | if not 0.0 <= betas[0] < 1.0: 26 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 27 | if not 0.0 <= betas[1] < 1.0: 28 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 29 | defaults = dict(lr=lr, betas=betas, eps=eps, 30 | weight_decay=weight_decay, amsgrad=amsgrad) 31 | #super(AdamW, self).__init__(params, defaults) 32 | super().__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | """Performs a single optimization step. 36 | 37 | Arguments: 38 | closure (callable, optional): A closure that reevaluates the model 39 | and returns the loss. 40 | """ 41 | loss = None 42 | if closure is not None: 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | for p in group['params']: 47 | if p.grad is None: 48 | continue 49 | grad = p.grad.data 50 | if grad.is_sparse: 51 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 52 | amsgrad = group['amsgrad'] 53 | 54 | state = self.state[p] 55 | 56 | # State initialization 57 | if len(state) == 0: 58 | state['step'] = 0 59 | # Exponential moving average of gradient values 60 | state['exp_avg'] = torch.zeros_like(p.data) 61 | # Exponential moving average of squared gradient values 62 | state['exp_avg_sq'] = torch.zeros_like(p.data) 63 | if amsgrad: 64 | # Maintains max of all exp. moving avg. of sq. grad. values 65 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 66 | 67 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 68 | if amsgrad: 69 | max_exp_avg_sq = state['max_exp_avg_sq'] 70 | beta1, beta2 = group['betas'] 71 | 72 | state['step'] += 1 73 | 74 | # Decay the first and second moment running average coefficient 75 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 76 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 77 | if amsgrad: 78 | # Maintains the maximum of all 2nd moment running avg. till now 79 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 80 | # Use the max. for normalizing running avg. of gradient 81 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 82 | else: 83 | denom = exp_avg_sq.sqrt().add_(group['eps']) 84 | 85 | bias_correction1 = 1 - beta1 ** state['step'] 86 | bias_correction2 = 1 - beta2 ** state['step'] 87 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 88 | 89 | if group['weight_decay'] != 0: 90 | decayed_weights = torch.mul(p.data, group['weight_decay']) 91 | p.data.addcdiv_(-step_size, exp_avg, denom) 92 | p.data.sub_(decayed_weights) 93 | else: 94 | p.data.addcdiv_(-step_size, exp_avg, denom) 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /bandsel/bandsel_nncv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import numba 6 | import cv2 7 | import time 8 | 9 | from pathlib import Path 10 | 11 | 12 | def cosine(a, b): 13 | y = b.unsqueeze(0) 14 | n_pixels = a.size()[0] 15 | batch_size = 1024 16 | if n_pixels % batch_size == 0: 17 | n_batches = n_pixels // batch_size 18 | else: 19 | n_batches = n_pixels // batch_size + 1 20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda() 21 | for batch_idx in range(n_batches): 22 | bs = batch_idx * batch_size 23 | be = min(n_pixels, (batch_idx + 1) * batch_size) 24 | x = a[bs:be, :].unsqueeze(1) 25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2) 26 | return sim 27 | 28 | 29 | def sim_func(dist, query, database): 30 | if (dist == 'full') or (query.size()[1] <= 5): 31 | return cosine(query, database) 32 | elif dist == 'split': 33 | return cosine(query[:,:4], database[:,:4]) + cosine(query[:,4:], database[:,4:]) 34 | else: 35 | assert(False) 36 | 37 | 38 | def feat_eng(dist, raw): 39 | if dist == 'full' or dist == 'split': 40 | return raw 41 | elif dist == 'decouple': 42 | swir = raw[:, 4:] 43 | mean_swir = swir.mean(dim=1, keepdim=True) 44 | feat = torch.cat((mean_swir, raw), dim=1) 45 | return feat 46 | else: 47 | assert(False) 48 | 49 | 50 | def parse_args(): 51 | parser = argparse.ArgumentParser(description='NNCV Band Selection') 52 | parser.add_argument('--data-path', type=str, default='../data') 53 | parser.add_argument('--log-path', type=str, default='./bands') 54 | parser.add_argument('--n-sels', type=int, default=49) 55 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split']) 56 | opt = parser.parse_args() 57 | return opt 58 | 59 | 60 | if __name__ == '__main__': 61 | opt = parse_args() 62 | Path(opt.log_path).mkdir(parents=True, exist_ok=True) 63 | if opt.dist == 'split': 64 | log = open(Path(opt.log_path) / ('nncv.txt'), 'w') 65 | else: 66 | log = open(Path(opt.log_path) / ('nncv_{}.txt'.format(opt.dist)), 'w') 67 | print(opt) 68 | opt.lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 69 | opt.n_lights = len(opt.lights) 70 | opt.n_powders = 100 71 | opt.n_bgmats = 100 72 | opt.n_channels = 965 73 | opt.n_full_swir_channels = 961 74 | 75 | train_path = Path(opt.data_path) / 'train' 76 | 77 | y = [] 78 | 79 | thick_list = np.zeros((opt.n_powders, opt.n_lights, opt.n_channels)) 80 | for i in range(opt.n_powders): 81 | idx = str(i).zfill(2) 82 | for lid, light in enumerate(opt.lights): 83 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz')) 84 | thick = np.concatenate((thick['rgbn'], thick['swir']), axis=2) 85 | thick = thick.mean((0, 1)) 86 | thick_list[i, lid] = thick 87 | y.append(i) 88 | thick_list = thick_list.reshape((opt.n_powders * opt.n_lights, opt.n_channels)) 89 | 90 | bgmat_list = np.zeros((opt.n_bgmats, opt.n_lights, opt.n_channels)) 91 | for i in range(opt.n_bgmats): 92 | idx = str(i).zfill(2) 93 | for lid, light in enumerate(opt.lights): 94 | bgmat = np.load(train_path / light / 'bgmat' / (idx + '_bgmat.npz')) 95 | bgmat = np.concatenate((bgmat['rgbn'], bgmat['swir']), axis=2) 96 | bgmat = bgmat.mean((0, 1)) 97 | bgmat_list[i, lid] = bgmat 98 | y.append(opt.n_powders) 99 | bgmat_list = bgmat_list.reshape((opt.n_bgmats * opt.n_lights, opt.n_channels)) 100 | 101 | raw = np.concatenate((thick_list, bgmat_list), axis=0) 102 | raw = torch.from_numpy(raw).cuda() 103 | y = np.array(y) 104 | y = torch.from_numpy(y).cuda() 105 | 106 | selection = np.zeros(opt.n_channels, dtype=np.bool_) 107 | selection[0] = True 108 | selection[1] = True 109 | selection[2] = True 110 | selection[3] = True 111 | 112 | start_time = time.time() 113 | bands = [] 114 | for i in range(opt.n_sels): 115 | best_acc = 0 116 | for j in range(opt.n_channels): 117 | if selection[j]: 118 | continue 119 | selection[j] = True 120 | selection_th = torch.from_numpy(selection.astype(np.uint8)).unsqueeze(0).cuda() 121 | x = torch.masked_select(raw, selection_th).view(raw.size()[0], -1) 122 | if i > 0: 123 | x = feat_eng(opt.dist, x) 124 | 125 | sims = sim_func(opt.dist, x, x) 126 | _, indices = torch.sort(sims, dim=1) 127 | acc = (y == y[indices[:,-2]]).cpu().numpy().astype(np.float32) 128 | 129 | acc = acc.reshape((opt.n_powders + opt.n_bgmats, opt.n_lights)).mean(axis=1) 130 | 131 | acc = (acc[:opt.n_powders].sum() + acc[opt.n_powders:].mean()) / (opt.n_powders + 1) 132 | if acc > best_acc: 133 | best_acc = acc 134 | best_sel = j 135 | selection[j] = False 136 | selection[best_sel] = True 137 | print(i, best_acc, best_sel - 4, round(time.time()-start_time)) 138 | bands.append(best_sel - 4) 139 | 140 | print(bands) 141 | st = '' 142 | for i in bands: 143 | st = st + ',' + str(i) 144 | st = st[1:] 145 | print(st, file=log) 146 | log.close() 147 | 148 | -------------------------------------------------------------------------------- /src/cosine_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | import math 3 | import torch 4 | 5 | 6 | class CosineLRWithRestarts(): 7 | """Decays learning rate with cosine annealing, normalizes weight decay 8 | hyperparameter value, implements restarts. 9 | https://arxiv.org/abs/1711.05101 10 | 11 | Args: 12 | optimizer (Optimizer): Wrapped optimizer. 13 | batch_size: minibatch size 14 | epoch_size: training samples per epoch 15 | restart_period: epoch count in the first restart period 16 | t_mult: multiplication factor by which the next restart period will extend/shrink 17 | 18 | 19 | Example: 20 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) 21 | >>> for epoch in range(100): 22 | >>> scheduler.step() 23 | >>> train(...) 24 | >>> ... 25 | >>> optimizer.zero_grad() 26 | >>> loss.backward() 27 | >>> optimizer.step() 28 | >>> scheduler.batch_step() 29 | >>> validate(...) 30 | """ 31 | 32 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, 33 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False): 34 | 35 | if not isinstance(optimizer, Optimizer): 36 | raise TypeError('{} is not an Optimizer'.format( 37 | type(optimizer).__name__)) 38 | self.optimizer = optimizer 39 | if last_epoch == -1: 40 | for group in optimizer.param_groups: 41 | group.setdefault('initial_lr', group['lr']) 42 | else: 43 | for i, group in enumerate(optimizer.param_groups): 44 | if 'initial_lr' not in group: 45 | raise KeyError("param 'initial_lr' is not specified " 46 | "in param_groups[{}] when resuming an" 47 | " optimizer".format(i)) 48 | self.base_lrs = list(map(lambda group: group['initial_lr'], 49 | optimizer.param_groups)) 50 | 51 | self.last_epoch = last_epoch 52 | self.batch_size = batch_size 53 | self.epoch_size = epoch_size 54 | self.eta_threshold = eta_threshold 55 | self.t_mult = t_mult 56 | self.verbose = verbose 57 | self.base_weight_decays = list(map(lambda group: group['weight_decay'], 58 | optimizer.param_groups)) 59 | self.restart_period = restart_period 60 | self.restarts = 0 61 | self.t_epoch = -1 62 | 63 | def state_dict(self): 64 | """Returns the state of the scheduler as a :class:`dict`. 65 | 66 | It contains an entry for every variable in self.__dict__ which 67 | is not the optimizer. 68 | """ 69 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'batch_increment'} 70 | 71 | def load_state_dict(self, state_dict): 72 | """Loads the schedulers state. 73 | 74 | Arguments: 75 | state_dict (dict): scheduler state. Should be an object returned 76 | from a call to :meth:`state_dict`. 77 | """ 78 | self.__dict__.update(state_dict) 79 | 80 | def _schedule_eta(self): 81 | """ 82 | Threshold value could be adjusted to shrink eta_min and eta_max values. 83 | """ 84 | eta_min = 0 85 | eta_max = 1 86 | if self.restarts <= self.eta_threshold: 87 | return eta_min, eta_max 88 | else: 89 | d = self.restarts - self.eta_threshold 90 | k = d * 0.09 91 | return (eta_min + k, eta_max - k) 92 | 93 | def get_lr(self, t_cur): 94 | eta_min, eta_max = self._schedule_eta() 95 | 96 | eta_t = (eta_min + 0.5 * (eta_max - eta_min) 97 | * (1. + math.cos(math.pi * 98 | (t_cur / self.restart_period)))) 99 | 100 | weight_decay_norm_multi = math.sqrt(self.batch_size / 101 | (self.epoch_size * 102 | self.restart_period)) 103 | lrs = [base_lr * eta_t for base_lr in self.base_lrs] 104 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi 105 | for base_weight_decay in self.base_weight_decays] 106 | 107 | if self.t_epoch % self.restart_period < self.t_epoch: 108 | if self.verbose: 109 | print("Restart at epoch {}".format(self.last_epoch)) 110 | self.restart_period *= self.t_mult 111 | self.restarts += 1 112 | self.t_epoch = 0 113 | 114 | return zip(lrs, weight_decays) 115 | 116 | def _set_batch_size(self): 117 | d, r = divmod(self.epoch_size, self.batch_size) 118 | batches_in_epoch = d + 2 if r > 0 else d + 1 119 | self.batch_increment = (i for i in torch.linspace(0, 1, 120 | batches_in_epoch)) 121 | 122 | def step(self): 123 | self.last_epoch += 1 124 | self.t_epoch += 1 125 | self._set_batch_size() 126 | self.batch_step() 127 | 128 | def batch_step(self): 129 | t_cur = self.t_epoch + next(self.batch_increment) 130 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups, 131 | self.get_lr(t_cur)): 132 | param_group['lr'] = lr 133 | param_group['weight_decay'] = weight_decay 134 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pydensecrf.densecrf as dcrf 8 | from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral 9 | from pathlib import Path 10 | from torch.utils.data import DataLoader 11 | from dataset import RealDataset 12 | from model import PowderNet 13 | from utils import to_image, colormap, errormap 14 | from evaluator import Evaluator 15 | import skimage.io as io 16 | import cv2 17 | import collections 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Testing') 22 | parser.add_argument('--ckpt', type=str) 23 | parser.add_argument('--real-path', type=str, default='../real') 24 | parser.add_argument('--out-path', type=str, default='./result') 25 | parser.add_argument('--bg-err', type=float, default=1.0) 26 | parser.add_argument('--sdims', type=int, default=3) 27 | parser.add_argument('--schan', type=int, default=3) 28 | parser.add_argument('--compat', type=int, default=3) 29 | parser.add_argument('--iters', type=int, default=10) 30 | parser.add_argument('--threads', type=int, default=1) 31 | parser.add_argument('--batch-size', type=int, default=1) 32 | opt = parser.parse_args() 33 | return opt 34 | 35 | 36 | def crf(prob, im, sdims, schan, compat, iters): 37 | if opt.channels == 965: 38 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964] 39 | elif opt.channels == 4: 40 | bilateral_ch = [0,1,2,3] 41 | elif opt.channels == -961: 42 | bilateral_ch = [0,155,320,465,640,775,960] 43 | else: 44 | bilateral_ch = range(opt.n_channels) 45 | C, H, W = prob.shape 46 | U = unary_from_softmax(prob) 47 | d = dcrf.DenseCRF2D(H, W, C) 48 | d.setUnaryEnergy(U) 49 | pairwise_energy = create_pairwise_bilateral(sdims=(sdims, sdims), schan=(schan,), img=im[bilateral_ch, :, :], chdim=0) 50 | d.addPairwiseEnergy(pairwise_energy, compat=compat) 51 | Q_unary = d.inference(iters) 52 | Q_unary = np.array(Q_unary).reshape(-1, H, W) 53 | return Q_unary 54 | 55 | 56 | def test(opt, test_loader, net, split): 57 | start_time = time.time() 58 | eva = Evaluator(opt.n_classes, opt.bg_err) 59 | eva_crf = Evaluator(opt.n_classes, opt.bg_err) 60 | ims = [] 61 | labels = [] 62 | 63 | net = net.eval() 64 | 65 | for iteration, batch in enumerate(test_loader): 66 | im, label = batch 67 | im = im.cuda() 68 | label = label.cuda() 69 | out = net(im) 70 | prob = F.softmax(out, dim=1) 71 | for i in range(opt.batch_size): 72 | prob_np = prob[i].detach().cpu().numpy() 73 | label_np = label[i].cpu().numpy() 74 | im_np = im[i].cpu().numpy() 75 | ims.append(to_image(im[i,:3,:,:])) 76 | labels.append(label_np) 77 | eva.register(label_np, prob_np) 78 | prob_crf = crf(prob_np, im_np, opt.sdims, opt.schan, opt.compat, opt.iters) 79 | eva_crf.register(label_np, prob_crf) 80 | print(str(iteration * opt.batch_size + i).zfill(2), time.time() - start_time, 'seconds') 81 | 82 | msa, preds_msa, miou, miiou, preds_miou = eva.evaluate() 83 | msa_crf, preds_msa_crf, miou_crf, miiou_crf, preds_miou_crf = eva_crf.evaluate() 84 | print('Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1))) 85 | print('Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1))) 86 | for i, label in enumerate(labels): 87 | pred_msa = preds_msa[i] 88 | pred_msa_crf = preds_msa_crf[i] 89 | pred_miou = preds_miou[i] 90 | pred_miou_crf = preds_miou_crf[i] 91 | vis_im = ims[i] 92 | vis_label = colormap(label) 93 | vis_pred_msa = colormap(pred_msa) 94 | vis_pred_msa_crf = colormap(pred_msa_crf) 95 | vis_pred_miou = colormap(pred_miou) 96 | vis_pred_miou_crf = colormap(pred_miou_crf) 97 | vis_all = np.concatenate(( 98 | np.concatenate((vis_im, vis_label), axis=2), 99 | np.concatenate((vis_pred_miou, vis_pred_miou_crf), axis=2)), axis=1) 100 | vis_all = vis_all.transpose((1, 2, 0)) 101 | io.imsave(Path(opt.out_path) / split / (str(i).zfill(2) + '.png'), vis_all) 102 | return msa, miou, miiou, msa_crf, miou_crf, miiou_crf 103 | 104 | if __name__ == '__main__': 105 | cv2.setNumThreads(0) 106 | 107 | opt = parse_args() 108 | print(opt) 109 | 110 | (Path(opt.out_path) / 'test').mkdir(parents=True, exist_ok=True) 111 | (Path(opt.out_path) / 'val').mkdir(parents=True, exist_ok=True) 112 | 113 | checkpoint = torch.load(opt.ckpt) 114 | 115 | opt.channels = checkpoint['opt'].channels if 'channels' in checkpoint['opt'].__dict__ else 965 116 | opt.n_channels = checkpoint['opt'].n_channels if 'n_channels' in checkpoint['opt'].__dict__ else abs(opt.channels) 117 | opt.n_classes = checkpoint['opt'].n_classes 118 | opt.arch = checkpoint['opt'].arch 119 | 120 | test_set = RealDataset(opt.real_path, opt.channels, split='test') 121 | test_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False) 122 | 123 | val_set = RealDataset(opt.real_path, opt.channels, split='val') 124 | val_loader = DataLoader(dataset=val_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False) 125 | 126 | net = PowderNet(opt.arch, opt.n_channels, opt.n_classes) 127 | net = net.cuda() 128 | net.load_state_dict(checkpoint['state_dict']) 129 | 130 | log_file = open(Path(opt.out_path) / 'performance.txt', 'w') 131 | print(opt, file=log_file) 132 | msa, miou, miiou, msa_crf, miou_crf, miiou_crf = test(opt, test_loader, net, 'test') 133 | print('Test Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)), file=log_file) 134 | print('Test Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)), file=log_file) 135 | msa, miou, miiou, msa_crf, miou_crf, miiou_crf = test(opt, val_loader, net, 'val') 136 | print('Val Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)), file=log_file) 137 | print('Val Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)), file=log_file) 138 | print('Complete', file=log_file) 139 | log_file.close() 140 | -------------------------------------------------------------------------------- /src/test_merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import pydensecrf.densecrf as dcrf 8 | from pydensecrf.utils import unary_from_softmax, create_pairwise_bilateral 9 | from pathlib import Path 10 | from torch.utils.data import DataLoader 11 | from dataset import RealDataset 12 | from model import PowderNet 13 | from utils import to_image, colormap, errormap 14 | from evaluator import Evaluator 15 | import skimage.io as io 16 | import cv2 17 | import collections 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Testing') 22 | parser.add_argument('--ckpt', type=str) 23 | parser.add_argument('--real-path', type=str, default='../real') 24 | parser.add_argument('--out-path', type=str, default='./result') 25 | parser.add_argument('--bg-err', type=float, default=1.0) 26 | parser.add_argument('--sdims', type=int, default=3) 27 | parser.add_argument('--schan', type=int, default=3) 28 | parser.add_argument('--compat', type=int, default=3) 29 | parser.add_argument('--iters', type=int, default=10) 30 | parser.add_argument('--threads', type=int, default=1) 31 | parser.add_argument('--batch-size', type=int, default=1) 32 | opt = parser.parse_args() 33 | return opt 34 | 35 | 36 | def crf(prob, im, sdims, schan, compat, iters): 37 | if opt.channels == 965: 38 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964] 39 | elif opt.channels == 4: 40 | bilateral_ch = [0,1,2,3] 41 | elif opt.channels == -961: 42 | bilateral_ch = [0,155,320,465,640,775,960] 43 | else: 44 | bilateral_ch = range(opt.n_channels) 45 | C, H, W = prob.shape 46 | U = unary_from_softmax(prob) 47 | d = dcrf.DenseCRF2D(H, W, C) 48 | d.setUnaryEnergy(U) 49 | pairwise_energy = create_pairwise_bilateral(sdims=(sdims, sdims), schan=(schan,), img=im[bilateral_ch, :, :], chdim=0) 50 | d.addPairwiseEnergy(pairwise_energy, compat=compat) 51 | Q_unary = d.inference(iters) 52 | Q_unary = np.array(Q_unary).reshape(-1, H, W) 53 | return Q_unary 54 | 55 | 56 | def test(opt, test_loader, testext_loader, net, split): 57 | start_time = time.time() 58 | eva = Evaluator(opt.n_classes, opt.bg_err) 59 | eva_crf = Evaluator(opt.n_classes, opt.bg_err) 60 | ims = [] 61 | labels = [] 62 | 63 | net = net.eval() 64 | 65 | for iteration, batch in enumerate(test_loader): 66 | im, label = batch 67 | im = im.cuda() 68 | label = label.cuda() 69 | out = net(im) 70 | prob = F.softmax(out, dim=1) 71 | for i in range(opt.batch_size): 72 | prob_np = prob[i].detach().cpu().numpy() 73 | label_np = label[i].cpu().numpy() 74 | im_np = im[i].cpu().numpy() 75 | ims.append(to_image(im[i,:3,:,:])) 76 | labels.append(label_np) 77 | eva.register(label_np, prob_np) 78 | prob_crf = crf(prob_np, im_np, opt.sdims, opt.schan, opt.compat, opt.iters) 79 | eva_crf.register(label_np, prob_crf) 80 | print('test', str(iteration * opt.batch_size + i).zfill(2), time.time() - start_time, 'seconds') 81 | 82 | for iteration, batch in enumerate(testext_loader): 83 | im, label = batch 84 | im = im.cuda() 85 | label = label.cuda() 86 | out = net(im) 87 | prob = F.softmax(out, dim=1) 88 | for i in range(opt.batch_size): 89 | prob_np = prob[i].detach().cpu().numpy() 90 | label_np = label[i].cpu().numpy() 91 | im_np = im[i].cpu().numpy() 92 | ims.append(to_image(im[i,:3,:,:])) 93 | labels.append(label_np) 94 | eva.register(label_np, prob_np) 95 | prob_crf = crf(prob_np, im_np, opt.sdims, opt.schan, opt.compat, opt.iters) 96 | eva_crf.register(label_np, prob_crf) 97 | print('testext', str(iteration * opt.batch_size + i).zfill(2), time.time() - start_time, 'seconds') 98 | 99 | msa, preds_msa, miou, miiou, preds_miou = eva.evaluate() 100 | msa_crf, preds_msa_crf, miou_crf, miiou_crf, preds_miou_crf = eva_crf.evaluate() 101 | print('Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1))) 102 | print('Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1))) 103 | for i, label in enumerate(labels): 104 | pred_msa = preds_msa[i] 105 | pred_msa_crf = preds_msa_crf[i] 106 | pred_miou = preds_miou[i] 107 | pred_miou_crf = preds_miou_crf[i] 108 | vis_im = ims[i] 109 | vis_label = colormap(label) 110 | vis_pred_msa = colormap(pred_msa) 111 | vis_pred_msa_crf = colormap(pred_msa_crf) 112 | vis_pred_miou = colormap(pred_miou) 113 | vis_pred_miou_crf = colormap(pred_miou_crf) 114 | vis_all = np.concatenate(( 115 | np.concatenate((vis_im, vis_label), axis=2), 116 | np.concatenate((vis_pred_miou, vis_pred_miou_crf), axis=2)), axis=1) 117 | vis_all = vis_all.transpose((1, 2, 0)) 118 | io.imsave(Path(opt.out_path) / split / (str(i).zfill(2) + '.png'), vis_all) 119 | return msa, miou, miiou, msa_crf, miou_crf, miiou_crf 120 | 121 | if __name__ == '__main__': 122 | cv2.setNumThreads(0) 123 | 124 | opt = parse_args() 125 | print(opt) 126 | 127 | (Path(opt.out_path) / 'merge').mkdir(parents=True, exist_ok=True) 128 | 129 | checkpoint = torch.load(opt.ckpt) 130 | 131 | opt.channels = checkpoint['opt'].channels if 'channels' in checkpoint['opt'].__dict__ else 965 132 | opt.n_channels = checkpoint['opt'].n_channels if 'n_channels' in checkpoint['opt'].__dict__ else abs(opt.channels) 133 | opt.n_classes = checkpoint['opt'].n_classes 134 | opt.arch = checkpoint['opt'].arch 135 | 136 | test_set = RealDataset(opt.real_path, opt.channels, split='test') 137 | test_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False) 138 | 139 | testext_set = RealDataset(opt.real_path, opt.channels, split='testext') 140 | testext_loader = DataLoader(dataset=testext_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=False) 141 | 142 | net = PowderNet(opt.arch, opt.n_channels, opt.n_classes) 143 | net = net.cuda() 144 | net.load_state_dict(checkpoint['state_dict']) 145 | 146 | log_file = open(Path(opt.out_path) / 'performance_merge.txt', 'w') 147 | print(opt, file=log_file) 148 | msa, miou, miiou, msa_crf, miou_crf, miiou_crf = test(opt, test_loader, testext_loader, net, 'merge') 149 | print('Merge Pre-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa * 100, 1), round(miou * 100, 1), round(miiou * 100, 1)), file=log_file) 150 | print('Merge Post-CRF: MSA: {} mIoU: {} miIoU: {}'.format(round(msa_crf * 100, 1), round(miou_crf * 100, 1), round(miiou_crf * 100, 1)), file=log_file) 151 | print('Complete', file=log_file) 152 | log_file.close() 153 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | from cosine_scheduler import CosineLRWithRestarts 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from dataset import SyntheticDataset, RealDataset 13 | from visualizer import Visualizer 14 | from model import PowderNet 15 | from utils import to_image, colormap, errormap 16 | from adamw import AdamW 17 | import cv2 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Training') 22 | parser.add_argument('--syn-path', type=str, default='../synthetic') 23 | parser.add_argument('--real-path', type=str, default='../real') 24 | parser.add_argument('--params-path', type=str, default='../params') 25 | parser.add_argument('--out-path', type=str, default='./checkpoint') 26 | parser.add_argument('--channels', type=int, choices=[965, 4, -961, 0], default=0, help='x>0 select [:x]; x<0 select [x:]; x=0 see --bands') 27 | parser.add_argument('--bands', type=str, default=None) 28 | parser.add_argument('--resume', type=str, default=None) 29 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa') 30 | parser.add_argument('--arch', type=str, choices=['deeplab'], default='deeplab') 31 | parser.add_argument('--threads', type=int, default=6) 32 | parser.add_argument('--batch-size', type=int, default=8) 33 | parser.add_argument('--n-epochs', type=int, default=248) 34 | parser.add_argument('--lr', type=float, default=1e-3) 35 | parser.add_argument('--decay', type=float, default=1e-4) 36 | parser.add_argument('--period', type=int, default=8) 37 | parser.add_argument('--t-mult', type=float, default=2) 38 | parser.add_argument('--vis-iter', type=int, default=0) 39 | parser.add_argument('--server', type=str, default='http://localhost') 40 | parser.add_argument('--env', type=str, default='main') 41 | opt = parser.parse_args() 42 | if opt.bands is not None: 43 | assert(opt.channels == 0) 44 | opt.channels = [int(i) for i in opt.bands.split(',')] 45 | opt.n_channels = len(opt.channels) 46 | else: 47 | opt.n_channels = abs(opt.channels) 48 | return opt 49 | 50 | 51 | def train(opt, vis, epoch, train_loader, net, optimizer, scheduler): 52 | net = net.train() 53 | train_len = len(train_loader) 54 | start_time = time.time() 55 | scheduler.step() 56 | for iteration, batch in enumerate(train_loader): 57 | # Load Data 58 | im, label = batch 59 | im = im.cuda(non_blocking=True) 60 | label = label.cuda(non_blocking=True) 61 | 62 | # Forward Pass 63 | out = net(im) 64 | loss = F.cross_entropy(out, label) 65 | 66 | # Backward Pass 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | scheduler.batch_step() 71 | 72 | # Logging 73 | cur_time = time.time() 74 | loss_scalar = float(loss.cpu().detach().numpy()) 75 | if iteration < opt.threads: 76 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \ 77 | round((cur_time - start_time) / (iteration + 1), 2), \ 78 | round(loss_scalar, 4))) 79 | if iteration == opt.threads - 1: 80 | start_time = cur_time 81 | else: 82 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \ 83 | round((cur_time - start_time) / (iteration + 1 - opt.threads), 2), \ 84 | round(loss_scalar, 4))) 85 | 86 | # Visualization 87 | vis.iteration.append(epoch + iteration / train_len) 88 | vis.nlogloss.append(-np.log(np.maximum(1e-6, loss_scalar))) 89 | vis.plot_loss() 90 | if opt.channels != 965 or opt.vis_iter <= 0 or iteration % opt.vis_iter > 0: 91 | continue 92 | prob, pred = torch.max(out, dim=1) 93 | vis_rgb = to_image(im[0, 0:3, :, :] * 0.5) 94 | vis_nir = to_image(im[0, 3:4, :, :] * 0.5) 95 | vis_swir1 = to_image(im[0, 4:5, :, :] * 0.5) 96 | vis_swir2 = to_image(im[0, 964:965, :, :] * 0.5) 97 | vis_label = colormap(label[0].cpu().numpy()) 98 | vis_pred = colormap(pred[0].cpu().numpy()) 99 | vis_im = np.concatenate((np.concatenate((vis_label, vis_pred), axis=1), \ 100 | np.concatenate((vis_rgb, vis_nir), axis=1), \ 101 | np.concatenate((vis_swir1, vis_swir2), axis=1)), axis=2) 102 | vis.plot_image(vis_im, 0) 103 | 104 | 105 | def test(opt, epoch, test_loader, net): 106 | if opt.channels == 965: 107 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964] 108 | elif opt.channels == 4: 109 | bilateral_ch = [0,1,2,3] 110 | elif opt.channels == -961: 111 | bilateral_ch = [0,155,320,465,640,775,960] 112 | else: 113 | bilateral_ch = range(opt.n_channels) 114 | net = net.eval() 115 | test_len = len(test_loader) 116 | tp = np.zeros(opt.n_classes) 117 | fp = np.zeros(opt.n_classes) 118 | num = np.zeros(opt.n_classes) 119 | start_time = time.time() 120 | for iteration, batch in enumerate(test_loader): 121 | # Load Data 122 | im, label = batch 123 | im = im.cuda() 124 | label = label.cuda() 125 | 126 | # Forward Pass 127 | out = net(im) 128 | 129 | # Evaluation 130 | prob = F.softmax(out, dim=1) 131 | _, pred = torch.max(prob, dim=1) 132 | 133 | bsize = pred.size()[0] 134 | 135 | for i in range(bsize): 136 | label_np = label[i].cpu().numpy() 137 | pred_np = pred[i].cpu().numpy() 138 | for c in range(opt.n_classes): 139 | mask = (label_np == c) 140 | tp[c] += ((pred_np == c) * mask).sum() 141 | fp[c] += ((pred_np == c) * (1 - mask)).sum() 142 | num[c] += mask.sum() 143 | 144 | iou = tp / (num + fp) 145 | miou = iou.mean() 146 | 147 | return miou 148 | 149 | 150 | if __name__ == '__main__': 151 | cv2.setNumThreads(0) 152 | 153 | opt = parse_args() 154 | print(opt) 155 | 156 | Path(opt.out_path).mkdir(parents=True, exist_ok=True) 157 | 158 | train_set = SyntheticDataset(opt.syn_path, opt.params_path, opt.blend, opt.channels) 159 | train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True, pin_memory=True) 160 | 161 | val_set = RealDataset(opt.real_path, opt.channels, split='val') 162 | val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False) 163 | 164 | test_set = RealDataset(opt.real_path, opt.channels, split='test') 165 | test_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False) 166 | 167 | opt.n_classes = train_set.n_classes 168 | net = PowderNet(opt.arch, opt.n_channels, train_set.n_classes) 169 | net = net.cuda() 170 | optimizer = AdamW(net.parameters(), lr=opt.lr, weight_decay=opt.decay) 171 | scheduler = CosineLRWithRestarts(optimizer, opt.batch_size, len(train_set), opt.period, opt.t_mult) 172 | vis = Visualizer(server=opt.server, env=opt.env) 173 | start_epoch = 0 174 | if opt.resume is not None: 175 | checkpoint = torch.load(opt.resume) 176 | old_opt = checkpoint['opt'] 177 | assert(old_opt.channels == opt.channels) 178 | assert(old_opt.bands == opt.bands) 179 | assert(old_opt.arch == opt.arch) 180 | assert(old_opt.blend == opt.blend) 181 | assert(old_opt.lr == opt.lr) 182 | assert(old_opt.decay == opt.decay) 183 | assert(old_opt.period == opt.period) 184 | assert(old_opt.t_mult == opt.t_mult) 185 | 186 | net.load_state_dict(checkpoint['state_dict']) 187 | optimizer.load_state_dict(checkpoint['optimizer']) 188 | scheduler.load_state_dict(checkpoint['scheduler']) 189 | vis.load_state_dict(checkpoint['vis']) 190 | start_epoch = checkpoint['epoch'] + 1 191 | 192 | for epoch in range(start_epoch, opt.n_epochs): 193 | train(opt, vis, epoch, train_loader, net, optimizer, scheduler) 194 | miou_val = test(opt, epoch, val_loader, net) 195 | miou_test = test(opt, epoch, test_loader, net) 196 | vis.epoch.append(epoch) 197 | vis.acc.append([miou_val, miou_test]) 198 | vis.plot_acc() 199 | if (epoch + 1) % opt.period == 0: 200 | torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth')) 201 | print('Val mIoU:', miou_val, ' Test mIoU:', miou_test) 202 | -------------------------------------------------------------------------------- /src/finetune_real.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | from cosine_scheduler import CosineLRWithRestarts 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from dataset import RealDataset 13 | from visualizer import Visualizer 14 | from model import PowderNet 15 | from utils import to_image, colormap, errormap 16 | from adamw import AdamW 17 | from model import get_1x_lr_params, get_10x_lr_params 18 | import cv2 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='PowderDetector') 23 | parser.add_argument('--real-path', type=str, default='../real') 24 | parser.add_argument('--params-path', type=str, default='../params') 25 | parser.add_argument('--out-path', type=str, default='./checkpoint') 26 | parser.add_argument('--channels', type=int, choices=[965, 4, -961, 0], default=0, help='x>0 select [:x]; x<0 select [x:]; x=0 see --bands') 27 | parser.add_argument('--bands', type=str, default=None) 28 | parser.add_argument('--pretrain', type=str, default=None) 29 | parser.add_argument('--resume', type=str, default=None) 30 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa') 31 | parser.add_argument('--arch', type=str, choices=['deeplab'], default='deeplab') 32 | parser.add_argument('--threads', type=int, default=6) 33 | parser.add_argument('--batch-size', type=int, default=8) 34 | parser.add_argument('--n-epochs', type=int, default=24) 35 | parser.add_argument('--lr', type=float, default=5e-5) 36 | parser.add_argument('--decay', type=float, default=1e-4) 37 | parser.add_argument('--period', type=int, default=8) 38 | parser.add_argument('--t-mult', type=float, default=2) 39 | parser.add_argument('--vis-iter', type=int, default=0) 40 | parser.add_argument('--server', type=str, default='http://localhost') 41 | parser.add_argument('--env', type=str, default='main') 42 | opt = parser.parse_args() 43 | if opt.bands is not None: 44 | assert(opt.channels == 0) 45 | opt.channels = [int(i) for i in opt.bands.split(',')] 46 | opt.n_channels = len(opt.channels) 47 | else: 48 | opt.n_channels = abs(opt.channels) 49 | return opt 50 | 51 | 52 | def train(opt, vis, epoch, train_loader, net, optimizer, scheduler): 53 | net = net.train() 54 | train_len = len(train_loader) 55 | start_time = time.time() 56 | scheduler.step() 57 | for iteration, batch in enumerate(train_loader): 58 | # Load Data 59 | im, label = batch 60 | im = im.cuda(non_blocking=True) 61 | label = label.cuda(non_blocking=True) 62 | 63 | # Forward Pass 64 | out = net(im) 65 | loss = F.cross_entropy(out, label) 66 | 67 | # Backward Pass 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | scheduler.batch_step() 72 | 73 | # Logging 74 | cur_time = time.time() 75 | loss_scalar = float(loss.cpu().detach().numpy()) 76 | if iteration < opt.threads: 77 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \ 78 | round((cur_time - start_time) / (iteration + 1), 2), \ 79 | round(loss_scalar, 4))) 80 | if iteration == opt.threads - 1: 81 | start_time = cur_time 82 | else: 83 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \ 84 | round((cur_time - start_time) / (iteration + 1 - opt.threads), 2), \ 85 | round(loss_scalar, 4))) 86 | 87 | # Visualization 88 | vis.iteration.append(epoch + iteration / train_len) 89 | vis.nlogloss.append(-np.log(np.maximum(1e-6, loss_scalar))) 90 | vis.plot_loss() 91 | if opt.vis_iter <= 0 or iteration % opt.vis_iter > 0: 92 | continue 93 | prob, pred = torch.max(out, dim=1) 94 | vis_rgb = to_image(im[0, 0:3, :, :] * 0.5) 95 | vis_nir = to_image(im[0, 3:4, :, :] * 0.5) 96 | vis_swir1 = to_image(im[0, 4:5, :, :] * 0.5) 97 | vis_swir2 = to_image(im[0, -2:-1, :, :] * 0.5) 98 | vis_label = colormap(label[0].cpu().numpy()) 99 | vis_pred = colormap(pred[0].cpu().numpy()) 100 | vis_im = np.concatenate((np.concatenate((vis_label, vis_pred), axis=1), \ 101 | np.concatenate((vis_rgb, vis_nir), axis=1), \ 102 | np.concatenate((vis_swir1, vis_swir2), axis=1)), axis=2) 103 | vis.plot_image(vis_im, 0) 104 | 105 | 106 | def test(opt, epoch, test_loader, net): 107 | if opt.channels == 965: 108 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964] 109 | elif opt.channels == 4: 110 | bilateral_ch = [0,1,2,3] 111 | elif opt.channels == -961: 112 | bilateral_ch = [0,155,320,465,640,775,960] 113 | else: 114 | bilateral_ch = range(opt.n_channels) 115 | net = net.eval() 116 | test_len = len(test_loader) 117 | tp = np.zeros(opt.n_classes) 118 | fp = np.zeros(opt.n_classes) 119 | tp_crf = np.zeros(opt.n_classes) 120 | fp_crf = np.zeros(opt.n_classes) 121 | num = np.zeros(opt.n_classes) 122 | start_time = time.time() 123 | for iteration, batch in enumerate(test_loader): 124 | # Load Data 125 | im, label = batch 126 | im = im.cuda() 127 | label = label.cuda() 128 | 129 | # Forward Pass 130 | out = net(im) 131 | 132 | # Visualization 133 | prob = F.softmax(out, dim=1) 134 | _, pred = torch.max(prob, dim=1) 135 | 136 | bsize = pred.size()[0] 137 | 138 | for i in range(bsize): 139 | label_np = label[i].cpu().numpy() 140 | pred_np = pred[i].cpu().numpy() 141 | for c in range(opt.n_classes): 142 | mask = (label_np == c) 143 | tp[c] += ((pred_np == c) * mask).sum() 144 | fp[c] += ((pred_np == c) * (1 - mask)).sum() 145 | num[c] += mask.sum() 146 | 147 | iou = tp / (num + fp) 148 | miou = iou.mean() 149 | return miou 150 | 151 | 152 | if __name__ == '__main__': 153 | cv2.setNumThreads(0) 154 | 155 | opt = parse_args() 156 | print(opt) 157 | 158 | Path(opt.out_path).mkdir(parents=True, exist_ok=True) 159 | 160 | train_set = RealDataset(opt.real_path, opt.channels, split='trainext', flip=True) 161 | train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True, pin_memory=True) 162 | 163 | val_set = RealDataset(opt.real_path, opt.channels, split='val') 164 | val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False) 165 | 166 | test_set = RealDataset(opt.real_path, opt.channels, split='test') 167 | test_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False) 168 | 169 | opt.n_classes = train_set.n_classes 170 | net = PowderNet(opt.arch, opt.n_channels, train_set.n_classes) 171 | net = net.cuda() 172 | optimizer = AdamW([{'params': get_1x_lr_params(net)}, {'params': get_10x_lr_params(net), 'lr': opt.lr * 10}], lr=opt.lr, weight_decay=opt.decay) 173 | scheduler = CosineLRWithRestarts(optimizer, opt.batch_size, len(train_set), opt.period, opt.t_mult) 174 | vis = Visualizer(server=opt.server, env=opt.env) 175 | start_epoch = 0 176 | if opt.resume is not None: 177 | checkpoint = torch.load(opt.resume) 178 | old_opt = checkpoint['opt'] 179 | assert(old_opt.channels == opt.channels) 180 | assert(old_opt.bands == opt.bands) 181 | assert(old_opt.arch == opt.arch) 182 | assert(old_opt.blend == opt.blend) 183 | assert(old_opt.lr == opt.lr) 184 | assert(old_opt.decay == opt.decay) 185 | assert(old_opt.period == opt.period) 186 | assert(old_opt.t_mult == opt.t_mult) 187 | net.load_state_dict(checkpoint['state_dict']) 188 | optimizer.load_state_dict(checkpoint['optimizer']) 189 | scheduler.load_state_dict(checkpoint['scheduler']) 190 | vis.load_state_dict(checkpoint['vis']) 191 | start_epoch = checkpoint['epoch'] + 1 192 | elif opt.pretrain is not None: 193 | checkpoint = torch.load(opt.pretrain) 194 | old_opt = checkpoint['opt'] 195 | assert(old_opt.channels == opt.channels) 196 | assert(old_opt.bands == opt.bands) 197 | assert(old_opt.arch == opt.arch) 198 | assert(old_opt.blend == opt.blend) 199 | net.load_state_dict(checkpoint['state_dict']) 200 | else: 201 | assert(False) 202 | 203 | for epoch in range(start_epoch, opt.n_epochs): 204 | train(opt, vis, epoch, train_loader, net, optimizer, scheduler) 205 | miou_val = test(opt, epoch, val_loader, net) 206 | miou_test = test(opt, epoch, test_loader, net) 207 | vis.epoch.append(epoch) 208 | vis.acc.append([miou_val, miou_test]) 209 | vis.plot_acc() 210 | if (epoch + 1) % opt.period == 0: 211 | torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth')) 212 | print('Val mIoU:', miou_val, ' Test mIoU:', miou_test) 213 | -------------------------------------------------------------------------------- /src/finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | from cosine_scheduler import CosineLRWithRestarts 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from dataset import RealDataset, HalfHalfDataset 13 | from visualizer import Visualizer 14 | from model import PowderNet 15 | from utils import to_image, colormap, errormap 16 | from adamw import AdamW 17 | from model import get_1x_lr_params, get_10x_lr_params 18 | import cv2 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='PowderDetector') 23 | parser.add_argument('--syn-path', type=str, default='../synthetic') 24 | parser.add_argument('--real-path', type=str, default='../real') 25 | parser.add_argument('--params-path', type=str, default='../params') 26 | parser.add_argument('--out-path', type=str, default='./checkpoint') 27 | parser.add_argument('--split', type=str, default='bgext', choices=['bg', 'bgext']) 28 | parser.add_argument('--channels', type=int, choices=[965, 4, -961, 0], default=0, help='x>0 select [:x]; x<0 select [x:]; x=0 see --bands') 29 | parser.add_argument('--bands', type=str, default=None) 30 | parser.add_argument('--pretrain', type=str, default=None) 31 | parser.add_argument('--resume', type=str, default=None) 32 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa') 33 | parser.add_argument('--arch', type=str, choices=['deeplab'], default='deeplab') 34 | parser.add_argument('--threads', type=int, default=6) 35 | parser.add_argument('--batch-size', type=int, default=8) 36 | parser.add_argument('--n-epochs', type=int, default=56) 37 | parser.add_argument('--lr', type=float, default=1e-4) 38 | parser.add_argument('--decay', type=float, default=1e-4) 39 | parser.add_argument('--period', type=int, default=8) 40 | parser.add_argument('--t-mult', type=float, default=2) 41 | parser.add_argument('--vis-iter', type=int, default=0) 42 | parser.add_argument('--server', type=str, default='http://localhost') 43 | parser.add_argument('--env', type=str, default='main') 44 | opt = parser.parse_args() 45 | if opt.bands is not None: 46 | assert(opt.channels == 0) 47 | opt.channels = [int(i) for i in opt.bands.split(',')] 48 | opt.n_channels = len(opt.channels) 49 | else: 50 | opt.n_channels = abs(opt.channels) 51 | return opt 52 | 53 | 54 | def train(opt, vis, epoch, train_loader, net, optimizer, scheduler): 55 | net = net.train() 56 | train_len = len(train_loader) 57 | start_time = time.time() 58 | scheduler.step() 59 | for iteration, batch in enumerate(train_loader): 60 | # Load Data 61 | im, label = batch 62 | im = im.cuda(non_blocking=True) 63 | label = label.cuda(non_blocking=True) 64 | 65 | # Forward Pass 66 | out = net(im) 67 | loss = F.cross_entropy(out, label) 68 | 69 | # Backward Pass 70 | optimizer.zero_grad() 71 | loss.backward() 72 | optimizer.step() 73 | scheduler.batch_step() 74 | 75 | # Logging 76 | cur_time = time.time() 77 | loss_scalar = float(loss.cpu().detach().numpy()) 78 | if iteration < opt.threads: 79 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \ 80 | round((cur_time - start_time) / (iteration + 1), 2), \ 81 | round(loss_scalar, 4))) 82 | if iteration == opt.threads - 1: 83 | start_time = cur_time 84 | else: 85 | print('{} [{}]({}/{}) AvgTime:{:>4} Loss:{:>4}'.format(opt.env, epoch, iteration, train_len, \ 86 | round((cur_time - start_time) / (iteration + 1 - opt.threads), 2), \ 87 | round(loss_scalar, 4))) 88 | 89 | # Visualization 90 | vis.iteration.append(epoch + iteration / train_len) 91 | vis.nlogloss.append(-np.log(np.maximum(1e-6, loss_scalar))) 92 | vis.plot_loss() 93 | if opt.vis_iter <= 0 or iteration % opt.vis_iter > 0: 94 | continue 95 | prob, pred = torch.max(out, dim=1) 96 | vis_rgb = to_image(im[0, 0:3, :, :] * 0.5) 97 | vis_nir = to_image(im[0, 3:4, :, :] * 0.5) 98 | vis_swir1 = to_image(im[0, 4:5, :, :] * 0.5) 99 | vis_swir2 = to_image(im[0, -2:-1, :, :] * 0.5) 100 | vis_label = colormap(label[0].cpu().numpy()) 101 | vis_pred = colormap(pred[0].cpu().numpy()) 102 | vis_im = np.concatenate((np.concatenate((vis_label, vis_pred), axis=1), \ 103 | np.concatenate((vis_rgb, vis_nir), axis=1), \ 104 | np.concatenate((vis_swir1, vis_swir2), axis=1)), axis=2) 105 | vis.plot_image(vis_im, 0) 106 | 107 | 108 | def test(opt, epoch, test_loader, net): 109 | if opt.channels == 965: 110 | bilateral_ch = [0,1,2,3,4,159,324,469,644,779,964] 111 | elif opt.channels == 4: 112 | bilateral_ch = [0,1,2,3] 113 | elif opt.channels == -961: 114 | bilateral_ch = [0,155,320,465,640,775,960] 115 | else: 116 | bilateral_ch = range(opt.n_channels) 117 | net = net.eval() 118 | test_len = len(test_loader) 119 | tp = np.zeros(opt.n_classes) 120 | fp = np.zeros(opt.n_classes) 121 | tp_crf = np.zeros(opt.n_classes) 122 | fp_crf = np.zeros(opt.n_classes) 123 | num = np.zeros(opt.n_classes) 124 | start_time = time.time() 125 | for iteration, batch in enumerate(test_loader): 126 | # Load Data 127 | im, label = batch 128 | im = im.cuda() 129 | label = label.cuda() 130 | 131 | # Forward Pass 132 | out = net(im) 133 | 134 | # Visualization 135 | prob = F.softmax(out, dim=1) 136 | _, pred = torch.max(prob, dim=1) 137 | 138 | bsize = pred.size()[0] 139 | 140 | for i in range(bsize): 141 | label_np = label[i].cpu().numpy() 142 | pred_np = pred[i].cpu().numpy() 143 | for c in range(opt.n_classes): 144 | mask = (label_np == c) 145 | tp[c] += ((pred_np == c) * mask).sum() 146 | fp[c] += ((pred_np == c) * (1 - mask)).sum() 147 | num[c] += mask.sum() 148 | 149 | iou = tp / (num + fp) 150 | miou = iou.mean() 151 | return miou 152 | 153 | 154 | if __name__ == '__main__': 155 | cv2.setNumThreads(0) 156 | 157 | opt = parse_args() 158 | print(opt) 159 | 160 | Path(opt.out_path).mkdir(parents=True, exist_ok=True) 161 | 162 | train_set = HalfHalfDataset(opt.real_path, opt.syn_path, opt.params_path, opt.blend, opt.channels, opt.split) 163 | train_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True, pin_memory=True) 164 | 165 | val_set = RealDataset(opt.real_path, opt.channels, split='val') 166 | val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False) 167 | 168 | test_set = RealDataset(opt.real_path, opt.channels, split='test') 169 | test_loader = DataLoader(dataset=test_set, num_workers=0, batch_size=1, shuffle=False) 170 | 171 | opt.n_classes = train_set.n_classes 172 | net = PowderNet(opt.arch, opt.n_channels, train_set.n_classes) 173 | net = net.cuda() 174 | optimizer = AdamW([{'params': get_1x_lr_params(net)}, {'params': get_10x_lr_params(net), 'lr': opt.lr * 10}], lr=opt.lr, weight_decay=opt.decay) 175 | scheduler = CosineLRWithRestarts(optimizer, opt.batch_size, len(train_set), opt.period, opt.t_mult) 176 | vis = Visualizer(server=opt.server, env=opt.env) 177 | start_epoch = 0 178 | if opt.resume is not None: 179 | checkpoint = torch.load(opt.resume) 180 | old_opt = checkpoint['opt'] 181 | assert(old_opt.channels == opt.channels) 182 | assert(old_opt.bands == opt.bands) 183 | assert(old_opt.arch == opt.arch) 184 | assert(old_opt.blend == opt.blend) 185 | assert(old_opt.lr == opt.lr) 186 | assert(old_opt.decay == opt.decay) 187 | assert(old_opt.period == opt.period) 188 | assert(old_opt.t_mult == opt.t_mult) 189 | net.load_state_dict(checkpoint['state_dict']) 190 | optimizer.load_state_dict(checkpoint['optimizer']) 191 | scheduler.load_state_dict(checkpoint['scheduler']) 192 | vis.load_state_dict(checkpoint['vis']) 193 | start_epoch = checkpoint['epoch'] + 1 194 | elif opt.pretrain is not None: 195 | checkpoint = torch.load(opt.pretrain) 196 | old_opt = checkpoint['opt'] 197 | #assert(old_opt.channels == opt.channels) 198 | #assert(old_opt.bands == opt.bands) 199 | assert(old_opt.arch == opt.arch) 200 | assert(old_opt.blend == opt.blend) 201 | net.load_state_dict(checkpoint['state_dict']) 202 | else: 203 | assert(False) 204 | 205 | for epoch in range(start_epoch, opt.n_epochs): 206 | train(opt, vis, epoch, train_loader, net, optimizer, scheduler) 207 | miou_val = test(opt, epoch, val_loader, net) 208 | miou_test = test(opt, epoch, test_loader, net) 209 | vis.epoch.append(epoch) 210 | vis.acc.append([miou_val, miou_test]) 211 | vis.plot_acc() 212 | if (epoch + 1) % opt.period == 0: 213 | torch.save({'epoch': epoch, 'opt': opt, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'vis': vis.state_dict()}, Path(opt.out_path) / (str(epoch) + '.pth')) 214 | print('Val mIoU:', miou_val, ' Test mIoU:', miou_test) 215 | -------------------------------------------------------------------------------- /recog/recognition.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import cv2 7 | import time 8 | 9 | from pathlib import Path 10 | 11 | 12 | def cosine(a, b): 13 | y = b.unsqueeze(0) 14 | n_pixels = a.size()[0] 15 | batch_size = 1000 16 | if n_pixels % batch_size == 0: 17 | n_batches = n_pixels // batch_size 18 | else: 19 | n_batches = n_pixels // batch_size + 1 20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda() 21 | for batch_idx in range(n_batches): 22 | bs = batch_idx * batch_size 23 | be = min(n_pixels, (batch_idx + 1) * batch_size) 24 | x = a[bs:be, :].unsqueeze(1) 25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2) 26 | return sim 27 | 28 | 29 | def sim_func(opt, query, database): 30 | if opt.dist == 'full': 31 | return cosine(query, database) 32 | elif opt.dist == 'split': 33 | return cosine(query[:, :opt.n_rgbns], database[:, :opt.n_rgbns]) + cosine(query[:, opt.n_rgbns:], database[:, opt.n_rgbns:]) 34 | else: 35 | assert(False) 36 | 37 | 38 | def sims2pred(n_lights, sims): 39 | votes = sims.argmax(dim=1) // n_lights 40 | votes = votes.cpu().numpy() 41 | counts = collections.Counter(votes) 42 | pred = [i[0] for i in counts.most_common()] 43 | return pred 44 | 45 | 46 | def match_powder_none(opt, database, scene): 47 | n_pixels = scene.size()[0] 48 | query = scene.view((n_pixels, -1)) 49 | sims = sim_func(opt, query, database) 50 | return sims 51 | 52 | 53 | def match_powder_kappa(opt, database, scene, bg, kappa): 54 | n_database = database.size()[0] 55 | n_pixels = scene.size()[0] 56 | n_channels = scene.size()[1] 57 | 58 | eta = torch.linspace(0, opt.eta_max, opt.n_etas, dtype=torch.double).cuda() 59 | alpha = eta.unsqueeze(0).unsqueeze(2) ** kappa.unsqueeze(1) 60 | 61 | # n_pixels * n_database * n_etas * n_channels 62 | bg = bg.unsqueeze(1).unsqueeze(2) 63 | 64 | # n_database * n_etas * n_channels 65 | database = database.unsqueeze(1) * (1 - alpha) 66 | 67 | batch_size = 64000 // n_channels 68 | sims = torch.zeros((n_pixels, n_database), dtype=torch.double).cuda() 69 | if n_database % batch_size == 0: 70 | n_batches = n_database // batch_size 71 | else: 72 | n_batches = n_database // batch_size + 1 73 | 74 | for p in range(n_pixels): 75 | query = scene[p:p+1, :] 76 | db = database + bg[p, :, :, :] * alpha 77 | for batch_idx in range(n_batches): 78 | bs = batch_idx * batch_size 79 | be = min(n_database, (batch_idx + 1) * batch_size) 80 | cur_database = db[bs:be, :, :].reshape(((be - bs) * opt.n_etas, -1)) 81 | cur_sims = sim_func(opt, query, cur_database) 82 | cur_sims, _ = cur_sims.view((be - bs, opt.n_etas)).max(1) 83 | sims[p, bs:be] = cur_sims 84 | 85 | return sims 86 | 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser(description='Recognition with known mask') 90 | parser.add_argument('--data-path', type=str, default='../data') 91 | parser.add_argument('--log-path', type=str, default='./log') 92 | parser.add_argument('--sel-path', type=str, default='../bandsel/bands') 93 | parser.add_argument('--bg', type=str, choices=['gt', 'inpaint'], default='inpaint') 94 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa') 95 | parser.add_argument('--eta-max', type=float, default=0.9) 96 | parser.add_argument('--n-etas', type=int, default=10) 97 | parser.add_argument('--kappa-params', type=str, default='../params/kappa_params.npz') 98 | parser.add_argument('--n-swirs', type=int, default=4) 99 | parser.add_argument('--n-rgbns', type=int, default=4, choices=[0, 1, 3, 4]) 100 | parser.add_argument('--sel', type=str, default='nncv', choices=['nncv', 'grid', 'mvpca', 'rs']) 101 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split']) 102 | parser.add_argument('--set', type=str, default='test', choices=['test', 'val']) 103 | opt = parser.parse_args() 104 | 105 | assert(opt.n_rgbns + opt.n_swirs > 1) 106 | if opt.n_rgbns <= 1 or opt.n_swirs <= 1: 107 | assert(opt.dist == 'full') 108 | 109 | return opt 110 | 111 | 112 | if __name__ == '__main__': 113 | opt = parse_args() 114 | 115 | Path(opt.log_path).mkdir(parents=True, exist_ok=True) 116 | log_fname = Path(opt.log_path) / ('{}_{}_{}_{}_{}.txt'.format(opt.set, opt.bg, opt.blend, opt.n_swirs, opt.n_rgbns)) 117 | assert(not log_fname.is_file()) 118 | 119 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 120 | n_lights = len(lights) 121 | n_powders = 100 122 | n_scenes = 32 123 | 124 | if opt.n_rgbns == 4: 125 | rgbn_channels = [0, 1, 2, 3] 126 | elif opt.n_rgbns == 3: 127 | rgbn_channels = [0, 1, 2] 128 | elif opt.n_rgbns == 1: 129 | rgbn_channels = [3] 130 | else: 131 | rgbn_channels = [] 132 | 133 | all_channels = rgbn_channels.copy() 134 | swir_channels = [] 135 | if opt.n_swirs > 0: 136 | if opt.sel == 'grid': 137 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs) 138 | if opt.n_swirs == 1: 139 | swir_channels.append(480) 140 | else: 141 | decimation = int(30 // (np.sqrt(opt.n_swirs) - 1)) 142 | for i in range(0, 31, decimation): 143 | for j in range(0, 31, decimation): 144 | swir_channels.append(i * 31 + j) 145 | else: 146 | sel_file = open(Path(opt.sel_path) / (opt.sel + '.txt'), 'r') 147 | splited = sel_file.readlines()[-1].strip().split(',') 148 | sel_file.close() 149 | for i in splited[:opt.n_swirs]: 150 | swir_channels.append(int(i)) 151 | assert(len(swir_channels) == opt.n_swirs) 152 | for i in swir_channels: 153 | all_channels.append(i + 4) 154 | 155 | n_channels = opt.n_rgbns + opt.n_swirs 156 | 157 | log_file = open(log_fname, 'w') 158 | print(opt) 159 | print(opt, file=log_file) 160 | print(swir_channels) 161 | print(swir_channels, file=log_file) 162 | 163 | train_path = Path(opt.data_path) / 'train' 164 | test_path = Path(opt.data_path) / opt.set 165 | 166 | scene_path = test_path / 'scene' 167 | bgscene_path = test_path / 'bgscene' 168 | label_path = test_path / 'label' 169 | 170 | thick_list = np.zeros((n_powders, n_lights, n_channels)) 171 | for i in range(n_powders): 172 | idx = str(i).zfill(2) 173 | for lid, light in enumerate(lights): 174 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz')) 175 | thick = np.concatenate((thick['rgbn'][:, :, rgbn_channels], thick['swir'][:, :, swir_channels]), axis=2) 176 | thick = thick.mean((0, 1)) 177 | thick_list[i, lid] = thick 178 | thick_list = thick_list.reshape((n_powders * n_lights, n_channels)) 179 | thick_list = torch.from_numpy(thick_list).cuda() 180 | 181 | if opt.blend == 'alpha': 182 | kappa = torch.ones((n_powders * n_lights, n_channels)).double().cuda() 183 | elif opt.blend == 'kappa': 184 | kappa_params = np.load(opt.kappa_params) 185 | kappa = kappa_params['params'][:, :, all_channels].reshape((n_powders * n_lights, n_channels)) 186 | kappa = torch.from_numpy(kappa).cuda() 187 | 188 | acc_top1 = [] 189 | acc_top3 = [] 190 | start_time = time.time() 191 | for i in range(n_scenes): 192 | idx = str(i).zfill(2) 193 | print() 194 | print('scene', idx) 195 | 196 | print(file=log_file) 197 | print('scene', idx, file=log_file) 198 | 199 | scene = np.load(scene_path / (idx + '_scene.npz')) 200 | scene = np.concatenate((scene['rgbn'][:, :, rgbn_channels], scene['swir'][:, :, swir_channels]), axis=2) 201 | label = cv2.imread(str(label_path / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 202 | if opt.bg == 'inpaint': 203 | mask = (label < 255).astype(np.uint8) * 255 204 | bgscene = scene.copy() 205 | for c in range(n_channels): 206 | scene_max = scene[mask == 255, c].max() 207 | bgscene[:, :, c] = (cv2.inpaint((scene[:, :, c] / scene_max * 65535).astype(np.uint16), mask, 3, cv2.INPAINT_TELEA)).astype(scene.dtype) * scene_max / 65535 208 | elif opt.bg == 'gt': 209 | bgscene = np.load(bgscene_path / (idx + '_bgscene.npz')) 210 | bgscene = np.concatenate((bgscene['rgbn'][:, :, rgbn_channels], bgscene['swir'][:, :, swir_channels]), axis=2) 211 | else: 212 | assert(False) 213 | for powder in range(n_powders): 214 | mask = (label == powder) 215 | if mask.any(): 216 | print('powder', powder) 217 | print('powder', powder, file=log_file) 218 | scene_list = scene[mask, :] 219 | bgscene_list = bgscene[mask, :] 220 | scene_list = torch.from_numpy(scene_list).cuda() 221 | bgscene_list = torch.from_numpy(bgscene_list).cuda() 222 | if opt.blend == 'none': 223 | sims = match_powder_none(opt, thick_list, scene_list) 224 | else: 225 | sims = match_powder_kappa(opt, thick_list, scene_list, bgscene_list, kappa) 226 | pred = sims2pred(n_lights, sims) 227 | top1 = (powder in pred[:1]) 228 | top3 = (powder in pred[:3]) 229 | acc_top1.append(top1) 230 | acc_top3.append(top3) 231 | print(pred) 232 | print(top1, top3) 233 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3)) 234 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's') 235 | print(pred, file=log_file) 236 | print(top1, top3, file=log_file) 237 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3), file=log_file) 238 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's', file=log_file) 239 | print(np.mean(acc_top1), np.mean(acc_top3)) 240 | print(np.mean(acc_top1), np.mean(acc_top3), file=log_file) 241 | print('Complete', file=log_file) 242 | log_file.close() 243 | -------------------------------------------------------------------------------- /recog/recognition_testext.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import cv2 7 | import time 8 | 9 | from pathlib import Path 10 | 11 | 12 | def cosine(a, b): 13 | y = b.unsqueeze(0) 14 | n_pixels = a.size()[0] 15 | batch_size = 1000 16 | if n_pixels % batch_size == 0: 17 | n_batches = n_pixels // batch_size 18 | else: 19 | n_batches = n_pixels // batch_size + 1 20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda() 21 | for batch_idx in range(n_batches): 22 | bs = batch_idx * batch_size 23 | be = min(n_pixels, (batch_idx + 1) * batch_size) 24 | x = a[bs:be, :].unsqueeze(1) 25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2) 26 | return sim 27 | 28 | 29 | def sim_func(opt, query, database): 30 | if opt.dist == 'full': 31 | return cosine(query, database) 32 | elif opt.dist == 'split': 33 | return cosine(query[:, :opt.n_rgbns], database[:, :opt.n_rgbns]) + cosine(query[:, opt.n_rgbns:], database[:, opt.n_rgbns:]) 34 | else: 35 | assert(False) 36 | 37 | 38 | def sims2pred(n_lights, sims): 39 | votes = sims.argmax(dim=1) // n_lights 40 | votes = votes.cpu().numpy() 41 | counts = collections.Counter(votes) 42 | pred = [i[0] for i in counts.most_common()] 43 | return pred 44 | 45 | 46 | def match_powder_none(opt, database, scene): 47 | n_pixels = scene.size()[0] 48 | query = scene.view((n_pixels, -1)) 49 | sims = sim_func(opt, query, database) 50 | return sims 51 | 52 | 53 | def match_powder_kappa(opt, database, scene, bg, kappa): 54 | n_database = database.size()[0] 55 | n_pixels = scene.size()[0] 56 | n_channels = scene.size()[1] 57 | 58 | eta = torch.linspace(0, opt.eta_max, opt.n_etas, dtype=torch.double).cuda() 59 | alpha = eta.unsqueeze(0).unsqueeze(2) ** kappa.unsqueeze(1) 60 | 61 | # n_pixels * n_database * n_etas * n_channels 62 | bg = bg.unsqueeze(1).unsqueeze(2) 63 | 64 | # n_database * n_etas * n_channels 65 | database = database.unsqueeze(1) * (1 - alpha) 66 | 67 | batch_size = 64000 // n_channels 68 | sims = torch.zeros((n_pixels, n_database), dtype=torch.double).cuda() 69 | if n_database % batch_size == 0: 70 | n_batches = n_database // batch_size 71 | else: 72 | n_batches = n_database // batch_size + 1 73 | 74 | for p in range(n_pixels): 75 | query = scene[p:p+1, :] 76 | db = database + bg[p, :, :, :] * alpha 77 | for batch_idx in range(n_batches): 78 | bs = batch_idx * batch_size 79 | be = min(n_database, (batch_idx + 1) * batch_size) 80 | cur_database = db[bs:be, :, :].reshape(((be - bs) * opt.n_etas, -1)) 81 | cur_sims = sim_func(opt, query, cur_database) 82 | cur_sims, _ = cur_sims.view((be - bs, opt.n_etas)).max(1) 83 | sims[p, bs:be] = cur_sims 84 | 85 | return sims 86 | 87 | def chmap(channels): 88 | ch_list = [4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964] 89 | mapped = [] 90 | for c in channels: 91 | for i, ch in enumerate(ch_list): 92 | if ch - 4 == c: 93 | mapped.append(i) 94 | return mapped 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser(description='Recognition with known mask') 98 | parser.add_argument('--data-path', type=str, default='../data') 99 | parser.add_argument('--log-path', type=str, default='./log') 100 | parser.add_argument('--sel-path', type=str, default='../bandsel/bands') 101 | parser.add_argument('--bg', type=str, choices=['gt', 'inpaint'], default='inpaint') 102 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa') 103 | parser.add_argument('--eta-max', type=float, default=0.9) 104 | parser.add_argument('--n-etas', type=int, default=10) 105 | parser.add_argument('--kappa-params', type=str, default='../params/kappa_params.npz') 106 | parser.add_argument('--n-swirs', type=int, default=4) 107 | parser.add_argument('--n-rgbns', type=int, default=4, choices=[0, 1, 3, 4]) 108 | parser.add_argument('--sel', type=str, default='nncv', choices=['nncv', 'grid', 'mvpca', 'rs']) 109 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split']) 110 | parser.add_argument('--set', type=str, default='testext', choices=['testext']) 111 | opt = parser.parse_args() 112 | 113 | assert(opt.n_rgbns + opt.n_swirs > 1) 114 | if opt.n_rgbns <= 1 or opt.n_swirs <= 1: 115 | assert(opt.dist == 'full') 116 | 117 | return opt 118 | 119 | 120 | if __name__ == '__main__': 121 | opt = parse_args() 122 | 123 | Path(opt.log_path).mkdir(parents=True, exist_ok=True) 124 | log_fname = Path(opt.log_path) / ('{}_{}_{}_{}_{}.txt'.format(opt.set, opt.bg, opt.blend, opt.n_swirs, opt.n_rgbns)) 125 | assert(not log_fname.is_file()) 126 | 127 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 128 | n_lights = len(lights) 129 | n_powders = 100 130 | n_scenes = 64 131 | 132 | if opt.n_rgbns == 4: 133 | rgbn_channels = [0, 1, 2, 3] 134 | elif opt.n_rgbns == 3: 135 | rgbn_channels = [0, 1, 2] 136 | elif opt.n_rgbns == 1: 137 | rgbn_channels = [3] 138 | else: 139 | rgbn_channels = [] 140 | 141 | all_channels = rgbn_channels.copy() 142 | swir_channels = [] 143 | if opt.n_swirs > 0: 144 | if opt.sel == 'grid': 145 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs) 146 | if opt.n_swirs == 1: 147 | swir_channels.append(480) 148 | else: 149 | decimation = int(30 // (np.sqrt(opt.n_swirs) - 1)) 150 | for i in range(0, 31, decimation): 151 | for j in range(0, 31, decimation): 152 | swir_channels.append(i * 31 + j) 153 | else: 154 | sel_file = open(Path(opt.sel_path) / (opt.sel + '.txt'), 'r') 155 | splited = sel_file.readlines()[-1].strip().split(',') 156 | sel_file.close() 157 | for i in splited[:opt.n_swirs]: 158 | swir_channels.append(int(i)) 159 | assert(len(swir_channels) == opt.n_swirs) 160 | for i in swir_channels: 161 | all_channels.append(i + 4) 162 | 163 | n_channels = opt.n_rgbns + opt.n_swirs 164 | 165 | log_file = open(log_fname, 'w') 166 | print(opt) 167 | print(opt, file=log_file) 168 | print(swir_channels) 169 | print(swir_channels, file=log_file) 170 | 171 | train_path = Path(opt.data_path) / 'train' 172 | test_path = Path(opt.data_path) / opt.set 173 | 174 | 175 | thick_list = np.zeros((n_powders, n_lights, n_channels)) 176 | for i in range(n_powders): 177 | idx = str(i).zfill(2) 178 | for lid, light in enumerate(lights): 179 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz')) 180 | thick = np.concatenate((thick['rgbn'][:, :, rgbn_channels], thick['swir'][:, :, swir_channels]), axis=2) 181 | thick = thick.mean((0, 1)) 182 | thick_list[i, lid] = thick 183 | thick_list = thick_list.reshape((n_powders * n_lights, n_channels)) 184 | thick_list = torch.from_numpy(thick_list).cuda() 185 | 186 | if opt.blend == 'alpha': 187 | kappa = torch.ones((n_powders * n_lights, n_channels)).double().cuda() 188 | elif opt.blend == 'kappa': 189 | kappa_params = np.load(opt.kappa_params) 190 | kappa = kappa_params['params'][:, :, all_channels].reshape((n_powders * n_lights, n_channels)) 191 | kappa = torch.from_numpy(kappa).cuda() 192 | 193 | acc_top1 = [] 194 | acc_top3 = [] 195 | start_time = time.time() 196 | for i in range(n_scenes): 197 | idx = str(i).zfill(2) 198 | print() 199 | print('scene', idx) 200 | print(file=log_file) 201 | print('scene', idx, file=log_file) 202 | 203 | scene_path = test_path / 'scene' 204 | bgscene_path = test_path / 'bgscene' 205 | label_path = test_path / 'label' 206 | 207 | scene = np.load(scene_path / (idx + '_scene.npz')) 208 | scene = np.concatenate((scene['rgbn'][:, :, rgbn_channels], scene['swir'][:, :, chmap(swir_channels)]), axis=2) 209 | label = cv2.imread(str(label_path / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 210 | if opt.bg == 'inpaint': 211 | mask = (label < 255).astype(np.uint8) * 255 212 | bgscene = scene.copy() 213 | for c in range(n_channels): 214 | scene_max = scene[mask == 255, c].max() 215 | bgscene[:, :, c] = (cv2.inpaint((scene[:, :, c] / scene_max * 65535).astype(np.uint16), mask, 3, cv2.INPAINT_TELEA)).astype(scene.dtype) * scene_max / 65535 216 | elif opt.bg == 'gt': 217 | bgscene = np.load(bgscene_path / (idx + '_bgscene.npz')) 218 | bgscene = np.concatenate((bgscene['rgbn'][:, :, rgbn_channels], bgscene['swir'][:, :, chmap(swir_channels)]), axis=2) 219 | else: 220 | assert(False) 221 | for powder in range(n_powders): 222 | mask = (label == powder) 223 | if mask.any(): 224 | print('powder', powder) 225 | print('powder', powder, file=log_file) 226 | scene_list = scene[mask, :] 227 | bgscene_list = bgscene[mask, :] 228 | scene_list = torch.from_numpy(scene_list).cuda() 229 | bgscene_list = torch.from_numpy(bgscene_list).cuda() 230 | if opt.blend == 'none': 231 | sims = match_powder_none(opt, thick_list, scene_list) 232 | else: 233 | sims = match_powder_kappa(opt, thick_list, scene_list, bgscene_list, kappa) 234 | pred = sims2pred(n_lights, sims) 235 | top1 = (powder in pred[:1]) 236 | top3 = (powder in pred[:3]) 237 | acc_top1.append(top1) 238 | acc_top3.append(top3) 239 | print(pred) 240 | print(top1, top3) 241 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3)) 242 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's') 243 | print(pred, file=log_file) 244 | print(top1, top3, file=log_file) 245 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3), file=log_file) 246 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's', file=log_file) 247 | print(np.mean(acc_top1), np.mean(acc_top3)) 248 | print(np.mean(acc_top1), np.mean(acc_top3), file=log_file) 249 | print('Complete', file=log_file) 250 | log_file.close() 251 | -------------------------------------------------------------------------------- /recog/recognition_trainext.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import cv2 7 | import time 8 | 9 | from pathlib import Path 10 | 11 | 12 | def cosine(a, b): 13 | y = b.unsqueeze(0) 14 | n_pixels = a.size()[0] 15 | batch_size = 1000 16 | if n_pixels % batch_size == 0: 17 | n_batches = n_pixels // batch_size 18 | else: 19 | n_batches = n_pixels // batch_size + 1 20 | sim = torch.zeros((n_pixels, b.size()[0]), dtype=torch.double).cuda() 21 | for batch_idx in range(n_batches): 22 | bs = batch_idx * batch_size 23 | be = min(n_pixels, (batch_idx + 1) * batch_size) 24 | x = a[bs:be, :].unsqueeze(1) 25 | sim[bs:be, :] = F.cosine_similarity(x, y, dim=2) 26 | return sim 27 | 28 | 29 | def sim_func(opt, query, database): 30 | if opt.dist == 'full': 31 | return cosine(query, database) 32 | elif opt.dist == 'split': 33 | return cosine(query[:, :opt.n_rgbns], database[:, :opt.n_rgbns]) + cosine(query[:, opt.n_rgbns:], database[:, opt.n_rgbns:]) 34 | else: 35 | assert(False) 36 | 37 | 38 | def sims2pred(n_lights, sims): 39 | votes = sims.argmax(dim=1) // n_lights 40 | votes = votes.cpu().numpy() 41 | counts = collections.Counter(votes) 42 | pred = [i[0] for i in counts.most_common()] 43 | return pred 44 | 45 | 46 | def match_powder_none(opt, database, scene): 47 | n_pixels = scene.size()[0] 48 | query = scene.view((n_pixels, -1)) 49 | sims = sim_func(opt, query, database) 50 | return sims 51 | 52 | 53 | def match_powder_kappa(opt, database, scene, bg, kappa): 54 | n_database = database.size()[0] 55 | n_pixels = scene.size()[0] 56 | n_channels = scene.size()[1] 57 | 58 | eta = torch.linspace(0, opt.eta_max, opt.n_etas, dtype=torch.double).cuda() 59 | alpha = eta.unsqueeze(0).unsqueeze(2) ** kappa.unsqueeze(1) 60 | 61 | # n_pixels * n_database * n_etas * n_channels 62 | bg = bg.unsqueeze(1).unsqueeze(2) 63 | 64 | # n_database * n_etas * n_channels 65 | database = database.unsqueeze(1) * (1 - alpha) 66 | 67 | batch_size = 64000 // n_channels 68 | sims = torch.zeros((n_pixels, n_database), dtype=torch.double).cuda() 69 | if n_database % batch_size == 0: 70 | n_batches = n_database // batch_size 71 | else: 72 | n_batches = n_database // batch_size + 1 73 | 74 | for p in range(n_pixels): 75 | query = scene[p:p+1, :] 76 | db = database + bg[p, :, :, :] * alpha 77 | for batch_idx in range(n_batches): 78 | bs = batch_idx * batch_size 79 | be = min(n_database, (batch_idx + 1) * batch_size) 80 | cur_database = db[bs:be, :, :].reshape(((be - bs) * opt.n_etas, -1)) 81 | cur_sims = sim_func(opt, query, cur_database) 82 | cur_sims, _ = cur_sims.view((be - bs, opt.n_etas)).max(1) 83 | sims[p, bs:be] = cur_sims 84 | 85 | return sims 86 | 87 | def chmap(channels): 88 | ch_list = [4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964] 89 | mapped = [] 90 | for c in channels: 91 | for i, ch in enumerate(ch_list): 92 | if ch - 4 == c: 93 | mapped.append(i) 94 | return mapped 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser(description='Recognition with known mask') 98 | parser.add_argument('--data-path', type=str, default='../data') 99 | parser.add_argument('--log-path', type=str, default='./log') 100 | parser.add_argument('--sel-path', type=str, default='../bandsel/bands') 101 | parser.add_argument('--bg', type=str, choices=['gt', 'inpaint'], default='inpaint') 102 | parser.add_argument('--blend', type=str, choices=['none', 'alpha', 'kappa'], default='kappa') 103 | parser.add_argument('--eta-max', type=float, default=0.9) 104 | parser.add_argument('--n-etas', type=int, default=10) 105 | parser.add_argument('--kappa-params', type=str, default='../params/kappa_params.npz') 106 | parser.add_argument('--n-swirs', type=int, default=4) 107 | parser.add_argument('--n-rgbns', type=int, default=4, choices=[0, 1, 3, 4]) 108 | parser.add_argument('--sel', type=str, default='nncv', choices=['nncv', 'grid', 'mvpca', 'rs']) 109 | parser.add_argument('--dist', type=str, default='split', choices=['full', 'split']) 110 | parser.add_argument('--set', type=str, default='trainext', choices=['trainext']) 111 | opt = parser.parse_args() 112 | 113 | assert(opt.n_rgbns + opt.n_swirs > 1) 114 | if opt.n_rgbns <= 1 or opt.n_swirs <= 1: 115 | assert(opt.dist == 'full') 116 | 117 | return opt 118 | 119 | 120 | if __name__ == '__main__': 121 | opt = parse_args() 122 | 123 | Path(opt.log_path).mkdir(parents=True, exist_ok=True) 124 | log_fname = Path(opt.log_path) / ('{}_{}_{}_{}_{}.txt'.format(opt.set, opt.bg, opt.blend, opt.n_swirs, opt.n_rgbns)) 125 | assert(not log_fname.is_file()) 126 | 127 | lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 128 | n_lights = len(lights) 129 | n_powders = 100 130 | n_scenes_per_light = 16 131 | 132 | if opt.n_rgbns == 4: 133 | rgbn_channels = [0, 1, 2, 3] 134 | elif opt.n_rgbns == 3: 135 | rgbn_channels = [0, 1, 2] 136 | elif opt.n_rgbns == 1: 137 | rgbn_channels = [3] 138 | else: 139 | rgbn_channels = [] 140 | 141 | all_channels = rgbn_channels.copy() 142 | swir_channels = [] 143 | if opt.n_swirs > 0: 144 | if opt.sel == 'grid': 145 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs) 146 | if opt.n_swirs == 1: 147 | swir_channels.append(480) 148 | else: 149 | decimation = int(30 // (np.sqrt(opt.n_swirs) - 1)) 150 | for i in range(0, 31, decimation): 151 | for j in range(0, 31, decimation): 152 | swir_channels.append(i * 31 + j) 153 | elif opt.sel == 'uniform': 154 | assert(int(np.sqrt(opt.n_swirs))**2 == opt.n_swirs) 155 | if opt.n_swirs == 1: 156 | swir_channels.append(480) 157 | else: 158 | a = int(np.sqrt(opt.n_swirs)) 159 | for i in range(0, a): 160 | for j in range(0, a): 161 | x = int(np.floor(31*(2*i+1)/2/a)) 162 | y = int(np.floor(31*(2*j+1)/2/a)) 163 | swir_channels.append(x * 31 + y) 164 | else: 165 | sel_file = open(Path(opt.sel_path) / (opt.sel + '.txt'), 'r') 166 | splited = sel_file.readlines()[-1].strip().split(',') 167 | sel_file.close() 168 | for i in splited[:opt.n_swirs]: 169 | swir_channels.append(int(i)) 170 | assert(len(swir_channels) == opt.n_swirs) 171 | for i in swir_channels: 172 | all_channels.append(i + 4) 173 | 174 | n_channels = opt.n_rgbns + opt.n_swirs 175 | 176 | log_file = open(log_fname, 'w') 177 | print(opt) 178 | print(opt, file=log_file) 179 | print(swir_channels) 180 | print(swir_channels, file=log_file) 181 | 182 | train_path = Path(opt.data_path) / 'train' 183 | test_path = Path(opt.data_path) / opt.set 184 | 185 | 186 | thick_list = np.zeros((n_powders, n_lights, n_channels)) 187 | for i in range(n_powders): 188 | idx = str(i).zfill(2) 189 | for lid, light in enumerate(lights): 190 | thick = np.load(train_path / light / 'thick' / (idx + '_thick.npz')) 191 | thick = np.concatenate((thick['rgbn'][:, :, rgbn_channels], thick['swir'][:, :, swir_channels]), axis=2) 192 | thick = thick.mean((0, 1)) 193 | thick_list[i, lid] = thick 194 | thick_list = thick_list.reshape((n_powders * n_lights, n_channels)) 195 | thick_list = torch.from_numpy(thick_list).cuda() 196 | 197 | if opt.blend == 'alpha': 198 | kappa = torch.ones((n_powders * n_lights, n_channels)).double().cuda() 199 | elif opt.blend == 'kappa': 200 | kappa_params = np.load(opt.kappa_params) 201 | kappa = kappa_params['params'][:, :, all_channels].reshape((n_powders * n_lights, n_channels)) 202 | kappa = torch.from_numpy(kappa).cuda() 203 | 204 | acc_top1 = [] 205 | acc_top3 = [] 206 | start_time = time.time() 207 | for lid, light in enumerate(lights): 208 | for i in range(n_scenes_per_light): 209 | idx = str(i).zfill(2) 210 | print() 211 | print('scene', light, idx) 212 | 213 | print(file=log_file) 214 | print('scene', light, idx, file=log_file) 215 | 216 | scene_path = test_path / light / 'scene' 217 | bgscene_path = test_path / light / 'bgscene' 218 | label_path = test_path / light / 'label' 219 | 220 | scene = np.load(scene_path / (idx + '_scene.npz')) 221 | scene = np.concatenate((scene['rgbn'][:, :, rgbn_channels], scene['swir'][:, :, chmap(swir_channels)]), axis=2) 222 | label = cv2.imread(str(label_path / (idx + '_label.png')), cv2.IMREAD_GRAYSCALE) 223 | if opt.bg == 'inpaint': 224 | mask = (label < 255).astype(np.uint8) * 255 225 | bgscene = scene.copy() 226 | for c in range(n_channels): 227 | scene_max = scene[mask == 255, c].max() 228 | bgscene[:, :, c] = (cv2.inpaint((scene[:, :, c] / scene_max * 65535).astype(np.uint16), mask, 3, cv2.INPAINT_TELEA)).astype(scene.dtype) * scene_max / 65535 229 | elif opt.bg == 'gt': 230 | bgscene = np.load(bgscene_path / (idx + '_bgscene.npz')) 231 | bgscene = np.concatenate((bgscene['rgbn'][:, :, rgbn_channels], bgscene['swir'][:, :, chmap(swir_channels)]), axis=2) 232 | else: 233 | assert(False) 234 | for powder in range(n_powders): 235 | mask = (label == powder) 236 | if mask.any(): 237 | print('powder', powder) 238 | print('powder', powder, file=log_file) 239 | scene_list = scene[mask, :] 240 | bgscene_list = bgscene[mask, :] 241 | scene_list = torch.from_numpy(scene_list).cuda() 242 | bgscene_list = torch.from_numpy(bgscene_list).cuda() 243 | if opt.blend == 'none': 244 | sims = match_powder_none(opt, thick_list, scene_list) 245 | else: 246 | sims = match_powder_kappa(opt, thick_list, scene_list, bgscene_list, kappa) 247 | pred = sims2pred(n_lights, sims) 248 | top1 = (powder in pred[:1]) 249 | top3 = (powder in pred[:3]) 250 | acc_top1.append(top1) 251 | acc_top3.append(top3) 252 | print(pred) 253 | print(top1, top3) 254 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3)) 255 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's') 256 | print(pred, file=log_file) 257 | print(top1, top3, file=log_file) 258 | print('Acc:', np.mean(acc_top1), np.mean(acc_top3), file=log_file) 259 | print('No.', len(acc_top1), ' ', (time.time() - start_time) / len(acc_top1), 's', file=log_file) 260 | print(np.mean(acc_top1), np.mean(acc_top3)) 261 | print(np.mean(acc_top1), np.mean(acc_top3), file=log_file) 262 | print('Complete', file=log_file) 263 | log_file.close() 264 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import torch.utils.data as data 4 | import torch 5 | import cv2 6 | import random 7 | import scipy.special 8 | from pathlib import Path 9 | 10 | 11 | class SyntheticDataset(data.Dataset): 12 | 13 | def __init__(self, data_path, params_path, blend, channels): 14 | super(SyntheticDataset, self).__init__() 15 | assert(blend in ['none', 'alpha', 'kappa']) 16 | self.data_path = Path(data_path) 17 | self.blend = blend 18 | self.lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 19 | self.n_lights = len(self.lights) 20 | self.n_powders = 100 21 | self.height = 160 22 | self.width = 280 23 | self.channels = channels 24 | 25 | if type(channels) is int: 26 | self.channel = abs(channels) 27 | if channels > 0: 28 | self.ch_begin = 0 29 | self.ch_end = channels 30 | else: 31 | self.ch_begin = 965 + channels 32 | self.ch_end = 965 33 | else: 34 | self.channel = len(channels) 35 | self.ch_begin = None 36 | self.ch_end = None 37 | 38 | self.thickness_threshold = 0.1 39 | self.n_classes = 100 + 1 40 | self.n_per_light = 1000 41 | self.thick_sigma = 0.1 42 | self.shad_sigma = 0.1 43 | self.brdf_sigma = 0.1 44 | if blend == 'kappa': 45 | if self.ch_begin is None: 46 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.channels] 47 | else: 48 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.ch_begin:self.ch_end] 49 | else: 50 | self.kappa = None 51 | 52 | def __getitem__(self, index): 53 | lid = index // self.n_per_light 54 | light = self.lights[lid] 55 | powder_idx = index % self.n_per_light 56 | bg_idx = random.randint(0, self.n_per_light - 1) 57 | h5file = h5py.File(self.data_path / (light + '.hdf5'), 'r') 58 | if self.ch_begin is None: 59 | bg = h5file['bg'][bg_idx, :, :, self.channels].astype(np.float32) 60 | powder = h5file['powder'][powder_idx, :, :, self.channels].astype(np.float32) 61 | else: 62 | bg = h5file['bg'][bg_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32) 63 | powder = h5file['powder'][powder_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32) 64 | shading = h5file['shading'][bg_idx].astype(np.float32) 65 | label = h5file['label'][powder_idx] 66 | thickness = h5file['thickness'][powder_idx].astype(np.float32) 67 | h5file.close() 68 | 69 | if random.randint(0, 1) == 1: 70 | bg = np.fliplr(bg) 71 | shading = np.fliplr(shading) 72 | if random.randint(0, 1) == 1: 73 | bg = np.flipud(bg) 74 | shading = np.flipud(shading) 75 | if random.randint(0, 1) == 1: 76 | powder = np.fliplr(powder) 77 | label = np.fliplr(label) 78 | thickness = np.fliplr(thickness) 79 | if random.randint(0, 1) == 1: 80 | powder = np.flipud(powder) 81 | label = np.flipud(label) 82 | thickness = np.flipud(thickness) 83 | 84 | for i in range(self.n_powders): 85 | mask = (label == i) 86 | thickness[mask] = thickness[mask] * self.exp_gauss(self.thick_sigma) 87 | powder[mask] = powder[mask] * self.exp_gauss(self.brdf_sigma) 88 | label[thickness < self.thickness_threshold] = 255 89 | 90 | if self.blend == 'none': 91 | thickness = (thickness >= self.thickness_threshold).astype(np.float32) 92 | thickness = thickness[:, :, np.newaxis] 93 | alpha = 1 - thickness 94 | elif self.blend == 'alpha': 95 | thickness[thickness > 1] = 1 96 | thickness = thickness[:, :, np.newaxis] 97 | alpha = 1 - thickness 98 | elif self.blend == 'kappa': 99 | thickness[thickness > 1] = 1 100 | thickness = thickness[:, :, np.newaxis] 101 | alpha = np.ones((self.height, self.width, self.channel), dtype=np.float32) 102 | for i in range(self.n_powders): 103 | mask = (label == i) 104 | alpha[mask, :] = (1 - thickness[mask, :]) ** self.kappa[i, lid, :][np.newaxis, :] 105 | im = alpha * bg + (1 - alpha) * powder 106 | med_shad = np.median(shading) 107 | im = im * shading[:, :, np.newaxis] / med_shad 108 | im = im * self.exp_gauss(self.shad_sigma) 109 | 110 | im = im.transpose([2, 0, 1]) 111 | label[label == 255] = self.n_classes - 1 112 | label = label.astype(np.int64) 113 | return im, label 114 | 115 | def exp_gauss(self, sigma): 116 | return np.exp(random.gauss(0, sigma)) 117 | 118 | def __len__(self): 119 | return self.n_lights * self.n_per_light 120 | 121 | 122 | class RealDataset(data.Dataset): 123 | 124 | def __init__(self, data_path, channels, split, flip=False): 125 | super(RealDataset, self).__init__() 126 | self.data_path = Path(data_path) 127 | self.n_classes = 100 + 1 128 | self.split = split 129 | self.flip = flip 130 | 131 | if split == 'trainext' or split == 'testext': 132 | assert(type(channels) is not int) 133 | self.n_images = 64 134 | self.channels = self.chmap(channels) 135 | self.channel = len(self.channels) 136 | self.ch_begin = None 137 | self.ch_end = None 138 | else: 139 | self.n_images = 32 140 | self.channels = channels 141 | if type(channels) is int: 142 | self.channel = abs(channels) 143 | if channels > 0: 144 | self.ch_begin = 0 145 | self.ch_end = channels 146 | else: 147 | self.ch_begin = 965 + channels 148 | self.ch_end = 965 149 | else: 150 | self.channel = len(channels) 151 | self.ch_begin = None 152 | self.ch_end = None 153 | 154 | def __getitem__(self, index): 155 | h5file = h5py.File(self.data_path / (self.split + '.hdf5'), 'r') 156 | if self.ch_begin is None: 157 | im = h5file['im'][index, :, :, self.channels].astype(np.float32) 158 | else: 159 | im = h5file['im'][index, :, :, self.ch_begin:self.ch_end].astype(np.float32) 160 | label = h5file['label'][index] 161 | h5file.close() 162 | 163 | if self.flip: 164 | if random.randint(0, 1) == 1: 165 | im = np.fliplr(im) 166 | label = np.fliplr(label) 167 | if random.randint(0, 1) == 1: 168 | im = np.flipud(im) 169 | label = np.flipud(label) 170 | 171 | im = im.transpose([2, 0, 1]).copy() 172 | label[label == 255] = self.n_classes - 1 173 | label = label.astype(np.int64).copy() 174 | return im, label 175 | 176 | def __len__(self): 177 | return self.n_images 178 | 179 | def chmap(self, channels): 180 | ch_list = [0, 1, 2, 3, 4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964] 181 | mapped = [] 182 | for i, ch in enumerate(ch_list): 183 | if ch in list(channels): 184 | mapped.append(i) 185 | print(channels, mapped) 186 | assert(len(channels) == len(mapped)) 187 | return mapped 188 | 189 | 190 | class HalfHalfDataset(data.Dataset): 191 | 192 | def __init__(self, real_path, syn_path, params_path, blend, channels, split): 193 | super(HalfHalfDataset, self).__init__() 194 | assert(blend in ['none', 'alpha', 'kappa']) 195 | self.real_path = Path(real_path) / split 196 | self.syn_path = Path(syn_path) 197 | self.blend = blend 198 | self.lights = ['EiKOIncandescent250W', 'IIIWoodsHalogen500W', 'LowelProHalogen250W', 'WestinghouseIncandescent150W'] 199 | self.n_lights = len(self.lights) 200 | self.n_powders = 100 201 | self.height = 160 202 | self.width = 280 203 | self.channels = channels 204 | if split == 'bgext': 205 | assert(type(channels) is not int) 206 | self.n_bg_per_light = 32 207 | self.bg_channels = self.chmap(channels) 208 | self.channel = len(self.channels) 209 | self.ch_begin = None 210 | self.ch_end = None 211 | else: 212 | self.bg_channels = self.channels 213 | self.n_bg_per_light = 16 214 | if type(channels) is int: 215 | self.channel = abs(channels) 216 | if channels > 0: 217 | self.ch_begin = 0 218 | self.ch_end = channels 219 | else: 220 | self.ch_begin = 965 + channels 221 | self.ch_end = 965 222 | else: 223 | self.channel = len(channels) 224 | self.ch_begin = None 225 | self.ch_end = None 226 | self.thickness_threshold = 0.1 227 | self.n_classes = 100 + 1 228 | self.n_powder_per_light = 1000 229 | self.thick_sigma = 0.1 230 | self.brdf_sigma = 0.15 231 | if blend == 'kappa': 232 | if self.ch_begin is None: 233 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.channels] 234 | else: 235 | self.kappa = np.load(Path(params_path) / 'kappa_params.npz')['params'][:, :, self.ch_begin:self.ch_end] 236 | else: 237 | self.kappa = None 238 | 239 | def __getitem__(self, index): 240 | lid = index // self.n_powder_per_light 241 | light = self.lights[lid] 242 | powder_idx = index % self.n_powder_per_light 243 | bg_idx = random.randint(0, self.n_bg_per_light - 1) 244 | h5file = h5py.File(self.real_path / (light + '.hdf5'), 'r') 245 | if self.ch_begin is None: 246 | bg = h5file['im'][bg_idx, :, :, self.bg_channels].astype(np.float32) 247 | else: 248 | bg = h5file['im'][bg_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32) 249 | h5file.close() 250 | h5file = h5py.File(self.syn_path / (light + '.hdf5'), 'r') 251 | if self.ch_begin is None: 252 | powder = h5file['powder'][powder_idx, :, :, self.channels].astype(np.float32) 253 | else: 254 | powder = h5file['powder'][powder_idx, :, :, self.ch_begin:self.ch_end].astype(np.float32) 255 | label = h5file['label'][powder_idx] 256 | thickness = h5file['thickness'][powder_idx].astype(np.float32) 257 | h5file.close() 258 | 259 | if random.randint(0, 1) == 1: 260 | bg = np.fliplr(bg) 261 | if random.randint(0, 1) == 1: 262 | bg = np.flipud(bg) 263 | if random.randint(0, 1) == 1: 264 | powder = np.fliplr(powder) 265 | label = np.fliplr(label) 266 | thickness = np.fliplr(thickness) 267 | if random.randint(0, 1) == 1: 268 | powder = np.flipud(powder) 269 | label = np.flipud(label) 270 | thickness = np.flipud(thickness) 271 | 272 | for i in range(self.n_powders): 273 | mask = (label == i) 274 | thickness[mask] = thickness[mask] * self.exp_gauss(self.thick_sigma) 275 | powder[mask] = powder[mask] * self.exp_gauss(self.brdf_sigma) 276 | label[thickness < self.thickness_threshold] = 255 277 | 278 | if self.blend == 'none': 279 | thickness = (thickness >= self.thickness_threshold).astype(np.float32) 280 | thickness = thickness[:, :, np.newaxis] 281 | alpha = 1 - thickness 282 | elif self.blend == 'alpha': 283 | thickness[thickness > 1] = 1 284 | thickness = thickness[:, :, np.newaxis] 285 | alpha = 1 - thickness 286 | elif self.blend == 'kappa': 287 | thickness[thickness > 1] = 1 288 | thickness = thickness[:, :, np.newaxis] 289 | alpha = np.ones((self.height, self.width, self.channel), dtype=np.float32) 290 | for i in range(self.n_powders): 291 | mask = (label == i) 292 | alpha[mask, :] = (1 - thickness[mask, :]) ** self.kappa[i, lid, :][np.newaxis, :] 293 | im = alpha * bg + (1 - alpha) * powder 294 | 295 | im = im.transpose([2, 0, 1]) 296 | label[label == 255] = self.n_classes - 1 297 | label = label.astype(np.int64) 298 | return im, label 299 | 300 | def exp_gauss(self, sigma): 301 | return np.exp(random.gauss(0, sigma)) 302 | 303 | def __len__(self): 304 | return self.n_lights * self.n_powder_per_light 305 | 306 | def chmap(self, channels): 307 | ch_list = [0, 1, 2, 3, 4, 5, 19, 34, 51, 77, 95, 127, 152, 342, 399, 401, 422, 434, 442, 469, 484, 487, 499, 538, 555, 588, 637, 664, 676, 683, 686, 750, 837, 879, 905, 934, 949, 964] 308 | mapped = [] 309 | for i, ch in enumerate(ch_list): 310 | if ch in list(channels): 311 | mapped.append(i) 312 | print(channels, mapped) 313 | assert(len(channels) == len(mapped)) 314 | return mapped 315 | -------------------------------------------------------------------------------- /src/deeplab/deeplab_xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | class SeparableConv2d(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False): 10 | super(SeparableConv2d, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, 13 | groups=inplanes, bias=bias) 14 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 15 | 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = self.pointwise(x) 19 | return x 20 | 21 | 22 | def fixed_padding(inputs, kernel_size, rate): 23 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 24 | pad_total = kernel_size_effective - 1 25 | pad_beg = pad_total // 2 26 | pad_end = pad_total - pad_beg 27 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 28 | return padded_inputs 29 | 30 | 31 | class SeparableConv2d_same(nn.Module): 32 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False): 33 | super(SeparableConv2d_same, self).__init__() 34 | 35 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 36 | groups=inplanes, bias=bias) 37 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 38 | 39 | def forward(self, x): 40 | x = fixed_padding(x, self.conv1.kernel_size[0], rate=self.conv1.dilation[0]) 41 | x = self.conv1(x) 42 | x = self.pointwise(x) 43 | return x 44 | 45 | 46 | class Block(nn.Module): 47 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True): 48 | super(Block, self).__init__() 49 | 50 | if planes != inplanes or stride != 1: 51 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 52 | #self.skipbn = nn.BatchNorm2d(planes) 53 | self.skipbn = nn.GroupNorm(8, planes) 54 | else: 55 | self.skip = None 56 | 57 | self.relu = nn.ReLU(inplace=True) 58 | rep = [] 59 | 60 | filters = inplanes 61 | if grow_first: 62 | rep.append(self.relu) 63 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) 64 | #rep.append(nn.BatchNorm2d(planes)) 65 | rep.append(nn.GroupNorm(8, planes)) 66 | filters = planes 67 | 68 | for i in range(reps - 1): 69 | rep.append(self.relu) 70 | rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) 71 | #rep.append(nn.BatchNorm2d(filters)) 72 | rep.append(nn.GroupNorm(8, filters)) 73 | 74 | if not grow_first: 75 | rep.append(self.relu) 76 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) 77 | #rep.append(nn.BatchNorm2d(planes)) 78 | rep.append(nn.GroupNorm(8, planes)) 79 | 80 | if not start_with_relu: 81 | rep = rep[1:] 82 | 83 | if stride != 1: 84 | rep.append(SeparableConv2d_same(planes, planes, 3, stride=stride)) 85 | 86 | self.rep = nn.Sequential(*rep) 87 | 88 | def forward(self, inp): 89 | x = self.rep(inp) 90 | 91 | if self.skip is not None: 92 | skip = self.skip(inp) 93 | skip = self.skipbn(skip) 94 | else: 95 | skip = inp 96 | 97 | x += skip 98 | 99 | return x 100 | 101 | 102 | class Xception(nn.Module): 103 | """ 104 | Modified Alighed Xception 105 | """ 106 | def __init__(self, inplanes=3, pretrained=False): 107 | super(Xception, self).__init__() 108 | 109 | # Entry flow 110 | self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False) 111 | #self.bn1 = nn.BatchNorm2d(32) 112 | self.bn1 = nn.GroupNorm(16, 32) 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 116 | #self.bn2 = nn.BatchNorm2d(64) 117 | self.bn2 = nn.GroupNorm(32, 64) 118 | 119 | self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) 120 | self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True) 121 | self.block3 = Block(256, 728, reps=2, stride=2, start_with_relu=True, grow_first=True) 122 | 123 | # Middle flow 124 | self.block4 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 125 | self.block5 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 126 | self.block6 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 127 | self.block7 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 128 | self.block8 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 129 | self.block9 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 130 | self.block10 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 131 | self.block11 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 132 | self.block12 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 133 | self.block13 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 134 | self.block14 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 135 | self.block15 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 136 | self.block16 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 137 | self.block17 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 138 | self.block18 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 139 | self.block19 = Block(728, 728, reps=3, stride=1, start_with_relu=True, grow_first=True) 140 | 141 | # Exit flow 142 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=2, start_with_relu=True, grow_first=False) 143 | 144 | self.conv3 = SeparableConv2d_same(1024, 1536, 3, stride=1, dilation=2) 145 | #self.bn3 = nn.BatchNorm2d(1536) 146 | self.bn3 = nn.GroupNorm(32, 1536) 147 | 148 | self.conv4 = SeparableConv2d_same(1536, 1536, 3, stride=1, dilation=2) 149 | #self.bn4 = nn.BatchNorm2d(1536) 150 | self.bn4 = nn.GroupNorm(32, 1536) 151 | 152 | self.conv5 = SeparableConv2d_same(1536, 2048, 3, stride=1, dilation=2) 153 | #self.bn5 = nn.BatchNorm2d(2048) 154 | self.bn5 = nn.GroupNorm(32, 2048) 155 | 156 | # Init weights 157 | self.__init_weight() 158 | 159 | # Load pretrained model 160 | if pretrained: 161 | self.__load_xception_pretrained() 162 | 163 | def forward(self, x): 164 | # Entry flow 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | x = self.relu(x) 168 | 169 | x = self.conv2(x) 170 | x = self.bn2(x) 171 | x = self.relu(x) 172 | 173 | x = self.block1(x) 174 | low_level_feat = x 175 | x = self.block2(x) 176 | x = self.block3(x) 177 | 178 | # Middle flow 179 | x = self.block4(x) 180 | x = self.block5(x) 181 | x = self.block6(x) 182 | x = self.block7(x) 183 | x = self.block8(x) 184 | x = self.block9(x) 185 | x = self.block10(x) 186 | x = self.block11(x) 187 | x = self.block12(x) 188 | x = self.block13(x) 189 | x = self.block14(x) 190 | x = self.block15(x) 191 | x = self.block16(x) 192 | x = self.block17(x) 193 | x = self.block18(x) 194 | x = self.block19(x) 195 | 196 | # Exit flow 197 | x = self.block20(x) 198 | x = self.conv3(x) 199 | x = self.bn3(x) 200 | x = self.relu(x) 201 | 202 | x = self.conv4(x) 203 | x = self.bn4(x) 204 | x = self.relu(x) 205 | 206 | x = self.conv5(x) 207 | x = self.bn5(x) 208 | x = self.relu(x) 209 | 210 | return x, low_level_feat 211 | 212 | def __init_weight(self): 213 | for m in self.modules(): 214 | if isinstance(m, nn.Conv2d): 215 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 216 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 217 | torch.nn.init.kaiming_normal_(m.weight) 218 | elif isinstance(m, nn.BatchNorm2d): 219 | m.weight.data.fill_(1) 220 | m.bias.data.zero_() 221 | 222 | def __load_xception_pretrained(self): 223 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 224 | model_dict = {} 225 | state_dict = self.state_dict() 226 | 227 | for k, v in pretrain_dict.items(): 228 | if k in state_dict: 229 | if 'pointwise' in k: 230 | v = v.unsqueeze(-1).unsqueeze(-1) 231 | if k.startswith('block12'): 232 | model_dict[k.replace('block12', 'block20')] = v 233 | elif k.startswith('block11'): 234 | model_dict[k.replace('block11', 'block12')] = v 235 | model_dict[k.replace('block11', 'block13')] = v 236 | model_dict[k.replace('block11', 'block14')] = v 237 | model_dict[k.replace('block11', 'block15')] = v 238 | model_dict[k.replace('block11', 'block16')] = v 239 | model_dict[k.replace('block11', 'block17')] = v 240 | model_dict[k.replace('block11', 'block18')] = v 241 | model_dict[k.replace('block11', 'block19')] = v 242 | elif k.startswith('conv3'): 243 | model_dict[k] = v 244 | elif k.startswith('bn3'): 245 | model_dict[k] = v 246 | model_dict[k.replace('bn3', 'bn4')] = v 247 | elif k.startswith('conv4'): 248 | model_dict[k.replace('conv4', 'conv5')] = v 249 | elif k.startswith('bn4'): 250 | model_dict[k.replace('bn4', 'bn5')] = v 251 | else: 252 | model_dict[k] = v 253 | state_dict.update(model_dict) 254 | self.load_state_dict(state_dict) 255 | 256 | class ASPP_module(nn.Module): 257 | def __init__(self, inplanes, planes, rate): 258 | super(ASPP_module, self).__init__() 259 | if rate == 1: 260 | kernel_size = 1 261 | padding = 0 262 | else: 263 | kernel_size = 3 264 | padding = rate 265 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 266 | stride=1, padding=padding, dilation=rate, bias=False) 267 | self.bn = nn.BatchNorm2d(planes) 268 | self.relu = nn.ReLU() 269 | 270 | self.__init_weight() 271 | 272 | def forward(self, x): 273 | x = self.atrous_convolution(x) 274 | x = self.bn(x) 275 | 276 | return self.relu(x) 277 | 278 | def __init_weight(self): 279 | for m in self.modules(): 280 | if isinstance(m, nn.Conv2d): 281 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 282 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 283 | torch.nn.init.kaiming_normal_(m.weight) 284 | elif isinstance(m, nn.BatchNorm2d): 285 | m.weight.data.fill_(1) 286 | m.bias.data.zero_() 287 | 288 | 289 | class DeepLabv3_plus(nn.Module): 290 | def __init__(self, nInputChannels=3, n_classes=21, pretrained=False, _print=True): 291 | if _print: 292 | print("Constructing DeepLabv3+ model...") 293 | print("Number of classes: {}".format(n_classes)) 294 | print("Number of Input Channels: {}".format(nInputChannels)) 295 | super(DeepLabv3_plus, self).__init__() 296 | 297 | # Atrous Conv 298 | self.xception_features = Xception(nInputChannels, pretrained=pretrained) 299 | 300 | # ASPP 301 | rates = [1, 6, 12, 18] 302 | self.aspp1 = ASPP_module(2048, 256, rate=rates[0]) 303 | self.aspp2 = ASPP_module(2048, 256, rate=rates[1]) 304 | self.aspp3 = ASPP_module(2048, 256, rate=rates[2]) 305 | self.aspp4 = ASPP_module(2048, 256, rate=rates[3]) 306 | 307 | self.relu = nn.ReLU() 308 | 309 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 310 | nn.Conv2d(2048, 256, 1, stride=1, bias=False), 311 | nn.GroupNorm(32, 256),#nn.BatchNorm2d(256), 312 | nn.ReLU()) 313 | 314 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 315 | #self.bn1 = nn.BatchNorm2d(256) 316 | self.bn1 = nn.GroupNorm(32, 256) 317 | 318 | # adopt [1x1, 48] for channel reduction. 319 | self.conv2 = nn.Conv2d(128, 48, 1, bias=False) 320 | #self.bn2 = nn.BatchNorm2d(48) 321 | self.bn2 = nn.GroupNorm(16, 48) 322 | 323 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 324 | #nn.BatchNorm2d(256), 325 | nn.GroupNorm(32, 256), 326 | nn.ReLU(), 327 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 328 | #nn.BatchNorm2d(256), 329 | nn.GroupNorm(32, 256), 330 | nn.ReLU(), 331 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1)) 332 | 333 | def forward(self, input): 334 | x, low_level_features = self.xception_features(input) 335 | 336 | low_level_features = self.conv2(low_level_features) 337 | low_level_features = self.bn2(low_level_features) 338 | low_level_features = self.relu(low_level_features) 339 | 340 | x1 = self.aspp1(x) 341 | x2 = self.aspp2(x) 342 | x3 = self.aspp3(x) 343 | x4 = self.aspp4(x) 344 | x5 = self.global_avg_pool(x) 345 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 346 | 347 | y = torch.cat((x1, x2, x3, x4, x5), dim=1) 348 | 349 | y = self.conv1(y) 350 | y = self.bn1(y) 351 | y = self.relu(y) 352 | y = F.interpolate(y, size=low_level_features.size()[2:], mode='bilinear', align_corners=True) 353 | 354 | z = torch.cat((y, low_level_features), dim=1) 355 | z = self.last_conv(z) 356 | z = F.interpolate(z, size=input.size()[2:], mode='bilinear', align_corners=True) 357 | 358 | return z 359 | 360 | def freeze_bn(self): 361 | for m in self.modules(): 362 | if isinstance(m, nn.BatchNorm2d): 363 | m.eval() 364 | 365 | def __init_weight(self): 366 | for m in self.modules(): 367 | if isinstance(m, nn.Conv2d): 368 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 369 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 370 | torch.nn.init.kaiming_normal_(m.weight) 371 | elif isinstance(m, nn.BatchNorm2d): 372 | m.weight.data.fill_(1) 373 | m.bias.data.zero_() 374 | 375 | def get_1x_lr_params(model): 376 | """ 377 | This generator returns all the parameters of the net except for 378 | the last classification layer. Note that for each batchnorm layer, 379 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 380 | any batchnorm parameter 381 | """ 382 | b = [model.xception_features] 383 | for i in range(len(b)): 384 | for k in b[i].parameters(): 385 | if k.requires_grad: 386 | yield k 387 | 388 | 389 | def get_10x_lr_params(model): 390 | """ 391 | This generator returns all the parameters for the last layer of the net, 392 | which does the classification of pixel into classes 393 | """ 394 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] 395 | for j in range(len(b)): 396 | for k in b[j].parameters(): 397 | if k.requires_grad: 398 | yield k 399 | 400 | 401 | if __name__ == "__main__": 402 | model = DeepLabv3_plus(nInputChannels=3, n_classes=21, pretrained=True, _print=True) 403 | image = torch.randn(1, 3, 512, 512) 404 | with torch.no_grad(): 405 | output = model.forward(image) 406 | print(output.size()) 407 | 408 | 409 | 410 | 411 | 412 | 413 | --------------------------------------------------------------------------------