├── LICENSE
├── README.md
├── __pycache__
├── args.cpython-36.pyc
├── test.cpython-36.pyc
├── train.cpython-36.pyc
├── transforms.cpython-36.pyc
└── utils.cpython-36.pyc
├── args.py
├── data
├── Cityscapes
│ ├── gtFine_trainvaltest
│ └── leftImg8bit_trainvaltest
├── README.md
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── camvid.cpython-36.pyc
│ ├── cityscapes.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── camvid.py
├── cityscapes.py
└── utils.py
├── main.py
├── metric
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── confusionmatrix.cpython-36.pyc
│ ├── iou.cpython-36.pyc
│ └── metric.cpython-36.pyc
├── confusionmatrix.py
├── iou.py
└── metric.py
├── models
├── __pycache__
│ └── rpnet.cpython-36.pyc
└── rpnet.py
├── pictures
├── 1.png
└── 2.png
├── requirements.txt
├── save
├── RPNet
└── RPNet_summary.txt
├── test.py
├── train.py
├── train.sh
├── transforms.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 davidtvs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch-RPNet
2 |
3 | PyTorch implementation of [*Residual Pyramid Learning for Single-Shot Semantic Segmentation*](https://arxiv.org/abs/1903.09746)
4 |
5 | ### Network
6 | 
7 |
8 | ### Segmentation Result
9 |
10 | 
11 |
12 |
13 | ### Citing RPNet
14 |
15 | Please cite RPNet in your publications if it helps your research:
16 |
17 | @ARTICLE{8744483,
18 | author={X. {Chen} and X. {Lou} and L. {Bai} and J. {Han}},
19 | journal={IEEE Transactions on Intelligent Transportation Systems},
20 | title={Residual Pyramid Learning for Single-Shot Semantic Segmentation},
21 | year={2019},
22 | volume={},
23 | number={},
24 | pages={1-11},
25 | keywords={Intelligent vehicles;real-time vision;scene understanding;residual learning.},
26 | doi={10.1109/TITS.2019.2922252},
27 | ISSN={1524-9050},
28 | month={},}
29 |
30 |
31 |
32 | This implementation has been tested on the Cityscapes and CamVid (TBD) datasets. Currently, a pre-trained version of the model trained in CamVid and Cityscapes is available [here](https://github.com/superlxt/RPnet-Pytorch/tree/master/save).
33 |
34 |
35 | | Dataset | Classes 1 | Input resolution | Batch size | Epochs | Mean IoU (%) |
36 | |:--------------------------------------------------------------------:|:--------------------:|:----------------:|:----------:|:----------:|:-----------------:|
37 | | [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) | 11 | 480x360 | 3 | 300 x 4 | 64.82 |
38 | | [Cityscapes](https://www.cityscapes-dataset.com/) | 19 | 1024x512 | 3 | 300 x 4 | 70.4 (val) |
39 |
40 | * When referring to the number of classes, the void/unlabeled class is always excluded.
41 | * Just for reference since changes in implementation, datasets, and hardware can lead to very different results. Reference hardware: Nvidia GTX 1080ti and an Intel Core i7-7920x 2.9GHz.
42 |
43 |
44 | The results on Cityscapes dataset
45 |
46 |
47 | | Method | Input Size | Mean IoU | Mean iIoU | fps | FLOPs |
48 | |:----------------:|:--------------:|:------------:|:-------------:|:--------:|:----------:|
49 | | ENet | 1024*512 | 58.3 | 34.4 | 77 | 4.03B |
50 | | ERFNet | 1024*512 | 68.0 | 40.4 | 59 | 25.6B |
51 | | ESPNet | 1024*512 | 60.3 | 31.8 | 139 | 3.19B |
52 | | BiSeNet | 1036*768 | 68.4 | - | 69 | 26.4B |
53 | | ICNet | 2048*1024 | 69.5 | - | 30 | - |
54 | |DeepLab(Mobilenet)| 2048*1024 | 70.71(val) | - | 16 | 21.3B |
55 | | LRR | 2048*1024 | 69.7 | 48.0 | 2 | - |
56 | | RefinNet | 2048*1024 | 73.6 | 47.2 | - | 263B |
57 | | **RPNet(ENet)** | 1024*512 | 63.37 | 39.0 | 88 | 4.28B |
58 | | **RPNet(ERFNet)**| 1024*512 | 67.9 | 44.9 | 123 | 20.7B |
59 |
60 |
61 |
62 | ## Installation
63 |
64 | 1. Python 3 and pip.
65 | 2. Set up a virtual environment (optional, but recommended).
66 | 3. Install dependencies using pip: ``pip install -r requirements.txt``.
67 |
68 |
69 |
70 | ### Examples: Training
71 |
72 | ```
73 | sh train.sh
74 | ```
75 |
76 | ### Examples: Testing
77 |
78 | ```
79 | python main.py -m test --step 1
80 | ```
81 |
82 |
--------------------------------------------------------------------------------
/__pycache__/args.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/args.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/test.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | def get_arguments():
5 | """Defines command-line arguments, and parses them.
6 |
7 | """
8 | parser = ArgumentParser()
9 |
10 | # Execution mode
11 | parser.add_argument(
12 | "--mode",
13 | "-m",
14 | choices=['train', 'test', 'full'],
15 | default='train',
16 | help=("train: performs training and validation; test: tests the model "
17 | "found in \"--save_dir\" with name \"--name\" on \"--dataset\"; "
18 | "full: combines train and test modes. Default: train"))
19 | parser.add_argument(
20 | "--resume",
21 | action='store_true',
22 | help=("The model found in \"--checkpoint_dir/--name/\" and filename "
23 | "\"--name.h5\" is loaded."))
24 |
25 | # Hyperparameters
26 | parser.add_argument(
27 | "--batch-size",
28 | "-b",
29 | type=int,
30 | default=3,
31 | help="The batch size. Default: 10")
32 | parser.add_argument(
33 | "--epochs",
34 | type=int,
35 | default=300,
36 | help="Number of training epochs. Default: 300")
37 | parser.add_argument(
38 | "--learning-rate",
39 | "-lr",
40 | type=float,
41 | default=5e-4,
42 | help="The learning rate. Default: 5e-4")
43 | parser.add_argument(
44 | "--lr-decay",
45 | type=float,
46 | default=0.1,
47 | help="The learning rate decay factor. Default: 0.5")
48 | parser.add_argument(
49 | "--lr-decay-epochs",
50 | type=int,
51 | default=100,
52 | help="The number of epochs before adjusting the learning rate. "
53 | "Default: 100")
54 | parser.add_argument(
55 | "--weight-decay",
56 | "-wd",
57 | type=float,
58 | default=1e-4,
59 | help="L2 regularization factor. Default: 2e-4")
60 |
61 | # Dataset
62 | parser.add_argument(
63 | "--dataset",
64 | choices=['camvid', 'cityscapes'],
65 | default='cityscapes',
66 | help="Dataset to use. Default: cityscapes")
67 | parser.add_argument(
68 | "--dataset-dir",
69 | type=str,
70 | default="data/Cityscapes",
71 | help="Path to the root directory of the selected dataset. "
72 | "Default: data/Cityscapes")
73 | parser.add_argument(
74 | "--height",
75 | type=int,
76 | default=512,
77 | help="The image height. Default: 1024")
78 | parser.add_argument(
79 | "--width",
80 | type=int,
81 | default=1024,
82 | help="The image height. Default: 2048")
83 | parser.add_argument(
84 | "--weighing",
85 | choices=['enet', 'mfb', 'none'],
86 | default='mfb',
87 | help="The class weighing technique to apply to the dataset. "
88 | "Default: enet")
89 | parser.add_argument(
90 | "--with-unlabeled",
91 | dest='ignore_unlabeled',
92 | action='store_false',
93 | help="The unlabeled class is not ignored.")
94 |
95 | # Step
96 | parser.add_argument(
97 | "--step",
98 | type=int,
99 | default='1',
100 | help="Step to choose. Default: 1")
101 |
102 | # Settings
103 | parser.add_argument(
104 | "--workers",
105 | type=int,
106 | default=4,
107 | help="Number of subprocesses to use for data loading. Default: 4")
108 | parser.add_argument(
109 | "--print-step",
110 | action='store_true',
111 | help="Print loss every step")
112 | parser.add_argument(
113 | "--imshow-batch",
114 | action='store_true',
115 | help=("Displays batch images when loading the dataset and making "
116 | "predictions."))
117 | parser.add_argument(
118 | "--no-cuda",
119 | dest='cuda',
120 | default='store_false',
121 | help="CPU only.")
122 |
123 | # Storage settings
124 | parser.add_argument(
125 | "--name",
126 | type=str,
127 | default='RPNet',
128 | help="Name given to the model when saving. Default: RPNet")
129 | parser.add_argument(
130 | "--save-dir",
131 | type=str,
132 | default='save',
133 | help="The directory where models are saved. Default: save")
134 |
135 | return parser.parse_args()
136 |
--------------------------------------------------------------------------------
/data/Cityscapes/gtFine_trainvaltest:
--------------------------------------------------------------------------------
1 | /media/lxt/CVPR/Fenge/PyTorch-ENet/data/Cityscapes/gtFine_trainvaltest
--------------------------------------------------------------------------------
/data/Cityscapes/leftImg8bit_trainvaltest:
--------------------------------------------------------------------------------
1 | /media/lxt/CVPR/Fenge/PyTorch-ENet/data/Cityscapes/leftImg8bit_trainvaltest
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Supported datasets
2 |
3 | - CamVid
4 | - CityScapes
5 |
6 | Note: When referring to the number of classes, the void/unlabeled class is excluded.
7 |
8 | ## CamVid Dataset
9 |
10 | The Cambridge-driving Labeled Video Database (CamVid) is a collection of over ten minutes of high-quality 30Hz footage with object class semantic labels at 1Hz and in part, 15Hz. Each pixel is associated with one of 32 classes.
11 |
12 | The CamVid dataset supported here is a 12 class version developed by the authors of SegNet. [Download link here](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). For actual training, an 11 class version is used - the "road marking" class is combined with the "road" class.
13 |
14 | More detailed information about the CamVid dataset can be found [here](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) and on the [SegNet GitHub repository](https://github.com/alexgkendall/SegNet-Tutorial).
15 |
16 | ## Cityscapes
17 |
18 | Cityscapes is a set of stereo video sequences recorded in streets from 50 different cities with 34 different classes. There are 5000 images with fine annotations and 20000 images coarsely annotated.
19 |
20 | The version supported here is the finely annotated one with 19 classes.
21 |
22 | For more detailed information see the official [website](https://www.cityscapes-dataset.com/) and [repository](https://github.com/mcordts/cityscapesScripts).
23 |
24 | The dataset can be downloaded from https://www.cityscapes-dataset.com/downloads/. At this time, a registration is required to download the data.
25 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .camvid import CamVid
2 | from .cityscapes import Cityscapes
3 |
4 | __all__ = ['CamVid', 'Cityscapes']
5 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/camvid.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/camvid.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/cityscapes.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/cityscapes.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/data/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/data/camvid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch.utils.data as data
4 | from . import utils
5 |
6 |
7 | class CamVid(data.Dataset):
8 | """CamVid dataset loader where the dataset is arranged as in
9 | https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid.
10 |
11 |
12 | Keyword arguments:
13 | - root_dir (``string``): Root directory path.
14 | - mode (``string``): The type of dataset: 'train' for training set, 'val'
15 | for validation set, and 'test' for test set.
16 | - transform (``callable``, optional): A function/transform that takes in
17 | an PIL image and returns a transformed version. Default: None.
18 | - label_transform (``callable``, optional): A function/transform that takes
19 | in the target and transforms it. Default: None.
20 | - loader (``callable``, optional): A function to load an image given its
21 | path. By default ``default_loader`` is used.
22 |
23 | """
24 | # Training dataset root folders
25 | train_folder = 'train'
26 | train_lbl_folder = 'trainannot'
27 |
28 | # Validation dataset root folders
29 | val_folder = 'val'
30 | val_lbl_folder = 'valannot'
31 |
32 | # Test dataset root folders
33 | test_folder = 'test'
34 | test_lbl_folder = 'testannot'
35 |
36 | # Images extension
37 | img_extension = '.png'
38 |
39 | # Default encoding for pixel value, class name, and class color
40 | color_encoding = OrderedDict([
41 | ('sky', (128, 128, 128)),
42 | ('building', (128, 0, 0)),
43 | ('pole', (192, 192, 128)),
44 | ('road_marking', (255, 69, 0)),
45 | ('road', (128, 64, 128)),
46 | ('pavement', (60, 40, 222)),
47 | ('tree', (128, 128, 0)),
48 | ('sign_symbol', (192, 128, 128)),
49 | ('fence', (64, 64, 128)),
50 | ('car', (64, 0, 128)),
51 | ('pedestrian', (64, 64, 0)),
52 | ('bicyclist', (0, 128, 192)),
53 | ('unlabeled', (0, 0, 0))
54 | ])
55 |
56 | def __init__(self,
57 | root_dir,
58 | mode='train',
59 | transform=None,
60 | label_transform=None,
61 | loader=utils.pil_loader):
62 | self.root_dir = root_dir
63 | self.mode = mode
64 | self.transform = transform
65 | self.label_transform = label_transform
66 | self.loader = loader
67 |
68 | if self.mode.lower() == 'train':
69 | # Get the training data and labels filepaths
70 | self.train_data = utils.get_files(
71 | os.path.join(root_dir, self.train_folder),
72 | extension_filter=self.img_extension)
73 |
74 | self.train_labels = utils.get_files(
75 | os.path.join(root_dir, self.train_lbl_folder),
76 | extension_filter=self.img_extension)
77 | elif self.mode.lower() == 'val':
78 | # Get the validation data and labels filepaths
79 | self.val_data = utils.get_files(
80 | os.path.join(root_dir, self.val_folder),
81 | extension_filter=self.img_extension)
82 |
83 | self.val_labels = utils.get_files(
84 | os.path.join(root_dir, self.val_lbl_folder),
85 | extension_filter=self.img_extension)
86 | elif self.mode.lower() == 'test':
87 | # Get the test data and labels filepaths
88 | self.test_data = utils.get_files(
89 | os.path.join(root_dir, self.test_folder),
90 | extension_filter=self.img_extension)
91 |
92 | self.test_labels = utils.get_files(
93 | os.path.join(root_dir, self.test_lbl_folder),
94 | extension_filter=self.img_extension)
95 | else:
96 | raise RuntimeError("Unexpected dataset mode. "
97 | "Supported modes are: train, val and test")
98 |
99 | def __getitem__(self, index):
100 | """
101 | Args:
102 | - index (``int``): index of the item in the dataset
103 |
104 | Returns:
105 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
106 | of the image.
107 |
108 | """
109 | if self.mode.lower() == 'train':
110 | data_path, label_path = self.train_data[index], self.train_labels[
111 | index]
112 | elif self.mode.lower() == 'val':
113 | data_path, label_path = self.val_data[index], self.val_labels[
114 | index]
115 | elif self.mode.lower() == 'test':
116 | data_path, label_path = self.test_data[index], self.test_labels[
117 | index]
118 | else:
119 | raise RuntimeError("Unexpected dataset mode. "
120 | "Supported modes are: train, val and test")
121 |
122 | img, label = self.loader(data_path, label_path)
123 |
124 | if self.transform is not None:
125 | img = self.transform(img)
126 |
127 | if self.label_transform is not None:
128 | label = self.label_transform(label)
129 |
130 | return img, label
131 |
132 | def __len__(self):
133 | """Returns the length of the dataset."""
134 | if self.mode.lower() == 'train':
135 | return len(self.train_data)
136 | elif self.mode.lower() == 'val':
137 | return len(self.val_data)
138 | elif self.mode.lower() == 'test':
139 | return len(self.test_data)
140 | else:
141 | raise RuntimeError("Unexpected dataset mode. "
142 | "Supported modes are: train, val and test")
143 |
--------------------------------------------------------------------------------
/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch.utils.data as data
4 | from . import utils
5 |
6 |
7 | class Cityscapes(data.Dataset):
8 | """Cityscapes dataset https://www.cityscapes-dataset.com/.
9 |
10 | Keyword arguments:
11 | - root_dir (``string``): Root directory path.
12 | - mode (``string``): The type of dataset: 'train' for training set, 'val'
13 | for validation set, and 'test' for test set.
14 | - transform (``callable``, optional): A function/transform that takes in
15 | an PIL image and returns a transformed version. Default: None.
16 | - label_transform (``callable``, optional): A function/transform that takes
17 | in the target and transforms it. Default: None.
18 | - loader (``callable``, optional): A function to load an image given its
19 | path. By default ``default_loader`` is used.
20 |
21 | """
22 | # Training dataset root folders
23 | train_folder = "leftImg8bit_trainvaltest/leftImg8bit/train"
24 | train_lbl_folder = "gtFine_trainvaltest/gtFine/train"
25 |
26 | # Validation dataset root folders
27 | val_folder = "leftImg8bit_trainvaltest/leftImg8bit/val"
28 | val_lbl_folder = "gtFine_trainvaltest/gtFine/val"
29 |
30 | # Test dataset root folders
31 | test_folder = "leftImg8bit_trainvaltest/leftImg8bit/val"
32 | test_lbl_folder = "gtFine_trainvaltest/gtFine/val"
33 |
34 | # Filters to find the images
35 | img_extension = '.png'
36 | lbl_name_filter = 'labelIds'
37 |
38 | # The values associated with the 35 classes
39 | full_classes = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
40 | 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
41 | 32, 33, -1)
42 | # The values above are remapped to the following
43 | new_classes = (0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 5, 0, 0, 0, 6, 0, 7,
44 | 8, 9, 10, 11, 12, 13, 14, 15, 16, 0, 0, 17, 18, 19, 0)
45 |
46 | # Default encoding for pixel value, class name, and class color
47 | color_encoding = OrderedDict([
48 | ('unlabeled', (0, 0, 0)),
49 | ('road', (128, 64, 128)),
50 | ('sidewalk', (244, 35, 232)),
51 | ('building', (70, 70, 70)),
52 | ('wall', (102, 102, 156)),
53 | ('fence', (190, 153, 153)),
54 | ('pole', (153, 153, 153)),
55 | ('traffic_light', (250, 170, 30)),
56 | ('traffic_sign', (220, 220, 0)),
57 | ('vegetation', (107, 142, 35)),
58 | ('terrain', (152, 251, 152)),
59 | ('sky', (70, 130, 180)),
60 | ('person', (220, 20, 60)),
61 | ('rider', (255, 0, 0)),
62 | ('car', (0, 0, 142)),
63 | ('truck', (0, 0, 70)),
64 | ('bus', (0, 60, 100)),
65 | ('train', (0, 80, 100)),
66 | ('motorcycle', (0, 0, 230)),
67 | ('bicycle', (119, 11, 32))
68 | ])
69 |
70 | def __init__(self,
71 | root_dir,
72 | mode='train',
73 | transform=None,
74 | label_transform=None,
75 | loader=utils.pil_loader):
76 | self.root_dir = root_dir
77 | self.mode = mode
78 | self.transform = transform
79 | self.label_transform = label_transform
80 | self.loader = loader
81 |
82 | if self.mode.lower() == 'train':
83 | # Get the training data and labels filepaths
84 | self.train_data = utils.get_files(
85 | os.path.join(root_dir, self.train_folder),
86 | extension_filter=self.img_extension)
87 |
88 | self.train_labels = utils.get_files(
89 | os.path.join(root_dir, self.train_lbl_folder),
90 | name_filter=self.lbl_name_filter,
91 | extension_filter=self.img_extension)
92 | elif self.mode.lower() == 'val':
93 | # Get the validation data and labels filepaths
94 | self.val_data = utils.get_files(
95 | os.path.join(root_dir, self.val_folder),
96 | extension_filter=self.img_extension)
97 |
98 | self.val_labels = utils.get_files(
99 | os.path.join(root_dir, self.val_lbl_folder),
100 | name_filter=self.lbl_name_filter,
101 | extension_filter=self.img_extension)
102 | elif self.mode.lower() == 'test':
103 | # Get the test data and labels filepaths
104 | self.test_data = utils.get_files(
105 | os.path.join(root_dir, self.test_folder),
106 | extension_filter=self.img_extension)
107 |
108 | self.test_labels = utils.get_files(
109 | os.path.join(root_dir, self.test_lbl_folder),
110 | name_filter=self.lbl_name_filter,
111 | extension_filter=self.img_extension)
112 | else:
113 | raise RuntimeError("Unexpected dataset mode. "
114 | "Supported modes are: train, val and test")
115 |
116 | def __getitem__(self, index):
117 | """
118 | Args:
119 | - index (``int``): index of the item in the dataset
120 |
121 | Returns:
122 | A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
123 | of the image.
124 |
125 | """
126 | if self.mode.lower() == 'train':
127 | data_path, label_path = self.train_data[index], self.train_labels[
128 | index]
129 | elif self.mode.lower() == 'val':
130 | data_path, label_path = self.val_data[index], self.val_labels[
131 | index]
132 | elif self.mode.lower() == 'test':
133 | data_path, label_path = self.test_data[index], self.test_labels[
134 | index]
135 | else:
136 | raise RuntimeError("Unexpected dataset mode. "
137 | "Supported modes are: train, val and test")
138 |
139 | img, label = self.loader(data_path, label_path)
140 | # Remap class labels
141 | label = utils.remap(label, self.full_classes, self.new_classes)
142 |
143 | if self.transform is not None:
144 | img = self.transform(img)
145 |
146 |
147 | if self.label_transform is not None:
148 | label = self.label_transform(label)
149 | return img, label
150 |
151 | def __len__(self):
152 | """Returns the length of the dataset."""
153 | if self.mode.lower() == 'train':
154 | return len(self.train_data)
155 | elif self.mode.lower() == 'val':
156 | return len(self.val_data)
157 | elif self.mode.lower() == 'test':
158 | return len(self.test_data)
159 | else:
160 | raise RuntimeError("Unexpected dataset mode. "
161 | "Supported modes are: train, val and test")
162 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import numpy as np
4 |
5 |
6 | def get_files(folder, name_filter=None, extension_filter=None):
7 | """Helper function that returns the list of files in a specified folder
8 | with a specified extension.
9 |
10 | Keyword arguments:
11 | - folder (``string``): The path to a folder.
12 | - name_filter (```string``, optional): The returned files must contain
13 | this substring in their filename. Default: None; files are not filtered.
14 | - extension_filter (``string``, optional): The desired file extension.
15 | Default: None; files are not filtered
16 |
17 | """
18 | if not os.path.isdir(folder):
19 | raise RuntimeError("\"{0}\" is not a folder.".format(folder))
20 |
21 | # Filename filter: if not specified don't filter (condition always true);
22 | # otherwise, use a lambda expression to filter out files that do not
23 | # contain "name_filter"
24 | if name_filter is None:
25 | # This looks hackish...there is probably a better way
26 | name_cond = lambda filename: True
27 | else:
28 | name_cond = lambda filename: name_filter in filename
29 |
30 | # Extension filter: if not specified don't filter (condition always true);
31 | # otherwise, use a lambda expression to filter out files whose extension
32 | # is not "extension_filter"
33 | if extension_filter is None:
34 | # This looks hackish...there is probably a better way
35 | ext_cond = lambda filename: True
36 | else:
37 | ext_cond = lambda filename: filename.endswith(extension_filter)
38 |
39 | filtered_files = []
40 |
41 | # Explore the directory tree to get files that contain "name_filter" and
42 | # with extension "extension_filter"
43 | for path, _, files in os.walk(folder):
44 | files.sort()
45 | for file in files:
46 | if name_cond(file) and ext_cond(file):
47 | full_path = os.path.join(path, file)
48 | filtered_files.append(full_path)
49 |
50 | return filtered_files
51 |
52 |
53 | def pil_loader(data_path, label_path):
54 | """Loads a sample and label image given their path as PIL images.
55 |
56 | Keyword arguments:
57 | - data_path (``string``): The filepath to the image.
58 | - label_path (``string``): The filepath to the ground-truth image.
59 |
60 | Returns the image and the label as PIL images.
61 |
62 | """
63 | data = Image.open(data_path)
64 | label = Image.open(label_path)
65 |
66 | return data, label
67 |
68 |
69 | def remap(image, old_values, new_values):
70 | assert isinstance(image, Image.Image) or isinstance(
71 | image, np.ndarray), "image must be of type PIL.Image or numpy.ndarray"
72 | assert type(new_values) is tuple, "new_values must be of type tuple"
73 | assert type(old_values) is tuple, "old_values must be of type tuple"
74 | assert len(new_values) == len(
75 | old_values), "new_values and old_values must have the same length"
76 |
77 | # If image is a PIL.Image convert it to a numpy array
78 | if isinstance(image, Image.Image):
79 | image = np.array(image)
80 |
81 | # Replace old values by the new ones
82 | tmp = np.zeros_like(image)
83 | for old, new in zip(old_values, new_values):
84 | # Since tmp is already initialized as zeros we can skip new values
85 | # equal to 0
86 | if new != 0:
87 | tmp[image == old] = new
88 |
89 | return Image.fromarray(tmp)
90 |
91 |
92 | def enet_weighing(dataloader, num_classes, c=1.02):
93 | """Computes class weights as described in the ENet paper:
94 |
95 | w_class = 1 / (ln(c + p_class)),
96 |
97 | where c is usually 1.02 and p_class is the propensity score of that
98 | class:
99 |
100 | propensity_score = freq_class / total_pixels.
101 |
102 | References: https://arxiv.org/abs/1606.02147
103 |
104 | Keyword arguments:
105 | - dataloader (``data.Dataloader``): A data loader to iterate over the
106 | dataset.
107 | - num_classes (``int``): The number of classes.
108 | - c (``int``, optional): AN additional hyper-parameter which restricts
109 | the interval of values for the weights. Default: 1.02.
110 |
111 | """
112 | class_count = 0
113 | total = 0
114 | for _, label in dataloader:
115 | label = label.cpu().numpy()
116 |
117 | # Flatten label
118 | flat_label = label.flatten()
119 |
120 | # Sum up the number of pixels of each class and the total pixel
121 | # counts for each label
122 | class_count += np.bincount(flat_label, minlength=num_classes)
123 | total += flat_label.size
124 |
125 | # Compute propensity score and then the weights for each class
126 | propensity_score = class_count / total
127 | class_weights = 1 / (np.log(c + propensity_score))
128 |
129 | return class_weights
130 |
131 |
132 | def median_freq_balancing(dataloader, num_classes):
133 | """Computes class weights using median frequency balancing as described
134 | in https://arxiv.org/abs/1411.4734:
135 |
136 | w_class = median_freq / freq_class,
137 |
138 | where freq_class is the number of pixels of a given class divided by
139 | the total number of pixels in images where that class is present, and
140 | median_freq is the median of freq_class.
141 |
142 | Keyword arguments:
143 | - dataloader (``data.Dataloader``): A data loader to iterate over the
144 | dataset.
145 | whose weights are going to be computed.
146 | - num_classes (``int``): The number of classes
147 |
148 | """
149 | class_count = 0
150 | total = 0
151 | for _, label in dataloader:
152 | label = label.cpu().numpy()
153 |
154 | # Flatten label
155 | flat_label = label.flatten()
156 |
157 | # Sum up the class frequencies
158 | bincount = np.bincount(flat_label, minlength=num_classes)
159 |
160 | # Create of mask of classes that exist in the label
161 | mask = bincount > 0
162 | # Multiply the mask by the pixel count. The resulting array has
163 | # one element for each class. The value is either 0 (if the class
164 | # does not exist in the label) or equal to the pixel count (if
165 | # the class exists in the label)
166 | total += mask * flat_label.size
167 |
168 | # Sum up the number of pixels found for each class
169 | class_count += bincount
170 |
171 | # Compute the frequency and its median
172 | freq = class_count / total
173 | med = np.median(freq)
174 |
175 | return med / freq
176 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | torch.cuda.set_device(1)
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.optim.lr_scheduler as lr_scheduler
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 | from torch.autograd import Variable
11 |
12 | import transforms as ext_transforms
13 | from models.rpnet import RPNet
14 | from train import Train
15 | from test import Test
16 | from metric.iou import IoU
17 | from args import get_arguments
18 | from data.utils import enet_weighing, median_freq_balancing
19 | import utils
20 | from PIL import Image
21 |
22 | import numpy as np
23 | # Get the arguments
24 | args = get_arguments()
25 |
26 | use_cuda = args.cuda and torch.cuda.is_available()
27 |
28 |
29 | def load_dataset(dataset):
30 | print("\nLoading dataset...\n")
31 |
32 | print("Selected dataset:", args.dataset)
33 | print("Dataset directory:", args.dataset_dir)
34 | print("Save directory:", args.save_dir)
35 |
36 | image_transform = transforms.Compose(
37 | [transforms.Resize((args.height, args.width),Image.BILINEAR),
38 | transforms.ToTensor()])
39 |
40 | label_transform = transforms.Compose([
41 | transforms.Resize((args.height, args.width),Image.NEAREST),
42 | ext_transforms.PILToLongTensor()
43 | ])
44 |
45 | # Get selected dataset
46 | # Load the training set as tensors
47 | train_set = dataset(
48 | args.dataset_dir,
49 | transform=image_transform,
50 | label_transform=label_transform)
51 | train_loader = data.DataLoader(
52 | train_set,
53 | batch_size=args.batch_size,
54 | shuffle=True,
55 | num_workers=args.workers)
56 |
57 | # Load the validation set as tensors
58 | val_set = dataset(
59 | args.dataset_dir,
60 | mode='val',
61 | transform=image_transform,
62 | label_transform=label_transform)
63 | val_loader = data.DataLoader(
64 | val_set,
65 | batch_size=args.batch_size,
66 | shuffle=True,
67 | num_workers=args.workers)
68 |
69 | # Load the test set as tensors
70 | test_set = dataset(
71 | args.dataset_dir,
72 | mode='test',
73 | transform=image_transform,
74 | label_transform=label_transform)
75 | test_loader = data.DataLoader(
76 | test_set,
77 | batch_size=args.batch_size,
78 | shuffle=True,
79 | num_workers=args.workers)
80 |
81 | # Get encoding between pixel valus in label images and RGB colors
82 | class_encoding = train_set.color_encoding
83 |
84 | # Remove the road_marking class from the CamVid dataset as it's merged
85 | # with the road class
86 | if args.dataset.lower() == 'camvid':
87 | del class_encoding['road_marking']
88 |
89 | # Get number of classes to predict
90 | num_classes = len(class_encoding)
91 |
92 | # Print information for debugging
93 | print("Number of classes to predict:", num_classes)
94 | print("Train dataset size:", len(train_set))
95 | print("Validation dataset size:", len(val_set))
96 |
97 | # Get a batch of samples to display
98 | if args.mode.lower() == 'test':
99 | images, labels = iter(test_loader).next()
100 | else:
101 | images, labels = iter(train_loader).next()
102 | print("Image size:", images.size())
103 | print("Label size:", labels.size())
104 | print("Class-color encoding:", class_encoding)
105 |
106 | # Show a batch of samples and labels
107 | if args.imshow_batch:
108 | print("Close the figure window to continue...")
109 | label_to_rgb = transforms.Compose([
110 | ext_transforms.LongTensorToRGBPIL(class_encoding),
111 | transforms.ToTensor()
112 | ])
113 | color_labels = utils.batch_transform(labels, label_to_rgb)
114 | utils.imshow_batch(images, color_labels)
115 |
116 | # Get class weights from the selected weighing technique
117 | print("\nWeighing technique:", args.weighing)
118 | class_weights = np.array([0.0,2.7,6.1,3.6,7.7,7.7,8.1,8.6,8.4,4.3,7.7,6.8,8.0,8.6,5.9,7.7,7.5,6.6,8.5,8.4])
119 | if class_weights is not None:
120 | class_weights = torch.from_numpy(class_weights).float()
121 | # Set the weight of the unlabeled class to 0
122 | if args.ignore_unlabeled:
123 | ignore_index = list(class_encoding).index('unlabeled')
124 | class_weights[ignore_index] = 0
125 |
126 | print("Class weights:", class_weights)
127 |
128 | return (train_loader, val_loader,
129 | test_loader), class_weights, class_encoding
130 |
131 |
132 | def train(train_loader, val_loader, class_weights, class_encoding):
133 | print("\nTraining...\n")
134 |
135 | num_classes = len(class_encoding)
136 |
137 | # Intialize RPNet
138 | model = RPNet(num_classes)
139 |
140 | # We are going to use the CrossEntropyLoss loss function as it's most
141 | # frequentely used in classification problems with multiple classes which
142 | # fits the problem. This criterion combines LogSoftMax and NLLLoss.
143 | criterion = nn.CrossEntropyLoss(weight=class_weights)
144 |
145 | # ENet authors used Adam as the optimizer
146 | optimizer = optim.Adam(
147 | model.parameters(),
148 | lr=args.learning_rate,
149 | weight_decay=args.weight_decay)
150 |
151 | # Learning rate decay scheduler
152 | lmd = lambda epoch: (1-epoch/args.epochs) ** 0.9
153 | lr_updater = lr_scheduler.LambdaLR(optimizer, lr_lambda=lmd)
154 |
155 | # Evaluation metric
156 | if args.ignore_unlabeled:
157 | ignore_index = list(class_encoding).index('unlabeled')
158 | else:
159 | ignore_index = None
160 | metric = IoU(num_classes, ignore_index=ignore_index)
161 |
162 | if use_cuda:
163 | model = model.cuda()
164 | criterion = criterion.cuda()
165 |
166 | # Optionally resume from a checkpoint
167 | if args.resume:
168 | model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
169 | model, optimizer, args.save_dir, args.name)
170 | print("Resuming from model: Start epoch = {0} "
171 | "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
172 | else:
173 | start_epoch = 0
174 | best_miou = 0
175 |
176 |
177 | # Step
178 | step = args.step
179 |
180 | # Start Training
181 | train = Train(model, train_loader, optimizer, criterion, metric, use_cuda, step)
182 | val = Test(model, val_loader, criterion, metric, use_cuda, step)
183 | for epoch in range(start_epoch, args.epochs):
184 | print(">>>> [Epoch: {0:d}] Training".format(epoch))
185 | train.model.train()
186 | lr_updater.step()
187 | epoch_loss, (iou, miou) = train.run_epoch(args.print_step)
188 |
189 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
190 | format(epoch, epoch_loss, miou),'current lr {:.5e}'.format(optimizer.param_groups[0]['lr']))
191 |
192 | if (epoch + 1) % 1 == 0 or epoch + 1 == args.epochs:
193 | val.model.eval()
194 | print(">>>> [Epoch: {0:d}] Validation".format(epoch))
195 |
196 | loss, (iou, miou) = val.run_epoch(args.print_step)
197 |
198 | print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
199 | format(epoch, loss, miou))
200 |
201 | # Print per class IoU on last epoch or if best iou
202 | if epoch + 1 == args.epochs or miou > best_miou:
203 | for key, class_iou in zip(class_encoding.keys(), iou):
204 | print("{0}: {1:.4f}".format(key, class_iou))
205 |
206 | # Save the model if it's the best thus far
207 | if miou > best_miou:
208 | print("\nBest model thus far. Saving...\n")
209 | best_miou = miou
210 | utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,args)
211 |
212 |
213 | return model
214 |
215 |
216 | def test(model, test_loader, class_weights, class_encoding, step):
217 | print("\nTesting...\n")
218 |
219 | num_classes = len(class_encoding)
220 |
221 | # We are going to use the CrossEntropyLoss loss function as it's most
222 | # frequentely used in classification problems with multiple classes which
223 | # fits the problem. This criterion combines LogSoftMax and NLLLoss.
224 | criterion = nn.CrossEntropyLoss(weight=class_weights)
225 | if use_cuda:
226 | criterion = criterion.cuda()
227 |
228 | # Evaluation metric
229 | if args.ignore_unlabeled:
230 | ignore_index = list(class_encoding).index('unlabeled')
231 | else:
232 | ignore_index = None
233 | metric = IoU(num_classes, ignore_index=ignore_index)
234 |
235 | # Test the trained model on the test set
236 | test = Test(model, test_loader, criterion, metric, use_cuda, step)
237 |
238 | print(">>>> Running test dataset")
239 |
240 | loss, (iou, miou) = test.run_epoch(args.print_step)
241 | class_iou = dict(zip(class_encoding.keys(), iou))
242 |
243 | print(">>>> Avg. loss: {0:.4f} | Mean IoU: {1:.4f}".format(loss, miou))
244 |
245 | # Print per class IoU
246 | for key, class_iou in zip(class_encoding.keys(), iou):
247 | print("{0}: {1:.4f}".format(key, class_iou))
248 |
249 | # Show a batch of samples and labels
250 | if args.imshow_batch:
251 | print("A batch of predictions from the test set...")
252 | images, _ = iter(test_loader).next()
253 | predict(model, images, class_encoding)
254 |
255 |
256 | def predict(model, images, class_encoding):
257 | images = Variable(images)
258 | if use_cuda:
259 | images = images.cuda()
260 |
261 | # Make predictions!
262 | predictions = model(images)
263 |
264 | # Predictions is one-hot encoded with "num_classes" channels.
265 | # Convert it to a single int using the indices where the maximum (1) occurs
266 | _, predictions = torch.max(predictions.data, 1)
267 |
268 | label_to_rgb = transforms.Compose([
269 | ext_transforms.LongTensorToRGBPIL(class_encoding),
270 | transforms.ToTensor()
271 | ])
272 | color_predictions = utils.batch_transform(predictions.cpu(), label_to_rgb)
273 | utils.imshow_batch(images.data.cpu(), color_predictions)
274 |
275 |
276 | # Run only if this module is being run directly
277 | if __name__ == '__main__':
278 |
279 | # Fail fast if the dataset directory doesn't exist
280 | assert os.path.isdir(
281 | args.dataset_dir), "The directory \"{0}\" doesn't exist.".format(
282 | args.dataset_dir)
283 |
284 | # Fail fast if the saving directory doesn't exist
285 | assert os.path.isdir(
286 | args.save_dir), "The directory \"{0}\" doesn't exist.".format(
287 | args.save_dir)
288 |
289 | # Import the requested dataset
290 | if args.dataset.lower() == 'camvid':
291 | from data import CamVid as dataset
292 | elif args.dataset.lower() == 'cityscapes':
293 | from data import Cityscapes as dataset
294 | else:
295 | # Should never happen...but just in case it does
296 | raise RuntimeError("\"{0}\" is not a supported dataset.".format(
297 | args.dataset))
298 |
299 | loaders, w_class, class_encoding = load_dataset(dataset)
300 | train_loader, val_loader, test_loader = loaders
301 |
302 | if args.mode.lower() in {'train', 'full'}:
303 | model = train(train_loader, val_loader, w_class, class_encoding)
304 | if args.mode.lower() == 'full':
305 | test(model, test_loader, w_class, class_encoding)
306 | elif args.mode.lower() == 'test':
307 | # Intialize a new RPNet model
308 | num_classes = len(class_encoding)
309 | model = RPNet(num_classes)
310 | print(model)
311 | #model = nn.DataParallel(model)
312 | model.eval()
313 | if use_cuda:
314 | model = model.cuda()
315 |
316 | # Initialize a optimizer just so we can retrieve the model from the
317 | # checkpoint
318 | optimizer = optim.Adam(model.parameters())
319 |
320 | # Load the previoulsy saved model state to the RPNet model
321 | model = utils.load_checkpoint(model, optimizer, args.save_dir,
322 | args.name)[0]
323 | #print(model)
324 | step = args.step
325 | test(model, test_loader, w_class, class_encoding, step)
326 | else:
327 | # Should never happen...but just in case it does
328 | raise RuntimeError(
329 | "\"{0}\" is not a valid choice for execution mode.".format(
330 | args.mode))
331 |
--------------------------------------------------------------------------------
/metric/__init__.py:
--------------------------------------------------------------------------------
1 | from .confusionmatrix import ConfusionMatrix
2 | from .iou import IoU
3 | from .metric import Metric
4 |
5 | __all__ = ['ConfusionMatrix', 'IoU', 'Metric']
6 |
--------------------------------------------------------------------------------
/metric/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/metric/__pycache__/confusionmatrix.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/confusionmatrix.cpython-36.pyc
--------------------------------------------------------------------------------
/metric/__pycache__/iou.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/iou.cpython-36.pyc
--------------------------------------------------------------------------------
/metric/__pycache__/metric.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/metric/__pycache__/metric.cpython-36.pyc
--------------------------------------------------------------------------------
/metric/confusionmatrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from metric import metric
4 |
5 |
6 | class ConfusionMatrix(metric.Metric):
7 | """Constructs a confusion matrix for a multi-class classification problems.
8 |
9 | Does not support multi-label, multi-class problems.
10 |
11 | Keyword arguments:
12 | - num_classes (int): number of classes in the classification problem.
13 | - normalized (boolean, optional): Determines whether or not the confusion
14 | matrix is normalized or not. Default: False.
15 |
16 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py
17 | """
18 |
19 | def __init__(self, num_classes, normalized=False):
20 | super().__init__()
21 |
22 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32)
23 | self.normalized = normalized
24 | self.num_classes = num_classes
25 | self.reset()
26 |
27 | def reset(self):
28 | self.conf.fill(0)
29 |
30 | def add(self, predicted, target):
31 | """Computes the confusion matrix
32 |
33 | The shape of the confusion matrix is K x K, where K is the number
34 | of classes.
35 |
36 | Keyword arguments:
37 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of
38 | predicted scores obtained from the model for N examples and K classes,
39 | or an N-tensor/array of integer values between 0 and K-1.
40 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of
41 | ground-truth classes for N examples and K classes, or an N-tensor/array
42 | of integer values between 0 and K-1.
43 |
44 | """
45 | # If target and/or predicted are tensors, convert them to numpy arrays
46 | if torch.is_tensor(predicted):
47 | predicted = predicted.cpu().numpy()
48 | if torch.is_tensor(target):
49 | target = target.cpu().numpy()
50 |
51 | assert predicted.shape[0] == target.shape[0], \
52 | 'number of targets and predicted outputs do not match'
53 |
54 | if np.ndim(predicted) != 1:
55 | assert predicted.shape[1] == self.num_classes, \
56 | 'number of predictions does not match size of confusion matrix'
57 | predicted = np.argmax(predicted, 1)
58 | else:
59 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \
60 | 'predicted values are not between 0 and k-1'
61 |
62 | if np.ndim(target) != 1:
63 | assert target.shape[1] == self.num_classes, \
64 | 'Onehot target does not match size of confusion matrix'
65 | assert (target >= 0).all() and (target <= 1).all(), \
66 | 'in one-hot encoding, target values should be 0 or 1'
67 | assert (target.sum(1) == 1).all(), \
68 | 'multi-label setting is not supported'
69 | target = np.argmax(target, 1)
70 | else:
71 | assert (target.max() < self.num_classes) and (target.min() >= 0), \
72 | 'target values are not between 0 and k-1'
73 |
74 | # hack for bincounting 2 arrays together
75 | x = predicted + self.num_classes * target
76 | bincount_2d = np.bincount(
77 | x.astype(np.int32), minlength=self.num_classes**2)
78 | assert bincount_2d.size == self.num_classes**2
79 | conf = bincount_2d.reshape((self.num_classes, self.num_classes))
80 |
81 | self.conf += conf
82 |
83 | def value(self):
84 | """
85 | Returns:
86 | Confustion matrix of K rows and K columns, where rows corresponds
87 | to ground-truth targets and columns corresponds to predicted
88 | targets.
89 | """
90 | if self.normalized:
91 | conf = self.conf.astype(np.float32)
92 | return conf / conf.sum(1).clip(min=1e-12)[:, None]
93 | else:
94 | return self.conf
95 |
--------------------------------------------------------------------------------
/metric/iou.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from metric import metric
4 | from metric.confusionmatrix import ConfusionMatrix
5 |
6 |
7 | class IoU(metric.Metric):
8 | """Computes the intersection over union (IoU) per class and corresponding
9 | mean (mIoU).
10 |
11 | Intersection over union (IoU) is a common evaluation metric for semantic
12 | segmentation. The predictions are first accumulated in a confusion matrix
13 | and the IoU is computed from it as follows:
14 |
15 | IoU = true_positive / (true_positive + false_positive + false_negative).
16 |
17 | Keyword arguments:
18 | - num_classes (int): number of classes in the classification problem
19 | - normalized (boolean, optional): Determines whether or not the confusion
20 | matrix is normalized or not. Default: False.
21 | - ignore_index (int or iterable, optional): Index of the classes to ignore
22 | when computing the IoU. Can be an int, or any iterable of ints.
23 | """
24 |
25 | def __init__(self, num_classes, normalized=False, ignore_index=None):
26 | super().__init__()
27 | self.conf_metric = ConfusionMatrix(num_classes, normalized)
28 |
29 | if ignore_index is None:
30 | self.ignore_index = None
31 | elif isinstance(ignore_index, int):
32 | self.ignore_index = (ignore_index,)
33 | else:
34 | try:
35 | self.ignore_index = tuple(ignore_index)
36 | except TypeError:
37 | raise ValueError("'ignore_index' must be an int or iterable")
38 |
39 | def reset(self):
40 | self.conf_metric.reset()
41 |
42 | def add(self, predicted, target):
43 | """Adds the predicted and target pair to the IoU metric.
44 |
45 | Keyword arguments:
46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of
47 | predicted scores obtained from the model for N examples and K classes,
48 | or (N, H, W) tensor of integer values between 0 and K-1.
49 | - target (Tensor): Can be a (N, K, H, W) tensor of
50 | target scores for N examples and K classes, or (N, H, W) tensor of
51 | integer values between 0 and K-1.
52 |
53 | """
54 | # Dimensions check
55 | assert predicted.size(0) == target.size(0), \
56 | 'number of targets and predicted outputs do not match'
57 | assert predicted.dim() == 3 or predicted.dim() == 4, \
58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)"
59 | assert target.dim() == 3 or target.dim() == 4, \
60 | "targets must be of dimension (N, H, W) or (N, K, H, W)"
61 |
62 | # If the tensor is in categorical format convert it to integer format
63 | if predicted.dim() == 4:
64 | _, predicted = predicted.max(1)
65 | if target.dim() == 4:
66 | _, target = target.max(1)
67 |
68 | self.conf_metric.add(predicted.view(-1), target.view(-1))
69 |
70 | def value(self):
71 | """Computes the IoU and mean IoU.
72 |
73 | The mean computation ignores NaN elements of the IoU array.
74 |
75 | Returns:
76 | Tuple: (IoU, mIoU). The first output is the per class IoU,
77 | for K classes it's numpy.ndarray with K elements. The second output,
78 | is the mean IoU.
79 | """
80 | conf_matrix = self.conf_metric.value()
81 | if self.ignore_index is not None:
82 | for index in self.ignore_index:
83 | conf_matrix[:, self.ignore_index] = 0
84 | conf_matrix[self.ignore_index, :] = 0
85 | true_positive = np.diag(conf_matrix)
86 | false_positive = np.sum(conf_matrix, 0) - true_positive
87 | false_negative = np.sum(conf_matrix, 1) - true_positive
88 |
89 | # Just in case we get a division by 0, ignore/hide the error
90 | with np.errstate(divide='ignore', invalid='ignore'):
91 | iou = true_positive / (true_positive + false_positive + false_negative)
92 |
93 | return iou, np.nanmean(iou)
94 |
--------------------------------------------------------------------------------
/metric/metric.py:
--------------------------------------------------------------------------------
1 | class Metric(object):
2 | """Base class for all metrics.
3 |
4 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py
5 | """
6 | def reset(self):
7 | pass
8 |
9 | def add(self):
10 | pass
11 |
12 | def value(self):
13 | pass
14 |
--------------------------------------------------------------------------------
/models/__pycache__/rpnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/models/__pycache__/rpnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/rpnet.py:
--------------------------------------------------------------------------------
1 | # ERFNet full model definition for Pytorch
2 | # Sept 2017
3 | # Eduardo Romera
4 | #######################
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.init as init
9 | import torch.nn.functional as F
10 |
11 | class DownsamplerBlock (nn.Module):
12 | def __init__(self, ninput, noutput):
13 | super().__init__()
14 |
15 | self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True)
16 | self.conv2 = nn.Conv2d(16, 64, (1, 1), stride=1, padding=0, bias=True)
17 | self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
18 | self.bn = nn.BatchNorm2d(noutput, eps=1e-3)
19 |
20 |
21 | def forward(self, input):
22 | c=input
23 | a=self.conv(input)
24 | b,max_indices=self.pool(input)
25 | #print(a.shape,b.shape,c.shape,max_indices.shape,"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
26 | output1 = torch.cat([a, b], 1)
27 | if b.shape[1]==16:
28 | b_c = self.conv2(b)
29 | else:
30 | b_c=b
31 | output = self.bn(output1)
32 | return F.relu(output), max_indices, b, b_c, output1
33 |
34 |
35 | class non_bottleneck_1d (nn.Module):
36 | def __init__(self, chann, dropprob, dilated):
37 | super().__init__()
38 |
39 | self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1,0), bias=True)
40 |
41 | self.conv1x3_1 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1), bias=True)
42 |
43 | self.bn1 = nn.BatchNorm2d(chann, eps=1e-03)
44 |
45 | self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1*dilated,0), bias=True, dilation = (dilated,1))
46 |
47 | self.conv1x3_2 = nn.Conv2d(chann, chann, (1,3), stride=1, padding=(0,1*dilated), bias=True, dilation = (1, dilated))
48 |
49 | self.bn2 = nn.BatchNorm2d(chann, eps=1e-03)
50 |
51 | self.dropout = nn.Dropout2d(dropprob)
52 |
53 |
54 | def forward(self, input):
55 |
56 | output = self.conv3x1_1(input)
57 | output = F.relu(output)
58 | output = self.conv1x3_1(output)
59 | output = self.bn1(output)
60 | output = F.relu(output)
61 |
62 | output = self.conv3x1_2(output)
63 | output = F.relu(output)
64 | output = self.conv1x3_2(output)
65 | output = self.bn2(output)
66 |
67 | if (self.dropout.p != 0):
68 | output = self.dropout(output)
69 |
70 | return F.relu(output+input), output #+input = identity (residual connection)
71 |
72 |
73 | class RPNet(nn.Module):
74 | def __init__(self, num_classes):
75 | super().__init__()
76 | self.initial_block = DownsamplerBlock(3,16)
77 |
78 | self.l0d1=non_bottleneck_1d(16, 0.03, 1)
79 | self.down0_25=DownsamplerBlock(16,64)
80 |
81 |
82 | self.l1d1=non_bottleneck_1d(64, 0.03, 1)
83 | self.l1d2=non_bottleneck_1d(64, 0.03, 1)
84 | self.l1d3=non_bottleneck_1d(64, 0.03, 1)
85 | self.l1d4=non_bottleneck_1d(64, 0.03, 1)
86 | self.l1d5=non_bottleneck_1d(64, 0.03, 1)
87 |
88 | self.down0_125=DownsamplerBlock(64,128)
89 |
90 | self.l2d1=non_bottleneck_1d(128, 0.3, 2)
91 | self.l2d2=non_bottleneck_1d(128, 0.3, 4)
92 | self.l2d3=non_bottleneck_1d(128, 0.3, 8)
93 | self.l2d4=non_bottleneck_1d(128, 0.3, 16)
94 |
95 | self.l3d1=non_bottleneck_1d(128, 0.3, 2)
96 | self.l3d2=non_bottleneck_1d(128, 0.3, 4)
97 | self.l3d3=non_bottleneck_1d(128, 0.3, 8)
98 | self.l3d4=non_bottleneck_1d(128, 0.3, 16)
99 | #Only in encoder mode:
100 | self.conv2d1 = nn.Conv2d(
101 | 128,
102 | num_classes,
103 | kernel_size=1,
104 | stride=1,
105 | padding=0,
106 | bias=True)
107 | self.conv2d2 = nn.Conv2d(
108 | 192,
109 | num_classes,
110 | kernel_size=1,
111 | stride=1,
112 | padding=0,
113 | bias=True)
114 | self.conv2d3 = nn.Conv2d(
115 | 36,
116 | num_classes,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0,
120 | bias=True)
121 | self.conv2d4 = nn.Conv2d(
122 | 16,
123 | num_classes,
124 | kernel_size=1,
125 | stride=1,
126 | padding=0,
127 | bias=False)
128 | self.conv2d5 = nn.Conv2d(
129 | 64,
130 | num_classes,
131 | kernel_size=1,
132 | stride=1,
133 | padding=0,
134 | bias=False)
135 |
136 | self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)
137 | self.main_unpool2 = nn.MaxUnpool2d(kernel_size=2)
138 |
139 |
140 | def forward(self, input, predict=False):
141 | output, max_indices0_0, d, d_d, dd = self.initial_block(input)
142 | output,y = self.l0d1(output)
143 |
144 |
145 |
146 |
147 | output, max_indices1_0,d1,d1_d1, ddd = self.down0_25(output)
148 | #print(d1.shape,'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx')
149 | d2 = self.main_unpool1(d1, max_indices1_0)
150 | d_1=d2-dd
151 |
152 | output,y = self.l1d1(output)
153 | output,y = self.l1d2(output)
154 | output,y = self.l1d3(output)
155 | output,y = self.l1d4(output)
156 |
157 | cc_2=self.conv2d4(d_1)
158 |
159 |
160 | output, max_indices2_0,d3,d3_d3, dddd = self.down0_125(output)
161 | d4 = self.main_unpool2(d3, max_indices2_0)
162 | d_2=d4-d1_d1
163 | cc_4=self.conv2d5(d_2)
164 |
165 | output,y = self.l2d1(output)
166 | output,y = self.l2d2(output)
167 | output,y = self.l2d3(output)
168 | output,y = self.l2d4(output)
169 | output,y = self.l3d1(output)
170 | output,y = self.l3d2(output)
171 | output,y = self.l3d3(output)
172 | output,y = self.l3d4(output)
173 | x1_81 = output
174 | x1_8 = self.conv2d1(output)
175 |
176 | x1_8_2 = torch.nn.functional.interpolate(x1_81, scale_factor=2, mode='bilinear')
177 |
178 | out4 = torch.cat((x1_8_2,d_2),1)
179 | x1_41 = self.conv2d2(out4)
180 | x1_4=x1_41+cc_4
181 |
182 | x1_4_2 = torch.nn.functional.interpolate(x1_4, scale_factor=2, mode='bilinear')
183 | out2 = torch.cat((x1_4_2, d_1), 1)
184 | x1_21 = self.conv2d3(out2)
185 | x1_2=x1_21+cc_2
186 |
187 | x1_1 = torch.nn.functional.interpolate(x1_2, scale_factor=2, mode='bilinear')
188 |
189 | return x1_1, x1_2, x1_4, x1_8
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------
/pictures/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/pictures/1.png
--------------------------------------------------------------------------------
/pictures/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/pictures/2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | kiwisolver==1.0.1
3 | matplotlib==2.2.2
4 | numpy==1.14.2
5 | Pillow==5.0.0
6 | pyparsing==2.2.0
7 | python-dateutil==2.7.0
8 | pytz==2018.3
9 | PyYAML==4.2
10 | six==1.11.0
11 | torch==0.3.1
12 | torchvision==0.2.0
13 |
--------------------------------------------------------------------------------
/save/RPNet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/superlxt/RPNet-Pytorch/1cffe785aa54d98625464dff30e3940bb4b98d28/save/RPNet
--------------------------------------------------------------------------------
/save/RPNet_summary.txt:
--------------------------------------------------------------------------------
1 | ARGUMENTS
2 | batch_size: 3
3 | cuda: store_false
4 | dataset: cityscapes
5 | dataset_dir: data/Cityscapes
6 | epochs: 300
7 | height: 512
8 | ignore_unlabeled: True
9 | imshow_batch: False
10 | learning_rate: 0.0005
11 | lr_decay: 0.1
12 | lr_decay_epochs: 100
13 | mode: train
14 | name: RPNet
15 | print_step: False
16 | resume: True
17 | save_dir: save
18 | weighing: mfb
19 | weight_decay: 0.0001
20 | width: 1024
21 | workers: 4
22 |
23 | BEST VALIDATION
24 | Epoch: 300
25 | Mean IoU: 0.703690656206444
26 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | import torch
3 |
4 |
5 | class Test():
6 | """Tests the ``model`` on the specified test dataset using the
7 | data loader, and loss criterion.
8 |
9 | Keyword arguments:
10 | - model (``nn.Module``): the model instance to test.
11 | - data_loader (``Dataloader``): Provides single or multi-process
12 | iterators over the dataset.
13 | - criterion (``Optimizer``): The loss criterion.
14 | - metric (```Metric``): An instance specifying the metric to return.
15 | - use_cuda (``bool``): If ``True``, the training is performed using
16 | CUDA operations (GPU).
17 |
18 | """
19 |
20 | def __init__(self, model, data_loader, criterion, metric, use_cuda, step):
21 | self.model = model
22 | self.data_loader = data_loader
23 | self.criterion = criterion
24 | self.metric = metric
25 | self.use_cuda = use_cuda
26 | self.step = step
27 |
28 | def run_epoch(self, iteration_loss=False):
29 | """Runs an epoch of validation.
30 |
31 | Keyword arguments:
32 | - iteration_loss (``bool``, optional): Prints loss at every step.
33 |
34 | Returns:
35 | - The epoch loss (float), and the values of the specified metrics
36 |
37 | """
38 | epoch_loss = 0.0
39 | self.metric.reset()
40 | for step, batch_data in enumerate(self.data_loader):
41 | # Get the inputs and labels
42 | inputs, labels = batch_data
43 |
44 | # Wrap them in a Varaible
45 | inputs, labels = Variable(inputs), Variable(labels)
46 | if self.use_cuda:
47 | inputs = inputs.cuda()
48 | labels = labels.cuda()
49 |
50 | labels4 = torch.nn.functional.interpolate(labels.unsqueeze(0).float(), scale_factor=0.125, mode='nearest').squeeze(0).long()
51 | labels3 = torch.nn.functional.interpolate(labels.unsqueeze(0).float(), scale_factor=0.25, mode='nearest').squeeze(0).long()
52 | labels2 = torch.nn.functional.interpolate(labels.unsqueeze(0).float(), scale_factor=0.5, mode='nearest').squeeze(0).long()
53 | labels1 = labels
54 |
55 | # Forward propagation
56 | outputs1, outputs2, outputs3, outputs4 = self.model(inputs)
57 |
58 | # Loss computation
59 | loss = self.criterion(eval('outputs{}'.format(self.step)), eval('labels{}'.format(self.step)))
60 |
61 | # Keep track of loss for current epoch
62 | epoch_loss += loss.item()
63 |
64 | # Keep track of evaluation the metric
65 | self.metric.add(eval('outputs{}'.format(self.step)).data, eval('labels{}'.format(self.step)).data)
66 |
67 | if iteration_loss:
68 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))
69 |
70 | return epoch_loss / len(self.data_loader), self.metric.value()
71 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | import torch.nn as nn
3 | import torch
4 | import random
5 | import cv2
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import matplotlib.image as mpimg
9 | import utils
10 |
11 |
12 |
13 | class Train():
14 | """Performs the training of ``model`` given a training dataset data
15 | loader, the optimizer, and the loss criterion.
16 |
17 | Keyword arguments:
18 | - model (``nn.Module``): the model instance to train.
19 | - data_loader (``Dataloader``): Provides single or multi-process
20 | iterators over the dataset.
21 | - optim (``Optimizer``): The optimization algorithm.
22 | - criterion (``Optimizer``): The loss criterion.
23 | - metric (```Metric``): An instance specifying the metric to return.
24 | - use_cuda (``bool``): If ``True``, the training is performed using
25 | CUDA operations (GPU).
26 |
27 | """
28 |
29 | def __init__(self, model, data_loader, optim, criterion, metric, use_cuda, step):
30 | self.model = model
31 | self.data_loader = data_loader
32 | self.optim = optim
33 | self.criterion = criterion
34 | self.metric = metric
35 | self.use_cuda = use_cuda
36 | self.step = step
37 |
38 | def run_epoch(self, iteration_loss=False):
39 | """Runs an epoch of training.
40 |
41 | Keyword arguments:
42 | - iteration_loss (``bool``, optional): Prints loss at every step.
43 |
44 | Returns:
45 | - The epoch loss (float).
46 |
47 | """
48 | epoch_loss = 0.0
49 | self.metric.reset()
50 | for step, batch_data in enumerate(self.data_loader):
51 | # Get the inputs and labels
52 |
53 | inputs, labels = batch_data
54 |
55 | # Use augmentation
56 | inputs, labels = utils.dataaug(inputs, labels)
57 |
58 | # Wrap them in a Varaible
59 | inputs, labels = Variable(inputs), Variable(labels.float())
60 |
61 |
62 | if self.use_cuda:
63 | inputs = inputs.cuda()
64 | labels = labels.cuda()
65 | labels=labels.unsqueeze(1)
66 |
67 | labels2 = torch.nn.functional.interpolate(labels, scale_factor=0.5, mode='nearest').squeeze(1)
68 | labels3 = torch.nn.functional.interpolate(labels, scale_factor=0.25, mode='nearest').squeeze(1)
69 | labels4 = torch.nn.functional.interpolate(labels, scale_factor=0.125, mode='nearest').squeeze(1)
70 | labels1 = labels.squeeze(1)
71 |
72 |
73 | # Forward propagation
74 | outputs1, outputs2, outputs3, outputs4 = self.model(inputs)
75 |
76 | # Loss computation
77 | loss1 = self.criterion(outputs1, labels1.long())
78 | loss2 = self.criterion(outputs2, labels2.long())
79 | loss3 = self.criterion(outputs3, labels3.long())
80 | loss4 = self.criterion(outputs4, labels4.long())
81 |
82 | # Step Loss
83 | loss= eval('loss{}'.format(self.step))
84 |
85 | # Backpropagation
86 | self.optim.zero_grad()
87 | loss.backward()
88 | self.optim.step()
89 |
90 | # Keep track of loss for current epoch
91 | epoch_loss += loss.item()
92 |
93 | # Metric
94 | self.metric.add(eval('outputs{}'.format(self.step)).data, eval('labels{}'.format(self.step)).data)
95 |
96 | if iteration_loss:
97 | print("[Step: %d] Iteration loss: %.4f" % (step, loss.item()))
98 |
99 | return epoch_loss / len(self.data_loader), self.metric.value()
100 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #level4 training (1/8 scale)
4 | python main.py --step=4
5 |
6 | #level3 training (1/4 scale)
7 | python main.py --step=3 --resume
8 |
9 | #level2 training (1/2 scale)
10 | python main.py --step=2 --resume
11 |
12 | #level1 training (original scale)
13 | python main.py --step=1 --resume
14 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from PIL import Image
4 | from collections import OrderedDict
5 | from torchvision.transforms import ToPILImage
6 |
7 |
8 | class PILToLongTensor(object):
9 | """Converts a ``PIL Image`` to a ``torch.LongTensor``.
10 |
11 | Code adapted from: http://pytorch.org/docs/master/torchvision/transforms.html?highlight=totensor
12 |
13 | """
14 |
15 | def __call__(self, pic):
16 | """Performs the conversion from a ``PIL Image`` to a ``torch.LongTensor``.
17 |
18 | Keyword arguments:
19 | - pic (``PIL.Image``): the image to convert to ``torch.LongTensor``
20 |
21 | Returns:
22 | A ``torch.LongTensor``.
23 |
24 | """
25 | if not isinstance(pic, Image.Image):
26 | raise TypeError("pic should be PIL Image. Got {}".format(
27 | type(pic)))
28 |
29 | # handle numpy array
30 | if isinstance(pic, np.ndarray):
31 | img = torch.from_numpy(pic.transpose((2, 0, 1)))
32 | # backward compatibility
33 | return img.long()
34 |
35 | # Convert PIL image to ByteTensor
36 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
37 |
38 | # Reshape tensor
39 | nchannel = len(pic.mode)
40 | img = img.view(pic.size[1], pic.size[0], nchannel)
41 |
42 | # Convert to long and squeeze the channels
43 | return img.transpose(0, 1).transpose(0,
44 | 2).contiguous().long().squeeze_()
45 |
46 |
47 | class LongTensorToRGBPIL(object):
48 | """Converts a ``torch.LongTensor`` to a ``PIL image``.
49 |
50 | The input is a ``torch.LongTensor`` where each pixel's value identifies the
51 | class.
52 |
53 | Keyword arguments:
54 | - rgb_encoding (``OrderedDict``): An ``OrderedDict`` that relates pixel
55 | values, class names, and class colors.
56 |
57 | """
58 | def __init__(self, rgb_encoding):
59 | self.rgb_encoding = rgb_encoding
60 |
61 | def __call__(self, tensor):
62 | """Performs the conversion from ``torch.LongTensor`` to a ``PIL image``
63 |
64 | Keyword arguments:
65 | - tensor (``torch.LongTensor``): the tensor to convert
66 |
67 | Returns:
68 | A ``PIL.Image``.
69 |
70 | """
71 | # Check if label_tensor is a LongTensor
72 | if not isinstance(tensor, torch.LongTensor):
73 | raise TypeError("label_tensor should be torch.LongTensor. Got {}"
74 | .format(type(tensor)))
75 | # Check if encoding is a ordered dictionary
76 | if not isinstance(self.rgb_encoding, OrderedDict):
77 | raise TypeError("encoding should be an OrderedDict. Got {}".format(
78 | type(self.rgb_encoding)))
79 |
80 | # label_tensor might be an image without a channel dimension, in this
81 | # case unsqueeze it
82 | if len(tensor.size()) == 2:
83 | tensor.unsqueeze_(0)
84 |
85 | color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2))
86 |
87 | for index, (class_name, color) in enumerate(self.rgb_encoding.items()):
88 | # Get a mask of elements equal to index
89 | mask = torch.eq(tensor, index).squeeze_()
90 | # Fill color_tensor with corresponding colors
91 | for channel, color_value in enumerate(color):
92 | color_tensor[channel].masked_fill_(mask, color_value)
93 |
94 | return ToPILImage()(color_tensor)
95 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch as F
3 | import torchvision
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import os
7 | import random
8 | import cv2
9 |
10 | def batch_transform(batch, transform):
11 | """Applies a transform to a batch of samples.
12 |
13 | Keyword arguments:
14 | - batch (): a batch os samples
15 | - transform (callable): A function/transform to apply to ``batch``
16 |
17 | """
18 |
19 | # Convert the single channel label to RGB in tensor form
20 | # 1. F.unbind removes the 0-dimension of "labels" and returns a tuple of
21 | # all slices along that dimension
22 | # 2. the transform is applied to each slice
23 | transf_slices = [transform(tensor) for tensor in F.unbind(batch)]
24 |
25 | return F.stack(transf_slices)
26 |
27 |
28 | def imshow_batch(images, labels):
29 | """Displays two grids of images. The top grid displays ``images``
30 | and the bottom grid ``labels``
31 |
32 | Keyword arguments:
33 | - images (``Tensor``): a 4D mini-batch tensor of shape
34 | (B, C, H, W)
35 | - labels (``Tensor``): a 4D mini-batch tensor of shape
36 | (B, C, H, W)
37 |
38 | """
39 |
40 | # Make a grid with the images and labels and convert it to numpy
41 | images = torchvision.utils.make_grid(images).numpy()
42 | labels = torchvision.utils.make_grid(labels).numpy()
43 |
44 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7))
45 | ax1.imshow(np.transpose(images, (1, 2, 0)))
46 | ax2.imshow(np.transpose(labels, (1, 2, 0)))
47 |
48 | plt.show()
49 |
50 |
51 | def save_checkpoint(model, optimizer, epoch, miou, args):
52 | """Saves the model in a specified directory with a specified name.save
53 |
54 | Keyword arguments:
55 | - model (``nn.Module``): The model to save.
56 | - optimizer (``torch.optim``): The optimizer state to save.
57 | - epoch (``int``): The current epoch for the model.
58 | - miou (``float``): The mean IoU obtained by the model.
59 | - args (``ArgumentParser``): An instance of ArgumentParser which contains
60 | the arguments used to train ``model``. The arguments are written to a text
61 | file in ``args.save_dir`` named "``args.name``_args.txt".
62 |
63 | """
64 | name = args.name
65 | save_dir = args.save_dir
66 |
67 | assert os.path.isdir(
68 | save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir)
69 |
70 | # Save model
71 | model_path = os.path.join(save_dir, name)
72 | checkpoint = {
73 | 'epoch': epoch,
74 | 'miou': miou,
75 | 'state_dict': model.state_dict(),
76 | 'optimizer': optimizer.state_dict()
77 | }
78 | torch.save(checkpoint, model_path)
79 |
80 | # Save arguments
81 | summary_filename = os.path.join(save_dir, name + '_summary.txt')
82 | with open(summary_filename, 'w') as summary_file:
83 | sorted_args = sorted(vars(args))
84 | summary_file.write("ARGUMENTS\n")
85 | for arg in sorted_args:
86 | arg_str = "{0}: {1}\n".format(arg, getattr(args, arg))
87 | summary_file.write(arg_str)
88 |
89 | summary_file.write("\nBEST VALIDATION\n")
90 | summary_file.write("Epoch: {0}\n". format(epoch))
91 | summary_file.write("Mean IoU: {0}\n". format(miou))
92 |
93 |
94 | def load_checkpoint(model, optimizer, folder_dir, filename):
95 | """Saves the model in a specified directory with a specified name.save
96 |
97 | Keyword arguments:
98 | - model (``nn.Module``): The stored model state is copied to this model
99 | instance.
100 | - optimizer (``torch.optim``): The stored optimizer state is copied to this
101 | optimizer instance.
102 | - folder_dir (``string``): The path to the folder where the saved model
103 | state is located.
104 | - filename (``string``): The model filename.
105 |
106 | Returns:
107 | The epoch, mean IoU, ``model``, and ``optimizer`` loaded from the
108 | checkpoint.
109 |
110 | """
111 | assert os.path.isdir(
112 | folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir)
113 |
114 | # Create folder to save model and information
115 | model_path = os.path.join(folder_dir, filename)
116 | assert os.path.isfile(
117 | model_path), "The model file \"{0}\" doesn't exist.".format(filename)
118 |
119 | # Load the stored model parameters to the model instance
120 | checkpoint = torch.load(model_path)
121 | model.load_state_dict(checkpoint['state_dict'])
122 | optimizer.load_state_dict(checkpoint['optimizer'])
123 | epoch = 0
124 | miou = 0
125 |
126 | return model, optimizer, epoch, miou
127 |
128 |
129 |
130 | def dataaug(inputs, labels):
131 | """use flip and warpAffine
132 | """
133 | for i in range(inputs.shape[0]):
134 | inputs1 = inputs[i, :, :, :]
135 | labels1 = labels[i, :, :]
136 | inputs1 = inputs1.numpy()
137 | labels1 = labels1.numpy()
138 | inputs1 = inputs1.transpose((1, 2, 0))
139 |
140 | x = random.randint(-2, 2)
141 | y = random.randint(-2, 2)
142 | H = np.float32([[1,0,x],[0,1,y]])
143 | inputs1 = cv2.warpAffine(inputs1,H,(inputs1.shape[1],inputs1.shape[0]))
144 | labels1 = cv2.warpAffine(labels1.astype(np.float32),H,(inputs1.shape[1],inputs1.shape[0])).astype(np.uint8)
145 |
146 |
147 | if random.random() < 0.5:
148 | inputs1 = cv2.flip(inputs1, 1) # horizontal flip
149 | labels1 = cv2.flip(labels1, 1) # horizontal flip
150 |
151 | inputs1 = inputs1.transpose((2, 0, 1))
152 | inputs1 = torch.from_numpy(inputs1)
153 | labels1 = torch.from_numpy(labels1)
154 | #print(inputs.shape)
155 | inputs[i, : , : ,:]=inputs1
156 | labels[i, : ,:]=labels1
157 | return inputs, labels
158 |
--------------------------------------------------------------------------------