├── LSNet.py ├── README.md ├── config.py ├── mobilenetv2.py ├── requirements.txt ├── rgbd_dataset.py ├── rgbt_dataset.py ├── test.py ├── train.py └── utils.py /LSNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.functional as F 4 | 5 | class AFD_semantic(nn.Module): 6 | ''' 7 | Pay Attention to Features, Transfer Learn Faster CNNs 8 | https://openreview.net/pdf?id=ryxyCeHtPB 9 | ''' 10 | 11 | def __init__(self, in_channels, att_f): 12 | super(AFD_semantic, self).__init__() 13 | mid_channels = int(in_channels * att_f) 14 | 15 | self.attention = nn.Sequential(*[ 16 | nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=True), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(mid_channels, in_channels, 3, 1, 1, bias=True) 19 | ]) 20 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 21 | 22 | for m in self.modules(): 23 | if isinstance(m, nn.Conv2d): 24 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 25 | if m.bias is not None: 26 | nn.init.constant_(m.bias, 0) 27 | 28 | def forward(self, fm_s, fm_t, eps=1e-6): 29 | 30 | fm_t_pooled = self.avg_pool(fm_t) 31 | rho = self.attention(fm_t_pooled) 32 | rho = torch.sigmoid(rho.squeeze()) 33 | rho = rho / torch.sum(rho, dim=1, keepdim=True) 34 | 35 | fm_s_norm = torch.norm(fm_s, dim=(2, 3), keepdim=True) 36 | fm_s = torch.div(fm_s, fm_s_norm + eps) 37 | fm_t_norm = torch.norm(fm_t, dim=(2, 3), keepdim=True) 38 | fm_t = torch.div(fm_t, fm_t_norm + eps) 39 | 40 | loss = rho * torch.pow(fm_s - fm_t, 2).mean(dim=(2, 3)) 41 | loss = loss.sum(1).mean(0) 42 | 43 | return loss 44 | 45 | 46 | class AFD_spatial(nn.Module): 47 | ''' 48 | Pay Attention to Features, Transfer Learn Faster CNNs 49 | https://openreview.net/pdf?id=ryxyCeHtPB 50 | ''' 51 | 52 | def __init__(self, in_channels): 53 | super(AFD_spatial, self).__init__() 54 | 55 | self.attention = nn.Sequential(*[ 56 | nn.Conv2d(in_channels, 1, 3, 1, 1) 57 | ]) 58 | 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | 65 | def forward(self, fm_s, fm_t, eps=1e-6): 66 | 67 | rho = self.attention(fm_t) 68 | rho = torch.sigmoid(rho) 69 | rho = rho / torch.sum(rho, dim=(2,3), keepdim=True) 70 | 71 | fm_s_norm = torch.norm(fm_s, dim=1, keepdim=True) 72 | fm_s = torch.div(fm_s, fm_s_norm + eps) 73 | fm_t_norm = torch.norm(fm_t, dim=1, keepdim=True) 74 | fm_t = torch.div(fm_t, fm_t_norm + eps) 75 | loss = rho * torch.pow(fm_s - fm_t, 2).mean(dim=1, keepdim=True) 76 | loss =torch.sum(loss,dim=(2,3)).mean(0) 77 | return loss 78 | 79 | from mobilenetv2 import mobilenet_v2 80 | class LSNet(nn.Module): 81 | def __init__(self): 82 | super(LSNet, self).__init__() 83 | # rgb,depth encode 84 | self.rgb_pretrained = mobilenet_v2() 85 | self.depth_pretrained = mobilenet_v2() 86 | 87 | # Upsample_model 88 | self.upsample1_g = nn.Sequential(nn.Conv2d(68, 34, 3, 1, 1, ), nn.BatchNorm2d(34), nn.GELU(), 89 | nn.UpsamplingBilinear2d(scale_factor=2, )) 90 | 91 | self.upsample2_g = nn.Sequential(nn.Conv2d(104, 52, 3, 1, 1, ), nn.BatchNorm2d(52), nn.GELU(), 92 | nn.UpsamplingBilinear2d(scale_factor=2, )) 93 | 94 | self.upsample3_g = nn.Sequential(nn.Conv2d(160, 80, 3, 1, 1, ), nn.BatchNorm2d(80), nn.GELU(), 95 | nn.UpsamplingBilinear2d(scale_factor=2, )) 96 | 97 | self.upsample4_g = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1, ), nn.BatchNorm2d(128), nn.GELU(), 98 | nn.UpsamplingBilinear2d(scale_factor=2, )) 99 | 100 | self.upsample5_g = nn.Sequential(nn.Conv2d(320, 160, 3, 1, 1, ), nn.BatchNorm2d(160), nn.GELU(), 101 | nn.UpsamplingBilinear2d(scale_factor=2, )) 102 | 103 | 104 | self.conv_g = nn.Conv2d(34, 1, 1) 105 | self.conv2_g = nn.Conv2d(52, 1, 1) 106 | self.conv3_g = nn.Conv2d(80, 1, 1) 107 | 108 | 109 | # Tips: speed test and params and more this part is not included. 110 | # please comment this part when involved. 111 | if self.training: 112 | self.AFD_semantic_5_R_T = AFD_semantic(320,0.0625) 113 | self.AFD_semantic_4_R_T = AFD_semantic(96,0.0625) 114 | self.AFD_semantic_3_R_T = AFD_semantic(32,0.0625) 115 | self.AFD_spatial_3_R_T = AFD_spatial(32) 116 | self.AFD_spatial_2_R_T = AFD_spatial(24) 117 | self.AFD_spatial_1_R_T = AFD_spatial(16) 118 | 119 | 120 | def forward(self, rgb, ti): 121 | # rgb 122 | A1, A2, A3, A4, A5 = self.rgb_pretrained(rgb) 123 | # ti 124 | A1_t, A2_t, A3_t, A4_t, A5_t = self.depth_pretrained(ti) 125 | 126 | F5 = A5_t + A5 127 | F4 = A4_t + A4 128 | F3 = A3_t + A3 129 | F2 = A2_t + A2 130 | F1 = A1_t + A1 131 | 132 | 133 | F5 = self.upsample5_g(F5) 134 | F4 = torch.cat((F4, F5), dim=1) 135 | F4 = self.upsample4_g(F4) 136 | F3 = torch.cat((F3, F4), dim=1) 137 | F3 = self.upsample3_g(F3) 138 | F2 = torch.cat((F2, F3), dim=1) 139 | F2 = self.upsample2_g(F2) 140 | F1 = torch.cat((F1, F2), dim=1) 141 | F1 = self.upsample1_g(F1) 142 | 143 | out = self.conv_g(F1) 144 | 145 | 146 | if self.training: 147 | out3 = self.conv3_g(F3) 148 | out2 = self.conv2_g(F2) 149 | loss_semantic_5_R_T = self.AFD_semantic_5_R_T(A5, A5_t.detach()) 150 | loss_semantic_5_T_R = self.AFD_semantic_5_R_T(A5_t, A5.detach()) 151 | loss_semantic_4_R_T = self.AFD_semantic_4_R_T(A4, A4_t.detach()) 152 | loss_semantic_4_T_R = self.AFD_semantic_4_R_T(A4_t, A4.detach()) 153 | loss_semantic_3_R_T = self.AFD_semantic_3_R_T(A3, A3_t.detach()) 154 | loss_semantic_3_T_R = self.AFD_semantic_3_R_T(A3_t, A3.detach()) 155 | loss_spatial_3_R_T = self.AFD_spatial_3_R_T(A3, A3_t.detach()) 156 | loss_spatial_3_T_R = self.AFD_spatial_3_R_T(A3_t, A3.detach()) 157 | loss_spatial_2_R_T = self.AFD_spatial_2_R_T(A2, A2_t.detach()) 158 | loss_spatial_2_T_R = self.AFD_spatial_2_R_T(A2_t, A2.detach()) 159 | loss_spatial_1_R_T = self.AFD_spatial_1_R_T(A1, A1_t.detach()) 160 | loss_spatial_1_T_R = self.AFD_spatial_1_R_T(A1_t, A1.detach()) 161 | loss_KD = loss_semantic_5_R_T + loss_semantic_5_T_R + \ 162 | loss_semantic_4_R_T + loss_semantic_4_T_R + \ 163 | loss_semantic_3_R_T + loss_semantic_3_T_R + \ 164 | loss_spatial_3_R_T + loss_spatial_3_T_R + \ 165 | loss_spatial_2_R_T + loss_spatial_2_T_R + \ 166 | loss_spatial_1_R_T + loss_spatial_1_T_R 167 | return out, out2, out3, loss_KD 168 | return out 169 | 170 | 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSNet 2 | This project provides the code and results for 'LSNet: Lightweight Spatial Boosting Network for Detecting Salient Objects in RGB-Thermal Images', IEEE TIP, 2023. [IEEE link](https://ieeexplore.ieee.org/document/10042233)
3 | 4 | # Requirements 5 | Python 3.7+, Pytorch 1.5.0+, Cuda 10.2+, TensorboardX 2.1, opencv-python
6 | If anything goes wrong with the environment, please check requirements.txt for details. 7 | 8 | # Architecture and Details 9 | ![image](https://user-images.githubusercontent.com/38373305/218299592-13bb523b-8f1d-485f-9c65-137dca4e1544.png) 10 | drawing drawing 11 | 12 | # Results 13 | drawing 14 | drawing 15 | drawing 16 | drawing 17 | 18 | # Data Preparation 19 | - Download the RGB-T raw data from [baidu](https://pan.baidu.com/s/1fDht3BmqIYPks_iquST5hQ), pin: sf9y / [Google drive](https://drive.google.com/file/d/1vjdD13DTh9mM69mRRRdFBbpWbmj6MSKj/view?usp=share_link)
20 | - Download the RGB-D raw data from [baidu](https://pan.baidu.com/s/1A-fwxAtnwMPuznn1PCATWg), pin: 7pi5 / [Google drive](https://drive.google.com/file/d/1WzTuHQJCKPE5OreanoU0N2e82Y1_VZyA/view?usp=share_link)
21 | 22 | Note that the depth maps of the raw data above are foreground is white. 23 | # Training & Testing 24 | modify the `train_root` `train_root` `save_path` path in `config.py` according to your own data path. 25 | - Train the LSNet: 26 | 27 | `python train.py` 28 | 29 | modify the `test_path` path in `config.py` according to your own data path. 30 | 31 | - Test the LSNet: 32 | 33 | `python test.py` 34 | 35 | Note that `task` in `config.py` determines which task and dataset to use. 36 | 37 | # Evaluate tools 38 | - You can select one of toolboxes to get the metrics 39 | [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics) 40 | 41 | # Saliency Maps 42 | - RGB-T [baidu](https://pan.baidu.com/s/1i5GwM0C0OfE5D5VLXlBkVA) pin: fxsk / [Google drive](https://drive.google.com/file/d/1ATEw8cNLHYfuCAK40VUBzcqBnMOKw-OV/view?usp=sharing)
43 | - RGB-D [baidu](https://pan.baidu.com/s/1bAlk753MeeRG0BLMJXAzxQ) pin: 6352 / [Google drive](https://drive.google.com/file/d/1WgQlcVWg_YC4_64TaIn8JSWuzZC_FfhW/view?usp=sharing)
44 | 45 | Note that we resize the testing data to the size of 224 * 224 for quicky evaluate.
46 | please check our previous works [APNet](https://github.com/zyrant/APNet) and [CCAFNet](https://github.com/zyrant/CCAFNet). 47 | 48 | # Pretraining Models 49 | - RGB-T [baidu](https://pan.baidu.com/s/1aGP283gNpb3oosvbq4OSDg) pin: wnoa / [Google drive](https://drive.google.com/drive/folders/17xmRA5zhLeIIS_-1EXbhxhPoW-Xn40xl?usp=sharing)
50 | - RGB-D [baidu](https://pan.baidu.com/s/1aGP283gNpb3oosvbq4OSDg) pin: wnoa / [Google drive](https://drive.google.com/drive/folders/17xmRA5zhLeIIS_-1EXbhxhPoW-Xn40xl?usp=sharing)
51 | 52 | # Citation 53 | @ARTICLE{Zhou_2023_LSNet, 54 | author={Zhou, Wujie and Zhu, Yun and Lei, Jingsheng and Yang, Rongwang and Yu, Lu}, 55 | journal={IEEE Transactions on Image Processing}, 56 | title={LSNet: Lightweight Spatial Boosting Network for Detecting Salient Objects in RGB-Thermal Images}, 57 | year={2023}, 58 | volume={32}, 59 | number={}, 60 | pages={1329-1340}, 61 | doi={10.1109/TIP.2023.3242775}} 62 | 63 | # Acknowledgement 64 | The implement of this project is based on the codebases bellow.
65 | - [BBS-Net](https://github.com/zyjwuyan/BBS-Net)
66 | - [Knowledge-Distillation-Zoo](https://github.com/AberHu/Knowledge-Distillation-Zoo)
67 | - Fps/speed test [MobileSal](https://github.com/yuhuan-wu/MobileSal/blob/master/speed_test.py) 68 | - Evaluate tools [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics)
69 | 70 | If you find this project helpful, Please also cite codebases above. 71 | 72 | # Contact 73 | Please drop me an email for any problems or discussion: https://wujiezhou.github.io/ (wujiezhou@163.com) or zzzyylink@gmail.com. 74 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | # train/val 4 | parser.add_argument('--task', type=str, default='RGBT', help='epoch number') 5 | parser.add_argument('--epoch', type=int, default=20, help='epoch number') 6 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 7 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size') 8 | parser.add_argument('--trainsize', type=int, default=224, help='training dataset size') 9 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 10 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 11 | parser.add_argument('--decay_epoch', type=int, default=40, help='every n epochs decay learning rate') 12 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints') 13 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id') 14 | parser.add_argument('--train_root', type=str, default='', help='the train images root') 15 | parser.add_argument('--val_root', type=str, default='', help='the val images root') 16 | parser.add_argument('--save_path', type=str, default='', help='the path to save models and logs') 17 | # test(predict) 18 | parser.add_argument('--testsize', type=int, default=224, help='testing size') 19 | parser.add_argument('--test_path',type=str,default='',help='test dataset path') 20 | opt = parser.parse_args() 21 | -------------------------------------------------------------------------------- /mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | model_urls = { 9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 10 | } 11 | 12 | 13 | def _make_divisible(v, divisor, min_value=None): 14 | """ 15 | This function is taken from the original tf repo. 16 | It ensures that all layers have a channel number that is divisible by 8 17 | It can be seen here: 18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 19 | :param v: 20 | :param divisor: 21 | :param min_value: 22 | :return: 23 | """ 24 | if min_value is None: 25 | min_value = divisor 26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 27 | # Make sure that round down does not go down by more than 10%. 28 | if new_v < 0.9 * v: 29 | new_v += divisor 30 | return new_v 31 | 32 | 33 | class ConvBNReLU(nn.Sequential): 34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 35 | padding = (kernel_size - 1) // 2 36 | if norm_layer is None: 37 | norm_layer = nn.BatchNorm2d 38 | super(ConvBNReLU, self).__init__( 39 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 40 | norm_layer(out_planes), 41 | nn.ReLU6(inplace=True) 42 | ) 43 | 44 | 45 | class InvertedResidual(nn.Module): 46 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 47 | super(InvertedResidual, self).__init__() 48 | self.stride = stride 49 | assert stride in [1, 2] 50 | 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | 54 | hidden_dim = int(round(inp * expand_ratio)) 55 | self.use_res_connect = self.stride == 1 and inp == oup 56 | 57 | layers = [] 58 | if expand_ratio != 1: 59 | # pw 60 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 61 | layers.extend([ 62 | # dw 63 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 64 | # pw-linear 65 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 66 | norm_layer(oup), 67 | ]) 68 | self.conv = nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | if self.use_res_connect: 72 | return x + self.conv(x) 73 | else: 74 | return self.conv(x) 75 | 76 | 77 | class MobileNetV2(nn.Module): 78 | def __init__(self, 79 | num_classes=1000, 80 | width_mult=1.0, 81 | inverted_residual_setting=None, 82 | round_nearest=8, 83 | block=None, 84 | norm_layer=None): 85 | """ 86 | MobileNet V2 main class 87 | 88 | Args: 89 | num_classes (int): Number of classes 90 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 91 | inverted_residual_setting: Network structure 92 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 93 | Set to 1 to turn off rounding 94 | block: Module specifying inverted residual building block for mobilenet 95 | norm_layer: Module specifying the normalization layer to use 96 | 97 | """ 98 | super(MobileNetV2, self).__init__() 99 | 100 | if block is None: 101 | block = InvertedResidual 102 | 103 | if norm_layer is None: 104 | norm_layer = nn.BatchNorm2d 105 | 106 | input_channel = 32 107 | last_channel = 1280 108 | 109 | if inverted_residual_setting is None: 110 | inverted_residual_setting = [ 111 | # t, c, n, s 112 | [1, 16, 1, 1], 113 | [6, 24, 2, 2], 114 | [6, 32, 3, 2], 115 | [6, 64, 4, 2], 116 | [6, 96, 3, 1], 117 | [6, 160, 3, 2], 118 | [6, 320, 1, 1], 119 | ] 120 | 121 | # only check the first element, assuming user knows t,c,n,s are required 122 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 123 | raise ValueError("inverted_residual_setting should be non-empty " 124 | "or a 4-element list, got {}".format(inverted_residual_setting)) 125 | 126 | # building first layer 127 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 128 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 129 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 130 | # building inverted residual blocks 131 | for t, c, n, s in inverted_residual_setting: 132 | output_channel = _make_divisible(c * width_mult, round_nearest) 133 | for i in range(n): 134 | stride = s if i == 0 else 1 135 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 136 | input_channel = output_channel 137 | # building last several layers 138 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 139 | # make it nn.Sequential 140 | self.features = nn.Sequential(*features) 141 | 142 | # building classifier 143 | # self.classifier = nn.Sequential( 144 | # nn.Dropout(0.2), 145 | # nn.Linear(self.last_channel, num_classes), 146 | # ) 147 | 148 | # weight initialization 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 152 | if m.bias is not None: 153 | nn.init.zeros_(m.bias) 154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 155 | nn.init.ones_(m.weight) 156 | nn.init.zeros_(m.bias) 157 | elif isinstance(m, nn.Linear): 158 | nn.init.normal_(m.weight, 0, 0.01) 159 | nn.init.zeros_(m.bias) 160 | 161 | def _forward_impl(self, x): 162 | # print(x.shape) 163 | # This exists since TorchScript doesn't support inheritance, so the superclass method 164 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 165 | x = self.features[:2](x) 166 | out1 = x 167 | x = self.features[2:3](x) 168 | out2 = x 169 | x = self.features[3:7](x) 170 | out3 = x 171 | x = self.features[7:14](x) 172 | out4 = x 173 | x = self.features[14:18](x) 174 | out5 = x 175 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 176 | # x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 177 | # x = self.classifier(x) 178 | return out1, out2, out3, out4, out5 179 | 180 | def forward(self, x): 181 | return self._forward_impl(x) 182 | 183 | 184 | def mobilenet_v2(pretrained=True, **kwargs): 185 | """ 186 | Constructs a MobileNetV2 architecture from 187 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | progress (bool): If True, displays a progress bar of the download to stderr 192 | """ 193 | model = MobileNetV2(**kwargs) 194 | # if pretrained: 195 | # print('v2 pretrained loading....') 196 | # state_dict = model_zoo.load_url(model_urls['mobilenet_v2']) 197 | # model.load_state_dict(state_dict) 198 | if pretrained: 199 | pretrained_vgg = model_zoo.load_url(model_urls['mobilenet_v2']) 200 | model_dict = {} 201 | state_dict = model.state_dict() 202 | for k, v in pretrained_vgg.items(): 203 | if k in state_dict: 204 | model_dict[k] = v 205 | # print(k, v) 206 | 207 | state_dict.update(model_dict) 208 | model.load_state_dict(state_dict) 209 | return model 210 | 211 | 212 | if __name__=='__main__': 213 | 214 | # model = ghost_net() 215 | # model.eval() 216 | import torch 217 | model = mobilenet_v2() 218 | rgb = torch.randn(1, 3, 224, 224) 219 | out = model(rgb) 220 | for i in out: 221 | print(i.shape) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.14.0 2 | addict==2.4.0 3 | alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work 4 | anaconda-client @ file:///tmp/build/80754af9/anaconda-client_1624473988214/work 5 | anaconda-navigator==2.0.3 6 | anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1621348054992/work 7 | anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist 8 | appdirs==1.4.4 9 | argh==0.26.2 10 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work 11 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work 12 | astroid @ file:///tmp/build/80754af9/astroid_1625075819965/work 13 | astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work 14 | asttokens==2.0.5 15 | async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work 16 | atomicwrites==1.4.0 17 | attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work 18 | autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work 19 | Babel @ file:///tmp/build/80754af9/babel_1620871417480/work 20 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 21 | backports.functools-lru-cache @ file:///tmp/build/80754af9/backports.functools_lru_cache_1618170165463/work 22 | backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work 23 | backports.tempfile @ file:///home/linux1/recipes/ci/backports.tempfile_1610991236607/work 24 | backports.weakref==1.0.post1 25 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work 26 | bitarray @ file:///tmp/build/80754af9/bitarray_1620827551536/work 27 | bkcharts==0.2 28 | black==19.10b0 29 | bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work 30 | bokeh @ file:///tmp/build/80754af9/bokeh_1620779595936/work 31 | boto==2.49.0 32 | Bottleneck==1.3.2 33 | brotlipy==0.7.0 34 | cachetools==4.2.2 35 | certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi 36 | cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work 37 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work 38 | click @ file:///tmp/build/80754af9/click_1621604852318/work 39 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work 40 | clyent==1.2.2 41 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work 42 | conda==4.11.0 43 | conda-build==3.21.4 44 | conda-content-trust @ file:///tmp/build/80754af9/conda-content-trust_1617045594566/work 45 | conda-pack @ file:///tmp/build/80754af9/conda-pack_1611163042455/work 46 | conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1618262148928/work 47 | conda-repo-cli @ file:///tmp/build/80754af9/conda-repo-cli_1620168426516/work 48 | conda-token @ file:///tmp/build/80754af9/conda-token_1620076980546/work 49 | conda-verify==3.4.2 50 | contextlib2==0.6.0.post1 51 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work 52 | cycler==0.10.0 53 | Cython @ file:///tmp/build/80754af9/cython_1618435160151/work 54 | cytoolz==0.11.0 55 | dask @ file:///tmp/build/80754af9/dask-core_1624381970968/work 56 | dataclasses==0.6 57 | decorator @ file:///home/ktietz/src/ci/decorator_1611930055503/work 58 | defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work 59 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work 60 | distributed @ file:///tmp/build/80754af9/distributed_1624589265858/work 61 | docutils @ file:///tmp/build/80754af9/docutils_1620827984873/work 62 | dtaidistance==2.3.2 63 | easydict==1.9 64 | einops==0.3.2 65 | entrypoints==0.3 66 | et-xmlfile==1.1.0 67 | fastcache==1.1.0 68 | filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work 69 | flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work 70 | Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work 71 | fsspec @ file:///tmp/build/80754af9/fsspec_1623705546600/work 72 | future==0.18.2 73 | gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work 74 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work 75 | gmpy2==2.0.8 76 | google-auth==1.35.0 77 | google-auth-oauthlib==0.4.6 78 | greenlet @ file:///tmp/build/80754af9/greenlet_1620913319000/work 79 | grpcio==1.40.0 80 | h5py==2.10.0 81 | HeapDict==1.0.1 82 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work 83 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work 84 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work 85 | imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work 86 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work 87 | iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work 88 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work 89 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl 90 | ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work 91 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work 92 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work 93 | isort @ file:///tmp/build/80754af9/isort_1624300337312/work 94 | itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work 95 | jdcal==1.4.1 96 | jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work 97 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work 98 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work 99 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work 100 | json5 @ file:///tmp/build/80754af9/json5_1624432770122/work 101 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work 102 | jupyter==1.0.0 103 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work 104 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work 105 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work 106 | jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work 107 | jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work 108 | jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work 109 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work 110 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work 111 | jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work 112 | keyring @ file:///tmp/build/80754af9/keyring_1621524402652/work 113 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work 114 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work 115 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work 116 | llvmlite==0.36.0 117 | locket==0.2.1 118 | lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work 119 | Markdown==3.3.4 120 | MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528148836/work 121 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work 122 | mccabe==0.6.1 123 | mindspore==1.5.1 124 | mistune==0.8.4 125 | mkl-fft==1.3.0 126 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work 127 | mkl-service==2.3.0 128 | mmcv==1.4.0 129 | mmcv-full==1.4.0 130 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work 131 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1622818384463/work 132 | mpmath==1.2.1 133 | msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work 134 | multipledispatch==0.6.0 135 | mypy-extensions==0.4.3 136 | navigator-updater==0.2.1 137 | nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work 138 | nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work 139 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1624479060632/work 140 | nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work 141 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work 142 | networkx @ file:///tmp/build/80754af9/networkx_1617653298338/work 143 | nltk @ file:///tmp/build/80754af9/nltk_1621347441292/work 144 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work 145 | notebook @ file:///tmp/build/80754af9/notebook_1621528346532/work 146 | numba @ file:///tmp/build/80754af9/numba_1616774046117/work 147 | numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work 148 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620830962040/work 149 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work 150 | nvidia-cublas-cu11==11.10.3.66 151 | nvidia-cuda-runtime-cu11==11.8.89 152 | nvidia-cudnn-cu11==8.6.0.163 153 | nvidia-pyindex==1.0.9 154 | nvidia-tensorrt==8.4.1.5 155 | oauthlib==3.1.1 156 | olefile==0.46 157 | opencv-python==4.5.2.54 158 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work 159 | packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work 160 | pandas==1.2.5 161 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work 162 | parso==0.7.0 163 | partd @ file:///tmp/build/80754af9/partd_1618000087440/work 164 | path @ file:///tmp/build/80754af9/path_1623603875173/work 165 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1625585678054/work 166 | pathspec==0.7.0 167 | pathtools==0.1.2 168 | patsy==0.5.1 169 | pep8==1.7.1 170 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 171 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 172 | Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work 173 | pkginfo==1.7.0 174 | pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work 175 | ply==3.11 176 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work 177 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work 178 | protobuf==3.17.3 179 | psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work 180 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 181 | py @ file:///tmp/build/80754af9/py_1607971587848/work 182 | pyasn1==0.4.8 183 | pyasn1-modules==0.2.8 184 | pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work 185 | pycosat==0.6.3 186 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 187 | pycurl==7.43.0.6 188 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1621600989141/work 189 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1621560806183/work 190 | pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work 191 | Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work 192 | pylint @ file:///tmp/build/80754af9/pylint_1625158820537/work 193 | pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work 194 | pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work 195 | pyodbc===4.0.0-unsupported 196 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work 197 | pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work 198 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work 199 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work 200 | pytest==6.2.4 201 | python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work 202 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work 203 | python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work 204 | pytorch-wavelets @ file:///home/sunfan/Downloads/qq-files/3101347528/file_recv/fft_conv/pytorch_wavelets 205 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work 206 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work 207 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work 208 | PyYAML==5.4.1 209 | pyzmq==20.0.0 210 | QDarkStyle==2.8.1 211 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work 212 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work 213 | QtPy==1.9.0 214 | regex @ file:///tmp/build/80754af9/regex_1617569202463/work 215 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work 216 | requests-oauthlib==1.3.0 217 | rope @ file:///tmp/build/80754af9/rope_1623703006312/work 218 | rsa==4.7.2 219 | Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work 220 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work 221 | scikit-image @ file:///tmp/build/80754af9/scikit-image_1648196304918/work 222 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1621370412049/work 223 | scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work 224 | seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work 225 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work 226 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work 227 | simplegeneric==0.8.1 228 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1623948242478/work 229 | sip==4.19.13 230 | six @ file:///tmp/build/80754af9/six_1623709665295/work 231 | sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work 232 | snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work 233 | sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work 234 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1623949099177/work 235 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work 236 | Sphinx @ file:///tmp/build/80754af9/sphinx_1623884544367/work 237 | sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work 238 | sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work 239 | sphinxcontrib-htmlhelp @ file:///tmp/build/80754af9/sphinxcontrib-htmlhelp_1623945626792/work 240 | sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work 241 | sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work 242 | sphinxcontrib-serializinghtml @ file:///tmp/build/80754af9/sphinxcontrib-serializinghtml_1624451540180/work 243 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work 244 | spyder @ file:///tmp/build/80754af9/spyder_1616775618138/work 245 | spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1614030590686/work 246 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1624584182860/work 247 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work 248 | sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work 249 | tables==3.6.1 250 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work 251 | tensorboard==2.6.0 252 | tensorboard-data-server==0.6.1 253 | tensorboard-plugin-wit==1.8.0 254 | tensorboardX==2.4 255 | tensorrt==0.0.1 256 | terminado==0.9.4 257 | testpath @ file:///tmp/build/80754af9/testpath_1624638946665/work 258 | textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work 259 | thop==0.0.31.post2005241907 260 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl 261 | three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work 262 | tifffile==2020.10.1 263 | timm==0.4.12 264 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work 265 | toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work 266 | torch==1.9.0 267 | torch-scatter==2.0.9 268 | torch-sparse==0.6.13 269 | torch2trt==0.4.0 270 | torchaudio==0.9.0a0+33b2469 271 | torchvision==0.10.0 272 | tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work 273 | tqdm @ file:///tmp/build/80754af9/tqdm_1625563689033/work 274 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work 275 | ttach==0.0.3 276 | typed-ast @ file:///tmp/build/80754af9/typed-ast_1624953673417/work 277 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1624965014186/work 278 | ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work 279 | unicodecsv==0.14.1 280 | urllib3 @ file:///tmp/build/80754af9/urllib3_1625084269274/work 281 | watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work 282 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work 283 | webencodings==0.5.1 284 | Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work 285 | widgetsnbextension==3.5.1 286 | wrapt==1.12.1 287 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work 288 | xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work 289 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1625006966557/work 290 | xlwt==1.3.0 291 | xmltodict==0.12.0 292 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work 293 | zict==2.0.0 294 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work 295 | zope.event==4.5.0 296 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1625035545636/work 297 | -------------------------------------------------------------------------------- /rgbd_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | import torch 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label, depth): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 19 | # top bottom flip 20 | # if flip_flag2==1: 21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 24 | return img, label, depth 25 | 26 | 27 | def randomCrop(image, label, depth): 28 | border = 30 29 | image_width = image.size[0] 30 | image_height = image.size[1] 31 | crop_win_width = np.random.randint(image_width - border, image_width) 32 | crop_win_height = np.random.randint(image_height - border, image_height) 33 | random_region = ( 34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 35 | (image_height + crop_win_height) >> 1) 36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region) 37 | 38 | 39 | def randomRotation(image, label, depth): 40 | mode = Image.BICUBIC 41 | if random.random() > 0.8: 42 | random_angle = np.random.randint(-15, 15) 43 | image = image.rotate(random_angle, mode) 44 | label = label.rotate(random_angle, mode) 45 | depth = depth.rotate(random_angle, mode) 46 | return image, label, depth 47 | 48 | 49 | def colorEnhance(image): 50 | bright_intensity = random.randint(5, 15) / 10.0 51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 52 | contrast_intensity = random.randint(5, 15) / 10.0 53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 54 | color_intensity = random.randint(0, 20) / 10.0 55 | image = ImageEnhance.Color(image).enhance(color_intensity) 56 | sharp_intensity = random.randint(0, 30) / 10.0 57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 58 | return image 59 | 60 | 61 | def randomGaussian(image, mean=0.1, sigma=0.35): 62 | def gaussianNoisy(im, mean=mean, sigma=sigma): 63 | for _i in range(len(im)): 64 | im[_i] += random.gauss(mean, sigma) 65 | return im 66 | 67 | img = np.asarray(image) 68 | width, height = img.shape 69 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 70 | img = img.reshape([width, height]) 71 | return Image.fromarray(np.uint8(img)) 72 | 73 | 74 | def randomPeper(img): 75 | img = np.array(img) 76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 77 | for i in range(noiseNum): 78 | 79 | randX = random.randint(0, img.shape[0] - 1) 80 | 81 | randY = random.randint(0, img.shape[1] - 1) 82 | 83 | if random.randint(0, 1) == 0: 84 | 85 | img[randX, randY] = 0 86 | 87 | else: 88 | 89 | img[randX, randY] = 255 90 | return Image.fromarray(img) 91 | 92 | 93 | # dataset for training 94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 96 | class SalObjDataset(data.Dataset): 97 | def __init__(self, image_root, gt_root, depth_root, trainsize): 98 | self.trainsize = trainsize 99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') 100 | or f.endswith('.png')] 101 | # print(self.images) 102 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 103 | or f.endswith('.png')] 104 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 105 | or f.endswith('.png')] 106 | self.images = sorted(self.images) 107 | self.gts = sorted(self.gts) 108 | self.depths = sorted(self.depths) 109 | self.filter_files() 110 | self.size = len(self.images) 111 | self.img_transform = transforms.Compose([ 112 | transforms.Resize((self.trainsize, self.trainsize)), 113 | transforms.ToTensor(), 114 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 115 | self.gt_transform = transforms.Compose([ 116 | transforms.Resize((self.trainsize, self.trainsize)), 117 | transforms.ToTensor()]) 118 | self.depths_transform = transforms.Compose([ 119 | transforms.Resize((self.trainsize, self.trainsize)), 120 | transforms.ToTensor(), 121 | transforms.Normalize([0.485], [0.229]) 122 | ]) 123 | 124 | def __getitem__(self, index): 125 | image = self.rgb_loader(self.images[index]) 126 | gt = self.binary_loader(self.gts[index]) 127 | depth = self.binary_loader(self.depths[index]) 128 | image, gt, depth = cv_random_flip(image, gt, depth) 129 | image, gt, depth = randomCrop(image, gt, depth) 130 | image, gt, depth = randomRotation(image, gt, depth) 131 | image = colorEnhance(image) 132 | # gt=randomGaussian(gt) 133 | gt = randomPeper(gt) 134 | # image, gt, depth = self.resize(image,gt, depth) 135 | image = self.img_transform(image) 136 | gt = self.gt_transform(gt) 137 | depth = self.depths_transform(depth) 138 | # depth = torch.div(depth.float(),255.0) # DUT 139 | 140 | return image, gt, depth 141 | 142 | def filter_files(self): 143 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 144 | # print(len(self.images),len(self.gts),len(self.depths)) 145 | images = [] 146 | gts = [] 147 | depths = [] 148 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths): 149 | img = Image.open(img_path) 150 | gt = Image.open(gt_path) 151 | depth = Image.open(depth_path) 152 | if img.size == gt.size and gt.size == depth.size: 153 | # if img.size == gt.size: 154 | images.append(img_path) 155 | gts.append(gt_path) 156 | depths.append(depth_path) 157 | self.images = images 158 | self.gts = gts 159 | self.depths = depths 160 | 161 | def rgb_loader(self, path): 162 | # print(path) 163 | with open(path, 'rb') as f: 164 | img = Image.open(f) 165 | # print(img) 166 | return img.convert('RGB') 167 | 168 | def binary_loader(self, path): 169 | with open(path, 'rb') as f: 170 | img = Image.open(f) 171 | return img.convert('L') 172 | 173 | def resize(self, img, gt, depth): 174 | assert img.size == gt.size and gt.size == depth.size 175 | h = self.trainsize 176 | w = self.trainsize 177 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 178 | Image.NEAREST) 179 | 180 | 181 | def __len__(self): 182 | return self.size 183 | 184 | 185 | # dataloader for training 186 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True): 187 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize) 188 | data_loader = data.DataLoader(dataset=dataset, 189 | batch_size=batchsize, 190 | shuffle=shuffle, 191 | num_workers=num_workers, 192 | pin_memory=pin_memory) 193 | return data_loader 194 | 195 | 196 | # test dataset and loader 197 | class test_dataset: 198 | def __init__(self, image_root, gt_root, depth_root, testsize): 199 | self.testsize = testsize 200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 201 | 202 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 203 | or f.endswith('.png')] 204 | 205 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 206 | or f.endswith('.png')] 207 | 208 | self.images = sorted(self.images) 209 | # print(self.images) 210 | self.gts = sorted(self.gts) 211 | # print(self.gts) 212 | self.depths = sorted(self.depths) 213 | # print(self.depths) 214 | self.transform = transforms.Compose([ 215 | transforms.Resize((self.testsize, self.testsize)), 216 | transforms.ToTensor(), 217 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 218 | # self.gt_transform = transforms.ToTensor() 219 | self.gt_transform = transforms.Compose([ 220 | transforms.Resize((self.testsize, self.testsize)), 221 | transforms.ToTensor()]) 222 | self.depths_transform = transforms.Compose([ 223 | transforms.Resize((self.testsize, self.testsize)), 224 | transforms.ToTensor(), 225 | transforms.Normalize([0.485], [0.229]) 226 | ]) 227 | self.size = len(self.images) 228 | self.index = 0 229 | 230 | def load_data(self): 231 | image = self.rgb_loader(self.images[self.index]) 232 | gt = self.binary_loader(self.gts[self.index]) 233 | depth = self.binary_loader(self.depths[self.index]) 234 | # image, gt, depth = self.resize(image, gt, depth) 235 | image = self.transform(image).unsqueeze(0) 236 | gt = self.gt_transform(gt).unsqueeze(0) 237 | depth = self.depths_transform(depth) 238 | # depth = torch.div(depth.float(), 255.0) # DUT 239 | depth = depth.unsqueeze(0) 240 | name = self.images[self.index].split('/')[-1] 241 | if name.endswith('.jpg'): 242 | name = name.split('.jpg')[0] + '.png' 243 | self.index += 1 244 | self.index = self.index % self.size 245 | return image, gt, depth, name 246 | 247 | def rgb_loader(self, path): 248 | with open(path, 'rb') as f: 249 | img = Image.open(f) 250 | return img.convert('RGB') 251 | 252 | def binary_loader(self, path): 253 | with open(path, 'rb') as f: 254 | img = Image.open(f) 255 | return img.convert('L') 256 | 257 | def resize(self, img, gt, depth): 258 | # assert img.size == gt.size and gt.size == depth.size 259 | h = self.testsize 260 | w = self.testsize 261 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 262 | Image.NEAREST) 263 | 264 | def __len__(self): 265 | return self.size 266 | -------------------------------------------------------------------------------- /rgbt_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label, ti): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | ti = ti.transpose(Image.FLIP_LEFT_RIGHT) 19 | # top bottom flip 20 | # if flip_flag2==1: 21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 23 | # ti = ti.transpose(Image.FLIP_TOP_BOTTOM) 24 | return img, label, ti 25 | 26 | 27 | def randomCrop(image, label, ti): 28 | border = 30 29 | image_width = image.size[0] 30 | image_height = image.size[1] 31 | crop_win_width = np.random.randint(image_width - border, image_width) 32 | crop_win_height = np.random.randint(image_height - border, image_height) 33 | random_region = ( 34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 35 | (image_height + crop_win_height) >> 1) 36 | return image.crop(random_region), label.crop(random_region), ti.crop(random_region) 37 | 38 | 39 | def randomRotation(image, label, ti): 40 | mode = Image.BICUBIC 41 | if random.random() > 0.8: 42 | random_angle = np.random.randint(-15, 15) 43 | image = image.rotate(random_angle, mode) 44 | label = label.rotate(random_angle, mode) 45 | ti = ti.rotate(random_angle, mode) 46 | return image, label, ti 47 | 48 | 49 | def colorEnhance(image): 50 | bright_intensity = random.randint(5, 15) / 10.0 51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 52 | contrast_intensity = random.randint(5, 15) / 10.0 53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 54 | color_intensity = random.randint(0, 20) / 10.0 55 | image = ImageEnhance.Color(image).enhance(color_intensity) 56 | sharp_intensity = random.randint(0, 30) / 10.0 57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 58 | return image 59 | 60 | 61 | def randomGaussian(image, mean=0.1, sigma=0.35): 62 | def gaussianNoisy(im, mean=mean, sigma=sigma): 63 | for _i in range(len(im)): 64 | im[_i] += random.gauss(mean, sigma) 65 | return im 66 | 67 | img = np.asarray(image) 68 | width, height = img.shape 69 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 70 | img = img.reshape([width, height]) 71 | return Image.fromarray(np.uint8(img)) 72 | 73 | 74 | def randomPeper(img): 75 | img = np.array(img) 76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 77 | for i in range(noiseNum): 78 | 79 | randX = random.randint(0, img.shape[0] - 1) 80 | 81 | randY = random.randint(0, img.shape[1] - 1) 82 | 83 | if random.randint(0, 1) == 0: 84 | 85 | img[randX, randY] = 0 86 | 87 | else: 88 | 89 | img[randX, randY] = 255 90 | return Image.fromarray(img) 91 | 92 | 93 | # dataset for training 94 | # The current loader is not using the normalized ti maps for training and test. If you use the normalized ti maps 95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 96 | class SalObjDataset(data.Dataset): 97 | def __init__(self, image_root, gt_root, ti_root, trainsize): 98 | self.trainsize = trainsize 99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 100 | # print(self.images) 101 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 102 | or f.endswith('.png')] 103 | 104 | self.tis = [ti_root + f for f in os.listdir(ti_root) if f.endswith('.jpg') 105 | or f.endswith('.png')] 106 | 107 | self.images = sorted(self.images) 108 | self.gts = sorted(self.gts) 109 | self.tis = sorted(self.tis) 110 | 111 | self.filter_files() 112 | self.size = len(self.images) 113 | self.img_transform = transforms.Compose([ 114 | transforms.Resize((self.trainsize, self.trainsize)), 115 | transforms.ToTensor(), 116 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 117 | self.gt_transform = transforms.Compose([ 118 | transforms.Resize((self.trainsize, self.trainsize)), 119 | transforms.ToTensor()]) 120 | self.tis_transform = transforms.Compose([ 121 | transforms.Resize((self.trainsize, self.trainsize)), 122 | transforms.ToTensor(), 123 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 124 | 125 | 126 | def __getitem__(self, index): 127 | image = self.rgb_loader(self.images[index]) 128 | gt = self.binary_loader(self.gts[index]) 129 | ti = self.rgb_loader(self.tis[index]) 130 | image, gt, ti = cv_random_flip(image, gt, ti) 131 | image, gt, ti = randomCrop(image, gt, ti) 132 | image, gt, ti = randomRotation(image, gt, ti) 133 | image = colorEnhance(image) 134 | # gt=randomGaussian(gt) 135 | gt = randomPeper(gt) 136 | image = self.img_transform(image) 137 | gt = self.gt_transform(gt) 138 | ti = self.tis_transform(ti) 139 | return image, gt, ti 140 | 141 | def filter_files(self): 142 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.tis) 143 | images = [] 144 | gts = [] 145 | tis = [] 146 | for img_path, gt_path, ti_path in zip(self.images, self.gts, self.tis): 147 | img = Image.open(img_path) 148 | gt = Image.open(gt_path) 149 | ti = Image.open(ti_path) 150 | if img.size == gt.size and gt.size == ti.size: 151 | images.append(img_path) 152 | gts.append(gt_path) 153 | tis.append(ti_path) 154 | self.images = images 155 | self.gts = gts 156 | self.tis = tis 157 | 158 | def rgb_loader(self, path): 159 | with open(path, 'rb') as f: 160 | img = Image.open(f) 161 | return img.convert('RGB') 162 | 163 | def binary_loader(self, path): 164 | with open(path, 'rb') as f: 165 | img = Image.open(f) 166 | return img.convert('L') 167 | 168 | def resize(self, img, gt, ti): 169 | assert img.size == gt.size and gt.size == ti.size 170 | w, h = img.size 171 | if h < self.trainsize or w < self.trainsize: 172 | h = max(h, self.trainsize) 173 | w = max(w, self.trainsize) 174 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), ti.resize((w, h), 175 | Image.NEAREST) 176 | else: 177 | return img, gt, ti 178 | 179 | def __len__(self): 180 | return self.size 181 | 182 | 183 | # dataloader for training 184 | def get_loader(image_root, gt_root, ti_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=False): 185 | dataset = SalObjDataset(image_root, gt_root, ti_root, trainsize) 186 | 187 | data_loader = data.DataLoader(dataset=dataset, 188 | batch_size=batchsize, 189 | shuffle=shuffle, 190 | num_workers=num_workers, 191 | pin_memory=pin_memory) 192 | # print(len(data_loader)) 193 | return data_loader 194 | 195 | 196 | # test dataset and loader 197 | class test_dataset: 198 | def __init__(self, image_root, gt_root, ti_root,testsize): 199 | self.testsize = testsize 200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 201 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 202 | or f.endswith('.png')] 203 | self.tis = [ti_root + f for f in os.listdir(ti_root) if f.endswith('.jpg') 204 | or f.endswith('.png')] 205 | 206 | self.images = sorted(self.images) 207 | self.gts = sorted(self.gts) 208 | self.tis = sorted(self.tis) 209 | self.transform = transforms.Compose([ 210 | transforms.Resize((self.testsize, self.testsize)), 211 | transforms.ToTensor(), 212 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 213 | # self.gt_transform = transforms.ToTensor() 214 | self.gt_transform = transforms.Compose([ 215 | transforms.Resize((self.testsize, self.testsize)), 216 | transforms.ToTensor()]) 217 | self.tis_transform = transforms.Compose([ 218 | transforms.Resize((self.testsize, self.testsize)), 219 | transforms.ToTensor(), 220 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 221 | self.size = len(self.images) 222 | self.index = 0 223 | 224 | def load_data(self): 225 | image = self.rgb_loader(self.images[self.index]) 226 | image = self.transform(image).unsqueeze(0) 227 | gt = self.binary_loader(self.gts[self.index]) 228 | gt = self.gt_transform(gt).unsqueeze(0) 229 | ti = self.rgb_loader(self.tis[self.index]) 230 | ti = self.tis_transform(ti).unsqueeze(0) 231 | 232 | name = self.images[self.index].split('/')[-1] 233 | if name.endswith('.jpg'): 234 | name = name.split('.jpg')[0] + '.png' 235 | self.index += 1 236 | self.index = self.index % self.size 237 | return image, gt, ti,name 238 | 239 | def rgb_loader(self, path): 240 | with open(path, 'rb') as f: 241 | img = Image.open(f) 242 | return img.convert('RGB') 243 | 244 | def binary_loader(self, path): 245 | with open(path, 'rb') as f: 246 | img = Image.open(f) 247 | return img.convert('L') 248 | 249 | def __len__(self): 250 | return self.size 251 | 252 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('./models') 5 | import numpy as np 6 | import os 7 | import cv2 8 | 9 | from LSNet import LSNet 10 | from config import opt 11 | 12 | 13 | 14 | dataset_path = opt.test_path 15 | 16 | #set device for test 17 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 18 | print('USE GPU:', opt.gpu_id) 19 | 20 | #load the model 21 | model = LSNet() 22 | 23 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training. 24 | model.load_state_dict(torch.load('')) 25 | model.cuda() 26 | model.eval() 27 | 28 | 29 | #test 30 | test_mae = [] 31 | if opt.task =='RGBT': 32 | from rgbt_dataset import test_dataset 33 | test_datasets = ['VT800','VT1000','VT5000'] 34 | elif opt.task == 'RGBD': 35 | from rgbd_dataset import test_dataset 36 | test_datasets = ['NJU2K', 'DES', 'LFSD', 'NLPR', 'SIP'] 37 | else: 38 | raise ValueError(f"Unknown task type {opt.task}") 39 | 40 | for dataset in test_datasets: 41 | mae_sum = 0 42 | save_path = '/' + dataset + '/' 43 | if not os.path.exists(save_path): 44 | os.makedirs(save_path) 45 | if opt.task == 'RGBT': 46 | image_root = dataset_path + dataset + '/RGB/' 47 | gt_root = dataset_path + dataset + '/GT/' 48 | ti_root=dataset_path + dataset +'/T/' 49 | elif opt.task == 'RGBD': 50 | image_root = dataset_path + dataset + '/RGB/' 51 | gt_root = dataset_path + dataset + '/GT/' 52 | ti_root = dataset_path + dataset + '/depth/' 53 | else: 54 | raise ValueError(f"Unknown task type {opt.task}") 55 | test_loader = test_dataset(image_root, gt_root, ti_root, opt.testsize) 56 | for i in range(test_loader.size): 57 | image, gt, ti, name = test_loader.load_data() 58 | gt = gt.cuda() 59 | image = image.cuda() 60 | ti = ti.cuda() 61 | if opt.task == 'RGBD': 62 | ti = torch.cat((ti,ti,ti),dim=1) 63 | res = model(image,ti) 64 | predict = torch.sigmoid(res) 65 | predict = (predict - predict.min()) / (predict.max() - predict.min() + 1e-8) 66 | mae = torch.sum(torch.abs(predict - gt)) / torch.numel(gt) 67 | # mae = torch.abs(predict - gt).mean() 68 | mae_sum = mae.item() + mae_sum 69 | predict = predict.data.cpu().numpy().squeeze() 70 | # print(predict.shape) 71 | print('save img to: ',save_path+name) 72 | cv2.imwrite(save_path+name, predict*255) 73 | test_mae.append(mae_sum / test_loader.size) 74 | print('Test Done!', 'MAE', test_mae) 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from datetime import datetime 7 | from torchvision.utils import make_grid 8 | from utils import clip_gradient, adjust_lr 9 | from tensorboardX import SummaryWriter 10 | import logging 11 | import torch.backends.cudnn as cudnn 12 | from config import opt 13 | from torch.cuda import amp 14 | # set the device for training 15 | cudnn.benchmark = True 16 | cudnn.enabled = True 17 | 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id 20 | print('USE GPU:', opt.gpu_id) 21 | 22 | # build the model 23 | from LSNet import LSNet 24 | model = LSNet() 25 | if (opt.load is not None): 26 | model.load_state_dict(torch.load(opt.load)) 27 | print('load model from ', opt.load) 28 | model.cuda() 29 | params = model.parameters() 30 | optimizer = torch.optim.Adam(params, opt.lr) 31 | 32 | # set the path 33 | train_dataset_path = opt.train_root 34 | 35 | val_dataset_path = opt.val_root 36 | 37 | save_path = opt.save_path 38 | 39 | if not os.path.exists(save_path): 40 | os.makedirs(save_path) 41 | 42 | # load data 43 | print('load data...') 44 | if opt.task =='RGBT': 45 | from rgbt_dataset import get_loader, test_dataset 46 | image_root = train_dataset_path + '/RGB/' 47 | ti_root = train_dataset_path + '/T/' 48 | gt_root = train_dataset_path + '/GT/' 49 | val_image_root = val_dataset_path + '/RGB/' 50 | val_ti_root = val_dataset_path + '/T/' 51 | val_gt_root = val_dataset_path + '/GT/' 52 | elif opt.task == 'RGBD': 53 | from rgbd_dataset import get_loader, test_dataset 54 | image_root = train_dataset_path + '/RGB/' 55 | ti_root = train_dataset_path + '/depth/' 56 | gt_root = train_dataset_path + '/GT/' 57 | val_image_root = val_dataset_path + '/RGB/' 58 | val_ti_root = val_dataset_path + '/depth/' 59 | val_gt_root = val_dataset_path + '/GT/' 60 | else: 61 | raise ValueError(f"Unknown task type {opt.task}") 62 | 63 | train_loader = get_loader(image_root, gt_root, ti_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 64 | test_loader = test_dataset(val_image_root, val_gt_root,val_ti_root, opt.trainsize) 65 | total_step = len(train_loader) 66 | # print(total_step) 67 | 68 | logging.basicConfig(filename=save_path + 'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', 69 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p') 70 | logging.info("Model:") 71 | logging.info(model) 72 | 73 | logging.info(save_path + "Train") 74 | logging.info("Config") 75 | logging.info( 76 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format( 77 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path, 78 | opt.decay_epoch)) 79 | 80 | # set loss function 81 | import torch.nn as nn 82 | 83 | class IOUBCE_loss(nn.Module): 84 | def __init__(self): 85 | super(IOUBCE_loss, self).__init__() 86 | self.nll_lose = nn.BCEWithLogitsLoss() 87 | 88 | def forward(self, input_scale, taeget_scale): 89 | b,_,_,_ = input_scale.size() 90 | loss = [] 91 | for inputs, targets in zip(input_scale, taeget_scale): 92 | bce = self.nll_lose(inputs,targets) 93 | pred = torch.sigmoid(inputs) 94 | inter = (pred * targets).sum(dim=(1, 2)) 95 | union = (pred + targets).sum(dim=(1, 2)) 96 | IOU = (inter + 1) / (union - inter + 1) 97 | loss.append(1- IOU + bce) 98 | total_loss = sum(loss) 99 | return total_loss / b 100 | 101 | 102 | CE = torch.nn.BCEWithLogitsLoss().cuda() 103 | IOUBCE = IOUBCE_loss().cuda() 104 | class IOUBCEWithoutLogits_loss(nn.Module): 105 | def __init__(self): 106 | super(IOUBCEWithoutLogits_loss, self).__init__() 107 | self.nll_lose = nn.BCELoss() 108 | 109 | def forward(self, input_scale, target_scale): 110 | b,c,h,w = input_scale.size() 111 | loss = [] 112 | for inputs, targets in zip(input_scale, target_scale): 113 | 114 | bce = self.nll_lose(inputs,targets) 115 | 116 | inter = (inputs * targets).sum(dim=(1, 2)) 117 | union = (inputs + targets).sum(dim=(1, 2)) 118 | IOU = (inter + 1) / (union - inter + 1) 119 | loss.append(1- IOU + bce) 120 | total_loss = sum(loss) 121 | return total_loss / b 122 | IOUBCEWithoutLogits = IOUBCEWithoutLogits_loss().cuda() 123 | 124 | 125 | step = 0 126 | writer = SummaryWriter(save_path + 'summary', flush_secs = 30) 127 | best_mae = 1 128 | best_epoch = 0 129 | Sacler = amp.GradScaler() 130 | 131 | # BBA 132 | def tesnor_bound(img, ksize): 133 | 134 | ''' 135 | :param img: tensor, B*C*H*W 136 | :param ksize: tensor, ksize * ksize 137 | :param 2patches: tensor, B * C * H * W * ksize * ksize 138 | :return: tensor, (inflation - corrosion), B * C * H * W 139 | ''' 140 | 141 | B, C, H, W = img.shape 142 | pad = int((ksize - 1) // 2) 143 | img_pad = F.pad(img, pad=[pad, pad, pad, pad], mode='constant',value = 0) 144 | # unfold in the second and third dimensions 145 | patches = img_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 146 | corrosion, _ = torch.min(patches.contiguous().view(B, C, H, W, -1), dim=-1) 147 | inflation, _ = torch.max(patches.contiguous().view(B, C, H, W, -1), dim=-1) 148 | return inflation - corrosion 149 | 150 | 151 | 152 | # train function 153 | def train(train_loader, model, optimizer, epoch, save_path): 154 | global step 155 | model.train() 156 | loss_all = 0 157 | epoch_step = 0 158 | try: 159 | for i, (images, gts, tis) in enumerate(train_loader, start=1): 160 | optimizer.zero_grad() 161 | images = images.cuda() 162 | tis = tis.cuda() 163 | gts = gts.cuda() 164 | if opt.task == 'RGBD': 165 | tis = torch.cat((tis, tis, tis), dim=1) 166 | 167 | gts2 = F.interpolate(gts, (112, 112)) 168 | gts3 = F.interpolate(gts, (56, 56)) 169 | 170 | 171 | bound = tesnor_bound(gts, 3).cuda() 172 | bound2 = F.interpolate(bound, (112, 112)) 173 | bound3 = F.interpolate(bound, (56, 56)) 174 | 175 | out = model(images, tis) 176 | 177 | 178 | loss1 = IOUBCE(out[0], gts) 179 | loss2 = IOUBCE(out[1], gts2) 180 | loss3 = IOUBCE(out[2], gts3) 181 | 182 | predict_bound0 = out[0] 183 | predict_bound1 = out[1] 184 | predict_bound2 = out[2] 185 | predict_bound0 = tesnor_bound(torch.sigmoid(predict_bound0), 3) 186 | predict_bound1 = tesnor_bound(torch.sigmoid(predict_bound1), 3) 187 | predict_bound2 = tesnor_bound(torch.sigmoid(predict_bound2), 3) 188 | loss6 = IOUBCEWithoutLogits(predict_bound0, bound) 189 | loss7 = IOUBCEWithoutLogits(predict_bound1, bound2) 190 | loss8 = IOUBCEWithoutLogits(predict_bound2, bound3) 191 | 192 | 193 | loss_sod = loss1 + loss2 + loss3 194 | loss_bound = loss6 + loss7 + loss8 195 | loss_trans = out[3] 196 | loss = loss_sod + loss_bound + loss_trans 197 | loss.backward() 198 | optimizer.step() 199 | step = step + 1 200 | epoch_step = epoch_step + 1 201 | loss_all = loss.item() + loss_all 202 | if i % 10 == 0 or i == total_step or i == 1: 203 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, loss_sod: {:.4f},' 204 | 'loss_bound: {:.4f},loss_trans: {:.4f}'. 205 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss.item(), 206 | loss_sod.item(),loss_bound.item(), loss_trans.item())) 207 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, loss_sod: {:.4f},' 208 | 'loss_bound: {:.4f},loss_trans: {:.4f} '. 209 | format(epoch, opt.epoch, i, total_step, loss.item(), 210 | loss_sod.item(),loss_bound.item(), loss_trans.item())) 211 | writer.add_scalar('Loss', loss, global_step=step) 212 | # grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True) 213 | # writer.add_image('train/RGB', grid_image, step) 214 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True) 215 | writer.add_image('train/Ground_truth', grid_image, step) 216 | grid_image = make_grid(bound[0].clone().cpu().data, 1, normalize=True) 217 | writer.add_image('train/bound', grid_image, step) 218 | 219 | # grid_image = make_grid(body[0].clone().cpu().data, 1, normalize=True) 220 | # writer.add_image('train/body', grid_image, step) 221 | res = out[0][0].clone() 222 | res = res.sigmoid().data.cpu().numpy().squeeze() 223 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 224 | writer.add_image('OUT/out', torch.tensor(res), step, dataformats='HW') 225 | res = predict_bound0[0].clone() 226 | res = res.sigmoid().data.cpu().numpy().squeeze() 227 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 228 | writer.add_image('OUT/bound', torch.tensor(res), step, dataformats='HW') 229 | 230 | 231 | loss_all /= epoch_step 232 | # logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all)) 233 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 234 | if (epoch) % 5 == 0: 235 | torch.save(model.state_dict(), save_path + 'Net_epoch_{}.pth'.format(epoch)) 236 | except KeyboardInterrupt: 237 | print('Keyboard Interrupt: save model and exit.') 238 | if not os.path.exists(save_path): 239 | os.makedirs(save_path) 240 | torch.save(model.state_dict(), save_path + 'Net_epoch_{}.pth'.format(epoch + 1)) 241 | print('save checkpoints successfully!') 242 | raise 243 | 244 | 245 | # test function 246 | def test(test_loader, model, epoch, save_path): 247 | global best_mae, best_epoch 248 | model.eval() 249 | with torch.no_grad(): 250 | mae_sum = 0 251 | for i in range(test_loader.size): 252 | image, gt, ti, name = test_loader.load_data() 253 | gt = gt.cuda() 254 | image = image.cuda() 255 | ti = ti.cuda() 256 | if opt.task == 'RGBD': 257 | tis = torch.cat((tis, tis, tis), dim=1) 258 | 259 | res = model(image, ti) 260 | res = torch.sigmoid(res) 261 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 262 | mae_train = torch.sum(torch.abs(res - gt)) * 1.0 / (torch.numel(gt)) 263 | # print(mae_train) 264 | mae_sum = mae_train.item() + mae_sum 265 | # print(test_loader.size) 266 | mae = mae_sum / test_loader.size 267 | # print(test_loader.size) 268 | writer.add_scalar('MAE', torch.as_tensor(mae), global_step=epoch) 269 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch)) 270 | if epoch == 1: 271 | best_mae = mae 272 | else: 273 | if mae < best_mae: 274 | best_mae = mae 275 | best_epoch = epoch 276 | torch.save(model.state_dict(), save_path + 'Net_epoch_best.pth') 277 | print('best epoch:{}'.format(epoch)) 278 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae)) 279 | 280 | 281 | if __name__ == '__main__': 282 | print("Start train...") 283 | for epoch in range(1, opt.epoch+1): 284 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 285 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch) 286 | train(train_loader, model, optimizer, epoch, save_path) 287 | test(test_loader, model, epoch, save_path) 288 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | 8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 9 | decay = decay_rate ** (epoch // decay_epoch) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = decay*init_lr 12 | lr=param_group['lr'] 13 | return lr 14 | --------------------------------------------------------------------------------