├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── criterion ├── __init__.py └── criterion.py ├── data ├── README.md ├── dataset │ ├── JPEGImages │ │ ├── 2007_000032.jpg │ │ ├── 2007_000039.jpg │ │ ├── 2007_000063.jpg │ │ ├── 2007_000068.jpg │ │ ├── 2007_000121.jpg │ │ ├── 2007_000170.jpg │ │ ├── 2007_000241.jpg │ │ ├── 2007_000243.jpg │ │ ├── 2007_000250.jpg │ │ ├── 2007_000256.jpg │ │ ├── 2007_000333.jpg │ │ └── 2007_000363.jpg │ └── gtFine │ │ ├── 2007_000032.png │ │ ├── 2007_000039.png │ │ ├── 2007_000063.png │ │ ├── 2007_000068.png │ │ ├── 2007_000121.png │ │ ├── 2007_000170.png │ │ ├── 2007_000241.png │ │ ├── 2007_000243.png │ │ ├── 2007_000250.png │ │ ├── 2007_000256.png │ │ ├── 2007_000333.png │ │ └── 2007_000363.png ├── test │ ├── image.txt │ └── label.txt ├── train │ ├── image.txt │ └── label.txt └── val │ ├── image.txt │ └── label.txt ├── dataloader ├── README.md ├── __init__.py ├── dataset.py ├── functional.py └── transform.py ├── eval.py ├── networks ├── __init__.py ├── erfnet.py ├── fcn.py ├── pspnet.py ├── segnet.py ├── unet.py └── utils.py ├── options ├── README.md ├── __init__.py ├── test_options.py └── train_options.py ├── save_models ├── README.md └── fcn16 │ ├── automated_log.txt │ ├── model.txt │ └── opts.txt ├── split_train_val.py ├── test.py ├── train.py └── utils ├── README.md ├── cityscapes ├── addToConfusionMatrix.c ├── addToConfusionMatrix.pyx ├── addToConfusionMatrix_impl.c ├── helpers │ ├── annotation.py │ ├── csHelpers.py │ └── labels.py └── setup.py ├── evalIoU.py ├── eval_weight.py └── label2Img.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Neo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Semantic-Segmentation 2 | ## Reference 3 | - *[ERFNet](https://github.com/Eromera/erfnet_pytorch)* 4 | - *[PiWise](https://github.com/bodokaiser/piwise)* 5 | ## Network 6 | - fcn 7 | - segnet 8 | - erfnet 9 | - pspnet 10 | - unet 11 | ## Environment 12 | - pytorch 0.2.0 13 | - torchvision 0.2.0 14 | - python 3.5.2 15 | - cython 16 | 17 | ## Download 18 | Recommand you use virtualenv. 19 | > virtualenv -p python3 YourVirtualEnv --no-site-packages 20 | 21 | > git clone https://github.com/mapleneverfade/pytorch-semantic-segmentation.git 22 | ## Install CSUPPORT (Options) 23 | To speed up calculating IoU: 24 | > cd ./utils/cityscape/ 25 | 26 | > python setup.py install 27 | ## Train 28 | If gpu is available: 29 | 30 | `CUDA_VISIBLE_DEVICES=0 python3 train.py --datadir ./data/ --savedir ./save_models/ --model segnet` 31 | else: 32 | 33 | `python3 train.py --cuda False --datadir ./data/ --savedir ./save_models/ --model segnet` 34 | 35 | There are some example pictures in ./data, so you could just run the command to test whether it work. 36 | ### More Training Options 37 | --model     model to use ['segnet fcn8 fcn16 fcn32 erfnet pspnet unet'] [default=segnet]· 38 | --datadir   where you store [train/val].In my case, './data' should have subfolders './data/train/' and './data/val/' where store image.txt' and 'label.txt'. [default='./data/'] 39 | --savedir   path to savedir [default='./save_models/'] 40 | --lr         learning rate [default=5e-4]· 41 | --num-epochs epochs [default=150]· 42 | --num-classes number of labels, pascalvoc is 21, cityscape is 20. change it when you training your own dataset.[default=21] 43 | 44 | ## Test 45 | `CUDA_VISIBLE_DEVICES=0 python3 test.py --datadir ./data/test --model segnet --model-dir ./save_models/segnet_50.pth --save-dir ./results/` 46 | "--model-dir" is the path to your trained model 47 | ### More Testing Options 48 | --model [default=segnet] 49 | --model-dir path to your trained model. For example './save_models/segnet/segnet_epoch_5.pth' 50 | --datadir [default='./data/test/'] 51 | --num-classes number of labels, pascalvoc is 21, cityscape is 20. change it when you training your own dataset.[default=21] 52 | --size reshape size [default=(672,480)] 53 | --savedir [default='./results/'] 54 | ## Options 55 | split_train_val.py mainly to split original [image.txt,label.txt] into './train/[image.txt,label.txt]' and './val/[image.txt,label.txt]' 56 | 57 | Options for split: 58 | 59 | --savedir [default='./data/'] 60 | --imagepath path to your own image.txt 61 | --labelpath path to your own label.txt 62 | --random-state random seed [default=10000] 63 | --train-size     ratio of train set [default=0.7] 64 | --val-size ratio of val set [default=0.3] 65 | For example, if your original folder like this : 66 | 67 | ./data 68 | image.txt 69 | label.txt 70 | After run 'python3 split_train_val.py --savedir ./data --imagepath ./data/image.txt --labelpath ./data/label.txt', you will see this: 71 | 72 | ./data 73 | ./train 74 | image.txt 75 | label.txt 76 | ./val 77 | image.txt 78 | label.txt 79 | 80 | ## Detail 81 | 1. More train and test options see ./options 82 | 2. datadir include image.txt and label.txt, and the default datasets is for [pascalvoc](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/segexamples/index.html) 83 | 3. If you want train your own data. Remember to modify your data labels in ./utils/cityscapes/helpers/labels,and NUM_CLASSES in options. 84 | 4. You can change the way the model load data in ./dataloader/ to fit your format of dataset. 85 | 5. test.py would calculate the mIoU and save the result of segmented picture in --savedir. 86 | 87 | ## ToDo 88 | 1. More Network 89 | 2. Clean up the code. 90 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/__init__.py -------------------------------------------------------------------------------- /criterion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/criterion/__init__.py -------------------------------------------------------------------------------- /criterion/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CrossEntropyLoss2d(nn.Module): 6 | def __init__(self, weight=None): 7 | super().__init__() 8 | self.loss = nn.NLLLoss2d(weight) 9 | 10 | def forward(self, outputs, targets): 11 | #torch version >0.2 F.log_softmax(input, dim=?) 12 | #dim (int): A dimension along which log_softmax will be computed. 13 | try: 14 | return self.loss(F.log_softmax(outputs,dim=1), targets) # if torch version >=0.3 15 | except TypeError as t: 16 | return self.loss(F.log_softmax(outputs), targets) #else 17 | 18 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | > ./train and ./test dir including image.txt and label.txt which stored the path to image and label data. 3 | > Data do not have to stored in this dir. 4 | -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000032.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000039.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000063.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000068.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000121.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000121.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000170.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000170.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000241.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000241.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000243.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000243.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000250.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000250.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000256.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000256.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000333.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000333.jpg -------------------------------------------------------------------------------- /data/dataset/JPEGImages/2007_000363.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/JPEGImages/2007_000363.jpg -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000032.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000039.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000063.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000063.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000068.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000068.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000121.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000121.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000170.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000170.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000241.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000241.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000243.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000243.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000250.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000250.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000256.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000333.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000333.png -------------------------------------------------------------------------------- /data/dataset/gtFine/2007_000363.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/data/dataset/gtFine/2007_000363.png -------------------------------------------------------------------------------- /data/test/image.txt: -------------------------------------------------------------------------------- 1 | ./data/dataset/JPEGImages/2007_000032.jpg 2 | ./data/dataset/JPEGImages/2007_000039.jpg 3 | ./data/dataset/JPEGImages/2007_000063.jpg 4 | ./data/dataset/JPEGImages/2007_000068.jpg 5 | ./data/dataset/JPEGImages/2007_000121.jpg 6 | ./data/dataset/JPEGImages/2007_000170.jpg 7 | ./data/dataset/JPEGImages/2007_000241.jpg 8 | ./data/dataset/JPEGImages/2007_000243.jpg 9 | ./data/dataset/JPEGImages/2007_000250.jpg 10 | ./data/dataset/JPEGImages/2007_000256.jpg 11 | ./data/dataset/JPEGImages/2007_000333.jpg 12 | ./data/dataset/JPEGImages/2007_000363.jpg 13 | -------------------------------------------------------------------------------- /data/test/label.txt: -------------------------------------------------------------------------------- 1 | ./data/dataset/gtFine/2007_000032.png 2 | ./data/dataset/gtFine/2007_000039.png 3 | ./data/dataset/gtFine/2007_000063.png 4 | ./data/dataset/gtFine/2007_000068.png 5 | ./data/dataset/gtFine/2007_000121.png 6 | ./data/dataset/gtFine/2007_000170.png 7 | ./data/dataset/gtFine/2007_000241.png 8 | ./data/dataset/gtFine/2007_000243.png 9 | ./data/dataset/gtFine/2007_000250.png 10 | ./data/dataset/gtFine/2007_000256.png 11 | ./data/dataset/gtFine/2007_000333.png 12 | ./data/dataset/gtFine/2007_000363.png 13 | -------------------------------------------------------------------------------- /data/train/image.txt: -------------------------------------------------------------------------------- 1 | ./data/dataset/JPEGImages/2007_000032.jpg 2 | ./data/dataset/JPEGImages/2007_000039.jpg 3 | ./data/dataset/JPEGImages/2007_000063.jpg 4 | ./data/dataset/JPEGImages/2007_000068.jpg 5 | ./data/dataset/JPEGImages/2007_000121.jpg 6 | ./data/dataset/JPEGImages/2007_000170.jpg 7 | ./data/dataset/JPEGImages/2007_000241.jpg 8 | ./data/dataset/JPEGImages/2007_000243.jpg 9 | ./data/dataset/JPEGImages/2007_000250.jpg 10 | ./data/dataset/JPEGImages/2007_000256.jpg 11 | ./data/dataset/JPEGImages/2007_000333.jpg 12 | ./data/dataset/JPEGImages/2007_000363.jpg -------------------------------------------------------------------------------- /data/train/label.txt: -------------------------------------------------------------------------------- 1 | ./data/dataset/gtFine/2007_000032.png 2 | ./data/dataset/gtFine/2007_000039.png 3 | ./data/dataset/gtFine/2007_000063.png 4 | ./data/dataset/gtFine/2007_000068.png 5 | ./data/dataset/gtFine/2007_000121.png 6 | ./data/dataset/gtFine/2007_000170.png 7 | ./data/dataset/gtFine/2007_000241.png 8 | ./data/dataset/gtFine/2007_000243.png 9 | ./data/dataset/gtFine/2007_000250.png 10 | ./data/dataset/gtFine/2007_000256.png 11 | ./data/dataset/gtFine/2007_000333.png 12 | ./data/dataset/gtFine/2007_000363.png -------------------------------------------------------------------------------- /data/val/image.txt: -------------------------------------------------------------------------------- 1 | ./data/dataset/JPEGImages/2007_000068.jpg 2 | ./data/dataset/JPEGImages/2007_000121.jpg 3 | ./data/dataset/JPEGImages/2007_000243.jpg 4 | ./data/dataset/JPEGImages/2007_000256.jpg 5 | -------------------------------------------------------------------------------- /data/val/label.txt: -------------------------------------------------------------------------------- 1 | ./data/dataset/gtFine/2007_000068.png 2 | ./data/dataset/gtFine/2007_000121.png 3 | ./data/dataset/gtFine/2007_000243.png 4 | ./data/dataset/gtFine/2007_000256.png 5 | -------------------------------------------------------------------------------- /dataloader/README.md: -------------------------------------------------------------------------------- 1 | # Something about dataset. 2 | 1. NeoData in dataset.py, mainly to load the data(both image and label), transform data which defined by transform.py. 3 | 4 | 2. You can change the way it work. I stored all the image path in '../data/train/image.txt',label path in '../data/train/label.txt' 5 | 6 | 3. Transform mainly include **Resize**, **Crop** and **Flip** the data. You would decide whether it is necessary to resize or not. 7 | 8 | 4. For my data it is too large, so i resize them before crop. For training data, using **RandomCrop**, for val data, using **CenterCrop**. 9 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | EXTENSIONS = ['.jpg', '.png','.JPG','.PNG'] 7 | 8 | def load_image(file): 9 | return Image.open(file) 10 | 11 | def is_image(filename): 12 | return any(filename.endswith(ext) for ext in EXTENSIONS) 13 | 14 | def image_path(root, basename, extension): 15 | return os.path.join(root, '{}{}'.format(basename,extension)) 16 | 17 | def image_path_city(root, name): 18 | return os.path.join(root, '{}'.format(name)) 19 | 20 | def image_basename(filename): 21 | return os.path.basename(os.path.splitext(filename)[0]) 22 | 23 | class NeoData(Dataset): 24 | def __init__(self, imagepath=None, labelpath=None, transform=None): 25 | # make sure label match with image 26 | self.transform = transform 27 | assert os.path.exists(imagepath), "{} not exists !".format(imagepath) 28 | assert os.path.exists(labelpath), "{} not exists !".format(labelpath) 29 | self.image = [] 30 | self.label= [] 31 | with open(imagepath,'r') as f: 32 | for line in f: 33 | self.image.append(line.strip()) 34 | with open(labelpath,'r') as f: 35 | for line in f: 36 | self.label.append(line.strip()) 37 | 38 | def __getitem__(self, index): 39 | filename = self.image[index] 40 | filenameGt = self.label[index] 41 | 42 | with open(filename, 'rb') as f: 43 | image = load_image(f).convert('RGB') 44 | with open(filenameGt, 'rb') as f: 45 | label = load_image(f).convert('P') 46 | if self.transform is not None: 47 | image, label = self.transform(image, label) 48 | return image, label 49 | 50 | def __len__(self): 51 | return len(self.image) 52 | 53 | class NeoData_test(Dataset): 54 | def __init__(self, imagepath=None, labelpath=None, transform=None): 55 | self.transform = transform 56 | 57 | assert os.path.exists(imagepath), "{} not exists !".format(imagepath) 58 | assert os.path.exists(labelpath), "{} not exists !".format(labelpath) 59 | 60 | self.image = [] 61 | self.label= [] 62 | with open(imagepath,'r') as f: 63 | for line in f: 64 | self.image.append(line.strip()) 65 | with open(labelpath,'r') as f: 66 | for line in f: 67 | self.label.append(line.strip()) 68 | print("Length of test data is {}".format(len(self.image))) 69 | def __getitem__(self, index): 70 | filename = self.image[index] 71 | filenameGt = self.label[index] 72 | 73 | with open(filename, 'rb') as f: # advance 74 | image = load_image(f).convert('RGB') 75 | with open(filenameGt, 'rb') as f: 76 | label = load_image(f).convert('P') 77 | 78 | if self.transform is not None: 79 | image_tensor, label_tensor, img = self.transform(image, label) 80 | 81 | return (image_tensor, label_tensor, np.array(img)) #return original image, in order to show segmented area in origin 82 | 83 | def __len__(self): 84 | return len(self.image) 85 | 86 | -------------------------------------------------------------------------------- /dataloader/functional.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import numpy as np 5 | from PIL import Image, ImageOps 6 | import numbers 7 | from torchvision.transforms import Pad,RandomHorizontalFlip 8 | from torchvision.transforms import ToTensor, ToPILImage 9 | 10 | 11 | def _is_pil_image(img): 12 | return isinstance(img, Image.Image) 13 | 14 | def crop(img, i, j, h, w): 15 | """Crop the given PIL Image. 16 | Args: 17 | img (PIL Image): Image to be cropped. 18 | i: Upper pixel coordinate. 19 | j: Left pixel coordinate. 20 | h: Height of the cropped image. 21 | w: Width of the cropped image. 22 | Returns: 23 | PIL Image: Cropped image. 24 | """ 25 | if not _is_pil_image(img): 26 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 27 | 28 | return img.crop((j, i, j + w, i + h)) 29 | 30 | class RandomCrop(object): 31 | """Crop the given PIL Image at a random location. 32 | Args: 33 | size (sequence or int): Desired output size of the crop. If size is an 34 | int instead of sequence like (h, w), a square crop (size, size) is 35 | made. 36 | padding (int or sequence, optional): Optional padding on each border 37 | of the image. Default is 0, i.e no padding. If a sequence of length 38 | 4 is provided, it is used to pad left, top, right, bottom borders 39 | respectively. 40 | 41 | #对pytorch包内RandomCrop做了修改,可以同时处理image和label,保证为同一区域。 42 | """ 43 | def __init__(self, size, padding=0): 44 | if isinstance(size, numbers.Number): 45 | self.size = (int(size), int(size)) 46 | else: 47 | self.size = size 48 | self.padding = padding 49 | 50 | @staticmethod 51 | def get_params(img, output_size): 52 | """Get parameters for ``crop`` for a random crop. 53 | Args: 54 | img (PIL Image): Image to be cropped. 55 | output_size (tuple): Expected output size of the crop. 56 | Returns: 57 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 58 | """ 59 | w, h = img.size 60 | tw, th = output_size 61 | if w == tw and h == th: 62 | return 0, 0, h, w 63 | 64 | i = random.randint(0, h - th) 65 | j = random.randint(0, w - tw) 66 | return i, j, th, tw 67 | 68 | def __call__(self, img, label): #crop the same area of ori-image and label 69 | """ 70 | Args: 71 | img (PIL Image): Image to be cropped. 72 | Returns: 73 | PIL Image: Cropped image. 74 | """ 75 | if self.padding > 0: 76 | img = F.pad(img, self.padding) 77 | 78 | i, j, h, w = self.get_params(img, self.size) 79 | 80 | return crop(img, i, j, h, w), crop(label, i, j, h, w) 81 | 82 | def __repr__(self): 83 | return self.__class__.__name__ + '(size={0})'.format(self.size) 84 | 85 | 86 | class RandomFlip(object): 87 | """Randomflip the given PIL Image randomly with a given probability. horizontal or vertical 88 | Args: 89 | p (float): probability of the image being flipped. Default value is 0.5 90 | """ 91 | # make sure that crop area of image and label are the same 92 | def __init__(self, p=0.5): 93 | self.p = p 94 | 95 | def __call__(self, img, label): 96 | """ 97 | Args: 98 | img (PIL Image): Image to be flipped. 99 | Returns: 100 | PIL Image: Randomly flipped image. 101 | """ 102 | if random.random() < self.p: 103 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 104 | label = label.transpose(Image.FLIP_LEFT_RIGHT) #left or right 105 | if random.random() < self.p: 106 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 107 | label = label.transpose(Image.FLIP_TOP_BOTTOM) # bottom or top 108 | return img, label 109 | 110 | def __repr__(self): 111 | return self.__class__.__name__ + '(p={})'.format(self.p) 112 | 113 | class CenterCrop(object): 114 | 115 | def __init__(self, size, padding=0): 116 | if isinstance(size, numbers.Number): 117 | self.size = (int(size), int(size)) 118 | else: 119 | self.size = size 120 | self.padding = padding 121 | 122 | def __call__(self, img, label): 123 | """ 124 | Args: 125 | img (PIL Image): Image to be cropped. 126 | Returns: 127 | PIL Image: Cropped image. 128 | """ 129 | 130 | w, h = img.size 131 | th, tw = self.size 132 | i = int(round((h - th) / 2.)) 133 | j = int(round((w - tw) / 2.)) 134 | return crop(img, j,i, tw, th), crop(label, j,i, tw, th) 135 | 136 | 137 | def __repr__(self): 138 | return self.__class__.__name__ + '(size={0})'.format(self.size) 139 | 140 | class RandomRotate(object): 141 | def __init__(self, degree): 142 | self.degree = degree 143 | 144 | def __call__(self, img, mask): 145 | rotate_degree = random.random() * 2 * self.degree - self.degree 146 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 147 | -------------------------------------------------------------------------------- /dataloader/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .functional import RandomCrop, CenterCrop,RandomFlip,RandomRotate 4 | from PIL import Image 5 | import random 6 | from torchvision.transforms import ToTensor, ToPILImage 7 | from torchvision.transforms import Normalize 8 | def colormap_cityscapes(n): 9 | cmap=np.zeros([n, 3]).astype(np.uint8) 10 | cmap[0,:] = np.array([128, 64,128]) 11 | cmap[1,:] = np.array([244, 35,232]) 12 | cmap[2,:] = np.array([ 70, 70, 70]) 13 | cmap[3,:] = np.array([ 102,102,156]) 14 | cmap[4,:] = np.array([ 190,153,153]) 15 | cmap[5,:] = np.array([ 153,153,153]) 16 | 17 | cmap[6,:] = np.array([ 250,170, 30]) 18 | cmap[7,:] = np.array([ 220,220, 0]) 19 | cmap[8,:] = np.array([ 107,142, 35]) 20 | cmap[9,:] = np.array([ 152,251,152]) 21 | cmap[10,:] = np.array([ 70,130,180]) 22 | 23 | cmap[11,:] = np.array([ 220, 20, 60]) 24 | cmap[12,:] = np.array([ 255, 0, 0]) 25 | cmap[13,:] = np.array([ 0, 0,142]) 26 | cmap[14,:] = np.array([ 0, 0, 70]) 27 | cmap[15,:] = np.array([ 0, 60,100]) 28 | 29 | cmap[16,:] = np.array([ 0, 80,100]) 30 | cmap[17,:] = np.array([ 0, 0,230]) 31 | cmap[18,:] = np.array([ 119, 11, 32]) 32 | cmap[19,:] = np.array([ 0, 0, 0]) 33 | 34 | return cmap 35 | 36 | def colormap(n): 37 | cmap=np.zeros([n, 3]).astype(np.uint8) 38 | for i in np.arange(n): 39 | r, g, b = np.zeros(3) 40 | for j in np.arange(8): 41 | r = r + (1<<(7-j))*((i&(1<<(3*j))) >> (3*j)) 42 | g = g + (1<<(7-j))*((i&(1<<(3*j+1))) >> (3*j+1)) 43 | b = b + (1<<(7-j))*((i&(1<<(3*j+2))) >> (3*j+2)) 44 | cmap[i,:] = np.array([r, g, b]) 45 | return cmap 46 | 47 | class Relabel: 48 | def __init__(self, olabel, nlabel): 49 | self.olabel = olabel 50 | self.nlabel = nlabel 51 | def __call__(self, tensor): 52 | assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, torch.ByteTensor)) , 'tensor needs to be LongTensor' 53 | tensor[tensor == self.olabel] = self.nlabel 54 | return tensor 55 | 56 | class ToLabel: 57 | def __call__(self, image): 58 | return torch.from_numpy(np.array(image)).long().unsqueeze(0) #np.array change the size of image 59 | 60 | class Colorize: 61 | def __init__(self, n=22): 62 | #self.cmap = colormap(256) 63 | self.cmap = colormap_cityscapes(256) 64 | self.cmap[n] = self.cmap[-1] 65 | self.cmap = torch.from_numpy(self.cmap[:n]) 66 | 67 | def __call__(self, gray_image): 68 | size = gray_image.size() 69 | #print(size) 70 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 71 | #color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) 72 | 73 | #for label in range(1, len(self.cmap)): 74 | for label in range(0, len(self.cmap)): 75 | mask = gray_image[0] == label 76 | #mask = gray_image == label 77 | color_image[0][mask] = self.cmap[label][0] 78 | color_image[1][mask] = self.cmap[label][1] 79 | color_image[2][mask] = self.cmap[label][2] 80 | return color_image 81 | 82 | class MyTransform(object): 83 | ''' 84 | 1. self-define transform rules, including resize, crop, flip. (crop and flip only for training set) 85 | 2. training set augmentation with RandomCrop and RandomFlip. 86 | 3. validation set using CenterCrop 87 | ''' 88 | def __init__(self,reshape_size=None, crop_size = None , augment=True): 89 | self.reshape_size = reshape_size 90 | self.crop_size = crop_size 91 | self.augment = augment 92 | self.flip = RandomFlip() 93 | self.rotate = RandomRotate(32) 94 | 95 | self.count = 0 96 | def __call__(self, input, target): 97 | # do something to both images and labels 98 | if self.reshape_size is not None: 99 | input = input.resize(self.reshape_size,Image.BILINEAR) 100 | target = target.resize(self.reshape_size,Image.NEAREST) 101 | 102 | if self.augment : 103 | input, target = RandomCrop(self.crop_size)(input,target) # RandomCrop for image and label in the same area 104 | input, target = self.flip(input,target) # RandomFlip for both croped image and label 105 | input, target = self.rotate(input,target) 106 | else: 107 | input, target = CenterCrop(self.crop_size)(input, target) # CenterCrop for the validation data 108 | 109 | input = ToTensor()(input) 110 | Normalize([.485, .456, .406], [.229, .224, .225])(input) #normalize with the params of imagenet 111 | 112 | target = torch.from_numpy(np.array(target)).long().unsqueeze(0) 113 | 114 | return input, target 115 | 116 | class Transform_test(object): 117 | ''' 118 | Transform for test data.Reshape size is difined in ./options/test_options.py 119 | ''' 120 | def __init__(self,size): 121 | self.size = size 122 | def __call__(self, input, target): 123 | # do something to both images 124 | input = input.resize(self.size, Image.BILINEAR) 125 | target = target.resize(self.size,Image.NEAREST) 126 | 127 | target = torch.from_numpy(np.array(target)).long().unsqueeze(0) 128 | input_tensor = ToTensor()(input) 129 | Normalize([.485, .456, .406], [.229, .224, .225])(input_tensor) 130 | return input_tensor, target, input 131 | 132 | def img2label(img,label,count): 133 | count+=1 134 | img = np.array(img) 135 | label = np.array(label) 136 | for i in range(label.shape[0]): 137 | for j in range(label.shape[1]): 138 | if label[i,j]==0: 139 | img[i,j,:]=0 140 | image = ToPILImage()(img) 141 | image.save('./results/imglabel_'+str(count)+'.jpg') 142 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from torch.autograd import Variable 4 | from torchvision.transforms import ToPILImage 5 | from utils import evalIoU 6 | 7 | def eval(args, model, loader_val, criterion, epoch): 8 | print("----- VALIDATING - EPOCH", epoch, "-----") 9 | model.eval() 10 | epoch_loss_val = [] 11 | time_val = [] 12 | 13 | #New confusion matrix 14 | confMatrix = evalIoU.generateMatrixTrainId(evalIoU.args) 15 | perImageStats = {} 16 | nbPixels = 0 17 | 18 | for step, (images, labels) in enumerate(loader_val): 19 | start_time = time.time() 20 | if args.cuda: 21 | images = images.cuda() 22 | labels = labels.cuda() 23 | inputs = Variable(images, volatile=True) 24 | targets = Variable(labels, volatile=True) 25 | 26 | outputs = model(inputs) 27 | loss = criterion(outputs, targets[:, 0]) 28 | 29 | epoch_loss_val.append(loss.data[0]) 30 | time_val.append(time.time() - start_time) 31 | 32 | average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val) 33 | 34 | if args.iouVal: # add to confMatrix 35 | add_to_confMatrix(outputs, labels,confMatrix, perImageStats, nbPixels) 36 | 37 | if args.steps_loss > 0 and step % args.steps_loss == 0: 38 | average = sum(epoch_loss_val) / len(epoch_loss_val) 39 | print('VAL loss: {} (epoch: {}, step: {})'.format(average,epoch,step), 40 | "// Avg time/img: %.4f s" % (sum(time_val) / len(time_val) / args.batch_size)) 41 | 42 | average_epoch_loss_train = sum(epoch_loss_val) / len(epoch_loss_val) 43 | iouAvgStr, iouVal, classScoreList = cal_iou(evalIoU, confMatrix) 44 | print ("EPOCH IoU on VAL set: ", iouAvgStr) 45 | 46 | return average_epoch_loss_val, iouVal 47 | 48 | def add_to_confMatrix(prediction, groundtruth, confMatrix, perImageStats, nbPixels): 49 | if isinstance(prediction, list): #merge multi-gpu tensors 50 | outputs_cpu = prediction[0].cpu() 51 | for i in range(1,len(outputs)): 52 | outputs_cpu = torch.cat((outputs_cpu, prediction[i].cpu()), 0) 53 | else: 54 | outputs_cpu = prediction.cpu() 55 | for i in range(0, outputs_cpu.size(0)): #args.batch_size,evaluate iou of each batch 56 | prediction = ToPILImage()(outputs_cpu[i].max(0)[1].data.unsqueeze(0).byte()) 57 | groundtruth_image = ToPILImage()(groundtruth[i].cpu().byte()) 58 | nbPixels += evalIoU.evaluatePairPytorch(prediction, groundtruth_image, confMatrix, perImageStats, evalIoU.args) 59 | 60 | def cal_iou(evalIoU, confMatrix): 61 | iou = 0 62 | classScoreList = {} 63 | for label in evalIoU.args.evalLabels: 64 | labelName = evalIoU.trainId2label[label].name 65 | classScoreList[labelName] = evalIoU.getIouScoreForTrainLabel(label, confMatrix, evalIoU.args) 66 | 67 | iouAvgStr = evalIoU.getColorEntry(evalIoU.getScoreAverage(classScoreList, evalIoU.args), evalIoU.args) + "{avg:5.3f}".format(avg=evalIoU.getScoreAverage(classScoreList, evalIoU.args)) + evalIoU.args.nocol 68 | iou = float(evalIoU.getScoreAverage(classScoreList, evalIoU.args)) 69 | return iouAvgStr, iou, classScoreList 70 | 71 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .fcn import FCN8, FCN16, FCN32 2 | from .erfnet import ERFNet 3 | from .pspnet import PSPNet 4 | from .segnet import SegNet 5 | from .unet import UNet 6 | from .utils import * 7 | 8 | net_dic = {'erfnet' : ERFNet, 'fcn8' : FCN8, 'fcn16' : FCN16, 9 | 'fcn32' : FCN32, 'unet' : UNet, 'pspnet': PSPNet, 'segnet' : SegNet} 10 | 11 | 12 | def get_model(args): 13 | 14 | Net = net_dic[args.model] 15 | model = Net(args.num_classes) 16 | model.apply(weights_init) 17 | return model 18 | -------------------------------------------------------------------------------- /networks/erfnet.py: -------------------------------------------------------------------------------- 1 | # ERFNet full model definition for Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | class DownsamplerBlock (nn.Module): 12 | def __init__(self, ninput, noutput): 13 | super().__init__() 14 | # changed the stride of downsampler 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 16 | self.pool = nn.MaxPool2d(2, stride=2) 17 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 18 | 19 | 20 | 21 | def forward(self, input): 22 | #output = self.conv(input) 23 | output = torch.cat([self.conv(input), self.pool(input)], 1) ###concatnate 24 | output = self.bn(output) 25 | return F.relu(output) 26 | 27 | 28 | class non_bottleneck_1d (nn.Module): 29 | def __init__(self, chann, dropprob, dilated): 30 | super().__init__() 31 | 32 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 33 | 34 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 35 | 36 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 37 | 38 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 39 | 40 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 41 | 42 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 43 | 44 | self.dropout = nn.Dropout2d(dropprob) 45 | 46 | def forward(self, input): 47 | 48 | output = self.conv3x1_1(input) 49 | output = F.relu(output) 50 | output = self.conv1x3_1(output) 51 | output = self.bn1(output) 52 | output = F.relu(output) 53 | 54 | output = self.conv3x1_2(output) 55 | output = F.relu(output) 56 | output = self.conv1x3_2(output) 57 | output = self.bn2(output) 58 | 59 | if (self.dropout.p != 0): 60 | output = self.dropout(output) 61 | 62 | return F.relu(output+input) #+input = identity (residual connection) 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__(self, num_classes): 67 | super().__init__() 68 | self.initial_block = DownsamplerBlock(3,16) 69 | 70 | self.layers = nn.ModuleList() 71 | 72 | self.layers.append(DownsamplerBlock(16,64)) 73 | 74 | for x in range(0, 5): #5 times 75 | self.layers.append(non_bottleneck_1d(64, 0.03, 1)) #0.03 76 | 77 | self.layers.append(DownsamplerBlock(64,128)) 78 | 79 | for x in range(0, 2): #2 times 80 | self.layers.append(non_bottleneck_1d(128, 0.3, 2)) #0.3 81 | self.layers.append(non_bottleneck_1d(128, 0.3, 4)) #0.3 82 | self.layers.append(non_bottleneck_1d(128, 0.3, 8)) #0.3 83 | self.layers.append(non_bottleneck_1d(128, 0.3, 16)) #0.3 84 | 85 | #Only in encoder mode: 86 | self.output_conv = nn.Conv2d(128, num_classes, 1, stride=1, padding=0, bias=True) 87 | 88 | def forward(self, input, predict=False): 89 | 90 | #print("Before input of Encoder:{}".format(input.data.shape)) 91 | output = self.initial_block(input) 92 | 93 | for layer in self.layers: 94 | output = layer(output) 95 | 96 | if predict: 97 | output = self.output_conv(output) 98 | 99 | #print("After forward of Encoder:{}".format(output.data.shape)) 100 | 101 | return output 102 | 103 | 104 | class UpsamplerBlock (nn.Module): 105 | def __init__(self, ninput, noutput): 106 | super().__init__() 107 | self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, padding=1, output_padding=1, bias=True) 108 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 109 | 110 | 111 | def forward(self, input): 112 | output = self.conv(input) 113 | output = self.bn(output) 114 | return F.relu(output) 115 | 116 | class Decoder (nn.Module): 117 | def __init__(self, num_classes): 118 | super().__init__() 119 | 120 | self.layers = nn.ModuleList() 121 | 122 | self.layers.append(UpsamplerBlock(128,64)) 123 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 124 | self.layers.append(non_bottleneck_1d(64, 0, 1)) 125 | 126 | self.layers.append(UpsamplerBlock(64,16)) 127 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 128 | self.layers.append(non_bottleneck_1d(16, 0, 1)) 129 | 130 | self.output_conv = nn.ConvTranspose2d( 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) 131 | 132 | 133 | 134 | def forward(self, input): 135 | output = input 136 | 137 | for layer in self.layers: 138 | output = layer(output) 139 | 140 | output = self.output_conv(output) 141 | 142 | return output 143 | 144 | 145 | class ERFNet(nn.Module): 146 | def __init__(self, num_classes, encoder=None): #use encoder to pass pretrained encoder 147 | super().__init__() 148 | 149 | if (encoder == None): 150 | self.encoder = Encoder(num_classes) 151 | else: 152 | self.encoder = encoder 153 | self.decoder = Decoder(num_classes) 154 | 155 | def forward(self, input, only_encode=False): ####encodr and decoder are seperated!!! 156 | if only_encode: 157 | 158 | return self.encoder.forward(input, predict=True) 159 | else: 160 | output = self.encoder.forward(input) 161 | 162 | return self.decoder.forward(output) 163 | -------------------------------------------------------------------------------- /networks/fcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | 8 | class FCN8(nn.Module): 9 | 10 | def __init__(self, num_classes): 11 | super().__init__() 12 | 13 | feats = list(models.vgg16(pretrained=True).features.children()) 14 | 15 | self.feats = nn.Sequential(*feats[0:9]) 16 | self.feat3 = nn.Sequential(*feats[10:16]) 17 | self.feat4 = nn.Sequential(*feats[17:23]) 18 | self.feat5 = nn.Sequential(*feats[24:30]) 19 | 20 | for m in self.modules(): 21 | if isinstance(m, nn.Conv2d): 22 | m.requires_grad = False 23 | 24 | self.fconn = nn.Sequential( 25 | nn.Conv2d(512, 4096, 7), 26 | nn.ReLU(inplace=True), 27 | nn.Dropout(), 28 | nn.Conv2d(4096, 4096, 1), 29 | nn.ReLU(inplace=True), 30 | nn.Dropout(), 31 | ) 32 | self.score_feat3 = nn.Conv2d(256, num_classes, 1) 33 | self.score_feat4 = nn.Conv2d(512, num_classes, 1) 34 | self.score_fconn = nn.Conv2d(4096, num_classes, 1) 35 | 36 | def forward(self, x): 37 | feats = self.feats(x) 38 | feat3 = self.feat3(feats) 39 | feat4 = self.feat4(feat3) 40 | feat5 = self.feat5(feat4) 41 | fconn = self.fconn(feat5) 42 | 43 | score_feat3 = self.score_feat3(feat3) 44 | score_feat4 = self.score_feat4(feat4) 45 | score_fconn = self.score_fconn(fconn) 46 | 47 | score = F.upsample_bilinear(score_fconn, score_feat4.size()[2:]) 48 | score += score_feat4 49 | score = F.upsample_bilinear(score, score_feat3.size()[2:]) 50 | score += score_feat3 51 | 52 | return F.upsample_bilinear(score, x.size()[2:]) 53 | 54 | 55 | class FCN16(nn.Module): 56 | 57 | def __init__(self, num_classes): 58 | super().__init__() 59 | 60 | feats = list(models.vgg16(pretrained=True).features.children()) 61 | self.feats = nn.Sequential(*feats[0:16]) 62 | self.feat4 = nn.Sequential(*feats[17:23]) 63 | self.feat5 = nn.Sequential(*feats[24:30]) 64 | self.fconn = nn.Sequential( 65 | nn.Conv2d(512, 4096, 7), 66 | nn.ReLU(inplace=True), 67 | nn.Dropout(), 68 | nn.Conv2d(4096, 4096, 1), 69 | nn.ReLU(inplace=True), 70 | nn.Dropout(), 71 | ) 72 | self.score_fconn = nn.Conv2d(4096, num_classes, 1) 73 | self.score_feat4 = nn.Conv2d(512, num_classes, 1) 74 | 75 | def forward(self, x): 76 | feats = self.feats(x) 77 | feat4 = self.feat4(feats) 78 | feat5 = self.feat5(feat4) 79 | fconn = self.fconn(feat5) 80 | 81 | score_feat4 = self.score_feat4(feat4) 82 | score_fconn = self.score_fconn(fconn) 83 | 84 | score = F.upsample_bilinear(score_fconn, score_feat4.size()[2:]) 85 | score += score_feat4 86 | 87 | return F.upsample_bilinear(score, x.size()[2:]) 88 | 89 | 90 | class FCN32(nn.Module): 91 | 92 | def __init__(self, num_classes): 93 | super().__init__() 94 | 95 | self.feats = models.vgg16(pretrained=True).features 96 | self.fconn = nn.Sequential( 97 | nn.Conv2d(512, 4096, 7), 98 | nn.ReLU(inplace=True), 99 | nn.Dropout(), 100 | nn.Conv2d(4096, 4096, 1), 101 | nn.ReLU(inplace=True), 102 | nn.Dropout(), 103 | ) 104 | self.score = nn.Conv2d(4096, num_classes, 1) 105 | 106 | def forward(self, x): 107 | feats = self.feats(x) 108 | fconn = self.fconn(feats) 109 | score = self.score(fconn) 110 | 111 | return F.upsample_bilinear(score, x.size()[2:]) 112 | -------------------------------------------------------------------------------- /networks/pspnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.utils import model_zoo 6 | from torchvision import models 7 | 8 | class PSPDec(nn.Module): 9 | 10 | def __init__(self, in_features, out_features, downsize, upsize=60): 11 | super().__init__() 12 | 13 | self.features = nn.Sequential( 14 | nn.AvgPool2d(downsize, stride=downsize), 15 | nn.Conv2d(in_features, out_features, 1, bias=False), 16 | nn.BatchNorm2d(out_features, momentum=.95), 17 | nn.ReLU(inplace=True), 18 | nn.UpsamplingBilinear2d(upsize) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.features(x) 23 | 24 | 25 | class PSPNet(nn.Module): 26 | 27 | def __init__(self, num_classes): 28 | super().__init__() 29 | 30 | resnet = models.resnet101(pretrained=True) 31 | 32 | self.conv1 = resnet.conv1 33 | self.layer1 = resnet.layer1 34 | self.layer2 = resnet.layer2 35 | self.layer3 = resnet.layer3 36 | self.layer4 = resnet.layer4 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | m.stride = 1 41 | m.requires_grad = False 42 | if isinstance(m, nn.BatchNorm2d): 43 | m.requires_grad = False 44 | 45 | self.layer5a = PSPDec(2048, 512, 60) 46 | self.layer5b = PSPDec(2048, 512, 30) 47 | self.layer5c = PSPDec(2048, 512, 20) 48 | self.layer5d = PSPDec(2048, 512, 10) 49 | 50 | self.final = nn.Sequential( 51 | nn.Conv2d(2048, 512, 3, padding=1, bias=False), 52 | nn.BatchNorm2d(512, momentum=.95), 53 | nn.ReLU(inplace=True), 54 | nn.Dropout(.1), 55 | nn.Conv2d(512, num_classes, 1), 56 | ) 57 | 58 | def forward(self, x): 59 | print('x', x.size()) 60 | x = self.conv1(x) 61 | print('conv1', x.size()) 62 | x = self.layer1(x) 63 | print('layer1', x.size()) 64 | x = self.layer2(x) 65 | print('layer2', x.size()) 66 | x = self.layer3(x) 67 | print('layer3', x.size()) 68 | x = self.layer4(x) 69 | print('layer4', x.size()) 70 | x = self.final(torch.cat([ 71 | x, 72 | self.layer5a(x), 73 | self.layer5b(x), 74 | self.layer5c(x), 75 | self.layer5d(x), 76 | ], 1)) 77 | print('final', x.size()) 78 | 79 | return F.upsample_bilinear(final, x.size()[2:]) 80 | -------------------------------------------------------------------------------- /networks/segnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | from torch.utils import model_zoo 7 | from torchvision import models 8 | 9 | class SegNetEnc(nn.Module): 10 | 11 | def __init__(self, in_channels, out_channels, num_layers): 12 | super().__init__() 13 | 14 | layers = [ 15 | nn.Upsample(scale_factor=2,mode='bilinear'), 16 | nn.Conv2d(in_channels, in_channels // 2, 3, padding=1), 17 | nn.BatchNorm2d(in_channels // 2), 18 | nn.ReLU(inplace=True), 19 | ] 20 | layers += [ 21 | nn.Conv2d(in_channels // 2, in_channels // 2, 3, padding=1), 22 | nn.BatchNorm2d(in_channels // 2), 23 | nn.ReLU(inplace=True), 24 | ] * num_layers 25 | layers += [ 26 | nn.Conv2d(in_channels // 2, out_channels, 3, padding=1), 27 | nn.BatchNorm2d(out_channels), 28 | nn.ReLU(inplace=True), 29 | ] 30 | self.encode = nn.Sequential(*layers) 31 | 32 | def forward(self, x): 33 | return self.encode(x) 34 | 35 | 36 | class SegNet(nn.Module): 37 | 38 | def __init__(self, num_classes): 39 | super().__init__() 40 | 41 | decoders = list(models.vgg16(pretrained=True).features.children()) 42 | 43 | self.dec1 = nn.Sequential(*decoders[:5]) 44 | self.dec2 = nn.Sequential(*decoders[5:10]) 45 | self.dec3 = nn.Sequential(*decoders[10:17]) 46 | self.dec4 = nn.Sequential(*decoders[17:24]) 47 | self.dec5 = nn.Sequential(*decoders[24:]) 48 | 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | m.requires_grad = False 52 | 53 | self.enc5 = SegNetEnc(512, 512, 1) 54 | self.enc4 = SegNetEnc(1024, 256, 1) 55 | self.enc3 = SegNetEnc(512, 128, 1) 56 | self.enc2 = SegNetEnc(256, 64, 0) 57 | self.enc1 = nn.Sequential( 58 | nn.Upsample(scale_factor=2,mode='bilinear'), 59 | nn.Conv2d(128, 64, 3, padding=1), 60 | nn.BatchNorm2d(64), 61 | nn.ReLU(inplace=True), 62 | ) 63 | self.final = nn.Conv2d(64, num_classes, 3, padding=1) 64 | 65 | def forward(self, x): 66 | ''' 67 | Attention, input size should be the 32x. 68 | ''' 69 | dec1 = self.dec1(x) 70 | dec2 = self.dec2(dec1) 71 | dec3 = self.dec3(dec2) 72 | dec4 = self.dec4(dec3) 73 | dec5 = self.dec5(dec4) 74 | enc5 = self.enc5(dec5) 75 | 76 | enc4 = self.enc4(torch.cat([dec4, enc5], 1)) 77 | enc3 = self.enc3(torch.cat([dec3, enc4], 1)) 78 | enc2 = self.enc2(torch.cat([dec2, enc3], 1)) 79 | enc1 = self.enc1(torch.cat([dec1, enc2], 1)) 80 | 81 | return F.upsample_bilinear(self.final(enc1), x.size()[2:]) 82 | -------------------------------------------------------------------------------- /networks/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | from torch.utils import model_zoo 7 | from torchvision import models 8 | 9 | class UNetEnc(nn.Module): 10 | 11 | def __init__(self, in_channels, features, out_channels): 12 | super().__init__() 13 | 14 | self.up = nn.Sequential( 15 | nn.Conv2d(in_channels, features, 3), 16 | nn.ReLU(inplace=True), 17 | nn.Conv2d(features, features, 3), 18 | nn.ReLU(inplace=True), 19 | nn.ConvTranspose2d(features, out_channels, 2, stride=2), 20 | nn.ReLU(inplace=True), 21 | ) 22 | 23 | def forward(self, x): 24 | return self.up(x) 25 | 26 | 27 | class UNetDec(nn.Module): 28 | 29 | def __init__(self, in_channels, out_channels, dropout=False): 30 | super().__init__() 31 | 32 | layers = [ 33 | nn.Conv2d(in_channels, out_channels, 3), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(out_channels, out_channels, 3), 36 | nn.ReLU(inplace=True), 37 | ] 38 | if dropout: 39 | layers += [nn.Dropout(.5)] 40 | layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)] 41 | 42 | self.down = nn.Sequential(*layers) 43 | 44 | def forward(self, x): 45 | return self.down(x) 46 | 47 | 48 | class UNet(nn.Module): 49 | 50 | def __init__(self, num_classes): 51 | super().__init__() 52 | 53 | self.dec1 = UNetDec(3, 64) 54 | self.dec2 = UNetDec(64, 128) 55 | self.dec3 = UNetDec(128, 256) 56 | self.dec4 = UNetDec(256, 512, dropout=True) 57 | self.center = nn.Sequential( 58 | nn.Conv2d(512, 1024, 3), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(1024, 1024, 3), 61 | nn.ReLU(inplace=True), 62 | nn.Dropout(), 63 | nn.ConvTranspose2d(1024, 512, 2, stride=2), 64 | nn.ReLU(inplace=True), 65 | ) 66 | self.enc4 = UNetEnc(1024, 512, 256) 67 | self.enc3 = UNetEnc(512, 256, 128) 68 | self.enc2 = UNetEnc(256, 128, 64) 69 | self.enc1 = nn.Sequential( 70 | nn.Conv2d(128, 64, 3), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(64, 64, 3), 73 | nn.ReLU(inplace=True), 74 | ) 75 | self.final = nn.Conv2d(64, num_classes, 1) 76 | 77 | def forward(self, x): 78 | dec1 = self.dec1(x) 79 | dec2 = self.dec2(dec1) 80 | dec3 = self.dec3(dec2) 81 | dec4 = self.dec4(dec3) 82 | center = self.center(dec4) 83 | enc4 = self.enc4(torch.cat([ 84 | center, F.upsample_bilinear(dec4, center.size()[2:])], 1)) 85 | enc3 = self.enc3(torch.cat([ 86 | enc4, F.upsample_bilinear(dec3, enc4.size()[2:])], 1)) 87 | enc2 = self.enc2(torch.cat([ 88 | enc3, F.upsample_bilinear(dec2, enc3.size()[2:])], 1)) 89 | enc1 = self.enc1(torch.cat([ 90 | enc2, F.upsample_bilinear(dec1, enc2.size()[2:])], 1)) 91 | 92 | return F.upsample_bilinear(self.final(enc1), x.size()[2:]) 93 | -------------------------------------------------------------------------------- /networks/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def weights_init(m): 3 | classname = m.__class__.__name__ 4 | if classname.find('Conv') != -1: 5 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 6 | m.weight.data.normal_(0, math.sqrt(2. / n)) 7 | elif classname.find('BatchNorm') != -1: 8 | m.weight.data.fill_(1) 9 | m.bias.data.fill_(0) 10 | -------------------------------------------------------------------------------- /options/README.md: -------------------------------------------------------------------------------- 1 | # Options 2 | > options for train and test. you can modify or add new items. 3 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mapleneverfade/pytorch-semantic-segmentation/7469de95cdb0fbfe9b00b93a8b068c35d398c6cf/options/__init__.py -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import argparse 3 | import os 4 | 5 | class TestOptions(): 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | self.initialized = False 9 | 10 | def initialize(self): 11 | self.parser.add_argument('--cuda', action='store_true', default=True) 12 | self.parser.add_argument('--model', default="segnet", help='model to train,options:fcn8,segnet...') 13 | self.parser.add_argument('--model-dir', default="./save_models/", help='path to stored-model') 14 | self.parser.add_argument('--num-classes', type=int, default=21) 15 | self.parser.add_argument('--datadir', default="./data/test/",help='path where image.txt and label.txt lies') 16 | self.parser.add_argument('--size', default=(672,480), help='resize the test image') 17 | self.parser.add_argument('--stored',default=True, help='whether or not store the result') 18 | self.parser.add_argument('--savedir', type=str, default='./results/',help='options. visualize the result of segmented picture, not just show IoU') 19 | 20 | self.initialized = True 21 | 22 | def parse(self): 23 | if not self.initialized: 24 | self.initialize() 25 | self.opt = self.parser.parse_args() 26 | args = vars(self.opt) 27 | 28 | print('------------ Options -------------') 29 | for k, v in sorted(args.items()): 30 | print('%s: %s' % (str(k), str(v))) 31 | print('-------------- End ----------------') 32 | 33 | return self.opt 34 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import argparse 3 | import os 4 | 5 | class TrainOptions(): 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | self.initialized = False 9 | 10 | def initialize(self): 11 | self.parser.add_argument('--cuda', action='store_true', default=True) 12 | self.parser.add_argument('--model', default="erfnet", help='model to train,options:fcn8,segnet...') 13 | self.parser.add_argument('--state') 14 | self.parser.add_argument('--num-classes', type=int, default=21) 15 | self.parser.add_argument('--datadir', default="./data/",help='path where image.txt and label.txt lies') 16 | self.parser.add_argument('--savedir', type=str, default='./save_models/',help='savedir for models') 17 | self.parser.add_argument('--lr', type=float, default=5e-4) 18 | self.parser.add_argument('--num-epochs', type=int, default=150) 19 | self.parser.add_argument('--num-workers', type=int, default=4) 20 | self.parser.add_argument('--batch-size', type=int, default=1) 21 | self.parser.add_argument('--epoch-save', type=int, default=5) #You can use this value to save model every X epochs 22 | self.parser.add_argument('--iouTrain', action='store_true', default=True) #recommended: False (takes a lot to train otherwise) 23 | self.parser.add_argument('--iouVal', action='store_true', default=True) #calculating IoU 24 | self.parser.add_argument('--steps-loss', type=int, default=5) 25 | self.parser.add_argument('--pretrained',type=str, default='./pre_trained/~~~.pth') 26 | self.parser.add_argument('--resume', action='store_true', default= False) 27 | 28 | self.initialized = True 29 | 30 | def parse(self): 31 | if not self.initialized: 32 | self.initialize() 33 | self.opt = self.parser.parse_args() 34 | args = vars(self.opt) 35 | 36 | print('------------ Options -------------') 37 | for k, v in sorted(args.items()): 38 | print('%s: %s' % (str(k), str(v))) 39 | print('-------------- End ----------------') 40 | 41 | return self.opt 42 | -------------------------------------------------------------------------------- /save_models/README.md: -------------------------------------------------------------------------------- 1 | # save_models stores the models during training 2 | > ./fcn16 3 | -------------------------------------------------------------------------------- /save_models/fcn16/automated_log.txt: -------------------------------------------------------------------------------- 1 | Epoch Train-loss Test-loss Train-IoU Test-IoU learningRate 2 | 1 9822122.8923 30.9549 0.0000 0.0000 0.00500000 -------------------------------------------------------------------------------- /save_models/fcn16/model.txt: -------------------------------------------------------------------------------- 1 | FCN16( 2 | (feats): Sequential( 3 | (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 4 | (1): ReLU(inplace) 5 | (2): Conv2d (64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 6 | (3): ReLU(inplace) 7 | (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 8 | (5): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 9 | (6): ReLU(inplace) 10 | (7): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 | (8): ReLU(inplace) 12 | (9): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1)) 13 | (10): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 14 | (11): ReLU(inplace) 15 | (12): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 16 | (13): ReLU(inplace) 17 | (14): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 18 | (15): ReLU(inplace) 19 | ) 20 | (feat4): Sequential( 21 | (0): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 22 | (1): ReLU(inplace) 23 | (2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 24 | (3): ReLU(inplace) 25 | (4): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 26 | (5): ReLU(inplace) 27 | ) 28 | (feat5): Sequential( 29 | (0): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 30 | (1): ReLU(inplace) 31 | (2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 32 | (3): ReLU(inplace) 33 | (4): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 34 | (5): ReLU(inplace) 35 | ) 36 | (fconn): Sequential( 37 | (0): Conv2d (512, 4096, kernel_size=(7, 7), stride=(1, 1)) 38 | (1): ReLU(inplace) 39 | (2): Dropout(p=0.5) 40 | (3): Conv2d (4096, 4096, kernel_size=(1, 1), stride=(1, 1)) 41 | (4): ReLU(inplace) 42 | (5): Dropout(p=0.5) 43 | ) 44 | (score_fconn): Conv2d (4096, 21, kernel_size=(1, 1), stride=(1, 1)) 45 | (score_feat4): Conv2d (512, 21, kernel_size=(1, 1), stride=(1, 1)) 46 | ) -------------------------------------------------------------------------------- /save_models/fcn16/opts.txt: -------------------------------------------------------------------------------- 1 | Namespace(batch_size=1, cuda=True, datadir='./data/train/', epoch_save=5, iouTrain=True, iouVal=True, lr=0.005, model='fcn16', num_classes=21, num_epochs=150, num_workers=4, pretrained='./pre_trained/~~~.pth', resume=False, savedir='./save_models/fcn16', state=None, steps_loss=5) -------------------------------------------------------------------------------- /split_train_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sklearn.model_selection import train_test_split 3 | from argparse import ArgumentParser 4 | 5 | ''' 6 | Split dataset into train-data and val-data,test-data option --imagepath and --labelpath is the path to your original [image.txt and label.txt], 7 | After split, new train-data will stored in './data/train/[image.txt,label.txt]', val-data will stored in './data/val/[image.txt,label.txt]' 8 | test-data in './data/test/[image.txt,label.txt]' 9 | ''' 10 | def split_train_val(args): 11 | imagepath = args.imagepath 12 | labelpath = args.labelpath 13 | assert os.path.exists(imagepath), "{} is not exists!".format(imagepath) 14 | assert os.path.exists(labelpath), "{} is not exists!".format(labelpath) 15 | image = [] 16 | label = [] 17 | with open(imagepath,'r') as f: 18 | for line in f: 19 | image.append(line.strip()) 20 | with open(labelpath,'r') as f: 21 | for line in f: 22 | label.append(line.strip()) 23 | 24 | #split dataset in train/ test/ val = 7: 2: 1 25 | image_train, image_val, label_train, label_val = train_test_split(image,label,random_state=args.random_state,train_size=args.train_size,test_size=args.val_size) 26 | image_test, image_val, label_test, label_val = train_test_split(image_val,label_val,random_state=args.random_state,train_size=args.train_size,test_size=args.val_size) 27 | 28 | if not os.path.exists(os.path.join(args.savedir,'train')): 29 | os.mkdir(os.path.join(args.savedir,'train')) 30 | if not os.path.exists(os.path.join(args.savedir,'val')): 31 | os.mkdir(os.path.join(args.savedir,'val')) 32 | if not os.path.exists(os.path.join(args.savedir,'test')): 33 | os.mkdir(os.path.join(args.savedir,'test')) 34 | 35 | #store train data in ./data/train/image.txt 36 | with open(os.path.join(args.savedir,'train/image.txt'),'w') as f: 37 | for image in image_train: 38 | f.write(image+'\n') 39 | with open(os.path.join(args.savedir,'train/label.txt'),'w') as f: 40 | for label in label_train: 41 | f.write(label+'\n') 42 | #store test data in ./data/test/image.txt 43 | with open(os.path.join(args.savedir,'test/image.txt'),'w') as f: 44 | for image in image_test: 45 | f.write(image+'\n') 46 | with open(os.path.join(args.savedir,'test/label.txt'),'w') as f: 47 | for label in label_test: 48 | f.write(label+'\n') 49 | #store val data in ./data/val/image.txt 50 | with open(os.path.join(args.savedir,'val/image.txt'),'w') as f: 51 | for image in image_val: 52 | f.write(image+'\n') 53 | with open(os.path.join(args.savedir,'val/label.txt'),'w') as f: 54 | for label in label_val: 55 | f.write(label+'\n') 56 | print('Done!') 57 | 58 | if __name__ == '__main__': 59 | parser = ArgumentParser() 60 | parser.add_argument('--savedir', default='./data/') 61 | parser.add_argument('--imagepath', default='./data/image.txt') 62 | parser.add_argument('--labelpath', default='./data/label.txt') 63 | parser.add_argument('--random-state',default=10000) 64 | parser.add_argument('--train-size',default=0.7) 65 | parser.add_argument('--val-size',default=0.3) 66 | 67 | split_train_val(parser.parse_args()) 68 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from options.test_options import TestOptions 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from PIL import Image 8 | from torch.utils.data import DataLoader 9 | from utils.label2Img import label2rgb 10 | from dataloader.transform import Transform_test 11 | from dataloader.dataset import NeoData_test 12 | from networks import get_model 13 | from eval import * 14 | 15 | def main(args): 16 | despath = args.savedir 17 | if not os.path.exists(despath): 18 | os.mkdir(despath) 19 | 20 | imagedir = os.path.join(args.datadir,'image.txt') 21 | labeldir = os.path.join(args.datadir,'label.txt') 22 | 23 | transform = Transform_test(args.size) 24 | dataset_test = NeoData_test(imagedir, labeldir, transform) 25 | loader = DataLoader(dataset_test, num_workers=4, batch_size=1,shuffle=False) #test data loader 26 | 27 | #eval the result of IoU 28 | confMatrix = evalIoU.generateMatrixTrainId(evalIoU.args) 29 | perImageStats = {} 30 | nbPixels = 0 31 | usedLr = 0 32 | 33 | model = get_model(args) 34 | if args.cuda: 35 | model = model.cuda() 36 | model.load_state_dict(torch.load(args.model_dir)) 37 | model.eval() 38 | count = 0 39 | for step, colign in enumerate(loader): 40 | 41 | img = colign[2].squeeze(0).numpy() #image-numpy,original image 42 | images = colign[0] #image-tensor 43 | label = colign[1] #label-tensor 44 | 45 | if args.cuda: 46 | images = images.cuda() 47 | inputs = Variable(images,volatile=True) 48 | 49 | outputs = model(inputs) 50 | out = outputs[0].cpu().max(0)[1].data.squeeze(0).byte().numpy() #index of max-channel 51 | 52 | add_to_confMatrix(outputs, label, confMatrix, perImageStats, nbPixels) #add result to confusion matrix 53 | 54 | label2img = label2rgb(out,img,n_labels = args.num_classes) #merge segmented result with original picture 55 | Image.fromarray(label2img).save(despath + 'label2img_' +str(count)+'.jpg' ) 56 | count += 1 57 | print("This is the {}th of image!".format(count)) 58 | 59 | iouAvgStr, iouTest, classScoreList = cal_iou(evalIoU, confMatrix) #calculate mIoU, classScoreList include IoU for each class 60 | print("IoU on TEST set : ",iouAvgStr) 61 | #print("IoU on TEST set of each class - car:{} light:{} ".format(classScoreList['car'],classScoreList['light'])) 62 | 63 | if __name__ == '__main__': 64 | parser = TestOptions().parse() 65 | main(parser) 66 | 67 | 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import torch 5 | from eval import * 6 | import torch.nn as nn 7 | from utils import evalIoU 8 | from networks import get_model 9 | from torch.autograd import Variable 10 | from dataloader.dataset import NeoData 11 | from torch.utils.data import DataLoader 12 | from dataloader.transform import MyTransform 13 | from torchvision.transforms import ToPILImage 14 | from options.train_options import TrainOptions 15 | from torch.optim import SGD, Adam, lr_scheduler 16 | from criterion.criterion import CrossEntropyLoss2d 17 | NUM_CHANNELS = 3 18 | 19 | def get_loader(args): 20 | #add the weight of each class (1/ln(c+Pclass)) 21 | #calculate the weights of each class 22 | 23 | #weight[0]=1.45 24 | ##weight[1]=54.38 25 | #weight[2] = 428.723 26 | imagepath_train = os.path.join(args.datadir, 'train/image.txt') 27 | labelpath_train = os.path.join(args.datadir, 'train/label.txt') 28 | imagepath_val = os.path.join(args.datadir, 'val/image.txt') 29 | labelpath_val = os.path.join(args.datadir, 'val/label.txt') 30 | 31 | train_transform = MyTransform(reshape_size=(500,350),crop_size=(448,320), augment=True) # data transform for training set with data augmentation, including resize, crop, flip and so on 32 | val_transform = MyTransform(reshape_size=(500,350),crop_size=(448,320), augment=False) #data transform for validation set without data augmentation 33 | 34 | dataset_train = NeoData(imagepath_train, labelpath_train, train_transform) #DataSet 35 | dataset_val = NeoData(imagepath_val, labelpath_val, val_transform) 36 | 37 | loader = DataLoader(dataset_train, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=True) 38 | loader_val = DataLoader(dataset_val, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False) 39 | 40 | return loader, loader_val 41 | 42 | def train(args, model): 43 | NUM_CLASSES = args.num_classes #pascal=21, cityscapes=20 44 | savedir = args.savedir 45 | weight = torch.ones(NUM_CLASSES) 46 | 47 | loader, loader_val = get_loader(args) 48 | 49 | if args.cuda: 50 | criterion = CrossEntropyLoss2d(weight).cuda() 51 | else: 52 | criterion = CrossEntropyLoss2d(weight) 53 | 54 | #save log 55 | automated_log_path = savedir + "/automated_log.txt" 56 | if (not os.path.exists(automated_log_path)): #dont add first line if it exists 57 | with open(automated_log_path, "a") as myfile: 58 | myfile.write("Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate") 59 | 60 | optimizer = Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=1e-4) 61 | lambda1 = lambda epoch: pow((1-((epoch-1)/args.num_epochs)),0.9) 62 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) # learning rate changed every epoch 63 | start_epoch = 1 64 | 65 | for epoch in range(start_epoch, args.num_epochs+1): 66 | print("----- TRAINING - EPOCH", epoch, "-----") 67 | 68 | scheduler.step(epoch) 69 | epoch_loss = [] 70 | time_train = [] 71 | 72 | #confmatrix for calculating IoU 73 | confMatrix = evalIoU.generateMatrixTrainId(evalIoU.args) 74 | perImageStats = {} 75 | nbPixels = 0 76 | usedLr = 0 77 | #for param_group in optimizer.param_groups: 78 | for param_group in optimizer.param_groups: 79 | print("LEARNING RATE: ", param_group['lr']) 80 | usedLr = float(param_group['lr']) 81 | 82 | model.train() 83 | count = 1 84 | for step, (images, labels) in enumerate(loader): 85 | start_time = time.time() 86 | if args.cuda: 87 | images = images.cuda() 88 | labels = labels.cuda() 89 | 90 | inputs = Variable(images) 91 | targets = Variable(labels) 92 | 93 | outputs = model(inputs) 94 | loss = criterion(outputs, targets[:, 0]) 95 | 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | epoch_loss.append(loss.data[0]) 101 | time_train.append(time.time() - start_time) 102 | 103 | #Add outputs to confusion matrix #CODE USING evalIoU.py remade from cityscapes/scripts/evaluation/evalPixelLevelSemanticLabeling.py 104 | if (args.iouTrain): 105 | add_to_confMatrix(outputs, labels, confMatrix, perImageStats, nbPixels) 106 | 107 | if args.steps_loss > 0 and step % args.steps_loss == 0: 108 | average = sum(epoch_loss) / len(epoch_loss) 109 | print('loss: {} (epoch: {}, step: {})'.format(average,epoch,step), 110 | "// Avg time/img: %.4f s" % (sum(time_train) / len(time_train) / args.batch_size)) 111 | 112 | average_epoch_loss_train = sum(epoch_loss) / len(epoch_loss) 113 | iouAvgStr, iouTrain, classScoreList = cal_iou(evalIoU, confMatrix) 114 | print ("EPOCH IoU on TRAIN set: ", iouAvgStr) 115 | 116 | # calculate eval-loss and eval-IoU 117 | average_epoch_loss_val, iouVal = eval(args, model, loader_val, criterion, epoch) 118 | 119 | #save model every X epoch 120 | if epoch % args.epoch_save==0: 121 | torch.save(model.state_dict(), '{}_{}.pth'.format(os.path.join(args.savedir,args.model),str(epoch))) 122 | 123 | #save log 124 | with open(automated_log_path, "a") as myfile: 125 | myfile.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.8f" % (epoch, average_epoch_loss_train, average_epoch_loss_val, iouTrain, iouVal, usedLr )) 126 | 127 | return(model) 128 | 129 | def main(args): 130 | ''' 131 | Train the model and record training options. 132 | ''' 133 | savedir = '{}'.format(args.savedir) 134 | modeltxtpath = os.path.join(savedir,'model.txt') 135 | 136 | if not os.path.exists(savedir): 137 | os.makedirs(savedir) 138 | with open(savedir + '/opts.txt', "w") as myfile: #record options 139 | myfile.write(str(args)) 140 | 141 | model = get_model(args) #load model 142 | 143 | with open(modeltxtpath, "w") as myfile: #record model 144 | myfile.write(str(model)) 145 | 146 | if args.cuda: 147 | model = model.cuda() 148 | print("========== TRAINING ===========") 149 | model = train(args,model) 150 | print("========== TRAINING FINISHED ===========") 151 | 152 | if __name__ == '__main__': 153 | 154 | parser = TrainOptions().parse() 155 | main(parser) 156 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # Install Csupport 2 | Install CSUPPORT to speed up evaluate confusionMatrix 3 | > cd cityscape 4 | 5 | > python setup.py install 6 | -------------------------------------------------------------------------------- /utils/cityscapes/addToConfusionMatrix.pyx: -------------------------------------------------------------------------------- 1 | # cython methods to speed-up evaluation 2 | 3 | import numpy as np 4 | cimport cython 5 | cimport numpy as np 6 | import ctypes 7 | 8 | np.import_array() 9 | 10 | 11 | cdef extern from "addToConfusionMatrix_impl.c": 12 | void addToConfusionMatrix( const unsigned char* f_prediction_p , 13 | const unsigned char* f_groundTruth_p , 14 | const unsigned int f_width_i , 15 | const unsigned int f_height_i , 16 | unsigned long long* f_confMatrix_p , 17 | const unsigned int f_confMatDim_i ) 18 | 19 | 20 | cdef tonumpyarray(unsigned long long* data, unsigned long long size): 21 | if not (data and size >= 0): raise ValueError 22 | return np.PyArray_SimpleNewFromData(2, [size, size], np.NPY_UINT64, data) 23 | 24 | @cython.boundscheck(False) 25 | def cEvaluatePair( np.ndarray[np.uint8_t , ndim=2] predictionArr , 26 | np.ndarray[np.uint8_t , ndim=2] groundTruthArr , 27 | np.ndarray[np.uint64_t, ndim=2] confMatrix , 28 | evalLabels ): 29 | cdef np.ndarray[np.uint8_t , ndim=2, mode="c"] predictionArr_c 30 | cdef np.ndarray[np.uint8_t , ndim=2, mode="c"] groundTruthArr_c 31 | cdef np.ndarray[np.ulonglong_t, ndim=2, mode="c"] confMatrix_c 32 | 33 | predictionArr_c = np.ascontiguousarray(predictionArr , dtype=np.uint8 ) 34 | groundTruthArr_c = np.ascontiguousarray(groundTruthArr, dtype=np.uint8 ) 35 | confMatrix_c = np.ascontiguousarray(confMatrix , dtype=np.ulonglong) 36 | 37 | cdef np.uint32_t height_ui = predictionArr.shape[1] 38 | cdef np.uint32_t width_ui = predictionArr.shape[0] 39 | cdef np.uint32_t confMatDim_ui = confMatrix.shape[0] 40 | 41 | addToConfusionMatrix(&predictionArr_c[0,0], &groundTruthArr_c[0,0], height_ui, width_ui, &confMatrix_c[0,0], confMatDim_ui) 42 | 43 | confMatrix = np.ascontiguousarray(tonumpyarray(&confMatrix_c[0,0], confMatDim_ui)) 44 | 45 | return np.copy(confMatrix) 46 | -------------------------------------------------------------------------------- /utils/cityscapes/addToConfusionMatrix_impl.c: -------------------------------------------------------------------------------- 1 | // cython methods to speed-up evaluation 2 | 3 | void addToConfusionMatrix( const unsigned char* f_prediction_p , 4 | const unsigned char* f_groundTruth_p , 5 | const unsigned int f_width_i , 6 | const unsigned int f_height_i , 7 | unsigned long long* f_confMatrix_p , 8 | const unsigned int f_confMatDim_i ) 9 | { 10 | const unsigned int size_ui = f_height_i * f_width_i; 11 | for (unsigned int i = 0; i < size_ui; ++i) 12 | { 13 | const unsigned char predPx = f_prediction_p [i]; 14 | const unsigned char gtPx = f_groundTruth_p[i]; 15 | f_confMatrix_p[f_confMatDim_i*gtPx + predPx] += 1u; 16 | } 17 | } -------------------------------------------------------------------------------- /utils/cityscapes/helpers/annotation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Classes to store, read, and write annotations 4 | # 5 | 6 | import os 7 | import json 8 | from collections import namedtuple 9 | 10 | # get current date and time 11 | import datetime 12 | import locale 13 | 14 | # A point in a polygon 15 | Point = namedtuple('Point', ['x', 'y']) 16 | 17 | # Class that contains the information of a single annotated object 18 | class CsObject: 19 | # Constructor 20 | def __init__(self): 21 | # the label 22 | self.label = "" 23 | # the polygon as list of points 24 | self.polygon = [] 25 | 26 | # the object ID 27 | self.id = -1 28 | # If deleted or not 29 | self.deleted = 0 30 | # If verified or not 31 | self.verified = 0 32 | # The date string 33 | self.date = "" 34 | # The username 35 | self.user = "" 36 | # Draw the object 37 | # Not read from or written to JSON 38 | # Set to False if deleted object 39 | # Might be set to False by the application for other reasons 40 | self.draw = True 41 | 42 | def __str__(self): 43 | polyText = "" 44 | if self.polygon: 45 | if len(self.polygon) <= 4: 46 | for p in self.polygon: 47 | polyText += '({},{}) '.format( p.x , p.y ) 48 | else: 49 | polyText += '({},{}) ({},{}) ... ({},{}) ({},{})'.format( 50 | self.polygon[ 0].x , self.polygon[ 0].y , 51 | self.polygon[ 1].x , self.polygon[ 1].y , 52 | self.polygon[-2].x , self.polygon[-2].y , 53 | self.polygon[-1].x , self.polygon[-1].y ) 54 | else: 55 | polyText = "none" 56 | text = "Object: {} - {}".format( self.label , polyText ) 57 | return text 58 | 59 | def fromJsonText(self, jsonText, objId): 60 | self.id = objId 61 | self.label = str(jsonText['label']) 62 | self.polygon = [ Point(p[0],p[1]) for p in jsonText['polygon'] ] 63 | if 'deleted' in jsonText.keys(): 64 | self.deleted = jsonText['deleted'] 65 | else: 66 | self.deleted = 0 67 | if 'verified' in jsonText.keys(): 68 | self.verified = jsonText['verified'] 69 | else: 70 | self.verified = 1 71 | if 'user' in jsonText.keys(): 72 | self.user = jsonText['user'] 73 | else: 74 | self.user = '' 75 | if 'date' in jsonText.keys(): 76 | self.date = jsonText['date'] 77 | else: 78 | self.date = '' 79 | if self.deleted == 1: 80 | self.draw = False 81 | else: 82 | self.draw = True 83 | 84 | def toJsonText(self): 85 | objDict = {} 86 | objDict['label'] = self.label 87 | objDict['id'] = self.id 88 | objDict['deleted'] = self.deleted 89 | objDict['verified'] = self.verified 90 | objDict['user'] = self.user 91 | objDict['date'] = self.date 92 | objDict['polygon'] = [] 93 | for pt in self.polygon: 94 | objDict['polygon'].append([pt.x, pt.y]) 95 | 96 | return objDict 97 | 98 | def updateDate( self ): 99 | try: 100 | locale.setlocale( locale.LC_ALL , 'en_US' ) 101 | except locale.Error: 102 | locale.setlocale( locale.LC_ALL , 'us_us' ) 103 | except: 104 | pass 105 | self.date = datetime.datetime.now().strftime("%d-%b-%Y %H:%M:%S") 106 | 107 | # Mark the object as deleted 108 | def delete(self): 109 | self.deleted = 1 110 | self.draw = False 111 | 112 | # The annotation of a whole image 113 | class Annotation: 114 | # Constructor 115 | def __init__(self): 116 | # the width of that image and thus of the label image 117 | self.imgWidth = 0 118 | # the height of that image and thus of the label image 119 | self.imgHeight = 0 120 | # the list of objects 121 | self.objects = [] 122 | 123 | def toJson(self): 124 | return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True, indent=4) 125 | 126 | def fromJsonText(self, jsonText): 127 | jsonDict = json.loads(jsonText) 128 | self.imgWidth = int(jsonDict['imgWidth']) 129 | self.imgHeight = int(jsonDict['imgHeight']) 130 | self.objects = [] 131 | for objId, objIn in enumerate(jsonDict[ 'objects' ]): 132 | obj = CsObject() 133 | obj.fromJsonText(objIn, objId) 134 | self.objects.append(obj) 135 | 136 | def toJsonText(self): 137 | jsonDict = {} 138 | jsonDict['imgWidth'] = self.imgWidth 139 | jsonDict['imgHeight'] = self.imgHeight 140 | jsonDict['objects'] = [] 141 | for obj in self.objects: 142 | objDict = obj.toJsonText() 143 | jsonDict['objects'].append(objDict) 144 | 145 | return jsonDict 146 | 147 | # Read a json formatted polygon file and return the annotation 148 | def fromJsonFile(self, jsonFile): 149 | if not os.path.isfile(jsonFile): 150 | print('Given json file not found: {}'.format(jsonFile)) 151 | return 152 | with open(jsonFile, 'r') as f: 153 | jsonText = f.read() 154 | self.fromJsonText(jsonText) 155 | 156 | def toJsonFile(self, jsonFile): 157 | with open(jsonFile, 'w') as f: 158 | f.write(self.toJson()) 159 | 160 | 161 | # a dummy example 162 | if __name__ == "__main__": 163 | obj = CsObject() 164 | obj.label = 'car' 165 | obj.polygon.append( Point( 0 , 0 ) ) 166 | obj.polygon.append( Point( 1 , 0 ) ) 167 | obj.polygon.append( Point( 1 , 1 ) ) 168 | obj.polygon.append( Point( 0 , 1 ) ) 169 | 170 | print(obj) -------------------------------------------------------------------------------- /utils/cityscapes/helpers/csHelpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Various helper methods and includes for Cityscapes 4 | # 5 | 6 | # Python imports 7 | import os, sys, getopt 8 | import glob 9 | import math 10 | import json 11 | from collections import namedtuple 12 | 13 | # Image processing 14 | # Check if PIL is actually Pillow as expected 15 | try: 16 | from PIL import PILLOW_VERSION 17 | except: 18 | print("Please install the module 'Pillow' for image processing, e.g.") 19 | print("pip install pillow") 20 | sys.exit(-1) 21 | 22 | try: 23 | import PIL.Image as Image 24 | import PIL.ImageDraw as ImageDraw 25 | except: 26 | print("Failed to import the image processing packages.") 27 | sys.exit(-1) 28 | 29 | # Numpy for datastructures 30 | try: 31 | import numpy as np 32 | except: 33 | print("Failed to import numpy package.") 34 | sys.exit(-1) 35 | 36 | # Cityscapes modules 37 | try: 38 | from annotation import Annotation 39 | from labels import labels, name2label, id2label, trainId2label, category2labels 40 | except: 41 | print("Failed to find all Cityscapes modules") 42 | sys.exit(-1) 43 | 44 | # Print an error message and quit 45 | def printError(message): 46 | print('ERROR: ' + str(message)) 47 | sys.exit(-1) 48 | 49 | # Class for colors 50 | class colors: 51 | RED = '\033[31;1m' 52 | GREEN = '\033[32;1m' 53 | YELLOW = '\033[33;1m' 54 | BLUE = '\033[34;1m' 55 | MAGENTA = '\033[35;1m' 56 | CYAN = '\033[36;1m' 57 | BOLD = '\033[1m' 58 | UNDERLINE = '\033[4m' 59 | ENDC = '\033[0m' 60 | 61 | # Colored value output if colorized flag is activated. 62 | def getColorEntry(val, args): 63 | if not args.colorized: 64 | return "" 65 | if not isinstance(val, float) or math.isnan(val): 66 | return colors.ENDC 67 | if (val < .20): 68 | return colors.RED 69 | elif (val < .40): 70 | return colors.YELLOW 71 | elif (val < .60): 72 | return colors.BLUE 73 | elif (val < .80): 74 | return colors.CYAN 75 | else: 76 | return colors.GREEN 77 | 78 | # Cityscapes files have a typical filename structure 79 | # ___[_]. 80 | # This class contains the individual elements as members 81 | # For the sequence and frame number, the strings are returned, including leading zeros 82 | CsFile = namedtuple( 'csFile' , [ 'city' , 'sequenceNb' , 'frameNb' , 'type' , 'type2' , 'ext' ] ) 83 | 84 | # Returns a CsFile object filled from the info in the given filename 85 | def getCsFileInfo(fileName): 86 | baseName = os.path.basename(fileName) 87 | parts = baseName.split('_') 88 | parts = parts[:-1] + parts[-1].split('.') 89 | if not parts: 90 | printError( 'Cannot parse given filename ({}). Does not seem to be a valid Cityscapes file.'.format(fileName) ) 91 | if len(parts) == 5: 92 | csFile = CsFile( *parts[:-1] , type2="" , ext=parts[-1] ) 93 | elif len(parts) == 6: 94 | csFile = CsFile( *parts ) 95 | else: 96 | printError( 'Found {} part(s) in given filename ({}). Expected 5 or 6.'.format(len(parts) , fileName) ) 97 | 98 | return csFile 99 | 100 | # Returns the part of Cityscapes filenames that is common to all data types 101 | # e.g. for city_123456_123456_gtFine_polygons.json returns city_123456_123456 102 | def getCoreImageFileName(filename): 103 | csFile = getCsFileInfo(filename) 104 | return "{}_{}_{}".format( csFile.city , csFile.sequenceNb , csFile.frameNb ) 105 | 106 | # Returns the directory name for the given filename, e.g. 107 | # fileName = "/foo/bar/foobar.txt" 108 | # return value is "bar" 109 | # Not much error checking though 110 | def getDirectory(fileName): 111 | dirName = os.path.dirname(fileName) 112 | return os.path.basename(dirName) 113 | 114 | # Make sure that the given path exists 115 | def ensurePath(path): 116 | if not path: 117 | return 118 | if not os.path.isdir(path): 119 | os.makedirs(path) 120 | 121 | # Write a dictionary as json file 122 | def writeDict2JSON(dictName, fileName): 123 | with open(fileName, 'w') as f: 124 | f.write(json.dumps(dictName, default=lambda o: o.__dict__, sort_keys=True, indent=4)) 125 | 126 | # dummy main 127 | if __name__ == "__main__": 128 | printError("Only for include, not executable on its own.") 129 | -------------------------------------------------------------------------------- /utils/cityscapes/helpers/labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Cityscapes labels 4 | # 5 | 6 | from collections import namedtuple 7 | 8 | 9 | #-------------------------------------------------------------------------------- 10 | # Definitions 11 | #-------------------------------------------------------------------------------- 12 | 13 | # a label and all meta information 14 | Label = namedtuple( 'Label' , [ 15 | 16 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 17 | # We use them to uniquely name a class 18 | 19 | 'id' , # An integer ID that is associated with this label. 20 | # The IDs are used to represent the label in ground truth images 21 | # An ID of -1 means that this label does not have an ID and thus 22 | # is ignored when creating ground truth images (e.g. license plate). 23 | # Do not modify these IDs, since exactly these IDs are expected by the 24 | # evaluation server. 25 | 26 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 27 | # ground truth images with train IDs, using the tools provided in the 28 | # 'preparation' folder. However, make sure to validate or submit results 29 | # to our evaluation server using the regular IDs above! 30 | # For trainIds, multiple labels might have the same ID. Then, these labels 31 | # are mapped to the same class in the ground truth images. For the inverse 32 | # mapping, we use the label that is defined first in the list below. 33 | # For example, mapping all void-type classes to the same ID in training, 34 | # might make sense for some approaches. 35 | # Max value is 255! 36 | 37 | 'category' , # The name of the category that this label belongs to 38 | 39 | 'categoryId' , # The ID of this category. Used to create ground truth images 40 | # on category level. 41 | 42 | 'hasInstances', # Whether this label distinguishes between single instances or not 43 | 44 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 45 | # during evaluations or not 46 | 47 | 'color' , # The color of this label 48 | ] ) 49 | 50 | 51 | #-------------------------------------------------------------------------------- 52 | # A list of all labels 53 | #-------------------------------------------------------------------------------- 54 | 55 | # Please adapt the train IDs as appropriate for you approach. 56 | # Note that you might want to ignore labels with ID 255 during training. 57 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 58 | # Make sure to provide your results using the original IDs and not the training IDs. 59 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 60 | 61 | labels = [ 62 | # name id trainId category catId hasInstances ignoreInEval color 63 | Label( 'unlabeled' , 0 , 0 , 'void' , 0 , False , True , ( 0, 0, 0) ), 64 | Label( 'aeroplane' , 1 , 1 , 'object' , 1 , True , False , (128, 0, 0) ), 65 | Label( 'bicycle' , 2 , 2 , 'object' , 1 , True , False , ( 0, 128, 0) ), 66 | Label( 'bird' , 3 , 3 , 'animal' , 2 , True , False , (128, 128, 0) ), 67 | Label( 'boat' , 4 , 4 , 'object' , 1 , True , False , ( 0, 0, 128) ), 68 | Label( 'bottle' , 5 , 5 , 'object' , 1 , True , False , (128, 0, 128) ), 69 | Label( 'bus' , 6 , 6 , 'object' , 1 , True , False , ( 0, 128, 128) ), 70 | Label( 'car' , 7 , 7 , 'object' , 1 , True , False , (128, 128, 128) ), 71 | Label( 'cat' , 8 , 8 , 'animal' , 2 , True , False , ( 64, 0, 0) ), 72 | Label( 'chair' , 9 , 9 , 'object' , 1 , True , False , (192, 0, 0) ), 73 | Label( 'cow' , 10 , 10 , 'animal' , 2 , True , False , ( 64, 128, 0) ), 74 | Label( 'diningtable' , 11 , 11 , 'object' , 1 , True , False , (192, 128, 0) ), 75 | Label( 'dog' , 12 , 12 , 'animal' , 2 , True , False , ( 64, 0, 128) ), 76 | Label( 'horse' , 13 , 13 , 'animal' , 2 , True , False , (192, 0, 128) ), 77 | Label( 'motorbike' , 14 , 14 , 'object' , 1 , True , False , ( 64, 128, 128) ), 78 | Label( 'person' , 15 , 15 , 'animal' , 2 , True , False , (192, 128, 128) ), 79 | Label( 'potted plant' , 16 , 16 , 'object' , 1 , True , False , ( 0, 64, 0) ), 80 | Label( 'sheep' , 17 , 17 , 'animal' , 2 , True , False , (128, 64, 0) ), 81 | Label( 'sofa' , 18 , 18 , 'object' , 1 , True , False , ( 0, 192, 0) ), 82 | Label( 'train' , 19 , 19 , 'object' , 1 , True , False , (128, 192, 0) ), 83 | Label( 'tv/monitor' , 20 , 20 , 'object' , 1 , True , False , ( 0, 64, 128) ) 84 | ] 85 | 86 | ''' 87 | labels = [ 88 | # name id trainId category catId hasInstances ignoreInEval color 89 | Label( 'unlabeled' , 0 , 0 , 'void' , 0 , False , True , (220, 20, 60)), 90 | Label( 'car' , 1 , 1 , 'object' , 0 , True , False , (153,153,153) ), 91 | Label( 'light' , 2 , 2 , 'reflection' , 0 , True , False , (153,153,153) ) 92 | ] 93 | ''' 94 | ''' 95 | labels = [ 96 | # name id trainId category catId hasInstances ignoreInEval color 97 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 98 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 99 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 100 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 101 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 102 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 103 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 104 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 105 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 106 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 107 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 108 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 109 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 110 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 111 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 112 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (153,153,153) ), 113 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 114 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 115 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 116 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 117 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 118 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 119 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 120 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 121 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 122 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 123 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 124 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 125 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 126 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 127 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 128 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 129 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 130 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 131 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 132 | ] 133 | ''' 134 | #-------------------------------------------------------------------------------- 135 | # Create dictionaries for a fast lookup 136 | #-------------------------------------------------------------------------------- 137 | 138 | # Please refer to the main method below for example usages! 139 | 140 | # name to label object 141 | name2label = { label.name : label for label in labels } 142 | # id to label object 143 | id2label = { label.id : label for label in labels } 144 | # trainId to label object 145 | trainId2label = { label.trainId : label for label in reversed(labels) } 146 | # category to list of label objects 147 | category2labels = {} 148 | for label in labels: 149 | category = label.category 150 | if category in category2labels: 151 | category2labels[category].append(label) 152 | else: 153 | category2labels[category] = [label] 154 | 155 | #-------------------------------------------------------------------------------- 156 | # Assure single instance name 157 | #-------------------------------------------------------------------------------- 158 | 159 | # returns the label name that describes a single instance (if possible) 160 | # e.g. input | output 161 | # ---------------------- 162 | # car | car 163 | # cargroup | car 164 | # foo | None 165 | # foogroup | None 166 | # skygroup | None 167 | def assureSingleInstanceName( name ): 168 | # if the name is known, it is not a group 169 | if name in name2label: 170 | return name 171 | # test if the name actually denotes a group 172 | if not name.endswith("group"): 173 | return None 174 | # remove group 175 | name = name[:-len("group")] 176 | # test if the new name exists 177 | if not name in name2label: 178 | return None 179 | # test if the new name denotes a label that actually has instances 180 | if not name2label[name].hasInstances: 181 | return None 182 | # all good then 183 | return name 184 | 185 | #-------------------------------------------------------------------------------- 186 | # Main for testing 187 | #-------------------------------------------------------------------------------- 188 | 189 | # just a dummy main 190 | if __name__ == "__main__": 191 | # Print all the labels 192 | print("List of cityscapes labels:") 193 | print("") 194 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )) 195 | print(" " + ('-' * 98)) 196 | for label in labels: 197 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )) 198 | print("") 199 | 200 | print("Example usages:") 201 | 202 | # Map from name to label 203 | name = 'car' 204 | id = name2label[name].id 205 | print("ID of label '{name}': {id}".format( name=name, id=id )) 206 | 207 | # Map from ID to label 208 | category = id2label[id].category 209 | print("Category of label with ID '{id}': {category}".format( id=id, category=category )) 210 | 211 | # Map from trainID to label 212 | trainId = 0 213 | name = trainId2label[trainId].name 214 | print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )) 215 | -------------------------------------------------------------------------------- /utils/cityscapes/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Enable cython support for eval scripts 4 | # Run as 5 | # setup.py build_ext --inplace 6 | # 7 | # WARNING: Only tested for Ubuntu 64bit OS. 8 | 9 | try: 10 | from distutils.core import setup 11 | from Cython.Build import cythonize 12 | except: 13 | print("Unable to setup. Please use pip to install: cython") 14 | print("sudo pip install cython") 15 | import os 16 | import numpy 17 | 18 | os.environ["CC"] = "g++" 19 | os.environ["CXX"] = "g++" 20 | 21 | setup(ext_modules = cythonize("addToConfusionMatrix.pyx"),include_dirs=[numpy.get_include()]) 22 | -------------------------------------------------------------------------------- /utils/evalIoU.py: -------------------------------------------------------------------------------- 1 | # Code for evaluating IoU - adapted from the Cityscapes scripts (https://github.com/mcordts/cityscapesScripts) 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | from __future__ import print_function 7 | import os, sys 8 | import platform 9 | import fnmatch 10 | 11 | try: 12 | from itertools import izip 13 | except ImportError: 14 | izip = zip 15 | 16 | sys.path.append( './utils/cityscapes/helpers' ) 17 | from csHelpers import * 18 | #import labels 19 | 20 | # Only tested for Ubuntu 64bit OS 21 | CSUPPORT = True 22 | # Check if C-Support is available for better performance 23 | sys.path.append( './utils/cityscapes/' ) 24 | if CSUPPORT: 25 | try: 26 | import addToConfusionMatrix 27 | except: 28 | CSUPPORT = False 29 | 30 | #import cityscapes.labels 31 | 32 | # A dummy class to collect all bunch of data 33 | class CArgs(object): 34 | pass 35 | # And a global object of that class 36 | args = CArgs() 37 | 38 | # Where to look for Cityscapes 39 | if 'CITYSCAPES_DATASET' in os.environ: 40 | args.cityscapesPath = os.environ['CITYSCAPES_DATASET'] 41 | else: 42 | args.cityscapesPath = os.path.join(os.path.dirname(os.path.realpath(__file__)),'..','..') 43 | 44 | # Parameters that should be modified by user 45 | args.exportFile = os.path.join( args.cityscapesPath , "evaluationResults" , "resultPixelLevelSemanticLabeling.json" ) 46 | args.groundTruthSearch = os.path.join( args.cityscapesPath , "gtFine" , "val" , "*", "*_gtFine_labelIds.png" ) 47 | 48 | # Remaining params 49 | args.evalInstLevelScore = False 50 | args.evalPixelAccuracy = False 51 | args.evalLabels = [] 52 | args.printRow = 5 53 | args.normalized = True 54 | args.colorized = hasattr(sys.stderr, "isatty") and sys.stderr.isatty() and platform.system()=='Linux' 55 | args.bold = colors.BOLD if args.colorized else "" 56 | args.nocol = colors.ENDC if args.colorized else "" 57 | args.JSONOutput = True 58 | args.quiet = False 59 | 60 | args.avgClassSize = { 61 | "bicycle" : 4672.3249222261 , 62 | "caravan" : 36771.8241758242 , 63 | "motorcycle" : 6298.7200839748 , 64 | "rider" : 3930.4788056518 , 65 | "bus" : 35732.1511111111 , 66 | "train" : 67583.7075812274 , 67 | "car" : 12794.0202738185 , 68 | "person" : 3462.4756337644 , 69 | "truck" : 27855.1264367816 , 70 | "trailer" : 16926.9763313609 , 71 | } 72 | 73 | 74 | ################################### 75 | 76 | def getPrediction( args, groundTruthFile ): 77 | # determine the prediction path, if the method is first called 78 | if not args.predictionPath: 79 | rootPath = None 80 | if 'CITYSCAPES_RESULTS' in os.environ: 81 | rootPath = os.environ['CITYSCAPES_RESULTS'] 82 | elif 'CITYSCAPES_DATASET' in os.environ: 83 | rootPath = os.path.join( os.environ['CITYSCAPES_DATASET'] , "results" ) 84 | else: 85 | rootPath = os.path.join(os.path.dirname(os.path.realpath(__file__)),'..','..','results') 86 | 87 | if not os.path.isdir(rootPath): 88 | printError("Could not find a result root folder. Please read the instructions of this method.") 89 | 90 | args.predictionPath = rootPath 91 | 92 | # walk the prediction path, if not happened yet 93 | if not args.predictionWalk: 94 | walk = [] 95 | for root, dirnames, filenames in os.walk(args.predictionPath): 96 | walk.append( (root,filenames) ) 97 | args.predictionWalk = walk 98 | 99 | csFile = getCsFileInfo(groundTruthFile) 100 | filePattern = "{}_{}_{}*.png".format( csFile.city , csFile.sequenceNb , csFile.frameNb ) 101 | 102 | predictionFile = None 103 | for root, filenames in args.predictionWalk: 104 | for filename in fnmatch.filter(filenames, filePattern): 105 | if not predictionFile: 106 | predictionFile = os.path.join(root, filename) 107 | else: 108 | printError("Found multiple predictions for ground truth {}".format(groundTruthFile)) 109 | 110 | if not predictionFile: 111 | printError("Found no prediction for ground truth {}".format(groundTruthFile)) 112 | 113 | return predictionFile 114 | 115 | 116 | 117 | # store some parameters for finding predictions in the args variable 118 | # the values are filled when the method getPrediction is first called 119 | args.predictionPath = None 120 | args.predictionWalk = None 121 | 122 | 123 | ######################### 124 | # Methods 125 | ######################### 126 | 127 | 128 | # Generate empty confusion matrix and create list of relevant labels 129 | def generateMatrix(args): 130 | args.evalLabels = [] 131 | for label in labels: 132 | if (label.id < 0): 133 | continue 134 | # we append all found labels, regardless of being ignored 135 | args.evalLabels.append(label.id) 136 | maxId = max(args.evalLabels) 137 | # We use longlong type to be sure that there are no overflows 138 | return np.zeros(shape=(maxId+1, maxId+1),dtype=np.ulonglong) 139 | 140 | def generateMatrixTrainId(args): 141 | args.evalLabels = [] 142 | for label in labels: 143 | if (label.trainId < 0): 144 | continue 145 | # we append all found labels, regardless of being ignored 146 | args.evalLabels.append(label.trainId) 147 | maxId = max(args.evalLabels) 148 | # We use longlong type to be sure that there are no overflows 149 | return np.zeros(shape=(maxId+1, maxId+1),dtype=np.ulonglong) 150 | 151 | def generateInstanceStats(args): 152 | instanceStats = {} 153 | instanceStats["classes" ] = {} 154 | instanceStats["categories"] = {} 155 | for label in labels: 156 | if label.hasInstances and not label.ignoreInEval: 157 | instanceStats["classes"][label.name] = {} 158 | instanceStats["classes"][label.name]["tp"] = 0.0 159 | instanceStats["classes"][label.name]["tpWeighted"] = 0.0 160 | instanceStats["classes"][label.name]["fn"] = 0.0 161 | instanceStats["classes"][label.name]["fnWeighted"] = 0.0 162 | for category in category2labels: 163 | labelIds = [] 164 | allInstances = True 165 | for label in category2labels[category]: 166 | if label.id < 0: 167 | continue 168 | if not label.hasInstances: 169 | allInstances = False 170 | break 171 | labelIds.append(label.id) 172 | if not allInstances: 173 | continue 174 | 175 | instanceStats["categories"][category] = {} 176 | instanceStats["categories"][category]["tp"] = 0.0 177 | instanceStats["categories"][category]["tpWeighted"] = 0.0 178 | instanceStats["categories"][category]["fn"] = 0.0 179 | instanceStats["categories"][category]["fnWeighted"] = 0.0 180 | instanceStats["categories"][category]["labelIds"] = labelIds 181 | 182 | return instanceStats 183 | 184 | 185 | # Get absolute or normalized value from field in confusion matrix. 186 | def getMatrixFieldValue(confMatrix, i, j, args): 187 | if args.normalized: 188 | rowSum = confMatrix[i].sum() 189 | if (rowSum == 0): 190 | return float('nan') 191 | return float(confMatrix[i][j]) / rowSum 192 | else: 193 | return confMatrix[i][j] 194 | 195 | # Calculate and return IOU score for a particular label 196 | def getIouScoreForLabel(label, confMatrix, args): 197 | if id2label[label].ignoreInEval: 198 | return float('nan') 199 | 200 | # the number of true positive pixels for this label 201 | # the entry on the diagonal of the confusion matrix 202 | tp = np.longlong(confMatrix[label,label]) 203 | 204 | # the number of false negative pixels for this label 205 | # the row sum of the matching row in the confusion matrix 206 | # minus the diagonal entry 207 | fn = np.longlong(confMatrix[label,:].sum()) - tp 208 | 209 | # the number of false positive pixels for this labels 210 | # Only pixels that are not on a pixel with ground truth label that is ignored 211 | # The column sum of the corresponding column in the confusion matrix 212 | # without the ignored rows and without the actual label of interest 213 | notIgnored = [l for l in args.evalLabels if not id2label[l].ignoreInEval and not l==label] 214 | fp = np.longlong(confMatrix[notIgnored,label].sum()) 215 | 216 | # the denominator of the IOU score 217 | denom = (tp + fp + fn) 218 | if denom == 0: 219 | return float('nan') 220 | 221 | # return IOU 222 | return float(tp) / denom 223 | 224 | def getIouScoreForTrainLabel(label, confMatrix, args): 225 | if trainId2label[label].ignoreInEval: 226 | return float('nan') 227 | 228 | # the number of true positive pixels for this label 229 | # the entry on the diagonal of the confusion matrix 230 | tp = np.longlong(confMatrix[label,label]) 231 | 232 | # the number of false negative pixels for this label 233 | # the row sum of the matching row in the confusion matrix 234 | # minus the diagonal entry 235 | fn = np.longlong(confMatrix[label,:].sum()) - tp 236 | 237 | # the number of false positive pixels for this labels 238 | # Only pixels that are not on a pixel with ground truth label that is ignored 239 | # The column sum of the corresponding column in the confusion matrix 240 | # without the ignored rows and without the actual label of interest 241 | notIgnored = [l for l in args.evalLabels if not trainId2label[l].ignoreInEval and not l==label] 242 | fp = np.longlong(confMatrix[notIgnored,label].sum()) 243 | 244 | # the denominator of the IOU score 245 | denom = (tp + fp + fn) 246 | if denom == 0: 247 | return float('nan') 248 | 249 | # return IOU 250 | return float(tp) / denom 251 | 252 | # Calculate and return IOU score for a particular label 253 | def getInstanceIouScoreForLabel(label, confMatrix, instStats, args): 254 | if id2label[label].ignoreInEval: 255 | return float('nan') 256 | 257 | labelName = id2label[label].name 258 | if not labelName in instStats["classes"]: 259 | return float('nan') 260 | 261 | tp = instStats["classes"][labelName]["tpWeighted"] 262 | fn = instStats["classes"][labelName]["fnWeighted"] 263 | # false postives computed as above 264 | notIgnored = [l for l in args.evalLabels if not id2label[l].ignoreInEval and not l==label] 265 | fp = np.longlong(confMatrix[notIgnored,label].sum()) 266 | 267 | # the denominator of the IOU score 268 | denom = (tp + fp + fn) 269 | if denom == 0: 270 | return float('nan') 271 | 272 | # return IOU 273 | return float(tp) / denom 274 | 275 | # Calculate prior for a particular class id. 276 | def getPrior(label, confMatrix): 277 | return float(confMatrix[label,:].sum()) / confMatrix.sum() 278 | 279 | # Get average of scores. 280 | # Only computes the average over valid entries. 281 | def getScoreAverage(scoreList, args): 282 | validScores = 0 283 | scoreSum = 0.0 284 | for score in scoreList: 285 | if not math.isnan(scoreList[score]): 286 | validScores += 1 287 | scoreSum += scoreList[score] 288 | if validScores == 0: 289 | return float('nan') 290 | return scoreSum / validScores 291 | 292 | # Calculate and return IOU score for a particular category 293 | def getIouScoreForCategory(category, confMatrix, args): 294 | # All labels in this category 295 | labels = category2labels[category] 296 | # The IDs of all valid labels in this category 297 | labelIds = [label.id for label in labels if not label.ignoreInEval and label.id in args.evalLabels] 298 | # If there are no valid labels, then return NaN 299 | if not labelIds: 300 | return float('nan') 301 | 302 | # the number of true positive pixels for this category 303 | # this is the sum of all entries in the confusion matrix 304 | # where row and column belong to a label ID of this category 305 | tp = np.longlong(confMatrix[labelIds,:][:,labelIds].sum()) 306 | 307 | # the number of false negative pixels for this category 308 | # that is the sum of all rows of labels within this category 309 | # minus the number of true positive pixels 310 | fn = np.longlong(confMatrix[labelIds,:].sum()) - tp 311 | 312 | # the number of false positive pixels for this category 313 | # we count the column sum of all labels within this category 314 | # while skipping the rows of ignored labels and of labels within this category 315 | notIgnoredAndNotInCategory = [l for l in args.evalLabels if not id2label[l].ignoreInEval and id2label[l].category != category] 316 | fp = np.longlong(confMatrix[notIgnoredAndNotInCategory,:][:,labelIds].sum()) 317 | 318 | # the denominator of the IOU score 319 | denom = (tp + fp + fn) 320 | if denom == 0: 321 | return float('nan') 322 | 323 | # return IOU 324 | return float(tp) / denom 325 | 326 | # Calculate and return IOU score for a particular category 327 | def getInstanceIouScoreForCategory(category, confMatrix, instStats, args): 328 | if not category in instStats["categories"]: 329 | return float('nan') 330 | labelIds = instStats["categories"][category]["labelIds"] 331 | 332 | tp = instStats["categories"][category]["tpWeighted"] 333 | fn = instStats["categories"][category]["fnWeighted"] 334 | 335 | # the number of false positive pixels for this category 336 | # same as above 337 | notIgnoredAndNotInCategory = [l for l in args.evalLabels if not id2label[l].ignoreInEval and id2label[l].category != category] 338 | fp = np.longlong(confMatrix[notIgnoredAndNotInCategory,:][:,labelIds].sum()) 339 | 340 | # the denominator of the IOU score 341 | denom = (tp + fp + fn) 342 | if denom == 0: 343 | return float('nan') 344 | 345 | # return IOU 346 | return float(tp) / denom 347 | 348 | 349 | # create a dictionary containing all relevant results 350 | def createResultDict( confMatrix, classScores, classInstScores, categoryScores, categoryInstScores, perImageStats, args ): 351 | # write JSON result file 352 | wholeData = {} 353 | wholeData["confMatrix"] = confMatrix.tolist() 354 | wholeData["priors"] = {} 355 | wholeData["labels"] = {} 356 | for label in args.evalLabels: 357 | wholeData["priors"][id2label[label].name] = getPrior(label, confMatrix) 358 | wholeData["labels"][id2label[label].name] = label 359 | wholeData["classScores"] = classScores 360 | wholeData["classInstScores"] = classInstScores 361 | wholeData["categoryScores"] = categoryScores 362 | wholeData["categoryInstScores"] = categoryInstScores 363 | wholeData["averageScoreClasses"] = getScoreAverage(classScores, args) 364 | wholeData["averageScoreInstClasses"] = getScoreAverage(classInstScores, args) 365 | wholeData["averageScoreCategories"] = getScoreAverage(categoryScores, args) 366 | wholeData["averageScoreInstCategories"] = getScoreAverage(categoryInstScores, args) 367 | 368 | if perImageStats: 369 | wholeData["perImageScores"] = perImageStats 370 | 371 | return wholeData 372 | 373 | def writeJSONFile(wholeData, args): 374 | path = os.path.dirname(args.exportFile) 375 | ensurePath(path) 376 | writeDict2JSON(wholeData, args.exportFile) 377 | 378 | # Print confusion matrix 379 | def printConfMatrix(confMatrix, args): 380 | # print line 381 | print("\b{text:{fill}>{width}}".format(width=15, fill='-', text=" "), end=' ') 382 | for label in args.evalLabels: 383 | print("\b{text:{fill}>{width}}".format(width=args.printRow + 2, fill='-', text=" "), end=' ') 384 | print("\b{text:{fill}>{width}}".format(width=args.printRow + 3, fill='-', text=" ")) 385 | 386 | # print label names 387 | print("\b{text:>{width}} |".format(width=13, text=""), end=' ') 388 | for label in args.evalLabels: 389 | print("\b{text:^{width}} |".format(width=args.printRow, text=id2label[label].name[0]), end=' ') 390 | print("\b{text:>{width}} |".format(width=6, text="Prior")) 391 | 392 | # print line 393 | print("\b{text:{fill}>{width}}".format(width=15, fill='-', text=" "), end=' ') 394 | for label in args.evalLabels: 395 | print("\b{text:{fill}>{width}}".format(width=args.printRow + 2, fill='-', text=" "), end=' ') 396 | print("\b{text:{fill}>{width}}".format(width=args.printRow + 3, fill='-', text=" ")) 397 | 398 | # print matrix 399 | for x in range(0, confMatrix.shape[0]): 400 | if (not x in args.evalLabels): 401 | continue 402 | # get prior of this label 403 | prior = getPrior(x, confMatrix) 404 | # skip if label does not exist in ground truth 405 | if prior < 1e-9: 406 | continue 407 | 408 | # print name 409 | name = id2label[x].name 410 | if len(name) > 13: 411 | name = name[:13] 412 | print("\b{text:>{width}} |".format(width=13,text=name), end=' ') 413 | # print matrix content 414 | for y in range(0, len(confMatrix[x])): 415 | if (not y in args.evalLabels): 416 | continue 417 | matrixFieldValue = getMatrixFieldValue(confMatrix, x, y, args) 418 | print(getColorEntry(matrixFieldValue, args) + "\b{text:>{width}.2f} ".format(width=args.printRow, text=matrixFieldValue) + args.nocol, end=' ') 419 | # print prior 420 | print(getColorEntry(prior, args) + "\b{text:>{width}.4f} ".format(width=6, text=prior) + args.nocol) 421 | # print line 422 | print("\b{text:{fill}>{width}}".format(width=15, fill='-', text=" "), end=' ') 423 | for label in args.evalLabels: 424 | print("\b{text:{fill}>{width}}".format(width=args.printRow + 2, fill='-', text=" "), end=' ') 425 | print("\b{text:{fill}>{width}}".format(width=args.printRow + 3, fill='-', text=" "), end=' ') 426 | 427 | # Print intersection-over-union scores for all classes. 428 | def printClassScores(scoreList, instScoreList, args): 429 | if (args.quiet): 430 | return 431 | print(args.bold + "classes IoU nIoU" + args.nocol) 432 | print("--------------------------------") 433 | for label in args.evalLabels: 434 | if (id2label[label].ignoreInEval): 435 | continue 436 | labelName = str(id2label[label].name) 437 | iouStr = getColorEntry(scoreList[labelName], args) + "{val:>5.3f}".format(val=scoreList[labelName]) + args.nocol 438 | niouStr = getColorEntry(instScoreList[labelName], args) + "{val:>5.3f}".format(val=instScoreList[labelName]) + args.nocol 439 | print("{:<14}: ".format(labelName) + iouStr + " " + niouStr) 440 | 441 | # Print intersection-over-union scores for all categorys. 442 | def printCategoryScores(scoreDict, instScoreDict, args): 443 | if (args.quiet): 444 | return 445 | print(args.bold + "categories IoU nIoU" + args.nocol) 446 | print("--------------------------------") 447 | for categoryName in scoreDict: 448 | if all( label.ignoreInEval for label in category2labels[categoryName] ): 449 | continue 450 | iouStr = getColorEntry(scoreDict[categoryName], args) + "{val:>5.3f}".format(val=scoreDict[categoryName]) + args.nocol 451 | niouStr = getColorEntry(instScoreDict[categoryName], args) + "{val:>5.3f}".format(val=instScoreDict[categoryName]) + args.nocol 452 | print("{:<14}: ".format(categoryName) + iouStr + " " + niouStr) 453 | 454 | # Evaluate image lists pairwise. 455 | def evaluateImgLists(predictionImgList, groundTruthImgList, args): 456 | if len(predictionImgList) != len(groundTruthImgList): 457 | printError("List of images for prediction and groundtruth are not of equal size.") 458 | confMatrix = generateMatrix(args) 459 | instStats = generateInstanceStats(args) 460 | perImageStats = {} 461 | nbPixels = 0 462 | 463 | if not args.quiet: 464 | print("Evaluating {} pairs of images...".format(len(predictionImgList))) 465 | 466 | # Evaluate all pairs of images and save them into a matrix 467 | for i in range(len(predictionImgList)): 468 | predictionImgFileName = predictionImgList[i] 469 | groundTruthImgFileName = groundTruthImgList[i] 470 | #print "Evaluate ", predictionImgFileName, "<>", groundTruthImgFileName 471 | nbPixels += evaluatePair(predictionImgFileName, groundTruthImgFileName, confMatrix, instStats, perImageStats, args) 472 | 473 | # sanity check 474 | if confMatrix.sum() != nbPixels: 475 | printError('Number of analyzed pixels and entries in confusion matrix disagree: contMatrix {}, pixels {}'.format(confMatrix.sum(),nbPixels)) 476 | 477 | if not args.quiet: 478 | print("\rImages Processed: {}".format(i+1), end=' ') 479 | sys.stdout.flush() 480 | if not args.quiet: 481 | print("\n") 482 | 483 | # sanity check 484 | if confMatrix.sum() != nbPixels: 485 | printError('Number of analyzed pixels and entries in confusion matrix disagree: contMatrix {}, pixels {}'.format(confMatrix.sum(),nbPixels)) 486 | 487 | # print confusion matrix 488 | if (not args.quiet): 489 | printConfMatrix(confMatrix, args) 490 | 491 | # Calculate IOU scores on class level from matrix 492 | classScoreList = {} 493 | for label in args.evalLabels: 494 | labelName = id2label[label].name 495 | classScoreList[labelName] = getIouScoreForLabel(label, confMatrix, args) 496 | 497 | # Calculate instance IOU scores on class level from matrix 498 | classInstScoreList = {} 499 | for label in args.evalLabels: 500 | labelName = id2label[label].name 501 | classInstScoreList[labelName] = getInstanceIouScoreForLabel(label, confMatrix, instStats, args) 502 | 503 | # Print IOU scores 504 | if (not args.quiet): 505 | print("") 506 | print("") 507 | printClassScores(classScoreList, classInstScoreList, args) 508 | iouAvgStr = getColorEntry(getScoreAverage(classScoreList, args), args) + "{avg:5.3f}".format(avg=getScoreAverage(classScoreList, args)) + args.nocol 509 | niouAvgStr = getColorEntry(getScoreAverage(classInstScoreList , args), args) + "{avg:5.3f}".format(avg=getScoreAverage(classInstScoreList , args)) + args.nocol 510 | print("--------------------------------") 511 | print("Score Average : " + iouAvgStr + " " + niouAvgStr) 512 | print("--------------------------------") 513 | print("") 514 | 515 | # Calculate IOU scores on category level from matrix 516 | categoryScoreList = {} 517 | for category in category2labels.keys(): 518 | categoryScoreList[category] = getIouScoreForCategory(category,confMatrix,args) 519 | 520 | # Calculate instance IOU scores on category level from matrix 521 | categoryInstScoreList = {} 522 | for category in category2labels.keys(): 523 | categoryInstScoreList[category] = getInstanceIouScoreForCategory(category,confMatrix,instStats,args) 524 | 525 | # Print IOU scores 526 | if (not args.quiet): 527 | print("") 528 | printCategoryScores(categoryScoreList, categoryInstScoreList, args) 529 | iouAvgStr = getColorEntry(getScoreAverage(categoryScoreList, args), args) + "{avg:5.3f}".format(avg=getScoreAverage(categoryScoreList, args)) + args.nocol 530 | niouAvgStr = getColorEntry(getScoreAverage(categoryInstScoreList, args), args) + "{avg:5.3f}".format(avg=getScoreAverage(categoryInstScoreList, args)) + args.nocol 531 | print("--------------------------------") 532 | print("Score Average : " + iouAvgStr + " " + niouAvgStr) 533 | print("--------------------------------") 534 | print("") 535 | 536 | # write result file 537 | allResultsDict = createResultDict( confMatrix, classScoreList, classInstScoreList, categoryScoreList, categoryInstScoreList, perImageStats, args ) 538 | writeJSONFile( allResultsDict, args) 539 | 540 | # return confusion matrix 541 | return allResultsDict 542 | 543 | # Main evaluation method. Evaluates pairs of prediction and ground truth 544 | # images which are passed as arguments. 545 | def evaluatePair(predictionImgFileName, groundTruthImgFileName, confMatrix, instanceStats, perImageStats, args): 546 | # Loading all resources for evaluation. 547 | try: 548 | predictionImg = Image.open(predictionImgFileName) 549 | predictionNp = np.array(predictionImg) 550 | except: 551 | printError("Unable to load " + predictionImgFileName) 552 | try: 553 | groundTruthImg = Image.open(groundTruthImgFileName) 554 | groundTruthNp = np.array(groundTruthImg) 555 | except: 556 | printError("Unable to load " + groundTruthImgFileName) 557 | # load ground truth instances, if needed 558 | if args.evalInstLevelScore: 559 | groundTruthInstanceImgFileName = groundTruthImgFileName.replace("labelIds","instanceIds") 560 | try: 561 | instanceImg = Image.open(groundTruthInstanceImgFileName) 562 | instanceNp = np.array(instanceImg) 563 | except: 564 | printError("Unable to load " + groundTruthInstanceImgFileName) 565 | 566 | # Check for equal image sizes 567 | if (predictionImg.size[0] != groundTruthImg.size[0]): 568 | printError("Image widths of " + predictionImgFileName + " and " + groundTruthImgFileName + " are not equal.") 569 | if (predictionImg.size[1] != groundTruthImg.size[1]): 570 | printError("Image heights of " + predictionImgFileName + " and " + groundTruthImgFileName + " are not equal.") 571 | if ( len(predictionNp.shape) != 2 ): 572 | printError("Predicted image has multiple channels.") 573 | 574 | imgWidth = predictionImg.size[0] 575 | imgHeight = predictionImg.size[1] 576 | nbPixels = imgWidth*imgHeight 577 | 578 | # Evaluate images 579 | if (CSUPPORT): 580 | # using cython 581 | confMatrix = addToConfusionMatrix.cEvaluatePair(predictionNp, groundTruthNp, confMatrix, args.evalLabels) 582 | else: 583 | # the slower python way 584 | for (groundTruthImgPixel,predictionImgPixel) in izip(groundTruthImg.getdata(),predictionImg.getdata()): 585 | if (not groundTruthImgPixel in args.evalLabels): 586 | printError("Unknown label with id {:}".format(groundTruthImgPixel)) 587 | 588 | confMatrix[groundTruthImgPixel][predictionImgPixel] += 1 589 | 590 | if args.evalInstLevelScore: 591 | # Generate category masks 592 | categoryMasks = {} 593 | for category in instanceStats["categories"]: 594 | categoryMasks[category] = np.in1d( predictionNp , instanceStats["categories"][category]["labelIds"] ).reshape(predictionNp.shape) 595 | 596 | instList = np.unique(instanceNp[instanceNp > 1000]) 597 | for instId in instList: 598 | labelId = int(instId/1000) 599 | label = id2label[ labelId ] 600 | if label.ignoreInEval: 601 | continue 602 | 603 | mask = instanceNp==instId 604 | instSize = np.count_nonzero( mask ) 605 | 606 | tp = np.count_nonzero( predictionNp[mask] == labelId ) 607 | fn = instSize - tp 608 | 609 | weight = args.avgClassSize[label.name] / float(instSize) 610 | tpWeighted = float(tp) * weight 611 | fnWeighted = float(fn) * weight 612 | 613 | instanceStats["classes"][label.name]["tp"] += tp 614 | instanceStats["classes"][label.name]["fn"] += fn 615 | instanceStats["classes"][label.name]["tpWeighted"] += tpWeighted 616 | instanceStats["classes"][label.name]["fnWeighted"] += fnWeighted 617 | 618 | category = label.category 619 | if category in instanceStats["categories"]: 620 | catTp = 0 621 | catTp = np.count_nonzero( np.logical_and( mask , categoryMasks[category] ) ) 622 | catFn = instSize - catTp 623 | 624 | catTpWeighted = float(catTp) * weight 625 | catFnWeighted = float(catFn) * weight 626 | 627 | instanceStats["categories"][category]["tp"] += catTp 628 | instanceStats["categories"][category]["fn"] += catFn 629 | instanceStats["categories"][category]["tpWeighted"] += catTpWeighted 630 | instanceStats["categories"][category]["fnWeighted"] += catFnWeighted 631 | 632 | if args.evalPixelAccuracy: 633 | notIgnoredLabels = [l for l in args.evalLabels if not id2label[l].ignoreInEval] 634 | notIgnoredPixels = np.in1d( groundTruthNp , notIgnoredLabels , invert=True ).reshape(groundTruthNp.shape) 635 | erroneousPixels = np.logical_and( notIgnoredPixels , ( predictionNp != groundTruthNp ) ) 636 | perImageStats[predictionImgFileName] = {} 637 | perImageStats[predictionImgFileName]["nbNotIgnoredPixels"] = np.count_nonzero(notIgnoredPixels) 638 | perImageStats[predictionImgFileName]["nbCorrectPixels"] = np.count_nonzero(erroneousPixels) 639 | 640 | return nbPixels 641 | 642 | def evaluatePairPytorch(prediction, groundtruth, confMatrix, perImageStats, args): 643 | # Loading all resources for evaluation. 644 | 645 | predictionImg = prediction 646 | predictionNp = np.array(predictionImg) 647 | 648 | groundTruthImg = groundtruth 649 | groundTruthNp = np.array(groundTruthImg) 650 | 651 | 652 | # Check for equal image sizes 653 | if (predictionImg.size[0] != groundTruthImg.size[0]): 654 | printError("Image widths are not equal.") 655 | if (predictionImg.size[1] != groundTruthImg.size[1]): 656 | printError("Image heights are not equal.") 657 | if ( len(predictionNp.shape) != 2 ): 658 | printError("Predicted image has multiple channels.") 659 | 660 | imgWidth = predictionImg.size[0] 661 | imgHeight = predictionImg.size[1] 662 | nbPixels = imgWidth*imgHeight 663 | 664 | # Evaluate images 665 | if (CSUPPORT): 666 | # using cython 667 | confMatrix = addToConfusionMatrix.cEvaluatePair(predictionNp, groundTruthNp, confMatrix, args.evalLabels) 668 | else: 669 | # the slower python way 670 | for (groundTruthImgPixel,predictionImgPixel) in izip(groundTruthImg.getdata(),predictionImg.getdata()): 671 | if (not groundTruthImgPixel in args.evalLabels): 672 | printError("Unknown label with id {:}".format(groundTruthImgPixel)) 673 | 674 | confMatrix[groundTruthImgPixel][predictionImgPixel] += 1 675 | 676 | 677 | if args.evalPixelAccuracy: 678 | notIgnoredLabels = [l for l in args.evalLabels if not id2label[l].ignoreInEval] 679 | notIgnoredPixels = np.in1d( groundTruthNp , notIgnoredLabels , invert=True ).reshape(groundTruthNp.shape) 680 | erroneousPixels = np.logical_and( notIgnoredPixels , ( predictionNp != groundTruthNp ) ) 681 | perImageStats[predictionImgFileName] = {} 682 | perImageStats[predictionImgFileName]["nbNotIgnoredPixels"] = np.count_nonzero(notIgnoredPixels) 683 | perImageStats[predictionImgFileName]["nbCorrectPixels"] = np.count_nonzero(erroneousPixels) 684 | 685 | return nbPixels 686 | 687 | # Print intersection-over-union scores for all classes. 688 | def printClassScoresPytorch(scoreList, args): 689 | if (args.quiet): 690 | return 691 | print(args.bold + "classes IoU nIoU" + args.nocol) 692 | print("--------------------------------") 693 | for label in args.evalLabels: 694 | if (id2label[label].ignoreInEval): 695 | continue 696 | labelName = str(id2label[label].name) 697 | iouStr = getColorEntry(scoreList[labelName], args) + "{val:>5.3f}".format(val=scoreList[labelName]) + args.nocol 698 | #niouStr = getColorEntry(instScoreList[labelName], args) + "{val:>5.3f}".format(val=instScoreList[labelName]) + args.nocol 699 | print("{:<14}: ".format(labelName) + iouStr )#+ " " + niouStr) 700 | 701 | # Print intersection-over-union scores for all classes. 702 | def printClassScoresPytorchTrain(scoreList, args): 703 | if (args.quiet): 704 | return 705 | print(args.bold + "classes IoU nIoU" + args.nocol) 706 | print("--------------------------------") 707 | for label in args.evalLabels: 708 | if (trainId2label[label].ignoreInEval): 709 | continue 710 | labelName = str(trainId2label[label].name) 711 | iouStr = getColorEntry(scoreList[labelName], args) + "{val:>5.3f}".format(val=scoreList[labelName]) + args.nocol 712 | #niouStr = getColorEntry(instScoreList[labelName], args) + "{val:>5.3f}".format(val=instScoreList[labelName]) + args.nocol 713 | print("{:<14}: ".format(labelName) + iouStr )#+ " " + niouStr) 714 | -------------------------------------------------------------------------------- /utils/eval_weight.py: -------------------------------------------------------------------------------- 1 | from PIL import Image,ImageOps 2 | import os 3 | import numpy as np 4 | 5 | 6 | imagefile = [] 7 | with open('./label.txt_all','r') as f: 8 | for line in f: 9 | imagefile.append(line.strip().replace('\n','')) 10 | 11 | 12 | 13 | s = 0 14 | light = 0 15 | car = 0 16 | count=0 17 | background = 0 18 | print(len(imagefile)) 19 | for i in imagefile: 20 | img = np.array(Image.open(i).convert('P')) 21 | count+=1 22 | background += np.sum(img==0) 23 | 24 | car += np.sum(img==1) 25 | light += np.sum(img==2) 26 | 27 | s += (img.shape[0]*img.shape[1]) 28 | print(count) 29 | 30 | 31 | print(background/s,light/s,car/s) 32 | weight_back = 1/np.log(1+background/s) 33 | weight_car = 1/np.log(1+car/s) 34 | weight_light = 1/np.log(1+light/s) 35 | print(weight_back,weight_car,weight_light) #1,1345.54,4779.0 -------------------------------------------------------------------------------- /utils/label2Img.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import io 3 | import scipy.misc 4 | import numpy as np 5 | from PIL import Image 6 | import base64 7 | def label_colormap(N=256): 8 | 9 | def bitget(byteval, idx): 10 | return ((byteval & (1 << idx)) != 0) 11 | 12 | cmap = np.zeros((N, 3)) 13 | for i in range(0, N): 14 | id = i 15 | r, g, b = 0, 0, 0 16 | for j in range(0, 8): 17 | r = np.bitwise_or(r, (bitget(id, 0) << 7-j)) 18 | g = np.bitwise_or(g, (bitget(id, 1) << 7-j)) 19 | b = np.bitwise_or(b, (bitget(id, 2) << 7-j)) 20 | id = (id >> 3) 21 | cmap[i, 0] = r 22 | cmap[i, 1] = g 23 | cmap[i, 2] = b 24 | cmap = cmap.astype(np.float32) / 255 25 | return cmap 26 | 27 | def label2rgb(lbl, img=None, n_labels=None, alpha=0.3, thresh_suppress=0): 28 | if n_labels is None: 29 | n_labels = len(np.unique(lbl)) 30 | cmap = label_colormap(n_labels) 31 | cmap = (cmap * 255).astype(np.uint8) 32 | #change color to deep 33 | cmap[1,0] = 255.0 34 | try : 35 | cmap[2,1] = 255.0 36 | except IndexError as e: 37 | pass 38 | lbl_viz = cmap[lbl] 39 | lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled 40 | 41 | if img is not None: 42 | img_gray = PIL.Image.fromarray(img).convert('LA') 43 | img_gray = np.asarray(img_gray.convert('RGB')) 44 | # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 45 | # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB) 46 | lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray 47 | lbl_viz = lbl_viz.astype(np.uint8) 48 | 49 | return lbl_viz 50 | 51 | def draw_label(label, img, label_names, colormap=None): 52 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0, 53 | wspace=0, hspace=0) 54 | plt.margins(0, 0) 55 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 56 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 57 | 58 | if colormap is None: 59 | colormap = label_colormap(len(label_names)) 60 | 61 | label_viz = label2rgb(label, img, n_labels=len(label_names)) 62 | plt.imshow(label_viz) 63 | plt.axis('off') 64 | 65 | plt_handlers = [] 66 | plt_titles = [] 67 | for label_value, label_name in enumerate(label_names): 68 | fc = colormap[label_value] 69 | p = plt.Rectangle((0, 0), 1, 1, fc=fc) 70 | plt_handlers.append(p) 71 | plt_titles.append(label_name) 72 | plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5) 73 | 74 | f = io.BytesIO() 75 | plt.savefig(f, bbox_inches='tight', pad_inches=0) 76 | plt.cla() 77 | plt.close() 78 | 79 | out = np.array(PIL.Image.open(f))[:, :, :3] 80 | out = scipy.misc.imresize(out, img.shape[:2]) 81 | return out 82 | --------------------------------------------------------------------------------