├── .gitignore
├── README.md
├── __pycache__
├── dataset_loader.cpython-37.pyc
├── loss.cpython-37.pyc
└── training.cpython-37.pyc
├── dataset_loader.py
├── imagenet_pretrain.py
├── launch_pretrain.sh
├── launch_test.sh
├── launch_train.sh
├── loss.py
├── main.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── dsam.cpython-36.pyc
│ ├── dsam.cpython-37.pyc
│ ├── genotypes.cpython-36.pyc
│ ├── genotypes.cpython-37.pyc
│ ├── gsmodule.cpython-36.pyc
│ ├── gsmodule.cpython-37.pyc
│ ├── model_depth.cpython-36.pyc
│ ├── model_depth.cpython-37.pyc
│ ├── model_fusion.cpython-36.pyc
│ ├── model_fusion.cpython-37.pyc
│ ├── model_fusion_raw.cpython-36.pyc
│ ├── model_fusion_raw.cpython-37.pyc
│ ├── model_rgb.cpython-36.pyc
│ ├── model_rgb.cpython-37.pyc
│ ├── operations.cpython-36.pyc
│ └── operations.cpython-37.pyc
├── dsam.py
├── genotypes.py
├── gsmodule.py
├── model_depth.py
├── model_fusion.py
├── model_rgb.py
└── operations.py
├── training.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── __init__.cpython-37.pyc
├── evaluateFM.cpython-36.pyc
├── evaluateFM.cpython-37.pyc
├── functions.cpython-36.pyc
└── functions.cpython-37.pyc
├── evaluateFM.py
├── functions.py
└── pretreat_SIP.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /runs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
9 |
10 |
11 | # DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion (CVPR'2021, Oral)
12 |
13 | This repo is the official implementation of
14 | ["DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion"](https://arxiv.org/pdf/2103.11832.pdf)
15 |
16 | by Peng Sun, Wenhu Zhang, Huanyu Wang, Songyuan Li, and Xi Li.
17 |
18 | # Prerequisites
19 | + Ubuntu 18
20 | + PyTorch 1.7.0
21 | + CUDA 10.1
22 | + Cudnn 7.5.1
23 | + Python 3.7
24 | + Numpy 1.17.3
25 |
26 |
27 | # Training
28 | Please see `launch_train.sh` and `launch_pretrain.sh` for imagenet pretraining and sod training, respectively.
29 |
30 | # Testing
31 | Please see `launch_test.sh` for testing on the sod benchmarks.
32 |
33 | ## Main Results
34 |
35 | |Dataset | Er| Sλmean|Fβmean| M |
36 | |:---:|:---:|:---:|:---:|:---:|
37 | |DUT-RGBD|0.950|0.921|0.926|0.030|
38 | |NJUD|0.923|0.903|0.901|0.039|
39 | |NLPR|0.950|0.918|0.897|0.024|
40 | |SSD|0.904|0.876|0.852|0.045|
41 | |STEREO|0.933|0.904|0.898|0.036|
42 | |LFSD|0.923|0.882|0.882|0.054|
43 | |RGBD135|0.962|0.920|0.896|0.021|
44 |
45 | ## Saliency maps and Evaluation
46 |
47 | All of the saliency maps mentioned in the paper are available on [GoogleDrive](https://drive.google.com/file/d/1pqRpWgyDry3o6iKNNDx_eM2_kEOftYY3/view?usp=sharing) or [BaiduYun](https://pan.baidu.com/s/1Fr5PuABceE7ordJvE84PKA)(code:juc2).
48 |
49 | You can use the toolbox provided by [jiwei0921](https://github.com/jiwei0921/Saliency-Evaluation-Toolbox) for evaluation.
50 |
51 | Additionally, we also provide the saliency maps of the STERE-1000 and SIP dataset on [BaiduYun](https://pan.baidu.com/s/1Pp1Hvckfsvr7mWq9qcY9pw)(code:qxfw) for easy comparison.
52 |
53 |
54 | |Dataset | Er| Sλmean|Fβmean| M |
55 | |:---:|:---:|:---:|:---:|:---:|
56 | |STERE-1000|0.928|0.897|0.895|0.038|
57 | |SIP|0.908|0.861|0.868|0.057|
58 |
59 | ## Citation
60 | ```
61 | @inproceedings{Sun2021DeepRS,
62 | title={Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion},
63 | author={P. Sun and Wenhu Zhang and Huanyu Wang and Songyuan Li and Xi Li},
64 | journal={IEEE Conf. Comput. Vis. Pattern Recog.},
65 | year={2021}
66 | }
67 | ```
68 |
69 |
70 | ## License
71 |
72 | The code is released under MIT License (see LICENSE file for details).
73 |
--------------------------------------------------------------------------------
/__pycache__/dataset_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/__pycache__/dataset_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/training.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/__pycache__/training.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL.Image
4 | import scipy.io as sio
5 | import torch
6 | from torch.utils import data
7 | import cv2
8 | from utils.functions import adaptive_bins, get_bins_masks
9 |
10 | class MyData(data.Dataset):
11 | """
12 | load data in a folder
13 | """
14 | mean_rgb = np.array([0.447, 0.407, 0.386])
15 | std_rgb = np.array([0.244, 0.250, 0.253])
16 |
17 |
18 | def __init__(self, root, transform=False):
19 | super(MyData, self).__init__()
20 | self.root = root
21 |
22 | self._transform = transform
23 | img_root = os.path.join(self.root, 'train_images')
24 | mask_root = os.path.join(self.root, 'train_masks')
25 | depth_root = os.path.join(self.root, 'train_depth')
26 | file_names = os.listdir(img_root)
27 | self.img_names = []
28 | self.mask_names = []
29 | self.depth_names = []
30 | for i, name in enumerate(file_names):
31 | if not name.endswith('.jpg'):
32 | continue
33 | ## training with 2 dataset
34 | # if len(name.split('_')[0]) ==4 :
35 | # continue
36 | # print(name)
37 | self.mask_names.append(
38 | os.path.join(mask_root, name[:-4] + '.png')
39 | )
40 |
41 | self.img_names.append(
42 | os.path.join(img_root, name)
43 | )
44 | self.depth_names.append(
45 | os.path.join(depth_root, name[:-4] + '.png')
46 | )
47 |
48 | def __len__(self):
49 | return len(self.img_names)
50 |
51 | def __getitem__(self, index):
52 | # load image
53 | img_file = self.img_names[index]
54 | img = PIL.Image.open(img_file)
55 | img = np.array(img, dtype=np.uint8)
56 | # load label
57 | mask_file = self.mask_names[index]
58 | mask = PIL.Image.open(mask_file)
59 | mask = np.array(mask, dtype=np.int32)
60 | mask[mask != 0] = 1
61 | # load depth
62 | depth_file = self.depth_names[index]
63 | depth = PIL.Image.open(depth_file)
64 | depth = np.array(depth, dtype=np.uint8)
65 | # bins
66 | bins_mask = get_bins_masks(depth)
67 |
68 | if self._transform:
69 | return self.transform(img, mask, depth, bins_mask)
70 | else:
71 | return img, mask, depth, bins_mask
72 |
73 | def transform(self, img, mask, depth, bins_mask):
74 | img = img.astype(np.float64)/255.0
75 | img -= self.mean_rgb
76 | img /= self.std_rgb
77 | img = img.transpose(2, 0, 1) # to verify
78 | img = torch.from_numpy(img).float()
79 | mask = torch.from_numpy(mask).long()
80 | depth = depth.astype(np.float64) / 255.0
81 | depth = torch.from_numpy(depth).float()
82 |
83 | bins_mask=torch.from_numpy(bins_mask).float()
84 | h,w=depth.size()
85 | bins_depth = depth.view(1, h, w).repeat(3, 1, 1)
86 | bins_depth=bins_depth * bins_mask
87 | for i in range(3):
88 | bins_depth[i]=bins_depth[i]/bins_depth[i].max()
89 | c, h, w = img.size()
90 | return img, mask, depth, bins_depth#
91 |
92 |
93 |
94 |
95 |
96 | class MyTestData(data.Dataset):
97 | """
98 | load data in a folder
99 | """
100 | mean_rgb = np.array([0.447, 0.407, 0.386])
101 | std_rgb = np.array([0.244, 0.250, 0.253])
102 |
103 | def __init__(self, root, transform=False, use_bins=True):
104 | super(MyTestData, self).__init__()
105 | self.root = root
106 | self._transform = transform
107 | self._bins = use_bins
108 |
109 | img_root = os.path.join(self.root, 'test_images')
110 | depth_root = os.path.join(self.root, 'test_depth')
111 | file_names = os.listdir(img_root)
112 | self.img_names = []
113 | self.names = []
114 | self.depth_names = []
115 |
116 | for i, name in enumerate(file_names):
117 | if not name.endswith('.jpg'):
118 | continue
119 | self.img_names.append(
120 | os.path.join(img_root, name)
121 | )
122 | self.names.append(name[:-4])
123 | self.depth_names.append(
124 | os.path.join(depth_root, name[:-4] + '.png')
125 | )
126 |
127 | def __len__(self):
128 | return len(self.img_names)
129 |
130 | def __getitem__(self, index):
131 | # load image
132 | img_file = self.img_names[index]
133 | img = PIL.Image.open(img_file)
134 | img_size = img.size
135 | img = np.array(img, dtype=np.uint8)
136 |
137 | # load depth
138 | depth_file = self.depth_names[index]
139 | depth = PIL.Image.open(depth_file)
140 | depth = np.array(depth, dtype=np.uint8)
141 |
142 | bins_mask = get_bins_masks(depth)
143 |
144 |
145 | if self._transform:
146 | img, depth, bins_depth = self.transform(img, depth, bins_mask)
147 | return img, depth, bins_depth, self.names[index], img_size
148 | else:
149 | return img, depth, bins_mask, self.names[index], img_size
150 |
151 |
152 |
153 | def transform(self, img, depth, bins_mask):
154 | img = img.astype(np.float64)/255.0
155 | img -= self.mean_rgb
156 | img /= self.std_rgb
157 | img = img.transpose(2, 0, 1) # to verify
158 | img = torch.from_numpy(img).float()
159 |
160 | depth = depth.astype(np.float64) / 255.0
161 | depth = torch.from_numpy(depth).float()
162 |
163 | bins_mask=torch.from_numpy(bins_mask).float()
164 | h,w=depth.size()
165 | bins_depth = depth.view(1, h, w).repeat(3, 1, 1)
166 | bins_depth=bins_depth * bins_mask
167 | for i in range(3):
168 | bins_depth[i]=bins_depth[i]/bins_depth[i].max()
169 | c, h, w = img.size()
170 | return img, depth,bins_depth#
171 |
172 | if __name__ == '__main__':
173 | root = "/data/wenhu/RGBD-SOD/SOD-RGBD/val/SIP"
174 | test_loader = torch.utils.data.DataLoader(MyTestData(root, transform=True),
175 | batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
176 | for id, (data, depth, bins, img_name, img_size) in enumerate(test_loader):
177 | print(img_size)
--------------------------------------------------------------------------------
/imagenet_pretrain.py:
--------------------------------------------------------------------------------
1 | """
2 | DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion
3 | """
4 | import os
5 | import torch
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torchvision import transforms
12 | from torch.autograd import Variable
13 | import torch.backends.cudnn as cudnn
14 | import torchvision.datasets as datasets
15 | import time
16 | import torchvision
17 | import logging
18 | import sys
19 | import argparse
20 | import numpy as np
21 | import torch.backends.cudnn as cudnn
22 | import torch.distributed as dist
23 | from torch.nn.parallel import DistributedDataParallel as DDP
24 | from tensorboardX import SummaryWriter
25 | from utils.functions import *
26 | from models.model_depth import DepthNet
27 | from models.model_rgb import RgbNet
28 | from models.model_fusion import NasFusionNet_pre
29 | import torch.multiprocessing as mp
30 | import warnings
31 | warnings.filterwarnings("ignore")
32 |
33 | def find_free_port():
34 | import socket
35 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
36 | # Binding to port 0 will cause the OS to find an available port for us
37 | sock.bind(("", 0))
38 | port = sock.getsockname()[1]
39 | sock.close()
40 | # NOTE: there is still a chance the port could be taken by other processes.
41 | return port
42 |
43 | def reduce_tensor(tensor):
44 | rt = tensor.clone()
45 | dist.all_reduce(rt, op=dist.reduce_op.SUM)
46 | rt /= args.world_size
47 | return rt
48 |
49 | def accuracy(output, target, topk=(1,)):
50 | """Computes the precision@k for the specified values of k"""
51 | maxk = max(topk)
52 | batch_size = target.size(0)
53 |
54 | _, pred = output.topk(maxk, 1, True, True)
55 | pred = pred.t()
56 | correct = pred.eq(target.view(1, -1).expand_as(pred))
57 |
58 | res = []
59 | for k in topk:
60 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
61 | res.append(correct_k.mul_(100.0 / batch_size))
62 | return res
63 |
64 | def to_python_float(t):
65 | if hasattr(t, 'item'):
66 | return t.item()
67 | else:
68 | return t[0]
69 |
70 |
71 | class AverageMeter(object):
72 | """Computes and stores the average and current value"""
73 | def __init__(self):
74 | self.reset()
75 |
76 | def reset(self):
77 | self.val = 0
78 | self.avg = 0
79 | self.sum = 0
80 | self.count = 0
81 |
82 | def update(self, val, n=1):
83 | self.val = val
84 | self.sum += val * n
85 | self.count += n
86 | self.avg = self.sum / self.count
87 |
88 |
89 | def adjust_learning_rate(optimizer, epoch, args):
90 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
91 | lr = args.lr
92 | if epoch >= 30:
93 | lr = 0.1 * lr
94 | if epoch >= 60:
95 | lr = 0.1 * lr
96 | if epoch >= 80:
97 | lr = 0.1 * lr
98 | optimizer.param_groups[0]['lr'] = lr
99 |
100 |
101 | def train(train_loader, models, CE, optimizers, epoch, logger, logging):
102 | """Train for one epoch on the training set"""
103 | batch_time = AverageMeter()
104 | losses = AverageMeter()
105 |
106 | top1 = AverageMeter()
107 | top5 = AverageMeter()
108 |
109 | # switch to train mode
110 | for m in models:
111 | m.train()
112 | end = time.time()
113 |
114 | for i, (inputs, target) in enumerate(train_loader):
115 | global_step = epoch * len(train_loader) + i
116 | target = target.cuda()
117 | inputs = inputs.cuda()
118 |
119 | # print(gpu,models[0].device_ids, inputs.device)
120 | b,c,h,w = inputs.size()
121 | depth = torch.mean(inputs,dim = 1).view(b,1,h,w).repeat(1, c, 1, 1)
122 | # print("inpus:",inputs.shape)
123 | h1, h2, h3, h4, h5 = models[0](inputs, depth, gumbel=True)
124 | d0, d1, d2, d3, d4 = models[1](depth)
125 | output = models[2](h1, h2, h3, h4, h5, d0, d1, d2, d3, d4)
126 |
127 | # A loss
128 | loss = CE( output, target) * 1.0
129 |
130 | # measure accuracy and record loss
131 | prec1, prec5 = accuracy(output.data, target, topk=(1,5))
132 |
133 | reduced_loss = reduce_tensor(loss.data)
134 | prec1 = reduce_tensor(prec1)
135 | prec5 = reduce_tensor(prec5)
136 |
137 | losses.update(to_python_float(reduced_loss), inputs.size(0))
138 | top1.update(to_python_float(prec1), inputs.size(0))
139 | top5.update(to_python_float(prec5), inputs.size(0))
140 |
141 |
142 | # compute gradient and do SGD step
143 | for op in optimizers:
144 | op.zero_grad()
145 |
146 | loss.backward()
147 |
148 | for op in optimizers:
149 | op.step()
150 |
151 | # measure elapsed time
152 | batch_time.update(time.time() - end)
153 | end = time.time()
154 |
155 | if i % 50 == 0 and args.rank ==0:
156 | logging.info('Epoch: [{0}][{1}/{2}]\t'
157 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
158 | 'Loss {loss.val:.4f} \t'
159 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
160 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
161 | epoch, i, len(train_loader), batch_time=batch_time,
162 | loss=losses, top1=top1, top5 = top5))
163 |
164 | logger.add_scalar('train/losses', losses.avg, global_step=global_step)
165 | logger.add_scalar('train/top1', top1.avg, global_step=global_step)
166 | logger.add_scalar('train/top5', top5.avg, global_step=global_step)
167 | logger.add_scalar('train/lr', optimizers[0].param_groups[0]['lr'], global_step=global_step)
168 |
169 |
170 | def validate(valid_loader, models, CE, epoch, logger, logging):
171 | """Perform validation on the validation set"""
172 | batch_time = AverageMeter()
173 | losses = AverageMeter()
174 | top1 = AverageMeter()
175 | top5 = AverageMeter()
176 |
177 | for m in models:
178 | m.eval()
179 |
180 | end = time.time()
181 | for i, (inputs, target) in enumerate(valid_loader):
182 | target = target.cuda()
183 | inputs = inputs.cuda()
184 | with torch.no_grad():
185 | b,c,h,w = inputs.size()
186 | depth = torch.mean(inputs,dim = 1).view(b,1,h,w).repeat(1, c, 1, 1)
187 |
188 | h1, h2, h3, h4, h5 = models[0](inputs, depth, gumbel=False)
189 | d0, d1, d2, d3, d4 = models[1](depth)
190 | output = models[2](h1, h2, h3, h4, h5, d0, d1, d2, d3, d4)
191 |
192 | loss = CE(output, target)
193 |
194 | # measure accuracy and record loss
195 | prec1 , prec5 = accuracy(output.data, target, topk=(1,5))
196 |
197 | reduced_loss = reduce_tensor(loss.data)
198 | prec1 = reduce_tensor(prec1)
199 | prec5 = reduce_tensor(prec5)
200 |
201 | losses.update(to_python_float(reduced_loss), inputs.size(0))
202 | top1.update(to_python_float(prec1), inputs.size(0))
203 | top5.update(to_python_float(prec5), inputs.size(0))
204 |
205 | # measure elapsed time
206 | batch_time.update(time.time() - end)
207 | end = time.time()
208 |
209 | if i % 50 == 0 and args.rank == 0:
210 | logging.info('Test: [{0}/{1}]\t'
211 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
212 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
213 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
214 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
215 | i, len(valid_loader), batch_time=batch_time, loss=losses,
216 | top1=top1, top5 = top5))
217 |
218 |
219 | logger.add_scalar('valid/top1', top1.avg, global_step=epoch)
220 | logger.add_scalar('valid/top5', top5.avg, global_step=epoch)
221 |
222 | logging.info(' * Prec@1 {top1.avg:.3f} * Prec@5 {top5.avg:.3f} '.format(top1=top1, top5=top5))
223 |
224 | return top1.avg
225 |
226 |
227 |
228 | def main_worker(gpu, argss):
229 | global args
230 | args = argss
231 |
232 | torch.cuda.set_device(gpu)
233 | rank = args.nr * args.gpus + gpu
234 | args.rank = rank
235 | exp_name = '/imagenet_pretrain'
236 | args.save_path = args.save_path + exp_name
237 | args.snapshot_root = args.save_path +'/snapshot/'
238 | args.log_root = args.save_path + '/logs/train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
239 |
240 | if args.phase =='train' and args.rank ==0 :
241 | create_exp_dir(args.log_root, scripts_to_save=None)
242 | log_format = '%(asctime)s %(message)s'
243 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
244 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
245 | fh = logging.FileHandler(os.path.join(args.log_root, 'log.txt'))
246 | fh.setFormatter(logging.Formatter(log_format))
247 | logging.getLogger().addHandler(fh)
248 |
249 | if not os.path.exists(args.snapshot_root) and args.rank ==0 :
250 | os.mkdir(args.snapshot_root)
251 |
252 | dist.init_process_group(
253 | backend='nccl',
254 | init_method=args.dist_url,
255 | world_size=args.world_size,
256 | rank=args.rank)
257 |
258 |
259 | """""""""""dataset loader"""""""""
260 | # ImageNet Data loading code
261 | train_dataset = datasets.ImageFolder(
262 | os.path.join(args.data_root, 'train'),
263 | transforms.Compose([
264 | transforms.RandomSizedCrop(224),
265 | transforms.RandomHorizontalFlip(),
266 | transforms.ToTensor(),
267 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
268 | std=[0.229, 0.224, 0.225]),
269 | ]))
270 | train_sampler = torch.utils.data.distributed.DistributedSampler(
271 | train_dataset,
272 | num_replicas=args.world_size,
273 | rank=rank,
274 | )
275 | train_loader = torch.utils.data.DataLoader(
276 | dataset = train_dataset,
277 | batch_size = args.batchsize,
278 | num_workers=0, pin_memory=True, sampler = train_sampler)
279 |
280 |
281 | valid_dataset = datasets.ImageFolder(
282 | os.path.join(args.data_root, 'val'),
283 | transforms.Compose([
284 | transforms.Scale(256),
285 | transforms.CenterCrop(224),
286 | transforms.ToTensor(),
287 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
288 | std=[0.229, 0.224, 0.225]),
289 | ]))
290 | valid_sampler = torch.utils.data.distributed.DistributedSampler(
291 | valid_dataset,
292 | num_replicas=args.world_size,
293 | rank=rank,
294 | shuffle=False
295 | )
296 | valid_loader = torch.utils.data.DataLoader(
297 | dataset = valid_dataset,
298 | batch_size = args.batchsize,
299 | num_workers=0, pin_memory=True, sampler = valid_sampler)
300 |
301 | kwargs = {'num_workers': 2, 'pin_memory': True}
302 | logging.info('data already')
303 |
304 | """""""""""train_data/test_data through nets"""""""""
305 |
306 | model_depth = torch.nn.SyncBatchNorm.convert_sync_batchnorm(DepthNet())
307 | model_rgb = torch.nn.SyncBatchNorm.convert_sync_batchnorm(RgbNet())
308 | model_fusion = torch.nn.SyncBatchNorm.convert_sync_batchnorm(NasFusionNet_pre())
309 |
310 | model_depth.init_weights()
311 | vgg19_bn = torchvision.models.vgg19_bn(pretrained=True)
312 | model_rgb.copy_params_from_vgg19_bn(vgg19_bn)
313 | model_fusion.init_weights()
314 |
315 | if args.rank==0:
316 | print("model_rgb param size = %fMB", count_parameters_in_MB(model_rgb))
317 | print("model_depth param size = %fMB", count_parameters_in_MB(model_depth))
318 | print("nas-model param size = %fMB", count_parameters_in_MB(model_fusion))
319 |
320 | model_depth = model_depth.cuda()
321 | model_rgb = model_rgb.cuda()
322 | model_fusion = model_fusion.cuda()
323 |
324 | if args.distributed:
325 | model_depth = DDP(model_depth, device_ids=[gpu])
326 | model_rgb = DDP(model_rgb, device_ids=[gpu])
327 | model_fusion = DDP(model_fusion, device_ids=[gpu])
328 |
329 | optimizer_depth = optim.SGD(model_depth.parameters(), lr= args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
330 | optimizer_rgb = optim.SGD(model_rgb.parameters(), lr= args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
331 | optimizer_fusion = optim.SGD(model_fusion.parameters(), lr= args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
332 |
333 | # print(optimizer_depth.param_groups[0]['lr'])
334 | CE = nn.CrossEntropyLoss().cuda()
335 |
336 | logger = SummaryWriter(args.log_root)
337 |
338 | best_prec1 = -1
339 | for epoch in range(0, args.epoch):
340 |
341 | adjust_learning_rate(optimizer_depth, epoch, args)
342 | adjust_learning_rate(optimizer_rgb, epoch, args)
343 | adjust_learning_rate(optimizer_fusion, epoch, args)
344 | if args.rank==0:
345 | print("lr:",optimizer_rgb.param_groups[0]['lr'])
346 | # train for one epoch
347 | train_sampler.set_epoch(epoch)
348 | train(train_loader, [model_rgb, model_depth, model_fusion], CE, [optimizer_rgb, optimizer_depth, optimizer_fusion], epoch, logger, logging)
349 |
350 | # evaluate on validation set
351 | prec1 = validate(valid_loader, [model_rgb, model_depth, model_fusion], CE, epoch, logger, logging)
352 |
353 | # remember best prec@1 and save checkpoint
354 | is_best = prec1 > best_prec1
355 | best_prec1 = max(prec1, best_prec1)
356 |
357 | if args.rank ==0:
358 | logging.info('Best accuracy: %f' % best_prec1)
359 | logger.add_scalar('best/accuracy', best_prec1, global_step=epoch)
360 |
361 | savename_depth = ('%s/depth_pre_epoch%d.pth' % (args.snapshot_root, epoch))
362 | torch.save(model_depth.state_dict(), savename_depth)
363 | print('save: (snapshot: %d)' % (epoch))
364 |
365 | savename_rgb = ('%s/rgb_pre_epoch%d.pth' % (args.snapshot_root, epoch))
366 | torch.save(model_rgb.state_dict(), savename_rgb)
367 | print('save: (snapshot: %d)' % (epoch))
368 |
369 | savename_fusion = ('%s/fusion_pre_epoch%d.pth' % (args.snapshot_root, epoch))
370 | torch.save(model_fusion.state_dict(), savename_fusion)
371 | print('save: (snapshot: %d)' % (epoch))
372 |
373 | if is_best:
374 | savename_depth = ('%s/depth_pre.pth' % (args.snapshot_root))
375 | torch.save(model_depth.state_dict(), savename_depth)
376 | print('save: (snapshot: %d)' % (epoch))
377 |
378 | savename_rgb = ('%s/rgb_pre.pth' % (args.snapshot_root))
379 | torch.save(model_rgb.state_dict(), savename_rgb)
380 | print('save: (snapshot: %d)' % (epoch))
381 |
382 | savename_fusion = ('%s/fusion_pre.pth' % (args.snapshot_root))
383 | torch.save(model_fusion.state_dict(), savename_fusion)
384 | print('save: (snapshot: %d)' % (epoch))
385 |
386 | def main():
387 | parser=argparse.ArgumentParser()
388 | parser.add_argument('--phase', type=str, default='train', help='train or test')
389 | parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters')
390 | parser.add_argument('--data_root', type=str, default='/4T/sunpeng/ImageNet')
391 |
392 | parser.add_argument('--save_path', type=str, default='/home/wenhu/pami21/runs/', help='save & log path')
393 | parser.add_argument('--snapshot_root', type=str, default='None', help='path to snapshot')
394 | parser.add_argument('--log_root', type=str, default='path to logs')
395 |
396 | parser.add_argument('--test_dataset', type=str, default='')
397 | parser.add_argument('--parse_method', type=str, default='darts', help='parse the code method')
398 |
399 | parser.add_argument('--batchsize', type=int, default=2, help='batchsize')
400 | parser.add_argument('--epoch', type=int, default=100, help='epoch')
401 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
402 | help='initial learning rate')
403 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
404 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float,
405 | help='weight decay (default: 1e-4)')
406 |
407 | parser.add_argument('-n', '--nodes', default=1,
408 | type=int, metavar='N')
409 | parser.add_argument('-g', '--gpus', default=2, type=int,
410 | help='number of gpus per node')
411 | parser.add_argument('-nr', '--nr', default=0, type=int,
412 | help='ranking within the nodes')
413 | args = parser.parse_args()
414 |
415 | args.distributed = True
416 |
417 |
418 | args.world_size = args.gpus * args.nodes
419 | port = find_free_port()
420 | args.dist_url = f"tcp://127.0.0.1:{port}"
421 | mp.spawn(main_worker, nprocs=args.gpus, args=(args,))
422 |
423 | if __name__ == '__main__':
424 | main()
--------------------------------------------------------------------------------
/launch_pretrain.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | ############### imagenet pretraining
4 | python imagenet_pretrain.py -n 1 -g 4 -nr 0 \
5 | --phase train --epoch 90 --batchsize 4 --lr 0.00625 --momentum 0.9 --weight_decay 5e-4 \
6 | --data_root /4T/sunpeng/ImageNet \
7 | --save_path /4T/wenhu/pami21/github_test
--------------------------------------------------------------------------------
/launch_test.sh:
--------------------------------------------------------------------------------
1 |
2 | ###
3 | # @Author: Wenhu Zhang
4 | # @Date: 2021-06-07 17:39:22
5 | # @LastEditTime: 2021-06-07 18:13:00
6 | # @LastEditors: Wenhu Zhang
7 | # @Description:
8 | # @FilePath: /github/wh/DSA2F/launch_test.sh
9 | ###
10 |
11 | CUDA_VISIBLE_DEVICES="0" python -u main.py --phase test --test_dataset NJUD --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_NJUD.txt &
12 | CUDA_VISIBLE_DEVICES="0" python -u main.py --phase test --test_dataset DUT-RGBD --begin_epoch 20 --end_epoch 97 --exp_name 0428debug > results/0428debug_DUT-RGBD.txt &
13 | CUDA_VISIBLE_DEVICES="1" python -u main.py --phase test --test_dataset NLPR --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_NLPR.txt &
14 | CUDA_VISIBLE_DEVICES="1" python -u main.py --phase test --test_dataset SSD --begin_epoch 20000 --end_epoch 20000 --exp_name 0428debug > results/0428debug_SSD.txt
15 |
16 | CUDA_VISIBLE_DEVICES="2" python -u main.py --phase test --test_dataset STEREO --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_STEREO.txt &
17 | CUDA_VISIBLE_DEVICES="2" python -u main.py --phase test --test_dataset LFSD --begin_epoch 1 --end_epoch 97 --exp_name 0428debug > results/0428debug_LFSD.txt &
18 | CUDA_VISIBLE_DEVICES="3" python -u main.py --phase test --test_dataset RGBD135 --begin_epoch 1 --end_epoch 8 --exp_name 0428debug > results/0428debug_RGBD135.txt &
19 | CUDA_VISIBLE_DEVICES="3" python -u main.py --phase test --test_dataset SIP --begin_epoch 1 --end_epoch 8 --exp_name 0428debug > results/0428debug_SIP.txt &
20 | CUDA_VISIBLE_DEVICES="3" python -u main.py --phase test --test_dataset ReDWeb --begin_epoch 1 --end_epoch 8 --exp_name 0428debug > results/0428debug_ReDWeb.txt &
--------------------------------------------------------------------------------
/launch_train.sh:
--------------------------------------------------------------------------------
1 |
2 | ###
3 | # @Author: Wenhu Zhang
4 | # @Date: 2021-06-07 17:39:22
5 | # @LastEditTime: 2021-06-07 18:09:35
6 | # @LastEditors: Wenhu Zhang
7 | # @Description:
8 | # @FilePath: /github/wh/DSA2F/launch_train.sh
9 | ###
10 | CUDA_VISIBLE_DEVICES="0" python main.py --phase train --epoch 60 \
11 | --save_path /4T/wenhu/pami21/ \
12 | --pretrain_path /home/wenhu/pami21/ckpt_best.pth.tar \
13 | --train_dataroot /4T/wenhu/dataset/SOD-RGBD/train_data-augment/ \
14 | --test_dataroot /4T/wenhu/dataset/SOD-RGBD/val/ \
15 | --exp_name 0607debug
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.autograd import Variable
3 | import torch.nn.functional as F
4 | import torch
5 | import numpy as np
6 | import torch.nn as nn
7 | class BinaryDiceLoss(nn.Module):
8 | """Dice loss of binary class
9 | Args:
10 | smooth: A float number to smooth loss, and avoid NaN error, default: 1
11 | p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
12 | predict: A tensor of shape [N, *]
13 | target: A tensor of shape same with predict
14 | reduction: Reduction method to apply, return mean over batch if 'mean',
15 | return sum if 'sum', return a tensor of shape [N,] if 'none'
16 | Returns:
17 | Loss tensor according to arg reduction
18 | Raise:
19 | Exception if unexpected reduction
20 | """
21 | def __init__(self, smooth=1, p=2, reduction='mean'):
22 | super(BinaryDiceLoss, self).__init__()
23 | self.smooth = smooth
24 | self.p = p
25 | self.reduction = reduction
26 |
27 | def forward(self, predict, target):
28 | predict = F.softmax(predict, dim=1)[:,1,:,:].unsqueeze(1)
29 | target = target.unsqueeze(1).float()
30 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
31 | predict = predict.contiguous().view(predict.shape[0], -1)
32 | target = target.contiguous().view(target.shape[0], -1)
33 |
34 | num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
35 | den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
36 |
37 | loss = 1 - num / den
38 |
39 | if self.reduction == 'mean':
40 | return loss.mean() *256*256
41 | elif self.reduction == 'sum':
42 | return loss.sum()*256*256
43 | elif self.reduction == 'none':
44 | return loss*256*256
45 | else:
46 | raise Exception('Unexpected reduction {}'.format(self.reduction))
47 |
48 |
49 |
50 | def cross_entropy2d(input, target, weight=None, size_average=True):
51 | n, c, h, w = input.size()
52 |
53 | input = input.transpose(1, 2).transpose(2, 3).contiguous()
54 | input = input[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0] # 262144 #input = 2*256*256*2
55 | input = input.view(-1, c)
56 | mask = target >= 0
57 | target = target[mask]
58 | loss = F.cross_entropy(input, target, weight=weight, size_average=False)
59 | if size_average:
60 | loss /= mask.data.sum()
61 | return loss
62 |
63 |
64 |
65 | def iou(pred, target, size_average = False):
66 |
67 | pred = F.softmax(pred, dim=1)
68 | IoU = 0.0
69 | Iand1 = torch.sum(target.float() * pred[:,1,:,:])
70 | Ior1 = torch.sum(target) + torch.sum(pred[:,1,:,:]) - Iand1
71 | IoU1 = (Iand1 + 1) / (Ior1 + 1)
72 |
73 | IoU = (1-IoU1)
74 |
75 | if size_average:
76 | IoU /= target.data.sum()
77 | return IoU * 256 * 256
78 |
79 |
80 | # class CrossEntropyLabelSmooth(nn.Module):
81 | # """Cross entropy loss with label smoothing regularizer.
82 | # Reference:
83 | # Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
84 | # Equation: y = (1 - epsilon) * y + epsilon / K.
85 | # Args:
86 | # num_classes (int): number of classes.
87 | # epsilon (float): weight.
88 | # """
89 |
90 | # def __init__(self, num_classes=1000, epsilon=0.1):
91 | # super(CrossEntropyLabelSmooth, self).__init__()
92 | # self.num_classes = num_classes
93 | # self.epsilon = epsilon
94 | # self.logsoftmax = nn.LogSoftmax(dim=1)
95 |
96 | # def forward_v1(self, inputs, targets):
97 | # log_probs = self.logsoftmax(inputs)
98 | # targets = torch.zeros(log_probs.size(), device=targets.device).scatter_(1, targets.unsqueeze(1), 1)
99 | # targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
100 | # loss = (- targets * log_probs).mean(0).sum()
101 | # return loss
102 |
103 | # def forward_v2(self, inputs, targets):
104 | # probs = self.logsoftmax(inputs)
105 | # targets = torch.zeros(probs.size(), device=targets.device).scatter_(1, targets.unsqueeze(1), 1)
106 | # targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
107 | # loss = nn.KLDivLoss()(probs, targets)
108 | # return loss
109 |
110 | # def forward(self, inputs, targets):
111 | # """
112 | # Args:
113 | # inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
114 | # targets: ground truth labels with shape (num_classes)
115 | # """
116 | # return self.forward_v1(inputs, targets)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | DSA^2 F: Deep RGB-D Saliency Detection with Depth-Sensitive Attention and Automatic Multi-Modal Fusion
3 | """
4 | import os
5 | import torch
6 | from torch.autograd import Variable
7 | from torch.utils.data import DataLoader
8 | import torch.nn.functional as F
9 | import time
10 | import torchvision
11 | import logging
12 | import sys
13 | import argparse
14 | import numpy as np
15 | import torch.backends.cudnn as cudnn
16 | import torch.distributed as dist
17 | from torch.nn.parallel import DistributedDataParallel as DDP
18 | from tensorboardX import SummaryWriter
19 | from dataset_loader import MyData, MyTestData
20 | from utils.functions import *
21 | from training import Trainer
22 | from utils.evaluateFM import get_FM
23 | from models.model_depth import DepthNet
24 | from models.model_rgb import RgbNet
25 | from models.model_fusion import NasFusionNet
26 | import warnings
27 | warnings.filterwarnings("ignore")
28 |
29 | configurations = {
30 | 1: dict(
31 | max_iteration=1000000,
32 | lr=5e-9,
33 | momentum=0.9,
34 | weight_decay=0.0005,
35 | spshot=20000,
36 | nclass=2,
37 | sshow=100,
38 | ),
39 | }
40 | parser=argparse.ArgumentParser()
41 | parser.add_argument('--phase', type=str, default='train', help='train or test')
42 | parser.add_argument('--param', type=str, default=True, help='path to pre-trained parameters')
43 | parser.add_argument('--train_dataroot', type=str, default='/4T/wenhu/dataset/SOD-RGBD/train_data-augment', help=
44 | 'path to train data')
45 | parser.add_argument('--test_dataroot', type=str, default='/4T/wenhu/dataset/SOD-RGBD/val/', help=
46 | 'path to test data')
47 | parser.add_argument('--pretrain_path', type=str, default='')
48 |
49 | parser.add_argument('--exp_name', type=str, default='debug', help='save & log path')
50 | parser.add_argument('--save_path', type=str, default='/4T/wenhu/pami21/', help='save & log path')
51 | parser.add_argument('--snapshot_root', type=str, default='None', help='path to snapshot')
52 | parser.add_argument('--salmap_root', type=str, default='None', help='path to saliency map')
53 | parser.add_argument('--log_root', type=str, default='path to logs')
54 |
55 | parser.add_argument('--test_dataset', type=str, default='LFSD')
56 | parser.add_argument('--begin_epoch', type=int, default=0)
57 | parser.add_argument('--end_epoch', type=int, default=0)
58 | parser.add_argument('--parse_method', type=str, default='darts', help='parse the code method')
59 |
60 | parser.add_argument('--batchsize', type=int, default=2, help='batchsize')
61 | parser.add_argument('--epoch', type=int, default=60, help='epoch')
62 | parser.add_argument("--local_rank", default=-1)
63 | parser.add_argument('-c', '--config', type=int, default=1, choices=configurations.keys())
64 | args = parser.parse_args()
65 | cfg = configurations
66 |
67 |
68 |
69 | args.save_path = args.save_path + args.exp_name
70 | if args.phase =='train':
71 | if os.path.exists(args.save_path):
72 | print(".... error!!!!!!!!!! save path already exist .....")
73 | logging.info(".... error!!!!!!!!!! save path already exist .....")
74 | sys.exit()
75 | else :
76 | os.mkdir(args.save_path)
77 |
78 | args.snapshot_root = args.save_path +'/snapshot/'
79 | args.salmap_root = args.save_path + '/sal_map/'
80 | args.log_root = args.save_path + '/logs/'
81 | if not os.path.exists(args.salmap_root):
82 | os.mkdir(args.salmap_root)
83 |
84 | cuda = torch.cuda.is_available
85 |
86 | if args.phase =='train':
87 | create_exp_dir(args.log_root, scripts_to_save=None)
88 | log_format = '%(asctime)s %(message)s'
89 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
90 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
91 | fh = logging.FileHandler(os.path.join(args.log_root, 'log.txt'))
92 | fh.setFormatter(logging.Formatter(log_format))
93 | logging.getLogger().addHandler(fh)
94 |
95 |
96 | """""""""""dataset loader"""""""""
97 |
98 | train_dataRoot = args.train_dataroot
99 |
100 | if not os.path.exists(args.snapshot_root):
101 | os.mkdir(args.snapshot_root)
102 |
103 | if args.phase == 'train':
104 | SnapRoot = args.snapshot_root # checkpoint
105 | train_loader = torch.utils.data.DataLoader(MyData(train_dataRoot, transform=True),
106 | batch_size = args.batchsize, shuffle=True, num_workers=0, pin_memory=True)
107 | else:
108 | test_dataRoot = args.test_dataroot +args.test_dataset
109 | max_F_dict = {}
110 | min_mae_dict = {}
111 | MapRoot = args.salmap_root +args.test_dataset
112 | if not os.path.exists(MapRoot):
113 | os.mkdir(MapRoot)
114 | test_loader = torch.utils.data.DataLoader(MyTestData(test_dataRoot, transform=True),
115 | batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
116 | print ('data already')
117 |
118 | """""""""""train_data/test_data through nets"""""""""
119 | cuda = torch.cuda.is_available
120 | start_epoch = 0
121 | start_iteration = 0
122 |
123 | model_depth = DepthNet()
124 | model_rgb = RgbNet()
125 | model_fusion = NasFusionNet()
126 |
127 | print("model_rgb param size = %fMB", count_parameters_in_MB(model_rgb))
128 | print("model_depth param size = %fMB", count_parameters_in_MB(model_depth))
129 | print("nas-model param size = %fMB", count_parameters_in_MB(model_fusion))
130 |
131 | if args.begin_epoch == args.end_epoch:
132 | test_check_list = [args.end_epoch]
133 | else:
134 | test_epoch_list = [i*16418 for i in range(1,61)]
135 | test_iter_list = [i*10000 for i in range(args.begin_epoch, args.end_epoch+1)]
136 | test_check_list = test_epoch_list + test_iter_list
137 | test_check_list.sort()
138 | for ckpt_i in test_check_list: # When training, remove this line.ssss
139 | best_F = -float('inf')
140 | best_mae = float('inf')
141 |
142 | if args.phase == 'test':
143 | ckpt = str(ckpt_i)
144 | print(".... load checkpoint "+ ckpt +" for test .....")
145 | model_depth.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'depth_snapshot_iter_' + ckpt + '.pth')))
146 | model_rgb.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'rgb_snapshot_iter_'+ckpt+'.pth')))
147 | model_fusion.load_state_dict(torch.load(os.path.join(args.snapshot_root, 'fusion_snapshot_iter_'+ckpt+'.pth')))
148 |
149 | elif (args.pretrain_path):
150 | pretrained_dict = load_pretrain(args.pretrain_path, model_depth.state_dict(), "model_depth.")
151 | model_depth.load_state_dict(pretrained_dict)
152 |
153 | pretrained_dict = load_pretrain(args.pretrain_path, model_rgb.state_dict(), "model_rgb.")
154 | model_rgb.load_state_dict(pretrained_dict)
155 |
156 | model_fusion.init_weights()
157 | pretrained_dict = load_pretrain(args.pretrain_path, model_fusion.state_dict(), "model_fusion.")
158 | model_fusion.load_state_dict(pretrained_dict)
159 | logging.info(".... load imagenet pretrain models .....")
160 |
161 | else:
162 | logging.info(".... norm init .....")
163 | model_depth.init_weights()
164 | vgg19_bn = torchvision.models.vgg19_bn(pretrained=True)
165 | model_rgb.copy_params_from_vgg19_bn(vgg19_bn)
166 | model_fusion.init_weights()
167 |
168 | if cuda:
169 | model_depth = model_depth.cuda()
170 | model_rgb = model_rgb.cuda()
171 | model_fusion = model_fusion.cuda()
172 |
173 | if args.phase == 'train':
174 | cudnn.benchmark = True
175 | # torch.manual_seed(444)
176 | cudnn.enabled=True
177 | # torch.cuda.manual_seed(444)
178 | writer = SummaryWriter(args.log_root)
179 | model_rgb.cuda()
180 | model_depth.cuda()
181 | model_fusion.cuda()
182 |
183 | training = Trainer(
184 | cuda=cuda,
185 | cfg=cfg,
186 | model_depth=model_depth,
187 | model_rgb=model_rgb,
188 | model_fusion=model_fusion,
189 | train_loader=train_loader,
190 | test_data_list = ["DUT-RGBD","NJUD","NLPR","SSD","STEREO","LFSD","RGBD135","SIP","ReDWeb"],
191 | test_data_root = args.test_dataroot,
192 | salmap_root = args.salmap_root,
193 | outpath=args.snapshot_root,
194 | logging=logging,
195 | writer=writer,
196 | max_epoch=args.epoch,
197 | )
198 | training.epoch = start_epoch
199 | training.iteration = start_iteration
200 | training.train()
201 | else:
202 | # -------------------------- inference --------------------------- #
203 | res = []
204 | for id, (data, depth, bins, img_name, img_size) in enumerate(test_loader):
205 | # print('testing bach %d' % id)
206 | inputs = Variable(data).cuda()
207 | depth = Variable(depth).cuda()
208 | bins = Variable(bins).cuda()
209 | n, c, h, w = inputs.size()
210 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1)
211 | torch.cuda.synchronize()
212 | start = time.time()
213 | with torch.no_grad():
214 | h1, h2, h3, h4, h5 = model_rgb(inputs, bins, gumbel=False)
215 | d0, d1, d2, d3, d4 = model_depth(depth)
216 | predict_mask = model_fusion(h1, h2, h3, h4, h5, d0, d1, d2, d3, d4)
217 | torch.cuda.synchronize()
218 | end = time.time()
219 |
220 | res.append(end - start)
221 | outputs_all = F.softmax(predict_mask, dim=1)
222 | outputs = outputs_all[0][1]
223 | outputs = outputs.cpu().data.resize_(h, w)
224 |
225 | imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size)
226 | time_sum = 0
227 | for i in res:
228 | time_sum += i
229 | print("FPS: %f" % (1.0 / (time_sum / len(res))))
230 | # -------------------------- validation --------------------------- #
231 | torch.cuda.empty_cache()
232 | print('the testing process has finished!')
233 | F_measure, mae = get_FM(salpath=MapRoot+'/', gtpath=test_dataRoot+'/test_masks/')
234 | print(args.test_dataset + ' F_measure:', F_measure)
235 | print(args.test_dataset + ' MAE:', mae)
236 |
237 | F_key = args.test_dataset +'_Fb'
238 | M_key = args.test_dataset +'_mae'
239 | ckpt_key = args.test_dataset +'_ckpt'
240 | if F_key in max_F_dict.keys():
241 | if F_measure > max_F_dict[F_key]:
242 | max_F_dict[F_key] = F_measure
243 | max_F_dict[M_key] = mae
244 | max_F_dict[ckpt_key] = ckpt
245 | else:
246 | max_F_dict[F_key] = F_measure
247 | max_F_dict[M_key] = mae
248 | max_F_dict[ckpt_key] = ckpt
249 |
250 | if M_key in min_mae_dict.keys():
251 | if mae < min_mae_dict[M_key]:
252 | min_mae_dict[F_key] = F_measure
253 | min_mae_dict[M_key] = mae
254 | min_mae_dict[ckpt_key] = ckpt
255 | else:
256 | min_mae_dict[F_key] = F_measure
257 | min_mae_dict[M_key] = mae
258 | min_mae_dict[ckpt_key] = ckpt
259 |
260 | if args.phase == 'test':
261 | print ("max_F_dict")
262 | print (max_F_dict)
263 | print ("min_mae_dict")
264 | print (min_mae_dict)
265 |
266 | print("finish!!!!!!!!")
267 |
268 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('models')
3 |
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dsam.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/dsam.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dsam.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/dsam.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/genotypes.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/genotypes.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/genotypes.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/genotypes.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/gsmodule.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/gsmodule.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/gsmodule.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/gsmodule.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_depth.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_depth.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_depth.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_depth.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_fusion.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_fusion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_fusion_raw.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion_raw.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_fusion_raw.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_fusion_raw.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_rgb.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_rgb.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model_rgb.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/model_rgb.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/operations.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/operations.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/operations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/models/__pycache__/operations.cpython-37.pyc
--------------------------------------------------------------------------------
/models/dsam.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: Wenhu Zhang
3 | Date: 2021-06-07 17:39:22
4 | LastEditTime: 2021-06-07 18:08:28
5 | LastEditors: Wenhu Zhang
6 | Description:
7 | FilePath: /github/wh/DSA2F/models/dsam.py
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import numpy as np
12 | import cv2
13 | from gsmodule import GumbelSoftmax2D
14 |
15 |
16 | class ChannelAttentionLayer(nn.Module):
17 | def __init__(self, C_in, C_out, reduction=16, affine=True, BN=nn.BatchNorm2d):
18 | super(ChannelAttentionLayer, self).__init__()
19 | # global average pooling: feature --> point
20 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
21 | # feature channel downscale and upscale --> channel weight
22 | self.conv_du = nn.Sequential(
23 | nn.Conv2d(C_in, max(1, C_in // reduction), 1, padding=0, bias=False),
24 | nn.ReLU(),
25 | nn.Conv2d(max(1, C_in // reduction) , C_out, 1, padding=0, bias=False),
26 | nn.Sigmoid())
27 | def forward(self, x):
28 | y = self.avg_pool(x)
29 | y = self.conv_du(y)
30 | return x * y
31 |
32 | # DSAM V2
33 | class Adaptive_DSAM(nn.Module):
34 | def __init__(self,channel):
35 | super(Adaptive_DSAM, self).__init__()
36 | self.depth_revise = nn.Sequential(
37 | nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True),nn.ReLU(inplace=True),
38 | )
39 | self.fc = nn.Conv2d(32, 3, 1)
40 | self.GS = GumbelSoftmax2D(hard = True)
41 |
42 | self.channel=channel
43 | self.conv0=nn.Conv2d(channel,channel,1,padding=0)
44 | self.conv1=nn.Conv2d(channel,channel,1,padding=0)
45 | self.conv2=nn.Conv2d(channel,channel,1,padding=0)
46 | self.channel_att = ChannelAttentionLayer(self.channel, self.channel)
47 | def forward(self,x, bins, gumbel=False):
48 | n,c,h,w=x.size()
49 |
50 | bins = self.depth_revise(bins)
51 | gate = self.fc(bins)
52 | bins = self.GS(gate, gumbel=gumbel) * torch.mean(bins, dim=1,keepdim=True)
53 |
54 | x0=self.conv0(bins[:,0,:,:].unsqueeze(1) * x)
55 | x1=self.conv1(bins[:,1,:,:].unsqueeze(1) * x)
56 | x2=self.conv2(bins[:,2,:,:].unsqueeze(1) * x)
57 | x = (x0+x1+x2)+ x
58 | x = self.channel_att(x)
59 | return x
60 |
61 |
62 | # DSAM
63 | class DSAM(nn.Module):
64 | def __init__(self,channel):
65 | super(DSAM, self).__init__()
66 | self.channel=channel
67 | self.conv0=nn.Conv2d(channel,channel,1,padding=0)
68 | self.conv1=nn.Conv2d(channel,channel,1,padding=0)
69 | self.conv2=nn.Conv2d(channel,channel,1,padding=0)
70 | def forward(self,x, bins, gumbel=False):
71 | n,c,h,w=x.size()
72 |
73 | x0=self.conv0(bins[:,0,:,:].unsqueeze(1) * x)
74 | x1=self.conv1(bins[:,1,:,:].unsqueeze(1) * x)
75 | x2=self.conv2(bins[:,2,:,:].unsqueeze(1) * x)
76 | x = (x0+x1+x2)+ x
77 | return x
78 |
--------------------------------------------------------------------------------
/models/genotypes.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 |
4 | Genotype_all = namedtuple('Genotype', 'fusion1 fusion1_concat fusion2 fusion2_concat fusion3 fusion3_concat aggregation aggregation_concat final_agg final_aggregation_concat low_high_agg low_high_agg_concat')
5 | """
6 | Operation sets
7 | """
8 | PRIMITIVES = [
9 | 'none',
10 | 'max_pool_3x3',
11 | 'skip_connect',
12 | 'sep_conv_3x3',
13 | 'dil_conv_3x3',
14 | 'conv_1x1',
15 | 'conv_3x3',
16 | 'spatial_attention',
17 | 'channel_attention'
18 | ]
19 |
20 |
21 | attention_snas_3_4_1 = Genotype_all(fusion1=[('sep_conv_3x3', 2), ('dil_conv_3x3', 0), ('spatial_attention', 3), ('sep_conv_3x3', 1), ('max_pool_3x3', 4), ('conv_1x1', 2), ('channel_attention', 3), ('conv_1x1', 0), ('max_pool_3x3', 4), ('max_pool_3x3', 5), ('conv_1x1', 0), ('conv_1x1', 2), ('sep_conv_3x3', 4), ('sep_conv_3x3', 2), ('sep_conv_3x3', 6), ('conv_1x1', 0), ('conv_1x1', 4), ('dil_conv_3x3', 7), ('dil_conv_3x3', 1), ('skip_connect', 5), ('dil_conv_3x3', 0), ('conv_1x1', 1), ('conv_3x3', 4), ('conv_3x3', 8), ('spatial_attention', 3), ('sep_conv_3x3', 0), ('channel_attention', 4), ('conv_1x1', 5), ('conv_1x1', 3), ('sep_conv_3x3', 0), ('dil_conv_3x3', 8), ('dil_conv_3x3', 9)], fusion1_concat=range(6, 12), fusion2=[('sep_conv_3x3', 2), ('dil_conv_3x3', 0), ('spatial_attention', 3), ('sep_conv_3x3', 1), ('max_pool_3x3', 4), ('conv_1x1', 2), ('channel_attention', 3), ('conv_1x1', 0), ('max_pool_3x3', 4), ('max_pool_3x3', 5), ('conv_1x1', 0), ('conv_1x1', 2), ('sep_conv_3x3', 4), ('sep_conv_3x3', 2), ('sep_conv_3x3', 6), ('conv_1x1', 0), ('conv_1x1', 4), ('dil_conv_3x3', 7), ('dil_conv_3x3', 1), ('skip_connect', 5), ('dil_conv_3x3', 0), ('conv_1x1', 1), ('conv_3x3', 4), ('conv_3x3', 8), ('spatial_attention', 3), ('sep_conv_3x3', 0), ('channel_attention', 4), ('conv_1x1', 5), ('conv_1x1', 3), ('sep_conv_3x3', 0), ('dil_conv_3x3', 8), ('dil_conv_3x3', 9)], fusion2_concat=range(6, 12), fusion3=[('sep_conv_3x3', 2), ('dil_conv_3x3', 0), ('spatial_attention', 3), ('sep_conv_3x3', 1), ('max_pool_3x3', 4), ('conv_1x1', 2), ('channel_attention', 3), ('conv_1x1', 0), ('max_pool_3x3', 4), ('max_pool_3x3', 5), ('conv_1x1', 0), ('conv_1x1', 2), ('sep_conv_3x3', 4), ('sep_conv_3x3', 2), ('sep_conv_3x3', 6), ('conv_1x1', 0), ('conv_1x1', 4), ('dil_conv_3x3', 7), ('dil_conv_3x3', 1), ('skip_connect', 5), ('dil_conv_3x3', 0), ('conv_1x1', 1), ('conv_3x3', 4), ('conv_3x3', 8), ('spatial_attention', 3), ('sep_conv_3x3', 0), ('channel_attention', 4), ('conv_1x1', 5), ('conv_1x1', 3), ('sep_conv_3x3', 0), ('dil_conv_3x3', 8), ('dil_conv_3x3', 9)], fusion3_concat=range(6, 12), aggregation=[('spatial_attention', 1), ('max_pool_3x3', 2), ('sep_conv_3x3', 0), ('spatial_attention', 1), ('dil_conv_3x3', 2), ('max_pool_3x3', 3), ('spatial_attention', 1), ('conv_3x3', 4), ('conv_1x1', 2), ('conv_3x3', 0), ('sep_conv_3x3', 5), ('dil_conv_3x3', 3), ('conv_3x3', 1), ('conv_1x1', 5), ('dil_conv_3x3', 0), ('channel_attention', 3), ('spatial_attention', 4), ('max_pool_3x3', 1), ('max_pool_3x3', 1), ('skip_connect', 3), ('conv_3x3', 4), ('channel_attention', 3), ('skip_connect', 1), ('sep_conv_3x3', 6)], aggregation_concat=range(5, 11), final_agg=[('conv_1x1', 1), ('conv_1x1', 0), ('max_pool_3x3', 2), ('conv_1x1', 3), ('channel_attention', 2), ('dil_conv_3x3', 1), ('dil_conv_3x3', 3), ('conv_1x1', 0), ('spatial_attention', 2), ('sep_conv_3x3', 4), ('conv_1x1', 5), ('conv_3x3', 0), ('dil_conv_3x3', 4), ('conv_1x1', 2), ('conv_1x1', 1), ('dil_conv_3x3', 6), ('conv_1x1', 4), ('skip_connect', 2), ('conv_1x1', 7), ('max_pool_3x3', 6), ('conv_3x3', 5), ('channel_attention', 7), ('max_pool_3x3', 2), ('conv_3x3', 4), ('spatial_attention', 7), ('max_pool_3x3', 3), ('sep_conv_3x3', 0), ('spatial_attention', 8), ('max_pool_3x3', 2), ('conv_1x1', 4), ('sep_conv_3x3', 5), ('conv_1x1', 3)], final_aggregation_concat=range(6, 12), low_high_agg=[('max_pool_3x3', 2), ('spatial_attention', 1), ('conv_3x3', 0), ('channel_attention', 1), ('max_pool_3x3', 2), ('conv_1x1', 3), ('max_pool_3x3', 3), ('channel_attention', 1), ('conv_3x3', 4), ('max_pool_3x3', 3), ('skip_connect', 1), ('conv_1x1', 2)], low_high_agg_concat=range(3, 7))
22 |
--------------------------------------------------------------------------------
/models/gsmodule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 |
5 | """
6 | Gumbel Softmax Sampler
7 | Requires 2D input [batchsize, number of categories]
8 |
9 | Does not support sinlge binary category. Use two dimensions with softmax instead.
10 | """
11 |
12 | class GumbelSoftmax2D(torch.nn.Module):
13 | def __init__(self, hard=False):
14 | super(GumbelSoftmax2D, self).__init__()
15 | self.hard = hard
16 | self.gpu = False
17 |
18 | def cuda(self):
19 | self.gpu = True
20 |
21 | def cpu(self):
22 | self.gpu = False
23 |
24 | def sample_gumbel(self, shape, eps=1e-10):
25 | """Sample from Gumbel(0, 1)"""
26 | noise = torch.rand(shape)
27 | noise.add_(eps).log_().neg_()
28 | noise.add_(eps).log_().neg_()
29 | if self.gpu:
30 | return Variable(noise).cuda()
31 | else:
32 | return Variable(noise)
33 |
34 | def sample_gumbel_like(self, template_tensor, eps=1e-10):
35 | uniform_samples_tensor = template_tensor.clone().uniform_()
36 | gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps))
37 | return gumble_samples_tensor
38 |
39 | def gumbel_softmax_sample(self, logits, temperature):
40 | """ Draw a sample from the Gumbel-Softmax distribution"""
41 | dim = logits.size(-1)
42 | gumble_samples_tensor = self.sample_gumbel_like(logits.data)
43 | gumble_trick_log_prob_samples = logits + Variable(gumble_samples_tensor)
44 | soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, 1)
45 | return soft_samples
46 |
47 | def gumbel_softmax(self, logits, temperature, hard=False, gumbel=False):
48 | """Sample from the Gumbel-Softmax distribution and optionally discretize.
49 | Args:
50 | logits: [batch_size, n_class] unnormalized log-probs
51 | temperature: non-negative scalar
52 | hard: if True, take argmax, but differentiate w.r.t. soft sample y
53 | Returns:
54 | [batch_size, n_class] sample from the Gumbel-Softmax distribution.
55 | If hard=True, then the returned sample will be one-hot, otherwise it will
56 | be a probabilitiy distribution that sums to 1 across classes
57 | """
58 | if gumbel:
59 | y = self.gumbel_softmax_sample(logits, temperature)
60 | else:
61 | y = F.softmax(logits,1)
62 | if hard:
63 | _, max_value_indexes = y.data.max(1, keepdim=True)
64 | y_hard = logits.data.clone().zero_().scatter_(1, max_value_indexes, 1)
65 | y = Variable(y_hard - y.data) + y
66 | return y
67 |
68 | def forward(self, logits, gumbel=False, temp=1):
69 | b,c,h,w= logits.size()
70 | logits = logits.permute(0,2,3,1).contiguous().view(-1,c)
71 | logits = self.gumbel_softmax(logits, temperature=1, hard=self.hard, gumbel=gumbel)
72 |
73 | return logits.view(b,h,w,c).permute(0,3,1,2).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/models/model_depth.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import logging
6 | import torch.nn as nn
7 |
8 | BN_MOMENTUM = 0.1
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def conv3x3(in_planes, out_planes, stride=1):
13 | """3x3 convolution with padding"""
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
15 |
16 | class BasicBlock(nn.Module):
17 |
18 | def __init__(self, inplanes, planes, stride=1, downsample=None):
19 | super(BasicBlock, self).__init__()
20 | self.conv1 = conv3x3(inplanes, planes, stride)
21 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.conv2 = conv3x3(planes, planes)
24 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
25 | self.downsample = downsample
26 | self.stride = stride
27 |
28 | def forward(self, x):
29 | residual = x
30 |
31 | out = self.conv1(x)
32 | out = self.bn1(out)
33 | out = self.relu(out)
34 |
35 | out = self.conv2(out)
36 | out = self.bn2(out)
37 |
38 | out += residual
39 | out = self.relu(out)
40 |
41 | return out
42 |
43 | class DepthNet(nn.Module):
44 |
45 | def __init__(self):
46 | super(DepthNet, self).__init__()
47 | # conv1
48 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1)
49 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
50 | self.relu1_1 = nn.ReLU(inplace=True)
51 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
52 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
53 | self.relu1_2 = nn.ReLU(inplace=True)
54 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers
55 |
56 | # conv2
57 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
58 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
59 | self.relu2_1 = nn.ReLU(inplace=True)
60 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
61 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
62 | self.relu2_2 = nn.ReLU(inplace=True)
63 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers
64 | num_stages = 3
65 | blocks = BasicBlock
66 | num_blocks = [4, 4, 4]
67 | num_channels = [32, 32, 128]
68 | self.stage = self._make_stages(num_stages, blocks, num_blocks, num_channels)
69 | self.transition1 = nn.Sequential(
70 | nn.Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
71 | nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
72 | nn.ReLU(inplace=True)
73 | )
74 | self.transition2 = nn.Sequential(
75 | nn.Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
76 | nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
77 | nn.ReLU(inplace=True)
78 | )
79 |
80 | def _make_one_stage(self, stage_index, block, num_blocks, num_channels):
81 | layers = []
82 | for i in range(0, num_blocks[stage_index]):
83 | layers.append(
84 | block(
85 | num_channels[stage_index],
86 | num_channels[stage_index]
87 | )
88 | )
89 | return nn.Sequential(*layers)
90 |
91 | def _make_stages(self, num_stages, block, num_blocks, num_channels):
92 | branches = []
93 |
94 | for i in range(num_stages):
95 | branches.append(
96 | self._make_one_stage(i, block, num_blocks, num_channels)
97 | )
98 | return nn.ModuleList(branches)
99 |
100 | def forward(self, d):
101 | #depth branch
102 | d = self.relu1_1(self.bn1_1(self.conv1_1(d)))
103 | d = self.relu1_2(self.bn1_2(self.conv1_2(d)))
104 | d0 = self.pool1(d) # (128x128)*64
105 |
106 | d = self.relu2_1(self.bn2_1(self.conv2_1(d0)))
107 | d = self.relu2_2(self.bn2_2(self.conv2_2(d)))
108 | d1 = self.pool2(d) # (64x64)*128
109 | dt2 = self.transition1(d1)
110 | d2 = self.stage[0](dt2)
111 | d3 = self.stage[1](d2)
112 | dt4 = self.transition2(d3)
113 | d4 = self.stage[2](dt4)
114 | return d0, d1, d2, d3, d4
115 |
116 | def init_weights(self):
117 | logger.info('=> Depth model init weights from normal distribution')
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
121 | nn.init.normal_(m.weight, std=0.001)
122 | for name, _ in m.named_parameters():
123 | if name in ['bias']:
124 | nn.init.constant_(m.bias, 0)
125 | elif isinstance(m, nn.BatchNorm2d):
126 | nn.init.constant_(m.weight, 1)
127 | nn.init.constant_(m.bias, 0)
128 | elif isinstance(m, nn.ConvTranspose2d):
129 | nn.init.normal_(m.weight, std=0.001)
130 | for name, _ in m.named_parameters():
131 | if name in ['bias']:
132 | nn.init.constant_(m.bias, 0)
--------------------------------------------------------------------------------
/models/model_fusion.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import logging
7 | from operations import *
8 | import genotypes
9 | from genotypes import attention_snas_3_4_1
10 | from genotypes import PRIMITIVES
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class MixedOp(nn.Module):
18 | def __init__(self, C, stride):
19 | super(MixedOp, self).__init__()
20 | self._ops = nn.ModuleList()
21 | for primitive in PRIMITIVES:
22 | op = OPS[primitive](C, stride, False)
23 | if 'pool' in primitive:
24 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
25 | self._ops.append(op)
26 |
27 | def forward(self, x, weights):
28 | return sum(w * op(x) for w, op in zip(weights, self._ops))
29 |
30 | # Take four inputs
31 | class FusionCell(nn.Module):
32 | def __init__(self, genotype, index, steps, multiplier, parse_method):
33 | super(FusionCell, self).__init__()
34 |
35 | self.index = index
36 | if self.index == 0:
37 | op_names, indices = zip(*genotype.fusion1)
38 | concat = genotype.fusion1_concat
39 | C = 128 #128 // 2 # Fusion Scale 64x64
40 | # two rgb feats (64x64 128c, 32x32s 256c)
41 | # two depth feats (64x64 128c, 64x64 32c)
42 | self.preprocess0_rgb = nn.Sequential(
43 | nn.Conv2d(128, C, kernel_size=1, bias=False),
44 | nn.BatchNorm2d(C, affine=True))
45 | self.preprocess1_rgb = nn.Sequential(
46 | nn.Conv2d(256, C, kernel_size=1, bias=False),
47 | nn.BatchNorm2d(C, affine=True),
48 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
49 | self.preprocess0_depth = nn.Sequential(
50 | nn.Conv2d(128, C, kernel_size=1, bias=False),
51 | nn.BatchNorm2d(C, affine=True))
52 | self.preprocess1_depth = nn.Sequential(
53 | nn.Conv2d(32, C, kernel_size=1, bias=False),
54 | nn.BatchNorm2d(C, affine=True))
55 | elif self.index == 1:
56 | op_names, indices = zip(*genotype.fusion2)
57 | concat = genotype.fusion2_concat
58 | C = 128 #128 // 2 # Fusion Scale 64x64
59 | # two rgb feats (32x32 256c, 16x16 512c)
60 | # two depth feats (64x64 32c, 64x64 32c)
61 | self.preprocess0_rgb = nn.Sequential(
62 | nn.Conv2d(256, C, kernel_size=1, bias=False),
63 | nn.BatchNorm2d(C, affine=True),
64 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
65 | self.preprocess1_rgb = nn.Sequential(
66 | nn.Conv2d(512, C, kernel_size=1, bias=False),
67 | nn.BatchNorm2d(C, affine=True),
68 | nn.ReLU(inplace=True),
69 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
70 | nn.Conv2d(C, C, kernel_size=1, bias=False),
71 | nn.BatchNorm2d(C, affine=True),
72 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
73 | self.preprocess0_depth = nn.Sequential(
74 | nn.Conv2d(32, C, kernel_size=1, bias=False),
75 | nn.BatchNorm2d(C, affine=True))
76 | self.preprocess1_depth = nn.Sequential(
77 | nn.Conv2d(32, C, kernel_size=1, bias=False),
78 | nn.BatchNorm2d(C, affine=True))
79 | else:
80 | op_names, indices = zip(*genotype.fusion3)
81 | concat = genotype.fusion3_concat
82 | C = 128 #256 // 2 # Fusion Scale 32x32
83 | # two rgb feats (16x16 512c, 8x8 512c)
84 | # two depth feats (64x64 32c, 64x64 128c)
85 | self.preprocess0_rgb = nn.Sequential(
86 | nn.Conv2d(512, C, kernel_size=1, bias=False),
87 | nn.BatchNorm2d(C, affine=True),
88 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
89 | self.preprocess1_rgb = nn.Sequential(
90 | nn.Conv2d(512, C, kernel_size=1, bias=False),
91 | nn.BatchNorm2d(C, affine=True),
92 | nn.ReLU(inplace=True),
93 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
94 | nn.Conv2d(C, C, kernel_size=1, bias=False),
95 | nn.BatchNorm2d(C, affine=True),
96 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
97 | self.preprocess0_depth = nn.Sequential(
98 | nn.Conv2d(32, C, kernel_size=3, stride=2, padding=1, bias=False),
99 | nn.BatchNorm2d(C, affine=True))
100 | self.preprocess1_depth = nn.Sequential(
101 | nn.Conv2d(128, C, kernel_size=3, stride=2, padding=1, bias=False),
102 | nn.BatchNorm2d(C, affine=True))
103 |
104 | self._steps = steps
105 | self._multiplier = multiplier
106 | self._compile(C, op_names, indices, concat)
107 |
108 | def _compile(self, C, op_names, indices, concat):
109 | assert len(op_names) == len(indices)
110 | self._concat = concat
111 | self.multiplier = len(concat)
112 |
113 | self._ops = nn.ModuleList()
114 | for name, index in zip(op_names, indices):
115 | stride = 1
116 | op = OPS[name](C, stride, True)
117 | self._ops += [op]
118 | self._indices = indices
119 |
120 | def forward(self, s0, s1, s2, s3, drop_prob):
121 |
122 | # print("s_input:",s0.shape, s1.shape, s2.shape, s3.shape)
123 | s0 = self.preprocess0_rgb(s0)
124 | s1 = self.preprocess1_rgb(s1)
125 | s2 = self.preprocess0_depth(s2)
126 | s3 = self.preprocess1_depth(s3)
127 |
128 | # print("s_prepoce:",s0.shape, s1.shape, s2.shape, s3.shape)
129 | states = [s0, s1, s2, s3]
130 | for i in range(self._steps):
131 | h1 = states[self._indices[4*i]]
132 | h2 = states[self._indices[4*i+1]]
133 | h3 = states[self._indices[4*i+2]]
134 | h4 = states[self._indices[4*i+3]]
135 | op1 = self._ops[4*i]
136 | op2 = self._ops[4*i+1]
137 | op3 = self._ops[4*i+2]
138 | op4 = self._ops[4*i+3]
139 | h1 = op1(h1)
140 | h2 = op2(h2)
141 | h3 = op3(h3)
142 | h4 = op4(h4)
143 | if self.training and drop_prob > 0.:
144 | if not isinstance(op1, Identity):
145 | h1 = drop_path(h1, drop_prob)
146 | if not isinstance(op2, Identity):
147 | h2 = drop_path(h2, drop_prob)
148 | if not isinstance(op3, Identity):
149 | h3 = drop_path(h3, drop_prob)
150 | if not isinstance(op4, Identity):
151 | h4 = drop_path(h4, drop_prob)
152 | # print("h:",h1.shape, h2.shape, h3.shape, h4.shape)
153 | s = h1 + h2 + h3 + h4
154 | states += [s]
155 |
156 | return torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
157 |
158 | # Take three inputs
159 | class AggregationCell(nn.Module):
160 | def __init__(self, genotype, steps, multiplier, parse_method):
161 | super(AggregationCell, self).__init__()
162 | C = 128
163 | self.preprocess0 = None
164 | self.preprocess1 = None
165 | self.preprocess2 = None
166 |
167 | op_names, indices = zip(*genotype.aggregation)
168 | concat = genotype.aggregation_concat
169 | self._steps = steps
170 | self._multiplier = multiplier
171 | self._compile(C, op_names, indices, concat)
172 |
173 | def _compile(self, C, op_names, indices, concat):
174 | assert len(op_names) == len(indices)
175 | self._concat = concat
176 | self.multiplier = len(concat)
177 | self._ops = nn.ModuleList()
178 | for name, index in zip(op_names, indices):
179 | stride = 1
180 | op = OPS[name](C, stride, True)
181 | self._ops += [op]
182 | self._indices = indices
183 |
184 | def forward(self, s0, s1, s2, drop_prob):
185 | # print("000:",s0.shape, s1.shape, s2.shape)
186 | s0 = self.preprocess0(s0)
187 | s1 = self.preprocess1(s1)
188 | s2 = self.preprocess2(s2)
189 | # print("111:",s0.shape, s1.shape, s2.shape)
190 |
191 | states = [s0, s1, s2]
192 | for i in range(self._steps):
193 | h1 = states[self._indices[3*i]]
194 | h2 = states[self._indices[3*i+1]]
195 | h3 = states[self._indices[3*i+2]]
196 | op1 = self._ops[3*i]
197 | op2 = self._ops[3*i+1]
198 | op3 = self._ops[3*i+2]
199 | h1 = op1(h1)
200 | h2 = op2(h2)
201 | h3 = op3(h3)
202 | if self.training and drop_prob > 0.:
203 | if not isinstance(op1, Identity):
204 | h1 = drop_path(h1, drop_prob)
205 | if not isinstance(op2, Identity):
206 | h2 = drop_path(h2, drop_prob)
207 | if not isinstance(op3, Identity):
208 | h3 = drop_path(h3, drop_prob)
209 | s = h1 + h2 + h3
210 | states += [s]
211 | return torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
212 |
213 | class AggregationCell_1(AggregationCell):
214 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [768,768,768]):
215 | super().__init__(genotype, steps, multiplier, parse_method)
216 | C = 128
217 | self.preprocess0 = nn.Sequential(
218 | nn.ReLU(inplace=True),
219 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False),
220 | nn.BatchNorm2d(C, affine=True)
221 | )
222 | self.preprocess1 = nn.Sequential(
223 | nn.ReLU(inplace=True),
224 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False),
225 | nn.BatchNorm2d(C, affine=True),
226 | )
227 | self.preprocess2 = nn.Sequential(
228 | nn.ReLU(inplace=True),
229 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False),
230 | nn.BatchNorm2d(C, affine=True),
231 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
232 |
233 | class AggregationCell_2(AggregationCell):
234 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [512,128,768]):
235 | super().__init__(genotype, steps, multiplier, parse_method)
236 | C = 128
237 | self.preprocess0 = nn.Sequential(
238 | nn.ReLU(inplace=False),
239 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False),
240 | nn.BatchNorm2d(C, affine=True),
241 | nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
242 | )
243 | self.preprocess1 = nn.Sequential(
244 | nn.ReLU(inplace=False),
245 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False),
246 | nn.BatchNorm2d(C, affine=True),
247 | # nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
248 | )
249 | self.preprocess2 = nn.Sequential(
250 | nn.ReLU(inplace=False),
251 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False),
252 | nn.BatchNorm2d(C, affine=True))
253 |
254 | class AggregationCell_3(AggregationCell):
255 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [256,32,768]):
256 | super().__init__(genotype, steps, multiplier, parse_method)
257 | C = 128
258 | self.preprocess0 = nn.Sequential(
259 | nn.ReLU(inplace=False),
260 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False),
261 | nn.BatchNorm2d(C, affine=True),
262 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
263 | )
264 | self.preprocess1 = nn.Sequential(
265 | nn.ReLU(inplace=False),
266 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False),
267 | nn.BatchNorm2d(C, affine=True),
268 | )
269 | self.preprocess2 = nn.Sequential(
270 | nn.ReLU(inplace=False),
271 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False),
272 | nn.BatchNorm2d(C, affine=True),
273 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
274 |
275 | class AggregationCell_4(AggregationCell):
276 | def __init__(self, genotype, steps, multiplier, parse_method, C_in = [512,32,768]):
277 | super().__init__(genotype, steps, multiplier, parse_method)
278 | C = 128
279 | self.preprocess0 = nn.Sequential(
280 | nn.ReLU(inplace=False),
281 | nn.Conv2d(C_in[0], C*2, kernel_size=1, bias=False),
282 | nn.BatchNorm2d(C*2, affine=True),
283 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
284 | nn.ReLU(inplace=False),
285 | nn.Conv2d(C*2, C, kernel_size=1, bias=False),
286 | nn.BatchNorm2d(C, affine=True),
287 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
288 | )
289 | self.preprocess1 = nn.Sequential(
290 | nn.ReLU(inplace=False),
291 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False),
292 | nn.BatchNorm2d(C, affine=True),
293 | )
294 | self.preprocess2 = nn.Sequential(
295 | nn.ReLU(inplace=False),
296 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False),
297 | nn.BatchNorm2d(C, affine=True))
298 |
299 | # Take foul inputs
300 | class GlobalAggregationCell(nn.Module):
301 | def __init__(self, genotype, steps, multiplier, parse_method):
302 | super(GlobalAggregationCell, self).__init__()
303 | C = 256
304 | self.preprocess0 = nn.Sequential(
305 | nn.ReLU(inplace=True),
306 | nn.Conv2d(768, C, kernel_size=1, bias=False),
307 | nn.BatchNorm2d(C, affine=True))
308 | self.preprocess1 = nn.Sequential(
309 | nn.ReLU(inplace=True),
310 | nn.Conv2d(768, C, kernel_size=1, bias=False),
311 | nn.BatchNorm2d(C, affine=True))
312 | self.preprocess2 = nn.Sequential(
313 | nn.ReLU(inplace=True),
314 | nn.Conv2d(768, C, kernel_size=1, bias=False),
315 | nn.BatchNorm2d(C, affine=True))
316 | self.preprocess3 = nn.Sequential(
317 | nn.ReLU(inplace=True),
318 | nn.Conv2d(768, C, kernel_size=1, bias=False),
319 | nn.BatchNorm2d(C, affine=True))
320 |
321 | op_names, indices = zip(*genotype.final_agg)
322 | concat = genotype.final_aggregation_concat
323 | self._steps = steps
324 | self._multiplier = multiplier
325 | self._compile(C, op_names, indices, concat)
326 |
327 | def _compile(self, C, op_names, indices, concat):
328 | assert len(op_names) == len(indices)
329 | self._concat = concat
330 | self.multiplier = len(concat)
331 | self._ops = nn.ModuleList()
332 | for name, index in zip(op_names, indices):
333 | stride = 1
334 | op = OPS[name](C, stride, True)
335 | self._ops += [op]
336 | self._indices = indices
337 |
338 | def forward(self, s0, s1, s2, s3, drop_prob):
339 | s0 = self.preprocess0(s0)
340 | s1 = self.preprocess1(s1)
341 | s2 = self.preprocess2(s2)
342 | s3 = self.preprocess3(s3)
343 |
344 | states = [s0, s1, s2, s3]
345 | for i in range(self._steps):
346 | h1 = states[self._indices[4*i]]
347 | h2 = states[self._indices[4*i+1]]
348 | h3 = states[self._indices[4*i+2]]
349 | h4 = states[self._indices[4*i+3]]
350 | op1 = self._ops[4*i]
351 | op2 = self._ops[4*i+1]
352 | op3 = self._ops[4*i+2]
353 | op4 = self._ops[4*i+3]
354 | h1 = op1(h1)
355 | h2 = op2(h2)
356 | h3 = op3(h3)
357 | h4 = op4(h4)
358 | if self.training and drop_prob > 0.:
359 | if not isinstance(op1, Identity):
360 | h1 = drop_path(h1, drop_prob)
361 | if not isinstance(op2, Identity):
362 | h2 = drop_path(h2, drop_prob)
363 | if not isinstance(op3, Identity):
364 | h3 = drop_path(h3, drop_prob)
365 | if not isinstance(op4, Identity):
366 | h4 = drop_path(h4, drop_prob)
367 | s = h1 + h2 + h3 + h4
368 | states += [s]
369 | return torch.cat([states[i] for i in self._concat], dim=1) # N,C,H, W
370 |
371 | # Take three inputs
372 | class Low_High_aggregation(AggregationCell):
373 |
374 | def __init__(self, genotype, steps, multiplier, parse_method, C_in=[64,64,128]):
375 | super().__init__(genotype, steps, multiplier, parse_method)
376 | C = 32
377 | self.preprocess0 = nn.Sequential(
378 | nn.ReLU(inplace=False),
379 | nn.Conv2d(C_in[0], C, kernel_size=1, bias=False),
380 | nn.BatchNorm2d(C, affine=True)
381 | )
382 | self.preprocess1 = nn.Sequential(
383 | nn.ReLU(inplace=False),
384 | nn.Conv2d(C_in[1], C, kernel_size=1, bias=False),
385 | nn.BatchNorm2d(C, affine=True),
386 | )
387 | self.preprocess2 = nn.Sequential(
388 | nn.ReLU(inplace=False),
389 | nn.Conv2d(C_in[2], C, kernel_size=1, bias=False),
390 | nn.BatchNorm2d(C, affine=True))
391 | op_names, indices = zip(*genotype.low_high_agg)
392 | concat = genotype.low_high_agg_concat
393 | self._compile(C, op_names, indices, concat)
394 |
395 |
396 |
397 | class NasFusionNet(nn.Module):
398 | def __init__(self, fusion_cell_number=3, steps=8, multiplier=6, agg_steps=8, agg_multiplier=6, genotype
399 | =attention_snas_3_4_1, parse_method='darts', op_threshold=0.85, drop_path_prob=0):
400 | self.inplanes = 64
401 | super(NasFusionNet, self).__init__()
402 | self.drop_path_prob = 0
403 | self._multiplier = 6
404 | self.parse_method = parse_method
405 | self.op_threshold = op_threshold
406 | self._steps = steps
407 | # init the fusion cells
408 | self.MM_cells = nn.ModuleList()
409 |
410 | for i in range(fusion_cell_number):
411 | cell = FusionCell(genotype, i, steps, multiplier, parse_method)
412 | self.MM_cells += [cell]
413 |
414 | self.MS_cell_1 = AggregationCell_1(genotype, agg_steps, agg_multiplier, parse_method)
415 | self.MS_cell_2 = AggregationCell_2(genotype, agg_steps, agg_multiplier, parse_method)
416 | self.MS_cell_3 = AggregationCell_3(genotype, agg_steps, agg_multiplier, parse_method)
417 | self.MS_cell_4 = AggregationCell_4(genotype, agg_steps, agg_multiplier, parse_method)
418 |
419 | self.GA_cell = GlobalAggregationCell(genotype, agg_steps, agg_multiplier, parse_method)
420 | self.SR_cell_1 = Low_High_aggregation(genotype, 4, 4, parse_method, C_in = [128,128,256])
421 | self.SR_cell_2 = Low_High_aggregation(genotype, 4, 4, parse_method, C_in = [64, 64, 128])
422 |
423 | self.final_layer0 = nn.Sequential(
424 | nn.Conv2d(1536, 512, kernel_size=1), nn.BatchNorm2d(512, affine=True), nn.ReLU(inplace=True), # 256
425 | nn.Conv2d(512, 256, kernel_size=1), nn.BatchNorm2d(256, affine=True), nn.ReLU(inplace=True),
426 | nn.Conv2d(256, 256, kernel_size=1), nn.BatchNorm2d(256, affine=True), nn.ReLU(inplace=True),
427 | )
428 |
429 | self.final_layer1 = nn.Sequential(
430 | nn.ReLU(inplace=True),
431 | nn.Conv2d(256+128, 256, kernel_size=1), nn.BatchNorm2d(256, affine=True), nn.ReLU(inplace=True),
432 | nn.Conv2d(256, 128, kernel_size=1), nn.BatchNorm2d(128, affine=True), nn.ReLU(inplace=True),
433 | nn.Conv2d(128, 128, kernel_size=1), nn.BatchNorm2d(128, affine=True), nn.ReLU(inplace=True)
434 | )
435 |
436 | self.final_layer2 = nn.Sequential(
437 | nn.Conv2d(128+128, 64, kernel_size=1), nn.BatchNorm2d(64, affine=True), nn.ReLU(inplace=True),
438 | nn.Conv2d(64, 64, kernel_size=1), nn.BatchNorm2d(64, affine=True), nn.ReLU(inplace=True),
439 | # nn.Dropout2d(p=0.1),
440 | nn.Conv2d(64, 2, kernel_size=1)
441 | )
442 |
443 |
444 | def forward(self, h1, h2, h3, h4, h5, d0, d1, d2, d3, d4):
445 | # print(h2.shape,d1.shape, h5.shape,d4.shape)
446 |
447 | output1 = self.MM_cells[0](h2, h3, d1, d2, self.drop_path_prob)
448 | output2 = self.MM_cells[1](h3, h4, d2, d3, self.drop_path_prob)
449 | output3 = self.MM_cells[2](h4, h5, d3, d4, self.drop_path_prob)
450 |
451 | agg_features1 = self.MS_cell_1(output1, output2, output3,self.drop_path_prob)
452 | agg_features2 = self.MS_cell_2(h5, d4, output2, self.drop_path_prob)
453 | agg_features3 = self.MS_cell_3(h3, d2, output3, self.drop_path_prob)
454 | agg_features4 = self.MS_cell_4(h4, d3, output1, self.drop_path_prob)
455 |
456 | agg_features = self.GA_cell(agg_features1, agg_features2, agg_features3, agg_features4, self.drop_path_prob)
457 | predict_mask = self.final_layer0(agg_features) # c=256
458 |
459 | low_high_combined1 = self.SR_cell_1(h2, d1, predict_mask, self.drop_path_prob) # c==128
460 | predict_mask = torch.cat([predict_mask, low_high_combined1], dim=1) # 256 + 128
461 |
462 | predict_mask = F.upsample(predict_mask, scale_factor=2, mode='bilinear', align_corners=True)
463 | predict_mask = self.final_layer1(predict_mask) # 128
464 |
465 | low_high_combined2 = self.SR_cell_2(h1, d0, predict_mask, self.drop_path_prob) # 128
466 | predict_mask = torch.cat([predict_mask, low_high_combined2], dim=1)
467 | predict_mask = F.upsample(predict_mask, scale_factor=2, mode='bilinear', align_corners=True)
468 | predict_mask = self.final_layer2(predict_mask)
469 |
470 | return F.sigmoid(predict_mask)
471 |
472 | def init_weights(self):
473 | logger.info('=> NAS Fusion model init weights from normal distribution')
474 | for m in self.modules():
475 | if isinstance(m, nn.Conv2d):
476 | nn.init.normal_(m.weight, std=0.001)
477 | for name, _ in m.named_parameters():
478 | if name in ['bias']:
479 | nn.init.constant_(m.bias, 0)
480 | elif isinstance(m, nn.BatchNorm2d):
481 | nn.init.constant_(m.weight, 1)
482 | nn.init.constant_(m.bias, 0)
483 | elif isinstance(m, nn.ConvTranspose2d):
484 | nn.init.normal_(m.weight, std=0.001)
485 | for name, _ in m.named_parameters():
486 | if name in ['bias']:
487 | nn.init.constant_(m.bias, 0)
488 |
489 |
490 |
491 |
492 | class NasFusionNet_pre(nn.Module):
493 | def __init__(self, fusion_cell_number=3, steps=8, multiplier=6, agg_steps=8, agg_multiplier=6, genotype
494 | =attention_snas_3_4_1, parse_method='darts', op_threshold=0.85, drop_path_prob=0):
495 | self.inplanes = 64
496 | super(NasFusionNet_pre, self).__init__()
497 | self.drop_path_prob = 0
498 | self._multiplier = 6
499 | self.parse_method = parse_method
500 | self.op_threshold = op_threshold
501 | self._steps = steps
502 | # init the fusion cells
503 | self.MM_cells = nn.ModuleList()
504 |
505 | for i in range(fusion_cell_number):
506 | cell = FusionCell(genotype, i, steps, multiplier, parse_method)
507 | self.MM_cells += [cell]
508 |
509 | self.MS_cell_1 = AggregationCell_1(genotype, agg_steps, agg_multiplier, parse_method)
510 | self.MS_cell_2 = AggregationCell_2(genotype, agg_steps, agg_multiplier, parse_method)
511 | self.MS_cell_3 = AggregationCell_3(genotype, agg_steps, agg_multiplier, parse_method)
512 | self.MS_cell_4 = AggregationCell_4(genotype, agg_steps, agg_multiplier, parse_method)
513 |
514 | self.GA_cell = GlobalAggregationCell(genotype, agg_steps, agg_multiplier, parse_method)
515 | ######## for pretrain
516 | self.class_head = nn.Sequential(
517 | nn.AvgPool2d((56, 56)))
518 | self.classifier = nn.Linear(1536, 1000)
519 |
520 |
521 | def forward(self, h1, h2, h3, h4, h5, d0, d1, d2, d3, d4):
522 |
523 | output1 = self.MM_cells[0](h2, h3, d1, d2, self.drop_path_prob)
524 | output2 = self.MM_cells[1](h3, h4, d2, d3, self.drop_path_prob)
525 | output3 = self.MM_cells[2](h4, h5, d3, d4, self.drop_path_prob)
526 |
527 | agg_features1 = self.MS_cell_1(output1, output2, output3,self.drop_path_prob)
528 | agg_features2 = self.MS_cell_2(h5, d4, output2, self.drop_path_prob)
529 | agg_features3 = self.MS_cell_3(h3, d2, output3, self.drop_path_prob)
530 | agg_features4 = self.MS_cell_4(h4, d3, output1, self.drop_path_prob)
531 |
532 | agg_features = self.GA_cell(agg_features1, agg_features2, agg_features3, agg_features4, self.drop_path_prob)
533 |
534 | ######## for pretrain
535 | # print(agg_features.shape)
536 | class_feature = self.class_head(agg_features).view(agg_features.size(0), -1)
537 | logits = self.classifier(class_feature)
538 | return logits
539 |
540 | def init_weights(self):
541 | logger.info('=> init weights from normal distribution')
542 | for m in self.modules():
543 | if isinstance(m, nn.Conv2d):
544 | nn.init.normal_(m.weight, std=0.001)
545 | for name, _ in m.named_parameters():
546 | if name in ['bias']:
547 | nn.init.constant_(m.bias, 0)
548 | elif isinstance(m, nn.BatchNorm2d):
549 | nn.init.constant_(m.weight, 1)
550 | nn.init.constant_(m.bias, 0)
551 | elif isinstance(m, nn.ConvTranspose2d):
552 | nn.init.normal_(m.weight, std=0.001)
553 | for name, _ in m.named_parameters():
554 | if name in ['bias']:
555 | nn.init.constant_(m.bias, 0)
556 |
557 |
--------------------------------------------------------------------------------
/models/model_rgb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from dsam import Adaptive_DSAM as DepthAttention
5 |
6 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
7 | """Make a 2D bilinear kernel suitable for upsampling"""
8 | factor = (kernel_size + 1) // 2
9 | if kernel_size % 2 == 1:
10 | center = factor - 1
11 | else:
12 | center = factor - 0.5
13 | og = np.ogrid[:kernel_size, :kernel_size]
14 | filt = (1 - abs(og[0] - center) / factor) * \
15 | (1 - abs(og[1] - center) / factor)
16 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
17 | dtype=np.float64)
18 | weight[range(in_channels), range(out_channels), :, :] = filt
19 | return torch.from_numpy(weight).float()
20 |
21 | #################################### Rgb Network #####################################
22 |
23 | class RgbNet(nn.Module):
24 | def __init__(self):
25 | super(RgbNet, self).__init__()
26 |
27 | # original image's size = 256*256*3
28 | # conv1
29 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1)
30 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
31 | self.relu1_1 = nn.ReLU(inplace=True)
32 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
33 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
34 | self.relu1_2 = nn.ReLU(inplace=True)
35 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers
36 | self.depth_att1 = DepthAttention(64)
37 |
38 | # conv2
39 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
40 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
41 | self.relu2_1 = nn.ReLU(inplace=True)
42 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
43 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
44 | self.relu2_2 = nn.ReLU(inplace=True)
45 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers
46 | self.depth_att2 = DepthAttention(128)
47 | # conv3
48 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
49 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
50 | self.relu3_1 = nn.ReLU(inplace=True)
51 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
52 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
53 | self.relu3_2 = nn.ReLU(inplace=True)
54 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
55 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
56 | self.relu3_3 = nn.ReLU(inplace=True)
57 | self.conv3_4 = nn.Conv2d(256, 256, 3, padding=1)
58 | self.bn3_4 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
59 | self.relu3_4 = nn.ReLU(inplace=True)
60 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers
61 | self.depth_att3 = DepthAttention(256)
62 |
63 | # conv4
64 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
65 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
66 | self.relu4_1 = nn.ReLU(inplace=True)
67 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
68 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
69 | self.relu4_2 = nn.ReLU(inplace=True)
70 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
71 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
72 | self.relu4_3 = nn.ReLU(inplace=True)
73 | self.conv4_4 = nn.Conv2d(512, 512, 3, padding=1)
74 | self.bn4_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
75 | self.relu4_4 = nn.ReLU(inplace=True)
76 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers
77 | self.depth_att4 = DepthAttention(512)
78 |
79 | # conv5
80 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
81 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
82 | self.relu5_1 = nn.ReLU(inplace=True)
83 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
84 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
85 | self.relu5_2 = nn.ReLU(inplace=True)
86 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
87 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
88 | self.relu5_3 = nn.ReLU(inplace=True)
89 | self.conv5_4 = nn.Conv2d(512, 512, 3, padding=1)
90 | self.bn5_4 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
91 | self.relu5_4 = nn.ReLU(inplace=True)
92 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 4 layers
93 | self.depth_att5 = DepthAttention(512)
94 |
95 | self._initialize_weights()
96 |
97 | def _initialize_weights(self):
98 | for m in self.modules():
99 | if isinstance(m, nn.Conv2d):
100 | # m.weight.data.zero_()
101 | nn.init.normal(m.weight.data, std=0.01)
102 | if m.bias is not None:
103 | m.bias.data.zero_()
104 | if isinstance(m, nn.ConvTranspose2d):
105 | assert m.kernel_size[0] == m.kernel_size[1]
106 | initial_weight = get_upsampling_weight(m.in_channels, m.out_channels, m.kernel_size[0])
107 | m.weight.data.copy_(initial_weight)
108 |
109 |
110 | def forward(self, x, bins, gumbel):
111 |
112 | b=bins
113 | h = x
114 | h = self.relu1_1(self.bn1_1(self.conv1_1(h)))
115 | h = self.conv1_2(h)
116 | h = self.depth_att1(h, b, gumbel=gumbel )
117 | h = self.relu1_2(self.bn1_2(h))
118 | h1 = self.pool1(h) # (128x128)*64
119 | b1 = self.pool1(b)
120 |
121 | h = self.relu2_1(self.bn2_1(self.conv2_1(h1)))
122 | h=self.conv2_2(h)
123 | h = self.depth_att2(h, b1, gumbel=gumbel)
124 | h = self.relu2_2(self.bn2_2(h))
125 | h2 = self.pool2(h) # (64x64)*128
126 | b2 = self.pool2(b1)
127 |
128 | h = self.relu3_1(self.bn3_1(self.conv3_1(h2)))
129 | h = self.relu3_2(self.bn3_2(self.conv3_2(h)))
130 | h = self.relu3_3(self.bn3_3(self.conv3_3(h)))
131 | h = self.conv3_4(h)
132 | h = self.depth_att3(h, b2, gumbel=gumbel)
133 | h = self.relu3_4(self.bn3_4(h))
134 | h3 = self.pool3(h)# (32x32)*256
135 | b3 = self.pool3(b2)
136 |
137 | h = self.relu4_1(self.bn4_1(self.conv4_1(h3)))
138 | h = self.relu4_2(self.bn4_2(self.conv4_2(h)))
139 | h = self.relu4_3(self.bn4_3(self.conv4_3(h)))
140 | h = self.conv4_4(h)
141 | h = self.depth_att4(h,b3, gumbel=gumbel)
142 | h = self.relu4_4(self.bn4_4(h))
143 | h4 = self.pool4(h)# (16x16)*512
144 | b4 = self.pool4(b3)
145 |
146 |
147 | h = self.relu5_1(self.bn5_1(self.conv5_1(h4)))
148 | h = self.relu5_2(self.bn5_2(self.conv5_2(h)))
149 | h = self.relu5_3(self.bn5_3(self.conv5_3(h)))
150 | h = self.conv5_4(h)
151 | h = self.depth_att5(h,b4, gumbel=gumbel)
152 | h = self.relu5_4(self.bn5_4(h))
153 | h5 = self.pool5(h)#(8x8)*512
154 |
155 | return h1,h2,h3,h4,h5
156 |
157 |
158 |
159 | def copy_params_from_vgg19_bn(self, vgg19_bn):
160 | features = [
161 | self.conv1_1, self.bn1_1, self.relu1_1,
162 | self.conv1_2, self.bn1_2, self.relu1_2,
163 | self.pool1,
164 | self.conv2_1, self.bn2_1, self.relu2_1,
165 | self.conv2_2, self.bn2_2, self.relu2_2,
166 | self.pool2,
167 | self.conv3_1, self.bn3_1, self.relu3_1,
168 | self.conv3_2, self.bn3_2, self.relu3_2,
169 | self.conv3_3, self.bn3_3, self.relu3_3,
170 | self.conv3_4, self.bn3_4, self.relu3_4,
171 | self.pool3,
172 | self.conv4_1, self.bn4_1, self.relu4_1,
173 | self.conv4_2, self.bn4_2, self.relu4_2,
174 | self.conv4_3, self.bn4_3, self.relu4_3,
175 | self.conv4_4, self.bn4_4, self.relu4_4,
176 | self.pool4,
177 | self.conv5_1, self.bn5_1, self.relu5_1,
178 | self.conv5_2, self.bn5_2, self.relu5_2,
179 | self.conv5_3, self.bn5_3, self.relu5_3,
180 | self.conv5_4, self.bn5_4, self.relu5_4,
181 | ]
182 | for l1, l2 in zip(vgg19_bn.features, features):
183 | if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
184 | assert l1.weight.size() == l2.weight.size()
185 | assert l1.bias.size() == l2.bias.size()
186 | l2.weight.data = l1.weight.data
187 | l2.bias.data = l1.bias.data
188 | if isinstance(l1, nn.BatchNorm2d) and isinstance(l2, nn.BatchNorm2d):
189 | assert l1.weight.size() == l2.weight.size()
190 | assert l1.bias.size() == l2.bias.size()
191 | l2.weight.data = l1.weight.data
192 | l2.bias.data = l1.bias.data
--------------------------------------------------------------------------------
/models/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | OPS = {
5 | 'none' : lambda C, stride, affine: Zero(stride),
6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
9 | 'conv_3x3' : lambda C, stride, affine : ReLUConvBN(C, C, 3, stride, 1, affine=affine),
10 | 'conv_1x1' : lambda C, stride, affine : ReLUConvBN(C, C, 1, stride, 0, affine=affine),
11 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
12 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
13 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
14 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
15 | 'dil_conv_3x3_2dil' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
16 | 'dil_conv_3x3_4dil' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 4, affine=affine),
17 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
19 | nn.ReLU(inplace=False),
20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
22 | nn.BatchNorm2d(C, affine=affine)
23 | ),
24 | 'spatial_attention': lambda C, stride, affine : SpatialAttentionLayer(C, C, 8, stride, affine),
25 | 'channel_attention': lambda C, stride, affine : ChannelAttentionLayer(C, C, 8, stride, affine)
26 | }
27 |
28 | class Zero(nn.Module):
29 |
30 | def __init__(self, stride):
31 | super(Zero, self).__init__()
32 | self.stride = stride
33 |
34 | def forward(self, x):
35 | if self.stride == 1:
36 | return x.mul(0.)
37 | return x[:,:,::self.stride,::self.stride].mul(0.)
38 |
39 |
40 |
41 | class SpatialAttentionLayer(nn.Module):
42 | def __init__(self, C_in, C_out, reduction=16, stride=1, affine=True, BN=nn.BatchNorm2d):
43 | super(SpatialAttentionLayer, self).__init__()
44 | self.stride = stride
45 | if stride == 1:
46 | self.fc = nn.Sequential(
47 | nn.Conv2d(C_in, C_in // reduction, kernel_size=3, stride=1, padding=1, bias=False),
48 | BN(C_in // reduction, affine=affine),
49 | nn.ReLU(inplace=False),
50 | nn.Conv2d(C_in // reduction, 1,kernel_size=3, stride=1, padding=1, bias=False),
51 | nn.Sigmoid()
52 | )
53 | else:
54 | self.fc = nn.Sequential(
55 | nn.Conv2d(C_in, C_in // reduction, kernel_size=3, stride=2, padding=1, bias=False),
56 | BN(C_in // reduction, affine=affine),
57 | nn.ReLU(inplace=False),
58 | nn.Conv2d(C_in // reduction, 1, kernel_size=3, stride=1, padding=1, bias=False),
59 | nn.Sigmoid()
60 | )
61 | self.reduce_map = nn.Sequential(
62 | nn.ReLU(inplace=False),
63 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=2, padding=0, bias=False),
64 | BN(C_out, affine=affine)
65 | )
66 |
67 | def forward(self, x):
68 | y = self.fc(x)
69 | if self.stride == 2:
70 | x = self.reduce_map(x)
71 | return x * y
72 |
73 |
74 | ## Channel Attention (CA) Layer
75 | class ChannelAttentionLayer(nn.Module):
76 | def __init__(self, C_in, C_out, reduction=16, stride=1, affine=True, BN=nn.BatchNorm2d):
77 | super(ChannelAttentionLayer, self).__init__()
78 | # global average pooling: feature --> point
79 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
80 | self.stride = stride
81 | # feature channel downscale and upscale --> channel weight
82 | if stride == 1:
83 | self.conv_du = nn.Sequential(
84 | nn.Conv2d(C_in, C_in // reduction, 1, padding=0, bias=False),
85 | nn.ReLU(inplace=False),
86 | nn.Conv2d(C_in // reduction, C_out, 1, padding=0, bias=False),
87 | nn.Sigmoid()
88 | )
89 | else:
90 | self.conv_du = nn.Sequential(
91 | nn.Conv2d(C_in, C_in // reduction, kernel_size=1, stride=2, padding=0, bias=False),
92 | nn.ReLU(inplace=False),
93 | nn.Conv2d(C_in // reduction, C_out, 1, padding=0, bias=False),
94 | nn.Sigmoid()
95 | )
96 | self.reduce_map = nn.Sequential(
97 | nn.ReLU(inplace=False),
98 | nn.Conv2d(C_in, C_out, kernel_size=1, stride=2, padding=0, bias=False),
99 | BN(C_out, affine=affine)
100 | )
101 |
102 | def forward(self, x):
103 | if self.stride == 2:
104 | x = self.reduce_map(x)
105 | y = self.avg_pool(x)
106 | y = self.conv_du(y)
107 | return x * y
108 |
109 | class ReLUConvBN(nn.Module):
110 | """
111 | ReLu -> Conv2d -> BatchNorm2d
112 | """
113 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
114 | super(ReLUConvBN, self).__init__()
115 | self.op = nn.Sequential(
116 | nn.ReLU(inplace=False),
117 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
118 | nn.BatchNorm2d(C_out, affine=affine)
119 | )
120 |
121 | def forward(self, x):
122 | return self.op(x)
123 |
124 | class DilConv(nn.Module):
125 | """
126 | Dilation Convolution : ReLU -> DilConv -> Conv2d -> BatchNorm2d
127 | """
128 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
129 | super(DilConv, self).__init__()
130 | self.op = nn.Sequential(
131 | nn.ReLU(inplace=False),
132 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
133 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
134 | nn.BatchNorm2d(C_out, affine=affine),
135 | )
136 |
137 | def forward(self, x):
138 | return self.op(x)
139 |
140 |
141 | class SepConv(nn.Module):
142 |
143 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
144 | super(SepConv, self).__init__()
145 | self.op = nn.Sequential(
146 | nn.ReLU(inplace=False),
147 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
148 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
149 | nn.BatchNorm2d(C_in, affine=affine),
150 | nn.ReLU(inplace=False),
151 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
152 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
153 | nn.BatchNorm2d(C_out, affine=affine),
154 | )
155 |
156 | def forward(self, x):
157 | return self.op(x)
158 |
159 |
160 | class Identity(nn.Module):
161 |
162 | def __init__(self):
163 | super(Identity, self).__init__()
164 |
165 | def forward(self, x):
166 | return x
167 |
168 |
169 | class Zero(nn.Module):
170 |
171 | def __init__(self, stride):
172 | super(Zero, self).__init__()
173 | self.stride = stride
174 |
175 | def forward(self, x):
176 | if self.stride == 1:
177 | return x.mul(0.)
178 | return x[:,:,::self.stride,::self.stride].mul(0.) # N * C * W * H
179 |
180 |
181 | class FactorizedReduce(nn.Module):
182 |
183 | def __init__(self, C_in, C_out, affine=True):
184 | super(FactorizedReduce, self).__init__()
185 | assert C_out % 2 == 0
186 | self.relu = nn.ReLU(inplace=False)
187 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
188 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
189 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
190 |
191 | def forward(self, x):
192 | x = self.relu(x)
193 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
194 | out = self.bn(out)
195 | return out
196 |
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.autograd import Variable
3 | import torch.nn.functional as F
4 | import torch
5 | import torch.optim as optim
6 | from dataset_loader import MyTestData
7 | import logging
8 | from tqdm import tqdm
9 | import time
10 | from utils.functions import *
11 | from utils.evaluateFM import get_FM
12 | from loss import cross_entropy2d, iou, BinaryDiceLoss
13 | running_loss_final = 0
14 | iou_final = 0
15 | aux_final = 0
16 |
17 | class Trainer(object):
18 |
19 | def __init__(self, cuda, cfg, model_depth, model_rgb, model_fusion, train_loader, test_data_list, test_data_root, salmap_root, outpath, logging, writer, max_epoch):
20 | self.cuda = cuda
21 | self.model_depth = model_depth
22 | self.model_rgb = model_rgb
23 | self.model_fusion = model_fusion
24 |
25 | self.optim_depth = optim.SGD(self.model_depth.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay'])
26 | self.optim_rgb = optim.SGD(self.model_rgb.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay'])
27 | self.optim_fusion = optim.SGD(self.model_fusion.parameters(), lr=cfg[1]['lr'], momentum=cfg[1]['momentum'], weight_decay=cfg[1]['weight_decay'])
28 |
29 | self.train_loader = train_loader
30 | self.test_data_list = test_data_list
31 | self.test_data_root = test_data_root
32 | self.salmap_root = salmap_root
33 | self.test_loaders={}
34 | self.best_f = {}
35 | self.best_m = {}
36 | for data_i in self.test_data_list:
37 | MapRoot = self.salmap_root + data_i
38 | TestRoot = self.test_data_root + data_i
39 | if not os.path.exists(MapRoot):
40 | os.mkdir(MapRoot)
41 | loader_i = torch.utils.data.DataLoader(MyTestData(TestRoot, transform=True),
42 | batch_size = 1, shuffle=True, num_workers=0, pin_memory=True)
43 | self.test_loaders[data_i] = loader_i
44 | self.best_f[data_i] = -1
45 | self.best_m[data_i] = 10000
46 |
47 | self.epoch = 0
48 | self.iteration = 0
49 | self.max_iter = 0
50 | self.snapshot = cfg[1]['spshot']
51 | self.outpath = outpath
52 | self.sshow = cfg[1]['sshow']
53 | self.logging = logging
54 | self.writer = writer
55 | self.max_epoch = max_epoch
56 | self.base_lr = cfg[1]['lr']
57 |
58 | self.dice = BinaryDiceLoss()
59 |
60 |
61 | def train_epoch(self):
62 | self.logging.info("length trainloader: %s", len(self.train_loader))
63 | self.logging.info("current_lr is : %s", self.optim_fusion.param_groups[0]['lr'])
64 | for batch_idx, (img, mask, depth, bins) in enumerate(tqdm(self.train_loader)):
65 | ########## for debug
66 | # if batch_idx % 10==0 and batch_idx>10:
67 | # self.save_test(iteration)
68 | iteration = batch_idx + self.epoch * len(self.train_loader)
69 |
70 | if self.iteration != 0 and (iteration - 1) != self.iteration:
71 | continue # for resuming
72 | self.iteration = iteration
73 |
74 | if self.cuda:
75 | img, mask, depth, bins = img.cuda(), mask.cuda(), depth.cuda(), bins.cuda()
76 | img, mask, depth, bins = Variable(img), Variable(mask), Variable(depth), bins.cuda()
77 | # print(img.size())
78 | n, c, h, w = img.size() # batch_size, channels, height, weight
79 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1)
80 |
81 | self.optim_depth.zero_grad()
82 | self.optim_rgb.zero_grad()
83 | self.optim_fusion.zero_grad()
84 |
85 | global running_loss_final ,iou_final, aux_final
86 |
87 | d0, d1, d2, d3, d4 = self.model_depth(depth)
88 | h1, h2, h3, h4, h5 = self.model_rgb(img, bins, gumbel=True)
89 | predict_mask = self.model_fusion(h1, h2, h3, h4, h5, d0, d1, d2, d3, d4)
90 |
91 | ce_loss = cross_entropy2d(predict_mask, mask, size_average=False)
92 | iou_loss = torch.zeros(1)
93 | aux_ce_loss = torch.zeros(1)
94 | # iou_loss = iou(predict_mask, mask,size_average=False ) * 0.2
95 | # iou_loss = self.dice(predict_mask, mask)
96 | loss = ce_loss #+ iou_loss + aux_ce_loss
97 |
98 | running_loss_final += ce_loss.item()
99 | iou_final += iou_loss.item()
100 | aux_final += aux_ce_loss.item()
101 |
102 | if iteration % self.sshow == (self.sshow - 1):
103 | self.logging.info('\n [%3d, %6d, RGB-D Net ce_loss: %.3f aux_loss: %.3f iou_loss: %.3f]' % (
104 | self.epoch + 1, iteration + 1, running_loss_final / (n * self.sshow), aux_final / (n * self.sshow), iou_final / (n * self.sshow)))
105 |
106 | self.writer.add_scalar('train/iou_loss', iou_final / (n * self.sshow), iteration + 1)
107 | self.writer.add_scalar('train/aux_loss', aux_final / (n * self.sshow), iteration + 1)
108 |
109 | self.writer.add_scalar('train/lr', self.optim_fusion.param_groups[0]['lr'] , iteration + 1)
110 | self.writer.add_scalar('train/iter_ce_loss', running_loss_final / (n * self.sshow), iteration + 1)
111 |
112 | self.writer.add_scalar('train/epoch_ce_loss', running_loss_final / (n * self.sshow), self.epoch + 1)
113 | running_loss_final = 0.0
114 | iou_final= 0.0
115 | aux_final=0.0
116 |
117 | loss.backward()
118 | self.optim_depth.step()
119 | self.optim_rgb.step()
120 | self.optim_fusion.step()
121 |
122 | if iteration <= 200000:
123 | if iteration % self.snapshot == (self.snapshot - 1):
124 | self.save_test(iteration)
125 | else:
126 | if iteration % 10000 == (10000 - 1):
127 | self.save_test(iteration)
128 |
129 | def test(self,iteration, test_data):
130 | res = []
131 | MapRoot = self.salmap_root + test_data
132 | for id, (data, depth, bins, img_name, img_size) in enumerate(self.test_loaders[test_data]):
133 | # print('testing bach %d' % id)
134 | inputs = Variable(data).cuda()
135 | depth = Variable(depth).cuda()
136 | bins = Variable(bins).cuda()
137 | n, c, h, w = inputs.size()
138 | depth = depth.view(n, 1, h, w).repeat(1, c, 1, 1)
139 | torch.cuda.synchronize()
140 | start = time.time()
141 | with torch.no_grad():
142 | h1, h2, h3, h4, h5 = self.model_rgb(inputs, bins, gumbel=False)
143 | d0, d1, d2, d3, d4 = self.model_depth(depth)
144 | predict_mask = self.model_fusion(h1, h2, h3, h4, h5, d0, d1, d2, d3, d4)
145 | torch.cuda.synchronize()
146 | end = time.time()
147 |
148 | res.append(end - start)
149 | outputs_all = F.softmax(predict_mask, dim=1)
150 | outputs = outputs_all[0][1]
151 | # import pdb; pdb.set_trace()
152 | outputs = outputs.cpu().data.resize_(h, w)
153 |
154 | imsave(os.path.join(MapRoot,img_name[0] + '.png'), outputs, img_size)
155 | time_sum = 0
156 | for i in res:
157 | time_sum += i
158 | self.logging.info("FPS: %f" % (1.0 / (time_sum / len(res))))
159 | # -------------------------- validation --------------------------- #
160 | torch.cuda.empty_cache()
161 | F_measure, mae = get_FM(salpath=MapRoot+'/', gtpath=self.test_data_root + test_data+'/test_masks/')
162 |
163 | self.writer.add_scalar('test/'+ test_data +'_F_measure', F_measure, iteration +1)
164 | self.writer.add_scalar('test/'+ test_data +'_MAE', mae, iteration+1)
165 |
166 | self.logging.info(MapRoot.split('/')[-1] + ' F_measure: %f' , F_measure)
167 | self.logging.info(MapRoot.split('/')[-1] + ' MAE: %f', mae)
168 | print('the testing process has finished!')
169 |
170 | return F_measure, mae
171 |
172 |
173 | def save_test(self, iteration, epoch = -1):
174 | self.save(iteration, epoch)
175 | for data_i in self.test_data_list:
176 | f, m = self.test(iteration, data_i)
177 |
178 | self.best_f[data_i] = max(f, self.best_f[data_i])
179 | self.best_m[data_i] = min(m, self.best_m[data_i])
180 | self.writer.add_scalar('best/'+ data_i +'_MAE', self.best_m[data_i], iteration)
181 | self.writer.add_scalar('best/'+ data_i +'_Fmeasure', self.best_f[data_i], iteration)
182 |
183 | def save(self, iteration=-1, epoch=-1):
184 | savename_depth = ('%s/depth_snapshot_iter_%d.pth' % (self.outpath, iteration + 1))
185 | torch.save(self.model_depth.state_dict(), savename_depth)
186 | self.logging.info('save: (snapshot: %d)' % (iteration + 1))
187 |
188 | savename_rgb = ('%s/rgb_snapshot_iter_%d.pth' % (self.outpath, iteration + 1))
189 | torch.save(self.model_rgb.state_dict(), savename_rgb)
190 | self.logging.info('save: (snapshot: %d)' % (iteration + 1))
191 |
192 | savename_fusion = ('%s/fusion_snapshot_iter_%d.pth' % (self.outpath, iteration + 1))
193 | torch.save(self.model_fusion.state_dict(), savename_fusion)
194 | self.logging.info('save: (snapshot: %d)' % (iteration + 1))
195 |
196 |
197 | if epoch > 0 :
198 | savename_depth = ('%s/depth_snapshot_epoch_%d.pth' % (self.outpath, epoch + 1))
199 | torch.save(self.model_depth.state_dict(), savename_depth)
200 | self.logging.info('save: (snapshot: %d)' % (self.epoch + 1))
201 |
202 | savename_rgb = ('%s/rgb_snapshot_epoch_%d.pth' % (self.outpath, epoch + 1))
203 | torch.save(self.model_rgb.state_dict(), savename_rgb)
204 | self.logging.info('save: (snapshot: %d)' % (self.epoch + 1))
205 |
206 | savename_fusion = ('%s/fusion_snapshot_epoch_%d.pth' % (self.outpath, epoch + 1))
207 | torch.save(self.model_fusion.state_dict(), savename_fusion)
208 | self.logging.info('save: (snapshot: %d)' % (self.epoch + 1))
209 |
210 |
211 |
212 | def adjust_learning_rate(self, epoch):
213 | """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
214 | lr = self.base_lr
215 | if epoch >= 20:
216 | lr = 0.1 * lr
217 | if epoch >= 40:
218 | lr = 0.1 * lr
219 |
220 | self.optim_depth.param_groups[0]['lr']= lr
221 | self.optim_rgb.param_groups[0]['lr']= lr
222 | self.optim_fusion.param_groups[0]['lr']= lr
223 |
224 | def train(self):
225 | max_epoch = self.max_epoch
226 | print ("max_epoch", max_epoch)
227 | self.max_iter = int(math.ceil(len(self.train_loader) * self.max_epoch))
228 | print ("max_iter", self.max_iter)
229 |
230 | for epoch in range(max_epoch):
231 | # self.adjust_learning_rate(epoch)
232 | self.epoch = epoch
233 | self.train_epoch()
234 | # save each epoch.
235 | self.save_test(self.iteration, epoch = self.epoch )
236 |
237 | self.logging.info('all training process finished')
238 | print(self.best_f)
239 | print(self.best_m)
240 |
241 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('utils')
3 |
4 |
5 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/evaluateFM.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/evaluateFM.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/evaluateFM.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/evaluateFM.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/functions.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/functions.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/functions.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunpeng1996/DSA2F/22ba0e20ef1c5ace50b748dcfe1b94c9c4d11a87/utils/__pycache__/functions.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/evaluateFM.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | # import cv2
4 | import matplotlib.pyplot as plt
5 | import PIL.Image as Image
6 | def get_FM(salpath,gtpath):
7 |
8 | gtdir = gtpath
9 | saldir = salpath
10 |
11 | files = os.listdir(gtdir)
12 | eps = np.finfo(float).eps
13 |
14 | m_pres = np.zeros(21)
15 | m_recs = np.zeros(21)
16 | m_fms = np.zeros(21)
17 | m_thfm = 0
18 | m_mea = 0
19 | it = 1
20 | for i, name in enumerate(files):
21 | if not os.path.exists(gtdir + name):
22 | print(gtdir + name, 'does not exist')
23 | gt = Image.open(gtdir + name)
24 | gt = np.array(gt, dtype=np.uint8)
25 |
26 |
27 | mask=Image.open(saldir+name).convert('L')
28 | mask=mask.resize((np.shape(gt)[1],np.shape(gt)[0]))
29 | mask = np.array(mask, dtype=np.float)
30 | # salmap = cv2.resize(salmap,(W,H))
31 |
32 | if len(mask.shape) != 2:
33 | mask = mask[:, :, 0]
34 | mask = (mask - mask.min()) / (mask.max() - mask.min() + eps)
35 | gt[gt != 0] = 1
36 | pres = []
37 | recs = []
38 | fms = []
39 | mea = np.abs(gt-mask).mean()
40 | # threshold fm
41 | binary = np.zeros(mask.shape)
42 | th = 2*mask.mean()
43 | if th > 1:
44 | th = 1
45 | binary[mask >= th] = 1
46 | sb = (binary * gt).sum()
47 | pre = sb / (binary.sum()+eps)
48 | rec = sb / (gt.sum()+eps)
49 | thfm = 1.3 * pre * rec / (0.3 * pre + rec + eps)
50 | for th in np.linspace(0, 1, 21):
51 | binary = np.zeros(mask.shape)
52 | binary[ mask >= th] = 1
53 | pre = (binary * gt).sum() / (binary.sum()+eps)
54 | rec = (binary * gt).sum() / (gt.sum()+ eps)
55 | fm = 1.3 * pre * rec / (0.3*pre + rec + eps)
56 | pres.append(pre)
57 | recs.append(rec)
58 | fms.append(fm)
59 | fms = np.array(fms)
60 | pres = np.array(pres)
61 | recs = np.array(recs)
62 | m_mea = m_mea * (it-1) / it + mea / it
63 | m_fms = m_fms * (it - 1) / it + fms / it
64 | m_recs = m_recs * (it - 1) / it + recs / it
65 | m_pres = m_pres * (it - 1) / it + pres / it
66 | m_thfm = m_thfm * (it - 1) / it + thfm / it
67 | it += 1
68 | return m_thfm, m_mea
69 |
70 | if __name__ == '__main__':
71 | m_thfm, m_mea = get_FM()
72 | print(m_thfm)
73 | print(m_mea)
--------------------------------------------------------------------------------
/utils/functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import torch
4 | # from scipy.misc import imresize
5 | from PIL import Image
6 | import os
7 | import cv2
8 |
9 | def adaptive_bins( hist,threshold):
10 | new = hist.copy()
11 | peak=hist.max()
12 | peak_depth=np.where(hist==peak)[0]
13 | delta_hist=np.diff(hist,n=1,axis=0)
14 | #print(peak,peak_depth,peak_depth.shape)
15 | left=peak_depth
16 | right=peak_depth
17 | i = np.array([peak_depth[0]])
18 | while(1):
19 | new[[i]]=0
20 | if (i>=254):
21 | right=np.array([254])
22 | break
23 | if (delta_hist[i]<0):
24 | i=i+1
25 | elif (hist[i]<=threshold*peak):
26 | right = i
27 | break
28 | else:
29 | i=i+1
30 | i = np.array([peak_depth[0]-1])
31 | while(1):
32 | new[[i+1]]=0
33 | if (i<=0):
34 | left=np.array([0])
35 | break
36 | if (delta_hist[i]>0):
37 | i=i-1
38 | elif (hist[i]<=threshold*peak):
39 | left = i+1
40 | break
41 | else:
42 | i=i-1
43 | #print(peak,peak_depth,left[0],right[0])
44 | return [new,left[0],right[0]]
45 |
46 | def get_bins_masks( depth):
47 | mask_list=[]
48 | hist = cv2.calcHist([depth],[0],None,[256],[0,255])
49 |
50 | hist1,left1,right1=adaptive_bins(hist,0.7)
51 | mask1 = (depth>left1-0.2*(right1-left1)) * (depth<=right1+0.2*(right1-left1))
52 | #mask1 = (depth>left1) * (depth<=right1)
53 |
54 | mask_list.append(mask1)
55 |
56 | hist2,left2,right2=adaptive_bins(hist1,0.2)
57 | mask2 = (depth>left2-0.2*(right2-left2)) * (depth<=right2+0.2*(right2-left2))
58 | #mask2 = (depth>left2) * (depth<=right2)
59 |
60 | mask_list.append(mask2)
61 |
62 | mask3_1 =(depth>left1) * (depth<=right1)
63 | mask3_2 =(depth>left2) * (depth<=right2)
64 | mask3=(~mask3_2)*(~mask3_1)
65 |
66 | mask_list.append(mask3)
67 | mask_bins = np.stack(mask_list,axis=0)
68 |
69 | return mask_bins
70 |
71 |
72 | def create_exp_dir(path, scripts_to_save=None):
73 | import time
74 | time.sleep(2)
75 | if not os.path.exists(path):
76 | os.makedirs(path)
77 | print('Experiment dir : {}'.format(path))
78 |
79 | if scripts_to_save is not None:
80 | os.makedirs(os.path.join(path, 'scripts'))
81 | for script in scripts_to_save:
82 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
83 | shutil.copyfile(script, dst_file)
84 |
85 |
86 |
87 | def count_parameters_in_MB(model):
88 | return np.sum(np.prod(v.size()) for v in model.parameters())/1e6
89 |
90 |
91 |
92 | def imsave(file_name, img, img_size):
93 | """
94 | save a torch tensor as an image
95 | :param file_name: 'image/folder/image_name'
96 | :param img: 3*h*w torch tensor
97 | :return: nothing
98 | """
99 | assert(type(img) == torch.FloatTensor,
100 | 'img must be a torch.FloatTensor')
101 | ndim = len(img.size())
102 | assert(ndim == 2 or ndim == 3,
103 | 'img must be a 2 or 3 dimensional tensor')
104 |
105 | img = img.numpy()
106 |
107 | img = np.array(Image.fromarray(img).resize((img_size[1][0], img_size[0][0]), Image.NEAREST))
108 | # img = imresize(img, [img_size[1][0], img_size[0][0]], interp='nearest')
109 | if ndim == 3:
110 | plt.imsave(file_name, np.transpose(img, (1, 2, 0)))
111 | else:
112 | plt.imsave(file_name, img, cmap='gray')
113 |
114 | def load_pretrain(path, state_dict, name):
115 | state = torch.load(path)
116 | if 'state_dict' in state:
117 | state = state['state_dict']
118 | name = "module."+name
119 | length = len(name)
120 | for k, v in state.items():
121 | if k[:length] == name:
122 | if k[length:] in state_dict.keys():
123 | state_dict[k[length:]] = v
124 | # print(k[length:])
125 | else:
126 | print("pass keys: ",k[7:])
127 |
128 | return state_dict
129 |
--------------------------------------------------------------------------------
/utils/pretreat_SIP.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL.Image
4 | import scipy.io as sio
5 | import torch
6 | from torch.utils import data
7 | import cv2
8 | from utils.functions import adaptive_bins, get_bins_masks
9 |
10 | root = "/data/wenhu/RGBD-SOD/SOD-RGBD/val/raw_SIP"
11 |
12 | img_root = os.path.join(root, 'test_images')
13 | depth_root = os.path.join(root, 'test_depth')
14 | gt_root = os.path.join(root,'test_masks')
15 | names=[]
16 | img_names=[]
17 | depth_names=[]
18 | gt_names=[]
19 |
20 | file_names = os.listdir(img_root)
21 |
22 | for i, name in enumerate(file_names):
23 | if not name.endswith('.jpg'):
24 | continue
25 | names.append(name[:-4])
26 | img_names.append(
27 | os.path.join(img_root, name)
28 | )
29 |
30 | depth_names.append(
31 | os.path.join(depth_root, name[:-4] + '.png')
32 | )
33 |
34 | gt_names.append(
35 | os.path.join(gt_root, name[:-4] + '.png')
36 | )
37 |
38 | new_root = "/data/wenhu/RGBD-SOD/SOD-RGBD/val/SIP"
39 | new_img_root = os.path.join(new_root, 'test_images')
40 | new_depth_root = os.path.join(new_root, 'test_depth')
41 | new_gt_root = os.path.join(new_root,'test_masks')
42 |
43 | if not os.path.exists(new_depth_root):
44 | os.mkdir(new_depth_root)
45 | if not os.path.exists(new_img_root):
46 | os.mkdir(new_img_root)
47 | if not os.path.exists(new_gt_root):
48 | os.mkdir(new_gt_root)
49 |
50 | # i=0
51 | # print(gt_names[0])
52 | # img = np.array(PIL.Image.open(img_names[i]))
53 | # depth = np.array(PIL.Image.open(depth_names[i]))
54 | # gt = np.array(PIL.Image.open(gt_names[i]))
55 | # print(img.shape, depth.shape, gt.shape)
56 |
57 |
58 |
59 |
60 | for i in range(len(img_names)):
61 | img = cv2.imread(img_names[i])
62 | depth = cv2.imread(depth_names[i])
63 | gt = cv2.imread(gt_names[i])
64 |
65 | img = cv2.resize(img, (512,512), interpolation = cv2.INTER_LINEAR)
66 | depth = cv2.resize(depth, (512,512), interpolation = cv2.INTER_LINEAR)[:,:,0]
67 | gt = cv2.resize(gt, (512,512), interpolation = cv2.INTER_LINEAR)[:,:,0]
68 |
69 | cv2.imwrite( os.path.join(new_img_root, names[i]+ '.jpg'), img )
70 | cv2.imwrite( os.path.join(new_depth_root, names[i]+ '.png'), depth )
71 | cv2.imwrite( os.path.join(new_gt_root, names[i]+ '.png'), gt )
72 |
73 | # if i>10:
74 | # break
75 |
--------------------------------------------------------------------------------