├── png
├── dss.png
├── loss.png
├── side.png
├── example.png
├── demo_anno.png
└── demo_img.jpg
├── tools
├── extract_vgg.py
├── crf_process.py
└── visual.py
├── loss.py
├── LICENSE
├── dataset.py
├── README.md
├── main.py
├── dssnet.py
└── solver.py
/png/dss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/dss.png
--------------------------------------------------------------------------------
/png/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/loss.png
--------------------------------------------------------------------------------
/png/side.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/side.png
--------------------------------------------------------------------------------
/png/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/example.png
--------------------------------------------------------------------------------
/png/demo_anno.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/demo_anno.png
--------------------------------------------------------------------------------
/png/demo_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/DSS-pytorch/HEAD/png/demo_img.jpg
--------------------------------------------------------------------------------
/tools/extract_vgg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision import models
4 |
5 | # extract vgg features
6 | if __name__ == '__main__':
7 | save_fold = '../weights'
8 | if not os.path.exists(save_fold):
9 | os.mkdir(save_fold)
10 | vgg = models.vgg16(pretrained=True)
11 | torch.save(vgg.features.state_dict(), os.path.join(save_fold, 'vgg16_feat.pth'))
12 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 |
4 |
5 | # loss function: seven probability map --- 6 scale + 1 fuse
6 | class Loss(nn.Module):
7 | def __init__(self, weight=[1.0] * 7):
8 | super(Loss, self).__init__()
9 | self.weight = weight
10 |
11 | def forward(self, x_list, label):
12 | loss = self.weight[0] * F.binary_cross_entropy(x_list[0], label)
13 | for i, x in enumerate(x_list[1:]):
14 | loss += self.weight[i + 1] * F.binary_cross_entropy(x, label)
15 | return loss
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Ace
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tools/crf_process.py:
--------------------------------------------------------------------------------
1 | # reference: https://github.com/Andrew-Qibin/dss_crf/blob/master/examples/dense_hsal.py
2 | import torch
3 | import numpy as np
4 | import pydensecrf.densecrf as dcrf
5 |
6 |
7 | def sigmoid(x):
8 | return 1 / (1 + np.exp(-x))
9 |
10 |
11 | # parameter
12 | EPSILON = 1e-8
13 | tau = 1.05
14 |
15 |
16 | # img: PIL
17 | # anno: numpy
18 | def crf(img, anno, to_tensor=False):
19 | img = np.array(img)
20 | d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 2)
21 | n_energy = -np.log((1.0 - anno + EPSILON)) / (tau * sigmoid(1 - anno))
22 | p_energy = -np.log(anno + EPSILON) / (tau * sigmoid(anno))
23 | U = np.zeros((2, img.shape[0] * img.shape[1]), dtype='float32')
24 | U[0, :] = n_energy.flatten()
25 | U[1, :] = p_energy.flatten()
26 | d.setUnaryEnergy(U)
27 |
28 | d.addPairwiseGaussian(sxy=3, compat=3)
29 | d.addPairwiseBilateral(sxy=60, srgb=5, rgbim=img, compat=5)
30 |
31 | # Do the inference
32 | infer = np.array(d.inference(1)).astype('float32')
33 | res = np.expand_dims(infer[1, :].reshape(img.shape[:2]), 0)
34 | if to_tensor:
35 | res = torch.from_numpy(res).unsqueeze(0)
36 | return res
37 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch
4 | from torch.utils import data
5 | from torchvision import transforms
6 |
7 |
8 | class ImageData(data.Dataset):
9 | """ image dataset
10 | img_root: image root (root which contain images)
11 | label_root: label root (root which contains labels)
12 | transform: pre-process for image
13 | t_transform: pre-process for label
14 | filename: MSRA-B use xxx.txt to recognize train-val-test data (only for MSRA-B)
15 | """
16 |
17 | def __init__(self, img_root, label_root, transform, t_transform, filename=None):
18 | if filename is None:
19 | self.image_path = list(map(lambda x: os.path.join(img_root, x), os.listdir(img_root)))
20 | self.label_path = list(
21 | map(lambda x: os.path.join(label_root, x.split('/')[-1][:-3] + 'png'), self.image_path))
22 | else:
23 | lines = [line.rstrip('\n')[:-3] for line in open(filename)]
24 | self.image_path = list(map(lambda x: os.path.join(img_root, x + 'jpg'), lines))
25 | self.label_path = list(map(lambda x: os.path.join(label_root, x + 'png'), lines))
26 |
27 | self.transform = transform
28 | self.t_transform = t_transform
29 |
30 | def __getitem__(self, item):
31 | image = Image.open(self.image_path[item])
32 | label = Image.open(self.label_path[item]).convert('L')
33 | if self.transform is not None:
34 | image = self.transform(image)
35 | if self.t_transform is not None:
36 | label = self.t_transform(label)
37 | return image, label
38 |
39 | def __len__(self):
40 | return len(self.image_path)
41 |
42 |
43 | # get the dataloader (Note: without data augmentation)
44 | def get_loader(img_root, label_root, img_size, batch_size, filename=None, mode='train', num_thread=4, pin=True):
45 | if mode == 'train':
46 | transform = transforms.Compose([
47 | transforms.Resize((img_size, img_size)),
48 | transforms.ToTensor(),
49 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50 | ])
51 | t_transform = transforms.Compose([
52 | transforms.Resize((img_size, img_size)),
53 | transforms.ToTensor(),
54 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary
55 | ])
56 | dataset = ImageData(img_root, label_root, transform, t_transform, filename=filename)
57 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread,
58 | pin_memory=pin)
59 | return data_loader
60 | else:
61 | t_transform = transforms.Compose([
62 | transforms.ToTensor(),
63 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary
64 | ])
65 | dataset = ImageData(img_root, label_root, None, t_transform, filename=filename)
66 | return dataset
67 |
68 |
69 | if __name__ == '__main__':
70 | import numpy as np
71 | img_root = '/home/ace/data/MSRA-B/image'
72 | label_root = '/home/ace/data/MSRA-B/annotation'
73 | filename = '/home/ace/data/MSRA-B/train_cvpr2013.txt'
74 | loader = get_loader(img_root, label_root, 224, 1, filename=filename, mode='test')
75 | for image, label in loader:
76 | print(np.array(image).shape)
77 | break
78 |
--------------------------------------------------------------------------------
/tools/visual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 |
6 | class Viz_visdom(object):
7 | def __init__(self, name, display_id=0):
8 | self.name = name
9 | self.display_id = display_id
10 | self.idx = display_id
11 | self.plot_data = {}
12 | if display_id > 0:
13 | import visdom
14 | self.vis = visdom.Visdom(port=8097)
15 |
16 | def plot_current_errors(self, epoch, counter_ratio, errors, idx=0):
17 |
18 | if idx not in self.plot_data:
19 | self.plot_data[idx] = {'X': [], 'Y': [], 'legend': list(errors.keys())}
20 | # self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
21 | self.plot_data[idx]['X'].append(epoch + counter_ratio)
22 | self.plot_data[idx]['Y'].append([errors[k] for k in self.plot_data[idx]['legend']])
23 | self.vis.line(
24 | X=np.stack([np.array(self.plot_data[idx]['X'])] * len(self.plot_data[idx]['legend']), 1)
25 | if len(errors) > 1 else np.array(self.plot_data[idx]['X']),
26 | Y=np.array(self.plot_data[idx]['Y']) if len(errors) > 1 else np.array(self.plot_data[idx]['Y'])[:, 0],
27 | opts={
28 | 'title': self.name + ' loss over time %d' % idx,
29 | 'legend': self.plot_data[idx]['legend'],
30 | 'xlabel': 'epoch',
31 | 'ylabel': 'loss'},
32 | win=self.display_id + idx)
33 | if self.idx < self.display_id + idx:
34 | self.idx = self.display_id + idx
35 |
36 | def plot_current_img(self, visuals, c_prev=True):
37 | idx = self.idx + 1
38 | for label, image_numpy in visuals.items():
39 | if c_prev:
40 | self.vis.image(image_numpy, opts=dict(title=label),
41 | win=self.display_id + idx)
42 | else:
43 | image_numpy = image_numpy.swapaxes(0, 2).swapaxes(1, 2)
44 | self.vis.image(image_numpy, opts=dict(title=label),
45 | win=self.display_id + idx)
46 | idx += 1
47 |
48 |
49 | # reference: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
50 | def plot_image(inp, fig_size, title=None, swap_channel=False, norm=False):
51 | """Imshow for Tensor."""
52 | if torch.is_tensor(inp):
53 | inp = inp.numpy().transpose((1, 2, 0)) if swap_channel else inp.numpy()
54 | else:
55 | inp = inp.transpose((1, 2, 0)) if swap_channel else inp
56 | if norm:
57 | mean = np.array([0.485, 0.456, 0.406])
58 | std = np.array([0.229, 0.224, 0.225])
59 | inp = std * inp + mean
60 | inp = np.clip(inp, 0, 1)
61 | plt.figure(figsize=fig_size)
62 | if inp.shape[0] == 1:
63 | plt.imshow(inp[0], cmap='gray')
64 | else:
65 | plt.imshow(inp)
66 | if title is not None:
67 | plt.title(title)
68 | plt.pause(0.0001) # pause a bit so that plots are updated
69 |
70 |
71 | def make_simple_grid(inp, padding=2, padding_value=1):
72 | inp = torch.stack(inp, dim=0)
73 | nmaps = inp.size(0)
74 | height, width = inp.size(2), int(inp.size(3) + padding)
75 | grid = inp.new(1, height, width * nmaps + padding).fill_(padding_value)
76 | for i in range(nmaps):
77 | grid.narrow(2, i * width + padding, width - padding).copy_(inp[i])
78 | return grid
79 |
80 |
81 | if __name__ == '__main__':
82 | inp = [torch.randn(1, 5, 5), torch.randn(1, 5, 5)]
83 | out = make_simple_grid(inp)
84 | print(out.size())
85 | plot_image(out)
86 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DSS-PyTorch
2 | PyTorch implement of [Deeply Supervised Salient Object Detection with Short Connection](https://arxiv.org/abs/1611.04849)
3 |
4 |

5 |
6 | The official caffe version: [DSS](https://github.com/Andrew-Qibin/DSS)
7 |
8 | ## Prerequisites
9 |
10 | - [Python 3](https://www.continuum.io/downloads)
11 | - [Pytorch 0.4.1+](http://pytorch.org/)
12 | - [torchvision](http://pytorch.org/)
13 | - [visdom](https://github.com/facebookresearch/visdom) (optional for visualization)
14 | - [PyDenseCRF](https://github.com/lucasb-eyer/pydensecrf)(optional for CRF post-process)
15 |
16 | ## Results
17 |
18 | The information of Loss:
19 |
20 | 
21 |
22 | Example output:
23 |
24 | 
25 |
26 | > Note: here the "blur boundary" caused by bad combine method
27 |
28 | Different connection output:
29 |
30 | 
31 |
32 | #### Some difference
33 |
34 | 1. The original paper use:$Z=h(\sum_{i=2}^4 f_mR^{(m)})$,here we use $Z=h(\sum_{i=1}^6 f_mR^{(m)})$ in inference stage
35 |
36 | #### Results Reproduct
37 |
38 | | Dataset (MSRA-B) | Paper | Here (v1) | Only Fusion (v1) | Here (v2) | Only Fusion (v2) | Here(v2 700) |
39 | | :------------------: | :---: | :-------: | :--------------: | :-------: | :--------------: | :----------: |
40 | | MAE (without CRF) | 0.043 | 0.054 | 0.052 | 0.068 | 0.052 | 0.051 |
41 | | F_beta (without CRF) | 0.920 | 0.910 | 0.914 | 0.912 | 0.910 | 0.918 |
42 | | MAE (with CRF) | 0.028 | 0.047 | 0.048 | 0.047 | 0.049 | 0.047 |
43 | | F_beta (with CRF) | 0.927 | 0.916 | 0.917 | 0.915 | 0.918 | 0.923 |
44 |
45 | Note:
46 |
47 | 1. v1 means use average fusion , v2 means use learnable fusion
48 | 2. You can try to use other "inference stragedy"(I think other combine can get better results --- here use sout-2+sout-3+sout-4+fusion --- you can just change [self.select](https://github.com/AceCoooool/DSS-pytorch/blob/66419dee7045f4581e7e18f910ca98e1a596705a/solver.py#L20))
49 | 3. v2 700 means training with 700 epochs. (I use pre-trained model by 500 epochs:so the optimizer is a little differnt to direct 700 eopch)
50 |
51 | ## Usage
52 |
53 | ### 1. Clone the repository
54 |
55 | ```shell
56 | git clone git@github.com:AceCoooool/DSS-pytorch.git
57 | cd DSS-pytorch/
58 | ```
59 |
60 | ### 2. Download the dataset
61 |
62 | Download the [MSRA-B](http://mmcheng.net/zh/msra10k/) dataset. (If you can not find this dataset, email to me --- I am not sure whether it's legal to put it on BaiDuYun)
63 |
64 | ```shell
65 | # file construction
66 | MSRA-B
67 | --- annotation
68 | --- xxx.png
69 | --- xxx.png
70 | --- image
71 | --- xxx.jpg
72 | --- xxx.jpg
73 | --- test_cvpr2013.txt
74 | --- train_cvpr2013.txt
75 | --- valid_cvpr2013.txt
76 | --- test_cvpr2013_debug.txt
77 | --- train_cvpr2013_debug.txt
78 | --- valid_cvpr2013_debug.txt
79 | ```
80 |
81 | ### 3. Get pre-trained vgg
82 |
83 | ```bash
84 | cd tools/
85 | python extract_vgg.py
86 | cd ..
87 | ```
88 |
89 | ### 4. Demo
90 |
91 | pleease see `demo.ipynb`
92 |
93 | Note:
94 |
95 | 1. default choose: download and copy the [pretrained model](https://pan.baidu.com/s/10XmHVMAOp1ewoJXhI0nRgA) to `weights` directory
96 |
97 | ### 5. Train
98 |
99 | ```shell
100 | python main.py --mode='train' --train_path='you_data' --label_path='you_label' --batch_size=8 --visdom=True --train_file='you_file'
101 | ```
102 |
103 | Note:
104 |
105 | 1. `--val=True` add the validation (but your need to add the `--val_path`, `--val_file` and `--val_label`)
106 | 2. `you_data, you_label` means your training data root. (connect to the step 2)
107 | 3. If you Download the data to `youhome/data/MSRA-B`(you can not "implicity" the path)
108 |
109 | ### 6. Test
110 |
111 | ```shell
112 | python main.py --mode='test' --test_path='you_data' --test_label='your_label' --use_crf=False --model='your_trained_model' --test_file='you_file'
113 | ```
114 |
115 | Note:
116 |
117 | 1. only support `bath_size=1`
118 | 2. `--use_crf=True`:means use CRF post-process
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from dataset import get_loader
4 | from solver import Solver
5 |
6 |
7 | def main(config):
8 | if config.mode == 'train':
9 | train_loader = get_loader(config.train_path, config.label_path, config.img_size, config.batch_size,
10 | filename=config.train_file, num_thread=config.num_thread)
11 | if config.val:
12 | val_loader = get_loader(config.val_path, config.val_label, config.img_size, config.batch_size,
13 | filename=config.val_file, num_thread=config.num_thread)
14 | run = 0
15 | while os.path.exists("%s/run-%d" % (config.save_fold, run)): run += 1
16 | os.mkdir("%s/run-%d" % (config.save_fold, run))
17 | os.mkdir("%s/run-%d/logs" % (config.save_fold, run))
18 | # os.mkdir("%s/run-%d/images" % (config.save_fold, run))
19 | os.mkdir("%s/run-%d/models" % (config.save_fold, run))
20 | config.save_fold = "%s/run-%d" % (config.save_fold, run)
21 | if config.val:
22 | train = Solver(train_loader, val_loader, None, config)
23 | else:
24 | train = Solver(train_loader, None, None, config)
25 | train.train()
26 | elif config.mode == 'test':
27 | test_loader = get_loader(config.test_path, config.test_label, config.img_size, config.batch_size, mode='test',
28 | filename=config.test_file, num_thread=config.num_thread)
29 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold)
30 | test = Solver(None, None, test_loader, config)
31 | test.test(100, use_crf=config.use_crf)
32 | else:
33 | raise IOError("illegal input!!!")
34 |
35 |
36 | if __name__ == '__main__':
37 | data_root = os.path.join(os.path.expanduser('~'), 'data')
38 | vgg_path = './weights/vgg16_feat.pth'
39 | # # -----ECSSD dataset-----
40 | # train_path = os.path.join(data_root, 'ECSSD/images')
41 | # label_path = os.path.join(data_root, 'ECSSD/ground_truth_mask')
42 | #
43 | # val_path = os.path.join(data_root, 'ECSSD/val_images')
44 | # val_label = os.path.join(data_root, 'ECSSD/val_ground_truth_mask')
45 | # test_path = os.path.join(data_root, 'ECSSD/test_images')
46 | # test_label = os.path.join(data_root, 'ECSSD/test_ground_truth_mask')
47 | # # -----MSRA-B dataset-----
48 | image_path = os.path.join(data_root, 'MSRA-B/image')
49 | label_path = os.path.join(data_root, 'MSRA-B/annotation')
50 | train_file = os.path.join(data_root, 'MSRA-B/train_cvpr2013.txt')
51 | valid_file = os.path.join(data_root, 'MSRA-B/valid_cvpr2013.txt')
52 | test_file = os.path.join(data_root, 'MSRA-B/test_cvpr2013.txt')
53 | parser = argparse.ArgumentParser()
54 |
55 | # Hyper-parameters
56 | parser.add_argument('--n_color', type=int, default=3)
57 | parser.add_argument('--img_size', type=int, default=256) # 256
58 | parser.add_argument('--lr', type=float, default=1e-6)
59 | parser.add_argument('--clip_gradient', type=float, default=1.0)
60 | parser.add_argument('--cuda', type=bool, default=True)
61 |
62 | # Training settings
63 | parser.add_argument('--vgg', type=str, default=vgg_path)
64 | parser.add_argument('--train_path', type=str, default=image_path)
65 | parser.add_argument('--label_path', type=str, default=label_path)
66 | parser.add_argument('--train_file', type=str, default=train_file)
67 | parser.add_argument('--epoch', type=int, default=500)
68 | parser.add_argument('--batch_size', type=int, default=1) # 8
69 | parser.add_argument('--val', type=bool, default=True)
70 | parser.add_argument('--val_path', type=str, default=image_path)
71 | parser.add_argument('--val_label', type=str, default=label_path)
72 | parser.add_argument('--val_file', type=str, default=valid_file)
73 | parser.add_argument('--num_thread', type=int, default=4)
74 | parser.add_argument('--load', type=str, default='')
75 | parser.add_argument('--save_fold', type=str, default='./results')
76 | parser.add_argument('--epoch_val', type=int, default=10)
77 | parser.add_argument('--epoch_save', type=int, default=20)
78 | parser.add_argument('--epoch_show', type=int, default=1)
79 | parser.add_argument('--pre_trained', type=str, default=None)
80 |
81 | # Testing settings
82 | parser.add_argument('--test_path', type=str, default=image_path)
83 | parser.add_argument('--test_label', type=str, default=label_path)
84 | parser.add_argument('--test_file', type=str, default=test_file)
85 | parser.add_argument('--model', type=str, default='./weights/final.pth')
86 | parser.add_argument('--test_fold', type=str, default='./results/test')
87 | parser.add_argument('--use_crf', type=bool, default=False)
88 |
89 | # Misc
90 | parser.add_argument('--mode', type=str, default='test', choices=['train', 'test'])
91 | parser.add_argument('--visdom', type=bool, default=False)
92 |
93 | config = parser.parse_args()
94 | if not os.path.exists(config.save_fold): os.mkdir(config.save_fold)
95 | main(config)
96 |
--------------------------------------------------------------------------------
/dssnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init
4 |
5 | # vgg choice
6 | base = {'dss': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']}
7 | # extend vgg choice --- follow the paper, you can change it
8 | extra = {'dss': [(64, 128, 3, [8, 16, 32, 64]), (128, 128, 3, [4, 8, 16, 32]), (256, 256, 5, [8, 16]),
9 | (512, 256, 5, [4, 8]), (512, 512, 5, []), (512, 512, 7, [])]}
10 | connect = {'dss': [[2, 3, 4, 5], [2, 3, 4, 5], [4, 5], [4, 5], [], []]}
11 |
12 |
13 | # vgg16
14 | def vgg(cfg, i=3, batch_norm=False):
15 | layers = []
16 | in_channels = i
17 | for v in cfg:
18 | if v == 'M':
19 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
20 | else:
21 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
22 | if batch_norm:
23 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
24 | else:
25 | layers += [conv2d, nn.ReLU(inplace=True)]
26 | in_channels = v
27 | return layers
28 |
29 |
30 | # feature map before sigmoid: build the connection and deconvolution
31 | class ConcatLayer(nn.Module):
32 | def __init__(self, list_k, k, scale=True):
33 | super(ConcatLayer, self).__init__()
34 | l, up, self.scale = len(list_k), [], scale
35 | for i in range(l):
36 | up.append(nn.ConvTranspose2d(1, 1, list_k[i], list_k[i] // 2, list_k[i] // 4))
37 | self.upconv = nn.ModuleList(up)
38 | self.conv = nn.Conv2d(l + 1, 1, 1, 1)
39 | self.deconv = nn.ConvTranspose2d(1, 1, k * 2, k, k // 2) if scale else None
40 |
41 | def forward(self, x, list_x):
42 | elem_x = [x]
43 | for i, elem in enumerate(list_x):
44 | elem_x.append(self.upconv[i](elem))
45 | if self.scale:
46 | out = self.deconv(self.conv(torch.cat(elem_x, dim=1)))
47 | else:
48 | out = self.conv(torch.cat(elem_x, dim=1))
49 | return out
50 |
51 |
52 | # extend vgg: side outputs
53 | class FeatLayer(nn.Module):
54 | def __init__(self, in_channel, channel, k):
55 | super(FeatLayer, self).__init__()
56 | self.main = nn.Sequential(nn.Conv2d(in_channel, channel, k, 1, k // 2), nn.ReLU(inplace=True),
57 | nn.Conv2d(channel, channel, k, 1, k // 2), nn.ReLU(inplace=True),
58 | nn.Conv2d(channel, 1, 1, 1))
59 |
60 | def forward(self, x):
61 | return self.main(x)
62 |
63 |
64 | # fusion features
65 | class FusionLayer(nn.Module):
66 | def __init__(self, nums=6):
67 | super(FusionLayer, self).__init__()
68 | self.weights = nn.Parameter(torch.randn(nums))
69 | self.nums = nums
70 | self._reset_parameters()
71 |
72 | def _reset_parameters(self):
73 | init.constant_(self.weights, 1 / self.nums)
74 |
75 | def forward(self, x):
76 | for i in range(self.nums):
77 | out = self.weights[i] * x[i] if i == 0 else out + self.weights[i] * x[i]
78 | return out
79 |
80 |
81 | # extra part
82 | def extra_layer(vgg, cfg):
83 | feat_layers, concat_layers, scale = [], [], 1
84 | for k, v in enumerate(cfg):
85 | # side output (paper: figure 3)
86 | feat_layers += [FeatLayer(v[0], v[1], v[2])]
87 | # feature map before sigmoid
88 | concat_layers += [ConcatLayer(v[3], scale, k != 0)]
89 | scale *= 2
90 | return vgg, feat_layers, concat_layers
91 |
92 |
93 | # DSS network
94 | # Note: if you use other backbone network, please change extract
95 | class DSS(nn.Module):
96 | def __init__(self, base, feat_layers, concat_layers, connect, extract=[3, 8, 15, 22, 29], v2=True):
97 | super(DSS, self).__init__()
98 | self.extract = extract
99 | self.connect = connect
100 | self.base = nn.ModuleList(base)
101 | self.feat = nn.ModuleList(feat_layers)
102 | self.comb = nn.ModuleList(concat_layers)
103 | self.pool = nn.AvgPool2d(3, 1, 1)
104 | self.v2 = v2
105 | if v2: self.fuse = FusionLayer()
106 |
107 | def forward(self, x, label=None):
108 | prob, back, y, num = list(), list(), list(), 0
109 | for k in range(len(self.base)):
110 | x = self.base[k](x)
111 | if k in self.extract:
112 | y.append(self.feat[num](x))
113 | num += 1
114 | # side output
115 | y.append(self.feat[num](self.pool(x)))
116 | for i, k in enumerate(range(len(y))):
117 | back.append(self.comb[i](y[i], [y[j] for j in self.connect[i]]))
118 | # fusion map
119 | if self.v2:
120 | # version2: learning fusion
121 | back.append(self.fuse(back))
122 | else:
123 | # version1: mean fusion
124 | back.append(torch.cat(back, dim=1).mean(dim=1, keepdim=True))
125 | # add sigmoid
126 | for i in back: prob.append(torch.sigmoid(i))
127 | return prob
128 |
129 |
130 | # build the whole network
131 | def build_model():
132 | return DSS(*extra_layer(vgg(base['dss'], 3), extra['dss']), connect['dss'])
133 |
134 |
135 | # weight init
136 | def xavier(param):
137 | init.xavier_uniform_(param)
138 |
139 |
140 | def weights_init(m):
141 | if isinstance(m, nn.Conv2d):
142 | xavier(m.weight.data)
143 | elif isinstance(m, nn.BatchNorm2d):
144 | init.constant_(m.weight, 1)
145 | init.constant_(m.bias, 0)
146 |
147 |
148 | if __name__ == '__main__':
149 | net = build_model()
150 | img = torch.randn(1, 3, 64, 64)
151 | net = net.to(torch.device('cuda:0'))
152 | img = img.to(torch.device('cuda:0'))
153 | out = net(img)
154 | k = [out[x] for x in [1, 2, 3, 6]]
155 | print(len(k))
156 | # for param in net.parameters():
157 | # print(param)
158 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from collections import OrderedDict
4 | from torch.nn import utils, functional as F
5 | from torch.optim import Adam
6 | from torch.backends import cudnn
7 | from torchvision import transforms
8 | from dssnet import build_model, weights_init
9 | from loss import Loss
10 | from tools.visual import Viz_visdom
11 |
12 |
13 | class Solver(object):
14 | def __init__(self, train_loader, val_loader, test_dataset, config):
15 | self.train_loader = train_loader
16 | self.val_loader = val_loader
17 | self.test_dataset = test_dataset
18 | self.config = config
19 | self.beta = math.sqrt(0.3) # for max F_beta metric
20 | # inference: choose the side map (see paper)
21 | self.select = [1, 2, 3, 6]
22 | self.device = torch.device('cpu')
23 | self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
24 | self.std = torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
25 | if self.config.cuda:
26 | cudnn.benchmark = True
27 | self.device = torch.device('cuda:0')
28 | if config.visdom:
29 | self.visual = Viz_visdom("DSS", 1)
30 | self.build_model()
31 | if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained))
32 | if config.mode == 'train':
33 | self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
34 | else:
35 | self.net.load_state_dict(torch.load(self.config.model))
36 | self.net.eval()
37 | self.test_output = open("%s/test.txt" % config.test_fold, 'w')
38 | self.transform = transforms.Compose([
39 | transforms.Resize((256, 256)),
40 | transforms.ToTensor(),
41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
42 | ])
43 |
44 | # print the network information and parameter numbers
45 | def print_network(self, model, name):
46 | num_params = 0
47 | for p in model.parameters():
48 | if p.requires_grad: num_params += p.numel()
49 | print(name)
50 | print(model)
51 | print("The number of parameters: {}".format(num_params))
52 |
53 | # build the network
54 | def build_model(self):
55 | self.net = build_model().to(self.device)
56 | if self.config.mode == 'train': self.loss = Loss().to(self.device)
57 | self.net.train()
58 | self.net.apply(weights_init)
59 | if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg))
60 | if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load))
61 | self.optimizer = Adam(self.net.parameters(), self.config.lr)
62 | self.print_network(self.net, 'DSS')
63 |
64 | # update the learning rate
65 | def update_lr(self, lr):
66 | for param_group in self.optimizer.param_groups:
67 | param_group['lr'] = lr
68 |
69 | # evaluate MAE (for test or validation phase)
70 | def eval_mae(self, y_pred, y):
71 | return torch.abs(y_pred - y).mean()
72 |
73 | # TODO: write a more efficient version
74 | # get precisions and recalls: threshold---divided [0, 1] to num values
75 | def eval_pr(self, y_pred, y, num):
76 | prec, recall = torch.zeros(num), torch.zeros(num)
77 | thlist = torch.linspace(0, 1 - 1e-10, num)
78 | for i in range(num):
79 | y_temp = (y_pred >= thlist[i]).float()
80 | tp = (y_temp * y).sum()
81 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
82 | return prec, recall
83 |
84 | # validation: using resize image, and only evaluate the MAE metric
85 | def validation(self):
86 | avg_mae = 0.0
87 | self.net.eval()
88 | with torch.no_grad():
89 | for i, data_batch in enumerate(self.val_loader):
90 | images, labels = data_batch
91 | images, labels = images.to(self.device), labels.to(self.device)
92 | prob_pred = self.net(images)
93 | prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
94 | avg_mae += self.eval_mae(prob_pred, labels).item()
95 | self.net.train()
96 | return avg_mae / len(self.val_loader)
97 |
98 | # test phase: using origin image size, evaluate MAE and max F_beta metrics
99 | def test(self, num, use_crf=False):
100 | if use_crf: from tools.crf_process import crf
101 | avg_mae, img_num = 0.0, len(self.test_dataset)
102 | avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
103 | with torch.no_grad():
104 | for i, (img, labels) in enumerate(self.test_dataset):
105 | images = self.transform(img).unsqueeze(0)
106 | labels = labels.unsqueeze(0)
107 | shape = labels.size()[2:]
108 | images = images.to(self.device)
109 | prob_pred = self.net(images)
110 | prob_pred = torch.mean(torch.cat([prob_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
111 | prob_pred = F.interpolate(prob_pred, size=shape, mode='bilinear', align_corners=True).cpu().data
112 | if use_crf:
113 | prob_pred = crf(img, prob_pred.numpy(), to_tensor=True)
114 | mae = self.eval_mae(prob_pred, labels)
115 | prec, recall = self.eval_pr(prob_pred, labels, num)
116 | print("[%d] mae: %.4f" % (i, mae))
117 | print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
118 | avg_mae += mae
119 | avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
120 | avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
121 | score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall)
122 | score[score != score] = 0 # delete the nan
123 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
124 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output)
125 |
126 | # training phase
127 | def train(self):
128 | iter_num = len(self.train_loader.dataset) // self.config.batch_size
129 | best_mae = 1.0 if self.config.val else None
130 | for epoch in range(self.config.epoch):
131 | loss_epoch = 0
132 | for i, data_batch in enumerate(self.train_loader):
133 | if (i + 1) > iter_num: break
134 | self.net.zero_grad()
135 | x, y = data_batch
136 | x, y = x.to(self.device), y.to(self.device)
137 | y_pred = self.net(x)
138 | loss = self.loss(y_pred, y)
139 | loss.backward()
140 | utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)
141 | # utils.clip_grad_norm(self.loss.parameters(), self.config.clip_gradient)
142 | self.optimizer.step()
143 | loss_epoch += loss.item()
144 | print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (
145 | epoch, self.config.epoch, i, iter_num, loss.item()))
146 | if self.config.visdom:
147 | error = OrderedDict([('loss:', loss.item())])
148 | self.visual.plot_current_errors(epoch, i / iter_num, error)
149 |
150 | if (epoch + 1) % self.config.epoch_show == 0:
151 | print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num),
152 | file=self.log_output)
153 | if self.config.visdom:
154 | avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)])
155 | self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1)
156 | y_show = torch.mean(torch.cat([y_pred[i] for i in self.select], dim=1), dim=1, keepdim=True)
157 | img = OrderedDict([('origin', x.cpu()[0] * self.std + self.mean), ('label', y.cpu()[0][0]),
158 | ('pred_label', y_show.cpu().data[0][0])])
159 | self.visual.plot_current_img(img)
160 | if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
161 | mae = self.validation()
162 | print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae))
163 | print('--- Best MAE: %.2f, Curr MAE: %.2f ---' % (best_mae, mae), file=self.log_output)
164 | if best_mae > mae:
165 | best_mae = mae
166 | torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold)
167 | if (epoch + 1) % self.config.epoch_save == 0:
168 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1))
169 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
170 |
--------------------------------------------------------------------------------