├── model ├── __init__.py ├── __pycache__ │ ├── BASNet.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── unet_model.cpython-36.pyc │ ├── unet_parts.cpython-36.pyc │ └── resnet_model.cpython-36.pyc ├── resnet_model.py └── BASNet.py ├── figures ├── qual.png ├── quan.png ├── architecture.png ├── cod_qual_comp.PNG ├── soc_qual_comp.PNG └── sod_qual_comp.PNG ├── test_data ├── test_images │ ├── 0003.jpg │ ├── 0005.jpg │ ├── 0010.jpg │ ├── 0012.jpg │ └── BKN06Z000006_W_big.jpg └── test_results │ ├── 0003.png │ ├── 0005.png │ ├── 0010.png │ └── 0012.png ├── __pycache__ └── data_loader.cpython-36.pyc ├── pytorch_iou ├── __pycache__ │ └── __init__.cpython-36.pyc └── __init__.py ├── pytorch_ssim ├── __pycache__ │ └── __init__.cpython-36.pyc └── __init__.py ├── LICENSE ├── README.md~ ├── basnet_test.py ├── README.md ├── basnet_train.py └── data_loader.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .BASNet import BASNet 2 | -------------------------------------------------------------------------------- /figures/qual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/figures/qual.png -------------------------------------------------------------------------------- /figures/quan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/figures/quan.png -------------------------------------------------------------------------------- /figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/figures/architecture.png -------------------------------------------------------------------------------- /figures/cod_qual_comp.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/figures/cod_qual_comp.PNG -------------------------------------------------------------------------------- /figures/soc_qual_comp.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/figures/soc_qual_comp.PNG -------------------------------------------------------------------------------- /figures/sod_qual_comp.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/figures/sod_qual_comp.PNG -------------------------------------------------------------------------------- /test_data/test_images/0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_images/0003.jpg -------------------------------------------------------------------------------- /test_data/test_images/0005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_images/0005.jpg -------------------------------------------------------------------------------- /test_data/test_images/0010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_images/0010.jpg -------------------------------------------------------------------------------- /test_data/test_images/0012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_images/0012.jpg -------------------------------------------------------------------------------- /test_data/test_results/0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_results/0003.png -------------------------------------------------------------------------------- /test_data/test_results/0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_results/0005.png -------------------------------------------------------------------------------- /test_data/test_results/0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_results/0010.png -------------------------------------------------------------------------------- /test_data/test_results/0012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_results/0012.png -------------------------------------------------------------------------------- /__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/BASNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/model/__pycache__/BASNet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/model/__pycache__/unet_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_parts.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/model/__pycache__/unet_parts.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/resnet_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/model/__pycache__/resnet_model.cpython-36.pyc -------------------------------------------------------------------------------- /test_data/test_images/BKN06Z000006_W_big.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/test_data/test_images/BKN06Z000006_W_big.jpg -------------------------------------------------------------------------------- /pytorch_iou/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/pytorch_iou/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuebinqin/BASNet/HEAD/pytorch_ssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def _iou(pred, target, size_average = True): 7 | 8 | b = pred.shape[0] 9 | IoU = 0.0 10 | for i in range(0,b): 11 | #compute the IoU of the foreground 12 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 13 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 14 | IoU1 = Iand1/Ior1 15 | 16 | #IoU loss is (1-IoU1) 17 | IoU = IoU + (1-IoU1) 18 | 19 | return IoU/b 20 | 21 | class IOU(torch.nn.Module): 22 | def __init__(self, size_average = True): 23 | super(IOU, self).__init__() 24 | self.size_average = size_average 25 | 26 | def forward(self, pred, target): 27 | 28 | return _iou(pred, target, self.size_average) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Xuebin Qin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md~: -------------------------------------------------------------------------------- 1 | # BASNet 2 | Code for CVPR 2019 paper '[*BASNet: Boundary-Aware Salient Object Detection*](https://webdocs.cs.ualberta.ca/~xuebin/BASNet.pdf)', [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), Zichen Zhang, Chenyang Huang, Chao Gao, Masood Dehghan and Martin Jagersand. [(supplementary)](https://webdocs.cs.ualberta.ca/~xuebin/BASNet-supp.pdf) 3 | 4 | __Contact__: xuebin[at]ualberta[dot]ca 5 | 6 | ## Required libraries 7 | 8 | Python 3.6 9 | numpy 1.15.2 10 | scikit-image 0.14.0 11 | PIL 5.2.0 12 | PyTorch 0.4.0 13 | torchvision 0.2.1 14 | glob 15 | 16 | The SSIM loss is adapted from [pytorch-ssim](https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py). 17 | 18 | ## Usage 19 | 1. Clone this repo 20 | ``` 21 | git clone https://github.com/NathanUA/BASNet.git 22 | ``` 23 | 2. Download the pre-trained model [basnet.pth](https://drive.google.com/file/d/1qeKYOTLIOeSJGqIhFJOEch48tPyzrsZx/view?usp=sharing) and put it into the dirctory 'saved_models/basnet_bsi/' 24 | 25 | 3. Cd to the directory 'BASNet', run the training or inference process by command: ```python basnet_train.py``` 26 | or ```python basnet_test.py``` respectively. 27 | 28 | We also provide the predicted [saliency maps](https://drive.google.com/file/d/1K9y9HpupXT0RJ4U4OizJ_Uk5byUyCupK/view?usp=sharing) for datasets SOD, ECSSD, DUT-OMRON, PASCAL-S, HKU-IS and DUTS-TE. 29 | 30 | ## Architecture 31 | 32 | ![BASNet architecture](figures/architecture.png) 33 | 34 | 35 | ## Quantitative Comparison 36 | 37 | ![Quantitative Comparison](figures/quan.png) 38 | 39 | ## Qualitative Comparison 40 | 41 | ![Qualitative Comparison](figures/qual.png) 42 | 43 | ## Citation 44 | ``` 45 | @InProceedings{Qin2019BASNet, 46 | author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin}, 47 | title = {BASNet: Boundary Aware Salient Object Detection}, 48 | booktitle={IEEE CVPR}, 49 | year = {2019} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /basnet_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from skimage import io, transform 3 | import torch 4 | import torchvision 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | from torchvision import transforms#, utils 10 | # import torch.optim as optim 11 | 12 | import numpy as np 13 | from PIL import Image 14 | import glob 15 | 16 | from data_loader import RescaleT 17 | from data_loader import CenterCrop 18 | from data_loader import ToTensor 19 | from data_loader import ToTensorLab 20 | from data_loader import SalObjDataset 21 | 22 | from model import BASNet 23 | 24 | def normPRED(d): 25 | ma = torch.max(d) 26 | mi = torch.min(d) 27 | 28 | dn = (d-mi)/(ma-mi) 29 | 30 | return dn 31 | 32 | def save_output(image_name,pred,d_dir): 33 | 34 | predict = pred 35 | predict = predict.squeeze() 36 | predict_np = predict.cpu().data.numpy() 37 | 38 | im = Image.fromarray(predict_np*255).convert('RGB') 39 | img_name = image_name.split("/")[-1] 40 | image = io.imread(image_name) 41 | imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR) 42 | 43 | pb_np = np.array(imo) 44 | 45 | aaa = img_name.split(".") 46 | bbb = aaa[0:-1] 47 | imidx = bbb[0] 48 | for i in range(1,len(bbb)): 49 | imidx = imidx + "." + bbb[i] 50 | 51 | imo.save(d_dir+imidx+'.png') 52 | 53 | 54 | if __name__ == '__main__': 55 | # --------- 1. get image path and name --------- 56 | 57 | image_dir = './test_data/test_images/' 58 | prediction_dir = './test_data/test_results/' 59 | model_dir = './saved_models/basnet_bsi/basnet.pth' 60 | 61 | img_name_list = glob.glob(image_dir + '*.jpg') 62 | 63 | # --------- 2. dataloader --------- 64 | #1. dataload 65 | test_salobj_dataset = SalObjDataset(img_name_list = img_name_list, lbl_name_list = [],transform=transforms.Compose([RescaleT(256),ToTensorLab(flag=0)])) 66 | test_salobj_dataloader = DataLoader(test_salobj_dataset, batch_size=1,shuffle=False,num_workers=1) 67 | 68 | # --------- 3. model define --------- 69 | print("...load BASNet...") 70 | net = BASNet(3,1) 71 | net.load_state_dict(torch.load(model_dir)) 72 | if torch.cuda.is_available(): 73 | net.cuda() 74 | net.eval() 75 | 76 | # --------- 4. inference for each image --------- 77 | for i_test, data_test in enumerate(test_salobj_dataloader): 78 | 79 | print("inferencing:",img_name_list[i_test].split("/")[-1]) 80 | 81 | inputs_test = data_test['image'] 82 | inputs_test = inputs_test.type(torch.FloatTensor) 83 | 84 | if torch.cuda.is_available(): 85 | inputs_test = Variable(inputs_test.cuda()) 86 | else: 87 | inputs_test = Variable(inputs_test) 88 | 89 | d1,d2,d3,d4,d5,d6,d7,d8 = net(inputs_test) 90 | 91 | # normalization 92 | pred = d1[:,0,:,:] 93 | pred = normPRED(pred) 94 | 95 | # save results to test_results folder 96 | save_output(img_name_list[i_test],pred,prediction_dir) 97 | 98 | del d1,d2,d3,d4,d5,d6,d7,d8 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BASNet (New Version May 2nd, 2021) 2 | 3 | '[Boundary-Aware Segmentation Network for 4 | Mobile and Web Applications](https://arxiv.org/pdf/2101.04704.pdf)', Xuebin Qin, Deng-Ping Fan, Chenyang Huang, Cyril Diagne, Zichen Zhang, 5 | Adria Cabeza Sant’Anna, Albert Suarez, Martin Jagersand, and Ling Shao. 6 | 7 | ## Salient Object Detection(SOD) Qualitative Comparison 8 | ![SOD Qualitative Comparison](figures/sod_qual_comp.PNG) 9 | 10 | ## Salient Objects in Clutter(SOC) Qualitative Comparison 11 | ![SOC Qualitative Comparison](figures/soc_qual_comp.PNG) 12 | 13 | ## Camouflaged Object Detection(COD) Qualitative Comparison 14 | ![COD Qualitative Comparison](figures/cod_qual_comp.PNG) 15 | 16 | ## Predicted maps of SOD, SOC and COD datasets 17 | 18 | [SOD Results will come soon!]() \ 19 | [SOC Results will come soon!]() \ 20 | [COD Results](https://drive.google.com/file/d/12jijUPpdOe7k2O1YcLbkJHyXCJb3MRMN/view?usp=sharing) 21 | 22 | 23 | 24 | # BASNet (CVPR 2019) 25 | Code for CVPR 2019 paper '[*BASNet: Boundary-Aware Salient Object Detection*](http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html) [code](https://github.com/NathanUA/BASNet)', [Xuebin Qin](https://webdocs.cs.ualberta.ca/~xuebin/), [Zichen Zhang](https://webdocs.cs.ualberta.ca/~zichen2/), [Chenyang Huang](https://chenyangh.com/), [Chao Gao](https://cgao3.github.io/), [Masood Dehghan](https://sites.google.com/view/masooddehghan) and [Martin Jagersand](https://webdocs.cs.ualberta.ca/~jag/). 26 | 27 | __Contact__: xuebin[at]ualberta[dot]ca 28 | 29 | ## (2020-May-09) NEWS! Our new Salient Object Detection model (U^2-Net), which is just accepted by Pattern Recognition, is available now! 30 | [U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection](https://github.com/NathanUA/U-2-Net) 31 | 32 | ## Evaluation 33 | [Evaluation Code](https://github.com/NathanUA/Binary-Segmentation-Evaluation-Tool) 34 | 35 | ## Required libraries 36 | 37 | Python 3.6 38 | numpy 1.15.2 39 | scikit-image 0.14.0 40 | PIL 5.2.0 41 | PyTorch 0.4.0 42 | torchvision 0.2.1 43 | glob 44 | 45 | The SSIM loss is adapted from [pytorch-ssim](https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py). 46 | 47 | ## Usage 48 | 1. Clone this repo 49 | ``` 50 | git clone https://github.com/NathanUA/BASNet.git 51 | ``` 52 | 2. Download the pre-trained model basnet.pth from [GoogleDrive](https://drive.google.com/open?id=1s52ek_4YTDRt_EOkx1FS53u-vJa0c4nu) or [baidu](https://pan.baidu.com/s/1PrsBdepwrkMWPLSW22FhAg) extraction code: 6phq, and put it into the dirctory 'saved_models/basnet_bsi/' 53 | 54 | 3. Cd to the directory 'BASNet', run the training or inference process by command: ```python basnet_train.py``` 55 | or ```python basnet_test.py``` respectively. 56 | 57 | We also provide the predicted saliency maps ([GoogleDrive](https://drive.google.com/file/d/1K9y9HpupXT0RJ4U4OizJ_Uk5byUyCupK/view?usp=sharing),[Baidu](https://pan.baidu.com/s/1FJKVO_9YrP7Iaz7WT6Xdhg)) for datasets SOD, ECSSD, DUT-OMRON, PASCAL-S, HKU-IS and DUTS-TE. 58 | 59 | ## Architecture 60 | 61 | ![BASNet architecture](figures/architecture.png) 62 | 63 | 64 | ## Quantitative Comparison 65 | 66 | ![Quantitative Comparison](figures/quan.png) 67 | 68 | ## Qualitative Comparison 69 | 70 | ![Qualitative Comparison](figures/qual.png) 71 | 72 | 73 | ## Citation 74 | ``` 75 | @article{DBLP:journals/corr/abs-2101-04704, 76 | author = {Xuebin Qin and 77 | Deng{-}Ping Fan and 78 | Chenyang Huang and 79 | Cyril Diagne and 80 | Zichen Zhang and 81 | Adri{\`{a}} Cabeza Sant'Anna and 82 | Albert Su{\`{a}}rez and 83 | Martin J{\"{a}}gersand and 84 | Ling Shao}, 85 | title = {Boundary-Aware Segmentation Network for Mobile and Web Applications}, 86 | journal = {CoRR}, 87 | volume = {abs/2101.04704}, 88 | year = {2021}, 89 | url = {https://arxiv.org/abs/2101.04704}, 90 | archivePrefix = {arXiv}, 91 | } 92 | ``` 93 | 94 | ## Citation 95 | ``` 96 | @InProceedings{Qin_2019_CVPR, 97 | author = {Qin, Xuebin and Zhang, Zichen and Huang, Chenyang and Gao, Chao and Dehghan, Masood and Jagersand, Martin}, 98 | title = {BASNet: Boundary-Aware Salient Object Detection}, 99 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 100 | month = {June}, 101 | year = {2019} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /model/resnet_model.py: -------------------------------------------------------------------------------- 1 | ## code from: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import torchvision 7 | 8 | # __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | # 'resnet152', 'ResNet34P','ResNet50S','ResNet50P','ResNet101P'] 10 | # 11 | # resnet18_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet18-5c106cde.pth' 12 | # resnet34_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet34-333f7ec4.pth' 13 | # resnet50_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet50-19c8e357.pth' 14 | # resnet101_dir = '/local/sda4/yqian3/RoadNets/resnet_model/resnet101-5d3b4d8f.pth' 15 | # 16 | # model_urls = { 17 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | # } 23 | 24 | def conv3x3(in_planes, out_planes, stride=1): 25 | "3x3 convolution with padding" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=1, bias=False) 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | class BasicBlockDe(nn.Module): 61 | expansion = 1 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(BasicBlockDe, self).__init__() 65 | 66 | self.convRes = conv3x3(inplanes,planes,stride) 67 | self.bnRes = nn.BatchNorm2d(planes) 68 | self.reluRes = nn.ReLU(inplace=True) 69 | 70 | self.conv1 = conv3x3(inplanes, planes, stride) 71 | self.bn1 = nn.BatchNorm2d(planes) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.conv2 = conv3x3(planes, planes) 74 | self.bn2 = nn.BatchNorm2d(planes) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = self.convRes(x) 80 | residual = self.bnRes(residual) 81 | residual = self.reluRes(residual) 82 | 83 | out = self.conv1(x) 84 | out = self.bn1(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv2(out) 88 | out = self.bn2(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class Bottleneck(nn.Module): 100 | expansion = 4 101 | 102 | def __init__(self, inplanes, planes, stride=1, downsample=None): 103 | super(Bottleneck, self).__init__() 104 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 105 | self.bn1 = nn.BatchNorm2d(planes) 106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 107 | padding=1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(planes) 109 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 110 | self.bn3 = nn.BatchNorm2d(planes * 4) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.downsample = downsample 113 | self.stride = stride 114 | 115 | def forward(self, x): 116 | residual = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | 129 | if self.downsample is not None: 130 | residual = self.downsample(x) 131 | 132 | out += residual 133 | out = self.relu(out) 134 | 135 | return out 136 | -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 69 | 70 | mu1_sq = mu1.pow(2) 71 | mu2_sq = mu2.pow(2) 72 | mu1_mu2 = mu1*mu2 73 | 74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 77 | 78 | C1 = 0.01**2 79 | C2 = 0.03**2 80 | 81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 83 | ssim_map = -torch.log(ssim_map + 1e-8) 84 | 85 | if size_average: 86 | return ssim_map.mean() 87 | else: 88 | return ssim_map.mean(1).mean(1).mean(1) 89 | 90 | class LOGSSIM(torch.nn.Module): 91 | def __init__(self, window_size = 11, size_average = True): 92 | super(LOGSSIM, self).__init__() 93 | self.window_size = window_size 94 | self.size_average = size_average 95 | self.channel = 1 96 | self.window = create_window(window_size, self.channel) 97 | 98 | def forward(self, img1, img2): 99 | (_, channel, _, _) = img1.size() 100 | 101 | if channel == self.channel and self.window.data.type() == img1.data.type(): 102 | window = self.window 103 | else: 104 | window = create_window(self.window_size, channel) 105 | 106 | if img1.is_cuda: 107 | window = window.cuda(img1.get_device()) 108 | window = window.type_as(img1) 109 | 110 | self.window = window 111 | self.channel = channel 112 | 113 | 114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 115 | 116 | 117 | def ssim(img1, img2, window_size = 11, size_average = True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, size_average) 126 | -------------------------------------------------------------------------------- /basnet_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms, utils 9 | import torch.optim as optim 10 | import torchvision.transforms as standard_transforms 11 | 12 | import numpy as np 13 | import glob 14 | 15 | from data_loader import Rescale 16 | from data_loader import RescaleT 17 | from data_loader import RandomCrop 18 | from data_loader import CenterCrop 19 | from data_loader import ToTensor 20 | from data_loader import ToTensorLab 21 | from data_loader import SalObjDataset 22 | 23 | from model import BASNet 24 | 25 | import pytorch_ssim 26 | import pytorch_iou 27 | 28 | # ------- 1. define loss function -------- 29 | 30 | bce_loss = nn.BCELoss(size_average=True) 31 | ssim_loss = pytorch_ssim.SSIM(window_size=11,size_average=True) 32 | iou_loss = pytorch_iou.IOU(size_average=True) 33 | 34 | def bce_ssim_loss(pred,target): 35 | 36 | bce_out = bce_loss(pred,target) 37 | ssim_out = 1 - ssim_loss(pred,target) 38 | iou_out = iou_loss(pred,target) 39 | 40 | loss = bce_out + ssim_out + iou_out 41 | 42 | return loss 43 | 44 | def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v): 45 | 46 | loss0 = bce_ssim_loss(d0,labels_v) 47 | loss1 = bce_ssim_loss(d1,labels_v) 48 | loss2 = bce_ssim_loss(d2,labels_v) 49 | loss3 = bce_ssim_loss(d3,labels_v) 50 | loss4 = bce_ssim_loss(d4,labels_v) 51 | loss5 = bce_ssim_loss(d5,labels_v) 52 | loss6 = bce_ssim_loss(d6,labels_v) 53 | loss7 = bce_ssim_loss(d7,labels_v) 54 | #ssim0 = 1 - ssim_loss(d0,labels_v) 55 | 56 | # iou0 = iou_loss(d0,labels_v) 57 | #loss = torch.pow(torch.mean(torch.abs(labels_v-d0)),2)*(5.0*loss0 + loss1 + loss2 + loss3 + loss4 + loss5) #+ 5.0*lossa 58 | loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7#+ 5.0*lossa 59 | print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data[0],loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],loss6.data[0])) 60 | # print("BCE: l1:%3f, l2:%3f, l3:%3f, l4:%3f, l5:%3f, la:%3f, all:%3f\n"%(loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],lossa.data[0],loss.data[0])) 61 | 62 | return loss0, loss 63 | 64 | 65 | # ------- 2. set the directory of training dataset -------- 66 | 67 | data_dir = './train_data/' 68 | tra_image_dir = 'DUTS/DUTS-TR/DUTS-TR/im_aug/' 69 | tra_label_dir = 'DUTS/DUTS-TR/DUTS-TR/gt_aug/' 70 | 71 | image_ext = '.jpg' 72 | label_ext = '.png' 73 | 74 | model_dir = "./saved_models/basnet_bsi/" 75 | 76 | 77 | epoch_num = 100000 78 | batch_size_train = 8 79 | batch_size_val = 1 80 | train_num = 0 81 | val_num = 0 82 | 83 | tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext) 84 | 85 | tra_lbl_name_list = [] 86 | for img_path in tra_img_name_list: 87 | img_name = img_path.split("/")[-1] 88 | 89 | aaa = img_name.split(".") 90 | bbb = aaa[0:-1] 91 | imidx = bbb[0] 92 | for i in range(1,len(bbb)): 93 | imidx = imidx + "." + bbb[i] 94 | 95 | tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext) 96 | 97 | print("---") 98 | print("train images: ", len(tra_img_name_list)) 99 | print("train labels: ", len(tra_lbl_name_list)) 100 | print("---") 101 | 102 | train_num = len(tra_img_name_list) 103 | 104 | salobj_dataset = SalObjDataset( 105 | img_name_list=tra_img_name_list, 106 | lbl_name_list=tra_lbl_name_list, 107 | transform=transforms.Compose([ 108 | RescaleT(256), 109 | RandomCrop(224), 110 | ToTensorLab(flag=0)])) 111 | salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=1) 112 | 113 | # ------- 3. define model -------- 114 | # define the net 115 | net = BASNet(3, 1) 116 | if torch.cuda.is_available(): 117 | net.cuda() 118 | 119 | # ------- 4. define optimizer -------- 120 | print("---define optimizer...") 121 | optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) 122 | 123 | # ------- 5. training process -------- 124 | print("---start training...") 125 | ite_num = 0 126 | running_loss = 0.0 127 | running_tar_loss = 0.0 128 | ite_num4val = 0 129 | 130 | for epoch in range(0, epoch_num): 131 | net.train() 132 | 133 | for i, data in enumerate(salobj_dataloader): 134 | ite_num = ite_num + 1 135 | ite_num4val = ite_num4val + 1 136 | 137 | inputs, labels = data['image'], data['label'] 138 | 139 | inputs = inputs.type(torch.FloatTensor) 140 | labels = labels.type(torch.FloatTensor) 141 | 142 | # wrap them in Variable 143 | if torch.cuda.is_available(): 144 | inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(), 145 | requires_grad=False) 146 | else: 147 | inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False) 148 | 149 | # y zero the parameter gradients 150 | optimizer.zero_grad() 151 | 152 | # forward + backward + optimize 153 | d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v) 154 | loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v) 155 | 156 | loss.backward() 157 | optimizer.step() 158 | 159 | # # print statistics 160 | running_loss += loss.data[0] 161 | running_tar_loss += loss2.data[0] 162 | 163 | # del temporary outputs and loss 164 | del d0, d1, d2, d3, d4, d5, d6, d7, loss2, loss 165 | 166 | print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % ( 167 | epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)) 168 | 169 | if ite_num % 2000 == 0: # save model every 2000 iterations 170 | 171 | torch.save(net.state_dict(), model_dir + "basnet_bsi_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val)) 172 | running_loss = 0.0 173 | running_tar_loss = 0.0 174 | net.train() # resume train 175 | ite_num4val = 0 176 | 177 | print('-------------Congratulations! Training Done!!!-------------') 178 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # data loader 2 | from __future__ import print_function, division 3 | import glob 4 | import torch 5 | from skimage import io, transform, color 6 | import numpy as np 7 | import math 8 | import matplotlib.pyplot as plt 9 | from torch.utils.data import Dataset, DataLoader 10 | from torchvision import transforms, utils 11 | from PIL import Image 12 | #==========================dataset load========================== 13 | 14 | class RescaleT(object): 15 | 16 | def __init__(self,output_size): 17 | assert isinstance(output_size,(int,tuple)) 18 | self.output_size = output_size 19 | 20 | def __call__(self,sample): 21 | image, label = sample['image'],sample['label'] 22 | 23 | h, w = image.shape[:2] 24 | 25 | if isinstance(self.output_size,int): 26 | if h > w: 27 | new_h, new_w = self.output_size*h/w,self.output_size 28 | else: 29 | new_h, new_w = self.output_size,self.output_size*w/h 30 | else: 31 | new_h, new_w = self.output_size 32 | 33 | new_h, new_w = int(new_h), int(new_w) 34 | 35 | # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] 36 | # img = transform.resize(image,(new_h,new_w),mode='constant') 37 | # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) 38 | 39 | img = transform.resize(image,(self.output_size,self.output_size),mode='constant') 40 | lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True) 41 | 42 | return {'image':img,'label':lbl} 43 | 44 | class Rescale(object): 45 | 46 | def __init__(self,output_size): 47 | assert isinstance(output_size,(int,tuple)) 48 | self.output_size = output_size 49 | 50 | def __call__(self,sample): 51 | image, label = sample['image'],sample['label'] 52 | 53 | h, w = image.shape[:2] 54 | 55 | if isinstance(self.output_size,int): 56 | if h > w: 57 | new_h, new_w = self.output_size*h/w,self.output_size 58 | else: 59 | new_h, new_w = self.output_size,self.output_size*w/h 60 | else: 61 | new_h, new_w = self.output_size 62 | 63 | new_h, new_w = int(new_h), int(new_w) 64 | 65 | # #resize the image to new_h x new_w and convert image from range [0,255] to [0,1] 66 | img = transform.resize(image,(new_h,new_w),mode='constant') 67 | lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True) 68 | 69 | return {'image':img,'label':lbl} 70 | 71 | class CenterCrop(object): 72 | 73 | def __init__(self,output_size): 74 | assert isinstance(output_size, (int, tuple)) 75 | if isinstance(output_size, int): 76 | self.output_size = (output_size, output_size) 77 | else: 78 | assert len(output_size) == 2 79 | self.output_size = output_size 80 | def __call__(self,sample): 81 | image, label = sample['image'], sample['label'] 82 | 83 | h, w = image.shape[:2] 84 | new_h, new_w = self.output_size 85 | 86 | # print("h: %d, w: %d, new_h: %d, new_w: %d"%(h, w, new_h, new_w)) 87 | assert((h >= new_h) and (w >= new_w)) 88 | 89 | h_offset = int(math.floor((h - new_h)/2)) 90 | w_offset = int(math.floor((w - new_w)/2)) 91 | 92 | image = image[h_offset: h_offset + new_h, w_offset: w_offset + new_w] 93 | label = label[h_offset: h_offset + new_h, w_offset: w_offset + new_w] 94 | 95 | return {'image': image, 'label': label} 96 | 97 | class RandomCrop(object): 98 | 99 | def __init__(self,output_size): 100 | assert isinstance(output_size, (int, tuple)) 101 | if isinstance(output_size, int): 102 | self.output_size = (output_size, output_size) 103 | else: 104 | assert len(output_size) == 2 105 | self.output_size = output_size 106 | def __call__(self,sample): 107 | image, label = sample['image'], sample['label'] 108 | 109 | h, w = image.shape[:2] 110 | new_h, new_w = self.output_size 111 | 112 | top = np.random.randint(0, h - new_h) 113 | left = np.random.randint(0, w - new_w) 114 | 115 | image = image[top: top + new_h, left: left + new_w] 116 | label = label[top: top + new_h, left: left + new_w] 117 | 118 | return {'image': image, 'label': label} 119 | 120 | class ToTensor(object): 121 | """Convert ndarrays in sample to Tensors.""" 122 | 123 | def __call__(self, sample): 124 | 125 | image, label = sample['image'], sample['label'] 126 | 127 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 128 | tmpLbl = np.zeros(label.shape) 129 | 130 | image = image/np.max(image) 131 | if(np.max(label)<1e-6): 132 | label = label 133 | else: 134 | label = label/np.max(label) 135 | 136 | if image.shape[2]==1: 137 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 138 | tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 139 | tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 140 | else: 141 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 142 | tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 143 | tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 144 | 145 | tmpLbl[:,:,0] = label[:,:,0] 146 | 147 | # change the r,g,b to b,r,g from [0,255] to [0,1] 148 | #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) 149 | tmpImg = tmpImg.transpose((2, 0, 1)) 150 | tmpLbl = label.transpose((2, 0, 1)) 151 | 152 | return {'image': torch.from_numpy(tmpImg), 153 | 'label': torch.from_numpy(tmpLbl)} 154 | 155 | class ToTensorLab(object): 156 | """Convert ndarrays in sample to Tensors.""" 157 | def __init__(self,flag=0): 158 | self.flag = flag 159 | 160 | def __call__(self, sample): 161 | 162 | image, label = sample['image'], sample['label'] 163 | 164 | tmpLbl = np.zeros(label.shape) 165 | 166 | if(np.max(label)<1e-6): 167 | label = label 168 | else: 169 | label = label/np.max(label) 170 | 171 | # change the color space 172 | if self.flag == 2: # with rgb and Lab colors 173 | tmpImg = np.zeros((image.shape[0],image.shape[1],6)) 174 | tmpImgt = np.zeros((image.shape[0],image.shape[1],3)) 175 | if image.shape[2]==1: 176 | tmpImgt[:,:,0] = image[:,:,0] 177 | tmpImgt[:,:,1] = image[:,:,0] 178 | tmpImgt[:,:,2] = image[:,:,0] 179 | else: 180 | tmpImgt = image 181 | tmpImgtl = color.rgb2lab(tmpImgt) 182 | 183 | # nomalize image to range [0,1] 184 | tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0])) 185 | tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1])) 186 | tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2])) 187 | tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0])) 188 | tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1])) 189 | tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2])) 190 | 191 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 192 | 193 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) 194 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) 195 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) 196 | tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3]) 197 | tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4]) 198 | tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5]) 199 | 200 | elif self.flag == 1: #with Lab color 201 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 202 | 203 | if image.shape[2]==1: 204 | tmpImg[:,:,0] = image[:,:,0] 205 | tmpImg[:,:,1] = image[:,:,0] 206 | tmpImg[:,:,2] = image[:,:,0] 207 | else: 208 | tmpImg = image 209 | 210 | tmpImg = color.rgb2lab(tmpImg) 211 | 212 | # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg)) 213 | 214 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0])) 215 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1])) 216 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2])) 217 | 218 | tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0]) 219 | tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1]) 220 | tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2]) 221 | 222 | else: # with rgb color 223 | tmpImg = np.zeros((image.shape[0],image.shape[1],3)) 224 | image = image/np.max(image) 225 | if image.shape[2]==1: 226 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 227 | tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229 228 | tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229 229 | else: 230 | tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229 231 | tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224 232 | tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225 233 | 234 | 235 | 236 | tmpLbl[:,:,0] = label[:,:,0] 237 | 238 | # change the r,g,b to b,r,g from [0,255] to [0,1] 239 | #transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)) 240 | tmpImg = tmpImg.transpose((2, 0, 1)) 241 | tmpLbl = label.transpose((2, 0, 1)) 242 | 243 | return {'image': torch.from_numpy(tmpImg), 244 | 'label': torch.from_numpy(tmpLbl)} 245 | 246 | class SalObjDataset(Dataset): 247 | def __init__(self,img_name_list,lbl_name_list,transform=None): 248 | # self.root_dir = root_dir 249 | # self.image_name_list = glob.glob(image_dir+'*.png') 250 | # self.label_name_list = glob.glob(label_dir+'*.png') 251 | self.image_name_list = img_name_list 252 | self.label_name_list = lbl_name_list 253 | self.transform = transform 254 | 255 | def __len__(self): 256 | return len(self.image_name_list) 257 | 258 | def __getitem__(self,idx): 259 | 260 | # image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx]) 261 | # label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx]) 262 | 263 | image = io.imread(self.image_name_list[idx]) 264 | 265 | if(0==len(self.label_name_list)): 266 | label_3 = np.zeros(image.shape) 267 | else: 268 | label_3 = io.imread(self.label_name_list[idx]) 269 | 270 | #print("len of label3") 271 | #print(len(label_3.shape)) 272 | #print(label_3.shape) 273 | 274 | label = np.zeros(label_3.shape[0:2]) 275 | if(3==len(label_3.shape)): 276 | label = label_3[:,:,0] 277 | elif(2==len(label_3.shape)): 278 | label = label_3 279 | 280 | if(3==len(image.shape) and 2==len(label.shape)): 281 | label = label[:,:,np.newaxis] 282 | elif(2==len(image.shape) and 2==len(label.shape)): 283 | image = image[:,:,np.newaxis] 284 | label = label[:,:,np.newaxis] 285 | 286 | # #vertical flipping 287 | # # fliph = np.random.randn(1) 288 | # flipv = np.random.randn(1) 289 | # 290 | # if flipv>0: 291 | # image = image[::-1,:,:] 292 | # label = label[::-1,:,:] 293 | # #vertical flip 294 | 295 | sample = {'image':image, 'label':label} 296 | 297 | if self.transform: 298 | sample = self.transform(sample) 299 | 300 | return sample 301 | -------------------------------------------------------------------------------- /model/BASNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | import torch.nn.functional as F 5 | 6 | from .resnet_model import * 7 | 8 | 9 | class RefUnet(nn.Module): 10 | def __init__(self,in_ch,inc_ch): 11 | super(RefUnet, self).__init__() 12 | 13 | self.conv0 = nn.Conv2d(in_ch,inc_ch,3,padding=1) 14 | 15 | self.conv1 = nn.Conv2d(inc_ch,64,3,padding=1) 16 | self.bn1 = nn.BatchNorm2d(64) 17 | self.relu1 = nn.ReLU(inplace=True) 18 | 19 | self.pool1 = nn.MaxPool2d(2,2,ceil_mode=True) 20 | 21 | self.conv2 = nn.Conv2d(64,64,3,padding=1) 22 | self.bn2 = nn.BatchNorm2d(64) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.pool2 = nn.MaxPool2d(2,2,ceil_mode=True) 26 | 27 | self.conv3 = nn.Conv2d(64,64,3,padding=1) 28 | self.bn3 = nn.BatchNorm2d(64) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.pool3 = nn.MaxPool2d(2,2,ceil_mode=True) 32 | 33 | self.conv4 = nn.Conv2d(64,64,3,padding=1) 34 | self.bn4 = nn.BatchNorm2d(64) 35 | self.relu4 = nn.ReLU(inplace=True) 36 | 37 | self.pool4 = nn.MaxPool2d(2,2,ceil_mode=True) 38 | 39 | ##### 40 | 41 | self.conv5 = nn.Conv2d(64,64,3,padding=1) 42 | self.bn5 = nn.BatchNorm2d(64) 43 | self.relu5 = nn.ReLU(inplace=True) 44 | 45 | ##### 46 | 47 | self.conv_d4 = nn.Conv2d(128,64,3,padding=1) 48 | self.bn_d4 = nn.BatchNorm2d(64) 49 | self.relu_d4 = nn.ReLU(inplace=True) 50 | 51 | self.conv_d3 = nn.Conv2d(128,64,3,padding=1) 52 | self.bn_d3 = nn.BatchNorm2d(64) 53 | self.relu_d3 = nn.ReLU(inplace=True) 54 | 55 | self.conv_d2 = nn.Conv2d(128,64,3,padding=1) 56 | self.bn_d2 = nn.BatchNorm2d(64) 57 | self.relu_d2 = nn.ReLU(inplace=True) 58 | 59 | self.conv_d1 = nn.Conv2d(128,64,3,padding=1) 60 | self.bn_d1 = nn.BatchNorm2d(64) 61 | self.relu_d1 = nn.ReLU(inplace=True) 62 | 63 | self.conv_d0 = nn.Conv2d(64,1,3,padding=1) 64 | 65 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') 66 | 67 | 68 | def forward(self,x): 69 | 70 | hx = x 71 | hx = self.conv0(hx) 72 | 73 | hx1 = self.relu1(self.bn1(self.conv1(hx))) 74 | hx = self.pool1(hx1) 75 | 76 | hx2 = self.relu2(self.bn2(self.conv2(hx))) 77 | hx = self.pool2(hx2) 78 | 79 | hx3 = self.relu3(self.bn3(self.conv3(hx))) 80 | hx = self.pool3(hx3) 81 | 82 | hx4 = self.relu4(self.bn4(self.conv4(hx))) 83 | hx = self.pool4(hx4) 84 | 85 | hx5 = self.relu5(self.bn5(self.conv5(hx))) 86 | 87 | hx = self.upscore2(hx5) 88 | 89 | d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx,hx4),1)))) 90 | hx = self.upscore2(d4) 91 | 92 | d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx,hx3),1)))) 93 | hx = self.upscore2(d3) 94 | 95 | d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx,hx2),1)))) 96 | hx = self.upscore2(d2) 97 | 98 | d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx,hx1),1)))) 99 | 100 | residual = self.conv_d0(d1) 101 | 102 | return x + residual 103 | 104 | class BASNet(nn.Module): 105 | def __init__(self,n_channels,n_classes): 106 | super(BASNet,self).__init__() 107 | 108 | resnet = models.resnet34(pretrained=True) 109 | 110 | ## -------------Encoder-------------- 111 | 112 | self.inconv = nn.Conv2d(n_channels,64,3,padding=1) 113 | self.inbn = nn.BatchNorm2d(64) 114 | self.inrelu = nn.ReLU(inplace=True) 115 | 116 | #stage 1 117 | self.encoder1 = resnet.layer1 #224 118 | #stage 2 119 | self.encoder2 = resnet.layer2 #112 120 | #stage 3 121 | self.encoder3 = resnet.layer3 #56 122 | #stage 4 123 | self.encoder4 = resnet.layer4 #28 124 | 125 | self.pool4 = nn.MaxPool2d(2,2,ceil_mode=True) 126 | 127 | #stage 5 128 | self.resb5_1 = BasicBlock(512,512) 129 | self.resb5_2 = BasicBlock(512,512) 130 | self.resb5_3 = BasicBlock(512,512) #14 131 | 132 | self.pool5 = nn.MaxPool2d(2,2,ceil_mode=True) 133 | 134 | #stage 6 135 | self.resb6_1 = BasicBlock(512,512) 136 | self.resb6_2 = BasicBlock(512,512) 137 | self.resb6_3 = BasicBlock(512,512) #7 138 | 139 | ## -------------Bridge-------------- 140 | 141 | #stage Bridge 142 | self.convbg_1 = nn.Conv2d(512,512,3,dilation=2, padding=2) # 7 143 | self.bnbg_1 = nn.BatchNorm2d(512) 144 | self.relubg_1 = nn.ReLU(inplace=True) 145 | self.convbg_m = nn.Conv2d(512,512,3,dilation=2, padding=2) 146 | self.bnbg_m = nn.BatchNorm2d(512) 147 | self.relubg_m = nn.ReLU(inplace=True) 148 | self.convbg_2 = nn.Conv2d(512,512,3,dilation=2, padding=2) 149 | self.bnbg_2 = nn.BatchNorm2d(512) 150 | self.relubg_2 = nn.ReLU(inplace=True) 151 | 152 | ## -------------Decoder-------------- 153 | 154 | #stage 6d 155 | self.conv6d_1 = nn.Conv2d(1024,512,3,padding=1) # 16 156 | self.bn6d_1 = nn.BatchNorm2d(512) 157 | self.relu6d_1 = nn.ReLU(inplace=True) 158 | 159 | self.conv6d_m = nn.Conv2d(512,512,3,dilation=2, padding=2)### 160 | self.bn6d_m = nn.BatchNorm2d(512) 161 | self.relu6d_m = nn.ReLU(inplace=True) 162 | 163 | self.conv6d_2 = nn.Conv2d(512,512,3,dilation=2, padding=2) 164 | self.bn6d_2 = nn.BatchNorm2d(512) 165 | self.relu6d_2 = nn.ReLU(inplace=True) 166 | 167 | #stage 5d 168 | self.conv5d_1 = nn.Conv2d(1024,512,3,padding=1) # 16 169 | self.bn5d_1 = nn.BatchNorm2d(512) 170 | self.relu5d_1 = nn.ReLU(inplace=True) 171 | 172 | self.conv5d_m = nn.Conv2d(512,512,3,padding=1)### 173 | self.bn5d_m = nn.BatchNorm2d(512) 174 | self.relu5d_m = nn.ReLU(inplace=True) 175 | 176 | self.conv5d_2 = nn.Conv2d(512,512,3,padding=1) 177 | self.bn5d_2 = nn.BatchNorm2d(512) 178 | self.relu5d_2 = nn.ReLU(inplace=True) 179 | 180 | #stage 4d 181 | self.conv4d_1 = nn.Conv2d(1024,512,3,padding=1) # 32 182 | self.bn4d_1 = nn.BatchNorm2d(512) 183 | self.relu4d_1 = nn.ReLU(inplace=True) 184 | 185 | self.conv4d_m = nn.Conv2d(512,512,3,padding=1)### 186 | self.bn4d_m = nn.BatchNorm2d(512) 187 | self.relu4d_m = nn.ReLU(inplace=True) 188 | 189 | self.conv4d_2 = nn.Conv2d(512,256,3,padding=1) 190 | self.bn4d_2 = nn.BatchNorm2d(256) 191 | self.relu4d_2 = nn.ReLU(inplace=True) 192 | 193 | #stage 3d 194 | self.conv3d_1 = nn.Conv2d(512,256,3,padding=1) # 64 195 | self.bn3d_1 = nn.BatchNorm2d(256) 196 | self.relu3d_1 = nn.ReLU(inplace=True) 197 | 198 | self.conv3d_m = nn.Conv2d(256,256,3,padding=1)### 199 | self.bn3d_m = nn.BatchNorm2d(256) 200 | self.relu3d_m = nn.ReLU(inplace=True) 201 | 202 | self.conv3d_2 = nn.Conv2d(256,128,3,padding=1) 203 | self.bn3d_2 = nn.BatchNorm2d(128) 204 | self.relu3d_2 = nn.ReLU(inplace=True) 205 | 206 | #stage 2d 207 | 208 | self.conv2d_1 = nn.Conv2d(256,128,3,padding=1) # 128 209 | self.bn2d_1 = nn.BatchNorm2d(128) 210 | self.relu2d_1 = nn.ReLU(inplace=True) 211 | 212 | self.conv2d_m = nn.Conv2d(128,128,3,padding=1)### 213 | self.bn2d_m = nn.BatchNorm2d(128) 214 | self.relu2d_m = nn.ReLU(inplace=True) 215 | 216 | self.conv2d_2 = nn.Conv2d(128,64,3,padding=1) 217 | self.bn2d_2 = nn.BatchNorm2d(64) 218 | self.relu2d_2 = nn.ReLU(inplace=True) 219 | 220 | #stage 1d 221 | self.conv1d_1 = nn.Conv2d(128,64,3,padding=1) # 256 222 | self.bn1d_1 = nn.BatchNorm2d(64) 223 | self.relu1d_1 = nn.ReLU(inplace=True) 224 | 225 | self.conv1d_m = nn.Conv2d(64,64,3,padding=1)### 226 | self.bn1d_m = nn.BatchNorm2d(64) 227 | self.relu1d_m = nn.ReLU(inplace=True) 228 | 229 | self.conv1d_2 = nn.Conv2d(64,64,3,padding=1) 230 | self.bn1d_2 = nn.BatchNorm2d(64) 231 | self.relu1d_2 = nn.ReLU(inplace=True) 232 | 233 | ## -------------Bilinear Upsampling-------------- 234 | self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')### 235 | self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear') 236 | self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear') 237 | self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear') 238 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear') 239 | 240 | ## -------------Side Output-------------- 241 | self.outconvb = nn.Conv2d(512,1,3,padding=1) 242 | self.outconv6 = nn.Conv2d(512,1,3,padding=1) 243 | self.outconv5 = nn.Conv2d(512,1,3,padding=1) 244 | self.outconv4 = nn.Conv2d(256,1,3,padding=1) 245 | self.outconv3 = nn.Conv2d(128,1,3,padding=1) 246 | self.outconv2 = nn.Conv2d(64,1,3,padding=1) 247 | self.outconv1 = nn.Conv2d(64,1,3,padding=1) 248 | 249 | ## -------------Refine Module------------- 250 | self.refunet = RefUnet(1,64) 251 | 252 | 253 | def forward(self,x): 254 | 255 | hx = x 256 | 257 | ## -------------Encoder------------- 258 | hx = self.inconv(hx) 259 | hx = self.inbn(hx) 260 | hx = self.inrelu(hx) 261 | 262 | h1 = self.encoder1(hx) # 256 263 | h2 = self.encoder2(h1) # 128 264 | h3 = self.encoder3(h2) # 64 265 | h4 = self.encoder4(h3) # 32 266 | 267 | hx = self.pool4(h4) # 16 268 | 269 | hx = self.resb5_1(hx) 270 | hx = self.resb5_2(hx) 271 | h5 = self.resb5_3(hx) 272 | 273 | hx = self.pool5(h5) # 8 274 | 275 | hx = self.resb6_1(hx) 276 | hx = self.resb6_2(hx) 277 | h6 = self.resb6_3(hx) 278 | 279 | ## -------------Bridge------------- 280 | hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8 281 | hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx))) 282 | hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx))) 283 | 284 | ## -------------Decoder------------- 285 | 286 | hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg,h6),1)))) 287 | hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx))) 288 | hd6 = self.relu6d_2(self.bn6d_2(self.conv6d_2(hx))) 289 | 290 | hx = self.upscore2(hd6) # 8 -> 16 291 | 292 | hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx,h5),1)))) 293 | hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx))) 294 | hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx))) 295 | 296 | hx = self.upscore2(hd5) # 16 -> 32 297 | 298 | hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx,h4),1)))) 299 | hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx))) 300 | hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx))) 301 | 302 | hx = self.upscore2(hd4) # 32 -> 64 303 | 304 | hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx,h3),1)))) 305 | hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx))) 306 | hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx))) 307 | 308 | hx = self.upscore2(hd3) # 64 -> 128 309 | 310 | hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx,h2),1)))) 311 | hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx))) 312 | hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx))) 313 | 314 | hx = self.upscore2(hd2) # 128 -> 256 315 | 316 | hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx,h1),1)))) 317 | hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx))) 318 | hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx))) 319 | 320 | ## -------------Side Output------------- 321 | db = self.outconvb(hbg) 322 | db = self.upscore6(db) # 8->256 323 | 324 | d6 = self.outconv6(hd6) 325 | d6 = self.upscore6(d6) # 8->256 326 | 327 | d5 = self.outconv5(hd5) 328 | d5 = self.upscore5(d5) # 16->256 329 | 330 | d4 = self.outconv4(hd4) 331 | d4 = self.upscore4(d4) # 32->256 332 | 333 | d3 = self.outconv3(hd3) 334 | d3 = self.upscore3(d3) # 64->256 335 | 336 | d2 = self.outconv2(hd2) 337 | d2 = self.upscore2(d2) # 128->256 338 | 339 | d1 = self.outconv1(hd1) # 256 340 | 341 | ## -------------Refine Module------------- 342 | dout = self.refunet(d1) # 256 343 | 344 | return F.sigmoid(dout), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6), F.sigmoid(db) 345 | --------------------------------------------------------------------------------