├── Images
├── MFNet.png
└── PST900.png
├── README.md
├── __pycache__
├── resnet.cpython-37.pyc
└── resnet.cpython-38.pyc
├── class_weights.py
├── configs
└── LASNet.json
├── generate_binary_labels.m
├── generate_bound_or_edge.m
├── model
├── LASNet.json
├── predicts_MFNet.zip
└── predicts_PST900.zip
├── resnet.py
├── sober.py
├── test_LASNet.py
├── toolbox
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── dual_self_att.cpython-37.pyc
│ ├── dual_self_att.cpython-38.pyc
│ ├── log.cpython-36.pyc
│ ├── log.cpython-37.pyc
│ ├── log.cpython-38.pyc
│ ├── losses.cpython-37.pyc
│ ├── losses.cpython-38.pyc
│ ├── metrics.cpython-36.pyc
│ ├── metrics.cpython-37.pyc
│ ├── metrics.cpython-38.pyc
│ ├── utils.cpython-36.pyc
│ ├── utils.cpython-37.pyc
│ └── utils.cpython-38.pyc
├── datasets
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── augmentations.cpython-36.pyc
│ │ ├── augmentations.cpython-37.pyc
│ │ ├── augmentations.cpython-38.pyc
│ │ ├── camvid.cpython-37.pyc
│ │ ├── irseg.cpython-36.pyc
│ │ ├── irseg.cpython-37.pyc
│ │ ├── irseg.cpython-38.pyc
│ │ ├── nyuv2.cpython-37.pyc
│ │ └── pst900.cpython-38.pyc
│ ├── augmentations.py
│ ├── camvid.py
│ ├── irseg.py
│ └── pst900.py
├── dual_self_att.py
├── log.py
├── losses.py
├── metrics.py
├── models
│ ├── LASNet.py
│ └── __pycache__
│ │ ├── EGFNet.cpython-37.pyc
│ │ ├── EGFNet.cpython-38.pyc
│ │ ├── LASNet.cpython-38.pyc
│ │ ├── LgyTestNet.cpython-37.pyc
│ │ └── LgyTestNet.cpython-38.pyc
├── optim
│ ├── Ranger.py
│ └── __pycache__
│ │ ├── Ranger.cpython-37.pyc
│ │ └── Ranger.cpython-38.pyc
├── scheduler
│ ├── __init__.py
│ └── lr_scheduler.py
└── utils.py
└── train_LASNet.py
/Images/MFNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/Images/MFNet.png
--------------------------------------------------------------------------------
/Images/PST900.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/Images/PST900.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LASNet
2 | This project provides the code and results for 'RGB-T Semantic Segmentation with Location, Activation, and Sharpening', IEEE TCSVT, 2023. [IEEE link](https://ieeexplore.ieee.org/document/9900351) and [arxiv link](https://arxiv.org/abs/2210.14530) [Homepage](https://mathlee.github.io/)
3 |
4 | # Requirements
5 | python 3.7/3.8 + pytorch 1.9.0 (biult on [EGFNet](https://github.com/ShaohuaDong2021/EGFNet))
6 |
7 |
8 | # Segmentation maps and performance
9 | We provide segmentation maps on MFNet dataset and PST900 dataset under './model/'.
10 |
11 | **Performace on MFNet dataset**
12 |
13 |
14 |

15 |
16 |
17 | **Performace on PST900 dataset**
18 |
19 |
20 |

21 |
22 |
23 |
24 | # Training
25 | 1. Install '[apex](https://github.com/NVIDIA/apex)'.
26 | 2. Download [MFNet dataset](https://pan.baidu.com/s/1NHGazP7pwgEM47SP_ljJPg) (code: 3b9o) or [PST900 dataset](https://pan.baidu.com/s/13xgwFfUbu8zNvkwJq2Ggug) (code: mp2h).
27 | 3. Use 'generate_binary_labels.m' to get binary labels, and use 'generate_bound_or_edge.m' to get edge labels.
28 | 4. Run train_LASNet.py (default to MFNet Dataset).
29 |
30 | Note: our main model is under './toolbox/models/LASNet.py'
31 |
32 |
33 | # Pre-trained model and testing
34 | 1. Download the following pre-trained model and put it under './model/'. [model_MFNet.pth](https://pan.baidu.com/s/1dWCbTl274nzgdHGOsJkK_Q) (code: 5th1) [model_PST900.pth](https://pan.baidu.com/s/1zQif2_8LTG5R7aabQOXjrA) (code: okdq)
35 |
36 | 2. Rename the name of the pre-trained model to 'model.pth', and then run test_LASNet.py (default to MFNet Dataset).
37 |
38 |
39 | # Citation
40 | @ARTICLE{Li_2023_LASNet,
41 | author = {Gongyang Li and Yike Wang and Zhi Liu and Xinpeng Zhang and Dan Zeng},
42 | title = {RGB-T Semantic Segmentation with Location, Activation, and Sharpening},
43 | journal = {IEEE Transactions on Circuits and Systems for Video Technology},
44 | year = {2023},
45 | volume = {33},
46 | number = {3},
47 | pages = {1223-1235},
48 | month = {Mar.},
49 | }
50 |
51 |
52 | If you encounter any problems with the code, want to report bugs, etc.
53 |
54 | Please contact me at lllmiemie@163.com or ligongyang@shu.edu.cn.
55 |
--------------------------------------------------------------------------------
/__pycache__/resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/__pycache__/resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/class_weights.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import PIL.Image
4 | import numpy as np
5 | import cv2
6 | import pdb
7 | import glob
8 | import configparser
9 | import numpy as np
10 | import sys
11 |
12 | class ClassWeights:
13 | """
14 | Calculate class weights for PST900
15 | """
16 | def __init__(self, datapath=''):
17 | """
18 | Initialize class
19 | """
20 | self.data_path = datapath
21 | self.label_path_train = os.path.join(datapath, 'train', 'labels')
22 | self.label_path_test = os.path.join(datapath, 'test', 'labels')
23 | self.label_stack = []
24 | self.label_paths = []
25 | self.num_classes = 5
26 |
27 | def process_labels(self):
28 | """
29 | Wrapper for processing all labels
30 | """
31 | train_labels = glob.glob(self.label_path_train + '/*.png')
32 | test_labels = glob.glob(self.label_path_test + '/*.png')
33 | self.label_paths = train_labels + test_labels
34 | print("Accumulating labels...")
35 | print(len(self.label_paths))
36 | for label_img in self.label_paths:
37 | print(label_img)
38 | label = cv2.imread(label_img, -1)
39 | self.label_stack.append(label)
40 | print("Accumulating stack of labels done...")
41 | print(self.label_stack)
42 | stack_np = np.stack(self.label_stack, axis=0)
43 | self.weights = self.calculate_class_weights(stack_np, self.num_classes)
44 | print("Weights are: {}".format(self.weights))
45 |
46 | def load_class_weights(self, weight_file):
47 | """
48 | Load class weights from .ini file
49 | """
50 | config = configparser.ConfigParser()
51 | config.sections()
52 | config.read(weight_file)
53 | weights_mat = np.zeros([1, self.num_classes])
54 | weights_mat[0,0] = float(config['ClassWeights']['background'])
55 | weights_mat[0,1] = float(config['ClassWeights']['fire_extinguisher'])
56 | weights_mat[0,2] = float(config['ClassWeights']['backpack'])
57 | weights_mat[0,3] = float(config['ClassWeights']['drill'])
58 | weights_mat[0,4] = float(config['ClassWeights']['rescue_randy'])
59 | num_images = float(config['ClassWeights']['num_images'])
60 | print("Loaded class weights from .ini file...")
61 | return weights_mat.squeeze(), num_images
62 |
63 | def save_class_weights(self, weight_file):
64 | """
65 | Save class weights to .ini file
66 | """
67 | config = configparser.ConfigParser()
68 | config['ClassWeights'] = {}
69 | config['ClassWeights']['background'] = str(self.weights[0])
70 | config['ClassWeights']['fire_extinguisher'] = str(self.weights[1])
71 | config['ClassWeights']['backpack'] = str(self.weights[2])
72 | config['ClassWeights']['drill'] = str(self.weights[3])
73 | config['ClassWeights']['rescue_randy'] = str(self.weights[4])
74 | config['ClassWeights']['num_images'] = str(len(self.label_paths))
75 | with open(weight_file, 'w') as configfile:
76 | config.write(configfile)
77 | print("Saved class weights to .ini file...")
78 |
79 | def calculate_class_weights(self, Y, n_classes, method="paszke", c=1.02):
80 | """ Given the training data labels Calculates the class weights.
81 | Args:
82 | Y: (numpy array) The training labels as class id integers.
83 | The shape does not matter, as long as each element represents
84 | a class id (ie, NOT one-hot-vectors).
85 | n_classes: (int) Number of possible classes.
86 | method: (str) The type of class weighting to use.
87 | - "paszke" = use the method from from Paszke et al 2016
88 | `1/ln(c + class_probability)`
89 | c: (float) Coefficient to use, when using paszke method.
90 | Returns:
91 | weights: (numpy array) Array of shape [n_classes] assigning a
92 | weight value to each class.
93 | References:
94 | Paszke et al 2016: https://arxiv.org/abs/1606.02147
95 | """
96 | ids, counts = np.unique(Y, return_counts=True)
97 | n_pixels = Y.size
98 | p_class = np.zeros(n_classes)
99 | p_class[ids] = counts/n_pixels
100 | weights = 1/np.log(c+p_class)
101 | return weights
102 |
103 | def main():
104 |
105 | pst900_path = './PST900_RGBT_Dataset/'
106 |
107 | weight_path = pst900_path + 'weights.ini'
108 |
109 | # Instantiate ClassWeights
110 | calc_weights = ClassWeights(pst900_path)
111 |
112 | # Example: to calculate weights for the entire dataset
113 | calc_weights.process_labels()
114 |
115 | # Example: to save weights to config file
116 | calc_weights.save_class_weights(weight_path)
117 |
118 | # Example: to load weights from config file
119 | weights, img_count = calc_weights.load_class_weights(weight_path)
120 |
121 | if __name__ == '__main__':
122 | main()
123 |
124 |
--------------------------------------------------------------------------------
/configs/LASNet.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "LASNet",
3 |
4 | "inputs": "rgbd",
5 |
6 | "dataset": "irseg",
7 | "root": "./dataset/",
8 | "n_classes": 9,
9 | "id_unlabel": -1,
10 | "brightness": 0.5,
11 | "contrast": 0.5,
12 | "saturation": 0.5,
13 | "p": 0.5,
14 | "scales_range": "0.5 2.0",
15 | "crop_size": "480 640",
16 | "eval_scales": "0.5 0.75 1.0 1.25 1.5 1.75",
17 | "eval_flip": "true",
18 |
19 |
20 | "ims_per_gpu": 4,
21 | "num_workers": 4,
22 |
23 | "lr_start": 5e-5,
24 | "momentum": 0.9,
25 | "weight_decay": 5e-4,
26 | "lr_power": 0.9,
27 | "epochs": 200,
28 |
29 | "loss": "crossentropy",
30 | "class_weight": "enet"
31 | }
32 |
33 |
34 |
--------------------------------------------------------------------------------
/generate_binary_labels.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 | %Path of semantic gts
3 | gtPath = '/Volumes/RGBT_Semantic_Seg/PST900_RGBT_Dataset/labels/';
4 |
5 | savePath = '/Volumes/RGBT_Semantic_Seg/PST900_RGBT_Dataset/binary_labels/';
6 |
7 | gts = dir([gtPath '*.png']);
8 | gtsNum = length(gts);
9 |
10 |
11 | for i=1:gtsNum
12 | gt_name = gts(i).name();
13 |
14 | gt = imread(fullfile(gtPath, gt_name));
15 |
16 | gt(find(gt>1)) = 255;
17 |
18 | imwrite(gt, [savePath gt_name] );
19 |
20 | end
21 |
22 |
--------------------------------------------------------------------------------
/generate_bound_or_edge.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 | %Path of semantic gts
3 | ssgtPath = '/Volumes//RGBT_Semantic_Seg/PST900_RGBT_Dataset/labels/';
4 |
5 | savePath = '/Volumes/RGBT_Semantic_Seg/PST900_RGBT_Dataset/bound/';
6 |
7 | ssgts = dir([ssgtPath '*.png']);
8 | gtsNum = length(ssgts);
9 |
10 |
11 | for i=1:gtsNum
12 | ssgt_name = ssgts(i).name();
13 |
14 | ssgt = imread(fullfile(ssgtPath, ssgt_name));
15 |
16 | [h,w] = size(ssgt);
17 |
18 | bound = zeros(size(ssgt));
19 |
20 | padmap = zeros(h+4, w+4);
21 |
22 | padmap(3:h+2,3:w+2) = ssgt;
23 |
24 |
25 | for hh = 1:h
26 | for ww = 1:w
27 | slidewindow = padmap(hh:hh+4, ww:ww+4);
28 | class = unique(slidewindow);
29 | if length(class)>=2
30 | bound(hh,ww) = 255;
31 | end
32 | end
33 | end
34 |
35 |
36 | imwrite(uint8(bound), [savePath ssgt_name] );
37 |
38 | end
39 |
40 |
--------------------------------------------------------------------------------
/model/LASNet.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "LASNet",
3 |
4 | "inputs": "rgbd",
5 |
6 | "dataset": "irseg",
7 | "root": "./dataset/",
8 | "n_classes": 9,
9 | "id_unlabel": -1,
10 | "brightness": 0.5,
11 | "contrast": 0.5,
12 | "saturation": 0.5,
13 | "p": 0.5,
14 | "scales_range": "0.5 2.0",
15 | "crop_size": "480 640",
16 | "eval_scales": "0.5 0.75 1.0 1.25 1.5 1.75",
17 | "eval_flip": "true",
18 |
19 |
20 | "ims_per_gpu": 4,
21 | "num_workers": 4,
22 |
23 | "lr_start": 5e-5,
24 | "momentum": 0.9,
25 | "weight_decay": 5e-4,
26 | "lr_power": 0.9,
27 | "epochs": 200,
28 |
29 | "loss": "crossentropy",
30 | "class_weight": "enet"
31 | }
32 |
33 |
34 |
--------------------------------------------------------------------------------
/model/predicts_MFNet.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/model/predicts_MFNet.zip
--------------------------------------------------------------------------------
/model/predicts_PST900.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/model/predicts_PST900.zip
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | # import torchvision.models as models
2 | # import torch.nn as nn
3 | # # https://pytorch.org/docs/stable/torchvision/models.html#id3
4 | #
5 | import torch
6 | import torch.nn as nn
7 | import torch.utils.model_zoo as model_zoo
8 |
9 | model_urls = {
10 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
11 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
12 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
13 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
14 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
21 |
22 |
23 | def conv1x1(in_planes, out_planes, stride=1):
24 | """1x1 convolution"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None):
32 | super(BasicBlock, self).__init__()
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = nn.BatchNorm2d(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | identity = x
43 |
44 | out = self.conv1(x)
45 | out = self.bn1(out)
46 | out = self.relu(out)
47 |
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | identity = self.downsample(x)
53 |
54 | out += identity
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None):
64 | super(Bottleneck, self).__init__()
65 | self.conv1 = conv1x1(inplanes, planes)
66 | self.bn1 = nn.BatchNorm2d(planes)
67 | self.conv2 = conv3x3(planes, planes, stride)
68 | self.bn2 = nn.BatchNorm2d(planes)
69 | self.conv3 = conv1x1(planes, planes * self.expansion)
70 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | identity = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | identity = self.downsample(x)
91 |
92 | out += identity
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class ResNet(nn.Module):
99 | def __init__(self, block, layers, zero_init_residual=False):
100 | super(ResNet, self).__init__()
101 | self.inplanes = 64
102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 6
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 3
110 |
111 | for m in self.modules():
112 | if isinstance(m, nn.Conv2d):
113 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
114 | elif isinstance(m, nn.BatchNorm2d):
115 | nn.init.constant_(m.weight, 1)
116 | nn.init.constant_(m.bias, 0)
117 |
118 | # Zero-initialize the last BN in each residual branch,
119 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
120 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
121 | if zero_init_residual:
122 | for m in self.modules():
123 | if isinstance(m, Bottleneck):
124 | nn.init.constant_(m.bn3.weight, 0)
125 | elif isinstance(m, BasicBlock):
126 | nn.init.constant_(m.bn2.weight, 0)
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample))
137 | self.inplanes = planes * block.expansion
138 | for _ in range(1, blocks):
139 | layers.append(block(self.inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.relu(x)
147 | x = self.maxpool(x)
148 |
149 | x = self.layer1(x)
150 | x = self.layer2(x)
151 | x = self.layer3(x)
152 | x = self.layer4(x)
153 |
154 | return x
155 |
156 |
157 | def resnet18(pretrained=False, **kwargs):
158 | """Constructs a ResNet-18 model.
159 |
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if pretrained:
165 | pretrained_dict = model_zoo.load_url(model_urls["resnet18"])
166 |
167 | model_dict = model.state_dict()
168 | # 1. filter out unnecessary keys
169 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
170 | # 2. overwrite entries in the existing state dict
171 | model_dict.update(pretrained_dict)
172 | # 3. load the new state dict
173 | model.load_state_dict(model_dict)
174 | return model
175 |
176 |
177 | def resnet34(pretrained=False, **kwargs):
178 | """Constructs a ResNet-34 model.
179 |
180 | Args:
181 | pretrained (bool): If True, returns a model pre-trained on ImageNet
182 | """
183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
184 | if pretrained:
185 | pretrained_dict = model_zoo.load_url(model_urls["resnet34"])
186 |
187 | model_dict = model.state_dict()
188 | # 1. filter out unnecessary keys
189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
190 | # 2. overwrite entries in the existing state dict
191 | model_dict.update(pretrained_dict)
192 | # 3. load the new state dict
193 | model.load_state_dict(model_dict)
194 | return model
195 |
196 |
197 | def resnet50(pretrained=False, **kwargs):
198 | """Constructs a ResNet-50 model.
199 |
200 | Args:
201 | pretrained (bool): If True, returns a model pre-trained on ImageNet
202 | """
203 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
204 |
205 | if pretrained:
206 | pretrained_dict = model_zoo.load_url(model_urls["resnet50"])
207 |
208 | model_dict = model.state_dict()
209 | # 1. filter out unnecessary keys
210 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
211 | # 2. overwrite entries in the existing state dict
212 | model_dict.update(pretrained_dict)
213 | # 3. load the new state dict
214 | model.load_state_dict(model_dict)
215 |
216 | return model
217 |
218 |
219 | def resnet101(pretrained=False, **kwargs):
220 | """Constructs a ResNet-101 model.
221 |
222 | Args:
223 | pretrained (bool): If True, returns a model pre-trained on ImageNet
224 | """
225 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
226 | if pretrained:
227 | pretrained_dict = model_zoo.load_url(model_urls["resnet101"])
228 |
229 | model_dict = model.state_dict()
230 | # 1. filter out unnecessary keys
231 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
232 | # 2. overwrite entries in the existing state dict
233 | model_dict.update(pretrained_dict)
234 | # 3. load the new state dict
235 | model.load_state_dict(model_dict)
236 | return model
237 |
238 |
239 | def resnet152(pretrained=False, **kwargs):
240 | """Constructs a ResNet-152 model.
241 |
242 | Args:
243 | pretrained (bool): If True, returns a model pre-trained on ImageNet
244 | """
245 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
246 |
247 | if pretrained:
248 | pretrained_dict = model_zoo.load_url(model_urls["resnet152"])
249 |
250 | model_dict = model.state_dict()
251 | # 1. filter out unnecessary keys
252 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
253 | # 2. overwrite entries in the existing state dict
254 | model_dict.update(pretrained_dict)
255 | # 3. load the new state dict
256 | model.load_state_dict(model_dict)
257 |
258 | return model
259 |
260 |
261 | def Backbone_ResNet34_in3(pretrained=True):
262 | if pretrained:
263 | print("The backbone model loads the pretrained parameters...")
264 | net = resnet34(pretrained=pretrained)
265 | div_2 = nn.Sequential(*list(net.children())[:3])
266 | div_4 = nn.Sequential(*list(net.children())[3:5])
267 | div_8 = net.layer2
268 | div_16 = net.layer3
269 | div_32 = net.layer4
270 |
271 | return div_2, div_4, div_8, div_16, div_32
272 |
273 |
274 | def Backbone_ResNet34_in1(pretrained=True):
275 | if pretrained:
276 | print("The backbone model loads the pretrained parameters...")
277 | net = resnet34(pretrained=pretrained)
278 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
279 | div_2 = nn.Sequential(*list(net.children())[:3])
280 | div_4 = nn.Sequential(*list(net.children())[3:5])
281 | div_8 = net.layer2
282 | div_16 = net.layer3
283 | div_32 = net.layer4
284 |
285 | return div_2, div_4, div_8, div_16, div_32
286 |
287 | def Backbone_ResNet50_in3(pretrained=True):
288 | if pretrained:
289 | print("The backbone model loads the pretrained parameters...")
290 | net = resnet50(pretrained=pretrained)
291 | div_2 = nn.Sequential(*list(net.children())[:3])
292 | div_4 = nn.Sequential(*list(net.children())[3:5])
293 | div_8 = net.layer2
294 | div_16 = net.layer3
295 | div_32 = net.layer4
296 |
297 | return div_2, div_4, div_8, div_16, div_32
298 |
299 |
300 | def Backbone_ResNet50_in1(pretrained=True):
301 | if pretrained:
302 | print("The backbone model loads the pretrained parameters...")
303 | net = resnet50(pretrained=pretrained)
304 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
305 | div_2 = nn.Sequential(*list(net.children())[:3])
306 | div_4 = nn.Sequential(*list(net.children())[3:5])
307 | div_8 = net.layer2
308 | div_16 = net.layer3
309 | div_32 = net.layer4
310 |
311 | return div_2, div_4, div_8, div_16, div_32
312 |
313 |
314 | def Backbone_ResNet152_in3(pretrained=True):
315 | if pretrained:
316 | print("The backbone model loads the pretrained parameters...")
317 | net = resnet152(pretrained=pretrained)
318 | div_2 = nn.Sequential(*list(net.children())[:3])
319 | div_4 = nn.Sequential(*list(net.children())[3:5])
320 | div_8 = net.layer2
321 | div_16 = net.layer3
322 | div_32 = net.layer4
323 |
324 | return div_2, div_4, div_8, div_16, div_32
325 |
326 |
327 | def Backbone_ResNet152_in1(pretrained=True):
328 | if pretrained:
329 | print("The backbone model loads the pretrained parameters...")
330 | net = resnet152(pretrained=pretrained)
331 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
332 | div_2 = nn.Sequential(*list(net.children())[:3])
333 | div_4 = nn.Sequential(*list(net.children())[3:5])
334 | div_8 = net.layer2
335 | div_16 = net.layer3
336 | div_32 = net.layer4
337 |
338 | return div_2, div_4, div_8, div_16, div_32
339 |
340 |
341 | if __name__ == "__main__":
342 | div_2, div_4, div_8, div_16, div_32 = Backbone_ResNet50_in1()
343 | indata = torch.rand(4, 1, 480, 640)
344 | x1 = div_2(indata)
345 | x2 = div_4(x1)
346 | x3 = div_8(x2)
347 | x4 = div_16(x3)
348 | x5 = div_32(x4)
349 | # print(div_8)
350 | print(x1.size())
351 | print(x2.size())
352 | print(x3.size())
353 | print(x4.size())
354 | print(x5.size())
355 |
356 |
--------------------------------------------------------------------------------
/sober.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | from torchvision import transforms
4 | import numpy as np
5 |
6 |
7 | with open(os.path.join('/home/user/EGFNet/dataset', f'all.txt'), 'r') as f:
8 | image_labels = f.readlines()
9 | for i in range(len(image_labels)):
10 | label_path1 = image_labels[i].strip()
11 | imgrgb= cv2.imread('/home/user/EGFNet/dataset/seperated_images/' + label_path1 + '_rgb.png' , 0)
12 | imgdepth = cv2.imread('/home/user/EGFNet/dataset/seperated_images/' + label_path1 + '_th.png', 0)
13 |
14 |
15 | def tensor_to_PIL(tensor):
16 | image = tensor.squeeze(0)
17 | image = unloader(image)
18 | return image
19 |
20 |
21 |
22 | x1 = cv2.Sobel(imgrgb, cv2.CV_16S, 1, 0)
23 | y1 = cv2.Sobel(imgrgb, cv2.CV_16S, 0, 1)
24 | x2 = cv2.Sobel(imgdepth, cv2.CV_16S, 1, 0)
25 | y2 = cv2.Sobel(imgdepth, cv2.CV_16S, 0, 1)
26 |
27 | absX1 = cv2.convertScaleAbs(x1)
28 | absY1 = cv2.convertScaleAbs(y1)
29 | absX2 = cv2.convertScaleAbs(x2)
30 | absY2 = cv2.convertScaleAbs(y2)
31 |
32 | dst1 = cv2.addWeighted(absX1, 0.5, absY1, 0.5, 0)
33 | dst2 = cv2.addWeighted(absX2, 0.5, absY2, 0.5, 0)
34 | loader = transforms.Compose([
35 | transforms.ToTensor()])
36 | unloader = transforms.ToPILImage()
37 |
38 |
39 |
40 | dst1 = loader(dst1)
41 | dst2 = loader(dst2)
42 | dst = (dst1 + dst2) / 255.
43 |
44 | c = tensor_to_PIL(dst)
45 | c = np.array(c)
46 |
47 | cv2.imwrite('/home/user/EGFNet/dataset/edge/' + label_path1 + '.png', c)
48 |
49 |
50 |
--------------------------------------------------------------------------------
/test_LASNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from tqdm import tqdm
4 | from PIL import Image
5 | import json
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import torch.nn.functional as F
10 |
11 | from toolbox import get_model
12 | from toolbox import averageMeter, runningScore
13 | from toolbox import class_to_RGB, load_ckpt, save_ckpt
14 |
15 | from toolbox.datasets.irseg import IRSeg
16 | from toolbox.datasets.pst900 import PSTSeg
17 |
18 |
19 | def evaluate(logdir, save_predict=False, options=['val', 'test', 'test_day', 'test_night'], prefix=''):
20 | # 加载配置文件cfg
21 | cfg = None
22 | for file in os.listdir(logdir):
23 | if file.endswith('.json'):
24 | with open(os.path.join(logdir, file), 'r') as fp:
25 | cfg = json.load(fp)
26 | assert cfg is not None
27 |
28 | device = torch.device('cuda')
29 |
30 | loaders = []
31 | for opt in options:
32 | dataset = IRSeg(cfg, mode=opt)
33 | # dataset = PST900(cfg, mode=opt)
34 | loaders.append((opt, DataLoader(dataset, batch_size=1, shuffle=False, num_workers=cfg['num_workers'])))
35 | cmap = dataset.cmap
36 |
37 | model = get_model(cfg).to(device)
38 |
39 |
40 | model = load_ckpt(logdir, model, prefix=prefix)
41 |
42 | running_metrics_val = runningScore(cfg['n_classes'], ignore_index=cfg['id_unlabel'])
43 | time_meter = averageMeter()
44 |
45 | save_path = os.path.join(logdir, 'predicts')
46 | if not os.path.exists(save_path) and save_predict:
47 | os.mkdir(save_path)
48 |
49 | for name, test_loader in loaders:
50 | running_metrics_val.reset()
51 | print('#'*50 + ' ' + name+prefix + ' ' + '#'*50)
52 | with torch.no_grad():
53 | model.eval()
54 | for i, sample in tqdm(enumerate(test_loader), total=len(test_loader)):
55 |
56 | time_start = time.time()
57 |
58 | if cfg['inputs'] == 'rgb':
59 | image = sample['image'].to(device)
60 | label = sample['label'].to(device)
61 | predict = model(image)
62 | else:
63 | image = sample['image'].to(device)
64 | depth = sample['depth'].to(device)
65 | label = sample['label'].to(device)
66 | edge = sample['edge'].to(device)
67 | predict = model(image, depth)[0]
68 |
69 | predict = predict.max(1)[1].cpu().numpy()
70 | label = label.cpu().numpy()
71 | running_metrics_val.update(label, predict)
72 |
73 | time_meter.update(time.time() - time_start, n=image.size(0))
74 |
75 | if save_predict:
76 | predict = predict.squeeze(0)
77 | predict = class_to_RGB(predict, N=len(cmap), cmap=cmap)
78 | predict = Image.fromarray(predict)
79 | predict.save(os.path.join(save_path, sample['label_path'][0]))
80 |
81 | metrics = running_metrics_val.get_scores()
82 | print('overall metrics .....')
83 | for k, v in metrics[0].items():
84 | print(k, f'{v:.4f}')
85 |
86 | print('iou for each class .....')
87 | for k, v in metrics[1].items():
88 | print(k, f'{v:.4f}')
89 | print('acc for each class .....')
90 | for k, v in metrics[2].items():
91 | print(k, f'{v:.4f}')
92 |
93 |
94 |
95 | if __name__ == '__main__':
96 | import argparse
97 |
98 | parser = argparse.ArgumentParser(description="evaluate")
99 | parser.add_argument("--logdir", default="./model/", type=str,
100 | help="run logdir")
101 | parser.add_argument("-s", type=bool, default="./model/",
102 | help="save predict or not")
103 | args = parser.parse_args()
104 |
105 | # prefix option ['', 'best_val_', 'best_test_]
106 | # options=['test', 'test_day', 'test_night']
107 | evaluate(args.logdir, save_predict=args.s, options=['test'], prefix='')
108 | # evaluate(args.logdir, save_predict=args.s, options=['val'], prefix='')
109 | # evaluate(args.logdir, save_predict=args.s, options=['test_day'], prefix='')
110 | #evaluate(args.logdir, save_predict=args.s, options=['test_night'], prefix='')
111 | # msc_evaluate(args.logdir, save_predict=args.s)
112 |
--------------------------------------------------------------------------------
/toolbox/__init__.py:
--------------------------------------------------------------------------------
1 | from .metrics import averageMeter, runningScore
2 | from .log import get_logger
3 | from .optim import Ranger
4 |
5 | from .utils import ClassWeight, save_ckpt, load_ckpt, class_to_RGB, \
6 | compute_speed, setup_seed, group_weight_decay
7 |
8 |
9 | def get_dataset(cfg):
10 | assert cfg['dataset'] in ['nyuv2', 'nyuv2_new', 'sunrgbd', 'cityscapes', 'camvid', 'irseg', 'pst900', 'irseg_msv']
11 |
12 | if cfg['dataset'] == 'irseg':
13 | from .datasets.irseg import IRSeg
14 | # return IRSeg(cfg, mode='trainval'), IRSeg(cfg, mode='test')
15 | return IRSeg(cfg, mode='train'), IRSeg(cfg, mode='val'), IRSeg(cfg, mode='test')
16 | elif cfg['dataset'] == 'pst900':
17 | from .datasets.pst900 import PSTSeg
18 | # return IRSeg(cfg, mode='trainval'), IRSeg(cfg, mode='test')
19 | return PSTSeg(cfg, mode='train'), PSTSeg(cfg, mode='val'), PSTSeg(cfg, mode='test')
20 |
21 |
22 | def get_model(cfg):
23 |
24 | if cfg['model_name'] == 'EGFNet':
25 | from .models.EGFNet import EGFNet
26 | return EGFNet(n_classes=cfg['n_classes'])
27 | else:
28 | from .models.LASNet import LASNet
29 | return LASNet(n_classes=cfg['n_classes'])
30 |
--------------------------------------------------------------------------------
/toolbox/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/dual_self_att.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/dual_self_att.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/dual_self_att.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/dual_self_att.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/log.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/log.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/log.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/log.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/log.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/log.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/losses.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/losses.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/losses.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/losses.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/augmentations.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/augmentations.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/augmentations.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/augmentations.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/augmentations.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/augmentations.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/camvid.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/camvid.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/irseg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/irseg.cpython-36.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/irseg.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/irseg.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/irseg.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/irseg.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/nyuv2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/nyuv2.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/__pycache__/pst900.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/pst900.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/datasets/augmentations.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import sys
3 | import random
4 | from PIL import Image
5 |
6 | try:
7 | import accimage
8 | except ImportError:
9 | accimage = None
10 | import numbers
11 | import collections
12 |
13 | import torchvision.transforms.functional as F
14 |
15 | __all__ = ["Compose",
16 | "Resize", # 尺寸缩减到对应size, 如果给定size为int,尺寸缩减到(size * height / width, size)
17 | "RandomScale", # 尺寸随机缩放
18 | "RandomCrop", # 随机裁剪,必要时可以进行padding
19 | "RandomHorizontalFlip", # 随机水平翻转
20 | "ColorJitter", # 亮度,对比度,饱和度,色调
21 | "RandomRotation", # 随机旋转
22 | ]
23 |
24 | _pil_interpolation_to_str = {
25 | Image.NEAREST: 'PIL.Image.NEAREST',
26 | Image.BILINEAR: 'PIL.Image.BILINEAR',
27 | Image.BICUBIC: 'PIL.Image.BICUBIC',
28 | Image.LANCZOS: 'PIL.Image.LANCZOS',
29 | Image.HAMMING: 'PIL.Image.HAMMING',
30 | Image.BOX: 'PIL.Image.BOX',
31 | }
32 |
33 | if sys.version_info < (3, 3):
34 | Sequence = collections.Sequence
35 | Iterable = collections.Iterable
36 | else:
37 | Sequence = collections.abc.Sequence
38 | Iterable = collections.abc.Iterable
39 |
40 |
41 | class Lambda(object):
42 | """Apply a user-defined lambda as a transform.
43 |
44 | Args:
45 | lambd (function): Lambda/function to be used for transform.
46 | """
47 |
48 | def __init__(self, lambd):
49 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable"
50 | self.lambd = lambd
51 |
52 | def __call__(self, img):
53 | return self.lambd(img)
54 |
55 |
56 | class Compose(object):
57 | def __init__(self, transforms):
58 | self.transforms = transforms
59 |
60 | def __call__(self, sample):
61 | for t in self.transforms:
62 | sample = t(sample)
63 | return sample
64 |
65 |
66 | class Resize(object):
67 | def __init__(self, size):
68 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
69 | self.size = size
70 |
71 | def __call__(self, sample):
72 | assert 'image' in sample.keys()
73 | assert 'label' in sample.keys()
74 |
75 | for key in sample.keys():
76 | # BILINEAR for image
77 | if key in ['image']:
78 | sample[key] = F.resize(sample[key], self.size, Image.BILINEAR)
79 | # NEAREST for depth, label, bound
80 | else:
81 | sample[key] = F.resize(sample[key], self.size, Image.NEAREST)
82 |
83 | return sample
84 |
85 |
86 | class RandomCrop(object):
87 |
88 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'):
89 | if isinstance(size, numbers.Number):
90 | self.size = (int(size), int(size))
91 | else:
92 | self.size = size
93 | self.padding = padding
94 | self.pad_if_needed = pad_if_needed
95 | self.fill = fill
96 | self.padding_mode = padding_mode
97 |
98 | @staticmethod
99 | def get_params(img, output_size):
100 | w, h = img.size
101 | th, tw = output_size
102 | if w == tw and h == th:
103 | return 0, 0, h, w
104 |
105 | i = random.randint(0, h - th)
106 | j = random.randint(0, w - tw)
107 | return i, j, th, tw
108 |
109 | def __call__(self, sample):
110 | img = sample['image']
111 | if self.padding is not None:
112 | for key in sample.keys():
113 | sample[key] = F.pad(sample[key], self.padding, self.fill, self.padding_mode)
114 |
115 | # pad the width if needed
116 | if self.pad_if_needed and img.size[0] < self.size[1]:
117 | for key in sample.keys():
118 | sample[key] = F.pad(sample[key], (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
119 | # pad the height if needed
120 | if self.pad_if_needed and img.size[1] < self.size[0]:
121 | for key in sample.keys():
122 | sample[key] = F.pad(sample[key], (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
123 |
124 | i, j, h, w = self.get_params(sample['image'], self.size)
125 | for key in sample.keys():
126 | sample[key] = F.crop(sample[key], i, j, h, w)
127 |
128 | return sample
129 |
130 |
131 | class RandomHorizontalFlip(object):
132 |
133 | def __init__(self, p=0.5):
134 | self.p = p
135 |
136 | def __call__(self, sample):
137 | if random.random() < self.p:
138 | for key in sample.keys():
139 | sample[key] = F.hflip(sample[key])
140 |
141 | return sample
142 |
143 |
144 | class RandomScale(object):
145 | def __init__(self, scale):
146 | assert isinstance(scale, Iterable) and len(scale) == 2
147 | assert 0 < scale[0] <= scale[1]
148 | self.scale = scale
149 |
150 | def __call__(self, sample):
151 | assert 'image' in sample.keys()
152 | assert 'label' in sample.keys()
153 |
154 | w, h = sample['image'].size
155 |
156 | scale = random.uniform(self.scale[0], self.scale[1])
157 | size = (int(round(h * scale)), int(round(w * scale)))
158 |
159 | for key in sample.keys():
160 | # BILINEAR for image
161 | if key in ['image']:
162 | sample[key] = F.resize(sample[key], size, Image.BILINEAR)
163 | # NEAREST for depth, label, bound
164 | else:
165 | sample[key] = F.resize(sample[key], size, Image.NEAREST)
166 |
167 | return sample
168 |
169 |
170 | class ColorJitter(object):
171 | """Randomly change the brightness, contrast and saturation of an image.
172 |
173 | Args:
174 | brightness (float or tuple of float (min, max)): How much to jitter brightness.
175 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
176 | or the given [min, max]. Should be non negative numbers.
177 | contrast (float or tuple of float (min, max)): How much to jitter contrast.
178 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
179 | or the given [min, max]. Should be non negative numbers.
180 | saturation (float or tuple of float (min, max)): How much to jitter saturation.
181 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
182 | or the given [min, max]. Should be non negative numbers.
183 | hue (float or tuple of float (min, max)): How much to jitter hue.
184 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
185 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
186 | """
187 |
188 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
189 | self.brightness = self._check_input(brightness, 'brightness')
190 | self.contrast = self._check_input(contrast, 'contrast')
191 | self.saturation = self._check_input(saturation, 'saturation')
192 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
193 | clip_first_on_zero=False)
194 |
195 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
196 | if isinstance(value, numbers.Number):
197 | if value < 0:
198 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
199 | value = [center - value, center + value]
200 | if clip_first_on_zero:
201 | value[0] = max(value[0], 0)
202 | elif isinstance(value, (tuple, list)) and len(value) == 2:
203 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
204 | raise ValueError("{} values should be between {}".format(name, bound))
205 | else:
206 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
207 |
208 | # if value is 0 or (1., 1.) for brightness/contrast/saturation
209 | # or (0., 0.) for hue, do nothing
210 | if value[0] == value[1] == center:
211 | value = None
212 | return value
213 |
214 | @staticmethod
215 | def get_params(brightness, contrast, saturation, hue):
216 |
217 | transforms = []
218 |
219 | if brightness is not None:
220 | brightness_factor = random.uniform(brightness[0], brightness[1])
221 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor)))
222 |
223 | if contrast is not None:
224 | contrast_factor = random.uniform(contrast[0], contrast[1])
225 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor)))
226 |
227 | if saturation is not None:
228 | saturation_factor = random.uniform(saturation[0], saturation[1])
229 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor)))
230 |
231 | if hue is not None:
232 | hue_factor = random.uniform(hue[0], hue[1])
233 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
234 |
235 | random.shuffle(transforms)
236 | transform = Compose(transforms)
237 |
238 | return transform
239 |
240 | def __call__(self, sample):
241 | assert 'image' in sample.keys()
242 | transform = self.get_params(self.brightness, self.contrast,
243 | self.saturation, self.hue)
244 | sample['image'] = transform(sample['image'])
245 | return sample
246 |
247 |
248 | class RandomRotation(object):
249 |
250 | def __init__(self, degrees, resample=False, expand=False, center=None):
251 | if isinstance(degrees, numbers.Number):
252 | if degrees < 0:
253 | raise ValueError("If degrees is a single number, it must be positive.")
254 | self.degrees = (-degrees, degrees)
255 | else:
256 | if len(degrees) != 2:
257 | raise ValueError("If degrees is a sequence, it must be of len 2.")
258 | self.degrees = degrees
259 |
260 | self.resample = resample
261 | self.expand = expand
262 | self.center = center
263 |
264 | @staticmethod
265 | def get_params(degrees):
266 |
267 | return random.uniform(degrees[0], degrees[1])
268 |
269 | def __call__(self, sample):
270 |
271 | angle = self.get_params(self.degrees)
272 | for key in sample.keys():
273 | sample[key] = F.rotate(sample[key], angle, self.resample, self.expand, self.center)
274 |
275 | return sample
276 |
--------------------------------------------------------------------------------
/toolbox/datasets/camvid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 |
5 | import torch
6 | import torch.utils.data as data
7 | from torchvision import transforms
8 |
9 | from toolbox.datasets.augmentations import Resize, Compose, ColorJitter, RandomHorizontalFlip, RandomCrop, RandomScale
10 |
11 |
12 | class Camvid(data.Dataset):
13 |
14 | def __init__(self, cfg, mode='trainval', do_aug=True):
15 |
16 | assert mode in ['trainval', 'test'], f'{mode} not support.'
17 | self.mode = mode
18 |
19 | ## pre-processing
20 | self.im_to_tensor = transforms.Compose([
21 | transforms.ToTensor(),
22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
23 | ])
24 |
25 | self.root = os.path.join(cfg['root'], 'all_data')
26 | self.n_classes = cfg['n_classes']
27 |
28 | scale_range = tuple(float(i) for i in cfg['scales_range'].split(' '))
29 | crop_size = tuple(int(i) for i in cfg['crop_size'].split(' '))
30 |
31 | self.aug = Compose([
32 | ColorJitter(
33 | brightness=cfg['brightness'],
34 | contrast=cfg['contrast'],
35 | saturation=cfg['saturation']),
36 | RandomHorizontalFlip(cfg['p']),
37 | RandomScale(scale_range),
38 | RandomCrop(crop_size, pad_if_needed=True)
39 | ])
40 |
41 | self.val_resize = Resize(crop_size)
42 |
43 | self.mode = mode
44 | self.do_aug = do_aug
45 |
46 | if cfg['class_weight'] == 'enet':
47 | self.class_weight = np.array(
48 | [6.3040, 4.3505, 35.0686, 3.4997, 14.0079, 8.0937, 32.6272, 28.6828, 14.8280, 38.3528, 37.4353,
49 | 18.7975])
50 | elif cfg['class_weight'] == 'median_freq_balancing':
51 | self.class_weight = np.array(
52 | [0.2778, 0.1770, 4.7280, 0.1358, 0.7816, 0.3785, 3.7939, 2.5866, 0.8480, 6.5770, 5.8139, 1.2184])
53 | else:
54 | raise (f"{cfg['class_weight']} not support.")
55 |
56 | with open(os.path.join(self.root, f'{mode}.txt'), 'r') as f:
57 | self.infos = f.readlines()
58 |
59 | def __len__(self):
60 | return len(self.infos)
61 |
62 | def __getitem__(self, index):
63 | image_path = self.infos[index].strip()
64 |
65 | image = Image.open(os.path.join(self.root, 'image', self.mode, image_path)) # RGB 0~255
66 | label = Image.open(os.path.join(self.root, 'label', self.mode, image_path)) # 1 channel 0~11
67 | # bound = Image.open(os.path.join(self.root, 'bound', self.mode, image_path))
68 |
69 | # move unlabel_id from 11 to 0
70 | label = np.asarray(label)
71 | label = label + 1
72 | label[label == 12] = 0
73 | label = Image.fromarray(label)
74 |
75 | sample = {
76 | 'image': image,
77 | # 'bound': bound,
78 | 'label': label,
79 | }
80 |
81 | if self.mode in ['train', 'trainval'] and self.do_aug: # 只对训练集增强
82 | sample = self.aug(sample)
83 | else:
84 | sample = self.val_resize(sample)
85 |
86 | sample['image'] = self.im_to_tensor(sample['image'])
87 | sample['label'] = torch.from_numpy(np.asarray(sample['label'], dtype=np.int64)).long()
88 | # sample['bound'] = torch.from_numpy(np.asarray(sample['bound'], dtype=np.int64)).long()
89 |
90 | sample['label_path'] = image_path.strip().split('/')[-1] # 后期保存预测图时的文件名和label文件名一致
91 | return sample
92 |
93 | @property
94 | def cmap(self):
95 | return [
96 | (0, 0, 0), # unlabeled
97 |
98 | (128, 128, 128), # sky
99 | (128, 0, 0), # building
100 | (192, 192, 128), # pole
101 | (128, 64, 128), # road
102 | (0, 0, 192), # pavement sidewalk
103 | (128, 128, 0), # tree
104 | (192, 128, 128), # sign_symbol
105 | (64, 64, 128), # fence
106 | (64, 0, 128), # car
107 | (64, 64, 0), # pedestrian
108 | (0, 128, 192), # bicyclist
109 |
110 | ]
111 |
112 |
113 | if __name__ == '__main__':
114 | import json
115 |
116 | path = '/home/dtrimina/Desktop/lxy/Segmentation_final/configs/bbbmodel/camvid_bbbmodel.json'
117 | with open(path, 'r') as fp:
118 | cfg = json.load(fp)
119 | cfg['root'] = '/home/dtrimina/Desktop/lxy/database/camvid'
120 |
121 |
122 | dataset = Camvid(cfg, mode='trainval', do_aug=True)
123 | from toolbox.utils import class_to_RGB
124 | import matplotlib.pyplot as plt
125 |
126 | for i in range(len(dataset)):
127 | sample = dataset[i]
128 |
129 | image = sample['image']
130 | label = sample['label']
131 |
132 | image = image.numpy()
133 | image = image.transpose((1, 2, 0))
134 | image *= np.asarray([0.229, 0.224, 0.225])
135 | image += np.asarray([0.485, 0.456, 0.406])
136 |
137 | label = label.numpy()
138 | label = class_to_RGB(label, N=len(dataset.cmap), cmap=dataset.cmap)
139 |
140 | plt.subplot('121')
141 | plt.imshow(image)
142 | plt.subplot('122')
143 | plt.imshow(label)
144 |
145 | plt.show()
146 |
147 | if i == 10:
148 | break
149 |
150 |
151 | # dataset = Camvid(cfg, mode='trainval', do_aug=False)
152 | # from toolbox.utils import ClassWeight
153 | #
154 | # train_loader = torch.utils.data.DataLoader(dataset, batch_size=cfg['ims_per_gpu'], shuffle=True,
155 | # num_workers=cfg['num_workers'], pin_memory=True)
156 | # classweight = ClassWeight('median_freq_balancing') # enet, median_freq_balancing
157 | # class_weight = classweight.get_weight(train_loader, cfg['n_classes'])
158 | # class_weight = torch.from_numpy(class_weight).float()
159 | # # class_weight[cfg['id_unlabel']] = 0
160 | #
161 | # print(class_weight)
162 | #
163 | # # # median_freq_balancing
164 | # # tensor([0.2778, 0.1770, 4.7280, 0.1358, 0.7816, 0.3785, 3.7939, 2.5866, 0.8480,
165 | # # 6.5770, 5.8139, 1.2184])
166 | #
167 | # # # enet
168 | # # tensor([6.3040, 4.3505, 35.0686, 3.4997, 14.0079, 8.0937, 32.6272, 28.6828,
169 | # # 14.8280, 38.3528, 37.4353, 18.7975])
170 |
--------------------------------------------------------------------------------
/toolbox/datasets/irseg.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | from sklearn.model_selection import train_test_split
5 |
6 | import torch
7 | import torch.utils.data as data
8 | from torchvision import transforms
9 | from toolbox.datasets.augmentations import Resize, Compose, ColorJitter, RandomHorizontalFlip, RandomCrop, RandomScale, \
10 | RandomRotation
11 |
12 |
13 | class IRSeg(data.Dataset):
14 |
15 | def __init__(self, cfg, mode='trainval', do_aug=True):
16 |
17 | assert mode in ['train', 'val', 'trainval', 'test', 'test_day', 'test_night'], f'{mode} not support.'
18 | self.mode = mode
19 |
20 | ## pre-processing
21 | self.im_to_tensor = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24 | ])
25 |
26 | self.dp_to_tensor = transforms.Compose([
27 | transforms.ToTensor(),
28 | transforms.Normalize([0.449, 0.449, 0.449], [0.226, 0.226, 0.226]),
29 | ])
30 |
31 | self.root = cfg['root']
32 | self.n_classes = cfg['n_classes']
33 |
34 | scale_range = tuple(float(i) for i in cfg['scales_range'].split(' '))
35 | crop_size = tuple(int(i) for i in cfg['crop_size'].split(' '))
36 |
37 | self.aug = Compose([
38 | ColorJitter(
39 | brightness=cfg['brightness'],
40 | contrast=cfg['contrast'],
41 | saturation=cfg['saturation']),
42 | RandomHorizontalFlip(cfg['p']),
43 | RandomScale(scale_range),
44 | RandomCrop(crop_size, pad_if_needed=True)
45 | ])
46 |
47 |
48 | self.mode = mode
49 | self.do_aug = do_aug
50 |
51 | if cfg['class_weight'] == 'enet':
52 | self.class_weight = np.array(
53 | [1.5105, 16.6591, 29.4238, 34.6315, 40.0845, 41.4357, 47.9794, 45.3725, 44.9000])
54 | self.binary_class_weight = np.array([1.5121, 10.2388])
55 | elif cfg['class_weight'] == 'median_freq_balancing':
56 | self.class_weight = np.array(
57 | [0.0118, 0.2378, 0.7091, 1.0000, 1.9267, 1.5433, 0.9057, 3.2556, 1.0686])
58 | self.binary_class_weight = np.array([0.5454, 6.0061])
59 | else:
60 | raise (f"{cfg['class_weight']} not support.")
61 |
62 | with open(os.path.join(self.root, f'{mode}.txt'), 'r') as f:
63 | self.infos = f.readlines()
64 |
65 | def __len__(self):
66 | return len(self.infos)
67 |
68 | def __getitem__(self, index):
69 | image_path = self.infos[index].strip()
70 |
71 |
72 | image = Image.open(os.path.join(self.root, 'seperated_images', image_path + '_rgb.png'))
73 | depth = Image.open(os.path.join(self.root, 'seperated_images', image_path + '_th.png')).convert('RGB')
74 | label = Image.open(os.path.join(self.root, 'labels', image_path + '.png'))
75 | bound = Image.open(os.path.join(self.root, 'bound', image_path+'.png'))
76 | edge = Image.open(os.path.join(self.root, 'edge', image_path+'.png'))
77 | binary_label = Image.open(os.path.join(self.root, 'binary_labels', image_path + '.png'))
78 |
79 |
80 | sample = {
81 | 'image': image,
82 | 'depth': depth, # depth is TIR image.
83 | 'label': label,
84 | 'bound': bound,
85 | 'edge': edge,
86 | 'binary_label': binary_label,
87 | }
88 |
89 | if self.mode in ['train', 'trainval'] and self.do_aug: # 只对训练集增强
90 | sample = self.aug(sample)
91 |
92 | sample['image'] = self.im_to_tensor(sample['image'])
93 | sample['depth'] = self.dp_to_tensor(sample['depth'])
94 | sample['label'] = torch.from_numpy(np.asarray(sample['label'], dtype=np.int64)).long()
95 | sample['edge'] = torch.from_numpy(np.asarray(sample['edge'], dtype=np.int64)).long()
96 | sample['bound'] = torch.from_numpy(np.asarray(sample['bound'], dtype=np.int64) / 255.).long()
97 | sample['binary_label'] = torch.from_numpy(np.asarray(sample['binary_label'], dtype=np.int64) / 255.).long()
98 | sample['label_path'] = image_path.strip().split('/')[-1] + '.png' # 后期保存预测图时的文件名和label文件名一致
99 | return sample
100 |
101 | @property
102 | def cmap(self):
103 | return [
104 | (0, 0, 0), # unlabelled
105 | (64, 0, 128), # car
106 | (64, 64, 0), # person
107 | (0, 128, 192), # bike
108 | (0, 0, 192), # curve
109 | (128, 128, 0), # car_stop
110 | (64, 64, 128), # guardrail
111 | (192, 128, 128), # color_cone
112 | (192, 64, 0), # bump
113 | ]
114 |
115 |
116 |
--------------------------------------------------------------------------------
/toolbox/datasets/pst900.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 | from sklearn.model_selection import train_test_split
5 |
6 | import torch
7 | import torch.utils.data as data
8 | from torchvision import transforms
9 | from toolbox.datasets.augmentations import Resize, Compose, ColorJitter, RandomHorizontalFlip, RandomCrop, RandomScale, \
10 | RandomRotation
11 |
12 |
13 | class PSTSeg(data.Dataset):
14 |
15 | def __init__(self, cfg, mode='trainval', do_aug=True):
16 |
17 | assert mode in ['train', 'val', 'trainval', 'test'], f'{mode} not support.'
18 | self.mode = mode
19 |
20 | ## pre-processing
21 | self.im_to_tensor = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
24 | ])
25 |
26 | self.dp_to_tensor = transforms.Compose([
27 | transforms.ToTensor(),
28 | transforms.Normalize([0.449, 0.449, 0.449], [0.226, 0.226, 0.226]),
29 | ])
30 |
31 | self.root = cfg['root']
32 | self.n_classes = cfg['n_classes']
33 |
34 | scale_range = tuple(float(i) for i in cfg['scales_range'].split(' '))
35 | crop_size = tuple(int(i) for i in cfg['crop_size'].split(' '))
36 |
37 | self.aug = Compose([
38 | ColorJitter(
39 | brightness=cfg['brightness'],
40 | contrast=cfg['contrast'],
41 | saturation=cfg['saturation']),
42 | RandomHorizontalFlip(cfg['p']),
43 | RandomScale(scale_range),
44 | RandomCrop(crop_size, pad_if_needed=True)
45 | ])
46 |
47 |
48 | self.mode = mode
49 | self.do_aug = do_aug
50 |
51 | if cfg['class_weight'] == 'enet':
52 | self.class_weight = np.array(
53 | [1.4537, 44.2457, 31.6650, 46.4071, 30.1391])
54 | self.binary_class_weight = np.array([1.4507, 21.5033])
55 | else:
56 | raise (f"{cfg['class_weight']} not support.")
57 |
58 | with open(os.path.join(self.root, f'{mode}.txt'), 'r') as f:
59 | self.infos = f.readlines()
60 |
61 | def __len__(self):
62 | return len(self.infos)
63 |
64 | def __getitem__(self, index):
65 | image_path = self.infos[index].strip()
66 |
67 |
68 | image = Image.open(os.path.join(self.root, 'rgb', image_path + '.png'))
69 | depth = Image.open(os.path.join(self.root, 'thermal', image_path + '.png')).convert('RGB')
70 | label = Image.open(os.path.join(self.root, 'labels', image_path + '.png'))
71 | bound = Image.open(os.path.join(self.root, 'bound', image_path+'.png'))
72 | edge = Image.open(os.path.join(self.root, 'bound', image_path+'.png'))
73 | binary_label = Image.open(os.path.join(self.root, 'binary_labels', image_path + '.png'))
74 |
75 |
76 | sample = {
77 | 'image': image,
78 | 'depth': depth,
79 | 'label': label,
80 | 'bound': bound,
81 | 'edge': edge,
82 | 'binary_label': binary_label,
83 | }
84 |
85 | if self.mode in ['train', 'trainval'] and self.do_aug: # 只对训练集增强
86 | sample = self.aug(sample)
87 |
88 | sample['image'] = self.im_to_tensor(sample['image'])
89 | sample['depth'] = self.dp_to_tensor(sample['depth'])
90 | sample['label'] = torch.from_numpy(np.asarray(sample['label'], dtype=np.int64)).long()
91 | sample['edge'] = torch.from_numpy(np.asarray(sample['edge'], dtype=np.int64)).long() # 没有edge
92 | sample['bound'] = torch.from_numpy(np.asarray(sample['bound'], dtype=np.int64) / 255.).long()
93 | sample['binary_label'] = torch.from_numpy(np.asarray(sample['binary_label'], dtype=np.int64) / 255.).long()
94 | sample['label_path'] = image_path.strip().split('/')[-1] + '.png' # 后期保存预测图时的文件名和label文件名一致
95 | return sample
96 |
97 | @property
98 | def cmap(self):
99 | return [
100 | [0, 0, 0], # background
101 | [0, 0, 255], # fire_extinguisher
102 | [0, 255, 0], # backpack
103 | [255, 0, 0], # drill
104 | [255, 255, 255], # survivor/rescue_randy
105 | ]
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/toolbox/dual_self_att.py:
--------------------------------------------------------------------------------
1 | ###########################################################################
2 | # Created by: CASIA IVA
3 | # Email: jliu@nlpr.ia.ac.cn
4 | # Copyright (c) 2018
5 | ###########################################################################
6 |
7 | import numpy as np
8 | import torch
9 | import math
10 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \
11 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding
12 | from torch.nn import functional as F
13 | from torch.autograd import Variable
14 | torch_ver = torch.__version__[:3]
15 |
16 | __all__ = ['PAM_Module', 'CAM_Module']
17 |
18 |
19 | class PAM_Module(Module):
20 | """ Position attention module"""
21 | #Ref from SAGAN
22 | def __init__(self, in_dim):
23 | super(PAM_Module, self).__init__()
24 | self.chanel_in = in_dim
25 |
26 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
27 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
28 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
29 | self.gamma = Parameter(torch.zeros(1))
30 |
31 | self.softmax = Softmax(dim=-1)
32 | def forward(self, x):
33 | """
34 | inputs :
35 | x : input feature maps( B X C X H X W)
36 | returns :
37 | out : attention value + input feature
38 | attention: B X (HxW) X (HxW)
39 | """
40 | m_batchsize, C, height, width = x.size()
41 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
42 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
43 | energy = torch.bmm(proj_query, proj_key)
44 | attention = self.softmax(energy)
45 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)
46 |
47 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
48 | out = out.view(m_batchsize, C, height, width)
49 |
50 | out = self.gamma*out + x
51 | return out
52 |
53 |
54 | class CAM_Module(Module):
55 | """ Channel attention module"""
56 | def __init__(self, in_dim):
57 | super(CAM_Module, self).__init__()
58 | self.chanel_in = in_dim
59 |
60 |
61 | self.gamma = Parameter(torch.zeros(1))
62 | self.softmax = Softmax(dim=-1)
63 | def forward(self,x):
64 | """
65 | inputs :
66 | x : input feature maps( B X C X H X W)
67 | returns :
68 | out : attention value + input feature
69 | attention: B X C X C
70 | """
71 | m_batchsize, C, height, width = x.size()
72 | proj_query = x.view(m_batchsize, C, -1)
73 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
74 | energy = torch.bmm(proj_query, proj_key)
75 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
76 | attention = self.softmax(energy_new)
77 | proj_value = x.view(m_batchsize, C, -1)
78 |
79 | out = torch.bmm(attention, proj_value)
80 | out = out.view(m_batchsize, C, height, width)
81 |
82 | out = self.gamma*out + x
83 | return out
84 |
85 |
--------------------------------------------------------------------------------
/toolbox/log.py:
--------------------------------------------------------------------------------
1 | """
2 | 日志记录
3 | 同时输出到屏幕和文件
4 | 可以通过日志等级,将训练最后得到的结果发送到邮箱,参考下面example
5 |
6 | """
7 |
8 | import logging
9 | import os
10 | import sys
11 | import time
12 |
13 |
14 | def get_logger(logdir):
15 |
16 | if not os.path.exists(logdir):
17 | os.makedirs(logdir)
18 | logname = f'run-{time.strftime("%Y-%m-%d-%H-%M")}.log'
19 | log_file = os.path.join(logdir, logname)
20 |
21 | # create log
22 | logger = logging.getLogger('train')
23 | logger.setLevel(logging.INFO)
24 |
25 | # Formatter 设置日志输出格式
26 | formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
27 |
28 | # StreamHandler 日志输出1 -> 到控制台
29 | stream_handler = logging.StreamHandler(sys.stdout)
30 | stream_handler.setFormatter(formatter)
31 | logger.addHandler(stream_handler)
32 |
33 | # FileHandler 日志输出2 -> 保存到文件log_file
34 | file_handler = logging.FileHandler(log_file)
35 | file_handler.setFormatter(formatter)
36 | logger.addHandler(file_handler)
37 |
38 | return logger
39 |
40 |
41 | # # example 输出到邮箱
42 | # from logging.handlers import SMTPHandler
43 | #
44 | # logger = logging.getLogger('train')
45 | # logger.setLevel(logging.INFO)
46 | #
47 | # SMTP_handler = SMTPHandler(
48 | # mailhost=('smtp.163.com', 25),
49 | # fromaddr='xxx163emailxxx@163.com',
50 | # toaddrs=['xxxqqemailxxx@qq.com', 'or other emails you want to send'],
51 | # subject='send title',
52 | # credentials=('fromaddr email', 'fromaddr passwd')
53 | # )
54 | #
55 | # formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
56 | # SMTP_handler.setFormatter(formatter)
57 | # SMTP_handler.setLevel(logging.WARNING) # 设置等级为warning, logger.warning('infos')将会把重要结果信息输出到邮箱
58 | # logger.addHandler(SMTP_handler)
59 | #
60 | # logging.warning('information need to be send to email. the final results_old or errors')
61 |
62 |
--------------------------------------------------------------------------------
/toolbox/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch
3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
4 | """
5 | #https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py
6 | from __future__ import print_function, division
7 |
8 | import torch
9 | from torch.autograd import Variable
10 | import torch.nn.functional as F
11 | import numpy as np
12 | try:
13 | from itertools import ifilterfalse
14 | except ImportError: # py3k
15 | from itertools import filterfalse as ifilterfalse
16 |
17 |
18 | def lovasz_grad(gt_sorted):
19 | """
20 | Computes gradient of the Lovasz extension w.r.t sorted errors
21 | See Alg. 1 in paper
22 | """
23 | p = len(gt_sorted)
24 | gts = gt_sorted.sum()
25 | intersection = gts - gt_sorted.float().cumsum(0)
26 | union = gts + (1 - gt_sorted).float().cumsum(0)
27 | jaccard = 1. - intersection / union
28 | if p > 1: # cover 1-pixel case
29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
30 | return jaccard
31 |
32 |
33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
34 | """
35 | IoU for foreground class
36 | binary: 1 foreground, 0 background
37 | """
38 | if not per_image:
39 | preds, labels = (preds,), (labels,)
40 | ious = []
41 | for pred, label in zip(preds, labels):
42 | intersection = ((label == 1) & (pred == 1)).sum()
43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
44 | if not union:
45 | iou = EMPTY
46 | else:
47 | iou = float(intersection) / float(union)
48 | ious.append(iou)
49 | iou = mean(ious) # mean accross images if per_image
50 | return 100 * iou
51 |
52 |
53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
54 | """
55 | Array of IoU for each (non ignored) class
56 | """
57 | if not per_image:
58 | preds, labels = (preds,), (labels,)
59 | ious = []
60 | for pred, label in zip(preds, labels):
61 | iou = []
62 | for i in range(C):
63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
64 | intersection = ((label == i) & (pred == i)).sum()
65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum()
66 | if not union:
67 | iou.append(EMPTY)
68 | else:
69 | iou.append(float(intersection) / float(union))
70 | ious.append(iou)
71 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
72 | return 100 * np.array(ious)
73 |
74 |
75 | # --------------------------- BINARY LOSSES ---------------------------
76 |
77 |
78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None):
79 | """
80 | Binary Lovasz hinge loss
81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
83 | per_image: compute the loss per image instead of per batch
84 | ignore: void class id
85 | """
86 | if per_image:
87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
88 | for log, lab in zip(logits, labels))
89 | else:
90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
91 | return loss
92 |
93 |
94 | def lovasz_hinge_flat(logits, labels):
95 | """
96 | Binary Lovasz hinge loss
97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
98 | labels: [P] Tensor, binary ground truth labels (0 or 1)
99 | ignore: label to ignore
100 | """
101 | if len(labels) == 0:
102 | # only void pixels, the gradients should be 0
103 | return logits.sum() * 0.
104 | signs = 2. * labels.float() - 1.
105 | errors = (1. - logits * Variable(signs))
106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
107 | perm = perm.data
108 | gt_sorted = labels[perm]
109 | grad = lovasz_grad(gt_sorted)
110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad))
111 | return loss
112 |
113 |
114 | def flatten_binary_scores(scores, labels, ignore=None):
115 | """
116 | Flattens predictions in the batch (binary case)
117 | Remove labels equal to 'ignore'
118 | """
119 | scores = scores.view(-1)
120 | labels = labels.view(-1)
121 | if ignore is None:
122 | return scores, labels
123 | valid = (labels != ignore)
124 | vscores = scores[valid]
125 | vlabels = labels[valid]
126 | return vscores, vlabels
127 |
128 |
129 | class StableBCELoss(torch.nn.modules.Module):
130 | def __init__(self):
131 | super(StableBCELoss, self).__init__()
132 | def forward(self, input, target):
133 | neg_abs = - input.abs()
134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
135 | return loss.mean()
136 |
137 |
138 | def binary_xloss(logits, labels, ignore=None):
139 | """
140 | Binary Cross entropy loss
141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
143 | ignore: void class id
144 | """
145 | logits, labels = flatten_binary_scores(logits, labels, ignore)
146 | loss = StableBCELoss()(logits, Variable(labels.float()))
147 | return loss
148 |
149 |
150 | # --------------------------- MULTICLASS LOSSES ---------------------------
151 |
152 |
153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
154 | """
155 | Multi-class Lovasz-Softmax loss
156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
160 | per_image: compute the loss per image instead of per batch
161 | ignore: void class labels
162 | """
163 | if per_image:
164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
165 | for prob, lab in zip(probas, labels))
166 | else:
167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
168 | return loss
169 |
170 |
171 | def lovasz_softmax_flat(probas, labels, classes='present'):
172 | """
173 | Multi-class Lovasz-Softmax loss
174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
177 | """
178 | if probas.numel() == 0:
179 | # only void pixels, the gradients should be 0
180 | return probas * 0.
181 | C = probas.size(1)
182 | losses = []
183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
184 | for c in class_to_sum:
185 | fg = (labels == c).float() # foreground for class c
186 | if (classes is 'present' and fg.sum() == 0):
187 | continue
188 | if C == 1:
189 | if len(classes) > 1:
190 | raise ValueError('Sigmoid output possible only with 1 class')
191 | class_pred = probas[:, 0]
192 | else:
193 | class_pred = probas[:, c]
194 | errors = (Variable(fg) - class_pred).abs()
195 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
196 | perm = perm.data
197 | fg_sorted = fg[perm]
198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
199 | return mean(losses)
200 |
201 |
202 | def flatten_probas(probas, labels, ignore=None):
203 | """
204 | Flattens predictions in the batch
205 | """
206 | if probas.dim() == 3:
207 | # assumes output of a sigmoid layer
208 | B, H, W = probas.size()
209 | probas = probas.view(B, 1, H, W)
210 | B, C, H, W = probas.size()
211 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
212 | labels = labels.view(-1)
213 | if ignore is None:
214 | return probas, labels
215 | valid = (labels != ignore)
216 | vprobas = probas[valid.nonzero().squeeze()]
217 | vlabels = labels[valid]
218 | return vprobas, vlabels
219 |
220 | def xloss(logits, labels, ignore=None):
221 | """
222 | Cross entropy loss
223 | """
224 | return F.cross_entropy(logits, Variable(labels), ignore_index=255)
225 |
226 |
227 | # --------------------------- HELPER FUNCTIONS ---------------------------
228 | def isnan(x):
229 | return x != x
230 |
231 |
232 | def mean(l, ignore_nan=False, empty=0):
233 | """
234 | nanmean compatible with generators.
235 | """
236 | l = iter(l)
237 | if ignore_nan:
238 | l = ifilterfalse(isnan, l)
239 | try:
240 | n = 1
241 | acc = next(l)
242 | except StopIteration:
243 | if empty == 'raise':
244 | raise ValueError('Empty mean')
245 | return empty
246 | for n, v in enumerate(l, 2):
247 | acc += v
248 | if n == 1:
249 | return acc
250 | return acc / n
--------------------------------------------------------------------------------
/toolbox/metrics.py:
--------------------------------------------------------------------------------
1 | # https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/metrics.py
2 |
3 | import numpy as np
4 |
5 |
6 | class runningScore(object):
7 | '''
8 | n_classes: database的类别,包括背景
9 | ignore_index: 需要忽略的类别id,一般为未标注id, eg. CamVid.id_unlabel
10 | '''
11 |
12 | def __init__(self, n_classes, ignore_index=None):
13 | self.n_classes = n_classes
14 | self.confusion_matrix = np.zeros((n_classes, n_classes))
15 |
16 | if ignore_index is None or ignore_index < 0 or ignore_index > n_classes:
17 | self.ignore_index = None
18 | elif isinstance(ignore_index, int):
19 | self.ignore_index = (ignore_index,)
20 | else:
21 | try:
22 | self.ignore_index = tuple(ignore_index)
23 | except TypeError:
24 | raise ValueError("'ignore_index' must be an int or iterable")
25 |
26 | def _fast_hist(self, label_true, label_pred, n_class):
27 | mask = (label_true >= 0) & (label_true < n_class)
28 | hist = np.bincount(
29 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
30 | ).reshape(n_class, n_class)
31 | return hist
32 |
33 | def update(self, label_trues, label_preds):
34 | for lt, lp in zip(label_trues, label_preds):
35 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
36 |
37 | def get_scores(self):
38 | """Returns accuracy score evaluation result.
39 | - pixel_acc:
40 | - class_acc: class mean acc
41 | - mIou : mean intersection over union
42 | - fwIou: frequency weighted intersection union
43 | """
44 |
45 | hist = self.confusion_matrix
46 |
47 | # ignore unlabel
48 | if self.ignore_index is not None:
49 | for index in self.ignore_index:
50 | hist = np.delete(hist, index, axis=0)
51 | hist = np.delete(hist, index, axis=1)
52 |
53 | acc = np.diag(hist).sum() / hist.sum()
54 | cls_acc = np.diag(hist) / hist.sum(axis=1)
55 | acc_cls = np.nanmean(cls_acc)
56 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
57 | mean_iou = np.nanmean(iu)
58 | freq = hist.sum(axis=1) / hist.sum()
59 | fw_iou = (freq[freq > 0] * iu[freq > 0]).sum()
60 |
61 | # set unlabel as nan
62 | if self.ignore_index is not None:
63 | for index in self.ignore_index:
64 | iu = np.insert(iu, index, np.nan)
65 |
66 | cls_iu = dict(zip(range(self.n_classes), iu))
67 | cls_acc = dict(zip(range(self.n_classes), cls_acc))
68 |
69 | return (
70 | {
71 | "pixel_acc: ": acc,
72 | "class_acc: ": acc_cls,
73 | "mIou: ": mean_iou,
74 | "fwIou: ": fw_iou,
75 | },
76 | cls_iu,
77 | cls_acc,
78 | )
79 |
80 | def reset(self):
81 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
82 |
83 |
84 | class averageMeter(object):
85 | """Computes and stores the average and current value"""
86 |
87 | def __init__(self):
88 | self.reset()
89 |
90 | def reset(self):
91 | self.val = 0
92 | self.avg = 0
93 | self.sum = 0
94 | self.count = 0
95 |
96 | def update(self, val, n=1):
97 | self.val = val
98 | self.sum += val * n
99 | self.count += n
100 | self.avg = self.sum / self.count
101 |
--------------------------------------------------------------------------------
/toolbox/models/LASNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.nn as nn
3 | import torch
4 | from resnet import Backbone_ResNet152_in3
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from toolbox.dual_self_att import CAM_Module
8 |
9 |
10 | class BasicConv2d(nn.Module):
11 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
12 | super(BasicConv2d, self).__init__()
13 | self.conv = nn.Conv2d(in_planes, out_planes,
14 | kernel_size=kernel_size, stride=stride,
15 | padding=padding, dilation=dilation, bias=False)
16 | self.bn = nn.BatchNorm2d(out_planes)
17 | #self.relu = nn.ReLU(inplace=True)
18 | self.relu = nn.LeakyReLU(0.1)
19 |
20 | def forward(self, x):
21 | x = self.conv(x)
22 | x = self.bn(x)
23 | x = self.relu(x)
24 | return x
25 |
26 | class ChannelAttention(nn.Module):
27 | def __init__(self, in_planes, ratio=4):
28 | super(ChannelAttention, self).__init__()
29 | self.max_pool = nn.AdaptiveMaxPool2d(1)
30 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
31 | self.relu1 = nn.ReLU()
32 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
33 |
34 | self.sigmoid = nn.Sigmoid()
35 |
36 | def forward(self, x):
37 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
38 | out = max_out
39 | return self.sigmoid(out)
40 |
41 | class SpatialAttention(nn.Module):
42 | def __init__(self, kernel_size=3):
43 | super(SpatialAttention, self).__init__()
44 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
45 | padding = 3 if kernel_size == 7 else 1
46 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
47 | self.sigmoid = nn.Sigmoid()
48 |
49 | def forward(self, x):
50 | max_out, _ = torch.max(x, dim=1, keepdim=True)
51 | x = max_out
52 | x = self.conv1(x)
53 | return self.sigmoid(x)
54 |
55 |
56 | class CorrelationModule(nn.Module):
57 | def __init__(self, all_channel=64):
58 | super(CorrelationModule, self).__init__()
59 | self.linear_e = nn.Linear(all_channel, all_channel,bias = False)
60 | self.channel = all_channel
61 | self.fusion = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
62 |
63 | def forward(self, exemplar, query): # exemplar: middle, query: rgb or T
64 | fea_size = exemplar.size()[2:]
65 | all_dim = fea_size[0]*fea_size[1]
66 | exemplar_flat = exemplar.view(-1, self.channel, all_dim) #N,C,H*W
67 | query_flat = query.view(-1, self.channel, all_dim)
68 | exemplar_t = torch.transpose(exemplar_flat,1,2).contiguous() #batchsize x dim x num, N,H*W,C
69 | exemplar_corr = self.linear_e(exemplar_t) #
70 | A = torch.bmm(exemplar_corr, query_flat)
71 | B = F.softmax(torch.transpose(A,1,2),dim=1)
72 | exemplar_att = torch.bmm(query_flat, B).contiguous()
73 |
74 | exemplar_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1])
75 | exemplar_out = self.fusion(exemplar_att)
76 |
77 | return exemplar_out
78 |
79 | class CLM(nn.Module):
80 | def __init__(self, all_channel=64):
81 | super(CLM, self).__init__()
82 | self.corr_x_2_x_ir = CorrelationModule(all_channel)
83 | self.corr_ir_2_x_ir = CorrelationModule(all_channel)
84 | self.smooth1 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
85 | self.smooth2 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
86 | self.fusion = BasicConv2d(2*all_channel, all_channel, kernel_size=3, padding=1)
87 | self.pred = nn.Conv2d(all_channel, 2, kernel_size=3, padding=1, bias = True)
88 |
89 | def forward(self, x, x_ir, ir): # exemplar: middle, query: rgb or T
90 | corr_x_2_x_ir = self.corr_x_2_x_ir(x_ir,x)
91 | corr_ir_2_x_ir = self.corr_ir_2_x_ir(x_ir,ir)
92 |
93 | summation = self.smooth1(corr_x_2_x_ir + corr_ir_2_x_ir)
94 | multiplication = self.smooth2(corr_x_2_x_ir * corr_ir_2_x_ir)
95 |
96 | fusion = self.fusion(torch.cat([summation,multiplication],1))
97 | sal_pred = self.pred(fusion)
98 |
99 | return fusion, sal_pred
100 |
101 |
102 | class CAM(nn.Module):
103 | def __init__(self, all_channel=64):
104 | super(CAM, self).__init__()
105 | #self.conv1 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
106 | self.conv2 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
107 | self.sa = SpatialAttention()
108 | # self-channel attention
109 | self.cam = CAM_Module(all_channel)
110 |
111 | def forward(self, x, ir):
112 | multiplication = x * ir
113 | summation = self.conv2(x + ir)
114 |
115 | sa = self.sa(multiplication)
116 | summation_sa = summation.mul(sa)
117 |
118 | sc_feat = self.cam(summation_sa)
119 |
120 | return sc_feat
121 |
122 |
123 | class ESM(nn.Module):
124 | def __init__(self, all_channel=64):
125 | super(ESM, self).__init__()
126 | self.conv1 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
127 | self.conv2 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1)
128 | self.dconv1 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, padding=1)
129 | self.dconv2 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, dilation=3, padding=3)
130 | self.dconv3 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, dilation=5, padding=5)
131 | self.dconv4 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, dilation=7, padding=7)
132 | self.fuse_dconv = nn.Conv2d(all_channel, all_channel, kernel_size=3,padding=1)
133 | self.pred = nn.Conv2d(all_channel, 2, kernel_size=3, padding=1, bias = True)
134 |
135 | def forward(self, x, ir):
136 | multiplication = self.conv1(x * ir)
137 | summation = self.conv2(x + ir)
138 | fusion = (summation + multiplication)
139 | x1 = self.dconv1(fusion)
140 | x2 = self.dconv2(fusion)
141 | x3 = self.dconv3(fusion)
142 | x4 = self.dconv4(fusion)
143 | out = self.fuse_dconv(torch.cat((x1, x2, x3, x4), dim=1))
144 | edge_pred = self.pred(out)
145 |
146 | return out, edge_pred
147 |
148 |
149 | class prediction_decoder(nn.Module):
150 | def __init__(self, channel1=64, channel2=128, channel3=256, channel4=256, channel5=512, n_classes=9):
151 | super(prediction_decoder, self).__init__()
152 | # 15 20
153 | self.decoder5 = nn.Sequential(
154 | nn.Dropout2d(p=0.1),
155 | BasicConv2d(channel5, channel5, kernel_size=3, padding=3, dilation=3),
156 | BasicConv2d(channel5, channel4, kernel_size=3, padding=1),
157 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
158 | )
159 | # 30 40
160 | self.decoder4 = nn.Sequential(
161 | nn.Dropout2d(p=0.1),
162 | BasicConv2d(channel4, channel4, kernel_size=3, padding=3, dilation=3),
163 | BasicConv2d(channel4, channel3, kernel_size=3, padding=1),
164 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
165 | )
166 | # 60 80
167 | self.decoder3 = nn.Sequential(
168 | nn.Dropout2d(p=0.1),
169 | BasicConv2d(channel3, channel3, kernel_size=3, padding=3, dilation=3),
170 | BasicConv2d(channel3, channel2, kernel_size=3, padding=1),
171 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
172 | )
173 | # 120 160
174 | self.decoder2 = nn.Sequential(
175 | nn.Dropout2d(p=0.1),
176 | BasicConv2d(channel2, channel2, kernel_size=3, padding=3, dilation=3),
177 | BasicConv2d(channel2, channel1, kernel_size=3, padding=1),
178 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
179 | )
180 | self.semantic_pred2 = nn.Conv2d(channel1, n_classes, kernel_size=3, padding=1)
181 | # 240 320 -> 480 640
182 | self.decoder1 = nn.Sequential(
183 | nn.Dropout2d(p=0.1),
184 | BasicConv2d(channel1, channel1, kernel_size=3, padding=3, dilation=3),
185 | BasicConv2d(channel1, channel1, kernel_size=3, padding=1),
186 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), # 480 640
187 | BasicConv2d(channel1, channel1, kernel_size=3, padding=1),
188 | nn.Conv2d(channel1, n_classes, kernel_size=3, padding=1)
189 | )
190 |
191 | def forward(self, x5, x4, x3, x2, x1):
192 | x5_decoder = self.decoder5(x5)
193 | # for PST900 dataset
194 | # since the input size is 720x1280, the size of x5_decoder and x4_decoder is 23 and 45, so we cannot use 2x upsampling directrly.
195 | # x5_decoder = F.interpolate(x5_decoder, size=fea_size, mode="bilinear", align_corners=True)
196 | x4_decoder = self.decoder4(x5_decoder + x4)
197 | x3_decoder = self.decoder3(x4_decoder + x3)
198 | x2_decoder = self.decoder2(x3_decoder + x2)
199 | semantic_pred2 = self.semantic_pred2(x2_decoder)
200 | semantic_pred = self.decoder1(x2_decoder + x1)
201 |
202 | return semantic_pred,semantic_pred2
203 |
204 |
205 | class LASNet(nn.Module):
206 | def __init__(self, n_classes):
207 | super(LASNet, self).__init__()
208 |
209 | (
210 | self.layer1_rgb,
211 | self.layer2_rgb,
212 | self.layer3_rgb,
213 | self.layer4_rgb,
214 | self.layer5_rgb,
215 | ) = Backbone_ResNet152_in3(pretrained=True)
216 |
217 | # reduce the channel number, input: 480 640
218 | self.rgbconv1 = BasicConv2d(64, 64, kernel_size=3, padding=1) # 240 320
219 | self.rgbconv2 = BasicConv2d(256, 128, kernel_size=3, padding=1) # 120 160
220 | self.rgbconv3 = BasicConv2d(512, 256, kernel_size=3, padding=1) # 60 80
221 | self.rgbconv4 = BasicConv2d(1024, 256, kernel_size=3, padding=1) # 30 40
222 | self.rgbconv5 = BasicConv2d(2048, 512, kernel_size=3, padding=1) # 15 20
223 |
224 | self.CLM5 = CLM(512)
225 | self.CAM4 = CAM(256)
226 | self.CAM3 = CAM(256)
227 | self.CAM2 = CAM(128)
228 | self.ESM1 = ESM(64)
229 |
230 | self.decoder = prediction_decoder(64,128,256,256,512, n_classes)
231 |
232 | def forward(self, rgb, depth):
233 | x = rgb
234 | ir = depth[:, :1, ...]
235 | ir = torch.cat((ir, ir, ir), dim=1)
236 |
237 | x1 = self.layer1_rgb(x)
238 | x2 = self.layer2_rgb(x1)
239 | x3 = self.layer3_rgb(x2)
240 | x4 = self.layer4_rgb(x3)
241 | x5 = self.layer5_rgb(x4)
242 |
243 | ir1 = self.layer1_rgb(ir)
244 | ir2 = self.layer2_rgb(ir1)
245 | ir3 = self.layer3_rgb(ir2)
246 | ir4 = self.layer4_rgb(ir3)
247 | ir5 = self.layer5_rgb(ir4)
248 |
249 | x1 = self.rgbconv1(x1)
250 | x2 = self.rgbconv2(x2)
251 | x3 = self.rgbconv3(x3)
252 | x4 = self.rgbconv4(x4)
253 | x5 = self.rgbconv5(x5)
254 |
255 | ir1 = self.rgbconv1(ir1)
256 | ir2 = self.rgbconv2(ir2)
257 | ir3 = self.rgbconv3(ir3)
258 | ir4 = self.rgbconv4(ir4)
259 | ir5 = self.rgbconv5(ir5)
260 |
261 | out5, sal = self.CLM5(x5, x5*ir5, ir5)
262 | out4 = self.CAM4(x4, ir4)
263 | out3 = self.CAM3(x3, ir3)
264 | out2 = self.CAM2(x2, ir2)
265 | out1, edge = self.ESM1(x1, ir1)
266 |
267 | semantic, semantic2 = self.decoder(out5, out4, out3, out2, out1)
268 | semantic2 = torch.nn.functional.interpolate(semantic2, scale_factor=2, mode='bilinear')
269 | sal = torch.nn.functional.interpolate(sal, scale_factor=32, mode='bilinear')
270 | edge = torch.nn.functional.interpolate(edge, scale_factor=2, mode='bilinear')
271 |
272 |
273 | return semantic, semantic2, sal, edge
274 |
275 | if __name__ == '__main__':
276 | LASNet(9)
277 | # for PST900 dataset
278 | # LASNet(5)
279 |
--------------------------------------------------------------------------------
/toolbox/models/__pycache__/EGFNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/EGFNet.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/models/__pycache__/EGFNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/EGFNet.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/models/__pycache__/LASNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/LASNet.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/models/__pycache__/LgyTestNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/LgyTestNet.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/models/__pycache__/LgyTestNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/LgyTestNet.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/optim/Ranger.py:
--------------------------------------------------------------------------------
1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2 |
3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4 | # and/or
5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6 |
7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard.
8 |
9 | # This version = 20.4.11
10 |
11 | # Credits:
12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16 |
17 | # summary of changes:
18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold;
22 | # changed eps to 1e-5 as better default than 1e-8.
23 |
24 | import math
25 | import torch
26 | from torch.optim.optimizer import Optimizer, required
27 |
28 |
29 | class Ranger(Optimizer):
30 |
31 | def __init__(self, params, lr=1e-3, # lr
32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
34 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers
35 | use_gc=True, gc_conv_only=False
36 | ):
37 |
38 | # parameter checks
39 | if not 0.0 <= alpha <= 1.0:
40 | raise ValueError(f'Invalid slow update rate: {alpha}')
41 | if not 1 <= k:
42 | raise ValueError(f'Invalid lookahead steps: {k}')
43 | if not lr > 0:
44 | raise ValueError(f'Invalid Learning Rate: {lr}')
45 | if not eps > 0:
46 | raise ValueError(f'Invalid eps: {eps}')
47 |
48 | # parameter comments:
49 | # beta1 (momentum) of .95 seems to work better than .90...
50 | # N_sma_threshold of 5 seems better in testing than 4.
51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
52 |
53 | # prep defaults and init torch.optim base
54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas,
55 | N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
56 | super().__init__(params, defaults)
57 |
58 | # adjustable threshold
59 | self.N_sma_threshhold = N_sma_threshhold
60 |
61 | # look ahead params
62 |
63 | self.alpha = alpha
64 | self.k = k
65 |
66 | # radam buffer for state
67 | self.radam_buffer = [[None, None, None] for ind in range(10)]
68 |
69 | # gc on or off
70 | self.use_gc = use_gc
71 |
72 | # level of gradient centralization
73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1
74 |
75 | print(
76 | f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
77 | if (self.use_gc and self.gc_gradient_threshold == 1):
78 | print(f"GC applied to both conv and fc layers")
79 | elif (self.use_gc and self.gc_gradient_threshold == 3):
80 | print(f"GC applied to conv layers only")
81 |
82 | def __setstate__(self, state):
83 | print("set state called")
84 | super(Ranger, self).__setstate__(state)
85 |
86 | def step(self, closure=None):
87 | loss = None
88 | # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
89 | # Uncomment if you need to use the actual closure...
90 |
91 | # if closure is not None:
92 | #loss = closure()
93 |
94 | # Evaluate averages and grad, update param tensors
95 | for group in self.param_groups:
96 |
97 | for p in group['params']:
98 | if p.grad is None:
99 | continue
100 | grad = p.grad.data.float()
101 |
102 | if grad.is_sparse:
103 | raise RuntimeError(
104 | 'Ranger optimizer does not support sparse gradients')
105 |
106 | p_data_fp32 = p.data.float()
107 |
108 | state = self.state[p] # get state dict for this param
109 |
110 | if len(state) == 0: # if first time to run...init dictionary with our desired entries
111 | # if self.first_run_check==0:
112 | # self.first_run_check=1
113 | #print("Initializing slow buffer...should not see this at load from saved model!")
114 | state['step'] = 0
115 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
116 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
117 |
118 | # look ahead weight storage now in state dict
119 | state['slow_buffer'] = torch.empty_like(p.data)
120 | state['slow_buffer'].copy_(p.data)
121 |
122 | else:
123 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
124 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
125 | p_data_fp32)
126 |
127 | # begin computations
128 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
129 | beta1, beta2 = group['betas']
130 |
131 | # GC operation for Conv layers and FC layers
132 | if grad.dim() > self.gc_gradient_threshold:
133 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
134 |
135 | state['step'] += 1
136 |
137 | # compute variance mov avg
138 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
139 | # compute mean moving avg
140 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
141 |
142 | buffered = self.radam_buffer[int(state['step'] % 10)]
143 |
144 | if state['step'] == buffered[0]:
145 | N_sma, step_size = buffered[1], buffered[2]
146 | else:
147 | buffered[0] = state['step']
148 | beta2_t = beta2 ** state['step']
149 | N_sma_max = 2 / (1 - beta2) - 1
150 | N_sma = N_sma_max - 2 * \
151 | state['step'] * beta2_t / (1 - beta2_t)
152 | buffered[1] = N_sma
153 | if N_sma > self.N_sma_threshhold:
154 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (
155 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
156 | else:
157 | step_size = 1.0 / (1 - beta1 ** state['step'])
158 | buffered[2] = step_size
159 |
160 | if group['weight_decay'] != 0:
161 | p_data_fp32.add_(-group['weight_decay']
162 | * group['lr'], p_data_fp32)
163 |
164 | # apply lr
165 | if N_sma > self.N_sma_threshhold:
166 | denom = exp_avg_sq.sqrt().add_(group['eps'])
167 | p_data_fp32.addcdiv_(-step_size *
168 | group['lr'], exp_avg, denom)
169 | else:
170 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
171 |
172 | p.data.copy_(p_data_fp32)
173 |
174 | # integrated look ahead...
175 | # we do it at the param level instead of group level
176 | if state['step'] % group['k'] == 0:
177 | # get access to slow param tensor
178 | slow_p = state['slow_buffer']
179 | # (fast weights - slow weights) * alpha
180 | slow_p.add_(self.alpha, p.data - slow_p)
181 | # copy interpolated weights to RAdam param tensor
182 | p.data.copy_(slow_p)
183 |
184 | return loss
--------------------------------------------------------------------------------
/toolbox/optim/__pycache__/Ranger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/optim/__pycache__/Ranger.cpython-37.pyc
--------------------------------------------------------------------------------
/toolbox/optim/__pycache__/Ranger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/optim/__pycache__/Ranger.cpython-38.pyc
--------------------------------------------------------------------------------
/toolbox/scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import *
2 |
--------------------------------------------------------------------------------
/toolbox/scheduler/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
3 |
4 |
5 | class WarmupMultiStepLR(MultiStepLR):
6 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
7 | warmup_iters=500, last_epoch=-1):
8 | self.warmup_factor = warmup_factor
9 | self.warmup_iters = warmup_iters
10 | super().__init__(optimizer, milestones, gamma, last_epoch)
11 |
12 | def get_lr(self):
13 | if self.last_epoch <= self.warmup_iters:
14 | alpha = self.last_epoch / self.warmup_iters
15 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
16 | # print(self.base_lrs[0]*warmup_factor)
17 | return [lr * warmup_factor for lr in self.base_lrs]
18 | else:
19 | lr = super().get_lr()
20 | return lr
21 |
22 |
23 | class WarmupCosineLR(_LRScheduler):
24 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500,
25 | eta_min=0, last_epoch=-1):
26 | self.warmup_factor = warmup_factor
27 | self.warmup_iters = warmup_iters
28 | self.T_max, self.eta_min = T_max, eta_min
29 | super().__init__(optimizer, last_epoch)
30 |
31 | def get_lr(self):
32 | if self.last_epoch <= self.warmup_iters:
33 | alpha = self.last_epoch / self.warmup_iters
34 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
35 | # print(self.base_lrs[0]*warmup_factor)
36 | return [lr * warmup_factor for lr in self.base_lrs]
37 | else:
38 | return [self.eta_min + (base_lr - self.eta_min) *
39 | (1 + math.cos(
40 | math.pi * (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters))) / 2
41 | for base_lr in self.base_lrs]
42 |
43 |
44 | class WarmupPolyLR(_LRScheduler):
45 | def __init__(self, optimizer, T_max, cur_iter, warmup_factor=1.0 / 3, warmup_iters=500,
46 | eta_min=0, power=0.9):
47 | self.warmup_factor = warmup_factor
48 | self.warmup_iters = warmup_iters
49 | self.power = power
50 | self.T_max, self.eta_min = T_max, eta_min
51 | self.cur_iter = cur_iter
52 | super().__init__(optimizer)
53 |
54 | def get_lr(self):
55 | if self.cur_iter <= self.warmup_iters:
56 | alpha = self.cur_iter / self.warmup_iters
57 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
58 | # print(self.base_lrs[0]*warmup_factor)
59 | return [lr * warmup_factor for lr in self.base_lrs]
60 | else:
61 | return [self.eta_min + (base_lr - self.eta_min) *
62 | math.pow(1 - (self.cur_iter - self.warmup_iters) / (self.T_max - self.warmup_iters),
63 | self.power) for base_lr in self.base_lrs]
64 |
65 |
66 | def poly_learning_rate(cur_epoch, max_epoch, curEpoch_iter, perEpoch_iter, baselr):
67 | cur_iter = cur_epoch * perEpoch_iter + curEpoch_iter
68 | max_iter = max_epoch * perEpoch_iter
69 | lr = baselr * pow((1 - 1.0 * cur_iter / max_iter), 0.9)
70 |
71 | return lr
72 |
73 |
74 | class GradualWarmupScheduler(_LRScheduler):
75 | """ Gradually warm-up(increasing) learning rate in optimizer.
76 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
77 | Args:
78 | optimizer (Optimizer): Wrapped optimizer.
79 | min_lr_mul: target learning rate = base lr * min_lr_mul
80 | total_epoch: target learning rate is reached at total_epoch, gradually
81 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
82 | """
83 |
84 | def __init__(self, optimizer, total_epoch, min_lr_mul=0.1, after_scheduler=None):
85 | self.min_lr_mul = min_lr_mul
86 | if self.min_lr_mul > 1. or self.min_lr_mul < 0.:
87 | raise ValueError('min_lr_mul should be [0., 1.]')
88 | self.total_epoch = total_epoch
89 | self.after_scheduler = after_scheduler
90 | self.finished = False
91 | super(GradualWarmupScheduler, self).__init__(optimizer)
92 |
93 | def get_lr(self):
94 | if self.last_epoch > self.total_epoch:
95 | if self.after_scheduler:
96 | if not self.finished:
97 | self.after_scheduler.base_lrs = self.base_lrs
98 | self.finished = True
99 | return self.after_scheduler.get_lr()
100 | else:
101 | return self.base_lrs
102 | else:
103 | return [base_lr * (self.min_lr_mul + (1. - self.min_lr_mul) * (self.last_epoch / float(self.total_epoch)))
104 | for base_lr in self.base_lrs]
105 |
106 | def step(self, epoch=None):
107 | if self.finished and self.after_scheduler:
108 | return self.after_scheduler.step(epoch - self.total_epoch)
109 | else:
110 | return super(GradualWarmupScheduler, self).step(epoch)
111 |
112 |
113 | if __name__ == '__main__':
114 | optim = WarmupPolyLR()
115 |
--------------------------------------------------------------------------------
/toolbox/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | import os
5 | import math
6 | import random
7 | import time
8 | import torch.backends.cudnn as cudnn
9 |
10 |
11 |
12 | class ClassWeight(object):
13 |
14 | def __init__(self, method):
15 | assert method in ['no', 'enet', 'median_freq_balancing']
16 | self.method = method
17 |
18 | def get_weight(self, dataloader, num_classes):
19 | if self.method == 'no':
20 | return np.ones(num_classes)
21 | if self.method == 'enet':
22 | return self._enet_weighing(dataloader, num_classes)
23 | if self.method == 'median_freq_balancing':
24 | return self._median_freq_balancing(dataloader, num_classes)
25 |
26 | def _enet_weighing(self, dataloader, num_classes, c=1.02):
27 | """Computes class weights as described in the ENet paper:
28 |
29 | w_class = 1 / (ln(c + p_class)),
30 |
31 | where c is usually 1.02 and p_class is the propensity score of that
32 | class:
33 |
34 | propensity_score = freq_class / total_pixels.
35 |
36 | References: https://arxiv.org/abs/1606.02147
37 |
38 | Keyword arguments:
39 | - dataloader (``data.Dataloader``): A data loader to iterate over the
40 | dataset.
41 | - num_classes (``int``): The number of classes.
42 | - c (``int``, optional): AN additional hyper-parameter which restricts
43 | the interval of values for the weights. Default: 1.02.
44 |
45 | """
46 | print('computing class weight .......................')
47 | class_count = 0
48 | total = 0
49 | for i, sample in tqdm(enumerate(dataloader), total=len(dataloader)):
50 | label = sample['label']
51 | label = label.cpu().numpy()
52 |
53 | # Flatten label
54 | flat_label = label.flatten()
55 |
56 | # Sum up the number of pixels of each class and the total pixel
57 | # counts for each label
58 | class_count += np.bincount(flat_label, minlength=num_classes)
59 | total += flat_label.size
60 |
61 | # Compute propensity score and then the weights for each class
62 | propensity_score = class_count / total
63 | class_weights = 1 / (np.log(c + propensity_score))
64 |
65 | return class_weights
66 |
67 | def _median_freq_balancing(self, dataloader, num_classes):
68 | """Computes class weights using median frequency balancing as described
69 | in https://arxiv.org/abs/1411.4734:
70 |
71 | w_class = median_freq / freq_class,
72 |
73 | where freq_class is the number of pixels of a given class divided by
74 | the total number of pixels in images where that class is present, and
75 | median_freq is the median of freq_class.
76 |
77 | Keyword arguments:
78 | - dataloader (``data.Dataloader``): A data loader to iterate over the
79 | dataset.
80 | whose weights are going to be computed.
81 | - num_classes (``int``): The number of classes
82 |
83 | """
84 | print('computing class weight .......................')
85 | class_count = 0
86 | total = 0
87 | for i, sample in tqdm(enumerate(dataloader), total=len(dataloader)):
88 | label = sample['label']
89 | label = label.cpu().numpy()
90 |
91 | # Flatten label
92 | flat_label = label.flatten()
93 |
94 | # Sum up the class frequencies
95 | bincount = np.bincount(flat_label, minlength=num_classes)
96 |
97 | # Create of mask of classes that exist in the label
98 | mask = bincount > 0
99 | # Multiply the mask by the pixel count. The resulting array has
100 | # one element for each class. The value is either 0 (if the class
101 | # does not exist in the label) or equal to the pixel count (if
102 | # the class exists in the label)
103 | total += mask * flat_label.size
104 |
105 | # Sum up the number of pixels found for each class
106 | class_count += bincount
107 |
108 | # Compute the frequency and its median
109 | freq = class_count / total
110 | med = np.median(freq)
111 |
112 | return med / freq
113 |
114 |
115 | def color_map(N=256, normalized=False):
116 | """
117 | Return Color Map in PASCAL VOC format
118 | """
119 |
120 | def bitget(byteval, idx):
121 | return (byteval & (1 << idx)) != 0
122 |
123 | dtype = "float32" if normalized else "uint8"
124 | cmap = np.zeros((N, 3), dtype=dtype)
125 | for i in range(N):
126 | r = g = b = 0
127 | c = i
128 | for j in range(8):
129 | r = r | (bitget(c, 0) << 7 - j)
130 | g = g | (bitget(c, 1) << 7 - j)
131 | b = b | (bitget(c, 2) << 7 - j)
132 | c = c >> 3
133 |
134 | cmap[i] = np.array([r, g, b])
135 |
136 | cmap = cmap / 255.0 if normalized else cmap
137 | return cmap
138 |
139 |
140 | def class_to_RGB(label, N, cmap=None, normalized=False):
141 | '''
142 | label: 2D numpy array with pixel-level classes shape=(h, w)
143 | N: number of classes, including background, should in [0, 255]
144 | cmap: list of colors for N class (include background) \
145 | if None, use VOC default color map.
146 | normalized: RGB in [0, 1] if True else [0, 255] if False
147 |
148 | :return 上色好的3D RGB numpy array shape=(h, w, 3)
149 | '''
150 | dtype = "float32" if normalized else "uint8"
151 |
152 | assert len(label.shape) == 2, f'label should be 2D, not {len(label.shape)}D'
153 | label_class = np.asarray(label)
154 |
155 | label_color = np.zeros((label.shape[0], label.shape[1], 3), dtype=dtype)
156 |
157 | if cmap is None:
158 | # 0表示背景为[0 0 0]黑色,1~N表示N个类别彩色
159 | cmap = color_map(N, normalized=normalized)
160 | else:
161 | cmap = np.asarray(cmap, dtype=dtype)
162 | cmap = cmap / 255.0 if normalized else cmap
163 |
164 | assert cmap.shape[0] == N, f'{N} classes and {cmap.shape[0]} colors not match.'
165 |
166 | # 给每个类别根据color_map上色
167 | for i_class in range(N):
168 | label_color[label_class == i_class] = cmap[i_class]
169 |
170 | return label_color
171 |
172 |
173 | def tensor_classes_to_RGBs(label, N, cmap=None):
174 | '''used in tensorboard'''
175 |
176 | if cmap is None:
177 | cmap = color_map(N)
178 | else:
179 | cmap = np.asarray(cmap)
180 |
181 | label = label.clone().cpu().numpy() # (batch_size, H, W)
182 | ctRGB = np.vectorize(lambda x: tuple(cmap[int(x)].tolist()))
183 |
184 | colored = np.asarray(ctRGB(label)).astype(np.float32) # (batch_size, 3, H, W)
185 | colored = colored.squeeze()
186 |
187 | try:
188 | return torch.from_numpy(colored.transpose([1, 0, 2, 3]))
189 | except ValueError:
190 | return torch.from_numpy(colored[np.newaxis, ...])
191 |
192 |
193 | def save_ckpt(logdir, model, epoch_iter, prefix=''):
194 | state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
195 | torch.save(state, os.path.join(logdir, prefix + 'model_' + str(epoch_iter) + '.pth'))
196 |
197 |
198 | def load_ckpt(logdir, model, prefix=''):
199 | save_pth = os.path.join(logdir, prefix+'model.pth')
200 | model.load_state_dict(torch.load(save_pth))
201 | return model
202 |
203 |
204 | def compute_speed(model, input_size, device=0, iteration=100):
205 | torch.cuda.set_device(device)
206 | cudnn.benchmark = True
207 |
208 | model.eval()
209 | model = model.cuda()
210 |
211 | input = torch.randn(*input_size, device=device)
212 |
213 | for _ in range(50):
214 | model(input)
215 |
216 | print('=========Eval Forward Time=========')
217 | torch.cuda.synchronize()
218 | t_start = time.time()
219 | for _ in range(iteration):
220 | model(input)
221 | torch.cuda.synchronize()
222 | elapsed_time = time.time() - t_start
223 |
224 | speed_time = elapsed_time / iteration * 1000
225 | fps = iteration / elapsed_time
226 |
227 | print('Elapsed Time: [%.2f s / %d iter]' % (elapsed_time, iteration))
228 | print('Speed Time: %.2f ms / iter FPS: %.2f' % (speed_time, fps))
229 | return speed_time, fps
230 |
231 |
232 | def setup_seed(seed):
233 | torch.manual_seed(seed)
234 | torch.cuda.manual_seed_all(seed)
235 | np.random.seed(seed)
236 | random.seed(seed)
237 | torch.backends.cudnn.deterministic = True
238 | torch.backends.cudnn.benchmark = False
239 |
240 |
241 | def group_weight_decay(model):
242 |
243 | import torch.nn as nn
244 | from torch.nn.modules.conv import _ConvNd
245 | from torch.nn.modules.batchnorm import _BatchNorm
246 |
247 | decays = []
248 | no_decays = []
249 | for m in model.modules():
250 | if isinstance(m, nn.Linear):
251 | decays.append(m.weight)
252 | if m.bias is not None:
253 | no_decays.append(m.bias)
254 | elif isinstance(m, _ConvNd):
255 | decays.append(m.weight)
256 | if m.bias is not None:
257 | no_decays.append(m.bias)
258 | elif isinstance(m, _BatchNorm):
259 | if m.weight is not None:
260 | no_decays.append(m.weight)
261 | if m.bias is not None:
262 | no_decays.append(m.bias)
263 |
264 | assert len(list(model.parameters())) == len(decays) + len(no_decays)
265 | groups = [dict(params=decays), dict(params=no_decays, weight_decay=0.0)]
266 | return groups
267 |
268 |
269 | if __name__ == '__main__':
270 | pass
271 |
--------------------------------------------------------------------------------
/train_LASNet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import json
4 | import time
5 |
6 | from apex import amp
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.optim.lr_scheduler import LambdaLR
12 | from torch.utils.data import DataLoader
13 |
14 | from toolbox import get_dataset # loss
15 | from toolbox.optim.Ranger import Ranger
16 | from toolbox import get_logger
17 | from toolbox import get_model
18 | from toolbox import averageMeter, runningScore
19 | from toolbox import save_ckpt
20 | from toolbox.datasets.irseg import IRSeg
21 | from toolbox.datasets.pst900 import PSTSeg
22 | from toolbox.losses import lovasz_softmax
23 |
24 | class eeemodelLoss(nn.Module):
25 |
26 | def __init__(self, class_weight=None, ignore_index=-100, reduction='mean'):
27 | super(eeemodelLoss, self).__init__()
28 |
29 | self.class_weight_semantic = torch.from_numpy(np.array(
30 | [1.5105, 16.6591, 29.4238, 34.6315, 40.0845, 41.4357, 47.9794, 45.3725, 44.9000])).float()
31 | self.class_weight_binary = torch.from_numpy(np.array([1.5121, 10.2388])).float()
32 | self.class_weight_boundary = torch.from_numpy(np.array([1.4459, 23.7228])).float()
33 |
34 | self.class_weight = class_weight
35 | # self.LovaszSoftmax = lovasz_softmax()
36 | self.cross_entropy = nn.CrossEntropyLoss()
37 |
38 | self.semantic_loss = nn.CrossEntropyLoss(weight=self.class_weight_semantic)
39 | self.binary_loss = nn.CrossEntropyLoss(weight=self.class_weight_binary)
40 | self.boundary_loss = nn.CrossEntropyLoss(weight=self.class_weight_boundary)
41 |
42 | def forward(self, inputs, targets):
43 | semantic_gt, binary_gt, boundary_gt = targets
44 | semantic_out, semantic_out_2, sal_out, edge_out = inputs
45 |
46 | loss1 = self.semantic_loss(semantic_out, semantic_gt)
47 | loss2 = lovasz_softmax(F.softmax(semantic_out, dim=1), semantic_gt, ignore=255)
48 | loss3 = self.semantic_loss(semantic_out_2, semantic_gt)
49 | loss4 = self.binary_loss(sal_out, binary_gt)
50 | loss5 = self.boundary_loss(edge_out, boundary_gt)
51 |
52 | loss = loss1 + loss2 + loss3 + 0.5*loss4 + loss5
53 | return loss
54 |
55 |
56 | def run(args):
57 | torch.cuda.set_device(args.cuda)
58 | with open(args.config, 'r') as fp:
59 | cfg = json.load(fp)
60 |
61 | logdir = f'run/{time.strftime("%Y-%m-%d-%H-%M")}-{cfg["dataset"]}-{cfg["model_name"]}-'
62 | if not os.path.exists(logdir):
63 | os.makedirs(logdir)
64 | shutil.copy(args.config, logdir)
65 |
66 | logger = get_logger(logdir)
67 | logger.info(f'Conf | use logdir {logdir}')
68 |
69 | model = get_model(cfg)
70 | device = torch.device(f'cuda:{args.cuda}')
71 | model.to(device)
72 |
73 |
74 | trainset, _, testset = get_dataset(cfg)
75 | train_loader = DataLoader(trainset, batch_size=cfg['ims_per_gpu'], shuffle=True, num_workers=cfg['num_workers'],
76 | pin_memory=True)
77 | test_loader = DataLoader(testset, batch_size=cfg['ims_per_gpu'], shuffle=False, num_workers=cfg['num_workers'],
78 | pin_memory=True)
79 |
80 | params_list = model.parameters()
81 | optimizer = Ranger(params_list, lr=cfg['lr_start'], weight_decay=cfg['weight_decay'])
82 | scheduler = LambdaLR(optimizer, lr_lambda=lambda ep: (1 - ep / cfg['epochs']) ** 0.9)
83 |
84 | train_criterion = eeemodelLoss().to(device)
85 | criterion = nn.CrossEntropyLoss().to(device)
86 |
87 | train_loss_meter = averageMeter()
88 | test_loss_meter = averageMeter()
89 | running_metrics_test = runningScore(cfg['n_classes'], ignore_index=cfg['id_unlabel'])
90 | best_test = 0
91 |
92 | amp.register_float_function(torch, 'sigmoid')
93 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level)
94 |
95 |
96 | for ep in range(cfg['epochs']):
97 |
98 | # training
99 | model.train()
100 | train_loss_meter.reset()
101 | for i, sample in enumerate(train_loader):
102 | optimizer.zero_grad()
103 |
104 | image = sample['image'].to(device)
105 | depth = sample['depth'].to(device)
106 | label = sample['label'].to(device)
107 | bound = sample['bound'].to(device)
108 | binary_label = sample['binary_label'].to(device)
109 | targets = [label, binary_label, bound]
110 | predict = model(image, depth)
111 |
112 | loss = train_criterion(predict, targets)
113 | ####################################################
114 |
115 | with amp.scale_loss(loss, optimizer) as scaled_loss:
116 | scaled_loss.backward()
117 | optimizer.step()
118 |
119 | train_loss_meter.update(loss.item())
120 |
121 | scheduler.step(ep)
122 |
123 | # test
124 | with torch.no_grad():
125 | model.eval()
126 | running_metrics_test.reset()
127 | test_loss_meter.reset()
128 | for i, sample in enumerate(test_loader):
129 |
130 | image = sample['image'].to(device)
131 | # Here, depth is TIR.
132 | depth = sample['depth'].to(device)
133 | label = sample['label'].to(device)
134 | predict = model(image, depth)[0]
135 |
136 | loss = criterion(predict, label)
137 | test_loss_meter.update(loss.item())
138 |
139 | predict = predict.max(1)[1].cpu().numpy() # [1, h, w]
140 | label = label.cpu().numpy()
141 | running_metrics_test.update(label, predict)
142 |
143 | train_loss = train_loss_meter.avg
144 | test_loss = test_loss_meter.avg
145 |
146 | test_macc = running_metrics_test.get_scores()[0]["class_acc: "]
147 | test_miou = running_metrics_test.get_scores()[0]["mIou: "]
148 | test_avg = (test_macc + test_miou) / 2
149 |
150 | logger.info(
151 | f'Iter | [{ep + 1:3d}/{cfg["epochs"]}] loss={train_loss:.3f}/{test_loss:.3f}, mPA={test_macc:.3f}, miou={test_miou:.3f}, avg={test_avg:.3f}')
152 | if test_avg > best_test:
153 | best_test = test_avg
154 | save_ckpt(logdir, model,ep+1)
155 | logger.info(
156 | f'Save Iter = [{ep + 1:3d}], mPA={test_macc:.3f}, miou={test_miou:.3f}, avg={test_avg:.3f}')
157 |
158 |
159 | if __name__ == '__main__':
160 | import argparse
161 |
162 | parser = argparse.ArgumentParser(description="config")
163 | parser.add_argument("--config", type=str, default="configs/LASNet.json", help="Configuration file to use")
164 | parser.add_argument("--opt_level", type=str, default='O1')
165 | parser.add_argument("--inputs", type=str.lower, default='rgb', choices=['rgb', 'rgbd'])
166 | parser.add_argument("--resume", type=str, default='',
167 | help="use this file to load last checkpoint for continuing training")
168 | parser.add_argument("--cuda", type=int, default=1, help="set cuda device id")
169 |
170 | args = parser.parse_args()
171 |
172 | print("Starting Training!")
173 | run(args)
174 |
--------------------------------------------------------------------------------