├── README.md
├── README.zh.md
├── dataset.py
├── demo.py
├── download.sh
├── loss.py
├── main.py
├── nldf.py
├── png
├── demo.jpg
├── example.png
└── loss.png
├── solver.py
└── tools
├── extract_vgg.py
└── visual.py
/README.md:
--------------------------------------------------------------------------------
1 | # NLDF
2 | [中文说明](./README.zh.md)
3 |
4 | An unofficial implementation of [Non-Local Deep Features for Salient Object Detection](https://sites.google.com/view/zhimingluo/nldf).
5 |
6 |

7 |
8 | The official Tensorflow version: [NLDF](https://github.com/zhimingluo/NLDF)
9 |
10 | Some thing difference:
11 |
12 | 1. ~~dataset~~
13 | 2. score with one channel, rather than two channels
14 | 3. Dice IOU: boundary version and area version
15 |
16 | ## Prerequisites
17 |
18 | - [Python 3](https://www.continuum.io/downloads)
19 | - [Pytorch 1.0](http://pytorch.org/)
20 | - [torchvision](http://pytorch.org/)
21 | - [visdom](https://github.com/facebookresearch/visdom) (optional for visualization)
22 |
23 | ## Results
24 |
25 | The information of Loss:
26 |
27 | 
28 |
29 | Performance:
30 |
31 | | Dataset | max F(paper) | MAE(paper) | max F(here) | MAE(here) |
32 | | :-----: | :----------: | :--------: | :---------: | :-------: |
33 | | MSRA-B | 0.911 | 0.048 | 0.9006 | 0.0592 |
34 |
35 | Note:
36 |
37 | 1. only training 200 epoch, larger epoch may nearly the original paper
38 | 2. This reproduction use area IOU, and original paper use boundary IOU
39 | 3. ~~it's unfairness to this compare. (Different training data, I can not find the dataset use in original paper )~~
40 |
41 | ## Usage
42 |
43 | ### 1. Clone the repository
44 |
45 | ```shell
46 | git clone git@github.com:AceCoooool/NLDF-pytorch.git
47 | cd NLDF-pytorch/
48 | ```
49 |
50 | ### 2. Download the dataset
51 |
52 | Note: the original paper use other datasets.
53 |
54 | Download the [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html) dataset.
55 |
56 | ```shell
57 | bash download.sh
58 | ```
59 |
60 | ### 3. Get pre-trained vgg
61 |
62 | ```bash
63 | cd tools/
64 | python extract_vgg.py
65 | cd ..
66 | ```
67 |
68 | ### 4. Demo
69 |
70 | ```shell
71 | python demo.py --demo_img='your_picture' --trained_model='pre_trained pth' --cuda=True
72 | ```
73 |
74 | Note:
75 |
76 | 1. default choose: download and copy the [pretrained model](https://drive.google.com/file/d/10cnWpqABT6MRdTO0p17hcHornMs6ggQL/view?usp=sharing) to `weights` directory.
77 | 2. a demo picture is in `png/demo.jpg`
78 |
79 | ### 5. Train
80 |
81 | ```shell
82 | python main.py --mode='train' --train_path='you_data' --label_path='you_label' --batch_size=8 --visdom=True --area=True
83 | ```
84 |
85 | Note:
86 |
87 | 1. `--area=True, --boundary=True` area and boundary Dice IOU (default: `--area=True --boundary=False`)
88 | 2. `--val=True` add the validation (but your need to add the `--val_path` and `--val_label`)
89 | 3. `you_data, you_label` means your training data root. (connect to the step 2)
90 |
91 | ### 6. Test
92 |
93 | ```shell
94 | python main.py --mode='test', --test_path='you_data' --test_label='your_label' --batch_size=1 --model='your_trained_model'
95 | ```
96 |
97 | Note:
98 |
99 | 1. use the same evaluation (this is a reproduction from original achievement)
100 |
101 | ## Bug
102 |
103 | 1. The boundary Dice IOU may cause `inf`,it is better to use area Dice IOU.
104 |
105 | Maybe, it is better to add Batch Normalization.
--------------------------------------------------------------------------------
/README.zh.md:
--------------------------------------------------------------------------------
1 | # NLFD
2 | [English](./README.md)
3 |
4 | 基于Pytorch的非官方版本实现: [Non-Local Deep Features for Salient Object Detection](https://sites.google.com/view/zhimingluo/nldf).
5 |
6 | 
7 |
8 | 官方Tensorflow版本链接: [NLDF](https://github.com/zhimingluo/NLDF)
9 |
10 | 此实现的几点改动:
11 |
12 | 1. ~~数据集(个人没找到MSRA-B的图片)~~
13 | 2. 网络结构上的一些不同:此处采用最后输出为单个概率图,官方版本中是两个互异的概率图
14 | 3. 增加了“面积重合率”,原论文中是“边缘重合率” (可同时选择两者)
15 |
16 | ## 依赖库
17 |
18 | - [Python 3](https://www.continuum.io/downloads)
19 | - [Pytorch 0.3.0](http://pytorch.org/)
20 | - [torchvision](http://pytorch.org/)
21 | - [visdom](https://github.com/facebookresearch/visdom) (optional for visualization)
22 |
23 | ## 复现情况
24 |
25 | 迭代过程中的损失函数下降情况:
26 |
27 | 
28 |
29 | 性能:
30 |
31 | | Dataset | max F(paper) | MAE(paper) | max F(here) | MAE(here) |
32 | | :-----: | :----------: | :--------: | :---------: | :-------: |
33 | | ECSSD | 0.905 | 0.063 | 0.9830 | 0.0375 |
34 |
35 | 说明:
36 |
37 | 1. 此处复现采用的是面积IOU,原始论文采用的是边缘IOU
38 | 2. 此处的比较是“不公平”,两者采用的数据集并不相同,且直接拿了训练集来测的指标(只是为了说明性能能够达到甚至超过原始paper)--- 原始论文中的数据集个人没找到。
39 |
40 | ## 使用说明
41 |
42 | ### 1. 复制仓库到本地
43 |
44 | ```shell
45 | git clone git@github.com:AceCoooool/NLFD-pytorch.git
46 | cd NLFD-pytorch/
47 | ```
48 |
49 | ### 2. 从网上下载数据集
50 |
51 | 注:原始论文中采用更多的数据集
52 |
53 | 可从下面链接下载数据集: [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html)
54 |
55 | ```shell
56 | bash download.sh
57 | ```
58 |
59 | ### 3. 提取预先训练好的VGG
60 |
61 | ```bash
62 | cd tools/
63 | python extract_vgg.py
64 | cd ..
65 | ```
66 |
67 | 注:此处个人直接采用torchvision里面训练好的VGG
68 |
69 | ### 4. 示例
70 |
71 | ```shell
72 | python demo.py --demo_img='your_picture' --trained_model='pre_trained pth' --cuda=True
73 | ```
74 |
75 | 注:
76 |
77 | 1. 默认参数:下载[训练好的模型](https://drive.google.com/file/d/10cnWpqABT6MRdTO0p17hcHornMs6ggQL/view?usp=sharing)并复制到`weights`文件夹下
78 | 2. 示例图片:默认采用`png/demo.jpg`
79 |
80 | ### 5. 训练
81 |
82 | ```shell
83 | python main.py --mode='train' --train_path='you_data' --label_path='you_label' --batch_size=8 --visdom=True --area=True --boundary=False
84 | ```
85 |
86 | 注:
87 |
88 | 1. `--val=True`:训练阶段开启validation. 你可以将部分训练集作为验证集。同时提供验证集的路径
89 | 2. `you_data, you_label` :关于第2步中数据集的路径
90 | 3. `--area --boundary`:选择area-IOU或者boundary-IOU,或者两者均选择(体现在损失函数里面`loss.py`,建议采用默认的形式---只选择 area=True)
91 |
92 | ### 6. 测试
93 |
94 | ```shell
95 | python main.py --mode='test', --test_path='you_data' --test_label='your_label' --batch_size=1 --model='your_trained_model'
96 | ```
97 |
98 | 注:
99 |
100 | 1. 采用的指标和原始论文一致(改写自原代码)
101 |
102 | ## Bug
103 |
104 | 1. 采用boundary-iou容易出现`inf`的情况:需要将学习率调整到很小,如`1e-10`
105 | 2. 可能还存在数值问题
106 |
107 |
108 |
109 | 如有任何问题,欢迎在issue中提问~
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch
4 | from torch.utils import data
5 | from torchvision import transforms
6 |
7 |
8 | class ImageData(data.Dataset):
9 | """ image dataset
10 | img_root: image root (root which contain images)
11 | label_root: label root (root which contains labels)
12 | transform: pre-process for image
13 | t_transform: pre-process for label
14 | filename: MSRA-B use xxx.txt to recognize train-val-test data (only for MSRA-B)
15 | """
16 |
17 | def __init__(self, img_root, label_root, transform, t_transform, filename=None):
18 | if filename is None:
19 | self.image_path = list(map(lambda x: os.path.join(img_root, x), os.listdir(img_root)))
20 | self.label_path = list(
21 | map(lambda x: os.path.join(label_root, x.split('/')[-1][:-3] + 'png'), self.image_path))
22 | else:
23 | lines = [line.rstrip('\n')[:-3] for line in open(filename)]
24 | self.image_path = list(map(lambda x: os.path.join(img_root, x + 'jpg'), lines))
25 | self.label_path = list(map(lambda x: os.path.join(label_root, x + 'png'), lines))
26 |
27 | self.transform = transform
28 | self.t_transform = t_transform
29 |
30 | def __getitem__(self, item):
31 | image = Image.open(self.image_path[item])
32 | label = Image.open(self.label_path[item]).convert('L')
33 | if self.transform is not None:
34 | image = self.transform(image)
35 | if self.t_transform is not None:
36 | label = self.t_transform(label)
37 | return image, label
38 |
39 | def __len__(self):
40 | return len(self.image_path)
41 |
42 |
43 | # get the dataloader (Note: without data augmentation)
44 | def get_loader(img_root, label_root, img_size, batch_size, filename=None, mode='train', num_thread=4, pin=True):
45 | mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255
46 | if mode == 'train':
47 | transform = transforms.Compose([
48 | transforms.Resize((img_size, img_size)),
49 | transforms.ToTensor(),
50 | transforms.Lambda(lambda x: x - mean)
51 | ])
52 | t_transform = transforms.Compose([
53 | transforms.Resize((img_size // 2, img_size // 2)),
54 | transforms.ToTensor(),
55 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary
56 | ])
57 | dataset = ImageData(img_root, label_root, transform, t_transform, filename=filename)
58 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread)
59 | else:
60 | transform = transforms.Compose([
61 | transforms.Resize((img_size, img_size)),
62 | transforms.ToTensor(),
63 | transforms.Lambda(lambda x: x - mean)
64 | ])
65 | t_transform = transforms.Compose([
66 | transforms.ToTensor(),
67 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary
68 | ])
69 | dataset = ImageData(img_root, label_root, transform, t_transform, filename=filename)
70 | data_loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=num_thread)
71 | return data_loader
72 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from PIL import Image
5 | from torch.autograd import Variable
6 | import numpy as np
7 | from torchvision import transforms
8 | from nldf import build_model
9 |
10 |
11 | def demo(model_path, img_path, cuda):
12 | transform = transforms.Compose([transforms.Resize((352, 352)), transforms.ToTensor()])
13 | img = Image.open(img_path)
14 | shape = img.size
15 | img = transform(img) - torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255
16 | img = Variable(img.unsqueeze(0), volatile=True)
17 | net = build_model()
18 | net.load_state_dict(torch.load(model_path))
19 | net.eval()
20 | if cuda: img, net = img.cuda(), net.cuda()
21 | prob = net(img)
22 | prob = (prob.cpu().data[0][0].numpy() * 255).astype(np.uint8)
23 | p_img = Image.fromarray(prob, mode='L').resize(shape)
24 | p_img.show()
25 |
26 |
27 | if __name__ == '__main__':
28 | model_path = './weights/best.pth'
29 | img_path = './png/demo.jpg'
30 | parser = argparse.ArgumentParser()
31 |
32 | parser.add_argument('--demo_img', type=str, default=img_path)
33 | parser.add_argument('--trained_model', type=str, default=model_path)
34 | parser.add_argument('--cuda', type=bool, default=True)
35 | config = parser.parse_args()
36 | ext = ['.jpg', '.png']
37 | if not os.path.splitext(config.demo_img)[-1] in ext:
38 | raise IOError('illegal image path')
39 |
40 | demo(config.trained_model, config.demo_img, config.cuda)
41 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ECSSD images
3 | URL=http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/data/ECSSD/images.zip
4 | ZIP_FILE=./data/images.zip
5 | mkdir -p ./data/
6 | wget -N $URL -O $ZIP_FILE
7 | unzip $ZIP_FILE -d ./data/
8 | rm $ZIP_FILE
9 |
10 | # ECSSD labels
11 | URL=http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/data/ECSSD/ground_truth_mask.zip
12 | ZIP_FILE=./data/ground_truth_mask.zip
13 | mkdir -p ./data/
14 | wget -N $URL -O $ZIP_FILE
15 | unzip $ZIP_FILE -d ./data/
16 | rm $ZIP_FILE
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class GradLayer(nn.Module):
7 | def __init__(self):
8 | super(GradLayer, self).__init__()
9 | self.grad_x = nn.Conv2d(2, 2, 3, padding=1, bias=False)
10 | self.grad_y = nn.Conv2d(2, 2, 3, padding=1, bias=False)
11 | self.set_weight()
12 |
13 | def set_weight(self):
14 | x = torch.Tensor([[-1., 0, 1], [-2., 0, 2.], [-1., 0, 1.]]).view(1, 1, 3, 3)
15 | y = torch.Tensor([[-1., -2., -1.], [0, 0, 0], [1., 2., 1.]]).view(1, 1, 3, 3)
16 | weight_x = nn.Parameter(x, requires_grad=False)
17 | weight_y = nn.Parameter(y, requires_grad=False)
18 | self.grad_x.weight, self.grad_y.weight = weight_x, weight_y
19 |
20 | def forward(self, x):
21 | x1, x2 = self.grad_x(x), self.grad_y(x)
22 | # return torch.sqrt(torch.pow(x1, 2) + torch.pow(x2, 2))
23 | return torch.pow(x1, 2) + torch.pow(x2, 2)
24 |
25 |
26 | class Loss(nn.Module):
27 | def __init__(self, area=True, boundary=False, contour_th=1.5, ratio=1):
28 | super(Loss, self).__init__()
29 | self.area, self.boundary, self.cth, self.ratio = area, boundary, contour_th, ratio
30 | if boundary:
31 | self.gradlayer = GradLayer()
32 |
33 | def forward(self, x, label):
34 | loss = F.binary_cross_entropy(x, label)
35 | if self.area:
36 | area_loss = 1 - 2 * ((x * label).sum() + 1) / (x.sum() + label.sum() + 1)
37 | loss += area_loss
38 | if self.boundary:
39 | prob_grad = F.tanh(self.gradlayer(x))
40 | label_grad = torch.gt(self.gradlayer(label), self.cth).float()
41 | inter = torch.sum(prob_grad * label_grad)
42 | union = torch.pow(prob_grad, 2).sum() + torch.pow(label_grad, 2).sum()
43 | boundary_loss = (1 - 2 * (inter + 1) / (union + 1))
44 | loss = loss + self.ratio * boundary_loss
45 | return loss
46 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from dataset import get_loader
4 | from solver import Solver
5 |
6 |
7 | def main(config):
8 | if config.mode == 'train':
9 | train_loader = get_loader(config.train_path, config.label_path, config.img_size, config.batch_size,
10 | filename=config.train_file, num_thread=config.num_thread)
11 | if config.val:
12 | val_loader = get_loader(config.val_path, config.val_label, config.img_size, config.batch_size,
13 | filename=config.val_file, num_thread=config.num_thread)
14 | run = 0
15 | while os.path.exists("%s/run-%d" % (config.save_fold, run)): run += 1
16 | os.mkdir("%s/run-%d" % (config.save_fold, run))
17 | os.mkdir("%s/run-%d/logs" % (config.save_fold, run))
18 | # os.mkdir("%s/run-%d/images" % (config.save_fold, run))
19 | os.mkdir("%s/run-%d/models" % (config.save_fold, run))
20 | config.save_fold = "%s/run-%d" % (config.save_fold, run)
21 | if config.val:
22 | train = Solver(train_loader, val_loader, None, config)
23 | else:
24 | train = Solver(train_loader, None, None, config)
25 | train.train()
26 | elif config.mode == 'test':
27 | test_loader = get_loader(config.test_path, config.test_label, config.img_size, config.batch_size, mode='test',
28 | filename=config.test_file, num_thread=config.num_thread)
29 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold)
30 | test = Solver(None, None, test_loader, config)
31 | test.test(100)
32 | else:
33 | raise IOError("illegal input!!!")
34 |
35 |
36 | if __name__ == '__main__':
37 | data_root = os.path.join(os.path.expanduser('~'), 'data')
38 | vgg_path = './weights/vgg16_feat.pth'
39 |
40 | # # -----ECSSD dataset-----
41 | # train_path = os.path.join(data_root, 'ECSSD/images')
42 | # label_path = os.path.join(data_root, 'ECSSD/ground_truth_mask')
43 | # val_path = os.path.join(data_root, 'ECSSD/val_images')
44 | # val_label = os.path.join(data_root, 'ECSSD/val_ground_truth_mask')
45 | # test_path = os.path.join(data_root, 'ECSSD/test_images')
46 | # test_label = os.path.join(data_root, 'ECSSD/test_ground_truth_mask')
47 |
48 | # # -----MSRA-B dataset-----
49 | image_path = os.path.join(data_root, 'MSRA-B/image')
50 | label_path = os.path.join(data_root, 'MSRA-B/annotation')
51 | train_file = os.path.join(data_root, 'MSRA-B/train_cvpr2013.txt')
52 | valid_file = os.path.join(data_root, 'MSRA-B/valid_cvpr2013.txt')
53 | test_file = os.path.join(data_root, 'MSRA-B/test_cvpr2013.txt')
54 |
55 | parser = argparse.ArgumentParser()
56 |
57 | # Hyper-parameters
58 | parser.add_argument('--n_color', type=int, default=3)
59 | parser.add_argument('--img_size', type=int, default=352)
60 | parser.add_argument('--lr', type=float, default=1e-6)
61 | parser.add_argument('--clip_gradient', type=float, default=1.0)
62 | parser.add_argument('--cuda', type=bool, default=True)
63 | parser.add_argument('--contour_th', type=float, default=1.5)
64 |
65 | # Training settings
66 | parser.add_argument('--vgg', type=str, default=vgg_path)
67 | parser.add_argument('--train_path', type=str, default=image_path)
68 | parser.add_argument('--label_path', type=str, default=label_path)
69 | parser.add_argument('--train_file', type=str, default=train_file)
70 | parser.add_argument('--epoch', type=int, default=500)
71 | parser.add_argument('--batch_size', type=int, default=1)
72 | parser.add_argument('--val', type=bool, default=True)
73 | parser.add_argument('--val_path', type=str, default=image_path)
74 | parser.add_argument('--val_label', type=str, default=label_path)
75 | parser.add_argument('--val_file', type=str, default=valid_file)
76 | parser.add_argument('--num_thread', type=int, default=4)
77 | parser.add_argument('--load', type=str, default='')
78 | parser.add_argument('--save_fold', type=str, default='./results')
79 | parser.add_argument('--epoch_val', type=int, default=5)
80 | parser.add_argument('--epoch_save', type=int, default=20)
81 | parser.add_argument('--epoch_show', type=int, default=1)
82 | parser.add_argument('--pre_trained', type=str, default=None)
83 | parser.add_argument('--area', type=bool, default=True)
84 | parser.add_argument('--boundary', type=bool, default=False)
85 |
86 | # Testing settings
87 | parser.add_argument('--test_path', type=str, default=image_path)
88 | parser.add_argument('--test_label', type=str, default=label_path)
89 | parser.add_argument('--model', type=str, default='./weights/best.pth')
90 | parser.add_argument('--test_fold', type=str, default='./results/test')
91 | parser.add_argument('--test_file', type=str, default=test_file)
92 |
93 | # Misc
94 | parser.add_argument('--mode', type=str, default='test', choices=['train', 'test'])
95 | parser.add_argument('--visdom', type=bool, default=False)
96 |
97 | config = parser.parse_args()
98 | if not os.path.exists(config.save_fold): os.mkdir(config.save_fold)
99 | main(config)
100 |
--------------------------------------------------------------------------------
/nldf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 |
6 | base = {'352': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']}
7 | extra = {'352': [2, 7, 14, 21, 28]}
8 |
9 |
10 | # vgg16
11 | def vgg(cfg, i, batch_norm=False):
12 | layers = []
13 | in_channels = i
14 | for v in cfg:
15 | if v == 'M':
16 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
17 | else:
18 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
19 | if batch_norm:
20 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
21 | else:
22 | layers += [conv2d, nn.ReLU(inplace=True)]
23 | in_channels = v
24 | return layers
25 |
26 |
27 | class ConvConstract(nn.Module):
28 | def __init__(self, in_channel):
29 | super(ConvConstract, self).__init__()
30 | self.conv1 = nn.Conv2d(in_channel, 128, kernel_size=3, padding=1)
31 | self.cons1 = nn.AvgPool2d(3, stride=1, padding=1)
32 |
33 | def forward(self, x):
34 | x = F.relu(self.conv1(x), inplace=True)
35 | x2 = self.cons1(x)
36 | return x, x - x2
37 |
38 |
39 | # extra part
40 | def extra_layer(vgg, cfg):
41 | feat_layers, pool_layers = [], []
42 | for k, v in enumerate(cfg):
43 | feat_layers += [ConvConstract(vgg[v].out_channels)]
44 | if k == 0:
45 | pool_layers += [nn.Conv2d(128 * (6 - k), 128 * (5 - k), 1)]
46 | else:
47 | # TODO: change this to sampling
48 | pool_layers += [nn.ConvTranspose2d(128 * (6 - k), 128 * (5 - k), 3, 2, 1, 1)]
49 | return vgg, feat_layers, pool_layers
50 |
51 |
52 | class NLDF(nn.Module):
53 | def __init__(self, base, feat_layers, pool_layers):
54 | super(NLDF, self).__init__()
55 | self.pos = [4, 9, 16, 23, 30]
56 | self.base = nn.ModuleList(base)
57 | self.feat = nn.ModuleList(feat_layers)
58 | self.pool = nn.ModuleList(pool_layers)
59 | self.glob = nn.Sequential(nn.Conv2d(512, 128, 5), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 5),
60 | nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3))
61 | self.conv_g = nn.Conv2d(128, 1, 1)
62 | self.conv_l = nn.Conv2d(640, 1, 1)
63 |
64 | def forward(self, x, label=None):
65 | sources, num = list(), 0
66 | for k in range(len(self.base)):
67 | x = self.base[k](x)
68 | if k in self.pos:
69 | sources.append(self.feat[num](x))
70 | num = num + 1
71 | for k in range(4, -1, -1):
72 | if k == 4:
73 | out = F.relu(self.pool[k](torch.cat([sources[k][0], sources[k][1]], dim=1)), inplace=True)
74 | else:
75 | out = self.pool[k](torch.cat([sources[k][0], sources[k][1], out], dim=1)) if k == 0 else F.relu(
76 | self.pool[k](torch.cat([sources[k][0], sources[k][1], out], dim=1)), inplace=True)
77 |
78 | score = self.conv_g(self.glob(x)) + self.conv_l(out)
79 | prob = torch.sigmoid(score)
80 | return prob
81 |
82 |
83 | def build_model():
84 | return NLDF(*extra_layer(vgg(base['352'], 3), extra['352']))
85 |
86 |
87 | def xavier(param):
88 | init.xavier_uniform_(param)
89 |
90 |
91 | def weights_init(m):
92 | if isinstance(m, nn.Conv2d):
93 | xavier(m.weight.data)
94 | m.bias.data.zero_()
95 |
96 |
--------------------------------------------------------------------------------
/png/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/NLDF-pytorch/7c4a5dddab277c6136378e1592ef7c287b60d314/png/demo.jpg
--------------------------------------------------------------------------------
/png/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/NLDF-pytorch/7c4a5dddab277c6136378e1592ef7c287b60d314/png/example.png
--------------------------------------------------------------------------------
/png/loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/NLDF-pytorch/7c4a5dddab277c6136378e1592ef7c287b60d314/png/loss.png
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 | from torch.nn import utils, functional as F
4 | from torch.optim import Adam
5 | from torch.backends import cudnn
6 | from nldf import build_model, weights_init
7 | from loss import Loss
8 | from tools.visual import Viz_visdom
9 |
10 |
11 | class Solver(object):
12 | def __init__(self, train_loader, val_loader, test_loader, config):
13 | self.train_loader = train_loader
14 | self.val_loader = val_loader
15 | self.test_loader = test_loader
16 | self.config = config
17 | self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255
18 | self.beta = 0.3
19 | self.device = torch.device('cpu')
20 | if self.config.cuda:
21 | cudnn.benchmark = True
22 | self.device = torch.device('cuda')
23 | if config.visdom:
24 | self.visual = Viz_visdom("NLDF", 1)
25 | self.build_model()
26 | if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained))
27 | if config.mode == 'train':
28 | self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w')
29 | else:
30 | self.net.load_state_dict(torch.load(self.config.model))
31 | self.net.eval()
32 | self.test_output = open("%s/test.txt" % config.test_fold, 'w')
33 |
34 | def print_network(self, model, name):
35 | num_params = 0
36 | for p in model.parameters():
37 | num_params += p.numel()
38 | print(name)
39 | print(model)
40 | print("The number of parameters: {}".format(num_params))
41 |
42 | def build_model(self):
43 | self.net = build_model()
44 | if self.config.mode == 'train': self.loss = Loss(self.config.area, self.config.boundary)
45 | self.net = self.net.to(self.device)
46 | if self.config.cuda and self.config.mode == 'train': self.loss = self.loss.cuda()
47 | self.net.train()
48 | self.net.apply(weights_init)
49 | if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg))
50 | if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load))
51 | self.optimizer = Adam(self.net.parameters(), self.config.lr)
52 | self.print_network(self.net, 'NLDF')
53 |
54 | def update_lr(self, lr):
55 | for param_group in self.optimizer.param_groups:
56 | param_group['lr'] = lr
57 |
58 | def clip(self, y):
59 | return torch.clamp(y, 0.0, 1.0)
60 |
61 | def eval_mae(self, y_pred, y):
62 | return torch.abs(y_pred - y).mean()
63 |
64 | # TODO: write a more efficient version
65 | def eval_pr(self, y_pred, y, num):
66 | prec, recall = torch.zeros(num), torch.zeros(num)
67 | thlist = torch.linspace(0, 1 - 1e-10, num)
68 | for i in range(num):
69 | y_temp = (y_pred >= thlist[i]).float()
70 | tp = (y_temp * y).sum()
71 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum()
72 | return prec, recall
73 |
74 | def validation(self):
75 | avg_mae = 0.0
76 | self.net.eval()
77 | for i, data_batch in enumerate(self.val_loader):
78 | with torch.no_grad():
79 | images, labels = data_batch
80 | images, labels = images.to(self.device), labels.to(self.device)
81 | prob_pred = self.net(images)
82 | avg_mae += self.eval_mae(prob_pred, labels).cpu().item()
83 | self.net.train()
84 | return avg_mae / len(self.val_loader)
85 |
86 | def test(self, num):
87 | avg_mae, img_num = 0.0, len(self.test_loader)
88 | avg_prec, avg_recall = torch.zeros(num), torch.zeros(num)
89 | for i, data_batch in enumerate(self.test_loader):
90 | with torch.no_grad():
91 | images, labels = data_batch
92 | shape = labels.size()[2:]
93 | images = images.to(self.device)
94 | prob_pred = F.interpolate(self.net(images), size=shape, mode='bilinear', align_corners=True).cpu()
95 | mae = self.eval_mae(prob_pred, labels)
96 | prec, recall = self.eval_pr(prob_pred, labels, num)
97 | print("[%d] mae: %.4f" % (i, mae))
98 | print("[%d] mae: %.4f" % (i, mae), file=self.test_output)
99 | avg_mae += mae
100 | avg_prec, avg_recall = avg_prec + prec, avg_recall + recall
101 | avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num
102 | score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall)
103 | score[score != score] = 0 # delete the nan
104 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()))
105 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output)
106 |
107 | def train(self):
108 | iter_num = len(self.train_loader.dataset) // self.config.batch_size
109 | best_mae = 1.0 if self.config.val else None
110 | for epoch in range(self.config.epoch):
111 | loss_epoch = 0
112 | for i, data_batch in enumerate(self.train_loader):
113 | if (i + 1) > iter_num: break
114 | self.net.zero_grad()
115 | x, y = data_batch
116 | x, y = x.to(self.device), y.to(self.device)
117 | y_pred = self.net(x)
118 | loss = self.loss(y_pred, y)
119 | loss.backward()
120 | utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient)
121 | self.optimizer.step()
122 | loss_epoch += loss.cpu().item()
123 | print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % (
124 | epoch, self.config.epoch, i, iter_num, loss.cpu().item()))
125 | if self.config.visdom:
126 | error = OrderedDict([('loss:', loss.cpu().item())])
127 | self.visual.plot_current_errors(epoch, i / iter_num, error)
128 | if (epoch + 1) % self.config.epoch_show == 0:
129 | print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num),
130 | file=self.log_output)
131 | if self.config.visdom:
132 | avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)])
133 | self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1)
134 | img = OrderedDict([('origin', self.mean + x.cpu()[0]), ('label', y.cpu()[0][0]),
135 | ('pred_label', y_pred.cpu()[0][0])])
136 | self.visual.plot_current_img(img)
137 | if self.config.val and (epoch + 1) % self.config.epoch_val == 0:
138 | mae = self.validation()
139 | print('--- Best MAE: %.4f, Curr MAE: %.4f ---' % (best_mae, mae))
140 | print('--- Best MAE: %.4f, Curr MAE: %.4f ---' % (best_mae, mae), file=self.log_output)
141 | if best_mae > mae:
142 | best_mae = mae
143 | torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold)
144 | if (epoch + 1) % self.config.epoch_save == 0:
145 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1))
146 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold)
147 |
--------------------------------------------------------------------------------
/tools/extract_vgg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision import models
4 |
5 |
6 | if __name__ == '__main__':
7 | save_fold = '../weights'
8 | if not os.path.exists(save_fold):
9 | os.mkdir(save_fold)
10 | vgg = models.vgg16(pretrained=True)
11 | torch.save(vgg.features.state_dict(), os.path.join(save_fold, 'vgg16_feat.pth'))
--------------------------------------------------------------------------------
/tools/visual.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Viz_visdom():
5 | def __init__(self, name, display_id=0):
6 | self.name = name
7 | self.display_id = display_id
8 | self.idx = display_id
9 | self.plot_data = {}
10 | if display_id > 0:
11 | import visdom
12 | self.vis = visdom.Visdom(port=8097)
13 |
14 | def plot_current_errors(self, epoch, counter_ratio, errors, idx=0):
15 |
16 | if idx not in self.plot_data:
17 | self.plot_data[idx] = {'X': [], 'Y': [], 'legend': list(errors.keys())}
18 | # self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
19 | self.plot_data[idx]['X'].append(epoch + counter_ratio)
20 | self.plot_data[idx]['Y'].append([errors[k] for k in self.plot_data[idx]['legend']])
21 | self.vis.line(
22 | X=np.stack([np.array(self.plot_data[idx]['X'])] * len(self.plot_data[idx]['legend']), 1)
23 | if len(errors) > 1 else np.array(self.plot_data[idx]['X']),
24 | Y=np.array(self.plot_data[idx]['Y']) if len(errors) > 1 else np.array(self.plot_data[idx]['Y'])[:, 0],
25 | opts={
26 | 'title': self.name + ' loss over time %d' % idx,
27 | 'legend': self.plot_data[idx]['legend'],
28 | 'xlabel': 'epoch',
29 | 'ylabel': 'loss'},
30 | win=self.display_id + idx)
31 | if self.idx < self.display_id + idx:
32 | self.idx = self.display_id + idx
33 |
34 | def plot_current_img(self, visuals, c_prev=True):
35 | idx = self.idx + 1
36 | for label, image_numpy in visuals.items():
37 | if c_prev:
38 | self.vis.image(image_numpy, opts=dict(title=label),
39 | win=self.display_id + idx)
40 | else:
41 | image_numpy = image_numpy.swapaxes(0, 2).swapaxes(1, 2)
42 | self.vis.image(image_numpy, opts=dict(title=label),
43 | win=self.display_id + idx)
44 | idx += 1
45 |
--------------------------------------------------------------------------------