├── README.md ├── README.zh.md ├── dataset.py ├── demo.py ├── download.sh ├── loss.py ├── main.py ├── nldf.py ├── png ├── demo.jpg ├── example.png └── loss.png ├── solver.py └── tools ├── extract_vgg.py └── visual.py /README.md: -------------------------------------------------------------------------------- 1 | # NLDF 2 | [中文说明](./README.zh.md) 3 | 4 | An unofficial implementation of [Non-Local Deep Features for Salient Object Detection](https://sites.google.com/view/zhimingluo/nldf). 5 | 6 |

7 | 8 | The official Tensorflow version: [NLDF](https://github.com/zhimingluo/NLDF) 9 | 10 | Some thing difference: 11 | 12 | 1. ~~dataset~~ 13 | 2. score with one channel, rather than two channels 14 | 3. Dice IOU: boundary version and area version 15 | 16 | ## Prerequisites 17 | 18 | - [Python 3](https://www.continuum.io/downloads) 19 | - [Pytorch 1.0](http://pytorch.org/) 20 | - [torchvision](http://pytorch.org/) 21 | - [visdom](https://github.com/facebookresearch/visdom) (optional for visualization) 22 | 23 | ## Results 24 | 25 | The information of Loss: 26 | 27 | ![](./png/loss.png) 28 | 29 | Performance: 30 | 31 | | Dataset | max F(paper) | MAE(paper) | max F(here) | MAE(here) | 32 | | :-----: | :----------: | :--------: | :---------: | :-------: | 33 | | MSRA-B | 0.911 | 0.048 | 0.9006 | 0.0592 | 34 | 35 | Note: 36 | 37 | 1. only training 200 epoch, larger epoch may nearly the original paper 38 | 2. This reproduction use area IOU, and original paper use boundary IOU 39 | 3. ~~it's unfairness to this compare. (Different training data, I can not find the dataset use in original paper )~~ 40 | 41 | ## Usage 42 | 43 | ### 1. Clone the repository 44 | 45 | ```shell 46 | git clone git@github.com:AceCoooool/NLDF-pytorch.git 47 | cd NLDF-pytorch/ 48 | ``` 49 | 50 | ### 2. Download the dataset 51 | 52 | Note: the original paper use other datasets. 53 | 54 | Download the [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html) dataset. 55 | 56 | ```shell 57 | bash download.sh 58 | ``` 59 | 60 | ### 3. Get pre-trained vgg 61 | 62 | ```bash 63 | cd tools/ 64 | python extract_vgg.py 65 | cd .. 66 | ``` 67 | 68 | ### 4. Demo 69 | 70 | ```shell 71 | python demo.py --demo_img='your_picture' --trained_model='pre_trained pth' --cuda=True 72 | ``` 73 | 74 | Note: 75 | 76 | 1. default choose: download and copy the [pretrained model](https://drive.google.com/file/d/10cnWpqABT6MRdTO0p17hcHornMs6ggQL/view?usp=sharing) to `weights` directory. 77 | 2. a demo picture is in `png/demo.jpg` 78 | 79 | ### 5. Train 80 | 81 | ```shell 82 | python main.py --mode='train' --train_path='you_data' --label_path='you_label' --batch_size=8 --visdom=True --area=True 83 | ``` 84 | 85 | Note: 86 | 87 | 1. `--area=True, --boundary=True` area and boundary Dice IOU (default: `--area=True --boundary=False`) 88 | 2. `--val=True` add the validation (but your need to add the `--val_path` and `--val_label`) 89 | 3. `you_data, you_label` means your training data root. (connect to the step 2) 90 | 91 | ### 6. Test 92 | 93 | ```shell 94 | python main.py --mode='test', --test_path='you_data' --test_label='your_label' --batch_size=1 --model='your_trained_model' 95 | ``` 96 | 97 | Note: 98 | 99 | 1. use the same evaluation (this is a reproduction from original achievement) 100 | 101 | ## Bug 102 | 103 | 1. The boundary Dice IOU may cause `inf`,it is better to use area Dice IOU. 104 | 105 | Maybe, it is better to add Batch Normalization. -------------------------------------------------------------------------------- /README.zh.md: -------------------------------------------------------------------------------- 1 | # NLFD 2 | [English](./README.md) 3 | 4 | 基于Pytorch的非官方版本实现: [Non-Local Deep Features for Salient Object Detection](https://sites.google.com/view/zhimingluo/nldf). 5 | 6 |

7 | 8 | 官方Tensorflow版本链接: [NLDF](https://github.com/zhimingluo/NLDF) 9 | 10 | 此实现的几点改动: 11 | 12 | 1. ~~数据集(个人没找到MSRA-B的图片)~~ 13 | 2. 网络结构上的一些不同:此处采用最后输出为单个概率图,官方版本中是两个互异的概率图 14 | 3. 增加了“面积重合率”,原论文中是“边缘重合率” (可同时选择两者) 15 | 16 | ## 依赖库 17 | 18 | - [Python 3](https://www.continuum.io/downloads) 19 | - [Pytorch 0.3.0](http://pytorch.org/) 20 | - [torchvision](http://pytorch.org/) 21 | - [visdom](https://github.com/facebookresearch/visdom) (optional for visualization) 22 | 23 | ## 复现情况 24 | 25 | 迭代过程中的损失函数下降情况: 26 | 27 | ![](./png/loss.png) 28 | 29 | 性能: 30 | 31 | | Dataset | max F(paper) | MAE(paper) | max F(here) | MAE(here) | 32 | | :-----: | :----------: | :--------: | :---------: | :-------: | 33 | | ECSSD | 0.905 | 0.063 | 0.9830 | 0.0375 | 34 | 35 | 说明: 36 | 37 | 1. 此处复现采用的是面积IOU,原始论文采用的是边缘IOU 38 | 2. 此处的比较是“不公平”,两者采用的数据集并不相同,且直接拿了训练集来测的指标(只是为了说明性能能够达到甚至超过原始paper)--- 原始论文中的数据集个人没找到。 39 | 40 | ## 使用说明 41 | 42 | ### 1. 复制仓库到本地 43 | 44 | ```shell 45 | git clone git@github.com:AceCoooool/NLFD-pytorch.git 46 | cd NLFD-pytorch/ 47 | ``` 48 | 49 | ### 2. 从网上下载数据集 50 | 51 | 注:原始论文中采用更多的数据集 52 | 53 | 可从下面链接下载数据集: [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html) 54 | 55 | ```shell 56 | bash download.sh 57 | ``` 58 | 59 | ### 3. 提取预先训练好的VGG 60 | 61 | ```bash 62 | cd tools/ 63 | python extract_vgg.py 64 | cd .. 65 | ``` 66 | 67 | 注:此处个人直接采用torchvision里面训练好的VGG 68 | 69 | ### 4. 示例 70 | 71 | ```shell 72 | python demo.py --demo_img='your_picture' --trained_model='pre_trained pth' --cuda=True 73 | ``` 74 | 75 | 注: 76 | 77 | 1. 默认参数:下载[训练好的模型](https://drive.google.com/file/d/10cnWpqABT6MRdTO0p17hcHornMs6ggQL/view?usp=sharing)并复制到`weights`文件夹下 78 | 2. 示例图片:默认采用`png/demo.jpg` 79 | 80 | ### 5. 训练 81 | 82 | ```shell 83 | python main.py --mode='train' --train_path='you_data' --label_path='you_label' --batch_size=8 --visdom=True --area=True --boundary=False 84 | ``` 85 | 86 | 注: 87 | 88 | 1. `--val=True`:训练阶段开启validation. 你可以将部分训练集作为验证集。同时提供验证集的路径 89 | 2. `you_data, you_label` :关于第2步中数据集的路径 90 | 3. `--area --boundary`:选择area-IOU或者boundary-IOU,或者两者均选择(体现在损失函数里面`loss.py`,建议采用默认的形式---只选择 area=True) 91 | 92 | ### 6. 测试 93 | 94 | ```shell 95 | python main.py --mode='test', --test_path='you_data' --test_label='your_label' --batch_size=1 --model='your_trained_model' 96 | ``` 97 | 98 | 注: 99 | 100 | 1. 采用的指标和原始论文一致(改写自原代码) 101 | 102 | ## Bug 103 | 104 | 1. 采用boundary-iou容易出现`inf`的情况:需要将学习率调整到很小,如`1e-10` 105 | 2. 可能还存在数值问题 106 | 107 | 108 | 109 | 如有任何问题,欢迎在issue中提问~ -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch 4 | from torch.utils import data 5 | from torchvision import transforms 6 | 7 | 8 | class ImageData(data.Dataset): 9 | """ image dataset 10 | img_root: image root (root which contain images) 11 | label_root: label root (root which contains labels) 12 | transform: pre-process for image 13 | t_transform: pre-process for label 14 | filename: MSRA-B use xxx.txt to recognize train-val-test data (only for MSRA-B) 15 | """ 16 | 17 | def __init__(self, img_root, label_root, transform, t_transform, filename=None): 18 | if filename is None: 19 | self.image_path = list(map(lambda x: os.path.join(img_root, x), os.listdir(img_root))) 20 | self.label_path = list( 21 | map(lambda x: os.path.join(label_root, x.split('/')[-1][:-3] + 'png'), self.image_path)) 22 | else: 23 | lines = [line.rstrip('\n')[:-3] for line in open(filename)] 24 | self.image_path = list(map(lambda x: os.path.join(img_root, x + 'jpg'), lines)) 25 | self.label_path = list(map(lambda x: os.path.join(label_root, x + 'png'), lines)) 26 | 27 | self.transform = transform 28 | self.t_transform = t_transform 29 | 30 | def __getitem__(self, item): 31 | image = Image.open(self.image_path[item]) 32 | label = Image.open(self.label_path[item]).convert('L') 33 | if self.transform is not None: 34 | image = self.transform(image) 35 | if self.t_transform is not None: 36 | label = self.t_transform(label) 37 | return image, label 38 | 39 | def __len__(self): 40 | return len(self.image_path) 41 | 42 | 43 | # get the dataloader (Note: without data augmentation) 44 | def get_loader(img_root, label_root, img_size, batch_size, filename=None, mode='train', num_thread=4, pin=True): 45 | mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 46 | if mode == 'train': 47 | transform = transforms.Compose([ 48 | transforms.Resize((img_size, img_size)), 49 | transforms.ToTensor(), 50 | transforms.Lambda(lambda x: x - mean) 51 | ]) 52 | t_transform = transforms.Compose([ 53 | transforms.Resize((img_size // 2, img_size // 2)), 54 | transforms.ToTensor(), 55 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary 56 | ]) 57 | dataset = ImageData(img_root, label_root, transform, t_transform, filename=filename) 58 | data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread) 59 | else: 60 | transform = transforms.Compose([ 61 | transforms.Resize((img_size, img_size)), 62 | transforms.ToTensor(), 63 | transforms.Lambda(lambda x: x - mean) 64 | ]) 65 | t_transform = transforms.Compose([ 66 | transforms.ToTensor(), 67 | transforms.Lambda(lambda x: torch.round(x)) # TODO: it maybe unnecessary 68 | ]) 69 | dataset = ImageData(img_root, label_root, transform, t_transform, filename=filename) 70 | data_loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=num_thread) 71 | return data_loader 72 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from PIL import Image 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from torchvision import transforms 8 | from nldf import build_model 9 | 10 | 11 | def demo(model_path, img_path, cuda): 12 | transform = transforms.Compose([transforms.Resize((352, 352)), transforms.ToTensor()]) 13 | img = Image.open(img_path) 14 | shape = img.size 15 | img = transform(img) - torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 16 | img = Variable(img.unsqueeze(0), volatile=True) 17 | net = build_model() 18 | net.load_state_dict(torch.load(model_path)) 19 | net.eval() 20 | if cuda: img, net = img.cuda(), net.cuda() 21 | prob = net(img) 22 | prob = (prob.cpu().data[0][0].numpy() * 255).astype(np.uint8) 23 | p_img = Image.fromarray(prob, mode='L').resize(shape) 24 | p_img.show() 25 | 26 | 27 | if __name__ == '__main__': 28 | model_path = './weights/best.pth' 29 | img_path = './png/demo.jpg' 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument('--demo_img', type=str, default=img_path) 33 | parser.add_argument('--trained_model', type=str, default=model_path) 34 | parser.add_argument('--cuda', type=bool, default=True) 35 | config = parser.parse_args() 36 | ext = ['.jpg', '.png'] 37 | if not os.path.splitext(config.demo_img)[-1] in ext: 38 | raise IOError('illegal image path') 39 | 40 | demo(config.trained_model, config.demo_img, config.cuda) 41 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ECSSD images 3 | URL=http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/data/ECSSD/images.zip 4 | ZIP_FILE=./data/images.zip 5 | mkdir -p ./data/ 6 | wget -N $URL -O $ZIP_FILE 7 | unzip $ZIP_FILE -d ./data/ 8 | rm $ZIP_FILE 9 | 10 | # ECSSD labels 11 | URL=http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/data/ECSSD/ground_truth_mask.zip 12 | ZIP_FILE=./data/ground_truth_mask.zip 13 | mkdir -p ./data/ 14 | wget -N $URL -O $ZIP_FILE 15 | unzip $ZIP_FILE -d ./data/ 16 | rm $ZIP_FILE -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GradLayer(nn.Module): 7 | def __init__(self): 8 | super(GradLayer, self).__init__() 9 | self.grad_x = nn.Conv2d(2, 2, 3, padding=1, bias=False) 10 | self.grad_y = nn.Conv2d(2, 2, 3, padding=1, bias=False) 11 | self.set_weight() 12 | 13 | def set_weight(self): 14 | x = torch.Tensor([[-1., 0, 1], [-2., 0, 2.], [-1., 0, 1.]]).view(1, 1, 3, 3) 15 | y = torch.Tensor([[-1., -2., -1.], [0, 0, 0], [1., 2., 1.]]).view(1, 1, 3, 3) 16 | weight_x = nn.Parameter(x, requires_grad=False) 17 | weight_y = nn.Parameter(y, requires_grad=False) 18 | self.grad_x.weight, self.grad_y.weight = weight_x, weight_y 19 | 20 | def forward(self, x): 21 | x1, x2 = self.grad_x(x), self.grad_y(x) 22 | # return torch.sqrt(torch.pow(x1, 2) + torch.pow(x2, 2)) 23 | return torch.pow(x1, 2) + torch.pow(x2, 2) 24 | 25 | 26 | class Loss(nn.Module): 27 | def __init__(self, area=True, boundary=False, contour_th=1.5, ratio=1): 28 | super(Loss, self).__init__() 29 | self.area, self.boundary, self.cth, self.ratio = area, boundary, contour_th, ratio 30 | if boundary: 31 | self.gradlayer = GradLayer() 32 | 33 | def forward(self, x, label): 34 | loss = F.binary_cross_entropy(x, label) 35 | if self.area: 36 | area_loss = 1 - 2 * ((x * label).sum() + 1) / (x.sum() + label.sum() + 1) 37 | loss += area_loss 38 | if self.boundary: 39 | prob_grad = F.tanh(self.gradlayer(x)) 40 | label_grad = torch.gt(self.gradlayer(label), self.cth).float() 41 | inter = torch.sum(prob_grad * label_grad) 42 | union = torch.pow(prob_grad, 2).sum() + torch.pow(label_grad, 2).sum() 43 | boundary_loss = (1 - 2 * (inter + 1) / (union + 1)) 44 | loss = loss + self.ratio * boundary_loss 45 | return loss 46 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from dataset import get_loader 4 | from solver import Solver 5 | 6 | 7 | def main(config): 8 | if config.mode == 'train': 9 | train_loader = get_loader(config.train_path, config.label_path, config.img_size, config.batch_size, 10 | filename=config.train_file, num_thread=config.num_thread) 11 | if config.val: 12 | val_loader = get_loader(config.val_path, config.val_label, config.img_size, config.batch_size, 13 | filename=config.val_file, num_thread=config.num_thread) 14 | run = 0 15 | while os.path.exists("%s/run-%d" % (config.save_fold, run)): run += 1 16 | os.mkdir("%s/run-%d" % (config.save_fold, run)) 17 | os.mkdir("%s/run-%d/logs" % (config.save_fold, run)) 18 | # os.mkdir("%s/run-%d/images" % (config.save_fold, run)) 19 | os.mkdir("%s/run-%d/models" % (config.save_fold, run)) 20 | config.save_fold = "%s/run-%d" % (config.save_fold, run) 21 | if config.val: 22 | train = Solver(train_loader, val_loader, None, config) 23 | else: 24 | train = Solver(train_loader, None, None, config) 25 | train.train() 26 | elif config.mode == 'test': 27 | test_loader = get_loader(config.test_path, config.test_label, config.img_size, config.batch_size, mode='test', 28 | filename=config.test_file, num_thread=config.num_thread) 29 | if not os.path.exists(config.test_fold): os.mkdir(config.test_fold) 30 | test = Solver(None, None, test_loader, config) 31 | test.test(100) 32 | else: 33 | raise IOError("illegal input!!!") 34 | 35 | 36 | if __name__ == '__main__': 37 | data_root = os.path.join(os.path.expanduser('~'), 'data') 38 | vgg_path = './weights/vgg16_feat.pth' 39 | 40 | # # -----ECSSD dataset----- 41 | # train_path = os.path.join(data_root, 'ECSSD/images') 42 | # label_path = os.path.join(data_root, 'ECSSD/ground_truth_mask') 43 | # val_path = os.path.join(data_root, 'ECSSD/val_images') 44 | # val_label = os.path.join(data_root, 'ECSSD/val_ground_truth_mask') 45 | # test_path = os.path.join(data_root, 'ECSSD/test_images') 46 | # test_label = os.path.join(data_root, 'ECSSD/test_ground_truth_mask') 47 | 48 | # # -----MSRA-B dataset----- 49 | image_path = os.path.join(data_root, 'MSRA-B/image') 50 | label_path = os.path.join(data_root, 'MSRA-B/annotation') 51 | train_file = os.path.join(data_root, 'MSRA-B/train_cvpr2013.txt') 52 | valid_file = os.path.join(data_root, 'MSRA-B/valid_cvpr2013.txt') 53 | test_file = os.path.join(data_root, 'MSRA-B/test_cvpr2013.txt') 54 | 55 | parser = argparse.ArgumentParser() 56 | 57 | # Hyper-parameters 58 | parser.add_argument('--n_color', type=int, default=3) 59 | parser.add_argument('--img_size', type=int, default=352) 60 | parser.add_argument('--lr', type=float, default=1e-6) 61 | parser.add_argument('--clip_gradient', type=float, default=1.0) 62 | parser.add_argument('--cuda', type=bool, default=True) 63 | parser.add_argument('--contour_th', type=float, default=1.5) 64 | 65 | # Training settings 66 | parser.add_argument('--vgg', type=str, default=vgg_path) 67 | parser.add_argument('--train_path', type=str, default=image_path) 68 | parser.add_argument('--label_path', type=str, default=label_path) 69 | parser.add_argument('--train_file', type=str, default=train_file) 70 | parser.add_argument('--epoch', type=int, default=500) 71 | parser.add_argument('--batch_size', type=int, default=1) 72 | parser.add_argument('--val', type=bool, default=True) 73 | parser.add_argument('--val_path', type=str, default=image_path) 74 | parser.add_argument('--val_label', type=str, default=label_path) 75 | parser.add_argument('--val_file', type=str, default=valid_file) 76 | parser.add_argument('--num_thread', type=int, default=4) 77 | parser.add_argument('--load', type=str, default='') 78 | parser.add_argument('--save_fold', type=str, default='./results') 79 | parser.add_argument('--epoch_val', type=int, default=5) 80 | parser.add_argument('--epoch_save', type=int, default=20) 81 | parser.add_argument('--epoch_show', type=int, default=1) 82 | parser.add_argument('--pre_trained', type=str, default=None) 83 | parser.add_argument('--area', type=bool, default=True) 84 | parser.add_argument('--boundary', type=bool, default=False) 85 | 86 | # Testing settings 87 | parser.add_argument('--test_path', type=str, default=image_path) 88 | parser.add_argument('--test_label', type=str, default=label_path) 89 | parser.add_argument('--model', type=str, default='./weights/best.pth') 90 | parser.add_argument('--test_fold', type=str, default='./results/test') 91 | parser.add_argument('--test_file', type=str, default=test_file) 92 | 93 | # Misc 94 | parser.add_argument('--mode', type=str, default='test', choices=['train', 'test']) 95 | parser.add_argument('--visdom', type=bool, default=False) 96 | 97 | config = parser.parse_args() 98 | if not os.path.exists(config.save_fold): os.mkdir(config.save_fold) 99 | main(config) 100 | -------------------------------------------------------------------------------- /nldf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | 6 | base = {'352': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']} 7 | extra = {'352': [2, 7, 14, 21, 28]} 8 | 9 | 10 | # vgg16 11 | def vgg(cfg, i, batch_norm=False): 12 | layers = [] 13 | in_channels = i 14 | for v in cfg: 15 | if v == 'M': 16 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 17 | else: 18 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 19 | if batch_norm: 20 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 21 | else: 22 | layers += [conv2d, nn.ReLU(inplace=True)] 23 | in_channels = v 24 | return layers 25 | 26 | 27 | class ConvConstract(nn.Module): 28 | def __init__(self, in_channel): 29 | super(ConvConstract, self).__init__() 30 | self.conv1 = nn.Conv2d(in_channel, 128, kernel_size=3, padding=1) 31 | self.cons1 = nn.AvgPool2d(3, stride=1, padding=1) 32 | 33 | def forward(self, x): 34 | x = F.relu(self.conv1(x), inplace=True) 35 | x2 = self.cons1(x) 36 | return x, x - x2 37 | 38 | 39 | # extra part 40 | def extra_layer(vgg, cfg): 41 | feat_layers, pool_layers = [], [] 42 | for k, v in enumerate(cfg): 43 | feat_layers += [ConvConstract(vgg[v].out_channels)] 44 | if k == 0: 45 | pool_layers += [nn.Conv2d(128 * (6 - k), 128 * (5 - k), 1)] 46 | else: 47 | # TODO: change this to sampling 48 | pool_layers += [nn.ConvTranspose2d(128 * (6 - k), 128 * (5 - k), 3, 2, 1, 1)] 49 | return vgg, feat_layers, pool_layers 50 | 51 | 52 | class NLDF(nn.Module): 53 | def __init__(self, base, feat_layers, pool_layers): 54 | super(NLDF, self).__init__() 55 | self.pos = [4, 9, 16, 23, 30] 56 | self.base = nn.ModuleList(base) 57 | self.feat = nn.ModuleList(feat_layers) 58 | self.pool = nn.ModuleList(pool_layers) 59 | self.glob = nn.Sequential(nn.Conv2d(512, 128, 5), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 5), 60 | nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3)) 61 | self.conv_g = nn.Conv2d(128, 1, 1) 62 | self.conv_l = nn.Conv2d(640, 1, 1) 63 | 64 | def forward(self, x, label=None): 65 | sources, num = list(), 0 66 | for k in range(len(self.base)): 67 | x = self.base[k](x) 68 | if k in self.pos: 69 | sources.append(self.feat[num](x)) 70 | num = num + 1 71 | for k in range(4, -1, -1): 72 | if k == 4: 73 | out = F.relu(self.pool[k](torch.cat([sources[k][0], sources[k][1]], dim=1)), inplace=True) 74 | else: 75 | out = self.pool[k](torch.cat([sources[k][0], sources[k][1], out], dim=1)) if k == 0 else F.relu( 76 | self.pool[k](torch.cat([sources[k][0], sources[k][1], out], dim=1)), inplace=True) 77 | 78 | score = self.conv_g(self.glob(x)) + self.conv_l(out) 79 | prob = torch.sigmoid(score) 80 | return prob 81 | 82 | 83 | def build_model(): 84 | return NLDF(*extra_layer(vgg(base['352'], 3), extra['352'])) 85 | 86 | 87 | def xavier(param): 88 | init.xavier_uniform_(param) 89 | 90 | 91 | def weights_init(m): 92 | if isinstance(m, nn.Conv2d): 93 | xavier(m.weight.data) 94 | m.bias.data.zero_() 95 | 96 | -------------------------------------------------------------------------------- /png/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/NLDF-pytorch/7c4a5dddab277c6136378e1592ef7c287b60d314/png/demo.jpg -------------------------------------------------------------------------------- /png/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/NLDF-pytorch/7c4a5dddab277c6136378e1592ef7c287b60d314/png/example.png -------------------------------------------------------------------------------- /png/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/NLDF-pytorch/7c4a5dddab277c6136378e1592ef7c287b60d314/png/loss.png -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from torch.nn import utils, functional as F 4 | from torch.optim import Adam 5 | from torch.backends import cudnn 6 | from nldf import build_model, weights_init 7 | from loss import Loss 8 | from tools.visual import Viz_visdom 9 | 10 | 11 | class Solver(object): 12 | def __init__(self, train_loader, val_loader, test_loader, config): 13 | self.train_loader = train_loader 14 | self.val_loader = val_loader 15 | self.test_loader = test_loader 16 | self.config = config 17 | self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255 18 | self.beta = 0.3 19 | self.device = torch.device('cpu') 20 | if self.config.cuda: 21 | cudnn.benchmark = True 22 | self.device = torch.device('cuda') 23 | if config.visdom: 24 | self.visual = Viz_visdom("NLDF", 1) 25 | self.build_model() 26 | if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) 27 | if config.mode == 'train': 28 | self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') 29 | else: 30 | self.net.load_state_dict(torch.load(self.config.model)) 31 | self.net.eval() 32 | self.test_output = open("%s/test.txt" % config.test_fold, 'w') 33 | 34 | def print_network(self, model, name): 35 | num_params = 0 36 | for p in model.parameters(): 37 | num_params += p.numel() 38 | print(name) 39 | print(model) 40 | print("The number of parameters: {}".format(num_params)) 41 | 42 | def build_model(self): 43 | self.net = build_model() 44 | if self.config.mode == 'train': self.loss = Loss(self.config.area, self.config.boundary) 45 | self.net = self.net.to(self.device) 46 | if self.config.cuda and self.config.mode == 'train': self.loss = self.loss.cuda() 47 | self.net.train() 48 | self.net.apply(weights_init) 49 | if self.config.load == '': self.net.base.load_state_dict(torch.load(self.config.vgg)) 50 | if self.config.load != '': self.net.load_state_dict(torch.load(self.config.load)) 51 | self.optimizer = Adam(self.net.parameters(), self.config.lr) 52 | self.print_network(self.net, 'NLDF') 53 | 54 | def update_lr(self, lr): 55 | for param_group in self.optimizer.param_groups: 56 | param_group['lr'] = lr 57 | 58 | def clip(self, y): 59 | return torch.clamp(y, 0.0, 1.0) 60 | 61 | def eval_mae(self, y_pred, y): 62 | return torch.abs(y_pred - y).mean() 63 | 64 | # TODO: write a more efficient version 65 | def eval_pr(self, y_pred, y, num): 66 | prec, recall = torch.zeros(num), torch.zeros(num) 67 | thlist = torch.linspace(0, 1 - 1e-10, num) 68 | for i in range(num): 69 | y_temp = (y_pred >= thlist[i]).float() 70 | tp = (y_temp * y).sum() 71 | prec[i], recall[i] = tp / (y_temp.sum() + 1e-20), tp / y.sum() 72 | return prec, recall 73 | 74 | def validation(self): 75 | avg_mae = 0.0 76 | self.net.eval() 77 | for i, data_batch in enumerate(self.val_loader): 78 | with torch.no_grad(): 79 | images, labels = data_batch 80 | images, labels = images.to(self.device), labels.to(self.device) 81 | prob_pred = self.net(images) 82 | avg_mae += self.eval_mae(prob_pred, labels).cpu().item() 83 | self.net.train() 84 | return avg_mae / len(self.val_loader) 85 | 86 | def test(self, num): 87 | avg_mae, img_num = 0.0, len(self.test_loader) 88 | avg_prec, avg_recall = torch.zeros(num), torch.zeros(num) 89 | for i, data_batch in enumerate(self.test_loader): 90 | with torch.no_grad(): 91 | images, labels = data_batch 92 | shape = labels.size()[2:] 93 | images = images.to(self.device) 94 | prob_pred = F.interpolate(self.net(images), size=shape, mode='bilinear', align_corners=True).cpu() 95 | mae = self.eval_mae(prob_pred, labels) 96 | prec, recall = self.eval_pr(prob_pred, labels, num) 97 | print("[%d] mae: %.4f" % (i, mae)) 98 | print("[%d] mae: %.4f" % (i, mae), file=self.test_output) 99 | avg_mae += mae 100 | avg_prec, avg_recall = avg_prec + prec, avg_recall + recall 101 | avg_mae, avg_prec, avg_recall = avg_mae / img_num, avg_prec / img_num, avg_recall / img_num 102 | score = (1 + self.beta ** 2) * avg_prec * avg_recall / (self.beta ** 2 * avg_prec + avg_recall) 103 | score[score != score] = 0 # delete the nan 104 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max())) 105 | print('average mae: %.4f, max fmeasure: %.4f' % (avg_mae, score.max()), file=self.test_output) 106 | 107 | def train(self): 108 | iter_num = len(self.train_loader.dataset) // self.config.batch_size 109 | best_mae = 1.0 if self.config.val else None 110 | for epoch in range(self.config.epoch): 111 | loss_epoch = 0 112 | for i, data_batch in enumerate(self.train_loader): 113 | if (i + 1) > iter_num: break 114 | self.net.zero_grad() 115 | x, y = data_batch 116 | x, y = x.to(self.device), y.to(self.device) 117 | y_pred = self.net(x) 118 | loss = self.loss(y_pred, y) 119 | loss.backward() 120 | utils.clip_grad_norm_(self.net.parameters(), self.config.clip_gradient) 121 | self.optimizer.step() 122 | loss_epoch += loss.cpu().item() 123 | print('epoch: [%d/%d], iter: [%d/%d], loss: [%.4f]' % ( 124 | epoch, self.config.epoch, i, iter_num, loss.cpu().item())) 125 | if self.config.visdom: 126 | error = OrderedDict([('loss:', loss.cpu().item())]) 127 | self.visual.plot_current_errors(epoch, i / iter_num, error) 128 | if (epoch + 1) % self.config.epoch_show == 0: 129 | print('epoch: [%d/%d], epoch_loss: [%.4f]' % (epoch, self.config.epoch, loss_epoch / iter_num), 130 | file=self.log_output) 131 | if self.config.visdom: 132 | avg_err = OrderedDict([('avg_loss', loss_epoch / iter_num)]) 133 | self.visual.plot_current_errors(epoch, i / iter_num, avg_err, 1) 134 | img = OrderedDict([('origin', self.mean + x.cpu()[0]), ('label', y.cpu()[0][0]), 135 | ('pred_label', y_pred.cpu()[0][0])]) 136 | self.visual.plot_current_img(img) 137 | if self.config.val and (epoch + 1) % self.config.epoch_val == 0: 138 | mae = self.validation() 139 | print('--- Best MAE: %.4f, Curr MAE: %.4f ---' % (best_mae, mae)) 140 | print('--- Best MAE: %.4f, Curr MAE: %.4f ---' % (best_mae, mae), file=self.log_output) 141 | if best_mae > mae: 142 | best_mae = mae 143 | torch.save(self.net.state_dict(), '%s/models/best.pth' % self.config.save_fold) 144 | if (epoch + 1) % self.config.epoch_save == 0: 145 | torch.save(self.net.state_dict(), '%s/models/epoch_%d.pth' % (self.config.save_fold, epoch + 1)) 146 | torch.save(self.net.state_dict(), '%s/models/final.pth' % self.config.save_fold) 147 | -------------------------------------------------------------------------------- /tools/extract_vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import models 4 | 5 | 6 | if __name__ == '__main__': 7 | save_fold = '../weights' 8 | if not os.path.exists(save_fold): 9 | os.mkdir(save_fold) 10 | vgg = models.vgg16(pretrained=True) 11 | torch.save(vgg.features.state_dict(), os.path.join(save_fold, 'vgg16_feat.pth')) -------------------------------------------------------------------------------- /tools/visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Viz_visdom(): 5 | def __init__(self, name, display_id=0): 6 | self.name = name 7 | self.display_id = display_id 8 | self.idx = display_id 9 | self.plot_data = {} 10 | if display_id > 0: 11 | import visdom 12 | self.vis = visdom.Visdom(port=8097) 13 | 14 | def plot_current_errors(self, epoch, counter_ratio, errors, idx=0): 15 | 16 | if idx not in self.plot_data: 17 | self.plot_data[idx] = {'X': [], 'Y': [], 'legend': list(errors.keys())} 18 | # self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} 19 | self.plot_data[idx]['X'].append(epoch + counter_ratio) 20 | self.plot_data[idx]['Y'].append([errors[k] for k in self.plot_data[idx]['legend']]) 21 | self.vis.line( 22 | X=np.stack([np.array(self.plot_data[idx]['X'])] * len(self.plot_data[idx]['legend']), 1) 23 | if len(errors) > 1 else np.array(self.plot_data[idx]['X']), 24 | Y=np.array(self.plot_data[idx]['Y']) if len(errors) > 1 else np.array(self.plot_data[idx]['Y'])[:, 0], 25 | opts={ 26 | 'title': self.name + ' loss over time %d' % idx, 27 | 'legend': self.plot_data[idx]['legend'], 28 | 'xlabel': 'epoch', 29 | 'ylabel': 'loss'}, 30 | win=self.display_id + idx) 31 | if self.idx < self.display_id + idx: 32 | self.idx = self.display_id + idx 33 | 34 | def plot_current_img(self, visuals, c_prev=True): 35 | idx = self.idx + 1 36 | for label, image_numpy in visuals.items(): 37 | if c_prev: 38 | self.vis.image(image_numpy, opts=dict(title=label), 39 | win=self.display_id + idx) 40 | else: 41 | image_numpy = image_numpy.swapaxes(0, 2).swapaxes(1, 2) 42 | self.vis.image(image_numpy, opts=dict(title=label), 43 | win=self.display_id + idx) 44 | idx += 1 45 | --------------------------------------------------------------------------------