├── .gitignore ├── README.md ├── __pycache__ ├── dataset_loader.cpython-37.pyc ├── loss.cpython-37.pyc └── training.cpython-37.pyc ├── dataset_loader.py ├── imagenet_pretrain.py ├── launch_pretrain.sh ├── launch_test.sh ├── launch_train.sh ├── loss.py ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── dsam.cpython-36.pyc │ ├── dsam.cpython-37.pyc │ ├── genotypes.cpython-36.pyc │ ├── genotypes.cpython-37.pyc │ ├── gsmodule.cpython-36.pyc │ ├── gsmodule.cpython-37.pyc │ ├── model_depth.cpython-36.pyc │ ├── model_depth.cpython-37.pyc │ ├── model_fusion.cpython-36.pyc │ ├── model_fusion.cpython-37.pyc │ ├── model_fusion_raw.cpython-36.pyc │ ├── model_fusion_raw.cpython-37.pyc │ ├── model_rgb.cpython-36.pyc │ ├── model_rgb.cpython-37.pyc │ ├── operations.cpython-36.pyc │ └── operations.cpython-37.pyc ├── dsam.py ├── genotypes.py ├── gsmodule.py ├── model_depth.py ├── model_fusion.py ├── model_rgb.py └── operations.py ├── training.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── evaluateFM.cpython-36.pyc ├── evaluateFM.cpython-37.pyc ├── functions.cpython-36.pyc └── functions.cpython-37.pyc ├── evaluateFM.py ├── functions.py └── pretreat_SIP.py /.gitignore: -------------------------------------------------------------------------------- 1 | /runs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 9 | 10 | 11 | # DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral) 12 | 13 | This repo is the official implementation of 14 | ["DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion"](https://arxiv.org/pdf/2103.11832.pdf) 15 | 16 | by Peng Sun, Wenhu Zhang, Huanyu Wang, Songyuan Li, and Xi Li. 17 | 18 | # Prerequisites 19 | + Ubuntu 18 20 | + PyTorch 1.7.0 21 | + CUDA 10.1 22 | + Cudnn 7.5.1 23 | + Python 3.7 24 | + Numpy 1.17.3 25 | 26 | 27 | # Training 28 | Please see `launch_train.sh` and `launch_pretrain.sh` for imagenet pretraining and sod training, respectively. 29 | 30 | # Testing 31 | Please see `launch_test.sh` for testing on the sod benchmarks. 32 | 33 | ## Main Results 34 | 35 | |Dataset | Er| Sλmean|Fβmean| M | 36 | |:---:|:---:|:---:|:---:|:---:| 37 | |DUT-RGBD|0.950|0.921|0.926|0.030| 38 | |NJUD|0.923|0.903|0.901|0.039| 39 | |NLPR|0.950|0.918|0.897|0.024| 40 | |SSD|0.904|0.876|0.852|0.045| 41 | |STEREO|0.933|0.904|0.898|0.036| 42 | |LFSD|0.923|0.882|0.882|0.054| 43 | |RGBD135|0.962|0.920|0.896|0.021| 44 | 45 | ## Saliency maps and Evaluation 46 | 47 | All of the saliency maps mentioned in the paper are available on [GoogleDrive](https://drive.google.com/file/d/1pqRpWgyDry3o6iKNNDx_eM2_kEOftYY3/view?usp=sharing) or [BaiduYun](https://pan.baidu.com/s/1Fr5PuABceE7ordJvE84PKA)(code:juc2). 48 | 49 | You can use the toolbox provided by [jiwei0921](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox) for evaluation. 50 | 51 | Additionally, we also provide the saliency maps of the STERE-1000 and SIP dataset on [BaiduYun](https://pan.baidu.com/s/1Pp1Hvckfsvr7mWq9qcY9pw)(code:qxfw) for easy comparison. 52 | 53 | 54 | |Dataset | Er| Sλmean|Fβmean| M | 55 | |:---:|:---:|:---:|:---:|:---:| 56 | |STERE-1000|0.928|0.897|0.895|0.038| 57 | |SIP|0.908|0.861|0.868|0.057| 58 | 59 | ## Citation 60 | ``` 61 | @inproceedings{Sun2021DeepRS, 62 | title={Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion}, 63 | author={P. Sun and Wenhu Zhang and Huanyu Wang and Songyuan Li and Xi Li}, 64 | journal={IEEE Conf. Comput. Vis. Pattern Recog.}, 65 | year={2021} 66 | } 67 | ``` 68 | 69 | 70 | ## License 71 | 72 | The code is released under MIT License (see LICENSE file for details). 73 | -------------------------------------------------------------------------------- /__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/training.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/__pycache__/training.cpython-37.pyc -------------------------------------------------------------------------------- /dataset_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image 4 | import scipy.io as sio 5 | import torch 6 | from torch.utils import data 7 | import cv2 8 | from utils.functions import adaptive_bins, get_bins_masks 9 | 10 | class MyData(data.Dataset): 11 | """ 12 | load data in a folder 13 | """ 14 | mean_rgb = np.array([0.447, 0.407, 0.386]) 15 | std_rgb = np.array([0.244, 0.250, 0.253]) 16 | 17 | 18 | def __init__(self, root, transform=False): 19 | super(MyData, self).__init__() 20 | self.root = root 21 | 22 | self._transform = transform 23 | img_root = os.path.join(self.root, 'train_images') 24 | mask_root = os.path.join(self.root, 'train_masks') 25 | depth_root = os.path.join(self.root, 'train_depth') 26 | file_names = os.listdir(img_root) 27 | self.img_names = [] 28 | self.mask_names = [] 29 | self.depth_names = [] 30 | for i, name in enumerate(file_names): 31 | if not name.endswith('.jpg'): 32 | continue 33 | ## training with 2 dataset 34 | # if len(name.split('_')[0]) ==4 : 35 | # continue 36 | # print(name) 37 | self.mask_names.append( 38 | os.path.join(mask_root, name[:-4] + '.png') 39 | ) 40 | 41 | self.img_names.append( 42 | os.path.join(img_root, name) 43 | ) 44 | self.depth_names.append( 45 | os.path.join(depth_root, name[:-4] + '.png') 46 | ) 47 | 48 | def __len__(self): 49 | return len(self.img_names) 50 | 51 | def __getitem__(self, index): 52 | # load image 53 | img_file = self.img_names[index] 54 | img = PIL.Image.open(img_file) 55 | img = np.array(img, dtype=np.uint8) 56 | # load label 57 | mask_file = self.mask_names[index] 58 | mask = PIL.Image.open(mask_file) 59 | mask = np.array(mask, dtype=np.int32) 60 | mask[mask != 0] = 1 61 | # load depth 62 | depth_file = self.depth_names[index] 63 | depth = PIL.Image.open(depth_file) 64 | depth = np.array(depth, dtype=np.uint8) 65 | # bins 66 | bins_mask = get_bins_masks(depth) 67 | 68 | if self._transform: 69 | return self.transform(img, mask, depth, bins_mask) 70 | else: 71 | return img, mask, depth, bins_mask 72 | 73 | def transform(self, img, mask, depth, bins_mask): 74 | img = img.astype(np.float64)/255.0 75 | img -= self.mean_rgb 76 | img /= self.std_rgb 77 | img = img.transpose(2, 0, 1) # to verify 78 | img = torch.from_numpy(img).float() 79 | mask = torch.from_numpy(mask).long() 80 | depth = depth.astype(np.float64) / 255.0 81 | depth = torch.from_numpy(depth).float() 82 | 83 | bins_mask=torch.from_numpy(bins_mask).float() 84 | h,w=depth.size() 85 | bins_depth = depth.view(1, h, w).repeat(3, 1, 1) 86 | bins_depth=bins_depth * bins_mask 87 | for i in range(3): 88 | bins_depth[i]=bins_depth[i]/bins_depth[i].max() 89 | c, h, w = img.size() 90 | return img, mask, depth, bins_depth# 91 | 92 | 93 | 94 | 95 | 96 | class MyTestData(data.Dataset): 97 | """ 98 | load data in a folder 99 | """ 100 | mean_rgb = np.array([0.447, 0.407, 0.386]) 101 | std_rgb = np.array([0.244, 0.250, 0.253]) 102 | 103 | def __init__(self, root, transform=False, use_bins=True): 104 | super(MyTestData, self).__init__() 105 | self.root = root 106 | self._transform = transform 107 | self._bins = use_bins 108 | 109 | img_root = os.path.join(self.root, 'test_images') 110 | depth_root = os.path.join(self.root, 'test_depth') 111 | file_names = os.listdir(img_root) 112 | self.img_names = [] 113 | self.names = [] 114 | self.depth_names = [] 115 | 116 | for i, name in enumerate(file_names): 117 | if not name.endswith('.jpg'): 118 | continue 119 | self.img_names.append( 120 | os.path.join(img_root, name) 121 | ) 122 | self.names.append(name[:-4]) 123 | self.depth_names.append( 124 | os.path.join(depth_root, name[:-4] + '.png') 125 | ) 126 | 127 | def __len__(self): 128 | return len(self.img_names) 129 | 130 | def __getitem__(self, index): 131 | # load image 132 | img_file = self.img_names[index] 133 | img = PIL.Image.open(img_file) 134 | img_size = img.size 135 | img = np.array(img, dtype=np.uint8) 136 | 137 | # load depth 138 | depth_file = self.depth_names[index] 139 | depth = PIL.Image.open(depth_file) 140 | depth = np.array(depth, dtype=np.uint8) 141 | 142 | bins_mask = get_bins_masks(depth) 143 | 144 | 145 | if self._transform: 146 | img, depth, bins_depth = self.transform(img, depth, bins_mask) 147 | return img, depth, bins_depth, self.names[index], img_size 148 | else: 149 | return img, depth, bins_mask, self.names[index], img_size 150 | 151 | 152 | 153 | def transform(self, img, depth, bins_mask): 154 | img = img.astype(np.float64)/255.0 155 | img -= self.mean_rgb 156 | img /= self.std_rgb 157 | img = img.transpose(2, 0, 1) # to verify 158 | img = torch.from_numpy(img).float() 159 | 160 | depth = depth.astype(np.float64) / 255.0 161 | depth = torch.from_numpy(depth).float() 162 | 163 | bins_mask=torch.from_numpy(bins_mask).float() 164 | h,w=depth.size() 165 | bins_depth = depth.view(1, h, w).repeat(3, 1, 1) 166 | bins_depth=bins_depth * bins_mask 167 | for i in range(3): 168 | bins_depth[i]=bins_depth[i]/bins_depth[i].max() 169 | c, h, w = img.size() 170 | return img, depth,bins_depth# 171 | 172 | if __name__ == '__main__': 173 | root = "/data/wenhu/RGBD-SOD/SOD-RGBD/val/SIP" 174 | test_loader = torch.utils.data.DataLoader(MyTestData(root, transform=True), 175 | batch_size=1, shuffle=True, num_workers=4, pin_memory=True) 176 | for id, (data, depth, bins, img_name, img_size) in enumerate(test_loader): 177 | print(img_size) -------------------------------------------------------------------------------- /imagenet_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion 3 | """ 4 | import os 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torchvision import transforms 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | import torchvision.datasets as datasets 15 | import time 16 | import torchvision 17 | import logging 18 | import sys 19 | import argparse 20 | import numpy as np 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | from torch.nn.parallel import DistributedDataParallel as DDP 24 | from tensorboardX import SummaryWriter 25 | from utils.functions import * 26 | from models.model_depth import DepthNet 27 | from models.model_rgb import RgbNet 28 | from models.model_fusion import NasFusionNet_pre 29 | import torch.multiprocessing as mp 30 | import warnings 31 | warnings.filterwarnings("ignore") 32 | 33 | def find_free_port(): 34 | import socket 35 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 36 | # Binding to port 0 will cause the OS to find an available port for us 37 | sock.bind(("", 0)) 38 | port = sock.getsockname()[1] 39 | sock.close() 40 | # NOTE: there is still a chance the port could be taken by other processes. 41 | return port 42 | 43 | def reduce_tensor(tensor): 44 | rt = tensor.clone() 45 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 46 | rt /= args.world_size 47 | return rt 48 | 49 | def accuracy(output, target, topk=(1,)): 50 | """Computes the precision@k for the specified values of k""" 51 | maxk = max(topk) 52 | batch_size = target.size(0) 53 | 54 | _, pred = output.topk(maxk, 1, True, True) 55 | pred = pred.t() 56 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 57 | 58 | res = [] 59 | for k in topk: 60 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 61 | res.append(correct_k.mul_(100.0 / batch_size)) 62 | return res 63 | 64 | def to_python_float(t): 65 | if hasattr(t, 'item'): 66 | return t.item() 67 | else: 68 | return t[0] 69 | 70 | 71 | class AverageMeter(object): 72 | """Computes and stores the average and current value""" 73 | def __init__(self): 74 | self.reset() 75 | 76 | def reset(self): 77 | self.val = 0 78 | self.avg = 0 79 | self.sum = 0 80 | self.count = 0 81 | 82 | def update(self, val, n=1): 83 | self.val = val 84 | self.sum += val * n 85 | self.count += n 86 | self.avg = self.sum / self.count 87 | 88 | 89 | def adjust_learning_rate(optimizer, epoch, args): 90 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs""" 91 | lr = args.lr 92 | if epoch >= 30: 93 | lr = 0.1 * lr 94 | if epoch >= 60: 95 | lr = 0.1 * lr 96 | if epoch >= 80: 97 | lr = 0.1 * lr 98 | optimizer.param_groups[0]['lr'] = lr 99 | 100 | 101 | def train(train_loader, models, CE, optimizers, epoch, logger, logging): 102 | """Train for one epoch on the training set""" 103 | batch_time = AverageMeter() 104 | losses = AverageMeter() 105 | 106 | top1 = AverageMeter() 107 | top5 = AverageMeter() 108 | 109 | # switch to train mode 110 | for m in models: 111 | m.train() 112 | end = time.time() 113 | 114 | for i, (inputs, target) in enumerate(train_loader): 115 | global_step = epoch * len(train_loader) + i 116 | target = target.cuda() 117 | inputs = inputs.cuda() 118 | 119 | # print(gpu,models[0].device_ids, inputs.device) 120 | b,c,h,w = inputs.size() 121 | depth = torch.mean(inputs,dim = 1).view(b,1,h,w).repeat(1, c, 1, 1) 122 | # print("inpus:",inputs.shape) 123 | h1, h2, h3, h4, h5 = models[0](inputs, depth, gumbel=True) 124 | d0, d1, d2, d3, d4 = models[1](depth) 125 | output = models[2](h1, h2, h3, h4, h5, d0, d1, d2, d3, d4) 126 | 127 | # A loss 128 | loss = CE( output, target) * 1.0 129 | 130 | # measure accuracy and record loss 131 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 132 | 133 | reduced_loss = reduce_tensor(loss.data) 134 | prec1 = reduce_tensor(prec1) 135 | prec5 = reduce_tensor(prec5) 136 | 137 | losses.update(to_python_float(reduced_loss), inputs.size(0)) 138 | top1.update(to_python_float(prec1), inputs.size(0)) 139 | top5.update(to_python_float(prec5), inputs.size(0)) 140 | 141 | 142 | # compute gradient and do SGD step 143 | for op in optimizers: 144 | op.zero_grad() 145 | 146 | loss.backward() 147 | 148 | for op in optimizers: 149 | op.step() 150 | 151 | # measure elapsed time 152 | batch_time.update(time.time() - end) 153 | end = time.time() 154 | 155 | if i % 50 == 0 and args.rank ==0: 156 | logging.info('Epoch: [{0}][{1}/{2}]\t' 157 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 158 | 'Loss {loss.val:.4f} \t' 159 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 160 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 161 | epoch, i, len(train_loader), batch_time=batch_time, 162 | loss=losses, top1=top1, top5 = top5)) 163 | 164 | logger.add_scalar('train/losses', losses.avg, global_step=global_step) 165 | logger.add_scalar('train/top1', top1.avg, global_step=global_step) 166 | logger.add_scalar('train/top5', top5.avg, global_step=global_step) 167 | logger.add_scalar('train/lr', optimizers[0].param_groups[0]['lr'], global_step=global_step) 168 | 169 | 170 | def validate(valid_loader, models, CE, epoch, logger, logging): 171 | """Perform validation on the validation set""" 172 | batch_time = AverageMeter() 173 | losses = AverageMeter() 174 | top1 = AverageMeter() 175 | top5 = AverageMeter() 176 | 177 | for m in models: 178 | m.eval() 179 | 180 | end = time.time() 181 | for i, (inputs, target) in enumerate(valid_loader): 182 | target = target.cuda() 183 | inputs = inputs.cuda() 184 | with torch.no_grad(): 185 | b,c,h,w = inputs.size() 186 | depth = torch.mean(inputs,dim = 1).view(b,1,h,w).repeat(1, c, 1, 1) 187 | 188 | h1, h2, h3, h4, h5 = models[0](inputs, depth, gumbel=False) 189 | d0, d1, d2, d3, d4 = models[1](depth) 190 | output = models[2](h1, h2, h3, h4, h5, d0, d1, d2, d3, d4) 191 | 192 | loss = CE(output, target) 193 | 194 | # measure accuracy and record loss 195 | prec1 , prec5 = accuracy(output.data, target, topk=(1,5)) 196 | 197 | reduced_loss = reduce_tensor(loss.data) 198 | prec1 = reduce_tensor(prec1) 199 | prec5 = reduce_tensor(prec5) 200 | 201 | losses.update(to_python_float(reduced_loss), inputs.size(0)) 202 | top1.update(to_python_float(prec1), inputs.size(0)) 203 | top5.update(to_python_float(prec5), inputs.size(0)) 204 | 205 | # measure elapsed time 206 | batch_time.update(time.time() - end) 207 | end = time.time() 208 | 209 | if i % 50 == 0 and args.rank == 0: 210 | logging.info('Test: [{0}/{1}]\t' 211 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 212 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 213 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 214 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 215 | i, len(valid_loader), batch_time=batch_time, loss=losses, 216 | top1=top1, top5 = top5)) 217 | 218 | 219 | logger.add_scalar('valid/top1', top1.avg, global_step=epoch) 220 | logger.add_scalar('valid/top5', top5.avg, global_step=epoch) 221 | 222 | logging.info(' * Prec@1 {top1.avg:.3f} * Prec@5 {top5.avg:.3f} '.format(top1=top1, top5=top5)) 223 | 224 | return top1.avg 225 | 226 | 227 | 228 | def main_worker(gpu, argss): 229 | global args 230 | args = argss 231 | 232 | torch.cuda.set_device(gpu) 233 | rank = args.nr * args.gpus + gpu 234 | args.rank = rank 235 | exp_name = '/imagenet_pretrain' 236 | args.save_path = args.save_path + exp_name 237 | args.snapshot_root = args.save_path +'/snapshot/' 238 | args.log_root = args.save_path + '/logs/train-{}'.format(time.strftime("%Y%m%d-%H%M%S")) 239 | 240 | if args.phase =='train' and args.rank ==0 : 241 | create_exp_dir(args.log_root, scripts_to_save=None) 242 | log_format = '%(asctime)s %(message)s' 243 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 244 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 245 | fh = logging.FileHandler(os.path.join(args.log_root, 'log.txt')) 246 | fh.setFormatter(logging.Formatter(log_format)) 247 | logging.getLogger().addHandler(fh) 248 | 249 | if not os.path.exists(args.snapshot_root) and args.rank ==0 : 250 | os.mkdir(args.snapshot_root) 251 | 252 | dist.init_process_group( 253 | backend='nccl', 254 | init_method=args.dist_url, 255 | world_size=args.world_size, 256 | rank=args.rank) 257 | 258 | 259 | """""""""""dataset loader""""""""" 260 | # ImageNet Data loading code 261 | train_dataset = datasets.ImageFolder( 262 | os.path.join(args.data_root, 'train'), 263 | transforms.Compose([ 264 | transforms.RandomSizedCrop(224), 265 | transforms.RandomHorizontalFlip(), 266 | transforms.ToTensor(), 267 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 268 | std=[0.229, 0.224, 0.225]), 269 | ])) 270 | train_sampler = torch.utils.data.distributed.DistributedSampler( 271 | train_dataset, 272 | num_replicas=args.world_size, 273 | rank=rank, 274 | ) 275 | train_loader = torch.utils.data.DataLoader( 276 | dataset = train_dataset, 277 | batch_size = args.batchsize, 278 | num_workers=0, pin_memory=True, sampler = train_sampler) 279 | 280 | 281 | valid_dataset = datasets.ImageFolder( 282 | os.path.join(args.data_root, 'val'), 283 | transforms.Compose([ 284 | transforms.Scale(256), 285 | transforms.CenterCrop(224), 286 | transforms.ToTensor(), 287 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 288 | std=[0.229, 0.224, 0.225]), 289 | ])) 290 | valid_sampler = torch.utils.data.distributed.DistributedSampler( 291 | valid_dataset, 292 | num_replicas=args.world_size, 293 | rank=rank, 294 | shuffle=False 295 | ) 296 | valid_loader = torch.utils.data.DataLoader( 297 | dataset = valid_dataset, 298 | batch_size = args.batchsize, 299 | num_workers=0, pin_memory=True, sampler = valid_sampler) 300 | 301 | kwargs = {'num_workers': 2, 'pin_memory': True} 302 | logging.info('data already') 303 | 304 | """""""""""train_data/test_data through nets""""""""" 305 | 306 | model_depth = torch.nn.SyncBatchNorm.convert_sync_batchnorm(DepthNet()) 307 | model_rgb = torch.nn.SyncBatchNorm.convert_sync_batchnorm(RgbNet()) 308 | model_fusion = torch.nn.SyncBatchNorm.convert_sync_batchnorm(NasFusionNet_pre()) 309 | 310 | model_depth.init_weights() 311 | vgg19_bn = torchvision.models.vgg19_bn(pretrained=True) 312 | model_rgb.copy_params_from_vgg19_bn(vgg19_bn) 313 | model_fusion.init_weights() 314 | 315 | if args.rank==0: 316 | print("model_rgb param size = %fMB", count_parameters_in_MB(model_rgb)) 317 | print("model_depth param size = %fMB", count_parameters_in_MB(model_depth)) 318 | print("nas-model param size = %fMB", count_parameters_in_MB(model_fusion)) 319 | 320 | model_depth = model_depth.cuda() 321 | model_rgb = model_rgb.cuda() 322 | model_fusion = model_fusion.cuda() 323 | 324 | if args.distributed: 325 | model_depth = DDP(model_depth, device_ids=[gpu]) 326 | model_rgb = DDP(model_rgb, device_ids=[gpu]) 327 | model_fusion = DDP(model_fusion, device_ids=[gpu]) 328 | 329 | optimizer_depth = optim.SGD(model_depth.parameters(), lr= args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 330 | optimizer_rgb = optim.SGD(model_rgb.parameters(), lr= args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 331 | optimizer_fusion = optim.SGD(model_fusion.parameters(), lr= args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 332 | 333 | # print(optimizer_depth.param_groups[0]['lr']) 334 | CE = nn.CrossEntropyLoss().cuda() 335 | 336 | logger = SummaryWriter(args.log_root) 337 | 338 | best_prec1 = -1 339 | for epoch in range(0, args.epoch): 340 | 341 | adjust_learning_rate(optimizer_depth, epoch, args) 342 | adjust_learning_rate(optimizer_rgb, epoch, args) 343 | adjust_learning_rate(optimizer_fusion, epoch, args) 344 | if args.rank==0: 345 | print("lr:",optimizer_rgb.param_groups[0]['lr']) 346 | # train for one epoch 347 | train_sampler.set_epoch(epoch) 348 | train(train_loader, [model_rgb, model_depth, model_fusion], CE, [optimizer_rgb, optimizer_depth, optimizer_fusion], epoch, logger, logging) 349 | 350 | # evaluate on validation set 351 | prec1 = validate(valid_loader, [model_rgb, model_depth, model_fusion], CE, epoch, logger, logging) 352 | 353 | # remember best prec@1 and save checkpoint 354 | is_best = prec1 > best_prec1 355 | best_prec1 = max(prec1, best_prec1) 356 | 357 | if args.rank ==0: 358 | logging.info('Best accuracy: %f' % best_prec1) 359 | logger.add_scalar('best/accuracy', best_prec1, global_step=epoch) 360 | 361 | savename_depth = ('%s/depth_pre_epoch%d.pth' % (args.snapshot_root, epoch)) 362 | torch.save(model_depth.state_dict(), savename_depth) 363 | print('save: (snapshot: %d)' % (epoch)) 364 | 365 | savename_rgb = ('%s/rgb_pre_epoch%d.pth' % (args.snapshot_root, epoch)) 366 | torch.save(model_rgb.state_dict(), savename_rgb) 367 | print('save: (snapshot: %d)' % (epoch)) 368 | 369 | savename_fusion = ('%s/fusion_pre_epoch%d.pth' % (args.snapshot_root, epoch)) 370 | torch.save(model_fusion.state_dict(), savename_fusion) 371 | print('save: (snapshot: %d)' % (epoch)) 372 | 373 | if is_best: 374 | savename_depth = ('%s/depth_pre.pth' % (args.snapshot_root)) 375 | torch.save(model_depth.state_dict(), savename_depth) 376 | print('save: (snapshot: %d)' % (epoch)) 377 | 378 | savename_rgb = ('%s/rgb_pre.pth' % (args.snapshot_root)) 379 | torch.save(model_rgb.state_dict(), savename_rgb) 380 | print('save: (snapshot: %d)' % (epoch)) 381 | 382 | savename_fusion = ('%s/fusion_pre.pth' % (args.snapshot_root)) 383 | torch.save(model_fusion.state_dict(), savename_fusion) 384 | print('save: (snapshot: %d)' % (epoch)) 385 | 386 | def main(): 387 | parser=argparse.ArgumentParser() 388 | parser.add_argument('--phase', type=str, default='train', help='train or test') 389 | parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters') 390 | parser.add_argument('--data_root', type=str, default='/4T/sunpeng/ImageNet') 391 | 392 | parser.add_argument('--save_path', type=str, default='/home/wenhu/pami21/runs/', help='save & log path') 393 | parser.add_argument('--snapshot_root', type=str, default='None', help='path to snapshot') 394 | parser.add_argument('--log_root', type=str, default='path to logs') 395 | 396 | parser.add_argument('--test_dataset', type=str, default='') 397 | parser.add_argument('--parse_method', type=str, default='darts', help='parse the code method') 398 | 399 | parser.add_argument('--batchsize', type=int, default=2, help='batchsize') 400 | parser.add_argument('--epoch', type=int, default=100, help='epoch') 401 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 402 | help='initial learning rate') 403 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 404 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 405 | help='weight decay (default: 1e-4)') 406 | 407 | parser.add_argument('-n', '--nodes', default=1, 408 | type=int, metavar='N') 409 | parser.add_argument('-g', '--gpus', default=2, type=int, 410 | help='number of gpus per node') 411 | parser.add_argument('-nr', '--nr', default=0, type=int, 412 | help='ranking within the nodes') 413 | args = parser.parse_args() 414 | 415 | args.distributed = True 416 | 417 | 418 | args.world_size = args.gpus * args.nodes 419 | port = find_free_port() 420 | args.dist_url = f"tcp://127.0.0.1:{port}" 421 | mp.spawn(main_worker, nprocs=args.gpus, args=(args,)) 422 | 423 | if __name__ == '__main__': 424 | main() -------------------------------------------------------------------------------- /launch_pretrain.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | ############### imagenet pretraining 4 | python imagenet_pretrain.py -n 1 -g 4 -nr 0 \ 5 | --phase train --epoch 90 --batchsize 4 --lr 0.00625 --momentum 0.9 --weight_decay 5e-4 \ 6 | --data_root /4T/sunpeng/ImageNet \ 7 | --save_path /4T/wenhu/pami21/github_test -------------------------------------------------------------------------------- /launch_test.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: Wenhu Zhang 4 | # @Date: 2021-06-07 17:39:22 5 | # @LastEditTime: 2021-06-07 18:13:00 6 | # @LastEditors: Wenhu Zhang 7 | # @Description: 8 | # @FilePath: /github/wh/DSA2F/launch_test.sh 9 | ### 10 | 11 | CUDA_VISIBLE_DEVICES="0" python -u main.py --phase test --test_dataset NJUD --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_NJUD.txt & 12 | CUDA_VISIBLE_DEVICES="0" python -u main.py --phase test --test_dataset DUT-RGBD --begin_epoch 20 --end_epoch 97 --exp_name 0428debug > results/0428debug_DUT-RGBD.txt & 13 | CUDA_VISIBLE_DEVICES="1" python -u main.py --phase test --test_dataset NLPR --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_NLPR.txt & 14 | CUDA_VISIBLE_DEVICES="1" python -u main.py --phase test --test_dataset SSD --begin_epoch 20000 --end_epoch 20000 --exp_name 0428debug > results/0428debug_SSD.txt 15 | 16 | CUDA_VISIBLE_DEVICES="2" python -u main.py --phase test --test_dataset STEREO --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_STEREO.txt & 17 | CUDA_VISIBLE_DEVICES="2" python -u main.py --phase test --test_dataset LFSD --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_LFSD.txt & 18 | CUDA_VISIBLE_DEVICES="3" python -u main.py --phase test --test_dataset RGBD135 --begin_epoch 1 --end_epoch 8 --exp_name 0428debug > results/0428debug_RGBD135.txt & 19 | CUDA_VISIBLE_DEVICES="3" python -u main.py --phase test --test_dataset SIP --begin_epoch 1 --end_epoch 8 --exp_name 0428debug > results/0428debug_SIP.txt & 20 | CUDA_VISIBLE_DEVICES="3" python -u main.py --phase test --test_dataset ReDWeb --begin_epoch 1 --end_epoch 8 --exp_name 0428debug > results/0428debug_ReDWeb.txt & -------------------------------------------------------------------------------- /launch_train.sh: -------------------------------------------------------------------------------- 1 | 2 | ### 3 | # @Author: Wenhu Zhang 4 | # @Date: 2021-06-07 17:39:22 5 | # @LastEditTime: 2021-06-07 18:09:35 6 | # @LastEditors: Wenhu Zhang 7 | # @Description: 8 | # @FilePath: /github/wh/DSA2F/launch_train.sh 9 | ### 10 | CUDA_VISIBLE_DEVICES="0" python main.py --phase train --epoch 60 \ 11 | --save_path /4T/wenhu/pami21/ \ 12 | --pretrain_path /home/wenhu/pami21/ckpt_best.pth.tar \ 13 | --train_dataroot /4T/wenhu/dataset/SOD-RGBD/train_data-augment/ \ 14 | --test_dataroot /4T/wenhu/dataset/SOD-RGBD/val/ \ 15 | --exp_name 0607debug -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | class BinaryDiceLoss(nn.Module): 8 | """Dice loss of binary class 9 | Args: 10 | smooth: A float number to smooth loss, and avoid NaN error, default: 1 11 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 12 | predict: A tensor of shape [N, *] 13 | target: A tensor of shape same with predict 14 | reduction: Reduction method to apply, return mean over batch if 'mean', 15 | return sum if 'sum', return a tensor of shape [N,] if 'none' 16 | Returns: 17 | Loss tensor according to arg reduction 18 | Raise: 19 | Exception if unexpected reduction 20 | """ 21 | def __init__(self, smooth=1, p=2, reduction='mean'): 22 | super(BinaryDiceLoss, self).__init__() 23 | self.smooth = smooth 24 | self.p = p 25 | self.reduction = reduction 26 | 27 | def forward(self, predict, target): 28 | predict = F.softmax(predict, dim=1)[:,1,:,:].unsqueeze(1) 29 | target = target.unsqueeze(1).float() 30 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 31 | predict = predict.contiguous().view(predict.shape[0], -1) 32 | target = target.contiguous().view(target.shape[0], -1) 33 | 34 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth 35 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth 36 | 37 | loss = 1 - num / den 38 | 39 | if self.reduction == 'mean': 40 | return loss.mean() *256*256 41 | elif self.reduction == 'sum': 42 | return loss.sum()*256*256 43 | elif self.reduction == 'none': 44 | return loss*256*256 45 | else: 46 | raise Exception('Unexpected reduction {}'.format(self.reduction)) 47 | 48 | 49 | 50 | def cross_entropy2d(input, target, weight=None, size_average=True): 51 | n, c, h, w = input.size() 52 | 53 | input = input.transpose(1, 2).transpose(2, 3).contiguous() 54 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] # 262144 #input = 2*256*256*2 55 | input = input.view(-1, c) 56 | mask = target >= 0 57 | target = target[mask] 58 | loss = F.cross_entropy(input, target, weight=weight, size_average=False) 59 | if size_average: 60 | loss /= mask.data.sum() 61 | return loss 62 | 63 | 64 | 65 | def iou(pred, target, size_average = False): 66 | 67 | pred = F.softmax(pred, dim=1) 68 | IoU = 0.0 69 | Iand1 = torch.sum(target.float() * pred[:,1,:,:]) 70 | Ior1 = torch.sum(target) + torch.sum(pred[:,1,:,:]) - Iand1 71 | IoU1 = (Iand1 + 1) / (Ior1 + 1) 72 | 73 | IoU = (1-IoU1) 74 | 75 | if size_average: 76 | IoU /= target.data.sum() 77 | return IoU * 256 * 256 78 | 79 | 80 | # class CrossEntropyLabelSmooth(nn.Module): 81 | # """Cross entropy loss with label smoothing regularizer. 82 | # Reference: 83 | # Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 84 | # Equation: y = (1 - epsilon) * y + epsilon / K. 85 | # Args: 86 | # num_classes (int): number of classes. 87 | # epsilon (float): weight. 88 | # """ 89 | 90 | # def __init__(self, num_classes=1000, epsilon=0.1): 91 | # super(CrossEntropyLabelSmooth, self).__init__() 92 | # self.num_classes = num_classes 93 | # self.epsilon = epsilon 94 | # self.logsoftmax = nn.LogSoftmax(dim=1) 95 | 96 | # def forward_v1(self, inputs, targets): 97 | # log_probs = self.logsoftmax(inputs) 98 | # targets = torch.zeros(log_probs.size(), device=targets.device).scatter_(1, targets.unsqueeze(1), 1) 99 | # targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 100 | # loss = (- targets * log_probs).mean(0).sum() 101 | # return loss 102 | 103 | # def forward_v2(self, inputs, targets): 104 | # probs = self.logsoftmax(inputs) 105 | # targets = torch.zeros(probs.size(), device=targets.device).scatter_(1, targets.unsqueeze(1), 1) 106 | # targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 107 | # loss = nn.KLDivLoss()(probs, targets) 108 | # return loss 109 | 110 | # def forward(self, inputs, targets): 111 | # """ 112 | # Args: 113 | # inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 114 | # targets: ground truth labels with shape (num_classes) 115 | # """ 116 | # return self.forward_v1(inputs, targets) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion 3 | """ 4 | import os 5 | import torch 6 | from torch.autograd import Variable 7 | from torch.utils.data import DataLoader 8 | import torch.nn.functional as F 9 | import time 10 | import torchvision 11 | import logging 12 | import sys 13 | import argparse 14 | import numpy as np 15 | import torch.backends.cudnn as cudnn 16 | import torch.distributed as dist 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from tensorboardX import SummaryWriter 19 | from dataset_loader import MyData, MyTestData 20 | from utils.functions import * 21 | from training import Trainer 22 | from utils.evaluateFM import get_FM 23 | from models.model_depth import DepthNet 24 | from models.model_rgb import RgbNet 25 | from models.model_fusion import NasFusionNet 26 | import warnings 27 | warnings.filterwarnings("ignore") 28 | 29 | configurations = { 30 | 1: dict( 31 | max_iteration=1000000, 32 | lr=5e-9, 33 | momentum=0.9, 34 | weight_decay=0.0005, 35 | spshot=20000, 36 | nclass=2, 37 | sshow=100, 38 | ), 39 | } 40 | parser=argparse.ArgumentParser() 41 | parser.add_argument('--phase', type=str, default='train', help='train or test') 42 | parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters') 43 | parser.add_argument('--train_dataroot', type=str, default='/4T/wenhu/dataset/SOD-RGBD/train_data-augment', help= 44 | 'path to train data') 45 | parser.add_argument('--test_dataroot', type=str, default='/4T/wenhu/dataset/SOD-RGBD/val/', help= 46 | 'path to test data') 47 | parser.add_argument('--pretrain_path', type=str, default='') 48 | 49 | parser.add_argument('--exp_name', type=str, default='debug', help='save & log path') 50 | parser.add_argument('--save_path', type=str, default='/4T/wenhu/pami21/', help='save & log path') 51 | parser.add_argument('--snapshot_root', type=str, default='None', help='path to snapshot') 52 | parser.add_argument('--salmap_root', type=str, default='None', help='path to saliency map') 53 | parser.add_argument('--log_root', type=str, default='path to logs') 54 | 55 | parser.add_argument('--test_dataset', type=str, default='LFSD') 56 | parser.add_argument('--begin_epoch', type=int, default=0) 57 | parser.add_argument('--end_epoch', type=int, default=0) 58 | parser.add_argument('--parse_method', type=str, default='darts', help='parse the code method') 59 | 60 | parser.add_argument('--batchsize', type=int, default=2, help='batchsize') 61 | parser.add_argument('--epoch', type=int, default=60, help='epoch') 62 | parser.add_argument("--local_rank", default=-1) 63 | parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys()) 64 | args = parser.parse_args() 65 | cfg = configurations 66 | 67 | 68 | 69 | args.save_path = args.save_path + args.exp_name 70 | if args.phase =='train': 71 | if os.path.exists(args.save_path): 72 | print(".... error!!!!!!!!!! save path already exist .....") 73 | logging.info(".... error!!!!!!!!!! save path already exist .....") 74 | sys.exit() 75 | else : 76 | os.mkdir(args.save_path) 77 | 78 | args.snapshot_root = args.save_path +'/snapshot/' 79 | args.salmap_root = args.save_path + '/sal_map/' 80 | args.log_root = args.save_path + '/logs/' 81 | if not os.path.exists(args.salmap_root): 82 | os.mkdir(args.salmap_root) 83 | 84 | cuda = torch.cuda.is_available 85 | 86 | if args.phase =='train': 87 | create_exp_dir(args.log_root, scripts_to_save=None) 88 | log_format = '%(asctime)s %(message)s' 89 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 90 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 91 | fh = logging.FileHandler(os.path.join(args.log_root, 'log.txt')) 92 | fh.setFormatter(logging.Formatter(log_format)) 93 | logging.getLogger().addHandler(fh) 94 | 95 | 96 | """""""""""dataset loader""""""""" 97 | 98 | train_dataRoot = args.train_dataroot 99 | 100 | if not os.path.exists(args.snapshot_root): 101 | os.mkdir(args.snapshot_root) 102 | 103 | if args.phase == 'train': 104 | SnapRoot = args.snapshot_root # checkpoint 105 | train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True), 106 | batch_size = args.batchsize, shuffle=True, num_workers=0, pin_memory=True) 107 | else: 108 | test_dataRoot = args.test_dataroot +args.test_dataset 109 | max_F_dict = {} 110 | min_mae_dict = {} 111 | MapRoot = args.salmap_root +args.test_dataset 112 | if not os.path.exists(MapRoot): 113 | os.mkdir(MapRoot) 114 | test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True), 115 | batch_size=1, shuffle=True, num_workers=4, pin_memory=True) 116 | print ('data already') 117 | 118 | """""""""""train_data/test_data through nets""""""""" 119 | cuda = torch.cuda.is_available 120 | start_epoch = 0 121 | start_iteration = 0 122 | 123 | model_depth = DepthNet() 124 | model_rgb = RgbNet() 125 | model_fusion = NasFusionNet() 126 | 127 | print("model_rgb param size = %fMB", count_parameters_in_MB(model_rgb)) 128 | print("model_depth param size = %fMB", count_parameters_in_MB(model_depth)) 129 | print("nas-model param size = %fMB", count_parameters_in_MB(model_fusion)) 130 | 131 | if args.begin_epoch == args.end_epoch: 132 | test_check_list = [args.end_epoch] 133 | else: 134 | test_epoch_list = [i*16418 for i in range(1,61)] 135 | test_iter_list = [i*10000 for i in range(args.begin_epoch, args.end_epoch+1)] 136 | test_check_list = test_epoch_list + test_iter_list 137 | test_check_list.sort() 138 | for ckpt_i in test_check_list: # When training, remove this line.ssss 139 | best_F = -float('inf') 140 | best_mae = float('inf') 141 | 142 | if args.phase == 'test': 143 | ckpt = str(ckpt_i) 144 | print(".... load checkpoint "+ ckpt +" for test .....") 145 | model_depth.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'depth_snapshot_iter_' + ckpt + '.pth'))) 146 | model_rgb.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'rgb_snapshot_iter_'+ckpt+'.pth'))) 147 | model_fusion.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'fusion_snapshot_iter_'+ckpt+'.pth'))) 148 | 149 | elif (args.pretrain_path): 150 | pretrained_dict = load_pretrain(args.pretrain_path, model_depth.state_dict(), "model_depth.") 151 | model_depth.load_state_dict(pretrained_dict) 152 | 153 | pretrained_dict = load_pretrain(args.pretrain_path, model_rgb.state_dict(), "model_rgb.") 154 | model_rgb.load_state_dict(pretrained_dict) 155 | 156 | model_fusion.init_weights() 157 | pretrained_dict = load_pretrain(args.pretrain_path, model_fusion.state_dict(), "model_fusion.") 158 | model_fusion.load_state_dict(pretrained_dict) 159 | logging.info(".... load imagenet pretrain models .....") 160 | 161 | else: 162 | logging.info(".... norm init .....") 163 | model_depth.init_weights() 164 | vgg19_bn = torchvision.models.vgg19_bn(pretrained=True) 165 | model_rgb.copy_params_from_vgg19_bn(vgg19_bn) 166 | model_fusion.init_weights() 167 | 168 | if cuda: 169 | model_depth = model_depth.cuda() 170 | model_rgb = model_rgb.cuda() 171 | model_fusion = model_fusion.cuda() 172 | 173 | if args.phase == 'train': 174 | cudnn.benchmark = True 175 | # torch.manual_seed(444) 176 | cudnn.enabled=True 177 | # torch.cuda.manual_seed(444) 178 | writer = SummaryWriter(args.log_root) 179 | model_rgb.cuda() 180 | model_depth.cuda() 181 | model_fusion.cuda() 182 | 183 | training = Trainer( 184 | cuda=cuda, 185 | cfg=cfg, 186 | model_depth=model_depth, 187 | model_rgb=model_rgb, 188 | model_fusion=model_fusion, 189 | train_loader=train_loader, 190 | test_data_list = ["DUT-RGBD","NJUD","NLPR","SSD","STEREO","LFSD","RGBD135","SIP","ReDWeb"], 191 | test_data_root = args.test_dataroot, 192 | salmap_root = args.salmap_root, 193 | outpath=args.snapshot_root, 194 | logging=logging, 195 | writer=writer, 196 | max_epoch=args.epoch, 197 | ) 198 | training.epoch = start_epoch 199 | training.iteration = start_iteration 200 | training.train() 201 | else: 202 | # -------------------------- inference --------------------------- # 203 | res = [] 204 | for id, (data, depth, bins, img_name, img_size) in enumerate(test_loader): 205 | # print('testing bach %d' % id) 206 | inputs = Variable(data).cuda() 207 | depth = Variable(depth).cuda() 208 | bins = Variable(bins).cuda() 209 | n, c, h, w = inputs.size() 210 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1) 211 | torch.cuda.synchronize() 212 | start = time.time() 213 | with torch.no_grad(): 214 | h1, h2, h3, h4, h5 = model_rgb(inputs, bins, gumbel=False) 215 | d0, d1, d2, d3, d4 = model_depth(depth) 216 | predict_mask = model_fusion(h1, h2, h3, h4, h5, d0, d1, d2, d3, d4) 217 | torch.cuda.synchronize() 218 | end = time.time() 219 | 220 | res.append(end - start) 221 | outputs_all = F.softmax(predict_mask, dim=1) 222 | outputs = outputs_all[0][1] 223 | outputs = outputs.cpu().data.resize_(h, w) 224 | 225 | imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) 226 | time_sum = 0 227 | for i in res: 228 | time_sum += i 229 | print("FPS: %f" % (1.0 / (time_sum / len(res)))) 230 | # -------------------------- validation --------------------------- # 231 | torch.cuda.empty_cache() 232 | print('the testing process has finished!') 233 | F_measure, mae = get_FM(salpath=MapRoot+'/', gtpath=test_dataRoot+'/test_masks/') 234 | print(args.test_dataset + ' F_measure:', F_measure) 235 | print(args.test_dataset + ' MAE:', mae) 236 | 237 | F_key = args.test_dataset +'_Fb' 238 | M_key = args.test_dataset +'_mae' 239 | ckpt_key = args.test_dataset +'_ckpt' 240 | if F_key in max_F_dict.keys(): 241 | if F_measure > max_F_dict[F_key]: 242 | max_F_dict[F_key] = F_measure 243 | max_F_dict[M_key] = mae 244 | max_F_dict[ckpt_key] = ckpt 245 | else: 246 | max_F_dict[F_key] = F_measure 247 | max_F_dict[M_key] = mae 248 | max_F_dict[ckpt_key] = ckpt 249 | 250 | if M_key in min_mae_dict.keys(): 251 | if mae < min_mae_dict[M_key]: 252 | min_mae_dict[F_key] = F_measure 253 | min_mae_dict[M_key] = mae 254 | min_mae_dict[ckpt_key] = ckpt 255 | else: 256 | min_mae_dict[F_key] = F_measure 257 | min_mae_dict[M_key] = mae 258 | min_mae_dict[ckpt_key] = ckpt 259 | 260 | if args.phase == 'test': 261 | print ("max_F_dict") 262 | print (max_F_dict) 263 | print ("min_mae_dict") 264 | print (min_mae_dict) 265 | 266 | print("finish!!!!!!!!") 267 | 268 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('models') 3 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/dsam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/dsam.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/dsam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/dsam.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/genotypes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/genotypes.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/genotypes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/genotypes.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/gsmodule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/gsmodule.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/gsmodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/gsmodule.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_depth.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_depth.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_depth.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_depth.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_fusion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_fusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_fusion_raw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion_raw.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_fusion_raw.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion_raw.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_rgb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_rgb.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_rgb.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_rgb.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/operations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/operations.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/operations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/operations.cpython-37.pyc -------------------------------------------------------------------------------- /models/dsam.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: Wenhu Zhang 3 | Date: 2021-06-07 17:39:22 4 | LastEditTime: 2021-06-07 18:08:28 5 | LastEditors: Wenhu Zhang 6 | Description: 7 | FilePath: /github/wh/DSA2F/models/dsam.py 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | import cv2 13 | from gsmodule import GumbelSoftmax2D 14 | 15 | 16 | class ChannelAttentionLayer(nn.Module): 17 | def __init__(self, C_in, C_out, reduction=16, affine=True, BN=nn.BatchNorm2d): 18 | super(ChannelAttentionLayer, self).__init__() 19 | # global average pooling: feature --> point 20 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 21 | # feature channel downscale and upscale --> channel weight 22 | self.conv_du = nn.Sequential( 23 | nn.Conv2d(C_in, max(1, C_in // reduction), 1, padding=0, bias=False), 24 | nn.ReLU(), 25 | nn.Conv2d(max(1, C_in // reduction) , C_out, 1, padding=0, bias=False), 26 | nn.Sigmoid()) 27 | def forward(self, x): 28 | y = self.avg_pool(x) 29 | y = self.conv_du(y) 30 | return x * y 31 | 32 | # DSAM V2 33 | class Adaptive_DSAM(nn.Module): 34 | def __init__(self,channel): 35 | super(Adaptive_DSAM, self).__init__() 36 | self.depth_revise = nn.Sequential( 37 | nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True),nn.ReLU(inplace=True), 38 | ) 39 | self.fc = nn.Conv2d(32, 3, 1) 40 | self.GS = GumbelSoftmax2D(hard = True) 41 | 42 | self.channel=channel 43 | self.conv0=nn.Conv2d(channel,channel,1,padding=0) 44 | self.conv1=nn.Conv2d(channel,channel,1,padding=0) 45 | self.conv2=nn.Conv2d(channel,channel,1,padding=0) 46 | self.channel_att = ChannelAttentionLayer(self.channel, self.channel) 47 | def forward(self,x, bins, gumbel=False): 48 | n,c,h,w=x.size() 49 | 50 | bins = self.depth_revise(bins) 51 | gate = self.fc(bins) 52 | bins = self.GS(gate, gumbel=gumbel) * torch.mean(bins, dim=1,keepdim=True) 53 | 54 | x0=self.conv0(bins[:,0,:,:].unsqueeze(1) * x) 55 | x1=self.conv1(bins[:,1,:,:].unsqueeze(1) * x) 56 | x2=self.conv2(bins[:,2,:,:].unsqueeze(1) * x) 57 | x = (x0+x1+x2)+ x 58 | x = self.channel_att(x) 59 | return x 60 | 61 | 62 | # DSAM 63 | class DSAM(nn.Module): 64 | def __init__(self,channel): 65 | super(DSAM, self).__init__() 66 | self.channel=channel 67 | self.conv0=nn.Conv2d(channel,channel,1,padding=0) 68 | self.conv1=nn.Conv2d(channel,channel,1,padding=0) 69 | self.conv2=nn.Conv2d(channel,channel,1,padding=0) 70 | def forward(self,x, bins, gumbel=False): 71 | n,c,h,w=x.size() 72 | 73 | x0=self.conv0(bins[:,0,:,:].unsqueeze(1) * x) 74 | x1=self.conv1(bins[:,1,:,:].unsqueeze(1) * x) 75 | x2=self.conv2(bins[:,2,:,:].unsqueeze(1) * x) 76 | x = (x0+x1+x2)+ x 77 | return x 78 | -------------------------------------------------------------------------------- /models/genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | Genotype_all = namedtuple('Genotype', 'fusion1 fusion1_concat fusion2 fusion2_concat fusion3 fusion3_concat aggregation aggregation_concat final_agg final_aggregation_concat low_high_agg low_high_agg_concat') 5 | """ 6 | Operation sets 7 | """ 8 | PRIMITIVES = [ 9 | 'none', 10 | 'max_pool_3x3', 11 | 'skip_connect', 12 | 'sep_conv_3x3', 13 | 'dil_conv_3x3', 14 | 'conv_1x1', 15 | 'conv_3x3', 16 | 'spatial_attention', 17 | 'channel_attention' 18 | ] 19 | 20 | 21 | attention_snas_3_4_1 = Genotype_all(fusion1=[('sep_conv_3x3', 2), ('dil_conv_3x3', 0), ('spatial_attention', 3), ('sep_conv_3x3', 1), ('max_pool_3x3', 4), ('conv_1x1', 2), ('channel_attention', 3), ('conv_1x1', 0), ('max_pool_3x3', 4), ('max_pool_3x3', 5), ('conv_1x1', 0), ('conv_1x1', 2), ('sep_conv_3x3', 4), ('sep_conv_3x3', 2), ('sep_conv_3x3', 6), ('conv_1x1', 0), ('conv_1x1', 4), ('dil_conv_3x3', 7), ('dil_conv_3x3', 1), ('skip_connect', 5), ('dil_conv_3x3', 0), ('conv_1x1', 1), ('conv_3x3', 4), ('conv_3x3', 8), ('spatial_attention', 3), ('sep_conv_3x3', 0), ('channel_attention', 4), ('conv_1x1', 5), ('conv_1x1', 3), ('sep_conv_3x3', 0), ('dil_conv_3x3', 8), ('dil_conv_3x3', 9)], fusion1_concat=range(6, 12), fusion2=[('sep_conv_3x3', 2), ('dil_conv_3x3', 0), ('spatial_attention', 3), ('sep_conv_3x3', 1), ('max_pool_3x3', 4), ('conv_1x1', 2), ('channel_attention', 3), ('conv_1x1', 0), ('max_pool_3x3', 4), ('max_pool_3x3', 5), ('conv_1x1', 0), ('conv_1x1', 2), ('sep_conv_3x3', 4), ('sep_conv_3x3', 2), ('sep_conv_3x3', 6), ('conv_1x1', 0), ('conv_1x1', 4), ('dil_conv_3x3', 7), ('dil_conv_3x3', 1), ('skip_connect', 5), ('dil_conv_3x3', 0), ('conv_1x1', 1), ('conv_3x3', 4), ('conv_3x3', 8), ('spatial_attention', 3), ('sep_conv_3x3', 0), ('channel_attention', 4), ('conv_1x1', 5), ('conv_1x1', 3), ('sep_conv_3x3', 0), ('dil_conv_3x3', 8), ('dil_conv_3x3', 9)], fusion2_concat=range(6, 12), fusion3=[('sep_conv_3x3', 2), ('dil_conv_3x3', 0), ('spatial_attention', 3), ('sep_conv_3x3', 1), ('max_pool_3x3', 4), ('conv_1x1', 2), ('channel_attention', 3), ('conv_1x1', 0), ('max_pool_3x3', 4), ('max_pool_3x3', 5), ('conv_1x1', 0), ('conv_1x1', 2), ('sep_conv_3x3', 4), ('sep_conv_3x3', 2), ('sep_conv_3x3', 6), ('conv_1x1', 0), ('conv_1x1', 4), ('dil_conv_3x3', 7), ('dil_conv_3x3', 1), ('skip_connect', 5), ('dil_conv_3x3', 0), ('conv_1x1', 1), ('conv_3x3', 4), ('conv_3x3', 8), ('spatial_attention', 3), ('sep_conv_3x3', 0), ('channel_attention', 4), ('conv_1x1', 5), ('conv_1x1', 3), ('sep_conv_3x3', 0), ('dil_conv_3x3', 8), ('dil_conv_3x3', 9)], fusion3_concat=range(6, 12), aggregation=[('spatial_attention', 1), ('max_pool_3x3', 2), ('sep_conv_3x3', 0), ('spatial_attention', 1), ('dil_conv_3x3', 2), ('max_pool_3x3', 3), ('spatial_attention', 1), ('conv_3x3', 4), ('conv_1x1', 2), ('conv_3x3', 0), ('sep_conv_3x3', 5), ('dil_conv_3x3', 3), ('conv_3x3', 1), ('conv_1x1', 5), ('dil_conv_3x3', 0), ('channel_attention', 3), ('spatial_attention', 4), ('max_pool_3x3', 1), ('max_pool_3x3', 1), ('skip_connect', 3), ('conv_3x3', 4), ('channel_attention', 3), ('skip_connect', 1), ('sep_conv_3x3', 6)], aggregation_concat=range(5, 11), final_agg=[('conv_1x1', 1), ('conv_1x1', 0), ('max_pool_3x3', 2), ('conv_1x1', 3), ('channel_attention', 2), ('dil_conv_3x3', 1), ('dil_conv_3x3', 3), ('conv_1x1', 0), ('spatial_attention', 2), ('sep_conv_3x3', 4), ('conv_1x1', 5), ('conv_3x3', 0), ('dil_conv_3x3', 4), ('conv_1x1', 2), ('conv_1x1', 1), ('dil_conv_3x3', 6), ('conv_1x1', 4), ('skip_connect', 2), ('conv_1x1', 7), ('max_pool_3x3', 6), ('conv_3x3', 5), ('channel_attention', 7), ('max_pool_3x3', 2), ('conv_3x3', 4), ('spatial_attention', 7), ('max_pool_3x3', 3), ('sep_conv_3x3', 0), ('spatial_attention', 8), ('max_pool_3x3', 2), ('conv_1x1', 4), ('sep_conv_3x3', 5), ('conv_1x1', 3)], final_aggregation_concat=range(6, 12), low_high_agg=[('max_pool_3x3', 2), ('spatial_attention', 1), ('conv_3x3', 0), ('channel_attention', 1), ('max_pool_3x3', 2), ('conv_1x1', 3), ('max_pool_3x3', 3), ('channel_attention', 1), ('conv_3x3', 4), ('max_pool_3x3', 3), ('skip_connect', 1), ('conv_1x1', 2)], low_high_agg_concat=range(3, 7)) 22 | -------------------------------------------------------------------------------- /models/gsmodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | 5 | """ 6 | Gumbel Softmax Sampler 7 | Requires 2D input [batchsize, number of categories] 8 | 9 | Does not support sinlge binary category. Use two dimensions with softmax instead. 10 | """ 11 | 12 | class GumbelSoftmax2D(torch.nn.Module): 13 | def __init__(self, hard=False): 14 | super(GumbelSoftmax2D, self).__init__() 15 | self.hard = hard 16 | self.gpu = False 17 | 18 | def cuda(self): 19 | self.gpu = True 20 | 21 | def cpu(self): 22 | self.gpu = False 23 | 24 | def sample_gumbel(self, shape, eps=1e-10): 25 | """Sample from Gumbel(0, 1)""" 26 | noise = torch.rand(shape) 27 | noise.add_(eps).log_().neg_() 28 | noise.add_(eps).log_().neg_() 29 | if self.gpu: 30 | return Variable(noise).cuda() 31 | else: 32 | return Variable(noise) 33 | 34 | def sample_gumbel_like(self, template_tensor, eps=1e-10): 35 | uniform_samples_tensor = template_tensor.clone().uniform_() 36 | gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps)) 37 | return gumble_samples_tensor 38 | 39 | def gumbel_softmax_sample(self, logits, temperature): 40 | """ Draw a sample from the Gumbel-Softmax distribution""" 41 | dim = logits.size(-1) 42 | gumble_samples_tensor = self.sample_gumbel_like(logits.data) 43 | gumble_trick_log_prob_samples = logits + Variable(gumble_samples_tensor) 44 | soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, 1) 45 | return soft_samples 46 | 47 | def gumbel_softmax(self, logits, temperature, hard=False, gumbel=False): 48 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 49 | Args: 50 | logits: [batch_size, n_class] unnormalized log-probs 51 | temperature: non-negative scalar 52 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 53 | Returns: 54 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 55 | If hard=True, then the returned sample will be one-hot, otherwise it will 56 | be a probabilitiy distribution that sums to 1 across classes 57 | """ 58 | if gumbel: 59 | y = self.gumbel_softmax_sample(logits, temperature) 60 | else: 61 | y = F.softmax(logits,1) 62 | if hard: 63 | _, max_value_indexes = y.data.max(1, keepdim=True) 64 | y_hard = logits.data.clone().zero_().scatter_(1, max_value_indexes, 1) 65 | y = Variable(y_hard - y.data) + y 66 | return y 67 | 68 | def forward(self, logits, gumbel=False, temp=1): 69 | b,c,h,w= logits.size() 70 | logits = logits.permute(0,2,3,1).contiguous().view(-1,c) 71 | logits = self.gumbel_softmax(logits, temperature=1, hard=self.hard, gumbel=gumbel) 72 | 73 | return logits.view(b,h,w,c).permute(0,3,1,2).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /models/model_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | import torch.nn as nn 7 | 8 | BN_MOMENTUM = 0.1 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | class BasicBlock(nn.Module): 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | out += residual 39 | out = self.relu(out) 40 | 41 | return out 42 | 43 | class DepthNet(nn.Module): 44 | 45 | def __init__(self): 46 | super(DepthNet, self).__init__() 47 | # conv1 48 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 49 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 50 | self.relu1_1 = nn.ReLU(inplace=True) 51 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 52 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 53 | self.relu1_2 = nn.ReLU(inplace=True) 54 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 55 | 56 | # conv2 57 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 58 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 59 | self.relu2_1 = nn.ReLU(inplace=True) 60 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 61 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 62 | self.relu2_2 = nn.ReLU(inplace=True) 63 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 64 | num_stages = 3 65 | blocks = BasicBlock 66 | num_blocks = [4, 4, 4] 67 | num_channels = [32, 32, 128] 68 | self.stage = self._make_stages(num_stages, blocks, num_blocks, num_channels) 69 | self.transition1 = nn.Sequential( 70 | nn.Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), 71 | nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 72 | nn.ReLU(inplace=True) 73 | ) 74 | self.transition2 = nn.Sequential( 75 | nn.Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), 76 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), 77 | nn.ReLU(inplace=True) 78 | ) 79 | 80 | def _make_one_stage(self, stage_index, block, num_blocks, num_channels): 81 | layers = [] 82 | for i in range(0, num_blocks[stage_index]): 83 | layers.append( 84 | block( 85 | num_channels[stage_index], 86 | num_channels[stage_index] 87 | ) 88 | ) 89 | return nn.Sequential(*layers) 90 | 91 | def _make_stages(self, num_stages, block, num_blocks, num_channels): 92 | branches = [] 93 | 94 | for i in range(num_stages): 95 | branches.append( 96 | self._make_one_stage(i, block, num_blocks, num_channels) 97 | ) 98 | return nn.ModuleList(branches) 99 | 100 | def forward(self, d): 101 | #depth branch 102 | d = self.relu1_1(self.bn1_1(self.conv1_1(d))) 103 | d = self.relu1_2(self.bn1_2(self.conv1_2(d))) 104 | d0 = self.pool1(d) # (128x128)*64 105 | 106 | d = self.relu2_1(self.bn2_1(self.conv2_1(d0))) 107 | d = self.relu2_2(self.bn2_2(self.conv2_2(d))) 108 | d1 = self.pool2(d) # (64x64)*128 109 | dt2 = self.transition1(d1) 110 | d2 = self.stage[0](dt2) 111 | d3 = self.stage[1](d2) 112 | dt4 = self.transition2(d3) 113 | d4 = self.stage[2](dt4) 114 | return d0, d1, d2, d3, d4 115 | 116 | def init_weights(self): 117 | logger.info('=> Depth model init weights from normal distribution') 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 121 | nn.init.normal_(m.weight, std=0.001) 122 | for name, _ in m.named_parameters(): 123 | if name in ['bias']: 124 | nn.init.constant_(m.bias, 0) 125 | elif isinstance(m, nn.BatchNorm2d): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | elif isinstance(m, nn.ConvTranspose2d): 129 | nn.init.normal_(m.weight, std=0.001) 130 | for name, _ in m.named_parameters(): 131 | if name in ['bias']: 132 | nn.init.constant_(m.bias, 0) -------------------------------------------------------------------------------- /models/model_fusion.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | from operations import * 8 | import genotypes 9 | from genotypes import attention_snas_3_4_1 10 | from genotypes import PRIMITIVES 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class MixedOp(nn.Module): 18 | def __init__(self, C, stride): 19 | super(MixedOp, self).__init__() 20 | self._ops = nn.ModuleList() 21 | for primitive in PRIMITIVES: 22 | op = OPS[primitive](C, stride, False) 23 | if 'pool' in primitive: 24 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False)) 25 | self._ops.append(op) 26 | 27 | def forward(self, x, weights): 28 | return sum(w * op(x) for w, op in zip(weights, self._ops)) 29 | 30 | # Take four inputs 31 | class FusionCell(nn.Module): 32 | def __init__(self, genotype, index, steps, multiplier, parse_method): 33 | super(FusionCell, self).__init__() 34 | 35 | self.index = index 36 | if self.index == 0: 37 | op_names, indices = zip(*genotype.fusion1) 38 | concat = genotype.fusion1_concat 39 | C = 128 #128 // 2 # Fusion Scale 64x64 40 | # two rgb feats (64x64 128c, 32x32s 256c) 41 | # two depth feats (64x64 128c, 64x64 32c) 42 | self.preprocess0_rgb = nn.Sequential( 43 | nn.Conv2d(128, C, kernel_size=1, bias=False), 44 | nn.BatchNorm2d(C, affine=True)) 45 | self.preprocess1_rgb = nn.Sequential( 46 | nn.Conv2d(256, C, kernel_size=1, bias=False), 47 | nn.BatchNorm2d(C, affine=True), 48 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 49 | self.preprocess0_depth = nn.Sequential( 50 | nn.Conv2d(128, C, kernel_size=1, bias=False), 51 | nn.BatchNorm2d(C, affine=True)) 52 | self.preprocess1_depth = nn.Sequential( 53 | nn.Conv2d(32, C, kernel_size=1, bias=False), 54 | nn.BatchNorm2d(C, affine=True)) 55 | elif self.index == 1: 56 | op_names, indices = zip(*genotype.fusion2) 57 | concat = genotype.fusion2_concat 58 | C = 128 #128 // 2 # Fusion Scale 64x64 59 | # two rgb feats (32x32 256c, 16x16 512c) 60 | # two depth feats (64x64 32c, 64x64 32c) 61 | self.preprocess0_rgb = nn.Sequential( 62 | nn.Conv2d(256, C, kernel_size=1, bias=False), 63 | nn.BatchNorm2d(C, affine=True), 64 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 65 | self.preprocess1_rgb = nn.Sequential( 66 | nn.Conv2d(512, C, kernel_size=1, bias=False), 67 | nn.BatchNorm2d(C, affine=True), 68 | nn.ReLU(inplace=True), 69 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 70 | nn.Conv2d(C, C, kernel_size=1, bias=False), 71 | nn.BatchNorm2d(C, affine=True), 72 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 73 | self.preprocess0_depth = nn.Sequential( 74 | nn.Conv2d(32, C, kernel_size=1, bias=False), 75 | nn.BatchNorm2d(C, affine=True)) 76 | self.preprocess1_depth = nn.Sequential( 77 | nn.Conv2d(32, C, kernel_size=1, bias=False), 78 | nn.BatchNorm2d(C, affine=True)) 79 | else: 80 | op_names, indices = zip(*genotype.fusion3) 81 | concat = genotype.fusion3_concat 82 | C = 128 #256 // 2 # Fusion Scale 32x32 83 | # two rgb feats (16x16 512c, 8x8 512c) 84 | # two depth feats (64x64 32c, 64x64 128c) 85 | self.preprocess0_rgb = nn.Sequential( 86 | nn.Conv2d(512, C, kernel_size=1, bias=False), 87 | nn.BatchNorm2d(C, affine=True), 88 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 89 | self.preprocess1_rgb = nn.Sequential( 90 | nn.Conv2d(512, C, kernel_size=1, bias=False), 91 | nn.BatchNorm2d(C, affine=True), 92 | nn.ReLU(inplace=True), 93 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 94 | nn.Conv2d(C, C, kernel_size=1, bias=False), 95 | nn.BatchNorm2d(C, affine=True), 96 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 97 | self.preprocess0_depth = nn.Sequential( 98 | nn.Conv2d(32, C, kernel_size=3, stride=2, padding=1, bias=False), 99 | nn.BatchNorm2d(C, affine=True)) 100 | self.preprocess1_depth = nn.Sequential( 101 | nn.Conv2d(128, C, kernel_size=3, stride=2, padding=1, bias=False), 102 | nn.BatchNorm2d(C, affine=True)) 103 | 104 | self._steps = steps 105 | self._multiplier = multiplier 106 | self._compile(C, op_names, indices, concat) 107 | 108 | def _compile(self, C, op_names, indices, concat): 109 | assert len(op_names) == len(indices) 110 | self._concat = concat 111 | self.multiplier = len(concat) 112 | 113 | self._ops = nn.ModuleList() 114 | for name, index in zip(op_names, indices): 115 | stride = 1 116 | op = OPS[name](C, stride, True) 117 | self._ops += [op] 118 | self._indices = indices 119 | 120 | def forward(self, s0, s1, s2, s3, drop_prob): 121 | 122 | # print("s_input:",s0.shape, s1.shape, s2.shape, s3.shape) 123 | s0 = self.preprocess0_rgb(s0) 124 | s1 = self.preprocess1_rgb(s1) 125 | s2 = self.preprocess0_depth(s2) 126 | s3 = self.preprocess1_depth(s3) 127 | 128 | # print("s_prepoce:",s0.shape, s1.shape, s2.shape, s3.shape) 129 | states = [s0, s1, s2, s3] 130 | for i in range(self._steps): 131 | h1 = states[self._indices[4*i]] 132 | h2 = states[self._indices[4*i+1]] 133 | h3 = states[self._indices[4*i+2]] 134 | h4 = states[self._indices[4*i+3]] 135 | op1 = self._ops[4*i] 136 | op2 = self._ops[4*i+1] 137 | op3 = self._ops[4*i+2] 138 | op4 = self._ops[4*i+3] 139 | h1 = op1(h1) 140 | h2 = op2(h2) 141 | h3 = op3(h3) 142 | h4 = op4(h4) 143 | if self.training and drop_prob > 0.: 144 | if not isinstance(op1, Identity): 145 | h1 = drop_path(h1, drop_prob) 146 | if not isinstance(op2, Identity): 147 | h2 = drop_path(h2, drop_prob) 148 | if not isinstance(op3, Identity): 149 | h3 = drop_path(h3, drop_prob) 150 | if not isinstance(op4, Identity): 151 | h4 = drop_path(h4, drop_prob) 152 | # print("h:",h1.shape, h2.shape, h3.shape, h4.shape) 153 | s = h1 + h2 + h3 + h4 154 | states += [s] 155 | 156 | return torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W 157 | 158 | # Take three inputs 159 | class AggregationCell(nn.Module): 160 | def __init__(self, genotype, steps, multiplier, parse_method): 161 | super(AggregationCell, self).__init__() 162 | C = 128 163 | self.preprocess0 = None 164 | self.preprocess1 = None 165 | self.preprocess2 = None 166 | 167 | op_names, indices = zip(*genotype.aggregation) 168 | concat = genotype.aggregation_concat 169 | self._steps = steps 170 | self._multiplier = multiplier 171 | self._compile(C, op_names, indices, concat) 172 | 173 | def _compile(self, C, op_names, indices, concat): 174 | assert len(op_names) == len(indices) 175 | self._concat = concat 176 | self.multiplier = len(concat) 177 | self._ops = nn.ModuleList() 178 | for name, index in zip(op_names, indices): 179 | stride = 1 180 | op = OPS[name](C, stride, True) 181 | self._ops += [op] 182 | self._indices = indices 183 | 184 | def forward(self, s0, s1, s2, drop_prob): 185 | # print("000:",s0.shape, s1.shape, s2.shape) 186 | s0 = self.preprocess0(s0) 187 | s1 = self.preprocess1(s1) 188 | s2 = self.preprocess2(s2) 189 | # print("111:",s0.shape, s1.shape, s2.shape) 190 | 191 | states = [s0, s1, s2] 192 | for i in range(self._steps): 193 | h1 = states[self._indices[3*i]] 194 | h2 = states[self._indices[3*i+1]] 195 | h3 = states[self._indices[3*i+2]] 196 | op1 = self._ops[3*i] 197 | op2 = self._ops[3*i+1] 198 | op3 = self._ops[3*i+2] 199 | h1 = op1(h1) 200 | h2 = op2(h2) 201 | h3 = op3(h3) 202 | if self.training and drop_prob > 0.: 203 | if not isinstance(op1, Identity): 204 | h1 = drop_path(h1, drop_prob) 205 | if not isinstance(op2, Identity): 206 | h2 = drop_path(h2, drop_prob) 207 | if not isinstance(op3, Identity): 208 | h3 = drop_path(h3, drop_prob) 209 | s = h1 + h2 + h3 210 | states += [s] 211 | return torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W 212 | 213 | class AggregationCell_1(AggregationCell): 214 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [768,768,768]): 215 | super().__init__(genotype, steps, multiplier, parse_method) 216 | C = 128 217 | self.preprocess0 = nn.Sequential( 218 | nn.ReLU(inplace=True), 219 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False), 220 | nn.BatchNorm2d(C, affine=True) 221 | ) 222 | self.preprocess1 = nn.Sequential( 223 | nn.ReLU(inplace=True), 224 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False), 225 | nn.BatchNorm2d(C, affine=True), 226 | ) 227 | self.preprocess2 = nn.Sequential( 228 | nn.ReLU(inplace=True), 229 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False), 230 | nn.BatchNorm2d(C, affine=True), 231 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 232 | 233 | class AggregationCell_2(AggregationCell): 234 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [512,128,768]): 235 | super().__init__(genotype, steps, multiplier, parse_method) 236 | C = 128 237 | self.preprocess0 = nn.Sequential( 238 | nn.ReLU(inplace=False), 239 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False), 240 | nn.BatchNorm2d(C, affine=True), 241 | nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 242 | ) 243 | self.preprocess1 = nn.Sequential( 244 | nn.ReLU(inplace=False), 245 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False), 246 | nn.BatchNorm2d(C, affine=True), 247 | # nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 248 | ) 249 | self.preprocess2 = nn.Sequential( 250 | nn.ReLU(inplace=False), 251 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False), 252 | nn.BatchNorm2d(C, affine=True)) 253 | 254 | class AggregationCell_3(AggregationCell): 255 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [256,32,768]): 256 | super().__init__(genotype, steps, multiplier, parse_method) 257 | C = 128 258 | self.preprocess0 = nn.Sequential( 259 | nn.ReLU(inplace=False), 260 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False), 261 | nn.BatchNorm2d(C, affine=True), 262 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 263 | ) 264 | self.preprocess1 = nn.Sequential( 265 | nn.ReLU(inplace=False), 266 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False), 267 | nn.BatchNorm2d(C, affine=True), 268 | ) 269 | self.preprocess2 = nn.Sequential( 270 | nn.ReLU(inplace=False), 271 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False), 272 | nn.BatchNorm2d(C, affine=True), 273 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)) 274 | 275 | class AggregationCell_4(AggregationCell): 276 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [512,32,768]): 277 | super().__init__(genotype, steps, multiplier, parse_method) 278 | C = 128 279 | self.preprocess0 = nn.Sequential( 280 | nn.ReLU(inplace=False), 281 | nn.Conv2d(C_in[0], C*2, kernel_size=1, bias=False), 282 | nn.BatchNorm2d(C*2, affine=True), 283 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 284 | nn.ReLU(inplace=False), 285 | nn.Conv2d(C*2, C, kernel_size=1, bias=False), 286 | nn.BatchNorm2d(C, affine=True), 287 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 288 | ) 289 | self.preprocess1 = nn.Sequential( 290 | nn.ReLU(inplace=False), 291 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False), 292 | nn.BatchNorm2d(C, affine=True), 293 | ) 294 | self.preprocess2 = nn.Sequential( 295 | nn.ReLU(inplace=False), 296 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False), 297 | nn.BatchNorm2d(C, affine=True)) 298 | 299 | # Take foul inputs 300 | class GlobalAggregationCell(nn.Module): 301 | def __init__(self, genotype, steps, multiplier, parse_method): 302 | super(GlobalAggregationCell, self).__init__() 303 | C = 256 304 | self.preprocess0 = nn.Sequential( 305 | nn.ReLU(inplace=True), 306 | nn.Conv2d(768, C, kernel_size=1, bias=False), 307 | nn.BatchNorm2d(C, affine=True)) 308 | self.preprocess1 = nn.Sequential( 309 | nn.ReLU(inplace=True), 310 | nn.Conv2d(768, C, kernel_size=1, bias=False), 311 | nn.BatchNorm2d(C, affine=True)) 312 | self.preprocess2 = nn.Sequential( 313 | nn.ReLU(inplace=True), 314 | nn.Conv2d(768, C, kernel_size=1, bias=False), 315 | nn.BatchNorm2d(C, affine=True)) 316 | self.preprocess3 = nn.Sequential( 317 | nn.ReLU(inplace=True), 318 | nn.Conv2d(768, C, kernel_size=1, bias=False), 319 | nn.BatchNorm2d(C, affine=True)) 320 | 321 | op_names, indices = zip(*genotype.final_agg) 322 | concat = genotype.final_aggregation_concat 323 | self._steps = steps 324 | self._multiplier = multiplier 325 | self._compile(C, op_names, indices, concat) 326 | 327 | def _compile(self, C, op_names, indices, concat): 328 | assert len(op_names) == len(indices) 329 | self._concat = concat 330 | self.multiplier = len(concat) 331 | self._ops = nn.ModuleList() 332 | for name, index in zip(op_names, indices): 333 | stride = 1 334 | op = OPS[name](C, stride, True) 335 | self._ops += [op] 336 | self._indices = indices 337 | 338 | def forward(self, s0, s1, s2, s3, drop_prob): 339 | s0 = self.preprocess0(s0) 340 | s1 = self.preprocess1(s1) 341 | s2 = self.preprocess2(s2) 342 | s3 = self.preprocess3(s3) 343 | 344 | states = [s0, s1, s2, s3] 345 | for i in range(self._steps): 346 | h1 = states[self._indices[4*i]] 347 | h2 = states[self._indices[4*i+1]] 348 | h3 = states[self._indices[4*i+2]] 349 | h4 = states[self._indices[4*i+3]] 350 | op1 = self._ops[4*i] 351 | op2 = self._ops[4*i+1] 352 | op3 = self._ops[4*i+2] 353 | op4 = self._ops[4*i+3] 354 | h1 = op1(h1) 355 | h2 = op2(h2) 356 | h3 = op3(h3) 357 | h4 = op4(h4) 358 | if self.training and drop_prob > 0.: 359 | if not isinstance(op1, Identity): 360 | h1 = drop_path(h1, drop_prob) 361 | if not isinstance(op2, Identity): 362 | h2 = drop_path(h2, drop_prob) 363 | if not isinstance(op3, Identity): 364 | h3 = drop_path(h3, drop_prob) 365 | if not isinstance(op4, Identity): 366 | h4 = drop_path(h4, drop_prob) 367 | s = h1 + h2 + h3 + h4 368 | states += [s] 369 | return torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W 370 | 371 | # Take three inputs 372 | class Low_High_aggregation(AggregationCell): 373 | 374 | def __init__(self, genotype, steps, multiplier, parse_method, C_in=[64,64,128]): 375 | super().__init__(genotype, steps, multiplier, parse_method) 376 | C = 32 377 | self.preprocess0 = nn.Sequential( 378 | nn.ReLU(inplace=False), 379 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False), 380 | nn.BatchNorm2d(C, affine=True) 381 | ) 382 | self.preprocess1 = nn.Sequential( 383 | nn.ReLU(inplace=False), 384 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False), 385 | nn.BatchNorm2d(C, affine=True), 386 | ) 387 | self.preprocess2 = nn.Sequential( 388 | nn.ReLU(inplace=False), 389 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False), 390 | nn.BatchNorm2d(C, affine=True)) 391 | op_names, indices = zip(*genotype.low_high_agg) 392 | concat = genotype.low_high_agg_concat 393 | self._compile(C, op_names, indices, concat) 394 | 395 | 396 | 397 | class NasFusionNet(nn.Module): 398 | def __init__(self, fusion_cell_number=3, steps=8, multiplier=6, agg_steps=8, agg_multiplier=6, genotype 399 | =attention_snas_3_4_1, parse_method='darts', op_threshold=0.85, drop_path_prob=0): 400 | self.inplanes = 64 401 | super(NasFusionNet, self).__init__() 402 | self.drop_path_prob = 0 403 | self._multiplier = 6 404 | self.parse_method = parse_method 405 | self.op_threshold = op_threshold 406 | self._steps = steps 407 | # init the fusion cells 408 | self.MM_cells = nn.ModuleList() 409 | 410 | for i in range(fusion_cell_number): 411 | cell = FusionCell(genotype, i, steps, multiplier, parse_method) 412 | self.MM_cells += [cell] 413 | 414 | self.MS_cell_1 = AggregationCell_1(genotype, agg_steps, agg_multiplier, parse_method) 415 | self.MS_cell_2 = AggregationCell_2(genotype, agg_steps, agg_multiplier, parse_method) 416 | self.MS_cell_3 = AggregationCell_3(genotype, agg_steps, agg_multiplier, parse_method) 417 | self.MS_cell_4 = AggregationCell_4(genotype, agg_steps, agg_multiplier, parse_method) 418 | 419 | self.GA_cell = GlobalAggregationCell(genotype, agg_steps, agg_multiplier, parse_method) 420 | self.SR_cell_1 = Low_High_aggregation(genotype, 4, 4, parse_method, C_in = [128,128,256]) 421 | self.SR_cell_2 = Low_High_aggregation(genotype, 4, 4, parse_method, C_in = [64, 64, 128]) 422 | 423 | self.final_layer0 = nn.Sequential( 424 | nn.Conv2d(1536, 512, kernel_size=1), nn.BatchNorm2d(512, affine=True), nn.ReLU(inplace=True), # 256 425 | nn.Conv2d(512, 256, kernel_size=1), nn.BatchNorm2d(256, affine=True), nn.ReLU(inplace=True), 426 | nn.Conv2d(256, 256, kernel_size=1), nn.BatchNorm2d(256, affine=True), nn.ReLU(inplace=True), 427 | ) 428 | 429 | self.final_layer1 = nn.Sequential( 430 | nn.ReLU(inplace=True), 431 | nn.Conv2d(256+128, 256, kernel_size=1), nn.BatchNorm2d(256, affine=True), nn.ReLU(inplace=True), 432 | nn.Conv2d(256, 128, kernel_size=1), nn.BatchNorm2d(128, affine=True), nn.ReLU(inplace=True), 433 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128, affine=True), nn.ReLU(inplace=True) 434 | ) 435 | 436 | self.final_layer2 = nn.Sequential( 437 | nn.Conv2d(128+128, 64, kernel_size=1), nn.BatchNorm2d(64, affine=True), nn.ReLU(inplace=True), 438 | nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64, affine=True), nn.ReLU(inplace=True), 439 | # nn.Dropout2d(p=0.1), 440 | nn.Conv2d(64, 2, kernel_size=1) 441 | ) 442 | 443 | 444 | def forward(self, h1, h2, h3, h4, h5, d0, d1, d2, d3, d4): 445 | # print(h2.shape,d1.shape, h5.shape,d4.shape) 446 | 447 | output1 = self.MM_cells[0](h2, h3, d1, d2, self.drop_path_prob) 448 | output2 = self.MM_cells[1](h3, h4, d2, d3, self.drop_path_prob) 449 | output3 = self.MM_cells[2](h4, h5, d3, d4, self.drop_path_prob) 450 | 451 | agg_features1 = self.MS_cell_1(output1, output2, output3,self.drop_path_prob) 452 | agg_features2 = self.MS_cell_2(h5, d4, output2, self.drop_path_prob) 453 | agg_features3 = self.MS_cell_3(h3, d2, output3, self.drop_path_prob) 454 | agg_features4 = self.MS_cell_4(h4, d3, output1, self.drop_path_prob) 455 | 456 | agg_features = self.GA_cell(agg_features1, agg_features2, agg_features3, agg_features4, self.drop_path_prob) 457 | predict_mask = self.final_layer0(agg_features) # c=256 458 | 459 | low_high_combined1 = self.SR_cell_1(h2, d1, predict_mask, self.drop_path_prob) # c==128 460 | predict_mask = torch.cat([predict_mask, low_high_combined1], dim=1) # 256 + 128 461 | 462 | predict_mask = F.upsample(predict_mask, scale_factor=2, mode='bilinear', align_corners=True) 463 | predict_mask = self.final_layer1(predict_mask) # 128 464 | 465 | low_high_combined2 = self.SR_cell_2(h1, d0, predict_mask, self.drop_path_prob) # 128 466 | predict_mask = torch.cat([predict_mask, low_high_combined2], dim=1) 467 | predict_mask = F.upsample(predict_mask, scale_factor=2, mode='bilinear', align_corners=True) 468 | predict_mask = self.final_layer2(predict_mask) 469 | 470 | return F.sigmoid(predict_mask) 471 | 472 | def init_weights(self): 473 | logger.info('=> NAS Fusion model init weights from normal distribution') 474 | for m in self.modules(): 475 | if isinstance(m, nn.Conv2d): 476 | nn.init.normal_(m.weight, std=0.001) 477 | for name, _ in m.named_parameters(): 478 | if name in ['bias']: 479 | nn.init.constant_(m.bias, 0) 480 | elif isinstance(m, nn.BatchNorm2d): 481 | nn.init.constant_(m.weight, 1) 482 | nn.init.constant_(m.bias, 0) 483 | elif isinstance(m, nn.ConvTranspose2d): 484 | nn.init.normal_(m.weight, std=0.001) 485 | for name, _ in m.named_parameters(): 486 | if name in ['bias']: 487 | nn.init.constant_(m.bias, 0) 488 | 489 | 490 | 491 | 492 | class NasFusionNet_pre(nn.Module): 493 | def __init__(self, fusion_cell_number=3, steps=8, multiplier=6, agg_steps=8, agg_multiplier=6, genotype 494 | =attention_snas_3_4_1, parse_method='darts', op_threshold=0.85, drop_path_prob=0): 495 | self.inplanes = 64 496 | super(NasFusionNet_pre, self).__init__() 497 | self.drop_path_prob = 0 498 | self._multiplier = 6 499 | self.parse_method = parse_method 500 | self.op_threshold = op_threshold 501 | self._steps = steps 502 | # init the fusion cells 503 | self.MM_cells = nn.ModuleList() 504 | 505 | for i in range(fusion_cell_number): 506 | cell = FusionCell(genotype, i, steps, multiplier, parse_method) 507 | self.MM_cells += [cell] 508 | 509 | self.MS_cell_1 = AggregationCell_1(genotype, agg_steps, agg_multiplier, parse_method) 510 | self.MS_cell_2 = AggregationCell_2(genotype, agg_steps, agg_multiplier, parse_method) 511 | self.MS_cell_3 = AggregationCell_3(genotype, agg_steps, agg_multiplier, parse_method) 512 | self.MS_cell_4 = AggregationCell_4(genotype, agg_steps, agg_multiplier, parse_method) 513 | 514 | self.GA_cell = GlobalAggregationCell(genotype, agg_steps, agg_multiplier, parse_method) 515 | ######## for pretrain 516 | self.class_head = nn.Sequential( 517 | nn.AvgPool2d((56, 56))) 518 | self.classifier = nn.Linear(1536, 1000) 519 | 520 | 521 | def forward(self, h1, h2, h3, h4, h5, d0, d1, d2, d3, d4): 522 | 523 | output1 = self.MM_cells[0](h2, h3, d1, d2, self.drop_path_prob) 524 | output2 = self.MM_cells[1](h3, h4, d2, d3, self.drop_path_prob) 525 | output3 = self.MM_cells[2](h4, h5, d3, d4, self.drop_path_prob) 526 | 527 | agg_features1 = self.MS_cell_1(output1, output2, output3,self.drop_path_prob) 528 | agg_features2 = self.MS_cell_2(h5, d4, output2, self.drop_path_prob) 529 | agg_features3 = self.MS_cell_3(h3, d2, output3, self.drop_path_prob) 530 | agg_features4 = self.MS_cell_4(h4, d3, output1, self.drop_path_prob) 531 | 532 | agg_features = self.GA_cell(agg_features1, agg_features2, agg_features3, agg_features4, self.drop_path_prob) 533 | 534 | ######## for pretrain 535 | # print(agg_features.shape) 536 | class_feature = self.class_head(agg_features).view(agg_features.size(0), -1) 537 | logits = self.classifier(class_feature) 538 | return logits 539 | 540 | def init_weights(self): 541 | logger.info('=> init weights from normal distribution') 542 | for m in self.modules(): 543 | if isinstance(m, nn.Conv2d): 544 | nn.init.normal_(m.weight, std=0.001) 545 | for name, _ in m.named_parameters(): 546 | if name in ['bias']: 547 | nn.init.constant_(m.bias, 0) 548 | elif isinstance(m, nn.BatchNorm2d): 549 | nn.init.constant_(m.weight, 1) 550 | nn.init.constant_(m.bias, 0) 551 | elif isinstance(m, nn.ConvTranspose2d): 552 | nn.init.normal_(m.weight, std=0.001) 553 | for name, _ in m.named_parameters(): 554 | if name in ['bias']: 555 | nn.init.constant_(m.bias, 0) 556 | 557 | -------------------------------------------------------------------------------- /models/model_rgb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from dsam import Adaptive_DSAM as DepthAttention 5 | 6 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 7 | """Make a 2D bilinear kernel suitable for upsampling""" 8 | factor = (kernel_size + 1) // 2 9 | if kernel_size % 2 == 1: 10 | center = factor - 1 11 | else: 12 | center = factor - 0.5 13 | og = np.ogrid[:kernel_size, :kernel_size] 14 | filt = (1 - abs(og[0] - center) / factor) * \ 15 | (1 - abs(og[1] - center) / factor) 16 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 17 | dtype=np.float64) 18 | weight[range(in_channels), range(out_channels), :, :] = filt 19 | return torch.from_numpy(weight).float() 20 | 21 | #################################### Rgb Network ##################################### 22 | 23 | class RgbNet(nn.Module): 24 | def __init__(self): 25 | super(RgbNet, self).__init__() 26 | 27 | # original image's size = 256*256*3 28 | # conv1 29 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 30 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 31 | self.relu1_1 = nn.ReLU(inplace=True) 32 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 33 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 34 | self.relu1_2 = nn.ReLU(inplace=True) 35 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 36 | self.depth_att1 = DepthAttention(64) 37 | 38 | # conv2 39 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 40 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 41 | self.relu2_1 = nn.ReLU(inplace=True) 42 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 43 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 44 | self.relu2_2 = nn.ReLU(inplace=True) 45 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 46 | self.depth_att2 = DepthAttention(128) 47 | # conv3 48 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 49 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 50 | self.relu3_1 = nn.ReLU(inplace=True) 51 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 52 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 53 | self.relu3_2 = nn.ReLU(inplace=True) 54 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 55 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 56 | self.relu3_3 = nn.ReLU(inplace=True) 57 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1) 58 | self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 59 | self.relu3_4 = nn.ReLU(inplace=True) 60 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 61 | self.depth_att3 = DepthAttention(256) 62 | 63 | # conv4 64 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 65 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 66 | self.relu4_1 = nn.ReLU(inplace=True) 67 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 68 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 69 | self.relu4_2 = nn.ReLU(inplace=True) 70 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 71 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 72 | self.relu4_3 = nn.ReLU(inplace=True) 73 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1) 74 | self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 75 | self.relu4_4 = nn.ReLU(inplace=True) 76 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 77 | self.depth_att4 = DepthAttention(512) 78 | 79 | # conv5 80 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 81 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 82 | self.relu5_1 = nn.ReLU(inplace=True) 83 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 84 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 85 | self.relu5_2 = nn.ReLU(inplace=True) 86 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 87 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 88 | self.relu5_3 = nn.ReLU(inplace=True) 89 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1) 90 | self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 91 | self.relu5_4 = nn.ReLU(inplace=True) 92 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 4 layers 93 | self.depth_att5 = DepthAttention(512) 94 | 95 | self._initialize_weights() 96 | 97 | def _initialize_weights(self): 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | # m.weight.data.zero_() 101 | nn.init.normal(m.weight.data, std=0.01) 102 | if m.bias is not None: 103 | m.bias.data.zero_() 104 | if isinstance(m, nn.ConvTranspose2d): 105 | assert m.kernel_size[0] == m.kernel_size[1] 106 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0]) 107 | m.weight.data.copy_(initial_weight) 108 | 109 | 110 | def forward(self, x, bins, gumbel): 111 | 112 | b=bins 113 | h = x 114 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 115 | h = self.conv1_2(h) 116 | h = self.depth_att1(h, b, gumbel=gumbel ) 117 | h = self.relu1_2(self.bn1_2(h)) 118 | h1 = self.pool1(h) # (128x128)*64 119 | b1 = self.pool1(b) 120 | 121 | h = self.relu2_1(self.bn2_1(self.conv2_1(h1))) 122 | h=self.conv2_2(h) 123 | h = self.depth_att2(h, b1, gumbel=gumbel) 124 | h = self.relu2_2(self.bn2_2(h)) 125 | h2 = self.pool2(h) # (64x64)*128 126 | b2 = self.pool2(b1) 127 | 128 | h = self.relu3_1(self.bn3_1(self.conv3_1(h2))) 129 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 130 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 131 | h = self.conv3_4(h) 132 | h = self.depth_att3(h, b2, gumbel=gumbel) 133 | h = self.relu3_4(self.bn3_4(h)) 134 | h3 = self.pool3(h)# (32x32)*256 135 | b3 = self.pool3(b2) 136 | 137 | h = self.relu4_1(self.bn4_1(self.conv4_1(h3))) 138 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 139 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 140 | h = self.conv4_4(h) 141 | h = self.depth_att4(h,b3, gumbel=gumbel) 142 | h = self.relu4_4(self.bn4_4(h)) 143 | h4 = self.pool4(h)# (16x16)*512 144 | b4 = self.pool4(b3) 145 | 146 | 147 | h = self.relu5_1(self.bn5_1(self.conv5_1(h4))) 148 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 149 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 150 | h = self.conv5_4(h) 151 | h = self.depth_att5(h,b4, gumbel=gumbel) 152 | h = self.relu5_4(self.bn5_4(h)) 153 | h5 = self.pool5(h)#(8x8)*512 154 | 155 | return h1,h2,h3,h4,h5 156 | 157 | 158 | 159 | def copy_params_from_vgg19_bn(self, vgg19_bn): 160 | features = [ 161 | self.conv1_1, self.bn1_1, self.relu1_1, 162 | self.conv1_2, self.bn1_2, self.relu1_2, 163 | self.pool1, 164 | self.conv2_1, self.bn2_1, self.relu2_1, 165 | self.conv2_2, self.bn2_2, self.relu2_2, 166 | self.pool2, 167 | self.conv3_1, self.bn3_1, self.relu3_1, 168 | self.conv3_2, self.bn3_2, self.relu3_2, 169 | self.conv3_3, self.bn3_3, self.relu3_3, 170 | self.conv3_4, self.bn3_4, self.relu3_4, 171 | self.pool3, 172 | self.conv4_1, self.bn4_1, self.relu4_1, 173 | self.conv4_2, self.bn4_2, self.relu4_2, 174 | self.conv4_3, self.bn4_3, self.relu4_3, 175 | self.conv4_4, self.bn4_4, self.relu4_4, 176 | self.pool4, 177 | self.conv5_1, self.bn5_1, self.relu5_1, 178 | self.conv5_2, self.bn5_2, self.relu5_2, 179 | self.conv5_3, self.bn5_3, self.relu5_3, 180 | self.conv5_4, self.bn5_4, self.relu5_4, 181 | ] 182 | for l1, l2 in zip(vgg19_bn.features, features): 183 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): 184 | assert l1.weight.size() == l2.weight.size() 185 | assert l1.bias.size() == l2.bias.size() 186 | l2.weight.data = l1.weight.data 187 | l2.bias.data = l1.bias.data 188 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d): 189 | assert l1.weight.size() == l2.weight.size() 190 | assert l1.bias.size() == l2.bias.size() 191 | l2.weight.data = l1.weight.data 192 | l2.bias.data = l1.bias.data -------------------------------------------------------------------------------- /models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'none' : lambda C, stride, affine: Zero(stride), 6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 9 | 'conv_3x3' : lambda C, stride, affine : ReLUConvBN(C, C, 3, stride, 1, affine=affine), 10 | 'conv_1x1' : lambda C, stride, affine : ReLUConvBN(C, C, 1, stride, 0, affine=affine), 11 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 12 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 13 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 14 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 15 | 'dil_conv_3x3_2dil' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 16 | 'dil_conv_3x3_4dil' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 4, affine=affine), 17 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( 19 | nn.ReLU(inplace=False), 20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 22 | nn.BatchNorm2d(C, affine=affine) 23 | ), 24 | 'spatial_attention': lambda C, stride, affine : SpatialAttentionLayer(C, C, 8, stride, affine), 25 | 'channel_attention': lambda C, stride, affine : ChannelAttentionLayer(C, C, 8, stride, affine) 26 | } 27 | 28 | class Zero(nn.Module): 29 | 30 | def __init__(self, stride): 31 | super(Zero, self).__init__() 32 | self.stride = stride 33 | 34 | def forward(self, x): 35 | if self.stride == 1: 36 | return x.mul(0.) 37 | return x[:,:,::self.stride,::self.stride].mul(0.) 38 | 39 | 40 | 41 | class SpatialAttentionLayer(nn.Module): 42 | def __init__(self, C_in, C_out, reduction=16, stride=1, affine=True, BN=nn.BatchNorm2d): 43 | super(SpatialAttentionLayer, self).__init__() 44 | self.stride = stride 45 | if stride == 1: 46 | self.fc = nn.Sequential( 47 | nn.Conv2d(C_in, C_in // reduction, kernel_size=3, stride=1, padding=1, bias=False), 48 | BN(C_in // reduction, affine=affine), 49 | nn.ReLU(inplace=False), 50 | nn.Conv2d(C_in // reduction, 1,kernel_size=3, stride=1, padding=1, bias=False), 51 | nn.Sigmoid() 52 | ) 53 | else: 54 | self.fc = nn.Sequential( 55 | nn.Conv2d(C_in, C_in // reduction, kernel_size=3, stride=2, padding=1, bias=False), 56 | BN(C_in // reduction, affine=affine), 57 | nn.ReLU(inplace=False), 58 | nn.Conv2d(C_in // reduction, 1, kernel_size=3, stride=1, padding=1, bias=False), 59 | nn.Sigmoid() 60 | ) 61 | self.reduce_map = nn.Sequential( 62 | nn.ReLU(inplace=False), 63 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=2, padding=0, bias=False), 64 | BN(C_out, affine=affine) 65 | ) 66 | 67 | def forward(self, x): 68 | y = self.fc(x) 69 | if self.stride == 2: 70 | x = self.reduce_map(x) 71 | return x * y 72 | 73 | 74 | ## Channel Attention (CA) Layer 75 | class ChannelAttentionLayer(nn.Module): 76 | def __init__(self, C_in, C_out, reduction=16, stride=1, affine=True, BN=nn.BatchNorm2d): 77 | super(ChannelAttentionLayer, self).__init__() 78 | # global average pooling: feature --> point 79 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 80 | self.stride = stride 81 | # feature channel downscale and upscale --> channel weight 82 | if stride == 1: 83 | self.conv_du = nn.Sequential( 84 | nn.Conv2d(C_in, C_in // reduction, 1, padding=0, bias=False), 85 | nn.ReLU(inplace=False), 86 | nn.Conv2d(C_in // reduction, C_out, 1, padding=0, bias=False), 87 | nn.Sigmoid() 88 | ) 89 | else: 90 | self.conv_du = nn.Sequential( 91 | nn.Conv2d(C_in, C_in // reduction, kernel_size=1, stride=2, padding=0, bias=False), 92 | nn.ReLU(inplace=False), 93 | nn.Conv2d(C_in // reduction, C_out, 1, padding=0, bias=False), 94 | nn.Sigmoid() 95 | ) 96 | self.reduce_map = nn.Sequential( 97 | nn.ReLU(inplace=False), 98 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=2, padding=0, bias=False), 99 | BN(C_out, affine=affine) 100 | ) 101 | 102 | def forward(self, x): 103 | if self.stride == 2: 104 | x = self.reduce_map(x) 105 | y = self.avg_pool(x) 106 | y = self.conv_du(y) 107 | return x * y 108 | 109 | class ReLUConvBN(nn.Module): 110 | """ 111 | ReLu -> Conv2d -> BatchNorm2d 112 | """ 113 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 114 | super(ReLUConvBN, self).__init__() 115 | self.op = nn.Sequential( 116 | nn.ReLU(inplace=False), 117 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 118 | nn.BatchNorm2d(C_out, affine=affine) 119 | ) 120 | 121 | def forward(self, x): 122 | return self.op(x) 123 | 124 | class DilConv(nn.Module): 125 | """ 126 | Dilation Convolution : ReLU -> DilConv -> Conv2d -> BatchNorm2d 127 | """ 128 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 129 | super(DilConv, self).__init__() 130 | self.op = nn.Sequential( 131 | nn.ReLU(inplace=False), 132 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 133 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 134 | nn.BatchNorm2d(C_out, affine=affine), 135 | ) 136 | 137 | def forward(self, x): 138 | return self.op(x) 139 | 140 | 141 | class SepConv(nn.Module): 142 | 143 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 144 | super(SepConv, self).__init__() 145 | self.op = nn.Sequential( 146 | nn.ReLU(inplace=False), 147 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 148 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 149 | nn.BatchNorm2d(C_in, affine=affine), 150 | nn.ReLU(inplace=False), 151 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 152 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 153 | nn.BatchNorm2d(C_out, affine=affine), 154 | ) 155 | 156 | def forward(self, x): 157 | return self.op(x) 158 | 159 | 160 | class Identity(nn.Module): 161 | 162 | def __init__(self): 163 | super(Identity, self).__init__() 164 | 165 | def forward(self, x): 166 | return x 167 | 168 | 169 | class Zero(nn.Module): 170 | 171 | def __init__(self, stride): 172 | super(Zero, self).__init__() 173 | self.stride = stride 174 | 175 | def forward(self, x): 176 | if self.stride == 1: 177 | return x.mul(0.) 178 | return x[:,:,::self.stride,::self.stride].mul(0.) # N * C * W * H 179 | 180 | 181 | class FactorizedReduce(nn.Module): 182 | 183 | def __init__(self, C_in, C_out, affine=True): 184 | super(FactorizedReduce, self).__init__() 185 | assert C_out % 2 == 0 186 | self.relu = nn.ReLU(inplace=False) 187 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 188 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 189 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 190 | 191 | def forward(self, x): 192 | x = self.relu(x) 193 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1) 194 | out = self.bn(out) 195 | return out 196 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | import torch 5 | import torch.optim as optim 6 | from dataset_loader import MyTestData 7 | import logging 8 | from tqdm import tqdm 9 | import time 10 | from utils.functions import * 11 | from utils.evaluateFM import get_FM 12 | from loss import cross_entropy2d, iou, BinaryDiceLoss 13 | running_loss_final = 0 14 | iou_final = 0 15 | aux_final = 0 16 | 17 | class Trainer(object): 18 | 19 | def __init__(self, cuda, cfg, model_depth, model_rgb, model_fusion, train_loader, test_data_list, test_data_root, salmap_root, outpath, logging, writer, max_epoch): 20 | self.cuda = cuda 21 | self.model_depth = model_depth 22 | self.model_rgb = model_rgb 23 | self.model_fusion = model_fusion 24 | 25 | self.optim_depth = optim.SGD(self.model_depth.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay']) 26 | self.optim_rgb = optim.SGD(self.model_rgb.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay']) 27 | self.optim_fusion = optim.SGD(self.model_fusion.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay']) 28 | 29 | self.train_loader = train_loader 30 | self.test_data_list = test_data_list 31 | self.test_data_root = test_data_root 32 | self.salmap_root = salmap_root 33 | self.test_loaders={} 34 | self.best_f = {} 35 | self.best_m = {} 36 | for data_i in self.test_data_list: 37 | MapRoot = self.salmap_root + data_i 38 | TestRoot = self.test_data_root + data_i 39 | if not os.path.exists(MapRoot): 40 | os.mkdir(MapRoot) 41 | loader_i = torch.utils.data.DataLoader(MyTestData(TestRoot, transform=True), 42 | batch_size = 1, shuffle=True, num_workers=0, pin_memory=True) 43 | self.test_loaders[data_i] = loader_i 44 | self.best_f[data_i] = -1 45 | self.best_m[data_i] = 10000 46 | 47 | self.epoch = 0 48 | self.iteration = 0 49 | self.max_iter = 0 50 | self.snapshot = cfg[1]['spshot'] 51 | self.outpath = outpath 52 | self.sshow = cfg[1]['sshow'] 53 | self.logging = logging 54 | self.writer = writer 55 | self.max_epoch = max_epoch 56 | self.base_lr = cfg[1]['lr'] 57 | 58 | self.dice = BinaryDiceLoss() 59 | 60 | 61 | def train_epoch(self): 62 | self.logging.info("length trainloader: %s", len(self.train_loader)) 63 | self.logging.info("current_lr is : %s", self.optim_fusion.param_groups[0]['lr']) 64 | for batch_idx, (img, mask, depth, bins) in enumerate(tqdm(self.train_loader)): 65 | ########## for debug 66 | # if batch_idx % 10==0 and batch_idx>10: 67 | # self.save_test(iteration) 68 | iteration = batch_idx + self.epoch * len(self.train_loader) 69 | 70 | if self.iteration != 0 and (iteration - 1) != self.iteration: 71 | continue # for resuming 72 | self.iteration = iteration 73 | 74 | if self.cuda: 75 | img, mask, depth, bins = img.cuda(), mask.cuda(), depth.cuda(), bins.cuda() 76 | img, mask, depth, bins = Variable(img), Variable(mask), Variable(depth), bins.cuda() 77 | # print(img.size()) 78 | n, c, h, w = img.size() # batch_size, channels, height, weight 79 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1) 80 | 81 | self.optim_depth.zero_grad() 82 | self.optim_rgb.zero_grad() 83 | self.optim_fusion.zero_grad() 84 | 85 | global running_loss_final ,iou_final, aux_final 86 | 87 | d0, d1, d2, d3, d4 = self.model_depth(depth) 88 | h1, h2, h3, h4, h5 = self.model_rgb(img, bins, gumbel=True) 89 | predict_mask = self.model_fusion(h1, h2, h3, h4, h5, d0, d1, d2, d3, d4) 90 | 91 | ce_loss = cross_entropy2d(predict_mask, mask, size_average=False) 92 | iou_loss = torch.zeros(1) 93 | aux_ce_loss = torch.zeros(1) 94 | # iou_loss = iou(predict_mask, mask,size_average=False ) * 0.2 95 | # iou_loss = self.dice(predict_mask, mask) 96 | loss = ce_loss #+ iou_loss + aux_ce_loss 97 | 98 | running_loss_final += ce_loss.item() 99 | iou_final += iou_loss.item() 100 | aux_final += aux_ce_loss.item() 101 | 102 | if iteration % self.sshow == (self.sshow - 1): 103 | self.logging.info('\n [%3d, %6d, RGB-D Net ce_loss: %.3f aux_loss: %.3f iou_loss: %.3f]' % ( 104 | self.epoch + 1, iteration + 1, running_loss_final / (n * self.sshow), aux_final / (n * self.sshow), iou_final / (n * self.sshow))) 105 | 106 | self.writer.add_scalar('train/iou_loss', iou_final / (n * self.sshow), iteration + 1) 107 | self.writer.add_scalar('train/aux_loss', aux_final / (n * self.sshow), iteration + 1) 108 | 109 | self.writer.add_scalar('train/lr', self.optim_fusion.param_groups[0]['lr'] , iteration + 1) 110 | self.writer.add_scalar('train/iter_ce_loss', running_loss_final / (n * self.sshow), iteration + 1) 111 | 112 | self.writer.add_scalar('train/epoch_ce_loss', running_loss_final / (n * self.sshow), self.epoch + 1) 113 | running_loss_final = 0.0 114 | iou_final= 0.0 115 | aux_final=0.0 116 | 117 | loss.backward() 118 | self.optim_depth.step() 119 | self.optim_rgb.step() 120 | self.optim_fusion.step() 121 | 122 | if iteration <= 200000: 123 | if iteration % self.snapshot == (self.snapshot - 1): 124 | self.save_test(iteration) 125 | else: 126 | if iteration % 10000 == (10000 - 1): 127 | self.save_test(iteration) 128 | 129 | def test(self,iteration, test_data): 130 | res = [] 131 | MapRoot = self.salmap_root + test_data 132 | for id, (data, depth, bins, img_name, img_size) in enumerate(self.test_loaders[test_data]): 133 | # print('testing bach %d' % id) 134 | inputs = Variable(data).cuda() 135 | depth = Variable(depth).cuda() 136 | bins = Variable(bins).cuda() 137 | n, c, h, w = inputs.size() 138 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1) 139 | torch.cuda.synchronize() 140 | start = time.time() 141 | with torch.no_grad(): 142 | h1, h2, h3, h4, h5 = self.model_rgb(inputs, bins, gumbel=False) 143 | d0, d1, d2, d3, d4 = self.model_depth(depth) 144 | predict_mask = self.model_fusion(h1, h2, h3, h4, h5, d0, d1, d2, d3, d4) 145 | torch.cuda.synchronize() 146 | end = time.time() 147 | 148 | res.append(end - start) 149 | outputs_all = F.softmax(predict_mask, dim=1) 150 | outputs = outputs_all[0][1] 151 | # import pdb; pdb.set_trace() 152 | outputs = outputs.cpu().data.resize_(h, w) 153 | 154 | imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size) 155 | time_sum = 0 156 | for i in res: 157 | time_sum += i 158 | self.logging.info("FPS: %f" % (1.0 / (time_sum / len(res)))) 159 | # -------------------------- validation --------------------------- # 160 | torch.cuda.empty_cache() 161 | F_measure, mae = get_FM(salpath=MapRoot+'/', gtpath=self.test_data_root + test_data+'/test_masks/') 162 | 163 | self.writer.add_scalar('test/'+ test_data +'_F_measure', F_measure, iteration +1) 164 | self.writer.add_scalar('test/'+ test_data +'_MAE', mae, iteration+1) 165 | 166 | self.logging.info(MapRoot.split('/')[-1] + ' F_measure: %f' , F_measure) 167 | self.logging.info(MapRoot.split('/')[-1] + ' MAE: %f', mae) 168 | print('the testing process has finished!') 169 | 170 | return F_measure, mae 171 | 172 | 173 | def save_test(self, iteration, epoch = -1): 174 | self.save(iteration, epoch) 175 | for data_i in self.test_data_list: 176 | f, m = self.test(iteration, data_i) 177 | 178 | self.best_f[data_i] = max(f, self.best_f[data_i]) 179 | self.best_m[data_i] = min(m, self.best_m[data_i]) 180 | self.writer.add_scalar('best/'+ data_i +'_MAE', self.best_m[data_i], iteration) 181 | self.writer.add_scalar('best/'+ data_i +'_Fmeasure', self.best_f[data_i], iteration) 182 | 183 | def save(self, iteration=-1, epoch=-1): 184 | savename_depth = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 185 | torch.save(self.model_depth.state_dict(), savename_depth) 186 | self.logging.info('save: (snapshot: %d)' % (iteration + 1)) 187 | 188 | savename_rgb = ('%s/rgb_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 189 | torch.save(self.model_rgb.state_dict(), savename_rgb) 190 | self.logging.info('save: (snapshot: %d)' % (iteration + 1)) 191 | 192 | savename_fusion = ('%s/fusion_snapshot_iter_%d.pth' % (self.outpath, iteration + 1)) 193 | torch.save(self.model_fusion.state_dict(), savename_fusion) 194 | self.logging.info('save: (snapshot: %d)' % (iteration + 1)) 195 | 196 | 197 | if epoch > 0 : 198 | savename_depth = ('%s/depth_snapshot_epoch_%d.pth' % (self.outpath, epoch + 1)) 199 | torch.save(self.model_depth.state_dict(), savename_depth) 200 | self.logging.info('save: (snapshot: %d)' % (self.epoch + 1)) 201 | 202 | savename_rgb = ('%s/rgb_snapshot_epoch_%d.pth' % (self.outpath, epoch + 1)) 203 | torch.save(self.model_rgb.state_dict(), savename_rgb) 204 | self.logging.info('save: (snapshot: %d)' % (self.epoch + 1)) 205 | 206 | savename_fusion = ('%s/fusion_snapshot_epoch_%d.pth' % (self.outpath, epoch + 1)) 207 | torch.save(self.model_fusion.state_dict(), savename_fusion) 208 | self.logging.info('save: (snapshot: %d)' % (self.epoch + 1)) 209 | 210 | 211 | 212 | def adjust_learning_rate(self, epoch): 213 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs""" 214 | lr = self.base_lr 215 | if epoch >= 20: 216 | lr = 0.1 * lr 217 | if epoch >= 40: 218 | lr = 0.1 * lr 219 | 220 | self.optim_depth.param_groups[0]['lr']= lr 221 | self.optim_rgb.param_groups[0]['lr']= lr 222 | self.optim_fusion.param_groups[0]['lr']= lr 223 | 224 | def train(self): 225 | max_epoch = self.max_epoch 226 | print ("max_epoch", max_epoch) 227 | self.max_iter = int(math.ceil(len(self.train_loader) * self.max_epoch)) 228 | print ("max_iter", self.max_iter) 229 | 230 | for epoch in range(max_epoch): 231 | # self.adjust_learning_rate(epoch) 232 | self.epoch = epoch 233 | self.train_epoch() 234 | # save each epoch. 235 | self.save_test(self.iteration, epoch = self.epoch ) 236 | 237 | self.logging.info('all training process finished') 238 | print(self.best_f) 239 | print(self.best_m) 240 | 241 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('utils') 3 | 4 | 5 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluateFM.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/evaluateFM.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/evaluateFM.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/evaluateFM.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/functions.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/functions.cpython-37.pyc -------------------------------------------------------------------------------- /utils/evaluateFM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | # import cv2 4 | import matplotlib.pyplot as plt 5 | import PIL.Image as Image 6 | def get_FM(salpath,gtpath): 7 | 8 | gtdir = gtpath 9 | saldir = salpath 10 | 11 | files = os.listdir(gtdir) 12 | eps = np.finfo(float).eps 13 | 14 | m_pres = np.zeros(21) 15 | m_recs = np.zeros(21) 16 | m_fms = np.zeros(21) 17 | m_thfm = 0 18 | m_mea = 0 19 | it = 1 20 | for i, name in enumerate(files): 21 | if not os.path.exists(gtdir + name): 22 | print(gtdir + name, 'does not exist') 23 | gt = Image.open(gtdir + name) 24 | gt = np.array(gt, dtype=np.uint8) 25 | 26 | 27 | mask=Image.open(saldir+name).convert('L') 28 | mask=mask.resize((np.shape(gt)[1],np.shape(gt)[0])) 29 | mask = np.array(mask, dtype=np.float) 30 | # salmap = cv2.resize(salmap,(W,H)) 31 | 32 | if len(mask.shape) != 2: 33 | mask = mask[:, :, 0] 34 | mask = (mask - mask.min()) / (mask.max() - mask.min() + eps) 35 | gt[gt != 0] = 1 36 | pres = [] 37 | recs = [] 38 | fms = [] 39 | mea = np.abs(gt-mask).mean() 40 | # threshold fm 41 | binary = np.zeros(mask.shape) 42 | th = 2*mask.mean() 43 | if th > 1: 44 | th = 1 45 | binary[mask >= th] = 1 46 | sb = (binary * gt).sum() 47 | pre = sb / (binary.sum()+eps) 48 | rec = sb / (gt.sum()+eps) 49 | thfm = 1.3 * pre * rec / (0.3 * pre + rec + eps) 50 | for th in np.linspace(0, 1, 21): 51 | binary = np.zeros(mask.shape) 52 | binary[ mask >= th] = 1 53 | pre = (binary * gt).sum() / (binary.sum()+eps) 54 | rec = (binary * gt).sum() / (gt.sum()+ eps) 55 | fm = 1.3 * pre * rec / (0.3*pre + rec + eps) 56 | pres.append(pre) 57 | recs.append(rec) 58 | fms.append(fm) 59 | fms = np.array(fms) 60 | pres = np.array(pres) 61 | recs = np.array(recs) 62 | m_mea = m_mea * (it-1) / it + mea / it 63 | m_fms = m_fms * (it - 1) / it + fms / it 64 | m_recs = m_recs * (it - 1) / it + recs / it 65 | m_pres = m_pres * (it - 1) / it + pres / it 66 | m_thfm = m_thfm * (it - 1) / it + thfm / it 67 | it += 1 68 | return m_thfm, m_mea 69 | 70 | if __name__ == '__main__': 71 | m_thfm, m_mea = get_FM() 72 | print(m_thfm) 73 | print(m_mea) -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | # from scipy.misc import imresize 5 | from PIL import Image 6 | import os 7 | import cv2 8 | 9 | def adaptive_bins( hist,threshold): 10 | new = hist.copy() 11 | peak=hist.max() 12 | peak_depth=np.where(hist==peak)[0] 13 | delta_hist=np.diff(hist,n=1,axis=0) 14 | #print(peak,peak_depth,peak_depth.shape) 15 | left=peak_depth 16 | right=peak_depth 17 | i = np.array([peak_depth[0]]) 18 | while(1): 19 | new[[i]]=0 20 | if (i>=254): 21 | right=np.array([254]) 22 | break 23 | if (delta_hist[i]<0): 24 | i=i+1 25 | elif (hist[i]<=threshold*peak): 26 | right = i 27 | break 28 | else: 29 | i=i+1 30 | i = np.array([peak_depth[0]-1]) 31 | while(1): 32 | new[[i+1]]=0 33 | if (i<=0): 34 | left=np.array([0]) 35 | break 36 | if (delta_hist[i]>0): 37 | i=i-1 38 | elif (hist[i]<=threshold*peak): 39 | left = i+1 40 | break 41 | else: 42 | i=i-1 43 | #print(peak,peak_depth,left[0],right[0]) 44 | return [new,left[0],right[0]] 45 | 46 | def get_bins_masks( depth): 47 | mask_list=[] 48 | hist = cv2.calcHist([depth],[0],None,[256],[0,255]) 49 | 50 | hist1,left1,right1=adaptive_bins(hist,0.7) 51 | mask1 = (depth>left1-0.2*(right1-left1)) * (depth<=right1+0.2*(right1-left1)) 52 | #mask1 = (depth>left1) * (depth<=right1) 53 | 54 | mask_list.append(mask1) 55 | 56 | hist2,left2,right2=adaptive_bins(hist1,0.2) 57 | mask2 = (depth>left2-0.2*(right2-left2)) * (depth<=right2+0.2*(right2-left2)) 58 | #mask2 = (depth>left2) * (depth<=right2) 59 | 60 | mask_list.append(mask2) 61 | 62 | mask3_1 =(depth>left1) * (depth<=right1) 63 | mask3_2 =(depth>left2) * (depth<=right2) 64 | mask3=(~mask3_2)*(~mask3_1) 65 | 66 | mask_list.append(mask3) 67 | mask_bins = np.stack(mask_list,axis=0) 68 | 69 | return mask_bins 70 | 71 | 72 | def create_exp_dir(path, scripts_to_save=None): 73 | import time 74 | time.sleep(2) 75 | if not os.path.exists(path): 76 | os.makedirs(path) 77 | print('Experiment dir : {}'.format(path)) 78 | 79 | if scripts_to_save is not None: 80 | os.makedirs(os.path.join(path, 'scripts')) 81 | for script in scripts_to_save: 82 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 83 | shutil.copyfile(script, dst_file) 84 | 85 | 86 | 87 | def count_parameters_in_MB(model): 88 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6 89 | 90 | 91 | 92 | def imsave(file_name, img, img_size): 93 | """ 94 | save a torch tensor as an image 95 | :param file_name: 'image/folder/image_name' 96 | :param img: 3*h*w torch tensor 97 | :return: nothing 98 | """ 99 | assert(type(img) == torch.FloatTensor, 100 | 'img must be a torch.FloatTensor') 101 | ndim = len(img.size()) 102 | assert(ndim == 2 or ndim == 3, 103 | 'img must be a 2 or 3 dimensional tensor') 104 | 105 | img = img.numpy() 106 | 107 | img = np.array(Image.fromarray(img).resize((img_size[1][0], img_size[0][0]), Image.NEAREST)) 108 | # img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest') 109 | if ndim == 3: 110 | plt.imsave(file_name, np.transpose(img, (1, 2, 0))) 111 | else: 112 | plt.imsave(file_name, img, cmap='gray') 113 | 114 | def load_pretrain(path, state_dict, name): 115 | state = torch.load(path) 116 | if 'state_dict' in state: 117 | state = state['state_dict'] 118 | name = "module."+name 119 | length = len(name) 120 | for k, v in state.items(): 121 | if k[:length] == name: 122 | if k[length:] in state_dict.keys(): 123 | state_dict[k[length:]] = v 124 | # print(k[length:]) 125 | else: 126 | print("pass keys: ",k[7:]) 127 | 128 | return state_dict 129 | -------------------------------------------------------------------------------- /utils/pretreat_SIP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image 4 | import scipy.io as sio 5 | import torch 6 | from torch.utils import data 7 | import cv2 8 | from utils.functions import adaptive_bins, get_bins_masks 9 | 10 | root = "/data/wenhu/RGBD-SOD/SOD-RGBD/val/raw_SIP" 11 | 12 | img_root = os.path.join(root, 'test_images') 13 | depth_root = os.path.join(root, 'test_depth') 14 | gt_root = os.path.join(root,'test_masks') 15 | names=[] 16 | img_names=[] 17 | depth_names=[] 18 | gt_names=[] 19 | 20 | file_names = os.listdir(img_root) 21 | 22 | for i, name in enumerate(file_names): 23 | if not name.endswith('.jpg'): 24 | continue 25 | names.append(name[:-4]) 26 | img_names.append( 27 | os.path.join(img_root, name) 28 | ) 29 | 30 | depth_names.append( 31 | os.path.join(depth_root, name[:-4] + '.png') 32 | ) 33 | 34 | gt_names.append( 35 | os.path.join(gt_root, name[:-4] + '.png') 36 | ) 37 | 38 | new_root = "/data/wenhu/RGBD-SOD/SOD-RGBD/val/SIP" 39 | new_img_root = os.path.join(new_root, 'test_images') 40 | new_depth_root = os.path.join(new_root, 'test_depth') 41 | new_gt_root = os.path.join(new_root,'test_masks') 42 | 43 | if not os.path.exists(new_depth_root): 44 | os.mkdir(new_depth_root) 45 | if not os.path.exists(new_img_root): 46 | os.mkdir(new_img_root) 47 | if not os.path.exists(new_gt_root): 48 | os.mkdir(new_gt_root) 49 | 50 | # i=0 51 | # print(gt_names[0]) 52 | # img = np.array(PIL.Image.open(img_names[i])) 53 | # depth = np.array(PIL.Image.open(depth_names[i])) 54 | # gt = np.array(PIL.Image.open(gt_names[i])) 55 | # print(img.shape, depth.shape, gt.shape) 56 | 57 | 58 | 59 | 60 | for i in range(len(img_names)): 61 | img = cv2.imread(img_names[i]) 62 | depth = cv2.imread(depth_names[i]) 63 | gt = cv2.imread(gt_names[i]) 64 | 65 | img = cv2.resize(img, (512,512), interpolation = cv2.INTER_LINEAR) 66 | depth = cv2.resize(depth, (512,512), interpolation = cv2.INTER_LINEAR)[:,:,0] 67 | gt = cv2.resize(gt, (512,512), interpolation = cv2.INTER_LINEAR)[:,:,0] 68 | 69 | cv2.imwrite( os.path.join(new_img_root, names[i]+ '.jpg'), img ) 70 | cv2.imwrite( os.path.join(new_depth_root, names[i]+ '.png'), depth ) 71 | cv2.imwrite( os.path.join(new_gt_root, names[i]+ '.png'), gt ) 72 | 73 | # if i>10: 74 | # break 75 | --------------------------------------------------------------------------------