├── .gitignore ├── LICENSE ├── README.md ├── builders ├── __init__.py ├── dataset_builder.py └── model_builder.py ├── dataset ├── __init__.py ├── camvid.py ├── camvid │ ├── camvid_test_list.txt │ ├── camvid_train_list.txt │ ├── camvid_trainval_list.txt │ └── camvid_val_list.txt ├── cityscapes.py ├── cityscapes │ ├── cityscapes_test_list.txt │ ├── cityscapes_train_list.txt │ ├── cityscapes_trainval_list.txt │ └── cityscapes_val_list.txt └── inform │ ├── camvid_inform.pkl │ └── cityscapes_inform.pkl ├── eval_fps.py ├── image ├── DABNet_demo.png ├── architecture.png ├── iou_vs_epochs.png └── loss_vs_epochs.png ├── model └── DABNet.py ├── predict.py ├── test.py ├── train.py └── utils ├── colorize_mask.py ├── convert_state.py ├── loss.py ├── lr_scheduler.py ├── metric.py ├── trainID2labelID.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.pyc 4 | *.pyo 5 | *.pth 6 | .idea/ 7 | result/ 8 | checkpoint/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Gen Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DABNet: Depth-wise Asymmetric Bottleneck for Real-time Semantic Segmentation 2 | This project contains the Pytorch implementation for the proposed DABNet: [[arXiv]](https://arxiv.org/abs/1907.11357). 3 | 4 | ### Introduction 5 |

6 | 7 | As a pixel-level prediction task, semantic segmentation needs large computational cost with enormous parameters to obtain high performance. Recently, due to the increasing demand for autonomous systems and robots, it is significant to make a tradeoff between accuracy and inference speed. In this paper, we propose a novel Depthwise Asymmetric Bottleneck (DAB) module to address this dilemma, which efficiently adopts depth-wise asymmetric convolution and dilated convolution to build a bottleneck structure. Based on the DAB module, we design a Depth-wise Asymmetric Bottleneck Network (DABNet) especially for real-time semantic segmentation, which creates sufficient receptive field and densely utilizes the contextual information. Experiments on Cityscapes and CamVid datasets demonstrate that the proposed DABNet achieves a balance between speed and precision. Specifically, without any pretrained model and postprocessing, it achieves 70.1% Mean IoU on the Cityscapes test dataset with only 0.76 million parameters and a speed of 104 FPS on a single GTX 1080Ti card. 8 | 9 | ### Installation 10 | - Env: Python 3.6; PyTorch 1.0; CUDA 9.0; cuDNN V7 11 | - Install some packages 12 | ``` 13 | pip install opencv-python pillow numpy matplotlib 14 | ``` 15 | - Clone this repository 16 | ``` 17 | git clone https://github.com/Reagan1311/DABNet 18 | cd DABNet 19 | ``` 20 | - One GPU with 11GB is needed 21 | 22 | ### Dataset 23 | You need to download the two dataset——CamVid and Cityscapes, and put the files in the `dataset` folder with following structure. 24 | ``` 25 | ├── camvid 26 | | ├── train 27 | | ├── test 28 | | ├── val 29 | | ├── trainannot 30 | | ├── testannot 31 | | ├── valannot 32 | | ├── camvid_trainval_list.txt 33 | | ├── camvid_train_list.txt 34 | | ├── camvid_test_list.txt 35 | | └── camvid_val_list.txt 36 | ├── cityscapes 37 | | ├── gtCoarse 38 | | ├── gtFine 39 | | ├── leftImg8bit 40 | | ├── cityscapes_trainval_list.txt 41 | | ├── cityscapes_train_list.txt 42 | | ├── cityscapes_test_list.txt 43 | | └── cityscapes_val_list.txt 44 | ``` 45 | 46 | ### Training 47 | 48 | - You can run: `python train.py -h` to check the detail of optional arguments. 49 | Basically, in the `train.py`, you can set the dataset, train type, epochs and batch size, etc. 50 | ``` 51 | python train.py --dataset ${camvid, cityscapes} --train_type ${train, trainval} --max_epochs ${EPOCHS} --batch_size ${BATCH_SIZE} --lr ${LR} --resume ${CHECKPOINT_FILE} 52 | ``` 53 | - training on Cityscapes train set 54 | ``` 55 | python train.py --dataset cityscapes 56 | ``` 57 | - training on CamVid train and val set 58 | ``` 59 | python train.py --dataset camvid --train_type trainval --max_epochs 1000 --lr 1e-3 --batch_size 16 60 | ``` 61 | - During training course, every 50 epochs, we will record the mean IoU of train set, validation set and training loss to draw a plot, so you can check whether the training process is normal. 62 | 63 | Val mIoU vs Epochs | Train loss vs Epochs 64 | :-------------------------:|:-------------------------: 65 | ![alt text-1](https://github.com/Reagan1311/DABNet/blob/master/image/iou_vs_epochs.png) | ![alt text-2](https://github.com/Reagan1311/DABNet/blob/master/image/loss_vs_epochs.png) 66 | 67 | (PS: Based on the graphs, we think that training is not saturated yet, maybe the LR is too large, so you can change the hyper-parameter to get better result) 68 | 69 | ### Testing 70 | - After training, the checkpoint will be saved at `checkpoint` folder, you can use `test.py` to get the result. 71 | ``` 72 | python test.py --dataset ${camvid, cityscapes} --checkpoint ${CHECKPOINT_FILE} 73 | ``` 74 | ### Evaluation 75 | - For those dataset that do not provide label on the test set (e.g. Cityscapes), you can use `predict.py` to save all the output images, then submit to official webpage for evaluation. 76 | ``` 77 | python predict.py --checkpoint ${CHECKPOINT_FILE} 78 | ``` 79 | 80 | 81 | ### Inference Speed 82 | - You can run the `eval_fps.py` to test the model inference speed, input the image size such as `512,1024`. 83 | ``` 84 | python eval_fps.py 512,1024 85 | ``` 86 | 87 | ### Results 88 | 89 | - quantitative results: 90 | 91 | |Dataset|Pretrained|Train type|mIoU|FPS|model| 92 | |:-:|:-:|:-:|:-:|:-:|:-:| 93 | |Cityscapes(Fine)|from scratch|trainval|**70.07​%**|104|[Detailed result](https://www.cityscapes-dataset.com/anonymous-results/?id=16896cc219a6d5af875f8aa3d528a0f7c4ce57644aece957938eae9062ed8070)| 94 | |Cityscapes(Fine)|from scratch|train|**69.57​%**|104|[GoogleDrive](https://drive.google.com/open?id=1ZKGBQogSqxyKD-QIJgzyDXw2TR0HUePA)| 95 | |CamVid|from scratch|trainval|**66.72​%**|146|[GoogleDrive](https://drive.google.com/open?id=1EPyv9-FUQwr_23w3kLwwiFKD13uRyFRk)| 96 | 97 | - qualitative segmentation examples: 98 | 99 |

100 | 101 | ### Citation 102 | 103 | Please consider citing the [DABNet](https://arxiv.org/abs/1907.11357) if it's helpful for your research. 104 | ``` 105 | @inproceedings{li2019dabnet, 106 | title={DABNet: Depth-wise Asymmetric Bottleneck for Real-time Semantic Segmentation}, 107 | author={Li, Gen and Kim, Joongkyu}, 108 | booktitle={British Machine Vision Conference}, 109 | year={2019} 110 | } 111 | ``` 112 | ### Thanks to the Third Party Libs 113 | [Pytorch](https://github.com/pytorch/pytorch) 114 | [Pytorch-Deeplab](https://github.com/speedinghzl/Pytorch-Deeplab) 115 | [ERFNet](https://github.com/Eromera/erfnet_pytorch) 116 | [CGNet](https://github.com/wutianyiRosun/CGNet) 117 | -------------------------------------------------------------------------------- /builders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Reagan1311/DABNet/b8d62fe7f14ae4909a9e9aad1dd6e0ade98431cd/builders/__init__.py -------------------------------------------------------------------------------- /builders/dataset_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from torch.utils import data 4 | from dataset.cityscapes import CityscapesDataSet, CityscapesTrainInform, CityscapesValDataSet, CityscapesTestDataSet 5 | from dataset.camvid import CamVidDataSet, CamVidValDataSet, CamVidTrainInform, CamVidTestDataSet 6 | 7 | 8 | def build_dataset_train(dataset, input_size, batch_size, train_type, random_scale, random_mirror, num_workers): 9 | data_dir = os.path.join('./dataset/', dataset) 10 | dataset_list = os.path.join(dataset, '_trainval_list.txt') 11 | train_data_list = os.path.join(data_dir, dataset + '_' + train_type + '_list.txt') 12 | val_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') 13 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') 14 | 15 | # inform_data_file collect the information of mean, std and weigth_class 16 | if not os.path.isfile(inform_data_file): 17 | print("%s is not found" % (inform_data_file)) 18 | if dataset == "cityscapes": 19 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, 20 | inform_data_file=inform_data_file) 21 | elif dataset == 'camvid': 22 | dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, 23 | inform_data_file=inform_data_file) 24 | else: 25 | raise NotImplementedError( 26 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) 27 | 28 | datas = dataCollect.collectDataAndSave() 29 | if datas is None: 30 | print("error while pickling data. Please check.") 31 | exit(-1) 32 | else: 33 | print("find file: ", str(inform_data_file)) 34 | datas = pickle.load(open(inform_data_file, "rb")) 35 | 36 | if dataset == "cityscapes": 37 | 38 | trainLoader = data.DataLoader( 39 | CityscapesDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, 40 | mirror=random_mirror, mean=datas['mean']), 41 | batch_size=batch_size, shuffle=True, num_workers=num_workers, 42 | pin_memory=True, drop_last=True) 43 | 44 | valLoader = data.DataLoader( 45 | CityscapesValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), 46 | batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True, 47 | drop_last=True) 48 | 49 | return datas, trainLoader, valLoader 50 | 51 | elif dataset == "camvid": 52 | 53 | trainLoader = data.DataLoader( 54 | CamVidDataSet(data_dir, train_data_list, crop_size=input_size, scale=random_scale, 55 | mirror=random_mirror, mean=datas['mean']), 56 | batch_size=batch_size, shuffle=True, num_workers=num_workers, 57 | pin_memory=True, drop_last=True) 58 | 59 | valLoader = data.DataLoader( 60 | CamVidValDataSet(data_dir, val_data_list, f_scale=1, mean=datas['mean']), 61 | batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True) 62 | 63 | return datas, trainLoader, valLoader 64 | 65 | 66 | def build_dataset_test(dataset, num_workers, none_gt=False): 67 | data_dir = os.path.join('./dataset/', dataset) 68 | dataset_list = os.path.join(dataset, '_trainval_list.txt') 69 | test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') 70 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') 71 | 72 | # inform_data_file collect the information of mean, std and weigth_class 73 | if not os.path.isfile(inform_data_file): 74 | print("%s is not found" % (inform_data_file)) 75 | if dataset == "cityscapes": 76 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=dataset_list, 77 | inform_data_file=inform_data_file) 78 | elif dataset == 'camvid': 79 | dataCollect = CamVidTrainInform(data_dir, 11, train_set_file=dataset_list, 80 | inform_data_file=inform_data_file) 81 | else: 82 | raise NotImplementedError( 83 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) 84 | 85 | datas = dataCollect.collectDataAndSave() 86 | if datas is None: 87 | print("error while pickling data. Please check.") 88 | exit(-1) 89 | else: 90 | print("find file: ", str(inform_data_file)) 91 | datas = pickle.load(open(inform_data_file, "rb")) 92 | 93 | if dataset == "cityscapes": 94 | # for cityscapes, if test on validation set, set none_gt to False 95 | # if test on the test set, set none_gt to True 96 | if none_gt: 97 | testLoader = data.DataLoader( 98 | CityscapesTestDataSet(data_dir, test_data_list, mean=datas['mean']), 99 | batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 100 | else: 101 | test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') 102 | testLoader = data.DataLoader( 103 | CityscapesValDataSet(data_dir, test_data_list, mean=datas['mean']), 104 | batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 105 | 106 | return datas, testLoader 107 | 108 | elif dataset == "camvid": 109 | 110 | testLoader = data.DataLoader( 111 | CamVidValDataSet(data_dir, test_data_list, mean=datas['mean']), 112 | batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True) 113 | 114 | return datas, testLoader 115 | -------------------------------------------------------------------------------- /builders/model_builder.py: -------------------------------------------------------------------------------- 1 | from model.DABNet import DABNet 2 | 3 | 4 | def build_model(model_name, num_classes): 5 | if model_name == 'DABNet': 6 | return DABNet(classes=num_classes) 7 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .camvid import * 2 | from .cityscapes import * 3 | -------------------------------------------------------------------------------- /dataset/camvid.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import random 4 | import cv2 5 | from torch.utils import data 6 | import pickle 7 | 8 | """ 9 | CamVid is a road scene understanding dataset with 367 training images and 233 testing images of day and dusk scenes. 10 | The challenge is to segment 11 classes such as road, building, cars, pedestrians, signs, poles, side-walk etc. We 11 | resize images to 360x480 pixels for training and testing. 12 | """ 13 | 14 | 15 | class CamVidDataSet(data.Dataset): 16 | """ 17 | CamVidDataSet is employed to load train set 18 | Args: 19 | root: the CamVid dataset path, 20 | list_path: camvid_train_list.txt, include partial path 21 | 22 | """ 23 | 24 | def __init__(self, root='', list_path='', max_iters=None, crop_size=(360, 360), 25 | mean=(128, 128, 128), scale=True, mirror=True, ignore_label=11): 26 | self.root = root 27 | self.list_path = list_path 28 | self.crop_h, self.crop_w = crop_size 29 | self.scale = scale 30 | self.ignore_label = ignore_label 31 | self.mean = mean 32 | self.is_mirror = mirror 33 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 34 | if not max_iters == None: 35 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 36 | self.files = [] 37 | 38 | # for split in ["train", "trainval", "val"]: 39 | for name in self.img_ids: 40 | img_file = osp.join(self.root, name.split()[0]) 41 | # print(img_file) 42 | label_file = osp.join(self.root, name.split()[1]) 43 | # print(label_file) 44 | self.files.append({ 45 | "img": img_file, 46 | "label": label_file, 47 | "name": name 48 | }) 49 | 50 | print("length of train set: ", len(self.files)) 51 | 52 | def __len__(self): 53 | return len(self.files) 54 | 55 | def __getitem__(self, index): 56 | datafiles = self.files[index] 57 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 58 | label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE) 59 | size = image.shape 60 | name = datafiles["name"] 61 | if self.scale: 62 | scale = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0] # random resize between 0.5 and 2 63 | f_scale = scale[random.randint(0, 5)] 64 | # f_scale = 0.5 + random.randint(0, 15) / 10.0 #random resize between 0.5 and 2 65 | image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_LINEAR) 66 | label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_NEAREST) 67 | 68 | image = np.asarray(image, np.float32) 69 | 70 | image -= self.mean 71 | # image = image.astype(np.float32) / 255.0 72 | image = image[:, :, ::-1] # change to RGB 73 | img_h, img_w = label.shape 74 | pad_h = max(self.crop_h - img_h, 0) 75 | pad_w = max(self.crop_w - img_w, 0) 76 | if pad_h > 0 or pad_w > 0: 77 | img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, 78 | pad_w, cv2.BORDER_CONSTANT, 79 | value=(0.0, 0.0, 0.0)) 80 | label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, 81 | pad_w, cv2.BORDER_CONSTANT, 82 | value=(self.ignore_label,)) 83 | else: 84 | img_pad, label_pad = image, label 85 | 86 | img_h, img_w = label_pad.shape 87 | h_off = random.randint(0, img_h - self.crop_h) 88 | w_off = random.randint(0, img_w - self.crop_w) 89 | # roi = cv2.Rect(w_off, h_off, self.crop_w, self.crop_h); 90 | image = np.asarray(img_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) 91 | label = np.asarray(label_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) 92 | 93 | image = image.transpose((2, 0, 1)) # NHWC -> NCHW 94 | 95 | if self.is_mirror: 96 | flip = np.random.choice(2) * 2 - 1 97 | image = image[:, :, ::flip] 98 | label = label[:, ::flip] 99 | 100 | return image.copy(), label.copy(), np.array(size), name 101 | 102 | 103 | class CamVidValDataSet(data.Dataset): 104 | """ 105 | CamVidValDataSet is employed to load val set 106 | Args: 107 | root: the CamVid dataset path, 108 | list_path: camvid_val_list.txt, include partial path 109 | 110 | """ 111 | 112 | def __init__(self, root='', list_path='', 113 | f_scale=1, mean=(128, 128, 128), ignore_label=11): 114 | self.root = root 115 | self.list_path = list_path 116 | self.ignore_label = ignore_label 117 | self.mean = mean 118 | self.f_scale = f_scale 119 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 120 | self.files = [] 121 | for name in self.img_ids: 122 | img_file = osp.join(self.root, name.split()[0]) 123 | # print(img_file) 124 | label_file = osp.join(self.root, name.split()[1]) 125 | # print(label_file) 126 | image_name = name.strip().split()[0].strip().split('/', 1)[1].split('.')[0] 127 | # print("image_name: ",image_name) 128 | self.files.append({ 129 | "img": img_file, 130 | "label": label_file, 131 | "name": image_name 132 | }) 133 | 134 | print("length of Validation set: ", len(self.files)) 135 | 136 | def __len__(self): 137 | return len(self.files) 138 | 139 | def __getitem__(self, index): 140 | datafiles = self.files[index] 141 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 142 | label = cv2.imread(datafiles["label"], cv2.IMREAD_GRAYSCALE) 143 | size = image.shape 144 | name = datafiles["name"] 145 | if self.f_scale != 1: 146 | image = cv2.resize(image, None, fx=self.f_scale, fy=self.f_scale, interpolation=cv2.INTER_LINEAR) 147 | # label = cv2.resize(label, None, fx=self.f_scale, fy=self.f_scale, interpolation = cv2.INTER_NEAREST) 148 | 149 | image = np.asarray(image, np.float32) 150 | 151 | image -= self.mean 152 | # image = image.astype(np.float32) / 255.0 153 | image = image[:, :, ::-1] # revert to RGB 154 | image = image.transpose((2, 0, 1)) # HWC -> CHW 155 | 156 | # print('image.shape:',image.shape) 157 | return image.copy(), label.copy(), np.array(size), name 158 | 159 | 160 | class CamVidTestDataSet(data.Dataset): 161 | """ 162 | CamVidTestDataSet is employed to load test set 163 | Args: 164 | root: the CamVid dataset path, 165 | list_path: camvid_test_list.txt, include partial path 166 | 167 | """ 168 | 169 | def __init__(self, root='', list_path='', 170 | mean=(128, 128, 128), ignore_label=11): 171 | self.root = root 172 | self.list_path = list_path 173 | self.ignore_label = ignore_label 174 | self.mean = mean 175 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 176 | self.files = [] 177 | for name in self.img_ids: 178 | img_file = osp.join(self.root, name.split()[0]) 179 | # print(img_file) 180 | image_name = name.strip().split()[0].strip().split('/', 1)[1].split('.')[0] 181 | # print(image_name) 182 | self.files.append({ 183 | "img": img_file, 184 | "name": image_name 185 | }) 186 | print("lenth of test set ", len(self.files)) 187 | 188 | def __len__(self): 189 | return len(self.files) 190 | 191 | def __getitem__(self, index): 192 | datafiles = self.files[index] 193 | 194 | image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) 195 | name = datafiles["name"] 196 | 197 | image = np.asarray(image, np.float32) 198 | 199 | size = image.shape 200 | image -= self.mean 201 | # image = image.astype(np.float32) / 255.0 202 | image = image[:, :, ::-1] # change to RGB 203 | image = image.transpose((2, 0, 1)) # HWC -> CHW 204 | 205 | return image.copy(), np.array(size), name 206 | 207 | 208 | class CamVidTrainInform: 209 | """ To get statistical information about the train set, such as mean, std, class distribution. 210 | The class is employed for tackle class imbalance. 211 | """ 212 | 213 | def __init__(self, data_dir='', classes=11, train_set_file="", 214 | inform_data_file="", normVal=1.10): 215 | """ 216 | Args: 217 | data_dir: directory where the dataset is kept 218 | classes: number of classes in the dataset 219 | inform_data_file: location where cached file has to be stored 220 | normVal: normalization value, as defined in ERFNet paper 221 | """ 222 | self.data_dir = data_dir 223 | self.classes = classes 224 | self.classWeights = np.ones(self.classes, dtype=np.float32) 225 | self.normVal = normVal 226 | self.mean = np.zeros(3, dtype=np.float32) 227 | self.std = np.zeros(3, dtype=np.float32) 228 | self.train_set_file = train_set_file 229 | self.inform_data_file = inform_data_file 230 | 231 | def compute_class_weights(self, histogram): 232 | """to compute the class weights 233 | Args: 234 | histogram: distribution of class samples 235 | """ 236 | normHist = histogram / np.sum(histogram) 237 | for i in range(self.classes): 238 | self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i])) 239 | 240 | def readWholeTrainSet(self, fileName, train_flag=True): 241 | """to read the whole train set of current dataset. 242 | Args: 243 | fileName: train set file that stores the image locations 244 | trainStg: if processing training or validation data 245 | 246 | return: 0 if successful 247 | """ 248 | global_hist = np.zeros(self.classes, dtype=np.float32) 249 | 250 | no_files = 0 251 | min_val_al = 0 252 | max_val_al = 0 253 | with open(self.data_dir + '/' + fileName, 'r') as textFile: 254 | # with open(fileName, 'r') as textFile: 255 | for line in textFile: 256 | # we expect the text file to contain the data in following format 257 | #