├── dataset ├── __init__.py ├── dataset.py └── joint_dataset.py ├── networks ├── __init__.py ├── vgg.py ├── poolnet.py ├── deeplab_resnet.py └── joint_poolnet.py ├── train.sh ├── joint_train.sh ├── speed_test.sh ├── forward_edge.sh ├── forward.sh ├── LICENSE ├── main.py ├── README.md ├── joint_main.py ├── solver.py └── joint_solver.py /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | python main.py --arch resnet --train_root ./data/DUTS/DUTS-TR --train_list ./data/DUTS/DUTS-TR/train_pair.lst 4 | # you can optionly change the -lr and -wd params 5 | -------------------------------------------------------------------------------- /joint_train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | python joint_main.py --arch resnet --train_root ./data/DUTS/DUTS-TR --train_list ./data/DUTS/DUTS-TR/train_pair.lst --train_edge_root ./data/HED-BSDS_PASCAL --train_edge_list ./data/HED-BSDS_PASCAL/bsds_pascal_train_pair_r_val_r_small.lst 4 | # you can optionly change the -lr and -wd params 5 | -------------------------------------------------------------------------------- /speed_test.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1 gpu id ,2 python file name, # 3 results folder name 4 | 5 | ARRAY=(m_r) 6 | ELEMENTS=${#ARRAY[@]} 7 | 8 | echo "Testing on GPU " $1 " with file " $2 " to " $3 9 | 10 | for (( i=0;i<$ELEMENTS;i++)); do 11 | CUDA_VISIBLE_DEVICES=$1 python $2 --mode='test' --model=$3'/models/final.pth' --test_fold=$3'-sal-'${ARRAY[${i}]} --sal_mode=${ARRAY[${i}]} 12 | done 13 | 14 | echo "Speed test done." 15 | -------------------------------------------------------------------------------- /forward_edge.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1 gpu id ,2 python file name, # 3 results folder name 4 | 5 | ARRAY=(b) 6 | ELEMENTS=${#ARRAY[@]} 7 | 8 | echo "Testing on GPU " $1 " with file " $2 " to " $3 9 | 10 | for (( i=0;i<$ELEMENTS;i++)); do 11 | CUDA_VISIBLE_DEVICES=$1 python $2 --mode='test' --model=$3'/models/final.pth' --test_fold=$3'-edge' --sal_mode=${ARRAY[${i}]} --test_mode=0 12 | done 13 | 14 | echo "Testing on bsds dataset done." 15 | -------------------------------------------------------------------------------- /forward.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1 gpu id ,2 python file name, # 3 results folder name 4 | 5 | ARRAY=(e p d h s t) 6 | # ARRAY=(m) 7 | ELEMENTS=${#ARRAY[@]} 8 | 9 | echo "Testing on GPU " $1 " with file " $2 " to " $3 10 | 11 | for (( i=0;i<$ELEMENTS;i++)); do 12 | CUDA_VISIBLE_DEVICES=$1 python $2 --mode='test' --model=$3'/models/final.pth' --test_fold=$3'-sal-'${ARRAY[${i}]} --sal_mode=${ARRAY[${i}]} 13 | done 14 | 15 | echo "Testing on e,p,d,h,s,t datasets done." 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jiang-Jiang Liu and Qibin Hou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /networks/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | # vgg16 8 | def vgg(cfg, i, batch_norm=False): 9 | layers = [] 10 | in_channels = i 11 | stage = 1 12 | for v in cfg: 13 | if v == 'M': 14 | stage += 1 15 | if stage == 6: 16 | layers += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1)] 17 | else: 18 | layers += [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)] 19 | else: 20 | if stage == 6: 21 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 22 | else: 23 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 24 | if batch_norm: 25 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 26 | else: 27 | layers += [conv2d, nn.ReLU(inplace=True)] 28 | in_channels = v 29 | return layers 30 | 31 | class vgg16(nn.Module): 32 | def __init__(self): 33 | super(vgg16, self).__init__() 34 | self.cfg = {'tun': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'tun_ex': [512, 512, 512]} 35 | self.extract = [8, 15, 22, 29] # [3, 8, 15, 22, 29] 36 | self.base = nn.ModuleList(vgg(self.cfg['tun'], 3)) 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 40 | m.weight.data.normal_(0, 0.01) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.weight.data.fill_(1) 43 | m.bias.data.zero_() 44 | 45 | def load_pretrained_model(self, model): 46 | self.base.load_state_dict(model, strict=False) 47 | 48 | def forward(self, x): 49 | tmp_x = [] 50 | for k in range(len(self.base)): 51 | x = self.base[k](x) 52 | if k in self.extract: 53 | tmp_x.append(x) 54 | return tmp_x 55 | 56 | class vgg16_locate(nn.Module): 57 | def __init__(self): 58 | super(vgg16_locate,self).__init__() 59 | self.vgg16 = vgg16() 60 | self.in_planes = 512 61 | self.out_planes = [512, 256, 128] 62 | 63 | ppms, infos = [], [] 64 | for ii in [1, 3, 5]: 65 | ppms.append(nn.Sequential(nn.AdaptiveAvgPool2d(ii), nn.Conv2d(self.in_planes, self.in_planes, 1, 1, bias=False), nn.ReLU(inplace=True))) 66 | self.ppms = nn.ModuleList(ppms) 67 | 68 | self.ppm_cat = nn.Sequential(nn.Conv2d(self.in_planes * 4, self.in_planes, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 69 | for ii in self.out_planes: 70 | infos.append(nn.Sequential(nn.Conv2d(self.in_planes, ii, 3, 1, 1, bias=False), nn.ReLU(inplace=True))) 71 | self.infos = nn.ModuleList(infos) 72 | 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv2d): 75 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 76 | m.weight.data.normal_(0, 0.01) 77 | elif isinstance(m, nn.BatchNorm2d): 78 | m.weight.data.fill_(1) 79 | m.bias.data.zero_() 80 | 81 | def load_pretrained_model(self, model): 82 | self.vgg16.load_pretrained_model(model) 83 | 84 | def forward(self, x): 85 | x_size = x.size()[2:] 86 | xs = self.vgg16(x) 87 | 88 | xls = [xs[-1]] 89 | for k in range(len(self.ppms)): 90 | xls.append(F.interpolate(self.ppms[k](xs[-1]), xs[-1].size()[2:], mode='bilinear', align_corners=True)) 91 | xls = self.ppm_cat(torch.cat(xls, dim=1)) 92 | infos = [] 93 | for k in range(len(self.infos)): 94 | infos.append(self.infos[k](F.interpolate(xls, xs[len(self.infos) - 1 - k].size()[2:], mode='bilinear', align_corners=True))) 95 | 96 | return xs, infos 97 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as F 8 | import numbers 9 | import numpy as np 10 | import random 11 | 12 | class ImageDataTrain(data.Dataset): 13 | def __init__(self, data_root, data_list): 14 | self.sal_root = data_root 15 | self.sal_source = data_list 16 | 17 | with open(self.sal_source, 'r') as f: 18 | self.sal_list = [x.strip() for x in f.readlines()] 19 | 20 | self.sal_num = len(self.sal_list) 21 | 22 | 23 | def __getitem__(self, item): 24 | # sal data loading 25 | im_name = self.sal_list[item % self.sal_num].split()[0] 26 | gt_name = self.sal_list[item % self.sal_num].split()[1] 27 | sal_image = load_image(os.path.join(self.sal_root, im_name)) 28 | sal_label = load_sal_label(os.path.join(self.sal_root, gt_name)) 29 | sal_image, sal_label = cv_random_flip(sal_image, sal_label) 30 | sal_image = torch.Tensor(sal_image) 31 | sal_label = torch.Tensor(sal_label) 32 | 33 | sample = {'sal_image': sal_image, 'sal_label': sal_label} 34 | return sample 35 | 36 | def __len__(self): 37 | return self.sal_num 38 | 39 | class ImageDataTest(data.Dataset): 40 | def __init__(self, data_root, data_list): 41 | self.data_root = data_root 42 | self.data_list = data_list 43 | with open(self.data_list, 'r') as f: 44 | self.image_list = [x.strip() for x in f.readlines()] 45 | 46 | self.image_num = len(self.image_list) 47 | 48 | def __getitem__(self, item): 49 | image, im_size = load_image_test(os.path.join(self.data_root, self.image_list[item])) 50 | image = torch.Tensor(image) 51 | 52 | return {'image': image, 'name': self.image_list[item % self.image_num], 'size': im_size} 53 | 54 | def __len__(self): 55 | return self.image_num 56 | 57 | 58 | def get_loader(config, mode='train', pin=False): 59 | shuffle = False 60 | if mode == 'train': 61 | shuffle = True 62 | dataset = ImageDataTrain(config.train_root, config.train_list) 63 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 64 | else: 65 | dataset = ImageDataTest(config.test_root, config.test_list) 66 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 67 | return data_loader 68 | 69 | def load_image(path): 70 | if not os.path.exists(path): 71 | print('File {} not exists'.format(path)) 72 | im = cv2.imread(path) 73 | in_ = np.array(im, dtype=np.float32) 74 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 75 | in_ = in_.transpose((2,0,1)) 76 | return in_ 77 | 78 | def load_image_test(path): 79 | if not os.path.exists(path): 80 | print('File {} not exists'.format(path)) 81 | im = cv2.imread(path) 82 | in_ = np.array(im, dtype=np.float32) 83 | im_size = tuple(in_.shape[:2]) 84 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 85 | in_ = in_.transpose((2,0,1)) 86 | return in_, im_size 87 | 88 | def load_sal_label(path): 89 | if not os.path.exists(path): 90 | print('File {} not exists'.format(path)) 91 | im = Image.open(path) 92 | label = np.array(im, dtype=np.float32) 93 | if len(label.shape) == 3: 94 | label = label[:,:,0] 95 | label = label / 255. 96 | label = label[np.newaxis, ...] 97 | return label 98 | 99 | def cv_random_flip(img, label): 100 | flip_flag = random.randint(0, 1) 101 | if flip_flag == 1: 102 | img = img[:,:,::-1].copy() 103 | label = label[:,:,::-1].copy() 104 | return img, label 105 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataset.dataset import get_loader 4 | from solver import Solver 5 | 6 | def get_test_info(sal_mode='e'): 7 | if sal_mode == 'e': 8 | image_root = './data/ECSSD/Imgs/' 9 | image_source = './data/ECSSD/test.lst' 10 | elif sal_mode == 'p': 11 | image_root = './data/PASCALS/Imgs/' 12 | image_source = './data/PASCALS/test.lst' 13 | elif sal_mode == 'd': 14 | image_root = './data/DUTOMRON/Imgs/' 15 | image_source = './data/DUTOMRON/test.lst' 16 | elif sal_mode == 'h': 17 | image_root = './data/HKU-IS/Imgs/' 18 | image_source = './data/HKU-IS/test.lst' 19 | elif sal_mode == 's': 20 | image_root = './data/SOD/Imgs/' 21 | image_source = './data/SOD/test.lst' 22 | elif sal_mode == 't': 23 | image_root = './data/DUTS-TE/Imgs/' 24 | image_source = './data/DUTS-TE/test.lst' 25 | elif sal_mode == 'm_r': # for speed test 26 | image_root = './data/MSRA/Imgs_resized/' 27 | image_source = './data/MSRA/test_resized.lst' 28 | 29 | return image_root, image_source 30 | 31 | def main(config): 32 | if config.mode == 'train': 33 | train_loader = get_loader(config) 34 | run = 0 35 | while os.path.exists("%s/run-%d" % (config.save_folder, run)): 36 | run += 1 37 | os.mkdir("%s/run-%d" % (config.save_folder, run)) 38 | os.mkdir("%s/run-%d/models" % (config.save_folder, run)) 39 | config.save_folder = "%s/run-%d" % (config.save_folder, run) 40 | train = Solver(train_loader, None, config) 41 | train.train() 42 | elif config.mode == 'test': 43 | config.test_root, config.test_list = get_test_info(config.sal_mode) 44 | test_loader = get_loader(config, mode='test') 45 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold) 46 | test = Solver(None, test_loader, config) 47 | test.test() 48 | else: 49 | raise IOError("illegal input!!!") 50 | 51 | if __name__ == '__main__': 52 | 53 | vgg_path = './dataset/pretrained/vgg16_20M.pth' 54 | resnet_path = './dataset/pretrained/resnet50_caffe.pth' 55 | 56 | parser = argparse.ArgumentParser() 57 | 58 | # Hyper-parameters 59 | parser.add_argument('--n_color', type=int, default=3) 60 | parser.add_argument('--lr', type=float, default=5e-5) # Learning rate resnet:5e-5, vgg:1e-4 61 | parser.add_argument('--wd', type=float, default=0.0005) # Weight decay 62 | parser.add_argument('--cuda', type=bool, default=True) 63 | 64 | # Training settings 65 | parser.add_argument('--arch', type=str, default='resnet') # resnet or vgg 66 | parser.add_argument('--pretrained_model', type=str, default=resnet_path) 67 | parser.add_argument('--epoch', type=int, default=24) 68 | parser.add_argument('--batch_size', type=int, default=1) # only support 1 now 69 | parser.add_argument('--num_thread', type=int, default=1) 70 | parser.add_argument('--load', type=str, default='') 71 | parser.add_argument('--save_folder', type=str, default='./results') 72 | parser.add_argument('--epoch_save', type=int, default=3) 73 | parser.add_argument('--iter_size', type=int, default=10) 74 | parser.add_argument('--show_every', type=int, default=50) 75 | 76 | # Train data 77 | parser.add_argument('--train_root', type=str, default='') 78 | parser.add_argument('--train_list', type=str, default='') 79 | 80 | # Testing settings 81 | parser.add_argument('--model', type=str, default=None) # Snapshot 82 | parser.add_argument('--test_fold', type=str, default=None) # Test results saving folder 83 | parser.add_argument('--sal_mode', type=str, default='e') # Test image dataset 84 | 85 | # Misc 86 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 87 | config = parser.parse_args() 88 | 89 | if not os.path.exists(config.save_folder): 90 | os.mkdir(config.save_folder) 91 | 92 | # Get test set info 93 | test_root, test_list = get_test_info(config.sal_mode) 94 | config.test_root = test_root 95 | config.test_list = test_list 96 | 97 | main(config) 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Simple Pooling-Based Design for Real-Time Salient Object Detection 2 | 3 | ### This is a PyTorch implementation of our CVPR 2019 [paper](https://arxiv.org/abs/1904.09569). 4 | 5 | ## Prerequisites 6 | 7 | - [Pytorch 0.4.1+](http://pytorch.org/) 8 | - [torchvision](http://pytorch.org/) 9 | 10 | ## Update 11 | 12 | We released our code for joint training with edge, which is also our best performance model. 13 | 14 | ## Todo 15 | 16 | Merge DSS into this repo. 17 | 18 | ## Usage 19 | 20 | ### 1. Clone the repository 21 | 22 | ```shell 23 | git clone https://github.com/backseason/PoolNet.git 24 | cd PoolNet/ 25 | ``` 26 | 27 | ### 2. Download the datasets 28 | 29 | Download the following datasets and unzip them into `data` folder. 30 | 31 | * [MSRA-B and HKU-IS](https://drive.google.com/open?id=1immMDAPC9Eb2KCtGi6AdfvXvQJnSkHHo) dataset. The .lst file for training is `data/msrab_hkuis/msrab_hkuis_train_no_small.lst`. 32 | * [DUTS](https://drive.google.com/open?id=14RA-qr7JxU6iljLv6PbWUCQG0AJsEgmd) dataset. The .lst file for training is `data/DUTS/DUTS-TR/train_pair.lst`. 33 | * [BSDS-PASCAL](https://drive.google.com/open?id=1qx8eyDNAewAAc6hlYHx3B9LXvEGSIqQp) dataset. The .lst file for training is `./data/HED-BSDS_PASCAL/bsds_pascal_train_pair_r_val_r_small.lst`. 34 | * [Datasets for testing](https://drive.google.com/open?id=1eB-59cMrYnhmMrz7hLWQ7mIssRaD-f4o). 35 | 36 | ### 3. Download the pre-trained models for backbone 37 | 38 | Download the following [pre-trained models](https://drive.google.com/open?id=1Q2Fg2KZV8AzNdWNjNgcavffKJBChdBgy) into `data/pretrained` folder. (Now we only provide models trained w/o edge) 39 | 40 | ### 4. Train 41 | 42 | 1. Set the `--train_root` and `--train_list` path in `train.sh` correctly. 43 | 44 | 2. We demo using ResNet-50 as network backbone and train with a initial lr of 5e-5 for 24 epoches, which is divided by 10 after 15 epochs. 45 | ```shell 46 | ./train.sh 47 | ``` 48 | 3. We demo joint training with edge using ResNet-50 as network backbone and train with a initial lr of 5e-5 for 11 epoches, which is divided by 10 after 8 epochs. Each epoch runs for 30000 iters. 49 | ```shell 50 | ./joint_train.sh 51 | ``` 52 | 4. After training the result model will be stored under `results/run-*` folder. 53 | 54 | ### 5. Test 55 | 56 | For single dataset testing: `*` changes accordingly and `--sal_mode` indicates different datasets (details can be found in `main.py`) 57 | ```shell 58 | python main.py --mode='test' --model='results/run-*/models/final.pth' --test_fold='results/run-*-sal-e' --sal_mode='e' 59 | ``` 60 | For all datasets testing used in our paper: `2` indicates the gpu to use 61 | ```shell 62 | ./forward.sh 2 main.py results/run-* 63 | ``` 64 | For joint training, to get salient object detection results use 65 | ```shell 66 | ./forward.sh 2 joint_main.py results/run-* 67 | ``` 68 | to get edge detection results use 69 | ```shell 70 | ./forward_edge.sh 2 joint_main.py results/run-* 71 | ``` 72 | 73 | All results saliency maps will be stored under `results/run-*-sal-*` folders in .png formats. 74 | 75 | 76 | ### 6. Pre-trained models, pre-computed results and evaluation results 77 | 78 | We provide the pre-trained model, pre-computed saliency maps and evaluation results for: 79 | 1. PoolNet-ResNet50 w/o edge model [run-0](https://drive.google.com/open?id=12Zgth_CP_kZPdXwnBJOu4gcTyVgV2Nof). 80 | 2. PoolNet-ResNet50 w/ edge model (best performance) [run-1](https://drive.google.com/open?id=1sH5RKEt6SnG33Z4sI-hfLs2d21GmegwR). 81 | 82 | Note: 83 | 84 | 1. only support `bath_size=1` 85 | 2. Except for the backbone we do not use BN layer. 86 | 87 | ### 7. Wants to participate in the project? 88 | 89 | You are welcome to send us your network to make this project bigger. 90 | 91 | Please email {j04.liu, andrewhoux}@gmail.com. 92 | 93 | 94 | ### If you think this work is helpful, please cite 95 | ```latex 96 | @inproceedings{Liu2019PoolSal, 97 | title={A Simple Pooling-Based Design for Real-Time Salient Object Detection}, 98 | author={Jiang-Jiang Liu and Qibin Hou and Ming-Ming Cheng and Jiashi Feng and Jianmin Jiang}, 99 | booktitle={IEEE CVPR}, 100 | year={2019}, 101 | } 102 | ``` 103 | ```latex 104 | @article{HouPami19Dss, 105 | title={Deeply Supervised Salient Object Detection with Short Connections}, 106 | author={Hou, Qibin and Cheng, Ming-Ming and Hu, Xiaowei and Borji, Ali and Tu, Zhuowen and Torr, Philip}, 107 | year = {2019}, 108 | volume={41}, 109 | number={4}, 110 | pages={815-828}, 111 | journal={IEEE TPAMI} 112 | } 113 | ``` 114 | 115 | Thanks to [DSS](https://github.com/Andrew-Qibin/DSS) and [DSS-pytorch](https://github.com/AceCoooool/DSS-pytorch). 116 | -------------------------------------------------------------------------------- /joint_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataset.joint_dataset import get_loader 4 | from joint_solver import Solver 5 | 6 | def get_test_info(sal_mode='e'): 7 | if sal_mode == 'e': 8 | image_root = './data/ECSSD/Imgs/' 9 | image_source = './data/ECSSD/test.lst' 10 | elif sal_mode == 'p': 11 | image_root = './data/PASCALS/Imgs/' 12 | image_source = './data/PASCALS/test.lst' 13 | elif sal_mode == 'd': 14 | image_root = './data/DUTOMRON/Imgs/' 15 | image_source = './data/DUTOMRON/test.lst' 16 | elif sal_mode == 'h': 17 | image_root = './data/HKU-IS/Imgs/' 18 | image_source = './data/HKU-IS/test.lst' 19 | elif sal_mode == 's': 20 | image_root = './data/SOD/Imgs/' 21 | image_source = './data/SOD/test.lst' 22 | elif sal_mode == 't': 23 | image_root = './data/DUTS-TE/Imgs/' 24 | image_source = './data/DUTS-TE/test.lst' 25 | elif sal_mode == 'm_r': # for speed test 26 | image_root = './data/MSRA/Imgs_resized/' 27 | image_source = './data/MSRA/test_resized.lst' 28 | elif sal_mode == 'b': # BSDS dataset for edge evaluation 29 | image_root = './data/HED-BSDS_PASCAL/HED-BSDS/test/' 30 | image_source = './data/HED-BSDS_PASCAL/HED-BSDS/test.lst' 31 | return image_root, image_source 32 | 33 | def main(config): 34 | if config.mode == 'train': 35 | train_loader = get_loader(config) 36 | run = 0 37 | while os.path.exists("%s/run-%d" % (config.save_folder, run)): 38 | run += 1 39 | os.mkdir("%s/run-%d" % (config.save_folder, run)) 40 | os.mkdir("%s/run-%d/models" % (config.save_folder, run)) 41 | config.save_folder = "%s/run-%d" % (config.save_folder, run) 42 | train = Solver(train_loader, None, config) 43 | train.train() 44 | elif config.mode == 'test': 45 | config.test_root, config.test_list = get_test_info(config.sal_mode) 46 | test_loader = get_loader(config, mode='test') 47 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold) 48 | test = Solver(None, test_loader, config) 49 | test.test(test_mode=config.test_mode) 50 | else: 51 | raise IOError("illegal input!!!") 52 | 53 | if __name__ == '__main__': 54 | 55 | vgg_path = './dataset/pretrained/vgg16_20M.pth' 56 | resnet_path = './dataset/pretrained/resnet50_caffe.pth' 57 | 58 | parser = argparse.ArgumentParser() 59 | 60 | # Hyper-parameters 61 | parser.add_argument('--n_color', type=int, default=3) 62 | parser.add_argument('--lr', type=float, default=5e-5) # Learning rate resnet:5e-5, vgg:1e-4 63 | parser.add_argument('--wd', type=float, default=0.0005) # Weight decay 64 | parser.add_argument('--cuda', type=bool, default=True) 65 | 66 | # Training settings 67 | parser.add_argument('--arch', type=str, default='resnet') # resnet or vgg 68 | parser.add_argument('--pretrained_model', type=str, default=resnet_path) 69 | parser.add_argument('--epoch', type=int, default=11) 70 | parser.add_argument('--batch_size', type=int, default=1) # only support 1 now 71 | parser.add_argument('--num_thread', type=int, default=1) 72 | parser.add_argument('--load', type=str, default='') 73 | parser.add_argument('--save_folder', type=str, default='./results') 74 | parser.add_argument('--epoch_save', type=int, default=3) 75 | parser.add_argument('--iter_size', type=int, default=10) 76 | parser.add_argument('--show_every', type=int, default=50) 77 | 78 | # Train data 79 | parser.add_argument('--train_root', type=str, default='') 80 | parser.add_argument('--train_list', type=str, default='') 81 | parser.add_argument('--train_edge_root', type=str, default='') # path for edge data 82 | parser.add_argument('--train_edge_list', type=str, default='') # list file for edge data 83 | 84 | # Testing settings 85 | parser.add_argument('--model', type=str, default=None) # Snapshot 86 | parser.add_argument('--test_fold', type=str, default=None) # Test results saving folder 87 | parser.add_argument('--test_mode', type=int, default=1) # 0->edge, 1->saliency 88 | parser.add_argument('--sal_mode', type=str, default='e') # Test image dataset 89 | 90 | # Misc 91 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 92 | config = parser.parse_args() 93 | 94 | if not os.path.exists(config.save_folder): 95 | os.mkdir(config.save_folder) 96 | 97 | # Get test set info 98 | test_root, test_list = get_test_info(config.sal_mode) 99 | config.test_root = test_root 100 | config.test_list = test_list 101 | 102 | main(config) 103 | -------------------------------------------------------------------------------- /dataset/joint_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as F 8 | import numbers 9 | import numpy as np 10 | import random 11 | 12 | class ImageDataTrain(data.Dataset): 13 | def __init__(self, sal_data_root, sal_data_list, edge_data_root, edge_data_list): 14 | self.sal_root = sal_data_root 15 | self.sal_source = sal_data_list 16 | self.edge_root = edge_data_root 17 | self.edge_source = edge_data_list 18 | 19 | with open(self.sal_source, 'r') as f: 20 | self.sal_list = [x.strip() for x in f.readlines()] 21 | with open(self.edge_source, 'r') as f: 22 | self.edge_list = [x.strip() for x in f.readlines()] 23 | 24 | self.sal_num = len(self.sal_list) 25 | self.edge_num = len(self.edge_list) 26 | 27 | 28 | def __getitem__(self, item): 29 | # edge data loading 30 | edge_im_name = self.edge_list[item % self.edge_num].split()[0] 31 | edge_gt_name = self.edge_list[item % self.edge_num].split()[1] 32 | edge_image = load_image(os.path.join(self.edge_root, edge_im_name)) 33 | edge_label = load_edge_label(os.path.join(self.edge_root, edge_gt_name)) 34 | edge_image = torch.Tensor(edge_image) 35 | edge_label = torch.Tensor(edge_label) 36 | 37 | # sal data loading 38 | sal_im_name = self.sal_list[item % self.sal_num].split()[0] 39 | sal_gt_name = self.sal_list[item % self.sal_num].split()[1] 40 | sal_image = load_image(os.path.join(self.sal_root, sal_im_name)) 41 | sal_label = load_sal_label(os.path.join(self.sal_root, sal_gt_name)) 42 | sal_image, sal_label = cv_random_flip(sal_image, sal_label) 43 | sal_image = torch.Tensor(sal_image) 44 | sal_label = torch.Tensor(sal_label) 45 | 46 | sample = {'edge_image': edge_image, 'edge_label': edge_label, 'sal_image': sal_image, 'sal_label': sal_label} 47 | return sample 48 | 49 | def __len__(self): 50 | return max(self.sal_num, self.edge_num) 51 | 52 | class ImageDataTest(data.Dataset): 53 | def __init__(self, data_root, data_list): 54 | self.data_root = data_root 55 | self.data_list = data_list 56 | with open(self.data_list, 'r') as f: 57 | self.image_list = [x.strip() for x in f.readlines()] 58 | 59 | self.image_num = len(self.image_list) 60 | 61 | def __getitem__(self, item): 62 | image, im_size = load_image_test(os.path.join(self.data_root, self.image_list[item])) 63 | image = torch.Tensor(image) 64 | 65 | return {'image': image, 'name': self.image_list[item % self.image_num], 'size': im_size} 66 | 67 | def __len__(self): 68 | return self.image_num 69 | 70 | 71 | def get_loader(config, mode='train', pin=False): 72 | shuffle = False 73 | if mode == 'train': 74 | shuffle = True 75 | dataset = ImageDataTrain(config.train_root, config.train_list, config.train_edge_root, config.train_edge_list) 76 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 77 | else: 78 | dataset = ImageDataTest(config.test_root, config.test_list) 79 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 80 | return data_loader 81 | 82 | def load_image(path): 83 | if not os.path.exists(path): 84 | print('File {} not exists'.format(path)) 85 | im = cv2.imread(path) 86 | in_ = np.array(im, dtype=np.float32) 87 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 88 | in_ = in_.transpose((2,0,1)) 89 | return in_ 90 | 91 | def load_image_test(path): 92 | if not os.path.exists(path): 93 | print('File {} not exists'.format(path)) 94 | im = cv2.imread(path) 95 | in_ = np.array(im, dtype=np.float32) 96 | im_size = tuple(in_.shape[:2]) 97 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 98 | in_ = in_.transpose((2,0,1)) 99 | return in_, im_size 100 | 101 | def load_sal_label(path): 102 | if not os.path.exists(path): 103 | print('File {} not exists'.format(path)) 104 | im = Image.open(path) 105 | label = np.array(im, dtype=np.float32) 106 | if len(label.shape) == 3: 107 | label = label[:,:,0] 108 | label = label / 255. 109 | label = label[np.newaxis, ...] 110 | return label 111 | 112 | def load_edge_label(path): 113 | """ 114 | pixels > 0.5 -> 1. 115 | """ 116 | if not os.path.exists(path): 117 | print('File {} not exists'.format(path)) 118 | im = Image.open(path) 119 | label = np.array(im, dtype=np.float32) 120 | if len(label.shape) == 3: 121 | label = label[:,:,0] 122 | label = label / 255. 123 | label[np.where(label > 0.5)] = 1. 124 | label = label[np.newaxis, ...] 125 | return label 126 | 127 | def cv_random_flip(img, label): 128 | flip_flag = random.randint(0, 1) 129 | if flip_flag == 1: 130 | img = img[:,:,::-1].copy() 131 | label = label[:,:,::-1].copy() 132 | return img, label 133 | -------------------------------------------------------------------------------- /networks/poolnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import math 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from .deeplab_resnet import resnet50_locate 10 | from .vgg import vgg16_locate 11 | 12 | 13 | config_vgg = {'convert': [[128,256,512,512,512],[64,128,256,512,512]], 'deep_pool': [[512, 512, 256, 128], [512, 256, 128, 128], [True, True, True, False], [True, True, True, False]], 'score': 128} # no convert layer, no conv6 14 | 15 | config_resnet = {'convert': [[64,256,512,1024,2048],[128,256,256,512,512]], 'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]], 'score': 128} 16 | 17 | class ConvertLayer(nn.Module): 18 | def __init__(self, list_k): 19 | super(ConvertLayer, self).__init__() 20 | up = [] 21 | for i in range(len(list_k[0])): 22 | up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True))) 23 | self.convert0 = nn.ModuleList(up) 24 | 25 | def forward(self, list_x): 26 | resl = [] 27 | for i in range(len(list_x)): 28 | resl.append(self.convert0[i](list_x[i])) 29 | return resl 30 | 31 | class DeepPoolLayer(nn.Module): 32 | def __init__(self, k, k_out, need_x2, need_fuse): 33 | super(DeepPoolLayer, self).__init__() 34 | self.pools_sizes = [2,4,8] 35 | self.need_x2 = need_x2 36 | self.need_fuse = need_fuse 37 | pools, convs = [],[] 38 | for i in self.pools_sizes: 39 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 40 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 41 | self.pools = nn.ModuleList(pools) 42 | self.convs = nn.ModuleList(convs) 43 | self.relu = nn.ReLU() 44 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 45 | if self.need_fuse: 46 | self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False) 47 | 48 | def forward(self, x, x2=None, x3=None): 49 | x_size = x.size() 50 | resl = x 51 | for i in range(len(self.pools_sizes)): 52 | y = self.convs[i](self.pools[i](x)) 53 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 54 | resl = self.relu(resl) 55 | if self.need_x2: 56 | resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True) 57 | resl = self.conv_sum(resl) 58 | if self.need_fuse: 59 | resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3)) 60 | return resl 61 | 62 | class ScoreLayer(nn.Module): 63 | def __init__(self, k): 64 | super(ScoreLayer, self).__init__() 65 | self.score = nn.Conv2d(k ,1, 1, 1) 66 | 67 | def forward(self, x, x_size=None): 68 | x = self.score(x) 69 | if x_size is not None: 70 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 71 | return x 72 | 73 | def extra_layer(base_model_cfg, vgg): 74 | if base_model_cfg == 'vgg': 75 | config = config_vgg 76 | elif base_model_cfg == 'resnet': 77 | config = config_resnet 78 | convert_layers, deep_pool_layers, score_layers = [], [], [] 79 | convert_layers = ConvertLayer(config['convert']) 80 | 81 | for i in range(len(config['deep_pool'][0])): 82 | deep_pool_layers += [DeepPoolLayer(config['deep_pool'][0][i], config['deep_pool'][1][i], config['deep_pool'][2][i], config['deep_pool'][3][i])] 83 | 84 | score_layers = ScoreLayer(config['score']) 85 | 86 | return vgg, convert_layers, deep_pool_layers, score_layers 87 | 88 | 89 | class PoolNet(nn.Module): 90 | def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, score_layers): 91 | super(PoolNet, self).__init__() 92 | self.base_model_cfg = base_model_cfg 93 | self.base = base 94 | self.deep_pool = nn.ModuleList(deep_pool_layers) 95 | self.score = score_layers 96 | if self.base_model_cfg == 'resnet': 97 | self.convert = convert_layers 98 | 99 | def forward(self, x): 100 | x_size = x.size() 101 | conv2merge, infos = self.base(x) 102 | if self.base_model_cfg == 'resnet': 103 | conv2merge = self.convert(conv2merge) 104 | conv2merge = conv2merge[::-1] 105 | 106 | edge_merge = [] 107 | merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0]) 108 | for k in range(1, len(conv2merge)-1): 109 | merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k]) 110 | 111 | merge = self.deep_pool[-1](merge) 112 | merge = self.score(merge, x_size) 113 | return merge 114 | 115 | def build_model(base_model_cfg='vgg'): 116 | if base_model_cfg == 'vgg': 117 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, vgg16_locate())) 118 | elif base_model_cfg == 'resnet': 119 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, resnet50_locate())) 120 | 121 | def weights_init(m): 122 | if isinstance(m, nn.Conv2d): 123 | m.weight.data.normal_(0, 0.01) 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from torch.nn import utils, functional as F 4 | from torch.optim import Adam 5 | from torch.autograd import Variable 6 | from torch.backends import cudnn 7 | from networks.poolnet_1task import build_model, weights_init 8 | import scipy.misc as sm 9 | import numpy as np 10 | import os 11 | import torchvision.utils as vutils 12 | import cv2 13 | import math 14 | import time 15 | 16 | 17 | class Solver(object): 18 | def __init__(self, train_loader, test_loader, config): 19 | self.train_loader = train_loader 20 | self.test_loader = test_loader 21 | self.config = config 22 | self.iter_size = config.iter_size 23 | self.show_every = config.show_every 24 | self.lr_decay_epoch = [15,] 25 | self.build_model() 26 | if config.mode == 'test': 27 | print('Loading pre-trained model from %s...' % self.config.model) 28 | self.net.load_state_dict(torch.load(self.config.model)) 29 | self.net.eval() 30 | 31 | # print the network information and parameter numbers 32 | def print_network(self, model, name): 33 | num_params = 0 34 | for p in model.parameters(): 35 | num_params += p.numel() 36 | print(name) 37 | print(model) 38 | print("The number of parameters: {}".format(num_params)) 39 | 40 | # build the network 41 | def build_model(self): 42 | self.net = build_model(self.config.arch) 43 | if self.config.cuda: 44 | self.net = self.net.cuda() 45 | # self.net.train() 46 | self.net.eval() # use_global_stats = True 47 | self.net.apply(weights_init) 48 | if self.config.load == '': 49 | self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model)) 50 | else: 51 | self.net.load_state_dict(torch.load(self.config.load)) 52 | 53 | self.lr = self.config.lr 54 | self.wd = self.config.wd 55 | 56 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 57 | self.print_network(self.net, 'PoolNet Structure') 58 | 59 | def test(self): 60 | mode_name = 'sal_fuse' 61 | time_s = time.time() 62 | img_num = len(self.test_loader) 63 | for i, data_batch in enumerate(self.test_loader): 64 | images, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) 65 | with torch.no_grad(): 66 | images = Variable(images) 67 | if self.config.cuda: 68 | images = images.cuda() 69 | preds = self.net(images) 70 | pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy()) 71 | multi_fuse = 255 * pred 72 | cv2.imwrite(os.path.join(self.config.test_fold, name[:-4] + '_' + mode_name + '.png'), multi_fuse) 73 | time_e = time.time() 74 | print('Speed: %f FPS' % (img_num/(time_e-time_s))) 75 | print('Test Done!') 76 | 77 | # training phase 78 | def train(self): 79 | iter_num = len(self.train_loader.dataset) // self.config.batch_size 80 | aveGrad = 0 81 | for epoch in range(self.config.epoch): 82 | r_sal_loss= 0 83 | self.net.zero_grad() 84 | for i, data_batch in enumerate(self.train_loader): 85 | sal_image, sal_label = data_batch['sal_image'], data_batch['sal_label'] 86 | if (sal_image.size(2) != sal_label.size(2)) or (sal_image.size(3) != sal_label.size(3)): 87 | print('IMAGE ERROR, PASSING```') 88 | continue 89 | sal_image, sal_label= Variable(sal_image), Variable(sal_label) 90 | if self.config.cuda: 91 | # cudnn.benchmark = True 92 | sal_image, sal_label = sal_image.cuda(), sal_label.cuda() 93 | 94 | sal_pred = self.net(sal_image) 95 | sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred, sal_label, reduction='sum') 96 | sal_loss = sal_loss_fuse / (self.iter_size * self.config.batch_size) 97 | r_sal_loss += sal_loss.data 98 | 99 | sal_loss.backward() 100 | 101 | aveGrad += 1 102 | 103 | # accumulate gradients as done in DSS 104 | if aveGrad % self.iter_size == 0: 105 | self.optimizer.step() 106 | self.optimizer.zero_grad() 107 | aveGrad = 0 108 | 109 | if i % (self.show_every // self.config.batch_size) == 0: 110 | if i == 0: 111 | x_showEvery = 1 112 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || Sal : %10.4f' % ( 113 | epoch, self.config.epoch, i, iter_num, r_sal_loss/x_showEvery)) 114 | print('Learning rate: ' + str(self.lr)) 115 | r_sal_loss= 0 116 | 117 | if (epoch + 1) % self.config.epoch_save == 0: 118 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_folder, epoch + 1)) 119 | 120 | if epoch in self.lr_decay_epoch: 121 | self.lr = self.lr * 0.1 122 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 123 | 124 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_folder) 125 | 126 | def bce2d(input, target, reduction=None): 127 | assert(input.size() == target.size()) 128 | pos = torch.eq(target, 1).float() 129 | neg = torch.eq(target, 0).float() 130 | 131 | num_pos = torch.sum(pos) 132 | num_neg = torch.sum(neg) 133 | num_total = num_pos + num_neg 134 | 135 | alpha = num_neg / num_total 136 | beta = 1.1 * num_pos / num_total 137 | # target pixel = 1 -> weight beta 138 | # target pixel = 0 -> weight 1-beta 139 | weights = alpha * pos + beta * neg 140 | 141 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) 142 | 143 | -------------------------------------------------------------------------------- /networks/deeplab_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | affine_par = True 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, dilation_ = 1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 51 | self.bn1 = nn.BatchNorm2d(planes,affine = affine_par) 52 | for i in self.bn1.parameters(): 53 | i.requires_grad = False 54 | padding = 1 55 | if dilation_ == 2: 56 | padding = 2 57 | elif dilation_ == 4: 58 | padding = 4 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 60 | padding=padding, bias=False, dilation = dilation_) 61 | self.bn2 = nn.BatchNorm2d(planes,affine = affine_par) 62 | for i in self.bn2.parameters(): 63 | i.requires_grad = False 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4, affine = affine_par) 66 | for i in self.bn3.parameters(): 67 | i.requires_grad = False 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, block, layers): 96 | self.inplanes = 64 97 | super(ResNet, self).__init__() 98 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 99 | bias=False) 100 | self.bn1 = nn.BatchNorm2d(64,affine = affine_par) 101 | for i in self.bn1.parameters(): 102 | i.requires_grad = False 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation__ = 2) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, 0.01) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1,dilation__ = 1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion or dilation__ == 2 or dilation__ == 4: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par), 125 | ) 126 | for i in downsample._modules['1'].parameters(): 127 | i.requires_grad = False 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride,dilation_=dilation__, downsample = downsample )) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes,dilation_=dilation__)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | tmp_x = [] 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | tmp_x.append(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | tmp_x.append(x) 146 | x = self.layer2(x) 147 | tmp_x.append(x) 148 | x = self.layer3(x) 149 | tmp_x.append(x) 150 | x = self.layer4(x) 151 | tmp_x.append(x) 152 | 153 | return tmp_x 154 | 155 | 156 | class ResNet_locate(nn.Module): 157 | def __init__(self, block, layers): 158 | super(ResNet_locate,self).__init__() 159 | self.resnet = ResNet(block, layers) 160 | self.in_planes = 512 161 | self.out_planes = [512, 256, 256, 128] 162 | 163 | self.ppms_pre = nn.Conv2d(2048, self.in_planes, 1, 1, bias=False) 164 | ppms, infos = [], [] 165 | for ii in [1, 3, 5]: 166 | ppms.append(nn.Sequential(nn.AdaptiveAvgPool2d(ii), nn.Conv2d(self.in_planes, self.in_planes, 1, 1, bias=False), nn.ReLU(inplace=True))) 167 | self.ppms = nn.ModuleList(ppms) 168 | 169 | self.ppm_cat = nn.Sequential(nn.Conv2d(self.in_planes * 4, self.in_planes, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 170 | for ii in self.out_planes: 171 | infos.append(nn.Sequential(nn.Conv2d(self.in_planes, ii, 3, 1, 1, bias=False), nn.ReLU(inplace=True))) 172 | self.infos = nn.ModuleList(infos) 173 | 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 177 | m.weight.data.normal_(0, 0.01) 178 | elif isinstance(m, nn.BatchNorm2d): 179 | m.weight.data.fill_(1) 180 | m.bias.data.zero_() 181 | 182 | def load_pretrained_model(self, model): 183 | self.resnet.load_state_dict(model, strict=False) 184 | 185 | def forward(self, x): 186 | x_size = x.size()[2:] 187 | xs = self.resnet(x) 188 | 189 | xs_1 = self.ppms_pre(xs[-1]) 190 | xls = [xs_1] 191 | for k in range(len(self.ppms)): 192 | xls.append(F.interpolate(self.ppms[k](xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True)) 193 | xls = self.ppm_cat(torch.cat(xls, dim=1)) 194 | 195 | infos = [] 196 | for k in range(len(self.infos)): 197 | infos.append(self.infos[k](F.interpolate(xls, xs[len(self.infos) - 1 - k].size()[2:], mode='bilinear', align_corners=True))) 198 | 199 | return xs, infos 200 | 201 | def resnet50_locate(): 202 | model = ResNet_locate(Bottleneck, [3, 4, 6, 3]) 203 | return model 204 | -------------------------------------------------------------------------------- /joint_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from torch.nn import utils, functional as F 4 | from torch.optim import Adam 5 | from torch.autograd import Variable 6 | from torch.backends import cudnn 7 | from networks.joint_poolnet import build_model, weights_init 8 | import scipy.misc as sm 9 | import numpy as np 10 | import os 11 | import torchvision.utils as vutils 12 | import cv2 13 | import math 14 | import time 15 | 16 | 17 | class Solver(object): 18 | def __init__(self, train_loader, test_loader, config): 19 | self.train_loader = train_loader 20 | self.test_loader = test_loader 21 | self.config = config 22 | self.iter_size = config.iter_size 23 | self.show_every = config.show_every 24 | self.lr_decay_epoch = [8,] 25 | self.build_model() 26 | if config.mode == 'test': 27 | print('Loading pre-trained model from %s...' % self.config.model) 28 | self.net.load_state_dict(torch.load(self.config.model)) 29 | self.net.eval() 30 | 31 | # print the network information and parameter numbers 32 | def print_network(self, model, name): 33 | num_params = 0 34 | for p in model.parameters(): 35 | num_params += p.numel() 36 | print(name) 37 | print(model) 38 | print("The number of parameters: {}".format(num_params)) 39 | 40 | # build the network 41 | def build_model(self): 42 | self.net = build_model(self.config.arch) 43 | if self.config.cuda: 44 | self.net = self.net.cuda() 45 | # self.net.train() 46 | self.net.eval() # use_global_stats = True 47 | self.net.apply(weights_init) 48 | if self.config.load == '': 49 | self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model)) 50 | else: 51 | self.net.load_state_dict(torch.load(self.config.load)) 52 | 53 | self.lr = self.config.lr 54 | self.wd = self.config.wd 55 | 56 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 57 | self.print_network(self.net, 'PoolNet Structure') 58 | 59 | def test(self, test_mode=1): 60 | mode_name = ['edge_fuse', 'sal_fuse'] 61 | EPSILON = 1e-8 62 | time_s = time.time() 63 | img_num = len(self.test_loader) 64 | for i, data_batch in enumerate(self.test_loader): 65 | images, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) 66 | if test_mode == 0: 67 | images = images.numpy()[0].transpose((1,2,0)) 68 | scale = [0.5, 1, 1.5, 2] # uncomment for multi-scale testing 69 | # scale = [1] 70 | multi_fuse = np.zeros(im_size, np.float32) 71 | for k in range(0, len(scale)): 72 | im_ = cv2.resize(images, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) 73 | im_ = im_.transpose((2, 0, 1)) 74 | im_ = torch.Tensor(im_[np.newaxis, ...]) 75 | 76 | with torch.no_grad(): 77 | im_ = Variable(im_) 78 | if self.config.cuda: 79 | im_ = im_.cuda() 80 | preds = self.net(im_, mode=test_mode) 81 | pred_0 = np.squeeze(torch.sigmoid(preds[1][0]).cpu().data.numpy()) 82 | pred_1 = np.squeeze(torch.sigmoid(preds[1][1]).cpu().data.numpy()) 83 | pred_2 = np.squeeze(torch.sigmoid(preds[1][2]).cpu().data.numpy()) 84 | pred_fuse = np.squeeze(torch.sigmoid(preds[0]).cpu().data.numpy()) 85 | 86 | pred = (pred_0 + pred_1 + pred_2 + pred_fuse) / 4 87 | pred = (pred - np.min(pred) + EPSILON) / (np.max(pred) - np.min(pred) + EPSILON) 88 | 89 | pred = cv2.resize(pred, (im_size[1], im_size[0]), interpolation=cv2.INTER_LINEAR) 90 | multi_fuse += pred 91 | 92 | multi_fuse /= len(scale) 93 | multi_fuse = 255 * (1 - multi_fuse) 94 | cv2.imwrite(os.path.join(self.config.test_fold, name[:-4] + '_' + mode_name[test_mode] + '.png'), multi_fuse) 95 | elif test_mode == 1: 96 | with torch.no_grad(): 97 | images = Variable(images) 98 | if self.config.cuda: 99 | images = images.cuda() 100 | preds = self.net(images, mode=test_mode) 101 | pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy()) 102 | multi_fuse = 255 * pred 103 | cv2.imwrite(os.path.join(self.config.test_fold, name[:-4] + '_' + mode_name[test_mode] + '.png'), multi_fuse) 104 | time_e = time.time() 105 | print('Speed: %f FPS' % (img_num/(time_e-time_s))) 106 | print('Test Done!') 107 | 108 | # training phase 109 | def train(self): 110 | iter_num = 30000 # each batch only train 30000 iters.(This number is just a random choice...) 111 | aveGrad = 0 112 | for epoch in range(self.config.epoch): 113 | r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 114 | self.net.zero_grad() 115 | for i, data_batch in enumerate(self.train_loader): 116 | if (i + 1) == iter_num: break 117 | edge_image, edge_label, sal_image, sal_label = data_batch['edge_image'], data_batch['edge_label'], data_batch['sal_image'], data_batch['sal_label'] 118 | if (sal_image.size(2) != sal_label.size(2)) or (sal_image.size(3) != sal_label.size(3)): 119 | print('IMAGE ERROR, PASSING```') 120 | continue 121 | edge_image, edge_label, sal_image, sal_label= Variable(edge_image), Variable(edge_label), Variable(sal_image), Variable(sal_label) 122 | if self.config.cuda: 123 | edge_image, edge_label, sal_image, sal_label = edge_image.cuda(), edge_label.cuda(), sal_image.cuda(), sal_label.cuda() 124 | 125 | # edge part 126 | edge_pred = self.net(edge_image, mode=0) 127 | edge_loss_fuse = bce2d(edge_pred[0], edge_label, reduction='sum') 128 | edge_loss_part = [] 129 | for ix in edge_pred[1]: 130 | edge_loss_part.append(bce2d(ix, edge_label, reduction='sum')) 131 | edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (self.iter_size * self.config.batch_size) 132 | r_edge_loss += edge_loss.data 133 | 134 | # sal part 135 | sal_pred = self.net(sal_image, mode=1) 136 | sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred, sal_label, reduction='sum') 137 | sal_loss = sal_loss_fuse / (self.iter_size * self.config.batch_size) 138 | r_sal_loss += sal_loss.data 139 | 140 | loss = sal_loss + edge_loss 141 | r_sum_loss += loss.data 142 | 143 | loss.backward() 144 | 145 | aveGrad += 1 146 | 147 | # accumulate gradients as done in DSS 148 | if aveGrad % self.iter_size == 0: 149 | self.optimizer.step() 150 | self.optimizer.zero_grad() 151 | aveGrad = 0 152 | 153 | if i % (self.show_every // self.config.batch_size) == 0: 154 | if i == 0: 155 | x_showEvery = 1 156 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || Edge : %10.4f || Sal : %10.4f || Sum : %10.4f' % ( 157 | epoch, self.config.epoch, i, iter_num, r_edge_loss/x_showEvery, r_sal_loss/x_showEvery, r_sum_loss/x_showEvery)) 158 | print('Learning rate: ' + str(self.lr)) 159 | r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 160 | 161 | if (epoch + 1) % self.config.epoch_save == 0: 162 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_folder, epoch + 1)) 163 | 164 | if epoch in self.lr_decay_epoch: 165 | self.lr = self.lr * 0.1 166 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 167 | 168 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_folder) 169 | 170 | def bce2d(input, target, reduction=None): 171 | assert(input.size() == target.size()) 172 | pos = torch.eq(target, 1).float() 173 | neg = torch.eq(target, 0).float() 174 | 175 | num_pos = torch.sum(pos) 176 | num_neg = torch.sum(neg) 177 | num_total = num_pos + num_neg 178 | 179 | alpha = num_neg / num_total 180 | beta = 1.1 * num_pos / num_total 181 | # target pixel = 1 -> weight beta 182 | # target pixel = 0 -> weight 1-beta 183 | weights = alpha * pos + beta * neg 184 | 185 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) 186 | 187 | -------------------------------------------------------------------------------- /networks/joint_poolnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import math 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from .deeplab_resnet import resnet50_locate 10 | from .vgg import vgg16_locate 11 | 12 | 13 | config_vgg = {'convert': [[128,256,512,512,512],[64,128,256,512,512]], 'deep_pool': [[512, 512, 256, 128], [512, 256, 128, 128], [True, True, True, False], [True, True, True, False]], 'score': 256, 'edgeinfoc':[48,128], 'block': [[512, [16]], [256, [16]], [128, [16]]], 'fuse': [[16, 16, 16], True]} # no convert layer, no conv6 14 | 15 | config_resnet = {'convert': [[64,256,512,1024,2048],[128,256,256,512,512]], 'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]], 'score': 256, 'edgeinfoc':[64,128], 'block': [[512, [16]], [256, [16]], [256, [16]], [128, [16]]], 'fuse': [[16, 16, 16, 16], True]} 16 | 17 | class ConvertLayer(nn.Module): 18 | def __init__(self, list_k): 19 | super(ConvertLayer, self).__init__() 20 | up = [] 21 | for i in range(len(list_k[0])): 22 | up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True))) 23 | self.convert0 = nn.ModuleList(up) 24 | 25 | def forward(self, list_x): 26 | resl = [] 27 | for i in range(len(list_x)): 28 | resl.append(self.convert0[i](list_x[i])) 29 | return resl 30 | 31 | class DeepPoolLayer(nn.Module): 32 | def __init__(self, k, k_out, need_x2, need_fuse): 33 | super(DeepPoolLayer, self).__init__() 34 | self.pools_sizes = [2,4,8] 35 | self.need_x2 = need_x2 36 | self.need_fuse = need_fuse 37 | pools, convs = [],[] 38 | for i in self.pools_sizes: 39 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 40 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 41 | self.pools = nn.ModuleList(pools) 42 | self.convs = nn.ModuleList(convs) 43 | self.relu = nn.ReLU() 44 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 45 | if self.need_fuse: 46 | self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False) 47 | 48 | def forward(self, x, x2=None, x3=None): 49 | x_size = x.size() 50 | resl = x 51 | for i in range(len(self.pools_sizes)): 52 | y = self.convs[i](self.pools[i](x)) 53 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 54 | resl = self.relu(resl) 55 | if self.need_x2: 56 | resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True) 57 | resl = self.conv_sum(resl) 58 | if self.need_fuse: 59 | resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3)) 60 | return resl 61 | 62 | class BlockLayer(nn.Module): 63 | def __init__(self, k_in, k_out_list): 64 | super(BlockLayer, self).__init__() 65 | up_in1, up_mid1, up_in2, up_mid2, up_out = [], [], [], [], [] 66 | 67 | for k in k_out_list: 68 | up_in1.append(nn.Conv2d(k_in, k_in//4, 1, 1, bias=False)) 69 | up_mid1.append(nn.Sequential(nn.Conv2d(k_in//4, k_in//4, 3, 1, 1, bias=False), nn.Conv2d(k_in//4, k_in, 1, 1, bias=False))) 70 | up_in2.append(nn.Conv2d(k_in, k_in//4, 1, 1, bias=False)) 71 | up_mid2.append(nn.Sequential(nn.Conv2d(k_in//4, k_in//4, 3, 1, 1, bias=False), nn.Conv2d(k_in//4, k_in, 1, 1, bias=False))) 72 | up_out.append(nn.Conv2d(k_in, k, 1, 1, bias=False)) 73 | 74 | self.block_in1 = nn.ModuleList(up_in1) 75 | self.block_in2 = nn.ModuleList(up_in2) 76 | self.block_mid1 = nn.ModuleList(up_mid1) 77 | self.block_mid2 = nn.ModuleList(up_mid2) 78 | self.block_out = nn.ModuleList(up_out) 79 | self.relu = nn.ReLU() 80 | 81 | def forward(self, x, mode=0): 82 | x_tmp = self.relu(x + self.block_mid1[mode](self.block_in1[mode](x))) 83 | # x_tmp = self.block_mid2[mode](self.block_in2[mode](self.relu(x + x_tmp))) 84 | x_tmp = self.relu(x_tmp + self.block_mid2[mode](self.block_in2[mode](x_tmp))) 85 | x_tmp = self.block_out[mode](x_tmp) 86 | 87 | return x_tmp 88 | 89 | class EdgeInfoLayerC(nn.Module): 90 | def __init__(self, k_in, k_out): 91 | super(EdgeInfoLayerC, self).__init__() 92 | self.trans = nn.Sequential(nn.Conv2d(k_in, k_in, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 93 | nn.Conv2d(k_in, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 94 | nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 95 | nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 96 | 97 | def forward(self, x, x_size): 98 | tmp_x = [] 99 | for i_x in x: 100 | tmp_x.append(F.interpolate(i_x, x_size[2:], mode='bilinear', align_corners=True)) 101 | x = self.trans(torch.cat(tmp_x, dim=1)) 102 | return x 103 | 104 | class FuseLayer1(nn.Module): 105 | def __init__(self, list_k, deep_sup): 106 | super(FuseLayer1, self).__init__() 107 | up = [] 108 | for i in range(len(list_k)): 109 | up.append(nn.Conv2d(list_k[i], 1, 1, 1)) 110 | self.trans = nn.ModuleList(up) 111 | self.fuse = nn.Conv2d(len(list_k), 1, 1, 1) 112 | self.deep_sup = deep_sup 113 | 114 | def forward(self, list_x, x_size): 115 | up_x = [] 116 | for i, i_x in enumerate(list_x): 117 | up_x.append(F.interpolate(self.trans[i](i_x), x_size[2:], mode='bilinear', align_corners=True)) 118 | out_fuse = self.fuse(torch.cat(up_x, dim = 1)) 119 | if self.deep_sup: 120 | out_all = [] 121 | for up_i in up_x: 122 | out_all.append(up_i) 123 | return [out_fuse, out_all] 124 | else: 125 | return [out_fuse] 126 | 127 | class ScoreLayer(nn.Module): 128 | def __init__(self, k): 129 | super(ScoreLayer, self).__init__() 130 | self.score = nn.Conv2d(k ,1, 3, 1, 1) 131 | 132 | def forward(self, x, x_size=None): 133 | x = self.score(x) 134 | if x_size is not None: 135 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 136 | return x 137 | 138 | def extra_layer(base_model_cfg, base): 139 | if base_model_cfg == 'vgg': 140 | config = config_vgg 141 | elif base_model_cfg == 'resnet': 142 | config = config_resnet 143 | convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers = [], [], [], [], [], [] 144 | convert_layers = ConvertLayer(config['convert']) 145 | 146 | for k in config['block']: 147 | block_layers += [BlockLayer(k[0], k[1])] 148 | 149 | for i in range(len(config['deep_pool'][0])): 150 | deep_pool_layers += [DeepPoolLayer(config['deep_pool'][0][i], config['deep_pool'][1][i], config['deep_pool'][2][i], config['deep_pool'][3][i])] 151 | 152 | fuse_layers = FuseLayer1(config['fuse'][0], config['fuse'][1]) 153 | edgeinfo_layers = EdgeInfoLayerC(config['edgeinfoc'][0], config['edgeinfoc'][1]) 154 | score_layers = ScoreLayer(config['score']) 155 | 156 | return base, convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers 157 | 158 | 159 | class PoolNet(nn.Module): 160 | def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers): 161 | super(PoolNet, self).__init__() 162 | self.base_model_cfg = base_model_cfg 163 | self.base = base 164 | self.block = nn.ModuleList(block_layers) 165 | self.deep_pool = nn.ModuleList(deep_pool_layers) 166 | self.fuse = fuse_layers 167 | self.edgeinfo = edgeinfo_layers 168 | self.score = score_layers 169 | if self.base_model_cfg == 'resnet': 170 | self.convert = convert_layers 171 | 172 | def forward(self, x, mode): 173 | x_size = x.size() 174 | conv2merge, infos = self.base(x) 175 | if self.base_model_cfg == 'resnet': 176 | conv2merge = self.convert(conv2merge) 177 | conv2merge = conv2merge[::-1] 178 | 179 | edge_merge = [] 180 | merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0]) 181 | edge_merge.append(merge) 182 | for k in range(1, len(conv2merge)-1): 183 | merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k]) 184 | edge_merge.append(merge) 185 | 186 | if mode == 0: 187 | edge_merge = [self.block[i](kk) for i, kk in enumerate(edge_merge)] 188 | merge = self.fuse(edge_merge, x_size) 189 | elif mode == 1: 190 | merge = self.deep_pool[-1](merge) 191 | edge_merge = [self.block[i](kk).detach() for i, kk in enumerate(edge_merge)] 192 | edge_merge = self.edgeinfo(edge_merge, merge.size()) 193 | merge = self.score(torch.cat([merge, edge_merge], dim=1), x_size) 194 | return merge 195 | 196 | def build_model(base_model_cfg='vgg'): 197 | if base_model_cfg == 'vgg': 198 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, vgg16_locate())) 199 | elif base_model_cfg == 'resnet': 200 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, resnet50_locate())) 201 | 202 | def weights_init(m): 203 | if isinstance(m, nn.Conv2d): 204 | m.weight.data.normal_(0, 0.01) 205 | if m.bias is not None: 206 | m.bias.data.zero_() 207 | --------------------------------------------------------------------------------