├── BaseTester.py ├── BaseTrainer.py ├── LICENSE ├── Mem_monitor.py ├── README.md ├── checkpoint-best.pth ├── configs ├── __init__.py ├── config.cfg └── config.py ├── data ├── __init__.py └── dataset.py ├── environment.yml ├── gpu_utilization_test.py ├── gradient_debug.py ├── inference_compare.py ├── launcher.py ├── metrics.py ├── models ├── BaseModel.py ├── EDANet.py ├── ENet.py ├── ERFNet.py ├── FCN.py ├── MyNetworks │ ├── ESFNet.py │ └── layers.py ├── SegNet.py ├── UNet.py └── __init__.py ├── predict.py ├── train.py └── utils ├── __init__.py ├── dataset.py ├── unpatchy.py └── util.py /BaseTester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import sys 5 | import time 6 | import torch.optim as optim 7 | import torch.nn as nn 8 | import cv2 9 | import glob 10 | import torchvision.transforms.functional as TF 11 | import torchvision.transforms as transforms 12 | from tqdm import tqdm 13 | from metrics import Accuracy, MIoU 14 | from utils.util import AverageMeter, ensure_dir 15 | from PIL import Image 16 | 17 | 18 | 19 | class BaseTester(object): 20 | def __init__(self, 21 | model, 22 | config, 23 | args, 24 | test_data_loader, 25 | begin_time, 26 | resume_file, 27 | loss_weight, 28 | ): 29 | 30 | # for general 31 | self.config = config 32 | self.args = args 33 | self.device = torch.device('cpu') if self.args.gpu == -1 else torch.device('cuda:{}'.format(self.args.gpu)) 34 | #self.do_predict = do_predict 35 | 36 | # for train 37 | #self.visdom = visdom 38 | self.model = model.to(self.device) 39 | self.loss_weight = loss_weight.to(self.device) 40 | self.loss = self._loss(loss_function= self.config.loss).to(self.device) 41 | self.optimizer = self._optimizer(lr_algorithm=self.config.lr_algorithm) 42 | self.lr_scheduler = self._lr_scheduler() 43 | 44 | # for time 45 | self.begin_time = begin_time 46 | 47 | # for data 48 | self.test_data_loader = test_data_loader 49 | 50 | # for resume/save path 51 | self.history = { 52 | 'eval': { 53 | 'loss': [], 54 | 'acc': [], 55 | 'miou': [], 56 | 'time': [], 57 | }, 58 | } 59 | self.test_log_path = os.path.join(self.args.output, 'test', 'log', self.model.name, self.begin_time) 60 | self.predict_path = os.path.join(self.args.output, 'test', 'predict', self.model.name, self.begin_time) 61 | # here begin_time is the same with the time used in BaseTrainer.py 62 | # loading args.weight or the checkpoint-best.pth 63 | self.resume_ckpt_path = resume_file if resume_file is not None else \ 64 | os.path.join(self.config.save_dir, self.model.name, self.begin_time, 'checkpoint-best.pth') 65 | 66 | ensure_dir(self.test_log_path) 67 | ensure_dir(self.predict_path) 68 | 69 | def _optimizer(self, lr_algorithm): 70 | 71 | if lr_algorithm == 'adam': 72 | optimizer = optim.Adam(self.model.parameters(), 73 | lr=self.config.init_lr, 74 | betas=(0.9, 0.999), 75 | eps=1e-08, 76 | weight_decay=self.config.weight_decay, 77 | amsgrad=False) 78 | return optimizer 79 | if lr_algorithm == 'sgd': 80 | optimizer = optim.SGD(self.model.parameters(), 81 | lr=self.config.init_lr, 82 | momentum=self.config.momentum, 83 | dampening=0, 84 | weight_decay=self.config.weight_decay, 85 | nesterov=True) 86 | return optimizer 87 | 88 | def _loss(self, loss_function): 89 | """ 90 | loss weight, ignore_index 91 | :param loss_function: bce_loss / cross_entropy 92 | :return: 93 | """ 94 | if loss_function == 'bceloss': 95 | loss = nn.BCEWithLogitsLoss(weight=self.loss_weight) 96 | return loss 97 | 98 | if loss_function == 'crossentropy': 99 | loss = nn.CrossEntropyLoss(weight=self.loss_weight) 100 | return loss 101 | 102 | def _lr_scheduler(self): 103 | 104 | lambda1 = lambda epoch: pow((1-((epoch-1)/self.config.epochs)), 0.9) 105 | lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda1) 106 | return lr_scheduler 107 | 108 | def eval_and_predict(self): 109 | 110 | self._resume_ckpt() 111 | 112 | self.model.eval() 113 | 114 | #predictions = [] 115 | #filenames = [] 116 | predict_time = AverageMeter() 117 | batch_time = AverageMeter() 118 | data_time = AverageMeter() 119 | ave_total_loss = AverageMeter() 120 | ave_acc = AverageMeter() 121 | ave_iou = AverageMeter() 122 | 123 | with torch.no_grad(): 124 | tic = time.time() 125 | for steps, (data, target, filename) in enumerate(self.test_data_loader,start=1): 126 | 127 | # data 128 | data = data.to(self.device, non_blocking=True) 129 | target = target.to(self.device, non_blocking=True) 130 | data_time.update(time.time()-tic) 131 | 132 | # output, loss, and metrics 133 | pre_tic = time.time() 134 | logits = self.model(data) 135 | self._save_pred(logits, filename) 136 | predict_time.update(time.time()-pre_tic) 137 | 138 | loss = self.loss(logits, target) 139 | acc = Accuracy(logits, target) 140 | miou = MIoU(logits, target, self.config.nb_classes) 141 | 142 | # update ave loss and metrics 143 | batch_time.update(time.time()-tic) 144 | tic = time.time() 145 | 146 | ave_total_loss.update(loss.data.item()) 147 | ave_acc.update(acc) 148 | ave_iou.update(miou) 149 | 150 | # display evaluation result at the end 151 | print('Evaluation phase !\n' 152 | 'Time: {:.2f}, Data: {:.2f},\n' 153 | 'MIoU: {:6.4f}, Accuracy: {:6.4f}, Loss: {:.6f}' 154 | .format(batch_time.average(), data_time.average(), 155 | ave_iou.average(), ave_acc.average(), ave_total_loss.average())) 156 | #print('Saving Predict Map ... ...') 157 | #self._save_pred(predictions, filenames) 158 | print('Prediction Phase !\n' 159 | 'Total Time cost: {}s\n' 160 | 'Average Time cost per batch: {}s!' 161 | .format(predict_time._get_sum(), predict_time.average())) 162 | 163 | 164 | self.history['eval']['loss'].append(ave_total_loss.average()) 165 | self.history['eval']['acc'].append(ave_acc.average()) 166 | self.history['eval']['miou'].append(ave_iou.average()) 167 | self.history['eval']['time'].append(predict_time.average()) 168 | 169 | #TODO 170 | print(" + Saved history of evaluation phase !") 171 | hist_path = os.path.join(self.test_log_path, "history1.txt") 172 | with open(hist_path, 'w') as f: 173 | f.write(str(self.history)) 174 | 175 | def _save_pred(self, predictions, filenames): 176 | """ 177 | save predictions after evaluation phase 178 | :param predictions: predictions (output of model logits(after softmax)) 179 | :param filenames: filenames list correspond to predictions 180 | :return: None 181 | """ 182 | 183 | for index, map in enumerate(predictions): 184 | 185 | map = torch.argmax(map, dim=0) 186 | 187 | map = map * 255 188 | map = np.asarray(map.cpu(), dtype=np.uint8) 189 | map = Image.fromarray(map) 190 | # filename /0.1.png [0] 0 [1] 1 191 | filename = filenames[index].split('/')[-1].split('.') 192 | save_filename = filename[0]+'.'+filename[1] 193 | save_path = os.path.join(self.predict_path, save_filename+'.png') 194 | 195 | map.save(save_path) 196 | 197 | # pred is tensor --> numpy.ndarray save as single-channel --> save 198 | # get a mask 不用管channel的问题 199 | 200 | 201 | def _resume_ckpt(self): 202 | 203 | print(" + Loading ckpt path : {} ...".format(self.resume_ckpt_path)) 204 | checkpoint = torch.load(self.resume_ckpt_path) 205 | 206 | self.model.load_state_dict(checkpoint['state_dict']) 207 | print(" + Model State Loaded ! :D ") 208 | self.optimizer.load_state_dict(checkpoint['optimizer']) 209 | print(" + Optimizer State Loaded ! :D ") 210 | print(" + Checkpoint file: '{}' , Loaded ! \n" 211 | " + Prepare to test ! ! !" 212 | .format(self.resume_ckpt_path)) 213 | 214 | 215 | def _untrain_data_transform(self, data): 216 | 217 | rgb_mean = (0.4353, 0.4452, 0.4131) 218 | rgb_std = (0.2044, 0.1924, 0.2013) 219 | 220 | data = TF.resize(data, size=self.config.input_size) 221 | data = TF.to_tensor(data) 222 | data = TF.normalize(data, mean=rgb_mean, std=rgb_std) 223 | 224 | return data 225 | 226 | # Using for predicting only 227 | def prediction(self, data_loader_for_predict): 228 | 229 | self._resume_ckpt() 230 | self.model.eval() 231 | 232 | predict_time = AverageMeter() 233 | batch_time = AverageMeter() 234 | data_time = AverageMeter() 235 | 236 | with torch.no_grad(): 237 | tic = time.time() 238 | for steps, (data, target, filenames) in enumerate(data_loader_for_predict, start=1): 239 | 240 | # data 241 | data = data.to(self.device, non_blocking=True) 242 | data_time.update(time.time() - tic) 243 | 244 | pre_tic = time.time() 245 | logits = self.model(data) 246 | predict_time.update(time.time() - pre_tic) 247 | self._save_pred(logits, filenames) 248 | 249 | batch_time.update(time.time() - tic) 250 | tic = time.time() 251 | 252 | print("Predicting and Saving Done!\n" 253 | "Total Time: {:.2f}\n" 254 | "Data Time: {:.2f}\n" 255 | "Pre Time: {:.2f}" 256 | .format(batch_time._get_sum(), data_time._get_sum(), predict_time._get_sum())) 257 | -------------------------------------------------------------------------------- /BaseTrainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import time 8 | from utils.util import AverageMeter, ensure_dir 9 | from metrics import Accuracy, MIoU 10 | #from visdom import Visdom 11 | 12 | 13 | class BaseTrainer(object): 14 | 15 | def __init__(self, 16 | model, 17 | config, 18 | args, 19 | train_data_loader, 20 | valid_data_loader, 21 | 22 | visdom, 23 | begin_time, 24 | # TODO resume_file 25 | resume_file=None, 26 | loss_weight=None): 27 | 28 | print(" + Training Start ... ...") 29 | # for general 30 | self.config = config 31 | self.args = args 32 | self.device = (self._device(self.args.gpu)) 33 | self.model = model.to(self.device) 34 | self.train_data_loader = train_data_loader 35 | self.valid_data_loder = valid_data_loader 36 | 37 | # for time 38 | self.begin_time = begin_time # part of ckpt name 39 | self.save_period = self.config.save_period # for save ckpt 40 | self.dis_period = self.config.dis_period # for display 41 | 42 | # for save directory : model, best model, log{train_log, validation_log} per epoch 43 | ''' 44 | Directory: 45 | #root | -- train 46 | | -- valid 47 | | -- test 48 | | -- save | -- {model.name} | -- datetime | -- ckpt-epoch{}.pth.format(epoch) 49 | | | -- best_model.pth 50 | | 51 | | -- log | -- {model.name} | -- datetime | -- history1.txt 52 | | -- test| -- log 53 | | -- predict 54 | ''' 55 | # /home/UserGroup/UserName/Building_Detecion/cropped_aerial_torch/save/model.name/time 56 | # TODO model name setting 57 | self.checkpoint_dir = os.path.join(self.args.output, self.model.name, self.begin_time) 58 | # /home/UserGroup/UserName/Building_Detection/cropped_aerial_torch/save/log/model.name/time 59 | self.log_dir = os.path.join(self.args.output, 'log', self.model.name, self.begin_time) 60 | ensure_dir(self.checkpoint_dir) 61 | ensure_dir(self.log_dir) 62 | 63 | self.history = { 64 | 'train': { 65 | 'epoch': [], 66 | 'loss': [], 67 | 'acc': [], 68 | 'miou': [], 69 | }, 70 | 'valid': { 71 | 'epoch': [], 72 | 'loss': [], 73 | 'acc': [], 74 | 'miou': [], 75 | } 76 | } 77 | # for resume update curve 78 | self.viz_winname = { 79 | 'miou': [], 80 | 'loss': [], 81 | 'acc': [], 82 | 'learning_rate': [], 83 | } 84 | # for optimize 85 | self.loss_weight = loss_weight.to(self.device) 86 | self.loss = self._loss(loss_function=self.config.loss).to(self.device) 87 | self.optimizer = self._optimizer(lr_algorithm=self.config.lr_algorithm) 88 | self.lr_scheduler = self._lr_scheduler() 89 | self.weight_init_algorithm = self.config.init_algorithm 90 | self.current_lr = self.config.init_lr 91 | 92 | print(self.optimizer) 93 | print(self.loss) 94 | 95 | # for train 96 | self.start_epoch = 1 97 | self.early_stop = self.config.early_stop # early stop steps 98 | self.monitor_mode = self.config.monitor.split('/')[0] 99 | self.monitor_metric = self.config.monitor.split('/')[1] 100 | self.monitor_best = 0 101 | self.best_epoch = -1 102 | 103 | # resume file: the confirmed ckpt file. 104 | 105 | self.resume_file = resume_file 106 | self.resume_ = True if resume_file else False 107 | 108 | # monitor init 109 | if self.monitor_mode != 'off': 110 | assert self.monitor_mode in ['min', 'max'] 111 | self.monitor_best = math.inf if self.monitor_mode == 'min' else -math.inf 112 | 113 | if resume_file is not None: 114 | self._resume_ckpt(resume_file=resume_file) 115 | 116 | # TODO visualization 117 | # value needed to visualize: loss, metrics[acc, miou], learning_rate 118 | self.visdom = visdom 119 | 120 | def _device(self, gpu): 121 | 122 | if gpu == -1: 123 | device = torch.device('cpu') 124 | return device 125 | else: 126 | device = torch.device('cuda:{}'.format(gpu)) 127 | return device 128 | 129 | def _optimizer(self, lr_algorithm): 130 | 131 | if lr_algorithm == 'adam': 132 | optimizer = optim.Adam(self.model.parameters(), 133 | lr=self.config.init_lr, 134 | betas=(0.9, 0.999), 135 | eps=1e-08, 136 | weight_decay=self.config.weight_decay, 137 | amsgrad=False) 138 | return optimizer 139 | if lr_algorithm == 'sgd': 140 | optimizer = optim.SGD(self.model.parameters(), 141 | lr=self.config.init_lr, 142 | momentum=self.config.momentum, 143 | dampening=0, 144 | weight_decay=self.config.weight_decay, 145 | nesterov=True) 146 | ''' 147 | optimizer = optim.RMSprop(self.model.parameters(), 148 | lr = self.config.init_lr, 149 | weight_decay=self.config.weight_decay, 150 | momentum=self.config.momentum 151 | ) 152 | ''' 153 | return optimizer 154 | 155 | def _loss(self, loss_function): 156 | """ 157 | loss weight, ignore_index 158 | :param loss_function: bce_loss / cross_entropy 159 | :return: 160 | """ 161 | if loss_function == 'bceloss': 162 | loss = nn.BCEWithLogitsLoss(weight=self.loss_weight) 163 | return loss 164 | 165 | if loss_function == 'crossentropy': 166 | loss = nn.CrossEntropyLoss(weight=self.loss_weight) 167 | return loss 168 | 169 | def _lr_scheduler(self): 170 | 171 | lambda1 = lambda epoch: pow((1-((epoch-1)/self.config.epochs)), 0.9) 172 | lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda1) 173 | return lr_scheduler 174 | 175 | def _weight_init(self, m): 176 | 177 | # no bias use 178 | classname = m.__class__.__name__ 179 | if classname.find('Conv') != -1: 180 | if self.weight_init_algorithm == 'kaiming': 181 | init.kaiming_normal_(m.weight.data) 182 | else: 183 | init.xavier_normal_(m.weight.data) 184 | elif classname.find('BatchNorm') != -1: 185 | m.weight.data.normal_(1.0, 0.02) 186 | m.bias.data.fill_(0) 187 | 188 | def train(self): 189 | 190 | if self.resume_ == False: 191 | # create panes for training phase for loss metrics learning_rate 192 | print(" + Visualization init ... ...") 193 | loss_window = self.visdom.line( 194 | X = torch.stack((torch.ones(1),torch.ones(1)),1), 195 | Y = torch.stack((torch.ones(1),torch.ones(1)),1), 196 | opts= dict(title='train_val_loss', 197 | # for different size panes, the result of download is the same! 198 | showlegend=True, 199 | legend=['training_loss', 'valid_loss'], 200 | xtype='linear', 201 | label='epoch', 202 | xtickmin=0, 203 | xtick=True, 204 | xtickstep=10, 205 | ytype='linear', 206 | ylabel='loss', 207 | ytickmin=0, 208 | #ytickmax=1, 209 | #ytickstep=0.1, 210 | ytick=True, 211 | ) 212 | ) 213 | lr_window = self.visdom.line( 214 | X = torch.ones(1), 215 | Y = torch.tensor([self.current_lr]), 216 | opts = dict(title = 'learning_rate', 217 | showlegend=True, 218 | legend=['learning_rate'], 219 | xtype='linear', 220 | xlabel='epoch', 221 | xtickmin=0, 222 | xtick=True, 223 | xtickstep=10, 224 | ytype='linear', 225 | ytickmin=0, 226 | #ytickmax=1, 227 | #ytickstep=0.1, 228 | ylabel='lr', 229 | ytick=True) 230 | ) 231 | miou_window = self.visdom.line( 232 | X = torch.stack((torch.ones(1),torch.ones(1)),1), 233 | Y = torch.stack((torch.ones(1),torch.ones(1)),1), 234 | opts = dict(title='train_val_MIoU', 235 | showlegend=True, 236 | legend=['Train_MIoU', 'Val_MIoU'], 237 | xtype='linear', 238 | xlabel='epoch', 239 | xtickmin=0, 240 | xtick=True, 241 | xtickstep=10, 242 | ytype='linear', 243 | ylabel='MIoU', 244 | ytickmin=0, 245 | #ytickmax=1, 246 | #ytickstep=0.1, 247 | ytick=True 248 | ) 249 | ) 250 | acc_window = self.visdom.line( 251 | X = torch.stack((torch.ones(1), torch.ones(1)),1), 252 | Y = torch.stack((torch.ones(1), torch.ones(1)),1), 253 | opts = dict(title='train_val_Accuracy', 254 | showlegend=True, 255 | legend=['Train_Acc', 'Val_Acc'], 256 | xtype='linear', 257 | xlabel='epoch', 258 | xtickmin=0, 259 | xtick=True, 260 | xtickstep=10, 261 | ytype='linear', 262 | ylabel='Accuracy', 263 | ytickmin=0, 264 | #ytickmax=1, 265 | #ytickstep=0.1, 266 | ytick=True) 267 | ) 268 | self.viz_winname['miou'].append(str(miou_window)) 269 | self.viz_winname['loss'].append(str(loss_window)) 270 | self.viz_winname['learning_rate'].append(str(lr_window)) 271 | self.viz_winname['acc'].append(str(acc_window)) 272 | else: 273 | # here already loaded the checkpoint file: resume_file 274 | print(" + Loading visdom file ... ... Done!") 275 | 276 | print(" + Loaded, Training !") 277 | 278 | epochs = self.config.epochs 279 | # init weights at first 280 | self.model.apply(self._weight_init) 281 | for epoch in range(self.start_epoch, epochs+1): 282 | 283 | # get log information of train and evaluation phase 284 | train_log = self._train_epoch(epoch) 285 | eval_log = self._eval_epoch(epoch) 286 | 287 | # TODO visualization 288 | # for loss 289 | self.visdom.line( 290 | X = torch.stack((torch.ones(1)*epoch,torch.ones(1)*epoch),1), 291 | Y = torch.stack((torch.tensor([train_log['loss']]),torch.tensor([eval_log['val_Loss']])),1), 292 | win = self.viz_winname['loss'][0], 293 | update='append' if epoch!=1 else 'insert', 294 | ) 295 | # for metrics_miou 296 | self.visdom.line( 297 | X = torch.stack((torch.ones(1)*epoch, torch.ones(1)*epoch),1), 298 | Y = torch.stack((torch.tensor([train_log['miou']]), torch.tensor([eval_log['val_MIoU']])),1), 299 | win = self.viz_winname['miou'][0], 300 | update='append' if epoch!=1 else 'insert', 301 | ) 302 | # for metrics_accuracy 303 | self.visdom.line( 304 | X = torch.stack((torch.ones(1)*epoch, torch.ones(1)*epoch),1), 305 | Y = torch.stack((torch.tensor([train_log['acc']]), torch.tensor([eval_log['val_Accuracy']])),1), 306 | win = self.viz_winname['acc'][0], 307 | update='append' if epoch!=1 else 'insert', 308 | ) 309 | # for learning_rate 310 | self.visdom.line( 311 | X=torch.ones(1) * epoch, 312 | Y=torch.tensor([self.current_lr]), 313 | win=self.viz_winname['learning_rate'][0], 314 | update='append' if epoch != 1 else 'insert', 315 | ) 316 | 317 | # save best model and save ckpt 318 | best = False 319 | not_improved_count = 0 320 | if self.monitor_mode != 'off': 321 | improved = (self.monitor_mode == 'min' and eval_log['val_'+self.monitor_metric] < self.monitor_best) or \ 322 | (self.monitor_mode == 'max' and eval_log['val_'+self.monitor_metric] > self.monitor_best) 323 | if improved: 324 | self.monitor_best = eval_log['val_'+self.monitor_metric] 325 | best = True 326 | self.best_epoch = eval_log['epoch'] 327 | else: 328 | not_improved_count += 1 329 | 330 | if not_improved_count > self.early_stop: 331 | print(" + Validation Performance didn\'t improve for {} epochs." 332 | " + Training stop :/" 333 | .format(not_improved_count)) 334 | break 335 | if epoch % self.save_period == 0 or best == True: 336 | self._save_ckpt(epoch, best=best) 337 | # save history file 338 | print(" + Saving History ... ... ") 339 | hist_path = os.path.join(self.log_dir, 'history1.txt') 340 | with open(hist_path, 'w') as f: 341 | f.write(str(self.history)) 342 | 343 | def _train_epoch(self, epoch): 344 | 345 | # lr update 346 | if self.lr_scheduler is not None: 347 | self.lr_scheduler.step(epoch) 348 | for param_group in self.optimizer.param_groups: 349 | self.current_lr = param_group['lr'] 350 | 351 | batch_time = AverageMeter() 352 | data_time = AverageMeter() 353 | ave_total_loss = AverageMeter() 354 | ave_acc = AverageMeter() 355 | ave_iou = AverageMeter() 356 | 357 | # set model mode 358 | self.model.train() 359 | tic = time.time() 360 | 361 | for steps, (data, target) in enumerate(self.train_data_loader, start=1): 362 | 363 | data = data.to(self.device, non_blocking=True) 364 | target = target.to(self.device, non_blocking=True) 365 | # 加载数据所用的时间 366 | data_time.update(time.time() - tic) 367 | 368 | # forward calculate 369 | logits = self.model(data) 370 | loss = self.loss(logits, target) 371 | acc = Accuracy(logits, target) 372 | miou = MIoU(logits, target, self.config.nb_classes) 373 | 374 | # compute gradient and do SGD step 375 | self.optimizer.zero_grad() 376 | loss.backward() 377 | self.optimizer.step() 378 | 379 | # update average metrics 380 | batch_time.update(time.time() - tic) 381 | ave_total_loss.update(loss.data.item()) 382 | ave_acc.update(acc.item()) 383 | ave_iou.update(miou.item()) 384 | 385 | # display on the screen per display_steps 386 | if steps % self.dis_period == 0: 387 | print('Epoch: [{}][{}/{}],\n' 388 | 'Learning_Rate: {:.6f},\n' 389 | 'Time: {:.4f}, Data: {:.4f},\n' 390 | 'MIoU: {:6.4f}, Accuracy: {:6.4f}, Loss: {:.6f}' 391 | .format(epoch, steps, len(self.train_data_loader), 392 | self.current_lr, 393 | batch_time.average(), data_time.average(), 394 | ave_iou.average(), ave_acc.average(), ave_total_loss.average())) 395 | tic = time.time() 396 | # train log and return 397 | self.history['train']['epoch'].append(epoch) 398 | self.history['train']['loss'].append(ave_total_loss.average()) 399 | self.history['train']['acc'].append(ave_acc.average()) 400 | self.history['train']['miou'].append(ave_iou.average()) 401 | return { 402 | 'epoch': epoch, 403 | 'loss': ave_total_loss.average(), 404 | 'acc': ave_acc.average(), 405 | 'miou': ave_iou.average(), 406 | } 407 | 408 | def _eval_epoch(self, epoch): 409 | 410 | 411 | batch_time = AverageMeter() 412 | data_time = AverageMeter() 413 | ave_total_loss = AverageMeter() 414 | ave_acc = AverageMeter() 415 | ave_iou = AverageMeter() 416 | 417 | # set model mode 418 | self.model.eval() 419 | 420 | with torch.no_grad(): 421 | tic = time.time() 422 | for steps, (data, target) in enumerate(self.valid_data_loder, start=1): 423 | 424 | # processing no blocking 425 | # non_blocking tries to convert asynchronously with respect to the host if possible 426 | # converting CPU tensor with pinned memory to CUDA tensor 427 | # overlap transfer if pinned memory 428 | data = data.to(self.device, non_blocking=True) 429 | target = target.to(self.device, non_blocking=True) 430 | data_time.update(time.time() - tic) 431 | 432 | logits = self.model(data) 433 | loss = self.loss(logits, target) 434 | # calculate metrics 435 | acc = Accuracy(logits, target) 436 | miou = MIoU(logits, target, self.config.nb_classes) 437 | #print("===========acc, miou==========", acc, miou) 438 | 439 | # update ave metrics 440 | batch_time.update(time.time()-tic) 441 | 442 | ave_total_loss.update(loss.data.item()) 443 | ave_acc.update(acc.item()) 444 | ave_iou.update(miou.item()) 445 | tic = time.time() 446 | # display validation at the end 447 | print('Epoch {} validation done !'.format(epoch)) 448 | print('Time: {:.4f}, Data: {:.4f},\n' 449 | 'MIoU: {:6.4f}, Accuracy: {:6.4f}, Loss: {:.6f}' 450 | .format(batch_time.average(), data_time.average(), 451 | ave_iou.average(), ave_acc.average(), ave_total_loss.average())) 452 | 453 | self.history['valid']['epoch'].append(epoch) 454 | self.history['valid']['loss'].append(ave_total_loss.average()) 455 | self.history['valid']['acc'].append(ave_acc.average()) 456 | self.history['valid']['miou'].append(ave_iou.average()) 457 | # validation log and return 458 | return { 459 | 'epoch': epoch, 460 | 'val_Loss': ave_total_loss.average(), 461 | 'val_Accuracy': ave_acc.average(), 462 | 'val_MIoU': ave_iou.average(), 463 | } 464 | 465 | def _save_ckpt(self, epoch, best): 466 | 467 | # save model ckpt 468 | state = { 469 | 'epoch': epoch, 470 | 'arch': str(self.model), 471 | 'history': self.history, 472 | 'state_dict': self.model.state_dict(), 473 | 'optimizer': self.optimizer.state_dict(), 474 | 'monitor_best': self.monitor_best, 475 | 'windows_name': self.viz_winname, 476 | } 477 | filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{}.pth'.format(epoch)) 478 | best_filename = os.path.join(self.checkpoint_dir, 'checkpoint-best.pth') 479 | if best: 480 | print(" + Saving Best Checkpoint : Epoch {} path: {} ... ".format(epoch, best_filename)) 481 | torch.save(state, best_filename) 482 | else: 483 | print(" + Saving Checkpoint per {} epochs, path: {} ... ".format(self.save_period, filename)) 484 | torch.save(state, filename) 485 | 486 | def _resume_ckpt(self, resume_file): 487 | """ 488 | :param resume_file: checkpoint file name 489 | :return: 490 | """ 491 | resume_path = os.path.join(resume_file) 492 | print(" + Loading Checkpoint: {} ... ".format(resume_path)) 493 | checkpoint = torch.load(resume_path) 494 | self.start_epoch = checkpoint['epoch'] + 1 # the next epoch, what saved in the ckeckpoint is the end-epoch 495 | self.monitor_best = checkpoint['monitor_best'] 496 | 497 | self.model.load_state_dict(checkpoint['state_dict']) 498 | print(" + Model State Loaded ! :D ") 499 | self.optimizer.load_state_dict(checkpoint['optimizer']) 500 | print(" + Optimizer State Loaded ! :D ") 501 | self.history = checkpoint['history'] 502 | self.viz_winname = checkpoint['windows_name'] 503 | print(" + Checkpoint file: '{}' , Start epoch {} Loaded !\n" 504 | " + Prepare to run ! ! !" 505 | .format(resume_path, self.start_epoch)) 506 | 507 | def state_cuda(self, msg): 508 | print("--", msg) 509 | print("allocated: %dM, max allocated: %dM, cached: %dM, max cached: %dM"%( 510 | torch.cuda.memory_allocated(self.device) / 1024 / 1024, 511 | torch.cuda.max_memory_allocated(self.device) / 1024/ 1024, 512 | torch.cuda.memory_cached(self.device) / 1024/ 1024, 513 | torch.cuda.max_memory_cached(self.device) / 1024/ 1024, 514 | )) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /Mem_monitor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # type_size 默认为4 默认类型为float32 7 | 8 | def modelsize(model, input, type_size=4): 9 | para = sum([np.prod(list(p.size())) for p in model.parameters()]) 10 | print('Model {} : params: {:4f}M'.format(model._get_name(), para*type_size/1000/1000)) 11 | 12 | input_ = input.clone() 13 | input_.requires_grad_(requires_grad=False) 14 | 15 | mods = list(model.modules()) 16 | # 存储输出的size的数组 17 | out_sizes= [] 18 | 19 | for i in range(1, len(mods)): 20 | m = mods[i] 21 | if isinstance(m, nn.ReLU): 22 | if m.inplace: 23 | continue 24 | out = m(input_) 25 | out_sizes.append(np.array(out.size())) 26 | input_ = out 27 | 28 | total_nums = 0 29 | for i in range(len(out_sizes)): 30 | s = out_sizes[i] 31 | nums = np.prod(np.array(s)) 32 | total_nums += nums 33 | 34 | print('Model {} : intermedite variables: {:3f} M (without backward)' 35 | .format(model._get_name(), total_nums * type_size/ 1000/ 1000)) 36 | # backward 过程所有中间变量需要保存下来再进行计算 37 | print('Model {} : intermedite variables: {:3f} M (with backward)' 38 | .format(model._get_name(), total_nums * type_size *2/ 1000/ 1000)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ESFNet: Efficient Networks for Building Extraction from High-Resolution Images 2 | The implementation of novel efficient neural network ESFNet 3 | 4 | ### Clone the Repository 5 | ``` 6 | git clone https://github.com/mrluin/ESFNet-Pytorch.git 7 | ``` 8 | ``` 9 | cd ./ESFNet-Pytorch 10 | ``` 11 | 12 | 13 | ### Installation using Conda 14 | ``` 15 | conda env create -f environment.yml 16 | ``` 17 | ``` 18 | conda activate esfnet 19 | ``` 20 | 21 | ### Sample Dataset 22 | For training, you can use as an example the [WHU Building Datase](study.rsgis.whu.edu.cn/pages/download/). 23 | 24 | You would need to download the cropped aerial images. `The 3rd option` 25 | 26 | ### Directory Structure 27 | ``` 28 | Directory: 29 | #root | -- train 30 | | -- valid 31 | | -- test 32 | | -- save | -- {model.name} | -- datetime | -- ckpt-epoch{}.pth.format(epoch) 33 | | | -- best_model.pth 34 | | 35 | | -- log | -- {model.name} | -- datetime | -- history.txt 36 | | -- test| -- log | -- {model.name} | --datetime | -- history.txt 37 | | -- predict | -- {model.name} | --datetime | -- *.png 38 | ``` 39 | ### Training 40 | 1. set `root_dir` in `./configs/config.cfg`, change the root_path like mentioned above. 41 | 2. set `divice_id` to choose which GPU will be used. 42 | 3. set `epochs` to control the length of the training phase. 43 | 4. setup the `train.py` script as follows: 44 | ``` 45 | python -m visdom.server -env_path='./visdom_log/' -port=8097 # start visdom server 46 | python train.py 47 | ``` 48 | `-env_path` is where the visdom logfile store in, and `-port` is the port for `visdom`. You could also change the `-port` in `train.py`. 49 | 50 | 51 | 52 | **If my work give you some insights and hints, star me please! Thank you~** 53 | -------------------------------------------------------------------------------- /checkpoint-best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrluin/ESFNet-Pytorch/a2c166a91281e96f953398cf953f446ff6337a14/checkpoint-best.pth -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /configs/config.cfg: -------------------------------------------------------------------------------- 1 | [Directory] 2 | root_dir = /home/UserGroup/UserName/Building_Detection/cropped_aerial_torch_512 3 | save_dir = ${root_dir}/save 4 | log_dir = ${save_dir}/log 5 | test_dir = ${save_dir}/test 6 | test_log_dir = ${test_dir}/log 7 | pred_dir = ${test_dir}/predict 8 | trainset_dir = ${root_dir}/train 9 | validset_dir = ${root_dir}/val 10 | testset_dir = ${root_dir}/test 11 | data_folder_name = image 12 | target_folder_name = label 13 | 14 | [Data] 15 | batch_size = 16 16 | nb_classes = 2 17 | original_size = 512 18 | cropped_size = 256 19 | input_size = 512 20 | overlapped = 0 21 | 22 | [General] 23 | use_gpu = True 24 | use_multi_gpus = False 25 | device_id = 0 26 | random_seed = 1 27 | num_workers = 0 28 | 29 | [Optimizer] 30 | lr_algorithm = adam 31 | init_lr = 5e-4 32 | lr_decay = 1e-1 33 | momentum = 0.9 34 | weight_decay = 2e-4 35 | epsilon = 1e-8 36 | 37 | [Train] 38 | monitor = max/MIoU 39 | init_algorithm = kaiming 40 | loss = crossentropy 41 | pre_trained = False 42 | visualization = True 43 | verbosity = 2 44 | early_stop = 10 45 | save_period = 10 46 | dis_period = 20 47 | epochs = 300 48 | 49 | -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | from configparser import ConfigParser 2 | import configparser 3 | 4 | """ 5 | # section distinguish the upper and lower letters, but key and value do not. 6 | The information in the configuration always store the data in string format, it will have translation when reading. 7 | [DEFAULT]: the value in [DEFAULT] offer default_value to all sections, and it owns the highest priority. 8 | get: getboolean() getint() getfloat() 9 | get方法 提供一个更复杂的界面 保持向后兼容性 可以回退关键字仅提供回退值 10 | fallback 回退值 11 | 12 | refer values in other sections 13 | interpolation method: configparser.BasicInterpolation() 14 | configparser.ExtendedInterpolation() ${section:key} 15 | config.set 16 | config.write 17 | """ 18 | class MyConfiguration(): 19 | def __init__(self, config_file=None): 20 | super(MyConfiguration, self).__init__() 21 | 22 | # ./ current directory 23 | if config_file is None: 24 | config_file = './configs/config.cfg' 25 | 26 | config = ConfigParser() 27 | # interpolation method 28 | config._interpolation = configparser.ExtendedInterpolation() 29 | config.read(filenames= config_file) 30 | 31 | self.config = config 32 | self.config_path = config_file 33 | self.add_section = 'Additional' 34 | print("Loaded config file successfully ...") 35 | #print(len(self.config.sections())) 36 | for section in self.config.sections(): 37 | for k, v in self.config.items(section): 38 | print(k, ":", v) 39 | 40 | # TODO make save dir 41 | 42 | config.write(open(config_file, 'w')) 43 | 44 | def add_args(self, key, value): 45 | self.config.set(self.add_section, key, value) 46 | self.config.write(open(self.config_path, 'w')) 47 | 48 | # string int float boolean 49 | @property 50 | def test_dir(self): 51 | return self.config.get("Directory", "test_dir") 52 | 53 | @property 54 | def test_log_dir(self): 55 | return self.config.get("Directory", "test_log_dir") 56 | 57 | @property 58 | def pred_dir(self): 59 | return self.config.get("Directory", "pred_dir") 60 | 61 | @property 62 | def root_dir(self): 63 | return self.config.get("Directory", "root_dir") 64 | 65 | @property 66 | def save_dir(self): 67 | return self.config.get("Directory", "save_dir") 68 | 69 | @property 70 | def log_dir(self): 71 | return self.config.get("Directory", "log_dir") 72 | 73 | @property 74 | def pred_dir(self): 75 | return self.config.get("Directory", "pred_dir") 76 | 77 | @property 78 | def trainset_dir(self): 79 | return self.config.get("Directory", "trainset_dir") 80 | 81 | @property 82 | def validset_dir(self): 83 | return self.config.get("Directory", "validset_dir") 84 | 85 | @property 86 | def testset_dir(self): 87 | return self.config.get("Directory", "testset_dir") 88 | 89 | @property 90 | def data_folder_name(self): 91 | return self.config.get("Directory", "data_folder_name") 92 | 93 | @property 94 | def target_folder_name(self): 95 | return self.config.get("Directory", "target_folder_name") 96 | 97 | @property 98 | def batch_size(self): 99 | return self.config.getint("Data", "batch_size") 100 | 101 | @property 102 | def nb_classes(self): 103 | return self.config.getint("Data", "nb_classes") 104 | 105 | @property 106 | def original_size(self): 107 | return self.config.getint("Data", "original_size") 108 | 109 | @property 110 | def cropped_size(self): 111 | return self.config.getint("Data", "cropped_size") 112 | 113 | @property 114 | def input_size(self): 115 | return self.config.getint("Data", "input_size") 116 | 117 | @property 118 | def overlapped(self): 119 | return self.config.getint("Data", "overlapped") 120 | 121 | @property 122 | def use_gpu(self): 123 | return self.config.getboolean("General", "use_gpu") 124 | 125 | @property 126 | def use_multi_gpus(self): 127 | return self.config.getboolean("General", "use_multi_gpus") 128 | 129 | @property 130 | def device_id(self): 131 | return self.config.getint("General", "device_id") 132 | 133 | @property 134 | def random_seed(self): 135 | return self.config.getint("General", "random_seed") 136 | 137 | @property 138 | def num_workers(self): 139 | return self.config.getint("General", "num_workers") 140 | 141 | @property 142 | def lr_algorithm(self): 143 | return self.config.get("Optimizer", "lr_algorithm") 144 | 145 | @property 146 | def init_lr(self): 147 | return self.config.getfloat("Optimizer", "init_lr") 148 | 149 | @property 150 | def lr_decay(self): 151 | return self.config.getfloat("Optimizer", "lr_decay") 152 | 153 | @property 154 | def momentum(self): 155 | return self.config.getfloat("Optimizer", "momentum") 156 | 157 | @property 158 | def weight_decay(self): 159 | return self.config.getfloat("Optimizer", "weight_decay") 160 | 161 | @property 162 | def epsilon(self): 163 | return self.config.getfloat("Optimizer", "epsilon") 164 | 165 | @property 166 | def monitor(self): 167 | return self.config.get("Train", "monitor") 168 | 169 | @property 170 | def dis_period(self): 171 | return self.config.getint("Train", "dis_period") 172 | 173 | @property 174 | def init_algorithm(self): 175 | return self.config.get("Train", "init_algorithm") 176 | 177 | @property 178 | def loss(self): 179 | return self.config.get("Train", "loss") 180 | 181 | @property 182 | def pre_trained(self): 183 | return self.config.getboolean("Train", "pre_trained") 184 | 185 | @property 186 | def visualization(self): 187 | return self.config.getboolean("Train", "visualization") 188 | 189 | @property 190 | def verbosity(self): 191 | return self.config.getint("Train", "verbosity") 192 | 193 | @property 194 | def early_stop(self): 195 | return self.config.getint("Train", "early_stop") 196 | 197 | @property 198 | def save_period(self): 199 | return self.config.getint("Train", "save_period") 200 | 201 | @property 202 | def epochs(self): 203 | return self.config.getint("Train", "epochs") 204 | 205 | 206 | 207 | 208 | 209 | if __name__ == '__main__': 210 | config = MyConfiguration() 211 | print(config.root_dir) 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms.functional as TF 4 | import torchvision.transforms as transforms 5 | import glob 6 | import random 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | 11 | rgb_mean = (0.4353, 0.4452, 0.4131) 12 | rgb_std = (0.2044, 0.1924, 0.2013) 13 | 14 | class MyDataset(Dataset): 15 | def __init__(self, 16 | config, 17 | args, 18 | subset): 19 | super(MyDataset, self).__init__() 20 | assert subset == 'train' or subset == 'val' or subset == 'test' 21 | 22 | self.args = args 23 | self.config = config 24 | self.root = args.input 25 | self.subset = subset 26 | self.data = self.config.data_folder_name # image 27 | self.target = self.config.target_folder_name # label 28 | 29 | #self.data_transforms = data_transforms if data_transforms!=None else TF.to_tensor 30 | #self.target_transforms = target_transforms if target_transforms!= None else TF.to_tensor 31 | 32 | self.mapping = { 33 | 0: 0, 34 | 255: 1, 35 | } 36 | self.data_list = glob.glob(os.path.join( 37 | self.root, 38 | subset, 39 | self.data, 40 | '*' 41 | )) 42 | self.target_list = glob.glob(os.path.join( 43 | self.root, 44 | subset, 45 | self.target, 46 | '*' 47 | )) 48 | 49 | def mask_to_class(self, mask): 50 | for k in self.mapping: 51 | mask[mask == k] = self.mapping[k] 52 | return mask 53 | 54 | def train_transforms(self, image, mask): 55 | 56 | resize = transforms.Resize(size=(self.config.input_size, self.config.input_size)) 57 | image = resize(image) 58 | mask = resize(mask) 59 | 60 | if random.random() > 0.5: 61 | image = TF.hflip(image) 62 | mask = TF.hflip(mask) 63 | 64 | if random.random() > 0.5: 65 | image = TF.vflip(image) 66 | mask = TF.vflip(mask) 67 | 68 | image = TF.to_tensor(image) # scale 0-1 69 | image = TF.normalize(image, mean=rgb_mean, std=rgb_std) # normalize 70 | mask = torch.from_numpy(np.array(mask, dtype=np.uint8)) 71 | mask = self.mask_to_class(mask) 72 | mask = mask.long() 73 | return image, mask 74 | 75 | def untrain_transforms(self, image, mask): 76 | 77 | resize = transforms.Resize(size=(self.config.input_size, self.config.input_size)) 78 | image = resize(image) 79 | mask = resize(mask) 80 | 81 | # 没有旋转的变化 82 | 83 | image = TF.to_tensor(image) 84 | image = TF.normalize(image, mean=rgb_mean, std=rgb_std) 85 | mask = torch.from_numpy(np.array(mask, dtype=np.uint8)) 86 | mask = self.mask_to_class(mask) 87 | mask = mask.long() 88 | return image, mask 89 | 90 | def __getitem__(self, index): 91 | 92 | datas = Image.open(self.data_list[index]) 93 | targets = Image.open(self.target_list[index]) 94 | if self.subset == 'train': 95 | t_datas, t_targets = self.train_transforms(datas, targets) 96 | return t_datas, t_targets 97 | elif self.subset == 'val': 98 | t_datas, t_targets = self.untrain_transforms(datas, targets) 99 | return t_datas, t_targets 100 | elif self.subset == 'test': 101 | t_datas, t_targets = self.untrain_transforms(datas, targets) 102 | return t_datas, t_targets, self.data_list[index] 103 | 104 | def __len__(self): 105 | 106 | return len(self.data_list) 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: esfnet 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.6 7 | - pytorch 8 | - torchvision 9 | - cudatoolkit 10 | - cudnn 11 | - visdom 12 | - pillow 13 | - tqdm 14 | - opencv 15 | - pandas 16 | 17 | -------------------------------------------------------------------------------- /gpu_utilization_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | 5 | import torchvision.models as models 6 | import torchvision.datasets as datasets 7 | import torchvision.transforms as transforms 8 | 9 | model = models.resnet50() 10 | criterion = nn.CrossEntropyLoss() 11 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 12 | 13 | dataset = datasets.FakeData( 14 | size=1000, 15 | transform=transforms.ToTensor() 16 | ) 17 | loader = DataLoader( 18 | dataset, 19 | num_workers=1, 20 | pin_memory=True, 21 | ) 22 | model.to('cuda') 23 | 24 | for data, target in loader: 25 | data = data.to('cuda', non_blocking=True) 26 | target = target.to('cuda', non_blocking=True).long() 27 | optimizer.zero_grad() 28 | output = model(data) 29 | loss = criterion(output, target) 30 | loss.backward() 31 | optimizer.step() 32 | 33 | print("Done") 34 | -------------------------------------------------------------------------------- /gradient_debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_printer(msg): 5 | """ 6 | returns a printer function, that prints information about a tensor's gradient 7 | Used by register_hook in the backward pass. 8 | :param msg: 9 | :return: printer function 10 | """ 11 | def printer(tensor): 12 | if tensor.nelement == 1: 13 | print("{} {}".format(msg, tensor)) 14 | else: 15 | print("{} shape: {}" 16 | "max: {} min: {}" 17 | "mean: {}" 18 | .format(msg, tensor.shape, tensor.max(), tensor.min(), tensor.mean())) 19 | return printer 20 | 21 | def register_hook(tensor, msg): 22 | """ 23 | Utility function to call retain_grad and register_hook in a single line 24 | :param tensor: 25 | :param msg: 26 | :return: 27 | """ 28 | tensor.retain_grad() 29 | tensor.register_hook(get_printer(msg)) 30 | 31 | if __name__ == '__main__': 32 | 33 | x = torch.randn((1,1), requires_grad=True) 34 | y = 3*x 35 | z = y**2 36 | register_hook(y, 'y') 37 | z.backward() 38 | 39 | -------------------------------------------------------------------------------- /inference_compare.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import time 5 | from utils.util import AverageMeter 6 | # normal conv2d stacked 50 layers 7 | 8 | 9 | class separable_conv2d(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, dilation=1): 11 | """ 12 | in_channels = out_channels = groups, using thcunn backends of pytorch, otherwise using cudnn 13 | """ 14 | super(separable_conv2d, self).__init__() 15 | self.depthwise_conv2d = nn.Conv2d(in_channels=in_channels, 16 | out_channels=out_channels, 17 | kernel_size=kernel_size, stride=1, 18 | padding=padding, dilation=dilation, groups=in_channels) 19 | self.pointwise_conv2d = nn.Conv2d(in_channels=in_channels, 20 | out_channels=out_channels, 21 | kernel_size=1, stride=1, 22 | padding=0, dilation=dilation, groups=1) 23 | def forward(self, input): 24 | output = self.depthwise_conv2d(input) 25 | output = self.pointwise_conv2d(output) 26 | return output 27 | 28 | 29 | class normal_convnet(nn.Module): 30 | def __init__(self): 31 | super(normal_convnet, self).__init__() 32 | self.layer_list = nn.ModuleList() 33 | self.layer_list.append(nn.Conv2d(3, 256, 3, padding=1, groups=1)) 34 | for index in range(48): 35 | self.layer_list.append(nn.Conv2d(256, 256, 3, padding=1, groups=1)) 36 | self.layer_list.append(nn.Conv2d(256, 10, 3, padding=1, groups=1)) 37 | 38 | def forward(self, x): 39 | for layer in self.layer_list: 40 | x = layer(x) 41 | return x 42 | 43 | class sep_convnet(nn.Module): 44 | def __init__(self): 45 | super(sep_convnet, self).__init__() 46 | self.layer_list = nn.ModuleList() 47 | self.layer_list.append(nn.Conv2d(3, 256, 3, padding=1, groups=1)) 48 | for index in range(48): 49 | self.layer_list.append(separable_conv2d(256,256,3,1)) 50 | self.layer_list.append(nn.Conv2d(256, 10, 3, padding=1, groups=1)) 51 | 52 | def forward(self, x): 53 | for layer in self.layer_list: 54 | x = layer(x) 55 | 56 | return x 57 | 58 | def inference_test_both(): 59 | 60 | #torch.manual_seed(1) 61 | #torch.backends.cudnn.enabled = False 62 | #torch.backends.cudnn.benchmark = True 63 | #torch.backends.cudnn.deterministic=True 64 | 65 | random_input = torch.randn((1, 3, 256, 256)) 66 | random_output = torch.randint(low=0, high=10, size=(1,256,256)) # for 0,1,2,3,4,5,6,7,8,9 67 | #random_output = torch.randn((1,3,256,256)) 68 | 69 | net1 = normal_convnet().to('cuda:0') 70 | net2 = sep_convnet().to('cuda:0') 71 | 72 | # params = 0 73 | #params_normal = sum(p.numel() for p in normal_net.parameters() if p.requires_grad) 74 | #print("Trainable Parameters :", params_normal) 75 | 76 | #criterion = nn.MSELoss().to('cuda:0') 77 | criterion = nn.CrossEntropyLoss().to('cuda:0') 78 | 79 | #optimizer_1 = torch.optim.Adam(params=net1.parameters(), lr=0.1) 80 | #optimizer_2 = torch.optim.Adam(params=net2.parameters(), lr=0.1) 81 | optimizer_1 = torch.optim.SGD(params=net1.parameters(),lr=0.1) 82 | optimizer_2 = torch.optim.SGD(params=net2.parameters(),lr=0.2) 83 | # 84 | 85 | cost1 = AverageMeter() 86 | cost2 = AverageMeter() 87 | print("Simulate Training ... ...") 88 | 89 | 90 | input1 = random_input.to('cuda:0') 91 | target1 = random_output.to('cuda:0') 92 | torch.cuda.synchronize() 93 | tic = time.time() 94 | optimizer_1.zero_grad() 95 | output1 = net1(input1) 96 | loss = criterion(output1, target1) 97 | loss.backward() 98 | optimizer_1.step() 99 | torch.cuda.synchronize() 100 | cost1.update(time.time() - tic) 101 | 102 | #print(dw_net) 103 | #params_dw = sum(p.numel() for p in normal_net.parameters() if p.requires_grad) 104 | #print("Trainable Parameters :", params_dw) 105 | #optimizer_dw = torch.optim.Adam(params=dw_net.parameters()) 106 | 107 | input2 = random_input.to('cuda:0') 108 | target2 = random_output.to('cuda:0') 109 | torch.cuda.synchronize() 110 | tic = time.time() 111 | optimizer_1.zero_grad() 112 | output2 = net2(input2) 113 | loss = criterion(output2, target2) 114 | loss.backward() 115 | optimizer_2.step() 116 | torch.cuda.synchronize() 117 | cost2.update(time.time() - tic) 118 | 119 | print("Done for All !") 120 | 121 | print("Trainable Parameters:\n" 122 | "Normal_conv2d: {}\n" 123 | "Sep_conv2d : {}".format(parameters_sum(net1), parameters_sum(net2))) 124 | print("Inference Time cost:\n" 125 | "Normal_conv2d: {}s\n" 126 | "Sep_conv2d : {}s".format(cost1._get_sum(), cost2._get_sum())) 127 | 128 | 129 | def parameters_sum(model): 130 | 131 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 132 | 133 | 134 | 135 | if __name__ == '__main__': 136 | 137 | ''' 138 | #for separable conv2d 139 | #Duration: 0.07718038558959961 for normal 140 | #Duration: 0.041310787200927734 for separable conv2d including depthwise and pointwise 141 | input = torch.randn(size=(1,256,256,256)) 142 | target = torch.randn(size=(1,256,256,256)) 143 | 144 | m1 = nn.Conv2d(256,256,3,padding=1, groups=1,bias=False).to('cuda:0') 145 | m2 = nn.Conv2d(256,256,3,padding=1, groups=256,bias=False).to('cuda:0') 146 | m2_p = nn.Conv2d(256,256,1,padding=0, groups=1,bias=False).to('cuda:0') 147 | criterion = nn.MSELoss().to('cuda:0') 148 | optimizer_m1 = torch.optim.SGD(params=m1.parameters(),lr=0.1) 149 | optimizer_m2 = torch.optim.SGD(params=m2.parameters(),lr=0.1) 150 | 151 | tic = time.time() 152 | input1 = input.to('cuda:0') 153 | target1 = target.to('cuda:0') 154 | out1 = m1(input1) 155 | loss = criterion(out1, target1) 156 | optimizer_m1.zero_grad() 157 | loss.backward() 158 | optimizer_m1.step() 159 | print("Duration: ", time.time()-tic) 160 | 161 | tic = time.time() 162 | input2 = input.to('cuda:0') 163 | target2 = target.to('cuda:0') 164 | out2 = m2(input2) 165 | out2 = m2_p(out2) 166 | loss = criterion(out2, target2) 167 | optimizer_m2.zero_grad() 168 | loss.backward() 169 | optimizer_m2.step() 170 | print("Duration: ", time.time()-tic) 171 | ''' 172 | inference_test_both() 173 | """ 174 | output: for 50 layers stacked up and run 100 iterations 175 | Normal_conv2d: 17.160080671310425s 176 | sep_conv2d : 7.4953773021698s 177 | """ 178 | """ 179 | output: net1 trainable parameters: 29504000 180 | net2 trainable parameters: 3417600 181 | 182 | 29504000 / 3417600 = 8.6 183 | net1 = normal_convnet() 184 | net2 = sep_convnet() 185 | print("net1 trainable parameters: ", parameters_sum(net1)) 186 | print("net2 trainable parameters: ", parameters_sum(net2)) 187 | 188 | """ 189 | -------------------------------------------------------------------------------- /launcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import glob 5 | import random 6 | import cv2 7 | import torchvision.transforms.functional as TF 8 | import numpy as np 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import torch.utils.data as data 12 | from utils.dataset import Cropper 13 | from predict import Predictor, dataset_predict 14 | from configs.config import MyConfiguration 15 | from models.MyNetworks.ESFNet import ESFNet 16 | from utils.unpatchy import unpatchify 17 | ''' 18 | # instructions high-resolution images are saved in '--input' 19 | # then we use Cropper, get patches and saved in '--input/image_patches/' 20 | # use torch.utils.data.Dataset and torch.utils.data.DataLoader to load data 21 | # get predictions by the pre-trained model 22 | # save the output patches in '--output/patches' 23 | # re-merge the output patches into high-resolution images and save them in '--output/remerge' 24 | ''' 25 | 26 | def config_parser(): 27 | 28 | parser = argparse.ArgumentParser(description='configurations') 29 | parser.add_argument('--gpu', type=int, default=0, 30 | help='0 and 1 means gpu id, and -1 means cpu') 31 | parser.add_argument('-i', '--input', type=str, default=os.path.join('.', 'input'), 32 | help='directory of input images, including images used to train and predict') 33 | parser.add_argument('-o', '--output', type=str, default=os.path.join('.', 'output'), 34 | help='directory of output images, for predictions') 35 | parser.add_argument('--ckpt_path', type=str, default=os.path.join('.', 'checkpoint-best.pth'), 36 | help='path to the checkpoint file, default name checkpoint-best.pth') 37 | # dataloader settings 38 | parser.add_argument('--batch_size', type=int, default=1, 39 | help='batch_size') 40 | parser.add_argument('--pin_memory', type=bool, default=False, 41 | help='When True, it will accelerate the prediction phase but with high CPU-Utilization, and it ' 42 | 'will also allocate additional GPU-Memory') 43 | parser.add_argument('--nb_workers', type=int, default=1, 44 | help='workers for DataLoader') 45 | # patches settings, some configs have already included in config.cfg 46 | parser.add_argument('--image_margin_color', type=list, default=[255, 255, 255], 47 | help='the color of image margin color') 48 | parser.add_argument('--label_margin_color', type=list, default=[255, 255, 255], 49 | help='the color of label margin color') 50 | 51 | return parser.parse_args() 52 | 53 | def main(): 54 | 55 | args = config_parser() 56 | config = MyConfiguration() 57 | 58 | # for duplicating 59 | torch.backends.cudnn.benchmark = True 60 | torch.backends.cudnn.deterministic = True 61 | torch.manual_seed(config.random_seed) 62 | random.seed(config.random_seed) 63 | np.random.seed(config.random_seed) 64 | 65 | # model load the pre-trained weight, load ckpt once out of predictor 66 | model = ESFNet(config=config).to('cuda:{}'.format(args.gpu) if args.gpu >= 0 else 'cpu') 67 | ckpt = torch.load(args.ckpt_path, map_location='cuda:{}'.format(args.gpu) if args.gpu >=0 else 'cpu') 68 | model.load_state_dict(ckpt['state_dict']) 69 | 70 | # path for each high-resolution images -> crop -> predict -> merge 71 | source_image_pathes = glob.glob(os.path.join(args.input, '*.png')) 72 | for source_image in tqdm(source_image_pathes): 73 | # get high-resolution image name 74 | filename = source_image.split('/')[-1].split('.')[0] 75 | # cropper get patches and save to --input/patches 76 | c = Cropper(args=args, configs=config, predict=True) 77 | _, n_w, n_h, image_h, image_w = c.image_processor(image_path=source_image) 78 | my_dataset = dataset_predict(args=args) 79 | my_dataloader = data.DataLoader(my_dataset, batch_size=args.batch_size, shuffle=False, 80 | pin_memory=args.pin_memory, drop_last=False, num_workers=args.nb_workers) 81 | 82 | # predict using pre-trained network 83 | p = Predictor(args=args, model=model, dataloader_predict=my_dataloader) 84 | p.predict() 85 | # patches [total_size, C, H, W] p.patches tensor -> reshape -> [total_size, H, W, C] 86 | patches_tensor = torch.transpose(p.patches, 1, 3) 87 | patches_tensor = patches_tensor.view(n_h, n_w, config.cropped_size, config.cropped_size, 3) 88 | # merge and save the output image 89 | patches = patches_tensor.cpu().numpy() 90 | img = unpatchify(patches, image_h, image_w) 91 | #img = Image.fromarray(img) 92 | save_path = os.path.join(args.output, 'remerge', filename+'.png') 93 | cv2.imwrite(save_path, img) 94 | #img.save(save_path) 95 | 96 | def patchify_and_unpatchify_test(): 97 | 98 | args = config_parser() 99 | config = MyConfiguration() 100 | 101 | source_image_path = os.path.join(args.input, 'top_mosaic_09cm_area17.tif') 102 | filename = source_image_path.split('/')[-1].split('.')[0] 103 | c = Cropper(args=args, configs=config, predict=True) 104 | patches, n_w, n_h, image_h, image_w = c.image_processor(image_path=source_image_path) 105 | # patch list -> np.array [total_size, H, W, C] 106 | np_patches = np.asarray(patches) 107 | np_patches = np_patches.reshape(n_h, n_w, config.cropped_size, config.cropped_size, 3) 108 | img = unpatchify(np_patches, image_h, image_w) 109 | save_path = os.path.join(args.output, 'remerge', filename+'.png') 110 | cv2.imwrite(save_path, img) 111 | 112 | if __name__ == '__main__': 113 | 114 | #main() 115 | patchify_and_unpatchify_test() -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | ''' 5 | Tensor has no attr 'copy' should use 'clone' 6 | pred 's requires_grad = True 7 | .clone().cpu().numpy() 8 | ''' 9 | def Accuracy(pred, label): 10 | 11 | with torch.no_grad(): 12 | pred = torch.argmax(pred, dim=1) 13 | pred = pred.view(-1) 14 | label = label.view(-1) 15 | # ignore 0 background 16 | valid = (label > 0).long() 17 | # convert to float() 做除法的时候分子和分母都要转换成 float 如果是long 则会出现zero 18 | # .long() convert boolean to long then .float() convert to float 19 | # 合法的 pred == label的 pixel总数 20 | acc_sum = torch.sum(valid * (pred == label).long()).float() 21 | # 合法的pixel总数 22 | pixel_sum = torch.sum(valid).float() 23 | # epsilon 24 | acc = acc_sum / (pixel_sum + 1e-10) 25 | return acc 26 | 27 | 28 | def MIoU(pred, label, nb_classes): 29 | 30 | with torch.no_grad(): 31 | pred = torch.argmax(pred, dim=1) 32 | pred = pred.view(-1) 33 | label = label.view(-1) 34 | iou = torch.zeros(nb_classes ).to(pred.device) 35 | for k in range(1, nb_classes): 36 | # pred_inds ,target_inds boolean map 37 | pred_inds = pred == k 38 | target_inds = label == k 39 | intersection = pred_inds[target_inds].long().sum().float() 40 | union = (pred_inds.long().sum() + target_inds.long().sum() - intersection).float() 41 | 42 | iou[k] = (intersection/ (union+1e-10)) 43 | 44 | return (iou.sum()/ (nb_classes-1)) 45 | 46 | 47 | if __name__ == '__main__': 48 | a = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 49 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 50 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 51 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 52 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 53 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 54 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 55 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 56 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 57 | [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],]) 58 | b = np.array([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 59 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 60 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 61 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 62 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 63 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 64 | [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 65 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 66 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 67 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],]) 68 | 69 | a = torch.from_numpy(a).long() 70 | b = torch.from_numpy(b).long() 71 | 72 | # intersection = 16, union = 73 | acc = Accuracy(a,b) 74 | print(acc) 75 | -------------------------------------------------------------------------------- /models/BaseModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class BaseModel(nn.Module): 6 | """ 7 | Base class for all models 8 | """ 9 | def __init__(self): 10 | super(BaseModel, self).__init__() 11 | 12 | def forward(self, *input): 13 | """ 14 | Foward pass logic 15 | :param input: 16 | :return: model output 17 | """ 18 | raise NotImplementedError 19 | 20 | def summary(self): 21 | """ 22 | Model Summary 23 | :return: None 24 | """ 25 | # 把可训练参数挑出来 26 | #model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 27 | # np.prod 用来计算所有元素的乘积 axis 内 28 | #params = sum([np.prod(p.size()) for p in model_parameters]) 29 | params = sum(p.numel() for p in self.parameters() if p.requires_grad) 30 | print("Trainable parameters: {}".format(params)) 31 | 32 | def __str__(self): 33 | """ 34 | Model prints with number of trainable paramters 35 | :return: 36 | """ 37 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 38 | params = sum([np.prod(p.size()) for p in model_parameters]) 39 | return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params) 40 | -------------------------------------------------------------------------------- /models/EDANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from models.BaseModel import BaseModel 6 | 7 | class DownsamplerBlock(nn.Module): 8 | def __init__(self, ninput, noutput): 9 | super(DownsamplerBlock, self).__init__() 10 | 11 | self.ninput = ninput 12 | self.noutput = noutput 13 | 14 | if self.ninput < self.noutput: 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, 16 | kernel_size=3, stride=2, padding=1) 17 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 18 | 19 | else: 20 | self.conv = nn.Conv2d(ninput, noutput, 21 | kernel_size=3, stride=2, padding=1) 22 | 23 | self.bn = nn.BatchNorm2d(noutput) 24 | 25 | def forward(self, input): 26 | if self.ninput < self.noutput: 27 | output = torch.cat([self.conv(input), self.pool(input)], 1) 28 | else: 29 | output = self.conv(input) 30 | 31 | output = self.bn(output) 32 | return F.relu(output) 33 | 34 | class EDABlock(nn.Module): 35 | def __init__(self, ninput, dilated, k=40, dropprob=0.02): 36 | super(EDABlock, self).__init__() 37 | 38 | self.conv1x1 = nn.Conv2d(ninput, k, kernel_size=1) 39 | self.bn0 = nn.BatchNorm2d(k) 40 | 41 | self.conv3x1_1 = nn.Conv2d(k, k, kernel_size=(3,1), padding=(1,0)) 42 | self.conv1x3_1 = nn.Conv2d(k, k, kernel_size=(1,3), padding=(0,1)) 43 | self.bn1 = nn.BatchNorm2d(k) 44 | # ConvLayer with dilated_rate padding [(kernel_size-1)/2]*(dilated_rate-1)+1 45 | # ConvLayer (kernel_size-1)/2 46 | self.conv3x1_2 = nn.Conv2d(k, k, kernel_size=(3,1), padding=(dilated,0), dilation=dilated) 47 | self.conv1x3_2 = nn.Conv2d(k, k, kernel_size=(1,3), padding=(0,dilated), dilation=dilated) 48 | self.bn2 = nn.BatchNorm2d(k) 49 | 50 | self.dropout = nn.Dropout2d(dropprob) 51 | 52 | def forward(self, input): 53 | 54 | x = input 55 | output = self.conv1x1(input) 56 | output = self.bn0(output) 57 | output = F.relu(output) 58 | 59 | output = self.conv3x1_1(output) 60 | output = self.conv1x3_1(output) 61 | output = self.bn1(output) 62 | output = F.relu(output) 63 | 64 | output = self.conv3x1_2(output) 65 | output = self.conv1x3_2(output) 66 | output = self.bn2(output) 67 | output = F.relu(output) 68 | 69 | if self.dropout.p != 0: 70 | output = self.dropout(output) 71 | output = torch.cat([output, x], 1) 72 | 73 | return output 74 | 75 | class EDANet(BaseModel): 76 | def __init__(self, config): 77 | super(EDANet, self).__init__() 78 | 79 | self.name='EDANet' 80 | self.nb_classes = config.nb_classes 81 | 82 | self.layers = nn.ModuleList() 83 | # for stage1 84 | self.dilation1 = [1,1,1,2,2] 85 | # for stage2 86 | self.dilation2 = [2,2,4,4,8,8,16,16] 87 | 88 | self.layers.append(DownsamplerBlock(3,15)) 89 | self.layers.append(DownsamplerBlock(15,60)) 90 | 91 | for i in range(5): 92 | self.layers.append(EDABlock(60+40*i, self.dilation1[i])) 93 | 94 | self.layers.append(DownsamplerBlock(260,130)) 95 | 96 | for j in range(8): 97 | self.layers.append(EDABlock(130+40*j, self.dilation2[j])) 98 | 99 | # projection layer 100 | self.project_layer = nn.Conv2d(450, self.nb_classes, kernel_size=1) 101 | 102 | self.weights_init() 103 | 104 | 105 | def weights_init(self): 106 | for idx, m in enumerate(self.modules()): 107 | classname = m.__class__.__name__ 108 | if classname.find('Conv') != -1: 109 | init.kaiming_normal_(m.weight.data) 110 | elif classname.find('BatchNorm') != -1: 111 | m.weight.data.normal_(1.0, 0.02) 112 | m.bias.data.fill_(0) 113 | 114 | def forward(self,x): 115 | output = x 116 | 117 | for layer in self.layers: 118 | output = layer(output) 119 | 120 | output = self.project_layer(output) 121 | 122 | # bilinear interpolation x8 123 | output = F.interpolate(output, scale_factor=8, mode='bilinear', align_corners=True) 124 | 125 | # bilinear interpolation x2 126 | #if not self.training: 127 | # output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=True) 128 | 129 | return output 130 | 131 | if __name__ == '__main__': 132 | 133 | input = torch.randn(1,3,512,512) 134 | # for the inference only 135 | model = EDANet().eval() 136 | print(model) 137 | output = model(input) 138 | print(output.shape) 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /models/ENet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from models.BaseModel import BaseModel 4 | import torch.nn.functional as F 5 | 6 | 7 | class InitialBlock(nn.Module): 8 | def __init__(self, in_channels,out_channels, kernel_size, padding=0, bias=False,relu=True): 9 | super(InitialBlock, self).__init__() 10 | 11 | if relu: 12 | activation = nn.ReLU() 13 | else: 14 | activation = nn.PReLU() 15 | 16 | self.main_branch = nn.Conv2d( 17 | in_channels, 18 | out_channels-3, 19 | kernel_size=kernel_size, 20 | stride=2, 21 | padding=padding, 22 | bias=bias, 23 | ) 24 | # MP need padding too 25 | self.ext_branch = nn.MaxPool2d(kernel_size, stride=2, padding=padding) 26 | self.batch_norm = nn.BatchNorm2d(out_channels) 27 | self.out_prelu = activation 28 | 29 | 30 | def forward(self, input): 31 | main = self.main_branch(input) 32 | ext = self.ext_branch(input) 33 | 34 | out = torch.cat((main, ext), dim=1) 35 | 36 | out = self.batch_norm(out) 37 | return self.out_prelu(out) 38 | 39 | class RegularBottleneck(nn.Module): 40 | def __init__(self, channels, internal_ratio=4, kernel_size=3, padding=0, 41 | dilation=1, asymmetric=False, dropout_prob=0., bias=False, relu=True): 42 | super(RegularBottleneck, self).__init__() 43 | 44 | internal_channels = channels // internal_ratio 45 | 46 | if relu: 47 | activation = nn.ReLU() 48 | else: 49 | activation = nn.PReLU() 50 | 51 | # 1x1 projection conv 52 | self.ext_conv1 = nn.Sequential( 53 | nn.Conv2d(channels, internal_channels, kernel_size=1, stride=1, bias=bias), 54 | nn.BatchNorm2d(internal_channels), 55 | activation, 56 | ) 57 | if asymmetric: 58 | self.ext_conv2 = nn.Sequential( 59 | nn.Conv2d(internal_channels, internal_channels, kernel_size=(kernel_size,1), 60 | stride=1, padding=(padding,0), dilation=dilation, bias=bias), 61 | nn.BatchNorm2d(internal_channels), 62 | activation, 63 | nn.Conv2d(internal_channels, internal_channels, kernel_size=(1,kernel_size), 64 | stride=1, padding=(0, padding), dilation=dilation, bias=bias), 65 | nn.BatchNorm2d(internal_channels), 66 | activation, 67 | ) 68 | else: 69 | self.ext_conv2 = nn.Sequential( 70 | nn.Conv2d(internal_channels, internal_channels, kernel_size=kernel_size, 71 | stride=1, padding=padding, dilation=dilation, bias=bias), 72 | nn.BatchNorm2d(internal_channels), 73 | activation, 74 | ) 75 | 76 | self.ext_conv3 = nn.Sequential( 77 | nn.Conv2d(internal_channels, channels, kernel_size=1, stride=1, bias=bias), 78 | nn.BatchNorm2d(channels), 79 | activation, 80 | ) 81 | self.ext_regu1 = nn.Dropout2d(p=dropout_prob) 82 | self.out_prelu = activation 83 | 84 | def forward(self, input): 85 | main = input 86 | 87 | ext = self.ext_conv1(input) 88 | ext = self.ext_conv2(ext) 89 | ext = self.ext_conv3(ext) 90 | ext = self.ext_regu1(ext) 91 | 92 | out = main + ext 93 | return self.out_prelu(out) 94 | 95 | class DownsamplingBottleneck(nn.Module): 96 | def __init__(self, 97 | in_channels, 98 | out_channels, 99 | internal_ratio=4, 100 | kernel_size=3, 101 | padding=0, 102 | return_indices=False, 103 | dropout_prob=0., 104 | bias=False, 105 | relu=True): 106 | super().__init__() 107 | 108 | # Store parameters that are needed later 109 | self.return_indices = return_indices 110 | 111 | internal_channels = in_channels // internal_ratio 112 | 113 | if relu: 114 | activation = nn.ReLU() 115 | else: 116 | activation = nn.PReLU() 117 | 118 | # Main branch - max pooling followed by feature map (channels) padding 119 | self.main_max1 = nn.MaxPool2d( 120 | kernel_size, 121 | stride=2, 122 | padding=padding, 123 | return_indices=return_indices) 124 | 125 | # Extension branch - 2x2 convolution, followed by a regular, dilated or 126 | # asymmetric convolution, followed by another 1x1 convolution. Number 127 | # of channels is doubled. 128 | 129 | # 2x2 projection convolution with stride 2, no padding 130 | self.ext_conv1 = nn.Sequential( 131 | nn.Conv2d(in_channels,internal_channels,kernel_size=2,stride=2,bias=bias), 132 | nn.BatchNorm2d(internal_channels), 133 | activation 134 | ) 135 | 136 | # Convolution 137 | self.ext_conv2 = nn.Sequential( 138 | nn.Conv2d( 139 | internal_channels, 140 | internal_channels, 141 | kernel_size=kernel_size, 142 | stride=1, 143 | padding=padding, 144 | bias=bias), nn.BatchNorm2d(internal_channels), activation) 145 | 146 | # 1x1 expansion convolution 147 | self.ext_conv3 = nn.Sequential( 148 | nn.Conv2d( 149 | internal_channels, 150 | out_channels, 151 | kernel_size=1, 152 | stride=1, 153 | bias=bias), nn.BatchNorm2d(out_channels), activation) 154 | 155 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 156 | 157 | # PReLU layer to apply after concatenating the branches 158 | self.out_prelu = activation 159 | 160 | def forward(self, x): 161 | # Main branch shortcut 162 | if self.return_indices: 163 | main, max_indices = self.main_max1(x) 164 | else: 165 | main = self.main_max1(x) 166 | 167 | # Extension branch 168 | ext = self.ext_conv1(x) 169 | ext = self.ext_conv2(ext) 170 | ext = self.ext_conv3(ext) 171 | ext = self.ext_regul(ext) 172 | 173 | # Main branch channel padding 174 | # calculate for padding ch_ext - ch_main 175 | n, ch_ext, h, w = ext.size() 176 | ch_main = main.size()[1] 177 | padding = torch.zeros(n, ch_ext - ch_main, h, w) 178 | 179 | # Before concatenating, check if main is on the CPU or GPU and 180 | # convert padding accordingly 181 | if main.is_cuda: 182 | padding = padding.cuda() 183 | 184 | # Concatenate, padding for less channels of main branch 185 | main = torch.cat((main, padding), 1) 186 | 187 | # Add main and extension branches 188 | out = main + ext 189 | 190 | return self.out_prelu(out), max_indices 191 | 192 | class UpsamplingBottleneck(nn.Module): 193 | def __init__(self, 194 | in_channels, 195 | out_channels, 196 | internal_ratio=4, 197 | kernel_size=3, 198 | padding=0, 199 | dropout_prob=0., 200 | bias=False, 201 | relu=True): 202 | super().__init__() 203 | 204 | internal_channels = in_channels // internal_ratio 205 | 206 | if relu: 207 | activation = nn.ReLU() 208 | else: 209 | activation = nn.PReLU() 210 | 211 | # Main branch - max pooling followed by feature map (channels) padding 212 | self.main_conv1 = nn.Sequential( 213 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias), 214 | nn.BatchNorm2d(out_channels)) 215 | 216 | # Remember that the stride is the same as the kernel_size, just like 217 | # the max pooling layers 218 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2) 219 | 220 | # Extension branch - 1x1 convolution, followed by a regular, dilated or 221 | # asymmetric convolution, followed by another 1x1 convolution. Number 222 | # of channels is doubled. 223 | 224 | # 1x1 projection convolution with stride 1 225 | self.ext_conv1 = nn.Sequential( 226 | nn.Conv2d( 227 | in_channels, internal_channels, kernel_size=1, bias=bias), 228 | nn.BatchNorm2d(internal_channels), activation) 229 | 230 | # Transposed convolution 231 | self.ext_conv2 = nn.Sequential( 232 | nn.ConvTranspose2d( 233 | internal_channels, 234 | internal_channels, 235 | kernel_size=kernel_size, 236 | stride=2, 237 | padding=padding, 238 | output_padding=1, 239 | bias=bias), nn.BatchNorm2d(internal_channels), activation) 240 | 241 | # 1x1 expansion convolution 242 | self.ext_conv3 = nn.Sequential( 243 | nn.Conv2d( 244 | internal_channels, out_channels, kernel_size=1, bias=bias), 245 | nn.BatchNorm2d(out_channels), activation) 246 | 247 | self.ext_regul = nn.Dropout2d(p=dropout_prob) 248 | 249 | # PReLU layer to apply after concatenating the branches 250 | self.out_prelu = activation 251 | 252 | def forward(self, x, max_indices): 253 | # Main branch shortcut 254 | main = self.main_conv1(x) 255 | main = self.main_unpool1(main, max_indices) 256 | # Extension branch 257 | ext = self.ext_conv1(x) 258 | ext = self.ext_conv2(ext) 259 | ext = self.ext_conv3(ext) 260 | ext = self.ext_regul(ext) 261 | 262 | # Add main and extension branches 263 | out = main + ext 264 | 265 | return self.out_prelu(out) 266 | 267 | class ENet(nn.Module): 268 | def __init__(self, num_classes, encoder_relu=False, decoder_relu=True): 269 | super().__init__() 270 | # source code 271 | self.name='BaseLine_ENet_trans' 272 | 273 | self.initial_block = InitialBlock(3, 16, kernel_size=3 ,padding=1, relu=encoder_relu) 274 | 275 | # Stage 1 - Encoder 276 | self.downsample1_0 = DownsamplingBottleneck( 277 | 16, 278 | 64, 279 | padding=1, 280 | return_indices=True, 281 | dropout_prob=0.01, 282 | relu=encoder_relu) 283 | self.regular1_1 = RegularBottleneck( 284 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 285 | self.regular1_2 = RegularBottleneck( 286 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 287 | self.regular1_3 = RegularBottleneck( 288 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 289 | self.regular1_4 = RegularBottleneck( 290 | 64, padding=1, dropout_prob=0.01, relu=encoder_relu) 291 | 292 | # Stage 2 - Encoder 293 | self.downsample2_0 = DownsamplingBottleneck( 294 | 64, 295 | 128, 296 | padding=1, 297 | return_indices=True, 298 | dropout_prob=0.1, 299 | relu=encoder_relu) 300 | self.regular2_1 = RegularBottleneck( 301 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 302 | self.dilated2_2 = RegularBottleneck( 303 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 304 | self.asymmetric2_3 = RegularBottleneck( 305 | 128, 306 | kernel_size=5, 307 | padding=2, 308 | asymmetric=True, 309 | dropout_prob=0.1, 310 | relu=encoder_relu) 311 | self.dilated2_4 = RegularBottleneck( 312 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 313 | self.regular2_5 = RegularBottleneck( 314 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 315 | self.dilated2_6 = RegularBottleneck( 316 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 317 | self.asymmetric2_7 = RegularBottleneck( 318 | 128, 319 | kernel_size=5, 320 | asymmetric=True, 321 | padding=2, 322 | dropout_prob=0.1, 323 | relu=encoder_relu) 324 | self.dilated2_8 = RegularBottleneck( 325 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 326 | 327 | # Stage 3 - Encoder 328 | self.regular3_0 = RegularBottleneck( 329 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 330 | self.dilated3_1 = RegularBottleneck( 331 | 128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu) 332 | self.asymmetric3_2 = RegularBottleneck( 333 | 128, 334 | kernel_size=5, 335 | padding=2, 336 | asymmetric=True, 337 | dropout_prob=0.1, 338 | relu=encoder_relu) 339 | self.dilated3_3 = RegularBottleneck( 340 | 128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu) 341 | self.regular3_4 = RegularBottleneck( 342 | 128, padding=1, dropout_prob=0.1, relu=encoder_relu) 343 | self.dilated3_5 = RegularBottleneck( 344 | 128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu) 345 | self.asymmetric3_6 = RegularBottleneck( 346 | 128, 347 | kernel_size=5, 348 | asymmetric=True, 349 | padding=2, 350 | dropout_prob=0.1, 351 | relu=encoder_relu) 352 | self.dilated3_7 = RegularBottleneck( 353 | 128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu) 354 | 355 | # Stage 4 - Decoder 356 | self.upsample4_0 = UpsamplingBottleneck( 357 | 128, 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 358 | self.regular4_1 = RegularBottleneck( 359 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 360 | self.regular4_2 = RegularBottleneck( 361 | 64, padding=1, dropout_prob=0.1, relu=decoder_relu) 362 | 363 | # Stage 5 - Decoder 364 | self.upsample5_0 = UpsamplingBottleneck( 365 | 64, 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 366 | self.regular5_1 = RegularBottleneck( 367 | 16, padding=1, dropout_prob=0.1, relu=decoder_relu) 368 | self.transposed_conv = nn.ConvTranspose2d( 369 | 16, 370 | num_classes, 371 | kernel_size=3, 372 | stride=2, 373 | padding=1, 374 | output_padding=1, 375 | bias=False) 376 | 377 | self.project_layer = nn.Conv2d(128, num_classes, 1, bias=False) 378 | 379 | def forward(self, x): 380 | # Initial block 381 | x = self.initial_block(x) 382 | 383 | # Stage 1 - Encoder 384 | x, max_indices1_0 = self.downsample1_0(x) 385 | x = self.regular1_1(x) 386 | x = self.regular1_2(x) 387 | x = self.regular1_3(x) 388 | x = self.regular1_4(x) 389 | 390 | # Stage 2 - Encoder 391 | x, max_indices2_0 = self.downsample2_0(x) 392 | x = self.regular2_1(x) 393 | x = self.dilated2_2(x) 394 | x = self.asymmetric2_3(x) 395 | x = self.dilated2_4(x) 396 | x = self.regular2_5(x) 397 | x = self.dilated2_6(x) 398 | x = self.asymmetric2_7(x) 399 | x = self.dilated2_8(x) 400 | 401 | # Stage 3 - Encoder 402 | x = self.regular3_0(x) 403 | x = self.dilated3_1(x) 404 | x = self.asymmetric3_2(x) 405 | x = self.dilated3_3(x) 406 | x = self.regular3_4(x) 407 | x = self.dilated3_5(x) 408 | x = self.asymmetric3_6(x) 409 | x = self.dilated3_7(x) 410 | 411 | #x = self.project_layer(x) 412 | #x = F.interpolate(x, scale_factor=8, mode='bilinear', align_corners=True) 413 | 414 | # Stage 4 - Decoder 415 | x = self.upsample4_0(x, max_indices2_0) 416 | x = self.regular4_1(x) 417 | x = self.regular4_2(x) 418 | 419 | # Stage 5 - Decoder 420 | x = self.upsample5_0(x, max_indices1_0) 421 | x = self.regular5_1(x) 422 | x = self.transposed_conv(x) 423 | 424 | 425 | return x 426 | 427 | if __name__ == '__main__': 428 | 429 | 430 | model = ENet(num_classes=2, encoder_relu=True, decoder_relu=True) 431 | -------------------------------------------------------------------------------- /models/ERFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.BaseModel import BaseModel 5 | 6 | 7 | class down_sampling_block(nn.Module): 8 | 9 | def __init__(self, inpc, oupc): 10 | super(down_sampling_block, self).__init__() 11 | self.branch_conv = nn.Conv2d(inpc, oupc-inpc, 3, stride=2, padding= 1, bias=False) 12 | self.branch_mp = nn.MaxPool2d(kernel_size=2, stride=2) 13 | self.bn = nn.BatchNorm2d(oupc, eps=1e-03) 14 | 15 | def forward(self, x): 16 | output = torch.cat([self.branch_conv(x), self.branch_mp(x)], 1) 17 | output = self.bn(output) 18 | return F.relu(output) 19 | 20 | class up_sampling_block(nn.Module): 21 | 22 | def __init__(self, inpc, oupc): 23 | super(up_sampling_block, self).__init__() 24 | self.conv = nn.ConvTranspose2d(inpc, oupc, 3, 2, padding=1, output_padding=1, bias=False) 25 | self.bn = nn.BatchNorm2d(oupc) 26 | 27 | def forward(self, input): 28 | output = self.conv(input) 29 | output = self.bn(output) 30 | return F.relu(output) 31 | 32 | class non_bottleneck_1d(nn.Module): 33 | # stride = 1 kernel_size=3 34 | def __init__(self, inpc, oupc, dilated_rate, dropout_rate): 35 | super(non_bottleneck_1d, self).__init__() 36 | self.conv3x1_1 = nn.Conv2d(inpc, oupc, (3, 1), stride=1, padding=(1,0), bias=False) 37 | self.conv1x3_1 = nn.Conv2d(inpc, oupc, (1, 3), stride=1, padding=(0,1), bias=False) 38 | self.bn1 = nn.BatchNorm2d(oupc) 39 | self.conv3x1_2 = nn.Conv2d(inpc, oupc, (3, 1), stride=1, padding=(1*dilated_rate,0), dilation=(dilated_rate,1), bias=False) 40 | self.conv1x3_2 = nn.Conv2d(inpc, oupc, (1, 3), stride=1, padding=(0,1*dilated_rate), dilation=(1,dilated_rate), bias=False) 41 | self.bn2 = nn.BatchNorm2d(oupc) 42 | self.dropout = nn.Dropout2d(dropout_rate) 43 | 44 | def forward(self, input): 45 | output = self.conv3x1_1(input) 46 | output = F.relu(output) 47 | output = self.conv1x3_1(output) 48 | output = self.bn1(output) 49 | output = F.relu(output) 50 | 51 | output = self.conv3x1_2(output) 52 | output = F.relu(output) 53 | output = self.conv1x3_2(output) 54 | output = self.bn2(output) 55 | 56 | if self.dropout.p != 0: 57 | output = self.dropout(output) 58 | output += input 59 | return F.relu(output) 60 | 61 | class ERFNet(BaseModel): 62 | 63 | def __init__(self, 64 | config): 65 | super(ERFNet, self).__init__() 66 | 67 | self.config = config 68 | self.name="ERFNet" 69 | self.nb_classes = self.config.nb_classes 70 | #self.nb_classes = nb_classes 71 | 72 | # input 512 73 | self.initial_block = down_sampling_block(3, 16) 74 | # output 256 75 | self.layers = nn.ModuleList() 76 | 77 | self.layers.append(down_sampling_block(16, 64)) 78 | # output 128 79 | for index in range(0,5): 80 | self.layers.append(non_bottleneck_1d(64, 64, 1, 0.03)) 81 | 82 | self.layers.append(down_sampling_block(64, 128)) 83 | 84 | for index in range(0,2): 85 | self.layers.append(non_bottleneck_1d(128, 128, 2, 0.3)) 86 | self.layers.append(non_bottleneck_1d(128, 128, 4, 0.3)) 87 | self.layers.append(non_bottleneck_1d(128, 128, 8, 0.3)) 88 | self.layers.append(non_bottleneck_1d(128, 128, 16, 0.3)) 89 | 90 | self.layers.append(up_sampling_block(128, 64)) 91 | self.layers.append(non_bottleneck_1d(64, 64, 1, 0)) 92 | self.layers.append(non_bottleneck_1d(64, 64, 1, 0)) 93 | 94 | self.layers.append(up_sampling_block(64, 16)) 95 | self.layers.append(non_bottleneck_1d(16, 16, 1, 0)) 96 | self.layers.append(non_bottleneck_1d(16, 16, 1, 0)) 97 | 98 | self.output_conv = nn.ConvTranspose2d(16, self.nb_classes, 2, 2, padding=0, output_padding=0, bias=False) 99 | 100 | def forward(self, x): 101 | output = self.initial_block(x) 102 | for layer in self.layers: 103 | output = layer(output) 104 | 105 | output = self.output_conv(output) 106 | return output 107 | 108 | if __name__ == '__main__': 109 | 110 | model = ERFNet() 111 | print(model) -------------------------------------------------------------------------------- /models/FCN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #from __future__ import print_function 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torchvision import models 7 | from torchvision.models.vgg import VGG 8 | 9 | 10 | class FCN32s(nn.Module): 11 | 12 | def __init__(self, pretrained_net, n_class): 13 | super().__init__() 14 | self.n_class = n_class 15 | self.pretrained_net = pretrained_net 16 | self.relu = nn.ReLU(inplace=True) 17 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 18 | self.bn1 = nn.BatchNorm2d(512) 19 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 20 | self.bn2 = nn.BatchNorm2d(256) 21 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 22 | self.bn3 = nn.BatchNorm2d(128) 23 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 24 | self.bn4 = nn.BatchNorm2d(64) 25 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 26 | self.bn5 = nn.BatchNorm2d(32) 27 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 28 | 29 | def forward(self, x): 30 | output = self.pretrained_net(x) 31 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 32 | 33 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16) 34 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 35 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 36 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 37 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 38 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 39 | 40 | return score # size=(N, n_class, x.H/1, x.W/1) 41 | 42 | 43 | class FCN16s(nn.Module): 44 | 45 | def __init__(self, pretrained_net, n_class): 46 | super().__init__() 47 | self.n_class = n_class 48 | self.pretrained_net = pretrained_net 49 | self.relu = nn.ReLU(inplace=True) 50 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 51 | self.bn1 = nn.BatchNorm2d(512) 52 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 53 | self.bn2 = nn.BatchNorm2d(256) 54 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 55 | self.bn3 = nn.BatchNorm2d(128) 56 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 57 | self.bn4 = nn.BatchNorm2d(64) 58 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 59 | self.bn5 = nn.BatchNorm2d(32) 60 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 61 | 62 | def forward(self, x): 63 | output = self.pretrained_net(x) 64 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 65 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 66 | 67 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16) 68 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16) 69 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 70 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 71 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 72 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 73 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 74 | 75 | return score # size=(N, n_class, x.H/1, x.W/1) 76 | 77 | 78 | class FCN8s(nn.Module): 79 | 80 | def __init__(self, pretrained_net, n_class): 81 | super().__init__() 82 | self.n_class = n_class 83 | self.pretrained_net = pretrained_net 84 | self.relu = nn.ReLU(inplace=True) 85 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 86 | self.bn1 = nn.BatchNorm2d(512) 87 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 88 | self.bn2 = nn.BatchNorm2d(256) 89 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 90 | self.bn3 = nn.BatchNorm2d(128) 91 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 92 | self.bn4 = nn.BatchNorm2d(64) 93 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 94 | self.bn5 = nn.BatchNorm2d(32) 95 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 96 | 97 | def forward(self, x): 98 | output = self.pretrained_net(x) 99 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 100 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 101 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8) 102 | 103 | score = self.relu(self.deconv1(x5)) # size=(N, 512, x.H/16, x.W/16) 104 | score = self.bn1(score + x4) # element-wise add, size=(N, 512, x.H/16, x.W/16) 105 | score = self.relu(self.deconv2(score)) # size=(N, 256, x.H/8, x.W/8) 106 | score = self.bn2(score + x3) # element-wise add, size=(N, 256, x.H/8, x.W/8) 107 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 108 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 109 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 110 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 111 | 112 | return score # size=(N, n_class, x.H/1, x.W/1) 113 | 114 | 115 | class FCNs(nn.Module): 116 | 117 | def __init__(self, pretrained_net, n_class): 118 | super().__init__() 119 | self.n_class = n_class 120 | self.pretrained_net = pretrained_net 121 | self.relu = nn.ReLU(inplace=True) 122 | self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 123 | self.bn1 = nn.BatchNorm2d(512) 124 | self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 125 | self.bn2 = nn.BatchNorm2d(256) 126 | self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 127 | self.bn3 = nn.BatchNorm2d(128) 128 | self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 129 | self.bn4 = nn.BatchNorm2d(64) 130 | self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1) 131 | self.bn5 = nn.BatchNorm2d(32) 132 | self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 133 | 134 | def forward(self, x): 135 | output = self.pretrained_net(x) 136 | x5 = output['x5'] # size=(N, 512, x.H/32, x.W/32) 137 | x4 = output['x4'] # size=(N, 512, x.H/16, x.W/16) 138 | x3 = output['x3'] # size=(N, 256, x.H/8, x.W/8) 139 | x2 = output['x2'] # size=(N, 128, x.H/4, x.W/4) 140 | x1 = output['x1'] # size=(N, 64, x.H/2, x.W/2) 141 | 142 | score = self.bn1(self.relu(self.deconv1(x5))) # size=(N, 512, x.H/16, x.W/16) 143 | score = score + x4 # element-wise add, size=(N, 512, x.H/16, x.W/16) 144 | score = self.bn2(self.relu(self.deconv2(score))) # size=(N, 256, x.H/8, x.W/8) 145 | score = score + x3 # element-wise add, size=(N, 256, x.H/8, x.W/8) 146 | score = self.bn3(self.relu(self.deconv3(score))) # size=(N, 128, x.H/4, x.W/4) 147 | score = score + x2 # element-wise add, size=(N, 128, x.H/4, x.W/4) 148 | score = self.bn4(self.relu(self.deconv4(score))) # size=(N, 64, x.H/2, x.W/2) 149 | score = score + x1 # element-wise add, size=(N, 64, x.H/2, x.W/2) 150 | score = self.bn5(self.relu(self.deconv5(score))) # size=(N, 32, x.H, x.W) 151 | score = self.classifier(score) # size=(N, n_class, x.H/1, x.W/1) 152 | 153 | return score # size=(N, n_class, x.H/1, x.W/1) 154 | 155 | 156 | class VGGNet(VGG): 157 | def __init__(self, pretrained=True, model='vgg16', requires_grad=True, remove_fc=True, show_params=False): 158 | super().__init__(make_layers(cfg[model])) 159 | self.ranges = ranges[model] 160 | 161 | if pretrained: 162 | exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model) 163 | 164 | if not requires_grad: 165 | for param in super().parameters(): 166 | param.requires_grad = False 167 | 168 | if remove_fc: # delete redundant fully-connected layer params, can save memory 169 | del self.classifier 170 | 171 | if show_params: 172 | for name, param in self.named_parameters(): 173 | print(name, param.size()) 174 | 175 | def forward(self, x): 176 | output = {} 177 | 178 | # get the output of each maxpooling layer (5 maxpool in VGG net) 179 | for idx in range(len(self.ranges)): 180 | for layer in range(self.ranges[idx][0], self.ranges[idx][1]): 181 | x = self.features[layer](x) 182 | output["x%d"%(idx+1)] = x 183 | 184 | return output 185 | 186 | 187 | ranges = { 188 | 'vgg11': ((0, 3), (3, 6), (6, 11), (11, 16), (16, 21)), 189 | 'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)), 190 | 'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)), 191 | 'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37)) 192 | } 193 | 194 | # cropped version from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 195 | cfg = { 196 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 197 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 198 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 199 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 200 | } 201 | 202 | def make_layers(cfg, batch_norm=False): 203 | layers = [] 204 | in_channels = 3 205 | for v in cfg: 206 | if v == 'M': 207 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 208 | else: 209 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 210 | if batch_norm: 211 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 212 | else: 213 | layers += [conv2d, nn.ReLU(inplace=True)] 214 | in_channels = v 215 | return nn.Sequential(*layers) 216 | 217 | 218 | if __name__ == "__main__": 219 | batch_size, n_class, h, w = 10, 20, 160, 160 220 | 221 | # test output size 222 | vgg_model = VGGNet(requires_grad=True) 223 | input = torch.autograd.Variable(torch.randn(batch_size, 3, 224, 224)) 224 | output = vgg_model(input) 225 | assert output['x5'].size() == torch.Size([batch_size, 512, 7, 7]) 226 | 227 | fcn_model = FCN32s(pretrained_net=vgg_model, n_class=n_class) 228 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 229 | output = fcn_model(input) 230 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 231 | 232 | fcn_model = FCN16s(pretrained_net=vgg_model, n_class=n_class) 233 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 234 | output = fcn_model(input) 235 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 236 | 237 | fcn_model = FCN8s(pretrained_net=vgg_model, n_class=n_class) 238 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 239 | output = fcn_model(input) 240 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 241 | 242 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class) 243 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 244 | output = fcn_model(input) 245 | assert output.size() == torch.Size([batch_size, n_class, h, w]) 246 | 247 | print("Pass size check") 248 | 249 | # test a random batch, loss should decrease 250 | fcn_model = FCNs(pretrained_net=vgg_model, n_class=n_class) 251 | criterion = nn.BCELoss() 252 | optimizer = optim.SGD(fcn_model.parameters(), lr=1e-3, momentum=0.9) 253 | input = torch.autograd.Variable(torch.randn(batch_size, 3, h, w)) 254 | y = torch.autograd.Variable(torch.randn(batch_size, n_class, h, w), requires_grad=False) 255 | for iter in range(10): 256 | optimizer.zero_grad() 257 | output = fcn_model(input) 258 | output = nn.functional.sigmoid(output) 259 | loss = criterion(output, y) 260 | loss.backward() 261 | print("iter{}, loss {}".format(iter, loss.data[0])) 262 | optimizer.step() 263 | -------------------------------------------------------------------------------- /models/MyNetworks/ESFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.MyNetworks.layers import down_sampling_block, DownSamplingBlock_v2 5 | from models.MyNetworks.layers import SFRB, nonbt_dw_fac, bt_dw, Separabel_conv2d, nonbt_dw 6 | from models.MyNetworks.layers import bt, non_bt, bt_fac, nonbt_fac 7 | 8 | class ESFNet(nn.Module): 9 | def __init__(self, 10 | config, 11 | down_factor=8, 12 | interpolate=True, 13 | dilation=True, 14 | dropout=False, 15 | ): 16 | super(ESFNet, self).__init__() 17 | 18 | self.name = 'ESFNet_base' 19 | self.nb_classes = config.nb_classes 20 | self.down_factor = down_factor 21 | self.interpolate = interpolate 22 | self.stage_channels = [-1, 16, 64, 128, 256, 512] 23 | 24 | if dilation == True: 25 | self.dilation_list = [1, 2, 4, 8, 16] 26 | else: 27 | self.dilation_list = [1, 1, 1, 1, 1 ] 28 | 29 | if dropout == True: 30 | self.dropout_list = [0.01, 0.001] 31 | 32 | if down_factor==8: 33 | # 8x downsampling 34 | self.encoder = nn.Sequential( 35 | down_sampling_block(3, 16), 36 | 37 | DownSamplingBlock_v2(in_channels=self.stage_channels[1], out_channels=self.stage_channels[2]), 38 | 39 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 40 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 41 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 42 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 43 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 44 | 45 | DownSamplingBlock_v2(in_channels=self.stage_channels[2], out_channels=self.stage_channels[3]), 46 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 47 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[1], dropout_rate=0.0), 48 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 49 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[2], dropout_rate=0.0), 50 | 51 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 52 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[3], dropout_rate=0.0), 53 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 54 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[4], dropout_rate=0.0), 55 | 56 | ) 57 | if interpolate == True: 58 | self.project_layer = nn.Conv2d(self.stage_channels[3], self.nb_classes, 1, bias=False) 59 | else: 60 | self.decoder = nn.Sequential( 61 | nn.ConvTranspose2d(self.stage_channels[3], self.stage_channels[2], kernel_size=3, stride=2, 62 | padding=1, 63 | output_padding=1, bias=False), 64 | nn.BatchNorm2d(self.stage_channels[2]), 65 | nn.ReLU(inplace=True), 66 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], ), 67 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], ), 68 | nn.ConvTranspose2d(self.stage_channels[2], self.stage_channels[1], kernel_size=3, stride=2, 69 | padding=1, 70 | output_padding=1, bias=False), 71 | nn.BatchNorm2d(self.stage_channels[1]), 72 | nn.ReLU(inplace=True), 73 | SFRB(in_channels=self.stage_channels[1], out_channels=self.stage_channels[1], ), 74 | SFRB(in_channels=self.stage_channels[1], out_channels=self.stage_channels[1], ), 75 | nn.ConvTranspose2d(self.stage_channels[1], self.nb_classes, kernel_size=3, stride=2, padding=1, 76 | output_padding=1, bias=False) 77 | 78 | ) 79 | 80 | elif down_factor==16: 81 | # 16x downsampling 82 | self.encoder = nn.Sequential( 83 | down_sampling_block(3, 16), 84 | DownSamplingBlock_v2(self.stage_channels[1], self.stage_channels[2]), 85 | 86 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 87 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 88 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 89 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 90 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], dilation= self.dilation_list[0], dropout_rate=0.0), 91 | 92 | DownSamplingBlock_v2(self.stage_channels[2], self.stage_channels[3]), 93 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 94 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 95 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 96 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 97 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], dilation=self.dilation_list[0], dropout_rate=0.0), 98 | 99 | DownSamplingBlock_v2(self.stage_channels[3], self.stage_channels[4]), 100 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[0], dropout_rate=0.0), 101 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[1], dropout_rate=0.0), 102 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[0], dropout_rate=0.0), 103 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[2], dropout_rate=0.0), 104 | 105 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[0], dropout_rate=0.0), 106 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[3], dropout_rate=0.0), 107 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[0], dropout_rate=0.0), 108 | SFRB(in_channels=self.stage_channels[4], out_channels=self.stage_channels[4], dilation=self.dilation_list[4], dropout_rate=0.0), 109 | ) 110 | if interpolate == True: 111 | self.project_layer = nn.Conv2d(self.stage_channels[4], self.nb_classes, 1, bias=False) 112 | else: 113 | self.decoder = nn.Sequential( 114 | nn.ConvTranspose2d(self.stage_channels[4], self.stage_channels[3], kernel_size=3, stride=2, 115 | padding=1, 116 | output_padding=1, bias=False), 117 | nn.BatchNorm2d(self.stage_channels[3]), 118 | nn.ReLU(inplace=True), 119 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3]), 120 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3]), 121 | nn.ConvTranspose2d(self.stage_channels[3], self.stage_channels[2], kernel_size=3, stride=2, 122 | padding=1, 123 | output_padding=1, bias=False), 124 | nn.BatchNorm2d(self.stage_channels[2]), 125 | nn.ReLU(inplace=True), 126 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2]), 127 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2]), 128 | nn.ConvTranspose2d(self.stage_channels[2], self.stage_channels[1], kernel_size=3, stride=2, padding=1, 129 | output_padding=1, bias=False), 130 | nn.BatchNorm2d(self.stage_channels[1]), 131 | nn.ReLU(inplace=True), 132 | SFRB(in_channels=self.stage_channels[1], out_channels=self.stage_channels[1]), 133 | SFRB(in_channels=self.stage_channels[1], out_channels=self.stage_channels[1]), 134 | nn.ConvTranspose2d(self.stage_channels[2], self.nb_classes, kernel_size=3, stride=2, 135 | padding=1, 136 | output_padding=1, bias=False), 137 | ) 138 | 139 | 140 | def forward(self, input): 141 | 142 | encoder_out = self.encoder(input) 143 | 144 | if self.interpolate == True: 145 | decoder_out = self.project_layer(encoder_out) 146 | decoder_out = F.interpolate(decoder_out, scale_factor=self.down_factor, mode='bilinear', align_corners=True) 147 | else: 148 | decoder_out = self.decoder(encoder_out) 149 | 150 | return decoder_out 151 | 152 | 153 | class ESFNet_mini_ex(nn.Module): 154 | def __init__(self, 155 | config, 156 | interpolate=True, 157 | dilation=True, 158 | dropout=False,): 159 | super(ESFNet_mini_ex, self).__init__() 160 | 161 | self.name = 'ESFNet_mini_ex' 162 | self.nb_classes = config.nb_classes 163 | self.interpolate = interpolate 164 | self.stage_channels = [-1, 16, 64, 128, 256, 512] 165 | 166 | if dilation == True: 167 | self.dilation_list = [1, 2, 4, 8, 16] 168 | else: 169 | self.dilation_list = [1, 1, 1, 1, 1] 170 | 171 | 172 | self.encoder = nn.Sequential( 173 | down_sampling_block(3, 16), 174 | DownSamplingBlock_v2(in_channels=self.stage_channels[1], out_channels=self.stage_channels[2]), 175 | 176 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], 177 | dilation=self.dilation_list[0], dropout_rate=0.0), 178 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], 179 | dilation=self.dilation_list[0], dropout_rate=0.0), 180 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], 181 | dilation=self.dilation_list[0], dropout_rate=0.0), 182 | DownSamplingBlock_v2(in_channels=self.stage_channels[2], out_channels=self.stage_channels[3]), 183 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], 184 | dilation=self.dilation_list[1], dropout_rate=0.0), 185 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], 186 | dilation=self.dilation_list[2], dropout_rate=0.0), 187 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], 188 | dilation=self.dilation_list[3], dropout_rate=0.0), 189 | SFRB(in_channels=self.stage_channels[3], out_channels=self.stage_channels[3], 190 | dilation=self.dilation_list[4], dropout_rate=0.0), 191 | ) 192 | if interpolate == True: 193 | self.project_layer = nn.Conv2d(self.stage_channels[3], self.nb_classes, 1, bias=False) 194 | else: 195 | self.decoder = nn.Sequential( 196 | nn.ConvTranspose2d(self.stage_channels[3], self.stage_channels[2], kernel_size=3, stride=2, 197 | padding=1, 198 | output_padding=1, bias=False), 199 | nn.BatchNorm2d(self.stage_channels[2]), 200 | nn.ReLU(inplace=True), 201 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], ), 202 | SFRB(in_channels=self.stage_channels[2], out_channels=self.stage_channels[2], ), 203 | nn.ConvTranspose2d(self.stage_channels[2], self.stage_channels[1], kernel_size=3, stride=2, 204 | padding=1, 205 | output_padding=1, bias=False), 206 | nn.BatchNorm2d(self.stage_channels[1]), 207 | nn.ReLU(inplace=True), 208 | SFRB(in_channels=self.stage_channels[1], out_channels=self.stage_channels[1], ), 209 | SFRB(in_channels=self.stage_channels[1], out_channels=self.stage_channels[1], ), 210 | nn.ConvTranspose2d(self.stage_channels[1], self.nb_classes, kernel_size=3, stride=2, padding=1, 211 | output_padding=1, bias=False) 212 | ) 213 | def forward(self, x): 214 | 215 | encoder_out = self.encoder(input) 216 | 217 | if self.interpolate == True: 218 | decoder_out = self.project_layer(encoder_out) 219 | decoder_out = F.interpolate(decoder_out, scale_factor=8, mode='bilinear', align_corners=True) 220 | else: 221 | decoder_out = self.decoder(encoder_out) 222 | 223 | return decoder_out -------------------------------------------------------------------------------- /models/MyNetworks/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Separabel_conv2d(nn.Module): 6 | def __init__(self, 7 | in_channels, 8 | out_channels, 9 | groups, 10 | kernel_size=(3,3), 11 | dilation=(1,1), 12 | #padding=(1,1), 13 | stride=(1,1), 14 | bias=False): 15 | """ 16 | # Note: Default for kernel_size=3, 17 | for Depthwise conv2d groups should equal to in_channels and out_channels == in_channels 18 | Only bn after depthwise_conv2d and no no-linear 19 | 20 | padding = (kernel_size-1) / 2 21 | padding = padding * dilation 22 | """ 23 | super(Separabel_conv2d, self).__init__() 24 | self.depthwise_conv2d = nn.Conv2d( 25 | in_channels=in_channels, 26 | out_channels=out_channels, 27 | kernel_size=kernel_size, 28 | stride=stride, 29 | padding= (int((kernel_size[0]-1)/2)*dilation[0],int((kernel_size[1]-1)/2)*dilation[1]), 30 | dilation= dilation, groups=groups,bias=bias 31 | ) 32 | self.dw_bn = nn.BatchNorm2d(out_channels) 33 | self.pointwise_conv2d = nn.Conv2d( 34 | in_channels=out_channels, 35 | out_channels=out_channels, 36 | kernel_size=1, 37 | stride=1, padding=0, dilation=1, groups=1, bias=False 38 | ) 39 | self.pw_bn = nn.BatchNorm2d(out_channels) 40 | def forward(self, input): 41 | 42 | out = self.depthwise_conv2d(input) 43 | out = self.dw_bn(out) 44 | out = self.pointwise_conv2d(out) 45 | out = self.pw_bn(out) 46 | 47 | return out 48 | 49 | class down_sampling_block(nn.Module): 50 | 51 | def __init__(self, inpc, oupc): 52 | super(down_sampling_block, self).__init__() 53 | self.branch_conv = nn.Conv2d(inpc, oupc-inpc, 3, stride=2, padding= 1, bias=False) 54 | self.branch_mp = nn.MaxPool2d(kernel_size=2, stride=2) 55 | self.bn = nn.BatchNorm2d(oupc, eps=1e-03) 56 | 57 | def forward(self, x): 58 | output = torch.cat([self.branch_conv(x), self.branch_mp(x)], 1) 59 | output = self.bn(output) 60 | 61 | return F.relu(output) 62 | 63 | def channel_shuffle(input, groups): 64 | """ 65 | # Note that groups set to channels by default 66 | if depthwise_conv2d means groups == in_channels thus, channels_shuffle doesn't work for it. 67 | """ 68 | batch_size, channels, height, width = input.shape 69 | #groups = channels 70 | channels_per_group = channels // groups 71 | 72 | input = input.view(batch_size, groups, channels_per_group, height, width) 73 | input = input.transpose(1,2).contiguous() 74 | input = input.view(batch_size, -1, height, width) 75 | 76 | return input 77 | 78 | 79 | class DownSamplingBlock_v2(nn.Module): 80 | def __init__(self, 81 | in_channels, 82 | out_channels,): 83 | """ 84 | # Note: Initial_block from ENet 85 | Default that out_channels = 2 * in_channels 86 | compared to downsamplingblock_v1, change conv3x3 into Depthwise and projection_layer 87 | 88 | Add: channel_shuffle after concatenate 89 | to be testing 90 | 91 | gc prelu/ relu 92 | """ 93 | super(DownSamplingBlock_v2, self).__init__() 94 | 95 | # MaxPooling or AvgPooling 96 | self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 97 | ''' 98 | # FacDW 99 | self.depthwise_conv2d_1 = nn.Conv2d(in_channels= in_channels, out_channels=in_channels, kernel_size=(3,1), stride=(2,1), 100 | padding=(1,0), groups=in_channels, bias=False) 101 | self.dwbn1 = nn.BatchNorm2d(in_channels) 102 | self.depthwise_conv2d_2 = nn.Conv2d(in_channels= in_channels, out_channels=in_channels, kernel_size=(1,3), stride=(1,2), 103 | padding=(0,1), groups=in_channels, bias=False) 104 | self.dwbn2 = nn.BatchNorm2d(in_channels) 105 | ''' 106 | self.depthwise_conv2d = nn.Conv2d(in_channels= in_channels, out_channels=in_channels, 107 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 108 | self.dw_bn = nn.BatchNorm2d(in_channels) 109 | 110 | self.project_layer = nn.Conv2d(in_channels= in_channels, out_channels=out_channels-in_channels, 111 | kernel_size=1, bias=False) 112 | # here dont need project_bn, need to bn with ext_branch 113 | #self.project_bn = nn.BatchNorm2d(out_channels-in_channels) 114 | 115 | self.ret_bn = nn.BatchNorm2d(out_channels) 116 | self.ret_prelu = nn.ReLU(inplace= True) 117 | 118 | def forward(self, input): 119 | 120 | ext_branch = self.pooling(input) 121 | ''' 122 | # facDW 123 | main_branch = self.dwbn1(self.depthwise_conv2d_1(input)) 124 | main_branch = self.dwbn2(self.depthwise_conv2d_2(main_branch)) 125 | 126 | ''' 127 | main_branch = self.depthwise_conv2d(input) 128 | main_branch = self.dw_bn(main_branch) 129 | 130 | main_branch = self.project_layer(main_branch) 131 | 132 | ret = torch.cat([ext_branch, main_branch], dim=1) 133 | ret = self.ret_bn(ret) 134 | 135 | #ret = channel_shuffle(ret, 2) 136 | 137 | return self.ret_prelu(ret) 138 | 139 | class bt(nn.Module): 140 | def __init__(self, 141 | in_channels, 142 | out_channels, 143 | kernel_size=3, 144 | stride=1, 145 | dilation=1, 146 | ): 147 | super(bt, self).__init__() 148 | self.internal_channels = in_channels // 4 149 | # compress conv 150 | self.conv1 = nn.Conv2d(in_channels, self.internal_channels, 1, bias=False) 151 | self.conv1_bn = nn.BatchNorm2d(self.internal_channels) 152 | # a relu 153 | self.conv2 = nn.Conv2d(self.internal_channels, self.internal_channels, kernel_size, 154 | stride, padding=int((kernel_size-1)/2*dilation), dilation=dilation, groups=1, bias=False) 155 | self.conv2_bn = nn.BatchNorm2d(self.internal_channels) 156 | # a relu 157 | self.conv4 = nn.Conv2d(self.internal_channels, out_channels, 1, bias=False) 158 | self.conv4_bn = nn.BatchNorm2d(out_channels) 159 | def forward(self, x): 160 | 161 | residual = x 162 | main = F.relu(self.conv1_bn(self.conv1(x)),inplace=True) 163 | main = F.relu(self.conv2_bn(self.conv2(main)), inplace=True) 164 | main = self.conv4_bn(self.conv4(main)) 165 | out = F.relu(torch.add(main, residual), inplace=True) 166 | 167 | return out 168 | 169 | class bt_fac(nn.Module): 170 | def __init__(self, 171 | in_channels, 172 | out_channels, 173 | kernel_size=3, 174 | stride=1, 175 | dilation=1,): 176 | super(bt_fac, self).__init__() 177 | 178 | self.internal_channels = in_channels // 4 179 | self.compress_conv1 = nn.Conv2d(in_channels, self.internal_channels, 1, padding=0, bias=False) 180 | self.conv1_bn = nn.BatchNorm2d(self.internal_channels) 181 | # here is relu 182 | self.conv2_1 = nn.Conv2d(self.internal_channels, self.internal_channels, (kernel_size, 1), stride=(stride, 1), 183 | padding=(int((kernel_size-1)/2*dilation), 0), dilation=(dilation, 1), bias=False) 184 | self.conv2_1_bn = nn.BatchNorm2d(self.internal_channels) 185 | self.conv2_2 = nn.Conv2d(self.internal_channels, self.internal_channels, (1, kernel_size), stride=(1, stride), 186 | padding=(0, int((kernel_size-1)/2*dilation)), dilation=(1, dilation), bias=False) 187 | self.conv2_2_bn = nn.BatchNorm2d(self.internal_channels) 188 | # here is relu 189 | self.extend_conv3 = nn.Conv2d(self.internal_channels, out_channels, 1, padding=0, bias=False) 190 | 191 | self.conv3_bn = nn.BatchNorm2d(out_channels) 192 | def forward(self, x): 193 | 194 | main = F.relu((self.conv1_bn(self.compress_conv1(x))),inplace=True) 195 | main = F.relu(self.conv2_1_bn(self.conv2_1(main)), inplace=True) 196 | main = F.relu(self.conv2_2_bn(self.conv2_2(main)), inplace=True) 197 | 198 | main = self.conv3_bn(self.extend_conv3(main)) 199 | return F.relu((torch.add(main, x)), inplace=True) 200 | 201 | 202 | class non_bt(nn.Module): 203 | def __init__(self, 204 | in_channels, 205 | out_channels, 206 | stride=1, 207 | kernel_size=3, 208 | dilation=1): 209 | super(non_bt, self).__init__() 210 | 211 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, 212 | stride, padding=int((kernel_size-1)/2*dilation), dilation=dilation, groups=1, bias=False) 213 | self.conv1_bn = nn.BatchNorm2d(out_channels) 214 | # here is relu 215 | self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size, 216 | stride, padding=int((kernel_size-1)/2*dilation), dilation=dilation, groups=1, bias=False) 217 | self.conv2_bn = nn.BatchNorm2d(out_channels) 218 | # here is relu 219 | def forward(self, x): 220 | 221 | x1 = x 222 | x = F.relu(self.conv1_bn(self.conv1(x)), inplace=True) 223 | x = self.conv2_bn(self.conv2(x)) 224 | return F.relu(torch.add(x, x1), inplace=True) 225 | 226 | class bt_dw(nn.Module): 227 | def __init__(self, 228 | in_channels, 229 | out_channels, 230 | kernel_size=3, 231 | stride=1, 232 | dilation=1, 233 | ): 234 | 235 | # default decoupled 236 | super(bt_dw, self).__init__() 237 | 238 | self.internal_channels = in_channels // 4 239 | # compress conv 240 | self.conv1 = nn.Conv2d(in_channels, self.internal_channels, 1, bias=False) 241 | self.conv1_bn = nn.BatchNorm2d(self.internal_channels) 242 | # a relu 243 | self.conv2_3 = nn.Conv2d(self.internal_channels, self.internal_channels, (kernel_size, kernel_size), stride=(stride, stride), 244 | padding=(int((kernel_size-1)/2*dilation),int((kernel_size-1)/2*dilation)), 245 | dilation =(dilation, dilation), 246 | groups=self.internal_channels, bias=False) 247 | self.bn_2_3 = nn.BatchNorm2d(self.internal_channels) 248 | # a relu 249 | self.conv4 = nn.Conv2d(self.internal_channels, out_channels, 1, bias=False) 250 | self.conv4_bn = nn.BatchNorm2d(out_channels) 251 | def forward(self, input): 252 | 253 | residual = input 254 | main = self.conv1(input) 255 | main = self.conv2_3(F.relu(self.conv1_bn(main), inplace=True)) 256 | main = self.conv4(F.relu(self.bn_2_3(main), inplace=True)) 257 | 258 | return F.relu(torch.add(self.conv4_bn(main),residual)) 259 | 260 | class nonbt_dw(nn.Module): 261 | def __init__(self, 262 | in_channels, 263 | out_channels, 264 | kernel_size=3, 265 | stride=1, 266 | dilation=1, 267 | ): 268 | super(nonbt_dw, self).__init__() 269 | 270 | self.conv1_dw = nn.Conv2d(in_channels, out_channels, (kernel_size, kernel_size), stride=(stride, stride), 271 | padding=(int((kernel_size-1)/2),int((kernel_size-1)/2)), 272 | dilation =(1,1), 273 | groups=in_channels, bias=False) 274 | self.conv1_dw_bn = nn.BatchNorm2d(out_channels) 275 | self.conv1_pw = nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False) 276 | self.conv1_pw_bn = nn.BatchNorm2d(out_channels) 277 | # here is relu 278 | self.conv2_dw = nn.Conv2d(in_channels, out_channels, (kernel_size, kernel_size), stride=(stride, stride), 279 | padding=(int((kernel_size-1)/2*dilation),int((kernel_size-1)/2*dilation)), 280 | dilation =(dilation, dilation), 281 | groups=in_channels, bias=False) 282 | self.conv2_dw_bn = nn.BatchNorm2d(out_channels) 283 | self.conv2_pw = nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False) 284 | self.conv2_pw_bn = nn.BatchNorm2d(out_channels) 285 | 286 | # here is relu 287 | def forward(self, x): 288 | residual = x 289 | m = self.conv1_dw(x) 290 | m = self.conv1_dw_bn(m) 291 | m = self.conv1_pw(m) 292 | m = self.conv1_pw_bn(m) 293 | 294 | m = self.conv2_dw(F.relu(m, inplace=True)) 295 | m = self.conv2_dw_bn(m) 296 | m = self.conv2_pw(m) 297 | m = self.conv2_pw_bn(m) 298 | 299 | return F.relu(torch.add(m, residual), inplace=True) 300 | 301 | 302 | class SFRB(nn.Module): 303 | def __init__(self, 304 | in_channels, 305 | out_channels, 306 | kernel_size=3, 307 | stride=1, 308 | dilation=1, 309 | dropout_rate =0.0, 310 | ): 311 | 312 | # default decoupled 313 | super(SFRB, self).__init__() 314 | 315 | self.internal_channels = in_channels // 4 316 | # compress conv 317 | self.conv1 = nn.Conv2d(in_channels, self.internal_channels, 1, bias=False) 318 | self.conv1_bn = nn.BatchNorm2d(self.internal_channels) 319 | # a relu 320 | # Depthwise_conv 3x1 and 1x3 321 | self.conv2 = nn.Conv2d(self.internal_channels, self.internal_channels, (kernel_size,1), stride=(stride,1), 322 | padding=(int((kernel_size-1)/2*dilation),0), dilation=(dilation,1), 323 | groups=self.internal_channels, bias=False) 324 | self.conv2_bn = nn.BatchNorm2d(self.internal_channels) 325 | self.conv3 = nn.Conv2d(self.internal_channels, self.internal_channels, (1,kernel_size), stride=(1,stride), 326 | padding=(0,int((kernel_size-1)/2*dilation)), dilation=(1, dilation), 327 | groups=self.internal_channels, bias=False) 328 | self.conv3_bn = nn.BatchNorm2d(self.internal_channels) 329 | self.conv4 = nn.Conv2d(self.internal_channels, out_channels, 1, bias=False) 330 | self.conv4_bn = nn.BatchNorm2d(out_channels) 331 | 332 | # regularization 333 | self.dropout = nn.Dropout2d(inplace=True, p=dropout_rate) 334 | def forward(self, input): 335 | 336 | residual = input 337 | main = self.conv1(input) 338 | main = self.conv1_bn(main) 339 | main = F.relu(main, inplace=True) 340 | 341 | main = self.conv2(main) 342 | main = self.conv2_bn(main) 343 | main = self.conv3(main) 344 | main = self.conv3_bn(main) 345 | main = self.conv4(main) 346 | main = self.conv4_bn(main) 347 | 348 | if self.dropout.p != 0: 349 | main = self.dropout(main) 350 | 351 | return F.relu(torch.add(main, residual), inplace=True) 352 | 353 | class nonbt_dw_fac(nn.Module): 354 | def __init__(self, 355 | in_channels, 356 | out_channels, 357 | kernel_size=3, 358 | stride=1, 359 | dilation=1, 360 | dropout_rate=0.0, 361 | ): 362 | super(nonbt_dw_fac, self).__init__() 363 | 364 | # defaultly inchannels = outchannels 365 | self.conv1_1 = nn.Conv2d(in_channels, out_channels, (kernel_size,1), stride=(stride,1), 366 | padding=(int((kernel_size-1)/2),0), dilation=(1,1), 367 | groups=in_channels, bias=False) 368 | self.conv1_1_bn = nn.BatchNorm2d(out_channels) 369 | self.conv1_2 = nn.Conv2d(in_channels, out_channels, (1,kernel_size), stride=(1,stride), 370 | padding=(0,int((kernel_size-1)/2)), dilation=(1, 1), 371 | groups=in_channels, bias=False) 372 | self.conv1_2_bn = nn.BatchNorm2d(out_channels) 373 | self.conv1 = nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False) 374 | self.conv1_bn = nn.BatchNorm2d(out_channels) 375 | # here is relu 376 | 377 | self.conv2_1 = nn.Conv2d(in_channels, out_channels, (kernel_size,1), stride=(stride,1), 378 | padding=(int((kernel_size-1)/2*dilation),0), dilation=(dilation,1), 379 | groups=in_channels, bias=False) 380 | self.conv2_1_bn = nn.BatchNorm2d(out_channels) 381 | self.conv2_2 = nn.Conv2d(in_channels, out_channels, (1,kernel_size), stride=(1,stride), 382 | padding=(0,int((kernel_size-1)/2*dilation)), dilation=(1, dilation), 383 | groups=in_channels, bias=False) 384 | self.conv2_2_bn = nn.BatchNorm2d(out_channels) 385 | self.conv2 = nn.Conv2d(in_channels, out_channels, 1, padding=0, bias=False) 386 | self.conv2_bn = nn.BatchNorm2d(out_channels) 387 | #self.drop2d = nn.Dropout2d(p=dropout_rate) 388 | 389 | # here is relu 390 | def forward(self, x): 391 | 392 | residual = x 393 | main = self.conv1_1(x) 394 | main = self.conv1_1_bn(main) 395 | main = self.conv1_2(main) 396 | main = self.conv1_2_bn(main) 397 | main = self.conv1(main) 398 | main = self.conv1_bn(main) 399 | 400 | main = self.conv2_1(main) 401 | main = self.conv2_1_bn(main) 402 | main = self.conv2_2(main) 403 | main = self.conv2_2_bn(main) 404 | main = self.conv2(main) 405 | 406 | return F.relu(torch.add(self.conv2_bn(main), residual), inplace=True) 407 | 408 | class nonbt_fac(nn.Module): 409 | # stride = 1 kernel_size=3 410 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, dropout_rate=0.0): 411 | super(nonbt_fac, self).__init__() 412 | 413 | self.conv3x1_1 = nn.Conv2d(in_channels, out_channels, (kernel_size, 1), stride=1, padding=(1,0), bias=False) 414 | self.conv1x3_1 = nn.Conv2d(in_channels, out_channels, (1, kernel_size), stride=1, padding=(0,1), bias=False) 415 | self.bn1 = nn.BatchNorm2d(out_channels) 416 | self.conv3x1_2 = nn.Conv2d(in_channels, out_channels, (3, 1), stride=1, padding=(1*dilation,0), dilation=(dilation,1), bias=False) 417 | self.conv1x3_2 = nn.Conv2d(in_channels, out_channels, (1, 3), stride=1, padding=(0,1*dilation), dilation=(1,dilation), bias=False) 418 | self.bn2 = nn.BatchNorm2d(out_channels) 419 | self.dropout = nn.Dropout2d(dropout_rate) 420 | 421 | def forward(self, input): 422 | output = self.conv3x1_1(input) 423 | output = F.relu(output) 424 | output = self.conv1x3_1(output) 425 | output = self.bn1(output) 426 | output = F.relu(output) 427 | 428 | output = self.conv3x1_2(output) 429 | output = F.relu(output) 430 | output = self.conv1x3_2(output) 431 | output = self.bn2(output) 432 | 433 | if self.dropout.p != 0: 434 | output = self.dropout(output) 435 | output += input 436 | return F.relu(output) 437 | 438 | class conv2d_bn_relu(nn.Module): 439 | def __init__(self, 440 | in_channels, 441 | out_channels, 442 | kernel_size=3, 443 | stride=1, 444 | dilation=1, 445 | padding=1, 446 | bias=False, 447 | groups=1): 448 | super(conv2d_bn_relu, self).__init__() 449 | self.conv2d = nn.Conv2d( 450 | in_channels=in_channels, out_channels=out_channels, 451 | kernel_size= kernel_size, stride=stride, 452 | dilation=dilation, padding=padding, 453 | bias=bias, groups=groups, 454 | ) 455 | self.bn = nn.BatchNorm2d(out_channels) 456 | def forward(self, input): 457 | 458 | out = self.conv2d(input) 459 | 460 | return F.relu(self.bn(out)) 461 | -------------------------------------------------------------------------------- /models/SegNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | class SegNet(nn.Module): 7 | # modified 8 | def __init__(self,config): 9 | super(SegNet, self).__init__() 10 | 11 | self.config = config 12 | self.name = 'SegNet' 13 | batchNorm_momentum = 0.1 14 | 15 | self.conv11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 16 | self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 17 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 18 | self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 19 | 20 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 21 | self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 22 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 23 | self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 24 | 25 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 26 | self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 27 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 28 | self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 29 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 30 | self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 31 | 32 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 33 | self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 34 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 35 | self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 36 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 37 | self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 38 | 39 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 41 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 42 | self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 43 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 44 | self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 45 | 46 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 47 | self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 48 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 49 | self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 50 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 51 | self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 52 | 53 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 54 | self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 55 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 56 | self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum) 57 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) 58 | self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 59 | 60 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 61 | self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 62 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 63 | self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum) 64 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) 65 | self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 66 | 67 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) 68 | self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum) 69 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) 70 | self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 71 | 72 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) 73 | self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum) 74 | self.conv11d = nn.Conv2d(64, self.config.nb_classes, kernel_size=3, padding=1) 75 | 76 | 77 | def forward(self, x): 78 | 79 | # Stage 1 80 | x11 = F.relu(self.bn11(self.conv11(x))) 81 | x12 = F.relu(self.bn12(self.conv12(x11))) 82 | x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True) 83 | 84 | # Stage 2 85 | x21 = F.relu(self.bn21(self.conv21(x1p))) 86 | x22 = F.relu(self.bn22(self.conv22(x21))) 87 | x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True) 88 | 89 | # Stage 3 90 | x31 = F.relu(self.bn31(self.conv31(x2p))) 91 | x32 = F.relu(self.bn32(self.conv32(x31))) 92 | x33 = F.relu(self.bn33(self.conv33(x32))) 93 | x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True) 94 | 95 | # Stage 4 96 | x41 = F.relu(self.bn41(self.conv41(x3p))) 97 | x42 = F.relu(self.bn42(self.conv42(x41))) 98 | x43 = F.relu(self.bn43(self.conv43(x42))) 99 | x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True) 100 | 101 | # Stage 5 102 | x51 = F.relu(self.bn51(self.conv51(x4p))) 103 | x52 = F.relu(self.bn52(self.conv52(x51))) 104 | x53 = F.relu(self.bn53(self.conv53(x52))) 105 | x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True) 106 | 107 | 108 | # Stage 5d 109 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2) 110 | x53d = F.relu(self.bn53d(self.conv53d(x5d))) 111 | x52d = F.relu(self.bn52d(self.conv52d(x53d))) 112 | x51d = F.relu(self.bn51d(self.conv51d(x52d))) 113 | 114 | # Stage 4d 115 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2) 116 | x43d = F.relu(self.bn43d(self.conv43d(x4d))) 117 | x42d = F.relu(self.bn42d(self.conv42d(x43d))) 118 | x41d = F.relu(self.bn41d(self.conv41d(x42d))) 119 | 120 | # Stage 3d 121 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2) 122 | x33d = F.relu(self.bn33d(self.conv33d(x3d))) 123 | x32d = F.relu(self.bn32d(self.conv32d(x33d))) 124 | x31d = F.relu(self.bn31d(self.conv31d(x32d))) 125 | 126 | # Stage 2d 127 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2) 128 | x22d = F.relu(self.bn22d(self.conv22d(x2d))) 129 | x21d = F.relu(self.bn21d(self.conv21d(x22d))) 130 | 131 | # Stage 1d 132 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2) 133 | x12d = F.relu(self.bn12d(self.conv12d(x1d))) 134 | x11d = self.conv11d(x12d) 135 | 136 | return x11d 137 | 138 | def load_from_segnet(self, model_path): 139 | s_dict = self.state_dict()# create a copy of the state dict 140 | th = torch.load(model_path).state_dict() # load the weigths 141 | # for name in th: 142 | # s_dict[corresp_name[name]] = th[name] 143 | self.load_state_dict(th) -------------------------------------------------------------------------------- /models/UNet.py: -------------------------------------------------------------------------------- 1 | # sub-parts of the U-Net model 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class double_conv(nn.Module): 9 | '''(conv => BN => ReLU) * 2''' 10 | 11 | def __init__(self, in_ch, out_ch): 12 | super(double_conv, self).__init__() 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 18 | nn.BatchNorm2d(out_ch), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | 27 | class inconv(nn.Module): 28 | def __init__(self, in_ch, out_ch): 29 | super(inconv, self).__init__() 30 | self.conv = double_conv(in_ch, out_ch) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class down(nn.Module): 38 | def __init__(self, in_ch, out_ch): 39 | super(down, self).__init__() 40 | self.mpconv = nn.Sequential( 41 | nn.MaxPool2d(2), 42 | double_conv(in_ch, out_ch) 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.mpconv(x) 47 | return x 48 | 49 | 50 | class up(nn.Module): 51 | def __init__(self, in_ch, out_ch, bilinear=True): 52 | super(up, self).__init__() 53 | 54 | # would be a nice idea if the upsampling could be learned too, 55 | # but my machine do not have enough memory to handle all those weights 56 | if bilinear: 57 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 58 | else: 59 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 60 | 61 | self.conv = double_conv(in_ch, out_ch) 62 | 63 | def forward(self, x1, x2): 64 | x1 = self.up(x1) 65 | 66 | # input is CHW 67 | diffY = x2.size()[2] - x1.size()[2] 68 | diffX = x2.size()[3] - x1.size()[3] 69 | 70 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 71 | diffY // 2, diffY - diffY // 2)) 72 | 73 | # for padding issues, see 74 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 75 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 76 | 77 | x = torch.cat([x2, x1], dim=1) 78 | x = self.conv(x) 79 | return x 80 | 81 | 82 | class outconv(nn.Module): 83 | def __init__(self, in_ch, out_ch): 84 | super(outconv, self).__init__() 85 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | 93 | class UNet(nn.Module): 94 | def __init__(self, config): 95 | super(UNet, self).__init__() 96 | self.name = 'Unet' 97 | self.config=config 98 | self.inc = inconv(3, 64) 99 | self.down1 = down(64, 128) 100 | self.down2 = down(128, 256) 101 | self.down3 = down(256, 512) 102 | self.down4 = down(512, 512) 103 | self.up1 = up(1024, 256) 104 | self.up2 = up(512, 128) 105 | self.up3 = up(256, 64) 106 | self.up4 = up(128, 64) 107 | self.outc = outconv(64, self.config.nb_classes) 108 | 109 | def forward(self, x): 110 | x1 = self.inc(x) 111 | x2 = self.down1(x1) 112 | x3 = self.down2(x2) 113 | x4 = self.down3(x3) 114 | x5 = self.down4(x4) 115 | x = self.up1(x5, x4) 116 | x = self.up2(x, x3) 117 | x = self.up3(x, x2) 118 | x = self.up4(x, x1) 119 | x = self.outc(x) 120 | #return F.sigmoid(x) 121 | 122 | return x -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import glob 4 | import os 5 | import argparse 6 | import time 7 | import numpy as np 8 | from PIL.Image import Image 9 | from PIL import Image 10 | import torch.utils.data as data 11 | from models.MyNetworks import ESFNet 12 | import torchvision.transforms as transforms 13 | import torchvision.transforms.functional as TF 14 | from utils.util import AverageMeter 15 | from data.dataset import MyDataset 16 | 17 | # mean and std for WHU Building dataset 18 | # whether using depends on your use case: if your dataset is larger than WHU Building dataset, you could use the mean and 19 | # std w.r.t. your own dataset, otherwise we recommend to use these mean and std. 20 | rgb_mean = (0.4353, 0.4452, 0.4131) 21 | rgb_std = (0.2044, 0.1924, 0.2013) 22 | 23 | class dataset_predict(data.Dataset): 24 | def __init__(self, 25 | args): 26 | super(dataset_predict, self).__init__() 27 | 28 | self.args = args 29 | self.input_path = os.path.join(self.args.input, 'image_patches') 30 | self.data_list = glob.glob(os.path.join(self.input_path, '*')) 31 | 32 | def transform(self, image): 33 | 34 | image = TF.to_tensor(image) 35 | image = TF.normalize(image, mean=rgb_mean, std=rgb_std) 36 | 37 | return image 38 | 39 | def __getitem__(self, index): 40 | 41 | datas = Image.open(self.data_list[index]) 42 | t_datas = self.transform(datas) 43 | # return filename for saving patch predictions. 44 | return t_datas, self.data_list[index] 45 | 46 | def __len__(self): 47 | 48 | return len(self.data_list) 49 | 50 | 51 | class Predictor(object): 52 | def __init__(self, 53 | args, model, dataloader_predict): 54 | super(Predictor, self).__init__() 55 | 56 | self.args = args 57 | self.model = model 58 | self.dataloader_predict = dataloader_predict 59 | self.patches = None 60 | 61 | def predict(self): 62 | 63 | self.model.eval() 64 | #predict_time = AverageMeter() 65 | #batch_time = AverageMeter() 66 | #data_time = AverageMeter() 67 | 68 | with torch.no_grad(): 69 | tic = time.time() 70 | for steps, (data, filenames) in enumerate(self.dataloader_predict, start=1): 71 | data = data.to(self.model.device, non_blocking = True) 72 | #data_time.update(time.time() - tic) 73 | pre_tic = time.time() 74 | logits = self.model(data) 75 | self._save_pred(logits, filenames) 76 | # here depends on the use case, logits -> mask 77 | if self.patches is None: 78 | self.patches = torch.argmax(logits) * 255. 79 | else: 80 | self.patches = torch.cat([self.patches, torch.argmax(logits)*255.], 0) 81 | #predict_time.update(time.time() - pre_tic) 82 | #batch_time.update(time.time() - tic) 83 | tic = time.time() 84 | 85 | #print("Predicting and Saving Done!\n" 86 | # "Total Time: {:.2f}\n" 87 | # "Data Time: {:.2f}\n" 88 | # "Pre Time: {:.2f}" 89 | # .format(batch_time._get_sum(), data_time._get_sum(), predict_time._get_sum())) 90 | def _save_pred(self, predictions, filenames): 91 | 92 | for index, map in enumerate(predictions): 93 | 94 | map = torch.argmax(map, dim=0) 95 | map = map * 255 96 | map = np.asarray(map.cpu(), dtype=np.uint8) 97 | map = Image.fromarray(map) 98 | # filename /0.1.png [0] 0 [1] 1 99 | filename = filenames[index].split('/')[-1].split('.') 100 | save_filename = filename[0]+'.'+filename[1] 101 | save_path = os.path.join(self.args.output, 'patches', save_filename+'.png') 102 | 103 | map.save(save_path) 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import argparse 3 | import torch 4 | import random 5 | import numpy as np 6 | from configs.config import MyConfiguration 7 | from BaseTrainer import BaseTrainer 8 | from BaseTester import BaseTester 9 | from data.dataset import MyDataset 10 | from torch.utils.data import DataLoader 11 | from models.MyNetworks import ESFNet 12 | from visdom import Visdom 13 | 14 | def for_train(model, 15 | config, 16 | args, 17 | train_data_loader, 18 | valid_data_loader, 19 | begin_time, 20 | resume_file, 21 | loss_weight, 22 | visdom): 23 | """ 24 | :param model: 25 | :param config: 26 | :param train_data_loader: 27 | :param valid_data_loader: 28 | :param resume_file: 29 | :param loss_weight: 30 | :return: 31 | """ 32 | Trainer = BaseTrainer(model=model, config=config, args= args, 33 | train_data_loader= train_data_loader, 34 | valid_data_loader= valid_data_loader, 35 | begin_time=begin_time, 36 | resume_file = resume_file, 37 | loss_weight= loss_weight, 38 | visdom=visdom) 39 | Trainer.train() 40 | print(" Training Done ! ") 41 | 42 | def for_test(model, config, args, test_data_loader, begin_time, resume_file, loss_weight): 43 | """ 44 | :param model: 45 | :param config: 46 | :param test_data_loader: 47 | :param begin_time: 48 | :param resume_file: 49 | :param loss_weight: 50 | :param predict: 51 | :return: 52 | """ 53 | Tester = BaseTester(model= model, config= config, args = args, 54 | test_data_loader= test_data_loader, 55 | begin_time= begin_time, 56 | resume_file = resume_file, 57 | loss_weight= loss_weight) 58 | 59 | Tester.eval_and_predict() 60 | print(" Evaluation Done ! ") 61 | #if do_predict == True : 62 | # Tester.predict() 63 | # print(" Make Predict Image Done ! ") 64 | 65 | def main(config, args): 66 | 67 | loss_weight = torch.ones(config.nb_classes) 68 | loss_weight[0] = 1.53297775619 69 | loss_weight[1] = 7.63194124408 70 | 71 | # Here config in model, only used for nb_classes, so we do not use args 72 | 73 | model = ESFNet.ESFNet(config= config) 74 | print(model) 75 | 76 | # create visdom 77 | viz = Visdom(server=args.server, port=args.port, env=model.name) 78 | assert viz.check_connection(timeout_seconds=3), \ 79 | 'No connection could be formed quickly' 80 | 81 | # TODO there are somewhat still need to change in ../configs/config.cfg 82 | train_dataset = MyDataset(config=config, args= args, subset='train') 83 | valid_dataset = MyDataset(config=config, args= args, subset='val') 84 | test_dataset = MyDataset(config=config, args= args, subset='test') 85 | 86 | train_data_loader = DataLoader(dataset=train_dataset, 87 | batch_size=config.batch_size, 88 | shuffle=True, 89 | pin_memory=True, 90 | num_workers=args.threads, 91 | drop_last=True) 92 | valid_data_loader = DataLoader(dataset=valid_dataset, 93 | batch_size=config.batch_size, 94 | shuffle=False, 95 | pin_memory=True, 96 | num_workers=args.threads, 97 | drop_last=True) 98 | test_data_loader = DataLoader(dataset=test_dataset, 99 | batch_size=config.batch_size, 100 | shuffle=False, 101 | pin_memory=True, 102 | num_workers=args.threads, 103 | drop_last=True) 104 | 105 | begin_time = datetime.datetime.now().strftime('%m%d_%H%M%S') 106 | 107 | 108 | for_train(model = model, config=config, args = args, 109 | train_data_loader = train_data_loader, 110 | valid_data_loader= valid_data_loader, 111 | begin_time= begin_time, 112 | resume_file = args.weight, 113 | loss_weight= loss_weight, 114 | visdom=viz) 115 | 116 | """ 117 | # testing phase does not need visdom, just one scalar for loss, miou and accuracy 118 | """ 119 | for_test(model = model, config=config, args= args, 120 | test_data_loader=test_data_loader, 121 | begin_time= begin_time, 122 | resume_file = args.weight, 123 | loss_weight= loss_weight, 124 | ) 125 | 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | config = MyConfiguration() 131 | 132 | # for visdom 133 | DEFAULT_PORT=8097 134 | DEFAULT_HOSTNAME="http://localhost" 135 | 136 | parser = argparse.ArgumentParser(description="Efficient Semantic Segmentation Network") 137 | parser.add_argument('-port', metavar='port', type=int, default=DEFAULT_PORT, 138 | help='port the visdom server is running on.') 139 | parser.add_argument('-server', metavar='server', type=str, default=DEFAULT_HOSTNAME, 140 | help='Server address of the target to run the demo on.') 141 | parser.add_argument('-input', metavar='input', type=str, default=config.root_dir, 142 | help='root path to directory containing input images, including train & valid & test') 143 | parser.add_argument('-output', metavar='output', type=str, default=config.save_dir, 144 | help='root path to directory containing all the output, including predictions, logs and ckpt') 145 | parser.add_argument('-weight', metavar='weight', type=str, default=None, 146 | help='path to ckpt which will be loaded') 147 | parser.add_argument('-threads', metavar='threads', type=int, default=8, 148 | help='number of thread used for DataLoader') 149 | parser.add_argument('-gpu', metavar='gpu', type=int, default=0, 150 | help='gpu id to be used for prediction') 151 | 152 | args = parser.parse_args() 153 | 154 | # GPU setting init 155 | # keep prediction results the same when model runs each time 156 | 157 | """ 158 | You have chosen to seed training. ' 159 | 'This will turn on the CUDNN deterministic setting, ' 160 | 'which can slow down your training considerably! ' 161 | 'You may see unexpected behavior when restarting ' 162 | 'from checkpoints.' 163 | """ 164 | ''' 165 | # It will improve the efficiency, internal cuDNN and auto-tuner will find the proper configurations in your use case, 166 | so that it will optimize the running efficiency. 167 | # If the dims and type do not have a magnitude difference, it will improve the running efficiency, otherwise it will 168 | find the proper configurations every time when it meets a new data format, thus it will have bad influence on efficiency. 169 | ''' 170 | torch.backends.cudnn.benchmark = True 171 | ''' 172 | # using deterministic mode can have performance impact(speed), depending on your model. 173 | ''' 174 | #torch.backends.cudnn.deterministic = True 175 | #torch.cuda.manual_seed(config.random_seed) 176 | # for distribution 177 | #torch.cuda.manual_seed_all(config.random_seed) 178 | # seed the RNG for all devices(both CPU and GPUs) 179 | torch.manual_seed(config.random_seed) 180 | random.seed(config.random_seed) 181 | np.random.seed(config.random_seed) 182 | 183 | main(config= config, args = args) 184 | 185 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import cv2 5 | import numpy as np 6 | from functools import partial 7 | from tqdm import tqdm 8 | 9 | def parse_args(): 10 | 11 | parser = argparse.ArgumentParser(description='configurations of dataset') 12 | 13 | parser.add_argument('-i', '--input', type=str, default=os.path.join('.','input'), 14 | help='directory of input images, including images used to train and predict') 15 | parser.add_argument('-o', '--output', type=str, default=os.path.join('.','output'), 16 | help='directory of output images') 17 | # whether shuffle or not is set in DataLoader 18 | parser.add_argument('--size_x', type=int, default=224, 19 | help='the width of image patches') 20 | parser.add_argument('--size_y', type=str, default=224, 21 | help='the height of image patches') 22 | # step is set equal to patch size by default, that is the overlap is zero 23 | # overlap = size - step 24 | parser.add_argument('--step_x', type=int, default=224, 25 | help='the horizontal step of cropping images') 26 | parser.add_argument('--step_y', type=int, default=224, 27 | help='the vertical step of cropping images') 28 | parser.add_argument('--image_margin_color', type=list, default=[255,255,255], 29 | help='the color of image margin color') 30 | parser.add_argument('--label_margin_color', type=list, default=[255,255,255], 31 | help='the color of label margin color') 32 | 33 | return parser.parse_args() 34 | 35 | 36 | class Cropper(object): 37 | def __init__(self, args, configs, predict=True): 38 | super(Cropper, self).__init__() 39 | 40 | self.args = args 41 | self.configs = configs 42 | self.predict = predict 43 | 44 | self.input_path = self.args.input 45 | self.input_patches_path = os.path.join(self.input_path, 'image_patches') 46 | self.input_label_path = os.path.join(self.input_path, 'label_patches') 47 | 48 | self.output_path = self.args.output 49 | 50 | if predict: 51 | self.ensure_and_mkdir(self.input_path) 52 | self.ensure_and_mkdir(self.input_patches_path) 53 | else: 54 | self.ensure_and_mkdir(self.input_path) 55 | self.ensure_and_mkdir(self.input_patches_path) 56 | self.ensure_and_mkdir(self.input_label_path) 57 | # by default 58 | self.size_x = self.configs.cropped_size 59 | self.size_y = self.configs.cropped_size 60 | self.step_x = self.configs.step_x 61 | self.step_y = self.configs.step_y 62 | 63 | self.image_margin_color = args.image_margin_color 64 | self.label_margin_color = args.label_margin_color 65 | 66 | def get_filename(self, path): 67 | return path.split('/')[-1].split('.')[0] 68 | 69 | def ensure_and_mkdir(self, path): 70 | if not os.path.exists(path): 71 | os.makedirs(path) 72 | 73 | def pad_and_crop_images(self ,image, margin_color): 74 | 75 | # TODO padding odd and even 76 | # should the value of border be equal to background ? 77 | # padding 78 | image_y, image_x = image.shape[:2] 79 | border_y = 0 80 | if image_y % self.size_y != 0: 81 | border_y_double = (self.size_y - (image_y % self.size_y)) 82 | if border_y_double % 2 == 0: 83 | image = cv2.copyMakeBorder(image, border_y_double//2, border_y_double//2, 0, 0, cv2.BORDER_CONSTANT, value=margin_color) 84 | else: 85 | image = cv2.copyMakeBorder(image, border_y_double//2, border_y_double//2+1, 0, 0, cv2.BORDER_CONSTANT, value=margin_color) 86 | image_y = image.shape[0] 87 | border_x = 0 88 | if image_x % self.size_x != 0: 89 | border_x_double = (self.size_x - (image_x % self.size_x)) 90 | if border_x_double % 2 == 0: 91 | image = cv2.copyMakeBorder(image, 0, 0, border_x_double//2, border_x_double//2, cv2.BORDER_CONSTANT, value=margin_color) 92 | else: 93 | image = cv2.copyMakeBorder(image, 0, 0, border_x_double//2, border_x_double//2+1, cv2.BORDER_CONSTANT, value=margin_color) 94 | image_x = image.shape[1] 95 | 96 | # calculate n_w and n_h 97 | n_w = int(image_x / self.size_x) 98 | n_h = int(image_y / self.size_y) 99 | 100 | # cropping 101 | # cannot adopt for loop, otherwise the cropped images will include very little margin 102 | patches = [] 103 | start_y = 0 104 | while (start_y + self.size_y) <= image_y: 105 | start_x = 0 106 | while (start_x + self.size_x) <= image_x: 107 | patches.append(image[start_y:start_y + self.size_y, start_x:start_x + self.size_x]) 108 | start_x += self.step_x 109 | start_y += self.step_y 110 | 111 | return patches, n_w, n_h, image_y, image_x 112 | 113 | def save_images(self, patches, save_path, father_name): 114 | 115 | for i, patch in enumerate(patches): 116 | cv2.imwrite(os.path.join(save_path, father_name + str(i) + '.png'), patch) 117 | 118 | def image_processor(self, image_path, label_path=None): 119 | 120 | image = cv2.imread(image_path) 121 | filename = self.get_filename(image_path) 122 | patches, n_w, n_h, image_h, image_w = self.pad_and_crop_images(image=image, margin_color=self.image_margin_color) 123 | # patches is saved in input_path/image_patches 124 | input_path = os.path.join(self.input_path, 'image_patches') 125 | self.save_images(patches, input_path, filename) 126 | 127 | if self.predict is False: 128 | assert label_path is not None, \ 129 | 'label_path is None' 130 | label = cv2.imread(label_path) 131 | label_filename = self.get_filename(label_path) 132 | label_patches, _, _, _, _ = self.pad_and_crop_images(image=label, margin_color=self.label_margin_color) 133 | # label_patches is saved in input_path/gt_patches 134 | input_path = os.path.join(self.input_path, 'gt_patches') 135 | self.save_images(label_patches, input_path, label_filename) 136 | 137 | return patches, n_w, n_h, image_h, image_w 138 | 139 | -------------------------------------------------------------------------------- /utils/unpatchy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import product 3 | 4 | 5 | def unpatchify(patches, image_h, image_w): 6 | 7 | # shape as [n_h, n_w, p_h, p_w] 8 | assert len(patches.shape) == 5 9 | 10 | image = np.zeros((image_h, image_w, 3), dtype=patches.dtype) 11 | divisor = np.zeros((image_h, image_w, 3), dtype=patches.dtype) 12 | n_h, n_w, p_h, p_w, _ = patches.shape 13 | 14 | # overlap 15 | o_w = (n_w * p_w - image_w) / (n_w - 1) 16 | o_h = (n_h * p_h - image_h) / (n_h - 1) 17 | 18 | #assert int(o_w) == o_w 19 | #assert int(o_h) == o_h 20 | 21 | o_w = int(o_w) 22 | o_h = int(o_h) 23 | 24 | # start position 25 | s_w = p_w - o_w 26 | s_h = p_h - o_h 27 | 28 | for i, j in product(range(n_h), range(n_w)): 29 | patch = patches[i, j, :, :, :] 30 | image[(i * s_h):(i * s_h) + p_h, (j* s_w):(j * s_w) + p_w, :] += patch 31 | divisor[(i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w, :] += 1 32 | 33 | return image / divisor 34 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import sys 5 | from configs.config import MyConfiguration 6 | import torch.nn as nn 7 | import glob 8 | from tqdm import tqdm 9 | from scipy import misc 10 | import numpy as np 11 | import re 12 | import functools 13 | from collections import OrderedDict 14 | import pandas as pd 15 | # from torch.autograd import Variable 16 | import torch.nn.functional as F 17 | 18 | 19 | def ensure_dir(path): 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | class AverageMeter(object): 25 | """ 26 | # Computes and stores the average and current value 27 | """ 28 | 29 | def __init__(self): 30 | self.initialized = False 31 | self.val = None 32 | self.avg = None 33 | self.sum = None 34 | self.count = None 35 | 36 | def initialize(self, val, weight): 37 | self.val = val 38 | self.avg = val 39 | self.sum = val * weight 40 | self.count = weight 41 | self.initialized = True 42 | 43 | def update(self, val, weight=1): 44 | if not self.initialized: 45 | self.initialize(val, weight) 46 | else: 47 | self.add(val, weight) 48 | 49 | def add(self, val, weight): 50 | self.val = val 51 | self.sum += val * weight 52 | self.count += weight 53 | self.avg = self.sum / self.count 54 | 55 | def value(self): 56 | return self.val 57 | 58 | def average(self): 59 | return self.avg 60 | 61 | def _get_sum(self): 62 | return self.sum 63 | 64 | 65 | def cropped_dataset(config, subset): 66 | assert subset == 'train' or subset == 'val' or subset == 'test' 67 | dataset_path = os.path.join(config.root_dir, subset, config.data_folder_name) 68 | target_path = os.path.join(config.root_dir, subset, config.target_folder_name) 69 | 70 | new_dataset_path = dataset_path.replace('512', '256') 71 | new_target_path = target_path.replace('512', '256') 72 | # print(dataset_path) 73 | # print(target_path) 74 | 75 | data_paths = glob.glob(os.path.join(dataset_path, '*.tif')) 76 | for data_path in tqdm(data_paths): 77 | target_path = data_path.replace(config.data_folder_name, config.target_folder_name) 78 | 79 | filename = data_path.split('/')[-1].split('.')[0] 80 | 81 | data = misc.imread(data_path) 82 | target = misc.imread(target_path) 83 | 84 | h, w = config.original_size, config.original_size 85 | 86 | subimgcounter = 1 87 | stride = config.cropped_size - config.overlapped 88 | for h_pixel in range(0, h - config.cropped_size + 1, stride): 89 | for w_pixel in range(0, w - config.cropped_size + 1, stride): 90 | new_data = data[h_pixel:h_pixel + config.cropped_size, w_pixel:w_pixel + config.cropped_size, :] 91 | new_target = target[h_pixel:h_pixel + config.cropped_size, w_pixel:w_pixel + config.cropped_size] 92 | 93 | data_save_path = os.path.join(new_dataset_path, filename + '.{}.tif'.format(subimgcounter)) 94 | target_save_path = data_save_path.replace(config.data_folder_name, config.target_folder_name) 95 | 96 | misc.imsave(data_save_path, new_data) 97 | misc.imsave(target_save_path, new_target) 98 | subimgcounter += 1 99 | 100 | 101 | # feature select from intermediate layer in Network from module or Sequential 102 | class SelectiveSequential(nn.Module): 103 | def __init__(self, to_select, modules_dict): 104 | super(SelectiveSequential, self).__init__() 105 | for key, module in modules_dict.items(): 106 | self.add_module(key, module) 107 | 108 | self._to_select = to_select 109 | 110 | def forward(self, input): 111 | list = [] 112 | for name, module in self._modules.iteritems(): 113 | x = module(x) 114 | if name in self._to_select: 115 | list.append(x) 116 | 117 | return list 118 | 119 | 120 | class FeatureExtractor(nn.Module): 121 | def __init__(self, submodule, extracted_layers): 122 | super(FeatureExtractor, self).__init__() 123 | self.submodule = submodule 124 | self.extracted_layers = extracted_layers 125 | 126 | def forward(self, x): 127 | outputs = [] 128 | for name, module in self.submodule._modules.items(): 129 | x = module(x) 130 | if name in self.extracted_layers: 131 | outputs += [x] 132 | 133 | return outputs + [x] 134 | 135 | 136 | class model_utils(): 137 | # class global variable 138 | flops_list_conv = [] 139 | flops_list_linear = [] 140 | flops_list_bn = [] 141 | flops_list_relu = [] 142 | flops_list_pooling = [] 143 | 144 | mac_list_conv = [] 145 | mac_list_linear = [] 146 | mac_list_bn = [] 147 | mac_list_relu = [] 148 | mac_list_pooling = [] 149 | 150 | MULTIPLY_ADDS = False 151 | summary = OrderedDict() 152 | names = {} 153 | display_input_shape = True 154 | display_weights = False 155 | display_nb_trainable = False 156 | 157 | def __init__(self, model, config, index=None): 158 | # TODO include batch_size or not include batch_size 159 | """ 160 | for hooks register 161 | Generally, since majority of flops are in conv and linear, nflops ~= X might show that you are approximating it, and 162 | that is prob sufficient for almost all things. 163 | 164 | :param model: 165 | :param config: for simulating input_size, batch_size 166 | """ 167 | self.model = model 168 | self.config = config 169 | self.hooks = [] 170 | # for test downsampling flops and params 171 | # input_size=256 172 | self.channels_list = [3, 64, 128, 256, 512, 1024] 173 | self.size_list = [256, 128, 64 , 32, 16, 8] 174 | self.index = index 175 | def model_params(self): 176 | 177 | trainable_params = sum([p.nelement() for p in filter(lambda p: p.requires_grad, self.model.parameters())]) 178 | total_params = sum([p.nelement() for p in self.model.parameters()]) 179 | 180 | return trainable_params, total_params 181 | 182 | def speed_testing(self): 183 | 184 | # cuDnn configurations 185 | torch.backends.cudnn.benchmark = False 186 | torch.backends.cudnn.deterministic = False 187 | # get average inference time for 200 iterations 188 | 189 | name = self.model.name 190 | print(" + {} Speed testing... ...".format(name)) 191 | model = self.model.to('cuda:{}'.format(self.config.device_id)) 192 | random_input = torch.randn(1,3,self.config.input_size, self.config.input_size).to('cuda:{}'.format(self.config.device_id)) 193 | 194 | model.eval() 195 | print(' +warm up ... ...') 196 | for i in range(500): 197 | model(random_input) 198 | 199 | time_list = [] 200 | for i in tqdm(range(10001)): 201 | torch.cuda.synchronize() 202 | tic = time.time() 203 | model(random_input) 204 | torch.cuda.synchronize() 205 | # the first iteration time cost much higher, so exclude the first iteration 206 | #print(time.time()-tic) 207 | time_list.append(time.time()-tic) 208 | time_list = time_list[1:] 209 | print(" + Done 10000 iterations inference !") 210 | print(" + Total time cost: {}s".format(sum(time_list))) 211 | print(" + Average time cost: {}s".format(sum(time_list)/10000)) 212 | print(" + Frame Per Second: {:.2f}".format(1/(sum(time_list)/10000))) 213 | 214 | 215 | def _register_hooks(self): 216 | 217 | # register_forward_hook called everytime after forward() compute an output 218 | # children 会返回init中定义的module, modules()会递归返回所有module 219 | for module in self.model.modules(): 220 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d): 221 | self.hooks.append(module.register_forward_hook(self.conv_hook)) 222 | self.hooks.append(module.register_forward_hook(self.hook)) 223 | elif isinstance(module, nn.Linear): 224 | self.hooks.append(module.register_forward_hook(self.linear_hook)) 225 | self.hooks.append(module.register_forward_hook(self.hook)) 226 | #elif isinstance(module, nn.BatchNorm2d): 227 | # self.hooks.append(module.register_forward_hook(self.bn_hook)) 228 | # self.hooks.append(module.register_forward_hook(self.hook)) 229 | #elif isinstance(module, nn.ReLU): 230 | # self.hooks.append(module.register_forward_hook(self.relu_hook)) 231 | # self.hooks.append(module.register_forward_hook(self.hook)) 232 | #elif isinstance(module, nn.MaxPool2d) or isinstance(module, nn.AvgPool2d): 233 | # self.hooks.append(module.register_forward_hook(self.pooling_hook)) 234 | # self.hooks.append(module.register_forward_hook(self.hook)) 235 | 236 | def simulate_forward(self): 237 | 238 | total_params, trainable_params = self.model_params() 239 | model_utils.names = self.get_names_dict() 240 | self._register_hooks() 241 | 242 | input = torch.randn(size=(1, 3, self.config.input_size, self.config.input_size)) 243 | # for test module performance 244 | #input = torch.randn(size=(1, self.channels_list[self.index], self.size_list[self.index], self.size_list[self.index])) 245 | if next(self.model.parameters()).is_cuda: 246 | input = input.cuda() 247 | self.model.eval() 248 | #tic = time.time() 249 | self.model(input) 250 | #duration = time.time()-tic 251 | 252 | total_mac = sum(self.mac_list_conv) + sum(self.mac_list_linear) + sum(self.mac_list_bn) + \ 253 | sum(self.mac_list_relu) + sum(self.mac_list_pooling) 254 | total_flops = sum(self.flops_list_conv) + sum(self.flops_list_linear) + sum(self.flops_list_bn) + \ 255 | sum(self.flops_list_relu) + sum(self.flops_list_pooling) 256 | df_summary = pd.DataFrame.from_dict(model_utils.summary, orient='index') 257 | 258 | print(df_summary) 259 | print(" + Number MAC of Model: {:.2f}M ".format(float(total_mac) / 1e6)) 260 | print(" + Number FLOPs of Model: {:.2f}M ".format(float(total_flops) / 1e6)) 261 | print(" + Number Total Params of Model: {:.2f}M ".format(float(total_params) / 1e6)) 262 | print(" + Number Trainanble Params of Model: {:.3f}K ".format(float(trainable_params) / 1e3)) 263 | #print(" + Duration per sample: {}s".format(duration)) 264 | #print(" + Frame Per Second: {}".format(int(1/duration))) 265 | for h in self.hooks: 266 | h.remove() 267 | 268 | def get_names_dict(self): 269 | """ 270 | Recursive walk to get names including path 271 | """ 272 | names = {} 273 | 274 | def _get_names(module, parent_name=''): 275 | for key, module in module.named_children(): 276 | name = parent_name + '.' + key if parent_name else key 277 | names[name] = module 278 | if isinstance(module, torch.nn.Module): 279 | _get_names(module, parent_name=name) 280 | 281 | _get_names(self.model) 282 | # print(names.keys()) 283 | return names 284 | 285 | @staticmethod 286 | def conv_hook(self, input, output): 287 | 288 | batch_size, in_channels, height, width = input[0].shape 289 | out_channels, out_height, out_width = output[0].shape 290 | 291 | kernel_flops = self.kernel_size[0] * self.kernel_size[1] * (in_channels // self.groups) * ( 292 | 2 if model_utils.MULTIPLY_ADDS else 1) 293 | 294 | kernel_mac = self.kernel_size[0] * self.kernel_size[1] * in_channels * out_channels 295 | map_mac = height * width * in_channels * out_channels 296 | MAC = kernel_mac + map_mac 297 | 298 | bias_flops = 1 if self.bias is not False else 0 299 | # Dk*Dk*in*Df*Df*out 300 | flops = batch_size * (kernel_flops + bias_flops) * out_height * out_width * out_channels 301 | # access class variable via class_name.variable_name 302 | model_utils.flops_list_conv.append(flops) 303 | model_utils.mac_list_conv.append(MAC) 304 | 305 | 306 | @staticmethod 307 | def linear_hook(self, input, output): 308 | 309 | batch_size = input[0].shape[0] if input[0].dim() == 2 else 1 310 | # in_features, out_features matrix multiply, 311 | # for each element* out_features = sum(each_features(1)* out_features) * in_features 312 | weight_flops = self.weight.nelement() * (2 if model_utils.MULTIPLY_ADDS else 1) 313 | bias_flops = self.bias.nelement() 314 | 315 | weight_mac = self.weight.nelement() 316 | map_mac = self.in_features+ self.out_features 317 | MAC = weight_mac + map_mac 318 | 319 | flops = batch_size * (weight_flops + bias_flops) 320 | model_utils.flops_list_linear.append(flops) 321 | model_utils.mac_list_linear.append(MAC) 322 | 323 | 324 | @staticmethod 325 | def bn_hook(self, input, output): 326 | model_utils.flops_list_bn.append(input[0].nelement()) 327 | map_mac = input[0].nelement() * 2 328 | model_utils.mac_list_bn.append(map_mac) 329 | 330 | @staticmethod 331 | def relu_hook(self, input, output): 332 | model_utils.flops_list_relu.append(input[0].nelement()) 333 | map_mac = input[0].nelement() * 2 334 | model_utils.mac_list_relu.append(map_mac) 335 | 336 | # TODO does pooling has flops ? maxpooling or average pooling? 337 | # average and maxpooling ommit adds auto 338 | @staticmethod 339 | def pooling_hook(self, input, output): 340 | batch_size, in_channels, height, width = input[0].shape 341 | #print(output) 342 | out_channels, out_height, out_width = output[0].shape 343 | 344 | # kernel_flops in each channels, doesn't share weight along all channels 345 | kernel_flops = self.kernel_size * self.kernel_size 346 | # MAC only for feature map, kernel doesn't have parameters 347 | map_mac = height * width * in_channels + out_height * out_width * out_channels 348 | 349 | bias_flops = 0 # pooling for no bias 350 | flops = batch_size * (kernel_flops + bias_flops) * out_height * out_width * out_channels 351 | model_utils.flops_list_pooling.append(flops) 352 | model_utils.mac_list_pooling.append(map_mac) 353 | 354 | @staticmethod 355 | def hook(self, input, output): 356 | name = '' 357 | for key, item in model_utils.names.items(): 358 | if item == self: 359 | name = key 360 | # 361 | class_name = str(self.__class__).split('.')[-1].split("'")[0] 362 | module_idx = len(model_utils.summary) 363 | 364 | # key_id for new module 365 | m_key = module_idx + 1 366 | 367 | model_utils.summary[m_key] = OrderedDict() 368 | model_utils.summary[m_key]['name'] = name 369 | model_utils.summary[m_key]['class_name'] = class_name 370 | 371 | if model_utils.display_input_shape: 372 | model_utils.summary[m_key]['input_shape'] = (-1,) + tuple(input[0].size())[1:] 373 | model_utils.summary[m_key]['output_shape'] = (-1,) + tuple(output.size())[1:] 374 | 375 | if model_utils.display_weights: 376 | model_utils.summary[m_key]['weights'] = list([tuple(p.size()) for p in self.parameters()]) 377 | 378 | #summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()]) 379 | if model_utils.display_nb_trainable: 380 | params_trainable = sum( 381 | [torch.LongTensor(list(p.size())).prod() for p in self.parameters() if p.requires_grad] 382 | ) 383 | model_utils.summary[m_key]['nb_trainable'] = params_trainable 384 | 385 | params = sum([torch.LongTensor(list(p.size())).prod() for p in self.parameters()]) 386 | model_utils.summary[m_key]['nb_params'] = params 387 | # hook ends 388 | 389 | # Net for testing util function 390 | class Net(nn.Module): 391 | def __init__(self): 392 | super(Net, self).__init__() 393 | self.features = SelectiveSequential( 394 | ['conv1', 'conv3'], 395 | { 396 | 'conv1': nn.Conv2d(1, 1, 3), 397 | 'conv2': nn.Conv2d(1, 1, 3), 398 | 'conv3': nn.Conv2d(1, 1, 3) 399 | } 400 | 401 | ) 402 | 403 | def forward(self, input): 404 | return self.features(input) 405 | 406 | 407 | # only for testing 408 | class test_Net(nn.Module): 409 | def __init__(self): 410 | super(test_Net, self).__init__() 411 | self.conv1 = nn.Conv2d(3, 10, 3, padding=1, bias=False) 412 | self.conv2 = nn.Conv2d(10, 10, 3, padding=1, bias=False) 413 | self.conv3 = nn.Conv2d(10, 3, 3, padding=1, bias=False) 414 | 415 | def forward(self, input): 416 | out = self.conv1(input) 417 | out = self.conv2(out) 418 | out = self.conv3(out) 419 | 420 | return out 421 | 422 | 423 | class test_hook_register(): 424 | aaa = 1 425 | 426 | def __init__(self, model): 427 | self.value = 2 428 | self.model = model 429 | 430 | @staticmethod 431 | def hook(self, input, output): 432 | # print(self.kernel_size) 433 | print('1') 434 | print(test_hook_register.aaa) 435 | # print(value) 436 | 437 | def get_value(self): 438 | pass 439 | 440 | def register(self): 441 | for child in model.children(): 442 | child.register_forward_hook(self.hook) 443 | 444 | def simulate_forward(self): 445 | self.register() 446 | input = torch.randn(size=(1, 3, 64, 64)) 447 | model(input) 448 | # print(" + Number flops of Module: {}".format()) 449 | 450 | 451 | if __name__ == '__main__': 452 | model = test_Net() 453 | # for name, module in model._modules.items(): 454 | # print(name, module) 455 | # print(module.in_channels, module.out_channels, module.kernel_size) 456 | # model.named_modules() 类似树形搜索 457 | # model.children() 只输出init部分定义的内容 与 model._modules.items()输出相同 458 | # model_flops(model) 459 | 460 | # input = torch.randn(1,3,224,224) 461 | # output = model(input) 462 | 463 | model_utils = model_utils(model) 464 | --------------------------------------------------------------------------------