├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── config
├── ade20k-hrnetv2.yaml
├── ade20k-mobilenetv2dilated-c1_deepsup.yaml
├── ade20k-resnet101-upernet.yaml
├── ade20k-resnet101dilated-ppm_deepsup.yaml
├── ade20k-resnet18dilated-ppm_deepsup.yaml
├── ade20k-resnet50-upernet.yaml
└── ade20k-resnet50dilated-ppm_deepsup.yaml
├── data
├── color150.mat
├── object150_info.csv
├── training.odgt
└── validation.odgt
├── demo_test.sh
├── download_ADE20K.sh
├── eval.py
├── eval_multipro.py
├── mit_semseg
├── __init__.py
├── config
│ ├── __init__.py
│ └── defaults.py
├── dataset.py
├── lib
│ ├── __init__.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── batchnorm.py
│ │ │ ├── comm.py
│ │ │ ├── replicate.py
│ │ │ ├── tests
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_numeric_batchnorm.py
│ │ │ │ └── test_sync_batchnorm.py
│ │ │ └── unittest.py
│ │ └── parallel
│ │ │ ├── __init__.py
│ │ │ └── data_parallel.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── data
│ │ ├── __init__.py
│ │ ├── dataloader.py
│ │ ├── dataset.py
│ │ ├── distributed.py
│ │ └── sampler.py
│ │ └── th.py
├── models
│ ├── __init__.py
│ ├── hrnet.py
│ ├── mobilenet.py
│ ├── models.py
│ ├── resnet.py
│ ├── resnext.py
│ └── utils.py
└── utils.py
├── notebooks
├── DemoSegmenter.ipynb
├── README.md
├── ckpt
├── config
├── data
├── ipynb_drop_output.py
├── mit_semseg
├── setup_notebooks.sh
└── teaser
├── requirements.txt
├── setup.py
├── teaser
├── ADE_val_00000278.png
└── ADE_val_00001519.png
├── test.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb filter=clean_ipynb
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 | ckpt/
4 | vis/
5 | log/
6 | pretrained/
7 |
8 | .ipynb_checkpoints
9 |
10 | ADE_val*.jpg
11 | ADE_val*.png
12 |
13 | .idea/
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2019, MIT CSAIL Computer Vision
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semantic Segmentation on MIT ADE20K dataset in PyTorch
2 |
3 | This is a PyTorch implementation of semantic segmentation models on MIT ADE20K scene parsing dataset (http://sceneparsing.csail.mit.edu/).
4 |
5 | ADE20K is the largest open source dataset for semantic segmentation and scene parsing, released by MIT Computer Vision team. Follow the link below to find the repository for our dataset and implementations on Caffe and Torch7:
6 | https://github.com/CSAILVision/sceneparsing
7 |
8 | If you simply want to play with our demo, please try this link: http://scenesegmentation.csail.mit.edu You can upload your own photo and parse it!
9 |
10 | [You can also use this colab notebook playground here](https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb) to tinker with the code for segmenting an image.
11 |
12 | All pretrained models can be found at:
13 | http://sceneparsing.csail.mit.edu/model/pytorch
14 |
15 |
16 |
17 | [From left to right: Test Image, Ground Truth, Predicted Result]
18 |
19 | Color encoding of semantic categories can be found here:
20 | https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing
21 |
22 | ## Updates
23 | - HRNet model is now supported.
24 | - We use configuration files to store most options which were in argument parser. The definitions of options are detailed in ```config/defaults.py```.
25 | - We conform to Pytorch practice in data preprocessing (RGB [0, 1], substract mean, divide std).
26 |
27 |
28 | ## Highlights
29 |
30 | ### Syncronized Batch Normalization on PyTorch
31 | This module computes the mean and standard-deviation across all devices during training. We empirically find that a reasonable large batch size is important for segmentation. We thank [Jiayuan Mao](http://vccy.xyz/) for his kind contributions, please refer to [Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) for details.
32 |
33 | The implementation is easy to use as:
34 | - It is pure-python, no C++ extra extension libs.
35 | - It is completely compatible with PyTorch's implementation. Specifically, it uses unbiased variance to update the moving average, and use sqrt(max(var, eps)) instead of sqrt(var + eps).
36 | - It is efficient, only 20% to 30% slower than UnsyncBN.
37 |
38 | ### Dynamic scales of input for training with multiple GPUs
39 | For the task of semantic segmentation, it is good to keep aspect ratio of images during training. So we re-implement the `DataParallel` module, and make it support distributing data to multiple GPUs in python dict, so that each gpu can process images of different sizes. At the same time, the dataloader also operates differently.
40 |
41 | *Now the batch size of a dataloader always equals to the number of GPUs*, each element will be sent to a GPU. It is also compatible with multi-processing. Note that the file index for the multi-processing dataloader is stored on the master process, which is in contradict to our goal that each worker maintains its own file list. So we use a trick that although the master process still gives dataloader an index for `__getitem__` function, we just ignore such request and send a random batch dict. Also, *the multiple workers forked by the dataloader all have the same seed*, you will find that multiple workers will yield exactly the same data, if we use the above-mentioned trick directly. Therefore, we add one line of code which sets the defaut seed for `numpy.random` before activating multiple worker in dataloader.
42 |
43 | ### State-of-the-Art models
44 | - **PSPNet** is scene parsing network that aggregates global representation with Pyramid Pooling Module (PPM). It is the winner model of ILSVRC'16 MIT Scene Parsing Challenge. Please refer to [https://arxiv.org/abs/1612.01105](https://arxiv.org/abs/1612.01105) for details.
45 | - **UPerNet** is a model based on Feature Pyramid Network (FPN) and Pyramid Pooling Module (PPM). It doesn't need dilated convolution, an operator that is time-and-memory consuming. *Without bells and whistles*, it is comparable or even better compared with PSPNet, while requiring much shorter training time and less GPU memory. Please refer to [https://arxiv.org/abs/1807.10221](https://arxiv.org/abs/1807.10221) for details.
46 | - **HRNet** is a recently proposed model that retains high resolution representations throughout the model, without the traditional bottleneck design. It achieves the SOTA performance on a series of pixel labeling tasks. Please refer to [https://arxiv.org/abs/1904.04514](https://arxiv.org/abs/1904.04514) for details.
47 |
48 |
49 | ## Supported models
50 | We split our models into encoder and decoder, where encoders are usually modified directly from classification networks, and decoders consist of final convolutions and upsampling. We have provided some pre-configured models in the ```config``` folder.
51 |
52 | Encoder:
53 | - MobileNetV2dilated
54 | - ResNet18/ResNet18dilated
55 | - ResNet50/ResNet50dilated
56 | - ResNet101/ResNet101dilated
57 | - HRNetV2 (W48)
58 |
59 | Decoder:
60 | - C1 (one convolution module)
61 | - C1_deepsup (C1 + deep supervision trick)
62 | - PPM (Pyramid Pooling Module, see [PSPNet](https://hszhao.github.io/projects/pspnet) paper for details.)
63 | - PPM_deepsup (PPM + deep supervision trick)
64 | - UPerNet (Pyramid Pooling + FPN head, see [UperNet](https://arxiv.org/abs/1807.10221) for details.)
65 |
66 | ## Performance:
67 | IMPORTANT: The base ResNet in our repository is a customized (different from the one in torchvision). The base models will be automatically downloaded when needed.
68 |
69 |
70 | Architecture |
71 | MultiScale Testing |
72 | Mean IoU |
73 | Pixel Accuracy(%) |
74 | Overall Score |
75 | Inference Speed(fps) |
76 |
77 | MobileNetV2dilated + C1_deepsup |
78 | No | 34.84 | 75.75 | 54.07 |
79 | 17.2 |
80 |
81 |
82 | Yes | 33.84 | 76.80 | 55.32 |
83 | 10.3 |
84 |
85 |
86 | MobileNetV2dilated + PPM_deepsup |
87 | No | 35.76 | 77.77 | 56.27 |
88 | 14.9 |
89 |
90 |
91 | Yes | 36.28 | 78.26 | 57.27 |
92 | 6.7 |
93 |
94 |
95 | ResNet18dilated + C1_deepsup |
96 | No | 33.82 | 76.05 | 54.94 |
97 | 13.9 |
98 |
99 |
100 | Yes | 35.34 | 77.41 | 56.38 |
101 | 5.8 |
102 |
103 |
104 | ResNet18dilated + PPM_deepsup |
105 | No | 38.00 | 78.64 | 58.32 |
106 | 11.7 |
107 |
108 |
109 | Yes | 38.81 | 79.29 | 59.05 |
110 | 4.2 |
111 |
112 |
113 | ResNet50dilated + PPM_deepsup |
114 | No | 41.26 | 79.73 | 60.50 |
115 | 8.3 |
116 |
117 |
118 | Yes | 42.14 | 80.13 | 61.14 |
119 | 2.6 |
120 |
121 |
122 | ResNet101dilated + PPM_deepsup |
123 | No | 42.19 | 80.59 | 61.39 |
124 | 6.8 |
125 |
126 |
127 | Yes | 42.53 | 80.91 | 61.72 |
128 | 2.0 |
129 |
130 |
131 | UperNet50 |
132 | No | 40.44 | 79.80 | 60.12 |
133 | 8.4 |
134 |
135 |
136 | Yes | 41.55 | 80.23 | 60.89 |
137 | 2.9 |
138 |
139 |
140 | UperNet101 |
141 | No | 42.00 | 80.79 | 61.40 |
142 | 7.8 |
143 |
144 |
145 | Yes | 42.66 | 81.01 | 61.84 |
146 | 2.3 |
147 |
148 |
149 | HRNetV2 |
150 | No | 42.03 | 80.77 | 61.40 |
151 | 5.8 |
152 |
153 |
154 | Yes | 43.20 | 81.47 | 62.34 |
155 | 1.9 |
156 |
157 |
158 |
159 |
160 | The training is benchmarked on a server with 8 NVIDIA Pascal Titan Xp GPUs (12GB GPU memory), the inference speed is benchmarked a single NVIDIA Pascal Titan Xp GPU, without visualization.
161 |
162 | ## Environment
163 | The code is developed under the following configurations.
164 | - Hardware: >=4 GPUs for training, >=1 GPU for testing (set ```[--gpus GPUS]``` accordingly)
165 | - Software: Ubuntu 16.04.3 LTS, ***CUDA>=8.0, Python>=3.5, PyTorch>=0.4.0***
166 | - Dependencies: numpy, scipy, opencv, yacs, tqdm
167 |
168 | ## Quick start: Test on an image using our trained model
169 | 1. Here is a simple demo to do inference on a single image:
170 | ```bash
171 | chmod +x demo_test.sh
172 | ./demo_test.sh
173 | ```
174 | This script downloads a trained model (ResNet50dilated + PPM_deepsup) and a test image, runs the test script, and saves predicted segmentation (.png) to the working directory.
175 |
176 | 2. To test on an image or a folder of images (```$PATH_IMG```), you can simply do the following:
177 | ```
178 | python3 -u test.py --imgs $PATH_IMG --gpu $GPU --cfg $CFG
179 | ```
180 |
181 | ## Training
182 | 1. Download the ADE20K scene parsing dataset:
183 | ```bash
184 | chmod +x download_ADE20K.sh
185 | ./download_ADE20K.sh
186 | ```
187 | 2. Train a model by selecting the GPUs (```$GPUS```) and configuration file (```$CFG```) to use. During training, checkpoints by default are saved in folder ```ckpt```.
188 | ```bash
189 | python3 train.py --gpus $GPUS --cfg $CFG
190 | ```
191 | - To choose which gpus to use, you can either do ```--gpus 0-7```, or ```--gpus 0,2,4,6```.
192 |
193 | For example, you can start with our provided configurations:
194 |
195 | * Train MobileNetV2dilated + C1_deepsup
196 | ```bash
197 | python3 train.py --gpus GPUS --cfg config/ade20k-mobilenetv2dilated-c1_deepsup.yaml
198 | ```
199 |
200 | * Train ResNet50dilated + PPM_deepsup
201 | ```bash
202 | python3 train.py --gpus GPUS --cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml
203 | ```
204 |
205 | * Train UPerNet101
206 | ```bash
207 | python3 train.py --gpus GPUS --cfg config/ade20k-resnet101-upernet.yaml
208 | ```
209 |
210 | 3. You can also override options in commandline, for example ```python3 train.py TRAIN.num_epoch 10 ```.
211 |
212 |
213 | ## Evaluation
214 | 1. Evaluate a trained model on the validation set. Add ```VAL.visualize True``` in argument to output visualizations as shown in teaser.
215 |
216 | For example:
217 |
218 | * Evaluate MobileNetV2dilated + C1_deepsup
219 | ```bash
220 | python3 eval_multipro.py --gpus GPUS --cfg config/ade20k-mobilenetv2dilated-c1_deepsup.yaml
221 | ```
222 |
223 | * Evaluate ResNet50dilated + PPM_deepsup
224 | ```bash
225 | python3 eval_multipro.py --gpus GPUS --cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml
226 | ```
227 |
228 | * Evaluate UPerNet101
229 | ```bash
230 | python3 eval_multipro.py --gpus GPUS --cfg config/ade20k-resnet101-upernet.yaml
231 | ```
232 |
233 | ## Integration with other projects
234 | This library can be installed via `pip` to easily integrate with another codebase
235 | ```bash
236 | pip install git+https://github.com/CSAILVision/semantic-segmentation-pytorch.git@master
237 | ```
238 |
239 | Now this library can easily be consumed programmatically. For example
240 | ```python
241 | from mit_semseg.config import cfg
242 | from mit_semseg.dataset import TestDataset
243 | from mit_semseg.models import ModelBuilder, SegmentationModule
244 | ```
245 |
246 | ## Reference
247 |
248 | If you find the code or pre-trained models useful, please cite the following papers:
249 |
250 | Semantic Understanding of Scenes through ADE20K Dataset. B. Zhou, H. Zhao, X. Puig, T. Xiao, S. Fidler, A. Barriuso and A. Torralba. International Journal on Computer Vision (IJCV), 2018. (https://arxiv.org/pdf/1608.05442.pdf)
251 |
252 | @article{zhou2018semantic,
253 | title={Semantic understanding of scenes through the ade20k dataset},
254 | author={Zhou, Bolei and Zhao, Hang and Puig, Xavier and Xiao, Tete and Fidler, Sanja and Barriuso, Adela and Torralba, Antonio},
255 | journal={International Journal on Computer Vision},
256 | year={2018}
257 | }
258 |
259 | Scene Parsing through ADE20K Dataset. B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso and A. Torralba. Computer Vision and Pattern Recognition (CVPR), 2017. (http://people.csail.mit.edu/bzhou/publication/scene-parse-camera-ready.pdf)
260 |
261 | @inproceedings{zhou2017scene,
262 | title={Scene Parsing through ADE20K Dataset},
263 | author={Zhou, Bolei and Zhao, Hang and Puig, Xavier and Fidler, Sanja and Barriuso, Adela and Torralba, Antonio},
264 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
265 | year={2017}
266 | }
267 |
268 |
--------------------------------------------------------------------------------
/config/ade20k-hrnetv2.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 32
9 | segm_downsampling_rate: 4
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "hrnetv2"
14 | arch_decoder: "c1"
15 | fc_dim: 720
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 30
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_30.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_30.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-hrnetv2-c1"
43 |
--------------------------------------------------------------------------------
/config/ade20k-mobilenetv2dilated-c1_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "mobilenetv2dilated"
14 | arch_decoder: "c1_deepsup"
15 | fc_dim: 320
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 3
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup"
43 |
--------------------------------------------------------------------------------
/config/ade20k-resnet101-upernet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 32
9 | segm_downsampling_rate: 4
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101"
14 | arch_decoder: "upernet"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 40
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_50.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_50.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet101-upernet"
43 |
--------------------------------------------------------------------------------
/config/ade20k-resnet101dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet101dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 25
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_25.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_25.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet101dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/ade20k-resnet18dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet18dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 512
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet18dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/config/ade20k-resnet50-upernet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 32
9 | segm_downsampling_rate: 4
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50"
14 | arch_decoder: "upernet"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 30
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_30.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_30.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50-upernet"
43 |
--------------------------------------------------------------------------------
/config/ade20k-resnet50dilated-ppm_deepsup.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | root_dataset: "./data/"
3 | list_train: "./data/training.odgt"
4 | list_val: "./data/validation.odgt"
5 | num_class: 150
6 | imgSizes: (300, 375, 450, 525, 600)
7 | imgMaxSize: 1000
8 | padding_constant: 8
9 | segm_downsampling_rate: 8
10 | random_flip: True
11 |
12 | MODEL:
13 | arch_encoder: "resnet50dilated"
14 | arch_decoder: "ppm_deepsup"
15 | fc_dim: 2048
16 |
17 | TRAIN:
18 | batch_size_per_gpu: 2
19 | num_epoch: 20
20 | start_epoch: 0
21 | epoch_iters: 5000
22 | optim: "SGD"
23 | lr_encoder: 0.02
24 | lr_decoder: 0.02
25 | lr_pow: 0.9
26 | beta1: 0.9
27 | weight_decay: 1e-4
28 | deep_sup_scale: 0.4
29 | fix_bn: False
30 | workers: 16
31 | disp_iter: 20
32 | seed: 304
33 |
34 | VAL:
35 | visualize: False
36 | checkpoint: "epoch_20.pth"
37 |
38 | TEST:
39 | checkpoint: "epoch_20.pth"
40 | result: "./"
41 |
42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"
43 |
--------------------------------------------------------------------------------
/data/color150.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/data/color150.mat
--------------------------------------------------------------------------------
/data/object150_info.csv:
--------------------------------------------------------------------------------
1 | Idx,Ratio,Train,Val,Stuff,Name
2 | 1,0.1576,11664,1172,1,wall
3 | 2,0.1072,6046,612,1,building;edifice
4 | 3,0.0878,8265,796,1,sky
5 | 4,0.0621,9336,917,1,floor;flooring
6 | 5,0.0480,6678,641,0,tree
7 | 6,0.0450,6604,643,1,ceiling
8 | 7,0.0398,4023,408,1,road;route
9 | 8,0.0231,1906,199,0,bed
10 | 9,0.0198,4688,460,0,windowpane;window
11 | 10,0.0183,2423,225,1,grass
12 | 11,0.0181,2874,294,0,cabinet
13 | 12,0.0166,3068,310,1,sidewalk;pavement
14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
15 | 14,0.0151,1804,190,1,earth;ground
16 | 15,0.0118,6666,796,0,door;double;door
17 | 16,0.0110,4269,411,0,table
18 | 17,0.0109,1691,160,1,mountain;mount
19 | 18,0.0104,3999,441,0,plant;flora;plant;life
20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
21 | 20,0.0103,3261,318,0,chair
22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
23 | 22,0.0074,709,75,1,water
24 | 23,0.0067,3296,315,0,painting;picture
25 | 24,0.0065,1191,106,0,sofa;couch;lounge
26 | 25,0.0061,1516,162,0,shelf
27 | 26,0.0060,667,69,1,house
28 | 27,0.0053,651,57,1,sea
29 | 28,0.0052,1847,224,0,mirror
30 | 29,0.0046,1158,128,1,rug;carpet;carpeting
31 | 30,0.0044,480,44,1,field
32 | 31,0.0044,1172,98,0,armchair
33 | 32,0.0044,1292,184,0,seat
34 | 33,0.0033,1386,138,0,fence;fencing
35 | 34,0.0031,698,61,0,desk
36 | 35,0.0030,781,73,0,rock;stone
37 | 36,0.0027,380,43,0,wardrobe;closet;press
38 | 37,0.0026,3089,302,0,lamp
39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
40 | 39,0.0024,804,99,0,railing;rail
41 | 40,0.0023,1453,153,0,cushion
42 | 41,0.0023,411,37,0,base;pedestal;stand
43 | 42,0.0022,1440,162,0,box
44 | 43,0.0022,800,77,0,column;pillar
45 | 44,0.0020,2650,298,0,signboard;sign
46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
47 | 46,0.0019,367,36,0,counter
48 | 47,0.0018,311,30,1,sand
49 | 48,0.0018,1181,122,0,sink
50 | 49,0.0018,287,23,1,skyscraper
51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace
52 | 51,0.0018,402,43,0,refrigerator;icebox
53 | 52,0.0018,130,12,1,grandstand;covered;stand
54 | 53,0.0018,561,64,1,path
55 | 54,0.0017,880,102,0,stairs;steps
56 | 55,0.0017,86,12,1,runway
57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine
58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
59 | 58,0.0017,930,109,0,pillow
60 | 59,0.0015,139,18,0,screen;door;screen
61 | 60,0.0015,564,52,1,stairway;staircase
62 | 61,0.0015,320,26,1,river
63 | 62,0.0015,261,29,1,bridge;span
64 | 63,0.0014,275,22,0,bookcase
65 | 64,0.0014,335,60,0,blind;screen
66 | 65,0.0014,792,75,0,coffee;table;cocktail;table
67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
68 | 67,0.0014,1309,138,0,flower
69 | 68,0.0013,1112,113,0,book
70 | 69,0.0013,266,27,1,hill
71 | 70,0.0013,659,66,0,bench
72 | 71,0.0012,331,31,0,countertop
73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
74 | 73,0.0012,369,36,0,palm;palm;tree
75 | 74,0.0012,144,9,0,kitchen;island
76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
77 | 76,0.0010,324,33,0,swivel;chair
78 | 77,0.0009,304,27,0,boat
79 | 78,0.0009,170,20,0,bar
80 | 79,0.0009,68,6,0,arcade;machine
81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
83 | 82,0.0008,492,49,0,towel
84 | 83,0.0008,2510,269,0,light;light;source
85 | 84,0.0008,440,39,0,truck;motortruck
86 | 85,0.0008,147,18,1,tower
87 | 86,0.0008,583,56,0,chandelier;pendant;pendent
88 | 87,0.0007,533,61,0,awning;sunshade;sunblind
89 | 88,0.0007,1989,239,0,streetlight;street;lamp
90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk
91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
92 | 91,0.0007,135,12,0,airplane;aeroplane;plane
93 | 92,0.0007,83,5,1,dirt;track
94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
95 | 94,0.0006,1003,104,0,pole
96 | 95,0.0006,182,12,1,land;ground;soil
97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
100 | 99,0.0006,965,114,0,bottle
101 | 100,0.0006,117,13,0,buffet;counter;sideboard
102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
103 | 102,0.0006,108,9,1,stage
104 | 103,0.0006,557,55,0,van
105 | 104,0.0006,52,4,0,ship
106 | 105,0.0005,99,5,0,fountain
107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
108 | 107,0.0005,292,31,0,canopy
109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine
110 | 109,0.0005,340,38,0,plaything;toy
111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
112 | 111,0.0005,465,49,0,stool
113 | 112,0.0005,50,4,0,barrel;cask
114 | 113,0.0005,622,75,0,basket;handbasket
115 | 114,0.0005,80,9,1,waterfall;falls
116 | 115,0.0005,59,3,0,tent;collapsible;shelter
117 | 116,0.0005,531,72,0,bag
118 | 117,0.0005,282,30,0,minibike;motorbike
119 | 118,0.0005,73,7,0,cradle
120 | 119,0.0005,435,44,0,oven
121 | 120,0.0005,136,25,0,ball
122 | 121,0.0005,116,24,0,food;solid;food
123 | 122,0.0004,266,31,0,step;stair
124 | 123,0.0004,58,12,0,tank;storage;tank
125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque
126 | 125,0.0004,319,43,0,microwave;microwave;oven
127 | 126,0.0004,1193,139,0,pot;flowerpot
128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle
130 | 129,0.0004,52,5,1,lake
131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen
133 | 132,0.0004,201,30,0,blanket;cover
134 | 133,0.0004,285,21,0,sculpture
135 | 134,0.0004,268,27,0,hood;exhaust;hood
136 | 135,0.0003,1020,108,0,sconce
137 | 136,0.0003,1282,122,0,vase
138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
139 | 138,0.0003,453,57,0,tray
140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
141 | 140,0.0003,397,44,0,fan
142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock
143 | 142,0.0003,228,18,0,crt;screen
144 | 143,0.0003,570,59,0,plate
145 | 144,0.0003,217,22,0,monitor;monitoring;device
146 | 145,0.0003,206,19,0,bulletin;board;notice;board
147 | 146,0.0003,130,14,0,shower
148 | 147,0.0003,178,28,0,radiator
149 | 148,0.0002,504,57,0,glass;drinking;glass
150 | 149,0.0002,775,96,0,clock
151 | 150,0.0002,421,56,0,flag
152 |
--------------------------------------------------------------------------------
/demo_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Image and model names
4 | TEST_IMG=ADE_val_00001519.jpg
5 | MODEL_NAME=ade20k-resnet50dilated-ppm_deepsup
6 | MODEL_PATH=ckpt/$MODEL_NAME
7 | RESULT_PATH=./
8 |
9 | ENCODER=$MODEL_NAME/encoder_epoch_20.pth
10 | DECODER=$MODEL_NAME/decoder_epoch_20.pth
11 |
12 | # Download model weights and image
13 | if [ ! -e $MODEL_PATH ]; then
14 | mkdir -p $MODEL_PATH
15 | fi
16 | if [ ! -e $ENCODER ]; then
17 | wget -P $MODEL_PATH http://sceneparsing.csail.mit.edu/model/pytorch/$ENCODER
18 | fi
19 | if [ ! -e $DECODER ]; then
20 | wget -P $MODEL_PATH http://sceneparsing.csail.mit.edu/model/pytorch/$DECODER
21 | fi
22 | if [ ! -e $TEST_IMG ]; then
23 | wget -P $RESULT_PATH http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016/images/validation/$TEST_IMG
24 | fi
25 |
26 | if [ -z "$DOWNLOAD_ONLY" ]
27 | then
28 |
29 | # Inference
30 | python3 -u test.py \
31 | --imgs $TEST_IMG \
32 | --cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml \
33 | DIR $MODEL_PATH \
34 | TEST.result ./ \
35 | TEST.checkpoint epoch_20.pth
36 |
37 | fi
38 |
--------------------------------------------------------------------------------
/download_ADE20K.sh:
--------------------------------------------------------------------------------
1 | wget -O ./data/ADEChallengeData2016.zip http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
2 | unzip ./data/ADEChallengeData2016.zip -d ./data
3 | rm ./data/ADEChallengeData2016.zip
4 | echo "Dataset downloaded."
5 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # System libs
2 | import os
3 | import time
4 | import argparse
5 | from distutils.version import LooseVersion
6 | # Numerical libs
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | from scipy.io import loadmat
11 | # Our libs
12 | from mit_semseg.config import cfg
13 | from mit_semseg.dataset import ValDataset
14 | from mit_semseg.models import ModelBuilder, SegmentationModule
15 | from mit_semseg.utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger
16 | from mit_semseg.lib.nn import user_scattered_collate, async_copy_to
17 | from mit_semseg.lib.utils import as_numpy
18 | from PIL import Image
19 | from tqdm import tqdm
20 |
21 | colors = loadmat('data/color150.mat')['colors']
22 |
23 |
24 | def visualize_result(data, pred, dir_result):
25 | (img, seg, info) = data
26 |
27 | # segmentation
28 | seg_color = colorEncode(seg, colors)
29 |
30 | # prediction
31 | pred_color = colorEncode(pred, colors)
32 |
33 | # aggregate images and save
34 | im_vis = np.concatenate((img, seg_color, pred_color),
35 | axis=1).astype(np.uint8)
36 |
37 | img_name = info.split('/')[-1]
38 | Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))
39 |
40 |
41 | def evaluate(segmentation_module, loader, cfg, gpu):
42 | acc_meter = AverageMeter()
43 | intersection_meter = AverageMeter()
44 | union_meter = AverageMeter()
45 | time_meter = AverageMeter()
46 |
47 | segmentation_module.eval()
48 |
49 | pbar = tqdm(total=len(loader))
50 | for batch_data in loader:
51 | # process data
52 | batch_data = batch_data[0]
53 | seg_label = as_numpy(batch_data['seg_label'][0])
54 | img_resized_list = batch_data['img_data']
55 |
56 | torch.cuda.synchronize()
57 | tic = time.perf_counter()
58 | with torch.no_grad():
59 | segSize = (seg_label.shape[0], seg_label.shape[1])
60 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
61 | scores = async_copy_to(scores, gpu)
62 |
63 | for img in img_resized_list:
64 | feed_dict = batch_data.copy()
65 | feed_dict['img_data'] = img
66 | del feed_dict['img_ori']
67 | del feed_dict['info']
68 | feed_dict = async_copy_to(feed_dict, gpu)
69 |
70 | # forward pass
71 | scores_tmp = segmentation_module(feed_dict, segSize=segSize)
72 | scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)
73 |
74 | _, pred = torch.max(scores, dim=1)
75 | pred = as_numpy(pred.squeeze(0).cpu())
76 |
77 | torch.cuda.synchronize()
78 | time_meter.update(time.perf_counter() - tic)
79 |
80 | # calculate accuracy
81 | acc, pix = accuracy(pred, seg_label)
82 | intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class)
83 | acc_meter.update(acc, pix)
84 | intersection_meter.update(intersection)
85 | union_meter.update(union)
86 |
87 | # visualization
88 | if cfg.VAL.visualize:
89 | visualize_result(
90 | (batch_data['img_ori'], seg_label, batch_data['info']),
91 | pred,
92 | os.path.join(cfg.DIR, 'result')
93 | )
94 |
95 | pbar.update(1)
96 |
97 | # summary
98 | iou = intersection_meter.sum / (union_meter.sum + 1e-10)
99 | for i, _iou in enumerate(iou):
100 | print('class [{}], IoU: {:.4f}'.format(i, _iou))
101 |
102 | print('[Eval Summary]:')
103 | print('Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s'
104 | .format(iou.mean(), acc_meter.average()*100, time_meter.average()))
105 |
106 |
107 | def main(cfg, gpu):
108 | torch.cuda.set_device(gpu)
109 |
110 | # Network Builders
111 | net_encoder = ModelBuilder.build_encoder(
112 | arch=cfg.MODEL.arch_encoder.lower(),
113 | fc_dim=cfg.MODEL.fc_dim,
114 | weights=cfg.MODEL.weights_encoder)
115 | net_decoder = ModelBuilder.build_decoder(
116 | arch=cfg.MODEL.arch_decoder.lower(),
117 | fc_dim=cfg.MODEL.fc_dim,
118 | num_class=cfg.DATASET.num_class,
119 | weights=cfg.MODEL.weights_decoder,
120 | use_softmax=True)
121 |
122 | crit = nn.NLLLoss(ignore_index=-1)
123 |
124 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
125 |
126 | # Dataset and Loader
127 | dataset_val = ValDataset(
128 | cfg.DATASET.root_dataset,
129 | cfg.DATASET.list_val,
130 | cfg.DATASET)
131 | loader_val = torch.utils.data.DataLoader(
132 | dataset_val,
133 | batch_size=cfg.VAL.batch_size,
134 | shuffle=False,
135 | collate_fn=user_scattered_collate,
136 | num_workers=5,
137 | drop_last=True)
138 |
139 | segmentation_module.cuda()
140 |
141 | # Main loop
142 | evaluate(segmentation_module, loader_val, cfg, gpu)
143 |
144 | print('Evaluation Done!')
145 |
146 |
147 | if __name__ == '__main__':
148 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
149 | 'PyTorch>=0.4.0 is required'
150 |
151 | parser = argparse.ArgumentParser(
152 | description="PyTorch Semantic Segmentation Validation"
153 | )
154 | parser.add_argument(
155 | "--cfg",
156 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
157 | metavar="FILE",
158 | help="path to config file",
159 | type=str,
160 | )
161 | parser.add_argument(
162 | "--gpu",
163 | default=0,
164 | help="gpu to use"
165 | )
166 | parser.add_argument(
167 | "opts",
168 | help="Modify config options using the command-line",
169 | default=None,
170 | nargs=argparse.REMAINDER,
171 | )
172 | args = parser.parse_args()
173 |
174 | cfg.merge_from_file(args.cfg)
175 | cfg.merge_from_list(args.opts)
176 | # cfg.freeze()
177 |
178 | logger = setup_logger(distributed_rank=0) # TODO
179 | logger.info("Loaded configuration file {}".format(args.cfg))
180 | logger.info("Running with config:\n{}".format(cfg))
181 |
182 | # absolute paths of model weights
183 | cfg.MODEL.weights_encoder = os.path.join(
184 | cfg.DIR, 'encoder_' + cfg.VAL.checkpoint)
185 | cfg.MODEL.weights_decoder = os.path.join(
186 | cfg.DIR, 'decoder_' + cfg.VAL.checkpoint)
187 | assert os.path.exists(cfg.MODEL.weights_encoder) and \
188 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
189 |
190 | if not os.path.isdir(os.path.join(cfg.DIR, "result")):
191 | os.makedirs(os.path.join(cfg.DIR, "result"))
192 |
193 | main(cfg, args.gpu)
194 |
--------------------------------------------------------------------------------
/eval_multipro.py:
--------------------------------------------------------------------------------
1 | # System libs
2 | import os
3 | import argparse
4 | from distutils.version import LooseVersion
5 | from multiprocessing import Queue, Process
6 | # Numerical libs
7 | import numpy as np
8 | import math
9 | import torch
10 | import torch.nn as nn
11 | from scipy.io import loadmat
12 | # Our libs
13 | from mit_semseg.config import cfg
14 | from mit_semseg.dataset import ValDataset
15 | from mit_semseg.models import ModelBuilder, SegmentationModule
16 | from mit_semseg.utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger
17 | from mit_semseg.lib.nn import user_scattered_collate, async_copy_to
18 | from mit_semseg.lib.utils import as_numpy
19 | from PIL import Image
20 | from tqdm import tqdm
21 |
22 | colors = loadmat('data/color150.mat')['colors']
23 |
24 |
25 | def visualize_result(data, pred, dir_result):
26 | (img, seg, info) = data
27 |
28 | # segmentation
29 | seg_color = colorEncode(seg, colors)
30 |
31 | # prediction
32 | pred_color = colorEncode(pred, colors)
33 |
34 | # aggregate images and save
35 | im_vis = np.concatenate((img, seg_color, pred_color),
36 | axis=1).astype(np.uint8)
37 |
38 | img_name = info.split('/')[-1]
39 | Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))
40 |
41 |
42 | def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
43 | segmentation_module.eval()
44 |
45 | for batch_data in loader:
46 | # process data
47 | batch_data = batch_data[0]
48 | seg_label = as_numpy(batch_data['seg_label'][0])
49 | img_resized_list = batch_data['img_data']
50 |
51 | with torch.no_grad():
52 | segSize = (seg_label.shape[0], seg_label.shape[1])
53 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
54 | scores = async_copy_to(scores, gpu_id)
55 |
56 | for img in img_resized_list:
57 | feed_dict = batch_data.copy()
58 | feed_dict['img_data'] = img
59 | del feed_dict['img_ori']
60 | del feed_dict['info']
61 | feed_dict = async_copy_to(feed_dict, gpu_id)
62 |
63 | # forward pass
64 | scores_tmp = segmentation_module(feed_dict, segSize=segSize)
65 | scores = scores + scores_tmp / len(cfg.DATASET.imgSizes)
66 |
67 | _, pred = torch.max(scores, dim=1)
68 | pred = as_numpy(pred.squeeze(0).cpu())
69 |
70 | # calculate accuracy and SEND THEM TO MASTER
71 | acc, pix = accuracy(pred, seg_label)
72 | intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class)
73 | result_queue.put_nowait((acc, pix, intersection, union))
74 |
75 | # visualization
76 | if cfg.VAL.visualize:
77 | visualize_result(
78 | (batch_data['img_ori'], seg_label, batch_data['info']),
79 | pred,
80 | os.path.join(cfg.DIR, 'result')
81 | )
82 |
83 |
84 | def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
85 | torch.cuda.set_device(gpu_id)
86 |
87 | # Dataset and Loader
88 | dataset_val = ValDataset(
89 | cfg.DATASET.root_dataset,
90 | cfg.DATASET.list_val,
91 | cfg.DATASET,
92 | start_idx=start_idx, end_idx=end_idx)
93 | loader_val = torch.utils.data.DataLoader(
94 | dataset_val,
95 | batch_size=cfg.VAL.batch_size,
96 | shuffle=False,
97 | collate_fn=user_scattered_collate,
98 | num_workers=2)
99 |
100 | # Network Builders
101 | net_encoder = ModelBuilder.build_encoder(
102 | arch=cfg.MODEL.arch_encoder.lower(),
103 | fc_dim=cfg.MODEL.fc_dim,
104 | weights=cfg.MODEL.weights_encoder)
105 | net_decoder = ModelBuilder.build_decoder(
106 | arch=cfg.MODEL.arch_decoder.lower(),
107 | fc_dim=cfg.MODEL.fc_dim,
108 | num_class=cfg.DATASET.num_class,
109 | weights=cfg.MODEL.weights_decoder,
110 | use_softmax=True)
111 |
112 | crit = nn.NLLLoss(ignore_index=-1)
113 |
114 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
115 |
116 | segmentation_module.cuda()
117 |
118 | # Main loop
119 | evaluate(segmentation_module, loader_val, cfg, gpu_id, result_queue)
120 |
121 |
122 | def main(cfg, gpus):
123 | with open(cfg.DATASET.list_val, 'r') as f:
124 | lines = f.readlines()
125 | num_files = len(lines)
126 |
127 | num_files_per_gpu = math.ceil(num_files / len(gpus))
128 |
129 | pbar = tqdm(total=num_files)
130 |
131 | acc_meter = AverageMeter()
132 | intersection_meter = AverageMeter()
133 | union_meter = AverageMeter()
134 |
135 | result_queue = Queue(500)
136 | procs = []
137 | for idx, gpu_id in enumerate(gpus):
138 | start_idx = idx * num_files_per_gpu
139 | end_idx = min(start_idx + num_files_per_gpu, num_files)
140 | proc = Process(target=worker, args=(cfg, gpu_id, start_idx, end_idx, result_queue))
141 | print('gpu:{}, start_idx:{}, end_idx:{}'.format(gpu_id, start_idx, end_idx))
142 | proc.start()
143 | procs.append(proc)
144 |
145 | # master fetches results
146 | processed_counter = 0
147 | while processed_counter < num_files:
148 | if result_queue.empty():
149 | continue
150 | (acc, pix, intersection, union) = result_queue.get()
151 | acc_meter.update(acc, pix)
152 | intersection_meter.update(intersection)
153 | union_meter.update(union)
154 | processed_counter += 1
155 | pbar.update(1)
156 |
157 | for p in procs:
158 | p.join()
159 |
160 | # summary
161 | iou = intersection_meter.sum / (union_meter.sum + 1e-10)
162 | for i, _iou in enumerate(iou):
163 | print('class [{}], IoU: {:.4f}'.format(i, _iou))
164 |
165 | print('[Eval Summary]:')
166 | print('Mean IoU: {:.4f}, Accuracy: {:.2f}%'
167 | .format(iou.mean(), acc_meter.average()*100))
168 |
169 | print('Evaluation Done!')
170 |
171 |
172 | if __name__ == '__main__':
173 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
174 | 'PyTorch>=0.4.0 is required'
175 |
176 | parser = argparse.ArgumentParser(
177 | description="PyTorch Semantic Segmentation Validation"
178 | )
179 | parser.add_argument(
180 | "--cfg",
181 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
182 | metavar="FILE",
183 | help="path to config file",
184 | type=str,
185 | )
186 | parser.add_argument(
187 | "--gpus",
188 | default="0-3",
189 | help="gpus to use, e.g. 0-3 or 0,1,2,3"
190 | )
191 | parser.add_argument(
192 | "opts",
193 | help="Modify config options using the command-line",
194 | default=None,
195 | nargs=argparse.REMAINDER,
196 | )
197 | args = parser.parse_args()
198 |
199 | cfg.merge_from_file(args.cfg)
200 | cfg.merge_from_list(args.opts)
201 | # cfg.freeze()
202 |
203 | logger = setup_logger(distributed_rank=0) # TODO
204 | logger.info("Loaded configuration file {}".format(args.cfg))
205 | logger.info("Running with config:\n{}".format(cfg))
206 |
207 | # absolute paths of model weights
208 | cfg.MODEL.weights_encoder = os.path.join(
209 | cfg.DIR, 'encoder_' + cfg.VAL.checkpoint)
210 | cfg.MODEL.weights_decoder = os.path.join(
211 | cfg.DIR, 'decoder_' + cfg.VAL.checkpoint)
212 | assert os.path.exists(cfg.MODEL.weights_encoder) and \
213 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
214 |
215 | if not os.path.isdir(os.path.join(cfg.DIR, "result")):
216 | os.makedirs(os.path.join(cfg.DIR, "result"))
217 |
218 | # Parse gpu ids
219 | gpus = parse_devices(args.gpus)
220 | gpus = [x.replace('gpu', '') for x in gpus]
221 | gpus = [int(x) for x in gpus]
222 |
223 | main(cfg, gpus)
224 |
--------------------------------------------------------------------------------
/mit_semseg/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT CSAIL Semantic Segmentation
3 | """
4 |
5 | __version__ = '1.0.0'
6 |
--------------------------------------------------------------------------------
/mit_semseg/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import _C as cfg
2 |
--------------------------------------------------------------------------------
/mit_semseg/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Config definition
5 | # -----------------------------------------------------------------------------
6 |
7 | _C = CN()
8 | _C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup"
9 |
10 | # -----------------------------------------------------------------------------
11 | # Dataset
12 | # -----------------------------------------------------------------------------
13 | _C.DATASET = CN()
14 | _C.DATASET.root_dataset = "./data/"
15 | _C.DATASET.list_train = "./data/training.odgt"
16 | _C.DATASET.list_val = "./data/validation.odgt"
17 | _C.DATASET.num_class = 150
18 | # multiscale train/test, size of short edge (int or tuple)
19 | _C.DATASET.imgSizes = (300, 375, 450, 525, 600)
20 | # maximum input image size of long edge
21 | _C.DATASET.imgMaxSize = 1000
22 | # maxmimum downsampling rate of the network
23 | _C.DATASET.padding_constant = 8
24 | # downsampling rate of the segmentation label
25 | _C.DATASET.segm_downsampling_rate = 8
26 | # randomly horizontally flip images when train/test
27 | _C.DATASET.random_flip = True
28 |
29 | # -----------------------------------------------------------------------------
30 | # Model
31 | # -----------------------------------------------------------------------------
32 | _C.MODEL = CN()
33 | # architecture of net_encoder
34 | _C.MODEL.arch_encoder = "resnet50dilated"
35 | # architecture of net_decoder
36 | _C.MODEL.arch_decoder = "ppm_deepsup"
37 | # weights to finetune net_encoder
38 | _C.MODEL.weights_encoder = ""
39 | # weights to finetune net_decoder
40 | _C.MODEL.weights_decoder = ""
41 | # number of feature channels between encoder and decoder
42 | _C.MODEL.fc_dim = 2048
43 |
44 | # -----------------------------------------------------------------------------
45 | # Training
46 | # -----------------------------------------------------------------------------
47 | _C.TRAIN = CN()
48 | _C.TRAIN.batch_size_per_gpu = 2
49 | # epochs to train for
50 | _C.TRAIN.num_epoch = 20
51 | # epoch to start training. useful if continue from a checkpoint
52 | _C.TRAIN.start_epoch = 0
53 | # iterations of each epoch (irrelevant to batch size)
54 | _C.TRAIN.epoch_iters = 5000
55 |
56 | _C.TRAIN.optim = "SGD"
57 | _C.TRAIN.lr_encoder = 0.02
58 | _C.TRAIN.lr_decoder = 0.02
59 | # power in poly to drop LR
60 | _C.TRAIN.lr_pow = 0.9
61 | # momentum for sgd, beta1 for adam
62 | _C.TRAIN.beta1 = 0.9
63 | # weights regularizer
64 | _C.TRAIN.weight_decay = 1e-4
65 | # the weighting of deep supervision loss
66 | _C.TRAIN.deep_sup_scale = 0.4
67 | # fix bn params, only under finetuning
68 | _C.TRAIN.fix_bn = False
69 | # number of data loading workers
70 | _C.TRAIN.workers = 16
71 |
72 | # frequency to display
73 | _C.TRAIN.disp_iter = 20
74 | # manual seed
75 | _C.TRAIN.seed = 304
76 |
77 | # -----------------------------------------------------------------------------
78 | # Validation
79 | # -----------------------------------------------------------------------------
80 | _C.VAL = CN()
81 | # currently only supports 1
82 | _C.VAL.batch_size = 1
83 | # output visualization during validation
84 | _C.VAL.visualize = False
85 | # the checkpoint to evaluate on
86 | _C.VAL.checkpoint = "epoch_20.pth"
87 |
88 | # -----------------------------------------------------------------------------
89 | # Testing
90 | # -----------------------------------------------------------------------------
91 | _C.TEST = CN()
92 | # currently only supports 1
93 | _C.TEST.batch_size = 1
94 | # the checkpoint to test on
95 | _C.TEST.checkpoint = "epoch_20.pth"
96 | # folder to output visualization results
97 | _C.TEST.result = "./"
98 |
--------------------------------------------------------------------------------
/mit_semseg/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from torchvision import transforms
5 | import numpy as np
6 | from PIL import Image
7 |
8 |
9 | def imresize(im, size, interp='bilinear'):
10 | if interp == 'nearest':
11 | resample = Image.NEAREST
12 | elif interp == 'bilinear':
13 | resample = Image.BILINEAR
14 | elif interp == 'bicubic':
15 | resample = Image.BICUBIC
16 | else:
17 | raise Exception('resample method undefined!')
18 |
19 | return im.resize(size, resample)
20 |
21 |
22 | class BaseDataset(torch.utils.data.Dataset):
23 | def __init__(self, odgt, opt, **kwargs):
24 | # parse options
25 | self.imgSizes = opt.imgSizes
26 | self.imgMaxSize = opt.imgMaxSize
27 | # max down sampling rate of network to avoid rounding during conv or pooling
28 | self.padding_constant = opt.padding_constant
29 |
30 | # parse the input list
31 | self.parse_input_list(odgt, **kwargs)
32 |
33 | # mean and std
34 | self.normalize = transforms.Normalize(
35 | mean=[0.485, 0.456, 0.406],
36 | std=[0.229, 0.224, 0.225])
37 |
38 | def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
39 | if isinstance(odgt, list):
40 | self.list_sample = odgt
41 | elif isinstance(odgt, str):
42 | self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
43 |
44 | if max_sample > 0:
45 | self.list_sample = self.list_sample[0:max_sample]
46 | if start_idx >= 0 and end_idx >= 0: # divide file list
47 | self.list_sample = self.list_sample[start_idx:end_idx]
48 |
49 | self.num_sample = len(self.list_sample)
50 | assert self.num_sample > 0
51 | print('# samples: {}'.format(self.num_sample))
52 |
53 | def img_transform(self, img):
54 | # 0-255 to 0-1
55 | img = np.float32(np.array(img)) / 255.
56 | img = img.transpose((2, 0, 1))
57 | img = self.normalize(torch.from_numpy(img.copy()))
58 | return img
59 |
60 | def segm_transform(self, segm):
61 | # to tensor, -1 to 149
62 | segm = torch.from_numpy(np.array(segm)).long() - 1
63 | return segm
64 |
65 | # Round x to the nearest multiple of p and x' >= x
66 | def round2nearest_multiple(self, x, p):
67 | return ((x - 1) // p + 1) * p
68 |
69 |
70 | class TrainDataset(BaseDataset):
71 | def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs):
72 | super(TrainDataset, self).__init__(odgt, opt, **kwargs)
73 | self.root_dataset = root_dataset
74 | # down sampling rate of segm labe
75 | self.segm_downsampling_rate = opt.segm_downsampling_rate
76 | self.batch_per_gpu = batch_per_gpu
77 |
78 | # classify images into two classes: 1. h > w and 2. h <= w
79 | self.batch_record_list = [[], []]
80 |
81 | # override dataset length when trainig with batch_per_gpu > 1
82 | self.cur_idx = 0
83 | self.if_shuffled = False
84 |
85 | def _get_sub_batch(self):
86 | while True:
87 | # get a sample record
88 | this_sample = self.list_sample[self.cur_idx]
89 | if this_sample['height'] > this_sample['width']:
90 | self.batch_record_list[0].append(this_sample) # h > w, go to 1st class
91 | else:
92 | self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class
93 |
94 | # update current sample pointer
95 | self.cur_idx += 1
96 | if self.cur_idx >= self.num_sample:
97 | self.cur_idx = 0
98 | np.random.shuffle(self.list_sample)
99 |
100 | if len(self.batch_record_list[0]) == self.batch_per_gpu:
101 | batch_records = self.batch_record_list[0]
102 | self.batch_record_list[0] = []
103 | break
104 | elif len(self.batch_record_list[1]) == self.batch_per_gpu:
105 | batch_records = self.batch_record_list[1]
106 | self.batch_record_list[1] = []
107 | break
108 | return batch_records
109 |
110 | def __getitem__(self, index):
111 | # NOTE: random shuffle for the first time. shuffle in __init__ is useless
112 | if not self.if_shuffled:
113 | np.random.seed(index)
114 | np.random.shuffle(self.list_sample)
115 | self.if_shuffled = True
116 |
117 | # get sub-batch candidates
118 | batch_records = self._get_sub_batch()
119 |
120 | # resize all images' short edges to the chosen size
121 | if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple):
122 | this_short_size = np.random.choice(self.imgSizes)
123 | else:
124 | this_short_size = self.imgSizes
125 |
126 | # calculate the BATCH's height and width
127 | # since we concat more than one samples, the batch's h and w shall be larger than EACH sample
128 | batch_widths = np.zeros(self.batch_per_gpu, np.int32)
129 | batch_heights = np.zeros(self.batch_per_gpu, np.int32)
130 | for i in range(self.batch_per_gpu):
131 | img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
132 | this_scale = min(
133 | this_short_size / min(img_height, img_width), \
134 | self.imgMaxSize / max(img_height, img_width))
135 | batch_widths[i] = img_width * this_scale
136 | batch_heights[i] = img_height * this_scale
137 |
138 | # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
139 | batch_width = np.max(batch_widths)
140 | batch_height = np.max(batch_heights)
141 | batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant))
142 | batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant))
143 |
144 | assert self.padding_constant >= self.segm_downsampling_rate, \
145 | 'padding constant must be equal or large than segm downsamping rate'
146 | batch_images = torch.zeros(
147 | self.batch_per_gpu, 3, batch_height, batch_width)
148 | batch_segms = torch.zeros(
149 | self.batch_per_gpu,
150 | batch_height // self.segm_downsampling_rate,
151 | batch_width // self.segm_downsampling_rate).long()
152 |
153 | for i in range(self.batch_per_gpu):
154 | this_record = batch_records[i]
155 |
156 | # load image and label
157 | image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
158 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
159 |
160 | img = Image.open(image_path).convert('RGB')
161 | segm = Image.open(segm_path)
162 | assert(segm.mode == "L")
163 | assert(img.size[0] == segm.size[0])
164 | assert(img.size[1] == segm.size[1])
165 |
166 | # random_flip
167 | if np.random.choice([0, 1]):
168 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
169 | segm = segm.transpose(Image.FLIP_LEFT_RIGHT)
170 |
171 | # note that each sample within a mini batch has different scale param
172 | img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear')
173 | segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest')
174 |
175 | # further downsample seg label, need to avoid seg label misalignment
176 | segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate)
177 | segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate)
178 | segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0)
179 | segm_rounded.paste(segm, (0, 0))
180 | segm = imresize(
181 | segm_rounded,
182 | (segm_rounded.size[0] // self.segm_downsampling_rate, \
183 | segm_rounded.size[1] // self.segm_downsampling_rate), \
184 | interp='nearest')
185 |
186 | # image transform, to torch float tensor 3xHxW
187 | img = self.img_transform(img)
188 |
189 | # segm transform, to torch long tensor HxW
190 | segm = self.segm_transform(segm)
191 |
192 | # put into batch arrays
193 | batch_images[i][:, :img.shape[1], :img.shape[2]] = img
194 | batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm
195 |
196 | output = dict()
197 | output['img_data'] = batch_images
198 | output['seg_label'] = batch_segms
199 | return output
200 |
201 | def __len__(self):
202 | return int(1e10) # It's a fake length due to the trick that every loader maintains its own list
203 | #return self.num_sampleclass
204 |
205 |
206 | class ValDataset(BaseDataset):
207 | def __init__(self, root_dataset, odgt, opt, **kwargs):
208 | super(ValDataset, self).__init__(odgt, opt, **kwargs)
209 | self.root_dataset = root_dataset
210 |
211 | def __getitem__(self, index):
212 | this_record = self.list_sample[index]
213 | # load image and label
214 | image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
215 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
216 | img = Image.open(image_path).convert('RGB')
217 | segm = Image.open(segm_path)
218 | assert(segm.mode == "L")
219 | assert(img.size[0] == segm.size[0])
220 | assert(img.size[1] == segm.size[1])
221 |
222 | ori_width, ori_height = img.size
223 |
224 | img_resized_list = []
225 | for this_short_size in self.imgSizes:
226 | # calculate target height and width
227 | scale = min(this_short_size / float(min(ori_height, ori_width)),
228 | self.imgMaxSize / float(max(ori_height, ori_width)))
229 | target_height, target_width = int(ori_height * scale), int(ori_width * scale)
230 |
231 | # to avoid rounding in network
232 | target_width = self.round2nearest_multiple(target_width, self.padding_constant)
233 | target_height = self.round2nearest_multiple(target_height, self.padding_constant)
234 |
235 | # resize images
236 | img_resized = imresize(img, (target_width, target_height), interp='bilinear')
237 |
238 | # image transform, to torch float tensor 3xHxW
239 | img_resized = self.img_transform(img_resized)
240 | img_resized = torch.unsqueeze(img_resized, 0)
241 | img_resized_list.append(img_resized)
242 |
243 | # segm transform, to torch long tensor HxW
244 | segm = self.segm_transform(segm)
245 | batch_segms = torch.unsqueeze(segm, 0)
246 |
247 | output = dict()
248 | output['img_ori'] = np.array(img)
249 | output['img_data'] = [x.contiguous() for x in img_resized_list]
250 | output['seg_label'] = batch_segms.contiguous()
251 | output['info'] = this_record['fpath_img']
252 | return output
253 |
254 | def __len__(self):
255 | return self.num_sample
256 |
257 |
258 | class TestDataset(BaseDataset):
259 | def __init__(self, odgt, opt, **kwargs):
260 | super(TestDataset, self).__init__(odgt, opt, **kwargs)
261 |
262 | def __getitem__(self, index):
263 | this_record = self.list_sample[index]
264 | # load image
265 | image_path = this_record['fpath_img']
266 | img = Image.open(image_path).convert('RGB')
267 |
268 | ori_width, ori_height = img.size
269 |
270 | img_resized_list = []
271 | for this_short_size in self.imgSizes:
272 | # calculate target height and width
273 | scale = min(this_short_size / float(min(ori_height, ori_width)),
274 | self.imgMaxSize / float(max(ori_height, ori_width)))
275 | target_height, target_width = int(ori_height * scale), int(ori_width * scale)
276 |
277 | # to avoid rounding in network
278 | target_width = self.round2nearest_multiple(target_width, self.padding_constant)
279 | target_height = self.round2nearest_multiple(target_height, self.padding_constant)
280 |
281 | # resize images
282 | img_resized = imresize(img, (target_width, target_height), interp='bilinear')
283 |
284 | # image transform, to torch float tensor 3xHxW
285 | img_resized = self.img_transform(img_resized)
286 | img_resized = torch.unsqueeze(img_resized, 0)
287 | img_resized_list.append(img_resized)
288 |
289 | output = dict()
290 | output['img_ori'] = np.array(img)
291 | output['img_data'] = [x.contiguous() for x in img_resized_list]
292 | output['info'] = this_record['fpath_img']
293 | return output
294 |
295 | def __len__(self):
296 | return self.num_sample
297 |
--------------------------------------------------------------------------------
/mit_semseg/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/mit_semseg/lib/__init__.py
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .modules import *
2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
3 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | # customed batch norm statistics
49 | self._moving_average_fraction = 1. - momentum
50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
52 | self.register_buffer('_running_iter', torch.ones(1))
53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter
54 | self._tmp_running_var = self.running_var.clone() * self._running_iter
55 |
56 | def forward(self, input):
57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
58 | if not (self._is_parallel and self.training):
59 | return F.batch_norm(
60 | input, self.running_mean, self.running_var, self.weight, self.bias,
61 | self.training, self.momentum, self.eps)
62 |
63 | # Resize the input to (B, C, -1).
64 | input_shape = input.size()
65 | input = input.view(input.size(0), self.num_features, -1)
66 |
67 | # Compute the sum and square-sum.
68 | sum_size = input.size(0) * input.size(2)
69 | input_sum = _sum_ft(input)
70 | input_ssum = _sum_ft(input ** 2)
71 |
72 | # Reduce-and-broadcast the statistics.
73 | if self._parallel_id == 0:
74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
75 | else:
76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
77 |
78 | # Compute the output.
79 | if self.affine:
80 | # MJY:: Fuse the multiplication for speed.
81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
82 | else:
83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
84 |
85 | # Reshape it.
86 | return output.view(input_shape)
87 |
88 | def __data_parallel_replicate__(self, ctx, copy_id):
89 | self._is_parallel = True
90 | self._parallel_id = copy_id
91 |
92 | # parallel_id == 0 means master device.
93 | if self._parallel_id == 0:
94 | ctx.sync_master = self._sync_master
95 | else:
96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
97 |
98 | def _data_parallel_master(self, intermediates):
99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
101 |
102 | to_reduce = [i[1][:2] for i in intermediates]
103 | to_reduce = [j for i in to_reduce for j in i] # flatten
104 | target_gpus = [i[1].sum.get_device() for i in intermediates]
105 |
106 | sum_size = sum([i[1].sum_size for i in intermediates])
107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
108 |
109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110 |
111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112 |
113 | outputs = []
114 | for i, rec in enumerate(intermediates):
115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
116 |
117 | return outputs
118 |
119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`"""
121 | return dest * alpha + delta * beta + bias
122 |
123 | def _compute_mean_std(self, sum_, ssum, size):
124 | """Compute the mean and standard-deviation with sum and square-sum. This method
125 | also maintains the moving average on the master device."""
126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
127 | mean = sum_ / size
128 | sumvar = ssum - sum_ * mean
129 | unbias_var = sumvar / (size - 1)
130 | bias_var = sumvar / size
131 |
132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
135 |
136 | self.running_mean = self._tmp_running_mean / self._running_iter
137 | self.running_var = self._tmp_running_var / self._running_iter
138 |
139 | return mean, bias_var.clamp(self.eps) ** -0.5
140 |
141 |
142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
144 | mini-batch.
145 |
146 | .. math::
147 |
148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
149 |
150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
151 | standard-deviation are reduced across all devices during training.
152 |
153 | For example, when one uses `nn.DataParallel` to wrap the network during
154 | training, PyTorch's implementation normalize the tensor on each device using
155 | the statistics only on that device, which accelerated the computation and
156 | is also easy to implement, but the statistics might be inaccurate.
157 | Instead, in this synchronized version, the statistics will be computed
158 | over all training samples distributed on multiple devices.
159 |
160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
161 | as the built-in PyTorch implementation.
162 |
163 | The mean and standard-deviation are calculated per-dimension over
164 | the mini-batches and gamma and beta are learnable parameter vectors
165 | of size C (where C is the input size).
166 |
167 | During training, this layer keeps a running estimate of its computed mean
168 | and variance. The running sum is kept with a default momentum of 0.1.
169 |
170 | During evaluation, this running mean/variance is used for normalization.
171 |
172 | Because the BatchNorm is done over the `C` dimension, computing statistics
173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
174 |
175 | Args:
176 | num_features: num_features from an expected input of size
177 | `batch_size x num_features [x width]`
178 | eps: a value added to the denominator for numerical stability.
179 | Default: 1e-5
180 | momentum: the value used for the running_mean and running_var
181 | computation. Default: 0.1
182 | affine: a boolean value that when set to ``True``, gives the layer learnable
183 | affine parameters. Default: ``True``
184 |
185 | Shape:
186 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
188 |
189 | Examples:
190 | >>> # With Learnable Parameters
191 | >>> m = SynchronizedBatchNorm1d(100)
192 | >>> # Without Learnable Parameters
193 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
194 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
195 | >>> output = m(input)
196 | """
197 |
198 | def _check_input_dim(self, input):
199 | if input.dim() != 2 and input.dim() != 3:
200 | raise ValueError('expected 2D or 3D input (got {}D input)'
201 | .format(input.dim()))
202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
203 |
204 |
205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
207 | of 3d inputs
208 |
209 | .. math::
210 |
211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
212 |
213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
214 | standard-deviation are reduced across all devices during training.
215 |
216 | For example, when one uses `nn.DataParallel` to wrap the network during
217 | training, PyTorch's implementation normalize the tensor on each device using
218 | the statistics only on that device, which accelerated the computation and
219 | is also easy to implement, but the statistics might be inaccurate.
220 | Instead, in this synchronized version, the statistics will be computed
221 | over all training samples distributed on multiple devices.
222 |
223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
224 | as the built-in PyTorch implementation.
225 |
226 | The mean and standard-deviation are calculated per-dimension over
227 | the mini-batches and gamma and beta are learnable parameter vectors
228 | of size C (where C is the input size).
229 |
230 | During training, this layer keeps a running estimate of its computed mean
231 | and variance. The running sum is kept with a default momentum of 0.1.
232 |
233 | During evaluation, this running mean/variance is used for normalization.
234 |
235 | Because the BatchNorm is done over the `C` dimension, computing statistics
236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
237 |
238 | Args:
239 | num_features: num_features from an expected input of
240 | size batch_size x num_features x height x width
241 | eps: a value added to the denominator for numerical stability.
242 | Default: 1e-5
243 | momentum: the value used for the running_mean and running_var
244 | computation. Default: 0.1
245 | affine: a boolean value that when set to ``True``, gives the layer learnable
246 | affine parameters. Default: ``True``
247 |
248 | Shape:
249 | - Input: :math:`(N, C, H, W)`
250 | - Output: :math:`(N, C, H, W)` (same shape as input)
251 |
252 | Examples:
253 | >>> # With Learnable Parameters
254 | >>> m = SynchronizedBatchNorm2d(100)
255 | >>> # Without Learnable Parameters
256 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
258 | >>> output = m(input)
259 | """
260 |
261 | def _check_input_dim(self, input):
262 | if input.dim() != 4:
263 | raise ValueError('expected 4D input (got {}D input)'
264 | .format(input.dim()))
265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
266 |
267 |
268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
270 | of 4d inputs
271 |
272 | .. math::
273 |
274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
275 |
276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
277 | standard-deviation are reduced across all devices during training.
278 |
279 | For example, when one uses `nn.DataParallel` to wrap the network during
280 | training, PyTorch's implementation normalize the tensor on each device using
281 | the statistics only on that device, which accelerated the computation and
282 | is also easy to implement, but the statistics might be inaccurate.
283 | Instead, in this synchronized version, the statistics will be computed
284 | over all training samples distributed on multiple devices.
285 |
286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
287 | as the built-in PyTorch implementation.
288 |
289 | The mean and standard-deviation are calculated per-dimension over
290 | the mini-batches and gamma and beta are learnable parameter vectors
291 | of size C (where C is the input size).
292 |
293 | During training, this layer keeps a running estimate of its computed mean
294 | and variance. The running sum is kept with a default momentum of 0.1.
295 |
296 | During evaluation, this running mean/variance is used for normalization.
297 |
298 | Because the BatchNorm is done over the `C` dimension, computing statistics
299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
300 | or Spatio-temporal BatchNorm
301 |
302 | Args:
303 | num_features: num_features from an expected input of
304 | size batch_size x num_features x depth x height x width
305 | eps: a value added to the denominator for numerical stability.
306 | Default: 1e-5
307 | momentum: the value used for the running_mean and running_var
308 | computation. Default: 0.1
309 | affine: a boolean value that when set to ``True``, gives the layer learnable
310 | affine parameters. Default: ``True``
311 |
312 | Shape:
313 | - Input: :math:`(N, C, D, H, W)`
314 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
315 |
316 | Examples:
317 | >>> # With Learnable Parameters
318 | >>> m = SynchronizedBatchNorm3d(100)
319 | >>> # Without Learnable Parameters
320 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
322 | >>> output = m(input)
323 | """
324 |
325 | def _check_input_dim(self, input):
326 | if input.dim() != 5:
327 | raise ValueError('expected 5D input (got {}D input)'
328 | .format(input.dim()))
329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
330 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/mit_semseg/lib/nn/modules/tests/__init__.py
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/modules/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/parallel/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
2 |
--------------------------------------------------------------------------------
/mit_semseg/lib/nn/parallel/data_parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf8 -*-
2 |
3 | import torch.cuda as cuda
4 | import torch.nn as nn
5 | import torch
6 | import collections
7 | from torch.nn.parallel._functions import Gather
8 |
9 |
10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
11 |
12 |
13 | def async_copy_to(obj, dev, main_stream=None):
14 | if torch.is_tensor(obj):
15 | v = obj.cuda(dev, non_blocking=True)
16 | if main_stream is not None:
17 | v.data.record_stream(main_stream)
18 | return v
19 | elif isinstance(obj, collections.Mapping):
20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
21 | elif isinstance(obj, collections.Sequence):
22 | return [async_copy_to(o, dev, main_stream) for o in obj]
23 | else:
24 | return obj
25 |
26 |
27 | def dict_gather(outputs, target_device, dim=0):
28 | """
29 | Gathers variables from different GPUs on a specified device
30 | (-1 means the CPU), with dictionary support.
31 | """
32 | def gather_map(outputs):
33 | out = outputs[0]
34 | if torch.is_tensor(out):
35 | # MJY(20180330) HACK:: force nr_dims > 0
36 | if out.dim() == 0:
37 | outputs = [o.unsqueeze(0) for o in outputs]
38 | return Gather.apply(target_device, dim, *outputs)
39 | elif out is None:
40 | return None
41 | elif isinstance(out, collections.Mapping):
42 | return {k: gather_map([o[k] for o in outputs]) for k in out}
43 | elif isinstance(out, collections.Sequence):
44 | return type(out)(map(gather_map, zip(*outputs)))
45 | return gather_map(outputs)
46 |
47 |
48 | class DictGatherDataParallel(nn.DataParallel):
49 | def gather(self, outputs, output_device):
50 | return dict_gather(outputs, output_device, dim=self.dim)
51 |
52 |
53 | class UserScatteredDataParallel(DictGatherDataParallel):
54 | def scatter(self, inputs, kwargs, device_ids):
55 | assert len(inputs) == 1
56 | inputs = inputs[0]
57 | inputs = _async_copy_stream(inputs, device_ids)
58 | inputs = [[i] for i in inputs]
59 | assert len(kwargs) == 0
60 | kwargs = [{} for _ in range(len(inputs))]
61 |
62 | return inputs, kwargs
63 |
64 |
65 | def user_scattered_collate(batch):
66 | return batch
67 |
68 |
69 | def _async_copy(inputs, device_ids):
70 | nr_devs = len(device_ids)
71 | assert type(inputs) in (tuple, list)
72 | assert len(inputs) == nr_devs
73 |
74 | outputs = []
75 | for i, dev in zip(inputs, device_ids):
76 | with cuda.device(dev):
77 | outputs.append(async_copy_to(i, dev))
78 |
79 | return tuple(outputs)
80 |
81 |
82 | def _async_copy_stream(inputs, device_ids):
83 | nr_devs = len(device_ids)
84 | assert type(inputs) in (tuple, list)
85 | assert len(inputs) == nr_devs
86 |
87 | outputs = []
88 | streams = [_get_stream(d) for d in device_ids]
89 | for i, dev, stream in zip(inputs, device_ids, streams):
90 | with cuda.device(dev):
91 | main_stream = cuda.current_stream()
92 | with cuda.stream(stream):
93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream))
94 | main_stream.wait_stream(stream)
95 |
96 | return outputs
97 |
98 |
99 | """Adapted from: torch/nn/parallel/_functions.py"""
100 | # background streams used for copying
101 | _streams = None
102 |
103 |
104 | def _get_stream(device):
105 | """Gets a background stream for copying between CPU and GPU"""
106 | global _streams
107 | if device == -1:
108 | return None
109 | if _streams is None:
110 | _streams = [None] * cuda.device_count()
111 | if _streams[device] is None: _streams[device] = cuda.Stream(device)
112 | return _streams[device]
113 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .th import *
2 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .dataset import Dataset, TensorDataset, ConcatDataset
3 | from .dataloader import DataLoader
4 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as multiprocessing
3 | from torch._C import _set_worker_signal_handlers, \
4 | _remove_worker_pids, _error_if_any_worker_fails
5 | try:
6 | from torch._C import _set_worker_pids
7 | except:
8 | from torch._C import _update_worker_pids as _set_worker_pids
9 | from .sampler import SequentialSampler, RandomSampler, BatchSampler
10 | import signal
11 | import collections
12 | import re
13 | import sys
14 | import threading
15 | import traceback
16 | from torch._six import string_classes, int_classes
17 | import numpy as np
18 |
19 | if sys.version_info[0] == 2:
20 | import Queue as queue
21 | else:
22 | import queue
23 |
24 |
25 | class ExceptionWrapper(object):
26 | r"Wraps an exception plus traceback to communicate across threads"
27 |
28 | def __init__(self, exc_info):
29 | self.exc_type = exc_info[0]
30 | self.exc_msg = "".join(traceback.format_exception(*exc_info))
31 |
32 |
33 | _use_shared_memory = False
34 | """Whether to use shared memory in default_collate"""
35 |
36 |
37 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
38 | global _use_shared_memory
39 | _use_shared_memory = True
40 |
41 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
42 | # module's handlers are executed after Python returns from C low-level
43 | # handlers, likely when the same fatal signal happened again already.
44 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
45 | _set_worker_signal_handlers()
46 |
47 | torch.set_num_threads(1)
48 | torch.manual_seed(seed)
49 | np.random.seed(seed)
50 |
51 | if init_fn is not None:
52 | init_fn(worker_id)
53 |
54 | while True:
55 | r = index_queue.get()
56 | if r is None:
57 | break
58 | idx, batch_indices = r
59 | try:
60 | samples = collate_fn([dataset[i] for i in batch_indices])
61 | except Exception:
62 | data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
63 | else:
64 | data_queue.put((idx, samples))
65 |
66 |
67 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
68 | if pin_memory:
69 | torch.cuda.set_device(device_id)
70 |
71 | while True:
72 | try:
73 | r = in_queue.get()
74 | except Exception:
75 | if done_event.is_set():
76 | return
77 | raise
78 | if r is None:
79 | break
80 | if isinstance(r[1], ExceptionWrapper):
81 | out_queue.put(r)
82 | continue
83 | idx, batch = r
84 | try:
85 | if pin_memory:
86 | batch = pin_memory_batch(batch)
87 | except Exception:
88 | out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
89 | else:
90 | out_queue.put((idx, batch))
91 |
92 | numpy_type_map = {
93 | 'float64': torch.DoubleTensor,
94 | 'float32': torch.FloatTensor,
95 | 'float16': torch.HalfTensor,
96 | 'int64': torch.LongTensor,
97 | 'int32': torch.IntTensor,
98 | 'int16': torch.ShortTensor,
99 | 'int8': torch.CharTensor,
100 | 'uint8': torch.ByteTensor,
101 | }
102 |
103 |
104 | def default_collate(batch):
105 | "Puts each data field into a tensor with outer dimension batch size"
106 |
107 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
108 | elem_type = type(batch[0])
109 | if torch.is_tensor(batch[0]):
110 | out = None
111 | if _use_shared_memory:
112 | # If we're in a background process, concatenate directly into a
113 | # shared memory tensor to avoid an extra copy
114 | numel = sum([x.numel() for x in batch])
115 | storage = batch[0].storage()._new_shared(numel)
116 | out = batch[0].new(storage)
117 | return torch.stack(batch, 0, out=out)
118 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
119 | and elem_type.__name__ != 'string_':
120 | elem = batch[0]
121 | if elem_type.__name__ == 'ndarray':
122 | # array of string classes and object
123 | if re.search('[SaUO]', elem.dtype.str) is not None:
124 | raise TypeError(error_msg.format(elem.dtype))
125 |
126 | return torch.stack([torch.from_numpy(b) for b in batch], 0)
127 | if elem.shape == (): # scalars
128 | py_type = float if elem.dtype.name.startswith('float') else int
129 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
130 | elif isinstance(batch[0], int_classes):
131 | return torch.LongTensor(batch)
132 | elif isinstance(batch[0], float):
133 | return torch.DoubleTensor(batch)
134 | elif isinstance(batch[0], string_classes):
135 | return batch
136 | elif isinstance(batch[0], collections.Mapping):
137 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
138 | elif isinstance(batch[0], collections.Sequence):
139 | transposed = zip(*batch)
140 | return [default_collate(samples) for samples in transposed]
141 |
142 | raise TypeError((error_msg.format(type(batch[0]))))
143 |
144 |
145 | def pin_memory_batch(batch):
146 | if torch.is_tensor(batch):
147 | return batch.pin_memory()
148 | elif isinstance(batch, string_classes):
149 | return batch
150 | elif isinstance(batch, collections.Mapping):
151 | return {k: pin_memory_batch(sample) for k, sample in batch.items()}
152 | elif isinstance(batch, collections.Sequence):
153 | return [pin_memory_batch(sample) for sample in batch]
154 | else:
155 | return batch
156 |
157 |
158 | _SIGCHLD_handler_set = False
159 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one
160 | handler needs to be set for all DataLoaders in a process."""
161 |
162 |
163 | def _set_SIGCHLD_handler():
164 | # Windows doesn't support SIGCHLD handler
165 | if sys.platform == 'win32':
166 | return
167 | # can't set signal in child threads
168 | if not isinstance(threading.current_thread(), threading._MainThread):
169 | return
170 | global _SIGCHLD_handler_set
171 | if _SIGCHLD_handler_set:
172 | return
173 | previous_handler = signal.getsignal(signal.SIGCHLD)
174 | if not callable(previous_handler):
175 | previous_handler = None
176 |
177 | def handler(signum, frame):
178 | # This following call uses `waitid` with WNOHANG from C side. Therefore,
179 | # Python can still get and update the process status successfully.
180 | _error_if_any_worker_fails()
181 | if previous_handler is not None:
182 | previous_handler(signum, frame)
183 |
184 | signal.signal(signal.SIGCHLD, handler)
185 | _SIGCHLD_handler_set = True
186 |
187 |
188 | class DataLoaderIter(object):
189 | "Iterates once over the DataLoader's dataset, as specified by the sampler"
190 |
191 | def __init__(self, loader):
192 | self.dataset = loader.dataset
193 | self.collate_fn = loader.collate_fn
194 | self.batch_sampler = loader.batch_sampler
195 | self.num_workers = loader.num_workers
196 | self.pin_memory = loader.pin_memory and torch.cuda.is_available()
197 | self.timeout = loader.timeout
198 | self.done_event = threading.Event()
199 |
200 | self.sample_iter = iter(self.batch_sampler)
201 |
202 | if self.num_workers > 0:
203 | self.worker_init_fn = loader.worker_init_fn
204 | self.index_queue = multiprocessing.SimpleQueue()
205 | self.worker_result_queue = multiprocessing.SimpleQueue()
206 | self.batches_outstanding = 0
207 | self.worker_pids_set = False
208 | self.shutdown = False
209 | self.send_idx = 0
210 | self.rcvd_idx = 0
211 | self.reorder_dict = {}
212 |
213 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
214 | self.workers = [
215 | multiprocessing.Process(
216 | target=_worker_loop,
217 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
218 | base_seed + i, self.worker_init_fn, i))
219 | for i in range(self.num_workers)]
220 |
221 | if self.pin_memory or self.timeout > 0:
222 | self.data_queue = queue.Queue()
223 | if self.pin_memory:
224 | maybe_device_id = torch.cuda.current_device()
225 | else:
226 | # do not initialize cuda context if not necessary
227 | maybe_device_id = None
228 | self.worker_manager_thread = threading.Thread(
229 | target=_worker_manager_loop,
230 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
231 | maybe_device_id))
232 | self.worker_manager_thread.daemon = True
233 | self.worker_manager_thread.start()
234 | else:
235 | self.data_queue = self.worker_result_queue
236 |
237 | for w in self.workers:
238 | w.daemon = True # ensure that the worker exits on process exit
239 | w.start()
240 |
241 | _set_worker_pids(id(self), tuple(w.pid for w in self.workers))
242 | _set_SIGCHLD_handler()
243 | self.worker_pids_set = True
244 |
245 | # prime the prefetch loop
246 | for _ in range(2 * self.num_workers):
247 | self._put_indices()
248 |
249 | def __len__(self):
250 | return len(self.batch_sampler)
251 |
252 | def _get_batch(self):
253 | if self.timeout > 0:
254 | try:
255 | return self.data_queue.get(timeout=self.timeout)
256 | except queue.Empty:
257 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
258 | else:
259 | return self.data_queue.get()
260 |
261 | def __next__(self):
262 | if self.num_workers == 0: # same-process loading
263 | indices = next(self.sample_iter) # may raise StopIteration
264 | batch = self.collate_fn([self.dataset[i] for i in indices])
265 | if self.pin_memory:
266 | batch = pin_memory_batch(batch)
267 | return batch
268 |
269 | # check if the next sample has already been generated
270 | if self.rcvd_idx in self.reorder_dict:
271 | batch = self.reorder_dict.pop(self.rcvd_idx)
272 | return self._process_next_batch(batch)
273 |
274 | if self.batches_outstanding == 0:
275 | self._shutdown_workers()
276 | raise StopIteration
277 |
278 | while True:
279 | assert (not self.shutdown and self.batches_outstanding > 0)
280 | idx, batch = self._get_batch()
281 | self.batches_outstanding -= 1
282 | if idx != self.rcvd_idx:
283 | # store out-of-order samples
284 | self.reorder_dict[idx] = batch
285 | continue
286 | return self._process_next_batch(batch)
287 |
288 | next = __next__ # Python 2 compatibility
289 |
290 | def __iter__(self):
291 | return self
292 |
293 | def _put_indices(self):
294 | assert self.batches_outstanding < 2 * self.num_workers
295 | indices = next(self.sample_iter, None)
296 | if indices is None:
297 | return
298 | self.index_queue.put((self.send_idx, indices))
299 | self.batches_outstanding += 1
300 | self.send_idx += 1
301 |
302 | def _process_next_batch(self, batch):
303 | self.rcvd_idx += 1
304 | self._put_indices()
305 | if isinstance(batch, ExceptionWrapper):
306 | raise batch.exc_type(batch.exc_msg)
307 | return batch
308 |
309 | def __getstate__(self):
310 | # TODO: add limited pickling support for sharing an iterator
311 | # across multiple threads for HOGWILD.
312 | # Probably the best way to do this is by moving the sample pushing
313 | # to a separate thread and then just sharing the data queue
314 | # but signalling the end is tricky without a non-blocking API
315 | raise NotImplementedError("DataLoaderIterator cannot be pickled")
316 |
317 | def _shutdown_workers(self):
318 | try:
319 | if not self.shutdown:
320 | self.shutdown = True
321 | self.done_event.set()
322 | # if worker_manager_thread is waiting to put
323 | while not self.data_queue.empty():
324 | self.data_queue.get()
325 | for _ in self.workers:
326 | self.index_queue.put(None)
327 | # done_event should be sufficient to exit worker_manager_thread,
328 | # but be safe here and put another None
329 | self.worker_result_queue.put(None)
330 | finally:
331 | # removes pids no matter what
332 | if self.worker_pids_set:
333 | _remove_worker_pids(id(self))
334 | self.worker_pids_set = False
335 |
336 | def __del__(self):
337 | if self.num_workers > 0:
338 | self._shutdown_workers()
339 |
340 |
341 | class DataLoader(object):
342 | """
343 | Data loader. Combines a dataset and a sampler, and provides
344 | single- or multi-process iterators over the dataset.
345 |
346 | Arguments:
347 | dataset (Dataset): dataset from which to load the data.
348 | batch_size (int, optional): how many samples per batch to load
349 | (default: 1).
350 | shuffle (bool, optional): set to ``True`` to have the data reshuffled
351 | at every epoch (default: False).
352 | sampler (Sampler, optional): defines the strategy to draw samples from
353 | the dataset. If specified, ``shuffle`` must be False.
354 | batch_sampler (Sampler, optional): like sampler, but returns a batch of
355 | indices at a time. Mutually exclusive with batch_size, shuffle,
356 | sampler, and drop_last.
357 | num_workers (int, optional): how many subprocesses to use for data
358 | loading. 0 means that the data will be loaded in the main process.
359 | (default: 0)
360 | collate_fn (callable, optional): merges a list of samples to form a mini-batch.
361 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors
362 | into CUDA pinned memory before returning them.
363 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
364 | if the dataset size is not divisible by the batch size. If ``False`` and
365 | the size of dataset is not divisible by the batch size, then the last batch
366 | will be smaller. (default: False)
367 | timeout (numeric, optional): if positive, the timeout value for collecting a batch
368 | from workers. Should always be non-negative. (default: 0)
369 | worker_init_fn (callable, optional): If not None, this will be called on each
370 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
371 | input, after seeding and before data loading. (default: None)
372 |
373 | .. note:: By default, each worker will have its PyTorch seed set to
374 | ``base_seed + worker_id``, where ``base_seed`` is a long generated
375 | by main process using its RNG. You may use ``torch.initial_seed()`` to access
376 | this value in :attr:`worker_init_fn`, which can be used to set other seeds
377 | (e.g. NumPy) before data loading.
378 |
379 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
380 | unpicklable object, e.g., a lambda function.
381 | """
382 |
383 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
384 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
385 | timeout=0, worker_init_fn=None):
386 | self.dataset = dataset
387 | self.batch_size = batch_size
388 | self.num_workers = num_workers
389 | self.collate_fn = collate_fn
390 | self.pin_memory = pin_memory
391 | self.drop_last = drop_last
392 | self.timeout = timeout
393 | self.worker_init_fn = worker_init_fn
394 |
395 | if timeout < 0:
396 | raise ValueError('timeout option should be non-negative')
397 |
398 | if batch_sampler is not None:
399 | if batch_size > 1 or shuffle or sampler is not None or drop_last:
400 | raise ValueError('batch_sampler is mutually exclusive with '
401 | 'batch_size, shuffle, sampler, and drop_last')
402 |
403 | if sampler is not None and shuffle:
404 | raise ValueError('sampler is mutually exclusive with shuffle')
405 |
406 | if self.num_workers < 0:
407 | raise ValueError('num_workers cannot be negative; '
408 | 'use num_workers=0 to disable multiprocessing.')
409 |
410 | if batch_sampler is None:
411 | if sampler is None:
412 | if shuffle:
413 | sampler = RandomSampler(dataset)
414 | else:
415 | sampler = SequentialSampler(dataset)
416 | batch_sampler = BatchSampler(sampler, batch_size, drop_last)
417 |
418 | self.sampler = sampler
419 | self.batch_sampler = batch_sampler
420 |
421 | def __iter__(self):
422 | return DataLoaderIter(self)
423 |
424 | def __len__(self):
425 | return len(self.batch_sampler)
426 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import warnings
3 |
4 | from torch._utils import _accumulate
5 | from torch import randperm
6 |
7 |
8 | class Dataset(object):
9 | """An abstract class representing a Dataset.
10 |
11 | All other datasets should subclass it. All subclasses should override
12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``,
13 | supporting integer indexing in range from 0 to len(self) exclusive.
14 | """
15 |
16 | def __getitem__(self, index):
17 | raise NotImplementedError
18 |
19 | def __len__(self):
20 | raise NotImplementedError
21 |
22 | def __add__(self, other):
23 | return ConcatDataset([self, other])
24 |
25 |
26 | class TensorDataset(Dataset):
27 | """Dataset wrapping data and target tensors.
28 |
29 | Each sample will be retrieved by indexing both tensors along the first
30 | dimension.
31 |
32 | Arguments:
33 | data_tensor (Tensor): contains sample data.
34 | target_tensor (Tensor): contains sample targets (labels).
35 | """
36 |
37 | def __init__(self, data_tensor, target_tensor):
38 | assert data_tensor.size(0) == target_tensor.size(0)
39 | self.data_tensor = data_tensor
40 | self.target_tensor = target_tensor
41 |
42 | def __getitem__(self, index):
43 | return self.data_tensor[index], self.target_tensor[index]
44 |
45 | def __len__(self):
46 | return self.data_tensor.size(0)
47 |
48 |
49 | class ConcatDataset(Dataset):
50 | """
51 | Dataset to concatenate multiple datasets.
52 | Purpose: useful to assemble different existing datasets, possibly
53 | large-scale datasets as the concatenation operation is done in an
54 | on-the-fly manner.
55 |
56 | Arguments:
57 | datasets (iterable): List of datasets to be concatenated
58 | """
59 |
60 | @staticmethod
61 | def cumsum(sequence):
62 | r, s = [], 0
63 | for e in sequence:
64 | l = len(e)
65 | r.append(l + s)
66 | s += l
67 | return r
68 |
69 | def __init__(self, datasets):
70 | super(ConcatDataset, self).__init__()
71 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
72 | self.datasets = list(datasets)
73 | self.cumulative_sizes = self.cumsum(self.datasets)
74 |
75 | def __len__(self):
76 | return self.cumulative_sizes[-1]
77 |
78 | def __getitem__(self, idx):
79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
80 | if dataset_idx == 0:
81 | sample_idx = idx
82 | else:
83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
84 | return self.datasets[dataset_idx][sample_idx]
85 |
86 | @property
87 | def cummulative_sizes(self):
88 | warnings.warn("cummulative_sizes attribute is renamed to "
89 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
90 | return self.cumulative_sizes
91 |
92 |
93 | class Subset(Dataset):
94 | def __init__(self, dataset, indices):
95 | self.dataset = dataset
96 | self.indices = indices
97 |
98 | def __getitem__(self, idx):
99 | return self.dataset[self.indices[idx]]
100 |
101 | def __len__(self):
102 | return len(self.indices)
103 |
104 |
105 | def random_split(dataset, lengths):
106 | """
107 | Randomly split a dataset into non-overlapping new datasets of given lengths
108 | ds
109 |
110 | Arguments:
111 | dataset (Dataset): Dataset to be split
112 | lengths (iterable): lengths of splits to be produced
113 | """
114 | if sum(lengths) != len(dataset):
115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
116 |
117 | indices = randperm(sum(lengths))
118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
119 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/data/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from .sampler import Sampler
4 | from torch.distributed import get_world_size, get_rank
5 |
6 |
7 | class DistributedSampler(Sampler):
8 | """Sampler that restricts data loading to a subset of the dataset.
9 |
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 |
15 | .. note::
16 | Dataset is assumed to be of constant size.
17 |
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None):
26 | if num_replicas is None:
27 | num_replicas = get_world_size()
28 | if rank is None:
29 | rank = get_rank()
30 | self.dataset = dataset
31 | self.num_replicas = num_replicas
32 | self.rank = rank
33 | self.epoch = 0
34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
35 | self.total_size = self.num_samples * self.num_replicas
36 |
37 | def __iter__(self):
38 | # deterministically shuffle based on epoch
39 | g = torch.Generator()
40 | g.manual_seed(self.epoch)
41 | indices = list(torch.randperm(len(self.dataset), generator=g))
42 |
43 | # add extra samples to make it evenly divisible
44 | indices += indices[:(self.total_size - len(indices))]
45 | assert len(indices) == self.total_size
46 |
47 | # subsample
48 | offset = self.num_samples * self.rank
49 | indices = indices[offset:offset + self.num_samples]
50 | assert len(indices) == self.num_samples
51 |
52 | return iter(indices)
53 |
54 | def __len__(self):
55 | return self.num_samples
56 |
57 | def set_epoch(self, epoch):
58 | self.epoch = epoch
59 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Sampler(object):
5 | """Base class for all Samplers.
6 |
7 | Every Sampler subclass has to provide an __iter__ method, providing a way
8 | to iterate over indices of dataset elements, and a __len__ method that
9 | returns the length of the returned iterators.
10 | """
11 |
12 | def __init__(self, data_source):
13 | pass
14 |
15 | def __iter__(self):
16 | raise NotImplementedError
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 |
22 | class SequentialSampler(Sampler):
23 | """Samples elements sequentially, always in the same order.
24 |
25 | Arguments:
26 | data_source (Dataset): dataset to sample from
27 | """
28 |
29 | def __init__(self, data_source):
30 | self.data_source = data_source
31 |
32 | def __iter__(self):
33 | return iter(range(len(self.data_source)))
34 |
35 | def __len__(self):
36 | return len(self.data_source)
37 |
38 |
39 | class RandomSampler(Sampler):
40 | """Samples elements randomly, without replacement.
41 |
42 | Arguments:
43 | data_source (Dataset): dataset to sample from
44 | """
45 |
46 | def __init__(self, data_source):
47 | self.data_source = data_source
48 |
49 | def __iter__(self):
50 | return iter(torch.randperm(len(self.data_source)).long())
51 |
52 | def __len__(self):
53 | return len(self.data_source)
54 |
55 |
56 | class SubsetRandomSampler(Sampler):
57 | """Samples elements randomly from a given list of indices, without replacement.
58 |
59 | Arguments:
60 | indices (list): a list of indices
61 | """
62 |
63 | def __init__(self, indices):
64 | self.indices = indices
65 |
66 | def __iter__(self):
67 | return (self.indices[i] for i in torch.randperm(len(self.indices)))
68 |
69 | def __len__(self):
70 | return len(self.indices)
71 |
72 |
73 | class WeightedRandomSampler(Sampler):
74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
75 |
76 | Arguments:
77 | weights (list) : a list of weights, not necessary summing up to one
78 | num_samples (int): number of samples to draw
79 | replacement (bool): if ``True``, samples are drawn with replacement.
80 | If not, they are drawn without replacement, which means that when a
81 | sample index is drawn for a row, it cannot be drawn again for that row.
82 | """
83 |
84 | def __init__(self, weights, num_samples, replacement=True):
85 | self.weights = torch.DoubleTensor(weights)
86 | self.num_samples = num_samples
87 | self.replacement = replacement
88 |
89 | def __iter__(self):
90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
91 |
92 | def __len__(self):
93 | return self.num_samples
94 |
95 |
96 | class BatchSampler(object):
97 | """Wraps another sampler to yield a mini-batch of indices.
98 |
99 | Args:
100 | sampler (Sampler): Base sampler.
101 | batch_size (int): Size of mini-batch.
102 | drop_last (bool): If ``True``, the sampler will drop the last batch if
103 | its size would be less than ``batch_size``
104 |
105 | Example:
106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
110 | """
111 |
112 | def __init__(self, sampler, batch_size, drop_last):
113 | self.sampler = sampler
114 | self.batch_size = batch_size
115 | self.drop_last = drop_last
116 |
117 | def __iter__(self):
118 | batch = []
119 | for idx in self.sampler:
120 | batch.append(idx)
121 | if len(batch) == self.batch_size:
122 | yield batch
123 | batch = []
124 | if len(batch) > 0 and not self.drop_last:
125 | yield batch
126 |
127 | def __len__(self):
128 | if self.drop_last:
129 | return len(self.sampler) // self.batch_size
130 | else:
131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size
132 |
--------------------------------------------------------------------------------
/mit_semseg/lib/utils/th.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import collections
5 |
6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile']
7 |
8 | def as_variable(obj):
9 | if isinstance(obj, Variable):
10 | return obj
11 | if isinstance(obj, collections.Sequence):
12 | return [as_variable(v) for v in obj]
13 | elif isinstance(obj, collections.Mapping):
14 | return {k: as_variable(v) for k, v in obj.items()}
15 | else:
16 | return Variable(obj)
17 |
18 | def as_numpy(obj):
19 | if isinstance(obj, collections.Sequence):
20 | return [as_numpy(v) for v in obj]
21 | elif isinstance(obj, collections.Mapping):
22 | return {k: as_numpy(v) for k, v in obj.items()}
23 | elif isinstance(obj, Variable):
24 | return obj.data.cpu().numpy()
25 | elif torch.is_tensor(obj):
26 | return obj.cpu().numpy()
27 | else:
28 | return np.array(obj)
29 |
30 | def mark_volatile(obj):
31 | if torch.is_tensor(obj):
32 | obj = Variable(obj)
33 | if isinstance(obj, Variable):
34 | obj.no_grad = True
35 | return obj
36 | elif isinstance(obj, collections.Mapping):
37 | return {k: mark_volatile(o) for k, o in obj.items()}
38 | elif isinstance(obj, collections.Sequence):
39 | return [mark_volatile(o) for o in obj]
40 | else:
41 | return obj
42 |
--------------------------------------------------------------------------------
/mit_semseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ModelBuilder, SegmentationModule
2 |
--------------------------------------------------------------------------------
/mit_semseg/models/hrnet.py:
--------------------------------------------------------------------------------
1 | """
2 | This HRNet implementation is modified from the following repository:
3 | https://github.com/HRNet/HRNet-Semantic-Segmentation
4 | """
5 |
6 | import logging
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from .utils import load_url
11 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d
12 |
13 | BatchNorm2d = SynchronizedBatchNorm2d
14 | BN_MOMENTUM = 0.1
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | __all__ = ['hrnetv2']
19 |
20 |
21 | model_urls = {
22 | 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes, out_planes, stride=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=1, bias=False)
30 |
31 |
32 | class BasicBlock(nn.Module):
33 | expansion = 1
34 |
35 | def __init__(self, inplanes, planes, stride=1, downsample=None):
36 | super(BasicBlock, self).__init__()
37 | self.conv1 = conv3x3(inplanes, planes, stride)
38 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
39 | self.relu = nn.ReLU(inplace=True)
40 | self.conv2 = conv3x3(planes, planes)
41 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
42 | self.downsample = downsample
43 | self.stride = stride
44 |
45 | def forward(self, x):
46 | residual = x
47 |
48 | out = self.conv1(x)
49 | out = self.bn1(out)
50 | out = self.relu(out)
51 |
52 | out = self.conv2(out)
53 | out = self.bn2(out)
54 |
55 | if self.downsample is not None:
56 | residual = self.downsample(x)
57 |
58 | out += residual
59 | out = self.relu(out)
60 |
61 | return out
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | expansion = 4
66 |
67 | def __init__(self, inplanes, planes, stride=1, downsample=None):
68 | super(Bottleneck, self).__init__()
69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
70 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
72 | padding=1, bias=False)
73 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
74 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
75 | bias=False)
76 | self.bn3 = BatchNorm2d(planes * self.expansion,
77 | momentum=BN_MOMENTUM)
78 | self.relu = nn.ReLU(inplace=True)
79 | self.downsample = downsample
80 | self.stride = stride
81 |
82 | def forward(self, x):
83 | residual = x
84 |
85 | out = self.conv1(x)
86 | out = self.bn1(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv2(out)
90 | out = self.bn2(out)
91 | out = self.relu(out)
92 |
93 | out = self.conv3(out)
94 | out = self.bn3(out)
95 |
96 | if self.downsample is not None:
97 | residual = self.downsample(x)
98 |
99 | out += residual
100 | out = self.relu(out)
101 |
102 | return out
103 |
104 |
105 | class HighResolutionModule(nn.Module):
106 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
107 | num_channels, fuse_method, multi_scale_output=True):
108 | super(HighResolutionModule, self).__init__()
109 | self._check_branches(
110 | num_branches, blocks, num_blocks, num_inchannels, num_channels)
111 |
112 | self.num_inchannels = num_inchannels
113 | self.fuse_method = fuse_method
114 | self.num_branches = num_branches
115 |
116 | self.multi_scale_output = multi_scale_output
117 |
118 | self.branches = self._make_branches(
119 | num_branches, blocks, num_blocks, num_channels)
120 | self.fuse_layers = self._make_fuse_layers()
121 | self.relu = nn.ReLU(inplace=True)
122 |
123 | def _check_branches(self, num_branches, blocks, num_blocks,
124 | num_inchannels, num_channels):
125 | if num_branches != len(num_blocks):
126 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
127 | num_branches, len(num_blocks))
128 | logger.error(error_msg)
129 | raise ValueError(error_msg)
130 |
131 | if num_branches != len(num_channels):
132 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
133 | num_branches, len(num_channels))
134 | logger.error(error_msg)
135 | raise ValueError(error_msg)
136 |
137 | if num_branches != len(num_inchannels):
138 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
139 | num_branches, len(num_inchannels))
140 | logger.error(error_msg)
141 | raise ValueError(error_msg)
142 |
143 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
144 | stride=1):
145 | downsample = None
146 | if stride != 1 or \
147 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
148 | downsample = nn.Sequential(
149 | nn.Conv2d(self.num_inchannels[branch_index],
150 | num_channels[branch_index] * block.expansion,
151 | kernel_size=1, stride=stride, bias=False),
152 | BatchNorm2d(num_channels[branch_index] * block.expansion,
153 | momentum=BN_MOMENTUM),
154 | )
155 |
156 | layers = []
157 | layers.append(block(self.num_inchannels[branch_index],
158 | num_channels[branch_index], stride, downsample))
159 | self.num_inchannels[branch_index] = \
160 | num_channels[branch_index] * block.expansion
161 | for i in range(1, num_blocks[branch_index]):
162 | layers.append(block(self.num_inchannels[branch_index],
163 | num_channels[branch_index]))
164 |
165 | return nn.Sequential(*layers)
166 |
167 | def _make_branches(self, num_branches, block, num_blocks, num_channels):
168 | branches = []
169 |
170 | for i in range(num_branches):
171 | branches.append(
172 | self._make_one_branch(i, block, num_blocks, num_channels))
173 |
174 | return nn.ModuleList(branches)
175 |
176 | def _make_fuse_layers(self):
177 | if self.num_branches == 1:
178 | return None
179 |
180 | num_branches = self.num_branches
181 | num_inchannels = self.num_inchannels
182 | fuse_layers = []
183 | for i in range(num_branches if self.multi_scale_output else 1):
184 | fuse_layer = []
185 | for j in range(num_branches):
186 | if j > i:
187 | fuse_layer.append(nn.Sequential(
188 | nn.Conv2d(num_inchannels[j],
189 | num_inchannels[i],
190 | 1,
191 | 1,
192 | 0,
193 | bias=False),
194 | BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM)))
195 | elif j == i:
196 | fuse_layer.append(None)
197 | else:
198 | conv3x3s = []
199 | for k in range(i-j):
200 | if k == i - j - 1:
201 | num_outchannels_conv3x3 = num_inchannels[i]
202 | conv3x3s.append(nn.Sequential(
203 | nn.Conv2d(num_inchannels[j],
204 | num_outchannels_conv3x3,
205 | 3, 2, 1, bias=False),
206 | BatchNorm2d(num_outchannels_conv3x3,
207 | momentum=BN_MOMENTUM)))
208 | else:
209 | num_outchannels_conv3x3 = num_inchannels[j]
210 | conv3x3s.append(nn.Sequential(
211 | nn.Conv2d(num_inchannels[j],
212 | num_outchannels_conv3x3,
213 | 3, 2, 1, bias=False),
214 | BatchNorm2d(num_outchannels_conv3x3,
215 | momentum=BN_MOMENTUM),
216 | nn.ReLU(inplace=True)))
217 | fuse_layer.append(nn.Sequential(*conv3x3s))
218 | fuse_layers.append(nn.ModuleList(fuse_layer))
219 |
220 | return nn.ModuleList(fuse_layers)
221 |
222 | def get_num_inchannels(self):
223 | return self.num_inchannels
224 |
225 | def forward(self, x):
226 | if self.num_branches == 1:
227 | return [self.branches[0](x[0])]
228 |
229 | for i in range(self.num_branches):
230 | x[i] = self.branches[i](x[i])
231 |
232 | x_fuse = []
233 | for i in range(len(self.fuse_layers)):
234 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
235 | for j in range(1, self.num_branches):
236 | if i == j:
237 | y = y + x[j]
238 | elif j > i:
239 | width_output = x[i].shape[-1]
240 | height_output = x[i].shape[-2]
241 | y = y + F.interpolate(
242 | self.fuse_layers[i][j](x[j]),
243 | size=(height_output, width_output),
244 | mode='bilinear',
245 | align_corners=False)
246 | else:
247 | y = y + self.fuse_layers[i][j](x[j])
248 | x_fuse.append(self.relu(y))
249 |
250 | return x_fuse
251 |
252 |
253 | blocks_dict = {
254 | 'BASIC': BasicBlock,
255 | 'BOTTLENECK': Bottleneck
256 | }
257 |
258 |
259 | class HRNetV2(nn.Module):
260 | def __init__(self, n_class, **kwargs):
261 | super(HRNetV2, self).__init__()
262 | extra = {
263 | 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'},
264 | 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'},
265 | 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'},
266 | 'FINAL_CONV_KERNEL': 1
267 | }
268 |
269 | # stem net
270 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
271 | bias=False)
272 | self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
273 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
274 | bias=False)
275 | self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
276 | self.relu = nn.ReLU(inplace=True)
277 |
278 | self.layer1 = self._make_layer(Bottleneck, 64, 64, 4)
279 |
280 | self.stage2_cfg = extra['STAGE2']
281 | num_channels = self.stage2_cfg['NUM_CHANNELS']
282 | block = blocks_dict[self.stage2_cfg['BLOCK']]
283 | num_channels = [
284 | num_channels[i] * block.expansion for i in range(len(num_channels))]
285 | self.transition1 = self._make_transition_layer([256], num_channels)
286 | self.stage2, pre_stage_channels = self._make_stage(
287 | self.stage2_cfg, num_channels)
288 |
289 | self.stage3_cfg = extra['STAGE3']
290 | num_channels = self.stage3_cfg['NUM_CHANNELS']
291 | block = blocks_dict[self.stage3_cfg['BLOCK']]
292 | num_channels = [
293 | num_channels[i] * block.expansion for i in range(len(num_channels))]
294 | self.transition2 = self._make_transition_layer(
295 | pre_stage_channels, num_channels)
296 | self.stage3, pre_stage_channels = self._make_stage(
297 | self.stage3_cfg, num_channels)
298 |
299 | self.stage4_cfg = extra['STAGE4']
300 | num_channels = self.stage4_cfg['NUM_CHANNELS']
301 | block = blocks_dict[self.stage4_cfg['BLOCK']]
302 | num_channels = [
303 | num_channels[i] * block.expansion for i in range(len(num_channels))]
304 | self.transition3 = self._make_transition_layer(
305 | pre_stage_channels, num_channels)
306 | self.stage4, pre_stage_channels = self._make_stage(
307 | self.stage4_cfg, num_channels, multi_scale_output=True)
308 |
309 | def _make_transition_layer(
310 | self, num_channels_pre_layer, num_channels_cur_layer):
311 | num_branches_cur = len(num_channels_cur_layer)
312 | num_branches_pre = len(num_channels_pre_layer)
313 |
314 | transition_layers = []
315 | for i in range(num_branches_cur):
316 | if i < num_branches_pre:
317 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
318 | transition_layers.append(nn.Sequential(
319 | nn.Conv2d(num_channels_pre_layer[i],
320 | num_channels_cur_layer[i],
321 | 3,
322 | 1,
323 | 1,
324 | bias=False),
325 | BatchNorm2d(
326 | num_channels_cur_layer[i], momentum=BN_MOMENTUM),
327 | nn.ReLU(inplace=True)))
328 | else:
329 | transition_layers.append(None)
330 | else:
331 | conv3x3s = []
332 | for j in range(i+1-num_branches_pre):
333 | inchannels = num_channels_pre_layer[-1]
334 | outchannels = num_channels_cur_layer[i] \
335 | if j == i-num_branches_pre else inchannels
336 | conv3x3s.append(nn.Sequential(
337 | nn.Conv2d(
338 | inchannels, outchannels, 3, 2, 1, bias=False),
339 | BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
340 | nn.ReLU(inplace=True)))
341 | transition_layers.append(nn.Sequential(*conv3x3s))
342 |
343 | return nn.ModuleList(transition_layers)
344 |
345 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
346 | downsample = None
347 | if stride != 1 or inplanes != planes * block.expansion:
348 | downsample = nn.Sequential(
349 | nn.Conv2d(inplanes, planes * block.expansion,
350 | kernel_size=1, stride=stride, bias=False),
351 | BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
352 | )
353 |
354 | layers = []
355 | layers.append(block(inplanes, planes, stride, downsample))
356 | inplanes = planes * block.expansion
357 | for i in range(1, blocks):
358 | layers.append(block(inplanes, planes))
359 |
360 | return nn.Sequential(*layers)
361 |
362 | def _make_stage(self, layer_config, num_inchannels,
363 | multi_scale_output=True):
364 | num_modules = layer_config['NUM_MODULES']
365 | num_branches = layer_config['NUM_BRANCHES']
366 | num_blocks = layer_config['NUM_BLOCKS']
367 | num_channels = layer_config['NUM_CHANNELS']
368 | block = blocks_dict[layer_config['BLOCK']]
369 | fuse_method = layer_config['FUSE_METHOD']
370 |
371 | modules = []
372 | for i in range(num_modules):
373 | # multi_scale_output is only used last module
374 | if not multi_scale_output and i == num_modules - 1:
375 | reset_multi_scale_output = False
376 | else:
377 | reset_multi_scale_output = True
378 | modules.append(
379 | HighResolutionModule(
380 | num_branches,
381 | block,
382 | num_blocks,
383 | num_inchannels,
384 | num_channels,
385 | fuse_method,
386 | reset_multi_scale_output)
387 | )
388 | num_inchannels = modules[-1].get_num_inchannels()
389 |
390 | return nn.Sequential(*modules), num_inchannels
391 |
392 | def forward(self, x, return_feature_maps=False):
393 | x = self.conv1(x)
394 | x = self.bn1(x)
395 | x = self.relu(x)
396 | x = self.conv2(x)
397 | x = self.bn2(x)
398 | x = self.relu(x)
399 | x = self.layer1(x)
400 |
401 | x_list = []
402 | for i in range(self.stage2_cfg['NUM_BRANCHES']):
403 | if self.transition1[i] is not None:
404 | x_list.append(self.transition1[i](x))
405 | else:
406 | x_list.append(x)
407 | y_list = self.stage2(x_list)
408 |
409 | x_list = []
410 | for i in range(self.stage3_cfg['NUM_BRANCHES']):
411 | if self.transition2[i] is not None:
412 | x_list.append(self.transition2[i](y_list[-1]))
413 | else:
414 | x_list.append(y_list[i])
415 | y_list = self.stage3(x_list)
416 |
417 | x_list = []
418 | for i in range(self.stage4_cfg['NUM_BRANCHES']):
419 | if self.transition3[i] is not None:
420 | x_list.append(self.transition3[i](y_list[-1]))
421 | else:
422 | x_list.append(y_list[i])
423 | x = self.stage4(x_list)
424 |
425 | # Upsampling
426 | x0_h, x0_w = x[0].size(2), x[0].size(3)
427 | x1 = F.interpolate(
428 | x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
429 | x2 = F.interpolate(
430 | x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
431 | x3 = F.interpolate(
432 | x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False)
433 |
434 | x = torch.cat([x[0], x1, x2, x3], 1)
435 |
436 | # x = self.last_layer(x)
437 | return [x]
438 |
439 |
440 | def hrnetv2(pretrained=False, **kwargs):
441 | model = HRNetV2(n_class=1000, **kwargs)
442 | if pretrained:
443 | model.load_state_dict(load_url(model_urls['hrnetv2']), strict=False)
444 |
445 | return model
446 |
--------------------------------------------------------------------------------
/mit_semseg/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | """
2 | This MobileNetV2 implementation is modified from the following repository:
3 | https://github.com/tonylins/pytorch-mobilenet-v2
4 | """
5 |
6 | import torch.nn as nn
7 | import math
8 | from .utils import load_url
9 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d
10 |
11 | BatchNorm2d = SynchronizedBatchNorm2d
12 |
13 |
14 | __all__ = ['mobilenetv2']
15 |
16 |
17 | model_urls = {
18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
19 | }
20 |
21 |
22 | def conv_bn(inp, oup, stride):
23 | return nn.Sequential(
24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
25 | BatchNorm2d(oup),
26 | nn.ReLU6(inplace=True)
27 | )
28 |
29 |
30 | def conv_1x1_bn(inp, oup):
31 | return nn.Sequential(
32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
33 | BatchNorm2d(oup),
34 | nn.ReLU6(inplace=True)
35 | )
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride, expand_ratio):
40 | super(InvertedResidual, self).__init__()
41 | self.stride = stride
42 | assert stride in [1, 2]
43 |
44 | hidden_dim = round(inp * expand_ratio)
45 | self.use_res_connect = self.stride == 1 and inp == oup
46 |
47 | if expand_ratio == 1:
48 | self.conv = nn.Sequential(
49 | # dw
50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
51 | BatchNorm2d(hidden_dim),
52 | nn.ReLU6(inplace=True),
53 | # pw-linear
54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
55 | BatchNorm2d(oup),
56 | )
57 | else:
58 | self.conv = nn.Sequential(
59 | # pw
60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
61 | BatchNorm2d(hidden_dim),
62 | nn.ReLU6(inplace=True),
63 | # dw
64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
65 | BatchNorm2d(hidden_dim),
66 | nn.ReLU6(inplace=True),
67 | # pw-linear
68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
69 | BatchNorm2d(oup),
70 | )
71 |
72 | def forward(self, x):
73 | if self.use_res_connect:
74 | return x + self.conv(x)
75 | else:
76 | return self.conv(x)
77 |
78 |
79 | class MobileNetV2(nn.Module):
80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
81 | super(MobileNetV2, self).__init__()
82 | block = InvertedResidual
83 | input_channel = 32
84 | last_channel = 1280
85 | interverted_residual_setting = [
86 | # t, c, n, s
87 | [1, 16, 1, 1],
88 | [6, 24, 2, 2],
89 | [6, 32, 3, 2],
90 | [6, 64, 4, 2],
91 | [6, 96, 3, 1],
92 | [6, 160, 3, 2],
93 | [6, 320, 1, 1],
94 | ]
95 |
96 | # building first layer
97 | assert input_size % 32 == 0
98 | input_channel = int(input_channel * width_mult)
99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
100 | self.features = [conv_bn(3, input_channel, 2)]
101 | # building inverted residual blocks
102 | for t, c, n, s in interverted_residual_setting:
103 | output_channel = int(c * width_mult)
104 | for i in range(n):
105 | if i == 0:
106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
107 | else:
108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
109 | input_channel = output_channel
110 | # building last several layers
111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
112 | # make it nn.Sequential
113 | self.features = nn.Sequential(*self.features)
114 |
115 | # building classifier
116 | self.classifier = nn.Sequential(
117 | nn.Dropout(0.2),
118 | nn.Linear(self.last_channel, n_class),
119 | )
120 |
121 | self._initialize_weights()
122 |
123 | def forward(self, x):
124 | x = self.features(x)
125 | x = x.mean(3).mean(2)
126 | x = self.classifier(x)
127 | return x
128 |
129 | def _initialize_weights(self):
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133 | m.weight.data.normal_(0, math.sqrt(2. / n))
134 | if m.bias is not None:
135 | m.bias.data.zero_()
136 | elif isinstance(m, BatchNorm2d):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 | elif isinstance(m, nn.Linear):
140 | n = m.weight.size(1)
141 | m.weight.data.normal_(0, 0.01)
142 | m.bias.data.zero_()
143 |
144 |
145 | def mobilenetv2(pretrained=False, **kwargs):
146 | """Constructs a MobileNet_V2 model.
147 |
148 | Args:
149 | pretrained (bool): If True, returns a model pre-trained on ImageNet
150 | """
151 | model = MobileNetV2(n_class=1000, **kwargs)
152 | if pretrained:
153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
154 | return model
155 |
--------------------------------------------------------------------------------
/mit_semseg/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from . import resnet, resnext, mobilenet, hrnet
4 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d
5 | BatchNorm2d = SynchronizedBatchNorm2d
6 |
7 |
8 | class SegmentationModuleBase(nn.Module):
9 | def __init__(self):
10 | super(SegmentationModuleBase, self).__init__()
11 |
12 | def pixel_acc(self, pred, label):
13 | _, preds = torch.max(pred, dim=1)
14 | valid = (label >= 0).long()
15 | acc_sum = torch.sum(valid * (preds == label).long())
16 | pixel_sum = torch.sum(valid)
17 | acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
18 | return acc
19 |
20 |
21 | class SegmentationModule(SegmentationModuleBase):
22 | def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
23 | super(SegmentationModule, self).__init__()
24 | self.encoder = net_enc
25 | self.decoder = net_dec
26 | self.crit = crit
27 | self.deep_sup_scale = deep_sup_scale
28 |
29 | def forward(self, feed_dict, *, segSize=None):
30 | # training
31 | if segSize is None:
32 | if self.deep_sup_scale is not None: # use deep supervision technique
33 | (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
34 | else:
35 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True))
36 |
37 | loss = self.crit(pred, feed_dict['seg_label'])
38 | if self.deep_sup_scale is not None:
39 | loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label'])
40 | loss = loss + loss_deepsup * self.deep_sup_scale
41 |
42 | acc = self.pixel_acc(pred, feed_dict['seg_label'])
43 | return loss, acc
44 | # inference
45 | else:
46 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize)
47 | return pred
48 |
49 |
50 | class ModelBuilder:
51 | # custom weights initialization
52 | @staticmethod
53 | def weights_init(m):
54 | classname = m.__class__.__name__
55 | if classname.find('Conv') != -1:
56 | nn.init.kaiming_normal_(m.weight.data)
57 | elif classname.find('BatchNorm') != -1:
58 | m.weight.data.fill_(1.)
59 | m.bias.data.fill_(1e-4)
60 | #elif classname.find('Linear') != -1:
61 | # m.weight.data.normal_(0.0, 0.0001)
62 |
63 | @staticmethod
64 | def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
65 | pretrained = True if len(weights) == 0 else False
66 | arch = arch.lower()
67 | if arch == 'mobilenetv2dilated':
68 | orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
69 | net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
70 | elif arch == 'resnet18':
71 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
72 | net_encoder = Resnet(orig_resnet)
73 | elif arch == 'resnet18dilated':
74 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
75 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
76 | elif arch == 'resnet34':
77 | raise NotImplementedError
78 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
79 | net_encoder = Resnet(orig_resnet)
80 | elif arch == 'resnet34dilated':
81 | raise NotImplementedError
82 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained)
83 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
84 | elif arch == 'resnet50':
85 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
86 | net_encoder = Resnet(orig_resnet)
87 | elif arch == 'resnet50dilated':
88 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
89 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
90 | elif arch == 'resnet101':
91 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
92 | net_encoder = Resnet(orig_resnet)
93 | elif arch == 'resnet101dilated':
94 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained)
95 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
96 | elif arch == 'resnext101':
97 | orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained)
98 | net_encoder = Resnet(orig_resnext) # we can still use class Resnet
99 | elif arch == 'hrnetv2':
100 | net_encoder = hrnet.__dict__['hrnetv2'](pretrained=pretrained)
101 | else:
102 | raise Exception('Architecture undefined!')
103 |
104 | # encoders are usually pretrained
105 | # net_encoder.apply(ModelBuilder.weights_init)
106 | if len(weights) > 0:
107 | print('Loading weights for net_encoder')
108 | net_encoder.load_state_dict(
109 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
110 | return net_encoder
111 |
112 | @staticmethod
113 | def build_decoder(arch='ppm_deepsup',
114 | fc_dim=512, num_class=150,
115 | weights='', use_softmax=False):
116 | arch = arch.lower()
117 | if arch == 'c1_deepsup':
118 | net_decoder = C1DeepSup(
119 | num_class=num_class,
120 | fc_dim=fc_dim,
121 | use_softmax=use_softmax)
122 | elif arch == 'c1':
123 | net_decoder = C1(
124 | num_class=num_class,
125 | fc_dim=fc_dim,
126 | use_softmax=use_softmax)
127 | elif arch == 'ppm':
128 | net_decoder = PPM(
129 | num_class=num_class,
130 | fc_dim=fc_dim,
131 | use_softmax=use_softmax)
132 | elif arch == 'ppm_deepsup':
133 | net_decoder = PPMDeepsup(
134 | num_class=num_class,
135 | fc_dim=fc_dim,
136 | use_softmax=use_softmax)
137 | elif arch == 'upernet_lite':
138 | net_decoder = UPerNet(
139 | num_class=num_class,
140 | fc_dim=fc_dim,
141 | use_softmax=use_softmax,
142 | fpn_dim=256)
143 | elif arch == 'upernet':
144 | net_decoder = UPerNet(
145 | num_class=num_class,
146 | fc_dim=fc_dim,
147 | use_softmax=use_softmax,
148 | fpn_dim=512)
149 | else:
150 | raise Exception('Architecture undefined!')
151 |
152 | net_decoder.apply(ModelBuilder.weights_init)
153 | if len(weights) > 0:
154 | print('Loading weights for net_decoder')
155 | net_decoder.load_state_dict(
156 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
157 | return net_decoder
158 |
159 |
160 | def conv3x3_bn_relu(in_planes, out_planes, stride=1):
161 | "3x3 convolution + BN + relu"
162 | return nn.Sequential(
163 | nn.Conv2d(in_planes, out_planes, kernel_size=3,
164 | stride=stride, padding=1, bias=False),
165 | BatchNorm2d(out_planes),
166 | nn.ReLU(inplace=True),
167 | )
168 |
169 |
170 | class Resnet(nn.Module):
171 | def __init__(self, orig_resnet):
172 | super(Resnet, self).__init__()
173 |
174 | # take pretrained resnet, except AvgPool and FC
175 | self.conv1 = orig_resnet.conv1
176 | self.bn1 = orig_resnet.bn1
177 | self.relu1 = orig_resnet.relu1
178 | self.conv2 = orig_resnet.conv2
179 | self.bn2 = orig_resnet.bn2
180 | self.relu2 = orig_resnet.relu2
181 | self.conv3 = orig_resnet.conv3
182 | self.bn3 = orig_resnet.bn3
183 | self.relu3 = orig_resnet.relu3
184 | self.maxpool = orig_resnet.maxpool
185 | self.layer1 = orig_resnet.layer1
186 | self.layer2 = orig_resnet.layer2
187 | self.layer3 = orig_resnet.layer3
188 | self.layer4 = orig_resnet.layer4
189 |
190 | def forward(self, x, return_feature_maps=False):
191 | conv_out = []
192 |
193 | x = self.relu1(self.bn1(self.conv1(x)))
194 | x = self.relu2(self.bn2(self.conv2(x)))
195 | x = self.relu3(self.bn3(self.conv3(x)))
196 | x = self.maxpool(x)
197 |
198 | x = self.layer1(x); conv_out.append(x);
199 | x = self.layer2(x); conv_out.append(x);
200 | x = self.layer3(x); conv_out.append(x);
201 | x = self.layer4(x); conv_out.append(x);
202 |
203 | if return_feature_maps:
204 | return conv_out
205 | return [x]
206 |
207 |
208 | class ResnetDilated(nn.Module):
209 | def __init__(self, orig_resnet, dilate_scale=8):
210 | super(ResnetDilated, self).__init__()
211 | from functools import partial
212 |
213 | if dilate_scale == 8:
214 | orig_resnet.layer3.apply(
215 | partial(self._nostride_dilate, dilate=2))
216 | orig_resnet.layer4.apply(
217 | partial(self._nostride_dilate, dilate=4))
218 | elif dilate_scale == 16:
219 | orig_resnet.layer4.apply(
220 | partial(self._nostride_dilate, dilate=2))
221 |
222 | # take pretrained resnet, except AvgPool and FC
223 | self.conv1 = orig_resnet.conv1
224 | self.bn1 = orig_resnet.bn1
225 | self.relu1 = orig_resnet.relu1
226 | self.conv2 = orig_resnet.conv2
227 | self.bn2 = orig_resnet.bn2
228 | self.relu2 = orig_resnet.relu2
229 | self.conv3 = orig_resnet.conv3
230 | self.bn3 = orig_resnet.bn3
231 | self.relu3 = orig_resnet.relu3
232 | self.maxpool = orig_resnet.maxpool
233 | self.layer1 = orig_resnet.layer1
234 | self.layer2 = orig_resnet.layer2
235 | self.layer3 = orig_resnet.layer3
236 | self.layer4 = orig_resnet.layer4
237 |
238 | def _nostride_dilate(self, m, dilate):
239 | classname = m.__class__.__name__
240 | if classname.find('Conv') != -1:
241 | # the convolution with stride
242 | if m.stride == (2, 2):
243 | m.stride = (1, 1)
244 | if m.kernel_size == (3, 3):
245 | m.dilation = (dilate//2, dilate//2)
246 | m.padding = (dilate//2, dilate//2)
247 | # other convoluions
248 | else:
249 | if m.kernel_size == (3, 3):
250 | m.dilation = (dilate, dilate)
251 | m.padding = (dilate, dilate)
252 |
253 | def forward(self, x, return_feature_maps=False):
254 | conv_out = []
255 |
256 | x = self.relu1(self.bn1(self.conv1(x)))
257 | x = self.relu2(self.bn2(self.conv2(x)))
258 | x = self.relu3(self.bn3(self.conv3(x)))
259 | x = self.maxpool(x)
260 |
261 | x = self.layer1(x); conv_out.append(x);
262 | x = self.layer2(x); conv_out.append(x);
263 | x = self.layer3(x); conv_out.append(x);
264 | x = self.layer4(x); conv_out.append(x);
265 |
266 | if return_feature_maps:
267 | return conv_out
268 | return [x]
269 |
270 |
271 | class MobileNetV2Dilated(nn.Module):
272 | def __init__(self, orig_net, dilate_scale=8):
273 | super(MobileNetV2Dilated, self).__init__()
274 | from functools import partial
275 |
276 | # take pretrained mobilenet features
277 | self.features = orig_net.features[:-1]
278 |
279 | self.total_idx = len(self.features)
280 | self.down_idx = [2, 4, 7, 14]
281 |
282 | if dilate_scale == 8:
283 | for i in range(self.down_idx[-2], self.down_idx[-1]):
284 | self.features[i].apply(
285 | partial(self._nostride_dilate, dilate=2)
286 | )
287 | for i in range(self.down_idx[-1], self.total_idx):
288 | self.features[i].apply(
289 | partial(self._nostride_dilate, dilate=4)
290 | )
291 | elif dilate_scale == 16:
292 | for i in range(self.down_idx[-1], self.total_idx):
293 | self.features[i].apply(
294 | partial(self._nostride_dilate, dilate=2)
295 | )
296 |
297 | def _nostride_dilate(self, m, dilate):
298 | classname = m.__class__.__name__
299 | if classname.find('Conv') != -1:
300 | # the convolution with stride
301 | if m.stride == (2, 2):
302 | m.stride = (1, 1)
303 | if m.kernel_size == (3, 3):
304 | m.dilation = (dilate//2, dilate//2)
305 | m.padding = (dilate//2, dilate//2)
306 | # other convoluions
307 | else:
308 | if m.kernel_size == (3, 3):
309 | m.dilation = (dilate, dilate)
310 | m.padding = (dilate, dilate)
311 |
312 | def forward(self, x, return_feature_maps=False):
313 | if return_feature_maps:
314 | conv_out = []
315 | for i in range(self.total_idx):
316 | x = self.features[i](x)
317 | if i in self.down_idx:
318 | conv_out.append(x)
319 | conv_out.append(x)
320 | return conv_out
321 |
322 | else:
323 | return [self.features(x)]
324 |
325 |
326 | # last conv, deep supervision
327 | class C1DeepSup(nn.Module):
328 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
329 | super(C1DeepSup, self).__init__()
330 | self.use_softmax = use_softmax
331 |
332 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
333 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
334 |
335 | # last conv
336 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
337 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
338 |
339 | def forward(self, conv_out, segSize=None):
340 | conv5 = conv_out[-1]
341 |
342 | x = self.cbr(conv5)
343 | x = self.conv_last(x)
344 |
345 | if self.use_softmax: # is True during inference
346 | x = nn.functional.interpolate(
347 | x, size=segSize, mode='bilinear', align_corners=False)
348 | x = nn.functional.softmax(x, dim=1)
349 | return x
350 |
351 | # deep sup
352 | conv4 = conv_out[-2]
353 | _ = self.cbr_deepsup(conv4)
354 | _ = self.conv_last_deepsup(_)
355 |
356 | x = nn.functional.log_softmax(x, dim=1)
357 | _ = nn.functional.log_softmax(_, dim=1)
358 |
359 | return (x, _)
360 |
361 |
362 | # last conv
363 | class C1(nn.Module):
364 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
365 | super(C1, self).__init__()
366 | self.use_softmax = use_softmax
367 |
368 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
369 |
370 | # last conv
371 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
372 |
373 | def forward(self, conv_out, segSize=None):
374 | conv5 = conv_out[-1]
375 | x = self.cbr(conv5)
376 | x = self.conv_last(x)
377 |
378 | if self.use_softmax: # is True during inference
379 | x = nn.functional.interpolate(
380 | x, size=segSize, mode='bilinear', align_corners=False)
381 | x = nn.functional.softmax(x, dim=1)
382 | else:
383 | x = nn.functional.log_softmax(x, dim=1)
384 |
385 | return x
386 |
387 |
388 | # pyramid pooling
389 | class PPM(nn.Module):
390 | def __init__(self, num_class=150, fc_dim=4096,
391 | use_softmax=False, pool_scales=(1, 2, 3, 6)):
392 | super(PPM, self).__init__()
393 | self.use_softmax = use_softmax
394 |
395 | self.ppm = []
396 | for scale in pool_scales:
397 | self.ppm.append(nn.Sequential(
398 | nn.AdaptiveAvgPool2d(scale),
399 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
400 | BatchNorm2d(512),
401 | nn.ReLU(inplace=True)
402 | ))
403 | self.ppm = nn.ModuleList(self.ppm)
404 |
405 | self.conv_last = nn.Sequential(
406 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
407 | kernel_size=3, padding=1, bias=False),
408 | BatchNorm2d(512),
409 | nn.ReLU(inplace=True),
410 | nn.Dropout2d(0.1),
411 | nn.Conv2d(512, num_class, kernel_size=1)
412 | )
413 |
414 | def forward(self, conv_out, segSize=None):
415 | conv5 = conv_out[-1]
416 |
417 | input_size = conv5.size()
418 | ppm_out = [conv5]
419 | for pool_scale in self.ppm:
420 | ppm_out.append(nn.functional.interpolate(
421 | pool_scale(conv5),
422 | (input_size[2], input_size[3]),
423 | mode='bilinear', align_corners=False))
424 | ppm_out = torch.cat(ppm_out, 1)
425 |
426 | x = self.conv_last(ppm_out)
427 |
428 | if self.use_softmax: # is True during inference
429 | x = nn.functional.interpolate(
430 | x, size=segSize, mode='bilinear', align_corners=False)
431 | x = nn.functional.softmax(x, dim=1)
432 | else:
433 | x = nn.functional.log_softmax(x, dim=1)
434 | return x
435 |
436 |
437 | # pyramid pooling, deep supervision
438 | class PPMDeepsup(nn.Module):
439 | def __init__(self, num_class=150, fc_dim=4096,
440 | use_softmax=False, pool_scales=(1, 2, 3, 6)):
441 | super(PPMDeepsup, self).__init__()
442 | self.use_softmax = use_softmax
443 |
444 | self.ppm = []
445 | for scale in pool_scales:
446 | self.ppm.append(nn.Sequential(
447 | nn.AdaptiveAvgPool2d(scale),
448 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
449 | BatchNorm2d(512),
450 | nn.ReLU(inplace=True)
451 | ))
452 | self.ppm = nn.ModuleList(self.ppm)
453 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
454 |
455 | self.conv_last = nn.Sequential(
456 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
457 | kernel_size=3, padding=1, bias=False),
458 | BatchNorm2d(512),
459 | nn.ReLU(inplace=True),
460 | nn.Dropout2d(0.1),
461 | nn.Conv2d(512, num_class, kernel_size=1)
462 | )
463 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
464 | self.dropout_deepsup = nn.Dropout2d(0.1)
465 |
466 | def forward(self, conv_out, segSize=None):
467 | conv5 = conv_out[-1]
468 |
469 | input_size = conv5.size()
470 | ppm_out = [conv5]
471 | for pool_scale in self.ppm:
472 | ppm_out.append(nn.functional.interpolate(
473 | pool_scale(conv5),
474 | (input_size[2], input_size[3]),
475 | mode='bilinear', align_corners=False))
476 | ppm_out = torch.cat(ppm_out, 1)
477 |
478 | x = self.conv_last(ppm_out)
479 |
480 | if self.use_softmax: # is True during inference
481 | x = nn.functional.interpolate(
482 | x, size=segSize, mode='bilinear', align_corners=False)
483 | x = nn.functional.softmax(x, dim=1)
484 | return x
485 |
486 | # deep sup
487 | conv4 = conv_out[-2]
488 | _ = self.cbr_deepsup(conv4)
489 | _ = self.dropout_deepsup(_)
490 | _ = self.conv_last_deepsup(_)
491 |
492 | x = nn.functional.log_softmax(x, dim=1)
493 | _ = nn.functional.log_softmax(_, dim=1)
494 |
495 | return (x, _)
496 |
497 |
498 | # upernet
499 | class UPerNet(nn.Module):
500 | def __init__(self, num_class=150, fc_dim=4096,
501 | use_softmax=False, pool_scales=(1, 2, 3, 6),
502 | fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256):
503 | super(UPerNet, self).__init__()
504 | self.use_softmax = use_softmax
505 |
506 | # PPM Module
507 | self.ppm_pooling = []
508 | self.ppm_conv = []
509 |
510 | for scale in pool_scales:
511 | self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
512 | self.ppm_conv.append(nn.Sequential(
513 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
514 | BatchNorm2d(512),
515 | nn.ReLU(inplace=True)
516 | ))
517 | self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
518 | self.ppm_conv = nn.ModuleList(self.ppm_conv)
519 | self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1)
520 |
521 | # FPN Module
522 | self.fpn_in = []
523 | for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
524 | self.fpn_in.append(nn.Sequential(
525 | nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
526 | BatchNorm2d(fpn_dim),
527 | nn.ReLU(inplace=True)
528 | ))
529 | self.fpn_in = nn.ModuleList(self.fpn_in)
530 |
531 | self.fpn_out = []
532 | for i in range(len(fpn_inplanes) - 1): # skip the top layer
533 | self.fpn_out.append(nn.Sequential(
534 | conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
535 | ))
536 | self.fpn_out = nn.ModuleList(self.fpn_out)
537 |
538 | self.conv_last = nn.Sequential(
539 | conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
540 | nn.Conv2d(fpn_dim, num_class, kernel_size=1)
541 | )
542 |
543 | def forward(self, conv_out, segSize=None):
544 | conv5 = conv_out[-1]
545 |
546 | input_size = conv5.size()
547 | ppm_out = [conv5]
548 | for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
549 | ppm_out.append(pool_conv(nn.functional.interpolate(
550 | pool_scale(conv5),
551 | (input_size[2], input_size[3]),
552 | mode='bilinear', align_corners=False)))
553 | ppm_out = torch.cat(ppm_out, 1)
554 | f = self.ppm_last_conv(ppm_out)
555 |
556 | fpn_feature_list = [f]
557 | for i in reversed(range(len(conv_out) - 1)):
558 | conv_x = conv_out[i]
559 | conv_x = self.fpn_in[i](conv_x) # lateral branch
560 |
561 | f = nn.functional.interpolate(
562 | f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch
563 | f = conv_x + f
564 |
565 | fpn_feature_list.append(self.fpn_out[i](f))
566 |
567 | fpn_feature_list.reverse() # [P2 - P5]
568 | output_size = fpn_feature_list[0].size()[2:]
569 | fusion_list = [fpn_feature_list[0]]
570 | for i in range(1, len(fpn_feature_list)):
571 | fusion_list.append(nn.functional.interpolate(
572 | fpn_feature_list[i],
573 | output_size,
574 | mode='bilinear', align_corners=False))
575 | fusion_out = torch.cat(fusion_list, 1)
576 | x = self.conv_last(fusion_out)
577 |
578 | if self.use_softmax: # is True during inference
579 | x = nn.functional.interpolate(
580 | x, size=segSize, mode='bilinear', align_corners=False)
581 | x = nn.functional.softmax(x, dim=1)
582 | return x
583 |
584 | x = nn.functional.log_softmax(x, dim=1)
585 |
586 | return x
587 |
--------------------------------------------------------------------------------
/mit_semseg/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | from .utils import load_url
4 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d
5 | BatchNorm2d = SynchronizedBatchNorm2d
6 |
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon!
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth',
13 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
14 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth'
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | "3x3 convolution with padding"
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None):
28 | super(BasicBlock, self).__init__()
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn1 = BatchNorm2d(planes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.conv2 = conv3x3(planes, planes)
33 | self.bn2 = BatchNorm2d(planes)
34 | self.downsample = downsample
35 | self.stride = stride
36 |
37 | def forward(self, x):
38 | residual = x
39 |
40 | out = self.conv1(x)
41 | out = self.bn1(out)
42 | out = self.relu(out)
43 |
44 | out = self.conv2(out)
45 | out = self.bn2(out)
46 |
47 | if self.downsample is not None:
48 | residual = self.downsample(x)
49 |
50 | out += residual
51 | out = self.relu(out)
52 |
53 | return out
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 4
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None):
60 | super(Bottleneck, self).__init__()
61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62 | self.bn1 = BatchNorm2d(planes)
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64 | padding=1, bias=False)
65 | self.bn2 = BatchNorm2d(planes)
66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
67 | self.bn3 = BatchNorm2d(planes * 4)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x):
73 | residual = x
74 |
75 | out = self.conv1(x)
76 | out = self.bn1(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class ResNet(nn.Module):
96 |
97 | def __init__(self, block, layers, num_classes=1000):
98 | self.inplanes = 128
99 | super(ResNet, self).__init__()
100 | self.conv1 = conv3x3(3, 64, stride=2)
101 | self.bn1 = BatchNorm2d(64)
102 | self.relu1 = nn.ReLU(inplace=True)
103 | self.conv2 = conv3x3(64, 64)
104 | self.bn2 = BatchNorm2d(64)
105 | self.relu2 = nn.ReLU(inplace=True)
106 | self.conv3 = conv3x3(64, 128)
107 | self.bn3 = BatchNorm2d(128)
108 | self.relu3 = nn.ReLU(inplace=True)
109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
110 |
111 | self.layer1 = self._make_layer(block, 64, layers[0])
112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
115 | self.avgpool = nn.AvgPool2d(7, stride=1)
116 | self.fc = nn.Linear(512 * block.expansion, num_classes)
117 |
118 | for m in self.modules():
119 | if isinstance(m, nn.Conv2d):
120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
121 | m.weight.data.normal_(0, math.sqrt(2. / n))
122 | elif isinstance(m, BatchNorm2d):
123 | m.weight.data.fill_(1)
124 | m.bias.data.zero_()
125 |
126 | def _make_layer(self, block, planes, blocks, stride=1):
127 | downsample = None
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = nn.Sequential(
130 | nn.Conv2d(self.inplanes, planes * block.expansion,
131 | kernel_size=1, stride=stride, bias=False),
132 | BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample))
137 | self.inplanes = planes * block.expansion
138 | for i in range(1, blocks):
139 | layers.append(block(self.inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.relu1(self.bn1(self.conv1(x)))
145 | x = self.relu2(self.bn2(self.conv2(x)))
146 | x = self.relu3(self.bn3(self.conv3(x)))
147 | x = self.maxpool(x)
148 |
149 | x = self.layer1(x)
150 | x = self.layer2(x)
151 | x = self.layer3(x)
152 | x = self.layer4(x)
153 |
154 | x = self.avgpool(x)
155 | x = x.view(x.size(0), -1)
156 | x = self.fc(x)
157 |
158 | return x
159 |
160 | def resnet18(pretrained=False, **kwargs):
161 | """Constructs a ResNet-18 model.
162 |
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | """
166 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
167 | if pretrained:
168 | model.load_state_dict(load_url(model_urls['resnet18']))
169 | return model
170 |
171 | '''
172 | def resnet34(pretrained=False, **kwargs):
173 | """Constructs a ResNet-34 model.
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | """
178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
179 | if pretrained:
180 | model.load_state_dict(load_url(model_urls['resnet34']))
181 | return model
182 | '''
183 |
184 | def resnet50(pretrained=False, **kwargs):
185 | """Constructs a ResNet-50 model.
186 |
187 | Args:
188 | pretrained (bool): If True, returns a model pre-trained on ImageNet
189 | """
190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
191 | if pretrained:
192 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
193 | return model
194 |
195 |
196 | def resnet101(pretrained=False, **kwargs):
197 | """Constructs a ResNet-101 model.
198 |
199 | Args:
200 | pretrained (bool): If True, returns a model pre-trained on ImageNet
201 | """
202 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
203 | if pretrained:
204 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False)
205 | return model
206 |
207 | # def resnet152(pretrained=False, **kwargs):
208 | # """Constructs a ResNet-152 model.
209 | #
210 | # Args:
211 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
212 | # """
213 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
214 | # if pretrained:
215 | # model.load_state_dict(load_url(model_urls['resnet152']))
216 | # return model
217 |
--------------------------------------------------------------------------------
/mit_semseg/models/resnext.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | from .utils import load_url
4 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d
5 | BatchNorm2d = SynchronizedBatchNorm2d
6 |
7 |
8 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101
9 |
10 |
11 | model_urls = {
12 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth',
13 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth'
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | "3x3 convolution with padding"
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20 | padding=1, bias=False)
21 |
22 |
23 | class GroupBottleneck(nn.Module):
24 | expansion = 2
25 |
26 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None):
27 | super(GroupBottleneck, self).__init__()
28 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
29 | self.bn1 = BatchNorm2d(planes)
30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
31 | padding=1, groups=groups, bias=False)
32 | self.bn2 = BatchNorm2d(planes)
33 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False)
34 | self.bn3 = BatchNorm2d(planes * 2)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.downsample = downsample
37 | self.stride = stride
38 |
39 | def forward(self, x):
40 | residual = x
41 |
42 | out = self.conv1(x)
43 | out = self.bn1(out)
44 | out = self.relu(out)
45 |
46 | out = self.conv2(out)
47 | out = self.bn2(out)
48 | out = self.relu(out)
49 |
50 | out = self.conv3(out)
51 | out = self.bn3(out)
52 |
53 | if self.downsample is not None:
54 | residual = self.downsample(x)
55 |
56 | out += residual
57 | out = self.relu(out)
58 |
59 | return out
60 |
61 |
62 | class ResNeXt(nn.Module):
63 |
64 | def __init__(self, block, layers, groups=32, num_classes=1000):
65 | self.inplanes = 128
66 | super(ResNeXt, self).__init__()
67 | self.conv1 = conv3x3(3, 64, stride=2)
68 | self.bn1 = BatchNorm2d(64)
69 | self.relu1 = nn.ReLU(inplace=True)
70 | self.conv2 = conv3x3(64, 64)
71 | self.bn2 = BatchNorm2d(64)
72 | self.relu2 = nn.ReLU(inplace=True)
73 | self.conv3 = conv3x3(64, 128)
74 | self.bn3 = BatchNorm2d(128)
75 | self.relu3 = nn.ReLU(inplace=True)
76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
77 |
78 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups)
79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups)
80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups)
81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups)
82 | self.avgpool = nn.AvgPool2d(7, stride=1)
83 | self.fc = nn.Linear(1024 * block.expansion, num_classes)
84 |
85 | for m in self.modules():
86 | if isinstance(m, nn.Conv2d):
87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups
88 | m.weight.data.normal_(0, math.sqrt(2. / n))
89 | elif isinstance(m, BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 | def _make_layer(self, block, planes, blocks, stride=1, groups=1):
94 | downsample = None
95 | if stride != 1 or self.inplanes != planes * block.expansion:
96 | downsample = nn.Sequential(
97 | nn.Conv2d(self.inplanes, planes * block.expansion,
98 | kernel_size=1, stride=stride, bias=False),
99 | BatchNorm2d(planes * block.expansion),
100 | )
101 |
102 | layers = []
103 | layers.append(block(self.inplanes, planes, stride, groups, downsample))
104 | self.inplanes = planes * block.expansion
105 | for i in range(1, blocks):
106 | layers.append(block(self.inplanes, planes, groups=groups))
107 |
108 | return nn.Sequential(*layers)
109 |
110 | def forward(self, x):
111 | x = self.relu1(self.bn1(self.conv1(x)))
112 | x = self.relu2(self.bn2(self.conv2(x)))
113 | x = self.relu3(self.bn3(self.conv3(x)))
114 | x = self.maxpool(x)
115 |
116 | x = self.layer1(x)
117 | x = self.layer2(x)
118 | x = self.layer3(x)
119 | x = self.layer4(x)
120 |
121 | x = self.avgpool(x)
122 | x = x.view(x.size(0), -1)
123 | x = self.fc(x)
124 |
125 | return x
126 |
127 |
128 | '''
129 | def resnext50(pretrained=False, **kwargs):
130 | """Constructs a ResNet-50 model.
131 |
132 | Args:
133 | pretrained (bool): If True, returns a model pre-trained on Places
134 | """
135 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs)
136 | if pretrained:
137 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False)
138 | return model
139 | '''
140 |
141 |
142 | def resnext101(pretrained=False, **kwargs):
143 | """Constructs a ResNet-101 model.
144 |
145 | Args:
146 | pretrained (bool): If True, returns a model pre-trained on Places
147 | """
148 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs)
149 | if pretrained:
150 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False)
151 | return model
152 |
153 |
154 | # def resnext152(pretrained=False, **kwargs):
155 | # """Constructs a ResNeXt-152 model.
156 | #
157 | # Args:
158 | # pretrained (bool): If True, returns a model pre-trained on Places
159 | # """
160 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs)
161 | # if pretrained:
162 | # model.load_state_dict(load_url(model_urls['resnext152']))
163 | # return model
164 |
--------------------------------------------------------------------------------
/mit_semseg/models/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | try:
4 | from urllib import urlretrieve
5 | except ImportError:
6 | from urllib.request import urlretrieve
7 | import torch
8 |
9 |
10 | def load_url(url, model_dir='./pretrained', map_location=None):
11 | if not os.path.exists(model_dir):
12 | os.makedirs(model_dir)
13 | filename = url.split('/')[-1]
14 | cached_file = os.path.join(model_dir, filename)
15 | if not os.path.exists(cached_file):
16 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
17 | urlretrieve(url, cached_file)
18 | return torch.load(cached_file, map_location=map_location)
19 |
--------------------------------------------------------------------------------
/mit_semseg/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import logging
4 | import re
5 | import functools
6 | import fnmatch
7 | import numpy as np
8 |
9 |
10 | def setup_logger(distributed_rank=0, filename="log.txt"):
11 | logger = logging.getLogger("Logger")
12 | logger.setLevel(logging.DEBUG)
13 | # don't log results for the non-master process
14 | if distributed_rank > 0:
15 | return logger
16 | ch = logging.StreamHandler(stream=sys.stdout)
17 | ch.setLevel(logging.DEBUG)
18 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
19 | ch.setFormatter(logging.Formatter(fmt))
20 | logger.addHandler(ch)
21 |
22 | return logger
23 |
24 |
25 | def find_recursive(root_dir, ext='.jpg'):
26 | files = []
27 | for root, dirnames, filenames in os.walk(root_dir):
28 | for filename in fnmatch.filter(filenames, '*' + ext):
29 | files.append(os.path.join(root, filename))
30 | return files
31 |
32 |
33 | class AverageMeter(object):
34 | """Computes and stores the average and current value"""
35 | def __init__(self):
36 | self.initialized = False
37 | self.val = None
38 | self.avg = None
39 | self.sum = None
40 | self.count = None
41 |
42 | def initialize(self, val, weight):
43 | self.val = val
44 | self.avg = val
45 | self.sum = val * weight
46 | self.count = weight
47 | self.initialized = True
48 |
49 | def update(self, val, weight=1):
50 | if not self.initialized:
51 | self.initialize(val, weight)
52 | else:
53 | self.add(val, weight)
54 |
55 | def add(self, val, weight):
56 | self.val = val
57 | self.sum += val * weight
58 | self.count += weight
59 | self.avg = self.sum / self.count
60 |
61 | def value(self):
62 | return self.val
63 |
64 | def average(self):
65 | return self.avg
66 |
67 |
68 | def unique(ar, return_index=False, return_inverse=False, return_counts=False):
69 | ar = np.asanyarray(ar).flatten()
70 |
71 | optional_indices = return_index or return_inverse
72 | optional_returns = optional_indices or return_counts
73 |
74 | if ar.size == 0:
75 | if not optional_returns:
76 | ret = ar
77 | else:
78 | ret = (ar,)
79 | if return_index:
80 | ret += (np.empty(0, np.bool),)
81 | if return_inverse:
82 | ret += (np.empty(0, np.bool),)
83 | if return_counts:
84 | ret += (np.empty(0, np.intp),)
85 | return ret
86 | if optional_indices:
87 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
88 | aux = ar[perm]
89 | else:
90 | ar.sort()
91 | aux = ar
92 | flag = np.concatenate(([True], aux[1:] != aux[:-1]))
93 |
94 | if not optional_returns:
95 | ret = aux[flag]
96 | else:
97 | ret = (aux[flag],)
98 | if return_index:
99 | ret += (perm[flag],)
100 | if return_inverse:
101 | iflag = np.cumsum(flag) - 1
102 | inv_idx = np.empty(ar.shape, dtype=np.intp)
103 | inv_idx[perm] = iflag
104 | ret += (inv_idx,)
105 | if return_counts:
106 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
107 | ret += (np.diff(idx),)
108 | return ret
109 |
110 |
111 | def colorEncode(labelmap, colors, mode='RGB'):
112 | labelmap = labelmap.astype('int')
113 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
114 | dtype=np.uint8)
115 | for label in unique(labelmap):
116 | if label < 0:
117 | continue
118 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
119 | np.tile(colors[label],
120 | (labelmap.shape[0], labelmap.shape[1], 1))
121 |
122 | if mode == 'BGR':
123 | return labelmap_rgb[:, :, ::-1]
124 | else:
125 | return labelmap_rgb
126 |
127 |
128 | def accuracy(preds, label):
129 | valid = (label >= 0)
130 | acc_sum = (valid * (preds == label)).sum()
131 | valid_sum = valid.sum()
132 | acc = float(acc_sum) / (valid_sum + 1e-10)
133 | return acc, valid_sum
134 |
135 |
136 | def intersectionAndUnion(imPred, imLab, numClass):
137 | imPred = np.asarray(imPred).copy()
138 | imLab = np.asarray(imLab).copy()
139 |
140 | imPred += 1
141 | imLab += 1
142 | # Remove classes from unlabeled pixels in gt image.
143 | # We should not penalize detections in unlabeled portions of the image.
144 | imPred = imPred * (imLab > 0)
145 |
146 | # Compute area intersection:
147 | intersection = imPred * (imPred == imLab)
148 | (area_intersection, _) = np.histogram(
149 | intersection, bins=numClass, range=(1, numClass))
150 |
151 | # Compute area union:
152 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
153 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
154 | area_union = area_pred + area_lab - area_intersection
155 |
156 | return (area_intersection, area_union)
157 |
158 |
159 | class NotSupportedCliException(Exception):
160 | pass
161 |
162 |
163 | def process_range(xpu, inp):
164 | start, end = map(int, inp)
165 | if start > end:
166 | end, start = start, end
167 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1))
168 |
169 |
170 | REGEX = [
171 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]),
172 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]),
173 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'),
174 | functools.partial(process_range, 'gpu')),
175 | (re.compile(r'^(\d+)-(\d+)$'),
176 | functools.partial(process_range, 'gpu')),
177 | ]
178 |
179 |
180 | def parse_devices(input_devices):
181 |
182 | """Parse user's devices input str to standard format.
183 | e.g. [gpu0, gpu1, ...]
184 |
185 | """
186 | ret = []
187 | for d in input_devices.split(','):
188 | for regex, func in REGEX:
189 | m = regex.match(d.lower().strip())
190 | if m:
191 | tmp = func(m.groups())
192 | # prevent duplicate
193 | for x in tmp:
194 | if x not in ret:
195 | ret.append(x)
196 | break
197 | else:
198 | raise NotSupportedCliException(
199 | 'Can not recognize device: "{}"'.format(d))
200 | return ret
201 |
--------------------------------------------------------------------------------
/notebooks/DemoSegmenter.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Semantic Segmentation Demo\n",
8 | "\n",
9 | "This is a notebook for running the benchmark semantic segmentation network from the the [ADE20K MIT Scene Parsing Benchchmark](http://sceneparsing.csail.mit.edu/).\n",
10 | "\n",
11 | "The code for this notebook is available here\n",
12 | "https://github.com/CSAILVision/semantic-segmentation-pytorch/tree/master/notebooks\n",
13 | "\n",
14 | "It can be run on Colab at this URL https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb"
15 | ]
16 | },
17 | {
18 | "cell_type": "markdown",
19 | "metadata": {},
20 | "source": [
21 | "### Environment Setup\n",
22 | "\n",
23 | "First, download the code and pretrained models if we are on colab."
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "%%bash\n",
33 | "# Colab-specific setup\n",
34 | "!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit \n",
35 | "pip install yacs 2>&1 >> install.log\n",
36 | "git init 2>&1 >> install.log\n",
37 | "git remote add origin https://github.com/CSAILVision/semantic-segmentation-pytorch.git 2>> install.log\n",
38 | "git pull origin master 2>&1 >> install.log\n",
39 | "DOWNLOAD_ONLY=1 ./demo_test.sh 2>> install.log"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "## Imports and utility functions\n",
47 | "\n",
48 | "We need pytorch, numpy, and the code for the segmentation model. And some utilities for visualizing the data."
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "# System libs\n",
58 | "import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms\n",
59 | "# Our libs\n",
60 | "from mit_semseg.models import ModelBuilder, SegmentationModule\n",
61 | "from mit_semseg.utils import colorEncode\n",
62 | "\n",
63 | "colors = scipy.io.loadmat('data/color150.mat')['colors']\n",
64 | "names = {}\n",
65 | "with open('data/object150_info.csv') as f:\n",
66 | " reader = csv.reader(f)\n",
67 | " next(reader)\n",
68 | " for row in reader:\n",
69 | " names[int(row[0])] = row[5].split(\";\")[0]\n",
70 | "\n",
71 | "def visualize_result(img, pred, index=None):\n",
72 | " # filter prediction class if requested\n",
73 | " if index is not None:\n",
74 | " pred = pred.copy()\n",
75 | " pred[pred != index] = -1\n",
76 | " print(f'{names[index+1]}:')\n",
77 | " \n",
78 | " # colorize prediction\n",
79 | " pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n",
80 | "\n",
81 | " # aggregate images and save\n",
82 | " im_vis = numpy.concatenate((img, pred_color), axis=1)\n",
83 | " display(PIL.Image.fromarray(im_vis))"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "metadata": {},
89 | "source": [
90 | "## Loading the segmentation model\n",
91 | "\n",
92 | "Here we load a pretrained segmentation model. Like any pytorch model, we can call it like a function, or examine the parameters in all the layers.\n",
93 | "\n",
94 | "After loading, we put it on the GPU. And since we are doing inference, not training, we put the model in eval mode."
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "# Network Builders\n",
104 | "net_encoder = ModelBuilder.build_encoder(\n",
105 | " arch='resnet50dilated',\n",
106 | " fc_dim=2048,\n",
107 | " weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')\n",
108 | "net_decoder = ModelBuilder.build_decoder(\n",
109 | " arch='ppm_deepsup',\n",
110 | " fc_dim=2048,\n",
111 | " num_class=150,\n",
112 | " weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',\n",
113 | " use_softmax=True)\n",
114 | "\n",
115 | "crit = torch.nn.NLLLoss(ignore_index=-1)\n",
116 | "segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)\n",
117 | "segmentation_module.eval()\n",
118 | "segmentation_module.cuda()"
119 | ]
120 | },
121 | {
122 | "cell_type": "markdown",
123 | "metadata": {},
124 | "source": [
125 | "## Load test data\n",
126 | "\n",
127 | "Now we load and normalize a single test image. Here we use the commonplace convention of normalizing the image to a scale for which the RGB values of a large photo dataset would have zero mean and unit standard deviation. (These numbers come from the imagenet dataset.) With this normalization, the limiiting ranges of RGB values are within about (-2.2 to +2.7)."
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "# Load and normalize one image as a singleton tensor batch\n",
137 | "pil_to_tensor = torchvision.transforms.Compose([\n",
138 | " torchvision.transforms.ToTensor(),\n",
139 | " torchvision.transforms.Normalize(\n",
140 | " mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n",
141 | " std=[0.229, 0.224, 0.225]) # across a large photo dataset.\n",
142 | "])\n",
143 | "pil_image = PIL.Image.open('ADE_val_00001519.jpg').convert('RGB')\n",
144 | "img_original = numpy.array(pil_image)\n",
145 | "img_data = pil_to_tensor(pil_image)\n",
146 | "singleton_batch = {'img_data': img_data[None].cuda()}\n",
147 | "output_size = img_data.shape[1:]"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "metadata": {},
153 | "source": [
154 | "## Run the Model\n",
155 | "\n",
156 | "Finally we just pass the test image to the segmentation model.\n",
157 | "\n",
158 | "The segmentation model is coded as a function that takes a dictionary as input, because it wants to know both the input batch image data as well as the desired output segmentation resolution. We ask for full resolution output.\n",
159 | "\n",
160 | "Then we use the previously-defined visualize_result function to render the segmentation map."
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": null,
166 | "metadata": {
167 | "scrolled": false
168 | },
169 | "outputs": [],
170 | "source": [
171 | "# Run the segmentation at the highest resolution.\n",
172 | "with torch.no_grad():\n",
173 | " scores = segmentation_module(singleton_batch, segSize=output_size)\n",
174 | " \n",
175 | "# Get the predicted scores for each pixel\n",
176 | "_, pred = torch.max(scores, dim=1)\n",
177 | "pred = pred.cpu()[0].numpy()\n",
178 | "visualize_result(img_original, pred)"
179 | ]
180 | },
181 | {
182 | "cell_type": "markdown",
183 | "metadata": {},
184 | "source": [
185 | "## Showing classes individually\n",
186 | "\n",
187 | "To see which colors are which, here we visualize individual classes, one at a time."
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": null,
193 | "metadata": {},
194 | "outputs": [],
195 | "source": [
196 | "# Top classes in answer\n",
197 | "predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]\n",
198 | "for c in predicted_classes[:15]:\n",
199 | " visualize_result(img_original, pred, c)"
200 | ]
201 | }
202 | ],
203 | "metadata": {
204 | "accelerator": "GPU",
205 | "kernelspec": {
206 | "display_name": "Python 3",
207 | "language": "python",
208 | "name": "python3"
209 | },
210 | "language_info": {
211 | "codemirror_mode": {
212 | "name": "ipython",
213 | "version": 3
214 | },
215 | "file_extension": ".py",
216 | "mimetype": "text/x-python",
217 | "name": "python",
218 | "nbconvert_exporter": "python",
219 | "pygments_lexer": "ipython3",
220 | "version": "3.6.7"
221 | }
222 | },
223 | "nbformat": 4,
224 | "nbformat_minor": 2
225 | }
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | Semantic Segmentation Demo
2 | ==========================
3 |
4 | This directory contains a notebook for demonstrating the benchmark
5 | semantic segmentation network from the the ADE20K MIT Scene Parsing
6 | Benchchmark.
7 |
8 | It can be run on Colab at
9 | [this URL](https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb)
10 | or on a local Jupyter notebook.
11 |
12 | If running locally, run the script `setup_notebooks.sh` to start.
13 |
--------------------------------------------------------------------------------
/notebooks/ckpt:
--------------------------------------------------------------------------------
1 | ../ckpt
--------------------------------------------------------------------------------
/notebooks/config:
--------------------------------------------------------------------------------
1 | ../config
--------------------------------------------------------------------------------
/notebooks/data:
--------------------------------------------------------------------------------
1 | ../data
--------------------------------------------------------------------------------
/notebooks/ipynb_drop_output.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | Suppress output and prompt numbers in git version control.
5 |
6 | This script will tell git to ignore prompt numbers and cell output
7 | when looking at ipynb files UNLESS their metadata contains:
8 |
9 | "git" : { "keep_output" : true }
10 |
11 | The notebooks themselves are not changed.
12 |
13 | See also this blogpost: http://pascalbugnion.net/blog/ipython-notebooks-and-git.html.
14 |
15 | Usage instructions
16 | ==================
17 |
18 | 1. Put this script in a directory that is on the system's path.
19 | For future reference, I will assume you saved it in
20 | `~/scripts/ipynb_drop_output`.
21 | 2. Make sure it is executable by typing the command
22 | `chmod +x ~/scripts/ipynb_drop_output`.
23 | 3. Register a filter for ipython notebooks by
24 | putting the following line in `~/.config/git/attributes`:
25 | `*.ipynb filter=clean_ipynb`
26 | 4. Connect this script to the filter by running the following
27 | git commands:
28 |
29 | git config --global filter.clean_ipynb.clean ipynb_drop_output
30 | git config --global filter.clean_ipynb.smudge cat
31 |
32 | To tell git NOT to ignore the output and prompts for a notebook,
33 | open the notebook's metadata (Edit > Edit Notebook Metadata). A
34 | panel should open containing the lines:
35 |
36 | {
37 | "name" : "",
38 | "signature" : "some very long hash"
39 | }
40 |
41 | Add an extra line so that the metadata now looks like:
42 |
43 | {
44 | "name" : "",
45 | "signature" : "don't change the hash, but add a comma at the end of the line",
46 | "git" : { "keep_outputs" : true }
47 | }
48 |
49 | You may need to "touch" the notebooks for git to actually register a change, if
50 | your notebooks are already under version control.
51 |
52 | Notes
53 | =====
54 |
55 | Changed by David Bau to make stripping output the default.
56 |
57 | This script is inspired by http://stackoverflow.com/a/20844506/827862, but
58 | lets the user specify whether the ouptut of a notebook should be kept
59 | in the notebook's metadata, and works for IPython v3.0.
60 | """
61 |
62 | import sys
63 | import json
64 |
65 | nb = sys.stdin.read()
66 |
67 | json_in = json.loads(nb)
68 | nb_metadata = json_in["metadata"]
69 | keep_output = False
70 | if "git" in nb_metadata:
71 | if "keep_outputs" in nb_metadata["git"] and nb_metadata["git"]["keep_outputs"]:
72 | keep_output = True
73 | if keep_output:
74 | sys.stdout.write(nb)
75 | exit()
76 |
77 |
78 | ipy_version = int(json_in["nbformat"])-1 # nbformat is 1 more than actual version.
79 |
80 | def strip_output_from_cell(cell):
81 | if "outputs" in cell:
82 | cell["outputs"] = []
83 | if "prompt_number" in cell:
84 | del cell["prompt_number"]
85 | if "execution_count" in cell:
86 | cell["execution_count"] = None
87 |
88 |
89 | if ipy_version == 2:
90 | for sheet in json_in["worksheets"]:
91 | for cell in sheet["cells"]:
92 | strip_output_from_cell(cell)
93 | else:
94 | for cell in json_in["cells"]:
95 | strip_output_from_cell(cell)
96 |
97 | json.dump(json_in, sys.stdout, sort_keys=True, indent=1, separators=(",",": "))
98 |
--------------------------------------------------------------------------------
/notebooks/mit_semseg:
--------------------------------------------------------------------------------
1 | ../mit_semseg
--------------------------------------------------------------------------------
/notebooks/setup_notebooks.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Start from directory of script
4 | cd "$(dirname "${BASH_SOURCE[0]}")"
5 |
6 | # Set up git config filters so huge output of notebooks is not committed.
7 | git config filter.clean_ipynb.clean "$(pwd)/ipynb_drop_output.py"
8 | git config filter.clean_ipynb.smudge cat
9 | git config filter.clean_ipynb.required true
10 |
11 | # Set up symlinks for the example notebooks
12 | for DIRNAME in ckpt mit_semseg config notebooks teaser data
13 | do
14 | ln -sfn ../${DIRNAME} .
15 | done
16 |
--------------------------------------------------------------------------------
/notebooks/teaser:
--------------------------------------------------------------------------------
1 | ../teaser
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | pytorch==0.4.1
4 | torchvision
5 | opencv3
6 | yacs
7 | tqdm
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open('README.md', 'r') as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name='mit_semseg',
8 | version='1.0.0',
9 | author='MIT CSAIL',
10 | description='Pytorch implementation for Semantic Segmentation/Scene Parsing on MIT ADE20K dataset',
11 | long_description=long_description,
12 | long_description_content_type='text/markdown',
13 | url='https://github.com/CSAILVision/semantic-segmentation-pytorch',
14 | packages=setuptools.find_packages(),
15 | classifiers=(
16 | 'Programming Language :: Python :: 3',
17 | 'License :: OSI Approved :: BSD License',
18 | 'Operating System :: OS Independent',
19 | ),
20 | install_requires=[
21 | 'numpy',
22 | 'torch>=0.4.1',
23 | 'torchvision',
24 | 'opencv-python',
25 | 'yacs',
26 | 'scipy',
27 | 'tqdm'
28 | ]
29 | )
30 |
--------------------------------------------------------------------------------
/teaser/ADE_val_00000278.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/teaser/ADE_val_00000278.png
--------------------------------------------------------------------------------
/teaser/ADE_val_00001519.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/teaser/ADE_val_00001519.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # System libs
2 | import os
3 | import argparse
4 | from distutils.version import LooseVersion
5 | # Numerical libs
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from scipy.io import loadmat
10 | import csv
11 | # Our libs
12 | from mit_semseg.dataset import TestDataset
13 | from mit_semseg.models import ModelBuilder, SegmentationModule
14 | from mit_semseg.utils import colorEncode, find_recursive, setup_logger
15 | from mit_semseg.lib.nn import user_scattered_collate, async_copy_to
16 | from mit_semseg.lib.utils import as_numpy
17 | from PIL import Image
18 | from tqdm import tqdm
19 | from mit_semseg.config import cfg
20 |
21 | colors = loadmat('data/color150.mat')['colors']
22 | names = {}
23 | with open('data/object150_info.csv') as f:
24 | reader = csv.reader(f)
25 | next(reader)
26 | for row in reader:
27 | names[int(row[0])] = row[5].split(";")[0]
28 |
29 |
30 | def visualize_result(data, pred, cfg):
31 | (img, info) = data
32 |
33 | # print predictions in descending order
34 | pred = np.int32(pred)
35 | pixs = pred.size
36 | uniques, counts = np.unique(pred, return_counts=True)
37 | print("Predictions in [{}]:".format(info))
38 | for idx in np.argsort(counts)[::-1]:
39 | name = names[uniques[idx] + 1]
40 | ratio = counts[idx] / pixs * 100
41 | if ratio > 0.1:
42 | print(" {}: {:.2f}%".format(name, ratio))
43 |
44 | # colorize prediction
45 | pred_color = colorEncode(pred, colors).astype(np.uint8)
46 |
47 | # aggregate images and save
48 | im_vis = np.concatenate((img, pred_color), axis=1)
49 |
50 | img_name = info.split('/')[-1]
51 | Image.fromarray(im_vis).save(
52 | os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png')))
53 |
54 |
55 | def test(segmentation_module, loader, gpu):
56 | segmentation_module.eval()
57 |
58 | pbar = tqdm(total=len(loader))
59 | for batch_data in loader:
60 | # process data
61 | batch_data = batch_data[0]
62 | segSize = (batch_data['img_ori'].shape[0],
63 | batch_data['img_ori'].shape[1])
64 | img_resized_list = batch_data['img_data']
65 |
66 | with torch.no_grad():
67 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1])
68 | scores = async_copy_to(scores, gpu)
69 |
70 | for img in img_resized_list:
71 | feed_dict = batch_data.copy()
72 | feed_dict['img_data'] = img
73 | del feed_dict['img_ori']
74 | del feed_dict['info']
75 | feed_dict = async_copy_to(feed_dict, gpu)
76 |
77 | # forward pass
78 | pred_tmp = segmentation_module(feed_dict, segSize=segSize)
79 | scores = scores + pred_tmp / len(cfg.DATASET.imgSizes)
80 |
81 | _, pred = torch.max(scores, dim=1)
82 | pred = as_numpy(pred.squeeze(0).cpu())
83 |
84 | # visualization
85 | visualize_result(
86 | (batch_data['img_ori'], batch_data['info']),
87 | pred,
88 | cfg
89 | )
90 |
91 | pbar.update(1)
92 |
93 |
94 | def main(cfg, gpu):
95 | torch.cuda.set_device(gpu)
96 |
97 | # Network Builders
98 | net_encoder = ModelBuilder.build_encoder(
99 | arch=cfg.MODEL.arch_encoder,
100 | fc_dim=cfg.MODEL.fc_dim,
101 | weights=cfg.MODEL.weights_encoder)
102 | net_decoder = ModelBuilder.build_decoder(
103 | arch=cfg.MODEL.arch_decoder,
104 | fc_dim=cfg.MODEL.fc_dim,
105 | num_class=cfg.DATASET.num_class,
106 | weights=cfg.MODEL.weights_decoder,
107 | use_softmax=True)
108 |
109 | crit = nn.NLLLoss(ignore_index=-1)
110 |
111 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
112 |
113 | # Dataset and Loader
114 | dataset_test = TestDataset(
115 | cfg.list_test,
116 | cfg.DATASET)
117 | loader_test = torch.utils.data.DataLoader(
118 | dataset_test,
119 | batch_size=cfg.TEST.batch_size,
120 | shuffle=False,
121 | collate_fn=user_scattered_collate,
122 | num_workers=5,
123 | drop_last=True)
124 |
125 | segmentation_module.cuda()
126 |
127 | # Main loop
128 | test(segmentation_module, loader_test, gpu)
129 |
130 | print('Inference done!')
131 |
132 |
133 | if __name__ == '__main__':
134 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
135 | 'PyTorch>=0.4.0 is required'
136 |
137 | parser = argparse.ArgumentParser(
138 | description="PyTorch Semantic Segmentation Testing"
139 | )
140 | parser.add_argument(
141 | "--imgs",
142 | required=True,
143 | type=str,
144 | help="an image path, or a directory name"
145 | )
146 | parser.add_argument(
147 | "--cfg",
148 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
149 | metavar="FILE",
150 | help="path to config file",
151 | type=str,
152 | )
153 | parser.add_argument(
154 | "--gpu",
155 | default=0,
156 | type=int,
157 | help="gpu id for evaluation"
158 | )
159 | parser.add_argument(
160 | "opts",
161 | help="Modify config options using the command-line",
162 | default=None,
163 | nargs=argparse.REMAINDER,
164 | )
165 | args = parser.parse_args()
166 |
167 | cfg.merge_from_file(args.cfg)
168 | cfg.merge_from_list(args.opts)
169 | # cfg.freeze()
170 |
171 | logger = setup_logger(distributed_rank=0) # TODO
172 | logger.info("Loaded configuration file {}".format(args.cfg))
173 | logger.info("Running with config:\n{}".format(cfg))
174 |
175 | cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower()
176 | cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower()
177 |
178 | # absolute paths of model weights
179 | cfg.MODEL.weights_encoder = os.path.join(
180 | cfg.DIR, 'encoder_' + cfg.TEST.checkpoint)
181 | cfg.MODEL.weights_decoder = os.path.join(
182 | cfg.DIR, 'decoder_' + cfg.TEST.checkpoint)
183 |
184 | assert os.path.exists(cfg.MODEL.weights_encoder) and \
185 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
186 |
187 | # generate testing image list
188 | if os.path.isdir(args.imgs):
189 | imgs = find_recursive(args.imgs)
190 | else:
191 | imgs = [args.imgs]
192 | assert len(imgs), "imgs should be a path to image (.jpg) or directory."
193 | cfg.list_test = [{'fpath_img': x} for x in imgs]
194 |
195 | if not os.path.isdir(cfg.TEST.result):
196 | os.makedirs(cfg.TEST.result)
197 |
198 | main(cfg, args.gpu)
199 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # System libs
2 | import os
3 | import time
4 | # import math
5 | import random
6 | import argparse
7 | from distutils.version import LooseVersion
8 | # Numerical libs
9 | import torch
10 | import torch.nn as nn
11 | # Our libs
12 | from mit_semseg.config import cfg
13 | from mit_semseg.dataset import TrainDataset
14 | from mit_semseg.models import ModelBuilder, SegmentationModule
15 | from mit_semseg.utils import AverageMeter, parse_devices, setup_logger
16 | from mit_semseg.lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback
17 |
18 |
19 | # train one epoch
20 | def train(segmentation_module, iterator, optimizers, history, epoch, cfg):
21 | batch_time = AverageMeter()
22 | data_time = AverageMeter()
23 | ave_total_loss = AverageMeter()
24 | ave_acc = AverageMeter()
25 |
26 | segmentation_module.train(not cfg.TRAIN.fix_bn)
27 |
28 | # main loop
29 | tic = time.time()
30 | for i in range(cfg.TRAIN.epoch_iters):
31 | # load a batch of data
32 | batch_data = next(iterator)
33 | data_time.update(time.time() - tic)
34 | segmentation_module.zero_grad()
35 |
36 | # adjust learning rate
37 | cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters
38 | adjust_learning_rate(optimizers, cur_iter, cfg)
39 |
40 | # forward pass
41 | loss, acc = segmentation_module(batch_data)
42 | loss = loss.mean()
43 | acc = acc.mean()
44 |
45 | # Backward
46 | loss.backward()
47 | for optimizer in optimizers:
48 | optimizer.step()
49 |
50 | # measure elapsed time
51 | batch_time.update(time.time() - tic)
52 | tic = time.time()
53 |
54 | # update average loss and acc
55 | ave_total_loss.update(loss.data.item())
56 | ave_acc.update(acc.data.item()*100)
57 |
58 | # calculate accuracy, and display
59 | if i % cfg.TRAIN.disp_iter == 0:
60 | print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, '
61 | 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, '
62 | 'Accuracy: {:4.2f}, Loss: {:.6f}'
63 | .format(epoch, i, cfg.TRAIN.epoch_iters,
64 | batch_time.average(), data_time.average(),
65 | cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder,
66 | ave_acc.average(), ave_total_loss.average()))
67 |
68 | fractional_epoch = epoch - 1 + 1. * i / cfg.TRAIN.epoch_iters
69 | history['train']['epoch'].append(fractional_epoch)
70 | history['train']['loss'].append(loss.data.item())
71 | history['train']['acc'].append(acc.data.item())
72 |
73 |
74 | def checkpoint(nets, history, cfg, epoch):
75 | print('Saving checkpoints...')
76 | (net_encoder, net_decoder, crit) = nets
77 |
78 | dict_encoder = net_encoder.state_dict()
79 | dict_decoder = net_decoder.state_dict()
80 |
81 | torch.save(
82 | history,
83 | '{}/history_epoch_{}.pth'.format(cfg.DIR, epoch))
84 | torch.save(
85 | dict_encoder,
86 | '{}/encoder_epoch_{}.pth'.format(cfg.DIR, epoch))
87 | torch.save(
88 | dict_decoder,
89 | '{}/decoder_epoch_{}.pth'.format(cfg.DIR, epoch))
90 |
91 |
92 | def group_weight(module):
93 | group_decay = []
94 | group_no_decay = []
95 | for m in module.modules():
96 | if isinstance(m, nn.Linear):
97 | group_decay.append(m.weight)
98 | if m.bias is not None:
99 | group_no_decay.append(m.bias)
100 | elif isinstance(m, nn.modules.conv._ConvNd):
101 | group_decay.append(m.weight)
102 | if m.bias is not None:
103 | group_no_decay.append(m.bias)
104 | elif isinstance(m, nn.modules.batchnorm._BatchNorm):
105 | if m.weight is not None:
106 | group_no_decay.append(m.weight)
107 | if m.bias is not None:
108 | group_no_decay.append(m.bias)
109 |
110 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
111 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
112 | return groups
113 |
114 |
115 | def create_optimizers(nets, cfg):
116 | (net_encoder, net_decoder, crit) = nets
117 | optimizer_encoder = torch.optim.SGD(
118 | group_weight(net_encoder),
119 | lr=cfg.TRAIN.lr_encoder,
120 | momentum=cfg.TRAIN.beta1,
121 | weight_decay=cfg.TRAIN.weight_decay)
122 | optimizer_decoder = torch.optim.SGD(
123 | group_weight(net_decoder),
124 | lr=cfg.TRAIN.lr_decoder,
125 | momentum=cfg.TRAIN.beta1,
126 | weight_decay=cfg.TRAIN.weight_decay)
127 | return (optimizer_encoder, optimizer_decoder)
128 |
129 |
130 | def adjust_learning_rate(optimizers, cur_iter, cfg):
131 | scale_running_lr = ((1. - float(cur_iter) / cfg.TRAIN.max_iters) ** cfg.TRAIN.lr_pow)
132 | cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder * scale_running_lr
133 | cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder * scale_running_lr
134 |
135 | (optimizer_encoder, optimizer_decoder) = optimizers
136 | for param_group in optimizer_encoder.param_groups:
137 | param_group['lr'] = cfg.TRAIN.running_lr_encoder
138 | for param_group in optimizer_decoder.param_groups:
139 | param_group['lr'] = cfg.TRAIN.running_lr_decoder
140 |
141 |
142 | def main(cfg, gpus):
143 | # Network Builders
144 | net_encoder = ModelBuilder.build_encoder(
145 | arch=cfg.MODEL.arch_encoder.lower(),
146 | fc_dim=cfg.MODEL.fc_dim,
147 | weights=cfg.MODEL.weights_encoder)
148 | net_decoder = ModelBuilder.build_decoder(
149 | arch=cfg.MODEL.arch_decoder.lower(),
150 | fc_dim=cfg.MODEL.fc_dim,
151 | num_class=cfg.DATASET.num_class,
152 | weights=cfg.MODEL.weights_decoder)
153 |
154 | crit = nn.NLLLoss(ignore_index=-1)
155 |
156 | if cfg.MODEL.arch_decoder.endswith('deepsup'):
157 | segmentation_module = SegmentationModule(
158 | net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale)
159 | else:
160 | segmentation_module = SegmentationModule(
161 | net_encoder, net_decoder, crit)
162 |
163 | # Dataset and Loader
164 | dataset_train = TrainDataset(
165 | cfg.DATASET.root_dataset,
166 | cfg.DATASET.list_train,
167 | cfg.DATASET,
168 | batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)
169 |
170 | loader_train = torch.utils.data.DataLoader(
171 | dataset_train,
172 | batch_size=len(gpus), # we have modified data_parallel
173 | shuffle=False, # we do not use this param
174 | collate_fn=user_scattered_collate,
175 | num_workers=cfg.TRAIN.workers,
176 | drop_last=True,
177 | pin_memory=True)
178 | print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters))
179 |
180 | # create loader iterator
181 | iterator_train = iter(loader_train)
182 |
183 | # load nets into gpu
184 | if len(gpus) > 1:
185 | segmentation_module = UserScatteredDataParallel(
186 | segmentation_module,
187 | device_ids=gpus)
188 | # For sync bn
189 | patch_replication_callback(segmentation_module)
190 | segmentation_module.cuda()
191 |
192 | # Set up optimizers
193 | nets = (net_encoder, net_decoder, crit)
194 | optimizers = create_optimizers(nets, cfg)
195 |
196 | # Main loop
197 | history = {'train': {'epoch': [], 'loss': [], 'acc': []}}
198 |
199 | for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch):
200 | train(segmentation_module, iterator_train, optimizers, history, epoch+1, cfg)
201 |
202 | # checkpointing
203 | checkpoint(nets, history, cfg, epoch+1)
204 |
205 | print('Training Done!')
206 |
207 |
208 | if __name__ == '__main__':
209 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \
210 | 'PyTorch>=0.4.0 is required'
211 |
212 | parser = argparse.ArgumentParser(
213 | description="PyTorch Semantic Segmentation Training"
214 | )
215 | parser.add_argument(
216 | "--cfg",
217 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml",
218 | metavar="FILE",
219 | help="path to config file",
220 | type=str,
221 | )
222 | parser.add_argument(
223 | "--gpus",
224 | default="0-3",
225 | help="gpus to use, e.g. 0-3 or 0,1,2,3"
226 | )
227 | parser.add_argument(
228 | "opts",
229 | help="Modify config options using the command-line",
230 | default=None,
231 | nargs=argparse.REMAINDER,
232 | )
233 | args = parser.parse_args()
234 |
235 | cfg.merge_from_file(args.cfg)
236 | cfg.merge_from_list(args.opts)
237 | # cfg.freeze()
238 |
239 | logger = setup_logger(distributed_rank=0) # TODO
240 | logger.info("Loaded configuration file {}".format(args.cfg))
241 | logger.info("Running with config:\n{}".format(cfg))
242 |
243 | # Output directory
244 | if not os.path.isdir(cfg.DIR):
245 | os.makedirs(cfg.DIR)
246 | logger.info("Outputing checkpoints to: {}".format(cfg.DIR))
247 | with open(os.path.join(cfg.DIR, 'config.yaml'), 'w') as f:
248 | f.write("{}".format(cfg))
249 |
250 | # Start from checkpoint
251 | if cfg.TRAIN.start_epoch > 0:
252 | cfg.MODEL.weights_encoder = os.path.join(
253 | cfg.DIR, 'encoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
254 | cfg.MODEL.weights_decoder = os.path.join(
255 | cfg.DIR, 'decoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch))
256 | assert os.path.exists(cfg.MODEL.weights_encoder) and \
257 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!"
258 |
259 | # Parse gpu ids
260 | gpus = parse_devices(args.gpus)
261 | gpus = [x.replace('gpu', '') for x in gpus]
262 | gpus = [int(x) for x in gpus]
263 | num_gpus = len(gpus)
264 | cfg.TRAIN.batch_size = num_gpus * cfg.TRAIN.batch_size_per_gpu
265 |
266 | cfg.TRAIN.max_iters = cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch
267 | cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder
268 | cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder
269 |
270 | random.seed(cfg.TRAIN.seed)
271 | torch.manual_seed(cfg.TRAIN.seed)
272 |
273 | main(cfg, gpus)
274 |
--------------------------------------------------------------------------------