├── README.md ├── builders ├── dataset_builder.py └── model_builder.py ├── dataset ├── __init__.py ├── camvid.py ├── camvid │ ├── camvid_test_list.txt │ ├── camvid_train_list.txt │ ├── camvid_trainval_list.txt │ └── camvid_val_list.txt ├── cityscapes.py └── cityscapes │ ├── cityscapes_test_list.txt │ ├── cityscapes_train_list.txt │ ├── cityscapes_trainval_list.txt │ └── cityscapes_val_list.txt ├── eval_fps.py ├── model └── FBSNet.py ├── predict.py ├── test.py ├── train.py └── utils ├── colorize_mask.py ├── convert_state.py ├── loss.py ├── lr_scheduler.py ├── metric.py ├── trainID2labelID.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # FBSNet 2 | 3 | This repository is an official PyTorch implementation of our paper"FBSNet: A Fast Bilateral Symmetrical Network for Real-Time Semantic Segmentation". Accepted by IEEE TRANSACTIONS ON MULTIMEDIA, 2022. (IF: 6.513) 4 | 5 | [Paper](https://arxiv.org/abs/2109.00699v1) | [Code](https://github.com/XU-GITHUB-curry/FBSNet) 6 | 7 | 8 | 9 | ## Installation 10 | 11 | ``` 12 | cuda == 10.2 13 | Python == 3.6.4 14 | Pytorch == 1.8.0+cu101 15 | 16 | # clone this repository 17 | git clone https://github.com/XU-GITHUB-curry/FBSNet.git 18 | ``` 19 | 20 | 21 | 22 | ## Datasets 23 | 24 | We used Cityscapes dataset and CamVid dataset to train our model. 25 | 26 | - You can download cityscapes dataset from [here](https://www.cityscapes-dataset.com/). 27 | 28 | Note: please download leftImg8bit_trainvaltest.zip(11GB) and gtFine_trainvaltest(241MB). 29 | 30 | The Cityscapes dataset scripts for inspection, preparation, and evaluation can download from [here](https://github.com/mcordts/cityscapesScripts). 31 | 32 | - You can download camvid dataset from [here](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/). 33 | 34 | The folds of your datasets need satisfy the following structures: 35 | 36 | ``` 37 | ├── dataset # contains all datasets for the project 38 | | └── cityscapes # cityscapes dataset 39 | | | └── gtCoarse 40 | | | └── gtFine 41 | | | └── leftImg8bit 42 | | | └── cityscapes_test_list.txt 43 | | | └── cityscapes_train_list.txt 44 | | | └── cityscapes_trainval_list.txt 45 | | | └── cityscapes_val_list.txt 46 | | | └── cityscapesscripts # cityscapes dataset label convert scripts! 47 | | └── camvid # camvid dataset 48 | | | └── test 49 | | | └── testannot 50 | | | └── train 51 | | | └── trainannot 52 | | | └── val 53 | | | └── valannot 54 | | | └── camvid_test_list.txt 55 | | | └── camvid_train_list.txt 56 | | | └── camvid_trainval_list.txt 57 | | | └── camvid_val_list.txt 58 | | └── inform 59 | | | └── camvid_inform.pkl 60 | | | └── cityscapes_inform.pkl 61 | | └── camvid.py 62 | | └── cityscapes.py 63 | 64 | ``` 65 | 66 | 67 | 68 | ## Train 69 | 70 | ``` 71 | # cityscapes 72 | python train.py --dataset cityscapes --train_type train --max_epochs 1000 --lr 4.5e-2 --batch_size 4 73 | 74 | # camvid 75 | python train.py --dataset cityscapes --train_type train --max_epochs 1000 --lr 1e-3 --batch_size 6 76 | ``` 77 | 78 | 79 | 80 | ## Test 81 | 82 | ``` 83 | # cityscapes 84 | python test.py --dataset cityscapes --checkpoint ./checkpoint/cityscapes/FBSNetbs4gpu1_train/model_1000.pth 85 | 86 | # camvid 87 | python test.py --dataset camvid --checkpoint ./checkpoint/camvid/FBSNetbs6gpu1_trainval/model_1000.pth 88 | ``` 89 | 90 | ## Predict 91 | only for cityscapes dataset 92 | ``` 93 | python predict.py --dataset cityscapes 94 | ``` 95 | 96 | ## Results 97 | 98 | - Please refer to our article for more details. 99 | 100 | | Methods | Dataset | Input Size | mIoU(%) | 101 | | :-----: | :--------: | :--------: | :-----: | 102 | | FBSNet | Cityscapes | 512x1024 | 70.9 | 103 | | FBSNet | CamVid | 360x480 | 68.9 | 104 | 105 | 106 | 107 | ## Citation 108 | 109 | If you find this project useful for your research, please cite our paper: 110 | 111 | ``` 112 | @article{gao2022fbsnet, 113 | title={FBSNet: A fast bilateral symmetrical network for real-time semantic segmentation}, 114 | author={Gao, Guangwei and Xu, Guoan and Li, Juncheng and Yu, Yi and Lu, Huimin and Yang, Jian}, 115 | journal={IEEE Transactions on Multimedia}, 116 | year={2022}, 117 | publisher={IEEE} 118 | } 119 | ``` 120 | 121 | 122 | 123 | ## Acknowledgements 124 | 125 | 1. [LEDNet: A Lightweight Encoder-Decoder Network for Real-Time Semantic Segmentation](https://arxiv.org/abs/1905.02423) 126 | 2. [BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation](https://arxiv.org/abs/1808.00897) 127 | -------------------------------------------------------------------------------- /builders/dataset_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle #该pickle模块实现了用于序列化和反序列化Python对象结构的二进制协议 3 | #python的pickle模块实现了基本的数据序列和反序列化。通过pickle模块的序列化操作我们能够将程序中运行的对象信息保存到文件中去,永久存储; 4 | #通过pickle模块的反序列化操作,我们能够从文件中创建上一次程序保存的对象 5 | from torch.utils import data 6 | from dataset.cityscapes import CityscapesDataSet, CityscapesTrainInform, CityscapesValDataSet, CityscapesTestDataSet 7 | from dataset.camvid import CamVidDataSet, CamVidValDataSet, CamVidTrainInform, CamVidTestDataSet 8 | 9 | 10 | def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers): 11 | data_dir = os.path.join('./dataset/', dataset) 12 | dataset_list = os.path.join(dataset, '_trainval_list.txt') 13 | train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt') 14 | val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') 15 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') #两个pkl文件 16 | 17 | # inform_data_file collect the information of mean, std and weight_class,均值/标准差/类别权重 18 | if not os.path.isfile(inform_data_file): #如果找不到pkl文件,报错,相当于没有找到数据集 19 | # os.path.isfile用于判断某一对象(需提供绝对路径)是否为文件 20 | print("%s is not found" % (inform_data_file)) 21 | if dataset == "cityscapes": #训练信息 22 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, 23 | inform_data_file=inform_data_file) 24 | elif dataset == 'camvid': 25 | dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, 26 | inform_data_file=inform_data_file) 27 | else: 28 | raise NotImplementedError( 29 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) 30 | 31 | datas = dataCollect.collectDataAndSave() 32 | if datas is None: 33 | print("error while pickling data. Please check.") 34 | exit(-1) 35 | else: 36 | print("find file: ", str(inform_data_file)) 37 | datas = pickle.load(open(inform_data_file, "rb"))#从file中读取一个字符串,并将它重构为原来的python对象 38 | #file参数表示的需要打开文件的相对路径(当前工作目录)或者一个绝对路径 rb:以二进制方式读写操作 39 | if dataset == "cityscapes": 40 | #训练数据装载 41 | trainLoader = data.DataLoader( 42 | CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, 43 | mirror=random_mirror, mean=datas['mean']), #mean=[128,128,128] 灰度 44 | batch_size=batch_size, shuffle=True, num_workers=num_workers, 45 | pin_memory=True, drop_last=True) 46 | #验证数据加载 47 | valLoader = data.DataLoader( 48 | CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']),#f_scale=1 表示0.75-2 随机裁剪 49 | 50 | batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True, 51 | drop_last=True) 52 | 53 | return datas, trainLoader, valLoader 54 | 55 | elif dataset == "camvid": 56 | 57 | trainLoader = data.DataLoader( 58 | CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, 59 | mirror=random_mirror, mean=datas['mean']), 60 | batch_size=batch_size, shuffle=True, num_workers=num_workers, 61 | pin_memory=True, drop_last=True) 62 | 63 | valLoader = data.DataLoader( 64 | CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), 65 | batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True) 66 | 67 | return datas, trainLoader, valLoader 68 | 69 | 70 | def build_dataset_test(dataset, num_workers, none_gt=False): 71 | data_dir = os.path.join('./dataset/', dataset) 72 | dataset_list = os.path.join(dataset, '_trainval_list.txt') 73 | test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') 74 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') 75 | 76 | # inform_data_file collect the information of mean, std and weight_class 77 | if not os.path.isfile(inform_data_file): 78 | print("%s is not found" % (inform_data_file)) 79 | if dataset == "cityscapes": 80 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, 81 | inform_data_file=inform_data_file) 82 | elif dataset == 'camvid': 83 | dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, 84 | inform_data_file=inform_data_file) 85 | else: 86 | raise NotImplementedError( 87 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) 88 | 89 | datas = dataCollect.collectDataAndSave() 90 | if datas is None: 91 | print("error while pickling data. Please check.") 92 | exit(-1) 93 | else: 94 | print("find file: ", str(inform_data_file)) 95 | datas = pickle.load(open(inform_data_file, "rb")) 96 | 97 | if dataset == "cityscapes": 98 | # for cityscapes, if test on validation set, set none_gt to False 99 | # if test on the test set, set none_gt to True #在验证集测试有groundtruth,在测试集测试没有groundtruth 100 | if none_gt: #test 101 | testLoader = data.DataLoader( 102 | CityscapesTestDataSet(data_dir, test_data_list, mean=datas['mean']),#root='', list_path='' 103 | batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 104 | else: #val 105 | test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') 106 | testLoader = data.DataLoader( 107 | CityscapesValDataSet(data_dir, test_data_list, mean=datas['mean']), 108 | batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 109 | 110 | return datas, testLoader 111 | 112 | elif dataset == "camvid": 113 | 114 | testLoader = data.DataLoader( 115 | CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']), 116 | batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 117 | 118 | return datas, testLoader 119 | -------------------------------------------------------------------------------- /builders/model_builder.py: -------------------------------------------------------------------------------- 1 | from model.FDDWNet import Net 2 | 3 | def build_model(model_name, num_classes): 4 | if model_name == 'FDDWNet': 5 | return Net(classes=num_classes) -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .camvid import * 2 | from .cityscapes import * 3 | #from…import * 导入模块,每次使用模块中的函数,直接使用函数就可以了 4 | -------------------------------------------------------------------------------- /dataset/camvid.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import random 4 | import cv2 5 | from torch.utils import data 6 | import pickle 7 | 8 | """ 9 | CamVid is a road scene understanding dataset with 367 training images and 233 testing images of day and dusk scenes. 10 | The challenge is to segment 11 classes such as road, building, cars, pedestrians, signs, poles, side-walk etc. We 11 | resize images to 360x480 pixels for training and testing. 12 | """ 13 | 14 | 15 | class CamVidDataSet(data.Dataset): 16 | """ 17 | CamVidDataSet is employed to load train set 18 | Args: 19 | root: the CamVid dataset path, 20 | list_path: camvid_train_list.txt, include partial path 21 | 22 | """ 23 | 24 | def __init__(self, root='', list_path='', max_iters=None, crop_size=(360, 360),#(360,480) 25 | 26 | mean=(128, 128, 128), scale=True, mirror=True, ignore_label=11): 27 | self.root = root 28 | self.list_path = list_path 29 | self.crop_h, self.crop_w = crop_size 30 | self.scale = scale 31 | self.ignore_label = ignore_label 32 | self.mean = mean 33 | self.is_mirror = mirror 34 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 35 | if not max_iters == None: 36 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 37 | self.files = [] 38 | 39 | # for split in ["train", "trainval", "val"]: 40 | for name in self.img_ids: 41 | img_file = osp.join(self.root, name.split()[0]) 42 | # print(img_file) 43 | label_file = osp.join(self.root, name.split()[1]) 44 | # print(label_file) 45 | self.files.append({ 46 | "img": img_file, 47 | "label": label_file, 48 | "name": name 49 | }) 50 | 51 | print("length of train set: ", len(self.files)) 52 | 53 | def __len__(self): 54 | return len(self.files) 55 | 56 | def __getitem__(self, index): 57 | datafiles = self.files[index] 58 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 59 | label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE) 60 | size = image.shape 61 | name = datafiles["name"] 62 | if self.scale: 63 | scale = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0] # random resize between 0.5 and 2 64 | f_scale = scale[random.randint(0, 5)] 65 | # f_scale = 0.5 + random.randint(0, 15) / 10.0 #random resize between 0.5 and 2 66 | image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_LINEAR) 67 | label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_NEAREST) 68 | 69 | image = np.asarray(image, np.float32) 70 | 71 | image -= self.mean 72 | # image = image.astype(np.float32) / 255.0 73 | image = image[:, :, ::-1] # change to RGB 74 | img_h, img_w = label.shape 75 | pad_h = max(self.crop_h - img_h, 0) 76 | pad_w = max(self.crop_w - img_w, 0) 77 | if pad_h > 0 or pad_w > 0: 78 | img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, 79 | pad_w, cv2.BORDER_CONSTANT, 80 | value=(0.0, 0.0, 0.0)) 81 | label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, 82 | pad_w, cv2.BORDER_CONSTANT, 83 | value=(self.ignore_label,)) 84 | else: 85 | img_pad, label_pad = image, label 86 | 87 | img_h, img_w = label_pad.shape 88 | h_off = random.randint(0, img_h - self.crop_h) 89 | w_off = random.randint(0, img_w - self.crop_w) 90 | # roi = cv2.Rect(w_off, h_off, self.crop_w, self.crop_h); 91 | image = np.asarray(img_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) 92 | label = np.asarray(label_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) 93 | 94 | image = image.transpose((2, 0, 1)) # NHWC -> NCHW 95 | 96 | if self.is_mirror: 97 | flip = np.random.choice(2) * 2 - 1 98 | image = image[:, :, ::flip] 99 | label = label[:, ::flip] 100 | 101 | return image.copy(), label.copy(), np.array(size), name 102 | 103 | 104 | class CamVidValDataSet(data.Dataset): 105 | """ 106 | CamVidValDataSet is employed to load val set 107 | Args: 108 | root: the CamVid dataset path, 109 | list_path: camvid_val_list.txt, include partial path 110 | 111 | """ 112 | 113 | def __init__(self, root='', list_path='', 114 | f_scale=1, mean=(128, 128, 128), ignore_label=11): 115 | self.root = root 116 | self.list_path = list_path 117 | self.ignore_label = ignore_label 118 | self.mean = mean 119 | self.f_scale = f_scale 120 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 121 | self.files = [] 122 | for name in self.img_ids: 123 | img_file = osp.join(self.root, name.split()[0]) 124 | # print(img_file) 125 | label_file = osp.join(self.root, name.split()[1]) 126 | # print(label_file) 127 | image_name = name.strip().split()[0].strip().split('/', 1)[1].split('.')[0] 128 | # print("image_name: ",image_name) 129 | self.files.append({ 130 | "img": img_file, 131 | "label": label_file, 132 | "name": image_name 133 | }) 134 | 135 | print("length of Validation set: ", len(self.files)) 136 | 137 | def __len__(self): 138 | return len(self.files) 139 | 140 | def __getitem__(self, index): 141 | datafiles = self.files[index] 142 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 143 | label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE) 144 | size = image.shape 145 | name = datafiles["name"] 146 | if self.f_scale != 1: 147 | image = cv2.resize(image, None, fx=self.f_scale, fy=self.f_scale, interpolation=cv2.INTER_LINEAR) 148 | # label = cv2.resize(label, None, fx=self.f_scale, fy=self.f_scale, interpolation = cv2.INTER_NEAREST) 149 | 150 | image = np.asarray(image, np.float32) 151 | 152 | image -= self.mean 153 | # image = image.astype(np.float32) / 255.0 154 | image = image[:, :, ::-1] # revert to RGB 155 | image = image.transpose((2, 0, 1)) # HWC -> CHW 156 | 157 | # print('image.shape:',image.shape) 158 | return image.copy(), label.copy(), np.array(size), name 159 | 160 | 161 | class CamVidTestDataSet(data.Dataset): 162 | """ 163 | CamVidTestDataSet is employed to load test set 164 | Args: 165 | root: the CamVid dataset path, 166 | list_path: camvid_test_list.txt, include partial path 167 | 168 | """ 169 | 170 | def __init__(self, root='', list_path='', 171 | mean=(128, 128, 128), ignore_label=11): 172 | self.root = root 173 | self.list_path = list_path 174 | self.ignore_label = ignore_label 175 | self.mean = mean 176 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 177 | self.files = [] 178 | for name in self.img_ids: 179 | img_file = osp.join(self.root, name.split()[0]) 180 | # print(img_file) 181 | image_name = name.strip().split()[0].strip().split('/', 1)[1].split('.')[0] 182 | # print(image_name) 183 | self.files.append({ 184 | "img": img_file, 185 | "name": image_name 186 | }) 187 | print("lenth of test set ", len(self.files)) 188 | 189 | def __len__(self): 190 | return len(self.files) 191 | 192 | def __getitem__(self, index): 193 | datafiles = self.files[index] 194 | 195 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 196 | name = datafiles["name"] 197 | 198 | image = np.asarray(image, np.float32) 199 | 200 | size = image.shape 201 | image -= self.mean 202 | # image = image.astype(np.float32) / 255.0 203 | image = image[:, :, ::-1] # change to RGB 204 | image = image.transpose((2, 0, 1)) # HWC -> CHW 205 | 206 | return image.copy(), np.array(size), name 207 | 208 | 209 | class CamVidTrainInform: 210 | """ To get statistical information about the train set, such as mean, std, class distribution. 211 | The class is employed for tackle class imbalance. 212 | """ 213 | 214 | def __init__(self, data_dir='', classes=11, train_set_file="", 215 | inform_data_file="", normVal=1.10): 216 | """ 217 | Args: 218 | data_dir: directory where the dataset is kept 219 | classes: number of classes in the dataset 220 | inform_data_file: location where cached file has to be stored 221 | normVal: normalization value, as defined in ERFNet paper 222 | """ 223 | self.data_dir = data_dir 224 | self.classes = classes 225 | self.classWeights = np.ones(self.classes, dtype=np.float32) 226 | self.normVal = normVal 227 | self.mean = np.zeros(3, dtype=np.float32) 228 | self.std = np.zeros(3, dtype=np.float32) 229 | self.train_set_file = train_set_file 230 | self.inform_data_file = inform_data_file 231 | 232 | def compute_class_weights(self, histogram): 233 | """to compute the class weights 234 | Args: 235 | histogram: distribution of class samples 236 | """ 237 | normHist = histogram / np.sum(histogram) 238 | for i in range(self.classes): 239 | self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i])) 240 | 241 | def readWholeTrainSet(self, fileName, train_flag=True): 242 | """to read the whole train set of current dataset. 243 | Args: 244 | fileName: train set file that stores the image locations 245 | trainStg: if processing training or validation data 246 | 247 | return: 0 if successful 248 | """ 249 | global_hist = np.zeros(self.classes, dtype=np.float32) 250 | 251 | no_files = 0 252 | min_val_al = 0 253 | max_val_al = 0 254 | with open(self.data_dir + '/' + fileName, 'r') as textFile: 255 | # with open(fileName, 'r') as textFile: 256 | for line in textFile: 257 | # we expect the text file to contain the data in following format 258 | #