├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── config ├── ade20k-hrnetv2.yaml ├── ade20k-mobilenetv2dilated-c1_deepsup.yaml ├── ade20k-resnet101-upernet.yaml ├── ade20k-resnet101dilated-ppm_deepsup.yaml ├── ade20k-resnet18dilated-ppm_deepsup.yaml ├── ade20k-resnet50-upernet.yaml └── ade20k-resnet50dilated-ppm_deepsup.yaml ├── data ├── color150.mat ├── object150_info.csv ├── training.odgt └── validation.odgt ├── demo_test.sh ├── download_ADE20K.sh ├── eval.py ├── eval_multipro.py ├── mit_semseg ├── __init__.py ├── config │ ├── __init__.py │ └── defaults.py ├── dataset.py ├── lib │ ├── __init__.py │ ├── nn │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ ├── tests │ │ │ │ ├── __init__.py │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ └── test_sync_batchnorm.py │ │ │ └── unittest.py │ │ └── parallel │ │ │ ├── __init__.py │ │ │ └── data_parallel.py │ └── utils │ │ ├── __init__.py │ │ ├── data │ │ ├── __init__.py │ │ ├── dataloader.py │ │ ├── dataset.py │ │ ├── distributed.py │ │ └── sampler.py │ │ └── th.py ├── models │ ├── __init__.py │ ├── hrnet.py │ ├── mobilenet.py │ ├── models.py │ ├── resnet.py │ ├── resnext.py │ └── utils.py └── utils.py ├── notebooks ├── DemoSegmenter.ipynb ├── README.md ├── ckpt ├── config ├── data ├── ipynb_drop_output.py ├── mit_semseg ├── setup_notebooks.sh └── teaser ├── requirements.txt ├── setup.py ├── teaser ├── ADE_val_00000278.png └── ADE_val_00001519.png ├── test.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=clean_ipynb 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | ckpt/ 4 | vis/ 5 | log/ 6 | pretrained/ 7 | 8 | .ipynb_checkpoints 9 | 10 | ADE_val*.jpg 11 | ADE_val*.png 12 | 13 | .idea/ 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, MIT CSAIL Computer Vision 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation on MIT ADE20K dataset in PyTorch 2 | 3 | This is a PyTorch implementation of semantic segmentation models on MIT ADE20K scene parsing dataset (http://sceneparsing.csail.mit.edu/). 4 | 5 | ADE20K is the largest open source dataset for semantic segmentation and scene parsing, released by MIT Computer Vision team. Follow the link below to find the repository for our dataset and implementations on Caffe and Torch7: 6 | https://github.com/CSAILVision/sceneparsing 7 | 8 | If you simply want to play with our demo, please try this link: http://scenesegmentation.csail.mit.edu You can upload your own photo and parse it! 9 | 10 | [You can also use this colab notebook playground here](https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb) to tinker with the code for segmenting an image. 11 | 12 | All pretrained models can be found at: 13 | http://sceneparsing.csail.mit.edu/model/pytorch 14 | 15 | 16 | 17 | [From left to right: Test Image, Ground Truth, Predicted Result] 18 | 19 | Color encoding of semantic categories can be found here: 20 | https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing 21 | 22 | ## Updates 23 | - HRNet model is now supported. 24 | - We use configuration files to store most options which were in argument parser. The definitions of options are detailed in ```config/defaults.py```. 25 | - We conform to Pytorch practice in data preprocessing (RGB [0, 1], substract mean, divide std). 26 | 27 | 28 | ## Highlights 29 | 30 | ### Syncronized Batch Normalization on PyTorch 31 | This module computes the mean and standard-deviation across all devices during training. We empirically find that a reasonable large batch size is important for segmentation. We thank [Jiayuan Mao](http://vccy.xyz/) for his kind contributions, please refer to [Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) for details. 32 | 33 | The implementation is easy to use as: 34 | - It is pure-python, no C++ extra extension libs. 35 | - It is completely compatible with PyTorch's implementation. Specifically, it uses unbiased variance to update the moving average, and use sqrt(max(var, eps)) instead of sqrt(var + eps). 36 | - It is efficient, only 20% to 30% slower than UnsyncBN. 37 | 38 | ### Dynamic scales of input for training with multiple GPUs 39 | For the task of semantic segmentation, it is good to keep aspect ratio of images during training. So we re-implement the `DataParallel` module, and make it support distributing data to multiple GPUs in python dict, so that each gpu can process images of different sizes. At the same time, the dataloader also operates differently. 40 | 41 | *Now the batch size of a dataloader always equals to the number of GPUs*, each element will be sent to a GPU. It is also compatible with multi-processing. Note that the file index for the multi-processing dataloader is stored on the master process, which is in contradict to our goal that each worker maintains its own file list. So we use a trick that although the master process still gives dataloader an index for `__getitem__` function, we just ignore such request and send a random batch dict. Also, *the multiple workers forked by the dataloader all have the same seed*, you will find that multiple workers will yield exactly the same data, if we use the above-mentioned trick directly. Therefore, we add one line of code which sets the defaut seed for `numpy.random` before activating multiple worker in dataloader. 42 | 43 | ### State-of-the-Art models 44 | - **PSPNet** is scene parsing network that aggregates global representation with Pyramid Pooling Module (PPM). It is the winner model of ILSVRC'16 MIT Scene Parsing Challenge. Please refer to [https://arxiv.org/abs/1612.01105](https://arxiv.org/abs/1612.01105) for details. 45 | - **UPerNet** is a model based on Feature Pyramid Network (FPN) and Pyramid Pooling Module (PPM). It doesn't need dilated convolution, an operator that is time-and-memory consuming. *Without bells and whistles*, it is comparable or even better compared with PSPNet, while requiring much shorter training time and less GPU memory. Please refer to [https://arxiv.org/abs/1807.10221](https://arxiv.org/abs/1807.10221) for details. 46 | - **HRNet** is a recently proposed model that retains high resolution representations throughout the model, without the traditional bottleneck design. It achieves the SOTA performance on a series of pixel labeling tasks. Please refer to [https://arxiv.org/abs/1904.04514](https://arxiv.org/abs/1904.04514) for details. 47 | 48 | 49 | ## Supported models 50 | We split our models into encoder and decoder, where encoders are usually modified directly from classification networks, and decoders consist of final convolutions and upsampling. We have provided some pre-configured models in the ```config``` folder. 51 | 52 | Encoder: 53 | - MobileNetV2dilated 54 | - ResNet18/ResNet18dilated 55 | - ResNet50/ResNet50dilated 56 | - ResNet101/ResNet101dilated 57 | - HRNetV2 (W48) 58 | 59 | Decoder: 60 | - C1 (one convolution module) 61 | - C1_deepsup (C1 + deep supervision trick) 62 | - PPM (Pyramid Pooling Module, see [PSPNet](https://hszhao.github.io/projects/pspnet) paper for details.) 63 | - PPM_deepsup (PPM + deep supervision trick) 64 | - UPerNet (Pyramid Pooling + FPN head, see [UperNet](https://arxiv.org/abs/1807.10221) for details.) 65 | 66 | ## Performance: 67 | IMPORTANT: The base ResNet in our repository is a customized (different from the one in torchvision). The base models will be automatically downloaded when needed. 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 |
ArchitectureMultiScale TestingMean IoUPixel Accuracy(%)Overall ScoreInference Speed(fps)
MobileNetV2dilated + C1_deepsupNo34.8475.7554.0717.2
Yes33.8476.8055.3210.3
MobileNetV2dilated + PPM_deepsupNo35.7677.7756.2714.9
Yes36.2878.2657.276.7
ResNet18dilated + C1_deepsupNo33.8276.0554.9413.9
Yes35.3477.4156.385.8
ResNet18dilated + PPM_deepsupNo38.0078.6458.3211.7
Yes38.8179.2959.054.2
ResNet50dilated + PPM_deepsupNo41.2679.7360.508.3
Yes42.1480.1361.142.6
ResNet101dilated + PPM_deepsupNo42.1980.5961.396.8
Yes42.5380.9161.722.0
UperNet50No40.4479.8060.128.4
Yes41.5580.2360.892.9
UperNet101No42.0080.7961.407.8
Yes42.6681.0161.842.3
HRNetV2No42.0380.7761.405.8
Yes43.2081.4762.341.9
159 | 160 | The training is benchmarked on a server with 8 NVIDIA Pascal Titan Xp GPUs (12GB GPU memory), the inference speed is benchmarked a single NVIDIA Pascal Titan Xp GPU, without visualization. 161 | 162 | ## Environment 163 | The code is developed under the following configurations. 164 | - Hardware: >=4 GPUs for training, >=1 GPU for testing (set ```[--gpus GPUS]``` accordingly) 165 | - Software: Ubuntu 16.04.3 LTS, ***CUDA>=8.0, Python>=3.5, PyTorch>=0.4.0*** 166 | - Dependencies: numpy, scipy, opencv, yacs, tqdm 167 | 168 | ## Quick start: Test on an image using our trained model 169 | 1. Here is a simple demo to do inference on a single image: 170 | ```bash 171 | chmod +x demo_test.sh 172 | ./demo_test.sh 173 | ``` 174 | This script downloads a trained model (ResNet50dilated + PPM_deepsup) and a test image, runs the test script, and saves predicted segmentation (.png) to the working directory. 175 | 176 | 2. To test on an image or a folder of images (```$PATH_IMG```), you can simply do the following: 177 | ``` 178 | python3 -u test.py --imgs $PATH_IMG --gpu $GPU --cfg $CFG 179 | ``` 180 | 181 | ## Training 182 | 1. Download the ADE20K scene parsing dataset: 183 | ```bash 184 | chmod +x download_ADE20K.sh 185 | ./download_ADE20K.sh 186 | ``` 187 | 2. Train a model by selecting the GPUs (```$GPUS```) and configuration file (```$CFG```) to use. During training, checkpoints by default are saved in folder ```ckpt```. 188 | ```bash 189 | python3 train.py --gpus $GPUS --cfg $CFG 190 | ``` 191 | - To choose which gpus to use, you can either do ```--gpus 0-7```, or ```--gpus 0,2,4,6```. 192 | 193 | For example, you can start with our provided configurations: 194 | 195 | * Train MobileNetV2dilated + C1_deepsup 196 | ```bash 197 | python3 train.py --gpus GPUS --cfg config/ade20k-mobilenetv2dilated-c1_deepsup.yaml 198 | ``` 199 | 200 | * Train ResNet50dilated + PPM_deepsup 201 | ```bash 202 | python3 train.py --gpus GPUS --cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml 203 | ``` 204 | 205 | * Train UPerNet101 206 | ```bash 207 | python3 train.py --gpus GPUS --cfg config/ade20k-resnet101-upernet.yaml 208 | ``` 209 | 210 | 3. You can also override options in commandline, for example ```python3 train.py TRAIN.num_epoch 10 ```. 211 | 212 | 213 | ## Evaluation 214 | 1. Evaluate a trained model on the validation set. Add ```VAL.visualize True``` in argument to output visualizations as shown in teaser. 215 | 216 | For example: 217 | 218 | * Evaluate MobileNetV2dilated + C1_deepsup 219 | ```bash 220 | python3 eval_multipro.py --gpus GPUS --cfg config/ade20k-mobilenetv2dilated-c1_deepsup.yaml 221 | ``` 222 | 223 | * Evaluate ResNet50dilated + PPM_deepsup 224 | ```bash 225 | python3 eval_multipro.py --gpus GPUS --cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml 226 | ``` 227 | 228 | * Evaluate UPerNet101 229 | ```bash 230 | python3 eval_multipro.py --gpus GPUS --cfg config/ade20k-resnet101-upernet.yaml 231 | ``` 232 | 233 | ## Integration with other projects 234 | This library can be installed via `pip` to easily integrate with another codebase 235 | ```bash 236 | pip install git+https://github.com/CSAILVision/semantic-segmentation-pytorch.git@master 237 | ``` 238 | 239 | Now this library can easily be consumed programmatically. For example 240 | ```python 241 | from mit_semseg.config import cfg 242 | from mit_semseg.dataset import TestDataset 243 | from mit_semseg.models import ModelBuilder, SegmentationModule 244 | ``` 245 | 246 | ## Reference 247 | 248 | If you find the code or pre-trained models useful, please cite the following papers: 249 | 250 | Semantic Understanding of Scenes through ADE20K Dataset. B. Zhou, H. Zhao, X. Puig, T. Xiao, S. Fidler, A. Barriuso and A. Torralba. International Journal on Computer Vision (IJCV), 2018. (https://arxiv.org/pdf/1608.05442.pdf) 251 | 252 | @article{zhou2018semantic, 253 | title={Semantic understanding of scenes through the ade20k dataset}, 254 | author={Zhou, Bolei and Zhao, Hang and Puig, Xavier and Xiao, Tete and Fidler, Sanja and Barriuso, Adela and Torralba, Antonio}, 255 | journal={International Journal on Computer Vision}, 256 | year={2018} 257 | } 258 | 259 | Scene Parsing through ADE20K Dataset. B. Zhou, H. Zhao, X. Puig, S. Fidler, A. Barriuso and A. Torralba. Computer Vision and Pattern Recognition (CVPR), 2017. (http://people.csail.mit.edu/bzhou/publication/scene-parse-camera-ready.pdf) 260 | 261 | @inproceedings{zhou2017scene, 262 | title={Scene Parsing through ADE20K Dataset}, 263 | author={Zhou, Bolei and Zhao, Hang and Puig, Xavier and Fidler, Sanja and Barriuso, Adela and Torralba, Antonio}, 264 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 265 | year={2017} 266 | } 267 | 268 | -------------------------------------------------------------------------------- /config/ade20k-hrnetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "hrnetv2" 14 | arch_decoder: "c1" 15 | fc_dim: 720 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 30 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_30.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_30.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-hrnetv2-c1" 43 | -------------------------------------------------------------------------------- /config/ade20k-mobilenetv2dilated-c1_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "mobilenetv2dilated" 14 | arch_decoder: "c1_deepsup" 15 | fc_dim: 320 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 3 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-mobilenetv2dilated-c1_deepsup" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet101-upernet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101" 14 | arch_decoder: "upernet" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 40 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_50.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_50.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet101-upernet" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet101dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet101dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 25 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_25.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_25.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet101dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet18dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet18dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 512 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet18dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet50-upernet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 32 9 | segm_downsampling_rate: 4 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50" 14 | arch_decoder: "upernet" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 30 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_30.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_30.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50-upernet" 43 | -------------------------------------------------------------------------------- /config/ade20k-resnet50dilated-ppm_deepsup.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | root_dataset: "./data/" 3 | list_train: "./data/training.odgt" 4 | list_val: "./data/validation.odgt" 5 | num_class: 150 6 | imgSizes: (300, 375, 450, 525, 600) 7 | imgMaxSize: 1000 8 | padding_constant: 8 9 | segm_downsampling_rate: 8 10 | random_flip: True 11 | 12 | MODEL: 13 | arch_encoder: "resnet50dilated" 14 | arch_decoder: "ppm_deepsup" 15 | fc_dim: 2048 16 | 17 | TRAIN: 18 | batch_size_per_gpu: 2 19 | num_epoch: 20 20 | start_epoch: 0 21 | epoch_iters: 5000 22 | optim: "SGD" 23 | lr_encoder: 0.02 24 | lr_decoder: 0.02 25 | lr_pow: 0.9 26 | beta1: 0.9 27 | weight_decay: 1e-4 28 | deep_sup_scale: 0.4 29 | fix_bn: False 30 | workers: 16 31 | disp_iter: 20 32 | seed: 304 33 | 34 | VAL: 35 | visualize: False 36 | checkpoint: "epoch_20.pth" 37 | 38 | TEST: 39 | checkpoint: "epoch_20.pth" 40 | result: "./" 41 | 42 | DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup" 43 | -------------------------------------------------------------------------------- /data/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/data/color150.mat -------------------------------------------------------------------------------- /data/object150_info.csv: -------------------------------------------------------------------------------- 1 | Idx,Ratio,Train,Val,Stuff,Name 2 | 1,0.1576,11664,1172,1,wall 3 | 2,0.1072,6046,612,1,building;edifice 4 | 3,0.0878,8265,796,1,sky 5 | 4,0.0621,9336,917,1,floor;flooring 6 | 5,0.0480,6678,641,0,tree 7 | 6,0.0450,6604,643,1,ceiling 8 | 7,0.0398,4023,408,1,road;route 9 | 8,0.0231,1906,199,0,bed 10 | 9,0.0198,4688,460,0,windowpane;window 11 | 10,0.0183,2423,225,1,grass 12 | 11,0.0181,2874,294,0,cabinet 13 | 12,0.0166,3068,310,1,sidewalk;pavement 14 | 13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul 15 | 14,0.0151,1804,190,1,earth;ground 16 | 15,0.0118,6666,796,0,door;double;door 17 | 16,0.0110,4269,411,0,table 18 | 17,0.0109,1691,160,1,mountain;mount 19 | 18,0.0104,3999,441,0,plant;flora;plant;life 20 | 19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall 21 | 20,0.0103,3261,318,0,chair 22 | 21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar 23 | 22,0.0074,709,75,1,water 24 | 23,0.0067,3296,315,0,painting;picture 25 | 24,0.0065,1191,106,0,sofa;couch;lounge 26 | 25,0.0061,1516,162,0,shelf 27 | 26,0.0060,667,69,1,house 28 | 27,0.0053,651,57,1,sea 29 | 28,0.0052,1847,224,0,mirror 30 | 29,0.0046,1158,128,1,rug;carpet;carpeting 31 | 30,0.0044,480,44,1,field 32 | 31,0.0044,1172,98,0,armchair 33 | 32,0.0044,1292,184,0,seat 34 | 33,0.0033,1386,138,0,fence;fencing 35 | 34,0.0031,698,61,0,desk 36 | 35,0.0030,781,73,0,rock;stone 37 | 36,0.0027,380,43,0,wardrobe;closet;press 38 | 37,0.0026,3089,302,0,lamp 39 | 38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub 40 | 39,0.0024,804,99,0,railing;rail 41 | 40,0.0023,1453,153,0,cushion 42 | 41,0.0023,411,37,0,base;pedestal;stand 43 | 42,0.0022,1440,162,0,box 44 | 43,0.0022,800,77,0,column;pillar 45 | 44,0.0020,2650,298,0,signboard;sign 46 | 45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser 47 | 46,0.0019,367,36,0,counter 48 | 47,0.0018,311,30,1,sand 49 | 48,0.0018,1181,122,0,sink 50 | 49,0.0018,287,23,1,skyscraper 51 | 50,0.0018,468,38,0,fireplace;hearth;open;fireplace 52 | 51,0.0018,402,43,0,refrigerator;icebox 53 | 52,0.0018,130,12,1,grandstand;covered;stand 54 | 53,0.0018,561,64,1,path 55 | 54,0.0017,880,102,0,stairs;steps 56 | 55,0.0017,86,12,1,runway 57 | 56,0.0017,172,11,0,case;display;case;showcase;vitrine 58 | 57,0.0017,198,18,0,pool;table;billiard;table;snooker;table 59 | 58,0.0017,930,109,0,pillow 60 | 59,0.0015,139,18,0,screen;door;screen 61 | 60,0.0015,564,52,1,stairway;staircase 62 | 61,0.0015,320,26,1,river 63 | 62,0.0015,261,29,1,bridge;span 64 | 63,0.0014,275,22,0,bookcase 65 | 64,0.0014,335,60,0,blind;screen 66 | 65,0.0014,792,75,0,coffee;table;cocktail;table 67 | 66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne 68 | 67,0.0014,1309,138,0,flower 69 | 68,0.0013,1112,113,0,book 70 | 69,0.0013,266,27,1,hill 71 | 70,0.0013,659,66,0,bench 72 | 71,0.0012,331,31,0,countertop 73 | 72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove 74 | 73,0.0012,369,36,0,palm;palm;tree 75 | 74,0.0012,144,9,0,kitchen;island 76 | 75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system 77 | 76,0.0010,324,33,0,swivel;chair 78 | 77,0.0009,304,27,0,boat 79 | 78,0.0009,170,20,0,bar 80 | 79,0.0009,68,6,0,arcade;machine 81 | 80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty 82 | 81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle 83 | 82,0.0008,492,49,0,towel 84 | 83,0.0008,2510,269,0,light;light;source 85 | 84,0.0008,440,39,0,truck;motortruck 86 | 85,0.0008,147,18,1,tower 87 | 86,0.0008,583,56,0,chandelier;pendant;pendent 88 | 87,0.0007,533,61,0,awning;sunshade;sunblind 89 | 88,0.0007,1989,239,0,streetlight;street;lamp 90 | 89,0.0007,71,5,0,booth;cubicle;stall;kiosk 91 | 90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box 92 | 91,0.0007,135,12,0,airplane;aeroplane;plane 93 | 92,0.0007,83,5,1,dirt;track 94 | 93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes 95 | 94,0.0006,1003,104,0,pole 96 | 95,0.0006,182,12,1,land;ground;soil 97 | 96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail 98 | 97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway 99 | 98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock 100 | 99,0.0006,965,114,0,bottle 101 | 100,0.0006,117,13,0,buffet;counter;sideboard 102 | 101,0.0006,354,35,0,poster;posting;placard;notice;bill;card 103 | 102,0.0006,108,9,1,stage 104 | 103,0.0006,557,55,0,van 105 | 104,0.0006,52,4,0,ship 106 | 105,0.0005,99,5,0,fountain 107 | 106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter 108 | 107,0.0005,292,31,0,canopy 109 | 108,0.0005,77,9,0,washer;automatic;washer;washing;machine 110 | 109,0.0005,340,38,0,plaything;toy 111 | 110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium 112 | 111,0.0005,465,49,0,stool 113 | 112,0.0005,50,4,0,barrel;cask 114 | 113,0.0005,622,75,0,basket;handbasket 115 | 114,0.0005,80,9,1,waterfall;falls 116 | 115,0.0005,59,3,0,tent;collapsible;shelter 117 | 116,0.0005,531,72,0,bag 118 | 117,0.0005,282,30,0,minibike;motorbike 119 | 118,0.0005,73,7,0,cradle 120 | 119,0.0005,435,44,0,oven 121 | 120,0.0005,136,25,0,ball 122 | 121,0.0005,116,24,0,food;solid;food 123 | 122,0.0004,266,31,0,step;stair 124 | 123,0.0004,58,12,0,tank;storage;tank 125 | 124,0.0004,418,83,0,trade;name;brand;name;brand;marque 126 | 125,0.0004,319,43,0,microwave;microwave;oven 127 | 126,0.0004,1193,139,0,pot;flowerpot 128 | 127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna 129 | 128,0.0004,347,36,0,bicycle;bike;wheel;cycle 130 | 129,0.0004,52,5,1,lake 131 | 130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine 132 | 131,0.0004,108,13,0,screen;silver;screen;projection;screen 133 | 132,0.0004,201,30,0,blanket;cover 134 | 133,0.0004,285,21,0,sculpture 135 | 134,0.0004,268,27,0,hood;exhaust;hood 136 | 135,0.0003,1020,108,0,sconce 137 | 136,0.0003,1282,122,0,vase 138 | 137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight 139 | 138,0.0003,453,57,0,tray 140 | 139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin 141 | 140,0.0003,397,44,0,fan 142 | 141,0.0003,92,8,1,pier;wharf;wharfage;dock 143 | 142,0.0003,228,18,0,crt;screen 144 | 143,0.0003,570,59,0,plate 145 | 144,0.0003,217,22,0,monitor;monitoring;device 146 | 145,0.0003,206,19,0,bulletin;board;notice;board 147 | 146,0.0003,130,14,0,shower 148 | 147,0.0003,178,28,0,radiator 149 | 148,0.0002,504,57,0,glass;drinking;glass 150 | 149,0.0002,775,96,0,clock 151 | 150,0.0002,421,56,0,flag 152 | -------------------------------------------------------------------------------- /demo_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Image and model names 4 | TEST_IMG=ADE_val_00001519.jpg 5 | MODEL_NAME=ade20k-resnet50dilated-ppm_deepsup 6 | MODEL_PATH=ckpt/$MODEL_NAME 7 | RESULT_PATH=./ 8 | 9 | ENCODER=$MODEL_NAME/encoder_epoch_20.pth 10 | DECODER=$MODEL_NAME/decoder_epoch_20.pth 11 | 12 | # Download model weights and image 13 | if [ ! -e $MODEL_PATH ]; then 14 | mkdir -p $MODEL_PATH 15 | fi 16 | if [ ! -e $ENCODER ]; then 17 | wget -P $MODEL_PATH http://sceneparsing.csail.mit.edu/model/pytorch/$ENCODER 18 | fi 19 | if [ ! -e $DECODER ]; then 20 | wget -P $MODEL_PATH http://sceneparsing.csail.mit.edu/model/pytorch/$DECODER 21 | fi 22 | if [ ! -e $TEST_IMG ]; then 23 | wget -P $RESULT_PATH http://sceneparsing.csail.mit.edu/data/ADEChallengeData2016/images/validation/$TEST_IMG 24 | fi 25 | 26 | if [ -z "$DOWNLOAD_ONLY" ] 27 | then 28 | 29 | # Inference 30 | python3 -u test.py \ 31 | --imgs $TEST_IMG \ 32 | --cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml \ 33 | DIR $MODEL_PATH \ 34 | TEST.result ./ \ 35 | TEST.checkpoint epoch_20.pth 36 | 37 | fi 38 | -------------------------------------------------------------------------------- /download_ADE20K.sh: -------------------------------------------------------------------------------- 1 | wget -O ./data/ADEChallengeData2016.zip http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip 2 | unzip ./data/ADEChallengeData2016.zip -d ./data 3 | rm ./data/ADEChallengeData2016.zip 4 | echo "Dataset downloaded." 5 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | import os 3 | import time 4 | import argparse 5 | from distutils.version import LooseVersion 6 | # Numerical libs 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from scipy.io import loadmat 11 | # Our libs 12 | from mit_semseg.config import cfg 13 | from mit_semseg.dataset import ValDataset 14 | from mit_semseg.models import ModelBuilder, SegmentationModule 15 | from mit_semseg.utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger 16 | from mit_semseg.lib.nn import user_scattered_collate, async_copy_to 17 | from mit_semseg.lib.utils import as_numpy 18 | from PIL import Image 19 | from tqdm import tqdm 20 | 21 | colors = loadmat('data/color150.mat')['colors'] 22 | 23 | 24 | def visualize_result(data, pred, dir_result): 25 | (img, seg, info) = data 26 | 27 | # segmentation 28 | seg_color = colorEncode(seg, colors) 29 | 30 | # prediction 31 | pred_color = colorEncode(pred, colors) 32 | 33 | # aggregate images and save 34 | im_vis = np.concatenate((img, seg_color, pred_color), 35 | axis=1).astype(np.uint8) 36 | 37 | img_name = info.split('/')[-1] 38 | Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png'))) 39 | 40 | 41 | def evaluate(segmentation_module, loader, cfg, gpu): 42 | acc_meter = AverageMeter() 43 | intersection_meter = AverageMeter() 44 | union_meter = AverageMeter() 45 | time_meter = AverageMeter() 46 | 47 | segmentation_module.eval() 48 | 49 | pbar = tqdm(total=len(loader)) 50 | for batch_data in loader: 51 | # process data 52 | batch_data = batch_data[0] 53 | seg_label = as_numpy(batch_data['seg_label'][0]) 54 | img_resized_list = batch_data['img_data'] 55 | 56 | torch.cuda.synchronize() 57 | tic = time.perf_counter() 58 | with torch.no_grad(): 59 | segSize = (seg_label.shape[0], seg_label.shape[1]) 60 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) 61 | scores = async_copy_to(scores, gpu) 62 | 63 | for img in img_resized_list: 64 | feed_dict = batch_data.copy() 65 | feed_dict['img_data'] = img 66 | del feed_dict['img_ori'] 67 | del feed_dict['info'] 68 | feed_dict = async_copy_to(feed_dict, gpu) 69 | 70 | # forward pass 71 | scores_tmp = segmentation_module(feed_dict, segSize=segSize) 72 | scores = scores + scores_tmp / len(cfg.DATASET.imgSizes) 73 | 74 | _, pred = torch.max(scores, dim=1) 75 | pred = as_numpy(pred.squeeze(0).cpu()) 76 | 77 | torch.cuda.synchronize() 78 | time_meter.update(time.perf_counter() - tic) 79 | 80 | # calculate accuracy 81 | acc, pix = accuracy(pred, seg_label) 82 | intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class) 83 | acc_meter.update(acc, pix) 84 | intersection_meter.update(intersection) 85 | union_meter.update(union) 86 | 87 | # visualization 88 | if cfg.VAL.visualize: 89 | visualize_result( 90 | (batch_data['img_ori'], seg_label, batch_data['info']), 91 | pred, 92 | os.path.join(cfg.DIR, 'result') 93 | ) 94 | 95 | pbar.update(1) 96 | 97 | # summary 98 | iou = intersection_meter.sum / (union_meter.sum + 1e-10) 99 | for i, _iou in enumerate(iou): 100 | print('class [{}], IoU: {:.4f}'.format(i, _iou)) 101 | 102 | print('[Eval Summary]:') 103 | print('Mean IoU: {:.4f}, Accuracy: {:.2f}%, Inference Time: {:.4f}s' 104 | .format(iou.mean(), acc_meter.average()*100, time_meter.average())) 105 | 106 | 107 | def main(cfg, gpu): 108 | torch.cuda.set_device(gpu) 109 | 110 | # Network Builders 111 | net_encoder = ModelBuilder.build_encoder( 112 | arch=cfg.MODEL.arch_encoder.lower(), 113 | fc_dim=cfg.MODEL.fc_dim, 114 | weights=cfg.MODEL.weights_encoder) 115 | net_decoder = ModelBuilder.build_decoder( 116 | arch=cfg.MODEL.arch_decoder.lower(), 117 | fc_dim=cfg.MODEL.fc_dim, 118 | num_class=cfg.DATASET.num_class, 119 | weights=cfg.MODEL.weights_decoder, 120 | use_softmax=True) 121 | 122 | crit = nn.NLLLoss(ignore_index=-1) 123 | 124 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) 125 | 126 | # Dataset and Loader 127 | dataset_val = ValDataset( 128 | cfg.DATASET.root_dataset, 129 | cfg.DATASET.list_val, 130 | cfg.DATASET) 131 | loader_val = torch.utils.data.DataLoader( 132 | dataset_val, 133 | batch_size=cfg.VAL.batch_size, 134 | shuffle=False, 135 | collate_fn=user_scattered_collate, 136 | num_workers=5, 137 | drop_last=True) 138 | 139 | segmentation_module.cuda() 140 | 141 | # Main loop 142 | evaluate(segmentation_module, loader_val, cfg, gpu) 143 | 144 | print('Evaluation Done!') 145 | 146 | 147 | if __name__ == '__main__': 148 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ 149 | 'PyTorch>=0.4.0 is required' 150 | 151 | parser = argparse.ArgumentParser( 152 | description="PyTorch Semantic Segmentation Validation" 153 | ) 154 | parser.add_argument( 155 | "--cfg", 156 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml", 157 | metavar="FILE", 158 | help="path to config file", 159 | type=str, 160 | ) 161 | parser.add_argument( 162 | "--gpu", 163 | default=0, 164 | help="gpu to use" 165 | ) 166 | parser.add_argument( 167 | "opts", 168 | help="Modify config options using the command-line", 169 | default=None, 170 | nargs=argparse.REMAINDER, 171 | ) 172 | args = parser.parse_args() 173 | 174 | cfg.merge_from_file(args.cfg) 175 | cfg.merge_from_list(args.opts) 176 | # cfg.freeze() 177 | 178 | logger = setup_logger(distributed_rank=0) # TODO 179 | logger.info("Loaded configuration file {}".format(args.cfg)) 180 | logger.info("Running with config:\n{}".format(cfg)) 181 | 182 | # absolute paths of model weights 183 | cfg.MODEL.weights_encoder = os.path.join( 184 | cfg.DIR, 'encoder_' + cfg.VAL.checkpoint) 185 | cfg.MODEL.weights_decoder = os.path.join( 186 | cfg.DIR, 'decoder_' + cfg.VAL.checkpoint) 187 | assert os.path.exists(cfg.MODEL.weights_encoder) and \ 188 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!" 189 | 190 | if not os.path.isdir(os.path.join(cfg.DIR, "result")): 191 | os.makedirs(os.path.join(cfg.DIR, "result")) 192 | 193 | main(cfg, args.gpu) 194 | -------------------------------------------------------------------------------- /eval_multipro.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | import os 3 | import argparse 4 | from distutils.version import LooseVersion 5 | from multiprocessing import Queue, Process 6 | # Numerical libs 7 | import numpy as np 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | from scipy.io import loadmat 12 | # Our libs 13 | from mit_semseg.config import cfg 14 | from mit_semseg.dataset import ValDataset 15 | from mit_semseg.models import ModelBuilder, SegmentationModule 16 | from mit_semseg.utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger 17 | from mit_semseg.lib.nn import user_scattered_collate, async_copy_to 18 | from mit_semseg.lib.utils import as_numpy 19 | from PIL import Image 20 | from tqdm import tqdm 21 | 22 | colors = loadmat('data/color150.mat')['colors'] 23 | 24 | 25 | def visualize_result(data, pred, dir_result): 26 | (img, seg, info) = data 27 | 28 | # segmentation 29 | seg_color = colorEncode(seg, colors) 30 | 31 | # prediction 32 | pred_color = colorEncode(pred, colors) 33 | 34 | # aggregate images and save 35 | im_vis = np.concatenate((img, seg_color, pred_color), 36 | axis=1).astype(np.uint8) 37 | 38 | img_name = info.split('/')[-1] 39 | Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png'))) 40 | 41 | 42 | def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue): 43 | segmentation_module.eval() 44 | 45 | for batch_data in loader: 46 | # process data 47 | batch_data = batch_data[0] 48 | seg_label = as_numpy(batch_data['seg_label'][0]) 49 | img_resized_list = batch_data['img_data'] 50 | 51 | with torch.no_grad(): 52 | segSize = (seg_label.shape[0], seg_label.shape[1]) 53 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) 54 | scores = async_copy_to(scores, gpu_id) 55 | 56 | for img in img_resized_list: 57 | feed_dict = batch_data.copy() 58 | feed_dict['img_data'] = img 59 | del feed_dict['img_ori'] 60 | del feed_dict['info'] 61 | feed_dict = async_copy_to(feed_dict, gpu_id) 62 | 63 | # forward pass 64 | scores_tmp = segmentation_module(feed_dict, segSize=segSize) 65 | scores = scores + scores_tmp / len(cfg.DATASET.imgSizes) 66 | 67 | _, pred = torch.max(scores, dim=1) 68 | pred = as_numpy(pred.squeeze(0).cpu()) 69 | 70 | # calculate accuracy and SEND THEM TO MASTER 71 | acc, pix = accuracy(pred, seg_label) 72 | intersection, union = intersectionAndUnion(pred, seg_label, cfg.DATASET.num_class) 73 | result_queue.put_nowait((acc, pix, intersection, union)) 74 | 75 | # visualization 76 | if cfg.VAL.visualize: 77 | visualize_result( 78 | (batch_data['img_ori'], seg_label, batch_data['info']), 79 | pred, 80 | os.path.join(cfg.DIR, 'result') 81 | ) 82 | 83 | 84 | def worker(cfg, gpu_id, start_idx, end_idx, result_queue): 85 | torch.cuda.set_device(gpu_id) 86 | 87 | # Dataset and Loader 88 | dataset_val = ValDataset( 89 | cfg.DATASET.root_dataset, 90 | cfg.DATASET.list_val, 91 | cfg.DATASET, 92 | start_idx=start_idx, end_idx=end_idx) 93 | loader_val = torch.utils.data.DataLoader( 94 | dataset_val, 95 | batch_size=cfg.VAL.batch_size, 96 | shuffle=False, 97 | collate_fn=user_scattered_collate, 98 | num_workers=2) 99 | 100 | # Network Builders 101 | net_encoder = ModelBuilder.build_encoder( 102 | arch=cfg.MODEL.arch_encoder.lower(), 103 | fc_dim=cfg.MODEL.fc_dim, 104 | weights=cfg.MODEL.weights_encoder) 105 | net_decoder = ModelBuilder.build_decoder( 106 | arch=cfg.MODEL.arch_decoder.lower(), 107 | fc_dim=cfg.MODEL.fc_dim, 108 | num_class=cfg.DATASET.num_class, 109 | weights=cfg.MODEL.weights_decoder, 110 | use_softmax=True) 111 | 112 | crit = nn.NLLLoss(ignore_index=-1) 113 | 114 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) 115 | 116 | segmentation_module.cuda() 117 | 118 | # Main loop 119 | evaluate(segmentation_module, loader_val, cfg, gpu_id, result_queue) 120 | 121 | 122 | def main(cfg, gpus): 123 | with open(cfg.DATASET.list_val, 'r') as f: 124 | lines = f.readlines() 125 | num_files = len(lines) 126 | 127 | num_files_per_gpu = math.ceil(num_files / len(gpus)) 128 | 129 | pbar = tqdm(total=num_files) 130 | 131 | acc_meter = AverageMeter() 132 | intersection_meter = AverageMeter() 133 | union_meter = AverageMeter() 134 | 135 | result_queue = Queue(500) 136 | procs = [] 137 | for idx, gpu_id in enumerate(gpus): 138 | start_idx = idx * num_files_per_gpu 139 | end_idx = min(start_idx + num_files_per_gpu, num_files) 140 | proc = Process(target=worker, args=(cfg, gpu_id, start_idx, end_idx, result_queue)) 141 | print('gpu:{}, start_idx:{}, end_idx:{}'.format(gpu_id, start_idx, end_idx)) 142 | proc.start() 143 | procs.append(proc) 144 | 145 | # master fetches results 146 | processed_counter = 0 147 | while processed_counter < num_files: 148 | if result_queue.empty(): 149 | continue 150 | (acc, pix, intersection, union) = result_queue.get() 151 | acc_meter.update(acc, pix) 152 | intersection_meter.update(intersection) 153 | union_meter.update(union) 154 | processed_counter += 1 155 | pbar.update(1) 156 | 157 | for p in procs: 158 | p.join() 159 | 160 | # summary 161 | iou = intersection_meter.sum / (union_meter.sum + 1e-10) 162 | for i, _iou in enumerate(iou): 163 | print('class [{}], IoU: {:.4f}'.format(i, _iou)) 164 | 165 | print('[Eval Summary]:') 166 | print('Mean IoU: {:.4f}, Accuracy: {:.2f}%' 167 | .format(iou.mean(), acc_meter.average()*100)) 168 | 169 | print('Evaluation Done!') 170 | 171 | 172 | if __name__ == '__main__': 173 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ 174 | 'PyTorch>=0.4.0 is required' 175 | 176 | parser = argparse.ArgumentParser( 177 | description="PyTorch Semantic Segmentation Validation" 178 | ) 179 | parser.add_argument( 180 | "--cfg", 181 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml", 182 | metavar="FILE", 183 | help="path to config file", 184 | type=str, 185 | ) 186 | parser.add_argument( 187 | "--gpus", 188 | default="0-3", 189 | help="gpus to use, e.g. 0-3 or 0,1,2,3" 190 | ) 191 | parser.add_argument( 192 | "opts", 193 | help="Modify config options using the command-line", 194 | default=None, 195 | nargs=argparse.REMAINDER, 196 | ) 197 | args = parser.parse_args() 198 | 199 | cfg.merge_from_file(args.cfg) 200 | cfg.merge_from_list(args.opts) 201 | # cfg.freeze() 202 | 203 | logger = setup_logger(distributed_rank=0) # TODO 204 | logger.info("Loaded configuration file {}".format(args.cfg)) 205 | logger.info("Running with config:\n{}".format(cfg)) 206 | 207 | # absolute paths of model weights 208 | cfg.MODEL.weights_encoder = os.path.join( 209 | cfg.DIR, 'encoder_' + cfg.VAL.checkpoint) 210 | cfg.MODEL.weights_decoder = os.path.join( 211 | cfg.DIR, 'decoder_' + cfg.VAL.checkpoint) 212 | assert os.path.exists(cfg.MODEL.weights_encoder) and \ 213 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!" 214 | 215 | if not os.path.isdir(os.path.join(cfg.DIR, "result")): 216 | os.makedirs(os.path.join(cfg.DIR, "result")) 217 | 218 | # Parse gpu ids 219 | gpus = parse_devices(args.gpus) 220 | gpus = [x.replace('gpu', '') for x in gpus] 221 | gpus = [int(x) for x in gpus] 222 | 223 | main(cfg, gpus) 224 | -------------------------------------------------------------------------------- /mit_semseg/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT CSAIL Semantic Segmentation 3 | """ 4 | 5 | __version__ = '1.0.0' 6 | -------------------------------------------------------------------------------- /mit_semseg/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | -------------------------------------------------------------------------------- /mit_semseg/config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Config definition 5 | # ----------------------------------------------------------------------------- 6 | 7 | _C = CN() 8 | _C.DIR = "ckpt/ade20k-resnet50dilated-ppm_deepsup" 9 | 10 | # ----------------------------------------------------------------------------- 11 | # Dataset 12 | # ----------------------------------------------------------------------------- 13 | _C.DATASET = CN() 14 | _C.DATASET.root_dataset = "./data/" 15 | _C.DATASET.list_train = "./data/training.odgt" 16 | _C.DATASET.list_val = "./data/validation.odgt" 17 | _C.DATASET.num_class = 150 18 | # multiscale train/test, size of short edge (int or tuple) 19 | _C.DATASET.imgSizes = (300, 375, 450, 525, 600) 20 | # maximum input image size of long edge 21 | _C.DATASET.imgMaxSize = 1000 22 | # maxmimum downsampling rate of the network 23 | _C.DATASET.padding_constant = 8 24 | # downsampling rate of the segmentation label 25 | _C.DATASET.segm_downsampling_rate = 8 26 | # randomly horizontally flip images when train/test 27 | _C.DATASET.random_flip = True 28 | 29 | # ----------------------------------------------------------------------------- 30 | # Model 31 | # ----------------------------------------------------------------------------- 32 | _C.MODEL = CN() 33 | # architecture of net_encoder 34 | _C.MODEL.arch_encoder = "resnet50dilated" 35 | # architecture of net_decoder 36 | _C.MODEL.arch_decoder = "ppm_deepsup" 37 | # weights to finetune net_encoder 38 | _C.MODEL.weights_encoder = "" 39 | # weights to finetune net_decoder 40 | _C.MODEL.weights_decoder = "" 41 | # number of feature channels between encoder and decoder 42 | _C.MODEL.fc_dim = 2048 43 | 44 | # ----------------------------------------------------------------------------- 45 | # Training 46 | # ----------------------------------------------------------------------------- 47 | _C.TRAIN = CN() 48 | _C.TRAIN.batch_size_per_gpu = 2 49 | # epochs to train for 50 | _C.TRAIN.num_epoch = 20 51 | # epoch to start training. useful if continue from a checkpoint 52 | _C.TRAIN.start_epoch = 0 53 | # iterations of each epoch (irrelevant to batch size) 54 | _C.TRAIN.epoch_iters = 5000 55 | 56 | _C.TRAIN.optim = "SGD" 57 | _C.TRAIN.lr_encoder = 0.02 58 | _C.TRAIN.lr_decoder = 0.02 59 | # power in poly to drop LR 60 | _C.TRAIN.lr_pow = 0.9 61 | # momentum for sgd, beta1 for adam 62 | _C.TRAIN.beta1 = 0.9 63 | # weights regularizer 64 | _C.TRAIN.weight_decay = 1e-4 65 | # the weighting of deep supervision loss 66 | _C.TRAIN.deep_sup_scale = 0.4 67 | # fix bn params, only under finetuning 68 | _C.TRAIN.fix_bn = False 69 | # number of data loading workers 70 | _C.TRAIN.workers = 16 71 | 72 | # frequency to display 73 | _C.TRAIN.disp_iter = 20 74 | # manual seed 75 | _C.TRAIN.seed = 304 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Validation 79 | # ----------------------------------------------------------------------------- 80 | _C.VAL = CN() 81 | # currently only supports 1 82 | _C.VAL.batch_size = 1 83 | # output visualization during validation 84 | _C.VAL.visualize = False 85 | # the checkpoint to evaluate on 86 | _C.VAL.checkpoint = "epoch_20.pth" 87 | 88 | # ----------------------------------------------------------------------------- 89 | # Testing 90 | # ----------------------------------------------------------------------------- 91 | _C.TEST = CN() 92 | # currently only supports 1 93 | _C.TEST.batch_size = 1 94 | # the checkpoint to test on 95 | _C.TEST.checkpoint = "epoch_20.pth" 96 | # folder to output visualization results 97 | _C.TEST.result = "./" 98 | -------------------------------------------------------------------------------- /mit_semseg/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def imresize(im, size, interp='bilinear'): 10 | if interp == 'nearest': 11 | resample = Image.NEAREST 12 | elif interp == 'bilinear': 13 | resample = Image.BILINEAR 14 | elif interp == 'bicubic': 15 | resample = Image.BICUBIC 16 | else: 17 | raise Exception('resample method undefined!') 18 | 19 | return im.resize(size, resample) 20 | 21 | 22 | class BaseDataset(torch.utils.data.Dataset): 23 | def __init__(self, odgt, opt, **kwargs): 24 | # parse options 25 | self.imgSizes = opt.imgSizes 26 | self.imgMaxSize = opt.imgMaxSize 27 | # max down sampling rate of network to avoid rounding during conv or pooling 28 | self.padding_constant = opt.padding_constant 29 | 30 | # parse the input list 31 | self.parse_input_list(odgt, **kwargs) 32 | 33 | # mean and std 34 | self.normalize = transforms.Normalize( 35 | mean=[0.485, 0.456, 0.406], 36 | std=[0.229, 0.224, 0.225]) 37 | 38 | def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1): 39 | if isinstance(odgt, list): 40 | self.list_sample = odgt 41 | elif isinstance(odgt, str): 42 | self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] 43 | 44 | if max_sample > 0: 45 | self.list_sample = self.list_sample[0:max_sample] 46 | if start_idx >= 0 and end_idx >= 0: # divide file list 47 | self.list_sample = self.list_sample[start_idx:end_idx] 48 | 49 | self.num_sample = len(self.list_sample) 50 | assert self.num_sample > 0 51 | print('# samples: {}'.format(self.num_sample)) 52 | 53 | def img_transform(self, img): 54 | # 0-255 to 0-1 55 | img = np.float32(np.array(img)) / 255. 56 | img = img.transpose((2, 0, 1)) 57 | img = self.normalize(torch.from_numpy(img.copy())) 58 | return img 59 | 60 | def segm_transform(self, segm): 61 | # to tensor, -1 to 149 62 | segm = torch.from_numpy(np.array(segm)).long() - 1 63 | return segm 64 | 65 | # Round x to the nearest multiple of p and x' >= x 66 | def round2nearest_multiple(self, x, p): 67 | return ((x - 1) // p + 1) * p 68 | 69 | 70 | class TrainDataset(BaseDataset): 71 | def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs): 72 | super(TrainDataset, self).__init__(odgt, opt, **kwargs) 73 | self.root_dataset = root_dataset 74 | # down sampling rate of segm labe 75 | self.segm_downsampling_rate = opt.segm_downsampling_rate 76 | self.batch_per_gpu = batch_per_gpu 77 | 78 | # classify images into two classes: 1. h > w and 2. h <= w 79 | self.batch_record_list = [[], []] 80 | 81 | # override dataset length when trainig with batch_per_gpu > 1 82 | self.cur_idx = 0 83 | self.if_shuffled = False 84 | 85 | def _get_sub_batch(self): 86 | while True: 87 | # get a sample record 88 | this_sample = self.list_sample[self.cur_idx] 89 | if this_sample['height'] > this_sample['width']: 90 | self.batch_record_list[0].append(this_sample) # h > w, go to 1st class 91 | else: 92 | self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class 93 | 94 | # update current sample pointer 95 | self.cur_idx += 1 96 | if self.cur_idx >= self.num_sample: 97 | self.cur_idx = 0 98 | np.random.shuffle(self.list_sample) 99 | 100 | if len(self.batch_record_list[0]) == self.batch_per_gpu: 101 | batch_records = self.batch_record_list[0] 102 | self.batch_record_list[0] = [] 103 | break 104 | elif len(self.batch_record_list[1]) == self.batch_per_gpu: 105 | batch_records = self.batch_record_list[1] 106 | self.batch_record_list[1] = [] 107 | break 108 | return batch_records 109 | 110 | def __getitem__(self, index): 111 | # NOTE: random shuffle for the first time. shuffle in __init__ is useless 112 | if not self.if_shuffled: 113 | np.random.seed(index) 114 | np.random.shuffle(self.list_sample) 115 | self.if_shuffled = True 116 | 117 | # get sub-batch candidates 118 | batch_records = self._get_sub_batch() 119 | 120 | # resize all images' short edges to the chosen size 121 | if isinstance(self.imgSizes, list) or isinstance(self.imgSizes, tuple): 122 | this_short_size = np.random.choice(self.imgSizes) 123 | else: 124 | this_short_size = self.imgSizes 125 | 126 | # calculate the BATCH's height and width 127 | # since we concat more than one samples, the batch's h and w shall be larger than EACH sample 128 | batch_widths = np.zeros(self.batch_per_gpu, np.int32) 129 | batch_heights = np.zeros(self.batch_per_gpu, np.int32) 130 | for i in range(self.batch_per_gpu): 131 | img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] 132 | this_scale = min( 133 | this_short_size / min(img_height, img_width), \ 134 | self.imgMaxSize / max(img_height, img_width)) 135 | batch_widths[i] = img_width * this_scale 136 | batch_heights[i] = img_height * this_scale 137 | 138 | # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' 139 | batch_width = np.max(batch_widths) 140 | batch_height = np.max(batch_heights) 141 | batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant)) 142 | batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant)) 143 | 144 | assert self.padding_constant >= self.segm_downsampling_rate, \ 145 | 'padding constant must be equal or large than segm downsamping rate' 146 | batch_images = torch.zeros( 147 | self.batch_per_gpu, 3, batch_height, batch_width) 148 | batch_segms = torch.zeros( 149 | self.batch_per_gpu, 150 | batch_height // self.segm_downsampling_rate, 151 | batch_width // self.segm_downsampling_rate).long() 152 | 153 | for i in range(self.batch_per_gpu): 154 | this_record = batch_records[i] 155 | 156 | # load image and label 157 | image_path = os.path.join(self.root_dataset, this_record['fpath_img']) 158 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) 159 | 160 | img = Image.open(image_path).convert('RGB') 161 | segm = Image.open(segm_path) 162 | assert(segm.mode == "L") 163 | assert(img.size[0] == segm.size[0]) 164 | assert(img.size[1] == segm.size[1]) 165 | 166 | # random_flip 167 | if np.random.choice([0, 1]): 168 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 169 | segm = segm.transpose(Image.FLIP_LEFT_RIGHT) 170 | 171 | # note that each sample within a mini batch has different scale param 172 | img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear') 173 | segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest') 174 | 175 | # further downsample seg label, need to avoid seg label misalignment 176 | segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate) 177 | segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate) 178 | segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0) 179 | segm_rounded.paste(segm, (0, 0)) 180 | segm = imresize( 181 | segm_rounded, 182 | (segm_rounded.size[0] // self.segm_downsampling_rate, \ 183 | segm_rounded.size[1] // self.segm_downsampling_rate), \ 184 | interp='nearest') 185 | 186 | # image transform, to torch float tensor 3xHxW 187 | img = self.img_transform(img) 188 | 189 | # segm transform, to torch long tensor HxW 190 | segm = self.segm_transform(segm) 191 | 192 | # put into batch arrays 193 | batch_images[i][:, :img.shape[1], :img.shape[2]] = img 194 | batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm 195 | 196 | output = dict() 197 | output['img_data'] = batch_images 198 | output['seg_label'] = batch_segms 199 | return output 200 | 201 | def __len__(self): 202 | return int(1e10) # It's a fake length due to the trick that every loader maintains its own list 203 | #return self.num_sampleclass 204 | 205 | 206 | class ValDataset(BaseDataset): 207 | def __init__(self, root_dataset, odgt, opt, **kwargs): 208 | super(ValDataset, self).__init__(odgt, opt, **kwargs) 209 | self.root_dataset = root_dataset 210 | 211 | def __getitem__(self, index): 212 | this_record = self.list_sample[index] 213 | # load image and label 214 | image_path = os.path.join(self.root_dataset, this_record['fpath_img']) 215 | segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) 216 | img = Image.open(image_path).convert('RGB') 217 | segm = Image.open(segm_path) 218 | assert(segm.mode == "L") 219 | assert(img.size[0] == segm.size[0]) 220 | assert(img.size[1] == segm.size[1]) 221 | 222 | ori_width, ori_height = img.size 223 | 224 | img_resized_list = [] 225 | for this_short_size in self.imgSizes: 226 | # calculate target height and width 227 | scale = min(this_short_size / float(min(ori_height, ori_width)), 228 | self.imgMaxSize / float(max(ori_height, ori_width))) 229 | target_height, target_width = int(ori_height * scale), int(ori_width * scale) 230 | 231 | # to avoid rounding in network 232 | target_width = self.round2nearest_multiple(target_width, self.padding_constant) 233 | target_height = self.round2nearest_multiple(target_height, self.padding_constant) 234 | 235 | # resize images 236 | img_resized = imresize(img, (target_width, target_height), interp='bilinear') 237 | 238 | # image transform, to torch float tensor 3xHxW 239 | img_resized = self.img_transform(img_resized) 240 | img_resized = torch.unsqueeze(img_resized, 0) 241 | img_resized_list.append(img_resized) 242 | 243 | # segm transform, to torch long tensor HxW 244 | segm = self.segm_transform(segm) 245 | batch_segms = torch.unsqueeze(segm, 0) 246 | 247 | output = dict() 248 | output['img_ori'] = np.array(img) 249 | output['img_data'] = [x.contiguous() for x in img_resized_list] 250 | output['seg_label'] = batch_segms.contiguous() 251 | output['info'] = this_record['fpath_img'] 252 | return output 253 | 254 | def __len__(self): 255 | return self.num_sample 256 | 257 | 258 | class TestDataset(BaseDataset): 259 | def __init__(self, odgt, opt, **kwargs): 260 | super(TestDataset, self).__init__(odgt, opt, **kwargs) 261 | 262 | def __getitem__(self, index): 263 | this_record = self.list_sample[index] 264 | # load image 265 | image_path = this_record['fpath_img'] 266 | img = Image.open(image_path).convert('RGB') 267 | 268 | ori_width, ori_height = img.size 269 | 270 | img_resized_list = [] 271 | for this_short_size in self.imgSizes: 272 | # calculate target height and width 273 | scale = min(this_short_size / float(min(ori_height, ori_width)), 274 | self.imgMaxSize / float(max(ori_height, ori_width))) 275 | target_height, target_width = int(ori_height * scale), int(ori_width * scale) 276 | 277 | # to avoid rounding in network 278 | target_width = self.round2nearest_multiple(target_width, self.padding_constant) 279 | target_height = self.round2nearest_multiple(target_height, self.padding_constant) 280 | 281 | # resize images 282 | img_resized = imresize(img, (target_width, target_height), interp='bilinear') 283 | 284 | # image transform, to torch float tensor 3xHxW 285 | img_resized = self.img_transform(img_resized) 286 | img_resized = torch.unsqueeze(img_resized, 0) 287 | img_resized_list.append(img_resized) 288 | 289 | output = dict() 290 | output['img_ori'] = np.array(img) 291 | output['img_data'] = [x.contiguous() for x in img_resized_list] 292 | output['info'] = this_record['fpath_img'] 293 | return output 294 | 295 | def __len__(self): 296 | return self.num_sample 297 | -------------------------------------------------------------------------------- /mit_semseg/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/mit_semseg/lib/__init__.py -------------------------------------------------------------------------------- /mit_semseg/lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | # customed batch norm statistics 49 | self._moving_average_fraction = 1. - momentum 50 | self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) 51 | self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) 52 | self.register_buffer('_running_iter', torch.ones(1)) 53 | self._tmp_running_mean = self.running_mean.clone() * self._running_iter 54 | self._tmp_running_var = self.running_var.clone() * self._running_iter 55 | 56 | def forward(self, input): 57 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 58 | if not (self._is_parallel and self.training): 59 | return F.batch_norm( 60 | input, self.running_mean, self.running_var, self.weight, self.bias, 61 | self.training, self.momentum, self.eps) 62 | 63 | # Resize the input to (B, C, -1). 64 | input_shape = input.size() 65 | input = input.view(input.size(0), self.num_features, -1) 66 | 67 | # Compute the sum and square-sum. 68 | sum_size = input.size(0) * input.size(2) 69 | input_sum = _sum_ft(input) 70 | input_ssum = _sum_ft(input ** 2) 71 | 72 | # Reduce-and-broadcast the statistics. 73 | if self._parallel_id == 0: 74 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | else: 76 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 77 | 78 | # Compute the output. 79 | if self.affine: 80 | # MJY:: Fuse the multiplication for speed. 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 82 | else: 83 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 84 | 85 | # Reshape it. 86 | return output.view(input_shape) 87 | 88 | def __data_parallel_replicate__(self, ctx, copy_id): 89 | self._is_parallel = True 90 | self._parallel_id = copy_id 91 | 92 | # parallel_id == 0 means master device. 93 | if self._parallel_id == 0: 94 | ctx.sync_master = self._sync_master 95 | else: 96 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 97 | 98 | def _data_parallel_master(self, intermediates): 99 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 100 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 101 | 102 | to_reduce = [i[1][:2] for i in intermediates] 103 | to_reduce = [j for i in to_reduce for j in i] # flatten 104 | target_gpus = [i[1].sum.get_device() for i in intermediates] 105 | 106 | sum_size = sum([i[1].sum_size for i in intermediates]) 107 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 108 | 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 116 | 117 | return outputs 118 | 119 | def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): 120 | """return *dest* by `dest := dest*alpha + delta*beta + bias`""" 121 | return dest * alpha + delta * beta + bias 122 | 123 | def _compute_mean_std(self, sum_, ssum, size): 124 | """Compute the mean and standard-deviation with sum and square-sum. This method 125 | also maintains the moving average on the master device.""" 126 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 127 | mean = sum_ / size 128 | sumvar = ssum - sum_ * mean 129 | unbias_var = sumvar / (size - 1) 130 | bias_var = sumvar / size 131 | 132 | self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) 133 | self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) 134 | self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) 135 | 136 | self.running_mean = self._tmp_running_mean / self._running_iter 137 | self.running_var = self._tmp_running_var / self._running_iter 138 | 139 | return mean, bias_var.clamp(self.eps) ** -0.5 140 | 141 | 142 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 143 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 144 | mini-batch. 145 | 146 | .. math:: 147 | 148 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 149 | 150 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 151 | standard-deviation are reduced across all devices during training. 152 | 153 | For example, when one uses `nn.DataParallel` to wrap the network during 154 | training, PyTorch's implementation normalize the tensor on each device using 155 | the statistics only on that device, which accelerated the computation and 156 | is also easy to implement, but the statistics might be inaccurate. 157 | Instead, in this synchronized version, the statistics will be computed 158 | over all training samples distributed on multiple devices. 159 | 160 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 161 | as the built-in PyTorch implementation. 162 | 163 | The mean and standard-deviation are calculated per-dimension over 164 | the mini-batches and gamma and beta are learnable parameter vectors 165 | of size C (where C is the input size). 166 | 167 | During training, this layer keeps a running estimate of its computed mean 168 | and variance. The running sum is kept with a default momentum of 0.1. 169 | 170 | During evaluation, this running mean/variance is used for normalization. 171 | 172 | Because the BatchNorm is done over the `C` dimension, computing statistics 173 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 174 | 175 | Args: 176 | num_features: num_features from an expected input of size 177 | `batch_size x num_features [x width]` 178 | eps: a value added to the denominator for numerical stability. 179 | Default: 1e-5 180 | momentum: the value used for the running_mean and running_var 181 | computation. Default: 0.1 182 | affine: a boolean value that when set to ``True``, gives the layer learnable 183 | affine parameters. Default: ``True`` 184 | 185 | Shape: 186 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 187 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 188 | 189 | Examples: 190 | >>> # With Learnable Parameters 191 | >>> m = SynchronizedBatchNorm1d(100) 192 | >>> # Without Learnable Parameters 193 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 194 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 195 | >>> output = m(input) 196 | """ 197 | 198 | def _check_input_dim(self, input): 199 | if input.dim() != 2 and input.dim() != 3: 200 | raise ValueError('expected 2D or 3D input (got {}D input)' 201 | .format(input.dim())) 202 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 203 | 204 | 205 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 206 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 207 | of 3d inputs 208 | 209 | .. math:: 210 | 211 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 212 | 213 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 214 | standard-deviation are reduced across all devices during training. 215 | 216 | For example, when one uses `nn.DataParallel` to wrap the network during 217 | training, PyTorch's implementation normalize the tensor on each device using 218 | the statistics only on that device, which accelerated the computation and 219 | is also easy to implement, but the statistics might be inaccurate. 220 | Instead, in this synchronized version, the statistics will be computed 221 | over all training samples distributed on multiple devices. 222 | 223 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 224 | as the built-in PyTorch implementation. 225 | 226 | The mean and standard-deviation are calculated per-dimension over 227 | the mini-batches and gamma and beta are learnable parameter vectors 228 | of size C (where C is the input size). 229 | 230 | During training, this layer keeps a running estimate of its computed mean 231 | and variance. The running sum is kept with a default momentum of 0.1. 232 | 233 | During evaluation, this running mean/variance is used for normalization. 234 | 235 | Because the BatchNorm is done over the `C` dimension, computing statistics 236 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 237 | 238 | Args: 239 | num_features: num_features from an expected input of 240 | size batch_size x num_features x height x width 241 | eps: a value added to the denominator for numerical stability. 242 | Default: 1e-5 243 | momentum: the value used for the running_mean and running_var 244 | computation. Default: 0.1 245 | affine: a boolean value that when set to ``True``, gives the layer learnable 246 | affine parameters. Default: ``True`` 247 | 248 | Shape: 249 | - Input: :math:`(N, C, H, W)` 250 | - Output: :math:`(N, C, H, W)` (same shape as input) 251 | 252 | Examples: 253 | >>> # With Learnable Parameters 254 | >>> m = SynchronizedBatchNorm2d(100) 255 | >>> # Without Learnable Parameters 256 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 257 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 258 | >>> output = m(input) 259 | """ 260 | 261 | def _check_input_dim(self, input): 262 | if input.dim() != 4: 263 | raise ValueError('expected 4D input (got {}D input)' 264 | .format(input.dim())) 265 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 266 | 267 | 268 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 269 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 270 | of 4d inputs 271 | 272 | .. math:: 273 | 274 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 275 | 276 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 277 | standard-deviation are reduced across all devices during training. 278 | 279 | For example, when one uses `nn.DataParallel` to wrap the network during 280 | training, PyTorch's implementation normalize the tensor on each device using 281 | the statistics only on that device, which accelerated the computation and 282 | is also easy to implement, but the statistics might be inaccurate. 283 | Instead, in this synchronized version, the statistics will be computed 284 | over all training samples distributed on multiple devices. 285 | 286 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 287 | as the built-in PyTorch implementation. 288 | 289 | The mean and standard-deviation are calculated per-dimension over 290 | the mini-batches and gamma and beta are learnable parameter vectors 291 | of size C (where C is the input size). 292 | 293 | During training, this layer keeps a running estimate of its computed mean 294 | and variance. The running sum is kept with a default momentum of 0.1. 295 | 296 | During evaluation, this running mean/variance is used for normalization. 297 | 298 | Because the BatchNorm is done over the `C` dimension, computing statistics 299 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 300 | or Spatio-temporal BatchNorm 301 | 302 | Args: 303 | num_features: num_features from an expected input of 304 | size batch_size x num_features x depth x height x width 305 | eps: a value added to the denominator for numerical stability. 306 | Default: 1e-5 307 | momentum: the value used for the running_mean and running_var 308 | computation. Default: 0.1 309 | affine: a boolean value that when set to ``True``, gives the layer learnable 310 | affine parameters. Default: ``True`` 311 | 312 | Shape: 313 | - Input: :math:`(N, C, D, H, W)` 314 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 315 | 316 | Examples: 317 | >>> # With Learnable Parameters 318 | >>> m = SynchronizedBatchNorm3d(100) 319 | >>> # Without Learnable Parameters 320 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 321 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 322 | >>> output = m(input) 323 | """ 324 | 325 | def _check_input_dim(self, input): 326 | if input.dim() != 5: 327 | raise ValueError('expected 5D input (got {}D input)' 328 | .format(input.dim())) 329 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 330 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/mit_semseg/lib/nn/modules/tests/__init__.py -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /mit_semseg/lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .dataset import Dataset, TensorDataset, ConcatDataset 3 | from .dataloader import DataLoader 4 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch._C import _set_worker_signal_handlers, \ 4 | _remove_worker_pids, _error_if_any_worker_fails 5 | try: 6 | from torch._C import _set_worker_pids 7 | except: 8 | from torch._C import _update_worker_pids as _set_worker_pids 9 | from .sampler import SequentialSampler, RandomSampler, BatchSampler 10 | import signal 11 | import collections 12 | import re 13 | import sys 14 | import threading 15 | import traceback 16 | from torch._six import string_classes, int_classes 17 | import numpy as np 18 | 19 | if sys.version_info[0] == 2: 20 | import Queue as queue 21 | else: 22 | import queue 23 | 24 | 25 | class ExceptionWrapper(object): 26 | r"Wraps an exception plus traceback to communicate across threads" 27 | 28 | def __init__(self, exc_info): 29 | self.exc_type = exc_info[0] 30 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 31 | 32 | 33 | _use_shared_memory = False 34 | """Whether to use shared memory in default_collate""" 35 | 36 | 37 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id): 38 | global _use_shared_memory 39 | _use_shared_memory = True 40 | 41 | # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal 42 | # module's handlers are executed after Python returns from C low-level 43 | # handlers, likely when the same fatal signal happened again already. 44 | # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1 45 | _set_worker_signal_handlers() 46 | 47 | torch.set_num_threads(1) 48 | torch.manual_seed(seed) 49 | np.random.seed(seed) 50 | 51 | if init_fn is not None: 52 | init_fn(worker_id) 53 | 54 | while True: 55 | r = index_queue.get() 56 | if r is None: 57 | break 58 | idx, batch_indices = r 59 | try: 60 | samples = collate_fn([dataset[i] for i in batch_indices]) 61 | except Exception: 62 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 63 | else: 64 | data_queue.put((idx, samples)) 65 | 66 | 67 | def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id): 68 | if pin_memory: 69 | torch.cuda.set_device(device_id) 70 | 71 | while True: 72 | try: 73 | r = in_queue.get() 74 | except Exception: 75 | if done_event.is_set(): 76 | return 77 | raise 78 | if r is None: 79 | break 80 | if isinstance(r[1], ExceptionWrapper): 81 | out_queue.put(r) 82 | continue 83 | idx, batch = r 84 | try: 85 | if pin_memory: 86 | batch = pin_memory_batch(batch) 87 | except Exception: 88 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 89 | else: 90 | out_queue.put((idx, batch)) 91 | 92 | numpy_type_map = { 93 | 'float64': torch.DoubleTensor, 94 | 'float32': torch.FloatTensor, 95 | 'float16': torch.HalfTensor, 96 | 'int64': torch.LongTensor, 97 | 'int32': torch.IntTensor, 98 | 'int16': torch.ShortTensor, 99 | 'int8': torch.CharTensor, 100 | 'uint8': torch.ByteTensor, 101 | } 102 | 103 | 104 | def default_collate(batch): 105 | "Puts each data field into a tensor with outer dimension batch size" 106 | 107 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 108 | elem_type = type(batch[0]) 109 | if torch.is_tensor(batch[0]): 110 | out = None 111 | if _use_shared_memory: 112 | # If we're in a background process, concatenate directly into a 113 | # shared memory tensor to avoid an extra copy 114 | numel = sum([x.numel() for x in batch]) 115 | storage = batch[0].storage()._new_shared(numel) 116 | out = batch[0].new(storage) 117 | return torch.stack(batch, 0, out=out) 118 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 119 | and elem_type.__name__ != 'string_': 120 | elem = batch[0] 121 | if elem_type.__name__ == 'ndarray': 122 | # array of string classes and object 123 | if re.search('[SaUO]', elem.dtype.str) is not None: 124 | raise TypeError(error_msg.format(elem.dtype)) 125 | 126 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 127 | if elem.shape == (): # scalars 128 | py_type = float if elem.dtype.name.startswith('float') else int 129 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 130 | elif isinstance(batch[0], int_classes): 131 | return torch.LongTensor(batch) 132 | elif isinstance(batch[0], float): 133 | return torch.DoubleTensor(batch) 134 | elif isinstance(batch[0], string_classes): 135 | return batch 136 | elif isinstance(batch[0], collections.Mapping): 137 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 138 | elif isinstance(batch[0], collections.Sequence): 139 | transposed = zip(*batch) 140 | return [default_collate(samples) for samples in transposed] 141 | 142 | raise TypeError((error_msg.format(type(batch[0])))) 143 | 144 | 145 | def pin_memory_batch(batch): 146 | if torch.is_tensor(batch): 147 | return batch.pin_memory() 148 | elif isinstance(batch, string_classes): 149 | return batch 150 | elif isinstance(batch, collections.Mapping): 151 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 152 | elif isinstance(batch, collections.Sequence): 153 | return [pin_memory_batch(sample) for sample in batch] 154 | else: 155 | return batch 156 | 157 | 158 | _SIGCHLD_handler_set = False 159 | """Whether SIGCHLD handler is set for DataLoader worker failures. Only one 160 | handler needs to be set for all DataLoaders in a process.""" 161 | 162 | 163 | def _set_SIGCHLD_handler(): 164 | # Windows doesn't support SIGCHLD handler 165 | if sys.platform == 'win32': 166 | return 167 | # can't set signal in child threads 168 | if not isinstance(threading.current_thread(), threading._MainThread): 169 | return 170 | global _SIGCHLD_handler_set 171 | if _SIGCHLD_handler_set: 172 | return 173 | previous_handler = signal.getsignal(signal.SIGCHLD) 174 | if not callable(previous_handler): 175 | previous_handler = None 176 | 177 | def handler(signum, frame): 178 | # This following call uses `waitid` with WNOHANG from C side. Therefore, 179 | # Python can still get and update the process status successfully. 180 | _error_if_any_worker_fails() 181 | if previous_handler is not None: 182 | previous_handler(signum, frame) 183 | 184 | signal.signal(signal.SIGCHLD, handler) 185 | _SIGCHLD_handler_set = True 186 | 187 | 188 | class DataLoaderIter(object): 189 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 190 | 191 | def __init__(self, loader): 192 | self.dataset = loader.dataset 193 | self.collate_fn = loader.collate_fn 194 | self.batch_sampler = loader.batch_sampler 195 | self.num_workers = loader.num_workers 196 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 197 | self.timeout = loader.timeout 198 | self.done_event = threading.Event() 199 | 200 | self.sample_iter = iter(self.batch_sampler) 201 | 202 | if self.num_workers > 0: 203 | self.worker_init_fn = loader.worker_init_fn 204 | self.index_queue = multiprocessing.SimpleQueue() 205 | self.worker_result_queue = multiprocessing.SimpleQueue() 206 | self.batches_outstanding = 0 207 | self.worker_pids_set = False 208 | self.shutdown = False 209 | self.send_idx = 0 210 | self.rcvd_idx = 0 211 | self.reorder_dict = {} 212 | 213 | base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0] 214 | self.workers = [ 215 | multiprocessing.Process( 216 | target=_worker_loop, 217 | args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, 218 | base_seed + i, self.worker_init_fn, i)) 219 | for i in range(self.num_workers)] 220 | 221 | if self.pin_memory or self.timeout > 0: 222 | self.data_queue = queue.Queue() 223 | if self.pin_memory: 224 | maybe_device_id = torch.cuda.current_device() 225 | else: 226 | # do not initialize cuda context if not necessary 227 | maybe_device_id = None 228 | self.worker_manager_thread = threading.Thread( 229 | target=_worker_manager_loop, 230 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 231 | maybe_device_id)) 232 | self.worker_manager_thread.daemon = True 233 | self.worker_manager_thread.start() 234 | else: 235 | self.data_queue = self.worker_result_queue 236 | 237 | for w in self.workers: 238 | w.daemon = True # ensure that the worker exits on process exit 239 | w.start() 240 | 241 | _set_worker_pids(id(self), tuple(w.pid for w in self.workers)) 242 | _set_SIGCHLD_handler() 243 | self.worker_pids_set = True 244 | 245 | # prime the prefetch loop 246 | for _ in range(2 * self.num_workers): 247 | self._put_indices() 248 | 249 | def __len__(self): 250 | return len(self.batch_sampler) 251 | 252 | def _get_batch(self): 253 | if self.timeout > 0: 254 | try: 255 | return self.data_queue.get(timeout=self.timeout) 256 | except queue.Empty: 257 | raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout)) 258 | else: 259 | return self.data_queue.get() 260 | 261 | def __next__(self): 262 | if self.num_workers == 0: # same-process loading 263 | indices = next(self.sample_iter) # may raise StopIteration 264 | batch = self.collate_fn([self.dataset[i] for i in indices]) 265 | if self.pin_memory: 266 | batch = pin_memory_batch(batch) 267 | return batch 268 | 269 | # check if the next sample has already been generated 270 | if self.rcvd_idx in self.reorder_dict: 271 | batch = self.reorder_dict.pop(self.rcvd_idx) 272 | return self._process_next_batch(batch) 273 | 274 | if self.batches_outstanding == 0: 275 | self._shutdown_workers() 276 | raise StopIteration 277 | 278 | while True: 279 | assert (not self.shutdown and self.batches_outstanding > 0) 280 | idx, batch = self._get_batch() 281 | self.batches_outstanding -= 1 282 | if idx != self.rcvd_idx: 283 | # store out-of-order samples 284 | self.reorder_dict[idx] = batch 285 | continue 286 | return self._process_next_batch(batch) 287 | 288 | next = __next__ # Python 2 compatibility 289 | 290 | def __iter__(self): 291 | return self 292 | 293 | def _put_indices(self): 294 | assert self.batches_outstanding < 2 * self.num_workers 295 | indices = next(self.sample_iter, None) 296 | if indices is None: 297 | return 298 | self.index_queue.put((self.send_idx, indices)) 299 | self.batches_outstanding += 1 300 | self.send_idx += 1 301 | 302 | def _process_next_batch(self, batch): 303 | self.rcvd_idx += 1 304 | self._put_indices() 305 | if isinstance(batch, ExceptionWrapper): 306 | raise batch.exc_type(batch.exc_msg) 307 | return batch 308 | 309 | def __getstate__(self): 310 | # TODO: add limited pickling support for sharing an iterator 311 | # across multiple threads for HOGWILD. 312 | # Probably the best way to do this is by moving the sample pushing 313 | # to a separate thread and then just sharing the data queue 314 | # but signalling the end is tricky without a non-blocking API 315 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 316 | 317 | def _shutdown_workers(self): 318 | try: 319 | if not self.shutdown: 320 | self.shutdown = True 321 | self.done_event.set() 322 | # if worker_manager_thread is waiting to put 323 | while not self.data_queue.empty(): 324 | self.data_queue.get() 325 | for _ in self.workers: 326 | self.index_queue.put(None) 327 | # done_event should be sufficient to exit worker_manager_thread, 328 | # but be safe here and put another None 329 | self.worker_result_queue.put(None) 330 | finally: 331 | # removes pids no matter what 332 | if self.worker_pids_set: 333 | _remove_worker_pids(id(self)) 334 | self.worker_pids_set = False 335 | 336 | def __del__(self): 337 | if self.num_workers > 0: 338 | self._shutdown_workers() 339 | 340 | 341 | class DataLoader(object): 342 | """ 343 | Data loader. Combines a dataset and a sampler, and provides 344 | single- or multi-process iterators over the dataset. 345 | 346 | Arguments: 347 | dataset (Dataset): dataset from which to load the data. 348 | batch_size (int, optional): how many samples per batch to load 349 | (default: 1). 350 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 351 | at every epoch (default: False). 352 | sampler (Sampler, optional): defines the strategy to draw samples from 353 | the dataset. If specified, ``shuffle`` must be False. 354 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 355 | indices at a time. Mutually exclusive with batch_size, shuffle, 356 | sampler, and drop_last. 357 | num_workers (int, optional): how many subprocesses to use for data 358 | loading. 0 means that the data will be loaded in the main process. 359 | (default: 0) 360 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 361 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 362 | into CUDA pinned memory before returning them. 363 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 364 | if the dataset size is not divisible by the batch size. If ``False`` and 365 | the size of dataset is not divisible by the batch size, then the last batch 366 | will be smaller. (default: False) 367 | timeout (numeric, optional): if positive, the timeout value for collecting a batch 368 | from workers. Should always be non-negative. (default: 0) 369 | worker_init_fn (callable, optional): If not None, this will be called on each 370 | worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 371 | input, after seeding and before data loading. (default: None) 372 | 373 | .. note:: By default, each worker will have its PyTorch seed set to 374 | ``base_seed + worker_id``, where ``base_seed`` is a long generated 375 | by main process using its RNG. You may use ``torch.initial_seed()`` to access 376 | this value in :attr:`worker_init_fn`, which can be used to set other seeds 377 | (e.g. NumPy) before data loading. 378 | 379 | .. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an 380 | unpicklable object, e.g., a lambda function. 381 | """ 382 | 383 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 384 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False, 385 | timeout=0, worker_init_fn=None): 386 | self.dataset = dataset 387 | self.batch_size = batch_size 388 | self.num_workers = num_workers 389 | self.collate_fn = collate_fn 390 | self.pin_memory = pin_memory 391 | self.drop_last = drop_last 392 | self.timeout = timeout 393 | self.worker_init_fn = worker_init_fn 394 | 395 | if timeout < 0: 396 | raise ValueError('timeout option should be non-negative') 397 | 398 | if batch_sampler is not None: 399 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 400 | raise ValueError('batch_sampler is mutually exclusive with ' 401 | 'batch_size, shuffle, sampler, and drop_last') 402 | 403 | if sampler is not None and shuffle: 404 | raise ValueError('sampler is mutually exclusive with shuffle') 405 | 406 | if self.num_workers < 0: 407 | raise ValueError('num_workers cannot be negative; ' 408 | 'use num_workers=0 to disable multiprocessing.') 409 | 410 | if batch_sampler is None: 411 | if sampler is None: 412 | if shuffle: 413 | sampler = RandomSampler(dataset) 414 | else: 415 | sampler = SequentialSampler(dataset) 416 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 417 | 418 | self.sampler = sampler 419 | self.batch_sampler = batch_sampler 420 | 421 | def __iter__(self): 422 | return DataLoaderIter(self) 423 | 424 | def __len__(self): 425 | return len(self.batch_sampler) 426 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import warnings 3 | 4 | from torch._utils import _accumulate 5 | from torch import randperm 6 | 7 | 8 | class Dataset(object): 9 | """An abstract class representing a Dataset. 10 | 11 | All other datasets should subclass it. All subclasses should override 12 | ``__len__``, that provides the size of the dataset, and ``__getitem__``, 13 | supporting integer indexing in range from 0 to len(self) exclusive. 14 | """ 15 | 16 | def __getitem__(self, index): 17 | raise NotImplementedError 18 | 19 | def __len__(self): 20 | raise NotImplementedError 21 | 22 | def __add__(self, other): 23 | return ConcatDataset([self, other]) 24 | 25 | 26 | class TensorDataset(Dataset): 27 | """Dataset wrapping data and target tensors. 28 | 29 | Each sample will be retrieved by indexing both tensors along the first 30 | dimension. 31 | 32 | Arguments: 33 | data_tensor (Tensor): contains sample data. 34 | target_tensor (Tensor): contains sample targets (labels). 35 | """ 36 | 37 | def __init__(self, data_tensor, target_tensor): 38 | assert data_tensor.size(0) == target_tensor.size(0) 39 | self.data_tensor = data_tensor 40 | self.target_tensor = target_tensor 41 | 42 | def __getitem__(self, index): 43 | return self.data_tensor[index], self.target_tensor[index] 44 | 45 | def __len__(self): 46 | return self.data_tensor.size(0) 47 | 48 | 49 | class ConcatDataset(Dataset): 50 | """ 51 | Dataset to concatenate multiple datasets. 52 | Purpose: useful to assemble different existing datasets, possibly 53 | large-scale datasets as the concatenation operation is done in an 54 | on-the-fly manner. 55 | 56 | Arguments: 57 | datasets (iterable): List of datasets to be concatenated 58 | """ 59 | 60 | @staticmethod 61 | def cumsum(sequence): 62 | r, s = [], 0 63 | for e in sequence: 64 | l = len(e) 65 | r.append(l + s) 66 | s += l 67 | return r 68 | 69 | def __init__(self, datasets): 70 | super(ConcatDataset, self).__init__() 71 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 72 | self.datasets = list(datasets) 73 | self.cumulative_sizes = self.cumsum(self.datasets) 74 | 75 | def __len__(self): 76 | return self.cumulative_sizes[-1] 77 | 78 | def __getitem__(self, idx): 79 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 80 | if dataset_idx == 0: 81 | sample_idx = idx 82 | else: 83 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 84 | return self.datasets[dataset_idx][sample_idx] 85 | 86 | @property 87 | def cummulative_sizes(self): 88 | warnings.warn("cummulative_sizes attribute is renamed to " 89 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 90 | return self.cumulative_sizes 91 | 92 | 93 | class Subset(Dataset): 94 | def __init__(self, dataset, indices): 95 | self.dataset = dataset 96 | self.indices = indices 97 | 98 | def __getitem__(self, idx): 99 | return self.dataset[self.indices[idx]] 100 | 101 | def __len__(self): 102 | return len(self.indices) 103 | 104 | 105 | def random_split(dataset, lengths): 106 | """ 107 | Randomly split a dataset into non-overlapping new datasets of given lengths 108 | ds 109 | 110 | Arguments: 111 | dataset (Dataset): Dataset to be split 112 | lengths (iterable): lengths of splits to be produced 113 | """ 114 | if sum(lengths) != len(dataset): 115 | raise ValueError("Sum of input lengths does not equal the length of the input dataset!") 116 | 117 | indices = randperm(sum(lengths)) 118 | return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)] 119 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/data/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from .sampler import Sampler 4 | from torch.distributed import get_world_size, get_rank 5 | 6 | 7 | class DistributedSampler(Sampler): 8 | """Sampler that restricts data loading to a subset of the dataset. 9 | 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | 15 | .. note:: 16 | Dataset is assumed to be of constant size. 17 | 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None): 26 | if num_replicas is None: 27 | num_replicas = get_world_size() 28 | if rank is None: 29 | rank = get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | 37 | def __iter__(self): 38 | # deterministically shuffle based on epoch 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch) 41 | indices = list(torch.randperm(len(self.dataset), generator=g)) 42 | 43 | # add extra samples to make it evenly divisible 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | offset = self.num_samples * self.rank 49 | indices = indices[offset:offset + self.num_samples] 50 | assert len(indices) == self.num_samples 51 | 52 | return iter(indices) 53 | 54 | def __len__(self): 55 | return self.num_samples 56 | 57 | def set_epoch(self, epoch): 58 | self.epoch = epoch 59 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Sampler(object): 5 | """Base class for all Samplers. 6 | 7 | Every Sampler subclass has to provide an __iter__ method, providing a way 8 | to iterate over indices of dataset elements, and a __len__ method that 9 | returns the length of the returned iterators. 10 | """ 11 | 12 | def __init__(self, data_source): 13 | pass 14 | 15 | def __iter__(self): 16 | raise NotImplementedError 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | 22 | class SequentialSampler(Sampler): 23 | """Samples elements sequentially, always in the same order. 24 | 25 | Arguments: 26 | data_source (Dataset): dataset to sample from 27 | """ 28 | 29 | def __init__(self, data_source): 30 | self.data_source = data_source 31 | 32 | def __iter__(self): 33 | return iter(range(len(self.data_source))) 34 | 35 | def __len__(self): 36 | return len(self.data_source) 37 | 38 | 39 | class RandomSampler(Sampler): 40 | """Samples elements randomly, without replacement. 41 | 42 | Arguments: 43 | data_source (Dataset): dataset to sample from 44 | """ 45 | 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | 49 | def __iter__(self): 50 | return iter(torch.randperm(len(self.data_source)).long()) 51 | 52 | def __len__(self): 53 | return len(self.data_source) 54 | 55 | 56 | class SubsetRandomSampler(Sampler): 57 | """Samples elements randomly from a given list of indices, without replacement. 58 | 59 | Arguments: 60 | indices (list): a list of indices 61 | """ 62 | 63 | def __init__(self, indices): 64 | self.indices = indices 65 | 66 | def __iter__(self): 67 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 68 | 69 | def __len__(self): 70 | return len(self.indices) 71 | 72 | 73 | class WeightedRandomSampler(Sampler): 74 | """Samples elements from [0,..,len(weights)-1] with given probabilities (weights). 75 | 76 | Arguments: 77 | weights (list) : a list of weights, not necessary summing up to one 78 | num_samples (int): number of samples to draw 79 | replacement (bool): if ``True``, samples are drawn with replacement. 80 | If not, they are drawn without replacement, which means that when a 81 | sample index is drawn for a row, it cannot be drawn again for that row. 82 | """ 83 | 84 | def __init__(self, weights, num_samples, replacement=True): 85 | self.weights = torch.DoubleTensor(weights) 86 | self.num_samples = num_samples 87 | self.replacement = replacement 88 | 89 | def __iter__(self): 90 | return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 91 | 92 | def __len__(self): 93 | return self.num_samples 94 | 95 | 96 | class BatchSampler(object): 97 | """Wraps another sampler to yield a mini-batch of indices. 98 | 99 | Args: 100 | sampler (Sampler): Base sampler. 101 | batch_size (int): Size of mini-batch. 102 | drop_last (bool): If ``True``, the sampler will drop the last batch if 103 | its size would be less than ``batch_size`` 104 | 105 | Example: 106 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) 107 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 108 | >>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) 109 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 110 | """ 111 | 112 | def __init__(self, sampler, batch_size, drop_last): 113 | self.sampler = sampler 114 | self.batch_size = batch_size 115 | self.drop_last = drop_last 116 | 117 | def __iter__(self): 118 | batch = [] 119 | for idx in self.sampler: 120 | batch.append(idx) 121 | if len(batch) == self.batch_size: 122 | yield batch 123 | batch = [] 124 | if len(batch) > 0 and not self.drop_last: 125 | yield batch 126 | 127 | def __len__(self): 128 | if self.drop_last: 129 | return len(self.sampler) // self.batch_size 130 | else: 131 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size 132 | -------------------------------------------------------------------------------- /mit_semseg/lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /mit_semseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ModelBuilder, SegmentationModule 2 | -------------------------------------------------------------------------------- /mit_semseg/models/hrnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This HRNet implementation is modified from the following repository: 3 | https://github.com/HRNet/HRNet-Semantic-Segmentation 4 | """ 5 | 6 | import logging 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from .utils import load_url 11 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d 12 | 13 | BatchNorm2d = SynchronizedBatchNorm2d 14 | BN_MOMENTUM = 0.1 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | __all__ = ['hrnetv2'] 19 | 20 | 21 | model_urls = { 22 | 'hrnetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(BasicBlock, self).__init__() 37 | self.conv1 = conv3x3(inplanes, planes, stride) 38 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = conv3x3(planes, planes) 41 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | 55 | if self.downsample is not None: 56 | residual = self.downsample(x) 57 | 58 | out += residual 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, inplanes, planes, stride=1, downsample=None): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 71 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 74 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 75 | bias=False) 76 | self.bn3 = BatchNorm2d(planes * self.expansion, 77 | momentum=BN_MOMENTUM) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class HighResolutionModule(nn.Module): 106 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 107 | num_channels, fuse_method, multi_scale_output=True): 108 | super(HighResolutionModule, self).__init__() 109 | self._check_branches( 110 | num_branches, blocks, num_blocks, num_inchannels, num_channels) 111 | 112 | self.num_inchannels = num_inchannels 113 | self.fuse_method = fuse_method 114 | self.num_branches = num_branches 115 | 116 | self.multi_scale_output = multi_scale_output 117 | 118 | self.branches = self._make_branches( 119 | num_branches, blocks, num_blocks, num_channels) 120 | self.fuse_layers = self._make_fuse_layers() 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | def _check_branches(self, num_branches, blocks, num_blocks, 124 | num_inchannels, num_channels): 125 | if num_branches != len(num_blocks): 126 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 127 | num_branches, len(num_blocks)) 128 | logger.error(error_msg) 129 | raise ValueError(error_msg) 130 | 131 | if num_branches != len(num_channels): 132 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 133 | num_branches, len(num_channels)) 134 | logger.error(error_msg) 135 | raise ValueError(error_msg) 136 | 137 | if num_branches != len(num_inchannels): 138 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 139 | num_branches, len(num_inchannels)) 140 | logger.error(error_msg) 141 | raise ValueError(error_msg) 142 | 143 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 144 | stride=1): 145 | downsample = None 146 | if stride != 1 or \ 147 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 148 | downsample = nn.Sequential( 149 | nn.Conv2d(self.num_inchannels[branch_index], 150 | num_channels[branch_index] * block.expansion, 151 | kernel_size=1, stride=stride, bias=False), 152 | BatchNorm2d(num_channels[branch_index] * block.expansion, 153 | momentum=BN_MOMENTUM), 154 | ) 155 | 156 | layers = [] 157 | layers.append(block(self.num_inchannels[branch_index], 158 | num_channels[branch_index], stride, downsample)) 159 | self.num_inchannels[branch_index] = \ 160 | num_channels[branch_index] * block.expansion 161 | for i in range(1, num_blocks[branch_index]): 162 | layers.append(block(self.num_inchannels[branch_index], 163 | num_channels[branch_index])) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 168 | branches = [] 169 | 170 | for i in range(num_branches): 171 | branches.append( 172 | self._make_one_branch(i, block, num_blocks, num_channels)) 173 | 174 | return nn.ModuleList(branches) 175 | 176 | def _make_fuse_layers(self): 177 | if self.num_branches == 1: 178 | return None 179 | 180 | num_branches = self.num_branches 181 | num_inchannels = self.num_inchannels 182 | fuse_layers = [] 183 | for i in range(num_branches if self.multi_scale_output else 1): 184 | fuse_layer = [] 185 | for j in range(num_branches): 186 | if j > i: 187 | fuse_layer.append(nn.Sequential( 188 | nn.Conv2d(num_inchannels[j], 189 | num_inchannels[i], 190 | 1, 191 | 1, 192 | 0, 193 | bias=False), 194 | BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) 195 | elif j == i: 196 | fuse_layer.append(None) 197 | else: 198 | conv3x3s = [] 199 | for k in range(i-j): 200 | if k == i - j - 1: 201 | num_outchannels_conv3x3 = num_inchannels[i] 202 | conv3x3s.append(nn.Sequential( 203 | nn.Conv2d(num_inchannels[j], 204 | num_outchannels_conv3x3, 205 | 3, 2, 1, bias=False), 206 | BatchNorm2d(num_outchannels_conv3x3, 207 | momentum=BN_MOMENTUM))) 208 | else: 209 | num_outchannels_conv3x3 = num_inchannels[j] 210 | conv3x3s.append(nn.Sequential( 211 | nn.Conv2d(num_inchannels[j], 212 | num_outchannels_conv3x3, 213 | 3, 2, 1, bias=False), 214 | BatchNorm2d(num_outchannels_conv3x3, 215 | momentum=BN_MOMENTUM), 216 | nn.ReLU(inplace=True))) 217 | fuse_layer.append(nn.Sequential(*conv3x3s)) 218 | fuse_layers.append(nn.ModuleList(fuse_layer)) 219 | 220 | return nn.ModuleList(fuse_layers) 221 | 222 | def get_num_inchannels(self): 223 | return self.num_inchannels 224 | 225 | def forward(self, x): 226 | if self.num_branches == 1: 227 | return [self.branches[0](x[0])] 228 | 229 | for i in range(self.num_branches): 230 | x[i] = self.branches[i](x[i]) 231 | 232 | x_fuse = [] 233 | for i in range(len(self.fuse_layers)): 234 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 235 | for j in range(1, self.num_branches): 236 | if i == j: 237 | y = y + x[j] 238 | elif j > i: 239 | width_output = x[i].shape[-1] 240 | height_output = x[i].shape[-2] 241 | y = y + F.interpolate( 242 | self.fuse_layers[i][j](x[j]), 243 | size=(height_output, width_output), 244 | mode='bilinear', 245 | align_corners=False) 246 | else: 247 | y = y + self.fuse_layers[i][j](x[j]) 248 | x_fuse.append(self.relu(y)) 249 | 250 | return x_fuse 251 | 252 | 253 | blocks_dict = { 254 | 'BASIC': BasicBlock, 255 | 'BOTTLENECK': Bottleneck 256 | } 257 | 258 | 259 | class HRNetV2(nn.Module): 260 | def __init__(self, n_class, **kwargs): 261 | super(HRNetV2, self).__init__() 262 | extra = { 263 | 'STAGE2': {'NUM_MODULES': 1, 'NUM_BRANCHES': 2, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4), 'NUM_CHANNELS': (48, 96), 'FUSE_METHOD': 'SUM'}, 264 | 'STAGE3': {'NUM_MODULES': 4, 'NUM_BRANCHES': 3, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4), 'NUM_CHANNELS': (48, 96, 192), 'FUSE_METHOD': 'SUM'}, 265 | 'STAGE4': {'NUM_MODULES': 3, 'NUM_BRANCHES': 4, 'BLOCK': 'BASIC', 'NUM_BLOCKS': (4, 4, 4, 4), 'NUM_CHANNELS': (48, 96, 192, 384), 'FUSE_METHOD': 'SUM'}, 266 | 'FINAL_CONV_KERNEL': 1 267 | } 268 | 269 | # stem net 270 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 271 | bias=False) 272 | self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) 273 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, 274 | bias=False) 275 | self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) 276 | self.relu = nn.ReLU(inplace=True) 277 | 278 | self.layer1 = self._make_layer(Bottleneck, 64, 64, 4) 279 | 280 | self.stage2_cfg = extra['STAGE2'] 281 | num_channels = self.stage2_cfg['NUM_CHANNELS'] 282 | block = blocks_dict[self.stage2_cfg['BLOCK']] 283 | num_channels = [ 284 | num_channels[i] * block.expansion for i in range(len(num_channels))] 285 | self.transition1 = self._make_transition_layer([256], num_channels) 286 | self.stage2, pre_stage_channels = self._make_stage( 287 | self.stage2_cfg, num_channels) 288 | 289 | self.stage3_cfg = extra['STAGE3'] 290 | num_channels = self.stage3_cfg['NUM_CHANNELS'] 291 | block = blocks_dict[self.stage3_cfg['BLOCK']] 292 | num_channels = [ 293 | num_channels[i] * block.expansion for i in range(len(num_channels))] 294 | self.transition2 = self._make_transition_layer( 295 | pre_stage_channels, num_channels) 296 | self.stage3, pre_stage_channels = self._make_stage( 297 | self.stage3_cfg, num_channels) 298 | 299 | self.stage4_cfg = extra['STAGE4'] 300 | num_channels = self.stage4_cfg['NUM_CHANNELS'] 301 | block = blocks_dict[self.stage4_cfg['BLOCK']] 302 | num_channels = [ 303 | num_channels[i] * block.expansion for i in range(len(num_channels))] 304 | self.transition3 = self._make_transition_layer( 305 | pre_stage_channels, num_channels) 306 | self.stage4, pre_stage_channels = self._make_stage( 307 | self.stage4_cfg, num_channels, multi_scale_output=True) 308 | 309 | def _make_transition_layer( 310 | self, num_channels_pre_layer, num_channels_cur_layer): 311 | num_branches_cur = len(num_channels_cur_layer) 312 | num_branches_pre = len(num_channels_pre_layer) 313 | 314 | transition_layers = [] 315 | for i in range(num_branches_cur): 316 | if i < num_branches_pre: 317 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 318 | transition_layers.append(nn.Sequential( 319 | nn.Conv2d(num_channels_pre_layer[i], 320 | num_channels_cur_layer[i], 321 | 3, 322 | 1, 323 | 1, 324 | bias=False), 325 | BatchNorm2d( 326 | num_channels_cur_layer[i], momentum=BN_MOMENTUM), 327 | nn.ReLU(inplace=True))) 328 | else: 329 | transition_layers.append(None) 330 | else: 331 | conv3x3s = [] 332 | for j in range(i+1-num_branches_pre): 333 | inchannels = num_channels_pre_layer[-1] 334 | outchannels = num_channels_cur_layer[i] \ 335 | if j == i-num_branches_pre else inchannels 336 | conv3x3s.append(nn.Sequential( 337 | nn.Conv2d( 338 | inchannels, outchannels, 3, 2, 1, bias=False), 339 | BatchNorm2d(outchannels, momentum=BN_MOMENTUM), 340 | nn.ReLU(inplace=True))) 341 | transition_layers.append(nn.Sequential(*conv3x3s)) 342 | 343 | return nn.ModuleList(transition_layers) 344 | 345 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 346 | downsample = None 347 | if stride != 1 or inplanes != planes * block.expansion: 348 | downsample = nn.Sequential( 349 | nn.Conv2d(inplanes, planes * block.expansion, 350 | kernel_size=1, stride=stride, bias=False), 351 | BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 352 | ) 353 | 354 | layers = [] 355 | layers.append(block(inplanes, planes, stride, downsample)) 356 | inplanes = planes * block.expansion 357 | for i in range(1, blocks): 358 | layers.append(block(inplanes, planes)) 359 | 360 | return nn.Sequential(*layers) 361 | 362 | def _make_stage(self, layer_config, num_inchannels, 363 | multi_scale_output=True): 364 | num_modules = layer_config['NUM_MODULES'] 365 | num_branches = layer_config['NUM_BRANCHES'] 366 | num_blocks = layer_config['NUM_BLOCKS'] 367 | num_channels = layer_config['NUM_CHANNELS'] 368 | block = blocks_dict[layer_config['BLOCK']] 369 | fuse_method = layer_config['FUSE_METHOD'] 370 | 371 | modules = [] 372 | for i in range(num_modules): 373 | # multi_scale_output is only used last module 374 | if not multi_scale_output and i == num_modules - 1: 375 | reset_multi_scale_output = False 376 | else: 377 | reset_multi_scale_output = True 378 | modules.append( 379 | HighResolutionModule( 380 | num_branches, 381 | block, 382 | num_blocks, 383 | num_inchannels, 384 | num_channels, 385 | fuse_method, 386 | reset_multi_scale_output) 387 | ) 388 | num_inchannels = modules[-1].get_num_inchannels() 389 | 390 | return nn.Sequential(*modules), num_inchannels 391 | 392 | def forward(self, x, return_feature_maps=False): 393 | x = self.conv1(x) 394 | x = self.bn1(x) 395 | x = self.relu(x) 396 | x = self.conv2(x) 397 | x = self.bn2(x) 398 | x = self.relu(x) 399 | x = self.layer1(x) 400 | 401 | x_list = [] 402 | for i in range(self.stage2_cfg['NUM_BRANCHES']): 403 | if self.transition1[i] is not None: 404 | x_list.append(self.transition1[i](x)) 405 | else: 406 | x_list.append(x) 407 | y_list = self.stage2(x_list) 408 | 409 | x_list = [] 410 | for i in range(self.stage3_cfg['NUM_BRANCHES']): 411 | if self.transition2[i] is not None: 412 | x_list.append(self.transition2[i](y_list[-1])) 413 | else: 414 | x_list.append(y_list[i]) 415 | y_list = self.stage3(x_list) 416 | 417 | x_list = [] 418 | for i in range(self.stage4_cfg['NUM_BRANCHES']): 419 | if self.transition3[i] is not None: 420 | x_list.append(self.transition3[i](y_list[-1])) 421 | else: 422 | x_list.append(y_list[i]) 423 | x = self.stage4(x_list) 424 | 425 | # Upsampling 426 | x0_h, x0_w = x[0].size(2), x[0].size(3) 427 | x1 = F.interpolate( 428 | x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) 429 | x2 = F.interpolate( 430 | x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) 431 | x3 = F.interpolate( 432 | x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=False) 433 | 434 | x = torch.cat([x[0], x1, x2, x3], 1) 435 | 436 | # x = self.last_layer(x) 437 | return [x] 438 | 439 | 440 | def hrnetv2(pretrained=False, **kwargs): 441 | model = HRNetV2(n_class=1000, **kwargs) 442 | if pretrained: 443 | model.load_state_dict(load_url(model_urls['hrnetv2']), strict=False) 444 | 445 | return model 446 | -------------------------------------------------------------------------------- /mit_semseg/models/mobilenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This MobileNetV2 implementation is modified from the following repository: 3 | https://github.com/tonylins/pytorch-mobilenet-v2 4 | """ 5 | 6 | import torch.nn as nn 7 | import math 8 | from .utils import load_url 9 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d 10 | 11 | BatchNorm2d = SynchronizedBatchNorm2d 12 | 13 | 14 | __all__ = ['mobilenetv2'] 15 | 16 | 17 | model_urls = { 18 | 'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar', 19 | } 20 | 21 | 22 | def conv_bn(inp, oup, stride): 23 | return nn.Sequential( 24 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 25 | BatchNorm2d(oup), 26 | nn.ReLU6(inplace=True) 27 | ) 28 | 29 | 30 | def conv_1x1_bn(inp, oup): 31 | return nn.Sequential( 32 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 33 | BatchNorm2d(oup), 34 | nn.ReLU6(inplace=True) 35 | ) 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, expand_ratio): 40 | super(InvertedResidual, self).__init__() 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | hidden_dim = round(inp * expand_ratio) 45 | self.use_res_connect = self.stride == 1 and inp == oup 46 | 47 | if expand_ratio == 1: 48 | self.conv = nn.Sequential( 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 51 | BatchNorm2d(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 55 | BatchNorm2d(oup), 56 | ) 57 | else: 58 | self.conv = nn.Sequential( 59 | # pw 60 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 61 | BatchNorm2d(hidden_dim), 62 | nn.ReLU6(inplace=True), 63 | # dw 64 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 65 | BatchNorm2d(hidden_dim), 66 | nn.ReLU6(inplace=True), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | BatchNorm2d(oup), 70 | ) 71 | 72 | def forward(self, x): 73 | if self.use_res_connect: 74 | return x + self.conv(x) 75 | else: 76 | return self.conv(x) 77 | 78 | 79 | class MobileNetV2(nn.Module): 80 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 81 | super(MobileNetV2, self).__init__() 82 | block = InvertedResidual 83 | input_channel = 32 84 | last_channel = 1280 85 | interverted_residual_setting = [ 86 | # t, c, n, s 87 | [1, 16, 1, 1], 88 | [6, 24, 2, 2], 89 | [6, 32, 3, 2], 90 | [6, 64, 4, 2], 91 | [6, 96, 3, 1], 92 | [6, 160, 3, 2], 93 | [6, 320, 1, 1], 94 | ] 95 | 96 | # building first layer 97 | assert input_size % 32 == 0 98 | input_channel = int(input_channel * width_mult) 99 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 100 | self.features = [conv_bn(3, input_channel, 2)] 101 | # building inverted residual blocks 102 | for t, c, n, s in interverted_residual_setting: 103 | output_channel = int(c * width_mult) 104 | for i in range(n): 105 | if i == 0: 106 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 107 | else: 108 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 109 | input_channel = output_channel 110 | # building last several layers 111 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 112 | # make it nn.Sequential 113 | self.features = nn.Sequential(*self.features) 114 | 115 | # building classifier 116 | self.classifier = nn.Sequential( 117 | nn.Dropout(0.2), 118 | nn.Linear(self.last_channel, n_class), 119 | ) 120 | 121 | self._initialize_weights() 122 | 123 | def forward(self, x): 124 | x = self.features(x) 125 | x = x.mean(3).mean(2) 126 | x = self.classifier(x) 127 | return x 128 | 129 | def _initialize_weights(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | if m.bias is not None: 135 | m.bias.data.zero_() 136 | elif isinstance(m, BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.Linear): 140 | n = m.weight.size(1) 141 | m.weight.data.normal_(0, 0.01) 142 | m.bias.data.zero_() 143 | 144 | 145 | def mobilenetv2(pretrained=False, **kwargs): 146 | """Constructs a MobileNet_V2 model. 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | """ 151 | model = MobileNetV2(n_class=1000, **kwargs) 152 | if pretrained: 153 | model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False) 154 | return model 155 | -------------------------------------------------------------------------------- /mit_semseg/models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import resnet, resnext, mobilenet, hrnet 4 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d 5 | BatchNorm2d = SynchronizedBatchNorm2d 6 | 7 | 8 | class SegmentationModuleBase(nn.Module): 9 | def __init__(self): 10 | super(SegmentationModuleBase, self).__init__() 11 | 12 | def pixel_acc(self, pred, label): 13 | _, preds = torch.max(pred, dim=1) 14 | valid = (label >= 0).long() 15 | acc_sum = torch.sum(valid * (preds == label).long()) 16 | pixel_sum = torch.sum(valid) 17 | acc = acc_sum.float() / (pixel_sum.float() + 1e-10) 18 | return acc 19 | 20 | 21 | class SegmentationModule(SegmentationModuleBase): 22 | def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): 23 | super(SegmentationModule, self).__init__() 24 | self.encoder = net_enc 25 | self.decoder = net_dec 26 | self.crit = crit 27 | self.deep_sup_scale = deep_sup_scale 28 | 29 | def forward(self, feed_dict, *, segSize=None): 30 | # training 31 | if segSize is None: 32 | if self.deep_sup_scale is not None: # use deep supervision technique 33 | (pred, pred_deepsup) = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) 34 | else: 35 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True)) 36 | 37 | loss = self.crit(pred, feed_dict['seg_label']) 38 | if self.deep_sup_scale is not None: 39 | loss_deepsup = self.crit(pred_deepsup, feed_dict['seg_label']) 40 | loss = loss + loss_deepsup * self.deep_sup_scale 41 | 42 | acc = self.pixel_acc(pred, feed_dict['seg_label']) 43 | return loss, acc 44 | # inference 45 | else: 46 | pred = self.decoder(self.encoder(feed_dict['img_data'], return_feature_maps=True), segSize=segSize) 47 | return pred 48 | 49 | 50 | class ModelBuilder: 51 | # custom weights initialization 52 | @staticmethod 53 | def weights_init(m): 54 | classname = m.__class__.__name__ 55 | if classname.find('Conv') != -1: 56 | nn.init.kaiming_normal_(m.weight.data) 57 | elif classname.find('BatchNorm') != -1: 58 | m.weight.data.fill_(1.) 59 | m.bias.data.fill_(1e-4) 60 | #elif classname.find('Linear') != -1: 61 | # m.weight.data.normal_(0.0, 0.0001) 62 | 63 | @staticmethod 64 | def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): 65 | pretrained = True if len(weights) == 0 else False 66 | arch = arch.lower() 67 | if arch == 'mobilenetv2dilated': 68 | orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) 69 | net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) 70 | elif arch == 'resnet18': 71 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 72 | net_encoder = Resnet(orig_resnet) 73 | elif arch == 'resnet18dilated': 74 | orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) 75 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 76 | elif arch == 'resnet34': 77 | raise NotImplementedError 78 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 79 | net_encoder = Resnet(orig_resnet) 80 | elif arch == 'resnet34dilated': 81 | raise NotImplementedError 82 | orig_resnet = resnet.__dict__['resnet34'](pretrained=pretrained) 83 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 84 | elif arch == 'resnet50': 85 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 86 | net_encoder = Resnet(orig_resnet) 87 | elif arch == 'resnet50dilated': 88 | orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) 89 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 90 | elif arch == 'resnet101': 91 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 92 | net_encoder = Resnet(orig_resnet) 93 | elif arch == 'resnet101dilated': 94 | orig_resnet = resnet.__dict__['resnet101'](pretrained=pretrained) 95 | net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) 96 | elif arch == 'resnext101': 97 | orig_resnext = resnext.__dict__['resnext101'](pretrained=pretrained) 98 | net_encoder = Resnet(orig_resnext) # we can still use class Resnet 99 | elif arch == 'hrnetv2': 100 | net_encoder = hrnet.__dict__['hrnetv2'](pretrained=pretrained) 101 | else: 102 | raise Exception('Architecture undefined!') 103 | 104 | # encoders are usually pretrained 105 | # net_encoder.apply(ModelBuilder.weights_init) 106 | if len(weights) > 0: 107 | print('Loading weights for net_encoder') 108 | net_encoder.load_state_dict( 109 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False) 110 | return net_encoder 111 | 112 | @staticmethod 113 | def build_decoder(arch='ppm_deepsup', 114 | fc_dim=512, num_class=150, 115 | weights='', use_softmax=False): 116 | arch = arch.lower() 117 | if arch == 'c1_deepsup': 118 | net_decoder = C1DeepSup( 119 | num_class=num_class, 120 | fc_dim=fc_dim, 121 | use_softmax=use_softmax) 122 | elif arch == 'c1': 123 | net_decoder = C1( 124 | num_class=num_class, 125 | fc_dim=fc_dim, 126 | use_softmax=use_softmax) 127 | elif arch == 'ppm': 128 | net_decoder = PPM( 129 | num_class=num_class, 130 | fc_dim=fc_dim, 131 | use_softmax=use_softmax) 132 | elif arch == 'ppm_deepsup': 133 | net_decoder = PPMDeepsup( 134 | num_class=num_class, 135 | fc_dim=fc_dim, 136 | use_softmax=use_softmax) 137 | elif arch == 'upernet_lite': 138 | net_decoder = UPerNet( 139 | num_class=num_class, 140 | fc_dim=fc_dim, 141 | use_softmax=use_softmax, 142 | fpn_dim=256) 143 | elif arch == 'upernet': 144 | net_decoder = UPerNet( 145 | num_class=num_class, 146 | fc_dim=fc_dim, 147 | use_softmax=use_softmax, 148 | fpn_dim=512) 149 | else: 150 | raise Exception('Architecture undefined!') 151 | 152 | net_decoder.apply(ModelBuilder.weights_init) 153 | if len(weights) > 0: 154 | print('Loading weights for net_decoder') 155 | net_decoder.load_state_dict( 156 | torch.load(weights, map_location=lambda storage, loc: storage), strict=False) 157 | return net_decoder 158 | 159 | 160 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 161 | "3x3 convolution + BN + relu" 162 | return nn.Sequential( 163 | nn.Conv2d(in_planes, out_planes, kernel_size=3, 164 | stride=stride, padding=1, bias=False), 165 | BatchNorm2d(out_planes), 166 | nn.ReLU(inplace=True), 167 | ) 168 | 169 | 170 | class Resnet(nn.Module): 171 | def __init__(self, orig_resnet): 172 | super(Resnet, self).__init__() 173 | 174 | # take pretrained resnet, except AvgPool and FC 175 | self.conv1 = orig_resnet.conv1 176 | self.bn1 = orig_resnet.bn1 177 | self.relu1 = orig_resnet.relu1 178 | self.conv2 = orig_resnet.conv2 179 | self.bn2 = orig_resnet.bn2 180 | self.relu2 = orig_resnet.relu2 181 | self.conv3 = orig_resnet.conv3 182 | self.bn3 = orig_resnet.bn3 183 | self.relu3 = orig_resnet.relu3 184 | self.maxpool = orig_resnet.maxpool 185 | self.layer1 = orig_resnet.layer1 186 | self.layer2 = orig_resnet.layer2 187 | self.layer3 = orig_resnet.layer3 188 | self.layer4 = orig_resnet.layer4 189 | 190 | def forward(self, x, return_feature_maps=False): 191 | conv_out = [] 192 | 193 | x = self.relu1(self.bn1(self.conv1(x))) 194 | x = self.relu2(self.bn2(self.conv2(x))) 195 | x = self.relu3(self.bn3(self.conv3(x))) 196 | x = self.maxpool(x) 197 | 198 | x = self.layer1(x); conv_out.append(x); 199 | x = self.layer2(x); conv_out.append(x); 200 | x = self.layer3(x); conv_out.append(x); 201 | x = self.layer4(x); conv_out.append(x); 202 | 203 | if return_feature_maps: 204 | return conv_out 205 | return [x] 206 | 207 | 208 | class ResnetDilated(nn.Module): 209 | def __init__(self, orig_resnet, dilate_scale=8): 210 | super(ResnetDilated, self).__init__() 211 | from functools import partial 212 | 213 | if dilate_scale == 8: 214 | orig_resnet.layer3.apply( 215 | partial(self._nostride_dilate, dilate=2)) 216 | orig_resnet.layer4.apply( 217 | partial(self._nostride_dilate, dilate=4)) 218 | elif dilate_scale == 16: 219 | orig_resnet.layer4.apply( 220 | partial(self._nostride_dilate, dilate=2)) 221 | 222 | # take pretrained resnet, except AvgPool and FC 223 | self.conv1 = orig_resnet.conv1 224 | self.bn1 = orig_resnet.bn1 225 | self.relu1 = orig_resnet.relu1 226 | self.conv2 = orig_resnet.conv2 227 | self.bn2 = orig_resnet.bn2 228 | self.relu2 = orig_resnet.relu2 229 | self.conv3 = orig_resnet.conv3 230 | self.bn3 = orig_resnet.bn3 231 | self.relu3 = orig_resnet.relu3 232 | self.maxpool = orig_resnet.maxpool 233 | self.layer1 = orig_resnet.layer1 234 | self.layer2 = orig_resnet.layer2 235 | self.layer3 = orig_resnet.layer3 236 | self.layer4 = orig_resnet.layer4 237 | 238 | def _nostride_dilate(self, m, dilate): 239 | classname = m.__class__.__name__ 240 | if classname.find('Conv') != -1: 241 | # the convolution with stride 242 | if m.stride == (2, 2): 243 | m.stride = (1, 1) 244 | if m.kernel_size == (3, 3): 245 | m.dilation = (dilate//2, dilate//2) 246 | m.padding = (dilate//2, dilate//2) 247 | # other convoluions 248 | else: 249 | if m.kernel_size == (3, 3): 250 | m.dilation = (dilate, dilate) 251 | m.padding = (dilate, dilate) 252 | 253 | def forward(self, x, return_feature_maps=False): 254 | conv_out = [] 255 | 256 | x = self.relu1(self.bn1(self.conv1(x))) 257 | x = self.relu2(self.bn2(self.conv2(x))) 258 | x = self.relu3(self.bn3(self.conv3(x))) 259 | x = self.maxpool(x) 260 | 261 | x = self.layer1(x); conv_out.append(x); 262 | x = self.layer2(x); conv_out.append(x); 263 | x = self.layer3(x); conv_out.append(x); 264 | x = self.layer4(x); conv_out.append(x); 265 | 266 | if return_feature_maps: 267 | return conv_out 268 | return [x] 269 | 270 | 271 | class MobileNetV2Dilated(nn.Module): 272 | def __init__(self, orig_net, dilate_scale=8): 273 | super(MobileNetV2Dilated, self).__init__() 274 | from functools import partial 275 | 276 | # take pretrained mobilenet features 277 | self.features = orig_net.features[:-1] 278 | 279 | self.total_idx = len(self.features) 280 | self.down_idx = [2, 4, 7, 14] 281 | 282 | if dilate_scale == 8: 283 | for i in range(self.down_idx[-2], self.down_idx[-1]): 284 | self.features[i].apply( 285 | partial(self._nostride_dilate, dilate=2) 286 | ) 287 | for i in range(self.down_idx[-1], self.total_idx): 288 | self.features[i].apply( 289 | partial(self._nostride_dilate, dilate=4) 290 | ) 291 | elif dilate_scale == 16: 292 | for i in range(self.down_idx[-1], self.total_idx): 293 | self.features[i].apply( 294 | partial(self._nostride_dilate, dilate=2) 295 | ) 296 | 297 | def _nostride_dilate(self, m, dilate): 298 | classname = m.__class__.__name__ 299 | if classname.find('Conv') != -1: 300 | # the convolution with stride 301 | if m.stride == (2, 2): 302 | m.stride = (1, 1) 303 | if m.kernel_size == (3, 3): 304 | m.dilation = (dilate//2, dilate//2) 305 | m.padding = (dilate//2, dilate//2) 306 | # other convoluions 307 | else: 308 | if m.kernel_size == (3, 3): 309 | m.dilation = (dilate, dilate) 310 | m.padding = (dilate, dilate) 311 | 312 | def forward(self, x, return_feature_maps=False): 313 | if return_feature_maps: 314 | conv_out = [] 315 | for i in range(self.total_idx): 316 | x = self.features[i](x) 317 | if i in self.down_idx: 318 | conv_out.append(x) 319 | conv_out.append(x) 320 | return conv_out 321 | 322 | else: 323 | return [self.features(x)] 324 | 325 | 326 | # last conv, deep supervision 327 | class C1DeepSup(nn.Module): 328 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): 329 | super(C1DeepSup, self).__init__() 330 | self.use_softmax = use_softmax 331 | 332 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) 333 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 334 | 335 | # last conv 336 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 337 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 338 | 339 | def forward(self, conv_out, segSize=None): 340 | conv5 = conv_out[-1] 341 | 342 | x = self.cbr(conv5) 343 | x = self.conv_last(x) 344 | 345 | if self.use_softmax: # is True during inference 346 | x = nn.functional.interpolate( 347 | x, size=segSize, mode='bilinear', align_corners=False) 348 | x = nn.functional.softmax(x, dim=1) 349 | return x 350 | 351 | # deep sup 352 | conv4 = conv_out[-2] 353 | _ = self.cbr_deepsup(conv4) 354 | _ = self.conv_last_deepsup(_) 355 | 356 | x = nn.functional.log_softmax(x, dim=1) 357 | _ = nn.functional.log_softmax(_, dim=1) 358 | 359 | return (x, _) 360 | 361 | 362 | # last conv 363 | class C1(nn.Module): 364 | def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): 365 | super(C1, self).__init__() 366 | self.use_softmax = use_softmax 367 | 368 | self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) 369 | 370 | # last conv 371 | self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 372 | 373 | def forward(self, conv_out, segSize=None): 374 | conv5 = conv_out[-1] 375 | x = self.cbr(conv5) 376 | x = self.conv_last(x) 377 | 378 | if self.use_softmax: # is True during inference 379 | x = nn.functional.interpolate( 380 | x, size=segSize, mode='bilinear', align_corners=False) 381 | x = nn.functional.softmax(x, dim=1) 382 | else: 383 | x = nn.functional.log_softmax(x, dim=1) 384 | 385 | return x 386 | 387 | 388 | # pyramid pooling 389 | class PPM(nn.Module): 390 | def __init__(self, num_class=150, fc_dim=4096, 391 | use_softmax=False, pool_scales=(1, 2, 3, 6)): 392 | super(PPM, self).__init__() 393 | self.use_softmax = use_softmax 394 | 395 | self.ppm = [] 396 | for scale in pool_scales: 397 | self.ppm.append(nn.Sequential( 398 | nn.AdaptiveAvgPool2d(scale), 399 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 400 | BatchNorm2d(512), 401 | nn.ReLU(inplace=True) 402 | )) 403 | self.ppm = nn.ModuleList(self.ppm) 404 | 405 | self.conv_last = nn.Sequential( 406 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 407 | kernel_size=3, padding=1, bias=False), 408 | BatchNorm2d(512), 409 | nn.ReLU(inplace=True), 410 | nn.Dropout2d(0.1), 411 | nn.Conv2d(512, num_class, kernel_size=1) 412 | ) 413 | 414 | def forward(self, conv_out, segSize=None): 415 | conv5 = conv_out[-1] 416 | 417 | input_size = conv5.size() 418 | ppm_out = [conv5] 419 | for pool_scale in self.ppm: 420 | ppm_out.append(nn.functional.interpolate( 421 | pool_scale(conv5), 422 | (input_size[2], input_size[3]), 423 | mode='bilinear', align_corners=False)) 424 | ppm_out = torch.cat(ppm_out, 1) 425 | 426 | x = self.conv_last(ppm_out) 427 | 428 | if self.use_softmax: # is True during inference 429 | x = nn.functional.interpolate( 430 | x, size=segSize, mode='bilinear', align_corners=False) 431 | x = nn.functional.softmax(x, dim=1) 432 | else: 433 | x = nn.functional.log_softmax(x, dim=1) 434 | return x 435 | 436 | 437 | # pyramid pooling, deep supervision 438 | class PPMDeepsup(nn.Module): 439 | def __init__(self, num_class=150, fc_dim=4096, 440 | use_softmax=False, pool_scales=(1, 2, 3, 6)): 441 | super(PPMDeepsup, self).__init__() 442 | self.use_softmax = use_softmax 443 | 444 | self.ppm = [] 445 | for scale in pool_scales: 446 | self.ppm.append(nn.Sequential( 447 | nn.AdaptiveAvgPool2d(scale), 448 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 449 | BatchNorm2d(512), 450 | nn.ReLU(inplace=True) 451 | )) 452 | self.ppm = nn.ModuleList(self.ppm) 453 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 454 | 455 | self.conv_last = nn.Sequential( 456 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 457 | kernel_size=3, padding=1, bias=False), 458 | BatchNorm2d(512), 459 | nn.ReLU(inplace=True), 460 | nn.Dropout2d(0.1), 461 | nn.Conv2d(512, num_class, kernel_size=1) 462 | ) 463 | self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) 464 | self.dropout_deepsup = nn.Dropout2d(0.1) 465 | 466 | def forward(self, conv_out, segSize=None): 467 | conv5 = conv_out[-1] 468 | 469 | input_size = conv5.size() 470 | ppm_out = [conv5] 471 | for pool_scale in self.ppm: 472 | ppm_out.append(nn.functional.interpolate( 473 | pool_scale(conv5), 474 | (input_size[2], input_size[3]), 475 | mode='bilinear', align_corners=False)) 476 | ppm_out = torch.cat(ppm_out, 1) 477 | 478 | x = self.conv_last(ppm_out) 479 | 480 | if self.use_softmax: # is True during inference 481 | x = nn.functional.interpolate( 482 | x, size=segSize, mode='bilinear', align_corners=False) 483 | x = nn.functional.softmax(x, dim=1) 484 | return x 485 | 486 | # deep sup 487 | conv4 = conv_out[-2] 488 | _ = self.cbr_deepsup(conv4) 489 | _ = self.dropout_deepsup(_) 490 | _ = self.conv_last_deepsup(_) 491 | 492 | x = nn.functional.log_softmax(x, dim=1) 493 | _ = nn.functional.log_softmax(_, dim=1) 494 | 495 | return (x, _) 496 | 497 | 498 | # upernet 499 | class UPerNet(nn.Module): 500 | def __init__(self, num_class=150, fc_dim=4096, 501 | use_softmax=False, pool_scales=(1, 2, 3, 6), 502 | fpn_inplanes=(256, 512, 1024, 2048), fpn_dim=256): 503 | super(UPerNet, self).__init__() 504 | self.use_softmax = use_softmax 505 | 506 | # PPM Module 507 | self.ppm_pooling = [] 508 | self.ppm_conv = [] 509 | 510 | for scale in pool_scales: 511 | self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) 512 | self.ppm_conv.append(nn.Sequential( 513 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 514 | BatchNorm2d(512), 515 | nn.ReLU(inplace=True) 516 | )) 517 | self.ppm_pooling = nn.ModuleList(self.ppm_pooling) 518 | self.ppm_conv = nn.ModuleList(self.ppm_conv) 519 | self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, fpn_dim, 1) 520 | 521 | # FPN Module 522 | self.fpn_in = [] 523 | for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer 524 | self.fpn_in.append(nn.Sequential( 525 | nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), 526 | BatchNorm2d(fpn_dim), 527 | nn.ReLU(inplace=True) 528 | )) 529 | self.fpn_in = nn.ModuleList(self.fpn_in) 530 | 531 | self.fpn_out = [] 532 | for i in range(len(fpn_inplanes) - 1): # skip the top layer 533 | self.fpn_out.append(nn.Sequential( 534 | conv3x3_bn_relu(fpn_dim, fpn_dim, 1), 535 | )) 536 | self.fpn_out = nn.ModuleList(self.fpn_out) 537 | 538 | self.conv_last = nn.Sequential( 539 | conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), 540 | nn.Conv2d(fpn_dim, num_class, kernel_size=1) 541 | ) 542 | 543 | def forward(self, conv_out, segSize=None): 544 | conv5 = conv_out[-1] 545 | 546 | input_size = conv5.size() 547 | ppm_out = [conv5] 548 | for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): 549 | ppm_out.append(pool_conv(nn.functional.interpolate( 550 | pool_scale(conv5), 551 | (input_size[2], input_size[3]), 552 | mode='bilinear', align_corners=False))) 553 | ppm_out = torch.cat(ppm_out, 1) 554 | f = self.ppm_last_conv(ppm_out) 555 | 556 | fpn_feature_list = [f] 557 | for i in reversed(range(len(conv_out) - 1)): 558 | conv_x = conv_out[i] 559 | conv_x = self.fpn_in[i](conv_x) # lateral branch 560 | 561 | f = nn.functional.interpolate( 562 | f, size=conv_x.size()[2:], mode='bilinear', align_corners=False) # top-down branch 563 | f = conv_x + f 564 | 565 | fpn_feature_list.append(self.fpn_out[i](f)) 566 | 567 | fpn_feature_list.reverse() # [P2 - P5] 568 | output_size = fpn_feature_list[0].size()[2:] 569 | fusion_list = [fpn_feature_list[0]] 570 | for i in range(1, len(fpn_feature_list)): 571 | fusion_list.append(nn.functional.interpolate( 572 | fpn_feature_list[i], 573 | output_size, 574 | mode='bilinear', align_corners=False)) 575 | fusion_out = torch.cat(fusion_list, 1) 576 | x = self.conv_last(fusion_out) 577 | 578 | if self.use_softmax: # is True during inference 579 | x = nn.functional.interpolate( 580 | x, size=segSize, mode='bilinear', align_corners=False) 581 | x = nn.functional.softmax(x, dim=1) 582 | return x 583 | 584 | x = nn.functional.log_softmax(x, dim=1) 585 | 586 | return x 587 | -------------------------------------------------------------------------------- /mit_semseg/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import load_url 4 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d 5 | BatchNorm2d = SynchronizedBatchNorm2d 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet50', 'resnet101'] # resnet101 is coming soon! 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet18-imagenet.pth', 13 | 'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth', 14 | 'resnet101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet101-imagenet.pth' 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 128 99 | super(ResNet, self).__init__() 100 | self.conv1 = conv3x3(3, 64, stride=2) 101 | self.bn1 = BatchNorm2d(64) 102 | self.relu1 = nn.ReLU(inplace=True) 103 | self.conv2 = conv3x3(64, 64) 104 | self.bn2 = BatchNorm2d(64) 105 | self.relu2 = nn.ReLU(inplace=True) 106 | self.conv3 = conv3x3(64, 128) 107 | self.bn3 = BatchNorm2d(128) 108 | self.relu3 = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | 111 | self.layer1 = self._make_layer(block, 64, layers[0]) 112 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 113 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 114 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 115 | self.avgpool = nn.AvgPool2d(7, stride=1) 116 | self.fc = nn.Linear(512 * block.expansion, num_classes) 117 | 118 | for m in self.modules(): 119 | if isinstance(m, nn.Conv2d): 120 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 121 | m.weight.data.normal_(0, math.sqrt(2. / n)) 122 | elif isinstance(m, BatchNorm2d): 123 | m.weight.data.fill_(1) 124 | m.bias.data.zero_() 125 | 126 | def _make_layer(self, block, planes, blocks, stride=1): 127 | downsample = None 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | nn.Conv2d(self.inplanes, planes * block.expansion, 131 | kernel_size=1, stride=stride, bias=False), 132 | BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for i in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.relu1(self.bn1(self.conv1(x))) 145 | x = self.relu2(self.bn2(self.conv2(x))) 146 | x = self.relu3(self.bn3(self.conv3(x))) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | x = self.avgpool(x) 155 | x = x.view(x.size(0), -1) 156 | x = self.fc(x) 157 | 158 | return x 159 | 160 | def resnet18(pretrained=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 167 | if pretrained: 168 | model.load_state_dict(load_url(model_urls['resnet18'])) 169 | return model 170 | 171 | ''' 172 | def resnet34(pretrained=False, **kwargs): 173 | """Constructs a ResNet-34 model. 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(load_url(model_urls['resnet34'])) 181 | return model 182 | ''' 183 | 184 | def resnet50(pretrained=False, **kwargs): 185 | """Constructs a ResNet-50 model. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 191 | if pretrained: 192 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 193 | return model 194 | 195 | 196 | def resnet101(pretrained=False, **kwargs): 197 | """Constructs a ResNet-101 model. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 203 | if pretrained: 204 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 205 | return model 206 | 207 | # def resnet152(pretrained=False, **kwargs): 208 | # """Constructs a ResNet-152 model. 209 | # 210 | # Args: 211 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 212 | # """ 213 | # model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 214 | # if pretrained: 215 | # model.load_state_dict(load_url(model_urls['resnet152'])) 216 | # return model 217 | -------------------------------------------------------------------------------- /mit_semseg/models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | from .utils import load_url 4 | from mit_semseg.lib.nn import SynchronizedBatchNorm2d 5 | BatchNorm2d = SynchronizedBatchNorm2d 6 | 7 | 8 | __all__ = ['ResNeXt', 'resnext101'] # support resnext 101 9 | 10 | 11 | model_urls = { 12 | #'resnext50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext50-imagenet.pth', 13 | 'resnext101': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnext101-imagenet.pth' 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class GroupBottleneck(nn.Module): 24 | expansion = 2 25 | 26 | def __init__(self, inplanes, planes, stride=1, groups=1, downsample=None): 27 | super(GroupBottleneck, self).__init__() 28 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 29 | self.bn1 = BatchNorm2d(planes) 30 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 31 | padding=1, groups=groups, bias=False) 32 | self.bn2 = BatchNorm2d(planes) 33 | self.conv3 = nn.Conv2d(planes, planes * 2, kernel_size=1, bias=False) 34 | self.bn3 = BatchNorm2d(planes * 2) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class ResNeXt(nn.Module): 63 | 64 | def __init__(self, block, layers, groups=32, num_classes=1000): 65 | self.inplanes = 128 66 | super(ResNeXt, self).__init__() 67 | self.conv1 = conv3x3(3, 64, stride=2) 68 | self.bn1 = BatchNorm2d(64) 69 | self.relu1 = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(64, 64) 71 | self.bn2 = BatchNorm2d(64) 72 | self.relu2 = nn.ReLU(inplace=True) 73 | self.conv3 = conv3x3(64, 128) 74 | self.bn3 = BatchNorm2d(128) 75 | self.relu3 = nn.ReLU(inplace=True) 76 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 77 | 78 | self.layer1 = self._make_layer(block, 128, layers[0], groups=groups) 79 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, groups=groups) 80 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, groups=groups) 81 | self.layer4 = self._make_layer(block, 1024, layers[3], stride=2, groups=groups) 82 | self.avgpool = nn.AvgPool2d(7, stride=1) 83 | self.fc = nn.Linear(1024 * block.expansion, num_classes) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels // m.groups 88 | m.weight.data.normal_(0, math.sqrt(2. / n)) 89 | elif isinstance(m, BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | def _make_layer(self, block, planes, blocks, stride=1, groups=1): 94 | downsample = None 95 | if stride != 1 or self.inplanes != planes * block.expansion: 96 | downsample = nn.Sequential( 97 | nn.Conv2d(self.inplanes, planes * block.expansion, 98 | kernel_size=1, stride=stride, bias=False), 99 | BatchNorm2d(planes * block.expansion), 100 | ) 101 | 102 | layers = [] 103 | layers.append(block(self.inplanes, planes, stride, groups, downsample)) 104 | self.inplanes = planes * block.expansion 105 | for i in range(1, blocks): 106 | layers.append(block(self.inplanes, planes, groups=groups)) 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | x = self.relu1(self.bn1(self.conv1(x))) 112 | x = self.relu2(self.bn2(self.conv2(x))) 113 | x = self.relu3(self.bn3(self.conv3(x))) 114 | x = self.maxpool(x) 115 | 116 | x = self.layer1(x) 117 | x = self.layer2(x) 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | x = self.fc(x) 124 | 125 | return x 126 | 127 | 128 | ''' 129 | def resnext50(pretrained=False, **kwargs): 130 | """Constructs a ResNet-50 model. 131 | 132 | Args: 133 | pretrained (bool): If True, returns a model pre-trained on Places 134 | """ 135 | model = ResNeXt(GroupBottleneck, [3, 4, 6, 3], **kwargs) 136 | if pretrained: 137 | model.load_state_dict(load_url(model_urls['resnext50']), strict=False) 138 | return model 139 | ''' 140 | 141 | 142 | def resnext101(pretrained=False, **kwargs): 143 | """Constructs a ResNet-101 model. 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on Places 147 | """ 148 | model = ResNeXt(GroupBottleneck, [3, 4, 23, 3], **kwargs) 149 | if pretrained: 150 | model.load_state_dict(load_url(model_urls['resnext101']), strict=False) 151 | return model 152 | 153 | 154 | # def resnext152(pretrained=False, **kwargs): 155 | # """Constructs a ResNeXt-152 model. 156 | # 157 | # Args: 158 | # pretrained (bool): If True, returns a model pre-trained on Places 159 | # """ 160 | # model = ResNeXt(GroupBottleneck, [3, 8, 36, 3], **kwargs) 161 | # if pretrained: 162 | # model.load_state_dict(load_url(model_urls['resnext152'])) 163 | # return model 164 | -------------------------------------------------------------------------------- /mit_semseg/models/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | try: 4 | from urllib import urlretrieve 5 | except ImportError: 6 | from urllib.request import urlretrieve 7 | import torch 8 | 9 | 10 | def load_url(url, model_dir='./pretrained', map_location=None): 11 | if not os.path.exists(model_dir): 12 | os.makedirs(model_dir) 13 | filename = url.split('/')[-1] 14 | cached_file = os.path.join(model_dir, filename) 15 | if not os.path.exists(cached_file): 16 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 17 | urlretrieve(url, cached_file) 18 | return torch.load(cached_file, map_location=map_location) 19 | -------------------------------------------------------------------------------- /mit_semseg/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import logging 4 | import re 5 | import functools 6 | import fnmatch 7 | import numpy as np 8 | 9 | 10 | def setup_logger(distributed_rank=0, filename="log.txt"): 11 | logger = logging.getLogger("Logger") 12 | logger.setLevel(logging.DEBUG) 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | return logger 16 | ch = logging.StreamHandler(stream=sys.stdout) 17 | ch.setLevel(logging.DEBUG) 18 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 19 | ch.setFormatter(logging.Formatter(fmt)) 20 | logger.addHandler(ch) 21 | 22 | return logger 23 | 24 | 25 | def find_recursive(root_dir, ext='.jpg'): 26 | files = [] 27 | for root, dirnames, filenames in os.walk(root_dir): 28 | for filename in fnmatch.filter(filenames, '*' + ext): 29 | files.append(os.path.join(root, filename)) 30 | return files 31 | 32 | 33 | class AverageMeter(object): 34 | """Computes and stores the average and current value""" 35 | def __init__(self): 36 | self.initialized = False 37 | self.val = None 38 | self.avg = None 39 | self.sum = None 40 | self.count = None 41 | 42 | def initialize(self, val, weight): 43 | self.val = val 44 | self.avg = val 45 | self.sum = val * weight 46 | self.count = weight 47 | self.initialized = True 48 | 49 | def update(self, val, weight=1): 50 | if not self.initialized: 51 | self.initialize(val, weight) 52 | else: 53 | self.add(val, weight) 54 | 55 | def add(self, val, weight): 56 | self.val = val 57 | self.sum += val * weight 58 | self.count += weight 59 | self.avg = self.sum / self.count 60 | 61 | def value(self): 62 | return self.val 63 | 64 | def average(self): 65 | return self.avg 66 | 67 | 68 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 69 | ar = np.asanyarray(ar).flatten() 70 | 71 | optional_indices = return_index or return_inverse 72 | optional_returns = optional_indices or return_counts 73 | 74 | if ar.size == 0: 75 | if not optional_returns: 76 | ret = ar 77 | else: 78 | ret = (ar,) 79 | if return_index: 80 | ret += (np.empty(0, np.bool),) 81 | if return_inverse: 82 | ret += (np.empty(0, np.bool),) 83 | if return_counts: 84 | ret += (np.empty(0, np.intp),) 85 | return ret 86 | if optional_indices: 87 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 88 | aux = ar[perm] 89 | else: 90 | ar.sort() 91 | aux = ar 92 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 93 | 94 | if not optional_returns: 95 | ret = aux[flag] 96 | else: 97 | ret = (aux[flag],) 98 | if return_index: 99 | ret += (perm[flag],) 100 | if return_inverse: 101 | iflag = np.cumsum(flag) - 1 102 | inv_idx = np.empty(ar.shape, dtype=np.intp) 103 | inv_idx[perm] = iflag 104 | ret += (inv_idx,) 105 | if return_counts: 106 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 107 | ret += (np.diff(idx),) 108 | return ret 109 | 110 | 111 | def colorEncode(labelmap, colors, mode='RGB'): 112 | labelmap = labelmap.astype('int') 113 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 114 | dtype=np.uint8) 115 | for label in unique(labelmap): 116 | if label < 0: 117 | continue 118 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 119 | np.tile(colors[label], 120 | (labelmap.shape[0], labelmap.shape[1], 1)) 121 | 122 | if mode == 'BGR': 123 | return labelmap_rgb[:, :, ::-1] 124 | else: 125 | return labelmap_rgb 126 | 127 | 128 | def accuracy(preds, label): 129 | valid = (label >= 0) 130 | acc_sum = (valid * (preds == label)).sum() 131 | valid_sum = valid.sum() 132 | acc = float(acc_sum) / (valid_sum + 1e-10) 133 | return acc, valid_sum 134 | 135 | 136 | def intersectionAndUnion(imPred, imLab, numClass): 137 | imPred = np.asarray(imPred).copy() 138 | imLab = np.asarray(imLab).copy() 139 | 140 | imPred += 1 141 | imLab += 1 142 | # Remove classes from unlabeled pixels in gt image. 143 | # We should not penalize detections in unlabeled portions of the image. 144 | imPred = imPred * (imLab > 0) 145 | 146 | # Compute area intersection: 147 | intersection = imPred * (imPred == imLab) 148 | (area_intersection, _) = np.histogram( 149 | intersection, bins=numClass, range=(1, numClass)) 150 | 151 | # Compute area union: 152 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 153 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 154 | area_union = area_pred + area_lab - area_intersection 155 | 156 | return (area_intersection, area_union) 157 | 158 | 159 | class NotSupportedCliException(Exception): 160 | pass 161 | 162 | 163 | def process_range(xpu, inp): 164 | start, end = map(int, inp) 165 | if start > end: 166 | end, start = start, end 167 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 168 | 169 | 170 | REGEX = [ 171 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 172 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 173 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 174 | functools.partial(process_range, 'gpu')), 175 | (re.compile(r'^(\d+)-(\d+)$'), 176 | functools.partial(process_range, 'gpu')), 177 | ] 178 | 179 | 180 | def parse_devices(input_devices): 181 | 182 | """Parse user's devices input str to standard format. 183 | e.g. [gpu0, gpu1, ...] 184 | 185 | """ 186 | ret = [] 187 | for d in input_devices.split(','): 188 | for regex, func in REGEX: 189 | m = regex.match(d.lower().strip()) 190 | if m: 191 | tmp = func(m.groups()) 192 | # prevent duplicate 193 | for x in tmp: 194 | if x not in ret: 195 | ret.append(x) 196 | break 197 | else: 198 | raise NotSupportedCliException( 199 | 'Can not recognize device: "{}"'.format(d)) 200 | return ret 201 | -------------------------------------------------------------------------------- /notebooks/DemoSegmenter.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Semantic Segmentation Demo\n", 8 | "\n", 9 | "This is a notebook for running the benchmark semantic segmentation network from the the [ADE20K MIT Scene Parsing Benchchmark](http://sceneparsing.csail.mit.edu/).\n", 10 | "\n", 11 | "The code for this notebook is available here\n", 12 | "https://github.com/CSAILVision/semantic-segmentation-pytorch/tree/master/notebooks\n", 13 | "\n", 14 | "It can be run on Colab at this URL https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### Environment Setup\n", 22 | "\n", 23 | "First, download the code and pretrained models if we are on colab." 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "%%bash\n", 33 | "# Colab-specific setup\n", 34 | "!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit \n", 35 | "pip install yacs 2>&1 >> install.log\n", 36 | "git init 2>&1 >> install.log\n", 37 | "git remote add origin https://github.com/CSAILVision/semantic-segmentation-pytorch.git 2>> install.log\n", 38 | "git pull origin master 2>&1 >> install.log\n", 39 | "DOWNLOAD_ONLY=1 ./demo_test.sh 2>> install.log" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Imports and utility functions\n", 47 | "\n", 48 | "We need pytorch, numpy, and the code for the segmentation model. And some utilities for visualizing the data." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# System libs\n", 58 | "import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms\n", 59 | "# Our libs\n", 60 | "from mit_semseg.models import ModelBuilder, SegmentationModule\n", 61 | "from mit_semseg.utils import colorEncode\n", 62 | "\n", 63 | "colors = scipy.io.loadmat('data/color150.mat')['colors']\n", 64 | "names = {}\n", 65 | "with open('data/object150_info.csv') as f:\n", 66 | " reader = csv.reader(f)\n", 67 | " next(reader)\n", 68 | " for row in reader:\n", 69 | " names[int(row[0])] = row[5].split(\";\")[0]\n", 70 | "\n", 71 | "def visualize_result(img, pred, index=None):\n", 72 | " # filter prediction class if requested\n", 73 | " if index is not None:\n", 74 | " pred = pred.copy()\n", 75 | " pred[pred != index] = -1\n", 76 | " print(f'{names[index+1]}:')\n", 77 | " \n", 78 | " # colorize prediction\n", 79 | " pred_color = colorEncode(pred, colors).astype(numpy.uint8)\n", 80 | "\n", 81 | " # aggregate images and save\n", 82 | " im_vis = numpy.concatenate((img, pred_color), axis=1)\n", 83 | " display(PIL.Image.fromarray(im_vis))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## Loading the segmentation model\n", 91 | "\n", 92 | "Here we load a pretrained segmentation model. Like any pytorch model, we can call it like a function, or examine the parameters in all the layers.\n", 93 | "\n", 94 | "After loading, we put it on the GPU. And since we are doing inference, not training, we put the model in eval mode." 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Network Builders\n", 104 | "net_encoder = ModelBuilder.build_encoder(\n", 105 | " arch='resnet50dilated',\n", 106 | " fc_dim=2048,\n", 107 | " weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/encoder_epoch_20.pth')\n", 108 | "net_decoder = ModelBuilder.build_decoder(\n", 109 | " arch='ppm_deepsup',\n", 110 | " fc_dim=2048,\n", 111 | " num_class=150,\n", 112 | " weights='ckpt/ade20k-resnet50dilated-ppm_deepsup/decoder_epoch_20.pth',\n", 113 | " use_softmax=True)\n", 114 | "\n", 115 | "crit = torch.nn.NLLLoss(ignore_index=-1)\n", 116 | "segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)\n", 117 | "segmentation_module.eval()\n", 118 | "segmentation_module.cuda()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Load test data\n", 126 | "\n", 127 | "Now we load and normalize a single test image. Here we use the commonplace convention of normalizing the image to a scale for which the RGB values of a large photo dataset would have zero mean and unit standard deviation. (These numbers come from the imagenet dataset.) With this normalization, the limiiting ranges of RGB values are within about (-2.2 to +2.7)." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "# Load and normalize one image as a singleton tensor batch\n", 137 | "pil_to_tensor = torchvision.transforms.Compose([\n", 138 | " torchvision.transforms.ToTensor(),\n", 139 | " torchvision.transforms.Normalize(\n", 140 | " mean=[0.485, 0.456, 0.406], # These are RGB mean+std values\n", 141 | " std=[0.229, 0.224, 0.225]) # across a large photo dataset.\n", 142 | "])\n", 143 | "pil_image = PIL.Image.open('ADE_val_00001519.jpg').convert('RGB')\n", 144 | "img_original = numpy.array(pil_image)\n", 145 | "img_data = pil_to_tensor(pil_image)\n", 146 | "singleton_batch = {'img_data': img_data[None].cuda()}\n", 147 | "output_size = img_data.shape[1:]" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "## Run the Model\n", 155 | "\n", 156 | "Finally we just pass the test image to the segmentation model.\n", 157 | "\n", 158 | "The segmentation model is coded as a function that takes a dictionary as input, because it wants to know both the input batch image data as well as the desired output segmentation resolution. We ask for full resolution output.\n", 159 | "\n", 160 | "Then we use the previously-defined visualize_result function to render the segmentation map." 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "scrolled": false 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "# Run the segmentation at the highest resolution.\n", 172 | "with torch.no_grad():\n", 173 | " scores = segmentation_module(singleton_batch, segSize=output_size)\n", 174 | " \n", 175 | "# Get the predicted scores for each pixel\n", 176 | "_, pred = torch.max(scores, dim=1)\n", 177 | "pred = pred.cpu()[0].numpy()\n", 178 | "visualize_result(img_original, pred)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "## Showing classes individually\n", 186 | "\n", 187 | "To see which colors are which, here we visualize individual classes, one at a time." 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "# Top classes in answer\n", 197 | "predicted_classes = numpy.bincount(pred.flatten()).argsort()[::-1]\n", 198 | "for c in predicted_classes[:15]:\n", 199 | " visualize_result(img_original, pred, c)" 200 | ] 201 | } 202 | ], 203 | "metadata": { 204 | "accelerator": "GPU", 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.6.7" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 2 225 | } -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | Semantic Segmentation Demo 2 | ========================== 3 | 4 | This directory contains a notebook for demonstrating the benchmark 5 | semantic segmentation network from the the ADE20K MIT Scene Parsing 6 | Benchchmark. 7 | 8 | It can be run on Colab at 9 | [this URL](https://colab.research.google.com/github/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb) 10 | or on a local Jupyter notebook. 11 | 12 | If running locally, run the script `setup_notebooks.sh` to start. 13 | -------------------------------------------------------------------------------- /notebooks/ckpt: -------------------------------------------------------------------------------- 1 | ../ckpt -------------------------------------------------------------------------------- /notebooks/config: -------------------------------------------------------------------------------- 1 | ../config -------------------------------------------------------------------------------- /notebooks/data: -------------------------------------------------------------------------------- 1 | ../data -------------------------------------------------------------------------------- /notebooks/ipynb_drop_output.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Suppress output and prompt numbers in git version control. 5 | 6 | This script will tell git to ignore prompt numbers and cell output 7 | when looking at ipynb files UNLESS their metadata contains: 8 | 9 | "git" : { "keep_output" : true } 10 | 11 | The notebooks themselves are not changed. 12 | 13 | See also this blogpost: http://pascalbugnion.net/blog/ipython-notebooks-and-git.html. 14 | 15 | Usage instructions 16 | ================== 17 | 18 | 1. Put this script in a directory that is on the system's path. 19 | For future reference, I will assume you saved it in 20 | `~/scripts/ipynb_drop_output`. 21 | 2. Make sure it is executable by typing the command 22 | `chmod +x ~/scripts/ipynb_drop_output`. 23 | 3. Register a filter for ipython notebooks by 24 | putting the following line in `~/.config/git/attributes`: 25 | `*.ipynb filter=clean_ipynb` 26 | 4. Connect this script to the filter by running the following 27 | git commands: 28 | 29 | git config --global filter.clean_ipynb.clean ipynb_drop_output 30 | git config --global filter.clean_ipynb.smudge cat 31 | 32 | To tell git NOT to ignore the output and prompts for a notebook, 33 | open the notebook's metadata (Edit > Edit Notebook Metadata). A 34 | panel should open containing the lines: 35 | 36 | { 37 | "name" : "", 38 | "signature" : "some very long hash" 39 | } 40 | 41 | Add an extra line so that the metadata now looks like: 42 | 43 | { 44 | "name" : "", 45 | "signature" : "don't change the hash, but add a comma at the end of the line", 46 | "git" : { "keep_outputs" : true } 47 | } 48 | 49 | You may need to "touch" the notebooks for git to actually register a change, if 50 | your notebooks are already under version control. 51 | 52 | Notes 53 | ===== 54 | 55 | Changed by David Bau to make stripping output the default. 56 | 57 | This script is inspired by http://stackoverflow.com/a/20844506/827862, but 58 | lets the user specify whether the ouptut of a notebook should be kept 59 | in the notebook's metadata, and works for IPython v3.0. 60 | """ 61 | 62 | import sys 63 | import json 64 | 65 | nb = sys.stdin.read() 66 | 67 | json_in = json.loads(nb) 68 | nb_metadata = json_in["metadata"] 69 | keep_output = False 70 | if "git" in nb_metadata: 71 | if "keep_outputs" in nb_metadata["git"] and nb_metadata["git"]["keep_outputs"]: 72 | keep_output = True 73 | if keep_output: 74 | sys.stdout.write(nb) 75 | exit() 76 | 77 | 78 | ipy_version = int(json_in["nbformat"])-1 # nbformat is 1 more than actual version. 79 | 80 | def strip_output_from_cell(cell): 81 | if "outputs" in cell: 82 | cell["outputs"] = [] 83 | if "prompt_number" in cell: 84 | del cell["prompt_number"] 85 | if "execution_count" in cell: 86 | cell["execution_count"] = None 87 | 88 | 89 | if ipy_version == 2: 90 | for sheet in json_in["worksheets"]: 91 | for cell in sheet["cells"]: 92 | strip_output_from_cell(cell) 93 | else: 94 | for cell in json_in["cells"]: 95 | strip_output_from_cell(cell) 96 | 97 | json.dump(json_in, sys.stdout, sort_keys=True, indent=1, separators=(",",": ")) 98 | -------------------------------------------------------------------------------- /notebooks/mit_semseg: -------------------------------------------------------------------------------- 1 | ../mit_semseg -------------------------------------------------------------------------------- /notebooks/setup_notebooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Start from directory of script 4 | cd "$(dirname "${BASH_SOURCE[0]}")" 5 | 6 | # Set up git config filters so huge output of notebooks is not committed. 7 | git config filter.clean_ipynb.clean "$(pwd)/ipynb_drop_output.py" 8 | git config filter.clean_ipynb.smudge cat 9 | git config filter.clean_ipynb.required true 10 | 11 | # Set up symlinks for the example notebooks 12 | for DIRNAME in ckpt mit_semseg config notebooks teaser data 13 | do 14 | ln -sfn ../${DIRNAME} . 15 | done 16 | -------------------------------------------------------------------------------- /notebooks/teaser: -------------------------------------------------------------------------------- 1 | ../teaser -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | pytorch==0.4.1 4 | torchvision 5 | opencv3 6 | yacs 7 | tqdm 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('README.md', 'r') as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name='mit_semseg', 8 | version='1.0.0', 9 | author='MIT CSAIL', 10 | description='Pytorch implementation for Semantic Segmentation/Scene Parsing on MIT ADE20K dataset', 11 | long_description=long_description, 12 | long_description_content_type='text/markdown', 13 | url='https://github.com/CSAILVision/semantic-segmentation-pytorch', 14 | packages=setuptools.find_packages(), 15 | classifiers=( 16 | 'Programming Language :: Python :: 3', 17 | 'License :: OSI Approved :: BSD License', 18 | 'Operating System :: OS Independent', 19 | ), 20 | install_requires=[ 21 | 'numpy', 22 | 'torch>=0.4.1', 23 | 'torchvision', 24 | 'opencv-python', 25 | 'yacs', 26 | 'scipy', 27 | 'tqdm' 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /teaser/ADE_val_00000278.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/teaser/ADE_val_00000278.png -------------------------------------------------------------------------------- /teaser/ADE_val_00001519.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CSAILVision/semantic-segmentation-pytorch/8f27c9b97d2ca7c6e05333d5766d144bf7d8c31b/teaser/ADE_val_00001519.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | import os 3 | import argparse 4 | from distutils.version import LooseVersion 5 | # Numerical libs 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from scipy.io import loadmat 10 | import csv 11 | # Our libs 12 | from mit_semseg.dataset import TestDataset 13 | from mit_semseg.models import ModelBuilder, SegmentationModule 14 | from mit_semseg.utils import colorEncode, find_recursive, setup_logger 15 | from mit_semseg.lib.nn import user_scattered_collate, async_copy_to 16 | from mit_semseg.lib.utils import as_numpy 17 | from PIL import Image 18 | from tqdm import tqdm 19 | from mit_semseg.config import cfg 20 | 21 | colors = loadmat('data/color150.mat')['colors'] 22 | names = {} 23 | with open('data/object150_info.csv') as f: 24 | reader = csv.reader(f) 25 | next(reader) 26 | for row in reader: 27 | names[int(row[0])] = row[5].split(";")[0] 28 | 29 | 30 | def visualize_result(data, pred, cfg): 31 | (img, info) = data 32 | 33 | # print predictions in descending order 34 | pred = np.int32(pred) 35 | pixs = pred.size 36 | uniques, counts = np.unique(pred, return_counts=True) 37 | print("Predictions in [{}]:".format(info)) 38 | for idx in np.argsort(counts)[::-1]: 39 | name = names[uniques[idx] + 1] 40 | ratio = counts[idx] / pixs * 100 41 | if ratio > 0.1: 42 | print(" {}: {:.2f}%".format(name, ratio)) 43 | 44 | # colorize prediction 45 | pred_color = colorEncode(pred, colors).astype(np.uint8) 46 | 47 | # aggregate images and save 48 | im_vis = np.concatenate((img, pred_color), axis=1) 49 | 50 | img_name = info.split('/')[-1] 51 | Image.fromarray(im_vis).save( 52 | os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png'))) 53 | 54 | 55 | def test(segmentation_module, loader, gpu): 56 | segmentation_module.eval() 57 | 58 | pbar = tqdm(total=len(loader)) 59 | for batch_data in loader: 60 | # process data 61 | batch_data = batch_data[0] 62 | segSize = (batch_data['img_ori'].shape[0], 63 | batch_data['img_ori'].shape[1]) 64 | img_resized_list = batch_data['img_data'] 65 | 66 | with torch.no_grad(): 67 | scores = torch.zeros(1, cfg.DATASET.num_class, segSize[0], segSize[1]) 68 | scores = async_copy_to(scores, gpu) 69 | 70 | for img in img_resized_list: 71 | feed_dict = batch_data.copy() 72 | feed_dict['img_data'] = img 73 | del feed_dict['img_ori'] 74 | del feed_dict['info'] 75 | feed_dict = async_copy_to(feed_dict, gpu) 76 | 77 | # forward pass 78 | pred_tmp = segmentation_module(feed_dict, segSize=segSize) 79 | scores = scores + pred_tmp / len(cfg.DATASET.imgSizes) 80 | 81 | _, pred = torch.max(scores, dim=1) 82 | pred = as_numpy(pred.squeeze(0).cpu()) 83 | 84 | # visualization 85 | visualize_result( 86 | (batch_data['img_ori'], batch_data['info']), 87 | pred, 88 | cfg 89 | ) 90 | 91 | pbar.update(1) 92 | 93 | 94 | def main(cfg, gpu): 95 | torch.cuda.set_device(gpu) 96 | 97 | # Network Builders 98 | net_encoder = ModelBuilder.build_encoder( 99 | arch=cfg.MODEL.arch_encoder, 100 | fc_dim=cfg.MODEL.fc_dim, 101 | weights=cfg.MODEL.weights_encoder) 102 | net_decoder = ModelBuilder.build_decoder( 103 | arch=cfg.MODEL.arch_decoder, 104 | fc_dim=cfg.MODEL.fc_dim, 105 | num_class=cfg.DATASET.num_class, 106 | weights=cfg.MODEL.weights_decoder, 107 | use_softmax=True) 108 | 109 | crit = nn.NLLLoss(ignore_index=-1) 110 | 111 | segmentation_module = SegmentationModule(net_encoder, net_decoder, crit) 112 | 113 | # Dataset and Loader 114 | dataset_test = TestDataset( 115 | cfg.list_test, 116 | cfg.DATASET) 117 | loader_test = torch.utils.data.DataLoader( 118 | dataset_test, 119 | batch_size=cfg.TEST.batch_size, 120 | shuffle=False, 121 | collate_fn=user_scattered_collate, 122 | num_workers=5, 123 | drop_last=True) 124 | 125 | segmentation_module.cuda() 126 | 127 | # Main loop 128 | test(segmentation_module, loader_test, gpu) 129 | 130 | print('Inference done!') 131 | 132 | 133 | if __name__ == '__main__': 134 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ 135 | 'PyTorch>=0.4.0 is required' 136 | 137 | parser = argparse.ArgumentParser( 138 | description="PyTorch Semantic Segmentation Testing" 139 | ) 140 | parser.add_argument( 141 | "--imgs", 142 | required=True, 143 | type=str, 144 | help="an image path, or a directory name" 145 | ) 146 | parser.add_argument( 147 | "--cfg", 148 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml", 149 | metavar="FILE", 150 | help="path to config file", 151 | type=str, 152 | ) 153 | parser.add_argument( 154 | "--gpu", 155 | default=0, 156 | type=int, 157 | help="gpu id for evaluation" 158 | ) 159 | parser.add_argument( 160 | "opts", 161 | help="Modify config options using the command-line", 162 | default=None, 163 | nargs=argparse.REMAINDER, 164 | ) 165 | args = parser.parse_args() 166 | 167 | cfg.merge_from_file(args.cfg) 168 | cfg.merge_from_list(args.opts) 169 | # cfg.freeze() 170 | 171 | logger = setup_logger(distributed_rank=0) # TODO 172 | logger.info("Loaded configuration file {}".format(args.cfg)) 173 | logger.info("Running with config:\n{}".format(cfg)) 174 | 175 | cfg.MODEL.arch_encoder = cfg.MODEL.arch_encoder.lower() 176 | cfg.MODEL.arch_decoder = cfg.MODEL.arch_decoder.lower() 177 | 178 | # absolute paths of model weights 179 | cfg.MODEL.weights_encoder = os.path.join( 180 | cfg.DIR, 'encoder_' + cfg.TEST.checkpoint) 181 | cfg.MODEL.weights_decoder = os.path.join( 182 | cfg.DIR, 'decoder_' + cfg.TEST.checkpoint) 183 | 184 | assert os.path.exists(cfg.MODEL.weights_encoder) and \ 185 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!" 186 | 187 | # generate testing image list 188 | if os.path.isdir(args.imgs): 189 | imgs = find_recursive(args.imgs) 190 | else: 191 | imgs = [args.imgs] 192 | assert len(imgs), "imgs should be a path to image (.jpg) or directory." 193 | cfg.list_test = [{'fpath_img': x} for x in imgs] 194 | 195 | if not os.path.isdir(cfg.TEST.result): 196 | os.makedirs(cfg.TEST.result) 197 | 198 | main(cfg, args.gpu) 199 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | import os 3 | import time 4 | # import math 5 | import random 6 | import argparse 7 | from distutils.version import LooseVersion 8 | # Numerical libs 9 | import torch 10 | import torch.nn as nn 11 | # Our libs 12 | from mit_semseg.config import cfg 13 | from mit_semseg.dataset import TrainDataset 14 | from mit_semseg.models import ModelBuilder, SegmentationModule 15 | from mit_semseg.utils import AverageMeter, parse_devices, setup_logger 16 | from mit_semseg.lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback 17 | 18 | 19 | # train one epoch 20 | def train(segmentation_module, iterator, optimizers, history, epoch, cfg): 21 | batch_time = AverageMeter() 22 | data_time = AverageMeter() 23 | ave_total_loss = AverageMeter() 24 | ave_acc = AverageMeter() 25 | 26 | segmentation_module.train(not cfg.TRAIN.fix_bn) 27 | 28 | # main loop 29 | tic = time.time() 30 | for i in range(cfg.TRAIN.epoch_iters): 31 | # load a batch of data 32 | batch_data = next(iterator) 33 | data_time.update(time.time() - tic) 34 | segmentation_module.zero_grad() 35 | 36 | # adjust learning rate 37 | cur_iter = i + (epoch - 1) * cfg.TRAIN.epoch_iters 38 | adjust_learning_rate(optimizers, cur_iter, cfg) 39 | 40 | # forward pass 41 | loss, acc = segmentation_module(batch_data) 42 | loss = loss.mean() 43 | acc = acc.mean() 44 | 45 | # Backward 46 | loss.backward() 47 | for optimizer in optimizers: 48 | optimizer.step() 49 | 50 | # measure elapsed time 51 | batch_time.update(time.time() - tic) 52 | tic = time.time() 53 | 54 | # update average loss and acc 55 | ave_total_loss.update(loss.data.item()) 56 | ave_acc.update(acc.data.item()*100) 57 | 58 | # calculate accuracy, and display 59 | if i % cfg.TRAIN.disp_iter == 0: 60 | print('Epoch: [{}][{}/{}], Time: {:.2f}, Data: {:.2f}, ' 61 | 'lr_encoder: {:.6f}, lr_decoder: {:.6f}, ' 62 | 'Accuracy: {:4.2f}, Loss: {:.6f}' 63 | .format(epoch, i, cfg.TRAIN.epoch_iters, 64 | batch_time.average(), data_time.average(), 65 | cfg.TRAIN.running_lr_encoder, cfg.TRAIN.running_lr_decoder, 66 | ave_acc.average(), ave_total_loss.average())) 67 | 68 | fractional_epoch = epoch - 1 + 1. * i / cfg.TRAIN.epoch_iters 69 | history['train']['epoch'].append(fractional_epoch) 70 | history['train']['loss'].append(loss.data.item()) 71 | history['train']['acc'].append(acc.data.item()) 72 | 73 | 74 | def checkpoint(nets, history, cfg, epoch): 75 | print('Saving checkpoints...') 76 | (net_encoder, net_decoder, crit) = nets 77 | 78 | dict_encoder = net_encoder.state_dict() 79 | dict_decoder = net_decoder.state_dict() 80 | 81 | torch.save( 82 | history, 83 | '{}/history_epoch_{}.pth'.format(cfg.DIR, epoch)) 84 | torch.save( 85 | dict_encoder, 86 | '{}/encoder_epoch_{}.pth'.format(cfg.DIR, epoch)) 87 | torch.save( 88 | dict_decoder, 89 | '{}/decoder_epoch_{}.pth'.format(cfg.DIR, epoch)) 90 | 91 | 92 | def group_weight(module): 93 | group_decay = [] 94 | group_no_decay = [] 95 | for m in module.modules(): 96 | if isinstance(m, nn.Linear): 97 | group_decay.append(m.weight) 98 | if m.bias is not None: 99 | group_no_decay.append(m.bias) 100 | elif isinstance(m, nn.modules.conv._ConvNd): 101 | group_decay.append(m.weight) 102 | if m.bias is not None: 103 | group_no_decay.append(m.bias) 104 | elif isinstance(m, nn.modules.batchnorm._BatchNorm): 105 | if m.weight is not None: 106 | group_no_decay.append(m.weight) 107 | if m.bias is not None: 108 | group_no_decay.append(m.bias) 109 | 110 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) 111 | groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)] 112 | return groups 113 | 114 | 115 | def create_optimizers(nets, cfg): 116 | (net_encoder, net_decoder, crit) = nets 117 | optimizer_encoder = torch.optim.SGD( 118 | group_weight(net_encoder), 119 | lr=cfg.TRAIN.lr_encoder, 120 | momentum=cfg.TRAIN.beta1, 121 | weight_decay=cfg.TRAIN.weight_decay) 122 | optimizer_decoder = torch.optim.SGD( 123 | group_weight(net_decoder), 124 | lr=cfg.TRAIN.lr_decoder, 125 | momentum=cfg.TRAIN.beta1, 126 | weight_decay=cfg.TRAIN.weight_decay) 127 | return (optimizer_encoder, optimizer_decoder) 128 | 129 | 130 | def adjust_learning_rate(optimizers, cur_iter, cfg): 131 | scale_running_lr = ((1. - float(cur_iter) / cfg.TRAIN.max_iters) ** cfg.TRAIN.lr_pow) 132 | cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder * scale_running_lr 133 | cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder * scale_running_lr 134 | 135 | (optimizer_encoder, optimizer_decoder) = optimizers 136 | for param_group in optimizer_encoder.param_groups: 137 | param_group['lr'] = cfg.TRAIN.running_lr_encoder 138 | for param_group in optimizer_decoder.param_groups: 139 | param_group['lr'] = cfg.TRAIN.running_lr_decoder 140 | 141 | 142 | def main(cfg, gpus): 143 | # Network Builders 144 | net_encoder = ModelBuilder.build_encoder( 145 | arch=cfg.MODEL.arch_encoder.lower(), 146 | fc_dim=cfg.MODEL.fc_dim, 147 | weights=cfg.MODEL.weights_encoder) 148 | net_decoder = ModelBuilder.build_decoder( 149 | arch=cfg.MODEL.arch_decoder.lower(), 150 | fc_dim=cfg.MODEL.fc_dim, 151 | num_class=cfg.DATASET.num_class, 152 | weights=cfg.MODEL.weights_decoder) 153 | 154 | crit = nn.NLLLoss(ignore_index=-1) 155 | 156 | if cfg.MODEL.arch_decoder.endswith('deepsup'): 157 | segmentation_module = SegmentationModule( 158 | net_encoder, net_decoder, crit, cfg.TRAIN.deep_sup_scale) 159 | else: 160 | segmentation_module = SegmentationModule( 161 | net_encoder, net_decoder, crit) 162 | 163 | # Dataset and Loader 164 | dataset_train = TrainDataset( 165 | cfg.DATASET.root_dataset, 166 | cfg.DATASET.list_train, 167 | cfg.DATASET, 168 | batch_per_gpu=cfg.TRAIN.batch_size_per_gpu) 169 | 170 | loader_train = torch.utils.data.DataLoader( 171 | dataset_train, 172 | batch_size=len(gpus), # we have modified data_parallel 173 | shuffle=False, # we do not use this param 174 | collate_fn=user_scattered_collate, 175 | num_workers=cfg.TRAIN.workers, 176 | drop_last=True, 177 | pin_memory=True) 178 | print('1 Epoch = {} iters'.format(cfg.TRAIN.epoch_iters)) 179 | 180 | # create loader iterator 181 | iterator_train = iter(loader_train) 182 | 183 | # load nets into gpu 184 | if len(gpus) > 1: 185 | segmentation_module = UserScatteredDataParallel( 186 | segmentation_module, 187 | device_ids=gpus) 188 | # For sync bn 189 | patch_replication_callback(segmentation_module) 190 | segmentation_module.cuda() 191 | 192 | # Set up optimizers 193 | nets = (net_encoder, net_decoder, crit) 194 | optimizers = create_optimizers(nets, cfg) 195 | 196 | # Main loop 197 | history = {'train': {'epoch': [], 'loss': [], 'acc': []}} 198 | 199 | for epoch in range(cfg.TRAIN.start_epoch, cfg.TRAIN.num_epoch): 200 | train(segmentation_module, iterator_train, optimizers, history, epoch+1, cfg) 201 | 202 | # checkpointing 203 | checkpoint(nets, history, cfg, epoch+1) 204 | 205 | print('Training Done!') 206 | 207 | 208 | if __name__ == '__main__': 209 | assert LooseVersion(torch.__version__) >= LooseVersion('0.4.0'), \ 210 | 'PyTorch>=0.4.0 is required' 211 | 212 | parser = argparse.ArgumentParser( 213 | description="PyTorch Semantic Segmentation Training" 214 | ) 215 | parser.add_argument( 216 | "--cfg", 217 | default="config/ade20k-resnet50dilated-ppm_deepsup.yaml", 218 | metavar="FILE", 219 | help="path to config file", 220 | type=str, 221 | ) 222 | parser.add_argument( 223 | "--gpus", 224 | default="0-3", 225 | help="gpus to use, e.g. 0-3 or 0,1,2,3" 226 | ) 227 | parser.add_argument( 228 | "opts", 229 | help="Modify config options using the command-line", 230 | default=None, 231 | nargs=argparse.REMAINDER, 232 | ) 233 | args = parser.parse_args() 234 | 235 | cfg.merge_from_file(args.cfg) 236 | cfg.merge_from_list(args.opts) 237 | # cfg.freeze() 238 | 239 | logger = setup_logger(distributed_rank=0) # TODO 240 | logger.info("Loaded configuration file {}".format(args.cfg)) 241 | logger.info("Running with config:\n{}".format(cfg)) 242 | 243 | # Output directory 244 | if not os.path.isdir(cfg.DIR): 245 | os.makedirs(cfg.DIR) 246 | logger.info("Outputing checkpoints to: {}".format(cfg.DIR)) 247 | with open(os.path.join(cfg.DIR, 'config.yaml'), 'w') as f: 248 | f.write("{}".format(cfg)) 249 | 250 | # Start from checkpoint 251 | if cfg.TRAIN.start_epoch > 0: 252 | cfg.MODEL.weights_encoder = os.path.join( 253 | cfg.DIR, 'encoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch)) 254 | cfg.MODEL.weights_decoder = os.path.join( 255 | cfg.DIR, 'decoder_epoch_{}.pth'.format(cfg.TRAIN.start_epoch)) 256 | assert os.path.exists(cfg.MODEL.weights_encoder) and \ 257 | os.path.exists(cfg.MODEL.weights_decoder), "checkpoint does not exitst!" 258 | 259 | # Parse gpu ids 260 | gpus = parse_devices(args.gpus) 261 | gpus = [x.replace('gpu', '') for x in gpus] 262 | gpus = [int(x) for x in gpus] 263 | num_gpus = len(gpus) 264 | cfg.TRAIN.batch_size = num_gpus * cfg.TRAIN.batch_size_per_gpu 265 | 266 | cfg.TRAIN.max_iters = cfg.TRAIN.epoch_iters * cfg.TRAIN.num_epoch 267 | cfg.TRAIN.running_lr_encoder = cfg.TRAIN.lr_encoder 268 | cfg.TRAIN.running_lr_decoder = cfg.TRAIN.lr_decoder 269 | 270 | random.seed(cfg.TRAIN.seed) 271 | torch.manual_seed(cfg.TRAIN.seed) 272 | 273 | main(cfg, gpus) 274 | --------------------------------------------------------------------------------