├── LICENSE ├── README.md ├── __pycache__ ├── args.cpython-36.pyc ├── test.cpython-36.pyc ├── train.cpython-36.pyc ├── transforms.cpython-36.pyc └── utils.cpython-36.pyc ├── args.py ├── data ├── Cityscapes │ ├── gtFine_trainvaltest │ └── leftImg8bit_trainvaltest ├── README.md ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── camvid.cpython-36.pyc │ ├── cityscapes.cpython-36.pyc │ └── utils.cpython-36.pyc ├── camvid.py ├── cityscapes.py └── utils.py ├── main.py ├── metric ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── confusionmatrix.cpython-36.pyc │ ├── iou.cpython-36.pyc │ └── metric.cpython-36.pyc ├── confusionmatrix.py ├── iou.py └── metric.py ├── models ├── __pycache__ │ └── rpnet.cpython-36.pyc └── rpnet.py ├── pictures ├── 1.png └── 2.png ├── requirements.txt ├── save ├── RPNet └── RPNet_summary.txt ├── test.py ├── train.py ├── train.sh ├── transforms.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 davidtvs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-RPNet 2 | 3 | PyTorch implementation of [*Residual Pyramid Learning for Single-Shot Semantic Segmentation*](https://arxiv.org/abs/1903.09746) 4 | 5 | ### Network 6 | ![image_1](pictures/1.png) 7 | 8 | ### Segmentation Result 9 |

10 |
11 |

12 | 13 | ### Citing RPNet 14 | 15 | Please cite RPNet in your publications if it helps your research: 16 | 17 | @ARTICLE{8744483, 18 | author={X. {Chen} and X. {Lou} and L. {Bai} and J. {Han}}, 19 | journal={IEEE Transactions on Intelligent Transportation Systems}, 20 | title={Residual Pyramid Learning for Single-Shot Semantic Segmentation}, 21 | year={2019}, 22 | volume={}, 23 | number={}, 24 | pages={1-11}, 25 | keywords={Intelligent vehicles;real-time vision;scene understanding;residual learning.}, 26 | doi={10.1109/TITS.2019.2922252}, 27 | ISSN={1524-9050}, 28 | month={},} 29 | 30 | 31 | 32 | This implementation has been tested on the Cityscapes and CamVid (TBD) datasets. Currently, a pre-trained version of the model trained in CamVid and Cityscapes is available [here](https://github.com/superlxt/RPnet-Pytorch/tree/master/save). 33 | 34 | 35 | | Dataset | Classes 1 | Input resolution | Batch size | Epochs | Mean IoU (%) | 36 | |:--------------------------------------------------------------------:|:--------------------:|:----------------:|:----------:|:----------:|:-----------------:| 37 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) | 11 | 480x360 | 3 | 300 x 4 | 64.82 | 38 | | [Cityscapes](https://www.cityscapes-dataset.com/) | 19 | 1024x512 | 3 | 300 x 4 | 70.4 (val) | 39 | 40 | * When referring to the number of classes, the void/unlabeled class is always excluded.
41 | * Just for reference since changes in implementation, datasets, and hardware can lead to very different results. Reference hardware: Nvidia GTX 1080ti and an Intel Core i7-7920x 2.9GHz. 42 | 43 | 44 | The results on Cityscapes dataset 45 | 46 | 47 | | Method | Input Size | Mean IoU | Mean iIoU | fps | FLOPs | 48 | |:----------------:|:--------------:|:------------:|:-------------:|:--------:|:----------:| 49 | | ENet | 1024*512 | 58.3 | 34.4 | 77 | 4.03B | 50 | | ERFNet | 1024*512 | 68.0 | 40.4 | 59 | 25.6B | 51 | | ESPNet | 1024*512 | 60.3 | 31.8 | 139 | 3.19B | 52 | | BiSeNet | 1036*768 | 68.4 | - | 69 | 26.4B | 53 | | ICNet | 2048*1024 | 69.5 | - | 30 | - | 54 | |DeepLab(Mobilenet)| 2048*1024 | 70.71(val) | - | 16 | 21.3B | 55 | | LRR | 2048*1024 | 69.7 | 48.0 | 2 | - | 56 | | RefinNet | 2048*1024 | 73.6 | 47.2 | - | 263B | 57 | | **RPNet(ENet)** | 1024*512 | 63.37 | 39.0 | 88 | 4.28B | 58 | | **RPNet(ERFNet)**| 1024*512 | 67.9 | 44.9 | 123 | 20.7B | 59 | 60 | 61 | 62 | ## Installation 63 | 64 | 1. Python 3 and pip. 65 | 2. Set up a virtual environment (optional, but recommended). 66 | 3. Install dependencies using pip: ``pip install -r requirements.txt``. 67 | 68 | 69 | 70 | ### Examples: Training 71 | 72 | ``` 73 | sh train.sh 74 | ``` 75 | 76 | ### Examples: Testing 77 | 78 | ``` 79 | python main.py -m test --step 1 80 | ``` 81 | 82 | -------------------------------------------------------------------------------- /__pycache__/args.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/args.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def get_arguments(): 5 | """Defines command-line arguments, and parses them. 6 | 7 | """ 8 | parser = ArgumentParser() 9 | 10 | # Execution mode 11 | parser.add_argument( 12 | "--mode", 13 | "-m", 14 | choices=['train', 'test', 'full'], 15 | default='train', 16 | help=("train: performs training and validation; test: tests the model " 17 | "found in \"--save_dir\" with name \"--name\" on \"--dataset\"; " 18 | "full: combines train and test modes. Default: train")) 19 | parser.add_argument( 20 | "--resume", 21 | action='store_true', 22 | help=("The model found in \"--checkpoint_dir/--name/\" and filename " 23 | "\"--name.h5\" is loaded.")) 24 | 25 | # Hyperparameters 26 | parser.add_argument( 27 | "--batch-size", 28 | "-b", 29 | type=int, 30 | default=3, 31 | help="The batch size. Default: 10") 32 | parser.add_argument( 33 | "--epochs", 34 | type=int, 35 | default=300, 36 | help="Number of training epochs. Default: 300") 37 | parser.add_argument( 38 | "--learning-rate", 39 | "-lr", 40 | type=float, 41 | default=5e-4, 42 | help="The learning rate. Default: 5e-4") 43 | parser.add_argument( 44 | "--lr-decay", 45 | type=float, 46 | default=0.1, 47 | help="The learning rate decay factor. Default: 0.5") 48 | parser.add_argument( 49 | "--lr-decay-epochs", 50 | type=int, 51 | default=100, 52 | help="The number of epochs before adjusting the learning rate. " 53 | "Default: 100") 54 | parser.add_argument( 55 | "--weight-decay", 56 | "-wd", 57 | type=float, 58 | default=1e-4, 59 | help="L2 regularization factor. Default: 2e-4") 60 | 61 | # Dataset 62 | parser.add_argument( 63 | "--dataset", 64 | choices=['camvid', 'cityscapes'], 65 | default='cityscapes', 66 | help="Dataset to use. Default: cityscapes") 67 | parser.add_argument( 68 | "--dataset-dir", 69 | type=str, 70 | default="data/Cityscapes", 71 | help="Path to the root directory of the selected dataset. " 72 | "Default: data/Cityscapes") 73 | parser.add_argument( 74 | "--height", 75 | type=int, 76 | default=512, 77 | help="The image height. Default: 1024") 78 | parser.add_argument( 79 | "--width", 80 | type=int, 81 | default=1024, 82 | help="The image height. Default: 2048") 83 | parser.add_argument( 84 | "--weighing", 85 | choices=['enet', 'mfb', 'none'], 86 | default='mfb', 87 | help="The class weighing technique to apply to the dataset. " 88 | "Default: enet") 89 | parser.add_argument( 90 | "--with-unlabeled", 91 | dest='ignore_unlabeled', 92 | action='store_false', 93 | help="The unlabeled class is not ignored.") 94 | 95 | # Step 96 | parser.add_argument( 97 | "--step", 98 | type=int, 99 | default='1', 100 | help="Step to choose. Default: 1") 101 | 102 | # Settings 103 | parser.add_argument( 104 | "--workers", 105 | type=int, 106 | default=4, 107 | help="Number of subprocesses to use for data loading. Default: 4") 108 | parser.add_argument( 109 | "--print-step", 110 | action='store_true', 111 | help="Print loss every step") 112 | parser.add_argument( 113 | "--imshow-batch", 114 | action='store_true', 115 | help=("Displays batch images when loading the dataset and making " 116 | "predictions.")) 117 | parser.add_argument( 118 | "--no-cuda", 119 | dest='cuda', 120 | default='store_false', 121 | help="CPU only.") 122 | 123 | # Storage settings 124 | parser.add_argument( 125 | "--name", 126 | type=str, 127 | default='RPNet', 128 | help="Name given to the model when saving. Default: RPNet") 129 | parser.add_argument( 130 | "--save-dir", 131 | type=str, 132 | default='save', 133 | help="The directory where models are saved. Default: save") 134 | 135 | return parser.parse_args() 136 | -------------------------------------------------------------------------------- /data/Cityscapes/gtFine_trainvaltest: -------------------------------------------------------------------------------- 1 | /media/lxt/CVPR/Fenge/PyTorch-ENet/data/Cityscapes/gtFine_trainvaltest -------------------------------------------------------------------------------- /data/Cityscapes/leftImg8bit_trainvaltest: -------------------------------------------------------------------------------- 1 | /media/lxt/CVPR/Fenge/PyTorch-ENet/data/Cityscapes/leftImg8bit_trainvaltest -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Supported datasets 2 | 3 | - CamVid 4 | - CityScapes 5 | 6 | Note: When referring to the number of classes, the void/unlabeled class is excluded. 7 | 8 | ## CamVid Dataset 9 | 10 | The Cambridge-driving Labeled Video Database (CamVid) is a collection of over ten minutes of high-quality 30Hz footage with object class semantic labels at 1Hz and in part, 15Hz. Each pixel is associated with one of 32 classes. 11 | 12 | The CamVid dataset supported here is a 12 class version developed by the authors of SegNet. [Download link here](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). For actual training, an 11 class version is used - the "road marking" class is combined with the "road" class. 13 | 14 | More detailed information about the CamVid dataset can be found [here](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) and on the [SegNet GitHub repository](https://github.com/alexgkendall/SegNet-Tutorial). 15 | 16 | ## Cityscapes 17 | 18 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated. 19 | 20 | The version supported here is the finely annotated one with 19 classes. 21 | 22 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts). 23 | 24 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data. 25 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .camvid import CamVid 2 | from .cityscapes import Cityscapes 3 | 4 | __all__ = ['CamVid', 'Cityscapes'] 5 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/camvid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/camvid.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/cityscapes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/cityscapes.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /data/camvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch.utils.data as data 4 | from . import utils 5 | 6 | 7 | class CamVid(data.Dataset): 8 | """CamVid dataset loader where the dataset is arranged as in 9 | https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. 10 | 11 | 12 | Keyword arguments: 13 | - root_dir (``string``): Root directory path. 14 | - mode (``string``): The type of dataset: 'train' for training set, 'val' 15 | for validation set, and 'test' for test set. 16 | - transform (``callable``, optional): A function/transform that takes in 17 | an PIL image and returns a transformed version. Default: None. 18 | - label_transform (``callable``, optional): A function/transform that takes 19 | in the target and transforms it. Default: None. 20 | - loader (``callable``, optional): A function to load an image given its 21 | path. By default ``default_loader`` is used. 22 | 23 | """ 24 | # Training dataset root folders 25 | train_folder = 'train' 26 | train_lbl_folder = 'trainannot' 27 | 28 | # Validation dataset root folders 29 | val_folder = 'val' 30 | val_lbl_folder = 'valannot' 31 | 32 | # Test dataset root folders 33 | test_folder = 'test' 34 | test_lbl_folder = 'testannot' 35 | 36 | # Images extension 37 | img_extension = '.png' 38 | 39 | # Default encoding for pixel value, class name, and class color 40 | color_encoding = OrderedDict([ 41 | ('sky', (128, 128, 128)), 42 | ('building', (128, 0, 0)), 43 | ('pole', (192, 192, 128)), 44 | ('road_marking', (255, 69, 0)), 45 | ('road', (128, 64, 128)), 46 | ('pavement', (60, 40, 222)), 47 | ('tree', (128, 128, 0)), 48 | ('sign_symbol', (192, 128, 128)), 49 | ('fence', (64, 64, 128)), 50 | ('car', (64, 0, 128)), 51 | ('pedestrian', (64, 64, 0)), 52 | ('bicyclist', (0, 128, 192)), 53 | ('unlabeled', (0, 0, 0)) 54 | ]) 55 | 56 | def __init__(self, 57 | root_dir, 58 | mode='train', 59 | transform=None, 60 | label_transform=None, 61 | loader=utils.pil_loader): 62 | self.root_dir = root_dir 63 | self.mode = mode 64 | self.transform = transform 65 | self.label_transform = label_transform 66 | self.loader = loader 67 | 68 | if self.mode.lower() == 'train': 69 | # Get the training data and labels filepaths 70 | self.train_data = utils.get_files( 71 | os.path.join(root_dir, self.train_folder), 72 | extension_filter=self.img_extension) 73 | 74 | self.train_labels = utils.get_files( 75 | os.path.join(root_dir, self.train_lbl_folder), 76 | extension_filter=self.img_extension) 77 | elif self.mode.lower() == 'val': 78 | # Get the validation data and labels filepaths 79 | self.val_data = utils.get_files( 80 | os.path.join(root_dir, self.val_folder), 81 | extension_filter=self.img_extension) 82 | 83 | self.val_labels = utils.get_files( 84 | os.path.join(root_dir, self.val_lbl_folder), 85 | extension_filter=self.img_extension) 86 | elif self.mode.lower() == 'test': 87 | # Get the test data and labels filepaths 88 | self.test_data = utils.get_files( 89 | os.path.join(root_dir, self.test_folder), 90 | extension_filter=self.img_extension) 91 | 92 | self.test_labels = utils.get_files( 93 | os.path.join(root_dir, self.test_lbl_folder), 94 | extension_filter=self.img_extension) 95 | else: 96 | raise RuntimeError("Unexpected dataset mode. " 97 | "Supported modes are: train, val and test") 98 | 99 | def __getitem__(self, index): 100 | """ 101 | Args: 102 | - index (``int``): index of the item in the dataset 103 | 104 | Returns: 105 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth 106 | of the image. 107 | 108 | """ 109 | if self.mode.lower() == 'train': 110 | data_path, label_path = self.train_data[index], self.train_labels[ 111 | index] 112 | elif self.mode.lower() == 'val': 113 | data_path, label_path = self.val_data[index], self.val_labels[ 114 | index] 115 | elif self.mode.lower() == 'test': 116 | data_path, label_path = self.test_data[index], self.test_labels[ 117 | index] 118 | else: 119 | raise RuntimeError("Unexpected dataset mode. " 120 | "Supported modes are: train, val and test") 121 | 122 | img, label = self.loader(data_path, label_path) 123 | 124 | if self.transform is not None: 125 | img = self.transform(img) 126 | 127 | if self.label_transform is not None: 128 | label = self.label_transform(label) 129 | 130 | return img, label 131 | 132 | def __len__(self): 133 | """Returns the length of the dataset.""" 134 | if self.mode.lower() == 'train': 135 | return len(self.train_data) 136 | elif self.mode.lower() == 'val': 137 | return len(self.val_data) 138 | elif self.mode.lower() == 'test': 139 | return len(self.test_data) 140 | else: 141 | raise RuntimeError("Unexpected dataset mode. " 142 | "Supported modes are: train, val and test") 143 | -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch.utils.data as data 4 | from . import utils 5 | 6 | 7 | class Cityscapes(data.Dataset): 8 | """Cityscapes dataset https://www.cityscapes-dataset.com/. 9 | 10 | Keyword arguments: 11 | - root_dir (``string``): Root directory path. 12 | - mode (``string``): The type of dataset: 'train' for training set, 'val' 13 | for validation set, and 'test' for test set. 14 | - transform (``callable``, optional): A function/transform that takes in 15 | an PIL image and returns a transformed version. Default: None. 16 | - label_transform (``callable``, optional): A function/transform that takes 17 | in the target and transforms it. Default: None. 18 | - loader (``callable``, optional): A function to load an image given its 19 | path. By default ``default_loader`` is used. 20 | 21 | """ 22 | # Training dataset root folders 23 | train_folder = "leftImg8bit_trainvaltest/leftImg8bit/train" 24 | train_lbl_folder = "gtFine_trainvaltest/gtFine/train" 25 | 26 | # Validation dataset root folders 27 | val_folder = "leftImg8bit_trainvaltest/leftImg8bit/val" 28 | val_lbl_folder = "gtFine_trainvaltest/gtFine/val" 29 | 30 | # Test dataset root folders 31 | test_folder = "leftImg8bit_trainvaltest/leftImg8bit/val" 32 | test_lbl_folder = "gtFine_trainvaltest/gtFine/val" 33 | 34 | # Filters to find the images 35 | img_extension = '.png' 36 | lbl_name_filter = 'labelIds' 37 | 38 | # The values associated with the 35 classes 39 | full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 40 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 41 | 32, 33, -1) 42 | # The values above are remapped to the following 43 | new_classes = (0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 5, 0, 0, 0, 6, 0, 7, 44 | 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 17, 18, 19, 0) 45 | 46 | # Default encoding for pixel value, class name, and class color 47 | color_encoding = OrderedDict([ 48 | ('unlabeled', (0, 0, 0)), 49 | ('road', (128, 64, 128)), 50 | ('sidewalk', (244, 35, 232)), 51 | ('building', (70, 70, 70)), 52 | ('wall', (102, 102, 156)), 53 | ('fence', (190, 153, 153)), 54 | ('pole', (153, 153, 153)), 55 | ('traffic_light', (250, 170, 30)), 56 | ('traffic_sign', (220, 220, 0)), 57 | ('vegetation', (107, 142, 35)), 58 | ('terrain', (152, 251, 152)), 59 | ('sky', (70, 130, 180)), 60 | ('person', (220, 20, 60)), 61 | ('rider', (255, 0, 0)), 62 | ('car', (0, 0, 142)), 63 | ('truck', (0, 0, 70)), 64 | ('bus', (0, 60, 100)), 65 | ('train', (0, 80, 100)), 66 | ('motorcycle', (0, 0, 230)), 67 | ('bicycle', (119, 11, 32)) 68 | ]) 69 | 70 | def __init__(self, 71 | root_dir, 72 | mode='train', 73 | transform=None, 74 | label_transform=None, 75 | loader=utils.pil_loader): 76 | self.root_dir = root_dir 77 | self.mode = mode 78 | self.transform = transform 79 | self.label_transform = label_transform 80 | self.loader = loader 81 | 82 | if self.mode.lower() == 'train': 83 | # Get the training data and labels filepaths 84 | self.train_data = utils.get_files( 85 | os.path.join(root_dir, self.train_folder), 86 | extension_filter=self.img_extension) 87 | 88 | self.train_labels = utils.get_files( 89 | os.path.join(root_dir, self.train_lbl_folder), 90 | name_filter=self.lbl_name_filter, 91 | extension_filter=self.img_extension) 92 | elif self.mode.lower() == 'val': 93 | # Get the validation data and labels filepaths 94 | self.val_data = utils.get_files( 95 | os.path.join(root_dir, self.val_folder), 96 | extension_filter=self.img_extension) 97 | 98 | self.val_labels = utils.get_files( 99 | os.path.join(root_dir, self.val_lbl_folder), 100 | name_filter=self.lbl_name_filter, 101 | extension_filter=self.img_extension) 102 | elif self.mode.lower() == 'test': 103 | # Get the test data and labels filepaths 104 | self.test_data = utils.get_files( 105 | os.path.join(root_dir, self.test_folder), 106 | extension_filter=self.img_extension) 107 | 108 | self.test_labels = utils.get_files( 109 | os.path.join(root_dir, self.test_lbl_folder), 110 | name_filter=self.lbl_name_filter, 111 | extension_filter=self.img_extension) 112 | else: 113 | raise RuntimeError("Unexpected dataset mode. " 114 | "Supported modes are: train, val and test") 115 | 116 | def __getitem__(self, index): 117 | """ 118 | Args: 119 | - index (``int``): index of the item in the dataset 120 | 121 | Returns: 122 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth 123 | of the image. 124 | 125 | """ 126 | if self.mode.lower() == 'train': 127 | data_path, label_path = self.train_data[index], self.train_labels[ 128 | index] 129 | elif self.mode.lower() == 'val': 130 | data_path, label_path = self.val_data[index], self.val_labels[ 131 | index] 132 | elif self.mode.lower() == 'test': 133 | data_path, label_path = self.test_data[index], self.test_labels[ 134 | index] 135 | else: 136 | raise RuntimeError("Unexpected dataset mode. " 137 | "Supported modes are: train, val and test") 138 | 139 | img, label = self.loader(data_path, label_path) 140 | # Remap class labels 141 | label = utils.remap(label, self.full_classes, self.new_classes) 142 | 143 | if self.transform is not None: 144 | img = self.transform(img) 145 | 146 | 147 | if self.label_transform is not None: 148 | label = self.label_transform(label) 149 | return img, label 150 | 151 | def __len__(self): 152 | """Returns the length of the dataset.""" 153 | if self.mode.lower() == 'train': 154 | return len(self.train_data) 155 | elif self.mode.lower() == 'val': 156 | return len(self.val_data) 157 | elif self.mode.lower() == 'test': 158 | return len(self.test_data) 159 | else: 160 | raise RuntimeError("Unexpected dataset mode. " 161 | "Supported modes are: train, val and test") 162 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import numpy as np 4 | 5 | 6 | def get_files(folder, name_filter=None, extension_filter=None): 7 | """Helper function that returns the list of files in a specified folder 8 | with a specified extension. 9 | 10 | Keyword arguments: 11 | - folder (``string``): The path to a folder. 12 | - name_filter (```string``, optional): The returned files must contain 13 | this substring in their filename. Default: None; files are not filtered. 14 | - extension_filter (``string``, optional): The desired file extension. 15 | Default: None; files are not filtered 16 | 17 | """ 18 | if not os.path.isdir(folder): 19 | raise RuntimeError("\"{0}\" is not a folder.".format(folder)) 20 | 21 | # Filename filter: if not specified don't filter (condition always true); 22 | # otherwise, use a lambda expression to filter out files that do not 23 | # contain "name_filter" 24 | if name_filter is None: 25 | # This looks hackish...there is probably a better way 26 | name_cond = lambda filename: True 27 | else: 28 | name_cond = lambda filename: name_filter in filename 29 | 30 | # Extension filter: if not specified don't filter (condition always true); 31 | # otherwise, use a lambda expression to filter out files whose extension 32 | # is not "extension_filter" 33 | if extension_filter is None: 34 | # This looks hackish...there is probably a better way 35 | ext_cond = lambda filename: True 36 | else: 37 | ext_cond = lambda filename: filename.endswith(extension_filter) 38 | 39 | filtered_files = [] 40 | 41 | # Explore the directory tree to get files that contain "name_filter" and 42 | # with extension "extension_filter" 43 | for path, _, files in os.walk(folder): 44 | files.sort() 45 | for file in files: 46 | if name_cond(file) and ext_cond(file): 47 | full_path = os.path.join(path, file) 48 | filtered_files.append(full_path) 49 | 50 | return filtered_files 51 | 52 | 53 | def pil_loader(data_path, label_path): 54 | """Loads a sample and label image given their path as PIL images. 55 | 56 | Keyword arguments: 57 | - data_path (``string``): The filepath to the image. 58 | - label_path (``string``): The filepath to the ground-truth image. 59 | 60 | Returns the image and the label as PIL images. 61 | 62 | """ 63 | data = Image.open(data_path) 64 | label = Image.open(label_path) 65 | 66 | return data, label 67 | 68 | 69 | def remap(image, old_values, new_values): 70 | assert isinstance(image, Image.Image) or isinstance( 71 | image, np.ndarray), "image must be of type PIL.Image or numpy.ndarray" 72 | assert type(new_values) is tuple, "new_values must be of type tuple" 73 | assert type(old_values) is tuple, "old_values must be of type tuple" 74 | assert len(new_values) == len( 75 | old_values), "new_values and old_values must have the same length" 76 | 77 | # If image is a PIL.Image convert it to a numpy array 78 | if isinstance(image, Image.Image): 79 | image = np.array(image) 80 | 81 | # Replace old values by the new ones 82 | tmp = np.zeros_like(image) 83 | for old, new in zip(old_values, new_values): 84 | # Since tmp is already initialized as zeros we can skip new values 85 | # equal to 0 86 | if new != 0: 87 | tmp[image == old] = new 88 | 89 | return Image.fromarray(tmp) 90 | 91 | 92 | def enet_weighing(dataloader, num_classes, c=1.02): 93 | """Computes class weights as described in the ENet paper: 94 | 95 | w_class = 1 / (ln(c + p_class)), 96 | 97 | where c is usually 1.02 and p_class is the propensity score of that 98 | class: 99 | 100 | propensity_score = freq_class / total_pixels. 101 | 102 | References: https://arxiv.org/abs/1606.02147 103 | 104 | Keyword arguments: 105 | - dataloader (``data.Dataloader``): A data loader to iterate over the 106 | dataset. 107 | - num_classes (``int``): The number of classes. 108 | - c (``int``, optional): AN additional hyper-parameter which restricts 109 | the interval of values for the weights. Default: 1.02. 110 | 111 | """ 112 | class_count = 0 113 | total = 0 114 | for _, label in dataloader: 115 | label = label.cpu().numpy() 116 | 117 | # Flatten label 118 | flat_label = label.flatten() 119 | 120 | # Sum up the number of pixels of each class and the total pixel 121 | # counts for each label 122 | class_count += np.bincount(flat_label, minlength=num_classes) 123 | total += flat_label.size 124 | 125 | # Compute propensity score and then the weights for each class 126 | propensity_score = class_count / total 127 | class_weights = 1 / (np.log(c + propensity_score)) 128 | 129 | return class_weights 130 | 131 | 132 | def median_freq_balancing(dataloader, num_classes): 133 | """Computes class weights using median frequency balancing as described 134 | in https://arxiv.org/abs/1411.4734: 135 | 136 | w_class = median_freq / freq_class, 137 | 138 | where freq_class is the number of pixels of a given class divided by 139 | the total number of pixels in images where that class is present, and 140 | median_freq is the median of freq_class. 141 | 142 | Keyword arguments: 143 | - dataloader (``data.Dataloader``): A data loader to iterate over the 144 | dataset. 145 | whose weights are going to be computed. 146 | - num_classes (``int``): The number of classes 147 | 148 | """ 149 | class_count = 0 150 | total = 0 151 | for _, label in dataloader: 152 | label = label.cpu().numpy() 153 | 154 | # Flatten label 155 | flat_label = label.flatten() 156 | 157 | # Sum up the class frequencies 158 | bincount = np.bincount(flat_label, minlength=num_classes) 159 | 160 | # Create of mask of classes that exist in the label 161 | mask = bincount > 0 162 | # Multiply the mask by the pixel count. The resulting array has 163 | # one element for each class. The value is either 0 (if the class 164 | # does not exist in the label) or equal to the pixel count (if 165 | # the class exists in the label) 166 | total += mask * flat_label.size 167 | 168 | # Sum up the number of pixels found for each class 169 | class_count += bincount 170 | 171 | # Compute the frequency and its median 172 | freq = class_count / total 173 | med = np.median(freq) 174 | 175 | return med / freq 176 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | torch.cuda.set_device(1) 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as lr_scheduler 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from torch.autograd import Variable 11 | 12 | import transforms as ext_transforms 13 | from models.rpnet import RPNet 14 | from train import Train 15 | from test import Test 16 | from metric.iou import IoU 17 | from args import get_arguments 18 | from data.utils import enet_weighing, median_freq_balancing 19 | import utils 20 | from PIL import Image 21 | 22 | import numpy as np 23 | # Get the arguments 24 | args = get_arguments() 25 | 26 | use_cuda = args.cuda and torch.cuda.is_available() 27 | 28 | 29 | def load_dataset(dataset): 30 | print("\nLoading dataset...\n") 31 | 32 | print("Selected dataset:", args.dataset) 33 | print("Dataset directory:", args.dataset_dir) 34 | print("Save directory:", args.save_dir) 35 | 36 | image_transform = transforms.Compose( 37 | [transforms.Resize((args.height, args.width),Image.BILINEAR), 38 | transforms.ToTensor()]) 39 | 40 | label_transform = transforms.Compose([ 41 | transforms.Resize((args.height, args.width),Image.NEAREST), 42 | ext_transforms.PILToLongTensor() 43 | ]) 44 | 45 | # Get selected dataset 46 | # Load the training set as tensors 47 | train_set = dataset( 48 | args.dataset_dir, 49 | transform=image_transform, 50 | label_transform=label_transform) 51 | train_loader = data.DataLoader( 52 | train_set, 53 | batch_size=args.batch_size, 54 | shuffle=True, 55 | num_workers=args.workers) 56 | 57 | # Load the validation set as tensors 58 | val_set = dataset( 59 | args.dataset_dir, 60 | mode='val', 61 | transform=image_transform, 62 | label_transform=label_transform) 63 | val_loader = data.DataLoader( 64 | val_set, 65 | batch_size=args.batch_size, 66 | shuffle=True, 67 | num_workers=args.workers) 68 | 69 | # Load the test set as tensors 70 | test_set = dataset( 71 | args.dataset_dir, 72 | mode='test', 73 | transform=image_transform, 74 | label_transform=label_transform) 75 | test_loader = data.DataLoader( 76 | test_set, 77 | batch_size=args.batch_size, 78 | shuffle=True, 79 | num_workers=args.workers) 80 | 81 | # Get encoding between pixel valus in label images and RGB colors 82 | class_encoding = train_set.color_encoding 83 | 84 | # Remove the road_marking class from the CamVid dataset as it's merged 85 | # with the road class 86 | if args.dataset.lower() == 'camvid': 87 | del class_encoding['road_marking'] 88 | 89 | # Get number of classes to predict 90 | num_classes = len(class_encoding) 91 | 92 | # Print information for debugging 93 | print("Number of classes to predict:", num_classes) 94 | print("Train dataset size:", len(train_set)) 95 | print("Validation dataset size:", len(val_set)) 96 | 97 | # Get a batch of samples to display 98 | if args.mode.lower() == 'test': 99 | images, labels = iter(test_loader).next() 100 | else: 101 | images, labels = iter(train_loader).next() 102 | print("Image size:", images.size()) 103 | print("Label size:", labels.size()) 104 | print("Class-color encoding:", class_encoding) 105 | 106 | # Show a batch of samples and labels 107 | if args.imshow_batch: 108 | print("Close the figure window to continue...") 109 | label_to_rgb = transforms.Compose([ 110 | ext_transforms.LongTensorToRGBPIL(class_encoding), 111 | transforms.ToTensor() 112 | ]) 113 | color_labels = utils.batch_transform(labels, label_to_rgb) 114 | utils.imshow_batch(images, color_labels) 115 | 116 | # Get class weights from the selected weighing technique 117 | print("\nWeighing technique:", args.weighing) 118 | class_weights = np.array([0.0,2.7,6.1,3.6,7.7,7.7,8.1,8.6,8.4,4.3,7.7,6.8,8.0,8.6,5.9,7.7,7.5,6.6,8.5,8.4]) 119 | if class_weights is not None: 120 | class_weights = torch.from_numpy(class_weights).float() 121 | # Set the weight of the unlabeled class to 0 122 | if args.ignore_unlabeled: 123 | ignore_index = list(class_encoding).index('unlabeled') 124 | class_weights[ignore_index] = 0 125 | 126 | print("Class weights:", class_weights) 127 | 128 | return (train_loader, val_loader, 129 | test_loader), class_weights, class_encoding 130 | 131 | 132 | def train(train_loader, val_loader, class_weights, class_encoding): 133 | print("\nTraining...\n") 134 | 135 | num_classes = len(class_encoding) 136 | 137 | # Intialize RPNet 138 | model = RPNet(num_classes) 139 | 140 | # We are going to use the CrossEntropyLoss loss function as it's most 141 | # frequentely used in classification problems with multiple classes which 142 | # fits the problem. This criterion combines LogSoftMax and NLLLoss. 143 | criterion = nn.CrossEntropyLoss(weight=class_weights) 144 | 145 | # ENet authors used Adam as the optimizer 146 | optimizer = optim.Adam( 147 | model.parameters(), 148 | lr=args.learning_rate, 149 | weight_decay=args.weight_decay) 150 | 151 | # Learning rate decay scheduler 152 | lmd = lambda epoch: (1-epoch/args.epochs) ** 0.9 153 | lr_updater = lr_scheduler.LambdaLR(optimizer, lr_lambda=lmd) 154 | 155 | # Evaluation metric 156 | if args.ignore_unlabeled: 157 | ignore_index = list(class_encoding).index('unlabeled') 158 | else: 159 | ignore_index = None 160 | metric = IoU(num_classes, ignore_index=ignore_index) 161 | 162 | if use_cuda: 163 | model = model.cuda() 164 | criterion = criterion.cuda() 165 | 166 | # Optionally resume from a checkpoint 167 | if args.resume: 168 | model, optimizer, start_epoch, best_miou = utils.load_checkpoint( 169 | model, optimizer, args.save_dir, args.name) 170 | print("Resuming from model: Start epoch = {0} " 171 | "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou)) 172 | else: 173 | start_epoch = 0 174 | best_miou = 0 175 | 176 | 177 | # Step 178 | step = args.step 179 | 180 | # Start Training 181 | train = Train(model, train_loader, optimizer, criterion, metric, use_cuda, step) 182 | val = Test(model, val_loader, criterion, metric, use_cuda, step) 183 | for epoch in range(start_epoch, args.epochs): 184 | print(">>>> [Epoch: {0:d}] Training".format(epoch)) 185 | train.model.train() 186 | lr_updater.step() 187 | epoch_loss, (iou, miou) = train.run_epoch(args.print_step) 188 | 189 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". 190 | format(epoch, epoch_loss, miou),'current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 191 | 192 | if (epoch + 1) % 1 == 0 or epoch + 1 == args.epochs: 193 | val.model.eval() 194 | print(">>>> [Epoch: {0:d}] Validation".format(epoch)) 195 | 196 | loss, (iou, miou) = val.run_epoch(args.print_step) 197 | 198 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}". 199 | format(epoch, loss, miou)) 200 | 201 | # Print per class IoU on last epoch or if best iou 202 | if epoch + 1 == args.epochs or miou > best_miou: 203 | for key, class_iou in zip(class_encoding.keys(), iou): 204 | print("{0}: {1:.4f}".format(key, class_iou)) 205 | 206 | # Save the model if it's the best thus far 207 | if miou > best_miou: 208 | print("\nBest model thus far. Saving...\n") 209 | best_miou = miou 210 | utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,args) 211 | 212 | 213 | return model 214 | 215 | 216 | def test(model, test_loader, class_weights, class_encoding, step): 217 | print("\nTesting...\n") 218 | 219 | num_classes = len(class_encoding) 220 | 221 | # We are going to use the CrossEntropyLoss loss function as it's most 222 | # frequentely used in classification problems with multiple classes which 223 | # fits the problem. This criterion combines LogSoftMax and NLLLoss. 224 | criterion = nn.CrossEntropyLoss(weight=class_weights) 225 | if use_cuda: 226 | criterion = criterion.cuda() 227 | 228 | # Evaluation metric 229 | if args.ignore_unlabeled: 230 | ignore_index = list(class_encoding).index('unlabeled') 231 | else: 232 | ignore_index = None 233 | metric = IoU(num_classes, ignore_index=ignore_index) 234 | 235 | # Test the trained model on the test set 236 | test = Test(model, test_loader, criterion, metric, use_cuda, step) 237 | 238 | print(">>>> Running test dataset") 239 | 240 | loss, (iou, miou) = test.run_epoch(args.print_step) 241 | class_iou = dict(zip(class_encoding.keys(), iou)) 242 | 243 | print(">>>> Avg. loss: {0:.4f} | Mean IoU: {1:.4f}".format(loss, miou)) 244 | 245 | # Print per class IoU 246 | for key, class_iou in zip(class_encoding.keys(), iou): 247 | print("{0}: {1:.4f}".format(key, class_iou)) 248 | 249 | # Show a batch of samples and labels 250 | if args.imshow_batch: 251 | print("A batch of predictions from the test set...") 252 | images, _ = iter(test_loader).next() 253 | predict(model, images, class_encoding) 254 | 255 | 256 | def predict(model, images, class_encoding): 257 | images = Variable(images) 258 | if use_cuda: 259 | images = images.cuda() 260 | 261 | # Make predictions! 262 | predictions = model(images) 263 | 264 | # Predictions is one-hot encoded with "num_classes" channels. 265 | # Convert it to a single int using the indices where the maximum (1) occurs 266 | _, predictions = torch.max(predictions.data, 1) 267 | 268 | label_to_rgb = transforms.Compose([ 269 | ext_transforms.LongTensorToRGBPIL(class_encoding), 270 | transforms.ToTensor() 271 | ]) 272 | color_predictions = utils.batch_transform(predictions.cpu(), label_to_rgb) 273 | utils.imshow_batch(images.data.cpu(), color_predictions) 274 | 275 | 276 | # Run only if this module is being run directly 277 | if __name__ == '__main__': 278 | 279 | # Fail fast if the dataset directory doesn't exist 280 | assert os.path.isdir( 281 | args.dataset_dir), "The directory \"{0}\" doesn't exist.".format( 282 | args.dataset_dir) 283 | 284 | # Fail fast if the saving directory doesn't exist 285 | assert os.path.isdir( 286 | args.save_dir), "The directory \"{0}\" doesn't exist.".format( 287 | args.save_dir) 288 | 289 | # Import the requested dataset 290 | if args.dataset.lower() == 'camvid': 291 | from data import CamVid as dataset 292 | elif args.dataset.lower() == 'cityscapes': 293 | from data import Cityscapes as dataset 294 | else: 295 | # Should never happen...but just in case it does 296 | raise RuntimeError("\"{0}\" is not a supported dataset.".format( 297 | args.dataset)) 298 | 299 | loaders, w_class, class_encoding = load_dataset(dataset) 300 | train_loader, val_loader, test_loader = loaders 301 | 302 | if args.mode.lower() in {'train', 'full'}: 303 | model = train(train_loader, val_loader, w_class, class_encoding) 304 | if args.mode.lower() == 'full': 305 | test(model, test_loader, w_class, class_encoding) 306 | elif args.mode.lower() == 'test': 307 | # Intialize a new RPNet model 308 | num_classes = len(class_encoding) 309 | model = RPNet(num_classes) 310 | print(model) 311 | #model = nn.DataParallel(model) 312 | model.eval() 313 | if use_cuda: 314 | model = model.cuda() 315 | 316 | # Initialize a optimizer just so we can retrieve the model from the 317 | # checkpoint 318 | optimizer = optim.Adam(model.parameters()) 319 | 320 | # Load the previoulsy saved model state to the RPNet model 321 | model = utils.load_checkpoint(model, optimizer, args.save_dir, 322 | args.name)[0] 323 | #print(model) 324 | step = args.step 325 | test(model, test_loader, w_class, class_encoding, step) 326 | else: 327 | # Should never happen...but just in case it does 328 | raise RuntimeError( 329 | "\"{0}\" is not a valid choice for execution mode.".format( 330 | args.mode)) 331 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- 1 | from .confusionmatrix import ConfusionMatrix 2 | from .iou import IoU 3 | from .metric import Metric 4 | 5 | __all__ = ['ConfusionMatrix', 'IoU', 'Metric'] 6 | -------------------------------------------------------------------------------- /metric/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /metric/__pycache__/confusionmatrix.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/confusionmatrix.cpython-36.pyc -------------------------------------------------------------------------------- /metric/__pycache__/iou.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/iou.cpython-36.pyc -------------------------------------------------------------------------------- /metric/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /metric/confusionmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from metric import metric 4 | 5 | 6 | class ConfusionMatrix(metric.Metric): 7 | """Constructs a confusion matrix for a multi-class classification problems. 8 | 9 | Does not support multi-label, multi-class problems. 10 | 11 | Keyword arguments: 12 | - num_classes (int): number of classes in the classification problem. 13 | - normalized (boolean, optional): Determines whether or not the confusion 14 | matrix is normalized or not. Default: False. 15 | 16 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py 17 | """ 18 | 19 | def __init__(self, num_classes, normalized=False): 20 | super().__init__() 21 | 22 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) 23 | self.normalized = normalized 24 | self.num_classes = num_classes 25 | self.reset() 26 | 27 | def reset(self): 28 | self.conf.fill(0) 29 | 30 | def add(self, predicted, target): 31 | """Computes the confusion matrix 32 | 33 | The shape of the confusion matrix is K x K, where K is the number 34 | of classes. 35 | 36 | Keyword arguments: 37 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of 38 | predicted scores obtained from the model for N examples and K classes, 39 | or an N-tensor/array of integer values between 0 and K-1. 40 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of 41 | ground-truth classes for N examples and K classes, or an N-tensor/array 42 | of integer values between 0 and K-1. 43 | 44 | """ 45 | # If target and/or predicted are tensors, convert them to numpy arrays 46 | if torch.is_tensor(predicted): 47 | predicted = predicted.cpu().numpy() 48 | if torch.is_tensor(target): 49 | target = target.cpu().numpy() 50 | 51 | assert predicted.shape[0] == target.shape[0], \ 52 | 'number of targets and predicted outputs do not match' 53 | 54 | if np.ndim(predicted) != 1: 55 | assert predicted.shape[1] == self.num_classes, \ 56 | 'number of predictions does not match size of confusion matrix' 57 | predicted = np.argmax(predicted, 1) 58 | else: 59 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ 60 | 'predicted values are not between 0 and k-1' 61 | 62 | if np.ndim(target) != 1: 63 | assert target.shape[1] == self.num_classes, \ 64 | 'Onehot target does not match size of confusion matrix' 65 | assert (target >= 0).all() and (target <= 1).all(), \ 66 | 'in one-hot encoding, target values should be 0 or 1' 67 | assert (target.sum(1) == 1).all(), \ 68 | 'multi-label setting is not supported' 69 | target = np.argmax(target, 1) 70 | else: 71 | assert (target.max() < self.num_classes) and (target.min() >= 0), \ 72 | 'target values are not between 0 and k-1' 73 | 74 | # hack for bincounting 2 arrays together 75 | x = predicted + self.num_classes * target 76 | bincount_2d = np.bincount( 77 | x.astype(np.int32), minlength=self.num_classes**2) 78 | assert bincount_2d.size == self.num_classes**2 79 | conf = bincount_2d.reshape((self.num_classes, self.num_classes)) 80 | 81 | self.conf += conf 82 | 83 | def value(self): 84 | """ 85 | Returns: 86 | Confustion matrix of K rows and K columns, where rows corresponds 87 | to ground-truth targets and columns corresponds to predicted 88 | targets. 89 | """ 90 | if self.normalized: 91 | conf = self.conf.astype(np.float32) 92 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 93 | else: 94 | return self.conf 95 | -------------------------------------------------------------------------------- /metric/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from metric import metric 4 | from metric.confusionmatrix import ConfusionMatrix 5 | 6 | 7 | class IoU(metric.Metric): 8 | """Computes the intersection over union (IoU) per class and corresponding 9 | mean (mIoU). 10 | 11 | Intersection over union (IoU) is a common evaluation metric for semantic 12 | segmentation. The predictions are first accumulated in a confusion matrix 13 | and the IoU is computed from it as follows: 14 | 15 | IoU = true_positive / (true_positive + false_positive + false_negative). 16 | 17 | Keyword arguments: 18 | - num_classes (int): number of classes in the classification problem 19 | - normalized (boolean, optional): Determines whether or not the confusion 20 | matrix is normalized or not. Default: False. 21 | - ignore_index (int or iterable, optional): Index of the classes to ignore 22 | when computing the IoU. Can be an int, or any iterable of ints. 23 | """ 24 | 25 | def __init__(self, num_classes, normalized=False, ignore_index=None): 26 | super().__init__() 27 | self.conf_metric = ConfusionMatrix(num_classes, normalized) 28 | 29 | if ignore_index is None: 30 | self.ignore_index = None 31 | elif isinstance(ignore_index, int): 32 | self.ignore_index = (ignore_index,) 33 | else: 34 | try: 35 | self.ignore_index = tuple(ignore_index) 36 | except TypeError: 37 | raise ValueError("'ignore_index' must be an int or iterable") 38 | 39 | def reset(self): 40 | self.conf_metric.reset() 41 | 42 | def add(self, predicted, target): 43 | """Adds the predicted and target pair to the IoU metric. 44 | 45 | Keyword arguments: 46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of 47 | predicted scores obtained from the model for N examples and K classes, 48 | or (N, H, W) tensor of integer values between 0 and K-1. 49 | - target (Tensor): Can be a (N, K, H, W) tensor of 50 | target scores for N examples and K classes, or (N, H, W) tensor of 51 | integer values between 0 and K-1. 52 | 53 | """ 54 | # Dimensions check 55 | assert predicted.size(0) == target.size(0), \ 56 | 'number of targets and predicted outputs do not match' 57 | assert predicted.dim() == 3 or predicted.dim() == 4, \ 58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)" 59 | assert target.dim() == 3 or target.dim() == 4, \ 60 | "targets must be of dimension (N, H, W) or (N, K, H, W)" 61 | 62 | # If the tensor is in categorical format convert it to integer format 63 | if predicted.dim() == 4: 64 | _, predicted = predicted.max(1) 65 | if target.dim() == 4: 66 | _, target = target.max(1) 67 | 68 | self.conf_metric.add(predicted.view(-1), target.view(-1)) 69 | 70 | def value(self): 71 | """Computes the IoU and mean IoU. 72 | 73 | The mean computation ignores NaN elements of the IoU array. 74 | 75 | Returns: 76 | Tuple: (IoU, mIoU). The first output is the per class IoU, 77 | for K classes it's numpy.ndarray with K elements. The second output, 78 | is the mean IoU. 79 | """ 80 | conf_matrix = self.conf_metric.value() 81 | if self.ignore_index is not None: 82 | for index in self.ignore_index: 83 | conf_matrix[:, self.ignore_index] = 0 84 | conf_matrix[self.ignore_index, :] = 0 85 | true_positive = np.diag(conf_matrix) 86 | false_positive = np.sum(conf_matrix, 0) - true_positive 87 | false_negative = np.sum(conf_matrix, 1) - true_positive 88 | 89 | # Just in case we get a division by 0, ignore/hide the error 90 | with np.errstate(divide='ignore', invalid='ignore'): 91 | iou = true_positive / (true_positive + false_positive + false_negative) 92 | 93 | return iou, np.nanmean(iou) 94 | -------------------------------------------------------------------------------- /metric/metric.py: -------------------------------------------------------------------------------- 1 | class Metric(object): 2 | """Base class for all metrics. 3 | 4 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py 5 | """ 6 | def reset(self): 7 | pass 8 | 9 | def add(self): 10 | pass 11 | 12 | def value(self): 13 | pass 14 | -------------------------------------------------------------------------------- /models/__pycache__/rpnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/models/__pycache__/rpnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/rpnet.py: -------------------------------------------------------------------------------- 1 | # ERFNet full model definition for Pytorch 2 | # Sept 2017 3 | # Eduardo Romera 4 | ####################### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | class DownsamplerBlock (nn.Module): 12 | def __init__(self, ninput, noutput): 13 | super().__init__() 14 | 15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) 16 | self.conv2 = nn.Conv2d(16, 64, (1, 1), stride=1, padding=0, bias=True) 17 | self.pool = nn.MaxPool2d(2, stride=2, return_indices=True) 18 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3) 19 | 20 | 21 | def forward(self, input): 22 | c=input 23 | a=self.conv(input) 24 | b,max_indices=self.pool(input) 25 | #print(a.shape,b.shape,c.shape,max_indices.shape,"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") 26 | output1 = torch.cat([a, b], 1) 27 | if b.shape[1]==16: 28 | b_c = self.conv2(b) 29 | else: 30 | b_c=b 31 | output = self.bn(output1) 32 | return F.relu(output), max_indices, b, b_c, output1 33 | 34 | 35 | class non_bottleneck_1d (nn.Module): 36 | def __init__(self, chann, dropprob, dilated): 37 | super().__init__() 38 | 39 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True) 40 | 41 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True) 42 | 43 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) 44 | 45 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1)) 46 | 47 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated)) 48 | 49 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) 50 | 51 | self.dropout = nn.Dropout2d(dropprob) 52 | 53 | 54 | def forward(self, input): 55 | 56 | output = self.conv3x1_1(input) 57 | output = F.relu(output) 58 | output = self.conv1x3_1(output) 59 | output = self.bn1(output) 60 | output = F.relu(output) 61 | 62 | output = self.conv3x1_2(output) 63 | output = F.relu(output) 64 | output = self.conv1x3_2(output) 65 | output = self.bn2(output) 66 | 67 | if (self.dropout.p != 0): 68 | output = self.dropout(output) 69 | 70 | return F.relu(output+input), output #+input = identity (residual connection) 71 | 72 | 73 | class RPNet(nn.Module): 74 | def __init__(self, num_classes): 75 | super().__init__() 76 | self.initial_block = DownsamplerBlock(3,16) 77 | 78 | self.l0d1=non_bottleneck_1d(16, 0.03, 1) 79 | self.down0_25=DownsamplerBlock(16,64) 80 | 81 | 82 | self.l1d1=non_bottleneck_1d(64, 0.03, 1) 83 | self.l1d2=non_bottleneck_1d(64, 0.03, 1) 84 | self.l1d3=non_bottleneck_1d(64, 0.03, 1) 85 | self.l1d4=non_bottleneck_1d(64, 0.03, 1) 86 | self.l1d5=non_bottleneck_1d(64, 0.03, 1) 87 | 88 | self.down0_125=DownsamplerBlock(64,128) 89 | 90 | self.l2d1=non_bottleneck_1d(128, 0.3, 2) 91 | self.l2d2=non_bottleneck_1d(128, 0.3, 4) 92 | self.l2d3=non_bottleneck_1d(128, 0.3, 8) 93 | self.l2d4=non_bottleneck_1d(128, 0.3, 16) 94 | 95 | self.l3d1=non_bottleneck_1d(128, 0.3, 2) 96 | self.l3d2=non_bottleneck_1d(128, 0.3, 4) 97 | self.l3d3=non_bottleneck_1d(128, 0.3, 8) 98 | self.l3d4=non_bottleneck_1d(128, 0.3, 16) 99 | #Only in encoder mode: 100 | self.conv2d1 = nn.Conv2d( 101 | 128, 102 | num_classes, 103 | kernel_size=1, 104 | stride=1, 105 | padding=0, 106 | bias=True) 107 | self.conv2d2 = nn.Conv2d( 108 | 192, 109 | num_classes, 110 | kernel_size=1, 111 | stride=1, 112 | padding=0, 113 | bias=True) 114 | self.conv2d3 = nn.Conv2d( 115 | 36, 116 | num_classes, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0, 120 | bias=True) 121 | self.conv2d4 = nn.Conv2d( 122 | 16, 123 | num_classes, 124 | kernel_size=1, 125 | stride=1, 126 | padding=0, 127 | bias=False) 128 | self.conv2d5 = nn.Conv2d( 129 | 64, 130 | num_classes, 131 | kernel_size=1, 132 | stride=1, 133 | padding=0, 134 | bias=False) 135 | 136 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2) 137 | self.main_unpool2 = nn.MaxUnpool2d(kernel_size=2) 138 | 139 | 140 | def forward(self, input, predict=False): 141 | output, max_indices0_0, d, d_d, dd = self.initial_block(input) 142 | output,y = self.l0d1(output) 143 | 144 | 145 | 146 | 147 | output, max_indices1_0,d1,d1_d1, ddd = self.down0_25(output) 148 | #print(d1.shape,'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') 149 | d2 = self.main_unpool1(d1, max_indices1_0) 150 | d_1=d2-dd 151 | 152 | output,y = self.l1d1(output) 153 | output,y = self.l1d2(output) 154 | output,y = self.l1d3(output) 155 | output,y = self.l1d4(output) 156 | 157 | cc_2=self.conv2d4(d_1) 158 | 159 | 160 | output, max_indices2_0,d3,d3_d3, dddd = self.down0_125(output) 161 | d4 = self.main_unpool2(d3, max_indices2_0) 162 | d_2=d4-d1_d1 163 | cc_4=self.conv2d5(d_2) 164 | 165 | output,y = self.l2d1(output) 166 | output,y = self.l2d2(output) 167 | output,y = self.l2d3(output) 168 | output,y = self.l2d4(output) 169 | output,y = self.l3d1(output) 170 | output,y = self.l3d2(output) 171 | output,y = self.l3d3(output) 172 | output,y = self.l3d4(output) 173 | x1_81 = output 174 | x1_8 = self.conv2d1(output) 175 | 176 | x1_8_2 = torch.nn.functional.interpolate(x1_81, scale_factor=2, mode='bilinear') 177 | 178 | out4 = torch.cat((x1_8_2,d_2),1) 179 | x1_41 = self.conv2d2(out4) 180 | x1_4=x1_41+cc_4 181 | 182 | x1_4_2 = torch.nn.functional.interpolate(x1_4, scale_factor=2, mode='bilinear') 183 | out2 = torch.cat((x1_4_2, d_1), 1) 184 | x1_21 = self.conv2d3(out2) 185 | x1_2=x1_21+cc_2 186 | 187 | x1_1 = torch.nn.functional.interpolate(x1_2, scale_factor=2, mode='bilinear') 188 | 189 | return x1_1, x1_2, x1_4, x1_8 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /pictures/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/pictures/1.png -------------------------------------------------------------------------------- /pictures/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/pictures/2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | kiwisolver==1.0.1 3 | matplotlib==2.2.2 4 | numpy==1.14.2 5 | Pillow==5.0.0 6 | pyparsing==2.2.0 7 | python-dateutil==2.7.0 8 | pytz==2018.3 9 | PyYAML==4.2 10 | six==1.11.0 11 | torch==0.3.1 12 | torchvision==0.2.0 13 | -------------------------------------------------------------------------------- /save/RPNet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/save/RPNet -------------------------------------------------------------------------------- /save/RPNet_summary.txt: -------------------------------------------------------------------------------- 1 | ARGUMENTS 2 | batch_size: 3 3 | cuda: store_false 4 | dataset: cityscapes 5 | dataset_dir: data/Cityscapes 6 | epochs: 300 7 | height: 512 8 | ignore_unlabeled: True 9 | imshow_batch: False 10 | learning_rate: 0.0005 11 | lr_decay: 0.1 12 | lr_decay_epochs: 100 13 | mode: train 14 | name: RPNet 15 | print_step: False 16 | resume: True 17 | save_dir: save 18 | weighing: mfb 19 | weight_decay: 0.0001 20 | width: 1024 21 | workers: 4 22 | 23 | BEST VALIDATION 24 | Epoch: 300 25 | Mean IoU: 0.703690656206444 26 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | 4 | 5 | class Test(): 6 | """Tests the ``model`` on the specified test dataset using the 7 | data loader, and loss criterion. 8 | 9 | Keyword arguments: 10 | - model (``nn.Module``): the model instance to test. 11 | - data_loader (``Dataloader``): Provides single or multi-process 12 | iterators over the dataset. 13 | - criterion (``Optimizer``): The loss criterion. 14 | - metric (```Metric``): An instance specifying the metric to return. 15 | - use_cuda (``bool``): If ``True``, the training is performed using 16 | CUDA operations (GPU). 17 | 18 | """ 19 | 20 | def __init__(self, model, data_loader, criterion, metric, use_cuda, step): 21 | self.model = model 22 | self.data_loader = data_loader 23 | self.criterion = criterion 24 | self.metric = metric 25 | self.use_cuda = use_cuda 26 | self.step = step 27 | 28 | def run_epoch(self, iteration_loss=False): 29 | """Runs an epoch of validation. 30 | 31 | Keyword arguments: 32 | - iteration_loss (``bool``, optional): Prints loss at every step. 33 | 34 | Returns: 35 | - The epoch loss (float), and the values of the specified metrics 36 | 37 | """ 38 | epoch_loss = 0.0 39 | self.metric.reset() 40 | for step, batch_data in enumerate(self.data_loader): 41 | # Get the inputs and labels 42 | inputs, labels = batch_data 43 | 44 | # Wrap them in a Varaible 45 | inputs, labels = Variable(inputs), Variable(labels) 46 | if self.use_cuda: 47 | inputs = inputs.cuda() 48 | labels = labels.cuda() 49 | 50 | labels4 = torch.nn.functional.interpolate(labels.unsqueeze(0).float(), scale_factor=0.125, mode='nearest').squeeze(0).long() 51 | labels3 = torch.nn.functional.interpolate(labels.unsqueeze(0).float(), scale_factor=0.25, mode='nearest').squeeze(0).long() 52 | labels2 = torch.nn.functional.interpolate(labels.unsqueeze(0).float(), scale_factor=0.5, mode='nearest').squeeze(0).long() 53 | labels1 = labels 54 | 55 | # Forward propagation 56 | outputs1, outputs2, outputs3, outputs4 = self.model(inputs) 57 | 58 | # Loss computation 59 | loss = self.criterion(eval('outputs{}'.format(self.step)), eval('labels{}'.format(self.step))) 60 | 61 | # Keep track of loss for current epoch 62 | epoch_loss += loss.item() 63 | 64 | # Keep track of evaluation the metric 65 | self.metric.add(eval('outputs{}'.format(self.step)).data, eval('labels{}'.format(self.step)).data) 66 | 67 | if iteration_loss: 68 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) 69 | 70 | return epoch_loss / len(self.data_loader), self.metric.value() 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch.nn as nn 3 | import torch 4 | import random 5 | import cv2 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import matplotlib.image as mpimg 9 | import utils 10 | 11 | 12 | 13 | class Train(): 14 | """Performs the training of ``model`` given a training dataset data 15 | loader, the optimizer, and the loss criterion. 16 | 17 | Keyword arguments: 18 | - model (``nn.Module``): the model instance to train. 19 | - data_loader (``Dataloader``): Provides single or multi-process 20 | iterators over the dataset. 21 | - optim (``Optimizer``): The optimization algorithm. 22 | - criterion (``Optimizer``): The loss criterion. 23 | - metric (```Metric``): An instance specifying the metric to return. 24 | - use_cuda (``bool``): If ``True``, the training is performed using 25 | CUDA operations (GPU). 26 | 27 | """ 28 | 29 | def __init__(self, model, data_loader, optim, criterion, metric, use_cuda, step): 30 | self.model = model 31 | self.data_loader = data_loader 32 | self.optim = optim 33 | self.criterion = criterion 34 | self.metric = metric 35 | self.use_cuda = use_cuda 36 | self.step = step 37 | 38 | def run_epoch(self, iteration_loss=False): 39 | """Runs an epoch of training. 40 | 41 | Keyword arguments: 42 | - iteration_loss (``bool``, optional): Prints loss at every step. 43 | 44 | Returns: 45 | - The epoch loss (float). 46 | 47 | """ 48 | epoch_loss = 0.0 49 | self.metric.reset() 50 | for step, batch_data in enumerate(self.data_loader): 51 | # Get the inputs and labels 52 | 53 | inputs, labels = batch_data 54 | 55 | # Use augmentation 56 | inputs, labels = utils.dataaug(inputs, labels) 57 | 58 | # Wrap them in a Varaible 59 | inputs, labels = Variable(inputs), Variable(labels.float()) 60 | 61 | 62 | if self.use_cuda: 63 | inputs = inputs.cuda() 64 | labels = labels.cuda() 65 | labels=labels.unsqueeze(1) 66 | 67 | labels2 = torch.nn.functional.interpolate(labels, scale_factor=0.5, mode='nearest').squeeze(1) 68 | labels3 = torch.nn.functional.interpolate(labels, scale_factor=0.25, mode='nearest').squeeze(1) 69 | labels4 = torch.nn.functional.interpolate(labels, scale_factor=0.125, mode='nearest').squeeze(1) 70 | labels1 = labels.squeeze(1) 71 | 72 | 73 | # Forward propagation 74 | outputs1, outputs2, outputs3, outputs4 = self.model(inputs) 75 | 76 | # Loss computation 77 | loss1 = self.criterion(outputs1, labels1.long()) 78 | loss2 = self.criterion(outputs2, labels2.long()) 79 | loss3 = self.criterion(outputs3, labels3.long()) 80 | loss4 = self.criterion(outputs4, labels4.long()) 81 | 82 | # Step Loss 83 | loss= eval('loss{}'.format(self.step)) 84 | 85 | # Backpropagation 86 | self.optim.zero_grad() 87 | loss.backward() 88 | self.optim.step() 89 | 90 | # Keep track of loss for current epoch 91 | epoch_loss += loss.item() 92 | 93 | # Metric 94 | self.metric.add(eval('outputs{}'.format(self.step)).data, eval('labels{}'.format(self.step)).data) 95 | 96 | if iteration_loss: 97 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item())) 98 | 99 | return epoch_loss / len(self.data_loader), self.metric.value() 100 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #level4 training (1/8 scale) 4 | python main.py --step=4 5 | 6 | #level3 training (1/4 scale) 7 | python main.py --step=3 --resume 8 | 9 | #level2 training (1/2 scale) 10 | python main.py --step=2 --resume 11 | 12 | #level1 training (original scale) 13 | python main.py --step=1 --resume 14 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from collections import OrderedDict 5 | from torchvision.transforms import ToPILImage 6 | 7 | 8 | class PILToLongTensor(object): 9 | """Converts a ``PIL Image`` to a ``torch.LongTensor``. 10 | 11 | Code adapted from: http://pytorch.org/docs/master/torchvision/transforms.html?highlight=totensor 12 | 13 | """ 14 | 15 | def __call__(self, pic): 16 | """Performs the conversion from a ``PIL Image`` to a ``torch.LongTensor``. 17 | 18 | Keyword arguments: 19 | - pic (``PIL.Image``): the image to convert to ``torch.LongTensor`` 20 | 21 | Returns: 22 | A ``torch.LongTensor``. 23 | 24 | """ 25 | if not isinstance(pic, Image.Image): 26 | raise TypeError("pic should be PIL Image. Got {}".format( 27 | type(pic))) 28 | 29 | # handle numpy array 30 | if isinstance(pic, np.ndarray): 31 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 32 | # backward compatibility 33 | return img.long() 34 | 35 | # Convert PIL image to ByteTensor 36 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 37 | 38 | # Reshape tensor 39 | nchannel = len(pic.mode) 40 | img = img.view(pic.size[1], pic.size[0], nchannel) 41 | 42 | # Convert to long and squeeze the channels 43 | return img.transpose(0, 1).transpose(0, 44 | 2).contiguous().long().squeeze_() 45 | 46 | 47 | class LongTensorToRGBPIL(object): 48 | """Converts a ``torch.LongTensor`` to a ``PIL image``. 49 | 50 | The input is a ``torch.LongTensor`` where each pixel's value identifies the 51 | class. 52 | 53 | Keyword arguments: 54 | - rgb_encoding (``OrderedDict``): An ``OrderedDict`` that relates pixel 55 | values, class names, and class colors. 56 | 57 | """ 58 | def __init__(self, rgb_encoding): 59 | self.rgb_encoding = rgb_encoding 60 | 61 | def __call__(self, tensor): 62 | """Performs the conversion from ``torch.LongTensor`` to a ``PIL image`` 63 | 64 | Keyword arguments: 65 | - tensor (``torch.LongTensor``): the tensor to convert 66 | 67 | Returns: 68 | A ``PIL.Image``. 69 | 70 | """ 71 | # Check if label_tensor is a LongTensor 72 | if not isinstance(tensor, torch.LongTensor): 73 | raise TypeError("label_tensor should be torch.LongTensor. Got {}" 74 | .format(type(tensor))) 75 | # Check if encoding is a ordered dictionary 76 | if not isinstance(self.rgb_encoding, OrderedDict): 77 | raise TypeError("encoding should be an OrderedDict. Got {}".format( 78 | type(self.rgb_encoding))) 79 | 80 | # label_tensor might be an image without a channel dimension, in this 81 | # case unsqueeze it 82 | if len(tensor.size()) == 2: 83 | tensor.unsqueeze_(0) 84 | 85 | color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2)) 86 | 87 | for index, (class_name, color) in enumerate(self.rgb_encoding.items()): 88 | # Get a mask of elements equal to index 89 | mask = torch.eq(tensor, index).squeeze_() 90 | # Fill color_tensor with corresponding colors 91 | for channel, color_value in enumerate(color): 92 | color_tensor[channel].masked_fill_(mask, color_value) 93 | 94 | return ToPILImage()(color_tensor) 95 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as F 3 | import torchvision 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | import random 8 | import cv2 9 | 10 | def batch_transform(batch, transform): 11 | """Applies a transform to a batch of samples. 12 | 13 | Keyword arguments: 14 | - batch (): a batch os samples 15 | - transform (callable): A function/transform to apply to ``batch`` 16 | 17 | """ 18 | 19 | # Convert the single channel label to RGB in tensor form 20 | # 1. F.unbind removes the 0-dimension of "labels" and returns a tuple of 21 | # all slices along that dimension 22 | # 2. the transform is applied to each slice 23 | transf_slices = [transform(tensor) for tensor in F.unbind(batch)] 24 | 25 | return F.stack(transf_slices) 26 | 27 | 28 | def imshow_batch(images, labels): 29 | """Displays two grids of images. The top grid displays ``images`` 30 | and the bottom grid ``labels`` 31 | 32 | Keyword arguments: 33 | - images (``Tensor``): a 4D mini-batch tensor of shape 34 | (B, C, H, W) 35 | - labels (``Tensor``): a 4D mini-batch tensor of shape 36 | (B, C, H, W) 37 | 38 | """ 39 | 40 | # Make a grid with the images and labels and convert it to numpy 41 | images = torchvision.utils.make_grid(images).numpy() 42 | labels = torchvision.utils.make_grid(labels).numpy() 43 | 44 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7)) 45 | ax1.imshow(np.transpose(images, (1, 2, 0))) 46 | ax2.imshow(np.transpose(labels, (1, 2, 0))) 47 | 48 | plt.show() 49 | 50 | 51 | def save_checkpoint(model, optimizer, epoch, miou, args): 52 | """Saves the model in a specified directory with a specified name.save 53 | 54 | Keyword arguments: 55 | - model (``nn.Module``): The model to save. 56 | - optimizer (``torch.optim``): The optimizer state to save. 57 | - epoch (``int``): The current epoch for the model. 58 | - miou (``float``): The mean IoU obtained by the model. 59 | - args (``ArgumentParser``): An instance of ArgumentParser which contains 60 | the arguments used to train ``model``. The arguments are written to a text 61 | file in ``args.save_dir`` named "``args.name``_args.txt". 62 | 63 | """ 64 | name = args.name 65 | save_dir = args.save_dir 66 | 67 | assert os.path.isdir( 68 | save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir) 69 | 70 | # Save model 71 | model_path = os.path.join(save_dir, name) 72 | checkpoint = { 73 | 'epoch': epoch, 74 | 'miou': miou, 75 | 'state_dict': model.state_dict(), 76 | 'optimizer': optimizer.state_dict() 77 | } 78 | torch.save(checkpoint, model_path) 79 | 80 | # Save arguments 81 | summary_filename = os.path.join(save_dir, name + '_summary.txt') 82 | with open(summary_filename, 'w') as summary_file: 83 | sorted_args = sorted(vars(args)) 84 | summary_file.write("ARGUMENTS\n") 85 | for arg in sorted_args: 86 | arg_str = "{0}: {1}\n".format(arg, getattr(args, arg)) 87 | summary_file.write(arg_str) 88 | 89 | summary_file.write("\nBEST VALIDATION\n") 90 | summary_file.write("Epoch: {0}\n". format(epoch)) 91 | summary_file.write("Mean IoU: {0}\n". format(miou)) 92 | 93 | 94 | def load_checkpoint(model, optimizer, folder_dir, filename): 95 | """Saves the model in a specified directory with a specified name.save 96 | 97 | Keyword arguments: 98 | - model (``nn.Module``): The stored model state is copied to this model 99 | instance. 100 | - optimizer (``torch.optim``): The stored optimizer state is copied to this 101 | optimizer instance. 102 | - folder_dir (``string``): The path to the folder where the saved model 103 | state is located. 104 | - filename (``string``): The model filename. 105 | 106 | Returns: 107 | The epoch, mean IoU, ``model``, and ``optimizer`` loaded from the 108 | checkpoint. 109 | 110 | """ 111 | assert os.path.isdir( 112 | folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir) 113 | 114 | # Create folder to save model and information 115 | model_path = os.path.join(folder_dir, filename) 116 | assert os.path.isfile( 117 | model_path), "The model file \"{0}\" doesn't exist.".format(filename) 118 | 119 | # Load the stored model parameters to the model instance 120 | checkpoint = torch.load(model_path) 121 | model.load_state_dict(checkpoint['state_dict']) 122 | optimizer.load_state_dict(checkpoint['optimizer']) 123 | epoch = 0 124 | miou = 0 125 | 126 | return model, optimizer, epoch, miou 127 | 128 | 129 | 130 | def dataaug(inputs, labels): 131 | """use flip and warpAffine 132 | """ 133 | for i in range(inputs.shape[0]): 134 | inputs1 = inputs[i, :, :, :] 135 | labels1 = labels[i, :, :] 136 | inputs1 = inputs1.numpy() 137 | labels1 = labels1.numpy() 138 | inputs1 = inputs1.transpose((1, 2, 0)) 139 | 140 | x = random.randint(-2, 2) 141 | y = random.randint(-2, 2) 142 | H = np.float32([[1,0,x],[0,1,y]]) 143 | inputs1 = cv2.warpAffine(inputs1,H,(inputs1.shape[1],inputs1.shape[0])) 144 | labels1 = cv2.warpAffine(labels1.astype(np.float32),H,(inputs1.shape[1],inputs1.shape[0])).astype(np.uint8) 145 | 146 | 147 | if random.random() < 0.5: 148 | inputs1 = cv2.flip(inputs1, 1) # horizontal flip 149 | labels1 = cv2.flip(labels1, 1) # horizontal flip 150 | 151 | inputs1 = inputs1.transpose((2, 0, 1)) 152 | inputs1 = torch.from_numpy(inputs1) 153 | labels1 = torch.from_numpy(labels1) 154 | #print(inputs.shape) 155 | inputs[i, : , : ,:]=inputs1 156 | labels[i, : ,:]=labels1 157 | return inputs, labels 158 | --------------------------------------------------------------------------------