├── README.md ├── builders ├── dataset_builder.py ├── loss_builder.py ├── model_builder.py └── validation_builder.py ├── dataset ├── README.md ├── calculate_class_weight.py ├── cityscapes │ ├── Percent.png │ ├── cityscape_scripts │ │ ├── __init__.py │ │ ├── download_cityscapes.sh │ │ └── process_cityscapes.py │ ├── cityscapes.py │ ├── cityscapes_test_list.txt │ ├── cityscapes_train_list.txt │ ├── cityscapes_trainval_list.txt │ ├── cityscapes_val_list.txt │ └── class_map.csv ├── generate_txt.py └── inform │ └── cityscapes_inform.pkl ├── example ├── aachen_000000_000019_gtFine_color.png ├── aachen_000000_000019_gtFine_labelTrainIds.png ├── aachen_000000_000019_leftImg8bit.png ├── average_results.png ├── class_results1.png ├── class_results2.png ├── class_results3.png ├── lindau_000000_000019_leftImg8bit.png ├── lindau_000000_000019_leftImg8bit_color.png └── lindau_000000_000019_leftImg8bit_gt.png ├── model ├── BiSeNet.py ├── BiSeNetV2.py ├── DDRNet.py ├── DeeplabV3Plus.py ├── FCN8s.py ├── FCN_ResNet.py ├── HRNet.py ├── PSPNet │ ├── psanet.py │ ├── pspnet.py │ └── resnet.py ├── SegNet.py ├── UNet.py ├── base_model │ ├── __init__.py │ ├── resnet.py │ └── xception.py └── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── batchnorm_reimpl.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── predict.py ├── predict.sh ├── requirements.txt ├── train.py ├── train.sh └── utils ├── colorize_mask.py ├── distributed.py ├── earlyStopping.py ├── flops_counter ├── CHANGELOG.md ├── Flops_test.py ├── LICENSE ├── README.md ├── __init__.py ├── ptflops │ ├── __init__.py │ └── flops_counter.py └── setup.py ├── fps_test └── eval_forward_time.py ├── image_transform.py ├── losses ├── __init__.py ├── loss.py └── lovasz_losses.py ├── metric ├── SegmentationMetric.py └── __init__.py ├── optim ├── AdamW.py ├── Lookahead.py ├── RAdam.py ├── Ranger.py └── __init__.py ├── plot_log.py ├── record_log.py ├── scheduler ├── __init__.py └── lr_scheduler.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## :rocket: If it helps you, click a star! :star: ## 2 | ## Update log 3 | - 2020.12.10 Project structure adjustment, the previous code has been deleted, the adjustment will be re-uploaded code 4 | - 2021.04.09 Re-upload the code, "V1 Commit" 5 | - 2021.04.22 update torch distributed training 6 | - Ongoing update ..... 7 | 8 | # 1. Display (Cityscapes) 9 | - Using model DDRNet 1525 test sets, official MIOU =78.4069% 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 |
Average results
Class results1
Class results2
Class results3
19 | 20 | - Comparison of the original and predicted images 21 | 22 | 23 | 24 | 25 | 26 | 27 |
origin
label
predict
28 | 29 | # 2. Install 30 | ```pip install -r requirements.txt```
31 | Experimental environment: 32 | - Ubuntu 16.04 Nvidia-Cards >= 1 33 | - python==3.6.5
34 | - See Dependency Installation Package for details in requirement.txt
35 | 36 | # 3. Model 37 | All the modeling is done in `builders/model_builder.py`
38 | - [x] FCN 39 | - [x] FCN_ResNet 40 | - [x] SegNet 41 | - [x] UNet 42 | - [x] BiSeNet 43 | - [x] BiSeNetV2 44 | - [x] PSPNet 45 | - [x] DeepLabv3_plus 46 | - [x] HRNet 47 | - [x] DDRNet 48 | 49 | | Model| Backbone| Val mIoU | Test mIoU | Imagenet Pretrain| Pretrained Model | 50 | | :--- | :---: |:---: |:---:|:---:|:---:| 51 | | PSPNet | ResNet 50 | 76.54% | - | √ | [PSPNet](https://drive.google.com/file/d/10T321s62xDZQJUR3k0H-l64smYW0QAxN/view?usp=sharing) | 52 | | DeeplabV3+ | ResNet 50 | 77.78% | - | √ | [DeeplabV3+](https://drive.google.com/file/d/1xP7HQwFcXAPuoL_BCYdghOBnEJIxNE-T/view?usp=sharing) | 53 | | DDRNet23_slim | - | | | [DDRNet23_slim_imagenet](https://drive.google.com/file/d/1mg5tMX7TJ9ZVcAiGSB4PEihPtrJyalB4/view) | | 54 | | DDRNet23 | - | | | [DDRNet23_imagenet](https://drive.google.com/file/d/1VoUsERBeuCaiuQJufu8PqpKKtGvCTdug/view) | | 55 | | DDRNet39 | - | 79.63% | - | [DDRNet39_imagenet](https://drive.google.com/file/d/122CMx6DZBaRRf-dOHYwuDY9vG0_UQ10i/view) | [DDRNet39](https://drive.google.com/file/d/1-poQsQzXqGl2d2ILXRhWgQH452MUTX5y/view?usp=sharing) | 56 | Updating more model....... 57 | 58 | # 4. Data preprocessing 59 | This project enables you to expose data sets: `Cityscapes`、`ISPRS`
60 | The data set is uploaded later .....
61 | **Cityscapes data set preparation is shown here:** 62 | 63 | ## 4.1 Download the dataset 64 | Download the dataset from the link on the website, You can get `*leftImg8bit.png` suffix of original image under folder `leftImg8bit`, 65 | `a) *color.png`、`b) *labelIds.png`、`c) *instanceIds.png` suffix of fine labeled image under folder `gtFine`. 66 | ``` 67 | *leftImg8bit.png : the origin picture 68 | a) *color.png : the class is encoded by its color 69 | b) *labelIds.png : the class is encoded by its ID 70 | c) *instanceIds.png : the class and the instance are encoded by an instance ID 71 | ``` 72 | ## 4.2 Onehot encoding of label image 73 | The real label gray scale image Onehot encoding used by the semantic segmentation task is 0-18, so the label needs to be encoded. 74 | Using scripts `dataset/cityscapes/cityscapes_scripts/process_cityscapes.py` 75 | to process the image and get the result `*labelTrainIds.png`. 76 | `process_cityscapes.py` usage: Modify 486 lines `Cityscapes_path'is the path to store your own data. 77 | 78 | - Comparison of original image, color label image and gray label image (0-18) 79 | 80 | 81 | 82 | 83 | 84 | 85 |
***_leftImg8bit
***_gtFine_color
***_gtFine_labelTrainIds
86 | 87 | - Local storage path display `/data/open_data/cityscapes/`: 88 | ``` 89 | data 90 | |--open_data 91 | |--cityscapes 92 | |--leftImg8bit 93 | |--train 94 | |--cologne 95 | |--******* 96 | |--val 97 | |--******* 98 | |--test 99 | |--******* 100 | |--gtFine 101 | |--train 102 | |--cologne 103 | |--******* 104 | |--val 105 | |--******* 106 | |--test 107 | |--******* 108 | ``` 109 | 110 | ## 4.3 Generate image path 111 | - Generate a txt containing the image path
112 | Use script `dataset/generate_txt.py` to generate the path `txt` file containing the original image and labels. 113 | A total of 3 `txt` files will be generated: `cityscapes_train_list.txt`、`cityscapes_val_list.txt`、 114 | `cityscapes_test_list.txt`, and copy the three files to the dataset root directory.
115 | ``` 116 | data 117 | |--open_data 118 | |--cityscapes 119 | |--cityscapes_train_list.txt 120 | |--cityscapes_val_list.txt 121 | |--cityscapes_test_list.txt 122 | |--leftImg8bit 123 | |--train 124 | |--cologne 125 | |--******* 126 | |--val 127 | |--******* 128 | |--test 129 | |--******* 130 | |--gtFine 131 | |--train 132 | |--cologne 133 | |--******* 134 | |--val 135 | |--******* 136 | |--test 137 | |--******* 138 | ``` 139 | 140 | - The contents of the `txt` are shown as follows: 141 | ``` 142 | leftImg8bit/train/cologne/cologne_000000_000019_leftImg8bit.png gtFine/train/cologne/cologne_000000_000019_gtFine_labelTrainIds.png 143 | leftImg8bit/train/cologne/cologne_000001_000019_leftImg8bit.png gtFine/train/cologne/cologne_000001_000019_gtFine_labelTrainIds.png 144 | .............. 145 | ``` 146 | 147 | - The format of the `txt` are shown as follows: 148 | ``` 149 | origin image path + the separator '\t' + label path + the separator '\n' 150 | ``` 151 | 152 | 153 | # TODO..... 154 | # 5. How to train 155 | ``` 156 | sh train.sh 157 | ``` 158 | ## 5.1 Parameters 159 | ``` 160 | python -m torch.distributed.launch --nproc_per_node=2 \ 161 | train.py --model PSPNet_res50 --out_stride 8 \ 162 | --max_epochs 200 --val_epochs 20 --batch_size 4 --lr 0.01 --optim sgd --loss ProbOhemCrossEntropy2d \ 163 | --base_size 768 --crop_size 768 --tile_hw_size 768,768 \ 164 | --root '/data/open_data' --dataset cityscapes --gpus_id 1,2 165 | ``` 166 | # 6. How to validate 167 | ``` 168 | sh predict.sh 169 | ``` -------------------------------------------------------------------------------- /builders/dataset_builder.py: -------------------------------------------------------------------------------- 1 | # _*_ coding: utf-8 _*_ 2 | """ 3 | Time: 2020/11/30 17:02 4 | Author: Ding Cheng(Deeachain) 5 | File: dataset_builder.py 6 | Describe: Write during my study in Nanjing University of Information and Secience Technology 7 | Github: https://github.com/Deeachain 8 | """ 9 | import os 10 | import pickle 11 | import pandas as pd 12 | from dataset.cityscapes.cityscapes import CityscapesTrainDataSet, CityscapesTrainInform, CityscapesValDataSet, \ 13 | CityscapesTestDataSet 14 | 15 | def build_dataset_train(root, dataset, base_size, crop_size): 16 | data_dir = os.path.join(root, dataset) 17 | train_data_list = os.path.join(data_dir, dataset + '_' + 'train_list.txt') 18 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') 19 | 20 | # inform_data_file collect the information of mean, std and weigth_class 21 | if not os.path.isfile(inform_data_file): 22 | print("%s is not found" % (inform_data_file)) 23 | if dataset == "cityscapes": 24 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list, 25 | inform_data_file=inform_data_file) 26 | else: 27 | raise NotImplementedError( 28 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) 29 | 30 | datas = dataCollect.collectDataAndSave() 31 | if datas is None: 32 | print("error while pickling data. Please check.") 33 | exit(-1) 34 | else: 35 | datas = pickle.load(open(inform_data_file, "rb")) 36 | 37 | if dataset == "cityscapes": 38 | TrainDataSet = CityscapesTrainDataSet(data_dir, train_data_list, base_size=base_size, crop_size=crop_size, 39 | mean=datas['mean'], std=datas['std'], ignore_label=255) 40 | return datas, TrainDataSet 41 | 42 | 43 | def build_dataset_test(root, dataset, crop_size, mode='whole', gt=False): 44 | data_dir = os.path.join(root, dataset) 45 | inform_data_file = os.path.join('./dataset/inform/', dataset + '_inform.pkl') 46 | train_data_list = os.path.join(data_dir, dataset + '_train_list.txt') 47 | if mode == 'whole': 48 | test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') 49 | else: 50 | test_data_list = os.path.join(data_dir, dataset + '_test_sliding' + '_list.txt') 51 | 52 | # inform_data_file collect the information of mean, std and weigth_class 53 | if not os.path.isfile(inform_data_file): 54 | print("%s is not found" % (inform_data_file)) 55 | if dataset == "cityscapes": 56 | dataCollect = CityscapesTrainInform(data_dir, 19, train_set_file=train_data_list, 57 | inform_data_file=inform_data_file) 58 | else: 59 | raise NotImplementedError( 60 | "This repository now supports two datasets: cityscapes and camvid, %s is not included" % dataset) 61 | 62 | datas = dataCollect.collectDataAndSave() 63 | if datas is None: 64 | print("error while pickling data. Please check.") 65 | exit(-1) 66 | else: 67 | datas = pickle.load(open(inform_data_file, "rb")) 68 | 69 | class_dict_df = pd.read_csv(os.path.join('./dataset', dataset, 'class_map.csv')) 70 | if dataset == "cityscapes": 71 | # for cityscapes, if test on validation set, set none_gt to False 72 | # if test on the test set, set none_gt to True 73 | if gt: 74 | test_data_list = os.path.join(data_dir, dataset + '_val' + '_list.txt') 75 | testdataset = CityscapesValDataSet(data_dir, test_data_list, crop_size=crop_size, mean=datas['mean'], 76 | std=datas['std'], ignore_label=255) 77 | else: 78 | test_data_list = os.path.join(data_dir, dataset + '_test' + '_list.txt') 79 | testdataset = CityscapesTestDataSet(data_dir, test_data_list, crop_size=crop_size, mean=datas['mean'], 80 | std=datas['std'], ignore_label=255) 81 | return testdataset, class_dict_df 82 | -------------------------------------------------------------------------------- /builders/loss_builder.py: -------------------------------------------------------------------------------- 1 | # _*_ coding: utf-8 _*_ 2 | """ 3 | Time: 2020/11/30 17:02 4 | Author: Ding Cheng(Deeachain) 5 | File: loss_builder.py 6 | Describe: Write during my study in Nanjing University of Information and Secience Technology 7 | Github: https://github.com/Deeachain 8 | """ 9 | import torch 10 | from utils.losses.loss import LovaszSoftmax, CrossEntropyLoss2d, CrossEntropyLoss2dLabelSmooth, \ 11 | ProbOhemCrossEntropy2d, FocalLoss2d, LabelSmoothing 12 | 13 | 14 | def build_loss(args, datas, ignore_label): 15 | if args.dataset == 'cityscapes': 16 | weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 17 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 18 | 1.0865, 1.1529, 1.0507]) 19 | elif datas != None: 20 | weight = torch.from_numpy(datas['classWeights']) 21 | else: 22 | weight = None 23 | 24 | # Default uses cross quotient loss function 25 | criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label) 26 | if args.loss == 'ProbOhemCrossEntropy2d': 27 | h, w = args.base_size, args.base_size 28 | min_kept = int(args.batch_size // len(args.gpus_id) * h * w // 16) 29 | criteria = ProbOhemCrossEntropy2d(weight=weight, ignore_label=ignore_label, thresh=0.7, min_kept=min_kept) 30 | elif args.loss == 'CrossEntropyLoss2dLabelSmooth': 31 | criteria = CrossEntropyLoss2dLabelSmooth(weight=weight, ignore_label=ignore_label) 32 | # criteria = LabelSmoothing() 33 | elif args.loss == 'LovaszSoftmax': 34 | criteria = LovaszSoftmax(ignore_index=ignore_label) 35 | elif args.loss == 'FocalLoss2d': 36 | criteria = FocalLoss2d(weight=weight, ignore_index=ignore_label) 37 | 38 | return criteria 39 | -------------------------------------------------------------------------------- /builders/model_builder.py: -------------------------------------------------------------------------------- 1 | # _*_ coding: utf-8 _*_ 2 | """ 3 | Time: 2020/11/30 17:02 4 | Author: Ding Cheng(Deeachain) 5 | File: model_builder.py 6 | Describe: Write during my study in Nanjing University of Information and Secience Technology 7 | Github: https://github.com/Deeachain 8 | """ 9 | import torch.nn as nn 10 | import torch.utils.model_zoo as model_zoo 11 | from collections import OrderedDict 12 | from model.UNet import UNet 13 | from model.SegNet import SegNet 14 | from model.FCN8s import FCN 15 | from model.BiSeNet import BiSeNet 16 | from model.BiSeNetV2 import BiSeNetV2 17 | from model.PSPNet.pspnet import PSPNet 18 | from model.DeeplabV3Plus import Deeplabv3plus_res50 19 | from model.FCN_ResNet import FCN_ResNet 20 | from model.DDRNet import DDRNet 21 | from model.HRNet import HighResolutionNet 22 | 23 | model_urls = { 24 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 25 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 26 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 27 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 28 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 29 | } 30 | 31 | 32 | def build_model(model_name, num_classes, backbone='resnet18', pretrained=False, out_stride=32, mult_grid=False): 33 | if model_name == 'FCN': 34 | model = FCN(num_classes=num_classes) 35 | elif model_name == 'FCN_ResNet': 36 | model = FCN_ResNet(num_classes=num_classes, backbone=backbone, out_stride=out_stride, mult_grid=mult_grid) 37 | elif model_name == 'SegNet': 38 | model = SegNet(classes=num_classes) 39 | elif model_name == 'UNet': 40 | model = UNet(num_classes=num_classes) 41 | elif model_name == 'BiSeNet': 42 | model = BiSeNet(num_classes=num_classes, backbone=backbone) 43 | elif model_name == 'BiSeNetV2': 44 | model = BiSeNetV2(num_classes=num_classes) 45 | elif model_name == 'HRNet': 46 | model = HighResolutionNet(num_classes=num_classes) 47 | elif model_name == 'Deeplabv3plus_res50': 48 | model = Deeplabv3plus_res50(num_classes=num_classes, os=out_stride, pretrained=True) 49 | elif model_name == "DDRNet": 50 | model = DDRNet(pretrained=True, num_classes=num_classes) 51 | elif model_name == 'PSPNet_res50': 52 | model = PSPNet(layers=50, bins=(1, 2, 3, 6), dropout=0.1, num_classes=num_classes, zoom_factor=8, use_ppm=True, 53 | pretrained=True) 54 | elif model_name == 'PSPNet_res101': 55 | model = PSPNet(layers=101, bins=(1, 2, 3, 6), dropout=0.1, num_classes=num_classes, zoom_factor=8, use_ppm=True, 56 | pretrained=True) 57 | # elif model_name == 'PSANet50': 58 | # return PSANet(layers=50, dropout=0.1, classes=num_classes, zoom_factor=8, use_psa=True, psa_type=2, compact=compact, 59 | # shrink_factor=shrink_factor, mask_h=mask_h, mask_w=mask_w, psa_softmax=True, pretrained=True) 60 | 61 | if pretrained: 62 | checkpoint = model_zoo.load_url(model_urls[backbone]) 63 | model_dict = model.state_dict() 64 | # print(model_dict) 65 | # Screen out layers that are not loaded 66 | pretrained_dict = {'backbone.' + k: v for k, v in checkpoint.items() if 'backbone.' + k in model_dict} 67 | # Update the structure dictionary for the current network 68 | model_dict.update(pretrained_dict) 69 | model.load_state_dict(model_dict) 70 | 71 | return model 72 | 73 | 74 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # Supported datasets 2 | 3 | - CityScapes 4 | 5 | Note: When referring to the number of classes, the void/unlabeled class is excluded. 6 | 7 | 8 | ## Cityscapes 9 | 10 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated. 11 | 12 | The version supported here is the finely annotated one with 19 classes. 13 | 14 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts). 15 | 16 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data. -------------------------------------------------------------------------------- /dataset/calculate_class_weight.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from tqdm import tqdm 4 | from glob import glob 5 | 6 | 7 | class DatasetsInformation: 8 | """ To get statistical information about the train set, such as mean, std, class distribution. 9 | The class is employed for tackle class imbalance. 10 | """ 11 | 12 | def __init__(self, classes=32, normVal=1.10): 13 | """ 14 | Args: 15 | data_dir: directory where the dataset is kept 16 | classes: number of classes in the dataset 17 | inform_data_file: location where cached file has to be stored 18 | normVal: normalization value, as defined in ERFNet paper 19 | """ 20 | self.classes = classes 21 | self.classWeights = np.ones(self.classes, dtype=np.float32) 22 | self.normVal = normVal 23 | self.mean = np.zeros(3, dtype=np.float32) 24 | self.std = np.zeros(3, dtype=np.float32) 25 | 26 | def compute_class_weights(self, histogram): 27 | """to compute the class weights 28 | Args: 29 | histogram: distribution of class samples 30 | """ 31 | normHist = histogram / np.sum(histogram) 32 | 33 | for i in range(self.classes): 34 | self.classWeights[i] = 0.1 / (np.log(self.normVal + normHist[i])) 35 | return histogram, normHist * 100, self.classWeights 36 | 37 | def compute_class_weight_with_media_freq(self, n_classes, labels): 38 | count = np.zeros(n_classes) 39 | image_count = np.zeros(n_classes) 40 | example = cv2.imread(labels[0]) 41 | h, w, c = example.shape 42 | 43 | for label in tqdm(labels, desc="Claculating Class Weight"): 44 | data = cv2.imdecode(np.fromfile(label, dtype=np.uint8), -1) 45 | 46 | for c in range(n_classes): 47 | c_sum = np.sum(data == c) # 统计c类像素的个数 48 | count[c] += c_sum 49 | if np.sum(data == c) != 0: # 判断该图片中是否存在第c类像素,如果存在则第c类图片个数+1 50 | image_count[c] += 1 51 | 52 | frequency = count / (image_count * h * w) 53 | median = np.median(frequency) 54 | weight = median / frequency 55 | 56 | return frequency, weight 57 | 58 | def readWholeTrainSet(self, image, label, train_flag=True): 59 | """to read the whole train set of current dataset. 60 | Args: 61 | fileName: train set file that stores the image locations 62 | trainStg: if processing training or validation data 63 | 64 | return: 0 if successful 65 | """ 66 | global_hist = np.zeros(self.classes, dtype=np.float32) 67 | 68 | no_files = 0 69 | min_val_al = 0 70 | max_val_al = 0 71 | 72 | for index, label_file in tqdm(iterable=enumerate(label), desc='Calculate_ClassWeight', total=len(label)): 73 | img_file = image[index] 74 | 75 | label_img = cv2.imread(label_file, 0) 76 | unique_values = np.unique(label_img) 77 | max_val = max(unique_values) 78 | min_val = min(unique_values) 79 | 80 | max_val_al = max(max_val, max_val_al) 81 | min_val_al = min(min_val, min_val_al) 82 | 83 | if train_flag == True: 84 | hist = np.histogram(label_img, self.classes, range=(0, self.classes - 1)) 85 | global_hist += hist[0] 86 | 87 | rgb_img = cv2.imread(img_file) 88 | self.mean[0] += np.mean(rgb_img[:, :, 0]) 89 | self.mean[1] += np.mean(rgb_img[:, :, 1]) 90 | self.mean[2] += np.mean(rgb_img[:, :, 2]) 91 | 92 | self.std[0] += np.std(rgb_img[:, :, 0]) 93 | self.std[1] += np.std(rgb_img[:, :, 1]) 94 | self.std[2] += np.std(rgb_img[:, :, 2]) 95 | 96 | else: 97 | print("we can only collect statistical information of train set, please check") 98 | 99 | if max_val > (self.classes - 1) or min_val < 0: 100 | print('Labels can take value between 0 and number of classes.') 101 | print('Some problem with labels. Please check. label_set:', unique_values) 102 | print('Label Image ID: ' + label_file) 103 | no_files += 1 104 | 105 | # divide the mean and std values by the sample space size 106 | self.mean /= no_files 107 | self.std /= no_files 108 | 109 | # compute the class imbalance information 110 | self.compute_class_weights(global_hist) 111 | self.compute_class_weight_with_media_freq(self.classes, label) 112 | return self.mean, self.std, self.compute_class_weights(global_hist), self.compute_class_weight_with_media_freq( 113 | self.classes, label) 114 | 115 | 116 | if __name__ == '__main__': 117 | image = glob('/media/ding/Data/datasets/paris/512_image_625/crop_files/512_image/*.png') 118 | label = glob('/media/ding/Data/datasets/paris/512_image_625/crop_files/512_label/*.png') 119 | # image = glob('/media/ding/Data/datasets/paris/paris_origin/*_image.png') 120 | # label = glob('/media/ding/Data/datasets/paris/paris_origin/*labels_gray.png') 121 | class_num = 3 122 | 123 | info = DatasetsInformation(classes=class_num) 124 | out = info.readWholeTrainSet(image=image, label=label, train_flag=True) 125 | 126 | np.set_printoptions(suppress=True) 127 | print('Std is \n', out[0]) # 计算数据集图片的均值 128 | print('\nMean is \n', out[1]) # 计算数据集图片的均值 129 | print('\nPerClass Count is \n', np.array(out[2][0], dtype=int)) # 计算每个类别的个数,返回值是一个列表,列表的长度是类和个数 130 | print('\nPercentPerClass is (%)\n', out[2][1]) # 计算每个类别所占比例百分比 131 | print('\nCalcClassWeight is \n', out[2][2]) # 计算类别的权重 132 | print('\nMedia class Freq and weight', out[3][0], out[3][1]) 133 | 134 | import matplotlib.pyplot as plt 135 | 136 | plt.figure(figsize=(10, 8)) 137 | # plt.rcParams['font.sans-serif'] = ['SimSun'] 138 | # plt.rcParams['axes.unicode_minus'] = False 139 | X = np.arange(0, class_num) 140 | Y = out[2][1] 141 | print(Y) 142 | plt.bar(x=X, height=Y, color="c", width=0.8) 143 | plt.xticks(X, fontsize=20) # 标注横坐标的类别名 144 | plt.yticks(fontsize=20) # 标注纵坐标 145 | for x, y in zip(X, Y): 146 | plt.text(x, y, '%.2f' % y, ha='center', va='bottom', fontsize=10) # 使用matplotlib画柱状图并标记数字 147 | plt.xlabel("Class", fontsize=12) 148 | plt.ylabel("Percent(%)", fontsize=12) 149 | plt.title("Category scale distribution", fontsize=16) 150 | plt.savefig('Paris Percent.png') 151 | plt.show() 152 | -------------------------------------------------------------------------------- /dataset/cityscapes/Percent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/dataset/cityscapes/Percent.png -------------------------------------------------------------------------------- /dataset/cityscapes/cityscape_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Deeachain/Segmentation-Pytorch/acc6998863dfef884bc5fe954c2b8de1c28576a7/dataset/cityscapes/cityscape_scripts/__init__.py -------------------------------------------------------------------------------- /dataset/cityscapes/cityscape_scripts/download_cityscapes.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | global_path='../../../vision_datasets' 3 | data_dir=$global_path'/cityscapes' 4 | 5 | mkdir -p $data_dir 6 | cd $data_dir 7 | 8 | # enter user details 9 | uname='' # 10 | pass='' 11 | 12 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username='$uname'&password='$pass'&submit=Login' https://www.cityscapes-dataset.com/login/ 13 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 14 | wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 15 | # Uncomment if you want to download coarse 16 | #wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=4 17 | #wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=2 18 | 19 | 20 | #unzip -q -o gtCoarse.zip 21 | unzip -q -o gtFine_trainvaltest.zip 22 | #unzip -q -o leftImg8bit_trainextra.zip 23 | unzip -q -o leftImg8bit_trainvaltest.zip 24 | 25 | #rm -rf gtCoarse.zip 26 | rm -rf gtFine_trainvaltest.zip 27 | #rm -rf leftImg8bit_trainextra.zip 28 | rm -rf leftImg8bit_trainvaltest.zip -------------------------------------------------------------------------------- /dataset/cityscapes/cityscapes.py: -------------------------------------------------------------------------------- 1 | # _*_ coding: utf-8 _*_ 2 | """ 3 | Time: 2021/3/12 10:28 4 | Author: Ding Cheng(Deeachain) 5 | File: cityscapes.py 6 | Github: https://github.com/Deeachain 7 | """ 8 | import os.path as osp 9 | import numpy as np 10 | import cv2 11 | from torch.utils import data 12 | import pickle 13 | from PIL import Image 14 | from torchvision import transforms 15 | from utils import image_transform as tr 16 | 17 | 18 | class CityscapesTrainDataSet(data.Dataset): 19 | """ 20 | CityscapesTrainDataSet is employed to load train set 21 | Args: 22 | root: the Cityscapes dataset path, 23 | cityscapes 24 | ├── gtFine 25 | ├── leftImg8bit 26 | list_path: cityscapes_train_list.txt, include partial path 27 | mean: bgr_mean (73.15835921, 82.90891754, 72.39239876) 28 | 29 | """ 30 | 31 | def __init__(self, root='', list_path='', max_iters=None, base_size=513, crop_size=513, mean=(128, 128, 128), 32 | std=(128, 128, 128), ignore_label=255): 33 | self.root = root 34 | self.list_path = list_path 35 | self.base_size = base_size 36 | self.crop_size = crop_size 37 | self.mean = mean 38 | self.std = std 39 | self.ignore_label = ignore_label 40 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 41 | if not max_iters == None: 42 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 43 | self.files = [] 44 | 45 | for name in self.img_ids: 46 | img_file = osp.join(self.root, name.split()[0]) 47 | label_file = osp.join(self.root, name.split()[1]) 48 | name = name.strip().split()[0].strip().split('/', 3)[3].split('.')[0] 49 | self.files.append({"img": img_file, "label": label_file, "name": name}) 50 | 51 | print("length of train dataset: ", len(self.files)) 52 | 53 | def __len__(self): 54 | return len(self.files) 55 | 56 | def __getitem__(self, index): 57 | datafiles = self.files[index] 58 | image = Image.open(datafiles["img"]).convert('RGB') 59 | label = Image.open(datafiles["label"]) 60 | size = np.asarray(image).shape 61 | name = datafiles["name"] 62 | 63 | composed_transforms = transforms.Compose([ 64 | tr.RandomHorizontalFlip(), 65 | # tr.RandomRotate(180), 66 | tr.RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255), 67 | tr.RandomGaussianBlur(), 68 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 69 | tr.ToTensor()]) 70 | sample = {'image': image, 'label': label} 71 | sampled = composed_transforms(sample) 72 | image, label = sampled['image'], sampled['label'] 73 | return image, label, np.array(size), name 74 | 75 | 76 | class CityscapesValDataSet(data.Dataset): 77 | """ 78 | CityscapesDataSet is employed to load val set 79 | Args: 80 | root: the Cityscapes dataset path, 81 | cityscapes 82 | ├── gtFine 83 | ├── leftImg8bit 84 | list_path: cityscapes_val_list.txt, include partial path 85 | 86 | """ 87 | 88 | def __init__(self, root='', list_path='', crop_size=513, mean=(128, 128, 128), std=(128, 128, 128), 89 | ignore_label=255): 90 | self.root = root 91 | self.list_path = list_path 92 | self.crop_size = crop_size 93 | self.mean = mean 94 | self.std = std 95 | self.ignore_label = ignore_label 96 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 97 | self.files = [] 98 | for name in self.img_ids: 99 | img_file = osp.join(self.root, name.split()[0]) 100 | label_file = osp.join(self.root, name.split()[1]) 101 | name = name.strip().split()[0].strip().split('/', 3)[3].split('.')[0] 102 | self.files.append({"img": img_file, "label": label_file, "name": name}) 103 | 104 | print("length of validation dataset: ", len(self.files)) 105 | 106 | def __len__(self): 107 | return len(self.files) 108 | 109 | def __getitem__(self, index): 110 | datafiles = self.files[index] 111 | image = Image.open(datafiles["img"]).convert('RGB') 112 | label = Image.open(datafiles["label"]) 113 | size = np.asarray(image).shape 114 | name = datafiles["name"] 115 | composed_transforms = transforms.Compose([ 116 | tr.FixScaleCrop(crop_size=self.crop_size), 117 | # tr.FixedResize(size=(1024,512)), 118 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 119 | tr.ToTensor()]) 120 | sample = {'image': image, 'label': label} 121 | sampled = composed_transforms(sample) 122 | image, label = sampled['image'], sampled['label'] 123 | 124 | return image, label, np.array(size), name 125 | 126 | 127 | class CityscapesTestDataSet(data.Dataset): 128 | """ 129 | CityscapesDataSet is employed to load test set 130 | Args: 131 | root: the Cityscapes dataset path, 132 | list_path: cityscapes_test_list.txt, include partial path 133 | 134 | """ 135 | 136 | def __init__(self, root='', list_path='', crop_size=513, mean=(128, 128, 128), std=(128, 128, 128), 137 | ignore_label=255): 138 | self.root = root 139 | self.list_path = list_path 140 | self.crop_size = crop_size 141 | self.mean = mean 142 | self.std = std 143 | self.ignore_label = ignore_label 144 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 145 | self.files = [] 146 | for name in self.img_ids: 147 | img_file = osp.join(self.root, name.split()[0]) 148 | name = name.strip().split()[0].strip().split('/', 3)[3].split('.')[0] 149 | self.files.append({"img": img_file, "name": name}) 150 | 151 | print("length of validation dataset: ", len(self.files)) 152 | 153 | def __len__(self): 154 | return len(self.files) 155 | 156 | def __getitem__(self, index): 157 | datafiles = self.files[index] 158 | image = Image.open(datafiles["img"]).convert('RGB') 159 | size = np.asarray(image).shape 160 | name = datafiles["name"] 161 | composed_transforms = transforms.Compose([ 162 | # tr.FixedResize(size=(1024,512)), 163 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 164 | tr.ToTensor()]) 165 | sample = {'image': image} 166 | sampled = composed_transforms(sample) 167 | image = sampled['image'] 168 | 169 | return image, np.array(size), name 170 | 171 | 172 | class CityscapesTrainInform: 173 | """ To get statistical information about the train set, such as mean, std, class distribution. 174 | The class is employed for tackle class imbalance. 175 | """ 176 | 177 | def __init__(self, data_dir='', classes=19, 178 | train_set_file="", inform_data_file="", normVal=1.10): 179 | """ 180 | Args: 181 | data_dir: directory where the dataset is kept 182 | classes: number of classes in the dataset 183 | inform_data_file: location where cached file has to be stored 184 | normVal: normalization value, as defined in ERFNet paper 185 | """ 186 | self.data_dir = data_dir 187 | self.classes = classes 188 | self.classWeights = np.ones(self.classes, dtype=np.float32) 189 | self.normVal = normVal 190 | self.mean = np.zeros(3, dtype=np.float32) 191 | self.std = np.zeros(3, dtype=np.float32) 192 | self.train_set_file = train_set_file 193 | self.inform_data_file = inform_data_file 194 | 195 | def compute_class_weights(self, histogram): 196 | """to compute the class weights 197 | Args: 198 | histogram: distribution of class samples 199 | """ 200 | normHist = histogram / np.sum(histogram) 201 | for i in range(self.classes): 202 | self.classWeights[i] = 1 / (np.log(self.normVal + normHist[i])) 203 | 204 | def readWholeTrainSet(self, fileName, train_flag=True): 205 | """to read the whole train set of current dataset. 206 | Args: 207 | fileName: train set file that stores the image locations 208 | trainStg: if processing training or validation data 209 | 210 | return: 0 if successful 211 | """ 212 | global_hist = np.zeros(self.classes, dtype=np.float32) 213 | 214 | no_files = 0 215 | min_val_al = 0 216 | max_val_al = 0 217 | with open(self.data_dir + '/' + fileName, 'r') as textFile: 218 | for line in textFile: 219 | # we expect the text file to contain the data in following format 220 | #