├── Images ├── MFNet.png └── PST900.png ├── README.md ├── __pycache__ ├── resnet.cpython-37.pyc └── resnet.cpython-38.pyc ├── class_weights.py ├── configs └── LASNet.json ├── generate_binary_labels.m ├── generate_bound_or_edge.m ├── model ├── LASNet.json ├── predicts_MFNet.zip └── predicts_PST900.zip ├── resnet.py ├── sober.py ├── test_LASNet.py ├── toolbox ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── dual_self_att.cpython-37.pyc │ ├── dual_self_att.cpython-38.pyc │ ├── log.cpython-36.pyc │ ├── log.cpython-37.pyc │ ├── log.cpython-38.pyc │ ├── losses.cpython-37.pyc │ ├── losses.cpython-38.pyc │ ├── metrics.cpython-36.pyc │ ├── metrics.cpython-37.pyc │ ├── metrics.cpython-38.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ └── utils.cpython-38.pyc ├── datasets │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── augmentations.cpython-36.pyc │ │ ├── augmentations.cpython-37.pyc │ │ ├── augmentations.cpython-38.pyc │ │ ├── camvid.cpython-37.pyc │ │ ├── irseg.cpython-36.pyc │ │ ├── irseg.cpython-37.pyc │ │ ├── irseg.cpython-38.pyc │ │ ├── nyuv2.cpython-37.pyc │ │ └── pst900.cpython-38.pyc │ ├── augmentations.py │ ├── camvid.py │ ├── irseg.py │ └── pst900.py ├── dual_self_att.py ├── log.py ├── losses.py ├── metrics.py ├── models │ ├── LASNet.py │ └── __pycache__ │ │ ├── EGFNet.cpython-37.pyc │ │ ├── EGFNet.cpython-38.pyc │ │ ├── LASNet.cpython-38.pyc │ │ ├── LgyTestNet.cpython-37.pyc │ │ └── LgyTestNet.cpython-38.pyc ├── optim │ ├── Ranger.py │ └── __pycache__ │ │ ├── Ranger.cpython-37.pyc │ │ └── Ranger.cpython-38.pyc ├── scheduler │ ├── __init__.py │ └── lr_scheduler.py └── utils.py └── train_LASNet.py /Images/MFNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/Images/MFNet.png -------------------------------------------------------------------------------- /Images/PST900.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/Images/PST900.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LASNet 2 | This project provides the code and results for 'RGB-T Semantic Segmentation with Location, Activation, and Sharpening', IEEE TCSVT, 2023. [IEEE link](https://ieeexplore.ieee.org/document/9900351) and [arxiv link](https://arxiv.org/abs/2210.14530) [Homepage](https://mathlee.github.io/) 3 | 4 | # Requirements 5 | python 3.7/3.8 + pytorch 1.9.0 (biult on [EGFNet](https://github.com/ShaohuaDong2021/EGFNet)) 6 | 7 | 8 | # Segmentation maps and performance 9 | We provide segmentation maps on MFNet dataset and PST900 dataset under './model/'. 10 | 11 | **Performace on MFNet dataset** 12 | 13 |
14 | 15 |
16 | 17 | **Performace on PST900 dataset** 18 | 19 |
20 | 21 |
22 | 23 | 24 | # Training 25 | 1. Install '[apex](https://github.com/NVIDIA/apex)'. 26 | 2. Download [MFNet dataset](https://pan.baidu.com/s/1NHGazP7pwgEM47SP_ljJPg) (code: 3b9o) or [PST900 dataset](https://pan.baidu.com/s/13xgwFfUbu8zNvkwJq2Ggug) (code: mp2h). 27 | 3. Use 'generate_binary_labels.m' to get binary labels, and use 'generate_bound_or_edge.m' to get edge labels. 28 | 4. Run train_LASNet.py (default to MFNet Dataset). 29 | 30 | Note: our main model is under './toolbox/models/LASNet.py' 31 | 32 | 33 | # Pre-trained model and testing 34 | 1. Download the following pre-trained model and put it under './model/'. [model_MFNet.pth](https://pan.baidu.com/s/1dWCbTl274nzgdHGOsJkK_Q) (code: 5th1) [model_PST900.pth](https://pan.baidu.com/s/1zQif2_8LTG5R7aabQOXjrA) (code: okdq) 35 | 36 | 2. Rename the name of the pre-trained model to 'model.pth', and then run test_LASNet.py (default to MFNet Dataset). 37 | 38 | 39 | # Citation 40 | @ARTICLE{Li_2023_LASNet, 41 | author = {Gongyang Li and Yike Wang and Zhi Liu and Xinpeng Zhang and Dan Zeng}, 42 | title = {RGB-T Semantic Segmentation with Location, Activation, and Sharpening}, 43 | journal = {IEEE Transactions on Circuits and Systems for Video Technology}, 44 | year = {2023}, 45 | volume = {33}, 46 | number = {3}, 47 | pages = {1223-1235}, 48 | month = {Mar.}, 49 | } 50 | 51 | 52 | If you encounter any problems with the code, want to report bugs, etc. 53 | 54 | Please contact me at lllmiemie@163.com or ligongyang@shu.edu.cn. 55 | -------------------------------------------------------------------------------- /__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /class_weights.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import PIL.Image 4 | import numpy as np 5 | import cv2 6 | import pdb 7 | import glob 8 | import configparser 9 | import numpy as np 10 | import sys 11 | 12 | class ClassWeights: 13 | """ 14 | Calculate class weights for PST900 15 | """ 16 | def __init__(self, datapath=''): 17 | """ 18 | Initialize class 19 | """ 20 | self.data_path = datapath 21 | self.label_path_train = os.path.join(datapath, 'train', 'labels') 22 | self.label_path_test = os.path.join(datapath, 'test', 'labels') 23 | self.label_stack = [] 24 | self.label_paths = [] 25 | self.num_classes = 5 26 | 27 | def process_labels(self): 28 | """ 29 | Wrapper for processing all labels 30 | """ 31 | train_labels = glob.glob(self.label_path_train + '/*.png') 32 | test_labels = glob.glob(self.label_path_test + '/*.png') 33 | self.label_paths = train_labels + test_labels 34 | print("Accumulating labels...") 35 | print(len(self.label_paths)) 36 | for label_img in self.label_paths: 37 | print(label_img) 38 | label = cv2.imread(label_img, -1) 39 | self.label_stack.append(label) 40 | print("Accumulating stack of labels done...") 41 | print(self.label_stack) 42 | stack_np = np.stack(self.label_stack, axis=0) 43 | self.weights = self.calculate_class_weights(stack_np, self.num_classes) 44 | print("Weights are: {}".format(self.weights)) 45 | 46 | def load_class_weights(self, weight_file): 47 | """ 48 | Load class weights from .ini file 49 | """ 50 | config = configparser.ConfigParser() 51 | config.sections() 52 | config.read(weight_file) 53 | weights_mat = np.zeros([1, self.num_classes]) 54 | weights_mat[0,0] = float(config['ClassWeights']['background']) 55 | weights_mat[0,1] = float(config['ClassWeights']['fire_extinguisher']) 56 | weights_mat[0,2] = float(config['ClassWeights']['backpack']) 57 | weights_mat[0,3] = float(config['ClassWeights']['drill']) 58 | weights_mat[0,4] = float(config['ClassWeights']['rescue_randy']) 59 | num_images = float(config['ClassWeights']['num_images']) 60 | print("Loaded class weights from .ini file...") 61 | return weights_mat.squeeze(), num_images 62 | 63 | def save_class_weights(self, weight_file): 64 | """ 65 | Save class weights to .ini file 66 | """ 67 | config = configparser.ConfigParser() 68 | config['ClassWeights'] = {} 69 | config['ClassWeights']['background'] = str(self.weights[0]) 70 | config['ClassWeights']['fire_extinguisher'] = str(self.weights[1]) 71 | config['ClassWeights']['backpack'] = str(self.weights[2]) 72 | config['ClassWeights']['drill'] = str(self.weights[3]) 73 | config['ClassWeights']['rescue_randy'] = str(self.weights[4]) 74 | config['ClassWeights']['num_images'] = str(len(self.label_paths)) 75 | with open(weight_file, 'w') as configfile: 76 | config.write(configfile) 77 | print("Saved class weights to .ini file...") 78 | 79 | def calculate_class_weights(self, Y, n_classes, method="paszke", c=1.02): 80 | """ Given the training data labels Calculates the class weights. 81 | Args: 82 | Y: (numpy array) The training labels as class id integers. 83 | The shape does not matter, as long as each element represents 84 | a class id (ie, NOT one-hot-vectors). 85 | n_classes: (int) Number of possible classes. 86 | method: (str) The type of class weighting to use. 87 | - "paszke" = use the method from from Paszke et al 2016 88 | `1/ln(c + class_probability)` 89 | c: (float) Coefficient to use, when using paszke method. 90 | Returns: 91 | weights: (numpy array) Array of shape [n_classes] assigning a 92 | weight value to each class. 93 | References: 94 | Paszke et al 2016: https://arxiv.org/abs/1606.02147 95 | """ 96 | ids, counts = np.unique(Y, return_counts=True) 97 | n_pixels = Y.size 98 | p_class = np.zeros(n_classes) 99 | p_class[ids] = counts/n_pixels 100 | weights = 1/np.log(c+p_class) 101 | return weights 102 | 103 | def main(): 104 | 105 | pst900_path = './PST900_RGBT_Dataset/' 106 | 107 | weight_path = pst900_path + 'weights.ini' 108 | 109 | # Instantiate ClassWeights 110 | calc_weights = ClassWeights(pst900_path) 111 | 112 | # Example: to calculate weights for the entire dataset 113 | calc_weights.process_labels() 114 | 115 | # Example: to save weights to config file 116 | calc_weights.save_class_weights(weight_path) 117 | 118 | # Example: to load weights from config file 119 | weights, img_count = calc_weights.load_class_weights(weight_path) 120 | 121 | if __name__ == '__main__': 122 | main() 123 | 124 | -------------------------------------------------------------------------------- /configs/LASNet.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "LASNet", 3 | 4 | "inputs": "rgbd", 5 | 6 | "dataset": "irseg", 7 | "root": "./dataset/", 8 | "n_classes": 9, 9 | "id_unlabel": -1, 10 | "brightness": 0.5, 11 | "contrast": 0.5, 12 | "saturation": 0.5, 13 | "p": 0.5, 14 | "scales_range": "0.5 2.0", 15 | "crop_size": "480 640", 16 | "eval_scales": "0.5 0.75 1.0 1.25 1.5 1.75", 17 | "eval_flip": "true", 18 | 19 | 20 | "ims_per_gpu": 4, 21 | "num_workers": 4, 22 | 23 | "lr_start": 5e-5, 24 | "momentum": 0.9, 25 | "weight_decay": 5e-4, 26 | "lr_power": 0.9, 27 | "epochs": 200, 28 | 29 | "loss": "crossentropy", 30 | "class_weight": "enet" 31 | } 32 | 33 | 34 | -------------------------------------------------------------------------------- /generate_binary_labels.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | %Path of semantic gts 3 | gtPath = '/Volumes/RGBT_Semantic_Seg/PST900_RGBT_Dataset/labels/'; 4 | 5 | savePath = '/Volumes/RGBT_Semantic_Seg/PST900_RGBT_Dataset/binary_labels/'; 6 | 7 | gts = dir([gtPath '*.png']); 8 | gtsNum = length(gts); 9 | 10 | 11 | for i=1:gtsNum 12 | gt_name = gts(i).name(); 13 | 14 | gt = imread(fullfile(gtPath, gt_name)); 15 | 16 | gt(find(gt>1)) = 255; 17 | 18 | imwrite(gt, [savePath gt_name] ); 19 | 20 | end 21 | 22 | -------------------------------------------------------------------------------- /generate_bound_or_edge.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | %Path of semantic gts 3 | ssgtPath = '/Volumes//RGBT_Semantic_Seg/PST900_RGBT_Dataset/labels/'; 4 | 5 | savePath = '/Volumes/RGBT_Semantic_Seg/PST900_RGBT_Dataset/bound/'; 6 | 7 | ssgts = dir([ssgtPath '*.png']); 8 | gtsNum = length(ssgts); 9 | 10 | 11 | for i=1:gtsNum 12 | ssgt_name = ssgts(i).name(); 13 | 14 | ssgt = imread(fullfile(ssgtPath, ssgt_name)); 15 | 16 | [h,w] = size(ssgt); 17 | 18 | bound = zeros(size(ssgt)); 19 | 20 | padmap = zeros(h+4, w+4); 21 | 22 | padmap(3:h+2,3:w+2) = ssgt; 23 | 24 | 25 | for hh = 1:h 26 | for ww = 1:w 27 | slidewindow = padmap(hh:hh+4, ww:ww+4); 28 | class = unique(slidewindow); 29 | if length(class)>=2 30 | bound(hh,ww) = 255; 31 | end 32 | end 33 | end 34 | 35 | 36 | imwrite(uint8(bound), [savePath ssgt_name] ); 37 | 38 | end 39 | 40 | -------------------------------------------------------------------------------- /model/LASNet.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "LASNet", 3 | 4 | "inputs": "rgbd", 5 | 6 | "dataset": "irseg", 7 | "root": "./dataset/", 8 | "n_classes": 9, 9 | "id_unlabel": -1, 10 | "brightness": 0.5, 11 | "contrast": 0.5, 12 | "saturation": 0.5, 13 | "p": 0.5, 14 | "scales_range": "0.5 2.0", 15 | "crop_size": "480 640", 16 | "eval_scales": "0.5 0.75 1.0 1.25 1.5 1.75", 17 | "eval_flip": "true", 18 | 19 | 20 | "ims_per_gpu": 4, 21 | "num_workers": 4, 22 | 23 | "lr_start": 5e-5, 24 | "momentum": 0.9, 25 | "weight_decay": 5e-4, 26 | "lr_power": 0.9, 27 | "epochs": 200, 28 | 29 | "loss": "crossentropy", 30 | "class_weight": "enet" 31 | } 32 | 33 | 34 | -------------------------------------------------------------------------------- /model/predicts_MFNet.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/model/predicts_MFNet.zip -------------------------------------------------------------------------------- /model/predicts_PST900.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/model/predicts_PST900.zip -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # import torchvision.models as models 2 | # import torch.nn as nn 3 | # # https://pytorch.org/docs/stable/torchvision/models.html#id3 4 | # 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | model_urls = { 10 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 11 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 12 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 13 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 14 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | 22 | 23 | def conv1x1(in_planes, out_planes, stride=1): 24 | """1x1 convolution""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = conv1x1(inplanes, planes) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = conv3x3(planes, planes, stride) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = conv1x1(planes, planes * self.expansion) 70 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | identity = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | identity = self.downsample(x) 91 | 92 | out += identity 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class ResNet(nn.Module): 99 | def __init__(self, block, layers, zero_init_residual=False): 100 | super(ResNet, self).__init__() 101 | self.inplanes = 64 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 6 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 3 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 114 | elif isinstance(m, nn.BatchNorm2d): 115 | nn.init.constant_(m.weight, 1) 116 | nn.init.constant_(m.bias, 0) 117 | 118 | # Zero-initialize the last BN in each residual branch, 119 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 120 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 121 | if zero_init_residual: 122 | for m in self.modules(): 123 | if isinstance(m, Bottleneck): 124 | nn.init.constant_(m.bn3.weight, 0) 125 | elif isinstance(m, BasicBlock): 126 | nn.init.constant_(m.bn2.weight, 0) 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | pretrained_dict = model_zoo.load_url(model_urls["resnet18"]) 166 | 167 | model_dict = model.state_dict() 168 | # 1. filter out unnecessary keys 169 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 170 | # 2. overwrite entries in the existing state dict 171 | model_dict.update(pretrained_dict) 172 | # 3. load the new state dict 173 | model.load_state_dict(model_dict) 174 | return model 175 | 176 | 177 | def resnet34(pretrained=False, **kwargs): 178 | """Constructs a ResNet-34 model. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | pretrained_dict = model_zoo.load_url(model_urls["resnet34"]) 186 | 187 | model_dict = model.state_dict() 188 | # 1. filter out unnecessary keys 189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 190 | # 2. overwrite entries in the existing state dict 191 | model_dict.update(pretrained_dict) 192 | # 3. load the new state dict 193 | model.load_state_dict(model_dict) 194 | return model 195 | 196 | 197 | def resnet50(pretrained=False, **kwargs): 198 | """Constructs a ResNet-50 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 204 | 205 | if pretrained: 206 | pretrained_dict = model_zoo.load_url(model_urls["resnet50"]) 207 | 208 | model_dict = model.state_dict() 209 | # 1. filter out unnecessary keys 210 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 211 | # 2. overwrite entries in the existing state dict 212 | model_dict.update(pretrained_dict) 213 | # 3. load the new state dict 214 | model.load_state_dict(model_dict) 215 | 216 | return model 217 | 218 | 219 | def resnet101(pretrained=False, **kwargs): 220 | """Constructs a ResNet-101 model. 221 | 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 226 | if pretrained: 227 | pretrained_dict = model_zoo.load_url(model_urls["resnet101"]) 228 | 229 | model_dict = model.state_dict() 230 | # 1. filter out unnecessary keys 231 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 232 | # 2. overwrite entries in the existing state dict 233 | model_dict.update(pretrained_dict) 234 | # 3. load the new state dict 235 | model.load_state_dict(model_dict) 236 | return model 237 | 238 | 239 | def resnet152(pretrained=False, **kwargs): 240 | """Constructs a ResNet-152 model. 241 | 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | """ 245 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 246 | 247 | if pretrained: 248 | pretrained_dict = model_zoo.load_url(model_urls["resnet152"]) 249 | 250 | model_dict = model.state_dict() 251 | # 1. filter out unnecessary keys 252 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 253 | # 2. overwrite entries in the existing state dict 254 | model_dict.update(pretrained_dict) 255 | # 3. load the new state dict 256 | model.load_state_dict(model_dict) 257 | 258 | return model 259 | 260 | 261 | def Backbone_ResNet34_in3(pretrained=True): 262 | if pretrained: 263 | print("The backbone model loads the pretrained parameters...") 264 | net = resnet34(pretrained=pretrained) 265 | div_2 = nn.Sequential(*list(net.children())[:3]) 266 | div_4 = nn.Sequential(*list(net.children())[3:5]) 267 | div_8 = net.layer2 268 | div_16 = net.layer3 269 | div_32 = net.layer4 270 | 271 | return div_2, div_4, div_8, div_16, div_32 272 | 273 | 274 | def Backbone_ResNet34_in1(pretrained=True): 275 | if pretrained: 276 | print("The backbone model loads the pretrained parameters...") 277 | net = resnet34(pretrained=pretrained) 278 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 279 | div_2 = nn.Sequential(*list(net.children())[:3]) 280 | div_4 = nn.Sequential(*list(net.children())[3:5]) 281 | div_8 = net.layer2 282 | div_16 = net.layer3 283 | div_32 = net.layer4 284 | 285 | return div_2, div_4, div_8, div_16, div_32 286 | 287 | def Backbone_ResNet50_in3(pretrained=True): 288 | if pretrained: 289 | print("The backbone model loads the pretrained parameters...") 290 | net = resnet50(pretrained=pretrained) 291 | div_2 = nn.Sequential(*list(net.children())[:3]) 292 | div_4 = nn.Sequential(*list(net.children())[3:5]) 293 | div_8 = net.layer2 294 | div_16 = net.layer3 295 | div_32 = net.layer4 296 | 297 | return div_2, div_4, div_8, div_16, div_32 298 | 299 | 300 | def Backbone_ResNet50_in1(pretrained=True): 301 | if pretrained: 302 | print("The backbone model loads the pretrained parameters...") 303 | net = resnet50(pretrained=pretrained) 304 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 305 | div_2 = nn.Sequential(*list(net.children())[:3]) 306 | div_4 = nn.Sequential(*list(net.children())[3:5]) 307 | div_8 = net.layer2 308 | div_16 = net.layer3 309 | div_32 = net.layer4 310 | 311 | return div_2, div_4, div_8, div_16, div_32 312 | 313 | 314 | def Backbone_ResNet152_in3(pretrained=True): 315 | if pretrained: 316 | print("The backbone model loads the pretrained parameters...") 317 | net = resnet152(pretrained=pretrained) 318 | div_2 = nn.Sequential(*list(net.children())[:3]) 319 | div_4 = nn.Sequential(*list(net.children())[3:5]) 320 | div_8 = net.layer2 321 | div_16 = net.layer3 322 | div_32 = net.layer4 323 | 324 | return div_2, div_4, div_8, div_16, div_32 325 | 326 | 327 | def Backbone_ResNet152_in1(pretrained=True): 328 | if pretrained: 329 | print("The backbone model loads the pretrained parameters...") 330 | net = resnet152(pretrained=pretrained) 331 | net.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 332 | div_2 = nn.Sequential(*list(net.children())[:3]) 333 | div_4 = nn.Sequential(*list(net.children())[3:5]) 334 | div_8 = net.layer2 335 | div_16 = net.layer3 336 | div_32 = net.layer4 337 | 338 | return div_2, div_4, div_8, div_16, div_32 339 | 340 | 341 | if __name__ == "__main__": 342 | div_2, div_4, div_8, div_16, div_32 = Backbone_ResNet50_in1() 343 | indata = torch.rand(4, 1, 480, 640) 344 | x1 = div_2(indata) 345 | x2 = div_4(x1) 346 | x3 = div_8(x2) 347 | x4 = div_16(x3) 348 | x5 = div_32(x4) 349 | # print(div_8) 350 | print(x1.size()) 351 | print(x2.size()) 352 | print(x3.size()) 353 | print(x4.size()) 354 | print(x5.size()) 355 | 356 | -------------------------------------------------------------------------------- /sober.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from torchvision import transforms 4 | import numpy as np 5 | 6 | 7 | with open(os.path.join('/home/user/EGFNet/dataset', f'all.txt'), 'r') as f: 8 | image_labels = f.readlines() 9 | for i in range(len(image_labels)): 10 | label_path1 = image_labels[i].strip() 11 | imgrgb= cv2.imread('/home/user/EGFNet/dataset/seperated_images/' + label_path1 + '_rgb.png' , 0) 12 | imgdepth = cv2.imread('/home/user/EGFNet/dataset/seperated_images/' + label_path1 + '_th.png', 0) 13 | 14 | 15 | def tensor_to_PIL(tensor): 16 | image = tensor.squeeze(0) 17 | image = unloader(image) 18 | return image 19 | 20 | 21 | 22 | x1 = cv2.Sobel(imgrgb, cv2.CV_16S, 1, 0) 23 | y1 = cv2.Sobel(imgrgb, cv2.CV_16S, 0, 1) 24 | x2 = cv2.Sobel(imgdepth, cv2.CV_16S, 1, 0) 25 | y2 = cv2.Sobel(imgdepth, cv2.CV_16S, 0, 1) 26 | 27 | absX1 = cv2.convertScaleAbs(x1) 28 | absY1 = cv2.convertScaleAbs(y1) 29 | absX2 = cv2.convertScaleAbs(x2) 30 | absY2 = cv2.convertScaleAbs(y2) 31 | 32 | dst1 = cv2.addWeighted(absX1, 0.5, absY1, 0.5, 0) 33 | dst2 = cv2.addWeighted(absX2, 0.5, absY2, 0.5, 0) 34 | loader = transforms.Compose([ 35 | transforms.ToTensor()]) 36 | unloader = transforms.ToPILImage() 37 | 38 | 39 | 40 | dst1 = loader(dst1) 41 | dst2 = loader(dst2) 42 | dst = (dst1 + dst2) / 255. 43 | 44 | c = tensor_to_PIL(dst) 45 | c = np.array(c) 46 | 47 | cv2.imwrite('/home/user/EGFNet/dataset/edge/' + label_path1 + '.png', c) 48 | 49 | 50 | -------------------------------------------------------------------------------- /test_LASNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from tqdm import tqdm 4 | from PIL import Image 5 | import json 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | 11 | from toolbox import get_model 12 | from toolbox import averageMeter, runningScore 13 | from toolbox import class_to_RGB, load_ckpt, save_ckpt 14 | 15 | from toolbox.datasets.irseg import IRSeg 16 | from toolbox.datasets.pst900 import PSTSeg 17 | 18 | 19 | def evaluate(logdir, save_predict=False, options=['val', 'test', 'test_day', 'test_night'], prefix=''): 20 | # 加载配置文件cfg 21 | cfg = None 22 | for file in os.listdir(logdir): 23 | if file.endswith('.json'): 24 | with open(os.path.join(logdir, file), 'r') as fp: 25 | cfg = json.load(fp) 26 | assert cfg is not None 27 | 28 | device = torch.device('cuda') 29 | 30 | loaders = [] 31 | for opt in options: 32 | dataset = IRSeg(cfg, mode=opt) 33 | # dataset = PST900(cfg, mode=opt) 34 | loaders.append((opt, DataLoader(dataset, batch_size=1, shuffle=False, num_workers=cfg['num_workers']))) 35 | cmap = dataset.cmap 36 | 37 | model = get_model(cfg).to(device) 38 | 39 | 40 | model = load_ckpt(logdir, model, prefix=prefix) 41 | 42 | running_metrics_val = runningScore(cfg['n_classes'], ignore_index=cfg['id_unlabel']) 43 | time_meter = averageMeter() 44 | 45 | save_path = os.path.join(logdir, 'predicts') 46 | if not os.path.exists(save_path) and save_predict: 47 | os.mkdir(save_path) 48 | 49 | for name, test_loader in loaders: 50 | running_metrics_val.reset() 51 | print('#'*50 + ' ' + name+prefix + ' ' + '#'*50) 52 | with torch.no_grad(): 53 | model.eval() 54 | for i, sample in tqdm(enumerate(test_loader), total=len(test_loader)): 55 | 56 | time_start = time.time() 57 | 58 | if cfg['inputs'] == 'rgb': 59 | image = sample['image'].to(device) 60 | label = sample['label'].to(device) 61 | predict = model(image) 62 | else: 63 | image = sample['image'].to(device) 64 | depth = sample['depth'].to(device) 65 | label = sample['label'].to(device) 66 | edge = sample['edge'].to(device) 67 | predict = model(image, depth)[0] 68 | 69 | predict = predict.max(1)[1].cpu().numpy() 70 | label = label.cpu().numpy() 71 | running_metrics_val.update(label, predict) 72 | 73 | time_meter.update(time.time() - time_start, n=image.size(0)) 74 | 75 | if save_predict: 76 | predict = predict.squeeze(0) 77 | predict = class_to_RGB(predict, N=len(cmap), cmap=cmap) 78 | predict = Image.fromarray(predict) 79 | predict.save(os.path.join(save_path, sample['label_path'][0])) 80 | 81 | metrics = running_metrics_val.get_scores() 82 | print('overall metrics .....') 83 | for k, v in metrics[0].items(): 84 | print(k, f'{v:.4f}') 85 | 86 | print('iou for each class .....') 87 | for k, v in metrics[1].items(): 88 | print(k, f'{v:.4f}') 89 | print('acc for each class .....') 90 | for k, v in metrics[2].items(): 91 | print(k, f'{v:.4f}') 92 | 93 | 94 | 95 | if __name__ == '__main__': 96 | import argparse 97 | 98 | parser = argparse.ArgumentParser(description="evaluate") 99 | parser.add_argument("--logdir", default="./model/", type=str, 100 | help="run logdir") 101 | parser.add_argument("-s", type=bool, default="./model/", 102 | help="save predict or not") 103 | args = parser.parse_args() 104 | 105 | # prefix option ['', 'best_val_', 'best_test_] 106 | # options=['test', 'test_day', 'test_night'] 107 | evaluate(args.logdir, save_predict=args.s, options=['test'], prefix='') 108 | # evaluate(args.logdir, save_predict=args.s, options=['val'], prefix='') 109 | # evaluate(args.logdir, save_predict=args.s, options=['test_day'], prefix='') 110 | #evaluate(args.logdir, save_predict=args.s, options=['test_night'], prefix='') 111 | # msc_evaluate(args.logdir, save_predict=args.s) 112 | -------------------------------------------------------------------------------- /toolbox/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import averageMeter, runningScore 2 | from .log import get_logger 3 | from .optim import Ranger 4 | 5 | from .utils import ClassWeight, save_ckpt, load_ckpt, class_to_RGB, \ 6 | compute_speed, setup_seed, group_weight_decay 7 | 8 | 9 | def get_dataset(cfg): 10 | assert cfg['dataset'] in ['nyuv2', 'nyuv2_new', 'sunrgbd', 'cityscapes', 'camvid', 'irseg', 'pst900', 'irseg_msv'] 11 | 12 | if cfg['dataset'] == 'irseg': 13 | from .datasets.irseg import IRSeg 14 | # return IRSeg(cfg, mode='trainval'), IRSeg(cfg, mode='test') 15 | return IRSeg(cfg, mode='train'), IRSeg(cfg, mode='val'), IRSeg(cfg, mode='test') 16 | elif cfg['dataset'] == 'pst900': 17 | from .datasets.pst900 import PSTSeg 18 | # return IRSeg(cfg, mode='trainval'), IRSeg(cfg, mode='test') 19 | return PSTSeg(cfg, mode='train'), PSTSeg(cfg, mode='val'), PSTSeg(cfg, mode='test') 20 | 21 | 22 | def get_model(cfg): 23 | 24 | if cfg['model_name'] == 'EGFNet': 25 | from .models.EGFNet import EGFNet 26 | return EGFNet(n_classes=cfg['n_classes']) 27 | else: 28 | from .models.LASNet import LASNet 29 | return LASNet(n_classes=cfg['n_classes']) 30 | -------------------------------------------------------------------------------- /toolbox/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/dual_self_att.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/dual_self_att.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/dual_self_att.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/dual_self_att.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/log.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/log.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/log.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/log.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/losses.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/losses.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/augmentations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/augmentations.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/augmentations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/augmentations.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/augmentations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/augmentations.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/camvid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/camvid.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/irseg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/irseg.cpython-36.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/irseg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/irseg.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/irseg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/irseg.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/nyuv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/nyuv2.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/datasets/__pycache__/pst900.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/datasets/__pycache__/pst900.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import sys 3 | import random 4 | from PIL import Image 5 | 6 | try: 7 | import accimage 8 | except ImportError: 9 | accimage = None 10 | import numbers 11 | import collections 12 | 13 | import torchvision.transforms.functional as F 14 | 15 | __all__ = ["Compose", 16 | "Resize", # 尺寸缩减到对应size, 如果给定size为int,尺寸缩减到(size * height / width, size) 17 | "RandomScale", # 尺寸随机缩放 18 | "RandomCrop", # 随机裁剪,必要时可以进行padding 19 | "RandomHorizontalFlip", # 随机水平翻转 20 | "ColorJitter", # 亮度,对比度,饱和度,色调 21 | "RandomRotation", # 随机旋转 22 | ] 23 | 24 | _pil_interpolation_to_str = { 25 | Image.NEAREST: 'PIL.Image.NEAREST', 26 | Image.BILINEAR: 'PIL.Image.BILINEAR', 27 | Image.BICUBIC: 'PIL.Image.BICUBIC', 28 | Image.LANCZOS: 'PIL.Image.LANCZOS', 29 | Image.HAMMING: 'PIL.Image.HAMMING', 30 | Image.BOX: 'PIL.Image.BOX', 31 | } 32 | 33 | if sys.version_info < (3, 3): 34 | Sequence = collections.Sequence 35 | Iterable = collections.Iterable 36 | else: 37 | Sequence = collections.abc.Sequence 38 | Iterable = collections.abc.Iterable 39 | 40 | 41 | class Lambda(object): 42 | """Apply a user-defined lambda as a transform. 43 | 44 | Args: 45 | lambd (function): Lambda/function to be used for transform. 46 | """ 47 | 48 | def __init__(self, lambd): 49 | assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" 50 | self.lambd = lambd 51 | 52 | def __call__(self, img): 53 | return self.lambd(img) 54 | 55 | 56 | class Compose(object): 57 | def __init__(self, transforms): 58 | self.transforms = transforms 59 | 60 | def __call__(self, sample): 61 | for t in self.transforms: 62 | sample = t(sample) 63 | return sample 64 | 65 | 66 | class Resize(object): 67 | def __init__(self, size): 68 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 69 | self.size = size 70 | 71 | def __call__(self, sample): 72 | assert 'image' in sample.keys() 73 | assert 'label' in sample.keys() 74 | 75 | for key in sample.keys(): 76 | # BILINEAR for image 77 | if key in ['image']: 78 | sample[key] = F.resize(sample[key], self.size, Image.BILINEAR) 79 | # NEAREST for depth, label, bound 80 | else: 81 | sample[key] = F.resize(sample[key], self.size, Image.NEAREST) 82 | 83 | return sample 84 | 85 | 86 | class RandomCrop(object): 87 | 88 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 89 | if isinstance(size, numbers.Number): 90 | self.size = (int(size), int(size)) 91 | else: 92 | self.size = size 93 | self.padding = padding 94 | self.pad_if_needed = pad_if_needed 95 | self.fill = fill 96 | self.padding_mode = padding_mode 97 | 98 | @staticmethod 99 | def get_params(img, output_size): 100 | w, h = img.size 101 | th, tw = output_size 102 | if w == tw and h == th: 103 | return 0, 0, h, w 104 | 105 | i = random.randint(0, h - th) 106 | j = random.randint(0, w - tw) 107 | return i, j, th, tw 108 | 109 | def __call__(self, sample): 110 | img = sample['image'] 111 | if self.padding is not None: 112 | for key in sample.keys(): 113 | sample[key] = F.pad(sample[key], self.padding, self.fill, self.padding_mode) 114 | 115 | # pad the width if needed 116 | if self.pad_if_needed and img.size[0] < self.size[1]: 117 | for key in sample.keys(): 118 | sample[key] = F.pad(sample[key], (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 119 | # pad the height if needed 120 | if self.pad_if_needed and img.size[1] < self.size[0]: 121 | for key in sample.keys(): 122 | sample[key] = F.pad(sample[key], (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 123 | 124 | i, j, h, w = self.get_params(sample['image'], self.size) 125 | for key in sample.keys(): 126 | sample[key] = F.crop(sample[key], i, j, h, w) 127 | 128 | return sample 129 | 130 | 131 | class RandomHorizontalFlip(object): 132 | 133 | def __init__(self, p=0.5): 134 | self.p = p 135 | 136 | def __call__(self, sample): 137 | if random.random() < self.p: 138 | for key in sample.keys(): 139 | sample[key] = F.hflip(sample[key]) 140 | 141 | return sample 142 | 143 | 144 | class RandomScale(object): 145 | def __init__(self, scale): 146 | assert isinstance(scale, Iterable) and len(scale) == 2 147 | assert 0 < scale[0] <= scale[1] 148 | self.scale = scale 149 | 150 | def __call__(self, sample): 151 | assert 'image' in sample.keys() 152 | assert 'label' in sample.keys() 153 | 154 | w, h = sample['image'].size 155 | 156 | scale = random.uniform(self.scale[0], self.scale[1]) 157 | size = (int(round(h * scale)), int(round(w * scale))) 158 | 159 | for key in sample.keys(): 160 | # BILINEAR for image 161 | if key in ['image']: 162 | sample[key] = F.resize(sample[key], size, Image.BILINEAR) 163 | # NEAREST for depth, label, bound 164 | else: 165 | sample[key] = F.resize(sample[key], size, Image.NEAREST) 166 | 167 | return sample 168 | 169 | 170 | class ColorJitter(object): 171 | """Randomly change the brightness, contrast and saturation of an image. 172 | 173 | Args: 174 | brightness (float or tuple of float (min, max)): How much to jitter brightness. 175 | brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] 176 | or the given [min, max]. Should be non negative numbers. 177 | contrast (float or tuple of float (min, max)): How much to jitter contrast. 178 | contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] 179 | or the given [min, max]. Should be non negative numbers. 180 | saturation (float or tuple of float (min, max)): How much to jitter saturation. 181 | saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] 182 | or the given [min, max]. Should be non negative numbers. 183 | hue (float or tuple of float (min, max)): How much to jitter hue. 184 | hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. 185 | Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. 186 | """ 187 | 188 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 189 | self.brightness = self._check_input(brightness, 'brightness') 190 | self.contrast = self._check_input(contrast, 'contrast') 191 | self.saturation = self._check_input(saturation, 'saturation') 192 | self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), 193 | clip_first_on_zero=False) 194 | 195 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 196 | if isinstance(value, numbers.Number): 197 | if value < 0: 198 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 199 | value = [center - value, center + value] 200 | if clip_first_on_zero: 201 | value[0] = max(value[0], 0) 202 | elif isinstance(value, (tuple, list)) and len(value) == 2: 203 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 204 | raise ValueError("{} values should be between {}".format(name, bound)) 205 | else: 206 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 207 | 208 | # if value is 0 or (1., 1.) for brightness/contrast/saturation 209 | # or (0., 0.) for hue, do nothing 210 | if value[0] == value[1] == center: 211 | value = None 212 | return value 213 | 214 | @staticmethod 215 | def get_params(brightness, contrast, saturation, hue): 216 | 217 | transforms = [] 218 | 219 | if brightness is not None: 220 | brightness_factor = random.uniform(brightness[0], brightness[1]) 221 | transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) 222 | 223 | if contrast is not None: 224 | contrast_factor = random.uniform(contrast[0], contrast[1]) 225 | transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) 226 | 227 | if saturation is not None: 228 | saturation_factor = random.uniform(saturation[0], saturation[1]) 229 | transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) 230 | 231 | if hue is not None: 232 | hue_factor = random.uniform(hue[0], hue[1]) 233 | transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) 234 | 235 | random.shuffle(transforms) 236 | transform = Compose(transforms) 237 | 238 | return transform 239 | 240 | def __call__(self, sample): 241 | assert 'image' in sample.keys() 242 | transform = self.get_params(self.brightness, self.contrast, 243 | self.saturation, self.hue) 244 | sample['image'] = transform(sample['image']) 245 | return sample 246 | 247 | 248 | class RandomRotation(object): 249 | 250 | def __init__(self, degrees, resample=False, expand=False, center=None): 251 | if isinstance(degrees, numbers.Number): 252 | if degrees < 0: 253 | raise ValueError("If degrees is a single number, it must be positive.") 254 | self.degrees = (-degrees, degrees) 255 | else: 256 | if len(degrees) != 2: 257 | raise ValueError("If degrees is a sequence, it must be of len 2.") 258 | self.degrees = degrees 259 | 260 | self.resample = resample 261 | self.expand = expand 262 | self.center = center 263 | 264 | @staticmethod 265 | def get_params(degrees): 266 | 267 | return random.uniform(degrees[0], degrees[1]) 268 | 269 | def __call__(self, sample): 270 | 271 | angle = self.get_params(self.degrees) 272 | for key in sample.keys(): 273 | sample[key] = F.rotate(sample[key], angle, self.resample, self.expand, self.center) 274 | 275 | return sample 276 | -------------------------------------------------------------------------------- /toolbox/datasets/camvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data as data 7 | from torchvision import transforms 8 | 9 | from toolbox.datasets.augmentations import Resize, Compose, ColorJitter, RandomHorizontalFlip, RandomCrop, RandomScale 10 | 11 | 12 | class Camvid(data.Dataset): 13 | 14 | def __init__(self, cfg, mode='trainval', do_aug=True): 15 | 16 | assert mode in ['trainval', 'test'], f'{mode} not support.' 17 | self.mode = mode 18 | 19 | ## pre-processing 20 | self.im_to_tensor = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 23 | ]) 24 | 25 | self.root = os.path.join(cfg['root'], 'all_data') 26 | self.n_classes = cfg['n_classes'] 27 | 28 | scale_range = tuple(float(i) for i in cfg['scales_range'].split(' ')) 29 | crop_size = tuple(int(i) for i in cfg['crop_size'].split(' ')) 30 | 31 | self.aug = Compose([ 32 | ColorJitter( 33 | brightness=cfg['brightness'], 34 | contrast=cfg['contrast'], 35 | saturation=cfg['saturation']), 36 | RandomHorizontalFlip(cfg['p']), 37 | RandomScale(scale_range), 38 | RandomCrop(crop_size, pad_if_needed=True) 39 | ]) 40 | 41 | self.val_resize = Resize(crop_size) 42 | 43 | self.mode = mode 44 | self.do_aug = do_aug 45 | 46 | if cfg['class_weight'] == 'enet': 47 | self.class_weight = np.array( 48 | [6.3040, 4.3505, 35.0686, 3.4997, 14.0079, 8.0937, 32.6272, 28.6828, 14.8280, 38.3528, 37.4353, 49 | 18.7975]) 50 | elif cfg['class_weight'] == 'median_freq_balancing': 51 | self.class_weight = np.array( 52 | [0.2778, 0.1770, 4.7280, 0.1358, 0.7816, 0.3785, 3.7939, 2.5866, 0.8480, 6.5770, 5.8139, 1.2184]) 53 | else: 54 | raise (f"{cfg['class_weight']} not support.") 55 | 56 | with open(os.path.join(self.root, f'{mode}.txt'), 'r') as f: 57 | self.infos = f.readlines() 58 | 59 | def __len__(self): 60 | return len(self.infos) 61 | 62 | def __getitem__(self, index): 63 | image_path = self.infos[index].strip() 64 | 65 | image = Image.open(os.path.join(self.root, 'image', self.mode, image_path)) # RGB 0~255 66 | label = Image.open(os.path.join(self.root, 'label', self.mode, image_path)) # 1 channel 0~11 67 | # bound = Image.open(os.path.join(self.root, 'bound', self.mode, image_path)) 68 | 69 | # move unlabel_id from 11 to 0 70 | label = np.asarray(label) 71 | label = label + 1 72 | label[label == 12] = 0 73 | label = Image.fromarray(label) 74 | 75 | sample = { 76 | 'image': image, 77 | # 'bound': bound, 78 | 'label': label, 79 | } 80 | 81 | if self.mode in ['train', 'trainval'] and self.do_aug: # 只对训练集增强 82 | sample = self.aug(sample) 83 | else: 84 | sample = self.val_resize(sample) 85 | 86 | sample['image'] = self.im_to_tensor(sample['image']) 87 | sample['label'] = torch.from_numpy(np.asarray(sample['label'], dtype=np.int64)).long() 88 | # sample['bound'] = torch.from_numpy(np.asarray(sample['bound'], dtype=np.int64)).long() 89 | 90 | sample['label_path'] = image_path.strip().split('/')[-1] # 后期保存预测图时的文件名和label文件名一致 91 | return sample 92 | 93 | @property 94 | def cmap(self): 95 | return [ 96 | (0, 0, 0), # unlabeled 97 | 98 | (128, 128, 128), # sky 99 | (128, 0, 0), # building 100 | (192, 192, 128), # pole 101 | (128, 64, 128), # road 102 | (0, 0, 192), # pavement sidewalk 103 | (128, 128, 0), # tree 104 | (192, 128, 128), # sign_symbol 105 | (64, 64, 128), # fence 106 | (64, 0, 128), # car 107 | (64, 64, 0), # pedestrian 108 | (0, 128, 192), # bicyclist 109 | 110 | ] 111 | 112 | 113 | if __name__ == '__main__': 114 | import json 115 | 116 | path = '/home/dtrimina/Desktop/lxy/Segmentation_final/configs/bbbmodel/camvid_bbbmodel.json' 117 | with open(path, 'r') as fp: 118 | cfg = json.load(fp) 119 | cfg['root'] = '/home/dtrimina/Desktop/lxy/database/camvid' 120 | 121 | 122 | dataset = Camvid(cfg, mode='trainval', do_aug=True) 123 | from toolbox.utils import class_to_RGB 124 | import matplotlib.pyplot as plt 125 | 126 | for i in range(len(dataset)): 127 | sample = dataset[i] 128 | 129 | image = sample['image'] 130 | label = sample['label'] 131 | 132 | image = image.numpy() 133 | image = image.transpose((1, 2, 0)) 134 | image *= np.asarray([0.229, 0.224, 0.225]) 135 | image += np.asarray([0.485, 0.456, 0.406]) 136 | 137 | label = label.numpy() 138 | label = class_to_RGB(label, N=len(dataset.cmap), cmap=dataset.cmap) 139 | 140 | plt.subplot('121') 141 | plt.imshow(image) 142 | plt.subplot('122') 143 | plt.imshow(label) 144 | 145 | plt.show() 146 | 147 | if i == 10: 148 | break 149 | 150 | 151 | # dataset = Camvid(cfg, mode='trainval', do_aug=False) 152 | # from toolbox.utils import ClassWeight 153 | # 154 | # train_loader = torch.utils.data.DataLoader(dataset, batch_size=cfg['ims_per_gpu'], shuffle=True, 155 | # num_workers=cfg['num_workers'], pin_memory=True) 156 | # classweight = ClassWeight('median_freq_balancing') # enet, median_freq_balancing 157 | # class_weight = classweight.get_weight(train_loader, cfg['n_classes']) 158 | # class_weight = torch.from_numpy(class_weight).float() 159 | # # class_weight[cfg['id_unlabel']] = 0 160 | # 161 | # print(class_weight) 162 | # 163 | # # # median_freq_balancing 164 | # # tensor([0.2778, 0.1770, 4.7280, 0.1358, 0.7816, 0.3785, 3.7939, 2.5866, 0.8480, 165 | # # 6.5770, 5.8139, 1.2184]) 166 | # 167 | # # # enet 168 | # # tensor([6.3040, 4.3505, 35.0686, 3.4997, 14.0079, 8.0937, 32.6272, 28.6828, 169 | # # 14.8280, 38.3528, 37.4353, 18.7975]) 170 | -------------------------------------------------------------------------------- /toolbox/datasets/irseg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split 5 | 6 | import torch 7 | import torch.utils.data as data 8 | from torchvision import transforms 9 | from toolbox.datasets.augmentations import Resize, Compose, ColorJitter, RandomHorizontalFlip, RandomCrop, RandomScale, \ 10 | RandomRotation 11 | 12 | 13 | class IRSeg(data.Dataset): 14 | 15 | def __init__(self, cfg, mode='trainval', do_aug=True): 16 | 17 | assert mode in ['train', 'val', 'trainval', 'test', 'test_day', 'test_night'], f'{mode} not support.' 18 | self.mode = mode 19 | 20 | ## pre-processing 21 | self.im_to_tensor = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 24 | ]) 25 | 26 | self.dp_to_tensor = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.449, 0.449, 0.449], [0.226, 0.226, 0.226]), 29 | ]) 30 | 31 | self.root = cfg['root'] 32 | self.n_classes = cfg['n_classes'] 33 | 34 | scale_range = tuple(float(i) for i in cfg['scales_range'].split(' ')) 35 | crop_size = tuple(int(i) for i in cfg['crop_size'].split(' ')) 36 | 37 | self.aug = Compose([ 38 | ColorJitter( 39 | brightness=cfg['brightness'], 40 | contrast=cfg['contrast'], 41 | saturation=cfg['saturation']), 42 | RandomHorizontalFlip(cfg['p']), 43 | RandomScale(scale_range), 44 | RandomCrop(crop_size, pad_if_needed=True) 45 | ]) 46 | 47 | 48 | self.mode = mode 49 | self.do_aug = do_aug 50 | 51 | if cfg['class_weight'] == 'enet': 52 | self.class_weight = np.array( 53 | [1.5105, 16.6591, 29.4238, 34.6315, 40.0845, 41.4357, 47.9794, 45.3725, 44.9000]) 54 | self.binary_class_weight = np.array([1.5121, 10.2388]) 55 | elif cfg['class_weight'] == 'median_freq_balancing': 56 | self.class_weight = np.array( 57 | [0.0118, 0.2378, 0.7091, 1.0000, 1.9267, 1.5433, 0.9057, 3.2556, 1.0686]) 58 | self.binary_class_weight = np.array([0.5454, 6.0061]) 59 | else: 60 | raise (f"{cfg['class_weight']} not support.") 61 | 62 | with open(os.path.join(self.root, f'{mode}.txt'), 'r') as f: 63 | self.infos = f.readlines() 64 | 65 | def __len__(self): 66 | return len(self.infos) 67 | 68 | def __getitem__(self, index): 69 | image_path = self.infos[index].strip() 70 | 71 | 72 | image = Image.open(os.path.join(self.root, 'seperated_images', image_path + '_rgb.png')) 73 | depth = Image.open(os.path.join(self.root, 'seperated_images', image_path + '_th.png')).convert('RGB') 74 | label = Image.open(os.path.join(self.root, 'labels', image_path + '.png')) 75 | bound = Image.open(os.path.join(self.root, 'bound', image_path+'.png')) 76 | edge = Image.open(os.path.join(self.root, 'edge', image_path+'.png')) 77 | binary_label = Image.open(os.path.join(self.root, 'binary_labels', image_path + '.png')) 78 | 79 | 80 | sample = { 81 | 'image': image, 82 | 'depth': depth, # depth is TIR image. 83 | 'label': label, 84 | 'bound': bound, 85 | 'edge': edge, 86 | 'binary_label': binary_label, 87 | } 88 | 89 | if self.mode in ['train', 'trainval'] and self.do_aug: # 只对训练集增强 90 | sample = self.aug(sample) 91 | 92 | sample['image'] = self.im_to_tensor(sample['image']) 93 | sample['depth'] = self.dp_to_tensor(sample['depth']) 94 | sample['label'] = torch.from_numpy(np.asarray(sample['label'], dtype=np.int64)).long() 95 | sample['edge'] = torch.from_numpy(np.asarray(sample['edge'], dtype=np.int64)).long() 96 | sample['bound'] = torch.from_numpy(np.asarray(sample['bound'], dtype=np.int64) / 255.).long() 97 | sample['binary_label'] = torch.from_numpy(np.asarray(sample['binary_label'], dtype=np.int64) / 255.).long() 98 | sample['label_path'] = image_path.strip().split('/')[-1] + '.png' # 后期保存预测图时的文件名和label文件名一致 99 | return sample 100 | 101 | @property 102 | def cmap(self): 103 | return [ 104 | (0, 0, 0), # unlabelled 105 | (64, 0, 128), # car 106 | (64, 64, 0), # person 107 | (0, 128, 192), # bike 108 | (0, 0, 192), # curve 109 | (128, 128, 0), # car_stop 110 | (64, 64, 128), # guardrail 111 | (192, 128, 128), # color_cone 112 | (192, 64, 0), # bump 113 | ] 114 | 115 | 116 | -------------------------------------------------------------------------------- /toolbox/datasets/pst900.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | from sklearn.model_selection import train_test_split 5 | 6 | import torch 7 | import torch.utils.data as data 8 | from torchvision import transforms 9 | from toolbox.datasets.augmentations import Resize, Compose, ColorJitter, RandomHorizontalFlip, RandomCrop, RandomScale, \ 10 | RandomRotation 11 | 12 | 13 | class PSTSeg(data.Dataset): 14 | 15 | def __init__(self, cfg, mode='trainval', do_aug=True): 16 | 17 | assert mode in ['train', 'val', 'trainval', 'test'], f'{mode} not support.' 18 | self.mode = mode 19 | 20 | ## pre-processing 21 | self.im_to_tensor = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 24 | ]) 25 | 26 | self.dp_to_tensor = transforms.Compose([ 27 | transforms.ToTensor(), 28 | transforms.Normalize([0.449, 0.449, 0.449], [0.226, 0.226, 0.226]), 29 | ]) 30 | 31 | self.root = cfg['root'] 32 | self.n_classes = cfg['n_classes'] 33 | 34 | scale_range = tuple(float(i) for i in cfg['scales_range'].split(' ')) 35 | crop_size = tuple(int(i) for i in cfg['crop_size'].split(' ')) 36 | 37 | self.aug = Compose([ 38 | ColorJitter( 39 | brightness=cfg['brightness'], 40 | contrast=cfg['contrast'], 41 | saturation=cfg['saturation']), 42 | RandomHorizontalFlip(cfg['p']), 43 | RandomScale(scale_range), 44 | RandomCrop(crop_size, pad_if_needed=True) 45 | ]) 46 | 47 | 48 | self.mode = mode 49 | self.do_aug = do_aug 50 | 51 | if cfg['class_weight'] == 'enet': 52 | self.class_weight = np.array( 53 | [1.4537, 44.2457, 31.6650, 46.4071, 30.1391]) 54 | self.binary_class_weight = np.array([1.4507, 21.5033]) 55 | else: 56 | raise (f"{cfg['class_weight']} not support.") 57 | 58 | with open(os.path.join(self.root, f'{mode}.txt'), 'r') as f: 59 | self.infos = f.readlines() 60 | 61 | def __len__(self): 62 | return len(self.infos) 63 | 64 | def __getitem__(self, index): 65 | image_path = self.infos[index].strip() 66 | 67 | 68 | image = Image.open(os.path.join(self.root, 'rgb', image_path + '.png')) 69 | depth = Image.open(os.path.join(self.root, 'thermal', image_path + '.png')).convert('RGB') 70 | label = Image.open(os.path.join(self.root, 'labels', image_path + '.png')) 71 | bound = Image.open(os.path.join(self.root, 'bound', image_path+'.png')) 72 | edge = Image.open(os.path.join(self.root, 'bound', image_path+'.png')) 73 | binary_label = Image.open(os.path.join(self.root, 'binary_labels', image_path + '.png')) 74 | 75 | 76 | sample = { 77 | 'image': image, 78 | 'depth': depth, 79 | 'label': label, 80 | 'bound': bound, 81 | 'edge': edge, 82 | 'binary_label': binary_label, 83 | } 84 | 85 | if self.mode in ['train', 'trainval'] and self.do_aug: # 只对训练集增强 86 | sample = self.aug(sample) 87 | 88 | sample['image'] = self.im_to_tensor(sample['image']) 89 | sample['depth'] = self.dp_to_tensor(sample['depth']) 90 | sample['label'] = torch.from_numpy(np.asarray(sample['label'], dtype=np.int64)).long() 91 | sample['edge'] = torch.from_numpy(np.asarray(sample['edge'], dtype=np.int64)).long() # 没有edge 92 | sample['bound'] = torch.from_numpy(np.asarray(sample['bound'], dtype=np.int64) / 255.).long() 93 | sample['binary_label'] = torch.from_numpy(np.asarray(sample['binary_label'], dtype=np.int64) / 255.).long() 94 | sample['label_path'] = image_path.strip().split('/')[-1] + '.png' # 后期保存预测图时的文件名和label文件名一致 95 | return sample 96 | 97 | @property 98 | def cmap(self): 99 | return [ 100 | [0, 0, 0], # background 101 | [0, 0, 255], # fire_extinguisher 102 | [0, 255, 0], # backpack 103 | [255, 0, 0], # drill 104 | [255, 255, 255], # survivor/rescue_randy 105 | ] 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /toolbox/dual_self_att.py: -------------------------------------------------------------------------------- 1 | ########################################################################### 2 | # Created by: CASIA IVA 3 | # Email: jliu@nlpr.ia.ac.cn 4 | # Copyright (c) 2018 5 | ########################################################################### 6 | 7 | import numpy as np 8 | import torch 9 | import math 10 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 11 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 12 | from torch.nn import functional as F 13 | from torch.autograd import Variable 14 | torch_ver = torch.__version__[:3] 15 | 16 | __all__ = ['PAM_Module', 'CAM_Module'] 17 | 18 | 19 | class PAM_Module(Module): 20 | """ Position attention module""" 21 | #Ref from SAGAN 22 | def __init__(self, in_dim): 23 | super(PAM_Module, self).__init__() 24 | self.chanel_in = in_dim 25 | 26 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 27 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 28 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 29 | self.gamma = Parameter(torch.zeros(1)) 30 | 31 | self.softmax = Softmax(dim=-1) 32 | def forward(self, x): 33 | """ 34 | inputs : 35 | x : input feature maps( B X C X H X W) 36 | returns : 37 | out : attention value + input feature 38 | attention: B X (HxW) X (HxW) 39 | """ 40 | m_batchsize, C, height, width = x.size() 41 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 42 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 43 | energy = torch.bmm(proj_query, proj_key) 44 | attention = self.softmax(energy) 45 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) 46 | 47 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 48 | out = out.view(m_batchsize, C, height, width) 49 | 50 | out = self.gamma*out + x 51 | return out 52 | 53 | 54 | class CAM_Module(Module): 55 | """ Channel attention module""" 56 | def __init__(self, in_dim): 57 | super(CAM_Module, self).__init__() 58 | self.chanel_in = in_dim 59 | 60 | 61 | self.gamma = Parameter(torch.zeros(1)) 62 | self.softmax = Softmax(dim=-1) 63 | def forward(self,x): 64 | """ 65 | inputs : 66 | x : input feature maps( B X C X H X W) 67 | returns : 68 | out : attention value + input feature 69 | attention: B X C X C 70 | """ 71 | m_batchsize, C, height, width = x.size() 72 | proj_query = x.view(m_batchsize, C, -1) 73 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 74 | energy = torch.bmm(proj_query, proj_key) 75 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 76 | attention = self.softmax(energy_new) 77 | proj_value = x.view(m_batchsize, C, -1) 78 | 79 | out = torch.bmm(attention, proj_value) 80 | out = out.view(m_batchsize, C, height, width) 81 | 82 | out = self.gamma*out + x 83 | return out 84 | 85 | -------------------------------------------------------------------------------- /toolbox/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | 日志记录 3 | 同时输出到屏幕和文件 4 | 可以通过日志等级,将训练最后得到的结果发送到邮箱,参考下面example 5 | 6 | """ 7 | 8 | import logging 9 | import os 10 | import sys 11 | import time 12 | 13 | 14 | def get_logger(logdir): 15 | 16 | if not os.path.exists(logdir): 17 | os.makedirs(logdir) 18 | logname = f'run-{time.strftime("%Y-%m-%d-%H-%M")}.log' 19 | log_file = os.path.join(logdir, logname) 20 | 21 | # create log 22 | logger = logging.getLogger('train') 23 | logger.setLevel(logging.INFO) 24 | 25 | # Formatter 设置日志输出格式 26 | formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 27 | 28 | # StreamHandler 日志输出1 -> 到控制台 29 | stream_handler = logging.StreamHandler(sys.stdout) 30 | stream_handler.setFormatter(formatter) 31 | logger.addHandler(stream_handler) 32 | 33 | # FileHandler 日志输出2 -> 保存到文件log_file 34 | file_handler = logging.FileHandler(log_file) 35 | file_handler.setFormatter(formatter) 36 | logger.addHandler(file_handler) 37 | 38 | return logger 39 | 40 | 41 | # # example 输出到邮箱 42 | # from logging.handlers import SMTPHandler 43 | # 44 | # logger = logging.getLogger('train') 45 | # logger.setLevel(logging.INFO) 46 | # 47 | # SMTP_handler = SMTPHandler( 48 | # mailhost=('smtp.163.com', 25), 49 | # fromaddr='xxx163emailxxx@163.com', 50 | # toaddrs=['xxxqqemailxxx@qq.com', 'or other emails you want to send'], 51 | # subject='send title', 52 | # credentials=('fromaddr email', 'fromaddr passwd') 53 | # ) 54 | # 55 | # formatter = logging.Formatter('%(asctime)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 56 | # SMTP_handler.setFormatter(formatter) 57 | # SMTP_handler.setLevel(logging.WARNING) # 设置等级为warning, logger.warning('infos')将会把重要结果信息输出到邮箱 58 | # logger.addHandler(SMTP_handler) 59 | # 60 | # logging.warning('information need to be send to email. the final results_old or errors') 61 | 62 | -------------------------------------------------------------------------------- /toolbox/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | #https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | try: 13 | from itertools import ifilterfalse 14 | except ImportError: # py3k 15 | from itertools import filterfalse as ifilterfalse 16 | 17 | 18 | def lovasz_grad(gt_sorted): 19 | """ 20 | Computes gradient of the Lovasz extension w.r.t sorted errors 21 | See Alg. 1 in paper 22 | """ 23 | p = len(gt_sorted) 24 | gts = gt_sorted.sum() 25 | intersection = gts - gt_sorted.float().cumsum(0) 26 | union = gts + (1 - gt_sorted).float().cumsum(0) 27 | jaccard = 1. - intersection / union 28 | if p > 1: # cover 1-pixel case 29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 30 | return jaccard 31 | 32 | 33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 34 | """ 35 | IoU for foreground class 36 | binary: 1 foreground, 0 background 37 | """ 38 | if not per_image: 39 | preds, labels = (preds,), (labels,) 40 | ious = [] 41 | for pred, label in zip(preds, labels): 42 | intersection = ((label == 1) & (pred == 1)).sum() 43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 44 | if not union: 45 | iou = EMPTY 46 | else: 47 | iou = float(intersection) / float(union) 48 | ious.append(iou) 49 | iou = mean(ious) # mean accross images if per_image 50 | return 100 * iou 51 | 52 | 53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 54 | """ 55 | Array of IoU for each (non ignored) class 56 | """ 57 | if not per_image: 58 | preds, labels = (preds,), (labels,) 59 | ious = [] 60 | for pred, label in zip(preds, labels): 61 | iou = [] 62 | for i in range(C): 63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 64 | intersection = ((label == i) & (pred == i)).sum() 65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 66 | if not union: 67 | iou.append(EMPTY) 68 | else: 69 | iou.append(float(intersection) / float(union)) 70 | ious.append(iou) 71 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 72 | return 100 * np.array(ious) 73 | 74 | 75 | # --------------------------- BINARY LOSSES --------------------------- 76 | 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 160 | per_image: compute the loss per image instead of per batch 161 | ignore: void class labels 162 | """ 163 | if per_image: 164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 165 | for prob, lab in zip(probas, labels)) 166 | else: 167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 168 | return loss 169 | 170 | 171 | def lovasz_softmax_flat(probas, labels, classes='present'): 172 | """ 173 | Multi-class Lovasz-Softmax loss 174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 177 | """ 178 | if probas.numel() == 0: 179 | # only void pixels, the gradients should be 0 180 | return probas * 0. 181 | C = probas.size(1) 182 | losses = [] 183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 184 | for c in class_to_sum: 185 | fg = (labels == c).float() # foreground for class c 186 | if (classes is 'present' and fg.sum() == 0): 187 | continue 188 | if C == 1: 189 | if len(classes) > 1: 190 | raise ValueError('Sigmoid output possible only with 1 class') 191 | class_pred = probas[:, 0] 192 | else: 193 | class_pred = probas[:, c] 194 | errors = (Variable(fg) - class_pred).abs() 195 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 196 | perm = perm.data 197 | fg_sorted = fg[perm] 198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 199 | return mean(losses) 200 | 201 | 202 | def flatten_probas(probas, labels, ignore=None): 203 | """ 204 | Flattens predictions in the batch 205 | """ 206 | if probas.dim() == 3: 207 | # assumes output of a sigmoid layer 208 | B, H, W = probas.size() 209 | probas = probas.view(B, 1, H, W) 210 | B, C, H, W = probas.size() 211 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 212 | labels = labels.view(-1) 213 | if ignore is None: 214 | return probas, labels 215 | valid = (labels != ignore) 216 | vprobas = probas[valid.nonzero().squeeze()] 217 | vlabels = labels[valid] 218 | return vprobas, vlabels 219 | 220 | def xloss(logits, labels, ignore=None): 221 | """ 222 | Cross entropy loss 223 | """ 224 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 225 | 226 | 227 | # --------------------------- HELPER FUNCTIONS --------------------------- 228 | def isnan(x): 229 | return x != x 230 | 231 | 232 | def mean(l, ignore_nan=False, empty=0): 233 | """ 234 | nanmean compatible with generators. 235 | """ 236 | l = iter(l) 237 | if ignore_nan: 238 | l = ifilterfalse(isnan, l) 239 | try: 240 | n = 1 241 | acc = next(l) 242 | except StopIteration: 243 | if empty == 'raise': 244 | raise ValueError('Empty mean') 245 | return empty 246 | for n, v in enumerate(l, 2): 247 | acc += v 248 | if n == 1: 249 | return acc 250 | return acc / n -------------------------------------------------------------------------------- /toolbox/metrics.py: -------------------------------------------------------------------------------- 1 | # https://github.com/meetshah1995/pytorch-semseg/blob/master/ptsemseg/metrics.py 2 | 3 | import numpy as np 4 | 5 | 6 | class runningScore(object): 7 | ''' 8 | n_classes: database的类别,包括背景 9 | ignore_index: 需要忽略的类别id,一般为未标注id, eg. CamVid.id_unlabel 10 | ''' 11 | 12 | def __init__(self, n_classes, ignore_index=None): 13 | self.n_classes = n_classes 14 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 15 | 16 | if ignore_index is None or ignore_index < 0 or ignore_index > n_classes: 17 | self.ignore_index = None 18 | elif isinstance(ignore_index, int): 19 | self.ignore_index = (ignore_index,) 20 | else: 21 | try: 22 | self.ignore_index = tuple(ignore_index) 23 | except TypeError: 24 | raise ValueError("'ignore_index' must be an int or iterable") 25 | 26 | def _fast_hist(self, label_true, label_pred, n_class): 27 | mask = (label_true >= 0) & (label_true < n_class) 28 | hist = np.bincount( 29 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 30 | ).reshape(n_class, n_class) 31 | return hist 32 | 33 | def update(self, label_trues, label_preds): 34 | for lt, lp in zip(label_trues, label_preds): 35 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 36 | 37 | def get_scores(self): 38 | """Returns accuracy score evaluation result. 39 | - pixel_acc: 40 | - class_acc: class mean acc 41 | - mIou : mean intersection over union 42 | - fwIou: frequency weighted intersection union 43 | """ 44 | 45 | hist = self.confusion_matrix 46 | 47 | # ignore unlabel 48 | if self.ignore_index is not None: 49 | for index in self.ignore_index: 50 | hist = np.delete(hist, index, axis=0) 51 | hist = np.delete(hist, index, axis=1) 52 | 53 | acc = np.diag(hist).sum() / hist.sum() 54 | cls_acc = np.diag(hist) / hist.sum(axis=1) 55 | acc_cls = np.nanmean(cls_acc) 56 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 57 | mean_iou = np.nanmean(iu) 58 | freq = hist.sum(axis=1) / hist.sum() 59 | fw_iou = (freq[freq > 0] * iu[freq > 0]).sum() 60 | 61 | # set unlabel as nan 62 | if self.ignore_index is not None: 63 | for index in self.ignore_index: 64 | iu = np.insert(iu, index, np.nan) 65 | 66 | cls_iu = dict(zip(range(self.n_classes), iu)) 67 | cls_acc = dict(zip(range(self.n_classes), cls_acc)) 68 | 69 | return ( 70 | { 71 | "pixel_acc: ": acc, 72 | "class_acc: ": acc_cls, 73 | "mIou: ": mean_iou, 74 | "fwIou: ": fw_iou, 75 | }, 76 | cls_iu, 77 | cls_acc, 78 | ) 79 | 80 | def reset(self): 81 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 82 | 83 | 84 | class averageMeter(object): 85 | """Computes and stores the average and current value""" 86 | 87 | def __init__(self): 88 | self.reset() 89 | 90 | def reset(self): 91 | self.val = 0 92 | self.avg = 0 93 | self.sum = 0 94 | self.count = 0 95 | 96 | def update(self, val, n=1): 97 | self.val = val 98 | self.sum += val * n 99 | self.count += n 100 | self.avg = self.sum / self.count 101 | -------------------------------------------------------------------------------- /toolbox/models/LASNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn as nn 3 | import torch 4 | from resnet import Backbone_ResNet152_in3 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from toolbox.dual_self_att import CAM_Module 8 | 9 | 10 | class BasicConv2d(nn.Module): 11 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 12 | super(BasicConv2d, self).__init__() 13 | self.conv = nn.Conv2d(in_planes, out_planes, 14 | kernel_size=kernel_size, stride=stride, 15 | padding=padding, dilation=dilation, bias=False) 16 | self.bn = nn.BatchNorm2d(out_planes) 17 | #self.relu = nn.ReLU(inplace=True) 18 | self.relu = nn.LeakyReLU(0.1) 19 | 20 | def forward(self, x): 21 | x = self.conv(x) 22 | x = self.bn(x) 23 | x = self.relu(x) 24 | return x 25 | 26 | class ChannelAttention(nn.Module): 27 | def __init__(self, in_planes, ratio=4): 28 | super(ChannelAttention, self).__init__() 29 | self.max_pool = nn.AdaptiveMaxPool2d(1) 30 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 31 | self.relu1 = nn.ReLU() 32 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 33 | 34 | self.sigmoid = nn.Sigmoid() 35 | 36 | def forward(self, x): 37 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 38 | out = max_out 39 | return self.sigmoid(out) 40 | 41 | class SpatialAttention(nn.Module): 42 | def __init__(self, kernel_size=3): 43 | super(SpatialAttention, self).__init__() 44 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 45 | padding = 3 if kernel_size == 7 else 1 46 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 47 | self.sigmoid = nn.Sigmoid() 48 | 49 | def forward(self, x): 50 | max_out, _ = torch.max(x, dim=1, keepdim=True) 51 | x = max_out 52 | x = self.conv1(x) 53 | return self.sigmoid(x) 54 | 55 | 56 | class CorrelationModule(nn.Module): 57 | def __init__(self, all_channel=64): 58 | super(CorrelationModule, self).__init__() 59 | self.linear_e = nn.Linear(all_channel, all_channel,bias = False) 60 | self.channel = all_channel 61 | self.fusion = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 62 | 63 | def forward(self, exemplar, query): # exemplar: middle, query: rgb or T 64 | fea_size = exemplar.size()[2:] 65 | all_dim = fea_size[0]*fea_size[1] 66 | exemplar_flat = exemplar.view(-1, self.channel, all_dim) #N,C,H*W 67 | query_flat = query.view(-1, self.channel, all_dim) 68 | exemplar_t = torch.transpose(exemplar_flat,1,2).contiguous() #batchsize x dim x num, N,H*W,C 69 | exemplar_corr = self.linear_e(exemplar_t) # 70 | A = torch.bmm(exemplar_corr, query_flat) 71 | B = F.softmax(torch.transpose(A,1,2),dim=1) 72 | exemplar_att = torch.bmm(query_flat, B).contiguous() 73 | 74 | exemplar_att = exemplar_att.view(-1, self.channel, fea_size[0], fea_size[1]) 75 | exemplar_out = self.fusion(exemplar_att) 76 | 77 | return exemplar_out 78 | 79 | class CLM(nn.Module): 80 | def __init__(self, all_channel=64): 81 | super(CLM, self).__init__() 82 | self.corr_x_2_x_ir = CorrelationModule(all_channel) 83 | self.corr_ir_2_x_ir = CorrelationModule(all_channel) 84 | self.smooth1 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 85 | self.smooth2 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 86 | self.fusion = BasicConv2d(2*all_channel, all_channel, kernel_size=3, padding=1) 87 | self.pred = nn.Conv2d(all_channel, 2, kernel_size=3, padding=1, bias = True) 88 | 89 | def forward(self, x, x_ir, ir): # exemplar: middle, query: rgb or T 90 | corr_x_2_x_ir = self.corr_x_2_x_ir(x_ir,x) 91 | corr_ir_2_x_ir = self.corr_ir_2_x_ir(x_ir,ir) 92 | 93 | summation = self.smooth1(corr_x_2_x_ir + corr_ir_2_x_ir) 94 | multiplication = self.smooth2(corr_x_2_x_ir * corr_ir_2_x_ir) 95 | 96 | fusion = self.fusion(torch.cat([summation,multiplication],1)) 97 | sal_pred = self.pred(fusion) 98 | 99 | return fusion, sal_pred 100 | 101 | 102 | class CAM(nn.Module): 103 | def __init__(self, all_channel=64): 104 | super(CAM, self).__init__() 105 | #self.conv1 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 106 | self.conv2 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 107 | self.sa = SpatialAttention() 108 | # self-channel attention 109 | self.cam = CAM_Module(all_channel) 110 | 111 | def forward(self, x, ir): 112 | multiplication = x * ir 113 | summation = self.conv2(x + ir) 114 | 115 | sa = self.sa(multiplication) 116 | summation_sa = summation.mul(sa) 117 | 118 | sc_feat = self.cam(summation_sa) 119 | 120 | return sc_feat 121 | 122 | 123 | class ESM(nn.Module): 124 | def __init__(self, all_channel=64): 125 | super(ESM, self).__init__() 126 | self.conv1 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 127 | self.conv2 = BasicConv2d(all_channel, all_channel, kernel_size=3, padding=1) 128 | self.dconv1 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, padding=1) 129 | self.dconv2 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, dilation=3, padding=3) 130 | self.dconv3 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, dilation=5, padding=5) 131 | self.dconv4 = BasicConv2d(all_channel,int( all_channel/4), kernel_size=3, dilation=7, padding=7) 132 | self.fuse_dconv = nn.Conv2d(all_channel, all_channel, kernel_size=3,padding=1) 133 | self.pred = nn.Conv2d(all_channel, 2, kernel_size=3, padding=1, bias = True) 134 | 135 | def forward(self, x, ir): 136 | multiplication = self.conv1(x * ir) 137 | summation = self.conv2(x + ir) 138 | fusion = (summation + multiplication) 139 | x1 = self.dconv1(fusion) 140 | x2 = self.dconv2(fusion) 141 | x3 = self.dconv3(fusion) 142 | x4 = self.dconv4(fusion) 143 | out = self.fuse_dconv(torch.cat((x1, x2, x3, x4), dim=1)) 144 | edge_pred = self.pred(out) 145 | 146 | return out, edge_pred 147 | 148 | 149 | class prediction_decoder(nn.Module): 150 | def __init__(self, channel1=64, channel2=128, channel3=256, channel4=256, channel5=512, n_classes=9): 151 | super(prediction_decoder, self).__init__() 152 | # 15 20 153 | self.decoder5 = nn.Sequential( 154 | nn.Dropout2d(p=0.1), 155 | BasicConv2d(channel5, channel5, kernel_size=3, padding=3, dilation=3), 156 | BasicConv2d(channel5, channel4, kernel_size=3, padding=1), 157 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 158 | ) 159 | # 30 40 160 | self.decoder4 = nn.Sequential( 161 | nn.Dropout2d(p=0.1), 162 | BasicConv2d(channel4, channel4, kernel_size=3, padding=3, dilation=3), 163 | BasicConv2d(channel4, channel3, kernel_size=3, padding=1), 164 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 165 | ) 166 | # 60 80 167 | self.decoder3 = nn.Sequential( 168 | nn.Dropout2d(p=0.1), 169 | BasicConv2d(channel3, channel3, kernel_size=3, padding=3, dilation=3), 170 | BasicConv2d(channel3, channel2, kernel_size=3, padding=1), 171 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 172 | ) 173 | # 120 160 174 | self.decoder2 = nn.Sequential( 175 | nn.Dropout2d(p=0.1), 176 | BasicConv2d(channel2, channel2, kernel_size=3, padding=3, dilation=3), 177 | BasicConv2d(channel2, channel1, kernel_size=3, padding=1), 178 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 179 | ) 180 | self.semantic_pred2 = nn.Conv2d(channel1, n_classes, kernel_size=3, padding=1) 181 | # 240 320 -> 480 640 182 | self.decoder1 = nn.Sequential( 183 | nn.Dropout2d(p=0.1), 184 | BasicConv2d(channel1, channel1, kernel_size=3, padding=3, dilation=3), 185 | BasicConv2d(channel1, channel1, kernel_size=3, padding=1), 186 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), # 480 640 187 | BasicConv2d(channel1, channel1, kernel_size=3, padding=1), 188 | nn.Conv2d(channel1, n_classes, kernel_size=3, padding=1) 189 | ) 190 | 191 | def forward(self, x5, x4, x3, x2, x1): 192 | x5_decoder = self.decoder5(x5) 193 | # for PST900 dataset 194 | # since the input size is 720x1280, the size of x5_decoder and x4_decoder is 23 and 45, so we cannot use 2x upsampling directrly. 195 | # x5_decoder = F.interpolate(x5_decoder, size=fea_size, mode="bilinear", align_corners=True) 196 | x4_decoder = self.decoder4(x5_decoder + x4) 197 | x3_decoder = self.decoder3(x4_decoder + x3) 198 | x2_decoder = self.decoder2(x3_decoder + x2) 199 | semantic_pred2 = self.semantic_pred2(x2_decoder) 200 | semantic_pred = self.decoder1(x2_decoder + x1) 201 | 202 | return semantic_pred,semantic_pred2 203 | 204 | 205 | class LASNet(nn.Module): 206 | def __init__(self, n_classes): 207 | super(LASNet, self).__init__() 208 | 209 | ( 210 | self.layer1_rgb, 211 | self.layer2_rgb, 212 | self.layer3_rgb, 213 | self.layer4_rgb, 214 | self.layer5_rgb, 215 | ) = Backbone_ResNet152_in3(pretrained=True) 216 | 217 | # reduce the channel number, input: 480 640 218 | self.rgbconv1 = BasicConv2d(64, 64, kernel_size=3, padding=1) # 240 320 219 | self.rgbconv2 = BasicConv2d(256, 128, kernel_size=3, padding=1) # 120 160 220 | self.rgbconv3 = BasicConv2d(512, 256, kernel_size=3, padding=1) # 60 80 221 | self.rgbconv4 = BasicConv2d(1024, 256, kernel_size=3, padding=1) # 30 40 222 | self.rgbconv5 = BasicConv2d(2048, 512, kernel_size=3, padding=1) # 15 20 223 | 224 | self.CLM5 = CLM(512) 225 | self.CAM4 = CAM(256) 226 | self.CAM3 = CAM(256) 227 | self.CAM2 = CAM(128) 228 | self.ESM1 = ESM(64) 229 | 230 | self.decoder = prediction_decoder(64,128,256,256,512, n_classes) 231 | 232 | def forward(self, rgb, depth): 233 | x = rgb 234 | ir = depth[:, :1, ...] 235 | ir = torch.cat((ir, ir, ir), dim=1) 236 | 237 | x1 = self.layer1_rgb(x) 238 | x2 = self.layer2_rgb(x1) 239 | x3 = self.layer3_rgb(x2) 240 | x4 = self.layer4_rgb(x3) 241 | x5 = self.layer5_rgb(x4) 242 | 243 | ir1 = self.layer1_rgb(ir) 244 | ir2 = self.layer2_rgb(ir1) 245 | ir3 = self.layer3_rgb(ir2) 246 | ir4 = self.layer4_rgb(ir3) 247 | ir5 = self.layer5_rgb(ir4) 248 | 249 | x1 = self.rgbconv1(x1) 250 | x2 = self.rgbconv2(x2) 251 | x3 = self.rgbconv3(x3) 252 | x4 = self.rgbconv4(x4) 253 | x5 = self.rgbconv5(x5) 254 | 255 | ir1 = self.rgbconv1(ir1) 256 | ir2 = self.rgbconv2(ir2) 257 | ir3 = self.rgbconv3(ir3) 258 | ir4 = self.rgbconv4(ir4) 259 | ir5 = self.rgbconv5(ir5) 260 | 261 | out5, sal = self.CLM5(x5, x5*ir5, ir5) 262 | out4 = self.CAM4(x4, ir4) 263 | out3 = self.CAM3(x3, ir3) 264 | out2 = self.CAM2(x2, ir2) 265 | out1, edge = self.ESM1(x1, ir1) 266 | 267 | semantic, semantic2 = self.decoder(out5, out4, out3, out2, out1) 268 | semantic2 = torch.nn.functional.interpolate(semantic2, scale_factor=2, mode='bilinear') 269 | sal = torch.nn.functional.interpolate(sal, scale_factor=32, mode='bilinear') 270 | edge = torch.nn.functional.interpolate(edge, scale_factor=2, mode='bilinear') 271 | 272 | 273 | return semantic, semantic2, sal, edge 274 | 275 | if __name__ == '__main__': 276 | LASNet(9) 277 | # for PST900 dataset 278 | # LASNet(5) 279 | -------------------------------------------------------------------------------- /toolbox/models/__pycache__/EGFNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/EGFNet.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/models/__pycache__/EGFNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/EGFNet.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/models/__pycache__/LASNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/LASNet.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/models/__pycache__/LgyTestNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/LgyTestNet.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/models/__pycache__/LgyTestNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/models/__pycache__/LgyTestNet.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/optim/Ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer, required 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 35 | use_gc=True, gc_conv_only=False 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, 55 | N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | print( 76 | f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") 77 | if (self.use_gc and self.gc_gradient_threshold == 1): 78 | print(f"GC applied to both conv and fc layers") 79 | elif (self.use_gc and self.gc_gradient_threshold == 3): 80 | print(f"GC applied to conv layers only") 81 | 82 | def __setstate__(self, state): 83 | print("set state called") 84 | super(Ranger, self).__setstate__(state) 85 | 86 | def step(self, closure=None): 87 | loss = None 88 | # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 89 | # Uncomment if you need to use the actual closure... 90 | 91 | # if closure is not None: 92 | #loss = closure() 93 | 94 | # Evaluate averages and grad, update param tensors 95 | for group in self.param_groups: 96 | 97 | for p in group['params']: 98 | if p.grad is None: 99 | continue 100 | grad = p.grad.data.float() 101 | 102 | if grad.is_sparse: 103 | raise RuntimeError( 104 | 'Ranger optimizer does not support sparse gradients') 105 | 106 | p_data_fp32 = p.data.float() 107 | 108 | state = self.state[p] # get state dict for this param 109 | 110 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 111 | # if self.first_run_check==0: 112 | # self.first_run_check=1 113 | #print("Initializing slow buffer...should not see this at load from saved model!") 114 | state['step'] = 0 115 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 116 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 117 | 118 | # look ahead weight storage now in state dict 119 | state['slow_buffer'] = torch.empty_like(p.data) 120 | state['slow_buffer'].copy_(p.data) 121 | 122 | else: 123 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 124 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as( 125 | p_data_fp32) 126 | 127 | # begin computations 128 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 129 | beta1, beta2 = group['betas'] 130 | 131 | # GC operation for Conv layers and FC layers 132 | if grad.dim() > self.gc_gradient_threshold: 133 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 134 | 135 | state['step'] += 1 136 | 137 | # compute variance mov avg 138 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 139 | # compute mean moving avg 140 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 141 | 142 | buffered = self.radam_buffer[int(state['step'] % 10)] 143 | 144 | if state['step'] == buffered[0]: 145 | N_sma, step_size = buffered[1], buffered[2] 146 | else: 147 | buffered[0] = state['step'] 148 | beta2_t = beta2 ** state['step'] 149 | N_sma_max = 2 / (1 - beta2) - 1 150 | N_sma = N_sma_max - 2 * \ 151 | state['step'] * beta2_t / (1 - beta2_t) 152 | buffered[1] = N_sma 153 | if N_sma > self.N_sma_threshhold: 154 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( 155 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 156 | else: 157 | step_size = 1.0 / (1 - beta1 ** state['step']) 158 | buffered[2] = step_size 159 | 160 | if group['weight_decay'] != 0: 161 | p_data_fp32.add_(-group['weight_decay'] 162 | * group['lr'], p_data_fp32) 163 | 164 | # apply lr 165 | if N_sma > self.N_sma_threshhold: 166 | denom = exp_avg_sq.sqrt().add_(group['eps']) 167 | p_data_fp32.addcdiv_(-step_size * 168 | group['lr'], exp_avg, denom) 169 | else: 170 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 171 | 172 | p.data.copy_(p_data_fp32) 173 | 174 | # integrated look ahead... 175 | # we do it at the param level instead of group level 176 | if state['step'] % group['k'] == 0: 177 | # get access to slow param tensor 178 | slow_p = state['slow_buffer'] 179 | # (fast weights - slow weights) * alpha 180 | slow_p.add_(self.alpha, p.data - slow_p) 181 | # copy interpolated weights to RAdam param tensor 182 | p.data.copy_(slow_p) 183 | 184 | return loss -------------------------------------------------------------------------------- /toolbox/optim/__pycache__/Ranger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/optim/__pycache__/Ranger.cpython-37.pyc -------------------------------------------------------------------------------- /toolbox/optim/__pycache__/Ranger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/LASNet/539cf52fa889d51a9e9a9bd5c258002319a15eb3/toolbox/optim/__pycache__/Ranger.cpython-38.pyc -------------------------------------------------------------------------------- /toolbox/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import * 2 | -------------------------------------------------------------------------------- /toolbox/scheduler/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler 3 | 4 | 5 | class WarmupMultiStepLR(MultiStepLR): 6 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 7 | warmup_iters=500, last_epoch=-1): 8 | self.warmup_factor = warmup_factor 9 | self.warmup_iters = warmup_iters 10 | super().__init__(optimizer, milestones, gamma, last_epoch) 11 | 12 | def get_lr(self): 13 | if self.last_epoch <= self.warmup_iters: 14 | alpha = self.last_epoch / self.warmup_iters 15 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 16 | # print(self.base_lrs[0]*warmup_factor) 17 | return [lr * warmup_factor for lr in self.base_lrs] 18 | else: 19 | lr = super().get_lr() 20 | return lr 21 | 22 | 23 | class WarmupCosineLR(_LRScheduler): 24 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500, 25 | eta_min=0, last_epoch=-1): 26 | self.warmup_factor = warmup_factor 27 | self.warmup_iters = warmup_iters 28 | self.T_max, self.eta_min = T_max, eta_min 29 | super().__init__(optimizer, last_epoch) 30 | 31 | def get_lr(self): 32 | if self.last_epoch <= self.warmup_iters: 33 | alpha = self.last_epoch / self.warmup_iters 34 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 35 | # print(self.base_lrs[0]*warmup_factor) 36 | return [lr * warmup_factor for lr in self.base_lrs] 37 | else: 38 | return [self.eta_min + (base_lr - self.eta_min) * 39 | (1 + math.cos( 40 | math.pi * (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters))) / 2 41 | for base_lr in self.base_lrs] 42 | 43 | 44 | class WarmupPolyLR(_LRScheduler): 45 | def __init__(self, optimizer, T_max, cur_iter, warmup_factor=1.0 / 3, warmup_iters=500, 46 | eta_min=0, power=0.9): 47 | self.warmup_factor = warmup_factor 48 | self.warmup_iters = warmup_iters 49 | self.power = power 50 | self.T_max, self.eta_min = T_max, eta_min 51 | self.cur_iter = cur_iter 52 | super().__init__(optimizer) 53 | 54 | def get_lr(self): 55 | if self.cur_iter <= self.warmup_iters: 56 | alpha = self.cur_iter / self.warmup_iters 57 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 58 | # print(self.base_lrs[0]*warmup_factor) 59 | return [lr * warmup_factor for lr in self.base_lrs] 60 | else: 61 | return [self.eta_min + (base_lr - self.eta_min) * 62 | math.pow(1 - (self.cur_iter - self.warmup_iters) / (self.T_max - self.warmup_iters), 63 | self.power) for base_lr in self.base_lrs] 64 | 65 | 66 | def poly_learning_rate(cur_epoch, max_epoch, curEpoch_iter, perEpoch_iter, baselr): 67 | cur_iter = cur_epoch * perEpoch_iter + curEpoch_iter 68 | max_iter = max_epoch * perEpoch_iter 69 | lr = baselr * pow((1 - 1.0 * cur_iter / max_iter), 0.9) 70 | 71 | return lr 72 | 73 | 74 | class GradualWarmupScheduler(_LRScheduler): 75 | """ Gradually warm-up(increasing) learning rate in optimizer. 76 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 77 | Args: 78 | optimizer (Optimizer): Wrapped optimizer. 79 | min_lr_mul: target learning rate = base lr * min_lr_mul 80 | total_epoch: target learning rate is reached at total_epoch, gradually 81 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 82 | """ 83 | 84 | def __init__(self, optimizer, total_epoch, min_lr_mul=0.1, after_scheduler=None): 85 | self.min_lr_mul = min_lr_mul 86 | if self.min_lr_mul > 1. or self.min_lr_mul < 0.: 87 | raise ValueError('min_lr_mul should be [0., 1.]') 88 | self.total_epoch = total_epoch 89 | self.after_scheduler = after_scheduler 90 | self.finished = False 91 | super(GradualWarmupScheduler, self).__init__(optimizer) 92 | 93 | def get_lr(self): 94 | if self.last_epoch > self.total_epoch: 95 | if self.after_scheduler: 96 | if not self.finished: 97 | self.after_scheduler.base_lrs = self.base_lrs 98 | self.finished = True 99 | return self.after_scheduler.get_lr() 100 | else: 101 | return self.base_lrs 102 | else: 103 | return [base_lr * (self.min_lr_mul + (1. - self.min_lr_mul) * (self.last_epoch / float(self.total_epoch))) 104 | for base_lr in self.base_lrs] 105 | 106 | def step(self, epoch=None): 107 | if self.finished and self.after_scheduler: 108 | return self.after_scheduler.step(epoch - self.total_epoch) 109 | else: 110 | return super(GradualWarmupScheduler, self).step(epoch) 111 | 112 | 113 | if __name__ == '__main__': 114 | optim = WarmupPolyLR() 115 | -------------------------------------------------------------------------------- /toolbox/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import os 5 | import math 6 | import random 7 | import time 8 | import torch.backends.cudnn as cudnn 9 | 10 | 11 | 12 | class ClassWeight(object): 13 | 14 | def __init__(self, method): 15 | assert method in ['no', 'enet', 'median_freq_balancing'] 16 | self.method = method 17 | 18 | def get_weight(self, dataloader, num_classes): 19 | if self.method == 'no': 20 | return np.ones(num_classes) 21 | if self.method == 'enet': 22 | return self._enet_weighing(dataloader, num_classes) 23 | if self.method == 'median_freq_balancing': 24 | return self._median_freq_balancing(dataloader, num_classes) 25 | 26 | def _enet_weighing(self, dataloader, num_classes, c=1.02): 27 | """Computes class weights as described in the ENet paper: 28 | 29 | w_class = 1 / (ln(c + p_class)), 30 | 31 | where c is usually 1.02 and p_class is the propensity score of that 32 | class: 33 | 34 | propensity_score = freq_class / total_pixels. 35 | 36 | References: https://arxiv.org/abs/1606.02147 37 | 38 | Keyword arguments: 39 | - dataloader (``data.Dataloader``): A data loader to iterate over the 40 | dataset. 41 | - num_classes (``int``): The number of classes. 42 | - c (``int``, optional): AN additional hyper-parameter which restricts 43 | the interval of values for the weights. Default: 1.02. 44 | 45 | """ 46 | print('computing class weight .......................') 47 | class_count = 0 48 | total = 0 49 | for i, sample in tqdm(enumerate(dataloader), total=len(dataloader)): 50 | label = sample['label'] 51 | label = label.cpu().numpy() 52 | 53 | # Flatten label 54 | flat_label = label.flatten() 55 | 56 | # Sum up the number of pixels of each class and the total pixel 57 | # counts for each label 58 | class_count += np.bincount(flat_label, minlength=num_classes) 59 | total += flat_label.size 60 | 61 | # Compute propensity score and then the weights for each class 62 | propensity_score = class_count / total 63 | class_weights = 1 / (np.log(c + propensity_score)) 64 | 65 | return class_weights 66 | 67 | def _median_freq_balancing(self, dataloader, num_classes): 68 | """Computes class weights using median frequency balancing as described 69 | in https://arxiv.org/abs/1411.4734: 70 | 71 | w_class = median_freq / freq_class, 72 | 73 | where freq_class is the number of pixels of a given class divided by 74 | the total number of pixels in images where that class is present, and 75 | median_freq is the median of freq_class. 76 | 77 | Keyword arguments: 78 | - dataloader (``data.Dataloader``): A data loader to iterate over the 79 | dataset. 80 | whose weights are going to be computed. 81 | - num_classes (``int``): The number of classes 82 | 83 | """ 84 | print('computing class weight .......................') 85 | class_count = 0 86 | total = 0 87 | for i, sample in tqdm(enumerate(dataloader), total=len(dataloader)): 88 | label = sample['label'] 89 | label = label.cpu().numpy() 90 | 91 | # Flatten label 92 | flat_label = label.flatten() 93 | 94 | # Sum up the class frequencies 95 | bincount = np.bincount(flat_label, minlength=num_classes) 96 | 97 | # Create of mask of classes that exist in the label 98 | mask = bincount > 0 99 | # Multiply the mask by the pixel count. The resulting array has 100 | # one element for each class. The value is either 0 (if the class 101 | # does not exist in the label) or equal to the pixel count (if 102 | # the class exists in the label) 103 | total += mask * flat_label.size 104 | 105 | # Sum up the number of pixels found for each class 106 | class_count += bincount 107 | 108 | # Compute the frequency and its median 109 | freq = class_count / total 110 | med = np.median(freq) 111 | 112 | return med / freq 113 | 114 | 115 | def color_map(N=256, normalized=False): 116 | """ 117 | Return Color Map in PASCAL VOC format 118 | """ 119 | 120 | def bitget(byteval, idx): 121 | return (byteval & (1 << idx)) != 0 122 | 123 | dtype = "float32" if normalized else "uint8" 124 | cmap = np.zeros((N, 3), dtype=dtype) 125 | for i in range(N): 126 | r = g = b = 0 127 | c = i 128 | for j in range(8): 129 | r = r | (bitget(c, 0) << 7 - j) 130 | g = g | (bitget(c, 1) << 7 - j) 131 | b = b | (bitget(c, 2) << 7 - j) 132 | c = c >> 3 133 | 134 | cmap[i] = np.array([r, g, b]) 135 | 136 | cmap = cmap / 255.0 if normalized else cmap 137 | return cmap 138 | 139 | 140 | def class_to_RGB(label, N, cmap=None, normalized=False): 141 | ''' 142 | label: 2D numpy array with pixel-level classes shape=(h, w) 143 | N: number of classes, including background, should in [0, 255] 144 | cmap: list of colors for N class (include background) \ 145 | if None, use VOC default color map. 146 | normalized: RGB in [0, 1] if True else [0, 255] if False 147 | 148 | :return 上色好的3D RGB numpy array shape=(h, w, 3) 149 | ''' 150 | dtype = "float32" if normalized else "uint8" 151 | 152 | assert len(label.shape) == 2, f'label should be 2D, not {len(label.shape)}D' 153 | label_class = np.asarray(label) 154 | 155 | label_color = np.zeros((label.shape[0], label.shape[1], 3), dtype=dtype) 156 | 157 | if cmap is None: 158 | # 0表示背景为[0 0 0]黑色,1~N表示N个类别彩色 159 | cmap = color_map(N, normalized=normalized) 160 | else: 161 | cmap = np.asarray(cmap, dtype=dtype) 162 | cmap = cmap / 255.0 if normalized else cmap 163 | 164 | assert cmap.shape[0] == N, f'{N} classes and {cmap.shape[0]} colors not match.' 165 | 166 | # 给每个类别根据color_map上色 167 | for i_class in range(N): 168 | label_color[label_class == i_class] = cmap[i_class] 169 | 170 | return label_color 171 | 172 | 173 | def tensor_classes_to_RGBs(label, N, cmap=None): 174 | '''used in tensorboard''' 175 | 176 | if cmap is None: 177 | cmap = color_map(N) 178 | else: 179 | cmap = np.asarray(cmap) 180 | 181 | label = label.clone().cpu().numpy() # (batch_size, H, W) 182 | ctRGB = np.vectorize(lambda x: tuple(cmap[int(x)].tolist())) 183 | 184 | colored = np.asarray(ctRGB(label)).astype(np.float32) # (batch_size, 3, H, W) 185 | colored = colored.squeeze() 186 | 187 | try: 188 | return torch.from_numpy(colored.transpose([1, 0, 2, 3])) 189 | except ValueError: 190 | return torch.from_numpy(colored[np.newaxis, ...]) 191 | 192 | 193 | def save_ckpt(logdir, model, epoch_iter, prefix=''): 194 | state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() 195 | torch.save(state, os.path.join(logdir, prefix + 'model_' + str(epoch_iter) + '.pth')) 196 | 197 | 198 | def load_ckpt(logdir, model, prefix=''): 199 | save_pth = os.path.join(logdir, prefix+'model.pth') 200 | model.load_state_dict(torch.load(save_pth)) 201 | return model 202 | 203 | 204 | def compute_speed(model, input_size, device=0, iteration=100): 205 | torch.cuda.set_device(device) 206 | cudnn.benchmark = True 207 | 208 | model.eval() 209 | model = model.cuda() 210 | 211 | input = torch.randn(*input_size, device=device) 212 | 213 | for _ in range(50): 214 | model(input) 215 | 216 | print('=========Eval Forward Time=========') 217 | torch.cuda.synchronize() 218 | t_start = time.time() 219 | for _ in range(iteration): 220 | model(input) 221 | torch.cuda.synchronize() 222 | elapsed_time = time.time() - t_start 223 | 224 | speed_time = elapsed_time / iteration * 1000 225 | fps = iteration / elapsed_time 226 | 227 | print('Elapsed Time: [%.2f s / %d iter]' % (elapsed_time, iteration)) 228 | print('Speed Time: %.2f ms / iter FPS: %.2f' % (speed_time, fps)) 229 | return speed_time, fps 230 | 231 | 232 | def setup_seed(seed): 233 | torch.manual_seed(seed) 234 | torch.cuda.manual_seed_all(seed) 235 | np.random.seed(seed) 236 | random.seed(seed) 237 | torch.backends.cudnn.deterministic = True 238 | torch.backends.cudnn.benchmark = False 239 | 240 | 241 | def group_weight_decay(model): 242 | 243 | import torch.nn as nn 244 | from torch.nn.modules.conv import _ConvNd 245 | from torch.nn.modules.batchnorm import _BatchNorm 246 | 247 | decays = [] 248 | no_decays = [] 249 | for m in model.modules(): 250 | if isinstance(m, nn.Linear): 251 | decays.append(m.weight) 252 | if m.bias is not None: 253 | no_decays.append(m.bias) 254 | elif isinstance(m, _ConvNd): 255 | decays.append(m.weight) 256 | if m.bias is not None: 257 | no_decays.append(m.bias) 258 | elif isinstance(m, _BatchNorm): 259 | if m.weight is not None: 260 | no_decays.append(m.weight) 261 | if m.bias is not None: 262 | no_decays.append(m.bias) 263 | 264 | assert len(list(model.parameters())) == len(decays) + len(no_decays) 265 | groups = [dict(params=decays), dict(params=no_decays, weight_decay=0.0)] 266 | return groups 267 | 268 | 269 | if __name__ == '__main__': 270 | pass 271 | -------------------------------------------------------------------------------- /train_LASNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | import time 5 | 6 | from apex import amp 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.optim.lr_scheduler import LambdaLR 12 | from torch.utils.data import DataLoader 13 | 14 | from toolbox import get_dataset # loss 15 | from toolbox.optim.Ranger import Ranger 16 | from toolbox import get_logger 17 | from toolbox import get_model 18 | from toolbox import averageMeter, runningScore 19 | from toolbox import save_ckpt 20 | from toolbox.datasets.irseg import IRSeg 21 | from toolbox.datasets.pst900 import PSTSeg 22 | from toolbox.losses import lovasz_softmax 23 | 24 | class eeemodelLoss(nn.Module): 25 | 26 | def __init__(self, class_weight=None, ignore_index=-100, reduction='mean'): 27 | super(eeemodelLoss, self).__init__() 28 | 29 | self.class_weight_semantic = torch.from_numpy(np.array( 30 | [1.5105, 16.6591, 29.4238, 34.6315, 40.0845, 41.4357, 47.9794, 45.3725, 44.9000])).float() 31 | self.class_weight_binary = torch.from_numpy(np.array([1.5121, 10.2388])).float() 32 | self.class_weight_boundary = torch.from_numpy(np.array([1.4459, 23.7228])).float() 33 | 34 | self.class_weight = class_weight 35 | # self.LovaszSoftmax = lovasz_softmax() 36 | self.cross_entropy = nn.CrossEntropyLoss() 37 | 38 | self.semantic_loss = nn.CrossEntropyLoss(weight=self.class_weight_semantic) 39 | self.binary_loss = nn.CrossEntropyLoss(weight=self.class_weight_binary) 40 | self.boundary_loss = nn.CrossEntropyLoss(weight=self.class_weight_boundary) 41 | 42 | def forward(self, inputs, targets): 43 | semantic_gt, binary_gt, boundary_gt = targets 44 | semantic_out, semantic_out_2, sal_out, edge_out = inputs 45 | 46 | loss1 = self.semantic_loss(semantic_out, semantic_gt) 47 | loss2 = lovasz_softmax(F.softmax(semantic_out, dim=1), semantic_gt, ignore=255) 48 | loss3 = self.semantic_loss(semantic_out_2, semantic_gt) 49 | loss4 = self.binary_loss(sal_out, binary_gt) 50 | loss5 = self.boundary_loss(edge_out, boundary_gt) 51 | 52 | loss = loss1 + loss2 + loss3 + 0.5*loss4 + loss5 53 | return loss 54 | 55 | 56 | def run(args): 57 | torch.cuda.set_device(args.cuda) 58 | with open(args.config, 'r') as fp: 59 | cfg = json.load(fp) 60 | 61 | logdir = f'run/{time.strftime("%Y-%m-%d-%H-%M")}-{cfg["dataset"]}-{cfg["model_name"]}-' 62 | if not os.path.exists(logdir): 63 | os.makedirs(logdir) 64 | shutil.copy(args.config, logdir) 65 | 66 | logger = get_logger(logdir) 67 | logger.info(f'Conf | use logdir {logdir}') 68 | 69 | model = get_model(cfg) 70 | device = torch.device(f'cuda:{args.cuda}') 71 | model.to(device) 72 | 73 | 74 | trainset, _, testset = get_dataset(cfg) 75 | train_loader = DataLoader(trainset, batch_size=cfg['ims_per_gpu'], shuffle=True, num_workers=cfg['num_workers'], 76 | pin_memory=True) 77 | test_loader = DataLoader(testset, batch_size=cfg['ims_per_gpu'], shuffle=False, num_workers=cfg['num_workers'], 78 | pin_memory=True) 79 | 80 | params_list = model.parameters() 81 | optimizer = Ranger(params_list, lr=cfg['lr_start'], weight_decay=cfg['weight_decay']) 82 | scheduler = LambdaLR(optimizer, lr_lambda=lambda ep: (1 - ep / cfg['epochs']) ** 0.9) 83 | 84 | train_criterion = eeemodelLoss().to(device) 85 | criterion = nn.CrossEntropyLoss().to(device) 86 | 87 | train_loss_meter = averageMeter() 88 | test_loss_meter = averageMeter() 89 | running_metrics_test = runningScore(cfg['n_classes'], ignore_index=cfg['id_unlabel']) 90 | best_test = 0 91 | 92 | amp.register_float_function(torch, 'sigmoid') 93 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) 94 | 95 | 96 | for ep in range(cfg['epochs']): 97 | 98 | # training 99 | model.train() 100 | train_loss_meter.reset() 101 | for i, sample in enumerate(train_loader): 102 | optimizer.zero_grad() 103 | 104 | image = sample['image'].to(device) 105 | depth = sample['depth'].to(device) 106 | label = sample['label'].to(device) 107 | bound = sample['bound'].to(device) 108 | binary_label = sample['binary_label'].to(device) 109 | targets = [label, binary_label, bound] 110 | predict = model(image, depth) 111 | 112 | loss = train_criterion(predict, targets) 113 | #################################################### 114 | 115 | with amp.scale_loss(loss, optimizer) as scaled_loss: 116 | scaled_loss.backward() 117 | optimizer.step() 118 | 119 | train_loss_meter.update(loss.item()) 120 | 121 | scheduler.step(ep) 122 | 123 | # test 124 | with torch.no_grad(): 125 | model.eval() 126 | running_metrics_test.reset() 127 | test_loss_meter.reset() 128 | for i, sample in enumerate(test_loader): 129 | 130 | image = sample['image'].to(device) 131 | # Here, depth is TIR. 132 | depth = sample['depth'].to(device) 133 | label = sample['label'].to(device) 134 | predict = model(image, depth)[0] 135 | 136 | loss = criterion(predict, label) 137 | test_loss_meter.update(loss.item()) 138 | 139 | predict = predict.max(1)[1].cpu().numpy() # [1, h, w] 140 | label = label.cpu().numpy() 141 | running_metrics_test.update(label, predict) 142 | 143 | train_loss = train_loss_meter.avg 144 | test_loss = test_loss_meter.avg 145 | 146 | test_macc = running_metrics_test.get_scores()[0]["class_acc: "] 147 | test_miou = running_metrics_test.get_scores()[0]["mIou: "] 148 | test_avg = (test_macc + test_miou) / 2 149 | 150 | logger.info( 151 | f'Iter | [{ep + 1:3d}/{cfg["epochs"]}] loss={train_loss:.3f}/{test_loss:.3f}, mPA={test_macc:.3f}, miou={test_miou:.3f}, avg={test_avg:.3f}') 152 | if test_avg > best_test: 153 | best_test = test_avg 154 | save_ckpt(logdir, model,ep+1) 155 | logger.info( 156 | f'Save Iter = [{ep + 1:3d}], mPA={test_macc:.3f}, miou={test_miou:.3f}, avg={test_avg:.3f}') 157 | 158 | 159 | if __name__ == '__main__': 160 | import argparse 161 | 162 | parser = argparse.ArgumentParser(description="config") 163 | parser.add_argument("--config", type=str, default="configs/LASNet.json", help="Configuration file to use") 164 | parser.add_argument("--opt_level", type=str, default='O1') 165 | parser.add_argument("--inputs", type=str.lower, default='rgb', choices=['rgb', 'rgbd']) 166 | parser.add_argument("--resume", type=str, default='', 167 | help="use this file to load last checkpoint for continuing training") 168 | parser.add_argument("--cuda", type=int, default=1, help="set cuda device id") 169 | 170 | args = parser.parse_args() 171 | 172 | print("Starting Training!") 173 | run(args) 174 | --------------------------------------------------------------------------------