├── LSNet.py
├── README.md
├── config.py
├── mobilenetv2.py
├── requirements.txt
├── rgbd_dataset.py
├── rgbt_dataset.py
├── test.py
├── train.py
└── utils.py
/LSNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.functional as F
4 |
5 | class AFD_semantic(nn.Module):
6 | '''
7 | Pay Attention to Features, Transfer Learn Faster CNNs
8 | https://openreview.net/pdf?id=ryxyCeHtPB
9 | '''
10 |
11 | def __init__(self, in_channels, att_f):
12 | super(AFD_semantic, self).__init__()
13 | mid_channels = int(in_channels * att_f)
14 |
15 | self.attention = nn.Sequential(*[
16 | nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=True),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(mid_channels, in_channels, 3, 1, 1, bias=True)
19 | ])
20 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
21 |
22 | for m in self.modules():
23 | if isinstance(m, nn.Conv2d):
24 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
25 | if m.bias is not None:
26 | nn.init.constant_(m.bias, 0)
27 |
28 | def forward(self, fm_s, fm_t, eps=1e-6):
29 |
30 | fm_t_pooled = self.avg_pool(fm_t)
31 | rho = self.attention(fm_t_pooled)
32 | rho = torch.sigmoid(rho.squeeze())
33 | rho = rho / torch.sum(rho, dim=1, keepdim=True)
34 |
35 | fm_s_norm = torch.norm(fm_s, dim=(2, 3), keepdim=True)
36 | fm_s = torch.div(fm_s, fm_s_norm + eps)
37 | fm_t_norm = torch.norm(fm_t, dim=(2, 3), keepdim=True)
38 | fm_t = torch.div(fm_t, fm_t_norm + eps)
39 |
40 | loss = rho * torch.pow(fm_s - fm_t, 2).mean(dim=(2, 3))
41 | loss = loss.sum(1).mean(0)
42 |
43 | return loss
44 |
45 |
46 | class AFD_spatial(nn.Module):
47 | '''
48 | Pay Attention to Features, Transfer Learn Faster CNNs
49 | https://openreview.net/pdf?id=ryxyCeHtPB
50 | '''
51 |
52 | def __init__(self, in_channels):
53 | super(AFD_spatial, self).__init__()
54 |
55 | self.attention = nn.Sequential(*[
56 | nn.Conv2d(in_channels, 1, 3, 1, 1)
57 | ])
58 |
59 | for m in self.modules():
60 | if isinstance(m, nn.Conv2d):
61 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
62 | if m.bias is not None:
63 | nn.init.constant_(m.bias, 0)
64 |
65 | def forward(self, fm_s, fm_t, eps=1e-6):
66 |
67 | rho = self.attention(fm_t)
68 | rho = torch.sigmoid(rho)
69 | rho = rho / torch.sum(rho, dim=(2,3), keepdim=True)
70 |
71 | fm_s_norm = torch.norm(fm_s, dim=1, keepdim=True)
72 | fm_s = torch.div(fm_s, fm_s_norm + eps)
73 | fm_t_norm = torch.norm(fm_t, dim=1, keepdim=True)
74 | fm_t = torch.div(fm_t, fm_t_norm + eps)
75 | loss = rho * torch.pow(fm_s - fm_t, 2).mean(dim=1, keepdim=True)
76 | loss =torch.sum(loss,dim=(2,3)).mean(0)
77 | return loss
78 |
79 | from mobilenetv2 import mobilenet_v2
80 | class LSNet(nn.Module):
81 | def __init__(self):
82 | super(LSNet, self).__init__()
83 | # rgb,depth encode
84 | self.rgb_pretrained = mobilenet_v2()
85 | self.depth_pretrained = mobilenet_v2()
86 |
87 | # Upsample_model
88 | self.upsample1_g = nn.Sequential(nn.Conv2d(68, 34, 3, 1, 1, ), nn.BatchNorm2d(34), nn.GELU(),
89 | nn.UpsamplingBilinear2d(scale_factor=2, ))
90 |
91 | self.upsample2_g = nn.Sequential(nn.Conv2d(104, 52, 3, 1, 1, ), nn.BatchNorm2d(52), nn.GELU(),
92 | nn.UpsamplingBilinear2d(scale_factor=2, ))
93 |
94 | self.upsample3_g = nn.Sequential(nn.Conv2d(160, 80, 3, 1, 1, ), nn.BatchNorm2d(80), nn.GELU(),
95 | nn.UpsamplingBilinear2d(scale_factor=2, ))
96 |
97 | self.upsample4_g = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1, ), nn.BatchNorm2d(128), nn.GELU(),
98 | nn.UpsamplingBilinear2d(scale_factor=2, ))
99 |
100 | self.upsample5_g = nn.Sequential(nn.Conv2d(320, 160, 3, 1, 1, ), nn.BatchNorm2d(160), nn.GELU(),
101 | nn.UpsamplingBilinear2d(scale_factor=2, ))
102 |
103 |
104 | self.conv_g = nn.Conv2d(34, 1, 1)
105 | self.conv2_g = nn.Conv2d(52, 1, 1)
106 | self.conv3_g = nn.Conv2d(80, 1, 1)
107 |
108 |
109 | # Tips: speed test and params and more this part is not included.
110 | # please comment this part when involved.
111 | if self.training:
112 | self.AFD_semantic_5_R_T = AFD_semantic(320,0.0625)
113 | self.AFD_semantic_4_R_T = AFD_semantic(96,0.0625)
114 | self.AFD_semantic_3_R_T = AFD_semantic(32,0.0625)
115 | self.AFD_spatial_3_R_T = AFD_spatial(32)
116 | self.AFD_spatial_2_R_T = AFD_spatial(24)
117 | self.AFD_spatial_1_R_T = AFD_spatial(16)
118 |
119 |
120 | def forward(self, rgb, ti):
121 | # rgb
122 | A1, A2, A3, A4, A5 = self.rgb_pretrained(rgb)
123 | # ti
124 | A1_t, A2_t, A3_t, A4_t, A5_t = self.depth_pretrained(ti)
125 |
126 | F5 = A5_t + A5
127 | F4 = A4_t + A4
128 | F3 = A3_t + A3
129 | F2 = A2_t + A2
130 | F1 = A1_t + A1
131 |
132 |
133 | F5 = self.upsample5_g(F5)
134 | F4 = torch.cat((F4, F5), dim=1)
135 | F4 = self.upsample4_g(F4)
136 | F3 = torch.cat((F3, F4), dim=1)
137 | F3 = self.upsample3_g(F3)
138 | F2 = torch.cat((F2, F3), dim=1)
139 | F2 = self.upsample2_g(F2)
140 | F1 = torch.cat((F1, F2), dim=1)
141 | F1 = self.upsample1_g(F1)
142 |
143 | out = self.conv_g(F1)
144 |
145 |
146 | if self.training:
147 | out3 = self.conv3_g(F3)
148 | out2 = self.conv2_g(F2)
149 | loss_semantic_5_R_T = self.AFD_semantic_5_R_T(A5, A5_t.detach())
150 | loss_semantic_5_T_R = self.AFD_semantic_5_R_T(A5_t, A5.detach())
151 | loss_semantic_4_R_T = self.AFD_semantic_4_R_T(A4, A4_t.detach())
152 | loss_semantic_4_T_R = self.AFD_semantic_4_R_T(A4_t, A4.detach())
153 | loss_semantic_3_R_T = self.AFD_semantic_3_R_T(A3, A3_t.detach())
154 | loss_semantic_3_T_R = self.AFD_semantic_3_R_T(A3_t, A3.detach())
155 | loss_spatial_3_R_T = self.AFD_spatial_3_R_T(A3, A3_t.detach())
156 | loss_spatial_3_T_R = self.AFD_spatial_3_R_T(A3_t, A3.detach())
157 | loss_spatial_2_R_T = self.AFD_spatial_2_R_T(A2, A2_t.detach())
158 | loss_spatial_2_T_R = self.AFD_spatial_2_R_T(A2_t, A2.detach())
159 | loss_spatial_1_R_T = self.AFD_spatial_1_R_T(A1, A1_t.detach())
160 | loss_spatial_1_T_R = self.AFD_spatial_1_R_T(A1_t, A1.detach())
161 | loss_KD = loss_semantic_5_R_T + loss_semantic_5_T_R + \
162 | loss_semantic_4_R_T + loss_semantic_4_T_R + \
163 | loss_semantic_3_R_T + loss_semantic_3_T_R + \
164 | loss_spatial_3_R_T + loss_spatial_3_T_R + \
165 | loss_spatial_2_R_T + loss_spatial_2_T_R + \
166 | loss_spatial_1_R_T + loss_spatial_1_T_R
167 | return out, out2, out3, loss_KD
168 | return out
169 |
170 |
171 |
172 |
173 |
174 |
175 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LSNet
2 | This project provides the code and results for 'LSNet: Lightweight Spatial Boosting Network for Detecting Salient Objects in RGB-Thermal Images', IEEE TIP, 2023. [IEEE link](https://ieeexplore.ieee.org/document/10042233)
3 |
4 | # Requirements
5 | Python 3.7+, Pytorch 1.5.0+, Cuda 10.2+, TensorboardX 2.1, opencv-python
6 | If anything goes wrong with the environment, please check requirements.txt for details.
7 |
8 | # Architecture and Details
9 | 
10 |
11 |
12 | # Results
13 |
14 |
15 |
16 |
17 |
18 | # Data Preparation
19 | - Download the RGB-T raw data from [baidu](https://pan.baidu.com/s/1fDht3BmqIYPks_iquST5hQ), pin: sf9y / [Google drive](https://drive.google.com/file/d/1vjdD13DTh9mM69mRRRdFBbpWbmj6MSKj/view?usp=share_link)
20 | - Download the RGB-D raw data from [baidu](https://pan.baidu.com/s/1A-fwxAtnwMPuznn1PCATWg), pin: 7pi5 / [Google drive](https://drive.google.com/file/d/1WzTuHQJCKPE5OreanoU0N2e82Y1_VZyA/view?usp=share_link)
21 |
22 | Note that the depth maps of the raw data above are foreground is white.
23 | # Training & Testing
24 | modify the `train_root` `train_root` `save_path` path in `config.py` according to your own data path.
25 | - Train the LSNet:
26 |
27 | `python train.py`
28 |
29 | modify the `test_path` path in `config.py` according to your own data path.
30 |
31 | - Test the LSNet:
32 |
33 | `python test.py`
34 |
35 | Note that `task` in `config.py` determines which task and dataset to use.
36 |
37 | # Evaluate tools
38 | - You can select one of toolboxes to get the metrics
39 | [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics)
40 |
41 | # Saliency Maps
42 | - RGB-T [baidu](https://pan.baidu.com/s/1i5GwM0C0OfE5D5VLXlBkVA) pin: fxsk / [Google drive](https://drive.google.com/file/d/1ATEw8cNLHYfuCAK40VUBzcqBnMOKw-OV/view?usp=sharing)
43 | - RGB-D [baidu](https://pan.baidu.com/s/1bAlk753MeeRG0BLMJXAzxQ) pin: 6352 / [Google drive](https://drive.google.com/file/d/1WgQlcVWg_YC4_64TaIn8JSWuzZC_FfhW/view?usp=sharing)
44 |
45 | Note that we resize the testing data to the size of 224 * 224 for quicky evaluate.
46 | please check our previous works [APNet](https://github.com/zyrant/APNet) and [CCAFNet](https://github.com/zyrant/CCAFNet).
47 |
48 | # Pretraining Models
49 | - RGB-T [baidu](https://pan.baidu.com/s/1aGP283gNpb3oosvbq4OSDg) pin: wnoa / [Google drive](https://drive.google.com/drive/folders/17xmRA5zhLeIIS_-1EXbhxhPoW-Xn40xl?usp=sharing)
50 | - RGB-D [baidu](https://pan.baidu.com/s/1aGP283gNpb3oosvbq4OSDg) pin: wnoa / [Google drive](https://drive.google.com/drive/folders/17xmRA5zhLeIIS_-1EXbhxhPoW-Xn40xl?usp=sharing)
51 |
52 | # Citation
53 | @ARTICLE{Zhou_2023_LSNet,
54 | author={Zhou, Wujie and Zhu, Yun and Lei, Jingsheng and Yang, Rongwang and Yu, Lu},
55 | journal={IEEE Transactions on Image Processing},
56 | title={LSNet: Lightweight Spatial Boosting Network for Detecting Salient Objects in RGB-Thermal Images},
57 | year={2023},
58 | volume={32},
59 | number={},
60 | pages={1329-1340},
61 | doi={10.1109/TIP.2023.3242775}}
62 |
63 | # Acknowledgement
64 | The implement of this project is based on the codebases bellow.
65 | - [BBS-Net](https://github.com/zyjwuyan/BBS-Net)
66 | - [Knowledge-Distillation-Zoo](https://github.com/AberHu/Knowledge-Distillation-Zoo)
67 | - Fps/speed test [MobileSal](https://github.com/yuhuan-wu/MobileSal/blob/master/speed_test.py)
68 | - Evaluate tools [CODToolbox](https://github.com/DengPingFan/CODToolbox) / [PySODMetrics](https://github.com/lartpang/PySODMetrics)
69 |
70 | If you find this project helpful, Please also cite codebases above.
71 |
72 | # Contact
73 | Please drop me an email for any problems or discussion: https://wujiezhou.github.io/ (wujiezhou@163.com) or zzzyylink@gmail.com.
74 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | # train/val
4 | parser.add_argument('--task', type=str, default='RGBT', help='epoch number')
5 | parser.add_argument('--epoch', type=int, default=20, help='epoch number')
6 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
7 | parser.add_argument('--batchsize', type=int, default=10, help='training batch size')
8 | parser.add_argument('--trainsize', type=int, default=224, help='training dataset size')
9 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
10 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
11 | parser.add_argument('--decay_epoch', type=int, default=40, help='every n epochs decay learning rate')
12 | parser.add_argument('--load', type=str, default=None, help='train from checkpoints')
13 | parser.add_argument('--gpu_id', type=str, default='0', help='select gpu id')
14 | parser.add_argument('--train_root', type=str, default='', help='the train images root')
15 | parser.add_argument('--val_root', type=str, default='', help='the val images root')
16 | parser.add_argument('--save_path', type=str, default='', help='the path to save models and logs')
17 | # test(predict)
18 | parser.add_argument('--testsize', type=int, default=224, help='testing size')
19 | parser.add_argument('--test_path',type=str,default='',help='test dataset path')
20 | opt = parser.parse_args()
21 |
--------------------------------------------------------------------------------
/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 | __all__ = ['MobileNetV2', 'mobilenet_v2']
6 |
7 |
8 | model_urls = {
9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
10 | }
11 |
12 |
13 | def _make_divisible(v, divisor, min_value=None):
14 | """
15 | This function is taken from the original tf repo.
16 | It ensures that all layers have a channel number that is divisible by 8
17 | It can be seen here:
18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
19 | :param v:
20 | :param divisor:
21 | :param min_value:
22 | :return:
23 | """
24 | if min_value is None:
25 | min_value = divisor
26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
27 | # Make sure that round down does not go down by more than 10%.
28 | if new_v < 0.9 * v:
29 | new_v += divisor
30 | return new_v
31 |
32 |
33 | class ConvBNReLU(nn.Sequential):
34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
35 | padding = (kernel_size - 1) // 2
36 | if norm_layer is None:
37 | norm_layer = nn.BatchNorm2d
38 | super(ConvBNReLU, self).__init__(
39 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
40 | norm_layer(out_planes),
41 | nn.ReLU6(inplace=True)
42 | )
43 |
44 |
45 | class InvertedResidual(nn.Module):
46 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
47 | super(InvertedResidual, self).__init__()
48 | self.stride = stride
49 | assert stride in [1, 2]
50 |
51 | if norm_layer is None:
52 | norm_layer = nn.BatchNorm2d
53 |
54 | hidden_dim = int(round(inp * expand_ratio))
55 | self.use_res_connect = self.stride == 1 and inp == oup
56 |
57 | layers = []
58 | if expand_ratio != 1:
59 | # pw
60 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
61 | layers.extend([
62 | # dw
63 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
64 | # pw-linear
65 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
66 | norm_layer(oup),
67 | ])
68 | self.conv = nn.Sequential(*layers)
69 |
70 | def forward(self, x):
71 | if self.use_res_connect:
72 | return x + self.conv(x)
73 | else:
74 | return self.conv(x)
75 |
76 |
77 | class MobileNetV2(nn.Module):
78 | def __init__(self,
79 | num_classes=1000,
80 | width_mult=1.0,
81 | inverted_residual_setting=None,
82 | round_nearest=8,
83 | block=None,
84 | norm_layer=None):
85 | """
86 | MobileNet V2 main class
87 |
88 | Args:
89 | num_classes (int): Number of classes
90 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
91 | inverted_residual_setting: Network structure
92 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
93 | Set to 1 to turn off rounding
94 | block: Module specifying inverted residual building block for mobilenet
95 | norm_layer: Module specifying the normalization layer to use
96 |
97 | """
98 | super(MobileNetV2, self).__init__()
99 |
100 | if block is None:
101 | block = InvertedResidual
102 |
103 | if norm_layer is None:
104 | norm_layer = nn.BatchNorm2d
105 |
106 | input_channel = 32
107 | last_channel = 1280
108 |
109 | if inverted_residual_setting is None:
110 | inverted_residual_setting = [
111 | # t, c, n, s
112 | [1, 16, 1, 1],
113 | [6, 24, 2, 2],
114 | [6, 32, 3, 2],
115 | [6, 64, 4, 2],
116 | [6, 96, 3, 1],
117 | [6, 160, 3, 2],
118 | [6, 320, 1, 1],
119 | ]
120 |
121 | # only check the first element, assuming user knows t,c,n,s are required
122 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
123 | raise ValueError("inverted_residual_setting should be non-empty "
124 | "or a 4-element list, got {}".format(inverted_residual_setting))
125 |
126 | # building first layer
127 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
128 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
129 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
130 | # building inverted residual blocks
131 | for t, c, n, s in inverted_residual_setting:
132 | output_channel = _make_divisible(c * width_mult, round_nearest)
133 | for i in range(n):
134 | stride = s if i == 0 else 1
135 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
136 | input_channel = output_channel
137 | # building last several layers
138 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
139 | # make it nn.Sequential
140 | self.features = nn.Sequential(*features)
141 |
142 | # building classifier
143 | # self.classifier = nn.Sequential(
144 | # nn.Dropout(0.2),
145 | # nn.Linear(self.last_channel, num_classes),
146 | # )
147 |
148 | # weight initialization
149 | for m in self.modules():
150 | if isinstance(m, nn.Conv2d):
151 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
152 | if m.bias is not None:
153 | nn.init.zeros_(m.bias)
154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
155 | nn.init.ones_(m.weight)
156 | nn.init.zeros_(m.bias)
157 | elif isinstance(m, nn.Linear):
158 | nn.init.normal_(m.weight, 0, 0.01)
159 | nn.init.zeros_(m.bias)
160 |
161 | def _forward_impl(self, x):
162 | # print(x.shape)
163 | # This exists since TorchScript doesn't support inheritance, so the superclass method
164 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
165 | x = self.features[:2](x)
166 | out1 = x
167 | x = self.features[2:3](x)
168 | out2 = x
169 | x = self.features[3:7](x)
170 | out3 = x
171 | x = self.features[7:14](x)
172 | out4 = x
173 | x = self.features[14:18](x)
174 | out5 = x
175 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
176 | # x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
177 | # x = self.classifier(x)
178 | return out1, out2, out3, out4, out5
179 |
180 | def forward(self, x):
181 | return self._forward_impl(x)
182 |
183 |
184 | def mobilenet_v2(pretrained=True, **kwargs):
185 | """
186 | Constructs a MobileNetV2 architecture from
187 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | progress (bool): If True, displays a progress bar of the download to stderr
192 | """
193 | model = MobileNetV2(**kwargs)
194 | # if pretrained:
195 | # print('v2 pretrained loading....')
196 | # state_dict = model_zoo.load_url(model_urls['mobilenet_v2'])
197 | # model.load_state_dict(state_dict)
198 | if pretrained:
199 | pretrained_vgg = model_zoo.load_url(model_urls['mobilenet_v2'])
200 | model_dict = {}
201 | state_dict = model.state_dict()
202 | for k, v in pretrained_vgg.items():
203 | if k in state_dict:
204 | model_dict[k] = v
205 | # print(k, v)
206 |
207 | state_dict.update(model_dict)
208 | model.load_state_dict(state_dict)
209 | return model
210 |
211 |
212 | if __name__=='__main__':
213 |
214 | # model = ghost_net()
215 | # model.eval()
216 | import torch
217 | model = mobilenet_v2()
218 | rgb = torch.randn(1, 3, 224, 224)
219 | out = model(rgb)
220 | for i in out:
221 | print(i.shape)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.14.0
2 | addict==2.4.0
3 | alabaster @ file:///home/ktietz/src/ci/alabaster_1611921544520/work
4 | anaconda-client @ file:///tmp/build/80754af9/anaconda-client_1624473988214/work
5 | anaconda-navigator==2.0.3
6 | anaconda-project @ file:///tmp/build/80754af9/anaconda-project_1621348054992/work
7 | anyio @ file:///tmp/build/80754af9/anyio_1617783275907/work/dist
8 | appdirs==1.4.4
9 | argh==0.26.2
10 | argon2-cffi @ file:///tmp/build/80754af9/argon2-cffi_1613037097816/work
11 | asn1crypto @ file:///tmp/build/80754af9/asn1crypto_1596577642040/work
12 | astroid @ file:///tmp/build/80754af9/astroid_1625075819965/work
13 | astropy @ file:///tmp/build/80754af9/astropy_1617745353437/work
14 | asttokens==2.0.5
15 | async-generator @ file:///home/ktietz/src/ci/async_generator_1611927993394/work
16 | atomicwrites==1.4.0
17 | attrs @ file:///tmp/build/80754af9/attrs_1620827162558/work
18 | autopep8 @ file:///tmp/build/80754af9/autopep8_1615918855173/work
19 | Babel @ file:///tmp/build/80754af9/babel_1620871417480/work
20 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
21 | backports.functools-lru-cache @ file:///tmp/build/80754af9/backports.functools_lru_cache_1618170165463/work
22 | backports.shutil-get-terminal-size @ file:///tmp/build/80754af9/backports.shutil_get_terminal_size_1608222128777/work
23 | backports.tempfile @ file:///home/linux1/recipes/ci/backports.tempfile_1610991236607/work
24 | backports.weakref==1.0.post1
25 | beautifulsoup4 @ file:///home/linux1/recipes/ci/beautifulsoup4_1610988766420/work
26 | bitarray @ file:///tmp/build/80754af9/bitarray_1620827551536/work
27 | bkcharts==0.2
28 | black==19.10b0
29 | bleach @ file:///tmp/build/80754af9/bleach_1612211392645/work
30 | bokeh @ file:///tmp/build/80754af9/bokeh_1620779595936/work
31 | boto==2.49.0
32 | Bottleneck==1.3.2
33 | brotlipy==0.7.0
34 | cachetools==4.2.2
35 | certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi
36 | cffi @ file:///tmp/build/80754af9/cffi_1613246945912/work
37 | chardet @ file:///tmp/build/80754af9/chardet_1607706746162/work
38 | click @ file:///tmp/build/80754af9/click_1621604852318/work
39 | cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work
40 | clyent==1.2.2
41 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
42 | conda==4.11.0
43 | conda-build==3.21.4
44 | conda-content-trust @ file:///tmp/build/80754af9/conda-content-trust_1617045594566/work
45 | conda-pack @ file:///tmp/build/80754af9/conda-pack_1611163042455/work
46 | conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1618262148928/work
47 | conda-repo-cli @ file:///tmp/build/80754af9/conda-repo-cli_1620168426516/work
48 | conda-token @ file:///tmp/build/80754af9/conda-token_1620076980546/work
49 | conda-verify==3.4.2
50 | contextlib2==0.6.0.post1
51 | cryptography @ file:///tmp/build/80754af9/cryptography_1616769286105/work
52 | cycler==0.10.0
53 | Cython @ file:///tmp/build/80754af9/cython_1618435160151/work
54 | cytoolz==0.11.0
55 | dask @ file:///tmp/build/80754af9/dask-core_1624381970968/work
56 | dataclasses==0.6
57 | decorator @ file:///home/ktietz/src/ci/decorator_1611930055503/work
58 | defusedxml @ file:///tmp/build/80754af9/defusedxml_1615228127516/work
59 | diff-match-patch @ file:///tmp/build/80754af9/diff-match-patch_1594828741838/work
60 | distributed @ file:///tmp/build/80754af9/distributed_1624589265858/work
61 | docutils @ file:///tmp/build/80754af9/docutils_1620827984873/work
62 | dtaidistance==2.3.2
63 | easydict==1.9
64 | einops==0.3.2
65 | entrypoints==0.3
66 | et-xmlfile==1.1.0
67 | fastcache==1.1.0
68 | filelock @ file:///home/linux1/recipes/ci/filelock_1610993975404/work
69 | flake8 @ file:///tmp/build/80754af9/flake8_1615834841867/work
70 | Flask @ file:///home/ktietz/src/ci/flask_1611932660458/work
71 | fsspec @ file:///tmp/build/80754af9/fsspec_1623705546600/work
72 | future==0.18.2
73 | gevent @ file:///tmp/build/80754af9/gevent_1616770671827/work
74 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
75 | gmpy2==2.0.8
76 | google-auth==1.35.0
77 | google-auth-oauthlib==0.4.6
78 | greenlet @ file:///tmp/build/80754af9/greenlet_1620913319000/work
79 | grpcio==1.40.0
80 | h5py==2.10.0
81 | HeapDict==1.0.1
82 | html5lib @ file:///tmp/build/80754af9/html5lib_1593446221756/work
83 | idna @ file:///home/linux1/recipes/ci/idna_1610986105248/work
84 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work
85 | imagesize @ file:///home/ktietz/src/ci/imagesize_1611921604382/work
86 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1617874469820/work
87 | iniconfig @ file:///home/linux1/recipes/ci/iniconfig_1610983019677/work
88 | intervaltree @ file:///tmp/build/80754af9/intervaltree_1598376443606/work
89 | ipykernel @ file:///tmp/build/80754af9/ipykernel_1596207638929/work/dist/ipykernel-5.3.4-py3-none-any.whl
90 | ipython @ file:///tmp/build/80754af9/ipython_1617120885885/work
91 | ipython-genutils @ file:///tmp/build/80754af9/ipython_genutils_1606773439826/work
92 | ipywidgets @ file:///tmp/build/80754af9/ipywidgets_1610481889018/work
93 | isort @ file:///tmp/build/80754af9/isort_1624300337312/work
94 | itsdangerous @ file:///tmp/build/80754af9/itsdangerous_1621432558163/work
95 | jdcal==1.4.1
96 | jedi @ file:///tmp/build/80754af9/jedi_1606932564285/work
97 | jeepney @ file:///tmp/build/80754af9/jeepney_1606148855031/work
98 | Jinja2 @ file:///tmp/build/80754af9/jinja2_1612213139570/work
99 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work
100 | json5 @ file:///tmp/build/80754af9/json5_1624432770122/work
101 | jsonschema @ file:///tmp/build/80754af9/jsonschema_1602607155483/work
102 | jupyter==1.0.0
103 | jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1616770841739/work
104 | jupyter-console @ file:///tmp/build/80754af9/jupyter_console_1616615302928/work
105 | jupyter-core @ file:///tmp/build/80754af9/jupyter_core_1612213311222/work
106 | jupyter-packaging @ file:///tmp/build/80754af9/jupyter-packaging_1613502826984/work
107 | jupyter-server @ file:///tmp/build/80754af9/jupyter_server_1616083640759/work
108 | jupyterlab @ file:///tmp/build/80754af9/jupyterlab_1619133235951/work
109 | jupyterlab-pygments @ file:///tmp/build/80754af9/jupyterlab_pygments_1601490720602/work
110 | jupyterlab-server @ file:///tmp/build/80754af9/jupyterlab_server_1617134334258/work
111 | jupyterlab-widgets @ file:///tmp/build/80754af9/jupyterlab_widgets_1609884341231/work
112 | keyring @ file:///tmp/build/80754af9/keyring_1621524402652/work
113 | kiwisolver @ file:///tmp/build/80754af9/kiwisolver_1612282420641/work
114 | lazy-object-proxy @ file:///tmp/build/80754af9/lazy-object-proxy_1616526917483/work
115 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
116 | llvmlite==0.36.0
117 | locket==0.2.1
118 | lxml @ file:///tmp/build/80754af9/lxml_1616443220220/work
119 | Markdown==3.3.4
120 | MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528148836/work
121 | matplotlib @ file:///tmp/build/80754af9/matplotlib-suite_1613407855456/work
122 | mccabe==0.6.1
123 | mindspore==1.5.1
124 | mistune==0.8.4
125 | mkl-fft==1.3.0
126 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853849286/work
127 | mkl-service==2.3.0
128 | mmcv==1.4.0
129 | mmcv-full==1.4.0
130 | mock @ file:///tmp/build/80754af9/mock_1607622725907/work
131 | more-itertools @ file:///tmp/build/80754af9/more-itertools_1622818384463/work
132 | mpmath==1.2.1
133 | msgpack @ file:///tmp/build/80754af9/msgpack-python_1612287151062/work
134 | multipledispatch==0.6.0
135 | mypy-extensions==0.4.3
136 | navigator-updater==0.2.1
137 | nbclassic @ file:///tmp/build/80754af9/nbclassic_1616085367084/work
138 | nbclient @ file:///tmp/build/80754af9/nbclient_1614364831625/work
139 | nbconvert @ file:///tmp/build/80754af9/nbconvert_1624479060632/work
140 | nbformat @ file:///tmp/build/80754af9/nbformat_1617383369282/work
141 | nest-asyncio @ file:///tmp/build/80754af9/nest-asyncio_1613680548246/work
142 | networkx @ file:///tmp/build/80754af9/networkx_1617653298338/work
143 | nltk @ file:///tmp/build/80754af9/nltk_1621347441292/work
144 | nose @ file:///tmp/build/80754af9/nose_1606773131901/work
145 | notebook @ file:///tmp/build/80754af9/notebook_1621528346532/work
146 | numba @ file:///tmp/build/80754af9/numba_1616774046117/work
147 | numexpr @ file:///tmp/build/80754af9/numexpr_1618856167419/work
148 | numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620830962040/work
149 | numpydoc @ file:///tmp/build/80754af9/numpydoc_1605117425582/work
150 | nvidia-cublas-cu11==11.10.3.66
151 | nvidia-cuda-runtime-cu11==11.8.89
152 | nvidia-cudnn-cu11==8.6.0.163
153 | nvidia-pyindex==1.0.9
154 | nvidia-tensorrt==8.4.1.5
155 | oauthlib==3.1.1
156 | olefile==0.46
157 | opencv-python==4.5.2.54
158 | openpyxl @ file:///tmp/build/80754af9/openpyxl_1615411699337/work
159 | packaging @ file:///tmp/build/80754af9/packaging_1611952188834/work
160 | pandas==1.2.5
161 | pandocfilters @ file:///tmp/build/80754af9/pandocfilters_1605120460739/work
162 | parso==0.7.0
163 | partd @ file:///tmp/build/80754af9/partd_1618000087440/work
164 | path @ file:///tmp/build/80754af9/path_1623603875173/work
165 | pathlib2 @ file:///tmp/build/80754af9/pathlib2_1625585678054/work
166 | pathspec==0.7.0
167 | pathtools==0.1.2
168 | patsy==0.5.1
169 | pep8==1.7.1
170 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
171 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
172 | Pillow @ file:///tmp/build/80754af9/pillow_1617383569452/work
173 | pkginfo==1.7.0
174 | pluggy @ file:///tmp/build/80754af9/pluggy_1615976321666/work
175 | ply==3.11
176 | prometheus-client @ file:///tmp/build/80754af9/prometheus_client_1623189609245/work
177 | prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1616415428029/work
178 | protobuf==3.17.3
179 | psutil @ file:///tmp/build/80754af9/psutil_1612298023621/work
180 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
181 | py @ file:///tmp/build/80754af9/py_1607971587848/work
182 | pyasn1==0.4.8
183 | pyasn1-modules==0.2.8
184 | pycodestyle @ file:///home/ktietz/src/ci_mi/pycodestyle_1612807597675/work
185 | pycosat==0.6.3
186 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work
187 | pycurl==7.43.0.6
188 | pydocstyle @ file:///tmp/build/80754af9/pydocstyle_1621600989141/work
189 | pyerfa @ file:///tmp/build/80754af9/pyerfa_1621560806183/work
190 | pyflakes @ file:///home/ktietz/src/ci_ipy2/pyflakes_1612551159640/work
191 | Pygments @ file:///tmp/build/80754af9/pygments_1621606182707/work
192 | pylint @ file:///tmp/build/80754af9/pylint_1625158820537/work
193 | pyls-black @ file:///tmp/build/80754af9/pyls-black_1607553132291/work
194 | pyls-spyder @ file:///tmp/build/80754af9/pyls-spyder_1613849700860/work
195 | pyodbc===4.0.0-unsupported
196 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work
197 | pyparsing @ file:///home/linux1/recipes/ci/pyparsing_1610983426697/work
198 | pyrsistent @ file:///tmp/build/80754af9/pyrsistent_1600141720057/work
199 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
200 | pytest==6.2.4
201 | python-dateutil @ file:///home/ktietz/src/ci/python-dateutil_1611928101742/work
202 | python-jsonrpc-server @ file:///tmp/build/80754af9/python-jsonrpc-server_1600278539111/work
203 | python-language-server @ file:///tmp/build/80754af9/python-language-server_1607972495879/work
204 | pytorch-wavelets @ file:///home/sunfan/Downloads/qq-files/3101347528/file_recv/fft_conv/pytorch_wavelets
205 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
206 | PyWavelets @ file:///tmp/build/80754af9/pywavelets_1601658317819/work
207 | pyxdg @ file:///tmp/build/80754af9/pyxdg_1603822279816/work
208 | PyYAML==5.4.1
209 | pyzmq==20.0.0
210 | QDarkStyle==2.8.1
211 | QtAwesome @ file:///tmp/build/80754af9/qtawesome_1615991616277/work
212 | qtconsole @ file:///tmp/build/80754af9/qtconsole_1623278325812/work
213 | QtPy==1.9.0
214 | regex @ file:///tmp/build/80754af9/regex_1617569202463/work
215 | requests @ file:///tmp/build/80754af9/requests_1608241421344/work
216 | requests-oauthlib==1.3.0
217 | rope @ file:///tmp/build/80754af9/rope_1623703006312/work
218 | rsa==4.7.2
219 | Rtree @ file:///tmp/build/80754af9/rtree_1618420845272/work
220 | ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016699510/work
221 | scikit-image @ file:///tmp/build/80754af9/scikit-image_1648196304918/work
222 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1621370412049/work
223 | scipy @ file:///tmp/build/80754af9/scipy_1618855647378/work
224 | seaborn @ file:///tmp/build/80754af9/seaborn_1608578541026/work
225 | SecretStorage @ file:///tmp/build/80754af9/secretstorage_1614022784285/work
226 | Send2Trash @ file:///tmp/build/80754af9/send2trash_1607525499227/work
227 | simplegeneric==0.8.1
228 | singledispatch @ file:///tmp/build/80754af9/singledispatch_1623948242478/work
229 | sip==4.19.13
230 | six @ file:///tmp/build/80754af9/six_1623709665295/work
231 | sniffio @ file:///tmp/build/80754af9/sniffio_1614030475067/work
232 | snowballstemmer @ file:///tmp/build/80754af9/snowballstemmer_1611258885636/work
233 | sortedcollections @ file:///tmp/build/80754af9/sortedcollections_1611172717284/work
234 | sortedcontainers @ file:///tmp/build/80754af9/sortedcontainers_1623949099177/work
235 | soupsieve @ file:///tmp/build/80754af9/soupsieve_1616183228191/work
236 | Sphinx @ file:///tmp/build/80754af9/sphinx_1623884544367/work
237 | sphinxcontrib-applehelp @ file:///home/ktietz/src/ci/sphinxcontrib-applehelp_1611920841464/work
238 | sphinxcontrib-devhelp @ file:///home/ktietz/src/ci/sphinxcontrib-devhelp_1611920923094/work
239 | sphinxcontrib-htmlhelp @ file:///tmp/build/80754af9/sphinxcontrib-htmlhelp_1623945626792/work
240 | sphinxcontrib-jsmath @ file:///home/ktietz/src/ci/sphinxcontrib-jsmath_1611920942228/work
241 | sphinxcontrib-qthelp @ file:///home/ktietz/src/ci/sphinxcontrib-qthelp_1611921055322/work
242 | sphinxcontrib-serializinghtml @ file:///tmp/build/80754af9/sphinxcontrib-serializinghtml_1624451540180/work
243 | sphinxcontrib-websupport @ file:///tmp/build/80754af9/sphinxcontrib-websupport_1597081412696/work
244 | spyder @ file:///tmp/build/80754af9/spyder_1616775618138/work
245 | spyder-kernels @ file:///tmp/build/80754af9/spyder-kernels_1614030590686/work
246 | SQLAlchemy @ file:///tmp/build/80754af9/sqlalchemy_1624584182860/work
247 | statsmodels @ file:///tmp/build/80754af9/statsmodels_1614023746358/work
248 | sympy @ file:///tmp/build/80754af9/sympy_1618252284338/work
249 | tables==3.6.1
250 | tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work
251 | tensorboard==2.6.0
252 | tensorboard-data-server==0.6.1
253 | tensorboard-plugin-wit==1.8.0
254 | tensorboardX==2.4
255 | tensorrt==0.0.1
256 | terminado==0.9.4
257 | testpath @ file:///tmp/build/80754af9/testpath_1624638946665/work
258 | textdistance @ file:///tmp/build/80754af9/textdistance_1612461398012/work
259 | thop==0.0.31.post2005241907
260 | threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl
261 | three-merge @ file:///tmp/build/80754af9/three-merge_1607553261110/work
262 | tifffile==2020.10.1
263 | timm==0.4.12
264 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work
265 | toolz @ file:///home/linux1/recipes/ci/toolz_1610987900194/work
266 | torch==1.9.0
267 | torch-scatter==2.0.9
268 | torch-sparse==0.6.13
269 | torch2trt==0.4.0
270 | torchaudio==0.9.0a0+33b2469
271 | torchvision==0.10.0
272 | tornado @ file:///tmp/build/80754af9/tornado_1606942300299/work
273 | tqdm @ file:///tmp/build/80754af9/tqdm_1625563689033/work
274 | traitlets @ file:///home/ktietz/src/ci/traitlets_1611929699868/work
275 | ttach==0.0.3
276 | typed-ast @ file:///tmp/build/80754af9/typed-ast_1624953673417/work
277 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1624965014186/work
278 | ujson @ file:///tmp/build/80754af9/ujson_1611259522456/work
279 | unicodecsv==0.14.1
280 | urllib3 @ file:///tmp/build/80754af9/urllib3_1625084269274/work
281 | watchdog @ file:///tmp/build/80754af9/watchdog_1612471027849/work
282 | wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
283 | webencodings==0.5.1
284 | Werkzeug @ file:///home/ktietz/src/ci/werkzeug_1611932622770/work
285 | widgetsnbextension==3.5.1
286 | wrapt==1.12.1
287 | wurlitzer @ file:///tmp/build/80754af9/wurlitzer_1617224664226/work
288 | xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work
289 | XlsxWriter @ file:///tmp/build/80754af9/xlsxwriter_1625006966557/work
290 | xlwt==1.3.0
291 | xmltodict==0.12.0
292 | yapf @ file:///tmp/build/80754af9/yapf_1615749224965/work
293 | zict==2.0.0
294 | zipp @ file:///tmp/build/80754af9/zipp_1615904174917/work
295 | zope.event==4.5.0
296 | zope.interface @ file:///tmp/build/80754af9/zope.interface_1625035545636/work
297 |
--------------------------------------------------------------------------------
/rgbd_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 | import torch
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label, depth):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
19 | # top bottom flip
20 | # if flip_flag2==1:
21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
24 | return img, label, depth
25 |
26 |
27 | def randomCrop(image, label, depth):
28 | border = 30
29 | image_width = image.size[0]
30 | image_height = image.size[1]
31 | crop_win_width = np.random.randint(image_width - border, image_width)
32 | crop_win_height = np.random.randint(image_height - border, image_height)
33 | random_region = (
34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
35 | (image_height + crop_win_height) >> 1)
36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region)
37 |
38 |
39 | def randomRotation(image, label, depth):
40 | mode = Image.BICUBIC
41 | if random.random() > 0.8:
42 | random_angle = np.random.randint(-15, 15)
43 | image = image.rotate(random_angle, mode)
44 | label = label.rotate(random_angle, mode)
45 | depth = depth.rotate(random_angle, mode)
46 | return image, label, depth
47 |
48 |
49 | def colorEnhance(image):
50 | bright_intensity = random.randint(5, 15) / 10.0
51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
52 | contrast_intensity = random.randint(5, 15) / 10.0
53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
54 | color_intensity = random.randint(0, 20) / 10.0
55 | image = ImageEnhance.Color(image).enhance(color_intensity)
56 | sharp_intensity = random.randint(0, 30) / 10.0
57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
58 | return image
59 |
60 |
61 | def randomGaussian(image, mean=0.1, sigma=0.35):
62 | def gaussianNoisy(im, mean=mean, sigma=sigma):
63 | for _i in range(len(im)):
64 | im[_i] += random.gauss(mean, sigma)
65 | return im
66 |
67 | img = np.asarray(image)
68 | width, height = img.shape
69 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
70 | img = img.reshape([width, height])
71 | return Image.fromarray(np.uint8(img))
72 |
73 |
74 | def randomPeper(img):
75 | img = np.array(img)
76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
77 | for i in range(noiseNum):
78 |
79 | randX = random.randint(0, img.shape[0] - 1)
80 |
81 | randY = random.randint(0, img.shape[1] - 1)
82 |
83 | if random.randint(0, 1) == 0:
84 |
85 | img[randX, randY] = 0
86 |
87 | else:
88 |
89 | img[randX, randY] = 255
90 | return Image.fromarray(img)
91 |
92 |
93 | # dataset for training
94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
96 | class SalObjDataset(data.Dataset):
97 | def __init__(self, image_root, gt_root, depth_root, trainsize):
98 | self.trainsize = trainsize
99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')
100 | or f.endswith('.png')]
101 | # print(self.images)
102 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
103 | or f.endswith('.png')]
104 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
105 | or f.endswith('.png')]
106 | self.images = sorted(self.images)
107 | self.gts = sorted(self.gts)
108 | self.depths = sorted(self.depths)
109 | self.filter_files()
110 | self.size = len(self.images)
111 | self.img_transform = transforms.Compose([
112 | transforms.Resize((self.trainsize, self.trainsize)),
113 | transforms.ToTensor(),
114 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
115 | self.gt_transform = transforms.Compose([
116 | transforms.Resize((self.trainsize, self.trainsize)),
117 | transforms.ToTensor()])
118 | self.depths_transform = transforms.Compose([
119 | transforms.Resize((self.trainsize, self.trainsize)),
120 | transforms.ToTensor(),
121 | transforms.Normalize([0.485], [0.229])
122 | ])
123 |
124 | def __getitem__(self, index):
125 | image = self.rgb_loader(self.images[index])
126 | gt = self.binary_loader(self.gts[index])
127 | depth = self.binary_loader(self.depths[index])
128 | image, gt, depth = cv_random_flip(image, gt, depth)
129 | image, gt, depth = randomCrop(image, gt, depth)
130 | image, gt, depth = randomRotation(image, gt, depth)
131 | image = colorEnhance(image)
132 | # gt=randomGaussian(gt)
133 | gt = randomPeper(gt)
134 | # image, gt, depth = self.resize(image,gt, depth)
135 | image = self.img_transform(image)
136 | gt = self.gt_transform(gt)
137 | depth = self.depths_transform(depth)
138 | # depth = torch.div(depth.float(),255.0) # DUT
139 |
140 | return image, gt, depth
141 |
142 | def filter_files(self):
143 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
144 | # print(len(self.images),len(self.gts),len(self.depths))
145 | images = []
146 | gts = []
147 | depths = []
148 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths):
149 | img = Image.open(img_path)
150 | gt = Image.open(gt_path)
151 | depth = Image.open(depth_path)
152 | if img.size == gt.size and gt.size == depth.size:
153 | # if img.size == gt.size:
154 | images.append(img_path)
155 | gts.append(gt_path)
156 | depths.append(depth_path)
157 | self.images = images
158 | self.gts = gts
159 | self.depths = depths
160 |
161 | def rgb_loader(self, path):
162 | # print(path)
163 | with open(path, 'rb') as f:
164 | img = Image.open(f)
165 | # print(img)
166 | return img.convert('RGB')
167 |
168 | def binary_loader(self, path):
169 | with open(path, 'rb') as f:
170 | img = Image.open(f)
171 | return img.convert('L')
172 |
173 | def resize(self, img, gt, depth):
174 | assert img.size == gt.size and gt.size == depth.size
175 | h = self.trainsize
176 | w = self.trainsize
177 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
178 | Image.NEAREST)
179 |
180 |
181 | def __len__(self):
182 | return self.size
183 |
184 |
185 | # dataloader for training
186 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True):
187 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize)
188 | data_loader = data.DataLoader(dataset=dataset,
189 | batch_size=batchsize,
190 | shuffle=shuffle,
191 | num_workers=num_workers,
192 | pin_memory=pin_memory)
193 | return data_loader
194 |
195 |
196 | # test dataset and loader
197 | class test_dataset:
198 | def __init__(self, image_root, gt_root, depth_root, testsize):
199 | self.testsize = testsize
200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
201 |
202 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
203 | or f.endswith('.png')]
204 |
205 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
206 | or f.endswith('.png')]
207 |
208 | self.images = sorted(self.images)
209 | # print(self.images)
210 | self.gts = sorted(self.gts)
211 | # print(self.gts)
212 | self.depths = sorted(self.depths)
213 | # print(self.depths)
214 | self.transform = transforms.Compose([
215 | transforms.Resize((self.testsize, self.testsize)),
216 | transforms.ToTensor(),
217 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
218 | # self.gt_transform = transforms.ToTensor()
219 | self.gt_transform = transforms.Compose([
220 | transforms.Resize((self.testsize, self.testsize)),
221 | transforms.ToTensor()])
222 | self.depths_transform = transforms.Compose([
223 | transforms.Resize((self.testsize, self.testsize)),
224 | transforms.ToTensor(),
225 | transforms.Normalize([0.485], [0.229])
226 | ])
227 | self.size = len(self.images)
228 | self.index = 0
229 |
230 | def load_data(self):
231 | image = self.rgb_loader(self.images[self.index])
232 | gt = self.binary_loader(self.gts[self.index])
233 | depth = self.binary_loader(self.depths[self.index])
234 | # image, gt, depth = self.resize(image, gt, depth)
235 | image = self.transform(image).unsqueeze(0)
236 | gt = self.gt_transform(gt).unsqueeze(0)
237 | depth = self.depths_transform(depth)
238 | # depth = torch.div(depth.float(), 255.0) # DUT
239 | depth = depth.unsqueeze(0)
240 | name = self.images[self.index].split('/')[-1]
241 | if name.endswith('.jpg'):
242 | name = name.split('.jpg')[0] + '.png'
243 | self.index += 1
244 | self.index = self.index % self.size
245 | return image, gt, depth, name
246 |
247 | def rgb_loader(self, path):
248 | with open(path, 'rb') as f:
249 | img = Image.open(f)
250 | return img.convert('RGB')
251 |
252 | def binary_loader(self, path):
253 | with open(path, 'rb') as f:
254 | img = Image.open(f)
255 | return img.convert('L')
256 |
257 | def resize(self, img, gt, depth):
258 | # assert img.size == gt.size and gt.size == depth.size
259 | h = self.testsize
260 | w = self.testsize
261 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
262 | Image.NEAREST)
263 |
264 | def __len__(self):
265 | return self.size
266 |
--------------------------------------------------------------------------------
/rgbt_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label, ti):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | ti = ti.transpose(Image.FLIP_LEFT_RIGHT)
19 | # top bottom flip
20 | # if flip_flag2==1:
21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
23 | # ti = ti.transpose(Image.FLIP_TOP_BOTTOM)
24 | return img, label, ti
25 |
26 |
27 | def randomCrop(image, label, ti):
28 | border = 30
29 | image_width = image.size[0]
30 | image_height = image.size[1]
31 | crop_win_width = np.random.randint(image_width - border, image_width)
32 | crop_win_height = np.random.randint(image_height - border, image_height)
33 | random_region = (
34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
35 | (image_height + crop_win_height) >> 1)
36 | return image.crop(random_region), label.crop(random_region), ti.crop(random_region)
37 |
38 |
39 | def randomRotation(image, label, ti):
40 | mode = Image.BICUBIC
41 | if random.random() > 0.8:
42 | random_angle = np.random.randint(-15, 15)
43 | image = image.rotate(random_angle, mode)
44 | label = label.rotate(random_angle, mode)
45 | ti = ti.rotate(random_angle, mode)
46 | return image, label, ti
47 |
48 |
49 | def colorEnhance(image):
50 | bright_intensity = random.randint(5, 15) / 10.0
51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
52 | contrast_intensity = random.randint(5, 15) / 10.0
53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
54 | color_intensity = random.randint(0, 20) / 10.0
55 | image = ImageEnhance.Color(image).enhance(color_intensity)
56 | sharp_intensity = random.randint(0, 30) / 10.0
57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
58 | return image
59 |
60 |
61 | def randomGaussian(image, mean=0.1, sigma=0.35):
62 | def gaussianNoisy(im, mean=mean, sigma=sigma):
63 | for _i in range(len(im)):
64 | im[_i] += random.gauss(mean, sigma)
65 | return im
66 |
67 | img = np.asarray(image)
68 | width, height = img.shape
69 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
70 | img = img.reshape([width, height])
71 | return Image.fromarray(np.uint8(img))
72 |
73 |
74 | def randomPeper(img):
75 | img = np.array(img)
76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
77 | for i in range(noiseNum):
78 |
79 | randX = random.randint(0, img.shape[0] - 1)
80 |
81 | randY = random.randint(0, img.shape[1] - 1)
82 |
83 | if random.randint(0, 1) == 0:
84 |
85 | img[randX, randY] = 0
86 |
87 | else:
88 |
89 | img[randX, randY] = 255
90 | return Image.fromarray(img)
91 |
92 |
93 | # dataset for training
94 | # The current loader is not using the normalized ti maps for training and test. If you use the normalized ti maps
95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
96 | class SalObjDataset(data.Dataset):
97 | def __init__(self, image_root, gt_root, ti_root, trainsize):
98 | self.trainsize = trainsize
99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
100 | # print(self.images)
101 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
102 | or f.endswith('.png')]
103 |
104 | self.tis = [ti_root + f for f in os.listdir(ti_root) if f.endswith('.jpg')
105 | or f.endswith('.png')]
106 |
107 | self.images = sorted(self.images)
108 | self.gts = sorted(self.gts)
109 | self.tis = sorted(self.tis)
110 |
111 | self.filter_files()
112 | self.size = len(self.images)
113 | self.img_transform = transforms.Compose([
114 | transforms.Resize((self.trainsize, self.trainsize)),
115 | transforms.ToTensor(),
116 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
117 | self.gt_transform = transforms.Compose([
118 | transforms.Resize((self.trainsize, self.trainsize)),
119 | transforms.ToTensor()])
120 | self.tis_transform = transforms.Compose([
121 | transforms.Resize((self.trainsize, self.trainsize)),
122 | transforms.ToTensor(),
123 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
124 |
125 |
126 | def __getitem__(self, index):
127 | image = self.rgb_loader(self.images[index])
128 | gt = self.binary_loader(self.gts[index])
129 | ti = self.rgb_loader(self.tis[index])
130 | image, gt, ti = cv_random_flip(image, gt, ti)
131 | image, gt, ti = randomCrop(image, gt, ti)
132 | image, gt, ti = randomRotation(image, gt, ti)
133 | image = colorEnhance(image)
134 | # gt=randomGaussian(gt)
135 | gt = randomPeper(gt)
136 | image = self.img_transform(image)
137 | gt = self.gt_transform(gt)
138 | ti = self.tis_transform(ti)
139 | return image, gt, ti
140 |
141 | def filter_files(self):
142 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.tis)
143 | images = []
144 | gts = []
145 | tis = []
146 | for img_path, gt_path, ti_path in zip(self.images, self.gts, self.tis):
147 | img = Image.open(img_path)
148 | gt = Image.open(gt_path)
149 | ti = Image.open(ti_path)
150 | if img.size == gt.size and gt.size == ti.size:
151 | images.append(img_path)
152 | gts.append(gt_path)
153 | tis.append(ti_path)
154 | self.images = images
155 | self.gts = gts
156 | self.tis = tis
157 |
158 | def rgb_loader(self, path):
159 | with open(path, 'rb') as f:
160 | img = Image.open(f)
161 | return img.convert('RGB')
162 |
163 | def binary_loader(self, path):
164 | with open(path, 'rb') as f:
165 | img = Image.open(f)
166 | return img.convert('L')
167 |
168 | def resize(self, img, gt, ti):
169 | assert img.size == gt.size and gt.size == ti.size
170 | w, h = img.size
171 | if h < self.trainsize or w < self.trainsize:
172 | h = max(h, self.trainsize)
173 | w = max(w, self.trainsize)
174 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), ti.resize((w, h),
175 | Image.NEAREST)
176 | else:
177 | return img, gt, ti
178 |
179 | def __len__(self):
180 | return self.size
181 |
182 |
183 | # dataloader for training
184 | def get_loader(image_root, gt_root, ti_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=False):
185 | dataset = SalObjDataset(image_root, gt_root, ti_root, trainsize)
186 |
187 | data_loader = data.DataLoader(dataset=dataset,
188 | batch_size=batchsize,
189 | shuffle=shuffle,
190 | num_workers=num_workers,
191 | pin_memory=pin_memory)
192 | # print(len(data_loader))
193 | return data_loader
194 |
195 |
196 | # test dataset and loader
197 | class test_dataset:
198 | def __init__(self, image_root, gt_root, ti_root,testsize):
199 | self.testsize = testsize
200 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
201 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
202 | or f.endswith('.png')]
203 | self.tis = [ti_root + f for f in os.listdir(ti_root) if f.endswith('.jpg')
204 | or f.endswith('.png')]
205 |
206 | self.images = sorted(self.images)
207 | self.gts = sorted(self.gts)
208 | self.tis = sorted(self.tis)
209 | self.transform = transforms.Compose([
210 | transforms.Resize((self.testsize, self.testsize)),
211 | transforms.ToTensor(),
212 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
213 | # self.gt_transform = transforms.ToTensor()
214 | self.gt_transform = transforms.Compose([
215 | transforms.Resize((self.testsize, self.testsize)),
216 | transforms.ToTensor()])
217 | self.tis_transform = transforms.Compose([
218 | transforms.Resize((self.testsize, self.testsize)),
219 | transforms.ToTensor(),
220 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
221 | self.size = len(self.images)
222 | self.index = 0
223 |
224 | def load_data(self):
225 | image = self.rgb_loader(self.images[self.index])
226 | image = self.transform(image).unsqueeze(0)
227 | gt = self.binary_loader(self.gts[self.index])
228 | gt = self.gt_transform(gt).unsqueeze(0)
229 | ti = self.rgb_loader(self.tis[self.index])
230 | ti = self.tis_transform(ti).unsqueeze(0)
231 |
232 | name = self.images[self.index].split('/')[-1]
233 | if name.endswith('.jpg'):
234 | name = name.split('.jpg')[0] + '.png'
235 | self.index += 1
236 | self.index = self.index % self.size
237 | return image, gt, ti,name
238 |
239 | def rgb_loader(self, path):
240 | with open(path, 'rb') as f:
241 | img = Image.open(f)
242 | return img.convert('RGB')
243 |
244 | def binary_loader(self, path):
245 | with open(path, 'rb') as f:
246 | img = Image.open(f)
247 | return img.convert('L')
248 |
249 | def __len__(self):
250 | return self.size
251 |
252 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import sys
4 | sys.path.append('./models')
5 | import numpy as np
6 | import os
7 | import cv2
8 |
9 | from LSNet import LSNet
10 | from config import opt
11 |
12 |
13 |
14 | dataset_path = opt.test_path
15 |
16 | #set device for test
17 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
18 | print('USE GPU:', opt.gpu_id)
19 |
20 | #load the model
21 | model = LSNet()
22 |
23 | #Large epoch size may not generalize well. You can choose a good model to load according to the log file and pth files saved in ('./BBSNet_cpts/') when training.
24 | model.load_state_dict(torch.load(''))
25 | model.cuda()
26 | model.eval()
27 |
28 |
29 | #test
30 | test_mae = []
31 | if opt.task =='RGBT':
32 | from rgbt_dataset import test_dataset
33 | test_datasets = ['VT800','VT1000','VT5000']
34 | elif opt.task == 'RGBD':
35 | from rgbd_dataset import test_dataset
36 | test_datasets = ['NJU2K', 'DES', 'LFSD', 'NLPR', 'SIP']
37 | else:
38 | raise ValueError(f"Unknown task type {opt.task}")
39 |
40 | for dataset in test_datasets:
41 | mae_sum = 0
42 | save_path = '/' + dataset + '/'
43 | if not os.path.exists(save_path):
44 | os.makedirs(save_path)
45 | if opt.task == 'RGBT':
46 | image_root = dataset_path + dataset + '/RGB/'
47 | gt_root = dataset_path + dataset + '/GT/'
48 | ti_root=dataset_path + dataset +'/T/'
49 | elif opt.task == 'RGBD':
50 | image_root = dataset_path + dataset + '/RGB/'
51 | gt_root = dataset_path + dataset + '/GT/'
52 | ti_root = dataset_path + dataset + '/depth/'
53 | else:
54 | raise ValueError(f"Unknown task type {opt.task}")
55 | test_loader = test_dataset(image_root, gt_root, ti_root, opt.testsize)
56 | for i in range(test_loader.size):
57 | image, gt, ti, name = test_loader.load_data()
58 | gt = gt.cuda()
59 | image = image.cuda()
60 | ti = ti.cuda()
61 | if opt.task == 'RGBD':
62 | ti = torch.cat((ti,ti,ti),dim=1)
63 | res = model(image,ti)
64 | predict = torch.sigmoid(res)
65 | predict = (predict - predict.min()) / (predict.max() - predict.min() + 1e-8)
66 | mae = torch.sum(torch.abs(predict - gt)) / torch.numel(gt)
67 | # mae = torch.abs(predict - gt).mean()
68 | mae_sum = mae.item() + mae_sum
69 | predict = predict.data.cpu().numpy().squeeze()
70 | # print(predict.shape)
71 | print('save img to: ',save_path+name)
72 | cv2.imwrite(save_path+name, predict*255)
73 | test_mae.append(mae_sum / test_loader.size)
74 | print('Test Done!', 'MAE', test_mae)
75 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 | from datetime import datetime
7 | from torchvision.utils import make_grid
8 | from utils import clip_gradient, adjust_lr
9 | from tensorboardX import SummaryWriter
10 | import logging
11 | import torch.backends.cudnn as cudnn
12 | from config import opt
13 | from torch.cuda import amp
14 | # set the device for training
15 | cudnn.benchmark = True
16 | cudnn.enabled = True
17 |
18 |
19 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
20 | print('USE GPU:', opt.gpu_id)
21 |
22 | # build the model
23 | from LSNet import LSNet
24 | model = LSNet()
25 | if (opt.load is not None):
26 | model.load_state_dict(torch.load(opt.load))
27 | print('load model from ', opt.load)
28 | model.cuda()
29 | params = model.parameters()
30 | optimizer = torch.optim.Adam(params, opt.lr)
31 |
32 | # set the path
33 | train_dataset_path = opt.train_root
34 |
35 | val_dataset_path = opt.val_root
36 |
37 | save_path = opt.save_path
38 |
39 | if not os.path.exists(save_path):
40 | os.makedirs(save_path)
41 |
42 | # load data
43 | print('load data...')
44 | if opt.task =='RGBT':
45 | from rgbt_dataset import get_loader, test_dataset
46 | image_root = train_dataset_path + '/RGB/'
47 | ti_root = train_dataset_path + '/T/'
48 | gt_root = train_dataset_path + '/GT/'
49 | val_image_root = val_dataset_path + '/RGB/'
50 | val_ti_root = val_dataset_path + '/T/'
51 | val_gt_root = val_dataset_path + '/GT/'
52 | elif opt.task == 'RGBD':
53 | from rgbd_dataset import get_loader, test_dataset
54 | image_root = train_dataset_path + '/RGB/'
55 | ti_root = train_dataset_path + '/depth/'
56 | gt_root = train_dataset_path + '/GT/'
57 | val_image_root = val_dataset_path + '/RGB/'
58 | val_ti_root = val_dataset_path + '/depth/'
59 | val_gt_root = val_dataset_path + '/GT/'
60 | else:
61 | raise ValueError(f"Unknown task type {opt.task}")
62 |
63 | train_loader = get_loader(image_root, gt_root, ti_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
64 | test_loader = test_dataset(val_image_root, val_gt_root,val_ti_root, opt.trainsize)
65 | total_step = len(train_loader)
66 | # print(total_step)
67 |
68 | logging.basicConfig(filename=save_path + 'log.log', format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]',
69 | level=logging.INFO, filemode='a', datefmt='%Y-%m-%d %I:%M:%S %p')
70 | logging.info("Model:")
71 | logging.info(model)
72 |
73 | logging.info(save_path + "Train")
74 | logging.info("Config")
75 | logging.info(
76 | 'epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{};decay_epoch:{}'.format(
77 | opt.epoch, opt.lr, opt.batchsize, opt.trainsize, opt.clip, opt.decay_rate, opt.load, save_path,
78 | opt.decay_epoch))
79 |
80 | # set loss function
81 | import torch.nn as nn
82 |
83 | class IOUBCE_loss(nn.Module):
84 | def __init__(self):
85 | super(IOUBCE_loss, self).__init__()
86 | self.nll_lose = nn.BCEWithLogitsLoss()
87 |
88 | def forward(self, input_scale, taeget_scale):
89 | b,_,_,_ = input_scale.size()
90 | loss = []
91 | for inputs, targets in zip(input_scale, taeget_scale):
92 | bce = self.nll_lose(inputs,targets)
93 | pred = torch.sigmoid(inputs)
94 | inter = (pred * targets).sum(dim=(1, 2))
95 | union = (pred + targets).sum(dim=(1, 2))
96 | IOU = (inter + 1) / (union - inter + 1)
97 | loss.append(1- IOU + bce)
98 | total_loss = sum(loss)
99 | return total_loss / b
100 |
101 |
102 | CE = torch.nn.BCEWithLogitsLoss().cuda()
103 | IOUBCE = IOUBCE_loss().cuda()
104 | class IOUBCEWithoutLogits_loss(nn.Module):
105 | def __init__(self):
106 | super(IOUBCEWithoutLogits_loss, self).__init__()
107 | self.nll_lose = nn.BCELoss()
108 |
109 | def forward(self, input_scale, target_scale):
110 | b,c,h,w = input_scale.size()
111 | loss = []
112 | for inputs, targets in zip(input_scale, target_scale):
113 |
114 | bce = self.nll_lose(inputs,targets)
115 |
116 | inter = (inputs * targets).sum(dim=(1, 2))
117 | union = (inputs + targets).sum(dim=(1, 2))
118 | IOU = (inter + 1) / (union - inter + 1)
119 | loss.append(1- IOU + bce)
120 | total_loss = sum(loss)
121 | return total_loss / b
122 | IOUBCEWithoutLogits = IOUBCEWithoutLogits_loss().cuda()
123 |
124 |
125 | step = 0
126 | writer = SummaryWriter(save_path + 'summary', flush_secs = 30)
127 | best_mae = 1
128 | best_epoch = 0
129 | Sacler = amp.GradScaler()
130 |
131 | # BBA
132 | def tesnor_bound(img, ksize):
133 |
134 | '''
135 | :param img: tensor, B*C*H*W
136 | :param ksize: tensor, ksize * ksize
137 | :param 2patches: tensor, B * C * H * W * ksize * ksize
138 | :return: tensor, (inflation - corrosion), B * C * H * W
139 | '''
140 |
141 | B, C, H, W = img.shape
142 | pad = int((ksize - 1) // 2)
143 | img_pad = F.pad(img, pad=[pad, pad, pad, pad], mode='constant',value = 0)
144 | # unfold in the second and third dimensions
145 | patches = img_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
146 | corrosion, _ = torch.min(patches.contiguous().view(B, C, H, W, -1), dim=-1)
147 | inflation, _ = torch.max(patches.contiguous().view(B, C, H, W, -1), dim=-1)
148 | return inflation - corrosion
149 |
150 |
151 |
152 | # train function
153 | def train(train_loader, model, optimizer, epoch, save_path):
154 | global step
155 | model.train()
156 | loss_all = 0
157 | epoch_step = 0
158 | try:
159 | for i, (images, gts, tis) in enumerate(train_loader, start=1):
160 | optimizer.zero_grad()
161 | images = images.cuda()
162 | tis = tis.cuda()
163 | gts = gts.cuda()
164 | if opt.task == 'RGBD':
165 | tis = torch.cat((tis, tis, tis), dim=1)
166 |
167 | gts2 = F.interpolate(gts, (112, 112))
168 | gts3 = F.interpolate(gts, (56, 56))
169 |
170 |
171 | bound = tesnor_bound(gts, 3).cuda()
172 | bound2 = F.interpolate(bound, (112, 112))
173 | bound3 = F.interpolate(bound, (56, 56))
174 |
175 | out = model(images, tis)
176 |
177 |
178 | loss1 = IOUBCE(out[0], gts)
179 | loss2 = IOUBCE(out[1], gts2)
180 | loss3 = IOUBCE(out[2], gts3)
181 |
182 | predict_bound0 = out[0]
183 | predict_bound1 = out[1]
184 | predict_bound2 = out[2]
185 | predict_bound0 = tesnor_bound(torch.sigmoid(predict_bound0), 3)
186 | predict_bound1 = tesnor_bound(torch.sigmoid(predict_bound1), 3)
187 | predict_bound2 = tesnor_bound(torch.sigmoid(predict_bound2), 3)
188 | loss6 = IOUBCEWithoutLogits(predict_bound0, bound)
189 | loss7 = IOUBCEWithoutLogits(predict_bound1, bound2)
190 | loss8 = IOUBCEWithoutLogits(predict_bound2, bound3)
191 |
192 |
193 | loss_sod = loss1 + loss2 + loss3
194 | loss_bound = loss6 + loss7 + loss8
195 | loss_trans = out[3]
196 | loss = loss_sod + loss_bound + loss_trans
197 | loss.backward()
198 | optimizer.step()
199 | step = step + 1
200 | epoch_step = epoch_step + 1
201 | loss_all = loss.item() + loss_all
202 | if i % 10 == 0 or i == total_step or i == 1:
203 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, loss_sod: {:.4f},'
204 | 'loss_bound: {:.4f},loss_trans: {:.4f}'.
205 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss.item(),
206 | loss_sod.item(),loss_bound.item(), loss_trans.item()))
207 | logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Loss: {:.4f}, loss_sod: {:.4f},'
208 | 'loss_bound: {:.4f},loss_trans: {:.4f} '.
209 | format(epoch, opt.epoch, i, total_step, loss.item(),
210 | loss_sod.item(),loss_bound.item(), loss_trans.item()))
211 | writer.add_scalar('Loss', loss, global_step=step)
212 | # grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True)
213 | # writer.add_image('train/RGB', grid_image, step)
214 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True)
215 | writer.add_image('train/Ground_truth', grid_image, step)
216 | grid_image = make_grid(bound[0].clone().cpu().data, 1, normalize=True)
217 | writer.add_image('train/bound', grid_image, step)
218 |
219 | # grid_image = make_grid(body[0].clone().cpu().data, 1, normalize=True)
220 | # writer.add_image('train/body', grid_image, step)
221 | res = out[0][0].clone()
222 | res = res.sigmoid().data.cpu().numpy().squeeze()
223 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
224 | writer.add_image('OUT/out', torch.tensor(res), step, dataformats='HW')
225 | res = predict_bound0[0].clone()
226 | res = res.sigmoid().data.cpu().numpy().squeeze()
227 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
228 | writer.add_image('OUT/bound', torch.tensor(res), step, dataformats='HW')
229 |
230 |
231 | loss_all /= epoch_step
232 | # logging.info('#TRAIN#:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format(epoch, opt.epoch, loss_all))
233 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch)
234 | if (epoch) % 5 == 0:
235 | torch.save(model.state_dict(), save_path + 'Net_epoch_{}.pth'.format(epoch))
236 | except KeyboardInterrupt:
237 | print('Keyboard Interrupt: save model and exit.')
238 | if not os.path.exists(save_path):
239 | os.makedirs(save_path)
240 | torch.save(model.state_dict(), save_path + 'Net_epoch_{}.pth'.format(epoch + 1))
241 | print('save checkpoints successfully!')
242 | raise
243 |
244 |
245 | # test function
246 | def test(test_loader, model, epoch, save_path):
247 | global best_mae, best_epoch
248 | model.eval()
249 | with torch.no_grad():
250 | mae_sum = 0
251 | for i in range(test_loader.size):
252 | image, gt, ti, name = test_loader.load_data()
253 | gt = gt.cuda()
254 | image = image.cuda()
255 | ti = ti.cuda()
256 | if opt.task == 'RGBD':
257 | tis = torch.cat((tis, tis, tis), dim=1)
258 |
259 | res = model(image, ti)
260 | res = torch.sigmoid(res)
261 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
262 | mae_train = torch.sum(torch.abs(res - gt)) * 1.0 / (torch.numel(gt))
263 | # print(mae_train)
264 | mae_sum = mae_train.item() + mae_sum
265 | # print(test_loader.size)
266 | mae = mae_sum / test_loader.size
267 | # print(test_loader.size)
268 | writer.add_scalar('MAE', torch.as_tensor(mae), global_step=epoch)
269 | print('Epoch: {} MAE: {} #### bestMAE: {} bestEpoch: {}'.format(epoch, mae, best_mae, best_epoch))
270 | if epoch == 1:
271 | best_mae = mae
272 | else:
273 | if mae < best_mae:
274 | best_mae = mae
275 | best_epoch = epoch
276 | torch.save(model.state_dict(), save_path + 'Net_epoch_best.pth')
277 | print('best epoch:{}'.format(epoch))
278 | logging.info('#TEST#:Epoch:{} MAE:{} bestEpoch:{} bestMAE:{}'.format(epoch, mae, best_epoch, best_mae))
279 |
280 |
281 | if __name__ == '__main__':
282 | print("Start train...")
283 | for epoch in range(1, opt.epoch+1):
284 | cur_lr = adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
285 | writer.add_scalar('learning_rate', cur_lr, global_step=epoch)
286 | train(train_loader, model, optimizer, epoch, save_path)
287 | test(test_loader, model, epoch, save_path)
288 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | def clip_gradient(optimizer, grad_clip):
2 | for group in optimizer.param_groups:
3 | for param in group['params']:
4 | if param.grad is not None:
5 | param.grad.data.clamp_(-grad_clip, grad_clip)
6 |
7 |
8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9 | decay = decay_rate ** (epoch // decay_epoch)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = decay*init_lr
12 | lr=param_group['lr']
13 | return lr
14 |
--------------------------------------------------------------------------------