├── png ├── dss.png ├── loss.png ├── side.png ├── example.png ├── demo_anno.png └── demo_img.jpg ├── tools ├── extract_vgg.py ├── crf_process.py └── visual.py ├── loss.py ├── LICENSE ├── dataset.py ├── README.md ├── main.py ├── dssnet.py └── solver.py /png/dss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/dss.png -------------------------------------------------------------------------------- /png/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/loss.png -------------------------------------------------------------------------------- /png/side.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/side.png -------------------------------------------------------------------------------- /png/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/example.png -------------------------------------------------------------------------------- /png/demo_anno.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/demo_anno.png -------------------------------------------------------------------------------- /png/demo_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/demo_img.jpg -------------------------------------------------------------------------------- /tools/extract_vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import models 4 | 5 | # extract vgg features 6 | if __name__ == '__main__': 7 | save_fold = '../weights' 8 | if not os.path.exists(save_fold): 9 | os.mkdir(save_fold) 10 | vgg = models.vgg16(pretrained=True) 11 | torch.save(vgg.features.state_dict(), os.path.join(save_fold, 'vgg16_feat.pth')) 12 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | # loss function: seven probability map --- 6 scale + 1 fuse 6 | class Loss(nn.Module): 7 | def __init__(self, weight=[1.0] * 7): 8 | super(Loss, self).__init__() 9 | self.weight = weight 10 | 11 | def forward(self, x_list, label): 12 | loss = self.weight[0] * F.binary_cross_entropy(x_list[0], label) 13 | for i, x in enumerate(x_list[1:]): 14 | loss += self.weight[i + 1] * F.binary_cross_entropy(x, label) 15 | return loss 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ace 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tools/crf_process.py: -------------------------------------------------------------------------------- 1 | # reference: https://github.com/Andrew-Qibin/dss_crf/blob/master/examples/dense_hsal.py 2 | import torch 3 | import numpy as np 4 | import pydensecrf.densecrf as dcrf 5 | 6 | 7 | def sigmoid(x): 8 | return 1 / (1 + np.exp(-x)) 9 | 10 | 11 | # parameter 12 | EPSILON = 1e-8 13 | tau = 1.05 14 | 15 | 16 | # img: PIL 17 | # anno: numpy 18 | def crf(img, anno, to_tensor=False): 19 | img = np.array(img) 20 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 2) 21 | n_energy = -np.log((1.0 - anno + EPSILON)) / (tau * sigmoid(1 - anno)) 22 | p_energy = -np.log(anno + EPSILON) / (tau * sigmoid(anno)) 23 | U = np.zeros((2, img.shape[0] * img.shape[1]), dtype='float32') 24 | U[0, :] = n_energy.flatten() 25 | U[1, :] = p_energy.flatten() 26 | d.setUnaryEnergy(U) 27 | 28 | d.addPairwiseGaussian(sxy=3, compat=3) 29 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5) 30 | 31 | # Do the inference 32 | infer = np.array(d.inference(1)).astype('float32') 33 | res = np.expand_dims(infer[1, :].reshape(img.shape[:2]), 0) 34 | if to_tensor: 35 | res = torch.from_numpy(res).unsqueeze(0) 36 | return res 37 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils import data 5 | from torchvision import transforms 6 | 7 | 8 | class ImageData(data.Dataset): 9 | """ image dataset 10 | img_root: image root (root which contain images) 11 | label_root: label root (root which contains labels) 12 | transform: pre-process for image 13 | t_transform: pre-process for label 14 | filename: MSRA-B use xxx.txt to recognize train-val-test data (only for MSRA-B) 15 | """ 16 | 17 | def __init__(self, img_root, label_root, transform, t_transform, filename=None): 18 | if filename is None: 19 | self.image_path = list(map(lambda x: os.path.join(img_root, x), os.listdir(img_root))) 20 | self.label_path = list( 21 | map(lambda x: os.path.join(label_root, x.split('/')[-1][:-3] + 'png'), self.image_path)) 22 | else: 23 | lines = [line.rstrip('\n')[:-3] for line in open(filename)] 24 | self.image_path = list(map(lambda x: os.path.join(img_root, x + 'jpg'), lines)) 25 | self.label_path = list(map(lambda x: os.path.join(label_root, x + 'png'), lines)) 26 | 27 | self.transform = transform 28 | self.t_transform = t_transform 29 | 30 | def __getitem__(self, item): 31 | image = Image.open(self.image_path[item]) 32 | label = Image.open(self.label_path[item]).convert('L') 33 | if self.transform is not None: 34 | image = self.transform(image) 35 | if self.t_transform is not None: 36 | label = self.t_transform(label) 37 | return image, label 38 | 39 | def __len__(self): 40 | return len(self.image_path) 41 | 42 | 43 | # get the dataloader (Note: without data augmentation) 44 | def get_loader(img_root, label_root, img_size, batch_size, filename=None, mode='train', num_thread=4, pin=True): 45 | if mode == 'train': 46 | transform = transforms.Compose([ 47 | transforms.Resize((img_size, img_size)), 48 | transforms.ToTensor(), 49 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 50 | ]) 51 | t_transform = transforms.Compose([ 52 | transforms.Resize((img_size, img_size)), 53 | transforms.ToTensor(), 54 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary 55 | ]) 56 | dataset = ImageData(img_root, label_root, transform, t_transform, filename=filename) 57 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread, 58 | pin_memory=pin) 59 | return data_loader 60 | else: 61 | t_transform = transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary 64 | ]) 65 | dataset = ImageData(img_root, label_root, None, t_transform, filename=filename) 66 | return dataset 67 | 68 | 69 | if __name__ == '__main__': 70 | import numpy as np 71 | img_root = '/home/ace/data/MSRA-B/image' 72 | label_root = '/home/ace/data/MSRA-B/annotation' 73 | filename = '/home/ace/data/MSRA-B/train_cvpr2013.txt' 74 | loader = get_loader(img_root, label_root, 224, 1, filename=filename, mode='test') 75 | for image, label in loader: 76 | print(np.array(image).shape) 77 | break 78 | -------------------------------------------------------------------------------- /tools/visual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | class Viz_visdom(object): 7 | def __init__(self, name, display_id=0): 8 | self.name = name 9 | self.display_id = display_id 10 | self.idx = display_id 11 | self.plot_data = {} 12 | if display_id > 0: 13 | import visdom 14 | self.vis = visdom.Visdom(port=8097) 15 | 16 | def plot_current_errors(self, epoch, counter_ratio, errors, idx=0): 17 | 18 | if idx not in self.plot_data: 19 | self.plot_data[idx] = {'X': [], 'Y': [], 'legend': list(errors.keys())} 20 | # self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} 21 | self.plot_data[idx]['X'].append(epoch + counter_ratio) 22 | self.plot_data[idx]['Y'].append([errors[k] for k in self.plot_data[idx]['legend']]) 23 | self.vis.line( 24 | X=np.stack([np.array(self.plot_data[idx]['X'])] * len(self.plot_data[idx]['legend']), 1) 25 | if len(errors) > 1 else np.array(self.plot_data[idx]['X']), 26 | Y=np.array(self.plot_data[idx]['Y']) if len(errors) > 1 else np.array(self.plot_data[idx]['Y'])[:, 0], 27 | opts={ 28 | 'title': self.name + ' loss over time %d' % idx, 29 | 'legend': self.plot_data[idx]['legend'], 30 | 'xlabel': 'epoch', 31 | 'ylabel': 'loss'}, 32 | win=self.display_id + idx) 33 | if self.idx < self.display_id + idx: 34 | self.idx = self.display_id + idx 35 | 36 | def plot_current_img(self, visuals, c_prev=True): 37 | idx = self.idx + 1 38 | for label, image_numpy in visuals.items(): 39 | if c_prev: 40 | self.vis.image(image_numpy, opts=dict(title=label), 41 | win=self.display_id + idx) 42 | else: 43 | image_numpy = image_numpy.swapaxes(0, 2).swapaxes(1, 2) 44 | self.vis.image(image_numpy, opts=dict(title=label), 45 | win=self.display_id + idx) 46 | idx += 1 47 | 48 | 49 | # reference: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 50 | def plot_image(inp, fig_size, title=None, swap_channel=False, norm=False): 51 | """Imshow for Tensor.""" 52 | if torch.is_tensor(inp): 53 | inp = inp.numpy().transpose((1, 2, 0)) if swap_channel else inp.numpy() 54 | else: 55 | inp = inp.transpose((1, 2, 0)) if swap_channel else inp 56 | if norm: 57 | mean = np.array([0.485, 0.456, 0.406]) 58 | std = np.array([0.229, 0.224, 0.225]) 59 | inp = std * inp + mean 60 | inp = np.clip(inp, 0, 1) 61 | plt.figure(figsize=fig_size) 62 | if inp.shape[0] == 1: 63 | plt.imshow(inp[0], cmap='gray') 64 | else: 65 | plt.imshow(inp) 66 | if title is not None: 67 | plt.title(title) 68 | plt.pause(0.0001) # pause a bit so that plots are updated 69 | 70 | 71 | def make_simple_grid(inp, padding=2, padding_value=1): 72 | inp = torch.stack(inp, dim=0) 73 | nmaps = inp.size(0) 74 | height, width = inp.size(2), int(inp.size(3) + padding) 75 | grid = inp.new(1, height, width * nmaps + padding).fill_(padding_value) 76 | for i in range(nmaps): 77 | grid.narrow(2, i * width + padding, width - padding).copy_(inp[i]) 78 | return grid 79 | 80 | 81 | if __name__ == '__main__': 82 | inp = [torch.randn(1, 5, 5), torch.randn(1, 5, 5)] 83 | out = make_simple_grid(inp) 84 | print(out.size()) 85 | plot_image(out) 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSS-PyTorch 2 | PyTorch implement of [Deeply Supervised Salient Object Detection with Short Connection](https://arxiv.org/abs/1611.04849) 3 | 4 |

5 | 6 | The official caffe version: [DSS](https://github.com/Andrew-Qibin/DSS) 7 | 8 | ## Prerequisites 9 | 10 | - [Python 3](https://www.continuum.io/downloads) 11 | - [Pytorch 0.4.1+](http://pytorch.org/) 12 | - [torchvision](http://pytorch.org/) 13 | - [visdom](https://github.com/facebookresearch/visdom) (optional for visualization) 14 | - [PyDenseCRF](https://github.com/lucasb-eyer/pydensecrf)(optional for CRF post-process) 15 | 16 | ## Results 17 | 18 | The information of Loss: 19 | 20 | ![](./png/loss.png) 21 | 22 | Example output: 23 | 24 | ![](png/example.png) 25 | 26 | > Note: here the "blur boundary" caused by bad combine method 27 | 28 | Different connection output: 29 | 30 | ![](png/side.png) 31 | 32 | #### Some difference 33 | 34 | 1. The original paper use:$Z=h(\sum_{i=2}^4 f_mR^{(m)})$,here we use $Z=h(\sum_{i=1}^6 f_mR^{(m)})$ in inference stage 35 | 36 | #### Results Reproduct 37 | 38 | | Dataset (MSRA-B) | Paper | Here (v1) | Only Fusion (v1) | Here (v2) | Only Fusion (v2) | Here(v2 700) | 39 | | :------------------: | :---: | :-------: | :--------------: | :-------: | :--------------: | :----------: | 40 | | MAE (without CRF) | 0.043 | 0.054 | 0.052 | 0.068 | 0.052 | 0.051 | 41 | | F_beta (without CRF) | 0.920 | 0.910 | 0.914 | 0.912 | 0.910 | 0.918 | 42 | | MAE (with CRF) | 0.028 | 0.047 | 0.048 | 0.047 | 0.049 | 0.047 | 43 | | F_beta (with CRF) | 0.927 | 0.916 | 0.917 | 0.915 | 0.918 | 0.923 | 44 | 45 | Note: 46 | 47 | 1. v1 means use average fusion , v2 means use learnable fusion 48 | 2. You can try to use other "inference stragedy"(I think other combine can get better results --- here use sout-2+sout-3+sout-4+fusion --- you can just change [self.select](https://github.com/AceCoooool/DSS-pytorch/blob/66419dee7045f4581e7e18f910ca98e1a596705a/solver.py#L20)) 49 | 3. v2 700 means training with 700 epochs. (I use pre-trained model by 500 epochs:so the optimizer is a little differnt to direct 700 eopch) 50 | 51 | ## Usage 52 | 53 | ### 1. Clone the repository 54 | 55 | ```shell 56 | git clone git@github.com:AceCoooool/DSS-pytorch.git 57 | cd DSS-pytorch/ 58 | ``` 59 | 60 | ### 2. Download the dataset 61 | 62 | Download the [MSRA-B](http://mmcheng.net/zh/msra10k/) dataset. (If you can not find this dataset, email to me --- I am not sure whether it's legal to put it on BaiDuYun) 63 | 64 | ```shell 65 | # file construction 66 | MSRA-B 67 | --- annotation 68 | --- xxx.png 69 | --- xxx.png 70 | --- image 71 | --- xxx.jpg 72 | --- xxx.jpg 73 | --- test_cvpr2013.txt 74 | --- train_cvpr2013.txt 75 | --- valid_cvpr2013.txt 76 | --- test_cvpr2013_debug.txt 77 | --- train_cvpr2013_debug.txt 78 | --- valid_cvpr2013_debug.txt 79 | ``` 80 | 81 | ### 3. Get pre-trained vgg 82 | 83 | ```bash 84 | cd tools/ 85 | python extract_vgg.py 86 | cd .. 87 | ``` 88 | 89 | ### 4. Demo 90 | 91 | pleease see `demo.ipynb` 92 | 93 | Note: 94 | 95 | 1. default choose: download and copy the [pretrained model](https://pan.baidu.com/s/10XmHVMAOp1ewoJXhI0nRgA) to `weights` directory 96 | 97 | ### 5. Train 98 | 99 | ```shell 100 | python main.py --mode='train' --train_path='you_data' --label_path='you_label' --batch_size=8 --visdom=True --train_file='you_file' 101 | ``` 102 | 103 | Note: 104 | 105 | 1. `--val=True` add the validation (but your need to add the `--val_path`, `--val_file` and `--val_label`) 106 | 2. `you_data, you_label` means your training data root. (connect to the step 2) 107 | 3. If you Download the data to `youhome/data/MSRA-B`(you can not "implicity" the path) 108 | 109 | ### 6. Test 110 | 111 | ```shell 112 | python main.py --mode='test' --test_path='you_data' --test_label='your_label' --use_crf=False --model='your_trained_model' --test_file='you_file' 113 | ``` 114 | 115 | Note: 116 | 117 | 1. only support `bath_size=1` 118 | 2. `--use_crf=True`:means use CRF post-process -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataset import get_loader 4 | from solver import Solver 5 | 6 | 7 | def main(config): 8 | if config.mode == 'train': 9 | train_loader = get_loader(config.train_path, config.label_path, config.img_size, config.batch_size, 10 | filename=config.train_file, num_thread=config.num_thread) 11 | if config.val: 12 | val_loader = get_loader(config.val_path, config.val_label, config.img_size, config.batch_size, 13 | filename=config.val_file, num_thread=config.num_thread) 14 | run = 0 15 | while os.path.exists("%s/run-%d" % (config.save_fold, run)): run += 1 16 | os.mkdir("%s/run-%d" % (config.save_fold, run)) 17 | os.mkdir("%s/run-%d/logs" % (config.save_fold, run)) 18 | # os.mkdir("%s/run-%d/images" % (config.save_fold, run)) 19 | os.mkdir("%s/run-%d/models" % (config.save_fold, run)) 20 | config.save_fold = "%s/run-%d" % (config.save_fold, run) 21 | if config.val: 22 | train = Solver(train_loader, val_loader, None, config) 23 | else: 24 | train = Solver(train_loader, None, None, config) 25 | train.train() 26 | elif config.mode == 'test': 27 | test_loader = get_loader(config.test_path, config.test_label, config.img_size, config.batch_size, mode='test', 28 | filename=config.test_file, num_thread=config.num_thread) 29 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold) 30 | test = Solver(None, None, test_loader, config) 31 | test.test(100, use_crf=config.use_crf) 32 | else: 33 | raise IOError("illegal input!!!") 34 | 35 | 36 | if __name__ == '__main__': 37 | data_root = os.path.join(os.path.expanduser('~'), 'data') 38 | vgg_path = './weights/vgg16_feat.pth' 39 | # # -----ECSSD dataset----- 40 | # train_path = os.path.join(data_root, 'ECSSD/images') 41 | # label_path = os.path.join(data_root, 'ECSSD/ground_truth_mask') 42 | # 43 | # val_path = os.path.join(data_root, 'ECSSD/val_images') 44 | # val_label = os.path.join(data_root, 'ECSSD/val_ground_truth_mask') 45 | # test_path = os.path.join(data_root, 'ECSSD/test_images') 46 | # test_label = os.path.join(data_root, 'ECSSD/test_ground_truth_mask') 47 | # # -----MSRA-B dataset----- 48 | image_path = os.path.join(data_root, 'MSRA-B/image') 49 | label_path = os.path.join(data_root, 'MSRA-B/annotation') 50 | train_file = os.path.join(data_root, 'MSRA-B/train_cvpr2013.txt') 51 | valid_file = os.path.join(data_root, 'MSRA-B/valid_cvpr2013.txt') 52 | test_file = os.path.join(data_root, 'MSRA-B/test_cvpr2013.txt') 53 | parser = argparse.ArgumentParser() 54 | 55 | # Hyper-parameters 56 | parser.add_argument('--n_color', type=int, default=3) 57 | parser.add_argument('--img_size', type=int, default=256) # 256 58 | parser.add_argument('--lr', type=float, default=1e-6) 59 | parser.add_argument('--clip_gradient', type=float, default=1.0) 60 | parser.add_argument('--cuda', type=bool, default=True) 61 | 62 | # Training settings 63 | parser.add_argument('--vgg', type=str, default=vgg_path) 64 | parser.add_argument('--train_path', type=str, default=image_path) 65 | parser.add_argument('--label_path', type=str, default=label_path) 66 | parser.add_argument('--train_file', type=str, default=train_file) 67 | parser.add_argument('--epoch', type=int, default=500) 68 | parser.add_argument('--batch_size', type=int, default=1) # 8 69 | parser.add_argument('--val', type=bool, default=True) 70 | parser.add_argument('--val_path', type=str, default=image_path) 71 | parser.add_argument('--val_label', type=str, default=label_path) 72 | parser.add_argument('--val_file', type=str, default=valid_file) 73 | parser.add_argument('--num_thread', type=int, default=4) 74 | parser.add_argument('--load', type=str, default='') 75 | parser.add_argument('--save_fold', type=str, default='./results') 76 | parser.add_argument('--epoch_val', type=int, default=10) 77 | parser.add_argument('--epoch_save', type=int, default=20) 78 | parser.add_argument('--epoch_show', type=int, default=1) 79 | parser.add_argument('--pre_trained', type=str, default=None) 80 | 81 | # Testing settings 82 | parser.add_argument('--test_path', type=str, default=image_path) 83 | parser.add_argument('--test_label', type=str, default=label_path) 84 | parser.add_argument('--test_file', type=str, default=test_file) 85 | parser.add_argument('--model', type=str, default='./weights/final.pth') 86 | parser.add_argument('--test_fold', type=str, default='./results/test') 87 | parser.add_argument('--use_crf', type=bool, default=False) 88 | 89 | # Misc 90 | parser.add_argument('--mode', type=str, default='test', choices=['train', 'test']) 91 | parser.add_argument('--visdom', type=bool, default=False) 92 | 93 | config = parser.parse_args() 94 | if not os.path.exists(config.save_fold): os.mkdir(config.save_fold) 95 | main(config) 96 | -------------------------------------------------------------------------------- /dssnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | 5 | # vgg choice 6 | base = {'dss': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']} 7 | # extend vgg choice --- follow the paper, you can change it 8 | extra = {'dss': [(64, 128, 3, [8, 16, 32, 64]), (128, 128, 3, [4, 8, 16, 32]), (256, 256, 5, [8, 16]), 9 | (512, 256, 5, [4, 8]), (512, 512, 5, []), (512, 512, 7, [])]} 10 | connect = {'dss': [[2, 3, 4, 5], [2, 3, 4, 5], [4, 5], [4, 5], [], []]} 11 | 12 | 13 | # vgg16 14 | def vgg(cfg, i=3, batch_norm=False): 15 | layers = [] 16 | in_channels = i 17 | for v in cfg: 18 | if v == 'M': 19 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 20 | else: 21 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 22 | if batch_norm: 23 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 24 | else: 25 | layers += [conv2d, nn.ReLU(inplace=True)] 26 | in_channels = v 27 | return layers 28 | 29 | 30 | # feature map before sigmoid: build the connection and deconvolution 31 | class ConcatLayer(nn.Module): 32 | def __init__(self, list_k, k, scale=True): 33 | super(ConcatLayer, self).__init__() 34 | l, up, self.scale = len(list_k), [], scale 35 | for i in range(l): 36 | up.append(nn.ConvTranspose2d(1, 1, list_k[i], list_k[i] // 2, list_k[i] // 4)) 37 | self.upconv = nn.ModuleList(up) 38 | self.conv = nn.Conv2d(l + 1, 1, 1, 1) 39 | self.deconv = nn.ConvTranspose2d(1, 1, k * 2, k, k // 2) if scale else None 40 | 41 | def forward(self, x, list_x): 42 | elem_x = [x] 43 | for i, elem in enumerate(list_x): 44 | elem_x.append(self.upconv[i](elem)) 45 | if self.scale: 46 | out = self.deconv(self.conv(torch.cat(elem_x, dim=1))) 47 | else: 48 | out = self.conv(torch.cat(elem_x, dim=1)) 49 | return out 50 | 51 | 52 | # extend vgg: side outputs 53 | class FeatLayer(nn.Module): 54 | def __init__(self, in_channel, channel, k): 55 | super(FeatLayer, self).__init__() 56 | self.main = nn.Sequential(nn.Conv2d(in_channel, channel, k, 1, k // 2), nn.ReLU(inplace=True), 57 | nn.Conv2d(channel, channel, k, 1, k // 2), nn.ReLU(inplace=True), 58 | nn.Conv2d(channel, 1, 1, 1)) 59 | 60 | def forward(self, x): 61 | return self.main(x) 62 | 63 | 64 | # fusion features 65 | class FusionLayer(nn.Module): 66 | def __init__(self, nums=6): 67 | super(FusionLayer, self).__init__() 68 | self.weights = nn.Parameter(torch.randn(nums)) 69 | self.nums = nums 70 | self._reset_parameters() 71 | 72 | def _reset_parameters(self): 73 | init.constant_(self.weights, 1 / self.nums) 74 | 75 | def forward(self, x): 76 | for i in range(self.nums): 77 | out = self.weights[i] * x[i] if i == 0 else out + self.weights[i] * x[i] 78 | return out 79 | 80 | 81 | # extra part 82 | def extra_layer(vgg, cfg): 83 | feat_layers, concat_layers, scale = [], [], 1 84 | for k, v in enumerate(cfg): 85 | # side output (paper: figure 3) 86 | feat_layers += [FeatLayer(v[0], v[1], v[2])] 87 | # feature map before sigmoid 88 | concat_layers += [ConcatLayer(v[3], scale, k != 0)] 89 | scale *= 2 90 | return vgg, feat_layers, concat_layers 91 | 92 | 93 | # DSS network 94 | # Note: if you use other backbone network, please change extract 95 | class DSS(nn.Module): 96 | def __init__(self, base, feat_layers, concat_layers, connect, extract=[3, 8, 15, 22, 29], v2=True): 97 | super(DSS, self).__init__() 98 | self.extract = extract 99 | self.connect = connect 100 | self.base = nn.ModuleList(base) 101 | self.feat = nn.ModuleList(feat_layers) 102 | self.comb = nn.ModuleList(concat_layers) 103 | self.pool = nn.AvgPool2d(3, 1, 1) 104 | self.v2 = v2 105 | if v2: self.fuse = FusionLayer() 106 | 107 | def forward(self, x, label=None): 108 | prob, back, y, num = list(), list(), list(), 0 109 | for k in range(len(self.base)): 110 | x = self.base[k](x) 111 | if k in self.extract: 112 | y.append(self.feat[num](x)) 113 | num += 1 114 | # side output 115 | y.append(self.feat[num](self.pool(x))) 116 | for i, k in enumerate(range(len(y))): 117 | back.append(self.comb[i](y[i], [y[j] for j in self.connect[i]])) 118 | # fusion map 119 | if self.v2: 120 | # version2: learning fusion 121 | back.append(self.fuse(back)) 122 | else: 123 | # version1: mean fusion 124 | back.append(torch.cat(back, dim=1).mean(dim=1, keepdim=True)) 125 | # add sigmoid 126 | for i in back: prob.append(torch.sigmoid(i)) 127 | return prob 128 | 129 | 130 | # build the whole network 131 | def build_model(): 132 | return DSS(*extra_layer(vgg(base['dss'], 3), extra['dss']), connect['dss']) 133 | 134 | 135 | # weight init 136 | def xavier(param): 137 | init.xavier_uniform_(param) 138 | 139 | 140 | def weights_init(m): 141 | if isinstance(m, nn.Conv2d): 142 | xavier(m.weight.data) 143 | elif isinstance(m, nn.BatchNorm2d): 144 | init.constant_(m.weight, 1) 145 | init.constant_(m.bias, 0) 146 | 147 | 148 | if __name__ == '__main__': 149 | net = build_model() 150 | img = torch.randn(1, 3, 64, 64) 151 | net = net.to(torch.device('cuda:0')) 152 | img = img.to(torch.device('cuda:0')) 153 | out = net(img) 154 | k = [out[x] for x in [1, 2, 3, 6]] 155 | print(len(k)) 156 | # for param in net.parameters(): 157 | # print(param) 158 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from collections import OrderedDict 4 | from torch.nn import utils, functional as F 5 | from torch.optim import Adam 6 | from torch.backends import cudnn 7 | from torchvision import transforms 8 | from dssnet import build_model, weights_init 9 | from loss import Loss 10 | from tools.visual import Viz_visdom 11 | 12 | 13 | class Solver(object): 14 | def __init__(self, train_loader, val_loader, test_dataset, config): 15 | self.train_loader = train_loader 16 | self.val_loader = val_loader 17 | self.test_dataset = test_dataset 18 | self.config = config 19 | self.beta = math.sqrt(0.3) # for max F_beta metric 20 | # inference: choose the side map (see paper) 21 | self.select = [1, 2, 3, 6] 22 | self.device = torch.device('cpu') 23 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1) 24 | self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1) 25 | if self.config.cuda: 26 | cudnn.benchmark = True 27 | self.device = torch.device('cuda:0') 28 | if config.visdom: 29 | self.visual = Viz_visdom("DSS", 1) 30 | self.build_model() 31 | if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) 32 | if config.mode == 'train': 33 | self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') 34 | else: 35 | self.net.load_state_dict(torch.load(self.config.model)) 36 | self.net.eval() 37 | self.test_output = open("%s/test.txt" % config.test_fold, 'w') 38 | self.transform = transforms.Compose([ 39 | transforms.Resize((256, 256)), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 42 | ]) 43 | 44 | # print the network information and parameter numbers 45 | def print_network(self, model, name): 46 | num_params = 0 47 | for p in model.parameters(): 48 | if p.requires_grad: num_params += p.numel() 49 | print(name) 50 | print(model) 51 | print("The number of parameters: {}".format(num_params)) 52 | 53 | # build the network 54 | def build_model(self): 55 | self.net = build_model().to(self.device) 56 | if self.config.mode == 'train': self.loss = Loss().to(self.device) 57 | self.net.train() 58 | self.net.apply(weights_init) 59 | if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) 60 | if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) 61 | self.optimizer = Adam(self.net.parameters(), self.config.lr) 62 | self.print_network(self.net, 'DSS') 63 | 64 | # update the learning rate 65 | def update_lr(self, lr): 66 | for param_group in self.optimizer.param_groups: 67 | param_group['lr'] = lr 68 | 69 | # evaluate MAE (for test or validation phase) 70 | def eval_mae(self, y_pred, y): 71 | return torch.abs(y_pred - y).mean() 72 | 73 | # TODO: write a more efficient version 74 | # get precisions and recalls: threshold---divided [0, 1] to num values 75 | def eval_pr(self, y_pred, y, num): 76 | prec, recall = torch.zeros(num), torch.zeros(num) 77 | thlist = torch.linspace(0, 1 - 1e-10, num) 78 | for i in range(num): 79 | y_temp = (y_pred >= thlist[i]).float() 80 | tp = (y_temp * y).sum() 81 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() 82 | return prec, recall 83 | 84 | # validation: using resize image, and only evaluate the MAE metric 85 | def validation(self): 86 | avg_mae = 0.0 87 | self.net.eval() 88 | with torch.no_grad(): 89 | for i, data_batch in enumerate(self.val_loader): 90 | images, labels = data_batch 91 | images, labels = images.to(self.device), labels.to(self.device) 92 | prob_pred = self.net(images) 93 | prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) 94 | avg_mae += self.eval_mae(prob_pred, labels).item() 95 | self.net.train() 96 | return avg_mae / len(self.val_loader) 97 | 98 | # test phase: using origin image size, evaluate MAE and max F_beta metrics 99 | def test(self, num, use_crf=False): 100 | if use_crf: from tools.crf_process import crf 101 | avg_mae, img_num = 0.0, len(self.test_dataset) 102 | avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) 103 | with torch.no_grad(): 104 | for i, (img, labels) in enumerate(self.test_dataset): 105 | images = self.transform(img).unsqueeze(0) 106 | labels = labels.unsqueeze(0) 107 | shape = labels.size()[2:] 108 | images = images.to(self.device) 109 | prob_pred = self.net(images) 110 | prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) 111 | prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data 112 | if use_crf: 113 | prob_pred = crf(img, prob_pred.numpy(), to_tensor=True) 114 | mae = self.eval_mae(prob_pred, labels) 115 | prec, recall = self.eval_pr(prob_pred, labels, num) 116 | print("[%d] mae: %.4f" % (i, mae)) 117 | print("[%d] mae: %.4f" % (i, mae), file=self.test_output) 118 | avg_mae += mae 119 | avg_prec, avg_recall = avg_prec + prec, avg_recall + recall 120 | avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num 121 | score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall) 122 | score[score != score] = 0 # delete the nan 123 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) 124 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) 125 | 126 | # training phase 127 | def train(self): 128 | iter_num = len(self.train_loader.dataset) // self.config.batch_size 129 | best_mae = 1.0 if self.config.val else None 130 | for epoch in range(self.config.epoch): 131 | loss_epoch = 0 132 | for i, data_batch in enumerate(self.train_loader): 133 | if (i + 1) > iter_num: break 134 | self.net.zero_grad() 135 | x, y = data_batch 136 | x, y = x.to(self.device), y.to(self.device) 137 | y_pred = self.net(x) 138 | loss = self.loss(y_pred, y) 139 | loss.backward() 140 | utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) 141 | # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient) 142 | self.optimizer.step() 143 | loss_epoch += loss.item() 144 | print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % ( 145 | epoch, self.config.epoch, i, iter_num, loss.item())) 146 | if self.config.visdom: 147 | error = OrderedDict([('loss:', loss.item())]) 148 | self.visual.plot_current_errors(epoch, i / iter_num, error) 149 | 150 | if (epoch + 1) % self.config.epoch_show == 0: 151 | print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), 152 | file=self.log_output) 153 | if self.config.visdom: 154 | avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)]) 155 | self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) 156 | y_show = torch.mean(torch.cat([y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True) 157 | img = OrderedDict([('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]), 158 | ('pred_label', y_show.cpu().data[0][0])]) 159 | self.visual.plot_current_img(img) 160 | if self.config.val and (epoch + 1) % self.config.epoch_val == 0: 161 | mae = self.validation() 162 | print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae)) 163 | print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output) 164 | if best_mae > mae: 165 | best_mae = mae 166 | torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) 167 | if (epoch + 1) % self.config.epoch_save == 0: 168 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) 169 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold) 170 | --------------------------------------------------------------------------------