├── models ├── __init__.py ├── model_store.py ├── resnet.py ├── resnet_dilation.py └── fpn_global_local_fmreg_ensemble.py ├── utils ├── __init__.py ├── metrics.py ├── lr_scheduler.py ├── loss.py └── lovasz_losses.py ├── dataset ├── __init__.py └── deep_globe.py ├── docs └── images │ ├── glnet.png │ ├── examples.jpg │ ├── gl_branch.png │ └── deep_globe_acc_mem_ext.jpg ├── requirements.txt ├── train_deep_globe_global.sh ├── train_deep_globe_global2local.sh ├── train_deep_globe_local2global.sh ├── eval_deep_globe.sh ├── LICENSE ├── .gitignore ├── test.txt ├── option.py ├── crossvali.txt ├── README.md ├── train.txt ├── train_deep_globe.py └── helper.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/glnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/GLNet/HEAD/docs/images/glnet.png -------------------------------------------------------------------------------- /docs/images/examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/GLNet/HEAD/docs/images/examples.jpg -------------------------------------------------------------------------------- /docs/images/gl_branch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/GLNet/HEAD/docs/images/gl_branch.png -------------------------------------------------------------------------------- /docs/images/deep_globe_acc_mem_ext.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/GLNet/HEAD/docs/images/deep_globe_acc_mem_ext.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch==0.4.1 3 | torchvision==0.3.0 4 | tqdm 5 | tensorboardX 6 | Pillow 7 | opencv-python==3.4.4 8 | -------------------------------------------------------------------------------- /train_deep_globe_global.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python train_deep_globe.py \ 3 | --n_class 7 \ 4 | --data_path "/ssd1/chenwy/deep_globe/data/" \ 5 | --model_path "/home/chenwy/deep_globe/saved_models/" \ 6 | --log_path "/home/chenwy/deep_globe/runs/" \ 7 | --task_name "fpn_deepglobe_global" \ 8 | --mode 1 \ 9 | --batch_size 6 \ 10 | --sub_batch_size 6 \ 11 | --size_g 508 \ 12 | --size_p 508 \ -------------------------------------------------------------------------------- /train_deep_globe_global2local.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python train_deep_globe.py \ 3 | --n_class 7 \ 4 | --data_path "/ssd1/chenwy/deep_globe/data/" \ 5 | --model_path "/home/chenwy/deep_globe/saved_models/" \ 6 | --log_path "/home/chenwy/deep_globe/runs/" \ 7 | --task_name "fpn_deepglobe_global2local" \ 8 | --mode 2 \ 9 | --batch_size 6 \ 10 | --sub_batch_size 6 \ 11 | --size_g 508 \ 12 | --size_p 508 \ 13 | --path_g "fpn_deepglobe_global.pth" \ -------------------------------------------------------------------------------- /train_deep_globe_local2global.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python train_deep_globe.py \ 3 | --n_class 7 \ 4 | --data_path "/ssd1/chenwy/deep_globe/data/" \ 5 | --model_path "/home/chenwy/deep_globe/saved_models/" \ 6 | --log_path "/home/chenwy/deep_globe/runs/" \ 7 | --task_name "fpn_deepglobe_local2global" \ 8 | --mode 3 \ 9 | --batch_size 6 \ 10 | --sub_batch_size 6 \ 11 | --size_g 508 \ 12 | --size_p 508 \ 13 | --path_g "fpn_deepglobe_global.pth" \ 14 | --path_g2l "fpn_deepglobe_global2local.pth" \ -------------------------------------------------------------------------------- /eval_deep_globe.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 2 | python train_deep_globe.py \ 3 | --n_class 7 \ 4 | --data_path "/ssd1/chenwy/deep_globe/data/" \ 5 | --model_path "/home/chenwy/deep_globe/saved_models/" \ 6 | --log_path "/home/chenwy/deep_globe/runs/" \ 7 | --task_name "eval" \ 8 | --mode 3 \ 9 | --batch_size 6 \ 10 | --sub_batch_size 6 \ 11 | --size_g 508 \ 12 | --size_p 508 \ 13 | --path_g "fpn_deepglobe_global.pth" \ 14 | --path_g2l "fpn_deepglobe_global2local.pth" \ 15 | --path_l2g "fpn_deepglobe_local2global.pth" \ 16 | --evaluation -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Wuyang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # vim swp files 2 | *.swp 3 | # caffe/pytorch model files 4 | *.pth 5 | 6 | # Mkdocs 7 | # /docs/ 8 | /mkdocs/docs/temp 9 | 10 | .DS_Store 11 | .idea 12 | .vscode 13 | .pytest_cache 14 | /experiments 15 | 16 | # resource temp folder 17 | tests/resources/temp/* 18 | !tests/resources/temp/.gitkeep 19 | 20 | # Byte-compiled / optimized / DLL files 21 | __pycache__/ 22 | *.py[cod] 23 | *$py.class 24 | 25 | # C extensions 26 | *.so 27 | 28 | # Distribution / packaging 29 | .Python 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | .static_storage/ 76 | .media/ 77 | local_settings.py 78 | local_settings.py 79 | db.sqlite3 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | target/ 93 | 94 | # Jupyter Notebook 95 | .ipynb_checkpoints 96 | 97 | # pyenv 98 | .python-version 99 | 100 | # celery beat schedule file 101 | celerybeat-schedule 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | class ConfusionMatrix(object): 7 | 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | # axis = 0: target 11 | # axis = 1: prediction 12 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 13 | # self.iou = [] 14 | # self.iou_threshold = [] 15 | 16 | def _fast_hist(self, label_true, label_pred, n_class): 17 | mask = (label_true >= 0) & (label_true < n_class) 18 | hist = np.bincount(n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) 19 | return hist 20 | 21 | def update(self, label_trues, label_preds): 22 | for lt, lp in zip(label_trues, label_preds): 23 | tmp = self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 24 | 25 | # iu = np.diag(tmp) / (tmp.sum(axis=1) + tmp.sum(axis=0) - np.diag(tmp)) 26 | # self.iou.append(iu[1]) 27 | # if iu[1] >= 0.65: self.iou_threshold.append(iu[1]) 28 | # else: self.iou_threshold.append(0) 29 | 30 | self.confusion_matrix += tmp 31 | 32 | def get_scores(self): 33 | """Returns accuracy score evaluation result. 34 | - overall accuracy 35 | - mean accuracy 36 | - mean IU 37 | - fwavacc 38 | """ 39 | hist = self.confusion_matrix 40 | # accuracy is recall/sensitivity for each class, predicted TP / all real positives 41 | # axis in sum: perform summation along 42 | acc = np.nan_to_num(np.diag(hist) / hist.sum(axis=1)) 43 | acc_mean = np.mean(np.nan_to_num(acc)) 44 | 45 | intersect = np.diag(hist) 46 | union = hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) 47 | iou = intersect / union 48 | mean_iou = np.mean(np.nan_to_num(iou)) 49 | 50 | freq = hist.sum(axis=1) / hist.sum() # freq of each target 51 | # fwavacc = (freq[freq > 0] * iou[freq > 0]).sum() 52 | freq_iou = (freq * iou).sum() 53 | 54 | return {'accuracy': acc, 55 | 'accuracy_mean': acc_mean, 56 | 'freqw_iou': freq_iou, 57 | 'iou': iou, 58 | 'iou_mean': mean_iou, 59 | # 'IoU_threshold': np.mean(np.nan_to_num(self.iou_threshold)), 60 | } 61 | 62 | def reset(self): 63 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 64 | # self.iou = [] 65 | # self.iou_threshold = [] -------------------------------------------------------------------------------- /test.txt: -------------------------------------------------------------------------------- 1 | 10452_sat.jpg 2 | 114473_sat.jpg 3 | 120245_sat.jpg 4 | 127660_sat.jpg 5 | 137499_sat.jpg 6 | 143364_sat.jpg 7 | 143794_sat.jpg 8 | 147545_sat.jpg 9 | 148260_sat.jpg 10 | 148381_sat.jpg 11 | 161109_sat.jpg 12 | 170535_sat.jpg 13 | 181447_sat.jpg 14 | 185522_sat.jpg 15 | 186739_sat.jpg 16 | 195769_sat.jpg 17 | 209787_sat.jpg 18 | 211316_sat.jpg 19 | 219555_sat.jpg 20 | 225393_sat.jpg 21 | 225945_sat.jpg 22 | 226788_sat.jpg 23 | 242583_sat.jpg 24 | 245846_sat.jpg 25 | 255876_sat.jpg 26 | 271245_sat.jpg 27 | 271941_sat.jpg 28 | 273002_sat.jpg 29 | 277049_sat.jpg 30 | 277900_sat.jpg 31 | 28689_sat.jpg 32 | 28935_sat.jpg 33 | 294978_sat.jpg 34 | 307626_sat.jpg 35 | 309818_sat.jpg 36 | 321711_sat.jpg 37 | 326173_sat.jpg 38 | 326238_sat.jpg 39 | 330838_sat.jpg 40 | 332354_sat.jpg 41 | 338111_sat.jpg 42 | 340798_sat.jpg 43 | 343215_sat.jpg 44 | 349442_sat.jpg 45 | 351228_sat.jpg 46 | 387018_sat.jpg 47 | 393043_sat.jpg 48 | 396979_sat.jpg 49 | 397137_sat.jpg 50 | 402209_sat.jpg 51 | 407467_sat.jpg 52 | 412210_sat.jpg 53 | 420078_sat.jpg 54 | 427037_sat.jpg 55 | 428841_sat.jpg 56 | 432089_sat.jpg 57 | 437963_sat.jpg 58 | 449319_sat.jpg 59 | 454655_sat.jpg 60 | 457070_sat.jpg 61 | 457265_sat.jpg 62 | 471187_sat.jpg 63 | 498049_sat.jpg 64 | 501284_sat.jpg 65 | 503968_sat.jpg 66 | 504704_sat.jpg 67 | 505217_sat.jpg 68 | 508676_sat.jpg 69 | 509290_sat.jpg 70 | 513585_sat.jpg 71 | 513968_sat.jpg 72 | 525105_sat.jpg 73 | 533948_sat.jpg 74 | 533952_sat.jpg 75 | 543806_sat.jpg 76 | 547080_sat.jpg 77 | 556452_sat.jpg 78 | 557439_sat.jpg 79 | 560353_sat.jpg 80 | 572237_sat.jpg 81 | 574789_sat.jpg 82 | 576417_sat.jpg 83 | 584663_sat.jpg 84 | 589940_sat.jpg 85 | 591815_sat.jpg 86 | 599743_sat.jpg 87 | 603617_sat.jpg 88 | 606_sat.jpg 89 | 615420_sat.jpg 90 | 620018_sat.jpg 91 | 624916_sat.jpg 92 | 627583_sat.jpg 93 | 635841_sat.jpg 94 | 639004_sat.jpg 95 | 649042_sat.jpg 96 | 652183_sat.jpg 97 | 659953_sat.jpg 98 | 660933_sat.jpg 99 | 661864_sat.jpg 100 | 671164_sat.jpg 101 | 68078_sat.jpg 102 | 684377_sat.jpg 103 | 691384_sat.jpg 104 | 708588_sat.jpg 105 | 71125_sat.jpg 106 | 713813_sat.jpg 107 | 732669_sat.jpg 108 | 751939_sat.jpg 109 | 755453_sat.jpg 110 | 757745_sat.jpg 111 | 771393_sat.jpg 112 | 772452_sat.jpg 113 | 777185_sat.jpg 114 | 7791_sat.jpg 115 | 78298_sat.jpg 116 | 78430_sat.jpg 117 | 7892_sat.jpg 118 | 79049_sat.jpg 119 | 799523_sat.jpg 120 | 810749_sat.jpg 121 | 819442_sat.jpg 122 | 828684_sat.jpg 123 | 829962_sat.jpg 124 | 835147_sat.jpg 125 | 842556_sat.jpg 126 | 850510_sat.jpg 127 | 857201_sat.jpg 128 | 858771_sat.jpg 129 | 875327_sat.jpg 130 | 882451_sat.jpg 131 | 898741_sat.jpg 132 | 925382_sat.jpg 133 | 937922_sat.jpg 134 | 950926_sat.jpg 135 | 956410_sat.jpg 136 | 956928_sat.jpg 137 | 965276_sat.jpg 138 | 982744_sat.jpg 139 | 987381_sat.jpg 140 | 992507_sat.jpg 141 | 994520_sat.jpg 142 | 998002_sat.jpg 143 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 24 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 25 | :attr:`args.lr_step` 26 | 27 | iters_per_epoch: number of iterations per epoch 28 | """ 29 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 30 | lr_step=0, warmup_epochs=0): 31 | self.mode = mode 32 | print('Using {} LR Scheduler!'.format(self.mode)) 33 | self.lr = base_lr 34 | if mode == 'step': 35 | assert lr_step 36 | self.lr_step = lr_step 37 | self.iters_per_epoch = iters_per_epoch 38 | self.N = num_epochs * iters_per_epoch 39 | self.epoch = -1 40 | self.warmup_iters = warmup_epochs * iters_per_epoch 41 | 42 | def __call__(self, optimizer, i, epoch, best_pred): 43 | T = epoch * self.iters_per_epoch + i 44 | if self.mode == 'cos': 45 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 46 | elif self.mode == 'poly': 47 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 48 | elif self.mode == 'step': 49 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 50 | else: 51 | raise NotImplemented 52 | # warm up lr schedule 53 | if self.warmup_iters > 0 and T < self.warmup_iters: 54 | lr = lr * 1.0 * T / self.warmup_iters 55 | if epoch > self.epoch: 56 | print('\n=>Epoches %i, learning rate = %.7f, \ 57 | previous best = %.4f' % (epoch, lr, best_pred)) 58 | self.epoch = epoch 59 | assert lr >= 0 60 | self._adjust_learning_rate(optimizer, lr) 61 | 62 | def _adjust_learning_rate(self, optimizer, lr): 63 | if len(optimizer.param_groups) == 1: 64 | optimizer.param_groups[0]['lr'] = lr 65 | else: 66 | # enlarge the lr at the head 67 | for i in range(len(optimizer.param_groups)): 68 | if optimizer.param_groups[i]['lr'] > 0: optimizer.param_groups[i]['lr'] = lr 69 | # optimizer.param_groups[0]['lr'] = lr 70 | # for i in range(1, len(optimizer.param_groups)): 71 | # optimizer.param_groups[i]['lr'] = lr * 10 72 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: CASIA IVA 3 | # Email: jliu@nlpr.ia.ac.cn 4 | # Copyright (c) 2018 5 | ########################################################################### 6 | 7 | import os 8 | import argparse 9 | import torch 10 | 11 | # path_g = os.path.join(model_path, "cityscapes_global.800_4.5.2019.lr5e5.pth") 12 | # # path_g = os.path.join(model_path, "fpn_global.804_nonorm_3.17.2019.lr2e5" + ".pth") 13 | # path_g2l = os.path.join(model_path, "fpn_global2local.508_deep.cat.1x_fmreg_ensemble.p3.0.15l2_3.19.2019.lr2e5.pth") 14 | # path_l2g = os.path.join(model_path, "fpn_local2global.508_deep.cat.1x_fmreg_ensemble.p3_3.19.2019.lr2e5.pth") 15 | class Options(): 16 | def __init__(self): 17 | parser = argparse.ArgumentParser(description='PyTorch Segmentation') 18 | # model and dataset 19 | parser.add_argument('--n_class', type=int, default=7, help='segmentation classes') 20 | parser.add_argument('--data_path', type=str, help='path to dataset where images store') 21 | parser.add_argument('--model_path', type=str, help='path to store trained model files, no need to include task specific name') 22 | parser.add_argument('--log_path', type=str, help='path to store tensorboard log files, no need to include task specific name') 23 | parser.add_argument('--task_name', type=str, help='task name for naming saved model files and log files') 24 | parser.add_argument('--mode', type=int, default=1, choices=[1, 2, 3], help='mode for training procedure. 1: train global branch only. 2: train local branch with fixed global branch. 3: train global branch with fixed local branch') 25 | parser.add_argument('--evaluation', action='store_true', default=False, help='evaluation only') 26 | parser.add_argument('--batch_size', type=int, default=6, help='batch size for origin global image (without downsampling)') 27 | parser.add_argument('--sub_batch_size', type=int, default=6, help='batch size for using local image patches') 28 | parser.add_argument('--size_g', type=int, default=508, help='size (in pixel) for downsampled global image') 29 | parser.add_argument('--size_p', type=int, default=508, help='size (in pixel) for cropped local image') 30 | parser.add_argument('--path_g', type=str, default="", help='name for global model path') 31 | parser.add_argument('--path_g2l', type=str, default="", help='name for local from global model path') 32 | parser.add_argument('--path_l2g', type=str, default="", help='name for global from local model path') 33 | parser.add_argument('--lamb_fmreg', type=float, default=0.15, help='loss weight feature map regularization') 34 | 35 | # the parser 36 | self.parser = parser 37 | 38 | def parse(self): 39 | args = self.parser.parse_args() 40 | # default settings for epochs and lr 41 | if args.mode == 1 or args.mode == 3: 42 | args.num_epochs = 120 43 | args.lr = 5e-5 44 | else: 45 | args.num_epochs = 50 46 | args.lr = 2e-5 47 | return args 48 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class CrossEntropyLoss2d(nn.Module): 7 | def __init__(self, weight=None, size_average=True, ignore_index=-100): 8 | super(CrossEntropyLoss2d, self).__init__() 9 | self.nll_loss = nn.NLLLoss(weight, size_average, ignore_index) 10 | 11 | def forward(self, inputs, targets): 12 | return self.nll_loss(F.log_softmax(inputs, dim=1), targets) 13 | 14 | 15 | def one_hot(index, classes): 16 | # index is not flattened (pypass ignore) ############ 17 | # size = index.size()[:1] + (classes,) + index.size()[1:] 18 | # view = index.size()[:1] + (1,) + index.size()[1:] 19 | ##################################################### 20 | # index is flatten (during ignore) ################## 21 | size = index.size()[:1] + (classes,) 22 | view = index.size()[:1] + (1,) 23 | ##################################################### 24 | 25 | # mask = torch.Tensor(size).fill_(0).to(device) 26 | mask = torch.Tensor(size).fill_(0).cuda() 27 | index = index.view(view) 28 | ones = 1. 29 | 30 | return mask.scatter_(1, index, ones) 31 | 32 | 33 | class FocalLoss(nn.Module): 34 | 35 | def __init__(self, gamma=0, eps=1e-7, size_average=True, one_hot=True, ignore=None): 36 | super(FocalLoss, self).__init__() 37 | self.gamma = gamma 38 | self.eps = eps 39 | self.size_average = size_average 40 | self.one_hot = one_hot 41 | self.ignore = ignore 42 | 43 | def forward(self, input, target): 44 | ''' 45 | only support ignore at 0 46 | ''' 47 | B, C, H, W = input.size() 48 | input = input.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 49 | target = target.view(-1) 50 | if self.ignore is not None: 51 | valid = (target != self.ignore) 52 | input = input[valid] 53 | target = target[valid] 54 | 55 | if self.one_hot: target = one_hot(target, input.size(1)) 56 | probs = F.softmax(input, dim=1) 57 | probs = (probs * target).sum(1) 58 | probs = probs.clamp(self.eps, 1. - self.eps) 59 | 60 | log_p = probs.log() 61 | # print('probs size= {}'.format(probs.size())) 62 | # print(probs) 63 | 64 | batch_loss = -(torch.pow((1 - probs), self.gamma)) * log_p 65 | # print('-----bacth_loss------') 66 | # print(batch_loss) 67 | 68 | if self.size_average: 69 | loss = batch_loss.mean() 70 | else: 71 | loss = batch_loss.sum() 72 | return loss 73 | 74 | 75 | class SoftCrossEntropyLoss2d(nn.Module): 76 | def __init__(self): 77 | super(SoftCrossEntropyLoss2d, self).__init__() 78 | 79 | def forward(self, inputs, targets): 80 | loss = 0 81 | inputs = -F.log_softmax(inputs, dim=1) 82 | for index in range(inputs.size()[0]): 83 | loss += F.conv2d(inputs[range(index, index+1)], targets[range(index, index+1)])/(targets.size()[2] * 84 | targets.size()[3]) 85 | return loss 86 | -------------------------------------------------------------------------------- /models/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | __all__ = ['get_model_file', 'purge'] 4 | import os 5 | import zipfile 6 | 7 | from .utils import download, check_sha1 8 | 9 | _model_sha1 = {name: checksum for checksum, name in [ 10 | ('ebb6acbbd1d1c90b7f446ae59d30bf70c74febc1', 'resnet50'), 11 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), 12 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), 13 | ('2e22611a7f3992ebdee6726af169991bc26d7363', 'deepten_minc'), 14 | ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'), 15 | ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), 16 | ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'), 17 | ('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'), 18 | ('5ee47ee28b480cc781a195d13b5806d5bbc616bf', 'encnet_resnet101_coco'), 19 | ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50_pcontext'), 20 | ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'), 21 | ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'), 22 | ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'), 23 | ]} 24 | 25 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 26 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 27 | 28 | def short_hash(name): 29 | if name not in _model_sha1: 30 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 31 | return _model_sha1[name][:8] 32 | 33 | def get_model_file(name, root=os.path.join('~', '.encoding', 'models')): 34 | r"""Return location for the pretrained on local file system. 35 | 36 | This function will download from online model zoo when model cannot be found or has mismatch. 37 | The root directory will be created if it doesn't exist. 38 | 39 | Parameters 40 | ---------- 41 | name : str 42 | Name of the model. 43 | root : str, default '~/.encoding/models' 44 | Location for keeping the model parameters. 45 | 46 | Returns 47 | ------- 48 | file_path 49 | Path to the requested pretrained model file. 50 | """ 51 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 52 | root = os.path.expanduser(root) 53 | file_path = os.path.join(root, file_name+'.pth') 54 | sha1_hash = _model_sha1[name] 55 | if os.path.exists(file_path): 56 | if check_sha1(file_path, sha1_hash): 57 | return file_path 58 | else: 59 | print('Mismatch in the content of model file {} detected.' + 60 | ' Downloading again.'.format(file_path)) 61 | else: 62 | print('Model file {} is not found. Downloading.'.format(file_path)) 63 | 64 | if not os.path.exists(root): 65 | os.makedirs(root) 66 | 67 | zip_file_path = os.path.join(root, file_name+'.zip') 68 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 69 | if repo_url[-1] != '/': 70 | repo_url = repo_url + '/' 71 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 72 | path=zip_file_path, 73 | overwrite=True) 74 | with zipfile.ZipFile(zip_file_path) as zf: 75 | zf.extractall(root) 76 | os.remove(zip_file_path) 77 | 78 | if check_sha1(file_path, sha1_hash): 79 | return file_path 80 | else: 81 | raise ValueError('Downloaded file has different hash. Please try again.') 82 | 83 | def purge(root=os.path.join('~', '.encoding', 'models')): 84 | r"""Purge all pretrained model files in local file store. 85 | 86 | Parameters 87 | ---------- 88 | root : str, default '~/.encoding/models' 89 | Location for keeping the model parameters. 90 | """ 91 | root = os.path.expanduser(root) 92 | files = os.listdir(root) 93 | for f in files: 94 | if f.endswith(".pth"): 95 | os.remove(os.path.join(root, f)) 96 | 97 | def pretrained_model_list(): 98 | return list(_model_sha1.keys()) 99 | -------------------------------------------------------------------------------- /crossvali.txt: -------------------------------------------------------------------------------- 1 | 102122_sat.jpg 2 | 114577_sat.jpg 3 | 115444_sat.jpg 4 | 119012_sat.jpg 5 | 123172_sat.jpg 6 | 124529_sat.jpg 7 | 125510_sat.jpg 8 | 126796_sat.jpg 9 | 127976_sat.jpg 10 | 129297_sat.jpg 11 | 129298_sat.jpg 12 | 133209_sat.jpg 13 | 136252_sat.jpg 14 | 139581_sat.jpg 15 | 143353_sat.jpg 16 | 147716_sat.jpg 17 | 154626_sat.jpg 18 | 155165_sat.jpg 19 | 162310_sat.jpg 20 | 16453_sat.jpg 21 | 166293_sat.jpg 22 | 166805_sat.jpg 23 | 168514_sat.jpg 24 | 176225_sat.jpg 25 | 180902_sat.jpg 26 | 192918_sat.jpg 27 | 194156_sat.jpg 28 | 19627_sat.jpg 29 | 200561_sat.jpg 30 | 200589_sat.jpg 31 | 210436_sat.jpg 32 | 211739_sat.jpg 33 | 219670_sat.jpg 34 | 229383_sat.jpg 35 | 233615_sat.jpg 36 | 234269_sat.jpg 37 | 246378_sat.jpg 38 | 247179_sat.jpg 39 | 255889_sat.jpg 40 | 262885_sat.jpg 41 | 264436_sat.jpg 42 | 268881_sat.jpg 43 | 273274_sat.jpg 44 | 2774_sat.jpg 45 | 280861_sat.jpg 46 | 283326_sat.jpg 47 | 286339_sat.jpg 48 | 300745_sat.jpg 49 | 312676_sat.jpg 50 | 315848_sat.jpg 51 | 323581_sat.jpg 52 | 324170_sat.jpg 53 | 329017_sat.jpg 54 | 331421_sat.jpg 55 | 334677_sat.jpg 56 | 334811_sat.jpg 57 | 338661_sat.jpg 58 | 34567_sat.jpg 59 | 350033_sat.jpg 60 | 350328_sat.jpg 61 | 351271_sat.jpg 62 | 354033_sat.jpg 63 | 358314_sat.jpg 64 | 358464_sat.jpg 65 | 362191_sat.jpg 66 | 373103_sat.jpg 67 | 375563_sat.jpg 68 | 394500_sat.jpg 69 | 406425_sat.jpg 70 | 416794_sat.jpg 71 | 418261_sat.jpg 72 | 419820_sat.jpg 73 | 424590_sat.jpg 74 | 427774_sat.jpg 75 | 428597_sat.jpg 76 | 430587_sat.jpg 77 | 434210_sat.jpg 78 | 43814_sat.jpg 79 | 438721_sat.jpg 80 | 44070_sat.jpg 81 | 442338_sat.jpg 82 | 443271_sat.jpg 83 | 455374_sat.jpg 84 | 461001_sat.jpg 85 | 461755_sat.jpg 86 | 462612_sat.jpg 87 | 467855_sat.jpg 88 | 471930_sat.jpg 89 | 472774_sat.jpg 90 | 479682_sat.jpg 91 | 491491_sat.jpg 92 | 495406_sat.jpg 93 | 499325_sat.jpg 94 | 499600_sat.jpg 95 | 501804_sat.jpg 96 | 512669_sat.jpg 97 | 514385_sat.jpg 98 | 514414_sat.jpg 99 | 51911_sat.jpg 100 | 536496_sat.jpg 101 | 537221_sat.jpg 102 | 538243_sat.jpg 103 | 538922_sat.jpg 104 | 544078_sat.jpg 105 | 544537_sat.jpg 106 | 550312_sat.jpg 107 | 552001_sat.jpg 108 | 557175_sat.jpg 109 | 559477_sat.jpg 110 | 563092_sat.jpg 111 | 565914_sat.jpg 112 | 570992_sat.jpg 113 | 571520_sat.jpg 114 | 577164_sat.jpg 115 | 584712_sat.jpg 116 | 584865_sat.jpg 117 | 586222_sat.jpg 118 | 586806_sat.jpg 119 | 600230_sat.jpg 120 | 605707_sat.jpg 121 | 614561_sat.jpg 122 | 619800_sat.jpg 123 | 62078_sat.jpg 124 | 621459_sat.jpg 125 | 626323_sat.jpg 126 | 628479_sat.jpg 127 | 638168_sat.jpg 128 | 638937_sat.jpg 129 | 641771_sat.jpg 130 | 646596_sat.jpg 131 | 650253_sat.jpg 132 | 651537_sat.jpg 133 | 652733_sat.jpg 134 | 654770_sat.jpg 135 | 660069_sat.jpg 136 | 669156_sat.jpg 137 | 673927_sat.jpg 138 | 679507_sat.jpg 139 | 686781_sat.jpg 140 | 688544_sat.jpg 141 | 692982_sat.jpg 142 | 702918_sat.jpg 143 | 703413_sat.jpg 144 | 705728_sat.jpg 145 | 706996_sat.jpg 146 | 707319_sat.jpg 147 | 708527_sat.jpg 148 | 725646_sat.jpg 149 | 726265_sat.jpg 150 | 728521_sat.jpg 151 | 730889_sat.jpg 152 | 733758_sat.jpg 153 | 741105_sat.jpg 154 | 748225_sat.jpg 155 | 749375_sat.jpg 156 | 762470_sat.jpg 157 | 762937_sat.jpg 158 | 767012_sat.jpg 159 | 772130_sat.jpg 160 | 775304_sat.jpg 161 | 77669_sat.jpg 162 | 784518_sat.jpg 163 | 794214_sat.jpg 164 | 81039_sat.jpg 165 | 818254_sat.jpg 166 | 820347_sat.jpg 167 | 831146_sat.jpg 168 | 834900_sat.jpg 169 | 838873_sat.jpg 170 | 839012_sat.jpg 171 | 839641_sat.jpg 172 | 841286_sat.jpg 173 | 841404_sat.jpg 174 | 861353_sat.jpg 175 | 864488_sat.jpg 176 | 867349_sat.jpg 177 | 867983_sat.jpg 178 | 875409_sat.jpg 179 | 876248_sat.jpg 180 | 891153_sat.jpg 181 | 893651_sat.jpg 182 | 897901_sat.jpg 183 | 900985_sat.jpg 184 | 904606_sat.jpg 185 | 908837_sat.jpg 186 | 912087_sat.jpg 187 | 912620_sat.jpg 188 | 918105_sat.jpg 189 | 919602_sat.jpg 190 | 925425_sat.jpg 191 | 930491_sat.jpg 192 | 934795_sat.jpg 193 | 935193_sat.jpg 194 | 935318_sat.jpg 195 | 941237_sat.jpg 196 | 942986_sat.jpg 197 | 949559_sat.jpg 198 | 958243_sat.jpg 199 | 958443_sat.jpg 200 | 961919_sat.jpg 201 | 96841_sat.jpg 202 | 970925_sat.jpg 203 | 97337_sat.jpg 204 | 978039_sat.jpg 205 | 981253_sat.jpg 206 | 986342_sat.jpg 207 | 997521_sat.jpg 208 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet50', 'resnet101'] 6 | 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | class Bottleneck(nn.Module): 18 | expansion = 4 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(Bottleneck, self).__init__() 22 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv3(out) 45 | out = self.bn3(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class ResNet(nn.Module): 57 | 58 | def __init__(self, block, layers): 59 | self.inplanes = 64 60 | super(ResNet, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | self.layer1 = self._make_layer(block, 64, layers[0]) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 69 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 74 | elif isinstance(m, nn.BatchNorm2d): 75 | nn.init.constant_(m.weight, 1) 76 | nn.init.constant_(m.bias, 0) 77 | 78 | def _make_layer(self, block, planes, blocks, stride=1): 79 | downsample = None 80 | if stride != 1 or self.inplanes != planes * block.expansion: 81 | downsample = nn.Sequential( 82 | nn.Conv2d(self.inplanes, planes * block.expansion, 83 | kernel_size=1, stride=stride, bias=False), 84 | nn.BatchNorm2d(planes * block.expansion), 85 | ) 86 | 87 | layers = [] 88 | layers.append(block(self.inplanes, planes, stride, downsample)) 89 | self.inplanes = planes * block.expansion 90 | for i in range(1, blocks): 91 | layers.append(block(self.inplanes, planes)) 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | x = self.conv1(x) 97 | x = self.bn1(x) 98 | x = self.relu(x) 99 | x = self.maxpool(x) 100 | 101 | c2 = self.layer1(x) 102 | c3 = self.layer2(c2) 103 | c4 = self.layer3(c3) 104 | c5 = self.layer4(c4) 105 | 106 | return c2, c3, c4, c5 107 | 108 | 109 | def resnet50(pretrained=False, **kwargs): 110 | """Constructs a ResNet-50 model. 111 | Args: 112 | pretrained (bool): If True, returns a model pre-trained on ImageNet 113 | """ 114 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 115 | if pretrained: 116 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 117 | return model 118 | 119 | 120 | def resnet101(pretrained=False, **kwargs): 121 | """Constructs a ResNet-101 model. 122 | Args: 123 | pretrained (bool): If True, returns a model pre-trained on ImageNet 124 | """ 125 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 126 | if pretrained: 127 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False) 128 | return model 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GLNet for Memory-Efficient Segmentation of Ultra-High Resolution Images 2 | 3 | [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/chenwydj/ultra_high_resolution_segmentation.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/chenwydj/ultra_high_resolution_segmentation/context:python) [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | Collaborative Global-Local Networks for Memory-Efficient Segmentation of Ultra-High Resolution Images 6 | 7 | Wuyang Chen*, Ziyu Jiang*, Zhangyang Wang, Kexin Cui, and Xiaoning Qian 8 | 9 | In CVPR 2019 (Oral). [[Youtube](https://www.youtube.com/watch?v=am1GiItQI88)] 10 | 11 | ## Overview 12 | 13 | Segmentation of ultra-high resolution images is increasingly demanded in a wide range of applications (e.g. urban planning), yet poses significant challenges for algorithm efficiency, in particular considering the (GPU) memory limits. 14 | 15 | We propose collaborative **Global-Local Networks (GLNet)** to effectively preserve both global and local information in a highly memory-efficient manner. 16 | 17 | * **Memory-efficient**: **training w. only one 1080Ti** and **inference w. less than 2GB GPU memory**, for ultra-high resolution images of up to 30M pixels. 18 | 19 | * **High-quality**: GLNet outperforms existing segmentation models on ultra-high resolution images. 20 | 21 |

22 | Acc_vs_Mem
23 | Inference memory v.s. mIoU on the DeepGlobe dataset. 24 |
25 | GLNet (red dots) integrates both global and local information in a compact way, contributing to a well-balanced trade-off between accuracy and memory usage.
26 |

27 | 28 |

29 | Examples
30 | Ultra-high resolution Datasets: DeepGlobe, ISIC, Inria Aerial 31 |

32 | 33 | ## Methods 34 | 35 |

36 | GLNet
37 | GLNet: the global and local branch takes downsampled and cropped images, respectively. Deep feature map sharing and feature map regularization enforce our global-local collaboration. The final segmentation is generated by aggregating high-level feature maps from two branches. 38 |

39 | 40 |

41 | GLNet
42 | Deep feature map sharing: at each layer, feature maps with global context and ones with local fine structures are bidirectionally brought together, contributing to a complete patch-based deep global-local collaboration. 43 |

44 | 45 | ## Training 46 | Current this code base works for Python version >= 3.5. 47 | 48 | Please install the dependencies: `pip install -r requirements.txt` 49 | 50 | First, you could register and download the Deep Globe "Land Cover Classification" dataset here: 51 | https://competitions.codalab.org/competitions/18468 52 | 53 | Then please sequentially finish the following steps: 54 | 1. `./train_deep_globe_global.sh` 55 | 2. `./train_deep_globe_global2local.sh` 56 | 3. `./train_deep_globe_local2global.sh` 57 | 58 | The above jobs complete the following tasks: 59 | * create folder "saved_models" and "runs" to store the model checkpoints and logging files (you could configure the bash scrips to use your own paths). 60 | * step 1 and 2 prepare the trained models for step 2 and 3, respectively. You could use your own names to save the model checkpoints, but this requires to update values of the flag `path_g` and `path_g2l`. 61 | 62 | ## Evaluation 63 | 1. Please download the pre-trained models for the Deep Globe dataset and put them into folder "saved_models": 64 | * [fpn_deepglobe_global.pth](https://drive.google.com/file/d/1xUJoNEzj5LeclH9tHXZ2VsEI9LpC77kQ/view?usp=sharing) 65 | * [fpn_deepglobe_global2local.pth](https://drive.google.com/file/d/1_lCzi2KIygcrRcvBJ31G3cBwAMibn_AS/view?usp=sharing) 66 | * [fpn_deepglobe_local2global.pth](https://drive.google.com/file/d/198EcAO7VN8Ujn4N4FBg3sRgb8R_UKhYv/view?usp=sharing) 67 | 2. Download (see above "Training" section) and prepare the Deep Globe dataset according to the train.txt and crossvali.txt: put the image and label files into folder "train" and folder "crossvali" 68 | 3. Run script `./eval_deep_globe.sh` 69 | 70 | ## Citation 71 | If you use this code for your research, please cite our paper. 72 | ``` 73 | @inproceedings{chen2019GLNET, 74 | title={Collaborative Global-Local Networks for Memory-Efficient Segmentation of Ultra-High Resolution Images}, 75 | author={Chen, Wuyang and Jiang, Ziyu and Wang, Zhangyang and Cui, Kexin and Qian, Xiaoning}, 76 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 77 | year={2019} 78 | } 79 | ``` 80 | 81 | ## Acknowledgement 82 | We thank Prof. Andrew Jiang and Junru Wu for helping experiments. 83 | 84 | 86 | -------------------------------------------------------------------------------- /dataset/deep_globe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import numpy as np 4 | from PIL import Image, ImageFile 5 | import random 6 | from torchvision.transforms import ToTensor 7 | from torchvision import transforms 8 | import cv2 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 15 | 16 | 17 | def find_label_map_name(img_filenames, labelExtension=".png"): 18 | img_filenames = img_filenames.replace('_sat.jpg', '_mask') 19 | return img_filenames + labelExtension 20 | 21 | 22 | def RGB_mapping_to_class(label): 23 | l, w = label.shape[0], label.shape[1] 24 | classmap = np.zeros(shape=(l, w)) 25 | indices = np.where(np.all(label == (0, 255, 255), axis=-1)) 26 | classmap[indices[0].tolist(), indices[1].tolist()] = 1 27 | indices = np.where(np.all(label == (255, 255, 0), axis=-1)) 28 | classmap[indices[0].tolist(), indices[1].tolist()] = 2 29 | indices = np.where(np.all(label == (255, 0, 255), axis=-1)) 30 | classmap[indices[0].tolist(), indices[1].tolist()] = 3 31 | indices = np.where(np.all(label == (0, 255, 0), axis=-1)) 32 | classmap[indices[0].tolist(), indices[1].tolist()] = 4 33 | indices = np.where(np.all(label == (0, 0, 255), axis=-1)) 34 | classmap[indices[0].tolist(), indices[1].tolist()] = 5 35 | indices = np.where(np.all(label == (255, 255, 255), axis=-1)) 36 | classmap[indices[0].tolist(), indices[1].tolist()] = 6 37 | indices = np.where(np.all(label == (0, 0, 0), axis=-1)) 38 | classmap[indices[0].tolist(), indices[1].tolist()] = 0 39 | # plt.imshow(colmap) 40 | # plt.show() 41 | return classmap 42 | 43 | 44 | def classToRGB(label): 45 | l, w = label.shape[0], label.shape[1] 46 | colmap = np.zeros(shape=(l, w, 3)).astype(np.float32) 47 | indices = np.where(label == 1) 48 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [0, 255, 255] 49 | indices = np.where(label == 2) 50 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [255, 255, 0] 51 | indices = np.where(label == 3) 52 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [255, 0, 255] 53 | indices = np.where(label == 4) 54 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [0, 255, 0] 55 | indices = np.where(label == 5) 56 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [0, 0, 255] 57 | indices = np.where(label == 6) 58 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [255, 255, 255] 59 | indices = np.where(label == 0) 60 | colmap[indices[0].tolist(), indices[1].tolist(), :] = [0, 0, 0] 61 | transform = ToTensor(); 62 | # plt.imshow(colmap) 63 | # plt.show() 64 | return transform(colmap) 65 | 66 | 67 | def class_to_target(inputs, numClass): 68 | batchSize, l, w = inputs.shape[0], inputs.shape[1], inputs.shape[2] 69 | target = np.zeros(shape=(batchSize, l, w, numClass), dtype=np.float32) 70 | for index in range(7): 71 | indices = np.where(inputs == index) 72 | temp = np.zeros(shape=7, dtype=np.float32) 73 | temp[index] = 1 74 | target[indices[0].tolist(), indices[1].tolist(), indices[2].tolist(), :] = temp 75 | return target.transpose(0, 3, 1, 2) 76 | 77 | 78 | def label_bluring(inputs): 79 | batchSize, numClass, height, width = inputs.shape 80 | outputs = np.ones((batchSize, numClass, height, width), dtype=np.float) 81 | for batchCnt in range(batchSize): 82 | for index in range(numClass): 83 | outputs[batchCnt, index, ...] = cv2.GaussianBlur(inputs[batchCnt, index, ...].astype(np.float), (7, 7), 0) 84 | return outputs 85 | 86 | 87 | class DeepGlobe(data.Dataset): 88 | """input and label image dataset""" 89 | 90 | def __init__(self, root, ids, label=False, transform=False): 91 | super(DeepGlobe, self).__init__() 92 | """ 93 | Args: 94 | 95 | fileDir(string): directory with all the input images. 96 | transform(callable, optional): Optional transform to be applied on a sample 97 | """ 98 | self.root = root 99 | self.label = label 100 | self.transform = transform 101 | self.ids = ids 102 | self.classdict = {1: "urban", 2: "agriculture", 3: "rangeland", 4: "forest", 5: "water", 6: "barren", 0: "unknown"} 103 | 104 | self.color_jitter = transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.04) 105 | self.resizer = transforms.Resize((2448, 2448)) 106 | 107 | def __getitem__(self, index): 108 | sample = {} 109 | sample['id'] = self.ids[index][:-8] 110 | image = Image.open(os.path.join(self.root, "Sat/" + self.ids[index])) # w, h 111 | sample['image'] = image 112 | # sample['image'] = transforms.functional.adjust_contrast(image, 1.4) 113 | if self.label: 114 | # label = scipy.io.loadmat(join(self.root, 'Notification/' + self.ids[index].replace('_sat.jpg', '_mask.mat')))["label"] 115 | # label = Image.fromarray(label) 116 | label = Image.open(os.path.join(self.root, 'Label/' + self.ids[index].replace('_sat.jpg', '_mask.png'))) 117 | sample['label'] = label 118 | if self.transform and self.label: 119 | image, label = self._transform(image, label) 120 | sample['image'] = image 121 | sample['label'] = label 122 | # return {'image': image.astype(np.float32), 'label': label.astype(np.int64)} 123 | return sample 124 | 125 | def _transform(self, image, label): 126 | # if np.random.random() > 0.5: 127 | # image = self.color_jitter(image) 128 | 129 | # if np.random.random() > 0.5: 130 | # image = transforms.functional.vflip(image) 131 | # label = transforms.functional.vflip(label) 132 | 133 | if np.random.random() > 0.5: 134 | image = transforms.functional.hflip(image) 135 | label = transforms.functional.hflip(label) 136 | 137 | if np.random.random() > 0.5: 138 | degree = random.choice([90, 180, 270]) 139 | image = transforms.functional.rotate(image, degree) 140 | label = transforms.functional.rotate(label, degree) 141 | 142 | # if np.random.random() > 0.5: 143 | # degree = 60 * np.random.random() - 30 144 | # image = transforms.functional.rotate(image, degree) 145 | # label = transforms.functional.rotate(label, degree) 146 | 147 | # if np.random.random() > 0.5: 148 | # ratio = np.random.random() 149 | # h = int(2448 * (ratio + 2) / 3.) 150 | # w = int(2448 * (ratio + 2) / 3.) 151 | # i = int(np.floor(np.random.random() * (2448 - h))) 152 | # j = int(np.floor(np.random.random() * (2448 - w))) 153 | # image = self.resizer(transforms.functional.crop(image, i, j, h, w)) 154 | # label = self.resizer(transforms.functional.crop(label, i, j, h, w)) 155 | 156 | return image, label 157 | 158 | 159 | def __len__(self): 160 | return len(self.ids) -------------------------------------------------------------------------------- /utils/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | # https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py 2 | """ 3 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 4 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | import numpy as np 13 | try: 14 | from itertools import ifilterfalse 15 | except ImportError: # py3k 16 | from itertools import filterfalse 17 | 18 | 19 | def lovasz_grad(gt_sorted): 20 | """ 21 | Computes gradient of the Lovasz extension w.r.t sorted errors 22 | See Alg. 1 in paper 23 | """ 24 | p = len(gt_sorted) 25 | gts = gt_sorted.sum() 26 | intersection = gts - gt_sorted.float().cumsum(0) 27 | union = gts + (1 - gt_sorted).float().cumsum(0) 28 | jaccard = 1. - intersection / union 29 | if p > 1: # cover 1-pixel case 30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 31 | return jaccard 32 | 33 | 34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 35 | """ 36 | IoU for foreground class 37 | binary: 1 foreground, 0 background 38 | """ 39 | if not per_image: 40 | preds, labels = (preds,), (labels,) 41 | ious = [] 42 | for pred, label in zip(preds, labels): 43 | intersection = ((label == 1) & (pred == 1)).sum() 44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 45 | if not union: 46 | iou = EMPTY 47 | else: 48 | iou = float(intersection) / union 49 | ious.append(iou) 50 | iou = mean(ious) # mean accross images if per_image 51 | return 100 * iou 52 | 53 | 54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 55 | """ 56 | Array of IoU for each (non ignored) class 57 | """ 58 | if not per_image: 59 | preds, labels = (preds,), (labels,) 60 | ious = [] 61 | for pred, label in zip(preds, labels): 62 | iou = [] 63 | for i in range(C): 64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 65 | intersection = ((label == i) & (pred == i)).sum() 66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 67 | if not union: 68 | iou.append(EMPTY) 69 | else: 70 | iou.append(float(intersection) / union) 71 | ious.append(iou) 72 | ious = map(mean, zip(*ious)) # mean accross images if per_image 73 | return 100 * np.array(ious) 74 | 75 | 76 | # --------------------------- BINARY LOSSES --------------------------- 77 | 78 | 79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 80 | """ 81 | Binary Lovasz hinge loss 82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 84 | per_image: compute the loss per image instead of per batch 85 | ignore: void class id 86 | """ 87 | if per_image: 88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 89 | for log, lab in zip(logits, labels)) 90 | else: 91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 92 | return loss 93 | 94 | 95 | def lovasz_hinge_flat(logits, labels): 96 | """ 97 | Binary Lovasz hinge loss 98 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 99 | labels: [P] Tensor, binary ground truth labels (0 or 1) 100 | ignore: label to ignore 101 | """ 102 | if len(labels) == 0: 103 | # only void pixels, the gradients should be 0 104 | return logits.sum() * 0. 105 | signs = 2. * labels.float() - 1. 106 | errors = (1. - logits * Variable(signs)) 107 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 108 | perm = perm.data 109 | gt_sorted = labels[perm] 110 | grad = lovasz_grad(gt_sorted) 111 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 112 | return loss 113 | 114 | 115 | def flatten_binary_scores(scores, labels, ignore=None): 116 | """ 117 | Flattens predictions in the batch (binary case) 118 | Remove labels equal to 'ignore' 119 | """ 120 | scores = scores.view(-1) 121 | labels = labels.view(-1) 122 | if ignore is None: 123 | return scores, labels 124 | valid = (labels != ignore) 125 | vscores = scores[valid] 126 | vlabels = labels[valid] 127 | return vscores, vlabels 128 | 129 | 130 | class StableBCELoss(torch.nn.modules.Module): 131 | def __init__(self): 132 | super(StableBCELoss, self).__init__() 133 | def forward(self, input, target): 134 | neg_abs = - input.abs() 135 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 136 | return loss.mean() 137 | 138 | 139 | def binary_xloss(logits, labels, ignore=None): 140 | """ 141 | Binary Cross entropy loss 142 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 143 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 144 | ignore: void class id 145 | """ 146 | logits, labels = flatten_binary_scores(logits, labels, ignore) 147 | loss = StableBCELoss()(logits, Variable(labels.float())) 148 | return loss 149 | 150 | 151 | # --------------------------- MULTICLASS LOSSES --------------------------- 152 | 153 | 154 | def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=None): 155 | """ 156 | Multi-class Lovasz-Softmax loss 157 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1) 158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 159 | only_present: average only on classes present in ground truth 160 | per_image: compute the loss per image instead of per batch 161 | ignore: void class labels 162 | """ 163 | if per_image: 164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present) 165 | for prob, lab in zip(probas, labels)) 166 | else: 167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present) 168 | return loss 169 | 170 | 171 | def lovasz_softmax_flat(probas, labels, only_present=False): 172 | """ 173 | Multi-class Lovasz-Softmax loss 174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 176 | only_present: average only on classes present in ground truth 177 | """ 178 | C = probas.size(1) 179 | losses = [] 180 | for c in range(C): 181 | fg = (labels == c).float() # foreground for class c 182 | if only_present and fg.sum() == 0: 183 | continue 184 | errors = (Variable(fg) - probas[:, c]).abs() 185 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 186 | perm = perm.data 187 | fg_sorted = fg[perm] 188 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 189 | return mean(losses) 190 | 191 | 192 | def flatten_probas(probas, labels, ignore=None): 193 | """ 194 | Flattens predictions in the batch 195 | """ 196 | B, C, H, W = probas.size() 197 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 198 | labels = labels.view(-1) 199 | if ignore is None: 200 | return probas, labels 201 | valid = (labels != ignore) 202 | vprobas = probas[valid.nonzero().squeeze()] 203 | vlabels = labels[valid] 204 | # vlabels = labels[valid] - 1 205 | return vprobas, vlabels 206 | 207 | def xloss(logits, labels, ignore=None): 208 | """ 209 | Cross entropy loss 210 | """ 211 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 212 | 213 | 214 | # --------------------------- HELPER FUNCTIONS --------------------------- 215 | 216 | def mean(l, ignore_nan=False, empty=0): 217 | """ 218 | nanmean compatible with generators. 219 | """ 220 | l = iter(l) 221 | if ignore_nan: 222 | l = ifilterfalse(np.isnan, l) 223 | try: 224 | n = 1 225 | acc = next(l) 226 | except StopIteration: 227 | if empty == 'raise': 228 | raise ValueError('Empty mean') 229 | return empty 230 | for n, v in enumerate(l, 2): 231 | acc += v 232 | if n == 1: 233 | return acc 234 | return acc / n 235 | -------------------------------------------------------------------------------- /train.txt: -------------------------------------------------------------------------------- 1 | 100694_sat.jpg 2 | 10233_sat.jpg 3 | 103665_sat.jpg 4 | 103730_sat.jpg 5 | 104113_sat.jpg 6 | 10901_sat.jpg 7 | 111335_sat.jpg 8 | 114433_sat.jpg 9 | 119079_sat.jpg 10 | 119_sat.jpg 11 | 120625_sat.jpg 12 | 122104_sat.jpg 13 | 122178_sat.jpg 14 | 125795_sat.jpg 15 | 131720_sat.jpg 16 | 133254_sat.jpg 17 | 13415_sat.jpg 18 | 134465_sat.jpg 19 | 137806_sat.jpg 20 | 139482_sat.jpg 21 | 140299_sat.jpg 22 | 141685_sat.jpg 23 | 142766_sat.jpg 24 | 149624_sat.jpg 25 | 152569_sat.jpg 26 | 154124_sat.jpg 27 | 15573_sat.jpg 28 | 156574_sat.jpg 29 | 156951_sat.jpg 30 | 157839_sat.jpg 31 | 158163_sat.jpg 32 | 159177_sat.jpg 33 | 159280_sat.jpg 34 | 159322_sat.jpg 35 | 160037_sat.jpg 36 | 161838_sat.jpg 37 | 164029_sat.jpg 38 | 172307_sat.jpg 39 | 172854_sat.jpg 40 | 174980_sat.jpg 41 | 176112_sat.jpg 42 | 176506_sat.jpg 43 | 182027_sat.jpg 44 | 182422_sat.jpg 45 | 185562_sat.jpg 46 | 192576_sat.jpg 47 | 192602_sat.jpg 48 | 20187_sat.jpg 49 | 202277_sat.jpg 50 | 204494_sat.jpg 51 | 204562_sat.jpg 52 | 207663_sat.jpg 53 | 207743_sat.jpg 54 | 208495_sat.jpg 55 | 208695_sat.jpg 56 | 21023_sat.jpg 57 | 210473_sat.jpg 58 | 210669_sat.jpg 59 | 215525_sat.jpg 60 | 217085_sat.jpg 61 | 21717_sat.jpg 62 | 218329_sat.jpg 63 | 221278_sat.jpg 64 | 232373_sat.jpg 65 | 2334_sat.jpg 66 | 235869_sat.jpg 67 | 238322_sat.jpg 68 | 239955_sat.jpg 69 | 244423_sat.jpg 70 | 24813_sat.jpg 71 | 252743_sat.jpg 72 | 253691_sat.jpg 73 | 254565_sat.jpg 74 | 255711_sat.jpg 75 | 256189_sat.jpg 76 | 257695_sat.jpg 77 | 26261_sat.jpg 78 | 263576_sat.jpg 79 | 266_sat.jpg 80 | 267065_sat.jpg 81 | 267163_sat.jpg 82 | 269601_sat.jpg 83 | 271609_sat.jpg 84 | 27460_sat.jpg 85 | 276761_sat.jpg 86 | 276912_sat.jpg 87 | 277644_sat.jpg 88 | 277994_sat.jpg 89 | 280703_sat.jpg 90 | 282120_sat.jpg 91 | 28559_sat.jpg 92 | 291214_sat.jpg 93 | 291781_sat.jpg 94 | 293776_sat.jpg 95 | 29419_sat.jpg 96 | 294697_sat.jpg 97 | 296279_sat.jpg 98 | 296368_sat.jpg 99 | 298396_sat.jpg 100 | 298817_sat.jpg 101 | 299287_sat.jpg 102 | 300626_sat.jpg 103 | 300967_sat.jpg 104 | 303327_sat.jpg 105 | 306486_sat.jpg 106 | 308959_sat.jpg 107 | 310419_sat.jpg 108 | 311386_sat.jpg 109 | 315352_sat.jpg 110 | 316446_sat.jpg 111 | 318338_sat.jpg 112 | 321724_sat.jpg 113 | 322400_sat.jpg 114 | 325354_sat.jpg 115 | 331533_sat.jpg 116 | 331994_sat.jpg 117 | 33262_sat.jpg 118 | 333661_sat.jpg 119 | 335737_sat.jpg 120 | 33573_sat.jpg 121 | 337272_sat.jpg 122 | 338798_sat.jpg 123 | 340898_sat.jpg 124 | 343016_sat.jpg 125 | 34330_sat.jpg 126 | 343425_sat.jpg 127 | 34359_sat.jpg 128 | 345134_sat.jpg 129 | 345494_sat.jpg 130 | 347676_sat.jpg 131 | 347725_sat.jpg 132 | 3484_sat.jpg 133 | 351727_sat.jpg 134 | 352808_sat.jpg 135 | 358591_sat.jpg 136 | 361129_sat.jpg 137 | 36183_sat.jpg 138 | 362274_sat.jpg 139 | 365555_sat.jpg 140 | 373186_sat.jpg 141 | 37586_sat.jpg 142 | 376441_sat.jpg 143 | 37755_sat.jpg 144 | 382428_sat.jpg 145 | 383392_sat.jpg 146 | 383637_sat.jpg 147 | 384477_sat.jpg 148 | 387554_sat.jpg 149 | 388811_sat.jpg 150 | 392711_sat.jpg 151 | 397351_sat.jpg 152 | 397864_sat.jpg 153 | 400179_sat.jpg 154 | 40168_sat.jpg 155 | 402002_sat.jpg 156 | 40350_sat.jpg 157 | 403978_sat.jpg 158 | 405378_sat.jpg 159 | 405744_sat.jpg 160 | 411741_sat.jpg 161 | 413779_sat.jpg 162 | 416381_sat.jpg 163 | 416463_sat.jpg 164 | 417313_sat.jpg 165 | 41944_sat.jpg 166 | 420066_sat.jpg 167 | 423117_sat.jpg 168 | 428327_sat.jpg 169 | 434243_sat.jpg 170 | 435277_sat.jpg 171 | 439854_sat.jpg 172 | 442329_sat.jpg 173 | 444902_sat.jpg 174 | 45357_sat.jpg 175 | 45676_sat.jpg 176 | 457982_sat.jpg 177 | 458687_sat.jpg 178 | 458776_sat.jpg 179 | 463855_sat.jpg 180 | 467076_sat.jpg 181 | 468103_sat.jpg 182 | 470446_sat.jpg 183 | 470798_sat.jpg 184 | 476582_sat.jpg 185 | 476991_sat.jpg 186 | 482365_sat.jpg 187 | 483506_sat.jpg 188 | 485061_sat.jpg 189 | 491356_sat.jpg 190 | 491696_sat.jpg 191 | 492365_sat.jpg 192 | 495876_sat.jpg 193 | 496948_sat.jpg 194 | 499161_sat.jpg 195 | 499266_sat.jpg 196 | 499418_sat.jpg 197 | 499511_sat.jpg 198 | 501053_sat.jpg 199 | 507241_sat.jpg 200 | 508571_sat.jpg 201 | 511850_sat.jpg 202 | 515521_sat.jpg 203 | 516056_sat.jpg 204 | 516317_sat.jpg 205 | 518833_sat.jpg 206 | 520614_sat.jpg 207 | 524056_sat.jpg 208 | 524518_sat.jpg 209 | 528163_sat.jpg 210 | 530040_sat.jpg 211 | 534154_sat.jpg 212 | 53987_sat.jpg 213 | 541060_sat.jpg 214 | 541353_sat.jpg 215 | 544464_sat.jpg 216 | 547201_sat.jpg 217 | 547785_sat.jpg 218 | 548423_sat.jpg 219 | 548686_sat.jpg 220 | 549870_sat.jpg 221 | 549959_sat.jpg 222 | 552206_sat.jpg 223 | 552396_sat.jpg 224 | 55374_sat.jpg 225 | 556572_sat.jpg 226 | 557309_sat.jpg 227 | 561117_sat.jpg 228 | 568270_sat.jpg 229 | 56924_sat.jpg 230 | 570332_sat.jpg 231 | 575902_sat.jpg 232 | 584941_sat.jpg 233 | 585043_sat.jpg 234 | 586670_sat.jpg 235 | 587968_sat.jpg 236 | 588542_sat.jpg 237 | 58864_sat.jpg 238 | 58910_sat.jpg 239 | 596837_sat.jpg 240 | 599842_sat.jpg 241 | 599975_sat.jpg 242 | 601966_sat.jpg 243 | 602453_sat.jpg 244 | 604647_sat.jpg 245 | 604833_sat.jpg 246 | 605037_sat.jpg 247 | 605764_sat.jpg 248 | 606014_sat.jpg 249 | 606370_sat.jpg 250 | 607622_sat.jpg 251 | 608673_sat.jpg 252 | 609234_sat.jpg 253 | 611015_sat.jpg 254 | 612214_sat.jpg 255 | 61245_sat.jpg 256 | 613687_sat.jpg 257 | 616234_sat.jpg 258 | 616860_sat.jpg 259 | 617844_sat.jpg 260 | 618372_sat.jpg 261 | 621206_sat.jpg 262 | 621633_sat.jpg 263 | 622733_sat.jpg 264 | 623857_sat.jpg 265 | 625296_sat.jpg 266 | 626208_sat.jpg 267 | 627806_sat.jpg 268 | 629198_sat.jpg 269 | 632489_sat.jpg 270 | 634421_sat.jpg 271 | 634717_sat.jpg 272 | 635157_sat.jpg 273 | 636849_sat.jpg 274 | 638158_sat.jpg 275 | 639149_sat.jpg 276 | 639314_sat.jpg 277 | 6399_sat.jpg 278 | 642909_sat.jpg 279 | 644103_sat.jpg 280 | 644150_sat.jpg 281 | 645001_sat.jpg 282 | 649260_sat.jpg 283 | 650751_sat.jpg 284 | 651312_sat.jpg 285 | 65170_sat.jpg 286 | 651774_sat.jpg 287 | 652883_sat.jpg 288 | 655313_sat.jpg 289 | 66344_sat.jpg 290 | 664140_sat.jpg 291 | 664396_sat.jpg 292 | 665914_sat.jpg 293 | 668465_sat.jpg 294 | 669010_sat.jpg 295 | 669779_sat.jpg 296 | 672041_sat.jpg 297 | 672823_sat.jpg 298 | 675424_sat.jpg 299 | 675849_sat.jpg 300 | 676758_sat.jpg 301 | 678520_sat.jpg 302 | 679036_sat.jpg 303 | 682046_sat.jpg 304 | 682688_sat.jpg 305 | 682949_sat.jpg 306 | 692004_sat.jpg 307 | 695475_sat.jpg 308 | 696257_sat.jpg 309 | 69628_sat.jpg 310 | 698065_sat.jpg 311 | 698628_sat.jpg 312 | 699650_sat.jpg 313 | 711893_sat.jpg 314 | 714414_sat.jpg 315 | 715633_sat.jpg 316 | 715846_sat.jpg 317 | 71619_sat.jpg 318 | 717225_sat.jpg 319 | 723067_sat.jpg 320 | 723719_sat.jpg 321 | 727832_sat.jpg 322 | 72807_sat.jpg 323 | 730821_sat.jpg 324 | 736869_sat.jpg 325 | 736933_sat.jpg 326 | 739122_sat.jpg 327 | 739760_sat.jpg 328 | 740937_sat.jpg 329 | 747824_sat.jpg 330 | 749523_sat.jpg 331 | 753408_sat.jpg 332 | 759668_sat.jpg 333 | 759855_sat.jpg 334 | 761189_sat.jpg 335 | 762359_sat.jpg 336 | 763075_sat.jpg 337 | 763892_sat.jpg 338 | 765792_sat.jpg 339 | 76759_sat.jpg 340 | 768475_sat.jpg 341 | 772144_sat.jpg 342 | 772567_sat.jpg 343 | 77388_sat.jpg 344 | 774779_sat.jpg 345 | 778804_sat.jpg 346 | 782103_sat.jpg 347 | 784140_sat.jpg 348 | 786226_sat.jpg 349 | 7906_sat.jpg 350 | 798411_sat.jpg 351 | 801361_sat.jpg 352 | 802645_sat.jpg 353 | 80318_sat.jpg 354 | 803958_sat.jpg 355 | 805150_sat.jpg 356 | 806805_sat.jpg 357 | 807146_sat.jpg 358 | 80808_sat.jpg 359 | 808980_sat.jpg 360 | 81011_sat.jpg 361 | 810368_sat.jpg 362 | 811075_sat.jpg 363 | 820543_sat.jpg 364 | 825592_sat.jpg 365 | 825816_sat.jpg 366 | 827126_sat.jpg 367 | 830444_sat.jpg 368 | 834433_sat.jpg 369 | 838669_sat.jpg 370 | 841621_sat.jpg 371 | 845069_sat.jpg 372 | 847604_sat.jpg 373 | 848649_sat.jpg 374 | 848728_sat.jpg 375 | 848780_sat.jpg 376 | 849797_sat.jpg 377 | 853702_sat.jpg 378 | 855_sat.jpg 379 | 860326_sat.jpg 380 | 866782_sat.jpg 381 | 867017_sat.jpg 382 | 868003_sat.jpg 383 | 86805_sat.jpg 384 | 870705_sat.jpg 385 | 873132_sat.jpg 386 | 875328_sat.jpg 387 | 877160_sat.jpg 388 | 878990_sat.jpg 389 | 880610_sat.jpg 390 | 88571_sat.jpg 391 | 888263_sat.jpg 392 | 888343_sat.jpg 393 | 889145_sat.jpg 394 | 889920_sat.jpg 395 | 890145_sat.jpg 396 | 893261_sat.jpg 397 | 893904_sat.jpg 398 | 895509_sat.jpg 399 | 899693_sat.jpg 400 | 901715_sat.jpg 401 | 902350_sat.jpg 402 | 903649_sat.jpg 403 | 906113_sat.jpg 404 | 910525_sat.jpg 405 | 911457_sat.jpg 406 | 914008_sat.jpg 407 | 916141_sat.jpg 408 | 916336_sat.jpg 409 | 916518_sat.jpg 410 | 917081_sat.jpg 411 | 918446_sat.jpg 412 | 919051_sat.jpg 413 | 923223_sat.jpg 414 | 923618_sat.jpg 415 | 924236_sat.jpg 416 | 926392_sat.jpg 417 | 927126_sat.jpg 418 | 927644_sat.jpg 419 | 930028_sat.jpg 420 | 939614_sat.jpg 421 | 940229_sat.jpg 422 | 942307_sat.jpg 423 | 942594_sat.jpg 424 | 943463_sat.jpg 425 | 943943_sat.jpg 426 | 946386_sat.jpg 427 | 946408_sat.jpg 428 | 946475_sat.jpg 429 | 947994_sat.jpg 430 | 949235_sat.jpg 431 | 951120_sat.jpg 432 | 952430_sat.jpg 433 | 954552_sat.jpg 434 | 95613_sat.jpg 435 | 95683_sat.jpg 436 | 95863_sat.jpg 437 | 961407_sat.jpg 438 | 965977_sat.jpg 439 | 967818_sat.jpg 440 | 96870_sat.jpg 441 | 969934_sat.jpg 442 | 971880_sat.jpg 443 | 98150_sat.jpg 444 | 981852_sat.jpg 445 | 983603_sat.jpg 446 | 987079_sat.jpg 447 | 987427_sat.jpg 448 | 988517_sat.jpg 449 | 989499_sat.jpg 450 | 990573_sat.jpg 451 | 990617_sat.jpg 452 | 990619_sat.jpg 453 | 991758_sat.jpg 454 | 995492_sat.jpg 455 | -------------------------------------------------------------------------------- /train_deep_globe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from torchvision import transforms 11 | from tqdm import tqdm 12 | from dataset.deep_globe import DeepGlobe, classToRGB, is_image_file 13 | from utils.loss import CrossEntropyLoss2d, SoftCrossEntropyLoss2d, FocalLoss 14 | from utils.lovasz_losses import lovasz_softmax 15 | from utils.lr_scheduler import LR_Scheduler 16 | from tensorboardX import SummaryWriter 17 | from helper import create_model_load_weights, get_optimizer, Trainer, Evaluator, collate, collate_test 18 | from option import Options 19 | 20 | args = Options().parse() 21 | n_class = args.n_class 22 | 23 | # torch.cuda.synchronize() 24 | # torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.deterministic = True 26 | 27 | data_path = args.data_path 28 | model_path = args.model_path 29 | if not os.path.isdir(model_path): os.mkdir(model_path) 30 | log_path = args.log_path 31 | if not os.path.isdir(log_path): os.mkdir(log_path) 32 | task_name = args.task_name 33 | 34 | print(task_name) 35 | ################################### 36 | 37 | mode = args.mode # 1: train global; 2: train local from global; 3: train global from local 38 | evaluation = args.evaluation 39 | test = evaluation and False 40 | print("mode:", mode, "evaluation:", evaluation, "test:", test) 41 | 42 | ################################### 43 | print("preparing datasets and dataloaders......") 44 | batch_size = args.batch_size 45 | ids_train = [image_name for image_name in os.listdir(os.path.join(data_path, "train", "Sat")) if is_image_file(image_name)] 46 | ids_val = [image_name for image_name in os.listdir(os.path.join(data_path, "crossvali", "Sat")) if is_image_file(image_name)] 47 | ids_test = [image_name for image_name in os.listdir(os.path.join(data_path, "offical_crossvali", "Sat")) if is_image_file(image_name)] 48 | 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | dataset_train = DeepGlobe(os.path.join(data_path, "train"), ids_train, label=True, transform=True) 51 | dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=True, pin_memory=True) 52 | dataset_val = DeepGlobe(os.path.join(data_path, "crossvali"), ids_val, label=True) 53 | dataloader_val = torch.utils.data.DataLoader(dataset=dataset_val, batch_size=batch_size, num_workers=10, collate_fn=collate, shuffle=False, pin_memory=True) 54 | dataset_test = DeepGlobe(os.path.join(data_path, "offical_crossvali"), ids_test, label=False) 55 | dataloader_test = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, num_workers=10, collate_fn=collate_test, shuffle=False, pin_memory=True) 56 | 57 | ##### sizes are (w, h) ############################## 58 | # make sure margin / 32 is over 1.5 AND size_g is divisible by 4 59 | size_g = (args.size_g, args.size_g) # resized global image 60 | size_p = (args.size_p, args.size_p) # cropped local patch size 61 | sub_batch_size = args.sub_batch_size # batch size for train local patches 62 | ################################### 63 | print("creating models......") 64 | 65 | path_g = os.path.join(model_path, args.path_g) 66 | path_g2l = os.path.join(model_path, args.path_g2l) 67 | path_l2g = os.path.join(model_path, args.path_l2g) 68 | model, global_fixed = create_model_load_weights(n_class, mode, evaluation, path_g=path_g, path_g2l=path_g2l, path_l2g=path_l2g) 69 | 70 | ################################### 71 | num_epochs = args.num_epochs 72 | learning_rate = args.lr 73 | lamb_fmreg = args.lamb_fmreg 74 | 75 | optimizer = get_optimizer(model, mode, learning_rate=learning_rate) 76 | 77 | scheduler = LR_Scheduler('poly', learning_rate, num_epochs, len(dataloader_train)) 78 | ################################## 79 | 80 | criterion1 = FocalLoss(gamma=3) 81 | criterion2 = nn.CrossEntropyLoss() 82 | criterion3 = lovasz_softmax 83 | criterion = lambda x,y: criterion1(x, y) 84 | # criterion = lambda x,y: 0.5*criterion1(x, y) + 0.5*criterion3(x, y) 85 | mse = nn.MSELoss() 86 | 87 | if not evaluation: 88 | writer = SummaryWriter(log_dir=log_path + task_name) 89 | f_log = open(log_path + task_name + ".log", 'w') 90 | 91 | trainer = Trainer(criterion, optimizer, n_class, size_g, size_p, sub_batch_size, mode, lamb_fmreg) 92 | evaluator = Evaluator(n_class, size_g, size_p, sub_batch_size, mode, test) 93 | 94 | best_pred = 0.0 95 | print("start training......") 96 | for epoch in range(num_epochs): 97 | trainer.set_train(model) 98 | optimizer.zero_grad() 99 | tbar = tqdm(dataloader_train); train_loss = 0 100 | for i_batch, sample_batched in enumerate(tbar): 101 | if evaluation: break 102 | scheduler(optimizer, i_batch, epoch, best_pred) 103 | loss = trainer.train(sample_batched, model, global_fixed) 104 | train_loss += loss.item() 105 | score_train, score_train_global, score_train_local = trainer.get_scores() 106 | if mode == 1: tbar.set_description('Train loss: %.3f; global mIoU: %.3f' % (train_loss / (i_batch + 1), np.mean(np.nan_to_num(score_train_global["iou"])))) 107 | else: tbar.set_description('Train loss: %.3f; agg mIoU: %.3f' % (train_loss / (i_batch + 1), np.mean(np.nan_to_num(score_train["iou"])))) 108 | 109 | score_train, score_train_global, score_train_local = trainer.get_scores() 110 | trainer.reset_metrics() 111 | # torch.cuda.empty_cache() 112 | 113 | if epoch % 1 == 0: 114 | with torch.no_grad(): 115 | model.eval() 116 | print("evaluating...") 117 | 118 | if test: tbar = tqdm(dataloader_test) 119 | else: tbar = tqdm(dataloader_val) 120 | 121 | for i_batch, sample_batched in enumerate(tbar): 122 | predictions, predictions_global, predictions_local = evaluator.eval_test(sample_batched, model, global_fixed) 123 | score_val, score_val_global, score_val_local = evaluator.get_scores() 124 | # use [1:] since class0 is not considered in deep_globe metric 125 | if mode == 1: tbar.set_description('global mIoU: %.3f' % (np.mean(np.nan_to_num(score_val_global["iou"])[1:]))) 126 | else: tbar.set_description('agg mIoU: %.3f' % (np.mean(np.nan_to_num(score_val["iou"])[1:]))) 127 | images = sample_batched['image'] 128 | if not test: 129 | labels = sample_batched['label'] # PIL images 130 | 131 | if test: 132 | if not os.path.isdir("./prediction/"): os.mkdir("./prediction/") 133 | for i in range(len(images)): 134 | if mode == 1: 135 | transforms.functional.to_pil_image(classToRGB(predictions_global[i]) * 255.).save("./prediction/" + sample_batched['id'][i] + "_mask.png") 136 | else: 137 | transforms.functional.to_pil_image(classToRGB(predictions[i]) * 255.).save("./prediction/" + sample_batched['id'][i] + "_mask.png") 138 | 139 | if not evaluation and not test: 140 | if i_batch * batch_size + len(images) > (epoch % len(dataloader_val)) and i_batch * batch_size <= (epoch % len(dataloader_val)): 141 | writer.add_image('image', transforms.ToTensor()(images[(epoch % len(dataloader_val)) - i_batch * batch_size]), epoch) 142 | if not test: 143 | writer.add_image('mask', classToRGB(np.array(labels[(epoch % len(dataloader_val)) - i_batch * batch_size])) * 255., epoch) 144 | if mode == 2 or mode == 3: 145 | writer.add_image('prediction', classToRGB(predictions[(epoch % len(dataloader_val)) - i_batch * batch_size]) * 255., epoch) 146 | writer.add_image('prediction_local', classToRGB(predictions_local[(epoch % len(dataloader_val)) - i_batch * batch_size]) * 255., epoch) 147 | writer.add_image('prediction_global', classToRGB(predictions_global[(epoch % len(dataloader_val)) - i_batch * batch_size]) * 255., epoch) 148 | 149 | # torch.cuda.empty_cache() 150 | 151 | # if not (test or evaluation): torch.save(model.state_dict(), "./saved_models/" + task_name + ".epoch" + str(epoch) + ".pth") 152 | if not (test or evaluation): torch.save(model.state_dict(), "./saved_models/" + task_name + ".pth") 153 | 154 | if test: break 155 | else: 156 | score_val, score_val_global, score_val_local = evaluator.get_scores() 157 | evaluator.reset_metrics() 158 | if mode == 1: 159 | if np.mean(np.nan_to_num(score_val_global["iou"][1:])) > best_pred: best_pred = np.mean(np.nan_to_num(score_val_global["iou"][1:])) 160 | else: 161 | if np.mean(np.nan_to_num(score_val["iou"][1:])) > best_pred: best_pred = np.mean(np.nan_to_num(score_val["iou"][1:])) 162 | log = "" 163 | log = log + 'epoch [{}/{}] IoU: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, np.mean(np.nan_to_num(score_train["iou"][1:])), np.mean(np.nan_to_num(score_val["iou"][1:]))) + "\n" 164 | log = log + 'epoch [{}/{}] Local -- IoU: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, np.mean(np.nan_to_num(score_train_local["iou"][1:])), np.mean(np.nan_to_num(score_val_local["iou"][1:]))) + "\n" 165 | log = log + 'epoch [{}/{}] Global -- IoU: train = {:.4f}, val = {:.4f}'.format(epoch+1, num_epochs, np.mean(np.nan_to_num(score_train_global["iou"][1:])), np.mean(np.nan_to_num(score_val_global["iou"][1:]))) + "\n" 166 | log = log + "train: " + str(score_train["iou"]) + "\n" 167 | log = log + "val:" + str(score_val["iou"]) + "\n" 168 | log = log + "Local train:" + str(score_train_local["iou"]) + "\n" 169 | log = log + "Local val:" + str(score_val_local["iou"]) + "\n" 170 | log = log + "Global train:" + str(score_train_global["iou"]) + "\n" 171 | log = log + "Global val:" + str(score_val_global["iou"]) + "\n" 172 | log += "================================\n" 173 | print(log) 174 | if evaluation: break 175 | 176 | f_log.write(log) 177 | f_log.flush() 178 | if mode == 1: 179 | writer.add_scalars('IoU', {'train iou': np.mean(np.nan_to_num(score_train_global["iou"][1:])), 'validation iou': np.mean(np.nan_to_num(score_val_global["iou"][1:]))}, epoch) 180 | else: 181 | writer.add_scalars('IoU', {'train iou': np.mean(np.nan_to_num(score_train["iou"][1:])), 'validation iou': np.mean(np.nan_to_num(score_val["iou"][1:]))}, epoch) 182 | 183 | if not evaluation: f_log.close() -------------------------------------------------------------------------------- /models/resnet_dilation.py: -------------------------------------------------------------------------------- 1 | """Dilated ResNet""" 2 | # https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/dilated/resnet.py 3 | # https://github.com/fyu/drn 4 | import math 5 | import torch 6 | import torch.utils.model_zoo as model_zoo 7 | import torch.nn as nn 8 | from .model_store import get_model_file 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152', 'BasicBlock', 'Bottleneck'] 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | """ResNet BasicBlock 30 | """ 31 | expansion = 1 32 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, 33 | norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 36 | padding=dilation, dilation=dilation, bias=False) 37 | self.bn1 = norm_layer(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 40 | padding=previous_dilation, dilation=previous_dilation, bias=False) 41 | self.bn2 = norm_layer(planes) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | """ResNet Bottleneck 66 | """ 67 | # pylint: disable=unused-argument 68 | expansion = 4 69 | def __init__(self, inplanes, planes, stride=1, dilation=1, 70 | downsample=None, previous_dilation=1, norm_layer=None): 71 | super(Bottleneck, self).__init__() 72 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 73 | self.bn1 = norm_layer(planes) 74 | self.conv2 = nn.Conv2d( 75 | planes, planes, kernel_size=3, stride=stride, 76 | padding=dilation, dilation=dilation, bias=False) 77 | self.bn2 = norm_layer(planes) 78 | self.conv3 = nn.Conv2d( 79 | planes, planes * 4, kernel_size=1, bias=False) 80 | self.bn3 = norm_layer(planes * 4) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.dilation = dilation 84 | self.stride = stride 85 | 86 | def _sum_each(self, x, y): 87 | assert(len(x) == len(y)) 88 | z = [] 89 | for i in range(len(x)): 90 | z.append(x[i]+y[i]) 91 | return z 92 | 93 | def forward(self, x): 94 | residual = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | residual = self.downsample(x) 109 | 110 | out += residual 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | """Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5. 118 | 119 | Parameters 120 | ---------- 121 | block : Block 122 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 123 | layers : list of int 124 | Numbers of layers in each block 125 | classes : int, default 1000 126 | Number of classification classes. 127 | dilated : bool, default False 128 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 129 | typically used in Semantic Segmentation. 130 | norm_layer : object 131 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 132 | for Synchronized Cross-GPU BachNormalization). 133 | 134 | Reference: 135 | 136 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 137 | 138 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 139 | """ 140 | # pylint: disable=unused-variable 141 | def __init__(self, block, layers, num_classes=1000, dilated=True, norm_layer=nn.BatchNorm2d, multi_grid=True, multi_dilation=(1, 2, 3)): 142 | self.inplanes = 128 # 64 143 | super(ResNet, self).__init__() 144 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 145 | self.conv1 = nn.Sequential( 146 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 147 | norm_layer(64), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 150 | norm_layer(64), 151 | nn.ReLU(inplace=True), 152 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), 153 | ) 154 | self.bn1 = norm_layer(self.inplanes) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 157 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 158 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 159 | if dilated: 160 | if multi_grid: 161 | self.layer3 = self._make_layer(block,256,layers[2],stride=1, 162 | dilation=2, norm_layer=norm_layer) 163 | self.layer4 = self._make_layer(block,512,layers[3],stride=1, 164 | dilation=4, norm_layer=norm_layer, 165 | multi_grid=multi_grid, multi_dilation=multi_dilation) 166 | else: 167 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 168 | dilation=2, norm_layer=norm_layer) 169 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 170 | dilation=4, norm_layer=norm_layer) 171 | else: 172 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 173 | norm_layer=norm_layer) 174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 175 | norm_layer=norm_layer) 176 | # self.avgpool = nn.AvgPool2d(7) 177 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 178 | 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | m.weight.data.normal_(0, math.sqrt(2. / n)) 183 | elif isinstance(m, norm_layer): 184 | m.weight.data.fill_(1) 185 | m.bias.data.zero_() 186 | 187 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False, multi_dilation=None): 188 | downsample = None 189 | if stride != 1 or self.inplanes != planes * block.expansion: 190 | downsample = nn.Sequential( 191 | nn.Conv2d(self.inplanes, planes * block.expansion, 192 | kernel_size=1, stride=stride, bias=False), 193 | norm_layer(planes * block.expansion), 194 | ) 195 | 196 | layers = [] 197 | if multi_grid == False: 198 | if dilation == 1 or dilation == 2: 199 | layers.append(block(self.inplanes, planes, stride, dilation=1, 200 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 201 | elif dilation == 4: 202 | layers.append(block(self.inplanes, planes, stride, dilation=2, 203 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 204 | else: 205 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 206 | else: 207 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilation[0], 208 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 209 | self.inplanes = planes * block.expansion 210 | if multi_grid: 211 | div = len(multi_dilation) 212 | for i in range(1,blocks): 213 | layers.append(block(self.inplanes, planes, dilation=multi_dilation[i%div], previous_dilation=dilation, 214 | norm_layer=norm_layer)) 215 | else: 216 | for i in range(1, blocks): 217 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, 218 | norm_layer=norm_layer)) 219 | 220 | return nn.Sequential(*layers) 221 | 222 | def forward(self, x): 223 | x = self.conv1(x) 224 | x = self.bn1(x) 225 | x = self.relu(x) 226 | x = self.maxpool(x) 227 | 228 | c2 = self.layer1(x) 229 | c3 = self.layer2(c2) 230 | c4 = self.layer3(c3) 231 | c5 = self.layer4(c4) 232 | # x = self.avgpool(x) 233 | # x = x.view(x.size(0), -1) 234 | # x = self.fc(x) 235 | # return x 236 | return c2, c3, c4, c5 237 | 238 | 239 | def resnet18(pretrained=False, **kwargs): 240 | """Constructs a ResNet-18 model. 241 | 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | """ 245 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 246 | if pretrained: 247 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 248 | return model 249 | 250 | 251 | def resnet34(pretrained=False, **kwargs): 252 | """Constructs a ResNet-34 model. 253 | 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | """ 257 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 258 | if pretrained: 259 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 260 | return model 261 | 262 | 263 | def resnet50(pretrained=False, root='./pretrain_models', **kwargs): 264 | """Constructs a ResNet-50 model. 265 | 266 | Args: 267 | pretrained (bool): If True, returns a model pre-trained on ImageNet 268 | """ 269 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 270 | if pretrained: 271 | # from ..models.model_store import get_model_file 272 | model.load_state_dict(torch.load( 273 | get_model_file('resnet50', root=root)), strict=False) 274 | return model 275 | 276 | 277 | def resnet101(pretrained=False, root='./pretrain_models', **kwargs): 278 | """Constructs a ResNet-101 model. 279 | 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | """ 283 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 284 | #Remove the following lines of comments 285 | #if u want to train from a pretrained model 286 | if pretrained: 287 | # from ..models.model_store import get_model_file 288 | model.load_state_dict(torch.load( 289 | get_model_file('resnet101', root=root)), strict=False) 290 | return model 291 | 292 | 293 | def resnet152(pretrained=False, root='~/.encoding/models', **kwargs): 294 | """Constructs a ResNet-152 model. 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | """ 299 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 300 | if pretrained: 301 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 302 | model.load_state_dict(torch.load( 303 | './pretrain_models/resnet152-b121ed2d.pth'), strict=False) 304 | return model 305 | -------------------------------------------------------------------------------- /models/fpn_global_local_fmreg_ensemble.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet50 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | 7 | 8 | class fpn_module_global(nn.Module): 9 | def __init__(self, numClass): 10 | super(fpn_module_global, self).__init__() 11 | self._up_kwargs = {'mode': 'bilinear'} 12 | # Top layer 13 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels 14 | # Lateral layers 15 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 16 | self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) 17 | self.latlayer3 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) 18 | # Smooth layers 19 | self.smooth1_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 20 | self.smooth2_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 21 | self.smooth3_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 22 | self.smooth4_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 23 | self.smooth1_2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 24 | self.smooth2_2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 25 | self.smooth3_2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 26 | self.smooth4_2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 27 | # Classify layers 28 | self.classify = nn.Conv2d(128*4, numClass, kernel_size=3, stride=1, padding=1) 29 | 30 | # Local2Global: double #channels #################################### 31 | # Top layer 32 | self.toplayer_ext = nn.Conv2d(2048*2, 256, kernel_size=1, stride=1, padding=0) # Reduce channels 33 | # Lateral layers 34 | self.latlayer1_ext = nn.Conv2d(1024*2, 256, kernel_size=1, stride=1, padding=0) 35 | self.latlayer2_ext = nn.Conv2d(512*2, 256, kernel_size=1, stride=1, padding=0) 36 | self.latlayer3_ext = nn.Conv2d(256*2, 256, kernel_size=1, stride=1, padding=0) 37 | # Smooth layers 38 | self.smooth1_1_ext = nn.Conv2d(256*2, 256, kernel_size=3, stride=1, padding=1) 39 | self.smooth2_1_ext = nn.Conv2d(256*2, 256, kernel_size=3, stride=1, padding=1) 40 | self.smooth3_1_ext = nn.Conv2d(256*2, 256, kernel_size=3, stride=1, padding=1) 41 | self.smooth4_1_ext = nn.Conv2d(256*2, 256, kernel_size=3, stride=1, padding=1) 42 | self.smooth1_2_ext = nn.Conv2d(256*2, 128, kernel_size=3, stride=1, padding=1) 43 | self.smooth2_2_ext = nn.Conv2d(256*2, 128, kernel_size=3, stride=1, padding=1) 44 | self.smooth3_2_ext = nn.Conv2d(256*2, 128, kernel_size=3, stride=1, padding=1) 45 | self.smooth4_2_ext = nn.Conv2d(256*2, 128, kernel_size=3, stride=1, padding=1) 46 | self.smooth = nn.Conv2d(128*4*2, 128*4, kernel_size=3, stride=1, padding=1) 47 | 48 | def _concatenate(self, p5, p4, p3, p2): 49 | _, _, H, W = p2.size() 50 | p5 = F.interpolate(p5, size=(H, W), **self._up_kwargs) 51 | p4 = F.interpolate(p4, size=(H, W), **self._up_kwargs) 52 | p3 = F.interpolate(p3, size=(H, W), **self._up_kwargs) 53 | return torch.cat([p5, p4, p3, p2], dim=1) 54 | 55 | def _upsample_add(self, x, y): 56 | '''Upsample and add two feature maps. 57 | Args: 58 | x: (Variable) top feature map to be upsampled. 59 | y: (Variable) lateral feature map. 60 | Returns: 61 | (Variable) added feature map. 62 | Note in PyTorch, when input size is odd, the upsampled feature map 63 | with `F.interpolate(..., scale_factor=2, mode='nearest')` 64 | maybe not equal to the lateral feature map size. 65 | e.g. 66 | original input size: [N,_,15,15] -> 67 | conv2d feature map size: [N,_,8,8] -> 68 | upsampled feature map size: [N,_,16,16] 69 | So we choose bilinear upsample which supports arbitrary output sizes. 70 | ''' 71 | _, _, H, W = y.size() 72 | return F.interpolate(x, size=(H, W), **self._up_kwargs) + y 73 | 74 | def forward(self, c2, c3, c4, c5, c2_ext=None, c3_ext=None, c4_ext=None, c5_ext=None, ps0_ext=None, ps1_ext=None, ps2_ext=None): 75 | 76 | # Top-down 77 | if c5_ext is None: 78 | p5 = self.toplayer(c5) 79 | p4 = self._upsample_add(p5, self.latlayer1(c4)) 80 | p3 = self._upsample_add(p4, self.latlayer2(c3)) 81 | p2 = self._upsample_add(p3, self.latlayer3(c2)) 82 | else: 83 | p5 = self.toplayer_ext(torch.cat((c5, c5_ext), dim=1)) 84 | p4 = self._upsample_add(p5, self.latlayer1_ext(torch.cat((c4, c4_ext), dim=1))) 85 | p3 = self._upsample_add(p4, self.latlayer2_ext(torch.cat((c3, c3_ext), dim=1))) 86 | p2 = self._upsample_add(p3, self.latlayer3_ext(torch.cat((c2, c2_ext), dim=1))) 87 | ps0 = [p5, p4, p3, p2] 88 | 89 | # Smooth 90 | if ps0_ext is None: 91 | p5 = self.smooth1_1(p5) 92 | p4 = self.smooth2_1(p4) 93 | p3 = self.smooth3_1(p3) 94 | p2 = self.smooth4_1(p2) 95 | else: 96 | p5 = self.smooth1_1_ext(torch.cat((p5, ps0_ext[0]), dim=1)) 97 | p4 = self.smooth2_1_ext(torch.cat((p4, ps0_ext[1]), dim=1)) 98 | p3 = self.smooth3_1_ext(torch.cat((p3, ps0_ext[2]), dim=1)) 99 | p2 = self.smooth4_1_ext(torch.cat((p2, ps0_ext[3]), dim=1)) 100 | ps1 = [p5, p4, p3, p2] 101 | 102 | if ps1_ext is None: 103 | p5 = self.smooth1_2(p5) 104 | p4 = self.smooth2_2(p4) 105 | p3 = self.smooth3_2(p3) 106 | p2 = self.smooth4_2(p2) 107 | else: 108 | p5 = self.smooth1_2_ext(torch.cat((p5, ps1_ext[0]), dim=1)) 109 | p4 = self.smooth2_2_ext(torch.cat((p4, ps1_ext[1]), dim=1)) 110 | p3 = self.smooth3_2_ext(torch.cat((p3, ps1_ext[2]), dim=1)) 111 | p2 = self.smooth4_2_ext(torch.cat((p2, ps1_ext[3]), dim=1)) 112 | ps2 = [p5, p4, p3, p2] 113 | 114 | # Classify 115 | if ps2_ext is None: 116 | ps3 = self._concatenate(p5, p4, p3, p2) 117 | output = self.classify(ps3) 118 | else: 119 | p = self._concatenate( 120 | torch.cat((p5, ps2_ext[0]), dim=1), 121 | torch.cat((p4, ps2_ext[1]), dim=1), 122 | torch.cat((p3, ps2_ext[2]), dim=1), 123 | torch.cat((p2, ps2_ext[3]), dim=1) 124 | ) 125 | ps3 = self.smooth(p) 126 | output = self.classify(ps3) 127 | 128 | return output, ps0, ps1, ps2, ps3 129 | 130 | 131 | class fpn_module_local(nn.Module): 132 | def __init__(self, numClass): 133 | super(fpn_module_local, self).__init__() 134 | self._up_kwargs = {'mode': 'bilinear'} 135 | # Top layer 136 | fold = 2 137 | self.toplayer = nn.Conv2d(2048 * fold, 256, kernel_size=1, stride=1, padding=0) # Reduce channels 138 | # Lateral layers [C] 139 | self.latlayer1 = nn.Conv2d(1024 * fold, 256, kernel_size=1, stride=1, padding=0) 140 | self.latlayer2 = nn.Conv2d(512 * fold, 256, kernel_size=1, stride=1, padding=0) 141 | self.latlayer3 = nn.Conv2d(256 * fold, 256, kernel_size=1, stride=1, padding=0) 142 | # Smooth layers 143 | # ps0 144 | self.smooth1_1 = nn.Conv2d(256 * fold, 256, kernel_size=3, stride=1, padding=1) 145 | self.smooth2_1 = nn.Conv2d(256 * fold, 256, kernel_size=3, stride=1, padding=1) 146 | self.smooth3_1 = nn.Conv2d(256 * fold, 256, kernel_size=3, stride=1, padding=1) 147 | self.smooth4_1 = nn.Conv2d(256 * fold, 256, kernel_size=3, stride=1, padding=1) 148 | # ps1 149 | self.smooth1_2 = nn.Conv2d(256 * fold, 128, kernel_size=3, stride=1, padding=1) 150 | self.smooth2_2 = nn.Conv2d(256 * fold, 128, kernel_size=3, stride=1, padding=1) 151 | self.smooth3_2 = nn.Conv2d(256 * fold, 128, kernel_size=3, stride=1, padding=1) 152 | self.smooth4_2 = nn.Conv2d(256 * fold, 128, kernel_size=3, stride=1, padding=1) 153 | # ps2 is concatenation 154 | # Classify layers 155 | self.smooth = nn.Conv2d(128*4*fold, 128*4, kernel_size=3, stride=1, padding=1) 156 | self.classify = nn.Conv2d(128*4, numClass, kernel_size=3, stride=1, padding=1) 157 | 158 | def _concatenate(self, p5, p4, p3, p2): 159 | _, _, H, W = p2.size() 160 | p5 = F.interpolate(p5, size=(H, W), **self._up_kwargs) 161 | p4 = F.interpolate(p4, size=(H, W), **self._up_kwargs) 162 | p3 = F.interpolate(p3, size=(H, W), **self._up_kwargs) 163 | return torch.cat([p5, p4, p3, p2], dim=1) 164 | 165 | def _upsample_add(self, x, y): 166 | '''Upsample and add two feature maps. 167 | Args: 168 | x: (Variable) top feature map to be upsampled. 169 | y: (Variable) lateral feature map. 170 | Returns: 171 | (Variable) added feature map. 172 | Note in PyTorch, when input size is odd, the upsampled feature map 173 | with `F.interpolate(..., scale_factor=2, mode='nearest')` 174 | maybe not equal to the lateral feature map size. 175 | e.g. 176 | original input size: [N,_,15,15] -> 177 | conv2d feature map size: [N,_,8,8] -> 178 | upsampled feature map size: [N,_,16,16] 179 | So we choose bilinear upsample which supports arbitrary output sizes. 180 | ''' 181 | _, _, H, W = y.size() 182 | return F.interpolate(x, size=(H, W), **self._up_kwargs) + y 183 | 184 | def forward(self, c2, c3, c4, c5, c2_ext, c3_ext, c4_ext, c5_ext, ps0_ext, ps1_ext, ps2_ext): 185 | 186 | # Top-down 187 | p5 = self.toplayer(torch.cat([c5] + [F.interpolate(c5_ext[0], size=c5.size()[2:], **self._up_kwargs)], dim=1)) 188 | p4 = self._upsample_add(p5, self.latlayer1(torch.cat([c4] + [F.interpolate(c4_ext[0], size=c4.size()[2:], **self._up_kwargs)], dim=1))) 189 | p3 = self._upsample_add(p4, self.latlayer2(torch.cat([c3] + [F.interpolate(c3_ext[0], size=c3.size()[2:], **self._up_kwargs)], dim=1))) 190 | p2 = self._upsample_add(p3, self.latlayer3(torch.cat([c2] + [F.interpolate(c2_ext[0], size=c2.size()[2:], **self._up_kwargs)], dim=1))) 191 | ps0 = [p5, p4, p3, p2] 192 | 193 | # Smooth 194 | p5 = self.smooth1_1(torch.cat([p5] + [F.interpolate(ps0_ext[0][0], size=p5.size()[2:], **self._up_kwargs)], dim=1)) 195 | p4 = self.smooth2_1(torch.cat([p4] + [F.interpolate(ps0_ext[1][0], size=p4.size()[2:], **self._up_kwargs)], dim=1)) 196 | p3 = self.smooth3_1(torch.cat([p3] + [F.interpolate(ps0_ext[2][0], size=p3.size()[2:], **self._up_kwargs)], dim=1)) 197 | p2 = self.smooth4_1(torch.cat([p2] + [F.interpolate(ps0_ext[3][0], size=p2.size()[2:], **self._up_kwargs)], dim=1)) 198 | ps1 = [p5, p4, p3, p2] 199 | 200 | p5 = self.smooth1_2(torch.cat([p5] + [F.interpolate(ps1_ext[0][0], size=p5.size()[2:], **self._up_kwargs)], dim=1)) 201 | p4 = self.smooth2_2(torch.cat([p4] + [F.interpolate(ps1_ext[1][0], size=p4.size()[2:], **self._up_kwargs)], dim=1)) 202 | p3 = self.smooth3_2(torch.cat([p3] + [F.interpolate(ps1_ext[2][0], size=p3.size()[2:], **self._up_kwargs)], dim=1)) 203 | p2 = self.smooth4_2(torch.cat([p2] + [F.interpolate(ps1_ext[3][0], size=p2.size()[2:], **self._up_kwargs)], dim=1)) 204 | ps2 = [p5, p4, p3, p2] 205 | 206 | # Classify 207 | # use ps2_ext 208 | ps3 = self._concatenate( 209 | torch.cat([p5] + [F.interpolate(ps2_ext[0][0], size=p5.size()[2:], **self._up_kwargs)], dim=1), 210 | torch.cat([p4] + [F.interpolate(ps2_ext[1][0], size=p4.size()[2:], **self._up_kwargs)], dim=1), 211 | torch.cat([p3] + [F.interpolate(ps2_ext[2][0], size=p3.size()[2:], **self._up_kwargs)], dim=1), 212 | torch.cat([p2] + [F.interpolate(ps2_ext[3][0], size=p2.size()[2:], **self._up_kwargs)], dim=1) 213 | ) 214 | ps3 = self.smooth(ps3) 215 | output = self.classify(ps3) 216 | 217 | return output, ps0, ps1, ps2, ps3 218 | 219 | 220 | class fpn(nn.Module): 221 | def __init__(self, numClass): 222 | super(fpn, self).__init__() 223 | self._up_kwargs = {'mode': 'bilinear'} 224 | # Res net 225 | self.resnet_global = resnet50(True) 226 | self.resnet_local = resnet50(True) 227 | 228 | # fpn module 229 | self.fpn_global = fpn_module_global(numClass) 230 | self.fpn_local = fpn_module_local(numClass) 231 | 232 | self.c2_g = None; self.c3_g = None; self.c4_g = None; self.c5_g = None; self.output_g = None 233 | self.ps0_g = None; self.ps1_g = None; self.ps2_g = None; self.ps3_g = None 234 | 235 | self.c2_l = []; self.c3_l = []; self.c4_l = []; self.c5_l = []; 236 | self.ps00_l = []; self.ps01_l = []; self.ps02_l = []; self.ps03_l = []; 237 | self.ps10_l = []; self.ps11_l = []; self.ps12_l = []; self.ps13_l = []; 238 | self.ps20_l = []; self.ps21_l = []; self.ps22_l = []; self.ps23_l = []; 239 | self.ps0_l = None; self.ps1_l = None; self.ps2_l = None 240 | self.ps3_l = []#; self.output_l = [] 241 | 242 | self.c2_b = None; self.c3_b = None; self.c4_b = None; self.c5_b = None; 243 | self.ps00_b = None; self.ps01_b = None; self.ps02_b = None; self.ps03_b = None; 244 | self.ps10_b = None; self.ps11_b = None; self.ps12_b = None; self.ps13_b = None; 245 | self.ps20_b = None; self.ps21_b = None; self.ps22_b = None; self.ps23_b = None; 246 | self.ps3_b = []#; self.output_b = [] 247 | 248 | self.patch_n = 0 249 | 250 | self.mse = nn.MSELoss() 251 | 252 | self.ensemble_conv = nn.Conv2d(128*4 * 2, numClass, kernel_size=3, stride=1, padding=1) 253 | nn.init.normal_(self.ensemble_conv.weight, mean=0, std=0.01) 254 | 255 | # init fpn 256 | for m in self.fpn_global.children(): 257 | if hasattr(m, 'weight'): nn.init.normal_(m.weight, mean=0, std=0.01) 258 | if hasattr(m, 'bias'): nn.init.constant_(m.bias, 0) 259 | for m in self.fpn_local.children(): 260 | if hasattr(m, 'weight'): nn.init.normal_(m.weight, mean=0, std=0.01) 261 | if hasattr(m, 'bias'): nn.init.constant_(m.bias, 0) 262 | 263 | def clear_cache(self): 264 | self.c2_g = None; self.c3_g = None; self.c4_g = None; self.c5_g = None; self.output_g = None 265 | self.ps0_g = None; self.ps1_g = None; self.ps2_g = None; self.ps3_g = None 266 | 267 | self.c2_l = []; self.c3_l = []; self.c4_l = []; self.c5_l = []; 268 | self.ps00_l = []; self.ps01_l = []; self.ps02_l = []; self.ps03_l = []; 269 | self.ps10_l = []; self.ps11_l = []; self.ps12_l = []; self.ps13_l = []; 270 | self.ps20_l = []; self.ps21_l = []; self.ps22_l = []; self.ps23_l = []; 271 | self.ps0_l = None; self.ps1_l = None; self.ps2_l = None 272 | self.ps3_l = []; self.output_l = [] 273 | 274 | self.c2_b = None; self.c3_b = None; self.c4_b = None; self.c5_b = None; 275 | self.ps00_b = None; self.ps01_b = None; self.ps02_b = None; self.ps03_b = None; 276 | self.ps10_b = None; self.ps11_b = None; self.ps12_b = None; self.ps13_b = None; 277 | self.ps20_b = None; self.ps21_b = None; self.ps22_b = None; self.ps23_b = None; 278 | self.ps3_b = []; self.output_b = [] 279 | 280 | self.patch_n = 0 281 | 282 | 283 | def _sample_grid(self, fm, bbox, sampleSize): 284 | """ 285 | :param fm: tensor(b,c,h,w) the global feature map 286 | :param bbox: list [b* nparray(x1, y1, x2, y2)] the (x1,y1) is the left_top of bbox, (x2, y2) is the right_bottom of bbox 287 | there are in range [0, 1]. x is corresponding to width dimension and y is corresponding to height dimension 288 | :param sampleSize: (oH, oW) the point to sample in height dimension and width dimension 289 | :return: tensor(b, c, oH, oW) sampled tensor 290 | """ 291 | b, c, h, w = fm.shape 292 | b_bbox = len(bbox) 293 | bbox = [x*2 - 1 for x in bbox] # range transform 294 | if b != b_bbox and b == 1: 295 | fm = torch.cat([fm,]*b_bbox, dim=0) 296 | grid = np.zeros((b_bbox,) + sampleSize + (2,), dtype=np.float32) 297 | gridMap = np.array([[(cnt_w/(sampleSize[1]-1), cnt_h/(sampleSize[0]-1)) for cnt_w in range(sampleSize[1])] for cnt_h in range(sampleSize[0])]) 298 | for cnt_b in range(b_bbox): 299 | grid[cnt_b, :, :, 0] = bbox[cnt_b][0] + (bbox[cnt_b][2] - bbox[cnt_b][0])*gridMap[:, :, 0] 300 | grid[cnt_b, :, :, 1] = bbox[cnt_b][1] + (bbox[cnt_b][3] - bbox[cnt_b][1])*gridMap[:, :, 1] 301 | grid = torch.from_numpy(grid).cuda() 302 | return F.grid_sample(fm, grid) 303 | 304 | def _crop_global(self, f_global, top_lefts, ratio): 305 | ''' 306 | top_lefts: [(top, left)] * b 307 | ''' 308 | _, c, H, W = f_global.size() 309 | b = len(top_lefts) 310 | h, w = int(np.round(H * ratio[0])), int(np.round(W * ratio[1])) 311 | 312 | # bbox = [ np.array([left, top, left + ratio, top + ratio]) for (top, left) in top_lefts ] 313 | # crop = self._sample_grid(f_global, bbox, (H, W)) 314 | 315 | crop = [] 316 | for i in range(b): 317 | top, left = int(np.round(top_lefts[i][0] * H)), int(np.round(top_lefts[i][1] * W)) 318 | # # global's sub-region & upsample 319 | # f_global_patch = F.interpolate(f_global[0:1, :, top:top+h, left:left+w], size=(h, w), mode='bilinear') 320 | f_global_patch = f_global[0:1, :, top:top+h, left:left+w] 321 | crop.append(f_global_patch[0]) 322 | crop = torch.stack(crop, dim=0) # stack into mini-batch 323 | return [crop] # return as a list for easy to torch.cat 324 | 325 | def _merge_local(self, f_local, merge, f_global, top_lefts, oped, ratio, template): 326 | ''' 327 | merge feature maps from local patches, and finally to a whole image's feature map (on cuda) 328 | f_local: a sub_batch_size of patch's feature map 329 | oped: [start, end) 330 | ''' 331 | b, _, _, _ = f_local.size() 332 | _, c, H, W = f_global.size() # match global feature size 333 | if merge is None: 334 | merge = torch.zeros((1, c, H, W)).cuda() 335 | h, w = int(np.round(H * ratio[0])), int(np.round(W * ratio[1])) 336 | for i in range(b): 337 | index = oped[0] + i 338 | top, left = int(np.round(H * top_lefts[index][0])), int(np.round(W * top_lefts[index][1])) 339 | merge[:, :, top:top+h, left:left+w] += F.interpolate(f_local[i:i+1], size=(h, w), **self._up_kwargs) 340 | if oped[1] >= len(top_lefts): 341 | template = F.interpolate(template, size=(H, W), **self._up_kwargs) 342 | template = template.expand_as(merge) 343 | # template = Variable(template).cuda() 344 | merge /= template 345 | return merge 346 | 347 | def ensemble(self, f_local, f_global): 348 | return self.ensemble_conv(torch.cat((f_local, f_global), dim=1)) 349 | 350 | def collect_local_fm(self, image_global, patches, ratio, top_lefts, oped, batch_size, global_model=None, template=None, n_patch_all=None): 351 | ''' 352 | patches: 1 patch 353 | top_lefts: all top-left 354 | oped: [start, end) 355 | ''' 356 | with torch.no_grad(): 357 | if self.patch_n == 0: 358 | self.c2_g, self.c3_g, self.c4_g, self.c5_g = global_model.module.resnet_global.forward(image_global) 359 | self.output_g, self.ps0_g, self.ps1_g, self.ps2_g, self.ps3_g = global_model.module.fpn_global.forward(self.c2_g, self.c3_g, self.c4_g, self.c5_g) 360 | # self.output_g = F.interpolate(self.output_g, image_global.size()[2:], mode='nearest') 361 | self.patch_n += patches.size()[0] 362 | self.patch_n %= n_patch_all 363 | 364 | self.resnet_local.eval() 365 | self.fpn_local.eval() 366 | c2, c3, c4, c5 = self.resnet_local.forward(patches) 367 | # global's 1x patch cat 368 | output, ps0, ps1, ps2, ps3 = self.fpn_local.forward( 369 | c2, c3, c4, c5, 370 | self._crop_global(self.c2_g, top_lefts[oped[0]:oped[1]], ratio), 371 | c3_ext=self._crop_global(self.c3_g, top_lefts[oped[0]:oped[1]], ratio), 372 | c4_ext=self._crop_global(self.c4_g, top_lefts[oped[0]:oped[1]], ratio), 373 | c5_ext=self._crop_global(self.c5_g, top_lefts[oped[0]:oped[1]], ratio), 374 | ps0_ext=[ self._crop_global(f, top_lefts[oped[0]:oped[1]], ratio) for f in self.ps0_g ], 375 | ps1_ext=[ self._crop_global(f, top_lefts[oped[0]:oped[1]], ratio) for f in self.ps1_g ], 376 | ps2_ext=[ self._crop_global(f, top_lefts[oped[0]:oped[1]], ratio) for f in self.ps2_g ] 377 | ) 378 | # output = F.interpolate(output, patches.size()[2:], mode='nearest') 379 | 380 | self.c2_b = self._merge_local(c2, self.c2_b, self.c2_g, top_lefts, oped, ratio, template) 381 | self.c3_b = self._merge_local(c3, self.c3_b, self.c3_g, top_lefts, oped, ratio, template) 382 | self.c4_b = self._merge_local(c4, self.c4_b, self.c4_g, top_lefts, oped, ratio, template) 383 | self.c5_b = self._merge_local(c5, self.c5_b, self.c5_g, top_lefts, oped, ratio, template) 384 | 385 | self.ps00_b = self._merge_local(ps0[0], self.ps00_b, self.ps0_g[0], top_lefts, oped, ratio, template) 386 | self.ps01_b = self._merge_local(ps0[1], self.ps01_b, self.ps0_g[1], top_lefts, oped, ratio, template) 387 | self.ps02_b = self._merge_local(ps0[2], self.ps02_b, self.ps0_g[2], top_lefts, oped, ratio, template) 388 | self.ps03_b = self._merge_local(ps0[3], self.ps03_b, self.ps0_g[3], top_lefts, oped, ratio, template) 389 | self.ps10_b = self._merge_local(ps1[0], self.ps10_b, self.ps1_g[0], top_lefts, oped, ratio, template) 390 | self.ps11_b = self._merge_local(ps1[1], self.ps11_b, self.ps1_g[1], top_lefts, oped, ratio, template) 391 | self.ps12_b = self._merge_local(ps1[2], self.ps12_b, self.ps1_g[2], top_lefts, oped, ratio, template) 392 | self.ps13_b = self._merge_local(ps1[3], self.ps13_b, self.ps1_g[3], top_lefts, oped, ratio, template) 393 | self.ps20_b = self._merge_local(ps2[0], self.ps20_b, self.ps2_g[0], top_lefts, oped, ratio, template) 394 | self.ps21_b = self._merge_local(ps2[1], self.ps21_b, self.ps2_g[1], top_lefts, oped, ratio, template) 395 | self.ps22_b = self._merge_local(ps2[2], self.ps22_b, self.ps2_g[2], top_lefts, oped, ratio, template) 396 | self.ps23_b = self._merge_local(ps2[3], self.ps23_b, self.ps2_g[3], top_lefts, oped, ratio, template) 397 | 398 | self.ps3_b.append(ps3.cpu()) 399 | # self.output_b.append(output.cpu()) # each output is 1, 7, h, w 400 | 401 | if self.patch_n == 0: 402 | # merged all patches into an image 403 | self.c2_l.append(self.c2_b); self.c3_l.append(self.c3_b); self.c4_l.append(self.c4_b); self.c5_l.append(self.c5_b); 404 | self.ps00_l.append(self.ps00_b); self.ps01_l.append(self.ps01_b); self.ps02_l.append(self.ps02_b); self.ps03_l.append(self.ps03_b) 405 | self.ps10_l.append(self.ps10_b); self.ps11_l.append(self.ps11_b); self.ps12_l.append(self.ps12_b); self.ps13_l.append(self.ps13_b) 406 | self.ps20_l.append(self.ps20_b); self.ps21_l.append(self.ps21_b); self.ps22_l.append(self.ps22_b); self.ps23_l.append(self.ps23_b) 407 | 408 | # collected all ps3 and output of patches as a (b) tensor, append into list 409 | self.ps3_l.append(torch.cat(self.ps3_b, dim=0)); # a list of tensors 410 | # self.output_l.append(torch.cat(self.output_b, dim=0)) # a list of 36, 7, h, w tensors 411 | 412 | self.c2_b = None; self.c3_b = None; self.c4_b = None; self.c5_b = None; 413 | self.ps00_b = None; self.ps01_b = None; self.ps02_b = None; self.ps03_b = None; 414 | self.ps10_b = None; self.ps11_b = None; self.ps12_b = None; self.ps13_b = None; 415 | self.ps20_b = None; self.ps21_b = None; self.ps22_b = None; self.ps23_b = None; 416 | self.ps3_b = []# ; self.output_b = [] 417 | if len(self.c2_l) == batch_size: 418 | self.c2_l = torch.cat(self.c2_l, dim=0)# .cuda() 419 | self.c3_l = torch.cat(self.c3_l, dim=0)# .cuda() 420 | self.c4_l = torch.cat(self.c4_l, dim=0)# .cuda() 421 | self.c5_l = torch.cat(self.c5_l, dim=0)# .cuda() 422 | self.ps00_l = torch.cat(self.ps00_l, dim=0)# .cuda() 423 | self.ps01_l = torch.cat(self.ps01_l, dim=0)# .cuda() 424 | self.ps02_l = torch.cat(self.ps02_l, dim=0)# .cuda() 425 | self.ps03_l = torch.cat(self.ps03_l, dim=0)# .cuda() 426 | self.ps10_l = torch.cat(self.ps10_l, dim=0)# .cuda() 427 | self.ps11_l = torch.cat(self.ps11_l, dim=0)# .cuda() 428 | self.ps12_l = torch.cat(self.ps12_l, dim=0)# .cuda() 429 | self.ps13_l = torch.cat(self.ps13_l, dim=0)# .cuda() 430 | self.ps20_l = torch.cat(self.ps20_l, dim=0)# .cuda() 431 | self.ps21_l = torch.cat(self.ps21_l, dim=0)# .cuda() 432 | self.ps22_l = torch.cat(self.ps22_l, dim=0)# .cuda() 433 | self.ps23_l = torch.cat(self.ps23_l, dim=0)# .cuda() 434 | self.ps0_l = [self.ps00_l, self.ps01_l, self.ps02_l, self.ps03_l] 435 | self.ps1_l = [self.ps10_l, self.ps11_l, self.ps12_l, self.ps13_l] 436 | self.ps2_l = [self.ps20_l, self.ps21_l, self.ps22_l, self.ps23_l] 437 | # self.ps3_l = torch.cat(self.ps3_l, dim=0)# .cuda() 438 | return self.ps3_l, output# self.output_l 439 | 440 | 441 | def forward(self, image_global, patches, top_lefts, ratio, mode=1, global_model=None, n_patch=None): 442 | if mode == 1: 443 | # train global model 444 | c2_g, c3_g, c4_g, c5_g = self.resnet_global.forward(image_global) 445 | output_g, ps0_g, ps1_g, ps2_g, ps3_g = self.fpn_global.forward(c2_g, c3_g, c4_g, c5_g) 446 | # imsize = image_global.size()[2:] 447 | # output_g = F.interpolate(output_g, imsize, mode='nearest') 448 | return output_g, None 449 | elif mode == 2: 450 | # train global2local model 451 | with torch.no_grad(): 452 | if self.patch_n == 0: 453 | # calculate global images only if patches belong to a new set of global images (when self.patch_n % n_patch == 0) 454 | self.c2_g, self.c3_g, self.c4_g, self.c5_g = self.resnet_global.forward(image_global) 455 | self.output_g, self.ps0_g, self.ps1_g, self.ps2_g, self.ps3_g = self.fpn_global.forward(self.c2_g, self.c3_g, self.c4_g, self.c5_g) 456 | # imsize_glb = image_global.size()[2:] 457 | # self.output_g = F.interpolate(self.output_g, imsize_glb, mode='nearest') 458 | self.patch_n += patches.size()[0] 459 | self.patch_n %= n_patch 460 | 461 | # train local model ####################################### 462 | c2_l, c3_l, c4_l, c5_l = self.resnet_local.forward(patches) 463 | # global's 1x patch cat 464 | output_l, ps0_l, ps1_l, ps2_l, ps3_l = self.fpn_local.forward(c2_l, c3_l, c4_l, c5_l, 465 | self._crop_global(self.c2_g, top_lefts, ratio), 466 | self._crop_global(self.c3_g, top_lefts, ratio), 467 | self._crop_global(self.c4_g, top_lefts, ratio), 468 | self._crop_global(self.c5_g, top_lefts, ratio), 469 | [ self._crop_global(f, top_lefts, ratio) for f in self.ps0_g ], 470 | [ self._crop_global(f, top_lefts, ratio) for f in self.ps1_g ], 471 | [ self._crop_global(f, top_lefts, ratio) for f in self.ps2_g ] 472 | ) 473 | # imsize = patches.size()[2:] 474 | # output_l = F.interpolate(output_l, imsize, mode='nearest') 475 | ps3_g2l = self._crop_global(self.ps3_g, top_lefts, ratio)[0] # only calculate loss on 1x 476 | ps3_g2l = F.interpolate(ps3_g2l, size=ps3_l.size()[2:], **self._up_kwargs) 477 | 478 | output = self.ensemble(ps3_l, ps3_g2l) 479 | # output = F.interpolate(output, imsize, mode='nearest') 480 | return output, self.output_g, output_l, self.mse(ps3_l, ps3_g2l) 481 | else: 482 | # train local2global model 483 | c2_g, c3_g, c4_g, c5_g = self.resnet_global.forward(image_global) 484 | # local patch cat into global 485 | output_g, ps0_g, ps1_g, ps2_g, ps3_g = self.fpn_global.forward(c2_g, c3_g, c4_g, c5_g, c2_ext=self.c2_l, c3_ext=self.c3_l, c4_ext=self.c4_l, c5_ext=self.c5_l, ps0_ext=self.ps0_l, ps1_ext=self.ps1_l, ps2_ext=self.ps2_l) 486 | # imsize = image_global.size()[2:] 487 | # output_g = F.interpolate(output_g, imsize, mode='nearest') 488 | self.clear_cache() 489 | return output_g, ps3_g -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from torchvision import transforms 13 | from models.fpn_global_local_fmreg_ensemble import fpn 14 | from utils.metrics import ConfusionMatrix 15 | from PIL import Image 16 | 17 | # torch.cuda.synchronize() 18 | # torch.backends.cudnn.benchmark = True 19 | torch.backends.cudnn.deterministic = True 20 | 21 | transformer = transforms.Compose([ 22 | transforms.ToTensor(), 23 | ]) 24 | 25 | def resize(images, shape, label=False): 26 | ''' 27 | resize PIL images 28 | shape: (w, h) 29 | ''' 30 | resized = list(images) 31 | for i in range(len(images)): 32 | if label: 33 | resized[i] = images[i].resize(shape, Image.NEAREST) 34 | else: 35 | resized[i] = images[i].resize(shape, Image.BILINEAR) 36 | return resized 37 | 38 | def _mask_transform(mask): 39 | target = np.array(mask).astype('int32') 40 | target[target == 255] = -1 41 | # target -= 1 # in DeepGlobe: make class 0 (should be ignored) as -1 (to be ignored in cross_entropy) 42 | return target 43 | 44 | def masks_transform(masks, numpy=False): 45 | ''' 46 | masks: list of PIL images 47 | ''' 48 | targets = [] 49 | for m in masks: 50 | targets.append(_mask_transform(m)) 51 | targets = np.array(targets) 52 | if numpy: 53 | return targets 54 | else: 55 | return torch.from_numpy(targets).long().cuda() 56 | 57 | def images_transform(images): 58 | ''' 59 | images: list of PIL images 60 | ''' 61 | inputs = [] 62 | for img in images: 63 | inputs.append(transformer(img)) 64 | inputs = torch.stack(inputs, dim=0).cuda() 65 | return inputs 66 | 67 | def get_patch_info(shape, p_size): 68 | ''' 69 | shape: origin image size, (x, y) 70 | p_size: patch size (square) 71 | return: n_x, n_y, step_x, step_y 72 | ''' 73 | x = shape[0] 74 | y = shape[1] 75 | n = m = 1 76 | while x > n * p_size: 77 | n += 1 78 | while p_size - 1.0 * (x - p_size) / (n - 1) < 50: 79 | n += 1 80 | while y > m * p_size: 81 | m += 1 82 | while p_size - 1.0 * (y - p_size) / (m - 1) < 50: 83 | m += 1 84 | return n, m, (x - p_size) * 1.0 / (n - 1), (y - p_size) * 1.0 / (m - 1) 85 | 86 | def global2patch(images, p_size): 87 | ''' 88 | image/label => patches 89 | p_size: patch size 90 | return: list of PIL patch images; coordinates: images->patches; ratios: (h, w) 91 | ''' 92 | patches = []; coordinates = []; templates = []; sizes = []; ratios = [(0, 0)] * len(images); patch_ones = np.ones(p_size) 93 | for i in range(len(images)): 94 | w, h = images[i].size 95 | size = (h, w) 96 | sizes.append(size) 97 | ratios[i] = (float(p_size[0]) / size[0], float(p_size[1]) / size[1]) 98 | template = np.zeros(size) 99 | n_x, n_y, step_x, step_y = get_patch_info(size, p_size[0]) 100 | patches.append([images[i]] * (n_x * n_y)) 101 | coordinates.append([(0, 0)] * (n_x * n_y)) 102 | for x in range(n_x): 103 | if x < n_x - 1: top = int(np.round(x * step_x)) 104 | else: top = size[0] - p_size[0] 105 | for y in range(n_y): 106 | if y < n_y - 1: left = int(np.round(y * step_y)) 107 | else: left = size[1] - p_size[1] 108 | template[top:top+p_size[0], left:left+p_size[1]] += patch_ones 109 | coordinates[i][x * n_y + y] = (1.0 * top / size[0], 1.0 * left / size[1]) 110 | patches[i][x * n_y + y] = transforms.functional.crop(images[i], top, left, p_size[0], p_size[1]) 111 | templates.append(Variable(torch.Tensor(template).expand(1, 1, -1, -1)).cuda()) 112 | return patches, coordinates, templates, sizes, ratios 113 | 114 | def patch2global(patches, n_class, sizes, coordinates, p_size): 115 | ''' 116 | predicted patches (after classify layer) => predictions 117 | return: list of np.array 118 | ''' 119 | predictions = [ np.zeros((n_class, size[0], size[1])) for size in sizes ] 120 | for i in range(len(sizes)): 121 | for j in range(len(coordinates[i])): 122 | top, left = coordinates[i][j] 123 | top = int(np.round(top * sizes[i][0])); left = int(np.round(left * sizes[i][1])) 124 | predictions[i][:, top: top + p_size[0], left: left + p_size[1]] += patches[i][j] 125 | return predictions 126 | 127 | def template_patch2global(size_g, size_p, n, step): 128 | template = np.zeros(size_g) 129 | coordinates = [(0, 0)] * n ** 2 130 | patch = np.ones(size_p) 131 | step = (size_g[0] - size_p[0]) // (n - 1) 132 | x = y = 0 133 | i = 0 134 | while x + size_p[0] <= size_g[0]: 135 | while y + size_p[1] <= size_g[1]: 136 | template[x:x+size_p[0], y:y+size_p[1]] += patch 137 | coordinates[i] = (1.0 * x / size_g[0], 1.0 * y / size_g[1]) 138 | i += 1 139 | y += step 140 | x += step 141 | y = 0 142 | return Variable(torch.Tensor(template).expand(1, 1, -1, -1)).cuda(), coordinates 143 | 144 | def one_hot_gaussian_blur(index, classes): 145 | ''' 146 | index: numpy array b, h, w 147 | classes: int 148 | ''' 149 | mask = np.transpose((np.arange(classes) == index[..., None]).astype(float), (0, 3, 1, 2)) 150 | b, c, _, _ = mask.shape 151 | for i in range(b): 152 | for j in range(c): 153 | mask[i][j] = cv2.GaussianBlur(mask[i][j], (0, 0), 8) 154 | 155 | return mask 156 | 157 | def collate(batch): 158 | image = [ b['image'] for b in batch ] # w, h 159 | label = [ b['label'] for b in batch ] 160 | id = [ b['id'] for b in batch ] 161 | return {'image': image, 'label': label, 'id': id} 162 | 163 | def collate_test(batch): 164 | image = [ b['image'] for b in batch ] # w, h 165 | id = [ b['id'] for b in batch ] 166 | return {'image': image, 'id': id} 167 | 168 | 169 | def create_model_load_weights(n_class, mode=1, evaluation=False, path_g=None, path_g2l=None, path_l2g=None): 170 | model = fpn(n_class) 171 | model = nn.DataParallel(model) 172 | model = model.cuda() 173 | 174 | if (mode == 2 and not evaluation) or (mode == 1 and evaluation): 175 | # load fixed basic global branch 176 | partial = torch.load(path_g) 177 | state = model.state_dict() 178 | # 1. filter out unnecessary keys 179 | pretrained_dict = {k: v for k, v in partial.items() if k in state and "local" not in k} 180 | # 2. overwrite entries in the existing state dict 181 | state.update(pretrained_dict) 182 | # 3. load the new state dict 183 | model.load_state_dict(state) 184 | 185 | if (mode == 3 and not evaluation) or (mode == 2 and evaluation): 186 | partial = torch.load(path_g2l) 187 | state = model.state_dict() 188 | # 1. filter out unnecessary keys 189 | pretrained_dict = {k: v for k, v in partial.items() if k in state}# and "global" not in k} 190 | # 2. overwrite entries in the existing state dict 191 | state.update(pretrained_dict) 192 | # 3. load the new state dict 193 | model.load_state_dict(state) 194 | 195 | global_fixed = None 196 | if mode == 3: 197 | # load fixed basic global branch 198 | global_fixed = fpn(n_class) 199 | global_fixed = nn.DataParallel(global_fixed) 200 | global_fixed = global_fixed.cuda() 201 | partial = torch.load(path_g) 202 | state = global_fixed.state_dict() 203 | # 1. filter out unnecessary keys 204 | pretrained_dict = {k: v for k, v in partial.items() if k in state and "local" not in k} 205 | # 2. overwrite entries in the existing state dict 206 | state.update(pretrained_dict) 207 | # 3. load the new state dict 208 | global_fixed.load_state_dict(state) 209 | global_fixed.eval() 210 | 211 | if mode == 3 and evaluation: 212 | partial = torch.load(path_l2g) 213 | state = model.state_dict() 214 | # 1. filter out unnecessary keys 215 | pretrained_dict = {k: v for k, v in partial.items() if k in state}# and "global" not in k} 216 | # 2. overwrite entries in the existing state dict 217 | state.update(pretrained_dict) 218 | # 3. load the new state dict 219 | model.load_state_dict(state) 220 | 221 | if mode == 1 or mode == 3: 222 | model.module.resnet_local.eval() 223 | model.module.fpn_local.eval() 224 | else: 225 | model.module.resnet_global.eval() 226 | model.module.fpn_global.eval() 227 | 228 | return model, global_fixed 229 | 230 | 231 | def get_optimizer(model, mode=1, learning_rate=2e-5): 232 | if mode == 1 or mode == 3: 233 | # train global 234 | optimizer = torch.optim.Adam([ 235 | {'params': model.module.resnet_global.parameters(), 'lr': learning_rate}, 236 | {'params': model.module.resnet_local.parameters(), 'lr': 0}, 237 | {'params': model.module.fpn_global.parameters(), 'lr': learning_rate}, 238 | {'params': model.module.fpn_local.parameters(), 'lr': 0}, 239 | {'params': model.module.ensemble_conv.parameters(), 'lr': learning_rate}, 240 | ], weight_decay=5e-4) 241 | else: 242 | # train local 243 | optimizer = torch.optim.Adam([ 244 | {'params': model.module.resnet_global.parameters(), 'lr': 0}, 245 | {'params': model.module.resnet_local.parameters(), 'lr': learning_rate}, 246 | {'params': model.module.fpn_global.parameters(), 'lr': 0}, 247 | {'params': model.module.fpn_local.parameters(), 'lr': learning_rate}, 248 | {'params': model.module.ensemble_conv.parameters(), 'lr': learning_rate}, 249 | ], weight_decay=5e-4) 250 | return optimizer 251 | 252 | 253 | class Trainer(object): 254 | def __init__(self, criterion, optimizer, n_class, size_g, size_p, sub_batch_size=6, mode=1, lamb_fmreg=0.15): 255 | self.criterion = criterion 256 | self.optimizer = optimizer 257 | self.metrics_global = ConfusionMatrix(n_class) 258 | self.metrics_local = ConfusionMatrix(n_class) 259 | self.metrics = ConfusionMatrix(n_class) 260 | self.n_class = n_class 261 | self.size_g = size_g 262 | self.size_p = size_p 263 | self.sub_batch_size = sub_batch_size 264 | self.mode = mode 265 | self.lamb_fmreg = lamb_fmreg 266 | 267 | def set_train(self, model): 268 | model.module.ensemble_conv.train() 269 | if self.mode == 1 or self.mode == 3: 270 | model.module.resnet_global.train() 271 | model.module.fpn_global.train() 272 | else: 273 | model.module.resnet_local.train() 274 | model.module.fpn_local.train() 275 | 276 | def get_scores(self): 277 | score_train = self.metrics.get_scores() 278 | score_train_local = self.metrics_local.get_scores() 279 | score_train_global = self.metrics_global.get_scores() 280 | return score_train, score_train_global, score_train_local 281 | 282 | def reset_metrics(self): 283 | self.metrics.reset() 284 | self.metrics_local.reset() 285 | self.metrics_global.reset() 286 | 287 | def train(self, sample, model, global_fixed): 288 | images, labels = sample['image'], sample['label'] # PIL images 289 | labels_npy = masks_transform(labels, numpy=True) # label of origin size in numpy 290 | 291 | images_glb = resize(images, self.size_g) # list of resized PIL images 292 | images_glb = images_transform(images_glb) 293 | labels_glb = resize(labels, (self.size_g[0] // 4, self.size_g[1] // 4), label=True) # FPN down 1/4, for loss 294 | labels_glb = masks_transform(labels_glb) 295 | 296 | if self.mode == 2 or self.mode == 3: 297 | patches, coordinates, templates, sizes, ratios = global2patch(images, self.size_p) 298 | label_patches, _, _, _, _ = global2patch(labels, self.size_p) 299 | predicted_patches = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ] 300 | predicted_ensembles = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ] 301 | outputs_global = [ None for i in range(len(images)) ] 302 | 303 | if self.mode == 1: 304 | # training with only (resized) global image ######################################### 305 | outputs_global, _ = model.forward(images_glb, None, None, None) 306 | loss = self.criterion(outputs_global, labels_glb) 307 | loss.backward() 308 | self.optimizer.step() 309 | self.optimizer.zero_grad() 310 | ############################################## 311 | 312 | if self.mode == 2: 313 | # training with patches ########################################### 314 | for i in range(len(images)): 315 | j = 0 316 | while j < len(coordinates[i]): 317 | patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w 318 | label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], (self.size_p[0] // 4, self.size_p[1] // 4), label=True)) # down 1/4 for loss 319 | 320 | output_ensembles, output_global, output_patches, fmreg_l2 = model.forward(images_glb[i:i+1], patches_var, coordinates[i][j : j+self.sub_batch_size], ratios[i], mode=self.mode, n_patch=len(coordinates[i])) 321 | loss = self.criterion(output_patches, label_patches_var) + self.criterion(output_ensembles, label_patches_var) + self.lamb_fmreg * fmreg_l2 322 | loss.backward() 323 | 324 | # patch predictions 325 | predicted_patches[i][j:j+output_patches.size()[0]] = F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy() 326 | predicted_ensembles[i][j:j+output_ensembles.size()[0]] = F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy() 327 | j += self.sub_batch_size 328 | outputs_global[i] = output_global 329 | outputs_global = torch.cat(outputs_global, dim=0) 330 | 331 | self.optimizer.step() 332 | self.optimizer.zero_grad() 333 | ##################################################################################### 334 | 335 | if self.mode == 3: 336 | # train global with help from patches ################################################## 337 | # go through local patches to collect feature maps 338 | # collect predictions from patches 339 | for i in range(len(images)): 340 | j = 0 341 | while j < len(coordinates[i]): 342 | patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w 343 | fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, ratios[i], coordinates[i], [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=templates[i], n_patch_all=len(coordinates[i])) 344 | predicted_patches[i][j:j+output_patches.size()[0]] = F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy() 345 | j += self.sub_batch_size 346 | # train on global image 347 | outputs_global, fm_global = model.forward(images_glb, None, None, None, mode=self.mode) 348 | loss = self.criterion(outputs_global, labels_glb) 349 | loss.backward(retain_graph=True) 350 | # fmreg loss 351 | # generate ensembles & calc loss 352 | for i in range(len(images)): 353 | j = 0 354 | while j < len(coordinates[i]): 355 | label_patches_var = masks_transform(resize(label_patches[i][j : j+self.sub_batch_size], (self.size_p[0] // 4, self.size_p[1] // 4), label=True)) 356 | fl = fm_patches[i][j : j+self.sub_batch_size].cuda() 357 | fg = model.module._crop_global(fm_global[i:i+1], coordinates[i][j:j+self.sub_batch_size], ratios[i])[0] 358 | fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear') 359 | output_ensembles = model.module.ensemble(fl, fg) 360 | loss = self.criterion(output_ensembles, label_patches_var)# + 0.15 * mse(fl, fg) 361 | if i == len(images) - 1 and j + self.sub_batch_size >= len(coordinates[i]): 362 | loss.backward() 363 | else: 364 | loss.backward(retain_graph=True) 365 | 366 | # ensemble predictions 367 | predicted_ensembles[i][j:j+output_ensembles.size()[0]] = F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy() 368 | j += self.sub_batch_size 369 | self.optimizer.step() 370 | self.optimizer.zero_grad() 371 | 372 | # global predictions ########################### 373 | outputs_global = outputs_global.cpu() 374 | predictions_global = [F.interpolate(outputs_global[i:i+1], images[i].size[::-1], mode='nearest').argmax(1).detach().numpy() for i in range(len(images))] 375 | self.metrics_global.update(labels_npy, predictions_global) 376 | 377 | if self.mode == 2 or self.mode == 3: 378 | # patch predictions ########################### 379 | scores_local = np.array(patch2global(predicted_patches, self.n_class, sizes, coordinates, self.size_p)) # merge softmax scores from patches (overlaps) 380 | predictions_local = scores_local.argmax(1) # b, h, w 381 | self.metrics_local.update(labels_npy, predictions_local) 382 | ################################################### 383 | # combined/ensemble predictions ########################### 384 | scores = np.array(patch2global(predicted_ensembles, self.n_class, sizes, coordinates, self.size_p)) # merge softmax scores from patches (overlaps) 385 | predictions = scores.argmax(1) # b, h, w 386 | self.metrics.update(labels_npy, predictions) 387 | return loss 388 | 389 | 390 | class Evaluator(object): 391 | def __init__(self, n_class, size_g, size_p, sub_batch_size=6, mode=1, test=False): 392 | self.metrics_global = ConfusionMatrix(n_class) 393 | self.metrics_local = ConfusionMatrix(n_class) 394 | self.metrics = ConfusionMatrix(n_class) 395 | self.n_class = n_class 396 | self.size_g = size_g 397 | self.size_p = size_p 398 | self.sub_batch_size = sub_batch_size 399 | self.mode = mode 400 | self.test = test 401 | 402 | if test: 403 | self.flip_range = [False, True] 404 | self.rotate_range = [0, 1, 2, 3] 405 | else: 406 | self.flip_range = [False] 407 | self.rotate_range = [0] 408 | 409 | def get_scores(self): 410 | score_train = self.metrics.get_scores() 411 | score_train_local = self.metrics_local.get_scores() 412 | score_train_global = self.metrics_global.get_scores() 413 | return score_train, score_train_global, score_train_local 414 | 415 | def reset_metrics(self): 416 | self.metrics.reset() 417 | self.metrics_local.reset() 418 | self.metrics_global.reset() 419 | 420 | def eval_test(self, sample, model, global_fixed): 421 | with torch.no_grad(): 422 | images = sample['image'] 423 | if not self.test: 424 | labels = sample['label'] # PIL images 425 | labels_npy = masks_transform(labels, numpy=True) 426 | 427 | images_global = resize(images, self.size_g) 428 | outputs_global = np.zeros((len(images), self.n_class, self.size_g[0] // 4, self.size_g[1] // 4)) 429 | if self.mode == 2 or self.mode == 3: 430 | images_local = [ image.copy() for image in images ] 431 | scores_local = [ np.zeros((1, self.n_class, images[i].size[1], images[i].size[0])) for i in range(len(images)) ] 432 | scores = [ np.zeros((1, self.n_class, images[i].size[1], images[i].size[0])) for i in range(len(images)) ] 433 | 434 | for flip in self.flip_range: 435 | if flip: 436 | # we already rotated images for 270' 437 | for b in range(len(images)): 438 | images_global[b] = transforms.functional.rotate(images_global[b], 90) # rotate back! 439 | images_global[b] = transforms.functional.hflip(images_global[b]) 440 | if self.mode == 2 or self.mode == 3: 441 | images_local[b] = transforms.functional.rotate(images_local[b], 90) # rotate back! 442 | images_local[b] = transforms.functional.hflip(images_local[b]) 443 | for angle in self.rotate_range: 444 | if angle > 0: 445 | for b in range(len(images)): 446 | images_global[b] = transforms.functional.rotate(images_global[b], 90) 447 | if self.mode == 2 or self.mode == 3: 448 | images_local[b] = transforms.functional.rotate(images_local[b], 90) 449 | 450 | # prepare global images onto cuda 451 | images_glb = images_transform(images_global) # b, c, h, w 452 | 453 | if self.mode == 2 or self.mode == 3: 454 | patches, coordinates, templates, sizes, ratios = global2patch(images, self.size_p) 455 | predicted_patches = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ] 456 | predicted_ensembles = [ np.zeros((len(coordinates[i]), self.n_class, self.size_p[0], self.size_p[1])) for i in range(len(images)) ] 457 | 458 | if self.mode == 1: 459 | # eval with only resized global image ########################## 460 | if flip: 461 | outputs_global += np.flip(np.rot90(model.forward(images_glb, None, None, None)[0].data.cpu().numpy(), k=angle, axes=(3, 2)), axis=3) 462 | else: 463 | outputs_global += np.rot90(model.forward(images_glb, None, None, None)[0].data.cpu().numpy(), k=angle, axes=(3, 2)) 464 | ################################################################ 465 | 466 | if self.mode == 2: 467 | # eval with patches ########################################### 468 | for i in range(len(images)): 469 | j = 0 470 | while j < len(coordinates[i]): 471 | patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w 472 | output_ensembles, output_global, output_patches, _ = model.forward(images_glb[i:i+1], patches_var, coordinates[i][j : j+self.sub_batch_size], ratios[i], mode=self.mode, n_patch=len(coordinates[i])) 473 | 474 | # patch predictions 475 | predicted_patches[i][j:j+output_patches.size()[0]] += F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy() 476 | predicted_ensembles[i][j:j+output_ensembles.size()[0]] += F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy() 477 | j += patches_var.size()[0] 478 | if flip: 479 | outputs_global[i] += np.flip(np.rot90(output_global[0].data.cpu().numpy(), k=angle, axes=(2, 1)), axis=2) 480 | scores_local[i] += np.flip(np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)), axis=3) # merge softmax scores from patches (overlaps) 481 | scores[i] += np.flip(np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)), axis=3) # merge softmax scores from patches (overlaps) 482 | else: 483 | outputs_global[i] += np.rot90(output_global[0].data.cpu().numpy(), k=angle, axes=(2, 1)) 484 | scores_local[i] += np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps) 485 | scores[i] += np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps) 486 | ############################################################### 487 | 488 | if self.mode == 3: 489 | # eval global with help from patches ################################################## 490 | # go through local patches to collect feature maps 491 | # collect predictions from patches 492 | for i in range(len(images)): 493 | j = 0 494 | while j < len(coordinates[i]): 495 | patches_var = images_transform(patches[i][j : j+self.sub_batch_size]) # b, c, h, w 496 | fm_patches, output_patches = model.module.collect_local_fm(images_glb[i:i+1], patches_var, ratios[i], coordinates[i], [j, j+self.sub_batch_size], len(images), global_model=global_fixed, template=templates[i], n_patch_all=len(coordinates[i])) 497 | predicted_patches[i][j:j+output_patches.size()[0]] += F.interpolate(output_patches, size=self.size_p, mode='nearest').data.cpu().numpy() 498 | j += self.sub_batch_size 499 | # go through global image 500 | tmp, fm_global = model.forward(images_glb, None, None, None, mode=self.mode) 501 | if flip: 502 | outputs_global += np.flip(np.rot90(tmp.data.cpu().numpy(), k=angle, axes=(3, 2)), axis=3) 503 | else: 504 | outputs_global += np.rot90(tmp.data.cpu().numpy(), k=angle, axes=(3, 2)) 505 | # generate ensembles 506 | for i in range(len(images)): 507 | j = 0 508 | while j < len(coordinates[i]): 509 | fl = fm_patches[i][j : j+self.sub_batch_size].cuda() 510 | fg = model.module._crop_global(fm_global[i:i+1], coordinates[i][j:j+self.sub_batch_size], ratios[i])[0] 511 | fg = F.interpolate(fg, size=fl.size()[2:], mode='bilinear') 512 | output_ensembles = model.module.ensemble(fl, fg) # include cordinates 513 | 514 | # ensemble predictions 515 | predicted_ensembles[i][j:j+output_ensembles.size()[0]] += F.interpolate(output_ensembles, size=self.size_p, mode='nearest').data.cpu().numpy() 516 | j += self.sub_batch_size 517 | if flip: 518 | scores_local[i] += np.flip(np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)), axis=3)[0] # merge softmax scores from patches (overlaps) 519 | scores[i] += np.flip(np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)), axis=3)[0] # merge softmax scores from patches (overlaps) 520 | else: 521 | scores_local[i] += np.rot90(np.array(patch2global(predicted_patches[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps) 522 | scores[i] += np.rot90(np.array(patch2global(predicted_ensembles[i:i+1], self.n_class, sizes[i:i+1], coordinates[i:i+1], self.size_p)), k=angle, axes=(3, 2)) # merge softmax scores from patches (overlaps) 523 | ################################################### 524 | 525 | # global predictions ########################### 526 | outputs_global = torch.Tensor(outputs_global) 527 | predictions_global = [F.interpolate(outputs_global[i:i+1], images[i].size[::-1], mode='nearest').argmax(1).detach().numpy()[0] for i in range(len(images))] 528 | if not self.test: 529 | self.metrics_global.update(labels_npy, predictions_global) 530 | 531 | if self.mode == 2 or self.mode == 3: 532 | # patch predictions ########################### 533 | predictions_local = [ score.argmax(1)[0] for score in scores_local ] 534 | if not self.test: 535 | self.metrics_local.update(labels_npy, predictions_local) 536 | ################################################### 537 | # combined/ensemble predictions ########################### 538 | predictions = [ score.argmax(1)[0] for score in scores ] 539 | if not self.test: 540 | self.metrics.update(labels_npy, predictions) 541 | return predictions, predictions_global, predictions_local 542 | else: 543 | return None, predictions_global, None 544 | --------------------------------------------------------------------------------