├── LICENSE ├── README.md ├── baseline.py ├── evaluate.py ├── helpers ├── csHelpers.py ├── helpers.py └── labels.py ├── learning ├── learner.py ├── minicity.py ├── model.py └── utils.py ├── minicity ├── class_pixel_distribution.png ├── copyblob.PNG ├── cutmix.PNG ├── leaderboard.PNG └── visualizer.ipynb ├── option.py ├── requirements.txt └── results.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hoseong Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation Tutorial using PyTorch 2 | Semantic Segmentation Tutorial using PyTorch. Based on [2020 ECCV VIPriors Challange Start Code](https://github.com/VIPriors/vipriors-challenges-toolkit/tree/master/semantic-segmentation), implements semantic segmentation codebase and add some tricks. 3 | 4 | *Editer: Hoseong Lee (hoya012)* 5 | 6 | ## 0. Experimental Setup 7 | 8 | ### 0-1. Prepare Library 9 | ```python 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### 0-2. Download dataset (MiniCity from CityScapes) 14 | We will use MiniCity Dataset from Cityscapes. This dataset is used for 2020 ECCV VIPriors Challenge. 15 | - workshop page: https://vipriors.github.io/challenges/ 16 | - challenge link: https://competitions.codalab.org/competitions/23712 17 | - [dataset download(google drive)](https://drive.google.com/file/d/1YjkiaLqU1l9jVCVslrZpip4YsCHHlbNA/view?usp=sharing) 18 | - move dataset into `minicity` folder. 19 | 20 | ### 0-3. Dataset Simple EDA (Exploratory Data Analysis) - Class Distribution, Sample Distribution 21 | #### benchmark class 22 | ```python 23 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 24 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 25 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 26 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 27 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 28 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 29 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 30 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 31 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 32 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 33 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 34 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 35 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 36 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 37 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 38 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 39 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 40 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 41 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 42 | ``` 43 | 44 | #### from 0 to 18 class, count labeled pixels 45 | ![](https://github.com/hoya012/semantic-segmentation-tutorial-pytorch/blob/master/minicity/class_pixel_distribution.png) 46 | 47 | #### deeplab v3 baseline test set result 48 | - Dataset has severe Class-Imbalance problem. 49 | - IoU of minor class is very low. (wall, fence, bus, train) 50 | 51 | ```python 52 | classes IoU nIoU 53 | -------------------------------- 54 | road : 0.963 nan 55 | sidewalk : 0.762 nan 56 | building : 0.856 nan 57 | wall : 0.120 nan 58 | fence : 0.334 nan 59 | pole : 0.488 nan 60 | traffic light : 0.563 nan 61 | traffic sign : 0.631 nan 62 | vegetation : 0.884 nan 63 | terrain : 0.538 nan 64 | sky : 0.901 nan 65 | person : 0.732 0.529 66 | rider : 0.374 0.296 67 | car : 0.897 0.822 68 | truck : 0.444 0.218 69 | bus : 0.244 0.116 70 | train : 0.033 0.006 71 | motorcycle : 0.492 0.240 72 | bicycle : 0.638 0.439 73 | -------------------------------- 74 | Score Average : 0.573 0.333 75 | -------------------------------- 76 | ``` 77 | 78 | ## 1. Training Baseline Model 79 | - I use [DeepLabV3 from torchvision.](https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/) 80 | - ResNet-50 Backbone, ResNet-101 Backbone 81 | 82 | - I use 4 RTX 2080 Ti GPUs. (11GB x 4) 83 | - If you have just 1 GPU or small GPU Memory, please use smaller batch size (<= 8) 84 | 85 | ```python 86 | python baseline.py --save_path baseline_run_deeplabv3_resnet50 --crop_size 576 1152 --batch_size 8; 87 | ``` 88 | 89 | ```python 90 | python baseline.py --save_path baseline_run_deeplabv3_resnet101 --model DeepLabv3_resnet101 --train_size 512 1024 --test_size 512 1024 --crop_size 384 768 --batch_size 8; 91 | ``` 92 | 93 | ### 1-1. Loss Functions 94 | - I tried 3 loss functions. 95 | - Cross-Entropy Loss 96 | - Class-Weighted Cross Entropy Loss 97 | - Focal Loss 98 | - You can choose loss function using `--loss` argument. 99 | - I recommend default (ce) or Class-Weighted CE loss. Focal loss didn'y work well in my codebase. 100 | 101 | ```python 102 | # Cross Entropy Loss 103 | python baseline.py --save_path baseline_run_deeplabv3_resnet50 --crop_size 576 1152 --batch_size 8; 104 | ``` 105 | 106 | ```python 107 | # Weighted Cross Entropy Loss 108 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_wce --crop_size 576 1152 --batch_size 8 --loss weighted_ce; 109 | ``` 110 | 111 | ```python 112 | # Focal Loss 113 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_focal --crop_size 576 1152 --batch_size 8 --loss focal --focal_gamma 2.0; 114 | ``` 115 | 116 | ### 1-2. Normalization Layer 117 | - I tried 4 normalization layer. 118 | - Batch Normalization (BN) 119 | - Instance Normalization (IN) 120 | - Group Normalization (GN) 121 | - Evolving Normalization (EvoNorm) 122 | 123 | - You can choose normalization layer using `--norm` argument. 124 | - I recommend BN. 125 | 126 | ```python 127 | # Batch Normalization 128 | python baseline.py --save_path baseline_run_deeplabv3_resnet50 --crop_size 576 1152 --batch_size 8; 129 | ``` 130 | 131 | ```python 132 | # Instance Normalization 133 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_instancenorm --crop_size 576 1152 --batch_size 8 --norm instance; 134 | ``` 135 | 136 | ```python 137 | # Group Normalization 138 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_groupnorm --crop_size 576 1152 --batch_size 8 --norm group; 139 | ``` 140 | 141 | ```python 142 | # Evolving Normalization 143 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_evonorm --crop_size 576 1152 --batch_size 8 --norm evo; 144 | ``` 145 | 146 | ### 1-3. Additional Augmentation Tricks 147 | - Propose 2 data augmentation techniques (CutMix, copyblob) 148 | - CutMix Augmentation 149 | ![](https://github.com/hoya012/semantic-segmentation-tutorial-pytorch/blob/master/minicity/cutmix.PNG) 150 | - Based on [Original CutMix](https://arxiv.org/abs/1905.04899), bring idea to Semantic Segmentation. 151 | 152 | - CopyBlob Augmentation 153 | ![](https://github.com/hoya012/semantic-segmentation-tutorial-pytorch/blob/master/minicity/copyblob.PNG) 154 | - To tackle Class-Imbalance, use CopyBlob augmentation with visual inductive prior. 155 | - Wall must be located on the sidewalk 156 | - Fence must be located on the sidewalk 157 | - Bus must be located on the Road 158 | - Train must be located on the Road 159 | 160 | ```python 161 | # CutMix Augmentation 162 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_cutmix --crop_size 576 1152 --batch_size 8 --cutmix; 163 | ``` 164 | 165 | ```python 166 | # CopyBlob Augmentation 167 | python baseline.py --save_path baseline_run_deeplabv3_resnet50_copyblob --crop_size 576 1152 --batch_size 8 --copyblob; 168 | ``` 169 | 170 | ## 2. Inference 171 | - After training, we can evaluate using trained models. 172 | - I recommend same value for `train_size` and `test_size`. 173 | 174 | ```python 175 | python baseline.py --save_path baseline_run_deeplabv3_resnet50 --batch_size 4 --predict; 176 | ``` 177 | 178 | ### 2-1. Multi-Scale Infernece (Test Time Augmentation) 179 | - I use [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.2] scales for Multi-Scale Inference. Additionaly, use H-Flip. 180 | - Must use single batch (batch_size=1) 181 | 182 | ```python 183 | # Multi-Scale Inference 184 | python baseline.py --save_path baseline_run_deeplabv3_resnet50 --batch_size 1 --predict --mst; 185 | ``` 186 | 187 | ### 2-2. Calculate Metric using Validation Set 188 | - We can calculate metric and save results into `results.txt`. 189 | - ex) [My final validation set result](https://github.com/hoya012/semantic-segmentation-tutorial-pytorch/blob/master/results.txt) 190 | 191 | ```python 192 | python evaluate.py --results baseline_run_deeplabv3_resnet50/results_val --batch_size 1 --predict --mst; 193 | ``` 194 | 195 | ## 3. Final Result 196 | - ![](https://github.com/hoya012/semantic-segmentation-tutorial-pytorch/blob/master/minicity/leaderboard.PNG) 197 | - My final single model result is **0.6069831962012341** 198 | - Achieve 5th place on the leaderboard. 199 | - But, didn't submit short-paper, so my score is not official score. 200 | - If i use bigger model and bigger backbone, performance will be improved.. maybe.. 201 | - If i use ensemble various models, performance will be improved! 202 | - Leader board can be found in [Codalab Challenge Page](https://competitions.codalab.org/competitions/23712#results) 203 | 204 | ## 4. Reference 205 | - [vipriors-challange-toolkit](https://github.com/VIPriors/vipriors-challenges-toolkit) 206 | - [torchvision deeplab v3 model](https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py) 207 | - [Focal Loss](https://github.com/clcarwin/focal_loss_pytorch) 208 | - [Class Weighted CE Loss](https://github.com/openseg-group/OCNet.pytorch/blob/master/utils/loss.py) 209 | - [EvoNorm](https://github.com/digantamisra98/EvoNorm) 210 | - [CutMix Augmentation](https://github.com/clovaai/CutMix-PyTorch) -------------------------------------------------------------------------------- /baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import warnings 5 | import random 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | from option import get_args 11 | from learning.minicity import MiniCity 12 | from learning.learner import train_epoch, validate_epoch, predict 13 | from learning.utils import get_dataloader, get_lossfunc, get_model 14 | 15 | from helpers.helpers import plot_learning_curves 16 | import torchvision.transforms.functional as TF 17 | 18 | 19 | def main(): 20 | args = get_args() 21 | print("args : ", args) 22 | 23 | # Fix seed 24 | if args.seed is not None: 25 | torch.manual_seed(random_seed) 26 | torch.cuda.manual_seed(random_seed) 27 | torch.cuda.manual_seed_all(random_seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | np.random.seed(random_seed) 31 | random.seed(random_seed) 32 | warnings.warn('You have chosen to seed training. ' 33 | 'This will turn on the CUDNN deterministic setting, ' 34 | 'which can slow down your training considerably! ' 35 | 'You may see unexpected behavior when restarting from checkpoints.') 36 | 37 | assert args.crop_size[0] <= args.train_size[0] and args.crop_size[1] <= args.train_size[1], \ 38 | 'Must be Crop size <= Image Size.' 39 | 40 | # Create directory to store run files 41 | if not os.path.isdir(args.save_path): 42 | os.makedirs(args.save_path + '/images') 43 | if not os.path.isdir(args.save_path + '/results_color_val'): 44 | os.makedirs(args.save_path + '/results_color_val') 45 | os.makedirs(args.save_path + '/results_color_test') 46 | 47 | Dataset = MiniCity 48 | 49 | dataloaders = get_dataloader(Dataset, args) 50 | criterion = get_lossfunc(Dataset, args) 51 | model = get_model(Dataset, args) 52 | 53 | print(model) 54 | 55 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.lr_momentum, weight_decay=args.lr_weight_decay) 56 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 57 | 58 | # Initialize metrics 59 | best_miou = 0.0 60 | metrics = {'train_loss' : [], 61 | 'train_acc' : [], 62 | 'val_acc' : [], 63 | 'val_loss' : [], 64 | 'miou' : []} 65 | start_epoch = 0 66 | 67 | # Resume training from checkpoint 68 | if args.weights: 69 | print('Resuming training from {}.'.format(args.weights)) 70 | checkpoint = torch.load(args.weights) 71 | model.load_state_dict(checkpoint['model_state_dict'], strict=True) 72 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 73 | metrics = checkpoint['metrics'] 74 | best_miou = checkpoint['best_miou'] 75 | start_epoch = checkpoint['epoch']+1 76 | 77 | # Push model to GPU 78 | if torch.cuda.is_available(): 79 | model = torch.nn.DataParallel(model).cuda() 80 | print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0))) 81 | 82 | # No training, only running prediction on test set 83 | if args.predict: 84 | checkpoint = torch.load(args.save_path + '/best_weights.pth.tar') 85 | model.load_state_dict(checkpoint['model_state_dict'], strict=True) 86 | print('Loaded model weights from {}'.format(args.save_path + '/best_weights.pth.tar')) 87 | # Create results directory 88 | if not os.path.isdir(args.save_path + '/results_val'): 89 | os.makedirs(args.save_path + '/results_val') 90 | if not os.path.isdir(args.save_path + '/results_test'): 91 | os.makedirs(args.save_path + '/results_test') 92 | 93 | predict(dataloaders['test'], model, Dataset.mask_colors, folder=args.save_path, mode='test', args=args) 94 | predict(dataloaders['val'], model, Dataset.mask_colors, folder=args.save_path, mode='val', args=args) 95 | return 96 | 97 | # Generate log file 98 | with open(args.save_path + '/log_epoch.csv', 'a') as epoch_log: 99 | epoch_log.write('epoch, train loss, val loss, train acc, val acc, miou\n') 100 | 101 | since = time.time() 102 | 103 | for epoch in range(start_epoch, args.epochs): 104 | # Train 105 | print('--- Training ---') 106 | train_loss, train_acc = train_epoch(dataloaders['train'], model, criterion, optimizer, scheduler, epoch, void=Dataset.voidClass, args=args) 107 | metrics['train_loss'].append(train_loss) 108 | metrics['train_acc'].append(train_acc) 109 | print('Epoch {} train loss: {:.4f}, acc: {:.4f}'.format(epoch,train_loss,train_acc)) 110 | 111 | # Validate 112 | print('--- Validation ---') 113 | val_acc, val_loss, miou = validate_epoch(dataloaders['val'], model, criterion, epoch, 114 | Dataset.classLabels, Dataset.validClasses, void=Dataset.voidClass, 115 | maskColors=Dataset.mask_colors, folder=args.save_path, args=args) 116 | metrics['val_acc'].append(val_acc) 117 | metrics['val_loss'].append(val_loss) 118 | metrics['miou'].append(miou) 119 | 120 | # Write logs 121 | with open(args.save_path + '/log_epoch.csv', 'a') as epoch_log: 122 | epoch_log.write('{}, {:.5f}, {:.5f}, {:.5f}, {:.5f}, {:.5f}\n'.format( 123 | epoch, train_loss, val_loss, train_acc, val_acc, miou)) 124 | 125 | # Save checkpoint 126 | torch.save({ 127 | 'epoch': epoch, 128 | 'model_state_dict': model.state_dict(), 129 | 'optimizer_state_dict': optimizer.state_dict(), 130 | 'best_miou': best_miou, 131 | 'metrics': metrics, 132 | }, args.save_path + '/checkpoint.pth.tar') 133 | 134 | # Save best model to file 135 | if miou > best_miou: 136 | print('mIoU improved from {:.4f} to {:.4f}.'.format(best_miou, miou)) 137 | best_miou = miou 138 | torch.save({ 139 | 'epoch': epoch, 140 | 'model_state_dict': model.state_dict(), 141 | }, args.save_path + '/best_weights.pth.tar') 142 | 143 | time_elapsed = time.time() - since 144 | print('Training complete in {:.0f}m {:.0f}s'.format( 145 | time_elapsed // 60, time_elapsed % 60)) 146 | 147 | plot_learning_curves(metrics, args) 148 | 149 | # Load best model 150 | checkpoint = torch.load(args.save_path + '/best_weights.pth.tar') 151 | model.load_state_dict(checkpoint['model_state_dict'], strict=True) 152 | print('Loaded best model weights (epoch {}) from {}/best_weights.pth.tar'.format(checkpoint['epoch'], args.save_path)) 153 | 154 | # Create results directory 155 | if not os.path.isdir(args.save_path + '/results_val'): 156 | os.makedirs(args.save_path + '/results_val') 157 | 158 | if not os.path.isdir(args.save_path + '/results_test'): 159 | os.makedirs(args.save_path + '/results_test') 160 | 161 | # Run prediction on validation set. For predicting on test set, simple replace 'val' by 'test' 162 | predict(dataloaders['val'], model, Dataset.mask_colors, folder=args.save_path, mode='val', args=args) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Use this script to evaluate your model. It stores metrics in the file 3 | `scores.txt` and the more detailed `results.json` in the current directory. 4 | 5 | Based on the official CityScapes evaluation script: 6 | https://github.com/mcordts/cityscapesScripts 7 | 8 | - Assumes dataset and results to be in default location. Alternatively, specify 9 | paths to predictions and minicity folder as optional arguments. 10 | - Assumes predictions to have the same file names as the inputs. 11 | 12 | Usage: 13 | evaluate.py --results --cityscapes 14 | """ 15 | # python imports 16 | from __future__ import print_function, absolute_import, division 17 | from PIL import Image 18 | import os, sys 19 | import platform 20 | import fnmatch 21 | import math 22 | import numpy as np 23 | import glob 24 | import argparse 25 | 26 | # Cityscapes imports 27 | from helpers.csHelpers import printError, getColorEntry, \ 28 | getCsFileInfo, colors, ensurePath, writeDict2JSON, writeDict2Txt 29 | from helpers.labels import labels, category2labels, id2label 30 | 31 | parser = argparse.ArgumentParser(description='VIPriors Segmentation evaluation tool') 32 | parser.add_argument('--results', metavar='path/to/predictions', default='results', 33 | help='path to predictions') 34 | parser.add_argument('--minicity', metavar='path/to/dataset', default='minicity', 35 | help='path to dataset root (ends with /minicity)') 36 | 37 | def getPrediction( args, groundTruthFile ): 38 | # determine the prediction path, if the method is first called 39 | if not pargs.results: 40 | rootPath = os.path.join(os.path.dirname(os.path.realpath(__file__)),'results') 41 | if not os.path.isdir(rootPath): 42 | printError('Could not find a result root folder.') 43 | pargs.results = rootPath 44 | 45 | # walk the prediction path, if not happened yet 46 | if not args.predictionWalk: 47 | walk = [] 48 | for root, dirnames, filenames in os.walk(pargs.results): 49 | walk.append( (root,filenames) ) 50 | args.predictionWalk = walk 51 | 52 | csFile = getCsFileInfo(groundTruthFile) 53 | filePattern = '{}_{}_{}*.png'.format( csFile.city , csFile.sequenceNb , csFile.frameNb ) 54 | 55 | predictionFile = None 56 | for root, filenames in args.predictionWalk: 57 | for filename in fnmatch.filter(filenames, filePattern): 58 | if not predictionFile: 59 | predictionFile = os.path.join(root, filename) 60 | else: 61 | printError('Found multiple predictions for ground truth {}'.format(groundTruthFile)) 62 | 63 | if not predictionFile: 64 | printError('Found no prediction for ground truth {}'.format(groundTruthFile)) 65 | 66 | return predictionFile 67 | 68 | 69 | ###################### 70 | # Parameters 71 | ###################### 72 | 73 | 74 | # A dummy class to collect all bunch of data 75 | class CArgs(object): 76 | pass 77 | # And a global object of that class 78 | args = CArgs() 79 | 80 | # Define directories 81 | args.exportFile = 'results.json' 82 | 83 | # Remaining params 84 | args.evalInstLevelScore = True 85 | args.evalPixelAccuracy = True 86 | args.evalLabels = [] 87 | args.printRow = 5 88 | args.normalized = True 89 | args.colorized = hasattr(sys.stderr, 'isatty') and sys.stderr.isatty() and platform.system()=='Linux' 90 | args.bold = colors.BOLD if args.colorized else '' 91 | args.nocol = colors.ENDC if args.colorized else '' 92 | args.JSONOutput = True 93 | args.quiet = False 94 | 95 | args.avgClassSize = { 96 | 'bicycle' : 4672.3249222261 , 97 | 'caravan' : 36771.8241758242 , 98 | 'motorcycle' : 6298.7200839748 , 99 | 'rider' : 3930.4788056518 , 100 | 'bus' : 35732.1511111111 , 101 | 'train' : 67583.7075812274 , 102 | 'car' : 12794.0202738185 , 103 | 'person' : 3462.4756337644 , 104 | 'truck' : 27855.1264367816 , 105 | 'trailer' : 16926.9763313609 , 106 | } 107 | 108 | # value is filled when the method getPrediction is first called 109 | args.predictionWalk = None 110 | 111 | 112 | ######################### 113 | # Methods 114 | ######################### 115 | 116 | 117 | # Generate empty confusion matrix and create list of relevant labels 118 | def generateMatrix(args): 119 | args.evalLabels = [] 120 | for label in labels: 121 | if (label.id < 0): 122 | continue 123 | # we append all found labels, regardless of being ignored 124 | args.evalLabels.append(label.id) 125 | maxId = max(args.evalLabels) 126 | # We use longlong type to be sure that there are no overflows 127 | return np.zeros(shape=(maxId+1, maxId+1),dtype=np.ulonglong) 128 | 129 | def generateInstanceStats(args): 130 | instanceStats = {} 131 | instanceStats['classes' ] = {} 132 | instanceStats['categories'] = {} 133 | for label in labels: 134 | if label.hasInstances and not label.ignoreInEval: 135 | instanceStats['classes'][label.name] = {} 136 | instanceStats['classes'][label.name]['tp'] = 0.0 137 | instanceStats['classes'][label.name]['tpWeighted'] = 0.0 138 | instanceStats['classes'][label.name]['fn'] = 0.0 139 | instanceStats['classes'][label.name]['fnWeighted'] = 0.0 140 | for category in category2labels: 141 | labelIds = [] 142 | allInstances = True 143 | for label in category2labels[category]: 144 | if label.id < 0: 145 | continue 146 | if not label.hasInstances: 147 | allInstances = False 148 | break 149 | labelIds.append(label.id) 150 | if not allInstances: 151 | continue 152 | 153 | instanceStats['categories'][category] = {} 154 | instanceStats['categories'][category]['tp'] = 0.0 155 | instanceStats['categories'][category]['tpWeighted'] = 0.0 156 | instanceStats['categories'][category]['fn'] = 0.0 157 | instanceStats['categories'][category]['fnWeighted'] = 0.0 158 | instanceStats['categories'][category]['labelIds'] = labelIds 159 | 160 | return instanceStats 161 | 162 | 163 | # Get absolute or normalized value from field in confusion matrix. 164 | def getMatrixFieldValue(confMatrix, i, j, args): 165 | if args.normalized: 166 | rowSum = confMatrix[i].sum() 167 | if (rowSum == 0): 168 | return float('nan') 169 | return float(confMatrix[i][j]) / rowSum 170 | else: 171 | return confMatrix[i][j] 172 | 173 | # Calculate and return IOU score for a particular label 174 | def getIouScoreForLabel(label, confMatrix, args): 175 | if id2label[label].ignoreInEval: 176 | return float('nan') 177 | 178 | # the number of true positive pixels for this label 179 | # the entry on the diagonal of the confusion matrix 180 | tp = np.longlong(confMatrix[label,label]) 181 | 182 | # the number of false negative pixels for this label 183 | # the row sum of the matching row in the confusion matrix 184 | # minus the diagonal entry 185 | fn = np.longlong(confMatrix[label,:].sum()) - tp 186 | 187 | # the number of false positive pixels for this labels 188 | # Only pixels that are not on a pixel with ground truth label that is ignored 189 | # The column sum of the corresponding column in the confusion matrix 190 | # without the ignored rows and without the actual label of interest 191 | notIgnored = [l for l in args.evalLabels if not id2label[l].ignoreInEval and not l==label] 192 | fp = np.longlong(confMatrix[notIgnored,label].sum()) 193 | 194 | # the denominator of the IOU score 195 | denom = (tp + fp + fn) 196 | if denom == 0: 197 | return float('nan') 198 | 199 | # return IOU 200 | return float(tp) / denom 201 | 202 | # Calculate and return IOU score for a particular label 203 | def getInstanceIouScoreForLabel(label, confMatrix, instStats, args): 204 | if id2label[label].ignoreInEval: 205 | return float('nan') 206 | 207 | labelName = id2label[label].name 208 | if not labelName in instStats['classes']: 209 | return float('nan') 210 | 211 | tp = instStats['classes'][labelName]['tpWeighted'] 212 | fn = instStats['classes'][labelName]['fnWeighted'] 213 | # false postives computed as above 214 | notIgnored = [l for l in args.evalLabels if not id2label[l].ignoreInEval and not l==label] 215 | fp = np.longlong(confMatrix[notIgnored,label].sum()) 216 | 217 | # the denominator of the IOU score 218 | denom = (tp + fp + fn) 219 | if denom == 0: 220 | return float('nan') 221 | 222 | # return IOU 223 | return float(tp) / denom 224 | 225 | # Calculate prior for a particular class id. 226 | def getPrior(label, confMatrix): 227 | return float(confMatrix[label,:].sum()) / confMatrix.sum() 228 | 229 | # Get average of scores. 230 | # Only computes the average over valid entries. 231 | def getScoreAverage(scoreList, args): 232 | validScores = 0 233 | scoreSum = 0.0 234 | for score in scoreList: 235 | if not math.isnan(scoreList[score]): 236 | validScores += 1 237 | scoreSum += scoreList[score] 238 | if validScores == 0: 239 | return float('nan') 240 | return scoreSum / validScores 241 | 242 | # Calculate and return IOU score for a particular category 243 | def getIouScoreForCategory(category, confMatrix, args): 244 | # All labels in this category 245 | labels = category2labels[category] 246 | # The IDs of all valid labels in this category 247 | labelIds = [label.id for label in labels if not label.ignoreInEval and label.id in args.evalLabels] 248 | # If there are no valid labels, then return NaN 249 | if not labelIds: 250 | return float('nan') 251 | 252 | # the number of true positive pixels for this category 253 | # this is the sum of all entries in the confusion matrix 254 | # where row and column belong to a label ID of this category 255 | tp = np.longlong(confMatrix[labelIds,:][:,labelIds].sum()) 256 | 257 | # the number of false negative pixels for this category 258 | # that is the sum of all rows of labels within this category 259 | # minus the number of true positive pixels 260 | fn = np.longlong(confMatrix[labelIds,:].sum()) - tp 261 | 262 | # the number of false positive pixels for this category 263 | # we count the column sum of all labels within this category 264 | # while skipping the rows of ignored labels and of labels within this category 265 | notIgnoredAndNotInCategory = [l for l in args.evalLabels if not id2label[l].ignoreInEval and id2label[l].category != category] 266 | fp = np.longlong(confMatrix[notIgnoredAndNotInCategory,:][:,labelIds].sum()) 267 | 268 | # the denominator of the IOU score 269 | denom = (tp + fp + fn) 270 | if denom == 0: 271 | return float('nan') 272 | 273 | # return IOU 274 | return float(tp) / denom 275 | 276 | # Calculate and return IOU score for a particular category 277 | def getInstanceIouScoreForCategory(category, confMatrix, instStats, args): 278 | if not category in instStats['categories']: 279 | return float('nan') 280 | labelIds = instStats['categories'][category]['labelIds'] 281 | 282 | tp = instStats['categories'][category]['tpWeighted'] 283 | fn = instStats['categories'][category]['fnWeighted'] 284 | 285 | # the number of false positive pixels for this category 286 | # same as above 287 | notIgnoredAndNotInCategory = [l for l in args.evalLabels if not id2label[l].ignoreInEval and id2label[l].category != category] 288 | fp = np.longlong(confMatrix[notIgnoredAndNotInCategory,:][:,labelIds].sum()) 289 | 290 | # the denominator of the IOU score 291 | denom = (tp + fp + fn) 292 | if denom == 0: 293 | return float('nan') 294 | 295 | # return IOU 296 | return float(tp) / denom 297 | 298 | 299 | # create a dictionary containing all relevant results 300 | def createResultDict( confMatrix, classScores, classInstScores, categoryScores, categoryInstScores, perImageStats, args ): 301 | # write JSON result file 302 | wholeData = {} 303 | wholeData['confMatrix'] = confMatrix.tolist() 304 | wholeData['priors'] = {} 305 | wholeData['labels'] = {} 306 | for label in args.evalLabels: 307 | wholeData['priors'][id2label[label].name] = getPrior(label, confMatrix) 308 | wholeData['labels'][id2label[label].name] = label 309 | wholeData['classScores'] = classScores 310 | wholeData['classInstScores'] = classInstScores 311 | wholeData['categoryScores'] = categoryScores 312 | wholeData['categoryInstScores'] = categoryInstScores 313 | wholeData['averageScoreClasses'] = getScoreAverage(classScores, args) 314 | wholeData['averageScoreInstClasses'] = getScoreAverage(classInstScores, args) 315 | wholeData['averageScoreCategories'] = getScoreAverage(categoryScores, args) 316 | wholeData['averageScoreInstCategories'] = getScoreAverage(categoryInstScores, args) 317 | wholeData['accuracy'] = np.trace(confMatrix) / np.sum(confMatrix) 318 | 319 | if perImageStats: 320 | wholeData['perImageScores'] = perImageStats 321 | 322 | return wholeData 323 | 324 | def writeJSONFile(wholeData, args): 325 | path = os.path.dirname(args.exportFile) 326 | ensurePath(path) 327 | writeDict2JSON(wholeData, args.exportFile) 328 | 329 | # Print confusion matrix 330 | def printConfMatrix(confMatrix, args): 331 | # print line 332 | print('\b{text:{fill}>{width}}'.format(width=15, fill='-', text=' '), end=' ') 333 | for label in args.evalLabels: 334 | print('\b{text:{fill}>{width}}'.format(width=args.printRow + 2, fill='-', text=' '), end=' ') 335 | print('\b{text:{fill}>{width}}'.format(width=args.printRow + 3, fill='-', text=' ')) 336 | 337 | # print label names 338 | print('\b{text:>{width}} |'.format(width=13, text=''), end=' ') 339 | for label in args.evalLabels: 340 | print('\b{text:^{width}} |'.format(width=args.printRow, text=id2label[label].name[0]), end=' ') 341 | print('\b{text:>{width}} |'.format(width=6, text='Prior')) 342 | 343 | # print line 344 | print('\b{text:{fill}>{width}}'.format(width=15, fill='-', text=' '), end=' ') 345 | for label in args.evalLabels: 346 | print('\b{text:{fill}>{width}}'.format(width=args.printRow + 2, fill='-', text=' '), end=' ') 347 | print('\b{text:{fill}>{width}}'.format(width=args.printRow + 3, fill='-', text=' ')) 348 | 349 | # print matrix 350 | for x in range(0, confMatrix.shape[0]): 351 | if (not x in args.evalLabels): 352 | continue 353 | # get prior of this label 354 | prior = getPrior(x, confMatrix) 355 | # skip if label does not exist in ground truth 356 | if prior < 1e-9: 357 | continue 358 | 359 | # print name 360 | name = id2label[x].name 361 | if len(name) > 13: 362 | name = name[:13] 363 | print('\b{text:>{width}} |'.format(width=13,text=name), end=' ') 364 | # print matrix content 365 | for y in range(0, len(confMatrix[x])): 366 | if (not y in args.evalLabels): 367 | continue 368 | matrixFieldValue = getMatrixFieldValue(confMatrix, x, y, args) 369 | print(getColorEntry(matrixFieldValue, args) + '\b{text:>{width}.2f} '.format(width=args.printRow, text=matrixFieldValue) + args.nocol, end=' ') 370 | # print prior 371 | print(getColorEntry(prior, args) + '\b{text:>{width}.4f} '.format(width=6, text=prior) + args.nocol) 372 | # print line 373 | print('\b{text:{fill}>{width}}'.format(width=15, fill='-', text=' '), end=' ') 374 | for label in args.evalLabels: 375 | print('\b{text:{fill}>{width}}'.format(width=args.printRow + 2, fill='-', text=' '), end=' ') 376 | print('\b{text:{fill}>{width}}'.format(width=args.printRow + 3, fill='-', text=' '), end=' ') 377 | 378 | # Print intersection-over-union scores for all classes. 379 | def printClassScores(scoreList, instScoreList, args): 380 | if (args.quiet): 381 | return 382 | print(args.bold + 'classes IoU nIoU' + args.nocol) 383 | print('--------------------------------') 384 | for label in args.evalLabels: 385 | if (id2label[label].ignoreInEval): 386 | continue 387 | labelName = str(id2label[label].name) 388 | iouStr = getColorEntry(scoreList[labelName], args) + '{val:>5.3f}'.format(val=scoreList[labelName]) + args.nocol 389 | niouStr = getColorEntry(instScoreList[labelName], args) + '{val:>5.3f}'.format(val=instScoreList[labelName]) + args.nocol 390 | print('{:<14}: '.format(labelName) + iouStr + ' ' + niouStr) 391 | 392 | # Print intersection-over-union scores for all categorys. 393 | def printCategoryScores(scoreDict, instScoreDict, args): 394 | if (args.quiet): 395 | return 396 | print(args.bold + 'categories IoU nIoU' + args.nocol) 397 | print('--------------------------------') 398 | for categoryName in scoreDict: 399 | if all( label.ignoreInEval for label in category2labels[categoryName] ): 400 | continue 401 | iouStr = getColorEntry(scoreDict[categoryName], args) + '{val:>5.3f}'.format(val=scoreDict[categoryName]) + args.nocol 402 | niouStr = getColorEntry(instScoreDict[categoryName], args) + '{val:>5.3f}'.format(val=instScoreDict[categoryName]) + args.nocol 403 | print('{:<14}: '.format(categoryName) + iouStr + ' ' + niouStr) 404 | 405 | # Evaluate image lists pairwise. 406 | def evaluateImgLists(predictionImgList, groundTruthImgList, args): 407 | if len(predictionImgList) != len(groundTruthImgList): 408 | printError('List of images for prediction and groundtruth are not of equal size.') 409 | confMatrix = generateMatrix(args) 410 | instStats = generateInstanceStats(args) 411 | perImageStats = {} 412 | nbPixels = 0 413 | 414 | if not args.quiet: 415 | print('Evaluating {} pairs of images...'.format(len(predictionImgList))) 416 | 417 | # Evaluate all pairs of images and save them into a matrix 418 | for i in range(len(predictionImgList)): 419 | predictionImgFileName = predictionImgList[i] 420 | groundTruthImgFileName = groundTruthImgList[i] 421 | #print 'Evaluate ', predictionImgFileName, '<>', groundTruthImgFileName 422 | nbPixels += evaluatePair(predictionImgFileName, groundTruthImgFileName, confMatrix, instStats, perImageStats, args) 423 | 424 | # sanity check 425 | if confMatrix.sum() != nbPixels: 426 | printError('Number of analyzed pixels and entries in confusion matrix disagree: contMatrix {}, pixels {}'.format(confMatrix.sum(),nbPixels)) 427 | 428 | if not args.quiet: 429 | print('\rImages Processed: {}'.format(i+1), end=' ') 430 | sys.stdout.flush() 431 | if not args.quiet: 432 | print('\n') 433 | 434 | # sanity check 435 | if confMatrix.sum() != nbPixels: 436 | printError('Number of analyzed pixels and entries in confusion matrix disagree: contMatrix {}, pixels {}'.format(confMatrix.sum(),nbPixels)) 437 | 438 | # print confusion matrix 439 | if (not args.quiet): 440 | printConfMatrix(confMatrix, args) 441 | 442 | # Calculate IOU scores on class level from matrix 443 | classScoreList = {} 444 | for label in args.evalLabels: 445 | labelName = id2label[label].name 446 | classScoreList[labelName] = getIouScoreForLabel(label, confMatrix, args) 447 | 448 | # Calculate instance IOU scores on class level from matrix 449 | classInstScoreList = {} 450 | for label in args.evalLabels: 451 | labelName = id2label[label].name 452 | classInstScoreList[labelName] = getInstanceIouScoreForLabel(label, confMatrix, instStats, args) 453 | 454 | # Print IOU scores 455 | if (not args.quiet): 456 | print('') 457 | print('') 458 | printClassScores(classScoreList, classInstScoreList, args) 459 | iouAvgStr = getColorEntry(getScoreAverage(classScoreList, args), args) + '{avg:5.3f}'.format(avg=getScoreAverage(classScoreList, args)) + args.nocol 460 | niouAvgStr = getColorEntry(getScoreAverage(classInstScoreList , args), args) + '{avg:5.3f}'.format(avg=getScoreAverage(classInstScoreList , args)) + args.nocol 461 | print('--------------------------------') 462 | print('Score Average : ' + iouAvgStr + ' ' + niouAvgStr) 463 | print('--------------------------------') 464 | print('') 465 | 466 | # Calculate IOU scores on category level from matrix 467 | categoryScoreList = {} 468 | for category in category2labels.keys(): 469 | categoryScoreList[category] = getIouScoreForCategory(category,confMatrix,args) 470 | 471 | # Calculate instance IOU scores on category level from matrix 472 | categoryInstScoreList = {} 473 | for category in category2labels.keys(): 474 | categoryInstScoreList[category] = getInstanceIouScoreForCategory(category,confMatrix,instStats,args) 475 | 476 | # Print IOU scores 477 | if (not args.quiet): 478 | print('') 479 | printCategoryScores(categoryScoreList, categoryInstScoreList, args) 480 | iouAvgStr = getColorEntry(getScoreAverage(categoryScoreList, args), args) + '{avg:5.3f}'.format(avg=getScoreAverage(categoryScoreList, args)) + args.nocol 481 | niouAvgStr = getColorEntry(getScoreAverage(categoryInstScoreList, args), args) + '{avg:5.3f}'.format(avg=getScoreAverage(categoryInstScoreList, args)) + args.nocol 482 | print('--------------------------------') 483 | print('Score Average : ' + iouAvgStr + ' ' + niouAvgStr) 484 | print('--------------------------------') 485 | print('') 486 | 487 | allResultsDict = createResultDict( confMatrix, classScoreList, classInstScoreList, categoryScoreList, categoryInstScoreList, perImageStats, args ) 488 | # write result file 489 | if args.JSONOutput: 490 | writeJSONFile( allResultsDict, args) 491 | 492 | writeDict2Txt(allResultsDict, 'results.txt') 493 | 494 | # return confusion matrix 495 | return allResultsDict 496 | 497 | # Main evaluation method. Evaluates pairs of prediction and ground truth 498 | # images which are passed as arguments. 499 | def evaluatePair(predictionImgFileName, groundTruthImgFileName, confMatrix, instanceStats, perImageStats, args): 500 | # Loading all resources for evaluation. 501 | try: 502 | predictionImg = Image.open(predictionImgFileName) 503 | predictionNp = np.array(predictionImg) 504 | except: 505 | printError('Unable to load ' + predictionImgFileName) 506 | try: 507 | groundTruthImg = Image.open(groundTruthImgFileName) 508 | groundTruthNp = np.array(groundTruthImg) 509 | except: 510 | printError('Unable to load ' + groundTruthImgFileName) 511 | # load ground truth instances, if needed 512 | if args.evalInstLevelScore: 513 | groundTruthInstanceImgFileName = groundTruthImgFileName.replace('labelIds','instanceIds') 514 | try: 515 | instanceImg = Image.open(groundTruthInstanceImgFileName) 516 | instanceNp = np.array(instanceImg) 517 | except: 518 | printError('Unable to load ' + groundTruthInstanceImgFileName) 519 | 520 | # Check for equal image sizes 521 | if (predictionImg.size[0] != groundTruthImg.size[0]): 522 | printError('Image widths of ' + predictionImgFileName + ' and ' + groundTruthImgFileName + ' are not equal.') 523 | if (predictionImg.size[1] != groundTruthImg.size[1]): 524 | printError('Image heights of ' + predictionImgFileName + ' and ' + groundTruthImgFileName + ' are not equal.') 525 | if ( len(predictionNp.shape) != 2 ): 526 | printError('Predicted image has multiple channels.') 527 | 528 | imgWidth = predictionImg.size[0] 529 | imgHeight = predictionImg.size[1] 530 | nbPixels = imgWidth*imgHeight 531 | 532 | # Evaluate images 533 | encoding_value = max(groundTruthNp.max(), predictionNp.max()).astype(np.int32) + 1 534 | encoded = (groundTruthNp.astype(np.int32) * encoding_value) + predictionNp 535 | 536 | values, cnt = np.unique(encoded, return_counts=True) 537 | 538 | for value, c in zip(values, cnt): 539 | pred_id = value % encoding_value 540 | gt_id = int((value - pred_id)/encoding_value) 541 | if not gt_id in args.evalLabels: 542 | printError('Unknown label with id {:}'.format(gt_id)) 543 | confMatrix[gt_id][pred_id] += c 544 | 545 | 546 | if args.evalInstLevelScore: 547 | # Generate category masks 548 | categoryMasks = {} 549 | for category in instanceStats['categories']: 550 | categoryMasks[category] = np.in1d( predictionNp , instanceStats['categories'][category]['labelIds'] ).reshape(predictionNp.shape) 551 | 552 | instList = np.unique(instanceNp[instanceNp > 1000]) 553 | for instId in instList: 554 | labelId = int(instId/1000) 555 | label = id2label[ labelId ] 556 | if label.ignoreInEval: 557 | continue 558 | 559 | mask = instanceNp==instId 560 | instSize = np.count_nonzero( mask ) 561 | 562 | tp = np.count_nonzero( predictionNp[mask] == labelId ) 563 | fn = instSize - tp 564 | 565 | weight = args.avgClassSize[label.name] / float(instSize) 566 | tpWeighted = float(tp) * weight 567 | fnWeighted = float(fn) * weight 568 | 569 | instanceStats['classes'][label.name]['tp'] += tp 570 | instanceStats['classes'][label.name]['fn'] += fn 571 | instanceStats['classes'][label.name]['tpWeighted'] += tpWeighted 572 | instanceStats['classes'][label.name]['fnWeighted'] += fnWeighted 573 | 574 | category = label.category 575 | if category in instanceStats['categories']: 576 | catTp = 0 577 | catTp = np.count_nonzero( np.logical_and( mask , categoryMasks[category] ) ) 578 | catFn = instSize - catTp 579 | 580 | catTpWeighted = float(catTp) * weight 581 | catFnWeighted = float(catFn) * weight 582 | 583 | instanceStats['categories'][category]['tp'] += catTp 584 | instanceStats['categories'][category]['fn'] += catFn 585 | instanceStats['categories'][category]['tpWeighted'] += catTpWeighted 586 | instanceStats['categories'][category]['fnWeighted'] += catFnWeighted 587 | 588 | if args.evalPixelAccuracy: 589 | notIgnoredLabels = [l for l in args.evalLabels if not id2label[l].ignoreInEval] 590 | notIgnoredPixels = np.in1d( groundTruthNp , notIgnoredLabels , invert=True ).reshape(groundTruthNp.shape) 591 | erroneousPixels = np.logical_and( notIgnoredPixels , ( predictionNp != groundTruthNp ) ) 592 | perImageStats[predictionImgFileName] = {} 593 | perImageStats[predictionImgFileName]['nbNotIgnoredPixels'] = np.count_nonzero(notIgnoredPixels) 594 | perImageStats[predictionImgFileName]['nbCorrectPixels'] = np.count_nonzero(erroneousPixels) 595 | 596 | return nbPixels 597 | 598 | # The main method 599 | def main(): 600 | global args 601 | global pargs 602 | 603 | # Parse optional arguments 604 | pargs = parser.parse_args() 605 | # Parameters that should be modified by user 606 | args.groundTruthSearch = os.path.join(pargs.minicity , 'gtFine' , 'val' , '*_gtFine_labelIds.png') 607 | 608 | predictionImgList = [] 609 | groundTruthImgList = [] 610 | 611 | # use the ground truth search string specified above 612 | groundTruthImgList = glob.glob(args.groundTruthSearch) 613 | if not groundTruthImgList: 614 | printError('Cannot find any ground truth images to use for evaluation. Searched for: {}'.format(args.groundTruthSearch)) 615 | # get the corresponding prediction for each ground truth imag 616 | for gt in groundTruthImgList: 617 | predictionImgList.append( getPrediction(args,gt) ) 618 | 619 | # evaluate 620 | evaluateImgLists(predictionImgList, groundTruthImgList, args) 621 | 622 | return 623 | 624 | # call the main method 625 | if __name__ == '__main__': 626 | main() 627 | -------------------------------------------------------------------------------- /helpers/csHelpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Various helper methods and includes for Cityscapes 4 | # 5 | 6 | # Python imports 7 | from __future__ import print_function, absolute_import, division 8 | import os, sys 9 | import math 10 | import json 11 | from collections import namedtuple 12 | 13 | # Print an error message and quit 14 | def printError(message): 15 | print('ERROR: ' + str(message)) 16 | sys.exit(-1) 17 | 18 | # Class for colors 19 | class colors: 20 | RED = '\033[31;1m' 21 | GREEN = '\033[32;1m' 22 | YELLOW = '\033[33;1m' 23 | BLUE = '\033[34;1m' 24 | MAGENTA = '\033[35;1m' 25 | CYAN = '\033[36;1m' 26 | BOLD = '\033[1m' 27 | UNDERLINE = '\033[4m' 28 | ENDC = '\033[0m' 29 | 30 | # Colored value output if colorized flag is activated. 31 | def getColorEntry(val, args): 32 | if not args.colorized: 33 | return "" 34 | if not isinstance(val, float) or math.isnan(val): 35 | return colors.ENDC 36 | if (val < .20): 37 | return colors.RED 38 | elif (val < .40): 39 | return colors.YELLOW 40 | elif (val < .60): 41 | return colors.BLUE 42 | elif (val < .80): 43 | return colors.CYAN 44 | else: 45 | return colors.GREEN 46 | 47 | # Cityscapes files have a typical filename structure 48 | # ___[_]. 49 | # This class contains the individual elements as members 50 | # For the sequence and frame number, the strings are returned, including leading zeros 51 | CsFile = namedtuple( 'csFile' , [ 'city' , 'sequenceNb' , 'frameNb' , 'type' , 'type2' , 'ext' ] ) 52 | 53 | # Returns a CsFile object filled from the info in the given filename 54 | def getCsFileInfo(fileName): 55 | baseName = os.path.basename(fileName) 56 | parts = baseName.split('_') 57 | parts = parts[:-1] + parts[-1].split('.') 58 | if not parts: 59 | printError( 'Cannot parse given filename ({}). Does not seem to be a valid Cityscapes file.'.format(fileName) ) 60 | if len(parts) == 5: 61 | csFile = CsFile( *parts[:-1] , type2="" , ext=parts[-1] ) 62 | elif len(parts) == 6: 63 | csFile = CsFile( *parts ) 64 | else: 65 | printError( 'Found {} part(s) in given filename ({}). Expected 5 or 6.'.format(len(parts) , fileName) ) 66 | 67 | return csFile 68 | 69 | # Returns the part of Cityscapes filenames that is common to all data types 70 | # e.g. for city_123456_123456_gtFine_polygons.json returns city_123456_123456 71 | def getCoreImageFileName(filename): 72 | csFile = getCsFileInfo(filename) 73 | return "{}_{}_{}".format( csFile.city , csFile.sequenceNb , csFile.frameNb ) 74 | 75 | # Returns the directory name for the given filename, e.g. 76 | # fileName = "/foo/bar/foobar.txt" 77 | # return value is "bar" 78 | # Not much error checking though 79 | def getDirectory(fileName): 80 | dirName = os.path.dirname(fileName) 81 | return os.path.basename(dirName) 82 | 83 | # Make sure that the given path exists 84 | def ensurePath(path): 85 | if not path: 86 | return 87 | if not os.path.isdir(path): 88 | os.makedirs(path) 89 | 90 | # Write a dictionary as json file 91 | def writeDict2JSON(dictName, fileName): 92 | with open(fileName, 'w') as f: 93 | f.write(json.dumps(dictName, default=lambda o: o.__dict__, sort_keys=True, indent=4)) 94 | 95 | # Write a dictionary as json file 96 | def writeDict2Txt(dictName, fileName): 97 | with open(fileName, 'w') as f: 98 | f.write('IoU Class: '+str(dictName['averageScoreClasses'])+'\n') 99 | f.write('iIoU Class: '+str(dictName['averageScoreInstClasses'])+'\n') 100 | f.write('IoU Category: '+str(dictName['averageScoreCategories'])+'\n') 101 | f.write('iIoU Category: '+str(dictName['averageScoreInstCategories'])+'\n') 102 | f.write('Accuracy: '+str(dictName['accuracy'])) 103 | 104 | # dummy main 105 | if __name__ == "__main__": 106 | printError("Only for include, not executable on its own.") 107 | -------------------------------------------------------------------------------- /helpers/helpers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | """ 7 | ==================== 8 | Classes for logging and progress printing. 9 | ==================== 10 | """ 11 | class AverageMeter(object): 12 | """Computes and stores the average and current value""" 13 | def __init__(self, name, fmt=':f'): 14 | self.name = name 15 | self.fmt = fmt 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | def __str__(self): 31 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 32 | return fmtstr.format(**self.__dict__) 33 | 34 | 35 | class ProgressMeter(object): 36 | def __init__(self, num_batches, meters, prefix=""): 37 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 38 | self.meters = meters 39 | self.prefix = prefix 40 | 41 | def display(self, batch): 42 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 43 | entries += [str(meter) for meter in self.meters] 44 | print('\t'.join(entries)) 45 | 46 | def _get_batch_fmtstr(self, num_batches): 47 | num_digits = len(str(num_batches // 1)) 48 | fmt = '{:' + str(num_digits) + 'd}' 49 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 50 | 51 | """ 52 | ==================== 53 | iouCalc: Calculates IoU scores of dataset per individual batch. 54 | Arguments: 55 | - labelNames: list of class names 56 | - validClasses: list of valid class ids 57 | - voidClass: class id of void class (class ignored for calculating metrics) 58 | ==================== 59 | """ 60 | 61 | class iouCalc(): 62 | 63 | def __init__(self, classLabels, validClasses, voidClass = None): 64 | assert len(classLabels) == len(validClasses), 'Number of class ids and names must be equal' 65 | self.classLabels = classLabels 66 | self.validClasses = validClasses 67 | self.voidClass = voidClass 68 | self.evalClasses = [l for l in validClasses if l != voidClass] 69 | 70 | self.perImageStats = [] 71 | self.nbPixels = 0 72 | self.confMatrix = np.zeros(shape=(len(self.validClasses),len(self.validClasses)),dtype=np.ulonglong) 73 | 74 | # Init IoU log files 75 | self.headerStr = 'epoch, ' 76 | for label in self.classLabels: 77 | if label.lower() != 'void': 78 | self.headerStr += label + ', ' 79 | 80 | def clear(self): 81 | self.perImageStats = [] 82 | self.nbPixels = 0 83 | self.confMatrix = np.zeros(shape=(len(self.validClasses),len(self.validClasses)),dtype=np.ulonglong) 84 | 85 | def getIouScoreForLabel(self, label): 86 | # Calculate and return IOU score for a particular label (train_id) 87 | if label == self.voidClass: 88 | return float('nan') 89 | 90 | # the number of true positive pixels for this label 91 | # the entry on the diagonal of the confusion matrix 92 | tp = np.longlong(self.confMatrix[label,label]) 93 | 94 | # the number of false negative pixels for this label 95 | # the row sum of the matching row in the confusion matrix 96 | # minus the diagonal entry 97 | fn = np.longlong(self.confMatrix[label,:].sum()) - tp 98 | 99 | # the number of false positive pixels for this labels 100 | # Only pixels that are not on a pixel with ground truth label that is ignored 101 | # The column sum of the corresponding column in the confusion matrix 102 | # without the ignored rows and without the actual label of interest 103 | notIgnored = [l for l in self.validClasses if not l == self.voidClass and not l==label] 104 | fp = np.longlong(self.confMatrix[notIgnored,label].sum()) 105 | 106 | # the denominator of the IOU score 107 | denom = (tp + fp + fn) 108 | if denom == 0: 109 | return float('nan') 110 | 111 | # return IOU 112 | return float(tp) / denom 113 | 114 | def evaluateBatch(self, predictionBatch, groundTruthBatch): 115 | # Calculate IoU scores for single batch 116 | assert predictionBatch.size(0) == groundTruthBatch.size(0), 'Number of predictions and labels in batch disagree.' 117 | 118 | # Load batch to CPU and convert to numpy arrays 119 | predictionBatch = predictionBatch.cpu().numpy() 120 | groundTruthBatch = groundTruthBatch.cpu().numpy() 121 | 122 | for i in range(predictionBatch.shape[0]): 123 | predictionImg = predictionBatch[i,:,:] 124 | groundTruthImg = groundTruthBatch[i,:,:] 125 | 126 | # Check for equal image sizes 127 | assert predictionImg.shape == groundTruthImg.shape, 'Image shapes do not match.' 128 | assert len(predictionImg.shape) == 2, 'Predicted image has multiple channels.' 129 | 130 | imgWidth = predictionImg.shape[0] 131 | imgHeight = predictionImg.shape[1] 132 | nbPixels = imgWidth*imgHeight 133 | 134 | # Evaluate images 135 | encoding_value = max(groundTruthImg.max(), predictionImg.max()).astype(np.int32) + 1 136 | encoded = (groundTruthImg.astype(np.int32) * encoding_value) + predictionImg 137 | 138 | values, cnt = np.unique(encoded, return_counts=True) 139 | 140 | for value, c in zip(values, cnt): 141 | pred_id = value % encoding_value 142 | gt_id = int((value - pred_id)/encoding_value) 143 | if not gt_id in self.validClasses: 144 | printError('Unknown label with id {:}'.format(gt_id)) 145 | self.confMatrix[gt_id][pred_id] += c 146 | 147 | # Calculate pixel accuracy 148 | notIgnoredPixels = np.in1d(groundTruthImg, self.evalClasses, invert=True).reshape(groundTruthImg.shape) 149 | erroneousPixels = np.logical_and(notIgnoredPixels, (predictionImg != groundTruthImg)) 150 | nbNotIgnoredPixels = np.count_nonzero(notIgnoredPixels) 151 | nbErroneousPixels = np.count_nonzero(erroneousPixels) 152 | self.perImageStats.append([nbNotIgnoredPixels, nbErroneousPixels]) 153 | 154 | self.nbPixels += nbPixels 155 | 156 | return 157 | 158 | def outputScores(self): 159 | # Output scores over dataset 160 | assert self.confMatrix.sum() == self.nbPixels, 'Number of analyzed pixels and entries in confusion matrix disagree: confMatrix {}, pixels {}'.format(self.confMatrix.sum(),self.nbPixels) 161 | 162 | # Calculate IOU scores on class level from matrix 163 | classScoreList = [] 164 | 165 | # Print class IOU scores 166 | outStr = 'classes IoU\n' 167 | outStr += '---------------------\n' 168 | for c in self.evalClasses: 169 | iouScore = self.getIouScoreForLabel(c) 170 | classScoreList.append(iouScore) 171 | outStr += '{:<14}: {:>5.3f}\n'.format(self.classLabels[c], iouScore) 172 | miou = getScoreAverage(classScoreList) 173 | outStr += '---------------------\n' 174 | outStr += 'Mean IoU : {avg:5.3f}\n'.format(avg=miou) 175 | outStr += '---------------------' 176 | 177 | print(outStr) 178 | 179 | return miou 180 | 181 | # Print an error message and quit 182 | def printError(message): 183 | print('ERROR: ' + str(message)) 184 | sys.exit(-1) 185 | 186 | def getScoreAverage(scoreList): 187 | validScores = 0 188 | scoreSum = 0.0 189 | for score in scoreList: 190 | if not np.isnan(score): 191 | validScores += 1 192 | scoreSum += score 193 | if validScores == 0: 194 | return float('nan') 195 | return scoreSum / validScores 196 | 197 | 198 | """ 199 | ================ 200 | Visualize images 201 | ================ 202 | """ 203 | 204 | def visim(img, args): 205 | img = img.cpu() 206 | # Convert image data to visual representation 207 | img *= torch.tensor(args.dataset_std)[:,None,None] 208 | img += torch.tensor(args.dataset_mean)[:,None,None] 209 | npimg = (img.numpy()*255).astype('uint8') 210 | if len(npimg.shape) == 3 and npimg.shape[0] == 3: 211 | npimg = np.transpose(npimg, (1, 2, 0)) 212 | else: 213 | npimg = npimg[0,:,:] 214 | return npimg 215 | 216 | def vislbl(label, mask_colors): 217 | label = label.cpu() 218 | # Convert label data to visual representation 219 | label = np.array(label.numpy()) 220 | if label.shape[-1] == 1: 221 | label = label[:,:,0] 222 | 223 | # Convert train_ids to colors 224 | label = mask_colors[label] 225 | return label 226 | 227 | """ 228 | ==================== 229 | Plot learning curves 230 | ==================== 231 | """ 232 | 233 | def plot_learning_curves(metrics, args): 234 | x = np.arange(args.epochs) 235 | fig, ax1 = plt.subplots() 236 | ax1.set_xlabel('epochs') 237 | ax1.set_ylabel('loss') 238 | ln1 = ax1.plot(x, metrics['train_loss'], color='tab:red') 239 | ln2 = ax1.plot(x, metrics['val_loss'], color='tab:red', linestyle='dashed') 240 | ax1.grid() 241 | ax2 = ax1.twinx() 242 | ax2.set_ylabel('accuracy') 243 | ln3 = ax2.plot(x, metrics['train_acc'], color='tab:blue') 244 | ln4 = ax2.plot(x, metrics['val_acc'], color='tab:blue', linestyle='dashed') 245 | ln5 = ax2.plot(x, metrics['miou'], color='tab:green') 246 | lns = ln1+ln2+ln3+ln4+ln5 247 | plt.legend(lns, ['Train loss','Validation loss','Train accuracy','Validation accuracy','mIoU']) 248 | plt.tight_layout() 249 | plt.savefig(args.save_path + '/learning_curve.png', bbox_inches='tight') 250 | 251 | -------------------------------------------------------------------------------- /helpers/labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Cityscapes labels 4 | # 5 | 6 | from __future__ import print_function, absolute_import, division 7 | from collections import namedtuple 8 | 9 | 10 | #-------------------------------------------------------------------------------- 11 | # Definitions 12 | #-------------------------------------------------------------------------------- 13 | 14 | # a label and all meta information 15 | Label = namedtuple( 'Label' , [ 16 | 17 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 18 | # We use them to uniquely name a class 19 | 20 | 'id' , # An integer ID that is associated with this label. 21 | # The IDs are used to represent the label in ground truth images 22 | # An ID of -1 means that this label does not have an ID and thus 23 | # is ignored when creating ground truth images (e.g. license plate). 24 | # Do not modify these IDs, since exactly these IDs are expected by the 25 | # evaluation server. 26 | 27 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 28 | # ground truth images with train IDs, using the tools provided in the 29 | # 'preparation' folder. However, make sure to validate or submit results 30 | # to our evaluation server using the regular IDs above! 31 | # For trainIds, multiple labels might have the same ID. Then, these labels 32 | # are mapped to the same class in the ground truth images. For the inverse 33 | # mapping, we use the label that is defined first in the list below. 34 | # For example, mapping all void-type classes to the same ID in training, 35 | # might make sense for some approaches. 36 | # Max value is 255! 37 | 38 | 'category' , # The name of the category that this label belongs to 39 | 40 | 'categoryId' , # The ID of this category. Used to create ground truth images 41 | # on category level. 42 | 43 | 'hasInstances', # Whether this label distinguishes between single instances or not 44 | 45 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 46 | # during evaluations or not 47 | 48 | 'color' , # The color of this label 49 | ] ) 50 | 51 | 52 | #-------------------------------------------------------------------------------- 53 | # A list of all labels 54 | #-------------------------------------------------------------------------------- 55 | 56 | # Please adapt the train IDs as appropriate for your approach. 57 | # Note that you might want to ignore labels with ID 255 during training. 58 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 59 | # Make sure to provide your results using the original IDs and not the training IDs. 60 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 61 | 62 | labels = [ 63 | # name id trainId category catId hasInstances ignoreInEval color 64 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 65 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 66 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 67 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 68 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 69 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 70 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 71 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 72 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 73 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 74 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 75 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 76 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 77 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 78 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 79 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 80 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 81 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 82 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 83 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 84 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 85 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 86 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 87 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 88 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 89 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 90 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 91 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 92 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 93 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 94 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 95 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 96 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 97 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 98 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 99 | ] 100 | 101 | 102 | #-------------------------------------------------------------------------------- 103 | # Create dictionaries for a fast lookup 104 | #-------------------------------------------------------------------------------- 105 | 106 | # Please refer to the main method below for example usages! 107 | 108 | # name to label object 109 | name2label = { label.name : label for label in labels } 110 | # id to label object 111 | id2label = { label.id : label for label in labels } 112 | # trainId to label object 113 | trainId2label = { label.trainId : label for label in reversed(labels) } 114 | # category to list of label objects 115 | category2labels = {} 116 | for label in labels: 117 | category = label.category 118 | if category in category2labels: 119 | category2labels[category].append(label) 120 | else: 121 | category2labels[category] = [label] 122 | 123 | #-------------------------------------------------------------------------------- 124 | # Assure single instance name 125 | #-------------------------------------------------------------------------------- 126 | 127 | # returns the label name that describes a single instance (if possible) 128 | # e.g. input | output 129 | # ---------------------- 130 | # car | car 131 | # cargroup | car 132 | # foo | None 133 | # foogroup | None 134 | # skygroup | None 135 | def assureSingleInstanceName( name ): 136 | # if the name is known, it is not a group 137 | if name in name2label: 138 | return name 139 | # test if the name actually denotes a group 140 | if not name.endswith("group"): 141 | return None 142 | # remove group 143 | name = name[:-len("group")] 144 | # test if the new name exists 145 | if not name in name2label: 146 | return None 147 | # test if the new name denotes a label that actually has instances 148 | if not name2label[name].hasInstances: 149 | return None 150 | # all good then 151 | return name 152 | 153 | #-------------------------------------------------------------------------------- 154 | # Main for testing 155 | #-------------------------------------------------------------------------------- 156 | 157 | # just a dummy main 158 | if __name__ == "__main__": 159 | # Print all the labels 160 | print("List of cityscapes labels:") 161 | print("") 162 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )) 163 | print(" " + ('-' * 98)) 164 | for label in labels: 165 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )) 166 | print("") 167 | 168 | print("Example usages:") 169 | 170 | # Map from name to label 171 | name = 'car' 172 | id = name2label[name].id 173 | print("ID of label '{name}': {id}".format( name=name, id=id )) 174 | 175 | # Map from ID to label 176 | category = id2label[id].category 177 | print("Category of label with ID '{id}': {category}".format( id=id, category=category )) 178 | 179 | # Map from trainID to label 180 | trainId = 0 181 | name = trainId2label[trainId].name 182 | print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )) 183 | -------------------------------------------------------------------------------- /learning/learner.py: -------------------------------------------------------------------------------- 1 | from helpers.helpers import AverageMeter, ProgressMeter, iouCalc, visim, vislbl 2 | from learning.minicity import MiniCity 3 | from learning.utils import rand_bbox, copyblob 4 | import torch 5 | import torch.nn.functional as F 6 | import cv2 7 | import os 8 | import numpy as np 9 | import time 10 | from PIL import Image 11 | 12 | """ 13 | ================= 14 | Routine functions 15 | ================= 16 | """ 17 | 18 | def train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, void=-1, args=None): 19 | batch_time = AverageMeter('Time', ':6.3f') 20 | data_time = AverageMeter('Data', ':6.3f') 21 | loss_running = AverageMeter('Loss', ':.4e') 22 | acc_running = AverageMeter('Accuracy', ':.3f') 23 | progress = ProgressMeter( 24 | len(dataloader), 25 | [batch_time, data_time, loss_running, acc_running], 26 | prefix="Train, epoch: [{}]".format(epoch)) 27 | 28 | # input resolution 29 | if args.crop_size is not None: 30 | res = args.crop_size[0]*args.crop_size[1] 31 | else: 32 | res = args.train_size[0]*args.train_size[1] 33 | 34 | # Set model in training mode 35 | model.train() 36 | 37 | end = time.time() 38 | 39 | with torch.set_grad_enabled(True): 40 | # Iterate over data. 41 | for epoch_step, (inputs, labels, _) in enumerate(dataloader): 42 | data_time.update(time.time()-end) 43 | 44 | if args.copyblob: 45 | for i in range(inputs.size()[0]): 46 | rand_idx = np.random.randint(inputs.size()[0]) 47 | # wall(3) --> sidewalk(1) 48 | copyblob(src_img=inputs[i], src_mask=labels[i], dst_img=inputs[rand_idx], dst_mask=labels[rand_idx], src_class=3, dst_class=1) 49 | # fence(4) --> sidewalk(1) 50 | copyblob(src_img=inputs[i], src_mask=labels[i], dst_img=inputs[rand_idx], dst_mask=labels[rand_idx], src_class=4, dst_class=1) 51 | # bus(15) --> road(0) 52 | copyblob(src_img=inputs[i], src_mask=labels[i], dst_img=inputs[rand_idx], dst_mask=labels[rand_idx], src_class=15, dst_class=0) 53 | # train(16) --> road(0) 54 | copyblob(src_img=inputs[i], src_mask=labels[i], dst_img=inputs[rand_idx], dst_mask=labels[rand_idx], src_class=16, dst_class=0) 55 | 56 | inputs = inputs.float().cuda() 57 | labels = labels.long().cuda() 58 | 59 | # zero the parameter gradients 60 | optimizer.zero_grad() 61 | 62 | if args.cutmix: 63 | # generate mixed sample 64 | lam = np.random.beta(1., 1.) 65 | rand_index = torch.randperm(inputs.size()[0]).cuda() 66 | bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam) 67 | inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2] 68 | labels[:, bbx1:bbx2, bby1:bby2] = labels[rand_index, bbx1:bbx2, bby1:bby2] 69 | 70 | # forward pass 71 | outputs = model(inputs) 72 | outputs = outputs['out'] #FIXME for DeepLab V3 73 | preds = torch.argmax(outputs, 1) 74 | # cross-entropy loss 75 | loss = criterion(outputs, labels) 76 | 77 | # backward pass 78 | loss.backward() 79 | optimizer.step() 80 | 81 | # Statistics 82 | bs = inputs.size(0) # current batch size 83 | loss = loss.item() 84 | loss_running.update(loss, bs) 85 | corrects = torch.sum(preds == labels.data) 86 | nvoid = int((labels==void).sum()) 87 | acc = corrects.double()/(bs*res-nvoid) # correct/(batch_size*resolution-voids) 88 | acc_running.update(acc, bs) 89 | 90 | # output training info 91 | progress.display(epoch_step) 92 | 93 | # Measure time 94 | batch_time.update(time.time() - end) 95 | end = time.time() 96 | 97 | # Reduce learning rate 98 | lr_scheduler.step(loss_running.avg) 99 | 100 | return loss_running.avg, acc_running.avg 101 | 102 | 103 | def validate_epoch(dataloader, model, criterion, epoch, classLabels, validClasses, void=-1, maskColors=None, folder='baseline_run', args=None): 104 | batch_time = AverageMeter('Time', ':6.3f') 105 | data_time = AverageMeter('Data', ':6.3f') 106 | loss_running = AverageMeter('Loss', ':.4e') 107 | acc_running = AverageMeter('Accuracy', ':.4e') 108 | iou = iouCalc(classLabels, validClasses, voidClass = void) 109 | progress = ProgressMeter( 110 | len(dataloader), 111 | [batch_time, data_time, loss_running, acc_running], 112 | prefix="Test, epoch: [{}]".format(epoch)) 113 | 114 | # input resolution 115 | res = args.test_size[0]*args.test_size[1] 116 | 117 | # Set model in evaluation mode 118 | model.eval() 119 | 120 | with torch.no_grad(): 121 | end = time.time() 122 | for epoch_step, (inputs, labels, filepath) in enumerate(dataloader): 123 | data_time.update(time.time()-end) 124 | 125 | inputs = inputs.float().cuda() 126 | labels = labels.long().cuda() 127 | 128 | # forward 129 | outputs = model(inputs) 130 | outputs = outputs['out'] #FIXME 131 | preds = torch.argmax(outputs, 1) 132 | loss = criterion(outputs, labels) 133 | 134 | # Statistics 135 | bs = inputs.size(0) # current batch size 136 | loss = loss.item() 137 | loss_running.update(loss, bs) 138 | corrects = torch.sum(preds == labels.data) 139 | nvoid = int((labels==void).sum()) 140 | acc = corrects.double()/(bs*res-nvoid) # correct/(batch_size*resolution-voids) 141 | acc_running.update(acc, bs) 142 | # Calculate IoU scores of current batch 143 | iou.evaluateBatch(preds, labels) 144 | 145 | # Save visualizations of first batch 146 | if epoch_step == 0 and maskColors is not None: 147 | for i in range(inputs.size(0)): 148 | filename = os.path.splitext(os.path.basename(filepath[i]))[0] 149 | # Only save inputs and labels once 150 | if epoch == 0: 151 | img = visim(inputs[i,:,:,:], args) 152 | label = vislbl(labels[i,:,:], maskColors) 153 | if len(img.shape) == 3: 154 | cv2.imwrite(folder + '/images/{}.png'.format(filename),img[:,:,::-1]) 155 | else: 156 | cv2.imwrite(folder + '/images/{}.png'.format(filename),img) 157 | cv2.imwrite(folder + '/images/{}_gt.png'.format(filename),label[:,:,::-1]) 158 | # Save predictions 159 | pred = vislbl(preds[i,:,:], maskColors) 160 | cv2.imwrite(folder + '/images/{}_epoch_{}.png'.format(filename,epoch),pred[:,:,::-1]) 161 | 162 | # measure elapsed time 163 | batch_time.update(time.time() - end) 164 | end = time.time() 165 | 166 | # print progress info 167 | progress.display(epoch_step) 168 | 169 | miou = iou.outputScores() 170 | print('Accuracy : {:5.3f}'.format(acc_running.avg)) 171 | print('---------------------') 172 | 173 | return acc_running.avg, loss_running.avg, miou 174 | 175 | def predict(dataloader, model, maskColors, folder='baseline_run', mode='val', args=None): 176 | batch_time = AverageMeter('Time', ':6.3f') 177 | data_time = AverageMeter('Data', ':6.3f') 178 | progress = ProgressMeter( 179 | len(dataloader), 180 | [batch_time, data_time], 181 | prefix='Predict: ') 182 | 183 | Dataset = MiniCity 184 | 185 | # Set model in evaluation mode 186 | model.eval() 187 | 188 | with torch.no_grad(): 189 | end = time.time() 190 | for epoch_step, batch in enumerate(dataloader): 191 | 192 | if len(batch) == 2: 193 | inputs, filepath = batch 194 | else: 195 | inputs, _, filepath = batch 196 | 197 | data_time.update(time.time()-end) 198 | 199 | inputs = inputs.float().cuda() 200 | 201 | if args.mst: 202 | batch_idx, _, h, w = inputs.size() #(1, 20, 1024, 2048) 203 | # only single image is supported for multi-scale testing 204 | assert(batch_idx == 1) 205 | with torch.cuda.device_of(inputs): 206 | scores = inputs.new().resize_(batch_idx, len(Dataset.validClasses), h, w).zero_().cuda() 207 | 208 | scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.2] #FIXME 209 | 210 | for scale in scales: 211 | inputs_resized = F.interpolate(inputs, scale_factor=scale, mode='bilinear', align_corners=True) 212 | #print("original size {}x{} --> resized to {}x{}".format(h, w, inputs_resized.size()[2], inputs_resized.size()[3])) 213 | 214 | # forward 215 | outputs = model(inputs_resized) 216 | outputs = outputs['out'] #FIXME (1, 20, 512, 1024) for scale 0.5 217 | 218 | score = F.interpolate(outputs, (h, w), mode='bilinear', align_corners=True) 219 | scores += score 220 | 221 | 222 | # forward using flipped input 223 | with torch.cuda.device_of(inputs_resized): 224 | idx = torch.arange(inputs_resized.size(3)-1, -1, -1).type_as(inputs_resized).long() 225 | input_resized_flip = inputs_resized.index_select(3, idx) 226 | 227 | # forward 228 | outputs = model(input_resized_flip) 229 | outputs = outputs['out'] #FIXME 230 | outputs = outputs.index_select(3, idx) 231 | 232 | score = F.interpolate(outputs, (h, w), mode='bilinear', align_corners=True) 233 | scores += score 234 | 235 | # averaging scores 236 | scores = scores / (2*len(scales)) 237 | 238 | preds = torch.argmax(scores, 1) # (1, 512, 1024) 239 | else: 240 | # forward 241 | outputs = model(inputs) 242 | outputs = outputs['out'] #FIXME 243 | 244 | preds = torch.argmax(outputs, 1) 245 | 246 | # Save visualizations of first batch 247 | for i in range(inputs.size(0)): 248 | filename = os.path.splitext(os.path.basename(filepath[i]))[0] 249 | # Save input 250 | img = visim(inputs[i,:,:,:], args) 251 | img = Image.fromarray(img, 'RGB') 252 | img.save(folder + '/results_color_{}/{}_input.png'.format(mode, filename)) 253 | # Save prediction with color labels 254 | pred = preds[i,:,:].cpu() 255 | pred_color = vislbl(pred, maskColors) 256 | pred_color = Image.fromarray(pred_color.astype('uint8')) 257 | pred_color.save(folder + '/results_color_{}/{}_prediction.png'.format(mode, filename)) 258 | # Save class id prediction (used for evaluation) 259 | pred_id = Dataset.trainid2id[pred] 260 | pred_id = Image.fromarray(pred_id) 261 | pred_id = pred_id.resize((2048,1024), resample=Image.NEAREST) 262 | pred_id.save(folder + '/results_{}/{}.png'.format(mode, filename)) 263 | 264 | 265 | # measure elapsed time 266 | batch_time.update(time.time() - end) 267 | end = time.time() 268 | 269 | # print progress info 270 | progress.display(epoch_step) 271 | -------------------------------------------------------------------------------- /learning/minicity.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | from collections import namedtuple 6 | from torchvision.datasets import Cityscapes 7 | 8 | class MiniCity(Cityscapes): 9 | 10 | voidClass = 19 11 | 12 | # Convert ids to train_ids 13 | id2trainid = np.array([label.train_id for label in Cityscapes.classes if label.train_id >= 0], dtype='uint8') 14 | id2trainid[np.where(id2trainid==255)] = voidClass 15 | 16 | # Convert train_ids to colors 17 | mask_colors = [list(label.color) for label in Cityscapes.classes if label.train_id >= 0 and label.train_id <= 19] 18 | mask_colors.append([0,0,0]) 19 | mask_colors = np.array(mask_colors) 20 | 21 | # Convert train_ids to ids 22 | trainid2id = np.zeros((256), dtype='uint8') 23 | for label in Cityscapes.classes: 24 | if label.train_id >= 0 and label.train_id < 255: 25 | trainid2id[label.train_id] = label.id 26 | 27 | # List of valid class ids 28 | validClasses = np.unique([label.train_id for label in Cityscapes.classes if label.id >= 0]) 29 | validClasses[np.where(validClasses==255)] = voidClass 30 | validClasses = list(validClasses) 31 | 32 | # Create list of class names 33 | classLabels = [label.name for label in Cityscapes.classes if not (label.ignore_in_eval or label.id < 0)] 34 | classLabels.append('void') 35 | 36 | def __init__(self, root, split='train', transform=None, target_transform=None, transforms=None): 37 | super(Cityscapes, self).__init__(root, transforms, transform, target_transform) 38 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 39 | self.targets_dir = os.path.join(self.root, 'gtFine', split) 40 | self.split = split 41 | self.images = [] 42 | self.targets = [] 43 | 44 | assert split in ['train','val','test'], 'Unknown value {} for argument split.'.format(split) 45 | 46 | for file_name in os.listdir(self.images_dir): 47 | self.images.append(os.path.join(self.images_dir, file_name)) 48 | if split != 'test': 49 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 50 | 'gtFine_labelIds.png') 51 | self.targets.append(os.path.join(self.targets_dir, target_name)) 52 | 53 | 54 | def __getitem__(self, index): 55 | """ 56 | Args: 57 | index (int): Index 58 | Returns: 59 | tuple: (image, target) 60 | """ 61 | 62 | filepath = self.images[index] 63 | image = Image.open(filepath).convert('RGB') 64 | 65 | if self.split != 'test': 66 | target = Image.open(self.targets[index]) 67 | 68 | if self.transforms is not None: 69 | if self.split != 'test': 70 | image, target = self.transforms(image, mask=target) 71 | # Convert class ids to train_ids and then to tensor 72 | target = self.id2trainid[target] 73 | return image, target, filepath 74 | else: 75 | image = self.transforms(image) 76 | return image, filepath -------------------------------------------------------------------------------- /learning/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, n_classes, batchnorm = False): 8 | super(UNet, self).__init__() 9 | self.inc = inconv(3, 64, batchnorm) 10 | self.down1 = down(64, 128, batchnorm) 11 | self.down2 = down(128, 256, batchnorm) 12 | self.down3 = down(256, 512, batchnorm) 13 | self.down4 = down(512, 512, batchnorm) 14 | self.up1 = up(1024, 256, batchnorm) 15 | self.up2 = up(512, 128, batchnorm) 16 | self.up3 = up(256, 64, batchnorm) 17 | self.up4 = up(128, 64, batchnorm) 18 | self.outc = outconv(64, n_classes) 19 | 20 | def forward(self, x): 21 | x1 = self.inc(x) 22 | x2 = self.down1(x1) 23 | x3 = self.down2(x2) 24 | x4 = self.down3(x3) 25 | x5 = self.down4(x4) 26 | x = self.up1(x5, x4) 27 | x = self.up2(x, x3) 28 | x = self.up3(x, x2) 29 | x = self.up4(x, x1) 30 | x = self.outc(x) 31 | return x 32 | 33 | class double_conv(nn.Module): 34 | '''(conv => BN => ReLU) * 2''' 35 | def __init__(self, in_ch, out_ch, bn): 36 | super(double_conv, self).__init__() 37 | if bn: 38 | self.conv = nn.Sequential( 39 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 40 | nn.BatchNorm2d(out_ch), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 43 | nn.BatchNorm2d(out_ch), 44 | nn.ReLU(inplace=True) 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 51 | nn.ReLU(inplace=True) 52 | ) 53 | 54 | 55 | def forward(self, x): 56 | x = self.conv(x) 57 | return x 58 | 59 | 60 | class inconv(nn.Module): 61 | def __init__(self, in_ch, out_ch, bn): 62 | super(inconv, self).__init__() 63 | self.conv = double_conv(in_ch, out_ch, bn) 64 | 65 | def forward(self, x): 66 | x = self.conv(x) 67 | return x 68 | 69 | 70 | class down(nn.Module): 71 | def __init__(self, in_ch, out_ch, bn): 72 | super(down, self).__init__() 73 | self.mpconv = nn.Sequential( 74 | nn.MaxPool2d(2), 75 | double_conv(in_ch, out_ch, bn) 76 | ) 77 | 78 | def forward(self, x): 79 | x = self.mpconv(x) 80 | return x 81 | 82 | 83 | class up(nn.Module): 84 | def __init__(self, in_ch, out_ch, bn, bilinear=True): 85 | super(up, self).__init__() 86 | 87 | # would be a nice idea if the upsampling could be learned too, 88 | # but my machine do not have enough memory to handle all those weights 89 | if bilinear: 90 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 91 | else: 92 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 93 | 94 | self.conv = double_conv(in_ch, out_ch, bn) 95 | 96 | def forward(self, x1, x2): 97 | x1 = self.up(x1) 98 | 99 | # input is CHW 100 | diffY = x2.size()[2] - x1.size()[2] 101 | diffX = x2.size()[3] - x1.size()[3] 102 | 103 | x1 = F.pad(x1, (diffX // 2, diffX - diffX//2, 104 | diffY // 2, diffY - diffY//2)) 105 | 106 | # for padding issues, see 107 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 108 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 109 | 110 | x = torch.cat([x2, x1], dim=1) 111 | x = self.conv(x) 112 | return x 113 | 114 | class _SimpleSegmentationModel(nn.Module): 115 | __constants__ = ['aux_classifier'] 116 | 117 | def __init__(self, backbone, classifier, aux_classifier=None): 118 | super(_SimpleSegmentationModel, self).__init__() 119 | self.backbone = backbone 120 | self.classifier = classifier 121 | self.aux_classifier = aux_classifier 122 | 123 | def forward(self, x): 124 | input_shape = x.shape[-2:] 125 | # contract: features is a dict of tensors 126 | features = self.backbone(x) 127 | 128 | result = OrderedDict() 129 | x = features["out"] 130 | x = self.classifier(x) 131 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 132 | result["out"] = x 133 | 134 | if self.aux_classifier is not None: 135 | x = features["aux"] 136 | x = self.aux_classifier(x) 137 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 138 | result["aux"] = x 139 | 140 | return result 141 | 142 | __all__ = ["DeepLabV3"] 143 | 144 | 145 | class DeepLabV3(_SimpleSegmentationModel): 146 | """ 147 | Implements DeepLabV3 model from 148 | `"Rethinking Atrous Convolution for Semantic Image Segmentation" 149 | `_. 150 | Arguments: 151 | backbone (nn.Module): the network used to compute the features for the model. 152 | The backbone should return an OrderedDict[Tensor], with the key being 153 | "out" for the last feature map used, and "aux" if an auxiliary classifier 154 | is used. 155 | classifier (nn.Module): module that takes the "out" element returned from 156 | the backbone and returns a dense prediction. 157 | aux_classifier (nn.Module, optional): auxiliary classifier used during training 158 | """ 159 | pass 160 | 161 | 162 | class DeepLabHead(nn.Sequential): 163 | def __init__(self, in_channels, num_classes): 164 | super(DeepLabHead, self).__init__( 165 | ASPP(in_channels, [12, 24, 36]), 166 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 167 | nn.BatchNorm2d(256), 168 | nn.ReLU(), 169 | nn.Conv2d(256, num_classes, 1) 170 | ) 171 | 172 | 173 | class ASPPConv(nn.Sequential): 174 | def __init__(self, in_channels, out_channels, dilation): 175 | modules = [ 176 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 177 | nn.BatchNorm2d(out_channels), 178 | nn.ReLU() 179 | ] 180 | super(ASPPConv, self).__init__(*modules) 181 | 182 | 183 | class ASPPPooling(nn.Sequential): 184 | def __init__(self, in_channels, out_channels): 185 | super(ASPPPooling, self).__init__( 186 | nn.AdaptiveAvgPool2d(1), 187 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 188 | nn.BatchNorm2d(out_channels), 189 | nn.ReLU()) 190 | 191 | def forward(self, x): 192 | size = x.shape[-2:] 193 | for mod in self: 194 | x = mod(x) 195 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 196 | 197 | 198 | class ASPP(nn.Module): 199 | def __init__(self, in_channels, atrous_rates): 200 | super(ASPP, self).__init__() 201 | out_channels = 256 202 | modules = [] 203 | modules.append(nn.Sequential( 204 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 205 | nn.BatchNorm2d(out_channels), 206 | nn.ReLU())) 207 | 208 | rate1, rate2, rate3 = tuple(atrous_rates) 209 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 210 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 211 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 212 | modules.append(ASPPPooling(in_channels, out_channels)) 213 | 214 | self.convs = nn.ModuleList(modules) 215 | 216 | self.project = nn.Sequential( 217 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 218 | nn.BatchNorm2d(out_channels), 219 | nn.ReLU(), 220 | nn.Dropout(0.5)) 221 | 222 | def forward(self, x): 223 | res = [] 224 | for conv in self.convs: 225 | res.append(conv(x)) 226 | res = torch.cat(res, dim=1) 227 | return self.project(res) 228 | 229 | 230 | class outconv(nn.Module): 231 | def __init__(self, in_ch, out_ch): 232 | super(outconv, self).__init__() 233 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 234 | 235 | def forward(self, x): 236 | x = self.conv(x) 237 | return x 238 | 239 | def group_std(x, groups = 32, eps = 1e-5): 240 | N, C, H, W = x.size() 241 | x = torch.reshape(x, (N, groups, C // groups, H, W)) 242 | var = torch.var(x, dim = (2, 3, 4), keepdim = True).expand_as(x) 243 | return torch.reshape(torch.sqrt(var + eps), (N, C, H, W)) 244 | 245 | class EvoNorm(nn.Module): 246 | def __init__(self, input, non_linear = True, version = 'S0', momentum = 0.9, eps = 1e-5, training = True): 247 | super(EvoNorm, self).__init__() 248 | self.non_linear = non_linear 249 | self.version = version 250 | self.training = training 251 | self.momentum = momentum 252 | self.eps = eps 253 | if self.version not in ['B0', 'S0']: 254 | raise ValueError("Invalid EvoNorm version") 255 | self.insize = input 256 | self.gamma = nn.Parameter(torch.ones(1, self.insize, 1, 1)) 257 | self.beta = nn.Parameter(torch.zeros(1, self.insize, 1, 1)) 258 | if self.non_linear: 259 | self.v = nn.Parameter(torch.ones(1,self.insize,1,1)) 260 | self.register_buffer('running_var', torch.ones(1, self.insize, 1, 1)) 261 | 262 | self.reset_parameters() 263 | 264 | def reset_parameters(self): 265 | self.running_var.fill_(1) 266 | 267 | def forward(self, x): 268 | if x.dim() != 4: 269 | raise ValueError('expected 4D input (got {}D input)' 270 | .format(x.dim())) 271 | if self.version == 'S0': 272 | if self.non_linear: 273 | num = x * torch.sigmoid(self.v * x) 274 | return num / group_std(x, eps = self.eps) * self.gamma + self.beta 275 | else: 276 | return x * self.gamma + self.beta 277 | if self.version == 'B0': 278 | if self.training: 279 | var = torch.var(x, dim = (0, 2, 3), unbiased = False, keepdim = True).reshape(1, x.size(1), 1, 1) 280 | with torch.no_grad(): 281 | self.running_var.copy_(self.momentum * self.running_var + (1 - self.momentum) * var) 282 | else: 283 | var = self.running_var 284 | 285 | if self.non_linear: 286 | den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x, eps = self.eps)) 287 | return x / den * self.gamma + self.beta 288 | else: 289 | return x * self.gamma + self.beta 290 | 291 | def convert_bn_to_instancenorm(model): 292 | for child_name, child in model.named_children(): 293 | if isinstance(child, nn.BatchNorm2d): 294 | setattr(model, child_name, nn.InstanceNorm2d(child.num_features)) 295 | else: 296 | convert_bn_to_instancenorm(child) 297 | 298 | def convert_bn_to_evonorm(model): 299 | for child_name, child in model.named_children(): 300 | if isinstance(child, nn.BatchNorm2d): 301 | setattr(model, child_name, EvoNorm(child.num_features)) 302 | elif isinstance(child, nn.ReLU): 303 | setattr(model, child_name, nn.Identity()) 304 | else: 305 | convert_bn_to_evonorm(child) 306 | 307 | def convert_bn_to_groupnorm(model, num_groups=32): 308 | for child_name, child in model.named_children(): 309 | if isinstance(child, nn.BatchNorm2d): 310 | setattr(model, child_name, nn.GroupNorm(num_groups=num_groups, num_channels=child.num_features)) 311 | else: 312 | convert_bn_to_groupnorm(child, num_groups=num_groups) 313 | 314 | -------------------------------------------------------------------------------- /learning/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms.functional as TF 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | from learning.minicity import MiniCity 10 | from learning.model import convert_bn_to_instancenorm, convert_bn_to_evonorm, convert_bn_to_groupnorm, DeepLabHead, UNet 11 | 12 | """ 13 | ==================== 14 | Data Loader Function 15 | ==================== 16 | """ 17 | def get_dataloader(dataset, args): 18 | #args = args 19 | 20 | def test_trans(image, mask=None): 21 | # Resize, 1 for Image.LANCZOS 22 | image = TF.resize(image, args.test_size, interpolation=1) 23 | # From PIL to Tensor 24 | image = TF.to_tensor(image) 25 | # Normalize 26 | image = TF.normalize(image, args.dataset_mean, args.dataset_std) 27 | 28 | if mask: 29 | # Resize, 0 for Image.NEAREST 30 | mask = TF.resize(mask, args.test_size, interpolation=0) 31 | mask = np.array(mask, np.uint8) # PIL Image to numpy array 32 | mask = torch.from_numpy(mask) # Numpy array to tensor 33 | return image, mask 34 | else: 35 | return image 36 | 37 | def train_trans(image, mask): 38 | # Generate random parameters for augmentation 39 | bf = np.random.uniform(1-args.colorjitter_factor,1+args.colorjitter_factor) 40 | cf = np.random.uniform(1-args.colorjitter_factor,1+args.colorjitter_factor) 41 | sf = np.random.uniform(1-args.colorjitter_factor,1+args.colorjitter_factor) 42 | hf = np.random.uniform(-args.colorjitter_factor,+args.colorjitter_factor) 43 | pflip = np.random.randint(0,1) > 0.5 44 | 45 | # Random scaling 46 | scale_factor = np.random.uniform(0.75, 2.0) 47 | scaled_train_size = [int(element * scale_factor) for element in args.train_size] 48 | 49 | # Resize, 1 for Image.LANCZOS 50 | image = TF.resize(image, scaled_train_size, interpolation=1) 51 | # Resize, 0 for Image.NEAREST 52 | mask = TF.resize(mask, scaled_train_size, interpolation=0) 53 | 54 | # Random cropping 55 | if not args.train_size == args.crop_size: 56 | if image.size[1] <= args.crop_size[0]: # PIL image: (width, height) vs. args.size: (height, width) 57 | pad_h = args.crop_size[0] - image.size[1] + 1 58 | pad_w = args.crop_size[1] - image.size[0] + 1 59 | image = ImageOps.expand(image, border=(0, 0, pad_w, pad_h), fill=0) 60 | mask = ImageOps.expand(mask, border=(0, 0, pad_w, pad_h), fill=19) 61 | 62 | # From PIL to Tensor 63 | image = TF.to_tensor(image) 64 | mask = TF.to_tensor(mask) 65 | h, w = image.size()[1], image.size()[2] #scaled_train_size #args.train_size 66 | th, tw = args.crop_size 67 | 68 | i = np.random.randint(0, h - th) 69 | j = np.random.randint(0, w - tw) 70 | image_crop = image[:,i:i+th,j:j+tw] 71 | mask_crop = mask[:,i:i+th,j:j+tw] 72 | 73 | image = TF.to_pil_image(image_crop) 74 | mask = TF.to_pil_image(mask_crop[0,:,:]) 75 | 76 | # H-flip 77 | if pflip == True and args.hflip == True: 78 | image = TF.hflip(image) 79 | mask = TF.hflip(mask) 80 | 81 | # Color jitter 82 | image = TF.adjust_brightness(image, bf) 83 | image = TF.adjust_contrast(image, cf) 84 | image = TF.adjust_saturation(image, sf) 85 | image = TF.adjust_hue(image, hf) 86 | 87 | # From PIL to Tensor 88 | image = TF.to_tensor(image) 89 | 90 | # Normalize 91 | image = TF.normalize(image, args.dataset_mean, args.dataset_std) 92 | 93 | # Convert ids to train_ids 94 | mask = np.array(mask, np.uint8) # PIL Image to numpy array 95 | mask = torch.from_numpy(mask) # Numpy array to tensor 96 | 97 | return image, mask 98 | 99 | trainset = dataset(args.dataset_path, split='train', transforms=train_trans) 100 | valset = dataset(args.dataset_path, split='val', transforms=test_trans) 101 | testset = dataset(args.dataset_path, split='test', transforms=test_trans) 102 | dataloaders = {} 103 | dataloaders['train'] = torch.utils.data.DataLoader(trainset, 104 | batch_size=args.batch_size, shuffle=True, 105 | pin_memory=args.pin_memory, num_workers=args.num_workers) 106 | dataloaders['val'] = torch.utils.data.DataLoader(valset, 107 | batch_size=args.batch_size, shuffle=False, 108 | pin_memory=args.pin_memory, num_workers=args.num_workers) 109 | dataloaders['test'] = torch.utils.data.DataLoader(testset, 110 | batch_size=args.batch_size, shuffle=False, 111 | pin_memory=args.pin_memory, num_workers=args.num_workers) 112 | 113 | return dataloaders 114 | 115 | """ 116 | ==================== 117 | Focal Loss 118 | code reference: https://github.com/clcarwin/focal_loss_pytorch 119 | ==================== 120 | """ 121 | 122 | class FocalLoss(nn.Module): 123 | def __init__(self, gamma=0, alpha=None, size_average=True): 124 | super(FocalLoss, self).__init__() 125 | self.gamma = gamma 126 | self.alpha = alpha 127 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 128 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 129 | self.size_average = size_average 130 | 131 | def forward(self, input, target): 132 | if input.dim()>2: 133 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 134 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 135 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 136 | target = target.view(-1,1) 137 | 138 | logpt = F.log_softmax(input) 139 | logpt = logpt.gather(1,target) 140 | logpt = logpt.view(-1) 141 | pt = Variable(logpt.data.exp()) 142 | 143 | if self.alpha is not None: 144 | if self.alpha.type()!=input.data.type(): 145 | self.alpha = self.alpha.type_as(input.data) 146 | at = self.alpha.gather(0,target.data.view(-1)) 147 | logpt = logpt * Variable(at) 148 | 149 | loss = -1 * (1-pt)**self.gamma * logpt 150 | if self.size_average: return loss.mean() 151 | else: return loss.sum() 152 | 153 | 154 | """ 155 | ==================== 156 | Loss Function 157 | ==================== 158 | """ 159 | 160 | def get_lossfunc(dataset, args): 161 | # Define loss, optimizer and scheduler 162 | if args.loss == 'ce': 163 | if args.model == 'DeepLabv3_resnet50': 164 | criterion = nn.CrossEntropyLoss(ignore_index=dataset.voidClass) 165 | else: 166 | criterion = nn.CrossEntropyLoss() 167 | elif args.loss == 'weighted_ce': 168 | # Class-Weighted loss 169 | class_weight = [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507] 170 | class_weight.append(0) #for void-class 171 | class_weight = torch.FloatTensor(class_weight).cuda() 172 | criterion = nn.CrossEntropyLoss(weight=class_weight, ignore_index=dataset.voidClass) 173 | elif args.loss =='focal': 174 | criterion = FocalLoss(gamma=args.focal_gamma) 175 | else: 176 | raise NameError('Loss is not defined!') 177 | 178 | return criterion 179 | 180 | 181 | """ 182 | ==================== 183 | Model Architecture 184 | ==================== 185 | """ 186 | 187 | def get_model(dataset, args): 188 | if args.model == 'UNet': 189 | """ U-Net baeline """ 190 | model = UNet(len(dataset.validClasses), batchnorm=True) 191 | elif args.model == 'DeepLabv3_resnet50': 192 | """ DeepLab v3 ResNet50 """ 193 | model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False) 194 | model.classifier = DeepLabHead(2048, len(dataset.validClasses)) 195 | elif args.model == 'DeepLabv3_resnet101': 196 | """ DeepLab v3 ResNet101 """ 197 | model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False) 198 | model.classifier = DeepLabHead(2048, len(dataset.validClasses)) 199 | else: 200 | raise NameError('Model is not defined!') 201 | 202 | # Normalization Layer 203 | if args.norm == 'instance': 204 | convert_bn_to_instancenorm(model) 205 | elif args.norm == 'evo': 206 | convert_bn_to_evonorm(model) 207 | elif args.norm == 'group': 208 | convert_bn_to_groupnorm(model, num_groups=32) 209 | elif args.norm == 'batch': 210 | pass 211 | else: 212 | raise NameError('Normalization is not defined!') 213 | 214 | return model 215 | 216 | """ 217 | ==================== 218 | random bbox function for cutmix 219 | ==================== 220 | """ 221 | 222 | def rand_bbox(size, lam): 223 | W = size[2] 224 | H = size[3] 225 | cut_rat = np.sqrt(1. - lam) 226 | cut_w = np.int(W * cut_rat) 227 | cut_h = np.int(H * cut_rat) 228 | 229 | # uniform 230 | cx = np.random.randint(W) 231 | cy = np.random.randint(H) 232 | 233 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 234 | bby1 = np.clip(cy - cut_h // 2, 0, H) 235 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 236 | bby2 = np.clip(cy + cut_h // 2, 0, H) 237 | 238 | return bbx1, bby1, bbx2, bby2 239 | 240 | """ 241 | ==================== 242 | Custom copyblob function for copyblob data augmentation 243 | ==================== 244 | """ 245 | 246 | def copyblob(src_img, src_mask, dst_img, dst_mask, src_class, dst_class): 247 | mask_hist_src, _ = np.histogram(src_mask.numpy().ravel(), len(MiniCity.validClasses)-1, [0, len(MiniCity.validClasses)-1]) 248 | mask_hist_dst, _ = np.histogram(dst_mask.numpy().ravel(), len(MiniCity.validClasses)-1, [0, len(MiniCity.validClasses)-1]) 249 | 250 | if mask_hist_src[src_class] != 0 and mask_hist_dst[dst_class] != 0: 251 | """ copy src blob and paste to any dst blob""" 252 | mask_y, mask_x = src_mask.size() 253 | """ get src object's min index""" 254 | src_idx = np.where(src_mask==src_class) 255 | 256 | src_idx_sum = list(src_idx[0][i] + src_idx[1][i] for i in range(len(src_idx[0]))) 257 | src_idx_sum_min_idx = np.argmin(src_idx_sum) 258 | src_idx_min = src_idx[0][src_idx_sum_min_idx], src_idx[1][src_idx_sum_min_idx] 259 | 260 | """ get dst object's random index""" 261 | dst_idx = np.where(dst_mask==dst_class) 262 | rand_idx = np.random.randint(len(dst_idx[0])) 263 | target_pos = dst_idx[0][rand_idx], dst_idx[1][rand_idx] 264 | 265 | src_dst_offset = tuple(map(lambda x, y: x - y, src_idx_min, target_pos)) 266 | dst_idx = tuple(map(lambda x, y: x - y, src_idx, src_dst_offset)) 267 | 268 | for i in range(len(dst_idx[0])): 269 | dst_idx[0][i] = (min(dst_idx[0][i], mask_y-1)) 270 | for i in range(len(dst_idx[1])): 271 | dst_idx[1][i] = (min(dst_idx[1][i], mask_x-1)) 272 | 273 | dst_mask[dst_idx] = src_class 274 | dst_img[:, dst_idx[0], dst_idx[1]] = src_img[:, src_idx[0], src_idx[1]] -------------------------------------------------------------------------------- /minicity/class_pixel_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/semantic-segmentation-tutorial-pytorch/cfbc042ad65351ca0c603c4d76ba02483f7027b4/minicity/class_pixel_distribution.png -------------------------------------------------------------------------------- /minicity/copyblob.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/semantic-segmentation-tutorial-pytorch/cfbc042ad65351ca0c603c4d76ba02483f7027b4/minicity/copyblob.PNG -------------------------------------------------------------------------------- /minicity/cutmix.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/semantic-segmentation-tutorial-pytorch/cfbc042ad65351ca0c603c4d76ba02483f7027b4/minicity/cutmix.PNG -------------------------------------------------------------------------------- /minicity/leaderboard.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/semantic-segmentation-tutorial-pytorch/cfbc042ad65351ca0c603c4d76ba02483f7027b4/minicity/leaderboard.PNG -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser(description='VIPriors Segmentation baseline training script') 5 | 6 | # model architecture & checkpoint 7 | parser.add_argument('--model', metavar='[UNet, DeepLabv3_resnet50, DeepLabv3_resnet101]', 8 | default='DeepLabv3_resnet50', type=str, help='model') 9 | parser.add_argument('--save_path', metavar='path/to/save_results', default='./baseline_run', 10 | type=str, help='path to results saved') 11 | parser.add_argument('--weights', metavar='path/to/checkpoint', default=None, 12 | type=str, help='resume training from checkpoint') 13 | parser.add_argument('--norm', metavar='[batch, instance, evo, group]', default='batch', 14 | type=str, help='replace batch norm with other norm') 15 | 16 | # data loading 17 | parser.add_argument('--dataset_path', metavar='path/to/minicity/root', default='./minicity', 18 | type=str, help='path to dataset (ends with /minicity)') 19 | parser.add_argument('--pin_memory', metavar='[True,False]', default=True, 20 | type=bool, help='pin memory on GPU') 21 | parser.add_argument('--num_workers', metavar='8', default=8, type=int, 22 | help='number of dataloader workers') 23 | 24 | # data augmentation hyper-parameters 25 | parser.add_argument('--colorjitter_factor', metavar='0.3', default=0.3, 26 | type=float, help='data augmentation: color jitter factor') 27 | parser.add_argument('--hflip', metavar='[True,False]', default=True, 28 | type=float, help='data augmentation: random horizontal flip') 29 | parser.add_argument('--crop_size', default=[768, 768], nargs='+', type=int, help='data augmentation: random crop size') 30 | parser.add_argument('--train_size', default=[1024, 2048], nargs='+', type=int, help='image size during training') 31 | parser.add_argument('--test_size', default=[1024, 2048], nargs='+', type=int, help='image size during test') 32 | parser.add_argument('--dataset_mean', metavar='[0.485, 0.456, 0.406]', 33 | default=[0.485, 0.456, 0.406], type=list, 34 | help='mean for normalization') 35 | parser.add_argument('--dataset_std', metavar='[0.229, 0.224, 0.225]', 36 | default=[0.229, 0.224, 0.225], type=list, 37 | help='std for normalization') 38 | 39 | # training hyper-parameters 40 | parser.add_argument('--batch_size', default=4, type=int, help='batch size') 41 | parser.add_argument('--lr_init', metavar='1e-2', default=1e-2, type=float, 42 | help='initial learning rate') 43 | parser.add_argument('--lr_momentum', metavar='0.9', default=0.9, type=float, 44 | help='momentum for SGD optimizer') 45 | parser.add_argument('--lr_weight_decay', metavar='1e-4', default=1e-4, type=float, 46 | help='weight decay for SGD optimizer') 47 | parser.add_argument('--epochs', metavar='200', default=200, type=int, 48 | help='number of training epochs') 49 | parser.add_argument('--seed', metavar='42', default=None, type=int, 50 | help='random seed to use') 51 | parser.add_argument('--loss', metavar='[ce, weighted_ce, focal]', default='ce', 52 | type=str, help='loss criterion') 53 | parser.add_argument('--focal_gamma', default=2.0, type=float, help='initial learning rate') 54 | 55 | # additional training tricks 56 | parser.add_argument('--cutmix', action='store_true', help='cutmix augmentation') 57 | parser.add_argument('--copyblob', action='store_true', help='copyblob augmentation') 58 | 59 | # inference options 60 | parser.add_argument('--predict', action='store_true') 61 | parser.add_argument('--mst', action='store_true', help='multi-scale testing') 62 | #parser.add_argument('--minorcrop', action='store_true', help='minor crop augmentation') 63 | 64 | args = parser.parse_args() 65 | return args -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.5 2 | attrs==19.3.0 3 | backcall==0.1.0 4 | bleach==3.3.0 5 | cycler==0.10.0 6 | decorator==4.4.2 7 | defusedxml==0.6.0 8 | entrypoints==0.3 9 | imageio==2.8.0 10 | imgaug==0.2.6 11 | importlib-metadata==1.6.0 12 | inplace-abn==1.0.12.dev2+g5b017ac 13 | ipykernel==5.2.0 14 | ipython==7.13.0 15 | ipython-genutils==0.2.0 16 | ipywidgets==7.5.1 17 | jedi==0.16.0 18 | Jinja2==2.11.1 19 | joblib==0.14.1 20 | jsonschema==3.2.0 21 | jupyter==1.0.0 22 | jupyter-client==6.1.2 23 | jupyter-console==6.1.0 24 | jupyter-core==4.6.3 25 | kiwisolver==1.2.0 26 | MarkupSafe==1.1.1 27 | matplotlib==3.2.1 28 | mistune==0.8.4 29 | nbconvert==5.6.1 30 | nbformat==5.0.5 31 | networkx==2.4 32 | notebook==6.4.1 33 | numpy==1.18.2 34 | opencv-python-headless==4.2.0.34 35 | pandocfilters==1.4.2 36 | parso==0.6.2 37 | pexpect==4.8.0 38 | pickleshare==0.7.5 39 | Pillow==8.3.2 40 | prometheus-client==0.7.1 41 | prompt-toolkit==3.0.5 42 | ptyprocess==0.6.0 43 | Pygments==2.6.1 44 | pyparsing==2.4.7 45 | pyrsistent==0.16.0 46 | python-dateutil==2.8.1 47 | PyWavelets==1.1.1 48 | PyYAML==5.3.1 49 | pyzmq==19.0.0 50 | qtconsole==4.7.2 51 | QtPy==1.9.0 52 | scikit-image==0.16.2 53 | scikit-learn==0.22.2.post1 54 | scipy==1.4.1 55 | Send2Trash==1.5.0 56 | six==1.14.0 57 | terminado==0.8.3 58 | testpath==0.4.4 59 | torch==1.4.0 60 | torchvision==0.5.0 61 | tornado==6.0.4 62 | tqdm==4.45.0 63 | traitlets==4.3.3 64 | wcwidth==0.1.9 65 | webencodings==0.5.1 66 | widgetsnbextension==3.5.1 67 | zipp==3.1.0 68 | -------------------------------------------------------------------------------- /results.txt: -------------------------------------------------------------------------------- 1 | IoU Class: 0.6069831962012341 2 | iIoU Class: 0.31276752216405734 3 | IoU Category: 0.8259882068129724 4 | iIoU Category: 0.624785111637751 5 | Accuracy: 0.8056700706481934 --------------------------------------------------------------------------------