├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── dataset.py ├── dataset_old.py └── joint_dataset.py ├── forward.sh ├── forward_edge.sh ├── joint_main.py ├── joint_solver.py ├── joint_train_res2net.sh ├── main.py ├── networks ├── __init__.py ├── deeplab_res2net.py ├── deeplab_resnet.py ├── joint_poolnet.py ├── joint_poolnet_res2net.py ├── poolnet.py ├── poolnet_res2net.py └── vgg.py ├── solver.py └── train_res2net.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jiang-Jiang Liu and Qibin Hou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Res2Net for Salient Object Detection using PoolNet 2 | 3 | ## Introduction 4 | This repo uses [*PoolNet* (cvpr19)](https://arxiv.org/abs/1904.09569) as the baseline method for Salient Object Detection . 5 | 6 | [Res2Net](https://github.com/gasvn/Res2Net) is a powerful backbone architecture that can be easily implemented into state-of-the-art models by replacing the bottleneck with Res2Net module. 7 | More detail can be found on [ "Res2Net: A New Multi-scale Backbone Architecture"](https://arxiv.org/pdf/1904.01169.pdf) 8 | 9 | ## Performance 10 | 11 | ### Results on salient object detection datasets **without** joint training with edge. Models are trained using DUTS-TR. 12 | 13 | | Backbone | ECSSD | PASCAL-S | DUT-O | HKU-IS | SOD | DUTS-TE | 14 | |--------------|--------------|---------------|---------------|----------------|-----------------|----------------| 15 | | - | MaxF & MAE | MaxF & MAE | MaxF & MAE | MaxF & MAE | MaxF & MAE | MaxF & MAE | 16 | | vgg |0.936 & 0.047 | 0.857 & 0.078 | 0.817 & 0.058 | 0.928 & 0.035 | 0.859 & 0.115 | 0.876 & 0.043 | 17 | | resnet50 |0.940 & 0.042 | 0.863 & 0.075 | 0.830 & 0.055 | 0.934 & 0.032 | 0.867 & 0.100 | 0.886 & 0.040 | 18 | | **res2net50**|0.947 & 0.036 | 0.871 & 0.070 | 0.837 & 0.052 | 0.936 & 0.031 | 0.885 & 0.096 | 0.892 & 0.037 | 19 | 20 | 21 | 22 | ## Evaluation 23 | 24 | You may refer to this repo for results evaluation: [SalMetric](https://github.com/Andrew-Qibin/SalMetric). 25 | 26 | ## Todo 27 | We will merge this repo into the official repo of PoolNet soon. 28 | We only modify the normalization of inputs of the PoolNet as follows: 29 | ``` 30 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | ``` 33 | ## Usage 34 | 35 | ### Prerequisites 36 | 37 | - [Pytorch 0.4.1+](http://pytorch.org/) 38 | 39 | ### 1. Clone the repository 40 | 41 | ```shell 42 | https://github.com/gasvn/Res2Net-PoolNet.git 43 | cd Res2Net_PoolNet/ 44 | ``` 45 | 46 | ### 2. Download the datasets 47 | 48 | Download the following datasets and unzip them into `data` folder. 49 | 50 | * [MSRA-B and HKU-IS](https://drive.google.com/open?id=14RA-qr7JxU6iljLv6PbWUCQG0AJsEgmd) dataset. The .lst file for training is `data/msrab_hkuis/msrab_hkuis_train_no_small.lst`. 51 | * [DUTS](https://drive.google.com/open?id=1immMDAPC9Eb2KCtGi6AdfvXvQJnSkHHo) dataset. The .lst file for training is `data/DUTS/DUTS-TR/train_pair.lst`. 52 | * [BSDS-PASCAL](https://drive.google.com/open?id=1qx8eyDNAewAAc6hlYHx3B9LXvEGSIqQp) dataset. The .lst file for training is `./data/HED-BSDS_PASCAL/bsds_pascal_train_pair_r_val_r_small.lst`. 53 | * [Datasets for testing](https://drive.google.com/open?id=1eB-59cMrYnhmMrz7hLWQ7mIssRaD-f4o). 54 | 55 | ### 3. Download the pre-trained models for backbone 56 | 57 | Download the pretrained models of Res2Net50 from [Res2Net](https://github.com/gasvn/Res2Net) . 58 | Set the path to pretrain model of Res2Net in `main.py` (line 55) 59 | ``` 60 | res2net_path = '/home/shgao/.torch/models/res2net50_26w_4s-06e79181.pth' 61 | ``` 62 | ### 4. Train 63 | 64 | 1. Set the `--train_root` and `--train_list` path in `train_res2net.sh` correctly. 65 | 66 | 2. We demo using Res2Net-50 as network backbone and train with a initial lr of 5e-5 for 24 epoches, which is divided by 10 after 15 epochs. 67 | ```shell 68 | ./train_res2net.sh 69 | ``` 70 | 3. We demo joint training with edge using Res2Net-50 as network backbone and train with a initial lr of 5e-5 for 11 epoches, which is divided by 10 after 8 epochs. Each epoch runs for 30000 iters. 71 | ```shell 72 | ./joint_train_res2net.sh 73 | ``` 74 | 4. After training the result model will be stored under `results/run-*` folder. 75 | 76 | ### 5. Test 77 | 78 | For single dataset testing: `*` changes accordingly and `--sal_mode` indicates different datasets (details can be found in `main.py`) 79 | ```shell 80 | python main.py --mode='test' --model='results/run-*/models/final.pth' --test_fold='results/run-*-sal-e' --sal_mode='e' --arch res2net 81 | ``` 82 | For all datasets testing used in our paper: `0` indicates the gpu ID to use 83 | ```shell 84 | ./forward.sh 0 main.py results/run-* 85 | ``` 86 | For joint training, to get salient object detection results use 87 | ```shell 88 | ./forward.sh 0 joint_main.py results/run-* 89 | ``` 90 | to get edge detection results use 91 | ```shell 92 | ./forward_edge.sh 0 joint_main.py results/run-* 93 | ``` 94 | 95 | All results saliency maps will be stored under `results/run-*-sal-*` folders in .png formats. 96 | 97 | 98 | ### 6. Pre-trained models 99 | 100 | The pretrained models for SOD using Res2Net is now available on [ONEDRIVE](https://1drv.ms/u/s!AkxDDnOtroRPe43-1JjD304ecvU?e=Y7qCHN). 101 | 102 | Note: 103 | 104 | 1. only support `bath_size=1` 105 | 2. Except for the backbone we do not use BN layer. 106 | 107 | 108 | 109 | 110 | ## Applications 111 | Other applications such as Classification, Instance segmentation, Object detection, Segmantic segmentation, pose estimation, Class activation map can be found on https://mmcheng.net/res2net/ and https://github.com/gasvn/Res2Net . 112 | 113 | ## Citation 114 | If you find this work or code is helpful in your research, please cite: 115 | ``` 116 | @article{gao2019res2net, 117 | title={Res2Net: A New Multi-scale Backbone Architecture}, 118 | author={Gao, Shang-Hua and Cheng, Ming-Ming and Zhao, Kai and Zhang, Xin-Yu and Yang, Ming-Hsuan and Torr, Philip}, 119 | journal={IEEE TPAMI}, 120 | year={2020}, 121 | doi={10.1109/TPAMI.2019.2938758}, 122 | } 123 | @inproceedings{Liu2019PoolSal, 124 | title={A Simple Pooling-Based Design for Real-Time Salient Object Detection}, 125 | author={Jiang-Jiang Liu and Qibin Hou and Ming-Ming Cheng and Jiashi Feng and Jianmin Jiang}, 126 | booktitle={IEEE CVPR}, 127 | year={2019}, 128 | } 129 | ``` 130 | ## Acknowledge 131 | The code for salient object detection is partly borrowed from [A Simple Pooling-Based Design for Real-Time Salient Object Detection](https://github.com/backseason/PoolNet). 132 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Res2Net/Res2Net-PoolNet/7bef0652e83a6c4ebe4ed47f1b03ab5b7b16074a/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as F 8 | import numbers 9 | import numpy as np 10 | import random 11 | 12 | class ImageDataTrain(data.Dataset): 13 | def __init__(self, data_root, data_list): 14 | self.sal_root = data_root 15 | self.sal_source = data_list 16 | 17 | with open(self.sal_source, 'r') as f: 18 | self.sal_list = [x.strip() for x in f.readlines()] 19 | 20 | self.sal_num = len(self.sal_list) 21 | 22 | 23 | def __getitem__(self, item): 24 | # sal data loading 25 | im_name = self.sal_list[item % self.sal_num].split()[0] 26 | gt_name = self.sal_list[item % self.sal_num].split()[1] 27 | sal_image = load_image(os.path.join(self.sal_root, im_name)) 28 | sal_label = load_sal_label(os.path.join(self.sal_root, gt_name)) 29 | sal_image, sal_label = cv_random_flip(sal_image, sal_label) 30 | sal_image = torch.Tensor(sal_image) 31 | sal_label = torch.Tensor(sal_label) 32 | 33 | sample = {'sal_image': sal_image, 'sal_label': sal_label} 34 | return sample 35 | 36 | def __len__(self): 37 | return self.sal_num 38 | 39 | class ImageDataTest(data.Dataset): 40 | def __init__(self, data_root, data_list): 41 | self.data_root = data_root 42 | self.data_list = data_list 43 | with open(self.data_list, 'r') as f: 44 | self.image_list = [x.strip() for x in f.readlines()] 45 | 46 | self.image_num = len(self.image_list) 47 | 48 | def __getitem__(self, item): 49 | image, im_size = load_image_test(os.path.join(self.data_root, self.image_list[item])) 50 | image = torch.Tensor(image) 51 | 52 | return {'image': image, 'name': self.image_list[item % self.image_num], 'size': im_size} 53 | 54 | def __len__(self): 55 | return self.image_num 56 | 57 | 58 | def get_loader(config, mode='train', pin=False): 59 | shuffle = False 60 | if mode == 'train': 61 | shuffle = True 62 | dataset = ImageDataTrain(config.train_root, config.train_list) 63 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 64 | else: 65 | dataset = ImageDataTest(config.test_root, config.test_list) 66 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 67 | return data_loader 68 | 69 | def load_image(path): 70 | if not os.path.exists(path): 71 | print('File {} not exists'.format(path)) 72 | im = cv2.imread(path) 73 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # form rgb to bgr to suit for pytorch pretrained model. 74 | in_ = np.array(im, dtype=np.float32) 75 | in_ = in_/255 76 | in_ -= np.array((0.485, 0.456, 0.406)) 77 | in_ /= np.array((0.229, 0.224, 0.225)) 78 | # in_ = np.array(im, dtype=np.float32) 79 | # in_ -= np.array((104.00699, 116.66877, 122.67892)) 80 | in_ = in_.transpose((2,0,1)) 81 | return in_ 82 | 83 | def load_image_test(path): 84 | if not os.path.exists(path): 85 | print('File {} not exists'.format(path)) 86 | im = cv2.imread(path) 87 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # form rgb to bgr to suit for pytorch pretrained model. 88 | in_ = np.array(im, dtype=np.float32) 89 | in_ = in_/255 90 | in_ -= np.array((0.485, 0.456, 0.406)) 91 | in_ /= np.array((0.229, 0.224, 0.225)) 92 | # in_ = np.array(im, dtype=np.float32) 93 | im_size = tuple(in_.shape[:2]) 94 | # in_ -= np.array((104.00699, 116.66877, 122.67892)) 95 | in_ = in_.transpose((2,0,1)) 96 | return in_, im_size 97 | 98 | def load_sal_label(path): 99 | if not os.path.exists(path): 100 | print('File {} not exists'.format(path)) 101 | im = Image.open(path) 102 | label = np.array(im, dtype=np.float32) 103 | if len(label.shape) == 3: 104 | label = label[:,:,0] 105 | label = label / 255. 106 | label = label[np.newaxis, ...] 107 | return label 108 | 109 | def cv_random_flip(img, label): 110 | flip_flag = random.randint(0, 1) 111 | if flip_flag == 1: 112 | img = img[:,:,::-1].copy() 113 | label = label[:,:,::-1].copy() 114 | return img, label 115 | -------------------------------------------------------------------------------- /dataset/dataset_old.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as F 8 | import numbers 9 | import numpy as np 10 | import random 11 | 12 | class ImageDataTrain(data.Dataset): 13 | def __init__(self, data_root, data_list): 14 | self.sal_root = data_root 15 | self.sal_source = data_list 16 | 17 | with open(self.sal_source, 'r') as f: 18 | self.sal_list = [x.strip() for x in f.readlines()] 19 | 20 | self.sal_num = len(self.sal_list) 21 | 22 | 23 | def __getitem__(self, item): 24 | # sal data loading 25 | im_name = self.sal_list[item % self.sal_num].split()[0] 26 | gt_name = self.sal_list[item % self.sal_num].split()[1] 27 | sal_image = load_image(os.path.join(self.sal_root, im_name)) 28 | sal_label = load_sal_label(os.path.join(self.sal_root, gt_name)) 29 | sal_image, sal_label = cv_random_flip(sal_image, sal_label) 30 | sal_image = torch.Tensor(sal_image) 31 | sal_label = torch.Tensor(sal_label) 32 | 33 | sample = {'sal_image': sal_image, 'sal_label': sal_label} 34 | return sample 35 | 36 | def __len__(self): 37 | return self.sal_num 38 | 39 | class ImageDataTest(data.Dataset): 40 | def __init__(self, data_root, data_list): 41 | self.data_root = data_root 42 | self.data_list = data_list 43 | with open(self.data_list, 'r') as f: 44 | self.image_list = [x.strip() for x in f.readlines()] 45 | 46 | self.image_num = len(self.image_list) 47 | 48 | def __getitem__(self, item): 49 | image, im_size = load_image_test(os.path.join(self.data_root, self.image_list[item])) 50 | image = torch.Tensor(image) 51 | 52 | return {'image': image, 'name': self.image_list[item % self.image_num], 'size': im_size} 53 | 54 | def __len__(self): 55 | return self.image_num 56 | 57 | 58 | def get_loader(config, mode='train', pin=False): 59 | shuffle = False 60 | if mode == 'train': 61 | shuffle = True 62 | dataset = ImageDataTrain(config.train_root, config.train_list) 63 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 64 | else: 65 | dataset = ImageDataTest(config.test_root, config.test_list) 66 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 67 | return data_loader 68 | 69 | def load_image(path): 70 | if not os.path.exists(path): 71 | print('File {} not exists'.format(path)) 72 | im = cv2.imread(path) 73 | in_ = np.array(im, dtype=np.float32) 74 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 75 | in_ = in_.transpose((2,0,1)) 76 | return in_ 77 | 78 | def load_image_test(path): 79 | if not os.path.exists(path): 80 | print('File {} not exists'.format(path)) 81 | im = cv2.imread(path) 82 | in_ = np.array(im, dtype=np.float32) 83 | im_size = tuple(in_.shape[:2]) 84 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 85 | in_ = in_.transpose((2,0,1)) 86 | return in_, im_size 87 | 88 | def load_sal_label(path): 89 | if not os.path.exists(path): 90 | print('File {} not exists'.format(path)) 91 | im = Image.open(path) 92 | label = np.array(im, dtype=np.float32) 93 | if len(label.shape) == 3: 94 | label = label[:,:,0] 95 | label = label / 255. 96 | label = label[np.newaxis, ...] 97 | return label 98 | 99 | def cv_random_flip(img, label): 100 | flip_flag = random.randint(0, 1) 101 | if flip_flag == 1: 102 | img = img[:,:,::-1].copy() 103 | label = label[:,:,::-1].copy() 104 | return img, label 105 | -------------------------------------------------------------------------------- /dataset/joint_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as F 8 | import numbers 9 | import numpy as np 10 | import random 11 | 12 | class ImageDataTrain(data.Dataset): 13 | def __init__(self, sal_data_root, sal_data_list, edge_data_root, edge_data_list): 14 | self.sal_root = sal_data_root 15 | self.sal_source = sal_data_list 16 | self.edge_root = edge_data_root 17 | self.edge_source = edge_data_list 18 | 19 | with open(self.sal_source, 'r') as f: 20 | self.sal_list = [x.strip() for x in f.readlines()] 21 | with open(self.edge_source, 'r') as f: 22 | self.edge_list = [x.strip() for x in f.readlines()] 23 | 24 | self.sal_num = len(self.sal_list) 25 | self.edge_num = len(self.edge_list) 26 | 27 | 28 | def __getitem__(self, item): 29 | # edge data loading 30 | edge_im_name = self.edge_list[item % self.edge_num].split()[0] 31 | edge_gt_name = self.edge_list[item % self.edge_num].split()[1] 32 | edge_image = load_image(os.path.join(self.edge_root, edge_im_name)) 33 | edge_label = load_edge_label(os.path.join(self.edge_root, edge_gt_name)) 34 | edge_image = torch.Tensor(edge_image) 35 | edge_label = torch.Tensor(edge_label) 36 | 37 | # sal data loading 38 | sal_im_name = self.sal_list[item % self.sal_num].split()[0] 39 | sal_gt_name = self.sal_list[item % self.sal_num].split()[1] 40 | sal_image = load_image(os.path.join(self.sal_root, sal_im_name)) 41 | sal_label = load_sal_label(os.path.join(self.sal_root, sal_gt_name)) 42 | sal_image, sal_label = cv_random_flip(sal_image, sal_label) 43 | sal_image = torch.Tensor(sal_image) 44 | sal_label = torch.Tensor(sal_label) 45 | 46 | sample = {'edge_image': edge_image, 'edge_label': edge_label, 'sal_image': sal_image, 'sal_label': sal_label} 47 | return sample 48 | 49 | def __len__(self): 50 | return max(self.sal_num, self.edge_num) 51 | 52 | class ImageDataTest(data.Dataset): 53 | def __init__(self, data_root, data_list): 54 | self.data_root = data_root 55 | self.data_list = data_list 56 | with open(self.data_list, 'r') as f: 57 | self.image_list = [x.strip() for x in f.readlines()] 58 | 59 | self.image_num = len(self.image_list) 60 | 61 | def __getitem__(self, item): 62 | image, im_size = load_image_test(os.path.join(self.data_root, self.image_list[item])) 63 | image = torch.Tensor(image) 64 | 65 | return {'image': image, 'name': self.image_list[item % self.image_num], 'size': im_size} 66 | 67 | def __len__(self): 68 | return self.image_num 69 | 70 | 71 | def get_loader(config, mode='train', pin=False): 72 | shuffle = False 73 | if mode == 'train': 74 | shuffle = True 75 | dataset = ImageDataTrain(config.train_root, config.train_list, config.train_edge_root, config.train_edge_list) 76 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 77 | else: 78 | dataset = ImageDataTest(config.test_root, config.test_list) 79 | data_loader = data.DataLoader(dataset=dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=config.num_thread, pin_memory=pin) 80 | return data_loader 81 | 82 | def load_image(path): 83 | if not os.path.exists(path): 84 | print('File {} not exists'.format(path)) 85 | im = cv2.imread(path) 86 | in_ = np.array(im, dtype=np.float32) 87 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 88 | in_ = in_.transpose((2,0,1)) 89 | return in_ 90 | 91 | def load_image_test(path): 92 | if not os.path.exists(path): 93 | print('File {} not exists'.format(path)) 94 | im = cv2.imread(path) 95 | in_ = np.array(im, dtype=np.float32) 96 | im_size = tuple(in_.shape[:2]) 97 | in_ -= np.array((104.00699, 116.66877, 122.67892)) 98 | in_ = in_.transpose((2,0,1)) 99 | return in_, im_size 100 | 101 | def load_sal_label(path): 102 | if not os.path.exists(path): 103 | print('File {} not exists'.format(path)) 104 | im = Image.open(path) 105 | label = np.array(im, dtype=np.float32) 106 | if len(label.shape) == 3: 107 | label = label[:,:,0] 108 | label = label / 255. 109 | label = label[np.newaxis, ...] 110 | return label 111 | 112 | def load_edge_label(path): 113 | """ 114 | pixels > 0.5 -> 1. 115 | """ 116 | if not os.path.exists(path): 117 | print('File {} not exists'.format(path)) 118 | im = Image.open(path) 119 | label = np.array(im, dtype=np.float32) 120 | if len(label.shape) == 3: 121 | label = label[:,:,0] 122 | label = label / 255. 123 | label[np.where(label > 0.5)] = 1. 124 | label = label[np.newaxis, ...] 125 | return label 126 | 127 | def cv_random_flip(img, label): 128 | flip_flag = random.randint(0, 1) 129 | if flip_flag == 1: 130 | img = img[:,:,::-1].copy() 131 | label = label[:,:,::-1].copy() 132 | return img, label 133 | -------------------------------------------------------------------------------- /forward.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1 gpu id ,2 python file name, # 3 results folder name 4 | 5 | ARRAY=(e p d h s t) 6 | # ARRAY=(m) 7 | ELEMENTS=${#ARRAY[@]} 8 | 9 | echo "Testing on GPU " $1 " with file " $2 " to " $3 10 | 11 | for (( i=0;i<$ELEMENTS;i++)); do 12 | CUDA_VISIBLE_DEVICES=$1 python $2 --mode='test' --arch res2net --model=$3'/models/final.pth' --test_fold=$3'-sal-'${ARRAY[${i}]} --sal_mode=${ARRAY[${i}]} 13 | done 14 | 15 | echo "Testing on e,p,d,h,s,t datasets done." 16 | -------------------------------------------------------------------------------- /forward_edge.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # 1 gpu id ,2 python file name, # 3 results folder name 4 | 5 | ARRAY=(b) 6 | ELEMENTS=${#ARRAY[@]} 7 | 8 | echo "Testing on GPU " $1 " with file " $2 " to " $3 9 | 10 | for (( i=0;i<$ELEMENTS;i++)); do 11 | CUDA_VISIBLE_DEVICES=$1 python $2 --mode='test' --arch res2net --model=$3'/models/final.pth' --test_fold=$3'-edge' --sal_mode=${ARRAY[${i}]} --test_mode=0 12 | done 13 | 14 | echo "Testing on bsds dataset done." 15 | -------------------------------------------------------------------------------- /joint_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataset.joint_dataset import get_loader 4 | from joint_solver import Solver 5 | 6 | def get_test_info(sal_mode='e'): 7 | if sal_mode == 'e': 8 | image_root = './data/ECSSD/Imgs/' 9 | image_source = './data/ECSSD/test.lst' 10 | elif sal_mode == 'p': 11 | image_root = './data/PASCALS/Imgs/' 12 | image_source = './data/PASCALS/test.lst' 13 | elif sal_mode == 'd': 14 | image_root = './data/DUTOMRON/Imgs/' 15 | image_source = './data/DUTOMRON/test.lst' 16 | elif sal_mode == 'h': 17 | image_root = './data/HKU-IS/Imgs/' 18 | image_source = './data/HKU-IS/test.lst' 19 | elif sal_mode == 's': 20 | image_root = './data/SOD/Imgs/' 21 | image_source = './data/SOD/test.lst' 22 | elif sal_mode == 't': 23 | image_root = './data/DUTS-TE/Imgs/' 24 | image_source = './data/DUTS-TE/test.lst' 25 | elif sal_mode == 'm_r': # for speed test 26 | image_root = './data/MSRA/Imgs_resized/' 27 | image_source = './data/MSRA/test_resized.lst' 28 | elif sal_mode == 'b': # BSDS dataset for edge evaluation 29 | image_root = './data/HED-BSDS_PASCAL/HED-BSDS/test/' 30 | image_source = './data/HED-BSDS_PASCAL/HED-BSDS/test.lst' 31 | return image_root, image_source 32 | 33 | def main(config): 34 | if config.mode == 'train': 35 | train_loader = get_loader(config) 36 | run = 0 37 | while os.path.exists("%s/run-%d" % (config.save_folder, run)): 38 | run += 1 39 | os.mkdir("%s/run-%d" % (config.save_folder, run)) 40 | os.mkdir("%s/run-%d/models" % (config.save_folder, run)) 41 | config.save_folder = "%s/run-%d" % (config.save_folder, run) 42 | train = Solver(train_loader, None, config) 43 | train.train() 44 | elif config.mode == 'test': 45 | config.test_root, config.test_list = get_test_info(config.sal_mode) 46 | test_loader = get_loader(config, mode='test') 47 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold) 48 | test = Solver(None, test_loader, config) 49 | test.test(test_mode=config.test_mode) 50 | else: 51 | raise IOError("illegal input!!!") 52 | 53 | if __name__ == '__main__': 54 | 55 | vgg_path = './dataset/pretrained/vgg16_20M.pth' 56 | resnet_path = './dataset/pretrained/resnet50_caffe.pth' 57 | resnet_path = './dataset/pretrained/resnet50_caffe.pth' 58 | res2net_path = '/home/shgao/.torch/models/res2net50_26w_4s-06e79181.pth' 59 | parser = argparse.ArgumentParser() 60 | 61 | # Hyper-parameters 62 | parser.add_argument('--n_color', type=int, default=3) 63 | parser.add_argument('--lr', type=float, default=5e-5) # Learning rate resnet:5e-5, vgg:1e-4 64 | parser.add_argument('--wd', type=float, default=0.0005) # Weight decay 65 | parser.add_argument('--no-cuda', dest='cuda', action='store_false') 66 | 67 | # Training settings 68 | parser.add_argument('--arch', type=str, default='resnet') # resnet or vgg 69 | parser.add_argument('--pretrained_model', type=str, default=res2net_path) 70 | parser.add_argument('--epoch', type=int, default=11) 71 | parser.add_argument('--batch_size', type=int, default=1) # only support 1 now 72 | parser.add_argument('--num_thread', type=int, default=1) 73 | parser.add_argument('--load', type=str, default='') 74 | parser.add_argument('--save_folder', type=str, default='./results_edge') 75 | parser.add_argument('--epoch_save', type=int, default=3) 76 | parser.add_argument('--iter_size', type=int, default=10) 77 | parser.add_argument('--show_every', type=int, default=50) 78 | 79 | # Train data 80 | parser.add_argument('--train_root', type=str, default='') 81 | parser.add_argument('--train_list', type=str, default='') 82 | parser.add_argument('--train_edge_root', type=str, default='') # path for edge data 83 | parser.add_argument('--train_edge_list', type=str, default='') # list file for edge data 84 | 85 | # Testing settings 86 | parser.add_argument('--model', type=str, default=None) # Snapshot 87 | parser.add_argument('--test_fold', type=str, default=None) # Test results saving folder 88 | parser.add_argument('--test_mode', type=int, default=1) # 0->edge, 1->saliency 89 | parser.add_argument('--sal_mode', type=str, default='e') # Test image dataset 90 | 91 | # Misc 92 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 93 | config = parser.parse_args() 94 | 95 | if not os.path.exists(config.save_folder): 96 | os.mkdir(config.save_folder) 97 | 98 | # Get test set info 99 | test_root, test_list = get_test_info(config.sal_mode) 100 | config.test_root = test_root 101 | config.test_list = test_list 102 | 103 | main(config) 104 | -------------------------------------------------------------------------------- /joint_solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from torch.nn import utils, functional as F 4 | from torch.optim import Adam 5 | from torch.autograd import Variable 6 | from torch.backends import cudnn 7 | from networks.joint_poolnet_res2net import build_model, weights_init 8 | import scipy.misc as sm 9 | import numpy as np 10 | import os 11 | import torchvision.utils as vutils 12 | import cv2 13 | import math 14 | import time 15 | 16 | 17 | class Solver(object): 18 | def __init__(self, train_loader, test_loader, config): 19 | self.train_loader = train_loader 20 | self.test_loader = test_loader 21 | self.config = config 22 | self.iter_size = config.iter_size 23 | self.show_every = config.show_every 24 | self.lr_decay_epoch = [8,] 25 | self.build_model() 26 | if config.mode == 'test': 27 | print('Loading pre-trained model from %s...' % self.config.model) 28 | if self.config.cuda: 29 | self.net.load_state_dict(torch.load(self.config.model)) 30 | else: 31 | self.net.load_state_dict(torch.load(self.config.model, map_location='cpu')) 32 | self.net.eval() 33 | 34 | # print the network information and parameter numbers 35 | def print_network(self, model, name): 36 | num_params = 0 37 | for p in model.parameters(): 38 | num_params += p.numel() 39 | print(name) 40 | print(model) 41 | print("The number of parameters: {}".format(num_params)) 42 | 43 | # build the network 44 | def build_model(self): 45 | self.net = build_model(self.config.arch) 46 | if self.config.cuda: 47 | self.net = self.net.cuda() 48 | # self.net.train() 49 | self.net.eval() # use_global_stats = True 50 | self.net.apply(weights_init) 51 | if self.config.load == '': 52 | self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model)) 53 | else: 54 | self.net.load_state_dict(torch.load(self.config.load)) 55 | 56 | self.lr = self.config.lr 57 | self.wd = self.config.wd 58 | 59 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 60 | self.print_network(self.net, 'PoolNet Structure') 61 | 62 | def test(self, test_mode=1): 63 | mode_name = ['edge_fuse', 'sal_fuse'] 64 | EPSILON = 1e-8 65 | time_s = time.time() 66 | img_num = len(self.test_loader) 67 | for i, data_batch in enumerate(self.test_loader): 68 | images, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) 69 | if test_mode == 0: 70 | images = images.numpy()[0].transpose((1,2,0)) 71 | scale = [0.5, 1, 1.5, 2] # uncomment for multi-scale testing 72 | # scale = [1] 73 | multi_fuse = np.zeros(im_size, np.float32) 74 | for k in range(0, len(scale)): 75 | im_ = cv2.resize(images, None, fx=scale[k], fy=scale[k], interpolation=cv2.INTER_LINEAR) 76 | im_ = im_.transpose((2, 0, 1)) 77 | im_ = torch.Tensor(im_[np.newaxis, ...]) 78 | 79 | with torch.no_grad(): 80 | im_ = Variable(im_) 81 | if self.config.cuda: 82 | im_ = im_.cuda() 83 | preds = self.net(im_, mode=test_mode) 84 | pred_0 = np.squeeze(torch.sigmoid(preds[1][0]).cpu().data.numpy()) 85 | pred_1 = np.squeeze(torch.sigmoid(preds[1][1]).cpu().data.numpy()) 86 | pred_2 = np.squeeze(torch.sigmoid(preds[1][2]).cpu().data.numpy()) 87 | pred_fuse = np.squeeze(torch.sigmoid(preds[0]).cpu().data.numpy()) 88 | 89 | pred = (pred_0 + pred_1 + pred_2 + pred_fuse) / 4 90 | pred = (pred - np.min(pred) + EPSILON) / (np.max(pred) - np.min(pred) + EPSILON) 91 | 92 | pred = cv2.resize(pred, (im_size[1], im_size[0]), interpolation=cv2.INTER_LINEAR) 93 | multi_fuse += pred 94 | 95 | multi_fuse /= len(scale) 96 | multi_fuse = 255 * (1 - multi_fuse) 97 | cv2.imwrite(os.path.join(self.config.test_fold, name[:-4] + '_' + mode_name[test_mode] + '.png'), multi_fuse) 98 | elif test_mode == 1: 99 | with torch.no_grad(): 100 | images = Variable(images) 101 | if self.config.cuda: 102 | images = images.cuda() 103 | preds = self.net(images, mode=test_mode) 104 | pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy()) 105 | multi_fuse = 255 * pred 106 | cv2.imwrite(os.path.join(self.config.test_fold, name[:-4] + '_' + mode_name[test_mode] + '.png'), multi_fuse) 107 | time_e = time.time() 108 | print('Speed: %f FPS' % (img_num/(time_e-time_s))) 109 | print('Test Done!') 110 | 111 | # training phase 112 | def train(self): 113 | iter_num = 30000 # each batch only train 30000 iters.(This number is just a random choice...) 114 | aveGrad = 0 115 | for epoch in range(self.config.epoch): 116 | r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 117 | self.net.zero_grad() 118 | for i, data_batch in enumerate(self.train_loader): 119 | if (i + 1) == iter_num: break 120 | edge_image, edge_label, sal_image, sal_label = data_batch['edge_image'], data_batch['edge_label'], data_batch['sal_image'], data_batch['sal_label'] 121 | if (sal_image.size(2) != sal_label.size(2)) or (sal_image.size(3) != sal_label.size(3)): 122 | print('IMAGE ERROR, PASSING```') 123 | continue 124 | edge_image, edge_label, sal_image, sal_label= Variable(edge_image), Variable(edge_label), Variable(sal_image), Variable(sal_label) 125 | if self.config.cuda: 126 | edge_image, edge_label, sal_image, sal_label = edge_image.cuda(), edge_label.cuda(), sal_image.cuda(), sal_label.cuda() 127 | 128 | # edge part 129 | edge_pred = self.net(edge_image, mode=0) 130 | edge_loss_fuse = bce2d(edge_pred[0], edge_label, reduction='sum') 131 | edge_loss_part = [] 132 | for ix in edge_pred[1]: 133 | edge_loss_part.append(bce2d(ix, edge_label, reduction='sum')) 134 | edge_loss = (edge_loss_fuse + sum(edge_loss_part)) / (self.iter_size * self.config.batch_size) 135 | r_edge_loss += edge_loss.data 136 | 137 | # sal part 138 | sal_pred = self.net(sal_image, mode=1) 139 | sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred, sal_label, reduction='sum') 140 | sal_loss = sal_loss_fuse / (self.iter_size * self.config.batch_size) 141 | r_sal_loss += sal_loss.data 142 | 143 | loss = sal_loss + edge_loss 144 | r_sum_loss += loss.data 145 | 146 | loss.backward() 147 | 148 | aveGrad += 1 149 | 150 | # accumulate gradients as done in DSS 151 | if aveGrad % self.iter_size == 0: 152 | self.optimizer.step() 153 | self.optimizer.zero_grad() 154 | aveGrad = 0 155 | 156 | if i % (self.show_every // self.config.batch_size) == 0: 157 | if i == 0: 158 | x_showEvery = 1 159 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || Edge : %10.4f || Sal : %10.4f || Sum : %10.4f' % ( 160 | epoch, self.config.epoch, i, iter_num, r_edge_loss/x_showEvery, r_sal_loss/x_showEvery, r_sum_loss/x_showEvery)) 161 | print('Learning rate: ' + str(self.lr)) 162 | r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 163 | 164 | if (epoch + 1) % self.config.epoch_save == 0: 165 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_folder, epoch + 1)) 166 | 167 | if epoch in self.lr_decay_epoch: 168 | self.lr = self.lr * 0.1 169 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 170 | 171 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_folder) 172 | 173 | def bce2d(input, target, reduction=None): 174 | assert(input.size() == target.size()) 175 | pos = torch.eq(target, 1).float() 176 | neg = torch.eq(target, 0).float() 177 | 178 | num_pos = torch.sum(pos) 179 | num_neg = torch.sum(neg) 180 | num_total = num_pos + num_neg 181 | 182 | alpha = num_neg / num_total 183 | beta = 1.1 * num_pos / num_total 184 | # target pixel = 1 -> weight beta 185 | # target pixel = 0 -> weight 1-beta 186 | weights = alpha * pos + beta * neg 187 | 188 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) 189 | 190 | -------------------------------------------------------------------------------- /joint_train_res2net.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | python joint_main.py --arch res2net --train_root ./data/DUTS/DUTS-TR --train_list ./data/DUTS/DUTS-TR/train_pair.lst --train_edge_root ./data/HED-BSDS_PASCAL --train_edge_list ./data/HED-BSDS_PASCAL/bsds_pascal_train_pair_r_val_r_small.lst 4 | # you can optionly change the -lr and -wd params 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataset.dataset import get_loader 4 | from solver import Solver 5 | 6 | def get_test_info(sal_mode='e'): 7 | if sal_mode == 'e': 8 | image_root = './data/ECSSD/Imgs/' 9 | image_source = './data/ECSSD/test.lst' 10 | elif sal_mode == 'p': 11 | image_root = './data/PASCALS/Imgs/' 12 | image_source = './data/PASCALS/test.lst' 13 | elif sal_mode == 'd': 14 | image_root = './data/DUTOMRON/Imgs/' 15 | image_source = './data/DUTOMRON/test.lst' 16 | elif sal_mode == 'h': 17 | image_root = './data/HKU-IS/Imgs/' 18 | image_source = './data/HKU-IS/test.lst' 19 | elif sal_mode == 's': 20 | image_root = './data/SOD/Imgs/' 21 | image_source = './data/SOD/test.lst' 22 | elif sal_mode == 't': 23 | image_root = './data/DUTS-TE/Imgs/' 24 | image_source = './data/DUTS-TE/test.lst' 25 | elif sal_mode == 'm_r': # for speed test 26 | image_root = './data/MSRA/Imgs_resized/' 27 | image_source = './data/MSRA/test_resized.lst' 28 | 29 | return image_root, image_source 30 | 31 | def main(config): 32 | if config.mode == 'train': 33 | train_loader = get_loader(config) 34 | run = 0 35 | while os.path.exists("%s/run-%d" % (config.save_folder, run)): 36 | run += 1 37 | os.mkdir("%s/run-%d" % (config.save_folder, run)) 38 | os.mkdir("%s/run-%d/models" % (config.save_folder, run)) 39 | config.save_folder = "%s/run-%d" % (config.save_folder, run) 40 | train = Solver(train_loader, None, config) 41 | train.train() 42 | elif config.mode == 'test': 43 | config.test_root, config.test_list = get_test_info(config.sal_mode) 44 | test_loader = get_loader(config, mode='test') 45 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold) 46 | test = Solver(None, test_loader, config) 47 | test.test() 48 | else: 49 | raise IOError("illegal input!!!") 50 | 51 | if __name__ == '__main__': 52 | 53 | vgg_path = './dataset/pretrained/vgg16_20M.pth' 54 | resnet_path = './dataset/pretrained/resnet50_caffe.pth' 55 | res2net_path = '/home/shgao/.torch/models/res2net50_26w_4s-06e79181.pth' 56 | parser = argparse.ArgumentParser() 57 | 58 | # Hyper-parameters 59 | parser.add_argument('--n_color', type=int, default=3) 60 | parser.add_argument('--lr', type=float, default=5e-5) # Learning rate resnet:5e-5, vgg:1e-4 61 | parser.add_argument('--wd', type=float, default=0.0005) # Weight decay 62 | parser.add_argument('--no-cuda', dest='cuda', action='store_false') 63 | 64 | # Training settings 65 | parser.add_argument('--arch', type=str, default='res2net_path') # resnet or vgg 66 | parser.add_argument('--pretrained_model', type=str, default=res2net_path) 67 | parser.add_argument('--epoch', type=int, default=24) 68 | parser.add_argument('--batch_size', type=int, default=1) # only support 1 now 69 | parser.add_argument('--num_thread', type=int, default=1) 70 | parser.add_argument('--load', type=str, default='') 71 | parser.add_argument('--save_folder', type=str, default='./results') 72 | parser.add_argument('--epoch_save', type=int, default=3) 73 | parser.add_argument('--iter_size', type=int, default=10) 74 | parser.add_argument('--show_every', type=int, default=50) 75 | 76 | # Train data 77 | parser.add_argument('--train_root', type=str, default='') 78 | parser.add_argument('--train_list', type=str, default='') 79 | 80 | # Testing settings 81 | parser.add_argument('--model', type=str, default=None) # Snapshot 82 | parser.add_argument('--test_fold', type=str, default=None) # Test results saving folder 83 | parser.add_argument('--sal_mode', type=str, default='e') # Test image dataset 84 | 85 | # Misc 86 | parser.add_argument('--mode', type=str, default='train', choices=['train', 'test']) 87 | config = parser.parse_args() 88 | 89 | if not os.path.exists(config.save_folder): 90 | os.mkdir(config.save_folder) 91 | 92 | # Get test set info 93 | test_root, test_list = get_test_info(config.sal_mode) 94 | config.test_root = test_root 95 | config.test_list = test_list 96 | 97 | main(config) 98 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Res2Net/Res2Net-PoolNet/7bef0652e83a6c4ebe4ed47f1b03ab5b7b16074a/networks/__init__.py -------------------------------------------------------------------------------- /networks/deeplab_res2net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | affine_par = True 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottle2neck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, dilation_ = 1, downsample=None, baseWidth=26, scale = 4, stype='normal'): 49 | super(Bottle2neck, self).__init__() 50 | width = int(math.floor(planes * (baseWidth/64.0))) 51 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, stride=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(width*scale,affine = affine_par) 53 | for i in self.bn1.parameters(): 54 | i.requires_grad = False 55 | if scale == 1: 56 | self.nums = 1 57 | else: 58 | self.nums = scale -1 59 | if stype == 'stage': 60 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 61 | convs = [] 62 | bns = [] 63 | for i in range(self.nums): 64 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, dilation = dilation_, padding=dilation_, bias=False)) 65 | bns.append(nn.BatchNorm2d(width, affine = affine_par)) 66 | self.convs = nn.ModuleList(convs) 67 | self.bns = nn.ModuleList(bns) 68 | for j in range(self.nums): 69 | for i in self.bns[j].parameters(): 70 | i.requires_grad = False 71 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, affine = affine_par) 73 | for i in self.bn3.parameters(): 74 | i.requires_grad = False 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stype = stype 78 | self.scale = scale 79 | self.width = width 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | spx = torch.split(out, self.width, 1) 90 | for i in range(self.nums): 91 | if i==0 or self.stype=='stage': 92 | sp = spx[i] 93 | else: 94 | sp = sp + spx[i] 95 | sp = self.convs[i](sp) 96 | sp = self.relu(self.bns[i](sp)) 97 | if i==0: 98 | out = sp 99 | else: 100 | out = torch.cat((out, sp), 1) 101 | if self.scale != 1 and self.stype=='normal': 102 | out = torch.cat((out, spx[self.nums]),1) 103 | elif self.scale != 1 and self.stype=='stage': 104 | out = torch.cat((out, self.pool(spx[self.nums])),1) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | residual = self.downsample(x) 111 | 112 | out += residual 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | class Res2Net(nn.Module): 118 | def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): 119 | self.inplanes = 64 120 | super(Res2Net, self).__init__() 121 | self.baseWidth = baseWidth 122 | self.scale = scale 123 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 124 | bias=False) 125 | self.bn1 = nn.BatchNorm2d(64,affine = affine_par) 126 | for i in self.bn1.parameters(): 127 | i.requires_grad = False 128 | self.relu = nn.ReLU(inplace=True) 129 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 130 | self.layer1 = self._make_layer(block, 64, layers[0]) 131 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 132 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 133 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation__ = 2) 134 | self.avgpool = nn.AvgPool2d(7, stride=1) 135 | self.fc = nn.Linear(512 * block.expansion, num_classes) 136 | for m in self.modules(): 137 | if isinstance(m, nn.Conv2d): 138 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 139 | m.weight.data.normal_(0, 0.01) 140 | elif isinstance(m, nn.BatchNorm2d): 141 | m.weight.data.fill_(1) 142 | m.bias.data.zero_() 143 | 144 | def _make_layer(self, block, planes, blocks, stride=1,dilation__ = 1): 145 | downsample = None 146 | if stride != 1 or self.inplanes != planes * block.expansion or dilation__ == 2 or dilation__ == 4: 147 | downsample = nn.Sequential( 148 | nn.Conv2d(self.inplanes, planes * block.expansion, 149 | kernel_size=1, stride=stride, bias=False), 150 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par), 151 | ) 152 | for i in downsample._modules['1'].parameters(): 153 | i.requires_grad = False 154 | layers = [] 155 | layers.append(block(self.inplanes, planes, stride,dilation_=dilation__, downsample = downsample, 156 | stype='stage', baseWidth = self.baseWidth, scale=self.scale)) 157 | self.inplanes = planes * block.expansion 158 | for i in range(1, blocks): 159 | layers.append(block(self.inplanes, planes, dilation_=dilation__, baseWidth = self.baseWidth, scale=self.scale)) 160 | 161 | return nn.Sequential(*layers) 162 | 163 | def forward(self, x): 164 | tmp_x = [] 165 | x = self.conv1(x) 166 | x = self.bn1(x) 167 | x = self.relu(x) 168 | tmp_x.append(x) 169 | x = self.maxpool(x) 170 | 171 | x = self.layer1(x) 172 | tmp_x.append(x) 173 | x = self.layer2(x) 174 | tmp_x.append(x) 175 | x = self.layer3(x) 176 | tmp_x.append(x) 177 | x = self.layer4(x) 178 | tmp_x.append(x) 179 | 180 | return tmp_x 181 | 182 | 183 | class Res2Net_locate(nn.Module): 184 | def __init__(self, block, layers): 185 | super(Res2Net_locate,self).__init__() 186 | self.resnet = Res2Net(block, layers) 187 | self.in_planes = 512 188 | self.out_planes = [512, 256, 256, 128] 189 | 190 | self.ppms_pre = nn.Conv2d(2048, self.in_planes, 1, 1, bias=False) 191 | ppms, infos = [], [] 192 | for ii in [1, 3, 5]: 193 | ppms.append(nn.Sequential(nn.AdaptiveAvgPool2d(ii), nn.Conv2d(self.in_planes, self.in_planes, 1, 1, bias=False), nn.ReLU(inplace=True))) 194 | self.ppms = nn.ModuleList(ppms) 195 | 196 | self.ppm_cat = nn.Sequential(nn.Conv2d(self.in_planes * 4, self.in_planes, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 197 | for ii in self.out_planes: 198 | infos.append(nn.Sequential(nn.Conv2d(self.in_planes, ii, 3, 1, 1, bias=False), nn.ReLU(inplace=True))) 199 | self.infos = nn.ModuleList(infos) 200 | 201 | for m in self.modules(): 202 | if isinstance(m, nn.Conv2d): 203 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 204 | m.weight.data.normal_(0, 0.01) 205 | elif isinstance(m, nn.BatchNorm2d): 206 | m.weight.data.fill_(1) 207 | m.bias.data.zero_() 208 | 209 | def load_pretrained_model(self, model): 210 | self.resnet.load_state_dict(model, strict=False) 211 | 212 | def forward(self, x): 213 | x_size = x.size()[2:] 214 | xs = self.resnet(x) 215 | 216 | xs_1 = self.ppms_pre(xs[-1]) 217 | xls = [xs_1] 218 | for k in range(len(self.ppms)): 219 | xls.append(F.interpolate(self.ppms[k](xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True)) 220 | xls = self.ppm_cat(torch.cat(xls, dim=1)) 221 | 222 | infos = [] 223 | for k in range(len(self.infos)): 224 | infos.append(self.infos[k](F.interpolate(xls, xs[len(self.infos) - 1 - k].size()[2:], mode='bilinear', align_corners=True))) 225 | 226 | return xs, infos 227 | 228 | def res2net50_locate(): 229 | model = Res2Net_locate(Bottle2neck, [3, 4, 6, 3]) 230 | return model 231 | -------------------------------------------------------------------------------- /networks/deeplab_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | affine_par = True 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, dilation_ = 1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 51 | self.bn1 = nn.BatchNorm2d(planes,affine = affine_par) 52 | for i in self.bn1.parameters(): 53 | i.requires_grad = False 54 | padding = 1 55 | if dilation_ == 2: 56 | padding = 2 57 | elif dilation_ == 4: 58 | padding = 4 59 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 60 | padding=padding, bias=False, dilation = dilation_) 61 | self.bn2 = nn.BatchNorm2d(planes,affine = affine_par) 62 | for i in self.bn2.parameters(): 63 | i.requires_grad = False 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4, affine = affine_par) 66 | for i in self.bn3.parameters(): 67 | i.requires_grad = False 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, block, layers): 96 | self.inplanes = 64 97 | super(ResNet, self).__init__() 98 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 99 | bias=False) 100 | self.bn1 = nn.BatchNorm2d(64,affine = affine_par) 101 | for i in self.bn1.parameters(): 102 | i.requires_grad = False 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation__ = 2) 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 113 | m.weight.data.normal_(0, 0.01) 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1,dilation__ = 1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion or dilation__ == 2 or dilation__ == 4: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par), 125 | ) 126 | for i in downsample._modules['1'].parameters(): 127 | i.requires_grad = False 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride,dilation_=dilation__, downsample = downsample )) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes,dilation_=dilation__)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | tmp_x = [] 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | tmp_x.append(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | tmp_x.append(x) 146 | x = self.layer2(x) 147 | tmp_x.append(x) 148 | x = self.layer3(x) 149 | tmp_x.append(x) 150 | x = self.layer4(x) 151 | tmp_x.append(x) 152 | 153 | return tmp_x 154 | 155 | 156 | class ResNet_locate(nn.Module): 157 | def __init__(self, block, layers): 158 | super(ResNet_locate,self).__init__() 159 | self.resnet = ResNet(block, layers) 160 | self.in_planes = 512 161 | self.out_planes = [512, 256, 256, 128] 162 | 163 | self.ppms_pre = nn.Conv2d(2048, self.in_planes, 1, 1, bias=False) 164 | ppms, infos = [], [] 165 | for ii in [1, 3, 5]: 166 | ppms.append(nn.Sequential(nn.AdaptiveAvgPool2d(ii), nn.Conv2d(self.in_planes, self.in_planes, 1, 1, bias=False), nn.ReLU(inplace=True))) 167 | self.ppms = nn.ModuleList(ppms) 168 | 169 | self.ppm_cat = nn.Sequential(nn.Conv2d(self.in_planes * 4, self.in_planes, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 170 | for ii in self.out_planes: 171 | infos.append(nn.Sequential(nn.Conv2d(self.in_planes, ii, 3, 1, 1, bias=False), nn.ReLU(inplace=True))) 172 | self.infos = nn.ModuleList(infos) 173 | 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 177 | m.weight.data.normal_(0, 0.01) 178 | elif isinstance(m, nn.BatchNorm2d): 179 | m.weight.data.fill_(1) 180 | m.bias.data.zero_() 181 | 182 | def load_pretrained_model(self, model): 183 | self.resnet.load_state_dict(model, strict=False) 184 | 185 | def forward(self, x): 186 | x_size = x.size()[2:] 187 | xs = self.resnet(x) 188 | 189 | xs_1 = self.ppms_pre(xs[-1]) 190 | xls = [xs_1] 191 | for k in range(len(self.ppms)): 192 | xls.append(F.interpolate(self.ppms[k](xs_1), xs_1.size()[2:], mode='bilinear', align_corners=True)) 193 | xls = self.ppm_cat(torch.cat(xls, dim=1)) 194 | 195 | infos = [] 196 | for k in range(len(self.infos)): 197 | infos.append(self.infos[k](F.interpolate(xls, xs[len(self.infos) - 1 - k].size()[2:], mode='bilinear', align_corners=True))) 198 | 199 | return xs, infos 200 | 201 | def resnet50_locate(): 202 | model = ResNet_locate(Bottleneck, [3, 4, 6, 3]) 203 | return model 204 | -------------------------------------------------------------------------------- /networks/joint_poolnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import math 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from .deeplab_resnet import resnet50_locate 10 | from .vgg import vgg16_locate 11 | 12 | 13 | config_vgg = {'convert': [[128,256,512,512,512],[64,128,256,512,512]], 'deep_pool': [[512, 512, 256, 128], [512, 256, 128, 128], [True, True, True, False], [True, True, True, False]], 'score': 256, 'edgeinfoc':[48,128], 'block': [[512, [16]], [256, [16]], [128, [16]]], 'fuse': [[16, 16, 16], True]} # no convert layer, no conv6 14 | 15 | config_resnet = {'convert': [[64,256,512,1024,2048],[128,256,256,512,512]], 'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]], 'score': 256, 'edgeinfoc':[64,128], 'block': [[512, [16]], [256, [16]], [256, [16]], [128, [16]]], 'fuse': [[16, 16, 16, 16], True]} 16 | 17 | class ConvertLayer(nn.Module): 18 | def __init__(self, list_k): 19 | super(ConvertLayer, self).__init__() 20 | up = [] 21 | for i in range(len(list_k[0])): 22 | up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True))) 23 | self.convert0 = nn.ModuleList(up) 24 | 25 | def forward(self, list_x): 26 | resl = [] 27 | for i in range(len(list_x)): 28 | resl.append(self.convert0[i](list_x[i])) 29 | return resl 30 | 31 | class DeepPoolLayer(nn.Module): 32 | def __init__(self, k, k_out, need_x2, need_fuse): 33 | super(DeepPoolLayer, self).__init__() 34 | self.pools_sizes = [2,4,8] 35 | self.need_x2 = need_x2 36 | self.need_fuse = need_fuse 37 | pools, convs = [],[] 38 | for i in self.pools_sizes: 39 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 40 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 41 | self.pools = nn.ModuleList(pools) 42 | self.convs = nn.ModuleList(convs) 43 | self.relu = nn.ReLU() 44 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 45 | if self.need_fuse: 46 | self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False) 47 | 48 | def forward(self, x, x2=None, x3=None): 49 | x_size = x.size() 50 | resl = x 51 | for i in range(len(self.pools_sizes)): 52 | y = self.convs[i](self.pools[i](x)) 53 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 54 | resl = self.relu(resl) 55 | if self.need_x2: 56 | resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True) 57 | resl = self.conv_sum(resl) 58 | if self.need_fuse: 59 | resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3)) 60 | return resl 61 | 62 | class BlockLayer(nn.Module): 63 | def __init__(self, k_in, k_out_list): 64 | super(BlockLayer, self).__init__() 65 | up_in1, up_mid1, up_in2, up_mid2, up_out = [], [], [], [], [] 66 | 67 | for k in k_out_list: 68 | up_in1.append(nn.Conv2d(k_in, k_in//4, 1, 1, bias=False)) 69 | up_mid1.append(nn.Sequential(nn.Conv2d(k_in//4, k_in//4, 3, 1, 1, bias=False), nn.Conv2d(k_in//4, k_in, 1, 1, bias=False))) 70 | up_in2.append(nn.Conv2d(k_in, k_in//4, 1, 1, bias=False)) 71 | up_mid2.append(nn.Sequential(nn.Conv2d(k_in//4, k_in//4, 3, 1, 1, bias=False), nn.Conv2d(k_in//4, k_in, 1, 1, bias=False))) 72 | up_out.append(nn.Conv2d(k_in, k, 1, 1, bias=False)) 73 | 74 | self.block_in1 = nn.ModuleList(up_in1) 75 | self.block_in2 = nn.ModuleList(up_in2) 76 | self.block_mid1 = nn.ModuleList(up_mid1) 77 | self.block_mid2 = nn.ModuleList(up_mid2) 78 | self.block_out = nn.ModuleList(up_out) 79 | self.relu = nn.ReLU() 80 | 81 | def forward(self, x, mode=0): 82 | x_tmp = self.relu(x + self.block_mid1[mode](self.block_in1[mode](x))) 83 | # x_tmp = self.block_mid2[mode](self.block_in2[mode](self.relu(x + x_tmp))) 84 | x_tmp = self.relu(x_tmp + self.block_mid2[mode](self.block_in2[mode](x_tmp))) 85 | x_tmp = self.block_out[mode](x_tmp) 86 | 87 | return x_tmp 88 | 89 | class EdgeInfoLayerC(nn.Module): 90 | def __init__(self, k_in, k_out): 91 | super(EdgeInfoLayerC, self).__init__() 92 | self.trans = nn.Sequential(nn.Conv2d(k_in, k_in, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 93 | nn.Conv2d(k_in, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 94 | nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 95 | nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 96 | 97 | def forward(self, x, x_size): 98 | tmp_x = [] 99 | for i_x in x: 100 | tmp_x.append(F.interpolate(i_x, x_size[2:], mode='bilinear', align_corners=True)) 101 | x = self.trans(torch.cat(tmp_x, dim=1)) 102 | return x 103 | 104 | class FuseLayer1(nn.Module): 105 | def __init__(self, list_k, deep_sup): 106 | super(FuseLayer1, self).__init__() 107 | up = [] 108 | for i in range(len(list_k)): 109 | up.append(nn.Conv2d(list_k[i], 1, 1, 1)) 110 | self.trans = nn.ModuleList(up) 111 | self.fuse = nn.Conv2d(len(list_k), 1, 1, 1) 112 | self.deep_sup = deep_sup 113 | 114 | def forward(self, list_x, x_size): 115 | up_x = [] 116 | for i, i_x in enumerate(list_x): 117 | up_x.append(F.interpolate(self.trans[i](i_x), x_size[2:], mode='bilinear', align_corners=True)) 118 | out_fuse = self.fuse(torch.cat(up_x, dim = 1)) 119 | if self.deep_sup: 120 | out_all = [] 121 | for up_i in up_x: 122 | out_all.append(up_i) 123 | return [out_fuse, out_all] 124 | else: 125 | return [out_fuse] 126 | 127 | class ScoreLayer(nn.Module): 128 | def __init__(self, k): 129 | super(ScoreLayer, self).__init__() 130 | self.score = nn.Conv2d(k ,1, 3, 1, 1) 131 | 132 | def forward(self, x, x_size=None): 133 | x = self.score(x) 134 | if x_size is not None: 135 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 136 | return x 137 | 138 | def extra_layer(base_model_cfg, base): 139 | if base_model_cfg == 'vgg': 140 | config = config_vgg 141 | elif base_model_cfg == 'resnet': 142 | config = config_resnet 143 | convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers = [], [], [], [], [], [] 144 | convert_layers = ConvertLayer(config['convert']) 145 | 146 | for k in config['block']: 147 | block_layers += [BlockLayer(k[0], k[1])] 148 | 149 | for i in range(len(config['deep_pool'][0])): 150 | deep_pool_layers += [DeepPoolLayer(config['deep_pool'][0][i], config['deep_pool'][1][i], config['deep_pool'][2][i], config['deep_pool'][3][i])] 151 | 152 | fuse_layers = FuseLayer1(config['fuse'][0], config['fuse'][1]) 153 | edgeinfo_layers = EdgeInfoLayerC(config['edgeinfoc'][0], config['edgeinfoc'][1]) 154 | score_layers = ScoreLayer(config['score']) 155 | 156 | return base, convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers 157 | 158 | 159 | class PoolNet(nn.Module): 160 | def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers): 161 | super(PoolNet, self).__init__() 162 | self.base_model_cfg = base_model_cfg 163 | self.base = base 164 | self.block = nn.ModuleList(block_layers) 165 | self.deep_pool = nn.ModuleList(deep_pool_layers) 166 | self.fuse = fuse_layers 167 | self.edgeinfo = edgeinfo_layers 168 | self.score = score_layers 169 | if self.base_model_cfg == 'resnet': 170 | self.convert = convert_layers 171 | 172 | def forward(self, x, mode): 173 | x_size = x.size() 174 | conv2merge, infos = self.base(x) 175 | if self.base_model_cfg == 'resnet': 176 | conv2merge = self.convert(conv2merge) 177 | conv2merge = conv2merge[::-1] 178 | 179 | edge_merge = [] 180 | merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0]) 181 | edge_merge.append(merge) 182 | for k in range(1, len(conv2merge)-1): 183 | merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k]) 184 | edge_merge.append(merge) 185 | 186 | if mode == 0: 187 | edge_merge = [self.block[i](kk) for i, kk in enumerate(edge_merge)] 188 | merge = self.fuse(edge_merge, x_size) 189 | elif mode == 1: 190 | merge = self.deep_pool[-1](merge) 191 | edge_merge = [self.block[i](kk).detach() for i, kk in enumerate(edge_merge)] 192 | edge_merge = self.edgeinfo(edge_merge, merge.size()) 193 | merge = self.score(torch.cat([merge, edge_merge], dim=1), x_size) 194 | return merge 195 | 196 | def build_model(base_model_cfg='vgg'): 197 | if base_model_cfg == 'vgg': 198 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, vgg16_locate())) 199 | elif base_model_cfg == 'resnet': 200 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, resnet50_locate())) 201 | 202 | def weights_init(m): 203 | if isinstance(m, nn.Conv2d): 204 | m.weight.data.normal_(0, 0.01) 205 | if m.bias is not None: 206 | m.bias.data.zero_() 207 | -------------------------------------------------------------------------------- /networks/joint_poolnet_res2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import math 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from .deeplab_res2net import res2net50_locate 10 | from .vgg import vgg16_locate 11 | 12 | 13 | config_vgg = {'convert': [[128,256,512,512,512],[64,128,256,512,512]], 'deep_pool': [[512, 512, 256, 128], [512, 256, 128, 128], [True, True, True, False], [True, True, True, False]], 'score': 256, 'edgeinfoc':[48,128], 'block': [[512, [16]], [256, [16]], [128, [16]]], 'fuse': [[16, 16, 16], True]} # no convert layer, no conv6 14 | 15 | config_resnet = {'convert': [[64,256,512,1024,2048],[128,256,256,512,512]], 'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]], 'score': 256, 'edgeinfoc':[64,128], 'block': [[512, [16]], [256, [16]], [256, [16]], [128, [16]]], 'fuse': [[16, 16, 16, 16], True]} 16 | 17 | class ConvertLayer(nn.Module): 18 | def __init__(self, list_k): 19 | super(ConvertLayer, self).__init__() 20 | up = [] 21 | for i in range(len(list_k[0])): 22 | up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True))) 23 | self.convert0 = nn.ModuleList(up) 24 | 25 | def forward(self, list_x): 26 | resl = [] 27 | for i in range(len(list_x)): 28 | resl.append(self.convert0[i](list_x[i])) 29 | return resl 30 | 31 | class DeepPoolLayer(nn.Module): 32 | def __init__(self, k, k_out, need_x2, need_fuse): 33 | super(DeepPoolLayer, self).__init__() 34 | self.pools_sizes = [2,4,8] 35 | self.need_x2 = need_x2 36 | self.need_fuse = need_fuse 37 | pools, convs = [],[] 38 | for i in self.pools_sizes: 39 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 40 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 41 | self.pools = nn.ModuleList(pools) 42 | self.convs = nn.ModuleList(convs) 43 | self.relu = nn.ReLU() 44 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 45 | if self.need_fuse: 46 | self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False) 47 | 48 | def forward(self, x, x2=None, x3=None): 49 | x_size = x.size() 50 | resl = x 51 | for i in range(len(self.pools_sizes)): 52 | y = self.convs[i](self.pools[i](x)) 53 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 54 | resl = self.relu(resl) 55 | if self.need_x2: 56 | resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True) 57 | resl = self.conv_sum(resl) 58 | if self.need_fuse: 59 | resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3)) 60 | return resl 61 | 62 | class BlockLayer(nn.Module): 63 | def __init__(self, k_in, k_out_list): 64 | super(BlockLayer, self).__init__() 65 | up_in1, up_mid1, up_in2, up_mid2, up_out = [], [], [], [], [] 66 | 67 | for k in k_out_list: 68 | up_in1.append(nn.Conv2d(k_in, k_in//4, 1, 1, bias=False)) 69 | up_mid1.append(nn.Sequential(nn.Conv2d(k_in//4, k_in//4, 3, 1, 1, bias=False), nn.Conv2d(k_in//4, k_in, 1, 1, bias=False))) 70 | up_in2.append(nn.Conv2d(k_in, k_in//4, 1, 1, bias=False)) 71 | up_mid2.append(nn.Sequential(nn.Conv2d(k_in//4, k_in//4, 3, 1, 1, bias=False), nn.Conv2d(k_in//4, k_in, 1, 1, bias=False))) 72 | up_out.append(nn.Conv2d(k_in, k, 1, 1, bias=False)) 73 | 74 | self.block_in1 = nn.ModuleList(up_in1) 75 | self.block_in2 = nn.ModuleList(up_in2) 76 | self.block_mid1 = nn.ModuleList(up_mid1) 77 | self.block_mid2 = nn.ModuleList(up_mid2) 78 | self.block_out = nn.ModuleList(up_out) 79 | self.relu = nn.ReLU() 80 | 81 | def forward(self, x, mode=0): 82 | x_tmp = self.relu(x + self.block_mid1[mode](self.block_in1[mode](x))) 83 | # x_tmp = self.block_mid2[mode](self.block_in2[mode](self.relu(x + x_tmp))) 84 | x_tmp = self.relu(x_tmp + self.block_mid2[mode](self.block_in2[mode](x_tmp))) 85 | x_tmp = self.block_out[mode](x_tmp) 86 | 87 | return x_tmp 88 | 89 | class EdgeInfoLayerC(nn.Module): 90 | def __init__(self, k_in, k_out): 91 | super(EdgeInfoLayerC, self).__init__() 92 | self.trans = nn.Sequential(nn.Conv2d(k_in, k_in, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 93 | nn.Conv2d(k_in, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 94 | nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True), 95 | nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 96 | 97 | def forward(self, x, x_size): 98 | tmp_x = [] 99 | for i_x in x: 100 | tmp_x.append(F.interpolate(i_x, x_size[2:], mode='bilinear', align_corners=True)) 101 | x = self.trans(torch.cat(tmp_x, dim=1)) 102 | return x 103 | 104 | class FuseLayer1(nn.Module): 105 | def __init__(self, list_k, deep_sup): 106 | super(FuseLayer1, self).__init__() 107 | up = [] 108 | for i in range(len(list_k)): 109 | up.append(nn.Conv2d(list_k[i], 1, 1, 1)) 110 | self.trans = nn.ModuleList(up) 111 | self.fuse = nn.Conv2d(len(list_k), 1, 1, 1) 112 | self.deep_sup = deep_sup 113 | 114 | def forward(self, list_x, x_size): 115 | up_x = [] 116 | for i, i_x in enumerate(list_x): 117 | up_x.append(F.interpolate(self.trans[i](i_x), x_size[2:], mode='bilinear', align_corners=True)) 118 | out_fuse = self.fuse(torch.cat(up_x, dim = 1)) 119 | if self.deep_sup: 120 | out_all = [] 121 | for up_i in up_x: 122 | out_all.append(up_i) 123 | return [out_fuse, out_all] 124 | else: 125 | return [out_fuse] 126 | 127 | class ScoreLayer(nn.Module): 128 | def __init__(self, k): 129 | super(ScoreLayer, self).__init__() 130 | self.score = nn.Conv2d(k ,1, 3, 1, 1) 131 | 132 | def forward(self, x, x_size=None): 133 | x = self.score(x) 134 | if x_size is not None: 135 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 136 | return x 137 | 138 | def extra_layer(base_model_cfg, base): 139 | if base_model_cfg == 'vgg': 140 | config = config_vgg 141 | elif base_model_cfg == 'res2net': 142 | config = config_resnet 143 | convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers = [], [], [], [], [], [] 144 | convert_layers = ConvertLayer(config['convert']) 145 | 146 | for k in config['block']: 147 | block_layers += [BlockLayer(k[0], k[1])] 148 | 149 | for i in range(len(config['deep_pool'][0])): 150 | deep_pool_layers += [DeepPoolLayer(config['deep_pool'][0][i], config['deep_pool'][1][i], config['deep_pool'][2][i], config['deep_pool'][3][i])] 151 | 152 | fuse_layers = FuseLayer1(config['fuse'][0], config['fuse'][1]) 153 | edgeinfo_layers = EdgeInfoLayerC(config['edgeinfoc'][0], config['edgeinfoc'][1]) 154 | score_layers = ScoreLayer(config['score']) 155 | 156 | return base, convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers 157 | 158 | 159 | class PoolNet(nn.Module): 160 | def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, block_layers, fuse_layers, edgeinfo_layers, score_layers): 161 | super(PoolNet, self).__init__() 162 | self.base_model_cfg = base_model_cfg 163 | self.base = base 164 | self.block = nn.ModuleList(block_layers) 165 | self.deep_pool = nn.ModuleList(deep_pool_layers) 166 | self.fuse = fuse_layers 167 | self.edgeinfo = edgeinfo_layers 168 | self.score = score_layers 169 | if self.base_model_cfg == 'res2net': 170 | self.convert = convert_layers 171 | 172 | def forward(self, x, mode): 173 | x_size = x.size() 174 | conv2merge, infos = self.base(x) 175 | if self.base_model_cfg == 'res2net': 176 | conv2merge = self.convert(conv2merge) 177 | conv2merge = conv2merge[::-1] 178 | 179 | edge_merge = [] 180 | merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0]) 181 | edge_merge.append(merge) 182 | for k in range(1, len(conv2merge)-1): 183 | merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k]) 184 | edge_merge.append(merge) 185 | 186 | if mode == 0: 187 | edge_merge = [self.block[i](kk) for i, kk in enumerate(edge_merge)] 188 | merge = self.fuse(edge_merge, x_size) 189 | elif mode == 1: 190 | merge = self.deep_pool[-1](merge) 191 | edge_merge = [self.block[i](kk).detach() for i, kk in enumerate(edge_merge)] 192 | edge_merge = self.edgeinfo(edge_merge, merge.size()) 193 | merge = self.score(torch.cat([merge, edge_merge], dim=1), x_size) 194 | return merge 195 | 196 | def build_model(base_model_cfg='vgg'): 197 | if base_model_cfg == 'vgg': 198 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, vgg16_locate())) 199 | elif base_model_cfg == 'res2net': 200 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, res2net50_locate())) 201 | 202 | def weights_init(m): 203 | if isinstance(m, nn.Conv2d): 204 | m.weight.data.normal_(0, 0.01) 205 | if m.bias is not None: 206 | m.bias.data.zero_() 207 | -------------------------------------------------------------------------------- /networks/poolnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import math 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from .deeplab_resnet import resnet50_locate 10 | from .vgg import vgg16_locate 11 | 12 | 13 | config_vgg = {'convert': [[128,256,512,512,512],[64,128,256,512,512]], 'deep_pool': [[512, 512, 256, 128], [512, 256, 128, 128], [True, True, True, False], [True, True, True, False]], 'score': 128} # no convert layer, no conv6 14 | 15 | config_resnet = {'convert': [[64,256,512,1024,2048],[128,256,256,512,512]], 'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]], 'score': 128} 16 | 17 | class ConvertLayer(nn.Module): 18 | def __init__(self, list_k): 19 | super(ConvertLayer, self).__init__() 20 | up = [] 21 | for i in range(len(list_k[0])): 22 | up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True))) 23 | self.convert0 = nn.ModuleList(up) 24 | 25 | def forward(self, list_x): 26 | resl = [] 27 | for i in range(len(list_x)): 28 | resl.append(self.convert0[i](list_x[i])) 29 | return resl 30 | 31 | class DeepPoolLayer(nn.Module): 32 | def __init__(self, k, k_out, need_x2, need_fuse): 33 | super(DeepPoolLayer, self).__init__() 34 | self.pools_sizes = [2,4,8] 35 | self.need_x2 = need_x2 36 | self.need_fuse = need_fuse 37 | pools, convs = [],[] 38 | for i in self.pools_sizes: 39 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 40 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 41 | self.pools = nn.ModuleList(pools) 42 | self.convs = nn.ModuleList(convs) 43 | self.relu = nn.ReLU() 44 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 45 | if self.need_fuse: 46 | self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False) 47 | 48 | def forward(self, x, x2=None, x3=None): 49 | x_size = x.size() 50 | resl = x 51 | for i in range(len(self.pools_sizes)): 52 | y = self.convs[i](self.pools[i](x)) 53 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 54 | resl = self.relu(resl) 55 | if self.need_x2: 56 | resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True) 57 | resl = self.conv_sum(resl) 58 | if self.need_fuse: 59 | resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3)) 60 | return resl 61 | 62 | class ScoreLayer(nn.Module): 63 | def __init__(self, k): 64 | super(ScoreLayer, self).__init__() 65 | self.score = nn.Conv2d(k ,1, 1, 1) 66 | 67 | def forward(self, x, x_size=None): 68 | x = self.score(x) 69 | if x_size is not None: 70 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 71 | return x 72 | 73 | def extra_layer(base_model_cfg, vgg): 74 | if base_model_cfg == 'vgg': 75 | config = config_vgg 76 | elif base_model_cfg == 'res2net': 77 | config = config_resnet 78 | convert_layers, deep_pool_layers, score_layers = [], [], [] 79 | convert_layers = ConvertLayer(config['convert']) 80 | 81 | for i in range(len(config['deep_pool'][0])): 82 | deep_pool_layers += [DeepPoolLayer(config['deep_pool'][0][i], config['deep_pool'][1][i], config['deep_pool'][2][i], config['deep_pool'][3][i])] 83 | 84 | score_layers = ScoreLayer(config['score']) 85 | 86 | return vgg, convert_layers, deep_pool_layers, score_layers 87 | 88 | 89 | class PoolNet(nn.Module): 90 | def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, score_layers): 91 | super(PoolNet, self).__init__() 92 | self.base_model_cfg = base_model_cfg 93 | self.base = base 94 | self.deep_pool = nn.ModuleList(deep_pool_layers) 95 | self.score = score_layers 96 | if self.base_model_cfg == 'resnet': 97 | self.convert = convert_layers 98 | 99 | def forward(self, x): 100 | x_size = x.size() 101 | conv2merge, infos = self.base(x) 102 | if self.base_model_cfg == 'resnet': 103 | conv2merge = self.convert(conv2merge) 104 | conv2merge = conv2merge[::-1] 105 | 106 | edge_merge = [] 107 | merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0]) 108 | for k in range(1, len(conv2merge)-1): 109 | merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k]) 110 | 111 | merge = self.deep_pool[-1](merge) 112 | merge = self.score(merge, x_size) 113 | return merge 114 | 115 | def build_model(base_model_cfg='vgg'): 116 | if base_model_cfg == 'vgg': 117 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, vgg16_locate())) 118 | elif base_model_cfg == 'resnet': 119 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, resnet50_locate())) 120 | 121 | def weights_init(m): 122 | if isinstance(m, nn.Conv2d): 123 | m.weight.data.normal_(0, 0.01) 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | -------------------------------------------------------------------------------- /networks/poolnet_res2net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | import math 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from .deeplab_res2net import res2net50_locate 10 | from .vgg import vgg16_locate 11 | 12 | 13 | config_vgg = {'convert': [[128,256,512,512,512],[64,128,256,512,512]], 'deep_pool': [[512, 512, 256, 128], [512, 256, 128, 128], [True, True, True, False], [True, True, True, False]], 'score': 128} # no convert layer, no conv6 14 | 15 | config_resnet = {'convert': [[64,256,512,1024,2048],[128,256,256,512,512]], 'deep_pool': [[512, 512, 256, 256, 128], [512, 256, 256, 128, 128], [False, True, True, True, False], [True, True, True, True, False]], 'score': 128} 16 | 17 | class ConvertLayer(nn.Module): 18 | def __init__(self, list_k): 19 | super(ConvertLayer, self).__init__() 20 | up = [] 21 | for i in range(len(list_k[0])): 22 | up.append(nn.Sequential(nn.Conv2d(list_k[0][i], list_k[1][i], 1, 1, bias=False), nn.ReLU(inplace=True))) 23 | self.convert0 = nn.ModuleList(up) 24 | 25 | def forward(self, list_x): 26 | resl = [] 27 | for i in range(len(list_x)): 28 | resl.append(self.convert0[i](list_x[i])) 29 | return resl 30 | 31 | class DeepPoolLayer(nn.Module): 32 | def __init__(self, k, k_out, need_x2, need_fuse): 33 | super(DeepPoolLayer, self).__init__() 34 | self.pools_sizes = [2,4,8] 35 | self.need_x2 = need_x2 36 | self.need_fuse = need_fuse 37 | pools, convs = [],[] 38 | for i in self.pools_sizes: 39 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 40 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 41 | self.pools = nn.ModuleList(pools) 42 | self.convs = nn.ModuleList(convs) 43 | self.relu = nn.ReLU() 44 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 45 | if self.need_fuse: 46 | self.conv_sum_c = nn.Conv2d(k_out, k_out, 3, 1, 1, bias=False) 47 | 48 | def forward(self, x, x2=None, x3=None): 49 | x_size = x.size() 50 | resl = x 51 | for i in range(len(self.pools_sizes)): 52 | y = self.convs[i](self.pools[i](x)) 53 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 54 | resl = self.relu(resl) 55 | if self.need_x2: 56 | resl = F.interpolate(resl, x2.size()[2:], mode='bilinear', align_corners=True) 57 | resl = self.conv_sum(resl) 58 | if self.need_fuse: 59 | resl = self.conv_sum_c(torch.add(torch.add(resl, x2), x3)) 60 | return resl 61 | 62 | class ScoreLayer(nn.Module): 63 | def __init__(self, k): 64 | super(ScoreLayer, self).__init__() 65 | self.score = nn.Conv2d(k ,1, 1, 1) 66 | 67 | def forward(self, x, x_size=None): 68 | x = self.score(x) 69 | if x_size is not None: 70 | x = F.interpolate(x, x_size[2:], mode='bilinear', align_corners=True) 71 | return x 72 | 73 | def extra_layer(base_model_cfg, vgg): 74 | if base_model_cfg == 'vgg': 75 | config = config_vgg 76 | elif base_model_cfg == 'res2net': 77 | config = config_resnet 78 | convert_layers, deep_pool_layers, score_layers = [], [], [] 79 | convert_layers = ConvertLayer(config['convert']) 80 | 81 | for i in range(len(config['deep_pool'][0])): 82 | deep_pool_layers += [DeepPoolLayer(config['deep_pool'][0][i], config['deep_pool'][1][i], config['deep_pool'][2][i], config['deep_pool'][3][i])] 83 | 84 | score_layers = ScoreLayer(config['score']) 85 | 86 | return vgg, convert_layers, deep_pool_layers, score_layers 87 | 88 | 89 | class PoolNet(nn.Module): 90 | def __init__(self, base_model_cfg, base, convert_layers, deep_pool_layers, score_layers): 91 | super(PoolNet, self).__init__() 92 | self.base_model_cfg = base_model_cfg 93 | self.base = base 94 | self.deep_pool = nn.ModuleList(deep_pool_layers) 95 | self.score = score_layers 96 | if self.base_model_cfg == 'res2net': 97 | self.convert = convert_layers 98 | 99 | def forward(self, x): 100 | x_size = x.size() 101 | conv2merge, infos = self.base(x) 102 | if self.base_model_cfg == 'res2net': 103 | conv2merge = self.convert(conv2merge) 104 | conv2merge = conv2merge[::-1] 105 | 106 | edge_merge = [] 107 | merge = self.deep_pool[0](conv2merge[0], conv2merge[1], infos[0]) 108 | for k in range(1, len(conv2merge)-1): 109 | merge = self.deep_pool[k](merge, conv2merge[k+1], infos[k]) 110 | 111 | merge = self.deep_pool[-1](merge) 112 | merge = self.score(merge, x_size) 113 | return merge 114 | 115 | def build_model(base_model_cfg='res2net'): 116 | print("base_model_cfg",base_model_cfg) 117 | if base_model_cfg == 'vgg': 118 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, vgg16_locate())) 119 | elif base_model_cfg == 'res2net': 120 | return PoolNet(base_model_cfg, *extra_layer(base_model_cfg, res2net50_locate())) 121 | 122 | def weights_init(m): 123 | if isinstance(m, nn.Conv2d): 124 | m.weight.data.normal_(0, 0.01) 125 | if m.bias is not None: 126 | m.bias.data.zero_() 127 | 128 | if __name__ == '__main__': 129 | #images = torch.rand(2, 3, 224, 224) 130 | images = torch.rand(1, 3, 224, 224).cuda(0) 131 | model = build_model(base_model_cfg='resnet') 132 | model = model.cuda(0) 133 | total = sum([param.nelement() for param in model.parameters()]) 134 | print(' + Number of params: %.4fM' % (total / 1e6)) 135 | print(model(images).size()) 136 | print('Memory useage: %.4fM' % ( torch.cuda.max_memory_allocated() / 1024.0 / 1024.0)) -------------------------------------------------------------------------------- /networks/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | # vgg16 8 | def vgg(cfg, i, batch_norm=False): 9 | layers = [] 10 | in_channels = i 11 | stage = 1 12 | for v in cfg: 13 | if v == 'M': 14 | stage += 1 15 | if stage == 6: 16 | layers += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1)] 17 | else: 18 | layers += [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)] 19 | else: 20 | if stage == 6: 21 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 22 | else: 23 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 24 | if batch_norm: 25 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 26 | else: 27 | layers += [conv2d, nn.ReLU(inplace=True)] 28 | in_channels = v 29 | return layers 30 | 31 | class vgg16(nn.Module): 32 | def __init__(self): 33 | super(vgg16, self).__init__() 34 | self.cfg = {'tun': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'tun_ex': [512, 512, 512]} 35 | self.extract = [8, 15, 22, 29] # [3, 8, 15, 22, 29] 36 | self.base = nn.ModuleList(vgg(self.cfg['tun'], 3)) 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 40 | m.weight.data.normal_(0, 0.01) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.weight.data.fill_(1) 43 | m.bias.data.zero_() 44 | 45 | def load_pretrained_model(self, model): 46 | self.base.load_state_dict(model, strict=False) 47 | 48 | def forward(self, x): 49 | tmp_x = [] 50 | for k in range(len(self.base)): 51 | x = self.base[k](x) 52 | if k in self.extract: 53 | tmp_x.append(x) 54 | return tmp_x 55 | 56 | class vgg16_locate(nn.Module): 57 | def __init__(self): 58 | super(vgg16_locate,self).__init__() 59 | self.vgg16 = vgg16() 60 | self.in_planes = 512 61 | self.out_planes = [512, 256, 128] 62 | 63 | ppms, infos = [], [] 64 | for ii in [1, 3, 5]: 65 | ppms.append(nn.Sequential(nn.AdaptiveAvgPool2d(ii), nn.Conv2d(self.in_planes, self.in_planes, 1, 1, bias=False), nn.ReLU(inplace=True))) 66 | self.ppms = nn.ModuleList(ppms) 67 | 68 | self.ppm_cat = nn.Sequential(nn.Conv2d(self.in_planes * 4, self.in_planes, 3, 1, 1, bias=False), nn.ReLU(inplace=True)) 69 | for ii in self.out_planes: 70 | infos.append(nn.Sequential(nn.Conv2d(self.in_planes, ii, 3, 1, 1, bias=False), nn.ReLU(inplace=True))) 71 | self.infos = nn.ModuleList(infos) 72 | 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv2d): 75 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 76 | m.weight.data.normal_(0, 0.01) 77 | elif isinstance(m, nn.BatchNorm2d): 78 | m.weight.data.fill_(1) 79 | m.bias.data.zero_() 80 | 81 | def load_pretrained_model(self, model): 82 | self.vgg16.load_pretrained_model(model) 83 | 84 | def forward(self, x): 85 | x_size = x.size()[2:] 86 | xs = self.vgg16(x) 87 | 88 | xls = [xs[-1]] 89 | for k in range(len(self.ppms)): 90 | xls.append(F.interpolate(self.ppms[k](xs[-1]), xs[-1].size()[2:], mode='bilinear', align_corners=True)) 91 | xls = self.ppm_cat(torch.cat(xls, dim=1)) 92 | infos = [] 93 | for k in range(len(self.infos)): 94 | infos.append(self.infos[k](F.interpolate(xls, xs[len(self.infos) - 1 - k].size()[2:], mode='bilinear', align_corners=True))) 95 | 96 | return xs, infos 97 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from torch.nn import utils, functional as F 4 | from torch.optim import Adam 5 | from torch.autograd import Variable 6 | from torch.backends import cudnn 7 | from networks.poolnet_res2net import build_model, weights_init 8 | import scipy.misc as sm 9 | import numpy as np 10 | import os 11 | import torchvision.utils as vutils 12 | import cv2 13 | import math 14 | import time 15 | 16 | 17 | class Solver(object): 18 | def __init__(self, train_loader, test_loader, config): 19 | self.train_loader = train_loader 20 | self.test_loader = test_loader 21 | self.config = config 22 | self.iter_size = config.iter_size 23 | self.show_every = config.show_every 24 | self.lr_decay_epoch = [15,] 25 | self.build_model() 26 | if config.mode == 'test': 27 | print('Loading pre-trained model from %s...' % self.config.model) 28 | if self.config.cuda: 29 | self.net.load_state_dict(torch.load(self.config.model)) 30 | else: 31 | self.net.load_state_dict(torch.load(self.config.model, map_location='cpu')) 32 | self.net.eval() 33 | 34 | # print the network information and parameter numbers 35 | def print_network(self, model, name): 36 | num_params = 0 37 | for p in model.parameters(): 38 | num_params += p.numel() 39 | print(name) 40 | print(model) 41 | print("The number of parameters: {}".format(num_params)) 42 | 43 | # build the network 44 | def build_model(self): 45 | self.net = build_model(self.config.arch) 46 | if self.config.cuda: 47 | self.net = self.net.cuda() 48 | # self.net.train() 49 | self.net.eval() # use_global_stats = True 50 | self.net.apply(weights_init) 51 | if self.config.load == '': 52 | self.net.base.load_pretrained_model(torch.load(self.config.pretrained_model)) 53 | else: 54 | self.net.load_state_dict(torch.load(self.config.load)) 55 | 56 | self.lr = self.config.lr 57 | self.wd = self.config.wd 58 | 59 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 60 | self.print_network(self.net, 'PoolNet Structure') 61 | 62 | def test(self): 63 | mode_name = 'sal_fuse' 64 | time_s = time.time() 65 | img_num = len(self.test_loader) 66 | for i, data_batch in enumerate(self.test_loader): 67 | images, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) 68 | with torch.no_grad(): 69 | images = Variable(images) 70 | if self.config.cuda: 71 | images = images.cuda() 72 | preds = self.net(images) 73 | pred = np.squeeze(torch.sigmoid(preds).cpu().data.numpy()) 74 | multi_fuse = 255 * pred 75 | cv2.imwrite(os.path.join(self.config.test_fold, name[:-4] + '_' + mode_name + '.png'), multi_fuse) 76 | time_e = time.time() 77 | print('Speed: %f FPS' % (img_num/(time_e-time_s))) 78 | print('Test Done!') 79 | 80 | # training phase 81 | def train(self): 82 | iter_num = len(self.train_loader.dataset) // self.config.batch_size 83 | aveGrad = 0 84 | for epoch in range(self.config.epoch): 85 | r_sal_loss= 0 86 | self.net.zero_grad() 87 | for i, data_batch in enumerate(self.train_loader): 88 | sal_image, sal_label = data_batch['sal_image'], data_batch['sal_label'] 89 | if (sal_image.size(2) != sal_label.size(2)) or (sal_image.size(3) != sal_label.size(3)): 90 | print('IMAGE ERROR, PASSING```') 91 | continue 92 | sal_image, sal_label= Variable(sal_image), Variable(sal_label) 93 | if self.config.cuda: 94 | # cudnn.benchmark = True 95 | sal_image, sal_label = sal_image.cuda(), sal_label.cuda() 96 | 97 | sal_pred = self.net(sal_image) 98 | sal_loss_fuse = F.binary_cross_entropy_with_logits(sal_pred, sal_label, reduction='sum') 99 | sal_loss = sal_loss_fuse / (self.iter_size * self.config.batch_size) 100 | r_sal_loss += sal_loss.data 101 | 102 | sal_loss.backward() 103 | 104 | aveGrad += 1 105 | 106 | # accumulate gradients as done in DSS 107 | if aveGrad % self.iter_size == 0: 108 | self.optimizer.step() 109 | self.optimizer.zero_grad() 110 | aveGrad = 0 111 | 112 | if i % (self.show_every // self.config.batch_size) == 0: 113 | if i == 0: 114 | x_showEvery = 1 115 | print('epoch: [%2d/%2d], iter: [%5d/%5d] || Sal : %10.4f' % ( 116 | epoch, self.config.epoch, i, iter_num, r_sal_loss/x_showEvery)) 117 | print('Learning rate: ' + str(self.lr)) 118 | r_sal_loss= 0 119 | 120 | if (epoch + 1) % self.config.epoch_save == 0: 121 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_folder, epoch + 1)) 122 | 123 | if epoch in self.lr_decay_epoch: 124 | self.lr = self.lr * 0.1 125 | self.optimizer = Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=self.lr, weight_decay=self.wd) 126 | 127 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_folder) 128 | 129 | def bce2d(input, target, reduction=None): 130 | assert(input.size() == target.size()) 131 | pos = torch.eq(target, 1).float() 132 | neg = torch.eq(target, 0).float() 133 | 134 | num_pos = torch.sum(pos) 135 | num_neg = torch.sum(neg) 136 | num_total = num_pos + num_neg 137 | 138 | alpha = num_neg / num_total 139 | beta = 1.1 * num_pos / num_total 140 | # target pixel = 1 -> weight beta 141 | # target pixel = 0 -> weight 1-beta 142 | weights = alpha * pos + beta * neg 143 | 144 | return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) 145 | 146 | -------------------------------------------------------------------------------- /train_res2net.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | python main.py --arch res2net --train_root ./data/DUTS/DUTS-TR --train_list ./data/DUTS/DUTS-TR/train_pair.lst 4 | # you can optionly change the -lr and -wd params 5 | --------------------------------------------------------------------------------